This article is mainly based on Umar Jamil's course for learning and recording. Our goal is to align the behavior of LLMs with our expected outputs, and RLHF is one of the most well-known techniques. Its standard process involves four models (which sounds memory-intensive, so many methods remove some models), but just remember that a total of four are needed: Reward, Actor, Critic, and Reference Model; the model we optimize at the end is the Actor Model mentioned here.
LLM to RL#
Previously, my understanding of RL was that a policy tells you the probability of the action you should take in the current state. In this sense, the language model itself can be viewed as a policy: it receives a prompt (state) and outputs the probability of the next token (action), sampling to obtain a new state (the token appended to the prompt), which is equivalent to a policy with an action space of size vocab_size
, also an RL agent.
So, that means we still need something to provide rewards (which is usually an environment-built reward function in traditional RL).
Creating a "Q-A-Reward" dataset can achieve this, but humans are not good at finding consensus; however, they excel at comparing advantages. Therefore, we shift our focus: the model generates multiple answers (A) under high temperature, and then we ask domain experts (who can be human or AI models) to select the chosen/preferred answer, labeling a preference dataset to train a reward model that generates numerical rewards.
Reward Model#
This RM is implemented using a pre-trained LLM like Llama.
Note
In text generation tasks, we take the last hidden state (of the token) from the embedding (hidden states) produced by the Transformer after inputting the prompt, project it linearly into the vocabulary to obtain logits, and then use softmax and sampling strategies to select the next token.
When we want to generate numerical rewards instead of text, we can replace the linear projection to the vocabulary with a linear layer that has one output feature (outputting a scalar) to produce a single score for the entire text sequence.
Reward Model Loss#
Tip
During training, we want this model to generate high rewards for the chosen answers and low rewards for the unchosen answers.
Similar to the parameterized form of Bradley-Terry:
represents chosen, and represents the opposite. Therefore, when the model gives a high reward to the chosen answer,
Thus, the loss will be low, while if the model gives a low reward to the chosen answer, the loss will be very high.
The RewardTrainer
class in HuggingFace receives an AutoModelForSequenceClassification
input (which is the model structure we mentioned above).
Actor & Critic Model#
Trajectories#
As mentioned earlier, the core objective of reinforcement learning (RL) is to find a policy () that can guide the agent's actions to achieve the maximum possible expected return.
Mathematically, we express this as finding the optimal policy that maximizes the objective function :
Expected return represents the average total return that the agent can accumulate over many possible lifetimes or episodes while following policy .
The calculation method is: consider all possible trajectories () and weight the total return of each trajectory by the probability of that trajectory occurring under policy (averaging or integrating).
- indicates the expected value when trajectory is generated according to policy .
- is the total return obtained on a single trajectory .
- is the probability of a specific trajectory occurring when the agent uses policy .
Trajectory is a sequence of states and actions experienced by the agent, starting from the initial state. It is a possible "story" or "path" of the agent's interaction with the environment.
- : state at time step .
- : action taken at time step (usually based on state and policy ).
We typically model the environment as stochastic. This means that executing the same action in the same state does not always lead to the exactly same next state , which involves randomness.
The next state is drawn from a probability distribution conditioned on the current state and the action taken :
Considering the stochastic state transitions and the agent's policy, we can calculate the probability of the entire trajectory occurring. It is obtained by multiplying the following components:
- The probability of the agent being in the initial state : .
- For each time step in the trajectory:
- The probability of the environment transitioning to state given and : .
- The probability of the agent selecting action in state according to its policy: .
(where is the length of the trajectory).
When calculating the total return for the trajectory, we almost always use discounted rewards. This means that rewards received earlier are more valuable than those received later.
Why?
- It reflects real-world scenarios (a dollar in hand today is worth more than a dollar promised tomorrow).
- It avoids the problem of infinite returns in ongoing tasks (tasks without a fixed endpoint).
- It provides mathematical convenience.
We introduce a discount factor (), where . The closer is to 0, the more "short-sighted" the agent is (focusing more on immediate benefits); the closer is to 1, the more "far-sighted" the agent is (focusing more on long-term returns).
The total discounted return for the trajectory is calculated as follows:
- is the immediate reward received at time step .
- is the discount factor applied to the reward at time step .
So, what is a trajectory in LLM? As mentioned earlier, the model is the policy, the prompt is the state, and the next token is the action; thus, these sequences of s and a in autoregressive generation constitute the trajectory.
Policy Gradient#
We have established the goal of reinforcement learning: to find an optimal policy that maximizes the expected return . Great. But how do we actually represent and find this policy?
Typically, especially when dealing with complex problems, we do not search all possible policies. Instead, we define a parameterized policy, denoted as . You can think of as a set of "knobs" or parameters—if our policy is a neural network, then might be the weights and biases of the network.
Our current goal becomes: how to adjust these knobs to maximize our expected return?
Note
Under the parameterized policy , the expected return for all possible trajectories:
This means that the expected return depends on the trajectory , and the distribution of trajectories depends on the actions chosen by our specific policy . Changing changes the policy, changes the trajectories, and thus changes the expected return.
Note
We want to maximize by changing . In deep learning, we generally use gradient descent to minimize a loss function. Here, we want to maximize a function . Therefore, we use gradient ascent! It’s like climbing a mountain—wanting to find the steepest ascent direction (the gradient) and then taking a step in that direction.
Our policy is a neural network, and we will iteratively adjust its parameters to increase . This update rule looks very familiar (just replacing the minus sign in gradient descent with a plus sign):
- : our parameters at the -th iteration.
- : learning rate (step size).
- : the gradient of expected return with respect to parameters , calculated at the current parameters . It tells us which direction in parameter space maximally increases .
Important
PG Derivation
Here, I will start a bit verbose to reintroduce the indices, ADHD-friendly derivation...
Step 1, reiterate that the object we want the gradient of is the expected return :
Here:
- is the expected return.
- indicates the expected value, calculated over all possible trajectories (). The trajectory is a series of states and actions generated by the agent interacting with the environment .
- indicates that these trajectories are generated according to our current policy .
- refers to the total return obtained from a complete trajectory (usually discounted return).
- is the gradient operator, indicating that we want to take the partial derivative with respect to parameters .
Step 2, we expand the expression for the expectation:
What is the definition of expected value? For a random variable , its expectation can be calculated using its probability distribution :
- If it’s a continuous variable:
- If it’s a discrete variable:
In our case, the random variable is the return of the trajectory , and the probability distribution is the probability of the trajectory occurring (the probability of trajectory occurring under policy ). Thus, the expectation can be expressed in integral (or summation) form:
(Here, the integral symbol represents summing or integrating over all possible trajectories, which is more general).
Substituting this into the formula from Step 1:
Step 3: Move the gradient operator inside the integral
This requires some knowledge of calculus: under certain conditions (which we usually assume are satisfied in reinforcement learning), we can exchange the order of differentiation and integration. Just like .
Next, note that is the total return after a trajectory is determined; its value does not directly depend on the policy parameters . (It is the policy that affects which trajectory occurs, not the return value once that trajectory occurs). Therefore, the gradient only needs to act on :
This step tells us that the change in expected return is due to changes in the probability of each trajectory occurring caused by changes in parameter , multiplied by the return of that trajectory, and then summed over all trajectories.
Step 4: Log-derivative trick
This is the most core and clever step in the entire derivation! We need to introduce an identity.
- Calculus Review (Chain Rule and Log Derivative): Recall that the derivative of the natural logarithm (usually referring to ) is .
- Rearranging slightly, we get: .
Now, we apply this trick to the gradient. Let correspond to , and the variable correspond to the parameter . Then:
Substituting this result into the integral from Step 3:
Step 5: Return to expectation form
Observe the result from Step 4:
This fits the definition of expectation! .
Here, corresponds to the entire content in brackets .
Therefore, the entire integral can be written back in expectation form:
Significant Implication! We have successfully transformed the gradient of the expectation into the expectation of some quantity (gradient multiplied by return) . This form is very important because it can be approximated through sampling! We do not need to actually compute the integral over all trajectories. We just need to sample many trajectories , compute the value in the brackets for each trajectory, and then average to obtain an approximate gradient!
Step 6: Expand the gradient of log probability (Expression for grad-log-prob)
Now, we need to handle the term inside the expectation.
Recall that the trajectory (assuming the trajectory length consists of T+1 states and T+1 actions, or T time steps). The probability of a trajectory occurring is:
-
: the probability of the initial state .
-
: the environment dynamics, the probability of transitioning to state after executing action in state .
-
: the policy, the probability of selecting action in state (this part depends on ).
-
Mathematical Review (Log Properties): and .
Taking the logarithm of :
Now take the gradient of the above expression:
- Mathematical Review (Gradient Properties): The gradient addition rule . The gradient only acts on terms that depend on .
- Key Points:
- The initial state probability typically does not depend on the policy parameter , so .
- The environment dynamics describe the properties of the environment itself and also do not depend on the policy parameter , so .
- Only the policy depends on .
Therefore, the above expression simplifies to:
The gradient of the log probability of the entire trajectory equals the sum of the log probability gradients of each action in that trajectory! This greatly simplifies the computation.
Step 7: The final policy gradient theorem
Substitute the simplified result from Step 6 back into the expectation formula from Step 5:
This is the final form of the Policy Gradient Theorem (or one of its common forms).
The gradient of expected return with respect to parameter equals "sample a trajectory , compute the total return for that trajectory, multiply it by the sum of the log probability gradients of the policy for all (state, action) pairs in that trajectory , and then take the expectation (average) over all possible trajectories.
Clearly, the cost of obtaining all trajectories is extremely high; for example, we need to sample all generation results with max_token_length=100
, so we can approximate the expectation using the sample mean:
Note
Monte Carlo approximation: * Run the current policy , collect trajectories, forming the dataset (let ).
- Use the average of these samples to approximate the expected value:
Application to LM Policy#
By obtaining the log probabilities of each state-action pair in the sampled trajectory through the generation process shown in the figure, we can now backpropagate to calculate the gradient.
Then, multiply each gradient by the reward from the RM and input it into the expression to perform gradient ascent optimization:
High Variance#
PG algorithms perform well on small problems, but there are some issues when used for language modeling.
Note
The central limit theorem tells us: as long as the sample size is large enough, the sample mean will follow a normal distribution, allowing us to better predict and analyze data. When the sample size is small, the fluctuations in the sample mean can be large; even if the mean tends toward a normal distribution, the results of a single sample may vary greatly. We also know that the cost of sampling many trajectories from the LM is very high, leading to the problem of high variance in estimates.
How can we reduce variance without increasing the sample size?
- Remove historical rewards: reward-to-go
It must be acknowledged that the current action cannot affect rewards already obtained in the past, and past rewards add unnecessary noise, which should relate to the credit assignment problem in RL. Therefore, removing past terms can avoid adding noise and bring the estimated gradient closer to the true gradient. Thus, instead of calculating the rewards from scratch, we can only consider the rewards of actions starting from the current time step.
- Introduce a baseline
Research in RL has confirmed that introducing a term dependent on the state (such as a function that calculates the trajectory reward, which can also be a constant) can reduce variance. Here we choose the value function .
Value Function#
tells you the expected reward of the remaining trajectory when acting according to the current policy.
Examples of value definitions in classic RL scenarios and LM scenarios:
In practice, we use the LM we are trying to optimize as initialization, adding a linear layer on top to estimate the value, so that the parameters of the Transformer layer can be used for both language modeling (projecting tokens into the vocabulary) and value estimation.
The previously mentioned reward-to-go is referred to as the Q function in RL, which is the expected reward obtained by taking this action from the current state and completing subsequent actions according to the policy:
By introducing the value function, we obtain the difference between Q and V, which is referred to as the Advantage function.
This advantage term indicates how much better this specific action is compared to the average action that can be taken in state .
In the state pointed to by the red arrow in the figure, the advantage function for moving down will be higher than that of other actions.
Multiplying the gradient by the advantage function changes the effect to increase the log probability of actions with high advantage and decrease the log probability of actions that yield low average returns.
Note
In traditional reinforcement learning methods, Q networks and V networks are usually independent. That is, the Q function is used to estimate the expected total return of executing action in state , while the V function simply estimates the value of state . This requires two different neural networks to compute these two values separately.
However, we have now introduced the advantage function , which is calculated based on the difference between Q values and V values, namely:
By expressing as the difference between and , we find that we only need to train one network to output , and then calculate the Q value using the reward and discount factor .
Thus, only one neural network is needed, which primarily predicts . The Q value is calculated as follows:
The advantage function is further calculated as:
Advantage Sampling#
Short-step advantage estimators have high bias but low variance, while long-step advantage estimators have low bias but high variance. This trade-off is a part of reinforcement learning that requires careful selection and adjustment, depending on the stability requirements of the model and training efficiency.
An example: "A short-term memory person only remembers what happened yesterday; although not comprehensive, it is very stable; a long-term memory person can see the whole picture for the next few days but may be disturbed by more unknown factors."
GAE#
To address this bias-variance problem, we can use GAE (Generalized Advantage Estimation), which essentially is a weighted sum of all advantage terms, each multiplied by a decay factor.
Note
Now let's talk about TD error.
Online learning has a wonderful aspect: you do not need to wait until the end to update the policy. Thus, the Temporal Difference Error (TD Error) comes into play:
The key here is that the TD error is actually an online estimate of the advantage function. It tells you whether your action makes the future state better than you expected at this moment. This error directly reflects the concept of advantage:
- If : "Hey, this action is better than I imagined!" (advantage is positive).
- If : "Hmm, I thought it would be better..." (advantage is negative).
This allows you to gradually adjust your policy without waiting for an entire episode to end. This is an excellent strategy for improving efficiency.
The purpose of GAE is to provide a better estimate of the advantage function in policy gradient algorithms than the original return or simple TD error , to reduce the variance of gradient estimates and improve learning stability and efficiency.
This formula recursively defines the generalized advantage estimate . It does not only look at the one-step TD error but integrates the TD error information from multiple future steps.
This recursive formula calculates from the end of the trajectory (episode) (assuming is the last step, ):
*
*
*
* ...
* General Form: (assuming infinite steps or that after the terminal state is 0).
The parameter () is key to GAE; it controls the bias and variance of the estimate :
- When :
- . GAE degenerates into a simple one-step TD error. This estimate has lower variance (because it only relies on the information from the next step) but may have higher bias (because it heavily depends on potentially inaccurate estimates of ).
- When :
- . It can be proven that this is equivalent to , which is the Monte Carlo (MC) return minus the baseline. This estimate has lower bias (because it uses the complete actual return starting from time ) but usually has high variance (because it accumulates randomness from multiple time steps).
- When :
- GAE interpolates between the two extreme cases mentioned above. The closer is to 0, the more it leans towards low variance and high bias TD estimates; the closer is to 1, the more it leans towards high variance and low bias MC estimates.
- By choosing an appropriate (e.g., 0.97), GAE attempts to achieve a good balance between bias and variance, resulting in an advantage estimate that is both relatively accurate (controllable bias) and relatively stable (lower variance).
Advantage in Language Models#
As shown in the figure, the goal is to increase the log probability of the token "Shanghai" in the current state while decreasing the log probability of "chocolate," as the advantage of choosing "Shanghai" is higher than that of choosing "chocolate" (a random token).
Importance Sampling and Offline Learning#
In many cases, we may want to compute , but:
- It is difficult or impossible to sample directly from the target distribution .
- Or sampling from is inefficient, which is the case in language models where LM sampling is too costly.
However, we may easily sample from another alternative, or proposal distribution .
Importance Sampling (IS) is a technique for estimating the expectation of a target distribution by sampling from a different distribution, transforming an expectation calculated under the probability distribution into the expectation of a related function under another different probability distribution :
The key here is the introduction of the importance weight . The role of this weight is to correct the bias: for a sample drawn from , if its probability of occurring in the target distribution is higher (), it will receive a weight greater than 1; conversely, if its probability in is lower (), it will receive a weight less than 1. By averaging with these weights, we can obtain an estimate of the original expectation that is (usually unbiased or consistent).
Importance sampling allows us to:
- Draw samples from an easily sampled distribution .
- Estimate the original expected value by calculating the weighted average:
Returning to our scenario. Previously, we obtained the on-policy policy gradient estimate:
[!info]
On-Policy means that the policy used to collect data is the same as the one used during training. Since the calculation requires using trajectories generated from the current policy , this means that after each policy update, old data cannot be directly reused, leading to low sample efficiency. As for mini_batch_num > 1, is the data used for updates still considered on-policy? Strictly speaking, it feels like it isn't, so it can also be understood as semi-on-policy? (The expression may not be rigorous).On-Policy emphasizes whether the current policy model can interact with the environment.
We hope to utilize data generated by the old policy to estimate the gradient of the current new policy . This way, we can reuse data and improve sample efficiency.
Recall the principle of importance sampling IS: .
Corresponding to our PG (simplifying to single-step decisions):
* The target distribution corresponds to the new policy .
* The sampling distribution corresponds to the old policy .
* The importance weight is .
Applying the importance weight to each term of the on-policy gradient (for each time step ), we obtain the standard off-policy estimate:
Now we have found a way to perform complete gradient ascent optimization without having to sample from the policy we are optimizing (the model to be trained) each time; instead, we can sample once, save the trajectories to memory/database, optimize the policy using mini-batch, and then initialize the offline policy (the sampled policy) with the new policy.
PPO Loss#
PPO Loss mainly consists of three parts: policy loss (), value function loss (), and entropy reward ().
1. Policy Loss ()#
Clipped Surrogate Objective
This is the core of PPO. You will notice that it resembles the off-policy policy gradient objective derived using importance sampling, but with a key modification.
-
: This is the importance sampling ratio, referred to as . It is the probability of taking action in state according to the current (online) policy , divided by the probability of taking that action according to the old (offline) policy when collecting trajectory data. This ratio corrects for the fact that the data comes from a policy that is slightly different from the one we are currently trying to improve.
-
: This is the estimated advantage function, calculated using GAE, which helps balance bias and variance. It tells us how much better or worse taking action in state is compared to taking the average action in that state (as judged by the current value function).
-
clip
function: This is the key point of PPO.
It essentially says: if the probability ratio deviates too far from 1 (either too high or too low), we "clip" it. So, if tries to become and is , it will be clipped to . If it tries to become , it will be clipped to .
The parameter (epsilon) is a small hyperparameter (e.g., 0.1 or 0.2) that defines the clipping range . -
min
function: This objective function takes the smaller of the two items below:- Unclipped objective:
- Clipped objective:
Why do it this way? The objective of policy gradients is to increase the probability of actions with positive advantages and decrease the probability of actions with negative advantages.
However, using importance sampling, if becomes very large, it can lead to huge updates and instability. PPO tries to maintain proximity between the new policy and the old policy by clipping this ratio.- If (good action): We want to increase . The
min
function means that if grows beyond , the objective function will be limited to . This prevents the policy from changing too much in a single update, even if the unclipped objective suggests a larger increment. - If (bad action): We want to decrease . If shrinks below , the objective function will be limited to . (Note: when , the term is larger (closer to zero or positive) when is small, while is also larger when is small. The
min
operation effectively means that when the ratio exceeds the clipping boundaries, we take a more pessimistic update step, or one that results in a smaller change in log probability.)
More accurately, when , the product becomes more negative as increases. Themin
operation ensures that if deviates from the interval, we do not let the objective become too negative (that is, we do not overly reduce the probability of that action).
2. Value Function Loss ()#
This is exactly the same as the previous content:
- is the output of the value network (i.e., a linear layer added on top of the LLM to predict the expected cumulative reward starting from state ).
- (referred to as or target value) is the actual total discounted reward observed starting from state and following the current policy until the episode ends. This is the empirical target we set for .
- This loss function is the mean squared error (MSE) between the predicted value and the observed target value . We want the value function to accurately predict future rewards. This value function is crucial for calculating the advantage .
3. Entropy Reward ()#
- Here, (or more accurately, , the action probability distribution output by the current policy for all possible actions given state ) represents the action probability distribution output by the current policy in the given state.
- is the entropy of this probability distribution. Entropy measures the randomness or uncertainty of the distribution. A uniform distribution (very random) has high entropy, while a peaked distribution (very certain about a specific action) has low entropy.
- The loss term is negative entropy. When we minimize this in the total loss (assuming is positive), we are actually maximizing the entropy of the policy.
Encouraging higher entropy promotes exploration, making the policy a bit more random, trying different actions (in the case of LLM, trying different tokens), rather than quickly converging to a potentially suboptimal deterministic policy. This helps the agent discover better strategies.
Final Form #
The final PPO loss is the weighted sum of these three parts:
- : Value function loss, weighted by . A common value for is around .
- : Entropy reward (if , it is actually a penalty for low entropy), weighted by . is usually a small normal number (e.g., ) to encourage exploration without overwhelming the main policy objective.
The agent's parameters (i.e., the weights of the LLM) are updated by calculating the gradient of this combined loss and performing gradient descent.
Reference Model#
Reward Hacking#
A major issue in RL is reward hacking, where the model may learn to always output tokens or sequences that yield good rewards but are meaningless to humans, such as repeatedly saying "thank you" to boost politeness scores. Therefore, we hope that the outputs of the aligned model (after RL post-training) are as close as possible to the original model's outputs.
Thus, there will be another model with frozen weights (reference model), and when generating rewards through the reward model at each step of the trajectory, this reward will be penalized by the KL divergence between the reference model and the optimizing model's log probabilities to prevent the model from generating answers that differ too much from the original model, thereby avoiding the aforementioned model cheating phenomenon.
Code Walkthrough#
trl#
class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
# ... (class attributes like transformers_parent_class) ...
The core purpose of this class is to bundle a standard Causal Language Model (our Actor Model, responsible for generating text policy ) with a Value Head (i.e., Critic Model, responsible for estimating state value V(s)). In PPO / Actor Critic algorithms, we need both the policy and the value function simultaneously, and this class provides a unified model structure to output both.
def __init__(self, pretrained_model, **kwargs):
super().__init__(pretrained_model, **kwargs) # Basic setup
v_head_kwargs, _, _ = self._split_kwargs(kwargs) # Separate parameters for ValueHead
# Ensure the passed model has language model output capability
if not any(hasattr(self.pretrained_model, attribute) for attribute in self.lm_head_namings):
raise ValueError("The model does not have a language model head...")
# Create an instance of ValueHead, which will learn to predict the value of state V(s)
self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs)
# Initialize the weights of ValueHead
self._init_weights(**v_head_kwargs) # Default random initialization, can also specify normal distribution initialization
- Acting as Actor: This is our language model
pretrained_model
, which generates responses (actions a, i.e., a series of tokens) based on the current prompt (state s). - Critic: Evaluates how "good" the Actor is in a certain state s, outputting . This is the task of the linear layer
self.v_head
.
def forward(
self,
input_ids=None, # Input token IDs (state s)
attention_mask=None,
past_key_values=None, # Used to accelerate generation
**kwargs,
):
# Force the underlying model to output hidden_states, which ValueHead needs as input
kwargs["output_hidden_states"] = True
# ... (handle some details of past_key_values and PEFT, can be ignored for core understanding of PPO)
# 1. Actor (base language model) performs computation
base_model_output = self.pretrained_model(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs,
)
# 2. Extract Actor's output (for policy update) and Critic's input
lm_logits = base_model_output.logits # Actor's output: predicted probability distribution of the next token
# This forms the basis for calculating L_POLICY and L_ENTROPY in PPO.
last_hidden_state = base_model_output.hidden_states[-1] # Critic's input: the last hidden state of the LM,
# representing the representation of the current state s.
# (Optional) The loss of the language model itself, usually not directly used in the RL phase
loss = base_model_output.loss
# (Ensure data and model are on the same device)
if last_hidden_state.device != self.v_head.summary.weight.device:
last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device)
# 3. Critic (ValueHead) performs computation
# ValueHead receives the state representation and outputs the value estimate for that state V(s)
value = self.v_head(last_hidden_state).squeeze(-1) # This forms the basis for calculating the value loss L_VF and advantage A_hat in PPO.
# (Ensure logits are float32 for numerical stability)
if lm_logits.dtype != torch.float32:
lm_logits = lm_logits.float()
# Return Actor's logits, LM loss (which may be None), and Critic's value
return (lm_logits, loss, value)
For each step of PPO-RLHF training:
- We input the current batch of prompts (sequence
input_ids
) into the model. self.pretrained_model
(Actor) computes (rolls out)lm_logits
. These logits represent the probability distribution of the next tokens that the model believes should be generated given the current prompt. Both the policy loss and the entropy reward in PPO need to be calculated based on this probability distribution .- Simultaneously, we extract
last_hidden_state
frombase_model_output
. This can be seen as a vector representation of the current prompt (state s). - This
last_hidden_state
is fed intoself.v_head
(Critic), outputting a scalarvalue
. Thisvalue
is the model's estimate of the value of the current state s, . The value function loss in PPO aims to optimize this to be as close as possible to the true return. Moreover, this is a key component in calculating the advantage function , which in turn guides the calculation of . - The same prompt + response sequence is input to the Reward and Reference models for inference, obtaining rewards and log probabilities (for calculating KL penalties).
Thus, a single forward
call provides us with the core information needed to update both the Actor (policy) and the Critic (value function).
The training process can be understood with the help of the following diagram:
Tip
In RLHF, only the Actor needs to perform Prefill + Decode (complete Auto-Regressive Generation) during experience collection (rollout), while the other models only process existing responses to obtain log probabilities and values, performing only Prefill.
Additionally, the Actor involves both training and inference (referring to rollout), so it requires both a training engine (like Megatron, DeepSpeed, and FSDP) and a rollout engine (like SGLang and vLLM) to complete their respective tasks; the Critic reuses the internal representations from the training forward to output new value predictions, thus running within the same training engine; while the Reference and Reward models only need inference engines to obtain log probabilities and rewards.
verl#
Like OpenRLHF, it is an excellent RLHF framework, and a good introductory read is: 【AI Infra】VeRL Framework Introduction & Code Walkthrough