banner
Nagi-ovo

Nagi-ovo

Breezing homepage: [nagi.fun](nagi.fun)
github

讓我們建立 AlphaZero

本文是對於 Sunrise:從頭理解 AlphaZero,MCTS,Self-Play,UCB 等文章、視頻教程和代碼實現的消化與理解。

本文將從 AlphaGo 的設計原理出發,通過深入理解 MCTS 和 Self-Play 這兩個核心機制,逐步揭示如何構建一個能超越人類的 AI 五子棋(Gomoku) 系統。

AlphaGo:從模仿到超越#

AlphaGo 的進化路徑#

圍棋的可能局面數超過宇宙中的原子數量,使得傳統窮舉搜索方法完全失效。AlphaGo 通過一個分階段的方法解決這個問題:先學習人類知識,然後通過 self-play 不斷進化。

image

這個進化過程可以分為三層:

  1. 模仿人類高手
  2. 自我博弈提升
  3. 學習評估局勢

核心組件#

image

快速走子策略(Rollout Policy)#

最左邊的一個輕量級策略網絡,用於快速評估,準確性較低但計算效率高。

SL 策略網絡#

監督學習策略網絡 PσP_{\sigma} 通過模仿人類棋譜學習下棋:

  • 輸入:棋盤狀態
  • 輸出:模仿人類專家的下一步走法概率分布
  • 訓練數據:16 萬場對局,約 3000 萬個落子步驟

RL 策略網絡#

類似當你棋藝達到一定水平時,開始自己復盤,自己和自己對戰,發現新的戰術和更深的策略。RL 網絡 PρP_{\rho} 從 SL 網絡初始化,該網絡已經遠超人類,可以找到許多人類未曾發現的強力策略。

  • 這個階段生成了大量 self-play 的數據,輸入依然是棋盤狀態,輸出是通過強化學習改進的走棋策略
  • 與歷史版本的策略網絡 pπp_\pi 進行 self-play,網絡通過 “贏棋” 信號來調整參數,通過 Policy Gradiant 強化那些導致勝利的走法。:
tz(τ)logp(atst;ρ)\sum_{t} z(\tau) \log p(a_t \mid s_t; \rho)

其中:

  • τ\tau 是對局序列 (s0,a0,s1,a1,...)(s_0,a_0,s_1,a_1,...)
  • z(τ)z(\tau) 是勝負標籤(勝為正,負為負)

價值網絡#

價值網絡 vθv_\theta 學習評估局勢,可以是個 CNN:

  • 訓練目標:最小化均方誤差
i(zvθ(s))2\sum_{i} (z - v_\theta(s))^2
  • zz:最終勝負結果
  • vθ(s)v_\theta(s):預測的獲勝概率

補充解釋#

關於 PρP_\rhovθv_\theta 的關係

  • 策略網絡 PρP_\rho 提供具體的走子概率,用於指導搜索。
  • 價值網絡 vθv_\theta 提供局面評估,減少搜索樹中不必要的模擬。
  • 兩者配合,使 MCTS 不僅能更快地探索高勝率路徑,還能顯著提升整體的下棋水平。

AlphaGo 的 MCTS 實現#

Selection 階段#

結合了探索與利用:

at=argmaxaQ(st,a)+u(st,a)a_t = \arg\max_a Q(s_t, a) + u(s_t, a)
u(st,a)P(st,a)1+N(st,a)u(s_t, a) \propto \frac{P(s_t, a)}{1 + N(s_t, a)}

其中:

  • Q(st,a)Q(s_t, a):長期回報估計
  • u(st,a)u(s_t, a):探索獎勵
  • P(st,a)P(s_t, a):策略網絡輸出的先驗概率
  • N(st,a)N(s_t, a):訪問次數

Simulation 與評估#

在原始的 MCTS 算法中,模擬階段 (Simulation) 的作用是通過快速 rollout 策略從葉子節點(擴展的新節點)進行隨機對弈,直至遊戲結束,然後根據對弈的勝負得到一個回報。這個回報被傳遞回搜索樹中的節點,用於更新這些節點的值估計(即 Q(s,a)Q(s, a) )。

這樣的實現簡單直接,但是 rollout 策略通常是隨機或簡單的規則,模擬質量可能較差。且只能給出短期信息,無法很好地結合全局的戰略評估。

而 AlphaGo 在模擬階段結合了 Value Network vθv_\theta 和 rollout 策略。Value Network 提供了更高效的葉子節點估計和全局能力評估,rollout 策略通過快速模擬捕捉局部的短期效果。

使用超參數 λ\lambda 權衡 vθ(sL)v_\theta(s_L)zLz_L,兼顧局部模擬和全局評估。評估函數:

V(sL)=(1λ)vθ(sL)+λzL V(s_L) = (1 - \lambda)v_\theta(s_L) + \lambda z_L

Backpropagation#

n 次 MCTS 時節點訪問次數更新(I\mathbb{I}是指示函數,訪問則為 1):

N(st,a)=i=1nI(s,a,i)N(s_t, a) = \sum_{i=1}^{n} \mathbb{I}(s, a, i)

Q 值更新,即執行 a 到節點 sts_t 的長期預期回報:

Q(st,a)=1N(st,a)i=1nI(s,a,i)V(sLi)Q(s_t, a) = \frac{1}{N(s_t, a)} \sum_{i=1}^{n} \mathbb{I}(s, a, i) V(s_L^i)

總結#

  1. 結構創新

    • 策略網絡提供先驗知識
    • 價值網絡提供全局評估
    • MCTS 提供戰術驗證
  2. 訓練創新

    • 從監督學習起步
    • 通過強化學習超越教師
    • 自我博弈產生新知識
  3. MCTS 改進

    • 使用神經網絡指導搜索
    • Policy Network 提了探索方向的先驗概率,Value Network 提升了葉子節点评估的準確性。
    • 這樣結合價值網絡和 rollout 的評估,不僅減少了搜索寬度和深度,還顯著提高了整體性能。
    • 高效的探索 - 利用平衡

這種設計使 AlphaGo 能在龐大的搜索空間中找到高效的解決方案,最終超越人類水平。

啟發式搜索與 MCTS#

MCTS 就像是一個不斷進化的探索者,在決策樹中尋找最佳路徑。

核心思想#

MCTS 的本質是什麼?簡單來說,它是一個 "邊玩邊學" 的過程。想像你在玩一盤全新的棋類遊戲:

  • 開始時,你會嘗試各種可能的走法(探索)
  • 慢慢發現某些走法效果更好(利用)
  • 在探索新策略和利用已知好策略之間取得平衡

這正是 MCTS 所做的,只不過它用數學的方式來系統化這個過程。其是一種 rollout 算法,通過累積蒙特卡洛模擬的價值估計來引導策略選擇。

image

算法流程#

Monte Carlo Tree Search - YouTube這個老師對於 MCTS 的流程講的很好。

MCTS 的優雅之處在於它的四個簡單但強大的步驟,這裡我用 A、B 兩種理解方式來介紹:

  1. 選擇 (Selection)

    • A:你知道小孩子是怎麼學習的嗎?他們總是在已知和未知之間徘徊。MCTS 也是如此:從根節點開始,使用 UCB (Upper Confidence Bound) 公式來權衡是選擇已知的好路徑,還是去探索新的可能。
    • B:從根節點出發,依據特定策略從當前節點中選擇一個後續節點。該策略通常基於樹的搜索歷史,選取最具潛力的路徑。例如,我們在每個節點依據當前的策略 π(s)\pi(s) 執行動作 aa,以平衡探索與利用,逐步深入。
  2. 擴展 (Expansion)

    • A:就像探險家在地圖上開闢新的區域,當我們到達一個葉節點時,我們會向下擴展,創建新的可能性。
    • B:當選中的節點尚有未探索的子節點時,我們在此節點下依據可行動作集擴展新節點。這一過程的目的是增加決策樹的廣度,逐步生成可能的決策路徑。通過這一擴展操作,我們確保搜索涵蓋更多可能的 state action pair (s,a)(s,a)
  3. 模擬 (Simulation)

    • A:這是最有趣的部分。從新節點開始,我們進行一次 "假想" 的遊戲,直到遊戲結束。這就像下棋時在腦中推演 "如果我這樣走,對手那樣走..."。
    • B:從當前擴展的節點出發,執行隨機模擬(rollout),在 MDP 環境中沿著當前策略 π(s)\pi(s) 進行採樣直至終止狀態。此過程提供了從當前節點出發到終點的回報估計,為路徑的優劣提供數值依據。
  4. 回溯 (Backpropagation)

    • A:最後,我們把得到的結果沿路徑返回,更新每個節點的統計信息。這就像在說:"嘿,我試過這條路,效果還不錯(或不太好)"。
    • B:完成模擬後,將該模擬的估計回報回溯到經過的所有節點,以更新這些節點的價值。這一過程累積了歷史回報信息,使得未來的選擇更加精確地趨向高收益路徑。

image

UCB1:探索與利用的完美平衡#

這裡要特別提到 UCB1 公式,它是 MCTS 的靈魂:

UCB1=Xˉi+C×(ln(N)/ni)UCB1 = X̄ᵢ + C × \sqrt{(ln(N)/nᵢ)}

讓我們解構一下:

  • XˉiX̄ᵢ 是平均收益(利用項)
  • (ln(N)/ni)\sqrt{(ln(N)/nᵢ)} 是不確定性度量(探索項)
  • CC 是探索參數

就像一個優秀的投資組合,既要關注已知的好機會,也要保持對新機會的開放(探索 - 利用權衡)。

image

Best Multi-Armed Bandit Strategy? (feat: UCB Method) 這個視頻對 Multi-Armed Bandit 和 UCB Method 講的很好,我這裡借鑒這個老師用的例子:

一個吃貨嘗試理解 UCB#

想像你剛到一個城市,有 100 家餐廳可選擇,你有 300 天時間。每天你都要選一家餐廳就餐,希望在這 300 天裡平均吃到最好的美食。

ε-greedy 策略:簡單但不夠智能#

這就像是用擲骰子決定:

  • 90% 的時間 (ε=0.1):去已知最好的餐廳(利用)
  • 10% 的時間:隨機嘗試一家新餐廳(探索)
def epsilon_greedy(restaurants, ratings, epsilon=0.1):
    if random.random() < epsilon:
        return random.choice(restaurants)  # 探索
    else:
        return restaurants[np.argmax(ratings)]  # 利用

這樣的效果是:

  • 探索完全隨機,可能重複訪問已知很差的餐廳
  • 探索比例固定,不會隨時間調整
  • 不考慮訪問次數的影響

UCB 策略:更智能的選擇權衡#

UCB 公式在餐廳選擇中的含義如下:

評分=平均得分+C×ln(總訪問天數)/該餐廳訪問次數評分 = 平均得分 + C × \sqrt{ln(總訪問天數)/該餐廳訪問次數}

例如,考慮兩家餐廳在第 100 天時的情況:

A 餐廳:

  • 訪問 50 次,平均分 4.5
  • UCB 分數 = 4.5+2×ln(100)/504.5+0.6=5.14.5 + 2×\sqrt{ln(100)/50} ≈ 4.5 + 0.6 = 5.1

B 餐廳:

  • 訪問 5 次,平均分 4.0
  • UCB 分數 = 4.0+2×ln(100)/54.0+1.9=5.94.0 + 2×\sqrt{ln(100)/5} ≈ 4.0 + 1.9 = 5.9

雖然 B 餐廳平均分較低,但因為訪問次數少,不確定性高,所以 UCB 給予更高的探索獎勵

image

代碼實現:

class Restaurant:
    def __init__(self, name):
        self.name = name
        self.total_rating = 0
        self.visits = 0
        
def ucb_choice(restaurants, total_days, c=2):
    # 確保每家餐廳至少訪問一次
    for r in restaurants:
        if r.visits == 0:
            return r
            
    # 使用UCB公式選擇餐廳
    scores = []
    for r in restaurants:
        avg_rating = r.total_rating / r.visits
        exploration_term = c * math.sqrt(math.log(total_days) / r.visits)
        ucb_score = avg_rating + exploration_term
        scores.append(ucb_score)
        
    return restaurants[np.argmax(scores)]

為什麼 UCB 更好?#

  1. 自適應探索

    • 訪問次數少的餐廳獲得更高的探索獎勵
    • 隨著總天數增加,探索項會逐漸減小,以更好地進行利用
  2. 平衡時間投資

    • 不會在明顯較差的餐廳上浪費太多時間
    • 會在潛力相近的餐廳之間合理分配訪問次數
  3. 理論保證

    • Regret Bound(與最優選擇的差距)隨時間呈對數增長
    • 300 天的探索足夠找到最好的幾家餐廳

我們回到 MCTS:

image

為什麼 MCTS 如此強大?#

  • 高效處理組合爆炸: MCTS 不需要像 Minimax 窮舉搜索所有可能,而是專注於最有希望的分支,使其能夠有效處理分支因子巨大的問題。
  • 自適應搜索: MCTS 動態調整其搜索策略,將更多資源分配給更有希望的區域,從而更快地找到好的解決方案。
  • 平衡探索與利用: 通過 UCB 公式,MCTS 在探索新可能性和利用已知良好選擇之間取得平衡,避免陷入局部最優。
  • 無需領域知識: MCTS 不依賴於特定領域的專業知識,僅依靠遊戲規則和模擬結果進行學習,使其具有廣泛的適用性。
  • 可隨時停止: MCTS 是一種 “隨時” 算法,可以隨時中斷並返回當前最佳解決方案,這對於實時應用至關重要。

AlphaZero:從 MCTS 到自我進化#

AlphaZero 是 DeepMind 在 AlphaGo 之後推出的通用強化學習算法,它能夠在不使用人類棋譜的情況下,通過自我對弈(Self-Play)從零開始學習並最終超越專業水平。

AlphaZero 對傳統的 MCTS 進行了改進,引入了 神經網絡 來指導搜索:

  • 策略先驗:使用神經網絡預測每個動作的先驗概率,使 搜索更加高效
  • 價值評估:在葉節點,使用神經網絡的價值預測代替隨機模擬,降低計算成本。

下面我們以五子棋為例,實現一個可以這樣自學成才的 AI。這裡學習 & 借鑒 schinger/alphazero.py 的優秀實現。

遊戲環境的設計#

在實現 AlphaZero 之前,我們需要先定義遊戲環境。

定義遊戲接口#

在 AlphaZero 中,神經網絡接收當前的棋盤狀態,輸出一個策略向量 P\boldsymbol{P} 表示每個動作的概率,以及一個標量值 vv 表示當前玩家的勝率預測。

為使 MCTS 和神經網絡能夠通用地與遊戲交互,我們需要定義一個一致的遊戲接口。

class GomokuGame:
    def __init__(self, n=15):
        self.n = n  # 棋盤大小,默認15x15

    def getInitBoard(self):
        """
        返回初始的棋盤狀態,所有位置都為空。
        """
        b = Board(self.n)
        return np.array(b.pieces)

    def getBoardSize(self):
        """
        返回棋盤的尺寸,即 (n, n)。
        """
        return (self.n, self.n)

    def getActionSize(self):
        """
        返回動作的總數,這裡是 n * n,因為每個格子都可能是一個動作。
        """
        return self.n * self.n

    def getNextState(self, board, player, action):
        """
        執行動作,返回下一個棋盤狀態和下一個玩家。

        參數:
        - board: 當前棋盤狀態
        - player: 當前玩家(1 或 -1)
        - action: 當前動作,0 ~ n*n-1

        返回:
        - (next_board, next_player): 執行動作後的棋盤和下一玩家
        """
        b = Board(self.n)
        b.pieces = np.copy(board)
        # 將動作轉換為坐標 (x, y)
        move = (action // self.n, action % self.n)
        b.execute_move(move, player)
        return (b.pieces, -player)
  • 統一的接口:使得我們的 MCTS 和神經網絡可以在不同的遊戲中復用,只需實現遊戲的具體邏輯。
  • 棋盤表示:使用二維數組表示棋盤,便於處理和可視化。

核心算法實現#

雙頭神經網絡#

網絡結構#

在 AlphaZero 中,我們使用一個統一的雙頭神經網絡來同時預測策略(policy)和價值(value)。這個神經網絡接收當前的棋盤狀態以計算輸出:

pp :策略頭輸出的概率分布,表示在狀態 ss 下每個可能動作的選擇概率。這個策略頭與 MCTS 搜索結合,生成更強的決策。
vv :價值頭輸出一個標量,表示當前棋盤狀態 ss 的預測值(最終獲勝的概率)。

import torch
import torch.nn as nn
import torch.nn.functional as F

class AlphaZeroNNet(nn.Module):
    def __init__(self, game, args):
        """
        參數:
        - game: 遊戲對象,提供棋盤大小和動作空間大小等信息。
        - args: 包含網絡結構的參數,例如通道數、dropout 概率等。
        """
        super(AlphaZeroNNet, self).__init__()
        self.board_x, self.board_y = game.getBoardSize()  # 棋盤尺寸
        self.action_size = game.getActionSize()           # 動作空間大小
        self.args = args

        # 卷積層塊
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, args.num_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(args.num_channels),
            nn.ReLU(),
            nn.Conv2d(args.num_channels, args.num_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(args.num_channels),
            nn.ReLU(),
            nn.Conv2d(args.num_channels, args.num_channels, kernel_size=3),
            nn.BatchNorm2d(args.num_channels),
            nn.ReLU(),
            nn.Conv2d(args.num_channels, args.num_channels, kernel_size=3),
            nn.BatchNorm2d(args.num_channels),
            nn.ReLU(),
        )

        # 計算卷積層輸出的尺寸
        conv_output_size = self._get_conv_output_size()

        # 全連接層塊
        self.fc_layers = nn.Sequential(
            nn.Linear(conv_output_size, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(args.dropout),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(args.dropout),
        )

        # 策略頭:輸出每個動作的對數概率
        self.policy_head = nn.Linear(512, self.action_size)

        # 價值頭:輸出當前狀態的價值估計
        self.value_head = nn.Linear(512, 1)

    def _get_conv_output_size(self):
        """
        計算卷積層輸出的特徵數量。
        """
        dummy_input = torch.zeros(1, 1, self.board_x, self.board_y)
        output_feat = self.conv_layers(dummy_input)
        n_size = output_feat.view(1, -1).size(1)
        return n_size

    def forward(self, s):
        """
        前向傳播函數。

        參數:
        - s: 輸入的棋盤狀態,形狀為 (batch_size, board_x, board_y)。

        返回:
        - log_policies: 策略的對數概率,形狀為 (batch_size, action_size)。
        - values: 價值估計,形狀為 (batch_size, 1)。
        """
        # 輸入形狀調整
        s = s.view(-1, 1, self.board_x, self.board_y)  # (batch_size, 1, board_x, board_y)

        # 卷積層提取特徵
        s = self.conv_layers(s)  # (batch_size, num_channels, new_board_x, new_board_y)

        # 展平卷積層輸出
        s = s.view(s.size(0), -1)  # (batch_size, conv_output_size)

        # 全連接層提取高級特徵
        s = self.fc_layers(s)  # (batch_size, 512)

        # 策略頭輸出
        policies = self.policy_head(s)  # (batch_size, action_size)
        log_policies = F.log_softmax(policies, dim=1)

        # 價值頭輸出
        values = self.value_head(s)  # (batch_size, 1)
        values = torch.tanh(values)  # 將價值限定在 [-1, 1]

        return log_policies, values
  • 卷積層提取特徵:通過多層卷積層,提取棋盤的空間特徵。
  • 全連接層輸出:將特徵展平,通過全連接層,分別輸出策略和價值。
  • 激活函數:
    • 策略輸出使用 log_softmax,方便後續計算交叉熵損失。
    • 價值輸出使用 tanh,將值限制在 [1,1][-1, 1] 之間。

損失函數#

訓練目標通過以下損失函數定義實現同時訓練網絡的策略和價值評估能力

l=(zv)2πlogp+cθ2l = (z - v)^2 - \pi \log p + c \| \theta \|^2

(zv)2(z - v)^2 :價值損失,要求網絡輸出的 vv 儘量接近對局結果 zz(勝利為 +1,失敗為 -1,平均為 0)。
πlogp- \pi \log p:策略損失,通過交叉熵,要求網絡輸出的策略 pp 儘量匹配 MCTS 得到的最終策略 π\pi
cθ2c \| \theta \|^2 :L2 正則化項,用於防止過擬合,保持模型參數 θ\theta 的權重規模適當。

MCTS 類#

class MCTS:
    def __init__(self, game, nnet, args):
        self.game = game         # 遊戲環境
        self.nnet = nnet         # 神經網絡
        self.args = args         # 參數
        self.Qsa = {}            # 存儲 Q 值:Q(s,a)
        self.Nsa = {}            # 存儲邊的訪問次數:N(s,a)
        self.Ns = {}             # 存儲節點的訪問次數:N(s)
        self.Ps = {}             # 存儲策略先驗:P(s,a)

        self.Es = {}             # 存儲遊戲結束信息:E(s)
        self.Vs = {}             # 存儲合法動作:V(s)

這裡的一個亮點是使用了緩存機制:使用字典來緩存計算結果,避免重複計算,提高效率。

有效動作掩碼#

在很多遊戲中,某些動作在特定狀態下是非法的。比如在五子棋中,如果某個位置已經被佔據,那麼在該位置下子就是非法的。因此,確保只考慮合法動作 對於算法的正確性和效率至關重要。這是通過 有效動作掩碼(valid mask) 來實現的。

search 方法中,當我們到達葉節點並需要使用神經網絡進行預測時,我們對神經網絡輸出的策略進行處理,使用 valid mask 來屏蔽非法動作。

if s not in self.Ps: # 不在策略先驗存儲 P(s,a) 中
    # 葉節點,根據神經網絡輸出策略進行擴展
    self.Ps[s], v = self.nnet.predict(canonicalBoard)
    valids = self.game.getValidMoves(canonicalBoard, 1)  # 獲取合法動作的掩碼,1 為合法
    self.Ps[s] = self.Ps[s] * valids  # 掩碼非法動作
    sum_Ps_s = np.sum(self.Ps
載入中......
此文章數據所有權由區塊鏈加密技術和智能合約保障僅歸創作者所有。