banner
Nagi-ovo

Nagi-ovo

Breezing homepage: [nagi.fun](nagi.fun)
github

Let's build AlphaZero

This article is a digestion and understanding of articles, video tutorials, and code implementations such as Sunrise: Understanding AlphaZero, MCTS, Self-Play, UCB from Scratch.

This article will start from the design principles of AlphaGo, gradually revealing how to build an AI Gomoku system that can surpass human capabilities by deeply understanding the two core mechanisms: MCTS and Self-Play.

AlphaGo: From Imitation to Superiority#

Evolution Path of AlphaGo#

The number of possible Go positions exceeds the number of atoms in the universe, rendering traditional exhaustive search methods completely ineffective. AlphaGo addresses this issue through a phased approach: first learning human knowledge, then continuously evolving through self-play.

image

This evolutionary process can be divided into three layers:

  1. Imitating human experts
  2. Self-play enhancement
  3. Learning to evaluate situations

Core Components#

image

Rollout Policy#

The leftmost lightweight policy network is used for quick evaluation, with lower accuracy but high computational efficiency.

SL Policy Network#

The supervised learning policy network PσP_{\sigma} learns to play by imitating human game records:

  • Input: Board state
  • Output: Probability distribution of the next move imitating human experts
  • Training data: 160,000 games, approximately 30 million moves

RL Policy Network#

Similar to when your skill reaches a certain level, you start reviewing your games and playing against yourself, discovering new tactics and deeper strategies. The RL network PρP_{\rho} is initialized from the SL network, which has already surpassed human capabilities and can find many powerful strategies that humans have not discovered.

  • This stage generates a large amount of self-play data, with the input still being the board state, and the output is the move strategy improved through reinforcement learning.
  • Self-play with the historical version of the policy network pπp_\pi, where the network adjusts parameters based on the "win" signal, reinforcing the moves that lead to victory:
tz(τ)logp(atst;ρ)\sum_{t} z(\tau) \log p(a_t \mid s_t; \rho)

Where:

  • τ\tau is the game sequence (s0,a0,s1,a1,...)(s_0,a_0,s_1,a_1,...)
  • z(τ)z(\tau) is the win-loss label (win is positive, loss is negative)

Value Network#

The value network vθv_\theta learns to evaluate situations and can be a CNN:

  • Training objective: Minimize mean squared error
i(zvθ(s))2\sum_{i} (z - v_\theta(s))^2
  • zz: Final win-loss result
  • vθ(s)v_\theta(s): Predicted winning probability

Supplementary Explanation#

About the relationship between PρP_\rho and vθv_\theta:

  • The policy network PρP_\rho provides specific move probabilities to guide the search.
  • The value network vθv_\theta provides situation evaluations, reducing unnecessary simulations in the search tree.
  • The combination of both allows MCTS to not only explore high-win-rate paths faster but also significantly improve the overall playing level.

AlphaGo's MCTS Implementation#

Selection Phase#

Combining exploration and exploitation:

at=argmaxaQ(st,a)+u(st,a)a_t = \arg\max_a Q(s_t, a) + u(s_t, a)
u(st,a)P(st,a)1+N(st,a)u(s_t, a) \propto \frac{P(s_t, a)}{1 + N(s_t, a)}

Where:

  • Q(st,a)Q(s_t, a): Long-term return estimate
  • u(st,a)u(s_t, a): Exploration reward
  • P(st,a)P(s_t, a): Prior probability output by the policy network
  • N(st,a)N(s_t, a): Visit count

Simulation and Evaluation#

In the original MCTS algorithm, the simulation phase's role is to conduct random games from the leaf nodes (newly expanded nodes) using a fast rollout policy until the game ends, then obtain a reward based on the win-loss outcome. This reward is passed back to the nodes in the search tree to update their value estimates (i.e., Q(s,a)Q(s, a)).

This implementation is simple and direct, but the rollout policy is usually random or based on simple rules, leading to potentially poor simulation quality. It can only provide short-term information and does not effectively combine global strategic evaluations.

AlphaGo combines the Value Network vθv_\theta and the rollout policy during the simulation phase. The Value Network provides more efficient leaf node estimates and global capability evaluations, while the rollout policy captures local short-term effects through rapid simulations.

Using the hyperparameter λ\lambda to balance vθ(sL)v_\theta(s_L) and zLz_L, it considers both local simulations and global evaluations. The evaluation function:

V(sL)=(1λ)vθ(sL)+λzL V(s_L) = (1 - \lambda)v_\theta(s_L) + \lambda z_L

Backpropagation#

During n MCTS iterations, the node visit count updates (where I\mathbb{I} is the indicator function, and visits are 1):

N(st,a)=i=1nI(s,a,i)N(s_t, a) = \sum_{i=1}^{n} \mathbb{I}(s, a, i)

Q value updates, i.e., executing a to node sts_t's long-term expected return:

Q(st,a)=1N(st,a)i=1nI(s,a,i)V(sLi)Q(s_t, a) = \frac{1}{N(s_t, a)} \sum_{i=1}^{n} \mathbb{I}(s, a, i) V(s_L^i)

Summary#

  1. Structural Innovation:

    • The policy network provides prior knowledge
    • The value network provides global evaluation
    • MCTS provides tactical validation
  2. Training Innovation:

    • Starting from supervised learning
    • Surpassing the teacher through reinforcement learning
    • Self-play generates new knowledge
  3. MCTS Improvements:

    • Using neural networks to guide the search
    • The Policy Network provides prior probabilities for exploration directions, while the Value Network enhances the accuracy of leaf node evaluations.
    • This combination of value network and rollout evaluations not only reduces the search width and depth but also significantly improves overall performance.
    • Efficient exploration-exploitation balance

This design allows AlphaGo to find efficient solutions in a vast search space, ultimately surpassing human levels.

Heuristic Search and MCTS#

MCTS is like an ever-evolving explorer, seeking the best path in a decision tree.

Core Idea#

What is the essence of MCTS? Simply put, it is a "learn while playing" process. Imagine you are playing a brand new board game:

  • At first, you will try various possible moves (exploration)
  • Gradually, you discover that certain moves work better (exploitation)
  • Balancing between exploring new strategies and exploiting known good strategies

This is precisely what MCTS does, but it systematizes this process mathematically. It is a rollout algorithm that guides strategy selection by accumulating value estimates from Monte Carlo simulations.

image

Algorithm Flow#

Monte Carlo Tree Search - YouTube This teacher explains the MCTS process very well.

The elegance of MCTS lies in its four simple yet powerful steps, which I will introduce using two perspectives, A and B:

  1. Selection

    • A: Do you know how children learn? They always hover between the known and the unknown. MCTS is similar: starting from the root node, it uses the UCB (Upper Confidence Bound) formula to weigh whether to choose a known good path or explore new possibilities.
    • B: Starting from the root node, select a subsequent node from the current node based on a specific strategy. This strategy is usually based on the search history of the tree, selecting the most promising path. For example, at each node, we perform action aa based on the current strategy π(s)\pi(s) to balance exploration and exploitation, gradually delving deeper.
  2. Expansion

    • A: Like an explorer opening up new areas on a map, when we reach a leaf node, we will expand downwards, creating new possibilities.
    • B: When the selected node still has unexplored child nodes, we expand new nodes under this node based on the set of possible actions. The purpose of this process is to increase the breadth of the decision tree, gradually generating possible decision paths. Through this expansion operation, we ensure that the search covers more possible state-action pairs (s,a)(s,a).
  3. Simulation

    • A: This is the most interesting part. Starting from the new node, we conduct a "hypothetical" game until the game ends. It's like simulating in your mind, "If I move this way, my opponent will move that way..."
    • B: Starting from the currently expanded node, perform random simulations (rollouts), sampling along the current strategy π(s)\pi(s) in the MDP environment until the terminal state. This process provides a return estimate from the current node to the endpoint, offering numerical evidence for the quality of the path.
  4. Backpropagation

    • A: Finally, we return the results along the path, updating the statistics of each node. It's like saying, "Hey, I tried this path, and it worked pretty well (or not so well)."
    • B: After completing the simulation, the estimated return from that simulation is backpropagated to all the nodes traversed to update their values. This process accumulates historical return information, making future choices more accurately trend towards high-reward paths.

image

UCB1: The Perfect Balance of Exploration and Exploitation#

Here, we must particularly mention the UCB1 formula, which is the soul of MCTS:

UCB1=Xˉi+C×(ln(N)/ni)UCB1 = X̄ᵢ + C × \sqrt{(ln(N)/nᵢ)}

Let's break it down:

  • XˉiX̄ᵢ is the average return (exploitation term)
  • (ln(N)/ni)\sqrt{(ln(N)/nᵢ)} is the uncertainty measure (exploration term)
  • CC is the exploration parameter

Like a good investment portfolio, it must focus on known good opportunities while remaining open to new ones (exploration-exploitation trade-off).

image

Best Multi-Armed Bandit Strategy? (feat: UCB Method) This video explains the Multi-Armed Bandit and UCB Method very well, and I borrowed the example used by this teacher:

A Foodie Trying to Understand UCB#

Imagine you just arrived in a city with 100 restaurants to choose from, and you have 300 days. Each day, you must choose a restaurant to dine in, hoping to average the best food over these 300 days.

ε-greedy Strategy: Simple but Not Smart#

This is like deciding by rolling a dice:

  • 90% of the time (ε=0.1): Go to the known best restaurant (exploitation)
  • 10% of the time: Randomly try a new restaurant (exploration)
def epsilon_greedy(restaurants, ratings, epsilon=0.1):
    if random.random() < epsilon:
        return random.choice(restaurants)  # Exploration
    else:
        return restaurants[np.argmax(ratings)]  # Exploitation

The effect is:

  • Exploration is completely random, possibly revisiting known bad restaurants
  • The exploration ratio is fixed and does not adjust over time
  • Does not consider the impact of visit counts

UCB Strategy: A Smarter Choice Trade-off#

The UCB formula in restaurant selection means:

Score=AverageScore+C×ln(TotalVisitDays)/VisitsofThisRestaurantScore = Average Score + C × \sqrt{ln(Total Visit Days)/Visits of This Restaurant}

For example, consider two restaurants on day 100:

Restaurant A:

  • Visited 50 times, average score 4.5
  • UCB score = 4.5+2×ln(100)/504.5+0.6=5.14.5 + 2×\sqrt{ln(100)/50} ≈ 4.5 + 0.6 = 5.1

Restaurant B:

  • Visited 5 times, average score 4.0
  • UCB score = 4.0+2×ln(100)/54.0+1.9=5.94.0 + 2×\sqrt{ln(100)/5} ≈ 4.0 + 1.9 = 5.9

Although Restaurant B has a lower average score, due to fewer visits and higher uncertainty, UCB gives it a higher exploration reward.

image

Code implementation:

class Restaurant:
    def __init__(self, name):
        self.name = name
        self.total_rating = 0
        self.visits = 0
        
def ucb_choice(restaurants, total_days, c=2):
    # Ensure each restaurant is visited at least once
    for r in restaurants:
        if r.visits == 0:
            return r
            
    # Use UCB formula to choose a restaurant
    scores = []
    for r in restaurants:
        avg_rating = r.total_rating / r.visits
        exploration_term = c * math.sqrt(math.log(total_days) / r.visits)
        ucb_score = avg_rating + exploration_term
        scores.append(ucb_score)
        
    return restaurants[np.argmax(scores)]

Why UCB is Better?#

  1. Adaptive Exploration

    • Restaurants with fewer visits receive higher exploration rewards
    • As total days increase, the exploration term gradually decreases for better exploitation
  2. Balanced Time Investment

    • Avoid wasting too much time on clearly inferior restaurants
    • Reasonably distribute visit counts among similarly potential restaurants
  3. Theoretical Guarantee

    • Regret Bound (the gap from the optimal choice) grows logarithmically over time
    • 300 days of exploration is sufficient to find the best few restaurants

Back to MCTS:

image

Why is MCTS So Powerful?#

  • Efficiently Handling Combinatorial Explosion: MCTS does not need to exhaustively search all possibilities like Minimax but focuses on the most promising branches, enabling it to effectively handle problems with huge branching factors.
  • Adaptive Search: MCTS dynamically adjusts its search strategy, allocating more resources to more promising areas, thus finding good solutions faster.
  • Balancing Exploration and Exploitation: Through the UCB formula, MCTS strikes a balance between exploring new possibilities and exploiting known good choices, avoiding local optima.
  • No Domain Knowledge Required: MCTS does not rely on specific domain expertise but learns solely from game rules and simulation results, making it widely applicable.
  • Can Stop Anytime: MCTS is a "anytime" algorithm that can be interrupted at any time and return the current best solution, which is crucial for real-time applications.

AlphaZero: From MCTS to Self-Evolution#

AlphaZero is a general reinforcement learning algorithm launched by DeepMind after AlphaGo, capable of learning from scratch and ultimately surpassing professional levels through self-play without using human game records.

AlphaZero improves upon traditional MCTS by introducing neural networks to guide the search:

  • Policy Prior: Uses neural networks to predict the prior probabilities of each action, making the search more efficient.
  • Value Assessment: At leaf nodes, uses the value prediction from the neural network instead of random simulations, reducing computational costs.

Next, we will implement an AI that can learn in this way using Gomoku as an example. Here, we learn from and reference the excellent implementation of schinger/alphazero.py.

Designing the Game Environment#

Before implementing AlphaZero, we need to define the game environment.

Defining the Game Interface#

In AlphaZero, the neural network receives the current board state and outputs a policy vector P\boldsymbol{P} representing the probability of each action, as well as a scalar value vv representing the current player's win rate prediction.

To enable MCTS and the neural network to interact with the game universally, we need to define a consistent game interface.

class GomokuGame:
    def __init__(self, n=15):
        self.n = n  # Board size, default 15x15

    def getInitBoard(self):
        """
        Returns the initial board state, all positions are empty.
        """
        b = Board(self.n)
        return np.array(b.pieces)

    def getBoardSize(self):
        """
        Returns the size of the board, i.e., (n, n).
        """
        return (self.n, self.n)

    def getActionSize(self):
        """
        Returns the total number of actions, here it is n * n, as each cell can be an action.
        """
        return self.n * self.n

    def getNextState(self, board, player, action):
        """
        Executes an action and returns the next board state and the next player.

        Parameters:
        - board: Current board state
        - player: Current player (1 or -1)
        - action: Current action, 0 ~ n*n-1

        Returns:
        - (next_board, next_player): The board and next player after executing the action
        """
        b = Board(self.n)
        b.pieces = np.copy(board)
        # Convert action to coordinates (x, y)
        move = (action // self.n, action % self.n)
        b.execute_move(move, player)
        return (b.pieces, -player)
  • Unified interface: Allows our MCTS and neural network to be reused across different games, only needing to implement the specific game logic.
  • Board representation: Uses a two-dimensional array to represent the board, facilitating processing and visualization.

Core Algorithm Implementation#

Dual-Head Neural Network#

Network Structure#

In AlphaZero, we use a unified dual-head neural network to simultaneously predict policy (policy) and value (value). This neural network receives the current board state to compute outputs:

pp: The probability distribution output by the policy head, representing the probability of selecting each possible action in state ss. This policy head combines with MCTS search to generate stronger decisions.
vv: The value head outputs a scalar, representing the predicted value of the current board state ss (the probability of winning).

import torch
import torch.nn as nn
import torch.nn.functional as F

class AlphaZeroNNet(nn.Module):
    def __init__(self, game, args):
        """
        Parameters:
        - game: Game object providing board size and action space size.
        - args: Contains parameters for network structure, such as number of channels, dropout rate, etc.
        """
        super(AlphaZeroNNet, self).__init__()
        self.board_x, self.board_y = game.getBoardSize()  # Board size
        self.action_size = game.getActionSize()           # Action space size
        self.args = args

        # Convolutional layer blocks
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, args.num_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(args.num_channels),
            nn.ReLU(),
            nn.Conv2d(args.num_channels, args.num_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(args.num_channels),
            nn.ReLU(),
            nn.Conv2d(args.num_channels, args.num_channels, kernel_size=3),
            nn.BatchNorm2d(args.num_channels),
            nn.ReLU(),
            nn.Conv2d(args.num_channels, args.num_channels, kernel_size=3),
            nn.BatchNorm2d(args.num_channels),
            nn.ReLU(),
        )

        # Calculate the output size of the convolutional layers
        conv_output_size = self._get_conv_output_size()

        # Fully connected layer blocks
        self.fc_layers = nn.Sequential(
            nn.Linear(conv_output_size, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(args.dropout),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(args.dropout),
        )

        # Policy head: Outputs the log probabilities of each action
        self.policy_head = nn.Linear(512, self.action_size)

        # Value head: Outputs the value estimate of the current state
        self.value_head = nn.Linear(512, 1)

    def _get_conv_output_size(self):
        """
        Calculates the number of features output by the convolutional layers.
        """
        dummy_input = torch.zeros(1, 1, self.board_x, self.board_y)
        output_feat = self.conv_layers(dummy_input)
        n_size = output_feat.view(1, -1).size(1)
        return n_size

    def forward(self, s):
        """
        Forward propagation function.

        Parameters:
        - s: Input board state, shape (batch_size, board_x, board_y).

        Returns:
        - log_policies: Log probabilities of the policy, shape (batch_size, action_size).
        - values: Value estimates, shape (batch_size, 1).
        """
        # Adjust input shape
        s = s.view(-1, 1, self.board_x, self.board_y)  # (batch_size, 1, board_x, board_y)

        # Convolutional layers extract features
        s = self.conv_layers(s)  # (batch_size, num_channels, new_board_x, new_board_y)

        # Flatten the output of the convolutional layers
        s = s.view(s.size(0), -1)  # (batch_size, conv_output_size)

        # Fully connected layers extract high-level features
        s = self.fc_layers(s)  # (batch_size, 512)

        # Policy head output
        policies = self.policy_head(s)  # (batch_size, action_size)
        log_policies = F.log_softmax(policies, dim=1)

        # Value head output
        values = self.value_head(s)  # (batch_size, 1)
        values = torch.tanh(values)  # Limit the value to [-1, 1]

        return log_policies, values
  • Convolutional layers extract features: Through multiple convolutional layers, spatial features of the board are extracted.
  • Fully connected layers output: The features are flattened and passed through fully connected layers to output policy and value separately.
  • Activation functions:
    • The policy output uses log_softmax, facilitating subsequent cross-entropy loss calculations.
    • The value output uses tanh, limiting the value to [1,1][-1, 1].

Loss Function#

The training objective is defined through the following loss function to simultaneously train the network's policy and value assessment capabilities:

l=(zv)2πlogp+cθ2l = (z - v)^2 - \pi \log p + c \| \theta \|^2

(zv)2(z - v)^2: Value loss, requiring the network's output vv to be as close as possible to the game result zz (win is +1, loss is -1, draw is 0).
πlogp- \pi \log p: Policy loss, requiring the network's output policy pp to match the final policy π\pi obtained from MCTS as closely as possible through cross-entropy.
cθ2c \| \theta \|^2: L2 regularization term to prevent overfitting, keeping the weight scale of model parameters θ\theta appropriate.

MCTS Class#

class MCTS:
    def __init__(self, game, nnet, args):
        self.game = game         # Game environment
        self.nnet = nnet         # Neural network
        self.args = args         # Parameters
        self.Qsa = {}            # Store Q values: Q(s,a)
        self.Nsa = {}            # Store edge visit counts: N(s,a)
        self.Ns = {}             # Store node visit counts: N(s)
        self.Ps = {}             # Store policy priors: P(s,a)

        self.Es = {}             # Store game end information: E(s)
        self.Vs = {}             # Store valid actions: V(s)

One highlight here is the use of a caching mechanism: using dictionaries to cache computation results to avoid redundant calculations and improve efficiency.

Valid Action Mask#

In many games, certain actions are illegal in specific states. For example, in Gomoku, if a position is already occupied, placing a piece there is illegal. Therefore, ensuring that only valid actions are considered is crucial for the correctness and efficiency of the algorithm. This is achieved through the valid action mask.

In the search method, when we reach a leaf node and need to use the neural network for prediction, we process the policy output from the neural network, using the valid mask to mask illegal actions.

if s not in self.Ps: # Not in the stored policy prior P(s,a)
    # Leaf node, expand based on the neural network output policy
    self.Ps[s], v = self.nnet.predict(canonicalBoard)
    valids = self.game.getValidMoves(canonicalBoard, 1)  # Get the valid action mask, 1 for valid
    self.Ps[s] = self.Ps[s] * valids  # Mask illegal actions
    sum_Ps_s = np.sum(self.Ps[s])
    
    if sum_Ps_s > 0:
        self.Ps[s] /= sum_Ps_s  # Normalize policy probabilities
    else:
        # All actions are masked, perform uniform distribution
        self.Ps[s] = self.Ps[s] + valids
        self.Ps[s] /= np.sum(self.Ps[s])

    self.Vs[s] = valids  # Store the valid action mask
    self.Ns[s] = 0       # Initialize state visit count
    return v

PUCT#

A variant of the UCB formula, specifically for MCTS.

In AlphaZero, MCTS uses the neural network's output to guide the search direction. During the search process, the neural network selects actions using the improved upper confidence limit formula PUCT (Polynomial Upper Confidence Trees for Trees), introducing the prior probability P(s,a)P(s, a), allowing MCTS to leverage external knowledge (policy network) to guide the search:

U(s,a)=Q(s,a)+cpuctP(s,a)N(s)1+N(s,a)U(s, a) = Q(s, a) + c_{puct} \cdot P(s, a) \cdot \frac{\sqrt{N(s)}}{1 + N(s, a)}

Where:

  • Q(s,a)Q(s, a) is the average value of choosing action aa in state ss.
  • P(s,a)P(s, a) is the prior probability provided by the neural network.
  • N(s)N(s) and N(s,a)N(s, a) are the visit counts for state ss and edge (s,a)(s, a), respectively.
  • cpuctc_{puct} is a constant that controls the degree of exploration. The combination of the exploration term encourages the selection of actions with high prior probabilities but low visit counts.
cur_best = -float('inf') # Record the current maximum U value
best_act = -1            # Record the action with the maximum U value

# Iterate over all possible actions
for a in range(self.game.getActionSize()):
    if valids[a]: 
        # For valid actions, calculate the PUCT value
        if (s, a) in self.Qsa:
            # If (s, a) has been visited before, use the calculated Q value and visit count
            u = self.Qsa[(s, a)] + self.args.cpuct * self.Ps[s][a] * \
                math.sqrt(self.Ns[s]) / (1 + self.Nsa[(s, a)])
        else:
            # If it is an unvisited node, the Q value is 0, encouraging exploration
            u = self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s] + 1e-8) # Give higher U values to unvisited actions to avoid division by zero

        # Choose the action with the highest U value
        if u > cur_best:
            cur_best = u
            best_act = a

Balancing Exploration and Exploitation#

  • Unvisited nodes have a smaller denominator due to N(s, a) = 0, making the exploration term larger, encouraging the algorithm to explore new actions.

  • Q(s, a) reflects the current estimate of the action's value, representing exploitation.

Value Inversion#

From player AA's perspective, obtaining a higher value for player BB means a lower value for player AA, and vice versa. Therefore, throughout the search and value evaluation process, ensuring that the value reflects the maximization of the current player's interests maintains perspective consistency.

In the code, this implementation is reflected in the following call:

v = -self.search(next_s)

Search Function#

def search(self, canonicalBoard):
    """
    Executes an MCTS search.

    Parameters:
    - canonicalBoard: The current player's board state (canonical form)

    Returns:
    - v: Value assessment for the current node
    """
    s = self.game.stringRepresentation(canonicalBoard)

    if s not in self.Es:
        self.Es[s] = self.game.getGameEnded(canonicalBoard, 1)
    if self.Es[s] is not None:
        # If the game has ended, return the result
        return self.Es[s]

    '''
    Valid Mask & Leaf Node Expansion
    '''

    valids = self.Vs[s]

    '''
    Use the PUCT formula to select the optimal action
    '''

    a = best_act
    next_s, next_player = self.game.getNextState(canonicalBoard, 1, a)
    next_s = self.game.getCanonicalForm(next_s, next_player)

    # The value of the next state is inverted
    v = -self.search(next_s)

    # Update Q values and visit counts
    if (s, a) in self.Qsa:
        self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] + v) / (self.Nsa[(s, a)] + 1)
        self.Nsa[(s, a)] += 1
    else:
        self.Qsa[(s, a)] = v
        self.Nsa[(s, a)] = 1

    self.Ns[s] += 1
    return v
  • By recursively delving deeper into the search, simulating future possibilities.
  • Using the neural network for predictions at leaf nodes accelerates the search.
  • Since players alternate, the value of the next state is inverted.

Self-Play#

A Go algorithm that does not require human game data for training, only obtains data through self-play.

Self-play means that with each move, several MCTS iterations are executed to obtain the enhanced strategy π(s,a)\pi(s,a) for choosing moves. Players alternate between black and white until the game ends. Self-play generates a large amount of training data, helping the neural network learn better strategies and value estimates.

image

Implementing "Executing a Self-Play Game"#

First, let's look at the core function of Self-Play: executeEpisode, which is responsible for executing a complete game until the end and collecting training data.

def executeEpisode(self):
    """
    Executes a self-play game and returns a list of training samples.

    Returns:
        trainExamples: A list containing multiple (canonicalBoard, pi, v) tuples.
    """
    trainExamples = []
    board = self.game.getInitBoard()  # Initialize the board
    self.curPlayer = 1                # Current player (1 or -1)
    episodeStep = 0                   # Record the number of steps

    while True: # Until the game ends
        episodeStep += 1
        canonicalBoard = self.game.getCanonicalForm(board, self.curPlayer)

        temp = int(episodeStep < self.args.tempThreshold)  # Temperature parameter

        pi = self.mcts.getActionProb(canonicalBoard, temp=temp)  # Get action probability distribution (policy) using MCTS
        sym = self.game.getSymmetries(canonicalBoard, pi)        # Data augmentation (symmetry)
        for b, p in sym:
            trainExamples.append([b, self.curPlayer, p, None])

        action = np.random.choice(len(pi), p=pi)  # Choose an action according to the policy probabilities
        board, self.curPlayer = self.game.getNextState(board, self.curPlayer, action) # Execute the action, updating the board and current player.

        r = self.game.getGameEnded(board, self.curPlayer)  # Check if the game has ended
        if r is not None:
            # Assign final value to each training sample, +1 for the winning player, -1 for the losing player, 0 for a draw
            return [(x[0], x[2], r * ((-1)  (x[1] != self.curPlayer))) for x in trainExamples]
  • Temperature parameter: Controls the exploratory nature of the strategy. Initially set to 1 to encourage exploration; later set to 0 to favor the optimal strategy.

  • Data augmentation: In Gomoku, the rotation and flipping of the board do not change the essence of the game. Utilizing these symmetries generates equivalent boards and strategies, increasing training data to enhance the model's generalization ability.

If the game ends, training samples are generated based on the game records, and trainExamples is returned.

Main Loop of Self-Play#

Self-play is not just about executing a single game; we need to continuously perform self-play over multiple iterations, updating the model.

def learn(self):
    """
    Conduct multiple iterations of self-play and model updates.
    """
    for i in range(1, self.args.numIters + 1):
        log.info(f"Starting Iter #{i} ...")
        # Store training samples from this iteration
        iterationTrainExamples = deque([], maxlen=self.args.maxlenOfQueue)

        # Perform the specified number of self-plays
        for _ in tqdm(range(self.args.numEps), desc="Self Play"):
            self.mcts = MCTS(self.game, self.nnet, self.args)  # Reset MCTS
            iterationTrainExamples += self.executeEpisode()

        # Save training samples
        self.trainExamplesHistory.append(iterationTrainExamples)

        # Maintain a fixed-length history of training samples
        if len(self.trainExamplesHistory) > self.args.numItersForTrainExamplesHistory:
            log.warning(f"Removing the oldest entry in trainExamples.")
            self.trainExamplesHistory.pop(0)

        # Merge and shuffle all training samples
        trainExamples = []
        for e in self.trainExamplesHistory:
            trainExamples.extend(e)
        shuffle(trainExamples)

        # Train the neural network
        self.nnet.save_checkpoint(folder=self.args.checkpoint, filename="temp.pth.tar")
        self.pnet.load_checkpoint(folder=self.args.checkpoint, filename="temp.pth.tar")
        pmcts = MCTS(self.game, self.pnet, self.args)

        self.nnet.train(trainExamples)
        nmcts = MCTS(self.game, self.nnet, self.args)
  • If the history length exceeds the set threshold, the oldest sample is removed to ensure the model does not overfit old data.
  • After collecting enough training data, the model is trained using self.nnet.train(trainExamples).

Filtering Mechanism#

To ensure that the model's playing strength continues to improve, AlphaGo Zero introduces a filtering mechanism for new and old model matches:

  • After each training round, the new model is pitted against the old model.
    • If the new model's win rate exceeds a set threshold (e.g., 55%), the new model is accepted, and self-play continues to the next round; otherwise, the old model's weights are restored to continue generating data and training based on the old model, avoiding meaningless degradation and overfitting issues.
        # Evaluate the new model
        log.info("PITTING AGAINST PREVIOUS VERSION")
        arena = game.Arena(lambda x: np.argmax(pmcts.getActionProb(x, temp=0)),
                           lambda x: np.argmax(nmcts.getActionProb(x, temp=0)),
                           self.game)
        pwins, nwins, draws = arena.playGames(self.args.arenaCompare)

        log.info(f"NEW/PREV WINS : {nwins} / {pwins} ; DRAWS : {draws}")
        if pwins + nwins == 0 or float(nwins) / (pwins + nwins) < self.args.updateThreshold:
            log.info("REJECTING NEW MODEL")
            self.nnet.load_checkpoint(folder=self.args.checkpoint, filename="temp.pth.tar")
        else:
            log.info("ACCEPTING NEW MODEL")
            self.nnet.save_checkpoint(folder=self.args.checkpoint, filename="best.pth.tar")

Obtaining Strategy from MCTS#

During self-play, we use MCTS to select actions. The getActionProb function retrieves the strategy probability distribution from MCTS.

def getActionProb(self, canonicalBoard, temp=1):
    """
    Executes multiple MCTS searches and returns the action probability distribution.

    Parameters:
        canonicalBoard: The current normalized board state
        temp: Temperature parameter

    Returns:
        probs: Action probability distribution
    """
    # Perform the specified number of MCTS searches
    for _ in range(self.args.numMCTSSims):
        self.search(canonicalBoard) 

    s = self.game.stringRepresentation(canonicalBoard)
    # Record the visit count for each action N(s, a)
    counts = [self.Nsa[(s, a)] if (s, a) in self.Nsa else 0 
              for a in range(self.game.getActionSize())]

    if temp == 0:
        bestAs = np.array(np.argwhere(counts == np.max(counts))).flatten()
        bestA = np.random.choice(bestAs)
        probs = [0] * len(counts)
        probs[bestA] = 1
        return probs

    counts = [x  (1. / temp) for x in counts]
    counts_sum = float(sum(counts))
    probs = [x / counts_sum for x in counts]
    return probs

The visit count N(s, a) reflects MCTS's confidence in the action.

Calculating strategy probabilities:

  • When temp = 0:
    • Select the action with the highest visit count, probability 1.
    • Ensures the strategy is deterministic, typically used in the later stages of the game or evaluation phase.
  • When temp > 0:
    • Apply temperature transformation to visit counts: counts = [x (1. / temp) for x in counts].
    • Higher temperatures lead to strategies closer to uniform distribution, increasing exploration.

Training Process#

MCTS simulation times: 400, cpuct: 1.0, training for 50 rounds:

image

At this point, the model has learned the rules of Gomoku, reaching a beginner-level AI, and you can sneak a win at the edge of the board.

image

Adding 1cycle Learning Rate Adjustment#

Dynamically adjusting the learning rate can improve the model's convergence speed and generalization ability.

  • First phase (first half of the cycle): The learning rate gradually increases from the minimum value (1.0×1041.0 \times 10^{-4}) to the maximum value (1.0×1021.0 \times 10^{-2}).
  • Second phase (second half of the cycle): The learning rate gradually decreases back to the minimum value.
class NNetWrapper:
    def __init__(self, game, args):
        ...
        self.optimizer = optim.Adam(self.nnet.parameters(), lr=args.max_lr)
        
        # 1cycle learning rate parameters
        self.total_steps = args.numIters * args.epochs * (args.maxlenOfQueue // args.batch_size)
        self.current_step = 0

    def get_learning_rate(self):
        """Implement 1cycle learning rate strategy"""
        if self.current_step >= self.total_steps:
            return self.args.min_lr
        
        half_cycle = self.total_steps // 2
        
        if self.current_step <= half_cycle:
            # First phase: increase from min_lr to max_lr
            phase = self.current_step / half_cycle
            lr = self.args.min_lr + (self.args.max_lr - self.args.min_lr) * phase
        else:
            # Second phase: decrease from max_lr to min_lr
            phase = (self.current_step - half_cycle) / half_cycle
            lr = self.args.max_lr - (self.args.max_lr - self.args.min_lr) * phase
        
        return lr

Adding Gradient Clipping#

You can review PPO

To prevent gradient explosion and stabilize the training process, we add gradient clipping after backpropagation.
Assuming the gradient vector of a parameter is g\mathbf{g}, with its current L2 norm being g\|\mathbf{g}\|:

  • If g1.0\|\mathbf{g}\| \leq 1.0, then g\mathbf{g} remains unchanged.
  • If g>1.0\|\mathbf{g}\| > 1.0, then g\mathbf{g} is scaled to 1.0gg\frac{1.0}{\|\mathbf{g}\|} \cdot \mathbf{g}.

This ensures that:

1.0gg=1.0\|\frac{1.0}{\|\mathbf{g}\|} \cdot \mathbf{g}\| = 1.0
def train(self, examples):
    ...
    for epoch in range(self.args.epochs):
        ...
        for _ in t:
            ...
            # Calculate loss and backpropagate
            total_loss.backward()

            # Gradient clipping
            if self.args.grad_clip:
                torch.nn.utils.clip_grad_norm_(self.nnet.parameters(), self.args.grad_clip)

            # Update parameters
            self.optimizer.step()
            ...

Gradient clipping limits the norm of the gradients to a specified threshold, preventing overly large gradients from causing instability in training. By clipping the gradients before each parameter update, we ensure the robustness of the training process and accelerate the model's convergence.

Conclusion#

The final code and demo can be found in the repository below, and the model weights can be found in the huggingface link repository in the documentation.

After making the above improvements, the number of MCTS simulations was increased to 2000, cpuct: 4.0 to increase exploration inclination, and trained for 22 rounds (running on a single 3090 Ti for 4 days). At this point, our model gradually upgraded to an intermediate difficulty (under 50 MCTS simulations, the initial pressure is acceptable, but in the mid to late game, it may overlook the opponent's winning points).

image

If increasing the number of MCTS simulations during inference can elevate it to the level of a difficult AI, feel free to try and see if you can win.

Screenshot 2024-12-15 at 15.54.38

References#

Loading...
Ownership of this post data is guaranteed by blockchain and smart contracts to the creator alone.