This article is a digestion and understanding of articles, video tutorials, and code implementations such as Sunrise: Understanding AlphaZero, MCTS, Self-Play, UCB from Scratch.
This article will start from the design principles of AlphaGo, gradually revealing how to build an AI Gomoku system that can surpass human capabilities through an in-depth understanding of the two core mechanisms: MCTS and Self-Play.
AlphaGo: From Imitation to Surpassing#
Evolution Path of AlphaGo#
The number of possible Go positions exceeds the number of atoms in the universe, rendering traditional exhaustive search methods completely ineffective. AlphaGo addresses this problem through a phased approach: first learning human knowledge, then continuously evolving through self-play.
This evolutionary process can be divided into three layers:
- Imitating human experts
- Self-play enhancement
- Learning to evaluate situations
Core Components#
Rollout Policy#
The leftmost lightweight policy network is used for rapid evaluation, with lower accuracy but high computational efficiency.
SL Policy Network#
The supervised learning policy network learns to play by imitating human games:
- Input: Board state
- Output: Probability distribution of the next move imitating human experts
- Training data: 160,000 games, approximately 30 million moves
RL Policy Network#
Similar to when your skill reaches a certain level, you start reviewing your games and playing against yourself, discovering new tactics and deeper strategies. The RL network is initialized from the SL network, which has already surpassed human capabilities and can find many powerful strategies previously undiscovered by humans.
- This stage generates a large amount of self-play data, with the input still being the board state, and the output is the improved move strategy through reinforcement learning.
- Self-play with the historical version of the policy network , the network adjusts parameters through the "win" signal, reinforcing those moves that lead to victory:
Where:
- is the game sequence
- is the win-loss label (win is positive, loss is negative)
Value Network#
The value network learns to evaluate situations and can be a CNN:
- Training objective: Minimize mean squared error
- : Final win-loss result
- : Predicted winning probability
Supplementary Explanation#
About the relationship between and :
- The policy network provides specific move probabilities to guide the search.
- The value network provides situation evaluations, reducing unnecessary simulations in the search tree.
- The combination of both allows MCTS to not only explore high win-rate paths faster but also significantly improve overall playing strength.
AlphaGo's MCTS Implementation#
Selection Phase#
Combines exploration and exploitation:
Where:
- : Long-term return estimate
- : Exploration reward
- : Prior probability output from the policy network
- : Visit count
Simulation and Evaluation#
In the original MCTS algorithm, the simulation phase's role is to conduct random games from the leaf nodes (newly expanded nodes) using a fast rollout policy until the game ends, then obtain a return based on the win-loss outcome. This return is passed back to the nodes in the search tree to update their value estimates (i.e., ).
This implementation is simple and direct, but the rollout policy is often random or based on simple rules, leading to potentially poor simulation quality. It can only provide short-term information and does not effectively integrate global strategic evaluations.
AlphaGo enhances the simulation phase by combining the Value Network and the rollout policy. The Value Network provides more efficient leaf node estimates and global capability evaluations, while the rollout policy captures local short-term effects through rapid simulations.
Using the hyperparameter to balance and , it considers both local simulations and global evaluations. The evaluation function:
Backpropagation#
During n times of MCTS, the node visit count updates (where is the indicator function, with a visit being 1):
Q value updates, i.e., executing a to node 's long-term expected return:
Summary#
-
Structural Innovation:
- The policy network provides prior knowledge
- The value network provides global evaluation
- MCTS provides tactical validation
-
Training Innovation:
- Starting from supervised learning
- Surpassing the teacher through reinforcement learning
- Self-play generates new knowledge
-
MCTS Improvements:
- Using neural networks to guide the search
- The Policy Network provides prior probabilities for exploration directions, and the Value Network improves the accuracy of leaf node evaluations.
- This combination of value network and rollout evaluations not only reduces search width and depth but also significantly enhances overall performance.
- Efficient exploration-exploitation balance
This design allows AlphaGo to find efficient solutions in a vast search space, ultimately surpassing human levels.
Heuristic Search and MCTS#
MCTS is like an ever-evolving explorer, searching for the best path in a decision tree.
Core Idea#
What is the essence of MCTS? Simply put, it is a "learn while playing" process. Imagine you are playing a brand new board game:
- At first, you try various possible moves (exploration)
- Gradually, you discover that some moves work better (exploitation)
- You strike a balance between exploring new strategies and exploiting known good strategies
This is precisely what MCTS does, but it systematizes this process mathematically. It is a rollout algorithm that guides strategy selection through accumulated Monte Carlo simulation value estimates.
Algorithm Process#
Monte Carlo Tree Search - YouTube This teacher explains the MCTS process very well.
The elegance of MCTS lies in its four simple yet powerful steps, which I will introduce using two perspectives, A and B:
-
Selection
- A: Do you know how children learn? They always hover between the known and the unknown. MCTS is similar: starting from the root node, it uses the UCB (Upper Confidence Bound) formula to weigh whether to choose a known good path or explore new possibilities.
- B: Starting from the root node, a subsequent node is selected from the current node based on a specific strategy. This strategy is usually based on the search history of the tree, selecting the most promising path. For example, we perform action at each node based on the current strategy to balance exploration and exploitation, gradually delving deeper.
-
Expansion
- A: Like an explorer opening up new areas on a map, when we reach a leaf node, we expand downwards, creating new possibilities.
- B: When the selected node still has unexplored child nodes, we expand new nodes based on the set of possible actions at this node. The purpose of this process is to increase the breadth of the decision tree, gradually generating possible decision paths. Through this expansion operation, we ensure that the search covers more possible state-action pairs .
-
Simulation
- A: This is the most interesting part. Starting from the new node, we conduct a "hypothetical" game until it ends. It's like playing chess in your mind, simulating "if I move this way, my opponent will move that way...".
- B: Starting from the currently expanded node, perform random simulations (rollout), sampling along the current strategy in the MDP environment until reaching a terminal state. This process provides a return estimate from the current node to the endpoint, offering numerical evidence for the quality of the path.
-
Backpropagation
- A: Finally, we return the results along the path, updating the statistics of each node. It's like saying, "Hey, I tried this path, and it worked pretty well (or not so well)."
- B: After completing the simulation, the estimated return of this simulation is backpropagated to all the nodes traversed to update their values. This process accumulates historical return information, making future choices more accurately trend towards high-reward paths.
UCB1: The Perfect Balance of Exploration and Exploitation#
Here, the UCB1 formula, which is the soul of MCTS, deserves special mention:
Let's break it down:
- is the average return (exploitation term)
- is the uncertainty measure (exploration term)
- is the exploration parameter
Like a good investment portfolio, it should focus on known good opportunities while remaining open to new ones (exploration-exploitation trade-off).
Best Multi-Armed Bandit Strategy? (feat: UCB Method) This video explains the Multi-Armed Bandit and UCB Method very well, and I borrow the example used by this teacher:
A Foodie Trying to Understand UCB#
Imagine you just arrived in a city with 100 restaurants to choose from, and you have 300 days. Each day, you need to choose a restaurant to dine in, hoping to average the best food over these 300 days.
ε-greedy Strategy: Simple but Not Smart#
This is like deciding by rolling dice:
- 90% of the time (ε=0.1): Go to the known best restaurant (exploitation)
- 10% of the time: Randomly try a new restaurant (exploration)
def epsilon_greedy(restaurants, ratings, epsilon=0.1):
if random.random() < epsilon:
return random.choice(restaurants) # Exploration
else:
return restaurants[np.argmax(ratings)] # Exploitation
The effect is:
- Exploration is completely random, possibly revisiting known bad restaurants
- The exploration ratio is fixed and does not adjust over time
- It does not consider the impact of visit counts
UCB Strategy: Smarter Choice Trade-offs#
The UCB formula in restaurant selection means:
For example, consider the situation of two restaurants on day 100:
Restaurant A:
- Visited 50 times, average score 4.5
- UCB score =
Restaurant B:
- Visited 5 times, average score 4.0
- UCB score =
Although Restaurant B has a lower average score, its fewer visits lead to higher uncertainty, so UCB gives it a higher exploration reward.
Code implementation:
class Restaurant:
def __init__(self, name):
self.name = name
self.total_rating = 0
self.visits = 0
def ucb_choice(restaurants, total_days, c=2):
# Ensure each restaurant is visited at least once
for r in restaurants:
if r.visits == 0:
return r
# Use UCB formula to choose a restaurant
scores = []
for r in restaurants:
avg_rating = r.total_rating / r.visits
exploration_term = c * math.sqrt(math.log(total_days) / r.visits)
ucb_score = avg_rating + exploration_term
scores.append(ucb_score)
return restaurants[np.argmax(scores)]
Why UCB is Better?#
-
Adaptive Exploration
- Restaurants with fewer visits receive higher exploration rewards
- As total days increase, the exploration term gradually decreases to better exploit
-
Balanced Time Investment
- It does not waste too much time on clearly inferior restaurants
- It reasonably allocates visit counts among similarly potential restaurants
-
Theoretical Guarantee
- Regret Bound (the gap from the optimal choice) grows logarithmically over time
- 300 days of exploration is sufficient to find the best few restaurants
Returning to MCTS:
Why is MCTS So Powerful?#
- Efficiently Handles Combinatorial Explosion: MCTS does not need to exhaustively search all possibilities like Minimax but focuses on the most promising branches, allowing it to effectively handle problems with large branching factors.
- Adaptive Search: MCTS dynamically adjusts its search strategy, allocating more resources to more promising areas, thus finding good solutions faster.
- Balances Exploration and Exploitation: Through the UCB formula, MCTS strikes a balance between exploring new possibilities and exploiting known good choices, avoiding local optima.
- No Domain Knowledge Required: MCTS does not rely on specific domain expertise, learning solely from game rules and simulation results, making it widely applicable.
- Can Stop Anytime: MCTS is a "anytime" algorithm that can be interrupted at any time and return the current best solution, which is crucial for real-time applications.
AlphaZero: From MCTS to Self-Evolution#
AlphaZero is a general reinforcement learning algorithm launched by DeepMind after AlphaGo, capable of learning from scratch and ultimately surpassing professional levels through self-play without using human game records.
AlphaZero improves traditional MCTS by introducing neural networks to guide the search:
- Policy Priors: Using neural networks to predict prior probabilities for each action, making the search more efficient.
- Value Evaluation: At leaf nodes, using the value predictions from neural networks instead of random simulations to reduce computational costs.
Below, we will implement an AI that can learn in this way using Gomoku as an example. Here, we learn from and reference the excellent implementation of schinger/alphazero.py.
Designing the Game Environment#
Before implementing AlphaZero, we need to define the game environment.
Defining the Game Interface#
In AlphaZero, the neural network receives the current board state and outputs a policy vector representing the probabilities of each action, as well as a scalar value representing the current player's win rate prediction.
To ensure that MCTS and the neural network can universally interact with the game, we need to define a consistent game interface.
class GomokuGame:
def __init__(self, n=15):
self.n = n # Board size, default 15x15
def getInitBoard(self):
"""
Returns the initial board state, all positions are empty.
"""
b = Board(self.n)
return np.array(b.pieces)
def getBoardSize(self):
"""
Returns the size of the board, i.e., (n, n).
"""
return (self.n, self.n)
def getActionSize(self):
"""
Returns the total number of actions, here it is n * n, as each square can be an action.
"""
return self.n * self.n
def getNextState(self, board, player, action):
"""
Executes an action, returning the next board state and the next player.
Parameters:
- board: Current board state
- player: Current player (1 or -1)
- action: Current action, 0 ~ n*n-1
Returns:
- (next_board, next_player): The board and next player after executing the action
"""
b = Board(self.n)
b.pieces = np.copy(board)
# Convert action to coordinates (x, y)
move = (action // self.n, action % self.n)
b.execute_move(move, player)
return (b.pieces, -player)
- Unified interface: This allows our MCTS and neural network to be reused across different games, requiring only the implementation of the specific game logic.
- Board representation: Using a two-dimensional array to represent the board, facilitating processing and visualization.
Core Algorithm Implementation#
Dual-Head Neural Network#
Network Structure#
In AlphaZero, we use a unified dual-head neural network to predict both policy (policy) and value (value) simultaneously. This neural network receives the current board state to compute outputs:
• : The probability distribution output by the policy head, representing the selection probability of each possible action in state . This policy head combines with MCTS search to generate stronger decisions.
• : The value head outputs a scalar representing the predicted value of the current board state (the probability of winning).
import torch
import torch.nn as nn
import torch.nn.functional as F
class AlphaZeroNNet(nn.Module):
def __init__(self, game, args):
"""
Parameters:
- game: Game object providing information such as board size and action space size.
- args: Parameters containing the network structure, such as the number of channels, dropout rate, etc.
"""
super(AlphaZeroNNet, self).__init__()
self.board_x, self.board_y = game.getBoardSize() # Board size
self.action_size = game.getActionSize() # Action space size
self.args = args
# Convolutional layer blocks
self.conv_layers = nn.Sequential(
nn.Conv2d(1, args.num_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(args.num_channels),
nn.ReLU(),
nn.Conv2d(args.num_channels, args.num_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(args.num_channels),
nn.ReLU(),
nn.Conv2d(args.num_channels, args.num_channels, kernel_size=3),
nn.BatchNorm2d(args.num_channels),
nn.ReLU(),
nn.Conv2d(args.num_channels, args.num_channels, kernel_size=3),
nn.BatchNorm2d(args.num_channels),
nn.ReLU(),
)
# Calculate the output size of the convolutional layers
conv_output_size = self._get_conv_output_size()
# Fully connected layer blocks
self.fc_layers = nn.Sequential(
nn.Linear(conv_output_size, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Dropout(args.dropout),
nn.Linear(1024, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(args.dropout),
)
# Policy head: Outputs the log probabilities of each action
self.policy_head = nn.Linear(512, self.action_size)
# Value head: Outputs the value estimate of the current state
self.value_head = nn.Linear(512, 1)
def _get_conv_output_size(self):
"""
Calculates the number of features output by the convolutional layers.
"""
dummy_input = torch.zeros(1, 1, self.board_x, self.board_y)
output_feat = self.conv_layers(dummy_input)
n_size = output_feat.view(1, -1).size(1)
return n_size
def forward(self, s):
"""
Forward propagation function.
Parameters:
- s: Input board state, shape (batch_size, board_x, board_y).
Returns:
- log_policies: Log probabilities of the policy, shape (batch_size, action_size).
- values: Value estimates, shape (batch_size, 1).
"""
# Adjust input shape
s = s.view(-1, 1, self.board_x, self.board_y) # (batch_size, 1, board_x, board_y)
# Convolutional layers extract features
s = self.conv_layers(s) # (batch_size, num_channels, new_board_x, new_board_y)
# Flatten the output of the convolutional layers
s = s.view(s.size(0), -1) # (batch_size, conv_output_size)
# Fully connected layers extract high-level features
s = self.fc_layers(s) # (batch_size, 512)
# Policy head output
policies = self.policy_head(s) # (batch_size, action_size)
log_policies = F.log_softmax(policies, dim=1)
# Value head output
values = self.value_head(s) # (batch_size, 1)
values = torch.tanh(values) # Limit the value to [-1, 1]
return log_policies, values
- Convolutional layers extract features: Through multiple convolutional layers, spatial features of the board are extracted.
- Fully connected layers output: The features are flattened and passed through fully connected layers to output policy and value separately.
- Activation functions:
- The policy output uses
log_softmax
, facilitating subsequent cross-entropy loss calculations. - The value output uses
tanh
, limiting the value to .
- The policy output uses
Loss Function#
The training objective is defined through the following loss function to simultaneously train the network's policy and value evaluation capabilities:
• : Value loss, requiring the network output to be as close as possible to the game result (win is +1, loss is -1, draw is 0).
• : Policy loss, through cross-entropy, requiring the network output policy to match the final policy obtained from MCTS as closely as possible.
• : L2 regularization term to prevent overfitting, keeping the weight scale of model parameters appropriate.
MCTS Class#
class MCTS:
def __init__(self, game, nnet, args):
self.game = game # Game environment
self.nnet = nnet # Neural network
self.args = args # Parameters
self.Qsa = {} # Store Q values: Q(s,a)
self.Nsa = {} # Store edge visit counts: N(s,a)
self.Ns = {} # Store node visit counts: N(s)
self.Ps = {} # Store policy priors: P(s,a)
self.Es = {} # Store game end information: E(s)
self.Vs = {} # Store valid actions: V(s)
One highlight here is the use of a caching mechanism: using dictionaries to cache computation results, avoiding redundant calculations and improving efficiency.
Valid Action Mask#
In many games, certain actions are illegal in specific states. For example, in Gomoku, if a position is already occupied, placing a piece there is illegal. Therefore, ensuring that only valid actions are considered is crucial for the correctness and efficiency of the algorithm. This is achieved through the valid action mask.
In the search
method, when we reach a leaf node and need to use the neural network for predictions, we process the policy output from the neural network, using the valid mask to mask illegal actions.
if s not in self.Ps: # Not in the stored policy prior P(s,a)
# Leaf node, expand based on the policy output from the neural network
self.Ps[s], v = self.nnet.predict(canonicalBoard)
valids = self.game.getValidMoves(canonicalBoard, 1) # Get the mask for valid actions, 1 for valid
self.Ps[s] = self.Ps[s] * valids # Mask illegal actions
sum_Ps_s = np.sum(self.Ps[s])
if sum_Ps_s > 0:
self.Ps[s] /= sum_Ps_s # Normalize policy probabilities
else:
# All actions are masked, perform uniform distribution
self.Ps[s] = self.Ps[s] + valids
self.Ps[s] /= np.sum(self.Ps[s])
self.Vs[s] = valids # Store the mask for valid actions
self.Ns[s] = 0 # Initialize state visit count
return v
PUCT#
A variant of the UCB formula specifically for MCTS.
In AlphaZero, MCTS uses the outputs of the neural network to guide the search direction. During the search process, the neural network selects actions using the improved upper confidence limit formula PUCT (Polynomial Upper Confidence Trees for Trees), introducing the prior probability , allowing MCTS to leverage external knowledge (policy network) to guide the search:
Where:
- is the average value of choosing action in state .
- is the prior probability provided by the neural network.
- and are the visit counts for state and edge , respectively.
- is a constant controlling the degree of exploration. The combination of the subsequent exploration term encourages the selection of actions with high prior probabilities but fewer visits.
cur_best = -float('inf') # Record the current maximum U value
best_act = -1 # Record the action with the maximum U value
# Iterate through all possible actions
for a in range(self.game.getActionSize()):
if valids[a]:
# For valid actions, calculate the PUCT value
if (s, a) in self.Qsa:
# If (s, a) has been visited before, use the calculated Q value and visit count
u = self.Qsa[(s, a)] + self.args.cpuct * self.Ps[s][a] * \
math.sqrt(self.Ns[s]) / (1 + self.Nsa[(s, a)])
else:
# If it is an unvisited node, the Q value is 0, encouraging exploration
u = self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s] + 1e-8) # Give higher U value to unvisited actions, avoiding division by zero
# Select the action with the highest U value
if u > cur_best:
cur_best = u
best_act = a
Balancing Exploration and Exploitation#
-
Unvisited nodes have a smaller denominator due to
N(s, a) = 0
, making the exploration term larger, encouraging the algorithm to explore new actions. -
Q(s, a) reflects the current estimate of the action's value, representing exploitation.
Value Inversion#
From player 's perspective, obtaining a higher value for player means a lower value for player , and vice versa. Therefore, throughout the search and value evaluation process, ensuring that the value reflects the maximization of the current player's interests maintains perspective consistency.
In the code, this implementation is reflected in the following call:
v = -self.search(next_s)
Search Function#
def search(self, canonicalBoard):
"""
Executes an MCTS search.
Parameters:
- canonicalBoard: The current player's board state (canonical form)
Returns:
- v: Value estimate for the current node
"""
s = self.game.stringRepresentation(canonicalBoard)
if s not in self.Es:
self.Es[s] = self.game.getGameEnded(canonicalBoard, 1)
if self.Es[s] is not None:
# If the game has ended, return the result
return self.Es[s]
'''
Valid Mask & Leaf Node Expansion
'''
valids = self.Vs[s]
'''
Use the PUCT formula to select the optimal action
'''
a = best_act
next_s, next_player = self.game.getNextState(canonicalBoard, 1, a)
next_s = self.game.getCanonicalForm(next_s, next_player)
# The value of the next state is inverted
v = -self.search(next_s)
# Update Q value and visit count
if (s, a) in self.Qsa:
self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] + v) / (self.Nsa[(s, a)] + 1)
self.Nsa[(s, a)] += 1
else:
self.Qsa[(s, a)] = v
self.Nsa[(s, a)] = 1
self.Ns[s] += 1
return v
- Through recursive deep searches, it simulates future possibilities.
- At leaf nodes, it uses the neural network for predictions to speed up the search.
- Since players alternate, the value of the next state is inverted.
Self-Play#
A Go algorithm that does not require human game data for training, only obtains data through self-play.
Self-play means that for each move, several MCTS simulations are executed to obtain the improved strategy for selecting moves. Players alternate between black and white until the game ends. Self-play generates a large amount of training data, helping the neural network learn better strategies and value estimates.
Implementing "Executing a Self-Play Game"#
First, let's look at the core function of Self-Play: executeEpisode
, which is responsible for executing an entire game until the end and collecting training data.
def executeEpisode(self):
"""
Executes a self-play game, returning a list of training samples.
Returns:
trainExamples: A list containing multiple (canonicalBoard, pi, v) tuples.
"""
trainExamples = []
board = self.game.getInitBoard() # Initialize the board
self.curPlayer = 1 # Current player (1 or -1)
episodeStep = 0 # Record the step count
while True: # Until the game ends
episodeStep += 1
canonicalBoard = self.game.getCanonicalForm(board, self.curPlayer)
temp = int(episodeStep < self.args.tempThreshold) # Temperature parameter
pi = self.mcts.getActionProb(canonicalBoard, temp=temp) # Use MCTS to get the action probability distribution (policy)
sym = self.game.getSymmetries(canonicalBoard, pi) # Data augmentation (symmetry)
for b, p in sym:
trainExamples.append([b, self.curPlayer, p, None])
action = np.random.choice(len(pi), p=pi) # Choose an action according to the policy probabilities
board, self.curPlayer = self.game.getNextState(board, self.curPlayer, action) # Execute the action, updating the board and current player.
r = self.game.getGameEnded(board, self.curPlayer) # Check if the game has ended
if r is not None:
# Assign final value to each training sample, +1 for the winning player, -1 for the losing player, 0 for a draw
return [(x[0], x[2], r * ((-1) (x[1] != self.curPlayer))) for x in trainExamples]
-
Temperature parameter: Controls the exploratory nature of the policy. Initially set to 1 to encourage exploration; later set to 0 to favor optimal strategy exploitation.
-
Data augmentation: In Gomoku, the rotation and flipping of the board do not change the essence of the game. Utilizing these symmetries generates equivalent boards and strategies, increasing training data to improve the model's generalization ability.
If the game ends, training samples are generated based on the game record, returning trainExamples
.
Main Loop of Self-Play#
Self-play is not just about executing one game; we need to continuously engage in self-play over multiple iterations, updating the model.
def learn(self):
"""
Conduct multiple iterations of self-play and model updates.
"""
for i in range(1, self.args.numIters + 1):
log.info(f"Starting Iter #{i} ...")
# Store training samples from this iteration
iterationTrainExamples = deque([], maxlen=self.args.maxlenOfQueue)
# Conduct the specified number of self-plays
for _ in tqdm(range(self.args.numEps), desc="Self Play"):
self.mcts = MCTS(self.game, self.nnet, self.args) # Reset MCTS
iterationTrainExamples += self.executeEpisode()
# Save training samples
self.trainExamplesHistory.append(iterationTrainExamples)
# Maintain a fixed-length history of training samples
if len(self.trainExamplesHistory) > self.args.numItersForTrainExamplesHistory:
log.warning(f"Removing the oldest entry in trainExamples.")
self.trainExamplesHistory.pop(0)
# Merge and shuffle all training samples
trainExamples = []
for e in self.trainExamplesHistory:
trainExamples.extend(e)
shuffle(trainExamples)
# Train the neural network
self.nnet.save_checkpoint(folder=self.args.checkpoint, filename="temp.pth.tar")
self.pnet.load_checkpoint(folder=self.args.checkpoint, filename="temp.pth.tar")
pmcts = MCTS(self.game, self.pnet, self.args)
self.nnet.train(trainExamples)
nmcts = MCTS(self.game, self.nnet, self.args)
- If the history length exceeds the set threshold, the oldest sample is removed to ensure the model does not overfit to old data.
- After collecting enough training data, the model is trained using
self.nnet.train(trainExamples)
.
Filtering Mechanism#
To ensure the model's playing strength continues to improve, AlphaGo Zero introduces a filtering mechanism for new and old models:
- After each training round, the new model is pitted against the old model.
- If the new model's win rate exceeds a set threshold (e.g., 55%), the new model is accepted, and the next round of self-play continues; otherwise, the weights of the old model are restored to continue generating data and training based on the old model, avoiding meaningless degradation and overfitting issues.
# Evaluate the new model
log.info("PITTING AGAINST PREVIOUS VERSION")
arena = game.Arena(lambda x: np.argmax(pmcts.getActionProb(x, temp=0)),
lambda x: np.argmax(nmcts.getActionProb(x, temp=0)),
self.game)
pwins, nwins, draws = arena.playGames(self.args.arenaCompare)
log.info(f"NEW/PREV WINS : {nwins} / {pwins} ; DRAWS : {draws}")
if pwins + nwins == 0 or float(nwins) / (pwins + nwins) < self.args.updateThreshold:
log.info("REJECTING NEW MODEL")
self.nnet.load_checkpoint(folder=self.args.checkpoint, filename="temp.pth.tar")
else:
log.info("ACCEPTING NEW MODEL")
self.nnet.save_checkpoint(folder=self.args.checkpoint, filename="best.pth.tar")
Obtaining Strategy from MCTS#
During self-play, we use MCTS to select actions. The getActionProb
function retrieves the probability distribution of actions from MCTS.
def getActionProb(self, canonicalBoard, temp=1):
"""
Executes multiple MCTS searches and returns the probability distribution of actions.
Parameters:
canonicalBoard: The current normalized board state
temp: Temperature parameter
Returns:
probs: Probability distribution of actions
"""
# Perform the specified number of MCTS searches
for _ in range(self.args.numMCTSSims):
self.search(canonicalBoard)
s = self.game.stringRepresentation(canonicalBoard)
# Record the visit counts for each action N(s, a)
counts = [self.Nsa[(s, a)] if (s, a) in self.Nsa else 0
for a in range(self.game.getActionSize())]
if temp == 0:
bestAs = np.array(np.argwhere(counts == np.max(counts))).flatten()
bestA = np.random.choice(bestAs)
probs = [0] * len(counts)
probs[bestA] = 1
return probs
counts = [x (1. / temp) for x in counts]
counts_sum = float(sum(counts))
probs = [x / counts_sum for x in counts]
return probs
The visit count N(s, a)
reflects MCTS's confidence in the action.
Calculating the strategy probabilities:
- When
temp = 0
:- The action with the most visits is chosen, with a probability of 1.
- Ensures the strategy is deterministic, typically used in the later stages of the game or evaluation.
- When
temp > 0
:- The visit counts undergo temperature transformation:
counts = [x (1. / temp) for x in counts]
. - Higher temperatures lead to strategies closer to uniform distribution, enhancing exploration.
- The visit counts undergo temperature transformation:
Training Process#
MCTS simulation times: 400, cpuct: 1.0, training for 50 rounds:
At this point, the model has learned the rules of Gomoku, reaching the level of a beginner AI, and you can sneak a win at the edge of the board.
Adding 1cycle Learning Rate Adjustment#
Dynamically adjusting the learning rate can improve the model's convergence speed and generalization ability.
- First phase (first half-cycle): The learning rate gradually increases from the minimum value () to the maximum value ().
- Second phase (second half-cycle): The learning rate gradually decreases back to the minimum value.
class NNetWrapper:
def __init__(self, game, args):
...
self.optimizer = optim.Adam(self.nnet.parameters(), lr=args.max_lr)
# 1cycle learning rate parameters
self.total_steps = args.numIters * args.epochs * (args.maxlenOfQueue // args.batch_size)
self.current_step = 0
def get_learning_rate(self):
"""Implement 1cycle learning rate strategy"""
if self.current_step >= self.total_steps:
return self.args.min_lr
half_cycle = self.total_steps // 2
if self.current_step <= half_cycle:
# First phase: Increase from min_lr to max_lr
phase = self.current_step / half_cycle
lr = self.args.min_lr + (self.args.max_lr - self.args.min_lr) * phase
else:
# Second phase: Decrease from max_lr to min_lr
phase = (self.current_step - half_cycle) / half_cycle
lr = self.args.max_lr - (self.args.max_lr - self.args.min_lr) * phase
return lr
Adding Gradient Clipping#
You can review PPO
To prevent gradient explosion and stabilize the training process, we add gradient clipping after backpropagation.
Assuming the gradient vector of a parameter is , with its current L2 norm being :
- If , then remains unchanged.
- If , then is scaled to .
This ensures that:
def train(self, examples):
...
for epoch in range(self.args.epochs):
...
for _ in t:
...
# Calculate loss and backpropagate
total_loss.backward()
# Gradient clipping
if self.args.grad_clip:
torch.nn.utils.clip_grad_norm_(self.nnet.parameters(), self.args.grad_clip)
# Update parameters
self.optimizer.step()
...
Gradient clipping limits the norm of the gradient to a specified threshold, preventing instability in training due to excessively large gradients. By clipping the gradients before each parameter update, we ensure the robustness of the training process and accelerate the model's convergence.
Conclusion#
The final code and demo can be found in the repository below, and the model weights can be found in the huggingface link repository in the documentation.
After making these improvements, the number of MCTS simulations was increased to 4000, cpuct: 4.0, and trained for 5 rounds. At this point, our model gradually upgraded to an intermediate difficulty (the early pressure is manageable, but in the mid to late game, it may overlook the opponent's winning points), ultimately reaching the level of a difficult AI.