banner
Nagi-ovo

Nagi-ovo

Breezing
github

“速通” PPO

近端策略優化

終於到了這幾年 NLP 領域中比較火熱的 RL 演算法之一了

在 On-Policy 演算法中,採集數據用的策略和訓練的策略是相同的,這樣的問題是數據用一次後就得丟棄,然後再重新採集數據,訓練速度很慢。

PPO 背後的直覺#

PPO 的理念是通過限制每個訓練週期對策略的更改來提高策略的訓練穩定性:避免劇烈的策略更新。

Screenshot 2024-10-11 at 13.53.20

這出於兩個原因:

  • 根據這個領域的經驗,訓練中較小的策略更新更有可能收斂到最優解。
  • 策略更新中,過大的步長可能導致 “跌下懸崖”(得到不良策略),並需要很長時間恢復,甚至永遠無法回歸原始水平。

裁剪的替代目標函數#

回顧:策略目標函數#

我們的目標是通過採取梯度上升(或者梯度下降的負函數)來推動 agent 選擇那些能帶來更高獎勵的行為,並避免那些可能帶來負面效果的動作。

LPG(θ)=Et[logπθ(atst)At]L^{PG}(\theta) = \mathbb{E}_t \left[ \log \pi_\theta(a_t | s_t) * A_t \right]
  1. logπθ(atst)\log \pi_\theta(a_t | s_t):在狀態 sts_t 下選擇動作 ata_t 的對數概率,意味著我們在當前策略中採取這個動作的概率有多大。
  2. AtA_t:優勢函數(Advantage),如果 A>0A > 0,說明這個動作比當前狀態下其他可能的動作更好;反之,則較差。
    然而,經典的 PG 方法存在一個問題:策略更新步長的選擇至關重要。
  • 如果步長太小,訓練過程會非常慢;
  • 如果步長太大,訓練中的波動性太大,可能導致訓練不穩定。

於是,PPO 提出了個新方案,裁剪的替代目標函數,它通過裁剪策略變化的範圍,確保策略更新不會太激進,從而保持訓練過程的穩定性。

這個新的目標函數如下:

LCLIP(θ)=E^t[min(rt(θ)At^,clip(rt(θ),1ϵ,1+ϵ)At^)]L^{CLIP}(\theta) = \hat{\mathbb{E}}_t \left[ \min \left( r_t(\theta) \hat{A_t}, \text{clip}\left( r_t(\theta), 1 - \epsilon, 1 + \epsilon \right) \hat{A_t} \right) \right]

比率函數#

其中,關鍵的部分是比率函數 rt(θ)r_t(\theta),它表示當前策略與之前策略之間的動作概率比率:

rt(θ)=πθ(atst)πθold(atst)r_t(\theta) = \frac{\pi_\theta(a_t | s_t)}{\pi_{\theta_{\text{old}}}(a_t | s_t)}

比率反映了當前策略與舊策略的偏差程度:

  • 如果 rt(θ)(0,1)r_{t(\theta)}\in (0, 1) ,則說明在當前策略下,選擇該動作的概率變小。
  • 如果 rt(θ)>1r_t(\theta) > 1,說明在當前策略下,動作 $a_t$ 比之前更有可能被選擇。

未裁剪部分#

公式中的未裁剪部分為:

LCPI(θ)=E^t[πθ(atst)πθold(atst)A^t]=E^t[rt(θ)A^t]L^{CPI}(\theta) = \hat{\mathbb{E}}_t \left[ \frac{\pi_\theta(a_t | s_t)}{\pi_{\theta_{\text{old}}}(a_t | s_t)} \hat{A}_t \right] = \hat{\mathbb{E}}_t \left[ r_t(\theta) \hat{A}_t \right]

在未截斷的目標函數中, rt(θ)\ r_t(\theta) 直接乘以優勢值 A^t\hat{A}_t,如果動作 ata_t 在當前策略下比在舊策略下更加優(即優勢值 A^t>0\hat{A}_t > 0),那麼我們會推崇該動作,反之則會削弱它的影響。這是標準的策略梯度優化方向。

但如前面所提到的,沒有約束的策略更新可能會導致訓練不穩定。如果比率 rt(θ)r_t(\theta) 遠大於 1,策略更新會過大,進而導致訓練過程中難以收斂。

這時,PPO 引入了截斷策略,裁剪比率的範圍

裁剪部分#

LCLIP(θ)=E^t[min(rt(θ)A^t,clip(rt(θ),1ϵ,1+ϵ)A^t)]L^{CLIP}(\theta) = \hat{\mathbb{E}}_t \left[ \min \left( r_t(\theta) \hat{A}_t, \text{clip}\left( r_t(\theta), 1 - \epsilon, 1 + \epsilon \right) \hat{A}_t \right) \right]

在這裡,我們看到 min\min 操作的引入。當比率 rt(θ)r_t(\theta) 超過了設定的閾值 [1ϵ,1+ϵ][1 - \epsilon, 1 + \epsilon] 時,裁剪操作會將比率限制在這個範圍內,從而防止策略更新過大。

裁剪比率函數為:

clip(rt(θ),1ϵ,1+ϵ)\text{clip}\left( r_t(\theta), 1 - \epsilon, 1 + \epsilon \right)

這意味著如果比率 rt(θ)r_t(\theta) 超出設定的區間(原始論文中 ϵ=0.2\epsilon = 0.2),它會被截斷在 [0.8,1.2][0.8, 1.2] 之間,從而保證策略更新的穩定性。我們取截斷後的值和未截斷值之間的最小值,這保證了最終的目標函數不會過分樂觀,而是趨向於一個更加保守的估計。

可視化#

Pasted image 20241013022437

首先記住,我們取裁剪目標和未裁剪目標之間的最小值。

情況 1 和 2:比例在範圍內#

在這兩種情況下,都沒有剪裁,策略會根據 AtA_t 的正負進行相應更新。這是 PPO 的理想狀態,一切都按照預期進行。

  • 情況 1At>0A_t > 0pt(θ)[1ϵ,1+ϵ]p_t(\theta) \in [1 - \epsilon, 1 + \epsilon]

    • 優勢函數 AtA_t 為正,意味著這個動作比預期更好。
    • pt(θ)p_t(\theta) 處於這個範圍內,說明策略變化不大,我們想要鼓勵這個動作,因此不進行剪裁。
    • 結果:目標函數為正,梯度更新會推動策略進一步偏向執行這個動作。
  • 情況 2At<0A_t < 0pt(θ)[1ϵ,1+ϵ]p_t(\theta) \in [1 - \epsilon, 1 + \epsilon]

    • 優勢函數為負,意味著這個動作比預期更差。
    • 同樣,由於比例在範圍內,不進行剪裁。我們希望減少該動作的執行。
    • 結果:目標函數為負,梯度更新會使策略遠離執行這個動作。

情況 3 和 4:比例低於範圍#

這裡比例表明當前策略比舊策略低估了這個動作的概率。會發生什麼呢?

  • 情況 3At>0A_t > 0pt(θ)<1ϵp_t(\theta) < 1 - \epsilon

    • 動作很好(優勢函數為正),但新策略認為這個動作的概率較低。
    • 我們不進行剪裁,因為我們想要 增加 這個優秀動作的概率,允許梯度強烈推動更新。
    • 結果:目標函數為正,梯度鼓勵這個動作。
  • 情況 4At<0A_t < 0pt(θ)<1ϵp_t(\theta) < 1 - \epsilon

    • 動作很差(優勢函數為負),策略已經在減少這個動作的概率。
    • 然而,我們進行剪裁,因為概率已經低於 1ϵ1 - \epsilon,繼續降低可能會過度懲罰,導致訓練不穩定。
    • 結果:目標函數被剪裁,梯度不會再更新,該動作概率保持在下限。

情況 5 和 6:比例高於範圍#

在這裡,策略對動作過於自信,這意味著新策略讓這個動作的執行概率過高。

  • 情況 5At>0A_t > 0pt(θ)>1+ϵp_t(\theta) > 1 + \epsilon

    • 動作很好(優勢函數為正),但新策略過高估計了它的執行概率。
    • 我們進行剪裁,因為我們不希望策略過度偏向這個動作。即使 AtA_t 為正,我們也需要限制策略的更新步幅。
    • 結果:目標函數被剪裁,梯度不更新,我們限制了策略的變化幅度。
  • 情況 6At<0A_t < 0pt(θ)>1+ϵp_t(\theta) > 1 + \epsilon

    • 動作不好,但策略卻讓它的執行概率變得更高。這顯然不是我們想要的。
    • 此時,比率已經超出範圍,我們不進行剪裁。目標函數為負,梯度強烈推策略遠離這個差的動作。
    • 結果:目標函數為負,梯度會使策略遠離這個動作。

為什麼在剪裁的情況下梯度為 0?#

原因在於,當比值 rt(θ)r_t(\theta) 被剪裁到 1ϵ1 - ϵ1+ϵ1 + ϵ 時,導數不再是比值 rt(θ)r_t(\theta) 乘以優勢 AtA_t 的導數,而是 (1ϵ)At(1 - ϵ)A_t(1+ϵ)At(1 + ϵ)A_t 的導數,而這兩個表達式的導數為 0。

總結#

總結一下,PPO 的目標是通過 裁剪的替代目標 限制當前策略與舊策略之間的變化範圍。我們移除了讓概率比值超出 [1ϵ,1+ϵ][1 - ϵ, 1 + ϵ] 區間的激勵,因為一旦比值超出該區間,梯度就會變為 0,策略更新就停止。

在 PPO 更新過程中,我們只在兩種情況下更新策略:

  1. 當比值 rt(θ)r_t(\theta) 落在 [1ϵ,1+ϵ][1 - ϵ, 1 + ϵ] 區間內時。
  2. 比值在區間外,但優勢函數引導比值靠近該區間。

最後複習一下,PPO 的 裁剪的替代目標損失 是由三部分組成:

  • 裁剪的替代目標函數:限制策略更新的變化範圍。
  • 價值損失函數:用來最小化值函數的均方誤差。
  • 熵獎勵:用於保持足夠的探索,以防止策略過早陷入局部最優。

這三部分結合以確保 PPO 既能穩定地更新策略,又能保持足夠的探索性。

代碼實現#

現在讓我們從代碼角度深入理解 PPO 的實現。聚焦在 cleanrl 中 ppo.py最關鍵的部分,用簡潔的方式解釋其工作原理。

1. 策略網絡與值網絡結構#

class Agent(nn.Module):
    def __init__(self, envs):
        super().__init__()
        # 評論家網絡:將狀態映射為價值(用神經網絡估計狀態的好壞程度)
        self.critic = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std=1.0), # 最後一層的初始化對學習穩定性很重要
        )
        # 演員網絡:將狀態映射為動作概率(輸出策略的神經網絡)
        self.actor = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01), # 較小的標準差確保初始策略近似均勻分布
        )

這是一個典型的雙網絡架構:

  • actor (策略網絡) 輸出動作的概率分布
  • critic (值網絡) 預測狀態價值
  • 兩個網絡都採用簡單的兩層 MLP 結構 (64-64)
  • 使用正交初始化 (orthogonal initialization) 來幫助訓練穩定性

2. GAE (廣義優勢估計) 的實現#

# GAE計算:反向遞推計算優勢函數和回報值
with torch.no_grad():
    next_value = agent.get_value(next_obs).reshape(1, -1)
    advantages = torch.zeros_like(rewards).to(device)
    lastgaelam = 0
    for t in reversed(range(args.num_steps)):
        # GAE(廣義優勢估計)的優雅實現
        # 可以理解為時序差分誤差的指數加權和
        delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
        advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
    returns = advantages + values  # 回報值 = 優勢函數 + 價值估計

這段代碼展示了 GAE 的遞歸計算過程:

  • 從後向前計算 TD 誤差 (delta)
  • 用指數加權的方式累積這些 TD 誤差
  • gamma 和 lambda 超參數控制著價值估計的偏差 - 方差權衡

3. PPO 的核心損失函數計算#

# PPO的核心:在防止策略變化過大的同時改進策略
_, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])
ratio = (newlogprob - b_logprobs[mb_inds]).exp()  # 重要性採樣比率

# 著名的PPO-Clip目標函數
pg_loss1 = -mb_advantages * ratio
pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
pg_loss = torch.max(pg_loss1, pg_loss2).mean()  # 悲觀式(最壞情況)策略損失

# 價值函數損失同樣使用截斷來保持接近舊預測
if args.clip_vloss:
    v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
    v_clipped = b_values[mb_inds] + torch.clamp(
        newvalue - b_values[mb_inds],
        -args.clip_coef,
        args.clip_coef,
    )
    v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
    v_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean()

# 綜合損失函數:結合策略損失、價值損失和用於探索的熵獎勵項
loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef

這裡實現了前面提到的三個關鍵組件:

  • 裁剪的替代目標用 max 操作實現截斷
  • 價值函數損失同樣使用了截斷機制 (這是 OpenAI 的實現特色)
  • 熵獎勵項用於鼓勵探索

前段時間平台的 ipfs 崩了因而搁置許久,這篇發出來後 Huggingface DeepRL 系列正式完結~

載入中......
此文章數據所有權由區塊鏈加密技術和智能合約保障僅歸創作者所有。