banner
Nagi-ovo

Nagi-ovo

Breezing homepage: nagi.fun
github

從 RL 來,到 RLHF 去

本文主要基於 Umar Jamil 的課程[1]^{[1]}進行學習和記錄。我們的目標是讓 LLM 的行為與我們期望的輸出相一致,RLHF 則是最著名的技術之一。其標準流程涉及四個模型(聽上去就很佔顯存,所以很多方法是去掉部分模型),這裡只需記得一共需要四個即可:Reward、Actor、Critic 和 Reference Model,我們最後優化得到的模型是這裡說的 Actor Model。

LLM to RL#

以前對 RL 的認識中,策略是告訴你在當前 State 下把你應該採取的 Action 的概率的東西,那這麼說來,語言模型本身就可以看作一個 Policy:接收一個 Prompt (state),輸出下一個 token (action) 的概率,採樣後得到一個新 state(token 被拼到 prompt 後),即相當於一個有 vocab_size 大小 Action Space 的 Policy,也是個 RL Agent。

那這麼說,還差一個提供 Reward 的東西(傳統 RL 中一般是環境內置的獎勵函數)

做一個 “Q-A-Reward” 的數據集可以實現這點,但是人類並不擅長尋找共識,但在比較優劣這點卻很擅長。所以我們把方向轉為:模型在 High Temperature 下 generate 多個 A,然後請領域專家(可以是人也可以是 AI Model)來選擇出 Chosen / Prefer 的答案,標註出一個偏好數據集,用此訓練出一個生成數值獎勵的 Reward Model。

Reward Model#

這個 RM 是用一個預訓練 LLM 如 Llama 來實現的。

Note

在文本生成任務中,我們取 prompt 輸入 Transformer 後產生的 Embedding (Hidden States) 的最後一個 (token 的) Hidden State 送入 Linear 投影到詞表中得到 logits,然後用 Softmax 和採樣策略來選擇 next token。

當我們不想生成文本而是生成數值獎勵時,可以投影到詞彙表中的 Linear 替換為一個 one output feature (輸出一個標量) 的 Linear,用來產生整個文本序列的單一評分值。

Screenshot 2025-04-23 at 21.19.56

Reward Model Loss#

Tip

訓練時,我們要讓這個模型為選擇的答案生成高獎勵,為未被選擇的答案生成低獎勵

類似 Bradley-Terry 的參數化形式:

Loss=logσ(r(x,yw)r(x,yl))Loss = -\log \sigma(r(x, y_w) - r(x, y_l))

ywy_w 表示 Chosen,yly_l 反之。因此當模型為 chosen 給出高 reward 時,

r(x,yw)r(x,yl)>0r(x, y_w) - r(x, y_l)>0

σ(r(x,yw)r(x,yl))(0.5,1)\sigma(r(x, y_w) - r(x, y_{l))}\in(0.5,1)

logσ(r(x,yw)r(x,yl))(0,1)-\log \sigma(r(x, y_w) - r(x, y_l))\in(0,1)

這樣損失會低,而模型為選擇的答案給出低 reward 時損失會很高。

image image

HuggingFace 中 RewardTrainer 類接收一個 AutoModelForSequenceClassification 輸入(即我們上面提到的模型結構)

Screenshot 2025-04-23 at 22.02.29

Actor & Critic Model#

Trajectories#

如前所述,強化學習(RL)的核心目標是找到一個策略(policy, π\pi),該策略能指導 agent 的行動,以獲得最大可能的期望回報(expected return)。

數學上,我們將其表示為找到最大化目標函數 J(π)J(\pi) 的最優策略 π\pi^*

π=argmaxπJ(π)\pi^* = \arg \max_{\pi} J(\pi)

期望回報 J(π)J(\pi) 代表了智能體遵循策略 π\pi 時,在許多可能的生命周期或回合(episodes)中,預期能累積到的平均總回報。

它的計算方法是:考慮所有可能發生的軌跡(trajectory, τ\tau),並將每個軌跡的總回報 R(τ)R(\tau) 乘以該軌跡在策略 π\pi 下發生的概率 P(τπ)P(\tau|\pi) 進行加權平均(或積分)。

J(π)=P(τπ)R(τ)=Eτπ[R(τ)]J(\pi) = \int P(\tau|\pi) R(\tau) = E_{\tau \sim \pi} [R(\tau)]
  • Eτπ[]E_{\tau \sim \pi}[\cdot] 表示當軌跡 τ\tau 是根據策略 π\pi 生成時的期望值(Expected Value)。
  • R(τ)R(\tau) 是在單個軌跡 τ\tau 上獲得的總回報(獎勵)。
  • P(τπ)P(\tau|\pi) 是當智能體使用策略 π\pi 時,特定軌跡 τ\tau 發生的概率。

軌跡 τ\tau 就是 agent 經歷的一系列狀態和動作的序列,從初始狀態開始。它是 agent 與環境交互的一種可能的 “故事” 或 “路徑”。

τ=(s0,a0,s1,a1,s2,a2,)\tau = (s_0, a_0, s_1, a_1, s_2, a_2, \dots)
  • sts_t:時間步 tt 的狀態(State)。
  • ata_t:在時間步 tt 採取的動作(Action)(通常基於狀態 sts_t 和策略 π\pi)。

我們通常將環境建模為隨機的(stochastic)。這意味著在同一個狀態 sts_t 下執行相同的動作 ata_t,並不總會導致完全相同的下一個狀態 st+1s_{t+1}。其中涉及到隨機性。

下一個狀態 st+1s_{t+1} 是從一個以當前狀態 sts_t 和採取的動作 ata_t 為條件的概率分佈中抽取的:

st+1P(st,at)s_{t+1} \sim P(\cdot | s_t, a_t)

考慮到隨機性的狀態轉移和 Agent 的策略,我們可以計算整個軌跡發生的概率。它由以下各項連乘得到:

  1. Agent 在初始狀態 s0s_0 的概率:p0(s0)p_0(s_0)
  2. 對於軌跡中的每一個時間步 tt
    • 環境在給定 sts_tata_t 的條件下轉移到狀態 st+1s_{t+1} 的概率:P(st+1st,at)P(s_{t+1}|s_t, a_t)
    • 智能體根據其策略在狀態 sts_t 選擇動作 ata_t 的概率:π(atst)\pi(a_t|s_t)
P(τπ)=p0(s0)t=0T1P(st+1st,at)π(atst)P(\tau|\pi) = p_0(s_0) \prod_{t=0}^{T-1} P(s_{t+1}|s_t, a_t) \pi(a_t|s_t)

(其中 TT 是軌跡的長度)。

在計算軌跡的總回報 R(τ)R(\tau) 時,我們幾乎總是使用折扣回報(discounted rewards)。這意味著較早收到的回報相比較晚收到的回報更有價值。

為什麼?

  • 反映了現實場景(今天到手的一美元比明天畫餅的一美元更有價值)。
  • 在持續性任務(沒有固定終點的任務)中避免了無限回報的問題。
  • 提供了數學上的便利性。

我們引入一個折扣因子(discount factor)γ\gamma,其中 0γ<10 \le \gamma < 1γ\gamma 越接近 0,Agent 越 “短視”(更關注眼前利益);γ\gamma 越接近 1,Agent 越 “有遠見”(更關注長期回報)。

軌跡的總折扣回報計算如下:

R(τ)=t=0γtrtR(\tau) = \sum_{t=0}^{\infty} \gamma^t r_t
  • rtr_t 是在時間步 tt 收到的即時回報(immediate reward)。
  • γt\gamma^t 是應用於時間步 tt 回報的折扣系數。

那在 LLM 中,軌跡是什麼呢?前面提到了,模型是策略,prompt 是狀態,next token 是 action,因此自回歸生成中的這些 s, a 序列組成了軌跡。

Screenshot 2025-04-24 at 01.32.33

Policy Gradient#

我們確定了強化學習的目標:找到一個最優策略 ππ^∗ 來最大化期望回報 J(π)J(π)。很好。但我們實際上如何表示並找到這個策略呢?

通常,尤其是在處理複雜問題時,我們不會去搜索所有可能的策略。我們會定義一個參數化策略(parameterized policy),記作
πθπ_θ。你可以把 θθ 想象成一組 “旋鈕” 或參數 —— 如果我們的策略是一個神經網絡,那 θθ 可能就是網絡的權重和偏置。

我們現在的目標就變成了:如何調整這些旋鈕 θθ 來最大化我們的期望回報?

Note

在參數為 θθ 的策略 πθπ_θ​ 下,所有可能軌跡的期望回報:

J(πθ)=Eτπθ[R(τ)]J(\pi_{\theta}) = E_{\tau \sim \pi_{\theta}} [R(\tau)]

這意味著期望回報依賴於軌跡 τ\tau,而軌跡的分佈又依賴於我們特定策略 πθ\pi_{\theta} 所選擇的動作。改變 θ\theta,就改變了策略,改變了軌跡,也就改變了期望回報。

Note

我們想通過改變 θ\theta 來最大化 J(πθ)J(\pi_{\theta})。在深度學習中一般使用梯度下降(gradient descent)來最小化一個損失函數。而在這裡,我們想要最大化一個函數JJ。所以,我們反過來使用梯度上升(gradient ascent)!這就像爬山 —— 想找到最陡峭的上升方向(也就是梯度),然後朝著那個方向邁出一步。

我們的策略 πθ\pi_{\theta} 是一個神經網絡,我們會迭代地調整其參數 θ\theta 來增加 J(πθ)J(\pi_{\theta})。這個更新規則看起來會非常熟悉(只是把梯度下降裡的減號換成了加號):

θk+1=θk+αθJ(πθ)θk\theta_{k+1} = \theta_k + \alpha \nabla_{\theta} J(\pi_{\theta})|_{\theta_k}
  • θk\theta_k:我們在第 kk 次迭代時的參數。
  • α\alpha:學習率(步長大小)。
  • θJ(πθ)θk\nabla_{\theta} J(\pi_{\theta})|_{\theta_k}:期望回報 JJ 相對於參數 θ\theta梯度(gradient),在當前參數 θk\theta_k 處計算。它告訴我們在參數空間中哪個方向能最大程度地增加 JJ

Screenshot 2025-04-24 at 13.51.52

Important

PG 推導
這裡一開始會囉嗦一點來把角標重新都介紹一遍,ADHD 友好型推導...

第一步,重申我們要求梯度的對象是期望回報 J(πθ)J(\pi_{\theta})

θJ(πθ)=θEτπθ[R(τ)]\nabla_{\theta} J(\pi_{\theta}) = \nabla_{\theta} E_{\tau \sim \pi_{\theta}} [R(\tau)]

這裡:

  • J(πθ)J(\pi_{\theta}) 就是期望回報。
  • Eτπθ[]E_{\tau \sim \pi_{\theta}} [\cdot] 表示期望值,這個期望是針對所有可能的軌跡 (trajectory, τ\tau) 來計算的。軌跡 τ\tau 是 Agent 與環境交互產生的一系列狀態和動作 (s0,a0,s1,a1,)(s_0, a_0, s_1, a_1, \dots)
  • τπθ\tau \sim \pi_{\theta} 表示這些軌跡是根據我們當前的策略 πθ\pi_{\theta} 生成的。
  • R(τ)R(\tau) 是指一條完整軌跡 τ\tau 所獲得的總回報(通常是折扣回報)。
  • θ\nabla_{\theta} 是梯度算子,表示我們要對參數 θ\theta 求偏導數。

第二步,我們來展開期望的表達式:
期望值的定義是什麼?對於一個隨機變量 XX,它的期望 E[X]E[X] 可以通過它的概率分佈 p(x)p(x) 來計算:

  • 如果是連續變量:E[X]=p(x)xdxE[X] = \int p(x) x dx
  • 如果是離散變量:E[X]=p(x)xE[X] = \sum p(x) x

在我們的例子裡,隨機變量是軌跡的回報 R(τ)R(\tau),概率分佈是軌跡發生的概率 P(τπθ)P(\tau|\pi_{\theta})(給定策略 πθ\pi_{\theta} 下軌跡 τ\tau 發生的概率)。所以,期望可以寫成積分(或求和)的形式:

Eτπθ[R(τ)]=P(τπθ)R(τ)dτE_{\tau \sim \pi_{\theta}} [R(\tau)] = \int P(\tau|\pi_{\theta}) R(\tau) d\tau

(這裡用積分符號 \int 代表對所有可能的軌跡求和或積分,更通用)。
把這個代入第一步的公式:

θJ(πθ)=θP(τπθ)R(τ)dτ\nabla_{\theta} J(\pi_{\theta}) = \nabla_{\theta} \int P(\tau|\pi_{\theta}) R(\tau) d\tau


第三步:把梯度算子移到積分號裡面

θP(τπθ)R(τ)dτ=θ[P(τπθ)R(τ)]dτ\nabla_{\theta} \int P(\tau|\pi_{\theta}) R(\tau) d\tau = \int \nabla_{\theta} [P(\tau|\pi_{\theta}) R(\tau)] d\tau

這裡需要一點微積分知識:在滿足一定條件下(通常我們假設在強化學習中是滿足的),我們可以交換求導和積分的順序。就像 ddxfi(x)=ddxfi(x)\frac{d}{dx} \sum f_i(x) = \sum \frac{d}{dx} f_i(x) 一樣。
接著,注意到 R(τ)R(\tau)一條軌跡確定後的總回報,它本身的值不直接依賴於策略參數 θ\theta。(是策略 πθ\pi_{\theta} 影響了哪條軌跡會發生,而不是這條軌跡一旦發生後它的回報值是多少)。所以,梯度 θ\nabla_{\theta} 只需要作用在 P(τπθ)P(\tau|\pi_{\theta}) 上:

=[θP(τπθ)]R(τ)dτ= \int [\nabla_{\theta} P(\tau|\pi_{\theta})] R(\tau) d\tau

這一步告訴我們,期望回報的變化,是由於參數 θ\theta 改變導致每條軌跡發生的概率 P(τπθ)P(\tau|\pi_{\theta}) 變化,再乘以該軌跡本身的回報 R(τ)R(\tau),然後對所有軌跡累加起來的效果。


第四步:對數導數技巧 (Log-derivative trick)
這是整個推導中最核心、最巧妙的一步!我們需要引入一個恆等式。

  • 高數復習 (鏈式法則與對數求導): 回憶一下自然對數 log(x)\log(x) (通常指 ln(x)\ln(x)) 的導數是 ddxlog(f(x))=1f(x)df(x)dx=f(x)f(x)\frac{d}{dx} \log(f(x)) = \frac{1}{f(x)} \frac{d f(x)}{dx} = \frac{f'(x)}{f(x)}
  • 稍微變形一下,我們就得到:f(x)=f(x)ddxlog(f(x))f'(x) = f(x) \frac{d}{dx} \log(f(x))
    現在,我們把這個技巧應用到梯度上。令 f(x)f(x) 對應 P(τπθ)P(\tau|\pi_{\theta}),自變量 xx 對應參數 θ\theta。那麼:

θP(τπθ)=P(τπθ)θlogP(τπθ)\nabla_{\theta} P(\tau|\pi_{\theta}) = P(\tau|\pi_{\theta}) \nabla_{\theta} \log P(\tau|\pi_{\theta})

把這個結果代入第三步的積分中:
[θP(τπθ)]R(τ)dτ=[P(τπθ)θlogP(τπθ)]R(τ)dτ\int [\nabla_{\theta} P(\tau|\pi_{\theta})] R(\tau) d\tau = \int [P(\tau|\pi_{\theta}) \nabla_{\theta} \log P(\tau|\pi_{\theta})] R(\tau) d\tau


第五步:重新變回期望形式
觀察第四步的結果:

P(τπθ)[θlogP(τπθ)R(τ)]dτ\int P(\tau|\pi_{\theta}) [\nabla_{\theta} \log P(\tau|\pi_{\theta}) R(\tau)] d\tau

這又符合期望的定義了! E[f(τ)]=P(τπθ)f(τ)dτE[f(\tau)] = \int P(\tau|\pi_{\theta}) f(\tau) d\tau
這裡, f(τ)f(\tau) 就對應方括號裡的全部內容 [θlogP(τπθ)R(τ)][\nabla_{\theta} \log P(\tau|\pi_{\theta}) R(\tau)]
所以,整個積分可以寫回期望的形式:

=Eτπθ[θlogP(τπθ)R(τ)]= E_{\tau \sim \pi_{\theta}} [\nabla_{\theta} \log P(\tau|\pi_{\theta}) R(\tau)]

重大意義! 我們成功地把期望的梯度 θE[]\nabla_{\theta} E[\cdot] 轉換成了某個量(梯度乘以回報)的期望 E[()×R]E[\nabla (\cdot) \times R]。這個形式非常重要,因為它可以通過採樣來近似!我們不需要真的去計算所有軌跡的積分了。只需要採樣很多軌跡 τ\tau,對於每一條軌跡,計算括號裡的值 [θlogP(τπθ)R(τ)][\nabla_{\theta} \log P(\tau|\pi_{\theta}) R(\tau)],然後求平均,就可以得到梯度的近似值!


第六步:展開對數概率的梯度 (Expression for grad-log-prob)
現在,我們需要處理一下期望裡的 θlogP(τπθ)\nabla_{\theta} \log P(\tau|\pi_{\theta}) 這一項。
回憶一下軌跡 τ=(s0,a0,s1,a1,,sT,aT)\tau = (s_0, a_0, s_1, a_1, \dots, s_T, a_T) (假設軌跡長度為 T+1 個狀態和 T+1 個動作,或者 T 個時間步)。一條軌跡發生的概率是:

P(τπθ)=p0(s0)t=0TP(st+1st,at)πθ(atst)P(\tau|\pi_{\theta}) = p_0(s_0) \prod_{t=0}^{T} P(s_{t+1}|s_t, a_t) \pi_{\theta}(a_t|s_t)

  • p0(s0)p_0(s_0):初始狀態 s0s_0 的概率。

  • P(st+1st,at)P(s_{t+1}|s_t, a_t):環境動力學 (Environment Dynamics),在狀態 sts_t 執行動作 ata_t 後轉移到狀態 st+1s_{t+1} 的概率。

  • πθ(atst)\pi_{\theta}(a_t|s_t):策略,在狀態 sts_t 選擇動作 ata_t 的概率(這部分依賴於 θ\theta)。

  • 數學復習 (對數性質): log(a×b)=loga+logb\log(a \times b) = \log a + \log b 並且 log(ixi)=ilogxi\log(\prod_{i} x_i) = \sum_{i} \log x_i
    P(τπθ)P(\tau|\pi_{\theta}) 取對數:

logP(τπθ)=logp0(s0)+t=0T[logP(st+1st,at)+logπθ(atst)]\log P(\tau|\pi_{\theta}) = \log p_0(s_0) + \sum_{t=0}^{T} [\log P(s_{t+1}|s_t, a_t) + \log \pi_{\theta}(a_t|s_t)]

現在對上式關於 θ\theta 求梯度 θ\nabla_{\theta}

  • 數學復習 (梯度性質): 梯度的加法法則 (f+g)=f+g\nabla(f+g) = \nabla f + \nabla g。梯度 θ\nabla_{\theta} 只對依賴於 θ\theta 的項有作用。

θlogP(τπθ)=θlogp0(s0)+t=0T[θlogP(st+1st,at)+θlogπθ(atst)]\nabla_{\theta} \log P(\tau|\pi_{\theta}) = \nabla_{\theta} \log p_0(s_0) + \sum_{t=0}^{T} [\nabla_{\theta} \log P(s_{t+1}|s_t, a_t) + \nabla_{\theta} \log \pi_{\theta}(a_t|s_t)]

  • 關鍵點:
    • 初始狀態概率 logp0(s0)\log p_0(s_0) 通常不依賴於策略參數 θ\theta,所以 θlogp0(s0)=0\nabla_{\theta} \log p_0(s_0) = 0
    • 環境動力學 logP(st+1st,at)\log P(s_{t+1}|s_t, a_t) 描述的是環境本身的性質,也不依賴於策略參數 θ\theta,所以 θlogP(st+1st,at)=0\nabla_{\theta} \log P(s_{t+1}|s_t, a_t) = 0
    • 只有策略 logπθ(atst)\log \pi_{\theta}(a_t|s_t) 依賴於 θ\theta
      所以,上式簡化為:

θlogP(τπθ)=t=0Tθlogπθ(atst)\nabla_{\theta} \log P(\tau|\pi_{\theta}) = \sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_t|s_t)

整條軌跡的對數概率的梯度,等於這條軌跡中每一步動作的對數概率梯度之和!這大大簡化了計算。


第七步:最終的策略梯度定理
把第六步簡化後的結果代回到第五步的期望公式中:

θJ(πθ)=Eτπθ[(t=0Tθlogπθ(atst))R(τ)]\nabla_{\theta} J(\pi_{\theta}) = E_{\tau \sim \pi_{\theta}} [(\sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_t|s_t)) R(\tau)]

這就是策略梯度定理 (Policy Gradient Theorem) 的最終形式(或者說其中一種常見形式)。
期望回報 J(πθ)J(\pi_{\theta}) 對參數 θ\theta 的梯度,等於 “採樣一條軌跡 τ\tau,計算該軌跡的總回報 R(τ)R(\tau),再乘以這條軌跡中所有 (狀態,動作) 對對應的策略對數概率梯度 θlogπθ(atst)\nabla_{\theta} \log \pi_{\theta}(a_t|s_t) 之和”,然後對所有可能的軌跡求期望(平均)。

顯然得到所有軌跡的成本是極高的,例如我們要採樣出 max_token_length=100 的所有生成結果,因此我們可以用樣本均值來近似期望:

Note

蒙特卡洛近似:* 運行當前策略 πθ\pi_{\theta},收集 NN 條軌跡,組成數據集 D={τ1,...,τN}D = \{\tau_1, ..., \tau_N\} (記 N=DN = |D|)。

  • 用這些樣本的平均值近似期望值:

θJ(πθ)g^=1DτD[(t=0Tθlogπθ(atst))R(τ)]\nabla_{\theta} J(\pi_{\theta}) \approx \hat{g} = \frac{1}{|D|} \sum_{\tau \in D} [ (\sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_t|s_t)) R(\tau) ]

應用到 LM Policy 上#

通過圖中所示的生成流程得到這條採樣軌跡中每個 state action pair 的對數概率,現在就可以反向傳播來計算梯度,

Screenshot 2025-04-24 at 15.17.06

然後將每個梯度乘以從 RM 中 reward 送入表達式來運行梯度上升優化:

Screenshot 2025-04-24 at 15.23.14

High Variance#

PG 算法對於小問題效果較好,但是用於語言建模會有一些問題。

Note

中心極限定理告訴我們:只要樣本夠大,樣本均值就會呈正態分佈,這讓我們能更好地預測和分析數據。當樣本量較小時,樣本均值的波動會很大;即使均值趨向於正態分佈,但單次抽樣的結果可能差異很大。而我們又知道從 LM 中採樣很多軌跡的成本是很高的,這會導致估計量的高方差問題。

如何在不增加樣本量的情況下減少方差呢?

  1. 移除歷史獎勵: reward-to-go
    首先必須要承認的是,當前的 action 無法影響到過去已經獲得的獎勵,而過去的獎勵增加了不必要的噪聲,這和 RL 中的信用分配問題應該有點關係。因此去掉過去的項可以避免增加噪音,讓估計的梯度與真實梯度拉近。所以與其從零開始計算軌跡的獎勵,我們可以只考慮從當前時間步開始的動作的獎勵。

Screenshot 2025-04-24 at 15.48.52

  1. 引入 baseline
    RL 的研究已證實,引入一個依賴於 state 的項(比如計算軌跡獎勵的函數,也可以是一個常數)可以減少方差。這裡我們選擇 價值函數 Vπ(s)V^\pi(s)

Value Function#

Vπ(s)V^\pi(s) 告訴你根據當前策略進行行動,剩餘軌跡的期望獎勵是多少。

經典 RL 場景和 LM 場景中的價值定義例子:

Screenshot 2025-04-24 at 16.01.05

實際操作中,用的是我們試圖優化的那個 LM 做初始化,在其頂部再添加一個線性層來預估 value,這樣 Transformer 層的參數可以同時用於語言建模(用將 token 投射到詞表的層)和價值估計。

Screenshot 2025-04-24 at 16.45.48

前面說的 reward-to-go 在 RL 中被稱為 Q 函數,即從當前狀態開始採取這個動作,然後得到即時獎勵,按照策略完成後續行動的預期獎勵:

Screenshot 2025-04-24 at 16.54.58

再通過引入 Value 函數得到 Q 與 V 的差異,這個差異被稱為 Advantage 函數。

Screenshot 2025-04-24 at 16.56.48

這個 Aπ(st,at)A^\pi(s_{t,}a_t) 優勢項表示這個特定動作,相對於在狀態 ss 中可以採取的平均動作好多少。

Screenshot 2025-04-24 at 17.04.33

在圖中紅箭頭指向的狀態,向下移動的優勢函數將高於其它動作的優勢函數。

θ(J(θ))1Ni=1N(t=0Tθlogπθ(ai,tsi,t))Aπ(si,t,ai,t)\nabla_{\theta}(J(\theta)) \approx \frac{1}{N}\sum_{i=1}^{N}\left(\sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_{i,t} | s_{i,t})\right) A^{\pi}(s_{i,t}, a_{i,t})

梯度乘上優勢函數後,效果就變為了讓策略增加具有高優勢 action 的 logprob,並降低帶來低平均回報 action 的 log prob。

A^π(st,at)=Qπ(st,at)Vπ(st)=[r(st,at)+γVπ(st+1)]Vπ(st)\begin{align} \hat{A}^\pi(s_t, a_t) &= Q^\pi(s_t, a_t) - V^\pi(s_t) \\ &= [r(s_t, a_t) + \gamma V^\pi(s_{t+1})] - V^\pi(s_t) \end{align}

Note

在傳統的強化學習方法中,Q 網絡V 網絡 通常是獨立的。也就是說,Q 函數用於估計在狀態 ss 下執行動作 aa 所期望的總回報,而 V 函數 只是估計狀態 ss 的值。這樣需要兩個不同的神經網絡來分別計算這兩個值。

然而,現在我們引入了優勢函數 Aθ(s,a)A_{\theta}(s, a),其計算方式是基於 Q 值和 V 值之間的差異,即:

Aθ(s,a)=Qθ(s,a)Vθ(s)A_{\theta}(s, a) = Q_{\theta}(s, a) - V_{\theta}(s)

通過將 Aθ(s,a)A_{\theta}(s, a) 表達為 Qθ(s,a)Q_{\theta}(s, a)Vθ(s)V_{\theta}(s) 之間的差異,我們發現,我們只需要訓練一個網絡來輸出 Vθ(s)V_{\theta}(s),再通過獎勵 rtr_t 和折扣因子 γ\gamma 計算出 Q 值。

因此,只有一個神經網絡是需要的,這個網絡主要預測 Vθ(s)V_{\theta}(s)。Q 值通過以下公式計算得出:

Qθ(st,a)=rt+γVθ(st+1)Q_{\theta}(s_t, a) = r_t + \gamma \cdot V_{\theta}(s_{t+1})

優勢函數則進一步計算為:

Aθ(st,a)=rt+γVθ(st+1)Vθ(st)A_{\theta}(s_t, a) = r_t + \gamma \cdot V_{\theta}(s_{t+1}) - V_{\theta}(s_t)

優勢採樣#

短步長的優勢估計器偏差大但方差小,長步長的優勢估計器偏差小但方差大。這個權衡問題是強化學習中需要仔細選擇和調整的部分,取決於模型的穩定性要求和訓練效率。
Screenshot 2025-05-08 at 21.53.04
一個例子:“短期記憶的人只記得昨天發生的事,雖然不夠全面,但很穩定;長期記憶的人能看清未來幾天的全貌,但可能會被更多的未知因素所干擾。”

GAE#

為了解決這個偏差 - 方差問題,可以使用 GAE(廣義優勢估計),本質上就是對所有的優勢項的加權和,每項乘上一個衰減因子。

Note

現在來聊聊 TD 誤差
在線學習有一個妙處:你不需要等到最後再更新策略。於是,時序差分誤差(TD Error) 就登場了:

δ=r+γV(s)V(s)\delta = r + \gamma V(s') - V(s)

這裡的關鍵是:TD 誤差實際上是優勢函數的在線估計。它告訴你,就在此刻,你的動作是否讓未來狀態比你預期的要好。這個誤差 δ\delta 直接反映了優勢的概念:

  • 如果 δ>0\delta > 0:“嘿,這個動作比我想象中好!”(優勢為正)。
  • 如果 δ<0\delta < 0:“嗯,我本來以為會更好……”(優勢為負)。

這讓你可以逐步調整你的策略,不用等到一整 episode 結束才做改變。這對提高效率來說,簡直是絕佳策略。

GAE 的目的是在策略梯度算法中,提供一個比原始回報 R(τ)R(τ) 或簡單 TD 誤差 δtδ_t​ 更好的優勢函數 Aπ(s,a)A^π(s,a) 的估計值 $$A^t$​$,以降低梯度估計的方差,提高學習的穩定性和效率。

δt=rt+γVπ(st+1)Vπ(st)A^t=δt+γλA^t+1\begin{align} \delta_t &= r_t + \gamma V^\pi(s_{t+1}) - V^\pi(s_t) \\ \hat{A}_t &= \delta_t + \gamma \lambda \hat{A}_{t+1} \end{align}

這個公式遞歸地定義了廣義優勢估計 A^t\hat{A}_t。它不是只看一步的 TD 誤差 δt\delta_t,而是綜合了未來多步的 TD 誤差信息。

這個遞歸式從軌跡(episode)的末尾(假設 TT 是最後一步,A^T+1=0\hat{A}_{T+1}=0)向前計算:
* A^T=δT+0=δT\hat{A}_T = \delta_T + 0 = \delta_T
* A^T1=δT1+γλA^T=δT1+(γλ)δT\hat{A}_{T-1} = \delta_{T-1} + \gamma \lambda \hat{A}_T = \delta_{T-1} + (\gamma \lambda) \delta_T
* A^T2=δT2+γλA^T1=δT2+(γλ)δT1+(γλ)2δT\hat{A}_{T-2} = \delta_{T-2} + \gamma \lambda \hat{A}_{T-1} = \delta_{T-2} + (\gamma \lambda) \delta_{T-1} + (\gamma \lambda)^2 \delta_T
* ...
* 一般形式A^t=k=0(γλ)kδt+k\hat{A}_t = \sum_{k=0}^{\infty} (\gamma \lambda)^k \delta_{t+k} (假設無限步長或在終止狀態後的 δ\delta 為 0)

參數 λ\lambda0λ10 \le \lambda \le 1)是 GAE 的關鍵,它控制著估計 A^t\hat{A}_t 的偏差(bias)和方差(variance):

  • λ=0\lambda = 0 時:
    • A^t=δt\hat{A}_t = \delta_t。GAE 退化為簡單的單步 TD 誤差。這種估計方差較低(因為它只依賴於下一步的信息),但可能偏差較高(因為它嚴重依賴於可能不準確的 Vπ(st+1)V^{\pi}(s_{t+1}) 的估計)。
  • λ=1\lambda = 1 時:
    • A^t=k=0(γ)kδt+k\hat{A}_t = \sum_{k=0}^{\infty} (\gamma)^k \delta_{t+k}。經過推導可以證明,這等價於 A^t=(k=0γkrt+k)Vπ(st)\hat{A}_t = (\sum_{k=0}^{\infty} \gamma^k r_{t+k}) - V^{\pi}(s_t),也就是蒙特卡洛(Monte Carlo)回報減去基線(baseline)。這種估計偏差較低(因為它使用了從 tt 時刻開始的完整實際回報),但方差通常很高(因為它累積了多個時間步的隨機性)。
  • 0<λ<10 < \lambda < 1 時:
    • GAE 在上述兩種極端情況之間進行插值。λ\lambda 越接近 0,越偏向於低方差高偏差的 TD 估計;λ\lambda 越接近 1,越偏向於高方差低偏差的 MC 估計。
    • 通過選擇合適的 λ\lambda(例如 0.97),GAE 試圖在偏差和方差之間取得一個良好的平衡,從而得到一個既相對準確(偏差可控)又相對穩定(方差較小)的優勢估計。

語言模型的 Advantage#

如圖,目標是提高在當前狀態 “上海” token 的 logprob,降低 “巧克力” 的 logprob,因為選擇” 上海 “的優勢高於選擇 “巧克力”(胡亂說的 token )的優勢。

Screenshot 2025-04-24 at 23.13.57

重要性採樣和離線學習#

在很多情況下,我們可能想要計算 Exp(x)[f(x)]E_{x \sim p(x)}[f(x)],但是:

  1. 我們很難或者無法直接從目標分佈 p(x)p(x) 中採樣得到 xx
  2. 或者從 p(x)p(x) 採樣效率很低,語言模型中就是這個問題,LM 採樣成本太高了。

但是,我們可能可以很容易地從另一個替代的,或稱為提議(Proposal )的分佈 q(x)q(x) 中進行採樣。

重要性採樣 (Importance Sampling, IS) 是一種通過從不同的分佈採樣來估計目標分佈期望的技術,將一個在概率分佈 p(x)p(x) 下計算的期望值 Exp(x)[f(x)]E_{x \sim p(x)}[f(x)],轉換為在另一個不同的概率分佈 q(x)q(x) 下計算一個相關函數的期望值 Exq(x)[p(x)q(x)f(x)]E_{x \sim q(x)}\left[\frac{p(x)}{q(x)} f(x)\right]

Exp(x)[f(x)]=p(x)f(x)dx=q(x)q(x)p(x)f(x)dx(假設 q(x)0)=q(x)p(x)q(x)f(x)dx=Exq(x)[p(x)q(x)f(x)]\begin{align} E_{x \sim p(x)}[f(x)] &= \int p(x) f(x) \, dx \\ &= \int \frac{q(x)}{q(x)} p(x) f(x) \, dx \quad \text{(假設 } q(x) \neq 0 \text{)} \\ &= \int q(x) \frac{p(x)}{q(x)} f(x) \, dx \\ &= E_{x \sim q(x)}\left[\frac{p(x)}{q(x)} f(x)\right] \end{align}

這裡的關鍵是引入了重要性權重(importance weight) w(x)=p(x)q(x)w(x) = \frac{p(x)}{q(x)}。這個權重的作用是修正偏差:對於一個從 q(x)q(x) 中抽到的樣本 xix_i,如果它在目標分佈 p(x)p(x) 中出現的概率更高(p(xi)>q(xi)p(x_i) > q(x_i)),就會給它一個大於 1 的權重;反之,如果它在 p(x)p(x) 中出現的概率更低(p(xi)<q(xi)p(x_i) < q(x_i)),就會給它一個小於 1 的權重。這樣加權平均後,就能得到對原始期望 Exp(x)[f(x)]E_{x \sim p(x)}[f(x)] 的一個(通常是無偏或一致的)估計。

重要性採樣允許我們:

  1. 從容易採樣的分佈 q(x)q(x) 中抽取樣本 x1,x2,...,xNx_1, x_2, ..., x_N
  2. 通過計算加權平均來估計原始的期望值:
Exp(x)[f(x)]1Ni=1Np(xi)q(xi)f(xi)E_{x \sim p(x)}[f(x)] \approx \frac{1}{N} \sum_{i=1}^{N} \frac{p(x_i)}{q(x_i)} f(x_i)

回到我們的場景。前面我們得到了 On-Policy 的策略梯度估計:

θ(J(θ))1Ni=1N(t=0Tθlogπθ(ai,tsi,t))Aπ(si,t,ai,t)\nabla_{\theta}(J(\theta)) \approx \frac{1}{N}\sum_{i=1}^{N}\left(\sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_{i,t} | s_{i,t})\right) A^{\pi}(s_{i,t}, a_{i,t})

[!info]
On-Policy 的含義:即採集數據用的策略和訓練時用的策略是同一個。由於計算時需要使用從當前策略 πθ\pi_{\theta} 採樣生成的軌跡。這意味著每次策略更新後,舊數據就不能直接用了,樣本效率低。至於 mini_batch_num > 1 時,後面用於更新的數據還算是 On-Policy 嗎?嚴格意義上感覺不算,所以也可以理解為 Semi-On-Policy?(表達不一定嚴謹)。

而 On-Policy 則是強調當前策略模型能否和環境進行交互。[2]^{[2]}

我們希望利用過去由舊策略 πθOFFLINE\pi_{\theta_{OFFLINE}} 生成的數據(這些數據可能是大量存在的)來估計當前新策略 πθONLINE\pi_{\theta_{ONLINE}} 的梯度。這樣可以重複利用數據,提高樣本效率。

回憶下重要性採樣 IS 的原理: Exp(x)[f(x)]=Exq(x)[p(x)q(x)f(x)]E_{x \sim p(x)}[f(x)] = E_{x \sim q(x)}\left[\frac{p(x)}{q(x)} f(x)\right]
對應到我們的 PG(簡化考慮單步決策):
* 目標分佈 p(x)p(x) 對應新策略 πθONLINE(as)\pi_{\theta_{ONLINE}}(a|s)
* 採樣分佈 q(x)q(x) 對應舊策略 πθOFFLINE(as)\pi_{\theta_{OFFLINE}}(a|s)
* 重要性權重為 wt=πθONLINE(atst)πθOFFLINE(atst)w_t = \frac{\pi_{\theta_{ONLINE}}(a_t|s_t)}{\pi_{\theta_{OFFLINE}}(a_t|s_t)}

將重要性權重應用到 On-Policy 梯度的每一項(每個時間步 tt),得到標準的 Off-Policy 估計:

θONLINE(J(θONLINE,θOFFLINE))1Ni=1Nt=0T[πθONLINE(ai,tsi,t)πθOFFLINE(ai,tsi,t)θONLINElogπθONLINE(ai,tsi,t)Aπ(si,t,ai,t)]\nabla_{\theta_{ONLINE}} (J(\theta_{ONLINE},\theta_{OFFLINE})) \approx \frac{1}{N}\sum_{i=1}^{N}\sum_{t=0}^{T} \left[ \frac{\pi_{\theta_{ONLINE}}(a_{i,t}|s_{i,t})}{\pi_{\theta_{OFFLINE}}(a_{i,t}|s_{i,t})} \nabla_{\theta_{ONLINE}} \log \pi_{\theta_{ONLINE}}(a_{i,t} | s_{i,t}) A^{\pi}(s_{i,t}, a_{i,t}) \right]

現在我們找到了可以在不每次從我們正在優化的策略(要訓練的模型)中採樣的情況下完整梯度上升優化,而是只採樣一次,將軌跡保存到內存 / 數據庫中,用 mini-batch 優化策略,然後用新策略初始化離線策略(採樣的策略)。

PPO Loss#

PPO Loss 主要由三個部分組成:策略損失(LPOLICYL_{POLICY})、價值函數損失(LVFL_{VF})和熵獎勵(LENTROPYL_{ENTROPY})。

1. 策略損失 (LPOLICYL_{POLICY})#

Clipped Surrogate Objective (裁剪的替代目標)

LPOLICY=min(πθ(atst)πθold(atst)A^t,clip(πθ(atst)πθold(atst),1ϵ,1+ϵ)A^t)L_{POLICY} = \min \left( \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} \hat{A}_t, \text{clip} \left( \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}, 1-\epsilon, 1+\epsilon \right) \hat{A}_t \right)

這是 PPO 的核心。你會注意到它和我們前面用重要性採樣推導出來的 off-policy 策略梯度目標有點像,但有一個關鍵的改動。

  • πθ(atst)πθold(atst)\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}: 這就是重要性採樣比率,稱之為 rt(θ)r_t(\theta)。它是在狀態 sts_t 下,根據當前(在線)策略 πθ\pi_{\theta} 採取動作 ata_t 的概率,除以根據收集軌跡數據時使用的(離線)策略 πθold\pi_{\theta_{old}} 採取該動作的概率。這個比率修正了數據來自一個與我們當前試圖改進的策略略有不同的策略這一事實。

  • A^t\hat{A}_t: 這是優勢函數估計值,是使用 GAE 計算出來的,有助於平衡偏差和方差。它告訴我們,在狀態 sts_t 下採取動作 ata_t 比在該狀態下採取平均動作要好多少或差多少(根據當前的價值函數判斷)。

  • clip 函數: 這就是 PPO 的關鍵點所在。
    clip(rt(θ),1ϵ,1+ϵ)\text{clip} \left( r_t(\theta), 1-\epsilon, 1+\epsilon \right)
    它基本上是說:如果概率比率 rt(θ)r_t(\theta) 偏離 1 太遠(過高或過低),我們就把它 “裁剪” 掉。所以,如果 rt(θ)r_t(\theta) 試圖變成 1.51.5,而 ϵ\epsilon0.20.2,它就會被裁剪到 1.21.2。如果它試圖變成 0.50.5,它就會被裁剪到 0.80.8
    參數 ϵ\epsilon (epsilon) 是一個小的超參數(例如 0.1 或 0.2),它定義了裁剪範圍 [1ϵ,1+ϵ][1-\epsilon, 1+\epsilon]

  • min 函數: 這個目標函數取了下面兩項中的較小者:

    1. 未裁剪的目標: rt(θ)A^tr_t(\theta) \hat{A}_t
    2. 裁剪後的目標: clip(rt(θ),1ϵ,1+ϵ)A^t\text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_t

    為什麼要這樣做呢? 策略梯度的目標是增加具有正優勢的動作的概率,並減少具有負優勢的動作的概率。
    然而,使用重要性採樣時,如果 rt(θ)r_t(\theta) 變得非常大,可能會導致巨大的更新和不穩定性。PPO 通過裁剪這個比率來嘗試保持新策略與舊策略的接近。

    • 如果 A^t>0\hat{A}_t > 0 (好動作): 我們希望增加 πθ(atst)\pi_{\theta}(a_t|s_t)min 函數意味著如果 rt(θ)r_t(\theta) 增長超過 1+ϵ1+\epsilon,目標函數就會被限制在 (1+ϵ)A^t(1+\epsilon)\hat{A}_t。這可以防止策略在單次更新中變化過大,即使未裁剪的目標會建議一個更大的增幅。
    • 如果 A^t<0\hat{A}_t < 0 (壞動作): 我們希望減少 πθ(atst)\pi_{\theta}(a_t|s_t)。如果 rt(θ)r_t(\theta) 縮小到 1ϵ1-\epsilon 以下,目標函數就會被限制在 (1ϵ)A^t(1-\epsilon)\hat{A}_t。(注意:當 A^t<0\hat{A}_t < 0 時,rt(θ)A^tr_t(\theta)\hat{A}_t 這一項在 rt(θ)r_t(\theta) 較小时值較大(更接近零或為正),而 clip(...)A^t\text{clip}(...) \hat{A}_t 也是在 clip(...)\text{clip}(...) 較小时值較大。這裡的 min 實際上意味著當比率超出裁剪邊界時,我們採取的是更悲觀的更新步驟,或者說,導致 log 概率變化幅度更小的那一步。)

2. 價值函數損失 (LVFL_{VF})#

LVF=12Vθ(s)(t=tTγttrts0=s)22L_{VF} = \frac{1}{2} \left\| V_{\theta}(s) - \left( \sum_{t'=t}^{T} \gamma^{t'-t} r_{t'} \Big| s_0 = s \right) \right\|^2_2

這和前面的內容完全一樣:

  • Vθ(s)V_{\theta}(s) 是價值網絡的輸出(即 LLM 頂部再加一個線性層,用來預測從狀態 ss 開始的期望累積獎勵)。
  • γtrt\sum \gamma^{t'} r_{t'} 這一項(稱之為 GsG_s 或目標價值)是從狀態 ss 開始,並遵循當前策略直到回合結束所觀察到的實際折扣獎勵總和。這是我們為 Vθ(s)V_{\theta}(s) 設定的經驗目標。
  • 這個損失函數就是預測值 Vθ(s)V_{\theta}(s) 和觀察到的目標值 GsG_s 之間的均方誤差(MSE)。我們希望價值函數能夠準確預測未來的獎勵。這個價值函數對於計算優勢 A^t\hat{A}_t至關重要。

3. 熵獎勵 (LENTROPYL_{ENTROPY})#

LENTROPY=xp(x)logp(x)L_{ENTROPY} = - \sum_x p(x) \log p(x)
  • 這裡的 p(x)p(x)(或者更準確地說是 πθ(as)\pi_{\theta}(a|s),對於給定狀態 ss 下所有可能的動作 aa)代表了當前策略在給定狀態下輸出的動作概率分佈。
  • xp(x)logp(x)\sum_x p(x) \log p(x) 這一項是這個概率分佈的熵。熵衡量的是分佈的隨機性或不確定性。均勻分佈(非常隨機)具有高熵,而尖峰分佈(對某個動作非常確定)具有低熵。
  • 損失項是熵。當我們在總損失 LPPOL_{PPO} 中最小化這個 LENTROPYL_{ENTROPY} 時(假設 c2c_2 是正的),我們實際上是在最大化策略的熵。

鼓勵更高的熵可以促進探索,會讓策略變得更隨機一些,嘗試不同的動作(在 LLM 的情況下就是嘗試不同的 token),而不是過快地收斂到一個可能是次優的確定性策略。這有助於 Agent 發現更好的策略。

最終形式 LPPOL_{PPO}#

最終的 PPO 損失是這三個部分的加權和:

LPPO=LPOLICY+c1LVF+c2LENTROPYL_{PPO} = L_{POLICY} + c_1 L_{VF} + c_2 L_{ENTROPY}
  • c1LVFc_1 L_{VF}: 價值函數損失,由 c1c_1 加權。c1c_1 的一個常見值是 0.50.5 左右。
  • c2LENTROPYc_2 L_{ENTROPY}: 熵獎勵(如果 c2>0c_2 > 0,實際上是對低熵的懲罰),由 c2c_2 加權。c2c_2 通常是一個小的正常數(例如 0.010.01),用於鼓勵探索,同時又不會壓倒主要的策略目標。

Agent 的參數(即 LLM 的權重)通過計算這個組合損失 LPPOL_{PPO} 的梯度並執行梯度下降來進行更新。

Reference Model#

Reward Hacking#

RL 的一大問題就是 reward-hacking,模型可能會學會總是輸出帶來好獎勵但對人類來說沒意義的 token 或序列,比如連續說十遍 “謝謝你” 來提升禮貌分數,所以我們希望對齊的模型(RL post-training 後)的輸出儘量和原本模型的輸出較為相近。

因此會有另一個凍結了權重的模型(ref model),我們要優化的模型在每個軌跡的每一步中通過 reward model 生成獎勵的時候,這個獎勵會減去 ref model 與優化的模型 log prob 之間的 KL 散度,作為懲罰項來防止模型生成與原始模型差異過大的答案,以此防止上面所說的模型作弊現象。

Screenshot 2025-05-08 at 00.43.14

Code walk through#

trl#

class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
    # ... (class attributes like transformers_parent_class) ...

這個類的核心目的是將一個標準的因果語言模型(Causal LM)(我們的 Actor Model,負責生成文本的策略 πθπ_θ​)與一個 Value Head(即 Critic Model,負責估計狀態價值 V (s))捆綁在一起。在 PPO / Actor Critic 算法中,我們同時需要策略和價值函數,這個類就提供了一個統一的模型結構來同時輸出這兩者。

    def __init__(self, pretrained_model, **kwargs):
        super().__init__(pretrained_model, **kwargs) # 基礎設置
        v_head_kwargs, _, _ = self._split_kwargs(kwargs) # 分離出給ValueHead的參數

        # 確保傳入的是個有語言模型輸出能力的模型
        if not any(hasattr(self.pretrained_model, attribute) for attribute in self.lm_head_namings):
            raise ValueError("The model does not have a language model head...")

        # 創建ValueHead實例,它將學習預測狀態的價值 V(s)
        self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs)

        # 初始化ValueHead的權重
        self._init_weights(**v_head_kwargs) # 默認隨機初始化,也可以指定正態分佈初始化
  1. 充當 Actor: 即我們的語言模型 pretrained_model,它會根據當前 prompt(狀態 s)生成回應(動作 a,即一系列 token)。
  2. Critic: 評估 Actor 在某個狀態 s 的 “好壞”,即輸出 V(s)V(s)。這就是線性層 self.v_head 的任務。
    def forward(
        self,
        input_ids=None, # 輸入的token IDs (狀態 s)
        attention_mask=None,
        past_key_values=None, # 用於加速生成
        **kwargs,
    ):
        # 強制底層模型輸出 hidden_states,ValueHead 需要它們作為輸入
        kwargs["output_hidden_states"] = True
        # ... (處理 past_key_values 和 PEFT 的一些細節,PPO核心理解中可先忽略)

        # 1. Actor (基礎語言模型) 進行計算
        base_model_output = self.pretrained_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs,
        )

        # 2. 提取 Actor 的輸出 (用於策略更新) 和 Critic 的輸入
        lm_logits = base_model_output.logits # Actor 的輸出:預測下一個token的概率分佈
        # 這是計算 PPO 中 L_POLICY 和 L_ENTROPY 的基礎

        last_hidden_state = base_model_output.hidden_states[-1] # Critic 的輸入:LM最後一層的隱藏狀態,
        # 代表了當前狀態 s 的表徵

        # (可選) 語言模型本身的損失,在RL階段通常不直接用
        loss = base_model_output.loss

        # (確保數據和模型在同一設備)
        if last_hidden_state.device != self.v_head.summary.weight.device:
            last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device)

        # 3. Critic (ValueHead) 進行計算
        # ValueHead接收狀態表徵,輸出對該狀態的價值估計 V(s)
        value = self.v_head(last_hidden_state).squeeze(-1) # 這是計算 PPO 中價值損失 L_VF 和優勢 A_hat 的基礎

        # (確保 logits 是 float32,為了數值穩定性)
        if lm_logits.dtype != torch.float32:
            lm_logits = lm_logits.float()

        # 返回 Actor 的 logits, LM loss (可能為None), 和 Critic 的value
        return (lm_logits, loss, value)

對於 PPO-RLHF 訓練的每一步:

  1. 我們把當前的一批 prompt (序列 input_ids) 輸入模型。
  2. self.pretrained_model (Actor) 會計算(Rollout)出 lm_logits。這些 logits 代表了在當前 prompt 下,模型認為接下來應該生成哪些詞元的概率分佈。PPO 的策略損失 LPOLICYL_{POLICY}​ 和熵獎勵 LENTROPYL_{ENTROPY​} 都需要基於這個概率分佈πθ(atst)π_θ​(a_t​∣s_t​)來計算。
  3. 同時,我們從 base_model_output 中取出 last_hidden_state。這可以看作是當前 prompt (狀態 s) 的一個向量表示。
  4. 這個 last_hidden_state 被送入 self.v_head (Critic),輸出一個標量 value。這個 value 就是模型對當前狀態 s 的價值估計 Vθ(s)V_θ​(s)PPO 的價值函數損失 LVFL_{VF}​ 就是要優化這個 Vθ(s)V_θ​(s),使其儘可能接近真實的回報。並且,這個 Vθ(s)V_θ​(s)也是計算優勢函數 AtA^t​ 的關鍵組成部分,而 AtA^t​ 又會指導 LPOLICYL_{POLICY​} 的計算。
  5. 同樣的 prompt + response 序列輸入給 Reward 和 Reference model 做推理,得到 reward 和 log probs(計算 KL 懲罰)。

所以一次 forward 調用,我們就同時獲得了更新 Actor (策略) 和 Critic (價值函數) 所需的核心信息。
訓練的流程可以借下圖幫助了解:

rlhf-pipeline

Tip

在 RLHF 中,只有 Actor 在經驗收集(Rollout)時需要 Prefill + Decode(完整的 Auto-Regressive Generation),其余的模型都是在處理已有的 response 獲取 logprob 和 value 等,只做 Prefill。

此外 Actor 涉及訓練和推理(指 Rollout),因此需要 training engine(如 Megatron、DeepSpeed 和 FSDP) + rollout engine(如 SGLang 和 vLLM)兩者來各完成自己擅長的任務;Critic 推理時復用訓練的 forward 中的內部表徵來輸出新的 value 預測,因此運行在同一個 training engine 中;而 Reference 和 Reward model 都只用推理引擎來得到 logprob 和 reward 即可。[3]^{[3]}

verl#

和 OpenRLHF 等都是優秀的 RLHF 框架,一個比較好的導讀:【AI Infra】VeRL 框架入門 & 代碼帶讀

Reference#

載入中......
此文章數據所有權由區塊鏈加密技術和智能合約保障僅歸創作者所有。