banner
Nagi-ovo

Nagi-ovo

Breezing homepage: [nagi.fun](nagi.fun)
github

From DQN to Policy Gradient

For Atari games like "Space Invaders," we need to take game frames as states and inputs, with a single frame consisting of 210x160 pixels. Since the images are in color (RGB), they contain 3 channels. Thus, the observation space shape is (210, 160, 3). The value of each pixel ranges from 0 to 255, resulting in a total of 256210×160×3=256100800256^{210×160×3}=256^{100800} possible observations.

Pasted image 20241004000133

Generating and updating the Q table in this case would be inefficient. Therefore, we use Deep Q-Learning instead of Q-Learning, a Tabular Method, opting for a neural network as an approximator for the Q function. This neural network will approximate the Q values for each possible action based on the given state.

DQN#

Input Preprocessing and Temporal Limitations#

We certainly want to reduce state complexity to decrease the computation time required for training.

Pasted image 20241004000752

Grayscale#

Color does not provide important information, so the three color channels (RGB) can be reduced to one.

Cropping the Screen#

Areas that do not contain important information can be cropped out.

Capturing Temporal Information#

A single frame cannot provide motion information (direction, speed) for a pixel. To obtain temporal information, we stack four frames together.

CNN#

The stacked frames are processed through three convolutional layers, aiming to capture and utilize the spatial relationships in the images. Additionally, since the frames are stacked, we can also obtain temporal information across frames.

MLP#

Finally, a fully connected layer serves as the output, providing a Q value for each possible action in that state.

Pasted image 20241004000340

class QNetwork(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(4, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1),
            nn.ReLU(),
            nn.Flatten(),  # Flatten multi-dimensional input to one dimension
            nn.Linear(3136, 512),  # Fully connected layer, mapping 3136-dimensional input to 512-dimensional
            nn.ReLU(),
            nn.Linear(512, env.single_action_space.n),  # Output layer, corresponding to the action space dimension
        )

    def forward(self, x):
        return self.network(x / 255.0)  # Normalize input to [0,1] range

Input to the network: the 4-frame stack passed through the network as the state.
Output: a Q value vector for each possible action in that state.
Then, similar to Q-Learning, we simply use an epsilon-greedy strategy to choose which action to take.

During the training phase, we no longer directly update the Q values of state-action pairs as in Q-learning: we optimize the weights of DQN by designing a loss function and using gradient descent.

Pasted image 20241004002003

Training Process#

The deep Q-learning training algorithm has two phases:

  • Sampling: Execute actions and store the observed experience tuples in a replay buffer.
  • Training: Randomly select a mini-batch of tuples and learn from that batch using gradient descent update steps.

Pasted image 20241004002352

Due to the combination of non-linear Q-Value functions (neural networks) and bootstrapping (updating targets using existing estimates rather than actual complete returns, which is biased) in deep Q-learning (off-policy), the training process may exhibit instability. The "deadly triad" proposed by Sutton and Barto refers to this situation.

Stable Training#

To help stabilize our training, we implemented three different solutions:

  1. Experience Replay to utilize experience more efficiently.
  2. Fixed Q-Target to stabilize training.
  3. Double DQN to address the issue of overestimating Q values.

Experience Replay#

Experience replay in deep Q-learning serves two functions:

  1. Utilize training experience more efficiently. In typical online reinforcement learning, the agent interacts with the environment to gather experience (state, action, reward, and next state), learns from it (updates the neural network), and then discards that experience, which is highly inefficient. Experience replay helps by utilizing training experience more efficiently. We use a replay buffer to store experience samples for reuse during training.

Pasted image 20241004003215

The agent can learn from the same experience multiple times.

  1. Avoid forgetting previous experiences (i.e., catastrophic interference or catastrophic forgetting) and reduce the correlation between experiences. The setup of the Replay Buffer allows for storing experience tuples while interacting with the environment, then sampling a mini-batch from it. This prevents the network from only learning from the most recent actions. By randomly sampling experiences, we can diversify the experiences encountered, preventing overfitting to short-term states and avoiding drastic fluctuations or catastrophic divergence in action values.

Pasted image 20241004003640

Sampling experiences and calculating loss:

rb = ReplayBuffer(
    args.buffer_size, # Size of the replay buffer, determining how much experience to store.
    envs.single_observation_space,
    envs.single_action_space,
    device,
    optimize_memory_usage=True,
    handle_timeout_termination=False,
)

if global_step > args.learning_starts:
    if global_step % args.train_frequency == 0:
        data = rb.sample(args.batch_size) # Randomly sample a batch
        with torch.no_grad():
            target_max, _ = target_network(data.next_observations).max(dim=1)
            td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten())
        old_val = q_network(data.observations).gather(1, data.actions).squeeze()
        loss = F.mse_loss(td_target, old_val)

Fixed Q-Target#

A key issue in Q-Learning is that the TD target (i.e., Q-Target) and the current Q Value (i.e., Q estimate) share parameters. This leads to Q targets and Q estimates changing simultaneously, like chasing a constantly moving target. A wonderful metaphor is a cowboy (Q estimate) trying to catch a moving cow (Q target). Although the cowboy gradually approaches the cow (error decreases), the target is still moving, causing significant oscillations during training.

Pasted image 20241004012634

Pasted image 20241004012646

Pasted image 20241004012412

I really like this representation 🥹

To solve this problem, we introduce a fixed Q-Target. The core idea is to introduce an independent network that does not update at every time step but instead copies the parameters of the main network to this target network every C steps. This means that our target (Q-Target) remains fixed over multiple time steps and updates the network only based on old estimates. This significantly reduces the oscillation problem between targets and estimates.

Pasted image 20241004012440

As shown in the pseudocode above, the key is to use two different networks: one is the main network (used to select actions and perform updates), and the other is the target network (used to calculate Q-Target), with the main network's weights copied to the target network every C steps. This stabilizes the training process, allowing the "cowboy to more effectively chase the cow," reducing oscillations and speeding up convergence.

q_network = QNetwork(envs).to(device) # Current policy network, responsible for selecting actions and predicting Q values
optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)
target_network = QNetwork(envs).to(device) # Target network, calculates TD targets, providing stable learning targets.
target_network.load_state_dict(q_network.state_dict()) # Initialization: target network parameters are the same as the current policy network

Every args.target_network_frequency steps, the parameters of the main network are fully copied to the target network. This means that over multiple time steps, the Q-Target remains fixed, updating the network only based on old estimates, significantly reducing the oscillation problem between targets and estimates.

tau = 1.0

if global_step % args.target_network_frequency == 0:
    for target_param, param in zip(target_network.parameters(), q_network.parameters()):
        target_param.data.copy_(args.tau * param.data + (1.0 - args.tau) * target_param.data)

Double DQN#

Double DQN was proposed by Hado van Hasselt specifically to address the problem of overestimating Q values.

In the TD-Target calculation of Q-Learning, a common issue is "how to determine that the best action in the next state is the action with the highest Q value?" We know that the accuracy of Q values depends on the actions we try and the neighboring states we explore. Therefore, in the early stages of training, the information about the best action is insufficient. If we only select actions based on the highest Q value, it may lead to misjudgments.

For example, if a non-optimal action is assigned a Q value higher than the best action, the learning process becomes complicated and difficult to converge. To address this issue, Double DQN introduces two networks to decouple action selection and the generation of Q value targets:

  1. The main network (DQN network) is used to select the best action in the next state (i.e., the action with the highest Q value).
  2. The target network (Target network) is used to calculate the target Q value generated by executing that action.

Pasted image 20241004013644

with torch.no_grad(): 
	# Use the main network to select the best action in the next state 
	next_q_values = q_network(data.next_observations) 
	next_actions = torch.argmax(next_q_values, dim=1, keepdim=True) 
	# Use the target network to evaluate the Q Value of these actions
	target_q_values = target_network(data.next_observations) 
	target_max = target_q_values.gather(1, next_actions).squeeze() 
	# Calculate TD-Target 
	td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten()) 
	# Calculate current Q Value
	old_val = q_network(data.observations).gather(1, data.actions).squeeze() 
	loss = F.mse_loss(td_target, old_val)

Modern deep reinforcement learning also includes further improved techniques, such as prioritized experience replay and dueling networks, which are not covered here.

Optuna#

One of the most critical tasks in deep reinforcement learning is to find a good set of training hyperparameters. Optuna is a library that helps automate the search for the best hyperparameter combinations.

Policy Gradient#

The previous Q-Learning and DQN belong to value-based methods, which indirectly seek the optimal policy by estimating the value function. The existence of the policy (π\pi) depends entirely on the estimation of action values, as the policy is generated from the value function, such as a greedy policy that selects the action with the highest value in a given state.

By using policy-based methods, we aim to directly optimize the policy, thus bypassing the intermediate step of learning the value function. Next, we will delve into one of its subsets, namely policy gradient (Policy Gradient).

In policy-based methods, optimization is mostly on-policy, as we only use data collected by the latest version of πθ\pi_{\theta} during each update (action trajectories).

Parameterized Stochastic Policy#

For example, let the neural network πθ\pi_{\theta} output a probability distribution of actions (stochastic policy) πθ(as)\pi_{\theta}(a|s):

Pasted image 20241006164702

The objective function J(θ)J(\theta) optimizes the parameters θ\theta to maximize the performance of the parameterized policy through gradient ascent.

Advantages#

Convenient Integration#

  • Can directly estimate the policy without storing additional data (action value), which can be understood as end-to-end.

Can Learn Stochastic Policies#

Since the output is a probability distribution of actions, the agent can explore the state space without always following the same trajectory, eliminating the need for manual implementation of exploration/exploitation trade-offs. DNQ learns a deterministic policy, and we introduce randomness through some tricks (like ε-greedy strategy), but this is not an inherent feature of value function methods. It can also naturally handle uncertainty in states, addressing the perceptual aliasing problem.

For example, in the scenario below, the vacuum cleaner agent needs to suck up dust while avoiding harming the hamster, and the vacuum cleaner can only perceive the position of the walls. In the diagram, these two red states are called "aliased states" because in these states, the agent perceives the position of the walls—i.e., there are walls both above and below. This leads to ambiguity in the states, making it impossible to distinguish which specific red state it is in.

Pasted image 20241006170520

When using a deterministic policy, the vacuum cleaner will always move right or left in the red state; if it chooses the wrong direction, it will get stuck in a loop. Even with an ε-greedy strategy, the vacuum cleaner primarily follows the best strategy, but it may still repeatedly explore the wrong direction in the wrong state, leading to inefficiency.

Pasted image 20241006170145

Effective in High-Dimensional, Continuous Action Spaces#

Policy gradient methods are particularly effective in high-dimensional or continuous action spaces.

Autonomous vehicles may have infinitely many action choices in each state—for example, the steering wheel can turn 15°, 17.2°, 19.4°, or perform other actions like honking. Deep Q-Learning must calculate Q values for each possible action, and selecting the maximum Q value in a continuous action space is itself an optimization problem.
image
In contrast, policy gradient methods directly output the probability distribution of actions, eliminating the need to compute and store Q values for each action, making them more efficient in complex continuous action scenarios.

Better Convergence#

In value methods, we update the policy by taking the maximum Q value using argmaxargmax. In this case, even minor changes in Q values can lead to drastic changes in action selection. For example, if the Q value for turning left during training is 0.22, and then the Q value for turning right becomes 0.23, the policy will change significantly, favoring turning right over left.

In policy gradient methods, the probabilities of actions change smoothly over time, leading to a more stable policy.

Disadvantages#

Local Optima#

Often converges to local optima rather than global optima.

Low Training Efficiency#

The training process is slow and inefficient.

High Variance#

There is a high variance, which will be discussed in the subsequent actor-critic section regarding reasons and solutions.

Specific Analysis#

Policy gradient adjusts parameters (policy) through each interaction of the agent with the environment, allowing the probability distribution of actions to sample more of those good actions that maximize returns.

Pasted image 20241006172416

Objective Function#

Pasted image 20241006174139

Our goal is to find parameters θ\theta that can maximize expected returns:

maxθJ(θ)=Eτπθ[R(τ)]\max_{\theta} J(\theta) = \mathbb{E}_{\tau \sim \pi_{\theta}} [R(\tau)]

Since this is a concave function (we want to maximize the value), we use the gradient ascent method: θθ+αθJ(θ)θ←θ+α∗∇_θJ(θ).

However, the true gradient of the objective function cannot be computed because it requires calculating the probabilities of every possible trajectory, which is computationally expensive. Therefore, we hope to estimate the gradient through sample-based estimation (collecting some trajectories).

In addition, the state transition probabilities of the environment (or state distribution) are often unknown, or even if known, they are complex and non-linear, making it impossible to directly compute their derivatives, meaning we cannot directly differentiate the dynamics of state transitions (governed by Markov decision processes) to optimize the policy.

Policy Gradient Theorem#

The complete derivation can be found in Andrej Karpathy's blog, and I previously summarized my learning here:
Introduction to Policy Gradient 6. PG Derivation

Pasted image 20241006180211

Reinforcement Algorithm (Monte Carlo Reinforcement)#

Use the estimated returns of the entire episode to update the policy parameters θ\theta.
Collect a segment ττ using the policy πθπ_θ, and use that episode to estimate the gradient g^=θJ(θ)\hat{g}=∇_θJ(θ)
Optimize: θθ+αg^θ←θ+α\hat{g}

Pasted image 20241006181910

  • θlogπθ(atst)\nabla_{\theta} \log \pi_{\theta}(a_t | s_t):
    This part represents the gradient of the log probability of an action ata_t given a state sts_t, rather than calculating the specific action value (Q value).

  • R(τ)R(\tau):
    Here, R(τ)R(\tau) is the cumulative return over the entire trajectory τ\tau, used to measure the total return after executing the policy πθ\pi_\theta. A higher return increases the probability of the (state, action) combination, while a lower return decreases it.

Multiple segments (trajectories) can also be collected to estimate the gradient:
Pasted image 20241006182438

References#

cleanrl codebase

Loading...
Ownership of this post data is guaranteed by blockchain and smart contracts to the creator alone.