banner
Nagi-ovo

Nagi-ovo

Breezing homepage: nagi.fun
github

From RL to RLHF

This article is mainly based on Umar Jamil's course [1]^{[1]} for learning and recording. Our goal is to align the behavior of LLMs with our expected outputs, and RLHF is one of the most well-known techniques. Its standard process involves four models (which sounds memory-intensive, so many methods remove some models), but just remember that a total of four are needed: Reward, Actor, Critic, and Reference Model; the model we optimize at the end is the Actor Model mentioned here.

LLM to RL#

Previously, my understanding of RL was that a policy tells you the probability of the action you should take in the current state. In this sense, the language model itself can be viewed as a policy: it receives a prompt (state) and outputs the probability of the next token (action), sampling to obtain a new state (the token appended to the prompt), which is equivalent to a policy with an action space of size vocab_size, also an RL agent.

So, that means we still need something to provide rewards (which is usually an environment-built reward function in traditional RL).

Creating a "Q-A-Reward" dataset can achieve this, but humans are not good at finding consensus; however, they excel at comparing advantages. Therefore, we shift our focus: the model generates multiple answers (A) under high temperature, and then we ask domain experts (who can be human or AI models) to select the chosen/preferred answer, labeling a preference dataset to train a reward model that generates numerical rewards.

Reward Model#

This RM is implemented using a pre-trained LLM like Llama.

Note

In text generation tasks, we take the last hidden state (of the token) from the embedding (hidden states) produced by the Transformer after inputting the prompt, project it linearly into the vocabulary to obtain logits, and then use softmax and sampling strategies to select the next token.

When we want to generate numerical rewards instead of text, we can replace the linear projection to the vocabulary with a linear layer that has one output feature (outputting a scalar) to produce a single score for the entire text sequence.

Screenshot 2025-04-23 at 21.19.56

Reward Model Loss#

Tip

During training, we want this model to generate high rewards for the chosen answers and low rewards for the unchosen answers.

Similar to the parameterized form of Bradley-Terry:

Loss=logσ(r(x,yw)r(x,yl))Loss = -\log \sigma(r(x, y_w) - r(x, y_l))

ywy_w represents chosen, and yly_l represents the opposite. Therefore, when the model gives a high reward to the chosen answer,

r(x,yw)r(x,yl)>0r(x, y_w) - r(x, y_l)>0

σ(r(x,yw)r(x,yl)(0.5,1)\sigma(r(x, y_w) - r(x, y_{l})\in(0.5,1)

logσ(r(x,yw)r(x,yl))(0,1)-\log \sigma(r(x, y_w) - r(x, y_l))\in(0,1)

Thus, the loss will be low, while if the model gives a low reward to the chosen answer, the loss will be very high.

image image

The RewardTrainer class in HuggingFace receives an AutoModelForSequenceClassification input (which is the model structure we mentioned above).

Screenshot 2025-04-23 at 22.02.29

Actor & Critic Model#

Trajectories#

As mentioned earlier, the core objective of reinforcement learning (RL) is to find a policy (π\pi) that can guide the agent's actions to achieve the maximum possible expected return.

Mathematically, we express this as finding the optimal policy π\pi^* that maximizes the objective function J(π)J(\pi):

π=argmaxπJ(π)\pi^* = \arg \max_{\pi} J(\pi)

Expected return J(π)J(\pi) represents the average total return that the agent can accumulate over many possible lifetimes or episodes while following policy π\pi.

The calculation method is: consider all possible trajectories (τ\tau) and weight the total return R(τ)R(\tau) of each trajectory by the probability P(τπ)P(\tau|\pi) of that trajectory occurring under policy π\pi (averaging or integrating).

J(π)=P(τπ)R(τ)=Eτπ[R(τ)]J(\pi) = \int P(\tau|\pi) R(\tau) = E_{\tau \sim \pi} [R(\tau)]
  • Eτπ[]E_{\tau \sim \pi}[\cdot] indicates the expected value when trajectory τ\tau is generated according to policy π\pi.
  • R(τ)R(\tau) is the total return obtained on a single trajectory τ\tau.
  • P(τπ)P(\tau|\pi) is the probability of a specific trajectory τ\tau occurring when the agent uses policy π\pi.

Trajectory τ\tau is a sequence of states and actions experienced by the agent, starting from the initial state. It is a possible "story" or "path" of the agent's interaction with the environment.

τ=(s0,a0,s1,a1,s2,a2,)\tau = (s_0, a_0, s_1, a_1, s_2, a_2, \dots)
  • sts_t: state at time step tt.
  • ata_t: action taken at time step tt (usually based on state sts_t and policy π\pi).

We typically model the environment as stochastic. This means that executing the same action ata_t in the same state sts_t does not always lead to the exactly same next state st+1s_{t+1}, which involves randomness.

The next state st+1s_{t+1} is drawn from a probability distribution conditioned on the current state sts_t and the action taken ata_t:

st+1P(st,at)s_{t+1} \sim P(\cdot | s_t, a_t)

Considering the stochastic state transitions and the agent's policy, we can calculate the probability of the entire trajectory occurring. It is obtained by multiplying the following components:

  1. The probability of the agent being in the initial state s0s_0: p0(s0)p_0(s_0).
  2. For each time step tt in the trajectory:
    • The probability of the environment transitioning to state st+1s_{t+1} given sts_t and ata_t: P(st+1st,at)P(s_{t+1}|s_t, a_t).
    • The probability of the agent selecting action ata_t in state sts_t according to its policy: π(atst)\pi(a_t|s_t).
P(τπ)=p0(s0)t=0T1P(st+1st,at)π(atst)P(\tau|\pi) = p_0(s_0) \prod_{t=0}^{T-1} P(s_{t+1}|s_t, a_t) \pi(a_t|s_t)

(where TT is the length of the trajectory).

When calculating the total return R(τ)R(\tau) for the trajectory, we almost always use discounted rewards. This means that rewards received earlier are more valuable than those received later.

Why?

  • It reflects real-world scenarios (a dollar in hand today is worth more than a dollar promised tomorrow).
  • It avoids the problem of infinite returns in ongoing tasks (tasks without a fixed endpoint).
  • It provides mathematical convenience.

We introduce a discount factor (γ\gamma), where 0γ<10 \le \gamma < 1. The closer γ\gamma is to 0, the more "short-sighted" the agent is (focusing more on immediate benefits); the closer γ\gamma is to 1, the more "far-sighted" the agent is (focusing more on long-term returns).

The total discounted return for the trajectory is calculated as follows:

R(τ)=t=0γtrtR(\tau) = \sum_{t=0}^{\infty} \gamma^t r_t
  • rtr_t is the immediate reward received at time step tt.
  • γt\gamma^t is the discount factor applied to the reward at time step tt.

So, what is a trajectory in LLM? As mentioned earlier, the model is the policy, the prompt is the state, and the next token is the action; thus, these sequences of s and a in autoregressive generation constitute the trajectory.

Screenshot 2025-04-24 at 01.32.33

Policy Gradient#

We have established the goal of reinforcement learning: to find an optimal policy ππ^∗ that maximizes the expected return J(π)J(π). Great. But how do we actually represent and find this policy?

Typically, especially when dealing with complex problems, we do not search all possible policies. Instead, we define a parameterized policy, denoted as πθπ_θ. You can think of θθ as a set of "knobs" or parameters—if our policy is a neural network, then θθ might be the weights and biases of the network.

Our current goal becomes: how to adjust these knobs θθ to maximize our expected return?

Note

Under the parameterized policy πθπ_θ, the expected return for all possible trajectories:

J(πθ)=Eτπθ[R(τ)]J(\pi_{\theta}) = E_{\tau \sim \pi_{\theta}} [R(\tau)]

This means that the expected return depends on the trajectory τ\tau, and the distribution of trajectories depends on the actions chosen by our specific policy πθ\pi_{\theta}. Changing θθ changes the policy, changes the trajectories, and thus changes the expected return.

Note

We want to maximize J(πθ)J(\pi_{\theta}) by changing θθ. In deep learning, we generally use gradient descent to minimize a loss function. Here, we want to maximize a function JJ. Therefore, we use gradient ascent! It’s like climbing a mountain—wanting to find the steepest ascent direction (the gradient) and then taking a step in that direction.

Our policy πθ\pi_{\theta} is a neural network, and we will iteratively adjust its parameters θθ to increase J(πθ)J(\pi_{\theta}). This update rule looks very familiar (just replacing the minus sign in gradient descent with a plus sign):

θk+1=θk+αθJ(πθ)θk\theta_{k+1} = \theta_k + \alpha \nabla_{\theta} J(\pi_{\theta})|_{\theta_k}
  • θk\theta_k: our parameters at the kk-th iteration.
  • α\alpha: learning rate (step size).
  • θJ(πθ)θk\nabla_{\theta} J(\pi_{\theta})|_{\theta_k}: the gradient of expected return JJ with respect to parameters θθ, calculated at the current parameters θkθ_k. It tells us which direction in parameter space maximally increases JJ.

Screenshot 2025-04-24 at 13.51.52

Important

PG Derivation
Here, I will start a bit verbose to reintroduce the indices, ADHD-friendly derivation...

Step 1, reiterate that the object we want the gradient of is the expected return J(πθ)J(\pi_{\theta}):

θJ(πθ)=θEτπθ[R(τ)]\nabla_{\theta} J(\pi_{\theta}) = \nabla_{\theta} E_{\tau \sim \pi_{\theta}} [R(\tau)]

Here:

  • J(πθ)J(\pi_{\theta}) is the expected return.
  • Eτπθ[]E_{\tau \sim \pi_{\theta}} [\cdot] indicates the expected value, calculated over all possible trajectories (τ\tau). The trajectory τ\tau is a series of states and actions generated by the agent interacting with the environment (s0,a0,s1,a1,)(s_0, a_0, s_1, a_1, \dots).
  • τπθ\tau \sim \pi_{\theta} indicates that these trajectories are generated according to our current policy πθ\pi_{\theta}.
  • R(τ)R(\tau) refers to the total return obtained from a complete trajectory τ\tau (usually discounted return).
  • θ\nabla_{\theta} is the gradient operator, indicating that we want to take the partial derivative with respect to parameters θθ.

Step 2, we expand the expression for the expectation:
What is the definition of expected value? For a random variable XX, its expectation E[X]E[X] can be calculated using its probability distribution p(x)p(x):

  • If it’s a continuous variable: E[X]=p(x)xdxE[X] = \int p(x) x dx
  • If it’s a discrete variable: E[X]=p(x)xE[X] = \sum p(x) x

In our case, the random variable is the return of the trajectory R(τ)R(\tau), and the probability distribution is the probability of the trajectory occurring P(τπθ)P(\tau|\pi_{\theta}) (the probability of trajectory τ\tau occurring under policy πθ\pi_{\theta}). Thus, the expectation can be expressed in integral (or summation) form:

Eτπθ[R(τ)]=P(τπθ)R(τ)dτE_{\tau \sim \pi_{\theta}} [R(\tau)] = \int P(\tau|\pi_{\theta}) R(\tau) d\tau

(Here, the integral symbol \int represents summing or integrating over all possible trajectories, which is more general).
Substituting this into the formula from Step 1:

θJ(πθ)=θP(τπθ)R(τ)dτ\nabla_{\theta} J(\pi_{\theta}) = \nabla_{\theta} \int P(\tau|\pi_{\theta}) R(\tau) d\tau


Step 3: Move the gradient operator inside the integral

θP(τπθ)R(τ)dτ=θ[P(τπθ)R(τ)]dτ\nabla_{\theta} \int P(\tau|\pi_{\theta}) R(\tau) d\tau = \int \nabla_{\theta} [P(\tau|\pi_{\theta}) R(\tau)] d\tau

This requires some knowledge of calculus: under certain conditions (which we usually assume are satisfied in reinforcement learning), we can exchange the order of differentiation and integration. Just like ddxfi(x)=ddxfi(x)\frac{d}{dx} \sum f_i(x) = \sum \frac{d}{dx} f_i(x).
Next, note that R(τ)R(\tau) is the total return after a trajectory is determined; its value does not directly depend on the policy parameters θθ. (It is the policy πθ\pi_{\theta} that affects which trajectory occurs, not the return value once that trajectory occurs). Therefore, the gradient θ\nabla_{\theta} only needs to act on P(τπθ)P(\tau|\pi_{\theta}):

=[θP(τπθ)]R(τ)dτ= \int [\nabla_{\theta} P(\tau|\pi_{\theta})] R(\tau) d\tau

This step tells us that the change in expected return is due to changes in the probability of each trajectory occurring P(τπθ)P(\tau|\pi_{\theta}) caused by changes in parameter θθ, multiplied by the return R(τ)R(\tau) of that trajectory, and then summed over all trajectories.


Step 4: Log-derivative trick
This is the most core and clever step in the entire derivation! We need to introduce an identity.

  • Calculus Review (Chain Rule and Log Derivative): Recall that the derivative of the natural logarithm log(x)\log(x) (usually referring to ln(x)\ln(x)) is ddxlog(f(x))=1f(x)df(x)dx=f(x)f(x)\frac{d}{dx} \log(f(x)) = \frac{1}{f(x)} \frac{d f(x)}{dx} = \frac{f'(x)}{f(x)}.
  • Rearranging slightly, we get: f(x)=f(x)ddxlog(f(x))f'(x) = f(x) \frac{d}{dx} \log(f(x)).
    Now, we apply this trick to the gradient. Let f(x)f(x) correspond to P(τπθ)P(\tau|\pi_{\theta}), and the variable xx correspond to the parameter θθ. Then:

θP(τπθ)=P(τπθ)θlogP(τπθ)\nabla_{\theta} P(\tau|\pi_{\theta}) = P(\tau|\pi_{\theta}) \nabla_{\theta} \log P(\tau|\pi_{\theta})

Substituting this result into the integral from Step 3:
[θP(τπθ)]R(τ)dτ=[P(τπθ)θlogP(τπθ)]R(τ)dτ\int [\nabla_{\theta} P(\tau|\pi_{\theta})] R(\tau) d\tau = \int [P(\tau|\pi_{\theta}) \nabla_{\theta} \log P(\tau|\pi_{\theta})] R(\tau) d\tau


Step 5: Return to expectation form
Observe the result from Step 4:

P(τπθ)[θlogP(τπθ)R(τ)]dτ\int P(\tau|\pi_{\theta}) [\nabla_{\theta} \log P(\tau|\pi_{\theta}) R(\tau)] d\tau

This fits the definition of expectation! E[f(τ)]=P(τπθ)f(τ)dτE[f(\tau)] = \int P(\tau|\pi_{\theta}) f(\tau) d\tau.
Here, f(τ)f(\tau) corresponds to the entire content in brackets [θlogP(τπθ)R(τ)][\nabla_{\theta} \log P(\tau|\pi_{\theta}) R(\tau)].
Therefore, the entire integral can be written back in expectation form:

=Eτπθ[θlogP(τπθ)R(τ)]= E_{\tau \sim \pi_{\theta}} [\nabla_{\theta} \log P(\tau|\pi_{\theta}) R(\tau)]

Significant Implication! We have successfully transformed the gradient of the expectation θE[]\nabla_{\theta} E[\cdot] into the expectation of some quantity (gradient multiplied by return) E[()×R]E[\nabla (\cdot) \times R]. This form is very important because it can be approximated through sampling! We do not need to actually compute the integral over all trajectories. We just need to sample many trajectories τ\tau, compute the value in the brackets [θlogP(τπθ)R(τ)][\nabla_{\theta} \log P(\tau|\pi_{\theta}) R(\tau)] for each trajectory, and then average to obtain an approximate gradient!


Step 6: Expand the gradient of log probability (Expression for grad-log-prob)
Now, we need to handle the term θlogP(τπθ)\nabla_{\theta} \log P(\tau|\pi_{\theta}) inside the expectation.
Recall that the trajectory τ=(s0,a0,s1,a1,,sT,aT)\tau = (s_0, a_0, s_1, a_1, \dots, s_T, a_T) (assuming the trajectory length consists of T+1 states and T+1 actions, or T time steps). The probability of a trajectory occurring is:

P(τπθ)=p0(s0)t=0TP(st+1st,at)πθ(atst)P(\tau|\pi_{\theta}) = p_0(s_0) \prod_{t=0}^{T} P(s_{t+1}|s_t, a_t) \pi_{\theta}(a_t|s_t)

  • p0(s0)p_0(s_0): the probability of the initial state s0s_0.

  • P(st+1st,at)P(s_{t+1}|s_t, a_t): the environment dynamics, the probability of transitioning to state st+1s_{t+1} after executing action ata_t in state sts_t.

  • πθ(atst)\pi_{\theta}(a_t|s_t): the policy, the probability of selecting action ata_t in state sts_t (this part depends on θθ).

  • Mathematical Review (Log Properties): log(a×b)=loga+logb\log(a \times b) = \log a + \log b and log(ixi)=ilogxi\log(\prod_{i} x_i) = \sum_{i} \log x_i.
    Taking the logarithm of P(τπθ)P(\tau|\pi_{\theta}):

logP(τπθ)=logp0(s0)+t=0T[logP(st+1st,at)+logπθ(atst)]\log P(\tau|\pi_{\theta}) = \log p_0(s_0) + \sum_{t=0}^{T} [\log P(s_{t+1}|s_t, a_t) + \log \pi_{\theta}(a_t|s_t)]

Now take the gradient θ\nabla_{\theta} of the above expression:

  • Mathematical Review (Gradient Properties): The gradient addition rule (f+g)=f+g\nabla(f+g) = \nabla f + \nabla g. The gradient θ\nabla_{\theta} only acts on terms that depend on θθ.

θlogP(τπθ)=θlogp0(s0)+t=0T[θlogP(st+1st,at)+θlogπθ(atst)]\nabla_{\theta} \log P(\tau|\pi_{\theta}) = \nabla_{\theta} \log p_0(s_0) + \sum_{t=0}^{T} [\nabla_{\theta} \log P(s_{t+1}|s_t, a_t) + \nabla_{\theta} \log \pi_{\theta}(a_t|s_t)]

  • Key Points:
    • The initial state probability logp0(s0)\log p_0(s_0) typically does not depend on the policy parameter θθ, so θlogp0(s0)=0\nabla_{\theta} \log p_0(s_0) = 0.
    • The environment dynamics logP(st+1st,at)\log P(s_{t+1}|s_t, a_t) describe the properties of the environment itself and also do not depend on the policy parameter θθ, so θlogP(st+1st,at)=0\nabla_{\theta} \log P(s_{t+1}|s_t, a_t) = 0.
    • Only the policy logπθ(atst)\log \pi_{\theta}(a_t|s_t) depends on θθ.
      Therefore, the above expression simplifies to:

θlogP(τπθ)=t=0Tθlogπθ(atst)\nabla_{\theta} \log P(\tau|\pi_{\theta}) = \sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_t|s_t)

The gradient of the log probability of the entire trajectory equals the sum of the log probability gradients of each action in that trajectory! This greatly simplifies the computation.


Step 7: The final policy gradient theorem
Substitute the simplified result from Step 6 back into the expectation formula from Step 5:

θJ(πθ)=Eτπθ[(t=0Tθlogπθ(atst))R(τ)]\nabla_{\theta} J(\pi_{\theta}) = E_{\tau \sim \pi_{\theta}} [(\sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_t|s_t)) R(\tau)]

This is the final form of the Policy Gradient Theorem (or one of its common forms).
The gradient of expected return J(πθ)J(\pi_{\theta}) with respect to parameter θθ equals "sample a trajectory τ\tau, compute the total return R(τ)R(\tau) for that trajectory, multiply it by the sum of the log probability gradients of the policy for all (state, action) pairs in that trajectory θlogπθ(atst)\nabla_{\theta} \log \pi_{\theta}(a_t|s_t), and then take the expectation (average) over all possible trajectories.

Clearly, the cost of obtaining all trajectories is extremely high; for example, we need to sample all generation results with max_token_length=100, so we can approximate the expectation using the sample mean:

Note

Monte Carlo approximation: * Run the current policy πθ\pi_{\theta}, collect NN trajectories, forming the dataset D={τ1,...,τN}D = \{\tau_1, ..., \tau_N\} (let N=DN = |D|).

  • Use the average of these samples to approximate the expected value:

θJ(πθ)g^=1DτD[(t=0Tθlogπθ(atst))R(τ)]\nabla_{\theta} J(\pi_{\theta}) \approx \hat{g} = \frac{1}{|D|} \sum_{\tau \in D} [ (\sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_t|s_t)) R(\tau) ]

Application to LM Policy#

By obtaining the log probabilities of each state-action pair in the sampled trajectory through the generation process shown in the figure, we can now backpropagate to calculate the gradient.

Screenshot 2025-04-24 at 15.17.06

Then, multiply each gradient by the reward from the RM and input it into the expression to perform gradient ascent optimization:

Screenshot 2025-04-24 at 15.23.14

High Variance#

PG algorithms perform well on small problems, but there are some issues when used for language modeling.

Note

The central limit theorem tells us: as long as the sample size is large enough, the sample mean will follow a normal distribution, allowing us to better predict and analyze data. When the sample size is small, the fluctuations in the sample mean can be large; even if the mean tends toward a normal distribution, the results of a single sample may vary greatly. We also know that the cost of sampling many trajectories from the LM is very high, leading to the problem of high variance in estimates.

How can we reduce variance without increasing the sample size?

  1. Remove historical rewards: reward-to-go
    It must be acknowledged that the current action cannot affect rewards already obtained in the past, and past rewards add unnecessary noise, which should relate to the credit assignment problem in RL. Therefore, removing past terms can avoid adding noise and bring the estimated gradient closer to the true gradient. Thus, instead of calculating the rewards from scratch, we can only consider the rewards of actions starting from the current time step.

Screenshot 2025-04-24 at 15.48.52

  1. Introduce a baseline
    Research in RL has confirmed that introducing a term dependent on the state (such as a function that calculates the trajectory reward, which can also be a constant) can reduce variance. Here we choose the value function Vπ(s)V^\pi(s).

Value Function#

Vπ(s)V^\pi(s) tells you the expected reward of the remaining trajectory when acting according to the current policy.

Examples of value definitions in classic RL scenarios and LM scenarios:

Screenshot 2025-04-24 at 16.01.05

In practice, we use the LM we are trying to optimize as initialization, adding a linear layer on top to estimate the value, so that the parameters of the Transformer layer can be used for both language modeling (projecting tokens into the vocabulary) and value estimation.

Screenshot 2025-04-24 at 16.45.48

The previously mentioned reward-to-go is referred to as the Q function in RL, which is the expected reward obtained by taking this action from the current state and completing subsequent actions according to the policy:

Screenshot 2025-04-24 at 16.54.58

By introducing the value function, we obtain the difference between Q and V, which is referred to as the Advantage function.

Screenshot 2025-04-24 at 16.56.48

This Aπ(st,at)A^\pi(s_{t},a_t) advantage term indicates how much better this specific action is compared to the average action that can be taken in state ss.

Screenshot 2025-04-24 at 17.04.33

In the state pointed to by the red arrow in the figure, the advantage function for moving down will be higher than that of other actions.

θ(J(θ))1Ni=1N(t=0Tθlogπθ(ai,tsi,t))Aπ(si,t,ai,t)\nabla_{\theta}(J(\theta)) \approx \frac{1}{N}\sum_{i=1}^{N}\left(\sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_{i,t} | s_{i,t})\right) A^{\pi}(s_{i,t}, a_{i,t})

Multiplying the gradient by the advantage function changes the effect to increase the log probability of actions with high advantage and decrease the log probability of actions that yield low average returns.

A^π(st,at)=Qπ(st,at)Vπ(st)=[r(st,at)+γVπ(st+1)]Vπ(st)\begin{align} \hat{A}^\pi(s_t, a_t) &= Q^\pi(s_t, a_t) - V^\pi(s_t) \\ &= [r(s_t, a_t) + \gamma V^\pi(s_{t+1})] - V^\pi(s_t) \end{align}

Note

In traditional reinforcement learning methods, Q networks and V networks are usually independent. That is, the Q function is used to estimate the expected total return of executing action aa in state ss, while the V function simply estimates the value of state ss. This requires two different neural networks to compute these two values separately.

However, we have now introduced the advantage function Aθ(s,a)A_{\theta}(s, a), which is calculated based on the difference between Q values and V values, namely:

Aθ(s,a)=Qθ(s,a)Vθ(s)A_{\theta}(s, a) = Q_{\theta}(s, a) - V_{\theta}(s)

By expressing Aθ(s,a)A_{\theta}(s, a) as the difference between Qθ(s,a)Q_{\theta}(s, a) and Vθ(s)V_{\theta}(s), we find that we only need to train one network to output Vθ(s)V_{\theta}(s), and then calculate the Q value using the reward rtr_t and discount factor γ\gamma.

Thus, only one neural network is needed, which primarily predicts Vθ(s)V_{\theta}(s). The Q value is calculated as follows:

Qθ(st,a)=rt+γVθ(st+1)Q_{\theta}(s_t, a) = r_t + \gamma \cdot V_{\theta}(s_{t+1})

The advantage function is further calculated as:

Aθ(st,a)=rt+γVθ(st+1)Vθ(st)A_{\theta}(s_t, a) = r_t + \gamma \cdot V_{\theta}(s_{t+1}) - V_{\theta}(s_t)

Advantage Sampling#

Short-step advantage estimators have high bias but low variance, while long-step advantage estimators have low bias but high variance. This trade-off is a part of reinforcement learning that requires careful selection and adjustment, depending on the stability requirements of the model and training efficiency.
Screenshot 2025-05-08 at 21.53.04
An example: "A short-term memory person only remembers what happened yesterday; although not comprehensive, it is very stable; a long-term memory person can see the whole picture for the next few days but may be disturbed by more unknown factors."

GAE#

To address this bias-variance problem, we can use GAE (Generalized Advantage Estimation), which essentially is a weighted sum of all advantage terms, each multiplied by a decay factor.

Note

Now let's talk about TD error.
Online learning has a wonderful aspect: you do not need to wait until the end to update the policy. Thus, the Temporal Difference Error (TD Error) comes into play:

δ=r+γV(s)V(s)\delta = r + \gamma V(s') - V(s)

The key here is that the TD error is actually an online estimate of the advantage function. It tells you whether your action makes the future state better than you expected at this moment. This error δ\delta directly reflects the concept of advantage:

  • If δ>0\delta > 0: "Hey, this action is better than I imagined!" (advantage is positive).
  • If δ<0\delta < 0: "Hmm, I thought it would be better..." (advantage is negative).

This allows you to gradually adjust your policy without waiting for an entire episode to end. This is an excellent strategy for improving efficiency.

The purpose of GAE is to provide a better estimate of the advantage function Aπ(s,a)A^π(s,a) in policy gradient algorithms than the original return R(τ)R(τ) or simple TD error δtδ_t, to reduce the variance of gradient estimates and improve learning stability and efficiency.

δt=rt+γVπ(st+1)Vπ(st)A^t=δt+γλA^t+1\begin{align} \delta_t &= r_t + \gamma V^\pi(s_{t+1}) - V^\pi(s_t) \\ \hat{A}_t &= \delta_t + \gamma \lambda \hat{A}_{t+1} \end{align}

This formula recursively defines the generalized advantage estimate A^t\hat{A}_t. It does not only look at the one-step TD error δt\delta_t but integrates the TD error information from multiple future steps.

This recursive formula calculates from the end of the trajectory (episode) (assuming TT is the last step, A^T+1=0\hat{A}_{T+1}=0):
* A^T=δT+0=δT\hat{A}_T = \delta_T + 0 = \delta_T
* A^T1=δT1+γλA^T=δT1+(γλ)δT\hat{A}_{T-1} = \delta_{T-1} + \gamma \lambda \hat{A}_T = \delta_{T-1} + (\gamma \lambda) \delta_T
* A^T2=δT2+γλA^T1=δT2+(γλ)δT1+(γλ)2δT\hat{A}_{T-2} = \delta_{T-2} + \gamma \lambda \hat{A}_{T-1} = \delta_{T-2} + (\gamma \lambda) \delta_{T-1} + (\gamma \lambda)^2 \delta_T
* ...
* General Form: A^t=k=0(γλ)kδt+k\hat{A}_t = \sum_{k=0}^{\infty} (\gamma \lambda)^k \delta_{t+k} (assuming infinite steps or that δ\delta after the terminal state is 0).

The parameter λ\lambda (0λ10 \le \lambda \le 1) is key to GAE; it controls the bias and variance of the estimate A^t\hat{A}_t:

  • When λ=0\lambda = 0:
    • A^t=δt\hat{A}_t = \delta_t. GAE degenerates into a simple one-step TD error. This estimate has lower variance (because it only relies on the information from the next step) but may have higher bias (because it heavily depends on potentially inaccurate estimates of Vπ(st+1)V^{\pi}(s_{t+1})).
  • When λ=1\lambda = 1:
    • A^t=k=0(γ)kδt+k\hat{A}_t = \sum_{k=0}^{\infty} (\gamma)^k \delta_{t+k}. It can be proven that this is equivalent to A^t=(k=0γkrt+k)Vπ(st)\hat{A}_t = (\sum_{k=0}^{\infty} \gamma^k r_{t+k}) - V^{\pi}(s_t), which is the Monte Carlo (MC) return minus the baseline. This estimate has lower bias (because it uses the complete actual return starting from time tt) but usually has high variance (because it accumulates randomness from multiple time steps).
  • When 0<λ<10 < \lambda < 1:
    • GAE interpolates between the two extreme cases mentioned above. The closer λ\lambda is to 0, the more it leans towards low variance and high bias TD estimates; the closer λ\lambda is to 1, the more it leans towards high variance and low bias MC estimates.
    • By choosing an appropriate λ\lambda (e.g., 0.97), GAE attempts to achieve a good balance between bias and variance, resulting in an advantage estimate that is both relatively accurate (controllable bias) and relatively stable (lower variance).

Advantage in Language Models#

As shown in the figure, the goal is to increase the log probability of the token "Shanghai" in the current state while decreasing the log probability of "chocolate," as the advantage of choosing "Shanghai" is higher than that of choosing "chocolate" (a random token).

Screenshot 2025-04-24 at 23.13.57

Importance Sampling and Offline Learning#

In many cases, we may want to compute Exp(x)[f(x)]E_{x \sim p(x)}[f(x)], but:

  1. It is difficult or impossible to sample xx directly from the target distribution p(x)p(x).
  2. Or sampling from p(x)p(x) is inefficient, which is the case in language models where LM sampling is too costly.

However, we may easily sample from another alternative, or proposal distribution q(x)q(x).

Importance Sampling (IS) is a technique for estimating the expectation of a target distribution by sampling from a different distribution, transforming an expectation calculated under the probability distribution p(x)p(x) into the expectation of a related function under another different probability distribution q(x)q(x):

Exp(x)[f(x)]=p(x)f(x)dx=q(x)q(x)p(x)f(x)dx(assuming q(x)0)=q(x)p(x)q(x)f(x)dx=Exq(x)[p(x)q(x)f(x)]\begin{align} E_{x \sim p(x)}[f(x)] &= \int p(x) f(x) \, dx \\ &= \int \frac{q(x)}{q(x)} p(x) f(x) \, dx \quad \text{(assuming } q(x) \neq 0 \text{)} \\ &= \int q(x) \frac{p(x)}{q(x)} f(x) \, dx \\ &= E_{x \sim q(x)}\left[\frac{p(x)}{q(x)} f(x)\right] \end{align}

The key here is the introduction of the importance weight w(x)=p(x)q(x)w(x) = \frac{p(x)}{q(x)}. The role of this weight is to correct the bias: for a sample xix_i drawn from q(x)q(x), if its probability of occurring in the target distribution p(x)p(x) is higher (p(xi)>q(xi)p(x_i) > q(x_i)), it will receive a weight greater than 1; conversely, if its probability in p(x)p(x) is lower (p(xi)<q(xi)p(x_i) < q(x_i)), it will receive a weight less than 1. By averaging with these weights, we can obtain an estimate of the original expectation Exp(x)[f(x)]E_{x \sim p(x)}[f(x)] that is (usually unbiased or consistent).

Importance sampling allows us to:

  1. Draw samples x1,x2,...,xNx_1, x_2, ..., x_N from an easily sampled distribution q(x)q(x).
  2. Estimate the original expected value by calculating the weighted average:
Exp(x)[f(x)]1Ni=1Np(xi)q(xi)f(xi)E_{x \sim p(x)}[f(x)] \approx \frac{1}{N} \sum_{i=1}^{N} \frac{p(x_i)}{q(x_i)} f(x_i)

Returning to our scenario. Previously, we obtained the on-policy policy gradient estimate:

θ(J(θ))1Ni=1N(t=0Tθlogπθ(ai,tsi,t))Aπ(si,t,ai,t)\nabla_{\theta}(J(\theta)) \approx \frac{1}{N}\sum_{i=1}^{N}\left(\sum_{t=0}^{T} \nabla_{\theta} \log \pi_{\theta}(a_{i,t} | s_{i,t})\right) A^{\pi}(s_{i,t}, a_{i,t})

[!info]
On-Policy means that the policy used to collect data is the same as the one used during training. Since the calculation requires using trajectories generated from the current policy πθ\pi_{\theta}, this means that after each policy update, old data cannot be directly reused, leading to low sample efficiency. As for mini_batch_num > 1, is the data used for updates still considered on-policy? Strictly speaking, it feels like it isn't, so it can also be understood as semi-on-policy? (The expression may not be rigorous).

On-Policy emphasizes whether the current policy model can interact with the environment. [2]^{[2]}

We hope to utilize data generated by the old policy πθOFFLINE\pi_{\theta_{OFFLINE}} to estimate the gradient of the current new policy πθONLINE\pi_{\theta_{ONLINE}}. This way, we can reuse data and improve sample efficiency.

Recall the principle of importance sampling IS: Exp(x)[f(x)]=Exq(x)[p(x)q(x)f(x)]E_{x \sim p(x)}[f(x)] = E_{x \sim q(x)}\left[\frac{p(x)}{q(x)} f(x)\right].
Corresponding to our PG (simplifying to single-step decisions):
* The target distribution p(x)p(x) corresponds to the new policy πθONLINE(as)\pi_{\theta_{ONLINE}}(a|s).
* The sampling distribution q(x)q(x) corresponds to the old policy πθOFFLINE(as)\pi_{\theta_{OFFLINE}}(a|s).
* The importance weight is wt=πθONLINE(atst)πθOFFLINE(atst)w_t = \frac{\pi_{\theta_{ONLINE}}(a_t|s_t)}{\pi_{\theta_{OFFLINE}}(a_t|s_t)}.

Applying the importance weight to each term of the on-policy gradient (for each time step tt), we obtain the standard off-policy estimate:

θONLINE(J(θONLINE,θOFFLINE))1Ni=1Nt=0T[πθONLINE(ai,tsi,t)πθOFFLINE(ai,tsi,t)θONLINElogπθONLINE(ai,tsi,t)Aπ(si,t,ai,t)]\nabla_{\theta_{ONLINE}} (J(\theta_{ONLINE},\theta_{OFFLINE})) \approx \frac{1}{N}\sum_{i=1}^{N}\sum_{t=0}^{T} \left[ \frac{\pi_{\theta_{ONLINE}}(a_{i,t}|s_{i,t})}{\pi_{\theta_{OFFLINE}}(a_{i,t}|s_{i,t})} \nabla_{\theta_{ONLINE}} \log \pi_{\theta_{ONLINE}}(a_{i,t} | s_{i,t}) A^{\pi}(s_{i,t}, a_{i,t}) \right]

Now we have found a way to perform complete gradient ascent optimization without having to sample from the policy we are optimizing (the model to be trained) each time; instead, we can sample once, save the trajectories to memory/database, optimize the policy using mini-batch, and then initialize the offline policy (the sampled policy) with the new policy.

PPO Loss#

PPO Loss mainly consists of three parts: policy loss (LPOLICYL_{POLICY}), value function loss (LVFL_{VF}), and entropy reward (LENTROPYL_{ENTROPY}).

1. Policy Loss (LPOLICYL_{POLICY})#

Clipped Surrogate Objective

LPOLICY=min(πθ(atst)πθold(atst)A^t,clip(πθ(atst)πθold(atst),1ϵ,1+ϵ)A^t)L_{POLICY} = \min \left( \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} \hat{A}_t, \text{clip} \left( \frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}, 1-\epsilon, 1+\epsilon \right) \hat{A}_t \right)

This is the core of PPO. You will notice that it resembles the off-policy policy gradient objective derived using importance sampling, but with a key modification.

  • πθ(atst)πθold(atst)\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}: This is the importance sampling ratio, referred to as rt(θ)r_t(\theta). It is the probability of taking action ata_t in state sts_t according to the current (online) policy πθ\pi_{\theta}, divided by the probability of taking that action according to the old (offline) policy πθold\pi_{\theta_{old}} when collecting trajectory data. This ratio corrects for the fact that the data comes from a policy that is slightly different from the one we are currently trying to improve.

  • A^t\hat{A}_t: This is the estimated advantage function, calculated using GAE, which helps balance bias and variance. It tells us how much better or worse taking action ata_t in state sts_t is compared to taking the average action in that state (as judged by the current value function).

  • clip function: This is the key point of PPO.
    clip(rt(θ),1ϵ,1+ϵ)\text{clip} \left( r_t(\theta), 1-\epsilon, 1+\epsilon \right)
    It essentially says: if the probability ratio rt(θ)r_t(\theta) deviates too far from 1 (either too high or too low), we "clip" it. So, if rt(θ)r_t(\theta) tries to become 1.51.5 and ϵ\epsilon is 0.20.2, it will be clipped to 1.21.2. If it tries to become 0.50.5, it will be clipped to 0.80.8.
    The parameter ϵ\epsilon (epsilon) is a small hyperparameter (e.g., 0.1 or 0.2) that defines the clipping range [1ϵ,1+ϵ][1-\epsilon, 1+\epsilon].

  • min function: This objective function takes the smaller of the two items below:

    1. Unclipped objective: rt(θ)A^tr_t(\theta) \hat{A}_t
    2. Clipped objective: clip(rt(θ),1ϵ,1+ϵ)A^t\text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_t

    Why do it this way? The objective of policy gradients is to increase the probability of actions with positive advantages and decrease the probability of actions with negative advantages.
    However, using importance sampling, if rt(θ)r_t(\theta) becomes very large, it can lead to huge updates and instability. PPO tries to maintain proximity between the new policy and the old policy by clipping this ratio.

    • If A^t>0\hat{A}_t > 0 (good action): We want to increase πθ(atst)\pi_{\theta}(a_t|s_t). The min function means that if rt(θ)r_t(\theta) grows beyond 1+ϵ1+\epsilon, the objective function will be limited to (1+ϵ)A^t(1+\epsilon)\hat{A}_t. This prevents the policy from changing too much in a single update, even if the unclipped objective suggests a larger increment.
    • If A^t<0\hat{A}_t < 0 (bad action): We want to decrease πθ(atst)\pi_{\theta}(a_t|s_t). If rt(θ)r_t(\theta) shrinks below 1ϵ1-\epsilon, the objective function will be limited to (1ϵ)A^t(1-\epsilon)\hat{A}_t. (Note: when A^t<0\hat{A}_t < 0, the term rt(θ)A^tr_t(\theta)\hat{A}_t is larger (closer to zero or positive) when rt(θ)r_t(\theta) is small, while clip(...)A^t\text{clip}(...) \hat{A}_t is also larger when clip(...)\text{clip}(...) is small. The min operation effectively means that when the ratio exceeds the clipping boundaries, we take a more pessimistic update step, or one that results in a smaller change in log probability.)
      More accurately, when A^t<0\hat{A}_t < 0, the product rt(θ)A^tr_t(\theta)\hat{A}_t becomes more negative as rt(θ)r_t(\theta) increases. The min operation ensures that if rt(θ)r_t(\theta) deviates from the [1ϵ,1+ϵ][1-\epsilon, 1+\epsilon] interval, we do not let the objective become too negative (that is, we do not overly reduce the probability of that action).

2. Value Function Loss (LVFL_{VF})#

LVF=12Vθ(s)(t=tTγttrts0=s)22L_{VF} = \frac{1}{2} \left\| V_{\theta}(s) - \left( \sum_{t'=t}^{T} \gamma^{t'-t} r_{t'} \Big| s_0 = s \right) \right\|^2_2

This is exactly the same as the previous content:

  • Vθ(s)V_{\theta}(s) is the output of the value network (i.e., a linear layer added on top of the LLM to predict the expected cumulative reward starting from state ss).
  • γtrt\sum \gamma^{t'} r_{t'} (referred to as GsG_s or target value) is the actual total discounted reward observed starting from state ss and following the current policy until the episode ends. This is the empirical target we set for Vθ(s)V_{\theta}(s).
  • This loss function is the mean squared error (MSE) between the predicted value Vθ(s)V_{\theta}(s) and the observed target value GsG_s. We want the value function to accurately predict future rewards. This value function is crucial for calculating the advantage A^t\hat{A}_t.

3. Entropy Reward (LENTROPYL_{ENTROPY})#

LENTROPY=xp(x)logp(x)L_{ENTROPY} = - \sum_x p(x) \log p(x)
  • Here, p(x)p(x) (or more accurately, πθ(as)\pi_{\theta}(a|s), the action probability distribution output by the current policy for all possible actions aa given state ss) represents the action probability distribution output by the current policy in the given state.
  • xp(x)logp(x)\sum_x p(x) \log p(x) is the entropy of this probability distribution. Entropy measures the randomness or uncertainty of the distribution. A uniform distribution (very random) has high entropy, while a peaked distribution (very certain about a specific action) has low entropy.
  • The loss term is negative entropy. When we minimize this LENTROPYL_{ENTROPY} in the total loss LPPOL_{PPO} (assuming c2c_2 is positive), we are actually maximizing the entropy of the policy.

Encouraging higher entropy promotes exploration, making the policy a bit more random, trying different actions (in the case of LLM, trying different tokens), rather than quickly converging to a potentially suboptimal deterministic policy. This helps the agent discover better strategies.

Final Form LPPOL_{PPO}#

The final PPO loss is the weighted sum of these three parts:

LPPO=LPOLICY+c1LVF+c2LENTROPYL_{PPO} = L_{POLICY} + c_1 L_{VF} + c_2 L_{ENTROPY}
  • c1LVFc_1 L_{VF}: Value function loss, weighted by c1c_1. A common value for c1c_1 is around 0.50.5.
  • c2LENTROPYc_2 L_{ENTROPY}: Entropy reward (if c2>0c_2 > 0, it is actually a penalty for low entropy), weighted by c2c_2. c2c_2 is usually a small normal number (e.g., 0.010.01) to encourage exploration without overwhelming the main policy objective.

The agent's parameters (i.e., the weights of the LLM) are updated by calculating the gradient of this combined loss LPPOL_{PPO} and performing gradient descent.

Reference Model#

Reward Hacking#

A major issue in RL is reward hacking, where the model may learn to always output tokens or sequences that yield good rewards but are meaningless to humans, such as repeatedly saying "thank you" to boost politeness scores. Therefore, we hope that the outputs of the aligned model (after RL post-training) are as close as possible to the original model's outputs.

Thus, there will be another model with frozen weights (reference model), and when generating rewards through the reward model at each step of the trajectory, this reward will be penalized by the KL divergence between the reference model and the optimizing model's log probabilities to prevent the model from generating answers that differ too much from the original model, thereby avoiding the aforementioned model cheating phenomenon.

Screenshot 2025-05-08 at 00.43.14

Code Walkthrough#

trl#

class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
    # ... (class attributes like transformers_parent_class) ...

The core purpose of this class is to bundle a standard Causal Language Model (our Actor Model, responsible for generating text policy πθπ_θ​) with a Value Head (i.e., Critic Model, responsible for estimating state value V(s)). In PPO / Actor Critic algorithms, we need both the policy and the value function simultaneously, and this class provides a unified model structure to output both.

    def __init__(self, pretrained_model, **kwargs):
        super().__init__(pretrained_model, **kwargs) # Basic setup
        v_head_kwargs, _, _ = self._split_kwargs(kwargs) # Separate parameters for ValueHead

        # Ensure the passed model has language model output capability
        if not any(hasattr(self.pretrained_model, attribute) for attribute in self.lm_head_namings):
            raise ValueError("The model does not have a language model head...")

        # Create an instance of ValueHead, which will learn to predict the value of state V(s)
        self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs)

        # Initialize the weights of ValueHead
        self._init_weights(**v_head_kwargs) # Default random initialization, can also specify normal distribution initialization
  1. Acting as Actor: This is our language model pretrained_model, which generates responses (actions a, i.e., a series of tokens) based on the current prompt (state s).
  2. Critic: Evaluates how "good" the Actor is in a certain state s, outputting V(s)V(s). This is the task of the linear layer self.v_head.
    def forward(
        self,
        input_ids=None, # Input token IDs (state s)
        attention_mask=None,
        past_key_values=None, # Used to accelerate generation
        **kwargs,
    ):
        # Force the underlying model to output hidden_states, which ValueHead needs as input
        kwargs["output_hidden_states"] = True
        # ... (handle some details of past_key_values and PEFT, can be ignored for core understanding of PPO)

        # 1. Actor (base language model) performs computation
        base_model_output = self.pretrained_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs,
        )

        # 2. Extract Actor's output (for policy update) and Critic's input
        lm_logits = base_model_output.logits # Actor's output: predicted probability distribution of the next token
        # This forms the basis for calculating L_POLICY and L_ENTROPY in PPO.

        last_hidden_state = base_model_output.hidden_states[-1] # Critic's input: the last hidden state of the LM,
        # representing the representation of the current state s.

        # (Optional) The loss of the language model itself, usually not directly used in the RL phase
        loss = base_model_output.loss

        # (Ensure data and model are on the same device)
        if last_hidden_state.device != self.v_head.summary.weight.device:
            last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device)

        # 3. Critic (ValueHead) performs computation
        # ValueHead receives the state representation and outputs the value estimate for that state V(s)
        value = self.v_head(last_hidden_state).squeeze(-1) # This forms the basis for calculating the value loss L_VF and advantage A_hat in PPO.

        # (Ensure logits are float32 for numerical stability)
        if lm_logits.dtype != torch.float32:
            lm_logits = lm_logits.float()

        # Return Actor's logits, LM loss (which may be None), and Critic's value
        return (lm_logits, loss, value)

For each step of PPO-RLHF training:

  1. We input the current batch of prompts (sequence input_ids) into the model.
  2. self.pretrained_model (Actor) computes (rolls out) lm_logits. These logits represent the probability distribution of the next tokens that the model believes should be generated given the current prompt. Both the policy loss LPOLICYL_{POLICY}​ and the entropy reward LENTROPYL_{ENTROPY​} in PPO need to be calculated based on this probability distribution πθ(atst)π_θ​(a_t​∣s_t​).
  3. Simultaneously, we extract last_hidden_state from base_model_output. This can be seen as a vector representation of the current prompt (state s).
  4. This last_hidden_state is fed into self.v_head (Critic), outputting a scalar value. This value is the model's estimate of the value of the current state s, Vθ(s)V_θ​(s). The value function loss LVFL_{VF} in PPO aims to optimize this Vθ(s)V_θ​(s) to be as close as possible to the true return. Moreover, this Vθ(s)V_θ​(s) is a key component in calculating the advantage function AtA^t​, which in turn guides the calculation of LPOLICYL_{POLICY​}.
  5. The same prompt + response sequence is input to the Reward and Reference models for inference, obtaining rewards and log probabilities (for calculating KL penalties).

Thus, a single forward call provides us with the core information needed to update both the Actor (policy) and the Critic (value function).
The training process can be understood with the help of the following diagram:

rlhf-pipeline

Tip

In RLHF, only the Actor needs to perform Prefill + Decode (complete Auto-Regressive Generation) during experience collection (rollout), while the other models only process existing responses to obtain log probabilities and values, performing only Prefill.

Additionally, the Actor involves both training and inference (referring to rollout), so it requires both a training engine (like Megatron, DeepSpeed, and FSDP) and a rollout engine (like SGLang and vLLM) to complete their respective tasks; the Critic reuses the internal representations from the training forward to output new value predictions, thus running within the same training engine; while the Reference and Reward models only need inference engines to obtain log probabilities and rewards. [3]^{[3]}

verl#

Like OpenRLHF, it is an excellent RLHF framework, and a good introductory read is: 【AI Infra】VeRL Framework Introduction & Code Walkthrough

Reference#

Loading...
Ownership of this post data is guaranteed by blockchain and smart contracts to the creator alone.