偷偷摘套内射激情视频,久久精品99国产国产精,中文字幕无线乱码人妻,中文在线中文a,性爽19p

深度強(qiáng)化學(xué)習(xí)中SAC算法:數(shù)學(xué)原理、網(wǎng)絡(luò)架構(gòu)及其PyTorch實(shí)現(xiàn)

人工智能 深度學(xué)習(xí)
深度強(qiáng)化學(xué)習(xí)是人工智能領(lǐng)域最具挑戰(zhàn)性的研究方向之一,其設(shè)計(jì)理念源于生物學(xué)習(xí)系統(tǒng)從經(jīng)驗(yàn)中優(yōu)化決策的機(jī)制。在眾多深度強(qiáng)化學(xué)習(xí)算法中,軟演員-評(píng)論家算法(Soft Actor-Critic, SAC)因其在樣本效率、探索效果和訓(xùn)練穩(wěn)定性等方面的優(yōu)異表現(xiàn)而備受關(guān)注。

深度強(qiáng)化學(xué)習(xí)是人工智能領(lǐng)域最具挑戰(zhàn)性的研究方向之一,其設(shè)計(jì)理念源于生物學(xué)習(xí)系統(tǒng)從經(jīng)驗(yàn)中優(yōu)化決策的機(jī)制。在眾多深度強(qiáng)化學(xué)習(xí)算法中,軟演員-評(píng)論家算法(Soft Actor-Critic, SAC)因其在樣本效率、探索效果和訓(xùn)練穩(wěn)定性等方面的優(yōu)異表現(xiàn)而備受關(guān)注。

傳統(tǒng)的深度強(qiáng)化學(xué)習(xí)算法往往在探索-利用權(quán)衡、訓(xùn)練穩(wěn)定性等方面面臨挑戰(zhàn)。SAC算法通過(guò)引入最大熵強(qiáng)化學(xué)習(xí)框架,在策略?xún)?yōu)化過(guò)程中自動(dòng)調(diào)節(jié)探索程度,有效解決了這些問(wèn)題。其核心創(chuàng)新在于將熵最大化作為策略?xún)?yōu)化的額外目標(biāo),在保證收斂性的同時(shí)維持策略的多樣性。

本文將系統(tǒng)闡述SAC算法的技術(shù)細(xì)節(jié),主要包括:

  1. 基于最大熵框架的SAC算法數(shù)學(xué)原理
  2. 演員網(wǎng)絡(luò)與評(píng)論家網(wǎng)絡(luò)的具體架構(gòu)設(shè)計(jì)
  3. 基于PyTorch的詳細(xì)實(shí)現(xiàn)方案
  4. 網(wǎng)絡(luò)訓(xùn)練的關(guān)鍵技術(shù)要點(diǎn)

SAC算法采用演員-評(píng)論家架構(gòu),演員網(wǎng)絡(luò)負(fù)責(zé)生成動(dòng)作策略,評(píng)論家網(wǎng)絡(luò)評(píng)估動(dòng)作價(jià)值。通過(guò)兩個(gè)網(wǎng)絡(luò)的協(xié)同優(yōu)化,實(shí)現(xiàn)策略的逐步改進(jìn)。整個(gè)訓(xùn)練過(guò)程中,演員網(wǎng)絡(luò)致力于最大化評(píng)論家網(wǎng)絡(luò)預(yù)測(cè)的Q值,同時(shí)保持適度的策略探索;評(píng)論家網(wǎng)絡(luò)則不斷優(yōu)化其Q值估計(jì)的準(zhǔn)確性。

接下來(lái),我們將從演員網(wǎng)絡(luò)的數(shù)學(xué)原理開(kāi)始,詳細(xì)分析SAC算法的各個(gè)技術(shù)組件:

演員(策略)網(wǎng)絡(luò)

演員是由參數(shù)φ確定的策略網(wǎng)絡(luò),表示為:

這是一個(gè)基于狀態(tài)輸出動(dòng)作的隨機(jī)策略。它使用神經(jīng)網(wǎng)絡(luò)估計(jì)均值和對(duì)數(shù)標(biāo)準(zhǔn)差,從而得到給定狀態(tài)下動(dòng)作的分布及其對(duì)數(shù)概率。對(duì)數(shù)概率用于熵正則化,即目標(biāo)函數(shù)中包含一個(gè)用于最大化概率分布廣度(熵)的項(xiàng),以促進(jìn)智能體的探索行為。關(guān)于熵正則化的具體內(nèi)容將在后文詳述。演員網(wǎng)絡(luò)的架構(gòu)如圖所示:

均值μ(s)和對(duì)數(shù)σ(s)用于動(dòng)作采樣:

其中N表示正態(tài)分布。但這個(gè)操作存在梯度不可微的問(wèn)題,需要通過(guò)重參數(shù)化技巧來(lái)解決。

這里d表示動(dòng)作空間維度,每個(gè)分量ε_(tái)i從標(biāo)準(zhǔn)正態(tài)分布(均值0,標(biāo)準(zhǔn)差1)中采樣。應(yīng)用重參數(shù)化技巧:

這樣就解決了梯度截?cái)鄦?wèn)題。接下來(lái)通過(guò)激活函數(shù)將x_t轉(zhuǎn)換為標(biāo)準(zhǔn)化動(dòng)作:

該轉(zhuǎn)換確保動(dòng)作被限制在[-1,1]區(qū)間內(nèi)。

動(dòng)作對(duì)數(shù)概率計(jì)算

完成動(dòng)作計(jì)算后,就可以計(jì)算獎(jiǎng)勵(lì)和預(yù)期回報(bào)。演員的損失函數(shù)中還包含熵正則化項(xiàng),用于最大化分布的廣度。計(jì)算采樣動(dòng)作??t的對(duì)數(shù)概率Log(π?)時(shí),從預(yù)tanh變換x_t開(kāi)始分析更為便利。

由于x_t來(lái)自均值μ(s)和標(biāo)準(zhǔn)差σ(s)的高斯分布,其概率密度函數(shù)(PDF)為:

其中各獨(dú)立分量x_t,i的分布為:

對(duì)兩邊取對(duì)數(shù)可簡(jiǎn)化PDF:

要將其轉(zhuǎn)換為log(π_?),需要考慮x_t到a_t的tanh變換,這可通過(guò)微分鏈?zhǔn)椒▌t實(shí)現(xiàn):

這個(gè)關(guān)系的推導(dǎo)基于概率守恒原理:兩個(gè)變量在給定區(qū)間內(nèi)的概率必須相等:

其中a_i = tanh(x_i)。將區(qū)間縮小到無(wú)窮小的dx和da:

tanh的導(dǎo)數(shù)形式為:

代入得到:

最終可得完整表達(dá)式:

至此完成了演員部分的推導(dǎo),這里有動(dòng)作又有對(duì)數(shù)概率,就可以進(jìn)行損失函數(shù)的計(jì)算。下面是這些數(shù)學(xué)表達(dá)式的PyTorch實(shí)現(xiàn):

import gymnasium as gym  
 from src.utils.logger import logger  
 from src.models.callback import PolicyGradientLossCallback  
 from pydantic import Field, BaseModel, ConfigDict  
 from typing import Dict, List  
 import numpy as np  
 import os  
 from pathlib import Path  
 import torch  
 import torch.nn as nn  
 import torch.optim as optim  
 import torch.nn.functional as F  
 from torch.distributions import Normal  
   
 '''演員網(wǎng)絡(luò):估計(jì)均值和對(duì)數(shù)標(biāo)準(zhǔn)差用于熵正則化計(jì)算'''  
   
 class Actor(nn.Module):  
     def __init__(self,state_dim,action_dim):  
         super(Actor,self).__init__()  
   
         self.net = nn.Sequential(  
             nn.Linear(state_dim, 100),  
             nn.ReLU(),  
             nn.Linear(100,100),  
             nn.ReLU()  
        )  
         self.mean_linear = nn.Linear(100, action_dim)  
         self.log_std_linear = nn.Linear(100, action_dim)  
   
     def forward(self, state):  
         x = self.net(state)  
         mean = self.mean_linear(x)  
         log_std =self.log_std_linear(x)  
         log_std = torch.clamp(log_std, min=-20, max=2)  
         return mean, log_std  
       
     def sample(self, state):  
         mean, log_std = self.forward(state)  
         std = log_std.exp()  
         normal = Normal(mean, std)  
         x_t = normal.rsample() # 重參數(shù)化技巧  
         y_t = torch.tanh(x_t)  
         action = y_t  
         log_prob = normal.log_prob(x_t)  
         log_prob -= torch.log(1-y_t.pow(2)+1e-6)  
         log_prob = log_prob.sum(dim=1, keepdim =True)  
   
         return action, log_prob

在討論損失函數(shù)定義和演員網(wǎng)絡(luò)的訓(xùn)練過(guò)程之前,需要先介紹評(píng)論家網(wǎng)絡(luò)的數(shù)學(xué)原理。

評(píng)論家網(wǎng)絡(luò)

評(píng)論家網(wǎng)絡(luò)的核心功能是估計(jì)狀態(tài)-動(dòng)作對(duì)的預(yù)期回報(bào)(Q值)。這些估計(jì)值在訓(xùn)練過(guò)程中為演員網(wǎng)絡(luò)提供指導(dǎo)。評(píng)論家網(wǎng)絡(luò)采用雙網(wǎng)絡(luò)結(jié)構(gòu),分別提供預(yù)期回報(bào)的兩個(gè)獨(dú)立估計(jì),并選取較小值作為最終估計(jì)。這種設(shè)計(jì)可以有效避免過(guò)度估計(jì)偏差,同時(shí)提升訓(xùn)練穩(wěn)定性。其結(jié)構(gòu)如圖所示:

需要說(shuō)明的是,此時(shí)的示意圖是簡(jiǎn)化版本,主要用于理解演員和評(píng)論家網(wǎng)絡(luò)的基本角色,暫不考慮訓(xùn)練穩(wěn)定性的細(xì)節(jié)。另外,"智能體"實(shí)際上是演員和評(píng)論家網(wǎng)絡(luò)的統(tǒng)稱(chēng)而非獨(dú)立實(shí)體,圖中分開(kāi)表示只是為了清晰展示結(jié)構(gòu)。假設(shè)評(píng)論家網(wǎng)絡(luò)暫不需要訓(xùn)練,因?yàn)檫@樣可以專(zhuān)注于如何利用評(píng)論家網(wǎng)絡(luò)估計(jì)的Q值來(lái)訓(xùn)練演員網(wǎng)絡(luò)。演員網(wǎng)絡(luò)的損失函數(shù)表達(dá)式為:

更常見(jiàn)的形式是:

其中ρD表示狀態(tài)分布。損失函數(shù)通過(guò)對(duì)所有動(dòng)作空間和狀態(tài)空間的熵項(xiàng)與Q值進(jìn)行積分得到。但在實(shí)際應(yīng)用中,無(wú)法直接獲取完整的狀態(tài)分布,因此ρD實(shí)際上是基于重放緩沖區(qū)樣本的經(jīng)驗(yàn)狀態(tài)分布,期望其能較好地表征整體狀態(tài)分布特征。

基于該損失函數(shù)可以通過(guò)反向傳播對(duì)演員網(wǎng)絡(luò)進(jìn)行訓(xùn)練。以下是評(píng)論家網(wǎng)絡(luò)的PyTorch實(shí)現(xiàn):

'''評(píng)論家網(wǎng)絡(luò):定義q1和q2'''  
 class Critic(nn.Module):  
     def __init__(self, state_dim, action_dim):  
         super(Critic, self).__init__()  
   
         # Q1網(wǎng)絡(luò)架構(gòu)  
         self.q1_net = nn.Sequential(  
             nn.Linear(state_dim + action_dim, 256),  
             nn.ReLU(),  
             nn.Linear(256, 256),  
             nn.ReLU(),  
             nn.Linear(256, 1),  
        )  
   
         # Q2網(wǎng)絡(luò)架構(gòu)  
         self.q2_net = nn.Sequential(  
             nn.Linear(state_dim + action_dim, 256),  
             nn.ReLU(),  
             nn.Linear(256, 256),  
             nn.ReLU(),  
             nn.Linear(256, 1),  
        )  
   
     def forward(self, state, action):  
         sa = torch.cat([state, action], dim=1)  
         q1 = self.q1_net(sa)  
         q2 = self.q2_net(sa)  
         return q1, q2

前述內(nèi)容尚未涉及評(píng)論家網(wǎng)絡(luò)自身的訓(xùn)練機(jī)制。從重放緩沖區(qū)采樣的每個(gè)數(shù)據(jù)點(diǎn)包含[s_t, s_{t+1}, a_t, R]。對(duì)于狀態(tài)-動(dòng)作對(duì)的Q值,我們可以獲得兩種不同的估計(jì)。

第一種方法是直接將a_t和s_t輸入評(píng)論家網(wǎng)絡(luò):

第二種方法是基于貝爾曼方程:

這種方法使用s_t+1、a_t+1以及執(zhí)行動(dòng)作a_t獲得的獎(jiǎng)勵(lì)來(lái)重新估計(jì)。這里使用目標(biāo)網(wǎng)絡(luò)而非第一種方法中的評(píng)論家網(wǎng)絡(luò)進(jìn)行估計(jì)。采用目標(biāo)評(píng)論家網(wǎng)絡(luò)的主要目的是解決訓(xùn)練不穩(wěn)定性問(wèn)題。如果同一個(gè)評(píng)論家網(wǎng)絡(luò)同時(shí)用于生成當(dāng)前狀態(tài)和下一狀態(tài)的Q值(用于目標(biāo)Q值),這種耦合會(huì)導(dǎo)致網(wǎng)絡(luò)更新在目標(biāo)計(jì)算的兩端產(chǎn)生不一致的傳播,從而引起訓(xùn)練不穩(wěn)定。因此引入獨(dú)立的目標(biāo)網(wǎng)絡(luò)為下一狀態(tài)的Q值提供穩(wěn)定估計(jì)。目標(biāo)網(wǎng)絡(luò)作為評(píng)論家網(wǎng)絡(luò)的緩慢更新版本,確保目標(biāo)Q值能夠平穩(wěn)演化。具體結(jié)構(gòu)如圖所示:

評(píng)論家網(wǎng)絡(luò)的損失函數(shù)定義為:

通過(guò)該損失函數(shù)可以利用反向傳播更新評(píng)論家網(wǎng)絡(luò),而目標(biāo)網(wǎng)絡(luò)則采用軟更新機(jī)制:

其中ε是一個(gè)較小的常數(shù),用于限制目標(biāo)評(píng)論家的更新幅度,從而維持訓(xùn)練穩(wěn)定性。

完整流程

以上內(nèi)容完整闡述了SAC智能體的各個(gè)組件。下圖展示了完整SAC智能體的結(jié)構(gòu)及其計(jì)算流程:

下面是一個(gè)綜合了前述演員網(wǎng)絡(luò)、評(píng)論家網(wǎng)絡(luò)及其更新機(jī)制的完整SAC智能體實(shí)現(xiàn)

'''SAC智能體的實(shí)現(xiàn):整合演員網(wǎng)絡(luò)和評(píng)論家網(wǎng)絡(luò)'''  
   
 class SACAgent:  
     def __init__(self, state_dim, action_dim, learning_rate, device):  
         self.device = device  
   
         self.actor = Actor(state_dim, action_dim).to(device)  
         self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=learning_rate)  
   
         self.critic = Critic(state_dim, action_dim).to(device)  
         self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=learning_rate)  
   
         # 目標(biāo)網(wǎng)絡(luò)初始化  
         self.critic_target = Critic(state_dim, action_dim).to(device)  
         self.critic_target.load_state_dict(self.critic.state_dict())  
   
         # 熵溫度參數(shù)  
         self.target_entropy = -action_dim  
         self.log_alpha = torch.zeros(1, requires_grad=True, device=device)  
         self.alpha_optimizer = optim.Adam([self.log_alpha], lr=learning_rate)  
   
     def select_action(self, state, evaluate=False):  
         state = torch.FloatTensor(state).to(self.device).unsqueeze(0)  
         if evaluate:  
             with torch.no_grad():  
                 mean, _ = self.actor(state)  
                 action = torch.tanh(mean)  
                 return action.cpu().numpy().flatten()  
         else:  
             with torch.no_grad():  
                 action, _ = self.actor.sample(state)  
                 return action.cpu().numpy().flatten()  
   
     def update(self, replay_buffer, batch_size=256, gamma=0.99, tau=0.005):  
         # 從經(jīng)驗(yàn)回放中采樣訓(xùn)練數(shù)據(jù)  
         batch = replay_buffer.sample_batch(batch_size)  
         state = torch.FloatTensor(batch['state']).to(self.device)  
         action = torch.FloatTensor(batch['action']).to(self.device)  
         reward = torch.FloatTensor(batch['reward']).to(self.device)  
         next_state = torch.FloatTensor(batch['next_state']).to(self.device)  
         done = torch.FloatTensor(batch['done']).to(self.device)  
   
         # 評(píng)論家網(wǎng)絡(luò)更新  
         with torch.no_grad():  
             next_action, next_log_prob = self.actor.sample(next_state)  
             q1_next, q2_next = self.critic_target(next_state, next_action)  
             q_next = torch.min(q1_next, q2_next) - torch.exp(self.log_alpha) * next_log_prob  
             target_q = reward + (1 - done) * gamma * q_next  
   
         q1_current, q2_current = self.critic(state, action)  
         critic_loss = F.mse_loss(q1_current, target_q) + F.mse_loss(q2_current, target_q)  
   
         self.critic_optimizer.zero_grad()  
         critic_loss.backward()  
         self.critic_optimizer.step()  
   
         # 演員網(wǎng)絡(luò)更新  
         action_new, log_prob = self.actor.sample(state)  
         q1_new, q2_new = self.critic(state, action_new)  
         q_new = torch.min(q1_new, q2_new)  
         actor_loss = (torch.exp(self.log_alpha) * log_prob - q_new).mean()  
   
         self.actor_optimizer.zero_grad()  
         actor_loss.backward()  
         self.actor_optimizer.step()  
   
         # 溫度參數(shù)更新  
         alpha_loss = -(self.log_alpha * (log_prob + self.target_entropy).detach()).mean()  
   
         self.alpha_optimizer.zero_grad()  
         alpha_loss.backward()  
         self.alpha_optimizer.step()  
   
         # 目標(biāo)網(wǎng)絡(luò)軟更新  
         for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):  
             target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

總結(jié)

本文系統(tǒng)地闡述了SAC算法的數(shù)學(xué)基礎(chǔ)和實(shí)現(xiàn)細(xì)節(jié)。通過(guò)對(duì)演員網(wǎng)絡(luò)和評(píng)論家網(wǎng)絡(luò)的深入分析,我們可以看到SAC算法在以下幾個(gè)方面具有顯著優(yōu)勢(shì):

理論框架

  • 基于最大熵強(qiáng)化學(xué)習(xí)的理論基礎(chǔ)保證了算法的收斂性
  • 雙Q網(wǎng)絡(luò)設(shè)計(jì)有效降低了值函數(shù)估計(jì)的過(guò)度偏差
  • 自適應(yīng)溫度參數(shù)實(shí)現(xiàn)了探索-利用的動(dòng)態(tài)平衡

實(shí)現(xiàn)特點(diǎn)

  • 采用重參數(shù)化技巧確保了策略梯度的連續(xù)性
  • 軟更新機(jī)制提升了訓(xùn)練穩(wěn)定性
  • 基于PyTorch的向量化實(shí)現(xiàn)提高了計(jì)算效率

實(shí)踐價(jià)值

  • 算法在連續(xù)動(dòng)作空間中表現(xiàn)優(yōu)異
  • 樣本效率高,適合實(shí)際應(yīng)用場(chǎng)景
  • 訓(xùn)練過(guò)程穩(wěn)定,調(diào)參難度相對(duì)較小

未來(lái)研究可以在以下方向繼續(xù)深化:

  • 探索更高效的策略表達(dá)方式
  • 研究多智能體場(chǎng)景下的SAC算法擴(kuò)展
  • 結(jié)合遷移學(xué)習(xí)提升算法的泛化能力
  • 針對(duì)大規(guī)模狀態(tài)空間優(yōu)化網(wǎng)絡(luò)架構(gòu)

強(qiáng)化學(xué)習(xí)作為人工智能的核心研究方向之一,其理論體系和應(yīng)用場(chǎng)景都在持續(xù)發(fā)展。深入理解算法的數(shù)學(xué)原理和實(shí)現(xiàn)細(xì)節(jié),將有助于我們?cè)谶@個(gè)快速演進(jìn)的領(lǐng)域中把握技術(shù)本質(zhì),開(kāi)發(fā)更有效的解決方案。

責(zé)任編輯:華軒 來(lái)源: DeepHub IMBA
相關(guān)推薦

2019-09-29 10:42:02

人工智能機(jī)器學(xué)習(xí)技術(shù)

2024-10-12 17:14:12

2017-08-22 15:56:49

神經(jīng)網(wǎng)絡(luò)強(qiáng)化學(xué)習(xí)DQN

2023-06-25 11:30:47

可視化

2020-08-10 06:36:21

強(qiáng)化學(xué)習(xí)代碼深度學(xué)習(xí)

2022-05-31 10:45:01

深度學(xué)習(xí)防御

2023-12-03 22:08:41

深度學(xué)習(xí)人工智能

2023-03-23 16:30:53

PyTorchDDPG算法

2021-09-17 15:54:41

深度學(xué)習(xí)機(jī)器學(xué)習(xí)人工智能

2024-09-05 08:23:58

2025-05-28 02:25:00

2025-03-03 01:00:00

DeepSeekGRPO算法

2022-04-22 12:36:11

RNN神經(jīng)網(wǎng)絡(luò))機(jī)器學(xué)習(xí)

2023-01-24 17:03:13

強(qiáng)化學(xué)習(xí)算法機(jī)器人人工智能

2022-09-04 14:38:00

世界模型建模IRIS

2020-05-12 07:00:00

深度學(xué)習(xí)強(qiáng)化學(xué)習(xí)人工智能

2022-11-02 14:02:02

強(qiáng)化學(xué)習(xí)訓(xùn)練

2025-01-09 15:57:41

2019-01-15 13:14:03

機(jī)器人算法SAC

2025-03-11 01:00:00

GRPO算法模型
點(diǎn)贊
收藏

51CTO技術(shù)棧公眾號(hào)