本文主要基于 Umar Jamil のコース[ 1 ] ^{[1]} [ 1 ] に基づいて学習と記録を行います。私たちの目標は、LLM の動作を私たちの期待する出力と一致させることであり、RLHF は最も有名な技術の一つです。その標準的なプロセスは、4 つのモデルを含みます(聞こえは良いですが、メモリを多く消費するため、多くの方法では一部のモデルを削除します)。ここでは、合計で 4 つのモデルが必要であることを覚えておいてください:Reward、Actor、Critic、Reference Model。最終的に最適化されたモデルは、ここで言及されている Actor Model です。
LLM to RL#
以前の RL に関する理解では、ポリシーは現在の状態で取るべきアクションの確率を教えてくれるものでした。そう考えると、言語モデル自体がポリシーと見なすことができます:プロンプト(状態)を受け取り、次のトークン(アクション)の確率を出力し、サンプリング後に新しい状態(トークンがプロンプトに追加された後)を得ることになります。これは、vocab_size
の大きさのアクションスペースを持つポリシーでもあり、RL エージェントでもあります。
そうなると、報酬を提供するものが必要です(従来の RL では一般的に環境に組み込まれた報酬関数です)。
「Q-A-Reward」のデータセットを作成することでこれを実現できますが、人間は合意を見つけるのが得意ではありませんが、比較の優劣を判断するのが得意です。そこで、方向性を変えて、モデルが高温度で複数の A を生成し、分野の専門家(人間でも AI モデルでも可)に選ばれた / 好ましい答えを選んでもらい、好みのデータセットをラベル付けし、それを用いて数値報酬を生成する報酬モデルを訓練します。
Reward Model#
この RM は、Llama のような事前訓練された LLM を使用して実現されます。
Note
テキスト生成タスクでは、プロンプトを入力として Transformer に渡し、生成された埋め込み(隠れ状態)の最後の(トークンの)隠れ状態を線形投影して語彙に対するロジットを得て、ソフトマックスとサンプリング戦略を使用して次のトークンを選択します。
テキストを生成するのではなく数値報酬を生成したい場合、語彙に対する線形投影を単一出力特徴(スカラーを出力)を持つ線形に置き換えて、テキストシーケンス全体の単一スコアを生成します。
Reward Model Loss#
Tip
訓練中、選択された答えに対して高い報酬を生成し、未選択の答えに対して低い報酬を生成するようにモデルを調整します。
Bradley-Terry のパラメータ化形式に似ています:
L o s s = − log σ ( r ( x , y w ) − r ( x , y l ) ) Loss = -\log \sigma(r(x, y_w) - r(x, y_l)) L oss = − log σ ( r ( x , y w ) − r ( x , y l ))
y w y_w y w は選ばれたものを表し、y l y_l y l はその逆です。したがって、モデルが選ばれたものに高い報酬を与えるとき、
r ( x , y w ) − r ( x , y l ) > 0 r(x, y_w) - r(x, y_l)>0 r ( x , y w ) − r ( x , y l ) > 0
σ ( r ( x , y w ) − r ( x , y l ) ) ∈ ( 0.5 , 1 ) \sigma(r(x, y_w) - r(x, y_{l)})\in(0.5,1) σ ( r ( x , y w ) − r ( x , y l ) ) ∈ ( 0.5 , 1 )
− log σ ( r ( x , y w ) − r ( x , y l ) ) ∈ ( 0 , 1 ) -\log \sigma(r(x, y_w) - r(x, y_l))\in(0,1) − log σ ( r ( x , y w ) − r ( x , y l )) ∈ ( 0 , 1 )
このように損失は低くなり、モデルが選択された答えに低い報酬を与えると、損失は非常に高くなります。
HuggingFace の RewardTrainer
クラスは、AutoModelForSequenceClassification
を入力として受け取ります(つまり、上記で言及したモデル構造です)。
Actor & Critic Model#
Trajectories#
前述のように、強化学習(RL)の核心的な目標は、エージェントの行動を導くポリシー (policy, π \pi π )を見つけることです。このポリシーは、最大限の期待報酬 (expected return)を得るためのものです。
数学的には、最大化目標関数 J ( π ) J(\pi) J ( π ) の最適ポリシー π ∗ \pi^* π ∗ を見つけることを表します:
Copy π ∗ = arg max π J ( π ) \pi^* = \arg \max_{\pi}
J(\pi) π ∗ = arg π max J ( π )
期待報酬 J ( π ) J(\pi) J ( π ) は、エージェントがポリシー π \pi π に従ったときに、数多くの可能なライフサイクルやエピソード(episodes)で期待される平均総報酬を表します。
その計算方法は、発生する可能性のあるすべての軌跡 (trajectory, τ \tau τ )を考慮し、各軌跡の総報酬 R ( τ ) R(\tau) R ( τ ) にその軌跡がポリシー π \pi π の下で発生する確率 P ( τ ∣ π ) P(\tau|\pi) P ( τ ∣ π ) を掛けて加重平均(または積分)します。
Copy J ( π ) = ∫ P ( τ ∣ π ) R ( τ ) = E τ ∼ π [ R ( τ ) ] J(\pi) = \int P(\tau|\pi) R(\tau) = E_{\tau \sim \pi} [R(\tau)] J ( π ) = ∫ P ( τ ∣ π ) R ( τ ) = E τ ∼ π [ R ( τ )]
E τ ∼ π [ ⋅ ] E_{\tau \sim \pi}[\cdot] E τ ∼ π [ ⋅ ] は、軌跡 τ \tau τ がポリシー π \pi π に基づいて生成されたときの期待値 (Expected Value)を示します。
R ( τ ) R(\tau) R ( τ ) は、単一の軌跡 τ \tau τ で得られる総報酬(報酬)です。
P ( τ ∣ π ) P(\tau|\pi) P ( τ ∣ π ) は、エージェントがポリシー π \pi π を使用しているときに、特定の軌跡 τ \tau τ が発生する確率です。
軌跡 τ \tau τ は、エージェントが経験する一連の状態とアクションのシーケンスであり、初期状態から始まります。これは、エージェントが環境と相互作用する可能性のある「物語」または「パス」です。
Copy τ = ( s 0 , a 0 , s 1 , a 1 , s 2 , a 2 , … ) \tau = (s_0, a_0, s_1, a_1, s_2, a_2, \dots) τ = ( s 0 , a 0 , s 1 , a 1 , s 2 , a 2 , … )
s t s_t s t :時間ステップ t t t の状態(State)。
a t a_t a t :時間ステップ t t t で取られるアクション(Action)(通常は状態 s t s_t s t とポリシー π \pi π に基づいています)。
私たちは通常、環境を確率的 (stochastic)にモデル化します。これは、同じ状態 s t s_t s t で同じアクション a t a_t a t を実行しても、必ずしも完全に 同じ次の状態 s t + 1 s_{t+1} s t + 1 にはならないことを意味します。ここにはランダム性が関与しています。
次の状態 s t + 1 s_{t+1} s t + 1 は、現在の状態 s t s_t s t と取られたアクション a t a_t a t に条件付けられた確率分布から抽出されます:
Copy s t + 1 ∼ P ( ⋅ ∣ s t , a t ) s_{t+1} \sim P(\cdot | s_t, a_t) s t + 1 ∼ P ( ⋅ ∣ s t , a t )
ランダム性のある状態遷移とエージェントのポリシーを考慮すると、全体の軌跡が発生する確率を計算できます。それは以下の項の連乗で得られます:
エージェントが初期状態 s 0 s_0 s 0 にある確率:p 0 ( s 0 ) p_0(s_0) p 0 ( s 0 ) 。
軌跡内の各 時間ステップ t t t に対して:
環境が s t s_t s t と a t a_t a t の条件下で s t + 1 s_{t+1} s t + 1 に遷移する確率:P ( s t + 1 ∣ s t , a t ) P(s_{t+1}|s_t, a_t) P ( s t + 1 ∣ s t , a t ) 。
エージェントが状態 s t s_t s t でアクション a t a_t a t を選択する確率:π ( a t ∣ s t ) \pi(a_t|s_t) π ( a t ∣ s t ) 。
Copy P ( τ ∣ π ) = p 0 ( s 0 ) ∏ t = 0 T − 1 P ( s t + 1 ∣ s t , a t ) π ( a t ∣ s t ) P(\tau|\pi) = p_0(s_0) \prod_{t=0}^{T-1} P(s_{t+1}|s_t, a_t) \pi(a_t|s_t) P ( τ ∣ π ) = p 0 ( s 0 ) t = 0 ∏ T − 1 P ( s t + 1 ∣ s t , a t ) π ( a t ∣ s t )
(ここで T T T は軌跡の長さです)。
軌跡の総報酬 R ( τ ) R(\tau) R ( τ ) を計算する際、私たちはほぼ常に割引報酬 (discounted rewards)を使用します。これは、早く受け取った報酬が遅く受け取った報酬よりも価値があることを意味します。
なぜでしょうか?
現実のシナリオを反映しています(今日手に入る 1 ドルは明日約束された 1 ドルよりも価値があります)。
継続的なタスク(固定の終点がないタスク)では無限の報酬の問題を回避します。
数学的な便利さを提供します。
私たちは割引因子 (discount factor)γ \gamma γ を導入します。ここで 0 ≤ γ < 1 0 \le \gamma < 1 0 ≤ γ < 1 です。γ \gamma γ が 0 に近いほど、エージェントは「近視的」(目の前の利益により関心を持つ)になり、γ \gamma γ が 1 に近いほど、エージェントは「先見の明がある」(長期的な報酬により関心を持つ)ようになります。
軌跡の総割引報酬は次のように計算されます:
Copy R ( τ ) = ∑ t = 0 ∞ γ t r t R(\tau) = \sum_{t=0}^{\infty} \gamma^t r_t R ( τ ) = t = 0 ∑ ∞ γ t r t
r t r_t r t は時間ステップ t t t で受け取る即時報酬(immediate reward)です。
γ t \gamma^t γ t は時間ステップ t t t の報酬に適用される割引係数です。
では、LLM において、軌跡とは何でしょうか?前述のように、モデルはポリシーであり、プロンプトは状態であり、次のトークンはアクションです。したがって、自回帰生成におけるこれらの s, a シーケンスが軌跡を構成します。
Policy Gradient#
私たちは強化学習の目標を定義しました:期待報酬 J ( π ) J(π) J ( π ) を最大化する最適ポリシー π ∗ π^∗ π ∗ を見つけることです。素晴らしい。しかし、実際にこのポリシーをどのように表現し、見つけるのでしょうか?
通常、特に複雑な問題を扱う場合、すべての可能なポリシーを探索することはありません。私たちはパラメータ化ポリシー(parameterized policy)を定義し、π θ π_θ π θ と呼びます。θ θ θ を一連の「ノブ」またはパラメータと考えることができます。もし私たちのポリシーが神経ネットワークであれば、θ θ θ はネットワークの重みとバイアスかもしれません。
私たちの目標は、これらのノブ θ θ θ を調整して期待報酬を最大化することに変わります。
Note
パラメータ θ θ θ のポリシー π θ π_θ π θ の下で、すべての可能な軌跡の期待報酬:
J ( π θ ) = E τ ∼ π θ [ R ( τ ) ] J(\pi_{\theta}) = E_{\tau \sim \pi_{\theta}} [R(\tau)] J ( π θ ) = E τ ∼ π θ [ R ( τ )]
これは、期待報酬が軌跡 τ \tau τ に依存し、軌跡の分布が私たちの特定 のポリシー π θ \pi_{\theta} π θ によって選択されたアクションに依存することを意味します。θ θ θ を変更すると、ポリシーが変わり、軌跡が変わり、期待報酬も変わります。
Note
私たちは θ θ θ を変更することで J ( π θ ) J(\pi_{\theta}) J ( π θ ) を最大化したいと考えています。深層学習では一般に勾配降下法 (gradient descent)を使用して損失関数 を最小化 します。しかし、ここでは、関数J J J を最大化 したいのです。したがって、逆に勾配上昇法 (gradient ascent)を使用します!これは山を登るようなもので、最も急な上昇方向(つまり勾配)を見つけて、その方向に一歩進むことです。
私たちのポリシー π θ \pi_{\theta} π θ は神経ネットワークであり、J ( π θ ) J(\pi_{\theta}) J ( π θ ) を増加させるためにそのパラメータ θ θ θ を反復的に調整します。この更新ルールは非常に馴染みのあるものに見えるでしょう(ただし、勾配降下法のマイナス記号をプラスに置き換えたものです):
Copy θ k + 1 = θ k + α ∇ θ J ( π θ ) ∣ θ k \theta_{k+1} = \theta_k + \alpha \nabla_{\theta} J(\pi_{\theta})|_{\theta_k} θ k + 1 = θ k + α ∇ θ J ( π θ ) ∣ θ k
θ k \theta_k θ k :第 k k k 回の反復時のパラメータです。
α \alpha α :学習率(ステップサイズ)です。
∇ θ J ( π θ ) ∣ θ k \nabla_{\theta} J(\pi_{\theta})|_{\theta_k} ∇ θ J ( π θ ) ∣ θ k :期待報酬 J J J のパラメータ θ θ θ に対する勾配 (gradient)で、現在のパラメータ θ k θ_k θ k で計算されます。これは、パラメータ空間のどの方向が J J J を最大限に増加させるかを教えてくれます。
Important
PG の導出
ここでは、最初に少し冗長になりますが、インデックスを再度紹介して ADHD に優しい導出を行います...
第一歩 、私たちが求める勾配の対象は期待報酬 J ( π θ ) J(\pi_{\theta}) J ( π θ ) であることを再確認します:
∇ θ J ( π θ ) = ∇ θ E τ ∼ π θ [ R ( τ ) ] \nabla_{\theta} J(\pi_{\theta}) = \nabla_{\theta} E_{\tau \sim \pi_{\theta}} [R(\tau)] ∇ θ J ( π θ ) = ∇ θ E τ ∼ π θ [ R ( τ )]
ここで:
J ( π θ ) J(\pi_{\theta}) J ( π θ ) は期待報酬です。
E τ ∼ π θ [ ⋅ ] E_{\tau \sim \pi_{\theta}} [\cdot] E τ ∼ π θ [ ⋅ ] は期待値であり、この期待はすべての可能な軌跡 (trajectory, τ \tau τ ) に対して計算されます。軌跡 τ \tau τ はエージェントと環境の相互作用によって生成される一連の状態とアクション ( s 0 , a 0 , s 1 , a 1 , … ) (s_0, a_0, s_1, a_1, \dots) ( s 0 , a 0 , s 1 , a 1 , … ) です。
τ ∼ π θ \tau \sim \pi_{\theta} τ ∼ π θ は、これらの軌跡が現在のポリシー π θ \pi_{\theta} π θ に基づいて生成されることを示します。
R ( τ ) R(\tau) R ( τ ) は、完全な軌跡 τ \tau τ が得られる総報酬(通常は割引報酬)を指します。
∇ θ \nabla_{\theta} ∇ θ は勾配演算子であり、パラメータ θ \theta θ に対して偏導数を求めることを示します。
第二歩 、期待値の表現を展開します:
期待値の定義は何ですか?確率変数 X X X の期待値 E [ X ] E[X] E [ X ] は、その確率分布 p ( x ) p(x) p ( x ) によって計算できます:
連続変数の場合:E [ X ] = ∫ p ( x ) x d x E[X] = \int p(x) x dx E [ X ] = ∫ p ( x ) x d x
離散変数の場合:E [ X ] = ∑ p ( x ) x E[X] = \sum p(x) x E [ X ] = ∑ p ( x ) x
私たちの例では、確率変数は軌跡の報酬 R ( τ ) R(\tau) R ( τ ) であり、確率分布は軌跡が発生する確率 P ( τ ∣ π θ ) P(\tau|\pi_{\theta}) P ( τ ∣ π θ ) (ポリシー π θ \pi_{\theta} π θ の下で軌跡 τ \tau τ が発生する確率)です。したがって、期待値は積分(または合計)の形式で表すことができます:
E τ ∼ π θ [ R ( τ ) ] = ∫ P ( τ ∣ π θ ) R ( τ ) d τ E_{\tau \sim \pi_{\theta}} [R(\tau)] = \int P(\tau|\pi_{\theta}) R(\tau) d\tau E τ ∼ π θ [ R ( τ )] = ∫ P ( τ ∣ π θ ) R ( τ ) d τ
(ここでは、積分記号 ∫ \int ∫ を使用してすべての可能な軌跡を合計または積分することを示します)。
この結果を第一歩の公式に代入します:
∇ θ J ( π θ ) = ∇ θ ∫ P ( τ ∣ π θ ) R ( τ ) d τ \nabla_{\theta} J(\pi_{\theta}) = \nabla_{\theta} \int P(\tau|\pi_{\theta}) R(\tau) d\tau ∇ θ J ( π θ ) = ∇ θ ∫ P ( τ ∣ π θ ) R ( τ ) d τ
第三歩 :勾配演算子を積分記号の中に移動します
∇ θ ∫ P ( τ ∣ π θ ) R ( τ ) d τ = ∫ ∇ θ [ P ( τ ∣ π θ ) R ( τ ) ] d τ \nabla_{\theta} \int P(\tau|\pi_{\theta}) R(\tau) d\tau = \int \nabla_{\theta} [P(\tau|\pi_{\theta}) R(\tau)] d\tau ∇ θ ∫ P ( τ ∣ π θ ) R ( τ ) d τ = ∫ ∇ θ [ P ( τ ∣ π θ ) R ( τ )] d τ
ここでは少し微積分の知識が必要です:特定の条件を満たす場合(通常、強化学習では満たされると仮定します)、私たちは微分と積分の順序を交換することができます 。d d x ∑ f i ( x ) = ∑ d d x f i ( x ) \frac{d}{dx} \sum f_i(x) = \sum \frac{d}{dx} f_i(x) d x d ∑ f i ( x ) = ∑ d x d f i ( x ) のように。
次に、R ( τ ) R(\tau) R ( τ ) は特定の軌跡が決定された後の総報酬 であり、その値はポリシーのパラメータ θ \theta θ に直接依存しません(ポリシー π θ \pi_{\theta} π θ がどの軌跡が発生するか に影響を与えるのではなく、発生した後の報酬値に影響を与えます)。したがって、勾配 ∇ θ \nabla_{\theta} ∇ θ は P ( τ ∣ π θ ) P(\tau|\pi_{\theta}) P ( τ ∣ π θ ) のみに作用する必要があります:
= ∫ [ ∇ θ P ( τ ∣ π θ ) ] R ( τ ) d τ = \int [\nabla_{\theta} P(\tau|\pi_{\theta})] R(\tau) d\tau = ∫ [ ∇ θ P ( τ ∣ π θ )] R ( τ ) d τ
このステップは、期待報酬の変化が、パラメータ θ \theta θ の変更によって各軌跡が発生する確率 P ( τ ∣ π θ ) P(\tau|\pi_{\theta}) P ( τ ∣ π θ ) が変化し、その後その軌跡自体の報酬 R ( τ ) R(\tau) R ( τ ) に掛けられ、すべての軌跡を合計した結果であることを示しています。
第四歩 :対数導数のテクニック (Log-derivative trick)
これは全体の導出の中で最も核心的で巧妙なステップです!私たちは恒等式を導入する必要があります。
高等数学の復習 (連鎖律と対数微分): 自然対数 log ( x ) \log(x) log ( x ) (通常は ln ( x ) \ln(x) ln ( x ) ) の導数は d d x log ( f ( x ) ) = 1 f ( x ) d f ( x ) d x = f ′ ( x ) f ( x ) \frac{d}{dx} \log(f(x)) = \frac{1}{f(x)} \frac{d f(x)}{dx} = \frac{f'(x)}{f(x)} d x d log ( f ( x )) = f ( x ) 1 d x df ( x ) = f ( x ) f ′ ( x ) です。
少し変形すると、f ′ ( x ) = f ( x ) d d x log ( f ( x ) ) f'(x) = f(x) \frac{d}{dx} \log(f(x)) f ′ ( x ) = f ( x ) d x d log ( f ( x )) という結果が得られます。
現在、このテクニックを勾配に適用します。f ( x ) f(x) f ( x ) を P ( τ ∣ π θ ) P(\tau|\pi_{\theta}) P ( τ ∣ π θ ) に対応させ、変数 x x x をパラメータ θ \theta θ に対応させます。すると:
∇ θ P ( τ ∣ π θ ) = P ( τ ∣ π θ ) ∇ θ log P ( τ ∣ π θ ) \nabla_{\theta} P(\tau|\pi_{\theta}) = P(\tau|\pi_{\theta}) \nabla_{\theta} \log P(\tau|\pi_{\theta}) ∇ θ P ( τ ∣ π θ ) = P ( τ ∣ π θ ) ∇ θ log P ( τ ∣ π θ )
この結果を第三歩の積分に代入します:
∫ [ ∇ θ P ( τ ∣ π θ ) ] R ( τ ) d τ = ∫ [ P ( τ ∣ π θ ) ∇ θ log P ( τ ∣ π θ ) ] R ( τ ) d τ \int [\nabla_{\theta} P(\tau|\pi_{\theta})] R(\tau) d\tau = \int [P(\tau|\pi_{\theta}) \nabla_{\theta} \log P(\tau|\pi_{\theta})] R(\tau) d\tau ∫ [ ∇ θ P ( τ ∣ π θ )] R ( τ ) d τ = ∫ [ P ( τ ∣ π θ ) ∇ θ log P ( τ ∣ π θ )] R ( τ ) d τ
第五歩 :期待値形式に戻す
第四歩の結果を観察します:
∫ P ( τ ∣ π θ ) [ ∇ θ log P ( τ ∣ π θ ) R ( τ ) ] d τ \int P(\tau|\pi_{\theta}) [\nabla_{\theta} \log P(\tau|\pi_{\theta}) R(\tau)] d\tau ∫ P ( τ ∣ π θ ) [ ∇ θ log P ( τ ∣ π θ ) R ( τ )] d τ
これは期待値の定義に一致します! E [ f ( τ ) ] = ∫ P ( τ ∣ π θ ) f ( τ ) d τ E[f(\tau)] = \int P(\tau|\pi_{\theta}) f(\tau) d\tau E [ f ( τ )] = ∫ P ( τ ∣ π θ ) f ( τ ) d τ 。
ここで、 f ( τ ) f(\tau) f ( τ ) は角括弧内のすべての内容 [ ∇ θ log P ( τ ∣ π θ ) R ( τ ) ] [\nabla_{\theta} \log P(\tau|\pi_{\theta}) R(\tau)] [ ∇ θ log P ( τ ∣ π θ ) R ( τ )] に対応します。
したがって、全体の積分は期待値の形式に戻すことができます:
= E τ ∼ π θ [ ∇ θ log P ( τ ∣ π θ ) R ( τ ) ] = E_{\tau \sim \pi_{\theta}} [\nabla_{\theta} \log P(\tau|\pi_{\theta}) R(\tau)] = E τ ∼ π θ [ ∇ θ log P ( τ ∣ π θ ) R ( τ )]
重大な意義! 私たちは成功裏に期待の勾配 ∇ θ E [ ⋅ ] \nabla_{\theta} E[\cdot] ∇ θ E [ ⋅ ] をある量(勾配と報酬の積)の期待 E [ ∇ ( ⋅ ) × R ] E[\nabla (\cdot) \times R] E [ ∇ ( ⋅ ) × R ] に変換しました。この形式は非常に重要です。なぜなら、サンプリング によって近似できるからです!私たちは実際にすべての軌跡の積分を計算する必要はありません。たくさんの軌跡 τ \tau τ をサンプリングし、各軌跡に対して角括弧内の値 [ ∇ θ log P ( τ ∣ π θ ) R ( τ ) ] [\nabla_{\theta} \log P(\tau|\pi_{\theta}) R(\tau)] [ ∇ θ log P ( τ ∣ π θ ) R ( τ )] を計算し、平均を取ることで勾配の近似値を得ることができます!
第六歩 :対数確率の勾配の展開 (Expression for grad-log-prob)
現在、期待値の中の ∇ θ log P ( τ ∣ π θ ) \nabla_{\theta} \log P(\tau|\pi_{\theta}) ∇ θ log P ( τ ∣ π θ ) の項を処理する必要があります。
軌跡 τ = ( s 0 , a 0 , s 1 , a 1 , … , s T , a T ) \tau = (s_0, a_0, s_1, a_1, \dots, s_T, a_T) τ = ( s 0 , a 0 , s 1 , a 1 , … , s T , a T ) を思い出してください(軌跡の長さは T+1 の状態と T+1 のアクション、または T の時間ステップとします)。軌跡が発生する確率は:
P ( τ ∣ π θ ) = p 0 ( s 0 ) ∏ t = 0 T P ( s t + 1 ∣ s t , a t ) π θ ( a t ∣ s t ) P(\tau|\pi_{\theta}) = p_0(s_0) \prod_{t=0}^{T} P(s_{t+1}|s_t, a_t) \pi_{\theta}(a_t|s_t) P ( τ ∣ π θ ) = p 0 ( s 0 ) ∏ t = 0 T P ( s t + 1 ∣ s t , a t ) π θ ( a t ∣ s t )
p 0 ( s 0 ) p_0(s_0) p 0 ( s 0 ) :初期状態 s 0 s_0 s 0 の確率。
P ( s t + 1 ∣ s t , a t ) P(s_{t+1}|s_t, a_t) P ( s t + 1 ∣ s t , a t ) :環境のダイナミクス(Environment Dynamics)、状態 s t s_t s t でアクション a t a_t a t を実行した後、状態 s t + 1 s_{t+1} s t + 1 に遷移する確率。
π θ ( a t ∣ s t ) \pi_{\theta}(a_t|s_t) π θ ( a t ∣ s t ) :ポリシー、状態 s t s_t s t でアクション a t a_t a t を選択する確率(この部分は θ \theta θ に依存します)。
数学の復習 (対数の性質): log ( a × b ) = log a + log b \log(a \times b) = \log a + \log b log ( a × b ) = log a + log b および log ( ∏ i x i ) = ∑ i log x i \log(\prod_{i} x_i) = \sum_{i} \log x_i log ( ∏ i x i ) = ∑ i log x i 。
P ( τ ∣ π θ ) P(\tau|\pi_{\theta}) P ( τ ∣ π θ ) の対数を取ります:
log P ( τ ∣ π θ ) = log p 0 ( s 0 ) + ∑ t = 0 T [ log P ( s t + 1 ∣ s t , a t ) + log π θ ( a t ∣ s t ) ] \log P(\tau|\pi_{\theta}) = \log p_0(s_0) + \sum_{t=0}^{T} [\log P(s_{t+1}|s_t, a_t) + \log \pi_{\theta}(a_t|s_t)] log P ( τ ∣ π θ ) = log p 0 ( s 0 ) + ∑ t = 0 T [ log P ( s t + 1 ∣ s t , a t ) + log π θ ( a t ∣ s t )]
現在、上式について θ \theta θ の勾配 ∇ θ \nabla_{\theta} ∇ θ を求めます:
数学の復習 (勾配の性質): 勾配の加法法則 ∇ ( f + g ) = ∇ f + ∇ g \nabla(f+g) = \nabla f + \nabla g ∇ ( f + g ) = ∇ f + ∇ g 。勾配 ∇ θ \nabla_{\theta} ∇ θ は θ \theta θ に依存する項にのみ作用します。
∇ θ log P ( τ ∣ π θ ) = ∇ θ log p 0 ( s 0 ) + ∑ t = 0 T [ ∇ θ log P ( s t + 1 ∣ s t , a t ) + ∇ θ log π θ ( a t ∣ s t ) ] \nabla_{\theta} \log P(\tau|\pi_{\theta}) = \nabla_{\theta} \log p_0(s_0) + \sum_{t=0}^{T} [\nabla_{\theta} \log P(s_{t+1}|s_t, a_t) + \nabla_{\theta} \log \pi_{\theta}(a_t|s_t)] ∇ θ log P ( τ ∣ π θ ) = ∇ θ log p 0 ( s 0 ) + ∑ t = 0 T [ ∇ θ log P ( s t + 1 ∣ s t , a t ) + ∇ θ log π θ ( a t ∣ s t )]
重要な点:
初期状態の確率 log p 0 ( s 0 ) \log p_0(s_0) log p 0 ( s 0 ) は通常ポリシーのパラメータ θ \theta θ に依存しないため、∇ θ log p 0 ( s 0 ) = 0 \nabla_{\theta} \log p_0(s_0) = 0 ∇ θ log p 0 ( s 0 ) = 0 です。
環境のダイナミクス log P ( s t + 1 ∣ s t , a t ) \log P(s_{t+1}|s_t, a_t) log P ( s t + 1 ∣ s t , a t ) は環境自体の特性を記述しており、ポリシーのパラメータ θ \theta θ に依存しないため、∇ θ log P ( s t + 1 ∣ s t , a t ) = 0 \nabla_{\theta} \log P(s_{t+1}|s_t, a_t) = 0 ∇ θ log P ( s t + 1 ∣ s t , a t ) = 0 です。
ただし、ポリシー log π θ ( a t ∣ s t ) \log \pi_{\theta}(a_t|s_t) log π θ ( a t ∣ s t ) は θ \theta θ に依存します。
したがって、上式は次のように簡略化されます:
∇ θ log P ( τ ∣ π θ ) = ∑ t = 0 T ∇ θ log π θ ( a t ∣ s t ) \nabla_{\theta} \log P(\tau|\pi_{\theta}) = \sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_t|s_t) ∇ θ log P ( τ ∣ π θ ) = ∑ t = 0 T ∇ θ log π θ ( a t ∣ s t )
軌跡全体の対数確率の勾配は、この軌跡内の各ステップのアクションの対数確率の勾配 の合計に等しいです!これにより計算が大幅に簡素化されます。
第七歩 :最終的なポリシー勾配定理
第六歩で簡略化された結果を第五歩の期待値公式に戻します:
∇ θ J ( π θ ) = E τ ∼ π θ [ ( ∑ t = 0 T ∇ θ log π θ ( a t ∣ s t ) ) R ( τ ) ] \nabla_{\theta} J(\pi_{\theta}) = E_{\tau \sim \pi_{\theta}} [(\sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_t|s_t)) R(\tau)] ∇ θ J ( π θ ) = E τ ∼ π θ [( ∑ t = 0 T ∇ θ log π θ ( a t ∣ s t )) R ( τ )]
これがポリシー勾配定理 (Policy Gradient Theorem) の最終的な形式(または一般的な形式の一つ)です。
期待報酬 J ( π θ ) J(\pi_{\theta}) J ( π θ ) のパラメータ θ \theta θ に対する勾配は、「軌跡 τ \tau τ をサンプリングし、その軌跡の総報酬 R ( τ ) R(\tau) R ( τ ) を計算し、その後この軌跡内のすべての (状態,アクション) に対するポリシーの対数確率の勾配 ∇ θ log π θ ( a t ∣ s t ) \nabla_{\theta} \log \pi_{\theta}(a_t|s_t) ∇ θ log π θ ( a t ∣ s t ) の合計を掛け、すべての可能な軌跡に対して期待値(平均)を求める」ことに等しいです。
明らかに、すべての軌跡を得るコストは非常に高く、例えば max_token_length=100
のすべての生成結果をサンプリングする必要があります。したがって、サンプル平均を使用して期待値を近似できます:
Note
モンテカルロ近似:* 現在のポリシー π θ \pi_{\theta} π θ を実行し、N N N 本の軌跡を収集してデータセット D = { τ 1 , . . . , τ N } D = \{\tau_1, ..., \tau_N\} D = { τ 1 , ... , τ N } を構成します(N = ∣ D ∣ N = |D| N = ∣ D ∣ とします)。
これらのサンプルの平均値を使用して期待値を近似します:
∇ θ J ( π θ ) ≈ g ^ = 1 ∣ D ∣ ∑ τ ∈ D [ ( ∑ t = 0 T ∇ θ log π θ ( a t ∣ s t ) ) R ( τ ) ] \nabla_{\theta} J(\pi_{\theta}) \approx \hat{g} = \frac{1}{|D|} \sum_{\tau \in D} [ (\sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_t|s_t)) R(\tau) ] ∇ θ J ( π θ ) ≈ g ^ = ∣ D ∣ 1 ∑ τ ∈ D [( ∑ t = 0 T ∇ θ log π θ ( a t ∣ s t )) R ( τ )]
LM ポリシーへの適用#
図に示されている生成プロセスを通じて、このサンプリング軌跡内の各状態アクションペアの対数確率を得ることができ、今度は逆伝播を行って勾配を計算できます。
その後、各勾配に RM からの報酬を掛けて、勾配上昇最適化を実行します:
高い分散#
PG アルゴリズムは小さな問題に対しては良好に機能しますが、言語モデリングに使用するといくつかの問題が発生します。
Note
中心極限定理は、サンプルが十分に大きければ、サンプル平均は正規分布に従うことを教えてくれます。これにより、データをより良く予測し、分析することができます。サンプルサイズが小さい場合、サンプル平均の変動は非常に大きくなります。たとえ平均が正規分布に近づいても、単一のサンプリング結果は大きく異なる可能性があります。また、LM から多くの軌跡をサンプリングするコストが非常に高いことがわかっています。これにより、推定量の高い分散の問題が発生します。
サンプルサイズを増やさずに分散を減らすにはどうすればよいでしょうか?
過去の報酬を削除する:reward-to-go
まず、現在のアクションが過去に得られた報酬に影響を与えないことを認める必要があります。過去の報酬は不要なノイズを増加させ、これは RL における信用配分問題に関連しているはずです。したがって、過去の項を削除することでノイズの増加を避け、推定された勾配を真の勾配に近づけることができます。したがって、軌跡の報酬をゼロから計算するのではなく、現在の時間ステップから始まるアクションの報酬のみを考慮することができます。
ベースラインを導入する
RL の研究は、状態に依存する項(たとえば、軌跡報酬の関数を計算すること、または定数を使用すること)を導入することで分散を減らせることを証明しています。ここでは価値関数 V π ( s ) V^\pi(s) V π ( s ) を選択します。
価値関数#
V π ( s ) V^\pi(s) V π ( s ) は、現在のポリシーに従って行動した場合の残りの軌跡の期待報酬を示します。
古典的な RL シナリオと LM シナリオにおける価値の定義の例:
実際の操作では、最適化しようとしている LM を初期化として使用し、その上に線形層を追加して価値を予測します。これにより、Transformer 層のパラメータは、トークンを語彙に投影する層を使用して言語モデリングと価値推定の両方に同時に使用できます。
前述の報酬 - to-go は RL において Q 関数と呼ばれ、現在の状態からこのアクションを取り、即時報酬を得て、ポリシーに従ってその後の行動の期待報酬を得ることを示します:
次に、価値関数を導入することで Q と V の違いを得ます。この違いはアドバンテージ関数と呼ばれます。
この A π ( s t , a t ) A^\pi(s_{t,}a_t) A π ( s t , a t ) アドバンテージ項は、この特定のアクションが状態 s s s で取れる平均アクションよりもどれだけ良いかを示します。
図中の赤い矢印が指す状態では、下に移動するアドバンテージ関数が他のアクションのアドバンテージ関数よりも高くなります。
∇ θ ( J ( θ ) ) ≈ 1 N ∑ i = 1 N ( ∑ t = 0 T ∇ θ log π θ ( a i , t ∣ s i , t ) ) A π ( s i , t , a i , t ) \nabla_{\theta}(J(\theta)) \approx \frac{1}{N}\sum_{i=1}^{N}\left(\sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_{i,t} | s_{i,t})\right) A^{\pi}(s_{i,t}, a_{i,t}) ∇ θ ( J ( θ )) ≈ N 1 ∑ i = 1 N ( ∑ t = 0 T ∇ θ log π θ ( a i , t ∣ s i , t ) ) A π ( s i , t , a i , t )
勾配にアドバンテージ関数を掛けることで、効果は高いアドバンテージアクションの logprob を増加させ、低い平均報酬アクションの log prob を減少させることになります。
Copy A ^ π ( s t , a t ) = Q π ( s t , a t ) − V π ( s t ) = [ r ( s t , a t ) + γ V π ( s t + 1 ) ] − V π ( s t ) \begin{align}
\hat{A}^\pi(s_t, a_t) &= Q^\pi(s_t, a_t) - V^\pi(s_t) \\
&= [r(s_t, a_t) + \gamma V^\pi(s_{t+1})] - V^\pi(s_t)
\end{align} A ^ π ( s t , a t ) = Q π ( s t , a t ) − V π ( s t ) = [ r ( s t , a t ) + γ V π ( s t + 1 )] − V π ( s t )
Note
従来の強化学習手法では、Q ネットワーク と V ネットワーク は通常独立しています。つまり、Q 関数は状態 s s s でアクション a a a を実行したときの期待される総報酬を推定するために使用され、V 関数は状態 s s s の値を推定するために使用されます。これにより、これらの 2 つの値を計算するために 2 つの異なる神経ネットワークが必要になります。
しかし、ここではアドバンテージ関数 A θ ( s , a ) A_{\theta}(s, a) A θ ( s , a ) を導入しました。その計算方法は、Q 値と V 値の間の違いに基づいています。すなわち:
A θ ( s , a ) = Q θ ( s , a ) − V θ ( s ) A_{\theta}(s, a) = Q_{\theta}(s, a) - V_{\theta}(s) A θ ( s , a ) = Q θ ( s , a ) − V θ ( s )
A θ ( s , a ) A_{\theta}(s, a) A θ ( s , a ) を Q θ ( s , a ) Q_{\theta}(s, a) Q θ ( s , a ) と V θ ( s ) V_{\theta}(s) V θ ( s ) の違いとして表現することで、私たちは V θ ( s ) V_{\theta}(s) V θ ( s ) を出力するために 1 つのネットワークを訓練するだけで済む ことがわかります。次に、報酬 r t r_t r t と割引因子 γ \gamma γ を使用して Q 値を計算します。
したがって、必要なのは 1 つの神経ネットワークだけであり、このネットワークは主に V θ ( s ) V_{\theta}(s) V θ ( s ) を予測します。Q 値は次の式で計算されます:
Q θ ( s t , a ) = r t + γ ⋅ V θ ( s t + 1 ) Q_{\theta}(s_t, a) = r_t + \gamma \cdot V_{\theta}(s_{t+1}) Q θ ( s t , a ) = r t + γ ⋅ V θ ( s t + 1 )
アドバンテージ関数はさらに次のように計算されます:
A θ ( s t , a ) = r t + γ ⋅ V θ ( s t + 1 ) − V θ ( s t ) A_{\theta}(s_t, a) = r_t + \gamma \cdot V_{\theta}(s_{t+1}) - V_{\theta}(s_t) A θ ( s t , a ) = r t + γ ⋅ V θ ( s t + 1 ) − V θ ( s t )
アドバンテージサンプリング#
短いステップのアドバンテージ推定器は偏差が大きいが分散が小さく、長いステップのアドバンテージ推定器は偏差が小さいが分散が大きい。このトレードオフの問題は、強化学習において慎重に選択し調整する必要がある部分であり、モデルの安定性要件と訓練効率に依存します。
例:「短期記憶の人は昨日起こったことしか覚えていないが、全体像は不十分だが非常に安定している。長期記憶の人は未来数日の全体像を見通すことができるが、より多くの未知の要因によって干渉される可能性がある。」
GAE#
この偏差 - 分散の問題を解決するために、GAE(一般化アドバンテージ推定)を使用できます。これは本質的にすべてのアドバンテージ項の加重和であり、各項に減衰因子を掛けます。
Note
ここで TD 誤差について話しましょう。
オンライン学習 の利点は、最後まで待たずにポリシーを更新できることです。したがって、時系列差分誤差(TD Error) が登場します:
δ = r + γ V ( s ′ ) − V ( s ) \delta = r + \gamma V(s') - V(s) δ = r + γV ( s ′ ) − V ( s )
ここでの重要な点は、TD 誤差 が実際にはアドバンテージ関数のオンライン推定 であることです。これは、まさに今 、あなたのアクションが未来の状態を期待以上に良くしたかどうかを教えてくれます。この誤差 δ \delta δ はアドバンテージの概念を直接反映しています:
δ > 0 \delta > 0 δ > 0 の場合:「このアクションは私が想像していたよりも良い!」(アドバンテージは正)。
δ < 0 \delta < 0 δ < 0 の場合:「うーん、私はもっと良いと思っていたのに……」(アドバンテージは負)。
これにより、逐次的に ポリシーを調整でき、エピソード全体が終了するのを待つ必要がありません。これは効率を向上させるための素晴らしい戦略です。
GAE の目的は、ポリシー勾配アルゴリズムにおいて、元の報酬 R ( τ ) R(τ) R ( τ ) または単純な TD 誤差 δ t δ_t δ t よりも優れたアドバンテージ関数 A π ( s , a ) A^π(s,a) A π ( s , a ) の推定値 A t A^t A t を提供し、勾配推定の分散を低下させ、学習の安定性と効率を向上させることです。
Copy δ t = r t + γ V π ( s t + 1 ) − V π ( s t ) A ^ t = δ t + γ λ A ^ t + 1 \begin{align}
\delta_t &= r_t + \gamma V^\pi(s_{t+1}) - V^\pi(s_t) \\
\hat{A}_t &= \delta_t + \gamma \lambda \hat{A}_{t+1}
\end{align} δ t A ^ t = r t + γ V π ( s t + 1 ) − V π ( s t ) = δ t + γλ A ^ t + 1
この公式は再帰的に 一般化アドバンテージ推定 A ^ t \hat{A}_t A ^ t を定義します。これは、単に 1 ステップの TD 誤差 δ t \delta_t δ t を見るのではなく、将来の複数ステップの TD 誤差情報を統合します。
この再帰式は、軌跡(エピソード)の末尾から計算されます(T T T が最後のステップであると仮定し、A ^ T + 1 = 0 \hat{A}_{T+1}=0 A ^ T + 1 = 0 とします):
* A ^ T = δ T + 0 = δ T \hat{A}_T = \delta_T + 0 = \delta_T A ^ T = δ T + 0 = δ T
* A ^ T − 1 = δ T − 1 + γ λ A ^ T = δ T − 1 + ( γ λ ) δ T \hat{A}_{T-1} = \delta_{T-1} + \gamma \lambda \hat{A}_T = \delta_{T-1} + (\gamma \lambda) \delta_T A ^ T − 1 = δ T − 1 + γλ A ^ T = δ T − 1 + ( γλ ) δ T
* A ^ T − 2 = δ T − 2 + γ λ A ^ T − 1 = δ T − 2 + ( γ λ ) δ T − 1 + ( γ λ ) 2 δ T \hat{A}_{T-2} = \delta_{T-2} + \gamma \lambda \hat{A}_{T-1} = \delta_{T-2} + (\gamma \lambda) \delta_{T-1} + (\gamma \lambda)^2 \delta_T A ^ T − 2 = δ T − 2 + γλ A ^ T − 1 = δ T − 2 + ( γλ ) δ T − 1 + ( γλ ) 2 δ T
* ...
* 一般形式 :A ^ t = ∑ k = 0 ∞ ( γ λ ) k δ t + k \hat{A}_t = \sum_{k=0}^{\infty} (\gamma \lambda)^k \delta_{t+k} A ^ t = ∑ k = 0 ∞ ( γλ ) k δ t + k (無限ステップまたは終了状態後の δ \delta δ が 0 であると仮定します)
パラメータ λ \lambda λ (0 ≤ λ ≤ 1 0 \le \lambda \le 1 0 ≤ λ ≤ 1 )は GAE の鍵であり、推定 A ^ t \hat{A}_t A ^ t の偏差(bias)と分散(variance)を制御します:
λ = 0 \lambda = 0 λ = 0 の場合:
A ^ t = δ t \hat{A}_t = \delta_t A ^ t = δ t 。GAE は単純な単一ステップ TD 誤差 に退化します。この推定は分散が低い (次のステップの情報のみに依存するため)ですが、偏差が高い (おそらく不正確な V π ( s t + 1 ) V^{\pi}(s_{t+1}) V π ( s t + 1 ) の推定に大きく依存するため)可能性があります。
λ = 1 \lambda = 1 λ = 1 の場合:
A ^ t = ∑ k = 0 ∞ ( γ ) k δ t + k \hat{A}_t = \sum_{k=0}^{\infty} (\gamma)^k \delta_{t+k} A ^ t = ∑ k = 0 ∞ ( γ ) k δ t + k 。導出により、これは A ^ t = ( ∑ k = 0 ∞ γ k r t + k ) − V π ( s t ) \hat{A}_t = (\sum_{k=0}^{\infty} \gamma^k r_{t+k}) - V^{\pi}(s_t) A ^ t = ( ∑ k = 0 ∞ γ k r t + k ) − V π ( s t ) に等しいことが証明できます。これはモンテカルロ(Monte Carlo)報酬から基準(baseline)を引いたもの です。この推定は偏差が低い (t t t の時点からの完全な実際の報酬を使用するため)ですが、分散は通常高い (複数の時間ステップのランダム性を累積するため)です。
0 < λ < 1 0 < \lambda < 1 0 < λ < 1 の場合:
GAE は上記の 2 つの極端な状況の間で補間を行います。λ \lambda λ が 0 に近いほど、低分散高偏差の TD 推定に偏ります。λ \lambda λ が 1 に近いほど、高分散低偏差の MC 推定に偏ります。
適切な λ \lambda λ (例えば 0.97)を選択することで、GAE は偏差と分散の間で良好なバランス を取ろうとし、比較的正確(偏差が制御可能)であり、かつ比較的安定(分散が小さい)なアドバンテージ推定を得ることを目指します。
言語モデルのアドバンテージ#
図のように、目標は現在の状態「上海」トークンの logprob を高め、「チョコレート」の logprob を低下させることです。なぜなら、「上海」を選択することのアドバンテージが「チョコレート」(無関係なトークン)を選択することよりも高いからです。
重要性サンプリングとオフライン学習#
多くの状況では、私たちは E x ∼ p ( x ) [ f ( x ) ] E_{x \sim p(x)}[f(x)] E x ∼ p ( x ) [ f ( x )] を計算したいと思うかもしれませんが:
目標分布 p ( x ) p(x) p ( x ) から直接サンプリングして x x x を得るのが難しいか、できない。
または、p ( x ) p(x) p ( x ) からのサンプリングが効率的でない。言語モデルではこの問題があり、LM のサンプリングコストが非常に高いです。
しかし、私たちは別の代替の、または提案(Proposal)分布 q ( x ) q(x) q ( x ) からサンプリングすることが容易であるかもしれません。
重要性サンプリング (Importance Sampling, IS) は、異なる分布からサンプリングすることによって目標分布の期待値を推定する技術です。確率分布 p ( x ) p(x) p ( x ) の下で計算された期待値 E x ∼ p ( x ) [ f ( x ) ] E_{x \sim p(x)}[f(x)] E x ∼ p ( x ) [ f ( x )] を、別の異なる確率分布 q ( x ) q(x) q ( x ) の下で関連する関数の期待値 E x ∼ q ( x ) [ p ( x ) q ( x ) f ( x ) ] E_{x \sim q(x)}\left[\frac{p(x)}{q(x)} f(x)\right] E x ∼ q ( x ) [ q ( x ) p ( x ) f ( x ) ] に変換します。
Copy E x ∼ p ( x ) [ f ( x ) ] = ∫ p ( x ) f ( x ) d x = ∫ q ( x ) q ( x ) p ( x ) f ( x ) d x (仮定 q ( x ) ≠ 0 ) = ∫ q ( x ) p ( x ) q ( x ) f ( x ) d x = E x ∼ q ( x ) [ p ( x ) q ( x ) f ( x ) ] \begin{align}
E_{x \sim p(x)}[f(x)] &= \int p(x) f(x) \, dx \\
&= \int \frac{q(x)}{q(x)} p(x) f(x) \, dx \quad \text{(仮定 } q(x) \neq 0 \text{)} \\
&= \int q(x) \frac{p(x)}{q(x)} f(x) \, dx \\
&= E_{x \sim q(x)}\left[\frac{p(x)}{q(x)} f(x)\right]
\end{align} E x ∼ p ( x ) [ f ( x )] = ∫ p ( x ) f ( x ) d x = ∫ q ( x ) q ( x ) p ( x ) f ( x ) d x ( 仮定 q ( x ) = 0 ) = ∫ q ( x ) q ( x ) p ( x ) f ( x ) d x = E x ∼ q ( x ) [ q ( x ) p ( x ) f ( x ) ]
ここでの重要な点は、重要性重み(importance weight) w ( x ) = p ( x ) q ( x ) w(x) = \frac{p(x)}{q(x)} w ( x ) = q ( x ) p ( x ) を導入したことです。この重みの役割は偏差を修正する ことです:q ( x ) q(x) q ( x ) からサンプリングされたサンプル x i x_i x i が目標分布 p ( x ) p(x) p ( x ) で出現する確率が高い場合(p ( x i ) > q ( x i ) p(x_i) > q(x_i) p ( x i ) > q ( x i ) )、重みは 1 より大きくなります。逆に、p ( x ) p(x) p ( x ) で出現する確率が低い場合(p ( x i ) < q ( x i ) p(x_i) < q(x_i) p ( x i ) < q ( x i ) )、重みは 1 より小さくなります。このように加重平均を取ることで、元の期待値 E x ∼ p ( x ) [ f ( x ) ] E_{x \sim p(x)}[f(x)] E x ∼ p ( x ) [ f ( x )] の(通常は無偏または一致する)推定を得ることができます。
重要性サンプリングを使用すると:
簡単にサンプリングできる分布 q ( x ) q(x) q ( x ) からサンプル x 1 , x 2 , . . . , x N x_1, x_2, ..., x_N x 1 , x 2 , ... , x N を抽出できます。
加重平均を計算することで元の期待値を推定できます:
Copy E x ∼ p ( x ) [ f ( x ) ] ≈ 1 N ∑ i = 1 N p ( x i ) q ( x i ) f ( x i ) E_{x \sim p(x)}[f(x)] \approx \frac{1}{N} \sum_{i=1}^{N} \frac{p(x_i)}{q(x_i)} f(x_i) E x ∼ p ( x ) [ f ( x )] ≈ N 1 i = 1 ∑ N q ( x i ) p ( x i ) f ( x i )
私たちのシナリオに戻りましょう。前述のように、オンポリシーのポリシー勾配推定を得ました:
Copy ∇ θ ( J ( θ ) ) ≈ 1 N ∑ i = 1 N ( ∑ t = 0 T ∇ θ log π θ ( a i , t ∣ s i , t ) ) A π ( s i , t , a i , t ) \nabla_{\theta}(J(\theta)) \approx \frac{1}{N}\sum_{i=1}^{N}\left(\sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_{i,t} | s_{i,t})\right) A^{\pi}(s_{i,t}, a_{i,t}) ∇ θ ( J ( θ )) ≈ N 1 i = 1 ∑ N ( t = 0 ∑ T ∇ θ log π θ ( a i , t ∣ s i , t ) ) A π ( s i , t , a i , t )
[!info]
オンポリシーの意味:データを収集するためのポリシーと訓練時に使用するポリシーが同じであること。計算時には、現在のポリシー π θ \pi_{\theta} π θ からサンプリングされた軌跡を使用する必要があります。これは、ポリシーが更新されるたびに古いデータが直接使用できなくなることを意味し、サンプル効率が低下します。ミニバッチ数が 1 より大きい場合、後で更新に使用されるデータはオンポリシーと見なされるのでしょうか?厳密にはそうではないように感じるので、セミオンポリシーとして理解することもできます(表現は必ずしも厳密ではありません)。
オンポリシーは、現在のポリシーモデルが環境と相互作用できるかどうかを強調します。[ 2 ] ^{[2]} [ 2 ]
私たちは、古いポリシー π θ O F F L I N E \pi_{\theta_{OFFLINE}} π θ OFF L I NE によって生成されたデータを利用して、現在の新しいポリシー π θ O N L I N E \pi_{\theta_{ONLINE}} π θ ON L I NE の勾配を推定したいと考えています。これにより、データを再利用し、サンプル効率を向上させることができます。
重要性サンプリング IS の原理を思い出してください: E x ∼ p ( x ) [ f ( x ) ] = E x ∼ q ( x ) [ p ( x ) q ( x ) f ( x ) ] E_{x \sim p(x)}[f(x)] = E_{x \sim q(x)}\left[\frac{p(x)}{q(x)} f(x)\right] E x ∼ p ( x ) [ f ( x )] = E x ∼ q ( x ) [ q ( x ) p ( x ) f ( x ) ] 。
私たちの PG に対応させると(単一ステップの決定を簡略化して考えます):
* 目標分布 p ( x ) p(x) p ( x ) は新しいポリシー π θ O N L I N E ( a ∣ s ) \pi_{\theta_{ONLINE}}(a|s) π θ ON L I NE ( a ∣ s ) に対応します。
* サンプリング分布 q ( x ) q(x) q ( x ) は古いポリシー π θ O F F L I N E ( a ∣ s ) \pi_{\theta_{OFFLINE}}(a|s) π θ OFF L I NE ( a ∣ s ) に対応します。
* 重要性重みは w t = π θ O N L I N E ( a t ∣ s t ) π θ O F F L I N E ( a t ∣ s t ) w_t = \frac{\pi_{\theta_{ONLINE}}(a_t|s_t)}{\pi_{\theta_{OFFLINE}}(a_t|s_t)} w t = π θ OFF L I NE ( a t ∣ s t ) π θ ON L I NE ( a t ∣ s t ) です。
重要性重みをオンポリシー勾配の各項(各時間ステップ t t t )に適用すると、標準的なオフポリシー推定が得られます:
Copy ∇ θ O N L I N E ( J ( θ O N L I N E , θ O F F L I N E ) ) ≈ 1 N ∑ i = 1 N ∑ t = 0 T [ π θ O N L I N E ( a i , t ∣ s i , t ) π θ O F F L I N E ( a i , t ∣ s i , t ) ∇ θ O N L I N E log π θ O N L I N E ( a i , t ∣ s i , t ) A π ( s i , t , a i , t ) ] \nabla_{\theta_{ONLINE}} (J(\theta_{ONLINE},\theta_{OFFLINE})) \approx \frac{1}{N}\sum_{i=1}^{N}\sum_{t=0}^{T} \left[ \frac{\pi_{\theta_{ONLINE}}(a_{i,t}|s_{i,t})}{\pi_{\theta_{OFFLINE}}(a_{i,t}|s_{i,t})} \nabla_{\theta_{ONLINE}} \log \pi_{\theta_{ONLINE}}(a_{i,t} | s_{i,t}) A^{\pi}(s_{i,t}, a_{i,t}) \right] ∇ θ ON L I NE ( J ( θ ON L I NE , θ OFF L I NE )) ≈ N 1 i = 1 ∑ N t = 0 ∑ T [ π θ OFF L I NE ( a i , t ∣ s i , t ) π θ ON L I NE ( a i , t ∣ s i , t ) ∇ θ ON L I NE log π θ ON L I NE ( a i , t ∣ s i , t ) A π ( s i , t , a i , t ) ]
これにより、最適化中のポリシー(訓練するモデル)から毎回サンプリングすることなく、完全な勾配上昇最適化を行うことができるようになりました。代わりに、1 回サンプリングして軌跡をメモリ / データベースに保存し、ミニバッチでポリシーを最適化し、新しいポリシーをオフラインポリシー(サンプリングされたポリシー)で初期化できます。
PPO Loss#
PPO ロスは主に 3 つの部分から構成されます:ポリシーロス(L P O L I C Y L_{POLICY} L PO L I C Y )、価値関数ロス(L V F L_{VF} L V F )、およびエントロピー報酬(L E N T R O P Y L_{ENTROPY} L ENTROP Y )。
1. ポリシーロス (L P O L I C Y L_{POLICY} L PO L I C Y )#
裁剪された代替目標 (Clipped Surrogate Objective)
Copy L P O L I C Y = min ( π θ ( a t ∣ s t ) π θ o l d ( a t ∣ s t ) A ^ t , clip ( π θ ( a t ∣ s t ) π θ o l d ( a t ∣ s t ) , 1 − ϵ , 1 + ϵ ) A ^ t ) L_{POLICY} = \min \left( \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} \hat{A}_t, \text{clip} \left( \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}, 1-\epsilon, 1+\epsilon \right) \hat{A}_t \right) L PO L I C Y = min ( π θ o l d ( a t ∣ s t ) π θ ( a t ∣ s t ) A ^ t , clip ( π θ o l d ( a t ∣ s t ) π θ ( a t ∣ s t ) , 1 − ϵ , 1 + ϵ ) A ^ t )
これは PPO の核心です。あなたはこれが、私たちが前述の重要性サンプリングから導出したオフポリシーのポリシー勾配目標に似ていることに気付くでしょうが、重要な変更があります。
π θ ( a t ∣ s t ) π θ o l d ( a t ∣ s t ) \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} π θ o l d ( a t ∣ s t ) π θ ( a t ∣ s t ) : これは重要性サンプリング比率であり、r t ( θ ) r_t(\theta) r t ( θ ) と呼ばれます。これは、状態 s t s_t s t の下で、現在の (オンライン)ポリシー π θ \pi_{\theta} π θ に基づいてアクション a t a_t a t を取る確率を、収集された軌跡データを使用していた古い (オフライン)ポリシー π θ o l d \pi_{\theta_{old}} π θ o l d に基づいてそのアクションを取る確率で割ったものです。この比率は、データが現在改善しようとしているポリシーとはわずかに異なるポリシーから来ているという事実を修正します。
A ^ t \hat{A}_t A ^ t : これはアドバンテージ関数の推定値であり、GAE を使用して計算され、偏差と分散のバランスを取るのに役立ちます。これは、状態 s t s_t s t でアクション a t a_t a t を取ることが、同じ状態で平均的なアクションを取ることよりもどれだけ良いか、または悪いかを示します(現在の価値関数に基づいて判断されます)。
clip
関数 : これが PPO の重要なポイントです。
clip ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) \text{clip} \left( r_t(\theta), 1-\epsilon, 1+\epsilon \right) clip ( r t ( θ ) , 1 − ϵ , 1 + ϵ )
これは基本的に、確率比率 r t ( θ ) r_t(\theta) r t ( θ ) が 1 からあまりにも遠く離れた場合(高すぎるまたは低すぎる)、それを「裁剪」することを意味します。したがって、r t ( θ ) r_t(\theta) r t ( θ ) が 1.5 1.5 1.5 になろうとすると、ϵ \epsilon ϵ が 0.2 0.2 0.2 の場合、1.2 1.2 1.2 に裁剪されます。r t ( θ ) r_t(\theta) r t ( θ ) が 0.5 0.5 0.5 になろうとすると、0.8 0.8 0.8 に裁剪されます。
パラメータ ϵ \epsilon ϵ (epsilon) は小さなハイパーパラメータ(例えば 0.1 または 0.2)であり、裁剪範囲 [ 1 − ϵ , 1 + ϵ ] [1-\epsilon, 1+\epsilon] [ 1 − ϵ , 1 + ϵ ] を定義します。
min
関数 : この目的関数は、以下の 2 項のうち小さい方を取ります:
裁剪されていない目標: r t ( θ ) A ^ t r_t(\theta) \hat{A}_t r t ( θ ) A ^ t
裁剪された目標: clip ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) A ^ t \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_t clip ( r t ( θ ) , 1 − ϵ , 1 + ϵ ) A ^ t
なぜこのようにするのでしょうか? ポリシー勾配の目標は、正のアドバンテージを持つアクションの確率を増加させ、負のアドバンテージを持つアクションの確率を減少させることです。
しかし、重要性サンプリングを使用すると、r t ( θ ) r_t(\theta) r t ( θ ) が非常に大きくなると、巨大な更新と不安定性を引き起こす可能性があります。PPO はこの比率を裁剪することで、新しいポリシーと古いポリシーの近接性を保とうとします。
A ^ t > 0 \hat{A}_t > 0 A ^ t > 0 (良いアクション) の場合:私たちは π θ ( a t ∣ s t ) \pi_{\theta}(a_t|s_t) π θ ( a t ∣ s t ) を増加させたいと考えています。min
関数は、r t ( θ ) r_t(\theta) r t ( θ ) が 1 + ϵ 1+\epsilon 1 + ϵ を超えて成長した場合、目標関数は ( 1 + ϵ ) A ^ t (1+\epsilon)\hat{A}_t ( 1 + ϵ ) A ^ t に制限されます。これにより、ポリシーが単一の更新で大きく変化するのを防ぎます。未裁剪の目標がより大きな増幅を提案してもです。
A ^ t < 0 \hat{A}_t < 0 A ^ t < 0 (悪いアクション) の場合:私たちは π θ ( a t ∣ s t ) \pi_{\theta}(a_t|s_t) π θ ( a t ∣ s t ) を減少させたいと考えています。r t ( θ ) r_t(\theta) r t ( θ ) が 1 − ϵ 1-\epsilon 1 − ϵ 未満に縮小した場合、目標関数は ( 1 − ϵ ) A ^ t (1-\epsilon)\hat{A}_t ( 1 − ϵ ) A ^ t に制限されます。(注意:A ^ t < 0 \hat{A}_t < 0 A ^ t < 0 の場合、r t ( θ ) A ^ t r_t(\theta)\hat{A}_t r t ( θ ) A ^ t の項は r t ( θ ) r_t(\theta) r t ( θ ) が小さいときに大きな値を持ち(ゼロに近いまたは正)、clip ( . . . ) A ^ t \text{clip}(...) \hat{A}_t clip ( ... ) A ^ t も clip ( . . . ) \text{clip}(...) clip ( ... ) が小さいときに大きな値を持ちます。ここでの min
は、比率が裁剪境界を超えた場合、目標が過度に 負になることを防ぐことを意味します(つまり、そのアクションの確率を過度に低下させることはありません)。
2. 価値関数損失 (L V F L_{VF} L V F )#
Copy L V F = 1 2 ∥ V θ ( s ) − ( ∑ t ′ = t T γ t ′ − t r t ′ ∣ s 0 = s ) ∥ 2 2 L_{VF} = \frac{1}{2} \left\| V_{\theta}(s) - \left( \sum_{t'=t}^{T} \gamma^{t'-t} r_{t'} \Big| s_0 = s \right) \right\|^2_2 L V F = 2 1 V θ ( s ) − ( t ′ = t ∑ T γ t ′ − t r t ′ s 0 = s ) 2 2
これは前述の内容と完全に同じです:
V θ ( s ) V_{\theta}(s) V θ ( s ) は価値ネットワークの出力(つまり LLM の上に追加された線形層で、状態 s s s からの期待累積報酬を予測します)。
∑ γ t ′ r t ′ \sum \gamma^{t'} r_{t'} ∑ γ t ′ r t ′ という項(G s G_s G s または目標価値と呼ばれます)は、状態 s s s から始まり、現在のポリシーに従ってエピソードが終了するまで観察された実際の割引報酬の総和です。これは V θ ( s ) V_{\theta}(s) V θ ( s ) に対する経験的目標を設定します。
この損失関数は、予測値 V θ ( s ) V_{\theta}(s) V θ ( s ) と観察された目標値 G s G_s G s との間の平均二乗誤差(MSE)です。私たちは価値関数が将来の報酬を正確に予測できるようにしたいと考えています。この価値関数は、アドバンテージ A ^ t \hat{A}_t A ^ t を計算するために重要です。
3. エントロピー報酬 (L E N T R O P Y L_{ENTROPY} L ENTROP Y )#
Copy L E N T R O P Y = − ∑ x p ( x ) log p ( x ) L_{ENTROPY} = - \sum_x p(x) \log p(x) L ENTROP Y = − x ∑ p ( x ) log p ( x )
ここでの p ( x ) p(x) p ( x ) (またはより正確には π θ ( a ∣ s ) \pi_{\theta}(a|s) π θ ( a ∣ s ) 、与えられた状態 s s s の下でのすべての可能なアクション a a a に対する)を表します。これは、与えられた状態で出力されるアクションの確率分布です。
∑ x p ( x ) log p ( x ) \sum_x p(x) \log p(x) ∑ x p ( x ) log p ( x ) という項は、この確率分布のエントロピーです。エントロピーは分布のランダム性または不確実性を測定します。均一分布(非常にランダム)は高いエントロピーを持ち、尖った分布(特定のアクションに非常に確信を持つ)は低いエントロピーを持ちます。
損失項は負 のエントロピーです。私たちが総損失 L P P O L_{PPO} L PPO の中でこの L E N T R O P Y L_{ENTROPY} L ENTROP Y を最小化する場合(c 2 c_2 c 2 が正であると仮定)、実際にはポリシーのエントロピーを最大化 しています。
より高いエントロピーを奨励することで、探索を促進し、ポリシーをよりランダムにし、異なるアクション(LLM の場合は異なるトークンを試す)を試みることができ、過度に収束することを防ぎます。これは、エージェントがより良いポリシーを発見するのに役立ちます。
最終形式 L P P O L_{PPO} L PPO #
最終的な PPO 損失は、これら 3 つの部分の加重和です:
Copy L P P O = L P O L I C Y + c 1 L V F + c 2 L E N T R O P Y L_{PPO} = L_{POLICY} + c_1 L_{VF} + c_2 L_{ENTROPY} L PPO = L PO L I C Y + c 1 L V F + c 2 L ENTROP Y
c 1 L V F c_1 L_{VF} c 1 L V F : 価値関数損失で、c 1 c_1 c 1 によって加重されます。c 1 c_1 c 1 の一般的な値は約 0.5 0.5 0.5 です。
c 2 L E N T R O P Y c_2 L_{ENTROPY} c 2 L ENTROP Y : エントロピー報酬(c 2 > 0 c_2 > 0 c 2 > 0 の場合、実際には低エントロピーに対するペナルティ)、c 2 c_2 c 2 によって加重されます。c 2 c_2 c 2 は通常、小さな正の数(例えば 0.01 0.01 0.01 )であり、探索を奨励しつつ、主要なポリシー目標を圧倒しないようにします。
エージェントのパラメータ(つまり LLM の重み)は、この組み合わせ損失 L P P O L_{PPO} L PPO の勾配を計算し、勾配降下を実行することで更新されます。
Reference Model#
Reward Hacking#
RL の大きな問題の一つはリワードハッキングであり、モデルは人間にとって意味のないトークンやシーケンスを出力することで良い報酬を得ることを学習する可能性があります。例えば、「ありがとう」を十回連続して言うことで礼儀正しさのスコアを上げるなどです。したがって、整合性のあるモデル(RL 後のトレーニング)の出力は、元のモデルの出力とできるだけ近いものにしたいと考えています。
そのため、重みを固定した別のモデル(ref model)が存在し、最適化するモデルは、各軌跡の各ステップで報酬モデルを通じて報酬を生成する際に、この報酬から最適化モデルと参照モデルの log prob の間の KL 散度を引き、モデルが元のモデルと異なる答えを生成することを防ぐためのペナルティ項として使用します。これにより、上記のモデルの不正行為を防止します。
コードウォークスルー#
trl#
Copy class AutoModelForCausalLMWithValueHead ( PreTrainedModelWrapper ):
# ... (クラス属性のような transformers_parent_class) ...
このクラスの核心的な目的は、標準の因果言語モデル(Causal LM) (私たちの Actor Model 、テキストを生成するポリシー π θ π_θ π θ )と Value Head (つまり Critic Model 、状態価値 V ( s ) V(s) V ( s ) を推定する役割)を一緒に束ねることです。PPO / Actor Critic アルゴリズムでは、ポリシーと価値関数の両方が必要であり、このクラスは両者を同時に出力するための統一されたモデル構造を提供します。
Copy def __init__ (self, pretrained_model, ** kwargs):
super (). __init__ (pretrained_model, ** kwargs) # 基本設定
v_head_kwargs, _, _ = self ._split_kwargs(kwargs) # ValueHead に渡すパラメータを分離
# 言語モデル出力能力を持つモデルであることを確認
if not any ( hasattr ( self .pretrained_model, attribute) for attribute in self .lm_head_namings):
raise ValueError ( "The model does not have a language model head..." )
# 状態の価値 V(s) を予測するための ValueHead インスタンスを作成
self .v_head = ValueHead( self .pretrained_model.config, ** v_head_kwargs)
# ValueHead の重みを初期化