banner
Nagi-ovo

Nagi-ovo

Breezing homepage: [nagi.fun](nagi.fun)
github

Diffusion 的直覺和數學

本文主要以該視頻的教學邏輯為主線,結合講解內容進行整理和闡述,如有錯誤歡迎在評論區糾正!

直覺部分#

理論支持#

Deep Unsupervised Learning using Nonequilibrium Thermodynamics

這篇論文奠定了擴散模型的理論基礎。作者提出了一種基於非平衡熱力學的生成模型,通過逐步添加和去除噪聲來實現數據生成。這為後續的擴散模型研究提供了重要的理論支持。

我們對圖像應用大量噪聲,然後用神經網絡去噪。如果這個神經網絡學得很好,那可以從完全隨機的噪聲開始最終得到我們訓練數據中的圖像。

image

Forward Diffusion Process:#

迭代地對圖像施加噪聲,步驟足夠多時圖像會完全變成噪聲,使用正態分布作為噪聲源:

Screenshot 2024-12-09 at 21.03.59

Reverse Diffusion Process#

從純粹的噪聲到圖像,涉及一個學習一步步去噪的神經網絡。

為什麼是逐漸去噪?作者在論文中提到 “一步直接完成去噪” 的結果很糟糕。

那麼這個網絡是什麼樣的?它又要預測什麼?

算法改進#

Denoising Diffusion Probabilistic Models

這篇論文提出了去噪擴散概率模型(DDPM),顯著提高了擴散模型的生成質量和效率。通過引入簡單的去噪網絡和優化的訓練策略,DDPM 成為了擴散模型領域的一個重要里程碑。

作者討論了神經網絡可以預測的三種目標:

  1. 預測每個時間步的噪聲均值 (predict the mean of the noise at each timestep)

    • 即預測條件分布 p(xt1xt)p(x_{t-1}|x_t) 的均值
    • 方差是 fixed 的,不可學習
  2. 預測原始圖像 (predict x0x_0 directly)

    • 直接預測原始、未被污染的圖像
    • 實驗證明這種方式效果較差
  3. 預測添加的噪聲 (predict the added noise)

    • 預測在正向過程中添加的噪聲 εε
    • 和第一種方法(預測噪聲均值)實際上是數學上等價的,只是參數化方式不同,它們可以通過簡單的變換相互轉換

論文最終選擇了預測噪聲(第三種方式)作為主要方法,因為這種方式訓練更穩定且效果更好

這裡每一步添加的噪聲量是不固定的,通過一個 Linear Schedule 來控制噪聲添加,防止訓練過程不穩定。

大概長下面這樣:
Screenshot 2024-12-09 at 21.34.56

可以看到最後最後幾個時間步都接近完全噪聲了,信息很少,此外整體來看信息摧毀得太快了,因此 OpenAI 使用了 Cosine Schedule 解決了這兩個問題:

Screenshot 2024-12-09 at 21.39.35

模型架構#

U-Net#

隨著這篇論文一起發表的是一個名叫 U-Net 的模型架構:

這個模型在中間有一個 Bottleneck(也就是參數量較小的層),用 Downsample-Block 和 Resnet-Block 來將輸入圖像投影到小分辨率,輸出時用 Upsample-Block 將其投影回初始尺寸。

image

在某些分辨率下,作者加入了 Attention-Block 並在相同分辨率空間的層之間做 Skip-Connection。模型是被涉及為針對每一個時間步的,這是通過 Transformer 中的正弦位置編碼嵌入來實現的,嵌入被投影到每個 Residual-Block 中。模型還能結合 Schedule,在不同時間步中去除不同量的噪聲來提升生成效果,後面會詳細討論。

Bottleneck 和 Autoencoder#

Bottleneck(瓶頸層)的概念最初是在無監督學習方法 "自編碼器(Autoencoder)" 中提出並廣泛使用的。作為自編碼器架構中維度最低的隱藏層,它位於編碼器和解碼器之間,構成了網絡中最窄的部分,強制網絡學習數據的壓縮表示,最小化重建誤差並起到正則化作用:Lreconstruction=XDecoder(Encoder(X))2\mathcal{L}_{\text{reconstruction}} = \|X - \text{Decoder}(\text{Encoder}(X))\|^2

image

架構改進#

OpenAI 在他們的第二篇論文 Diffusion Models Beat GANs on Image Synthesis 中通過改進架構顯著改善了整體效果:

  1. 增加網絡深度(更多層),減少寬度(每層的通道數)
  2. 增加 Attention-Block 數量
  3. 並擴大每個 Attention Block 中的 heads 數量
  4. 引入 BigGAN 風格的 Residual Block 用於上採樣和下採樣
  5. 引入 Adaptive Group Normalization (AdaGN),通過條件信息(如時間步)來動態調整歸一化的參數
  6. 用 Separate Classifier Guidance 幫助模型生成某類圖片

數學部分#

符號表#

  • XtX_t 代表在 tt 時間步的圖像,即 X0X_0 是原始圖像。可以簡記 tt 越小噪聲越少:

image

  • 噪聲的最終圖像是一個 isotropic(各方向相同) 的完全噪聲,記作 XTX_T,在最開始的研究中 T=1000T=1000,後續工作會將其減小很多:

image

  • Forward Process:q(xtxt1)q(x_t|x_{t-1}),輸入xt1x_{t-1}圖像輸出一張噪聲更多的圖像XtX_t

image

  • Backward Process:p(xt1xt)p(x_{t-1}|x_t),輸入 xtx_t 圖像用神經網絡輸出一個降噪的圖像 xt1x_t-1

image

正向過程#

image

其中,

  • 1βtxt1\sqrt{1 - \beta_t} x_{t-1} 是分布的均值(mean);βt\beta_tnoise schedule 參數,範圍在 0 到 1 之間,配合1βt\sqrt{1 - \beta_t}噪聲完成縮放,隨著時間步增加而減小,表示保留的原始信號部分

image

  • βtI\beta_t I 是分布的協方差矩陣(covariance matrix),II單位矩陣,表示協方差矩陣是對角的且各維度獨立,隨著時間步增加,添加的噪聲量增大。

現在我們只需要將這個步驟不斷迭代執行即可得到 1000 步後的結果,但其實這些可以一步完成。

Reparameterization Trick#

重參數化技巧在擴散模型和其他生成模型(如變分自編碼器,VAE)中非常重要。它的核心思想是將隨機變量的採樣過程轉化為一個確定性函數加上一個標準化的隨機變量。這種轉換使得模型可以通過梯度下降進行優化,因為它消除了採樣過程中的隨機性對梯度計算的影響

這裡通過一個簡單的例子來解釋其意義

你要實現一個扔骰子可以有兩種方式,

  • 第一種是隨機性在函數內部:
# 1. 直接掷骰子(隨機採樣)
def roll_dice():
    return random.randint(1, 6)

result = roll_dice()
  • 第二種則是讓隨機性在函數外部,函數本身是確定性的:
# 2. 將隨機性分離出來
random_number = random.random()  # 生成0到1之間的隨機數

def transformed_dice(random_number):
    # 將0-1的隨機數映射到1-6
    return math.floor(random_number * 6) + 1

result = transformed_dice(random_number)

概率論中我們學過:如果 XX 是一個隨機變量,且 X𝒩(0,1)X ∼ 𝒩(0,1),那麼有:aX+b𝒩(b,a2)aX + b ∼ 𝒩(b, a²)

因此對於正態分布 N(μ,σ2)\mathcal{N}(\mu, \sigma^2) 可以通過以下方式生成樣本:

x=μ+σϵx = \mu + \sigma \cdot \epsilon

其中 ϵN(0,1)\epsilon \sim \mathcal{N}(0, 1) 是標準正態分布。

那同理,在正態分布中,

  • 不使用重參數化
# 直接從目標正態分布採樣
x = np.random.normal(mu, sigma)
  • 使用重參數化
# 先從標準正態分布採樣
epsilon = np.random.normal(0, 1)
# 然後通過確定性變換得到目標分布
x = mu + sigma * epsilon

對應到模型訓練中涉及的梯度計算時,

不使用重參數化

def sample_direct(mu, sigma):
    return np.random.normal(mu, sigma)

# 這種情況下很難計算關於mu和sigma的梯度
# 因為隨機採樣操作阻斷了梯度傳播

使用重參數化

def sample_reparameterized(mu, sigma):
    epsilon = np.random.normal(0, 1)  # 梯度不需要通過這裡傳播
    return mu + sigma * epsilon        # 可以輕鬆計算mu和sigma的梯度

以 VAE(變分自編碼器)為例:

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.encoder = Encoder()  # 輸出mu和sigma
        self.decoder = Decoder()

    def reparameterize(self, mu, sigma):
        # 重參數化技巧
        epsilon = torch.randn_like(mu)  # 從標準正態分布採樣
        z = mu + sigma * epsilon        # 確定性變換
        return z

    def forward(self, x):
        # 編碼器輸出mu和sigma
        mu, sigma = self.encoder(x)
        
        # 使用重參數化採樣
        z = self.reparameterize(mu, sigma)
        
        # 解碼器重建輸入
        reconstruction = self.decoder(z)
        return reconstruction

吃貨視角的重參數化#

想像你在製作奶茶:

不使用重參數化

  • 直接製作一杯特定甜度的奶茶
  • 如果不好喝,你不知道是糖放多了還是放少了

使用重參數化

  1. 先準備一杯標準濃度的糖水(ϵ\epsilon
  2. 然後通過調整糖水的量(μ\mu)和稀釋程度(σ\sigma)來達到目標甜度
  3. 如果不好喝,你可以清楚地知道是糖水量還是稀釋程度需要調整(參數可優化)

image

總之,經過重參數化:

  • 梯度可以通過確定性變換傳播
  • 參數可以通過梯度下降優化
  • 隨機性被隔離,不影響梯度計算

正向數學推導#

xt1x_{t-1}xtx_t 的轉移

  • 給定 xt1x_{t-1},我們希望生成 xtx_t
  • q(xtxt1)=N(xt;1βtxt1,βtI)q(x_t \mid x_{t-1}) = \mathcal{N}(x_t; \sqrt{1 - \beta_t} x_{t-1}, \beta_t I) 使用重新參數化技巧,
Σ=βtI,σ2=βtσ=βt\begin{align*} \because \Sigma &= \beta_t I, \sigma^2 = \beta_t \\ \therefore \sigma &= \sqrt{\beta_t} \end{align*}

可將 xtx_t 表示為 xt1x_{t-1} 的確定性變換加上噪聲項:

xt=1βtxt1+βtϵ\begin{align*} x_t = \sqrt{1 - \beta_t} x_{t-1} + \sqrt{\beta_t} \epsilon \end{align*}
  • 這裡,1βtxt1\sqrt{1 - \beta_t} x_{t-1} 是均值部分,βtϵ\sqrt{\beta_t} \epsilon 是噪聲部分。由於 ϵ\epsilon 是標準正態分布的樣本,與模型參數無關,因此在反向傳播時,梯度只需考慮 1βt\sqrt{1 - \beta_t}βt\sqrt{\beta_t} 對應的參數。這使得模型可以通過梯度下降進行有效優化。

我們用 αt\alpha_t 簡化記法 & 記錄其乘積的累計:

image

可得:

q(xtxt1)=αtxt1+1αtϵq(x_t \mid x_{t-1}) = \sqrt{\alpha_t} x_{t-1} + \sqrt{1 - \alpha_t} \epsilon

計算兩步轉移:從 xt2x_{t-2}xtx_t

xt1=αt1xt2+1αt1ϵt1xt=αt(αt1xt2+1αt1ϵt1)+1αtϵtxt=αtαt1xt2+αt(1αt1)ϵt1+1αtϵt\begin{align*} x_{t-1} &= \sqrt{\alpha_{t-1}} x_{t-2} + \sqrt{1 - \alpha_{t-1}} \epsilon_{t-1} \\ x_t &= \sqrt{\alpha_t} \left( \sqrt{\alpha_{t-1}} x_{t-2} + \sqrt{1 - \alpha_{t-1}} \epsilon_{t-1} \right) + \sqrt{1 - \alpha_t} \epsilon_t \\ x_t &= \sqrt{\alpha_t \alpha_{t-1}} x_{t-2} + \sqrt{\alpha_t (1 - \alpha_{t-1})} \epsilon_{t-1} + \sqrt{1 - \alpha_t} \epsilon_t \end{align*}

因為 ϵt1\epsilon_{t-1}ϵt\epsilon_t 是獨立的標準正態分布,合併噪聲部分為一個新的噪聲項 ϵN(0,I)\epsilon \sim \mathcal{N}(0, I)

xt=αtαt1xt2+1αtαt1ϵx_t = \sqrt{\alpha_t \alpha_{t-1}} x_{t-2} + \sqrt{1 - \alpha_t \alpha_{t-1}} \epsilon

同理:

xt=αtαt1xt2+1αtαt1ϵxt=αtαt1αt2xt3+1αtαt1αt2ϵxt=αtαt1α1α0x0+1αtαt1α1α0ϵ通過歸納法,可以推出:xt=s=k+1tαsxk+1s=k+1tαsϵαˉt=s=1tαsk=0,xt=αˉtx0+1αˉtϵ(ϵN(0,I))\begin{align*} x_t &= \sqrt{\alpha_t \alpha_{t-1}} x_{t-2} + \sqrt{1 - \alpha_t \alpha_{t-1}} \epsilon \\ x_t &= \sqrt{\alpha_t \alpha_{t-1} \alpha_{t-2}} x_{t-3} + \sqrt{1 - \alpha_t \alpha_{t-1} \alpha_{t-2}} \epsilon \\ x_t &= \sqrt{\alpha_t \alpha_{t-1} \cdots \alpha_1 \alpha_0} x_0 + \sqrt{1 - \alpha_t \alpha_{t-1} \cdots \alpha_1 \alpha_0} \epsilon \\ 通過&歸納法,可以推出:\\ x_t &= \sqrt{\prod_{s=k+1}^t \alpha_s} x_k + \sqrt{1 - \prod_{s=k+1}^t \alpha_s} \epsilon \\ \because \bar{\alpha}_t &= \prod_{s=1}^t \alpha_s \\ \therefore 當& k=0 時,x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon \quad (\epsilon \sim \mathcal{N}(0, I)) \end{align*}

完整推導流程如下:

q(xtxt1)=N(xt;1βtxt1,βtI)=1βtxt1+βtϵ=αtxt1+1αtϵq(xtxt2)=αtαt1xt2+1αtαt1ϵq(xtxt3)=αtαt1αt2xt3+1αtαt1αt2ϵq(xtx0)=αtαt1α1α0x0+1αtαt1α1α0ϵ=αˉtx0+1αˉtϵ(ϵN(0,I))=N(xt;αˉtx0,(1αˉt)I)\begin{align} q(x_t \mid x_{t-1}) &= \mathcal{N}(x_t; \sqrt{1 - \beta_t} x_{t-1}, \beta_t I) \\ &= \sqrt{1 - \beta_t} x_{t-1} + \sqrt{\beta_t} \epsilon \\ &= \sqrt{\alpha_t} x_{t-1} + \sqrt{1 - \alpha_t} \epsilon \\ q(x_t \mid x_{t-2}) &= \sqrt{\alpha_t \alpha_{t-1}} x_{t-2} + \sqrt{1 - \alpha_t \alpha_{t-1}} \epsilon \\ q(x_t \mid x_{t-3}) &= \sqrt{\alpha_t \alpha_{t-1} \alpha_{t-2}} x_{t-3} + \sqrt{1 - \alpha_t \alpha_{t-1} \alpha_{t-2}} \epsilon \\ q(x_t \mid x_0) &= \sqrt{\alpha_t \alpha_{t-1} \cdots \alpha_1 \alpha_0} x_0 + \sqrt{1 - \alpha_t \alpha_{t-1} \cdots \alpha_1 \alpha_0} \epsilon \\ &= \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon \quad (\epsilon \sim \mathcal{N}(0, I))\\ &= \mathcal{N}(x_t; \sqrt{\bar{\alpha}_t} x_0, (1 - \bar{\alpha}_t) I) \end{align}

逆向數學推導#

由於方差是固定的而不需要去學習 (見 1.3),所以我們只需要神經網絡去預測均值:

image

我們的最終目標是預測兩個時間步中的噪聲,現在從損失函數開始分析:

log(pθ(x0))-log(p_\theta(x_0))

但是這個負對數似然中,x0x_0的概率依賴前面的所有其它時間步。我們可以學習一個逼近這些條件概率的模型作為解決方案,這裡就需要用到 Varational Lower Bound (變分下界) 來得到一個更好計算的公式。

變分下界#

image

假設我們有一個無法計算的函數f(x)f(x),在我們的場景中是負對數似然,我們可以找一個始終滿足 g(x)f(x)g(x) \leq f(x) 的可計算函數 g(x)g(x) ,那麼優化 g(x)g(x) 也可以讓 f(x)f(x) 增加:
Screenshot 2024-12-12 at 20.06.47

我們這裡通過減去 KL 散度來保證這點,KL 散度是衡量兩個分布相似度的指標,始終非負

DKL(pq)=xp(x)logp(x)q(x)dxD_{KL}(p \| q) = \int_x p(x) \log \frac{p(x)}{q(x)} \, dx

減去一個始終非負的項可以保證結果始終小於原始函數,這裡用 “+” 是因為我們想要最小化損失,所以加上後始終保證原始的負對數似然大:

log(pθ(x0))log(pθ(x0))+DKL(q(x1:Tx0)pθ(x1:Tx0))-\log(p_\theta(x_0)) \leq -\log(p_\theta(x_0)) + D_{KL}(q(x_{1:T} \mid x_0) \| p_\theta(x_{1:T} \mid x_0))

這種形式下,因為負對數似然還在,下界仍然是不可計算的,我們需要得到一個更好的表達。首先將 KL 散度重寫為兩個項的對數比:

log(pθ(x0))log(pθ(x0))+DKL(q(x1:Tx0)pθ(x1:Tx0))=log(pθ(x0))+log(q(x1:Tx0)pθ(x1:Tx0))\begin{align*} -\log(p_\theta(x_0)) &\leq -\log(p_\theta(x_0)) + D_{KL}(q(x_{1:T} \mid x_0) \| p_\theta(x_{1:T} \mid x_0)) \\ &=-\log(p_\theta(x_0)) + \log \left( \frac{q(x_{1:T} \mid x_0)}{p_\theta(x_{1:T} \mid x_0)} \right) \\ \end{align*}

再對其分母應用貝葉斯法則:

pθ(x1:Tx0)=pθ(x0x1:T)pθ(x1:T)pθ(x0)p_\theta(x_{1:T} \mid x_0)= \frac{p_\theta(x_0 \mid x_{1:T}) p_\theta(x_{1:T})}{p_\theta(x_0)}

Note

貝葉斯法則: p(AB)=p(BA)p(A)p(B)p(A \mid B) = \frac{p(B \mid A) p(A)}{p(B)}

上式的分子部分 pθ(x0x1:T)pθ(x1:T)p_\theta(x_0 \mid x_{1:T}) p_\theta(x_{1:T}) 實際上是聯合概率 pθ(x0,x1:T)p_\theta(x_0, x_{1:T}),因為:

pθ(x0,x1:T)=pθ(x0x1:T)pθ(x1:T)p_\theta(x_0, x_{1:T}) = p_\theta(x_0 \mid x_{1:T}) p_\theta(x_{1:T})

通常, pθ(x0:T)p_\theta(x_{0:T}) 表示 x0x_0 和所有中間步驟 x1:Tx_{1:T} 的聯合概率,即:

pθ(x0:T)=pθ(x0,x1:T)p_\theta(x_{0:T}) = p_\theta(x_0, x_{1:T})

Note

pθ(x0:T)p_\theta(x_{0:T}) 表示從時間步 0 到 TT 的所有狀態 x0,x1,,xTx_0, x_1, \ldots, x_T 的聯合概率分布。

pθ(x0:T)=p(xT)t=1Tpθ(xt1xt) p_\theta(x_{0:T}) = p(x_T) \prod_{t=1}^T p_\theta(x_{t-1} \mid x_t)

代入有:

log(q(x1:Tx0)pθ(x1:Tx0))=log(q(x1:Tx0)pθ(x0:T)pθ(x0))將分母pθ(x0:T)pθ(x0)轉化為乘法形式:1pθ(x0:T)pθ(x0)=pθ(x0)pθ(x0:T)=log(q(x1:Tx0)pθ(x0:T))+log(pθ(x0))\begin{align*} \log \left( \frac{q(x_{1:T} \mid x_0)}{p_\theta(x_{1:T} \mid x_0)} \right) = \log \left( \frac{q(x_{1:T} \mid x_0)}{\frac{p_\theta(x_{0:T})}{p_\theta(x_0)}} \right) \\ 將分母\frac{p_\theta(x_{0:T})}{p_\theta(x_0)}轉化為乘法形式:\frac{1}{\frac{p_\theta(x_{0:T})}{p_\theta(x_0)}} = \frac{p_\theta(x_0)}{p_\theta(x_{0:T})} \\ = \log \left( \frac{q(x_{1:T} \mid x_0)}{p_\theta (x_{0:T})} \right) &+ \log(p_\theta(x_0)) \end{align*}

即按照下圖中的流程得到最後形式:

Screenshot 2024-12-12 at 22.58.23

豁然開朗,煩人的兩項消掉了:

log(pθ(x0))log(pθ(x0))+log(q(x1:Tx0)pθ(x0:T))+log(pθ(x0))=log(q(x1:Tx0)pθ(x0:T))\begin{align*} -\log(p_\theta(x_0)) &\leq -\log(p_\theta(x_0)) + \log \left( \frac{q(x_{1:T} \mid x_0)}{p_\theta (x_{0:T})} \right) + \log(p_\theta(x_0)) \\ &= \log \left( \frac{q(x_{1:T} \mid x_0)}{p_\theta (x_{0:T})} \right) \end{align*}

這樣就得到了可以最小化的下限,並且式中的內容都是已知的:

  • 分子是正向過程的聯合概率分布:q(x1:Tx0)=t=1Tq(xtxt1)q(x_{1:T} \mid x_0)=\prod_{t=1}^T q(x_t \mid x_{t-1})
  • 分母是逆向過程的聯合概率分布:pθ(x0:T)=p(xT)t=1Tpθ(xt1xt)p_\theta (x_{0:T})=p(x_T) \prod_{t=1}^T p_\theta(x_{t-1} \mid x_t)

為了讓其有解析解,還需要幾個額外的重組步驟:

log(q(x1:Tx0)pθ(x0:T))=log(t=1Tq(xtxt1)p(xT)t=1Tpθ(xt1xt))=log(1p(xT)t=1Tq(xtxt1)t=1Tpθ(xt1xt))=log(1p(xT))+log(t=1Tq(xtxt1)t=1Tpθ(xt1xt))=log(p(xT))+log(t=1Tq(xtxt1)t=1Tpθ(xt1xt))=log(p(xT))+t=1Tlog(q(xtxt1)pθ(xt1xt))=log(p(xT))+t=2Tlog(q(xtxt1)pθ(xt1xt))+log(q(x1x0)pθ(x0x1))\begin{align} \log \left( \frac{q(x_{1:T} \mid x_0)}{p_\theta(x_{0:T})} \right) &= \log \left( \frac{\prod_{t=1}^T q(x_t \mid x_{t-1})}{p(x_T) \prod_{t=1}^T p_\theta(x_{t-1} \mid x_t)} \right) \\ &= \log \left( \frac{1}{p(x_T)} \cdot \frac{\prod_{t=1}^T q(x_t \mid x_{t-1})}{\prod_{t=1}^T p_\theta(x_{t-1} \mid x_t)} \right)\\ &= \log \left( \frac{1}{p(x_T)} \right) + \log \left( \frac{\prod_{t=1}^T q(x_t \mid x_{t-1})}{\prod_{t=1}^T p_\theta(x_{t-1} \mid x_t)} \right) \\ &= -\log(p(x_T)) + \log \left( \frac{\prod_{t=1}^T q(x_t \mid x_{t-1})}{\prod_{t=1}^T p_\theta(x_{t-1} \mid x_t)} \right) \\ &=-\log(p(x_T)) + \sum_{t=1}^T \log \left( \frac{q(x_t \mid x_{t-1})}{p_\theta(x_{t-1} \mid x_t)} \right) \\ &=- \log(p(x_T)) + \sum_{t=2}^T \log \left( \frac{q(x_t \mid x_{t-1})}{p_\theta(x_{t-1} \mid x_t)} \right) + \log \left( \frac{q(x_1 \mid x_0)}{p_\theta(x_0 \mid x_1)} \right) \end{align}

根據貝葉斯法則重寫 sum 項的分子:q(xtxt1)=q(xt1xt)q(xt)q(xt1)q(x_t \mid x_{t-1})=\frac{q(x_{t-1}\mid x_t)q(x_t)}{q(x_{t-1})}

但這又回到了前面,這些項都是需要估計全部樣本導致 high variance,如給出下圖所示的 xtx_t,你很難確定上個狀態是什麼樣的:

image

改進思路則為通過直接條件化於原始數據 x0x_0

q(xt1xt,x0)q(xtx0)q(xt1x0)\Longrightarrow \frac{q(x_{t-1} \mid x_t, x_0) q(x_t \mid x_0)}{q(x_{t-1} \mid x_0)}

這樣同時給出無噪聲的圖像,侯選的 xt1x_{t-1} 就少了,方差會減小:

image

代入回原式:

=log(p(xT))+t=2Tlog(q(xt1xt,x0)q(xtx0)pθ(xt1xt)q(xt1x0))+log(q(x1x0)pθ(x0x1))=log(p(xT))+t=2Tlog(q(xt1xt,x0)pθ(xt1xt))+t=2Tlog(q(xtx0)q(xt1x0))+log(q(x1x0)pθ(x0x1))\begin{align} &= - \log(p(x_T)) + \sum_{t=2}^T \log \left( \frac{q(x_{t-1} \mid x_t, x_0) q(x_t \mid x_0)}{p_\theta(x_{t-1} \mid x_t) q(x_{t-1} \mid x_0)} \right) + \log \left( \frac{q(x_1 \mid x_0)}{p_\theta(x_0 \mid x_1)} \right) \\ &= - \log(p(x_T)) + \sum_{t=2}^T \log \left( \frac{q(x_{t-1} \mid x_t, x_0)}{p_\theta(x_{t-1} \mid x_t)} \right) + \sum_{t=2}^T \log \left( \frac{q(x_t \mid x_0)}{q(x_{t-1} \mid x_0)} \right) + \log \left( \frac{q(x_1 \mid x_0)}{p_\theta(x_0 \mid x_1)} \right) \end{align}

展開第二個 sum 項,可以發現大部分項都被化簡掉了:

image

=log(p(xT))+t=2Tlog(q(xt1xt,x0)pθ(xt1xt))+log(q(xTx0)q(x1x0))+log(q(x1x0)pθ(x0x1))\begin{align} &= - \log(p(x_T)) + \sum_{t=2}^T \log \left( \frac{q(x_{t-1} \mid x_t, x_0)}{p_\theta(x_{t-1} \mid x_t)} \right) + \log \left( \frac{q(x_T \mid x_0)}{q(x_{1} \mid x_0)} \right) + \log \left( \frac{q(x_1 \mid x_0)}{p_\theta(x_0 \mid x_1)} \right) \end{align}

對最後兩項應用 log rules 可以化簡一些項:

log(q(xTx0)q(x1x0))+log(q(x1x0)pθ(x0x1))=[logq(xTx0)logq(x1x0)]+[logq(x1x0)logpθ(x0x1)]=logq(xTx0)logpθ(x0x1)\begin{align*} \log \left( \frac{q(x_T \mid x_0)}{q(x_{1} \mid x_0)} \right) + \log \left( \frac{q(x_1 \mid x_0)}{p_\theta(x_0 \mid x_1)} \right)&=\left[ \log q(x_T \mid x_0) - \log q(x_{1} \mid x_0) \right] + \left[ \log q(x_1 \mid x_0) - \log p_\theta(x_0 \mid x_1) \right] \\ &=\log q(x_T \mid x_0) - \log p_\theta(x_0 \mid x_1) \end{align*}

再將化簡後的第一項移到前面,合併成一個對數得到最終解析形式:

=log(p(xT))+t=2Tlog(q(xt1xt,x0)pθ(xt1xt))+logq(xTx0)logpθ(x0x1)=log(q(xTx0)p(xT))+t=2Tlog(q(xt1xt,x0)pθ(xt1xt))log(pθ(x0x1))=DKL(q(xTx0)p(xT))+t=2TDKL(q(xt1xt,x0)pθ(xt1xt))log(pθ(x0x1))=t=2TDKL(q(xt1xt,x0)pθ(xt1xt))log(pθ(x0x1))\begin{align} &= - \log(p(x_T)) + \sum_{t=2}^T \log \left( \frac{q(x_{t-1} \mid x_t, x_0)}{p_\theta(x_{t-1} \mid x_t)} \right) + \log q(x_T \mid x_0)- \log p_\theta(x_0 \mid x_1) \\ &= \log(\frac{q(x_T\mid x_0)}{p(x_T)}) + \sum_{t=2}^T \log \left( \frac{q(x_{t-1} \mid x_t, x_0)}{p_\theta(x_{t-1} \mid x_t)} \right) - \log(p_\theta(x_0 \mid x_1)) \\ &= D_{KL}(q(x_T | x_0) \| p(x_T)) + \sum_{t=2}^T D_{KL}(q(x_{t-1} | x_t, x_0) \| p_\theta(x_{t-1} | x_t)) - \log(p_\theta(x_0 | x_1)) \\ &= \sum_{t=2}^T D_{KL}(q(x_{t-1} | x_t, x_0) \| p_\theta(x_{t-1} | x_t)) - \log(p_\theta(x_0 | x_1)) \end{align}

這個形式的第一項是可以忽略的,因為 qq 沒有可學習參數,只是加噪聲的正向過程,會收斂為正態分布,而 p(xT)p(x_T) 只是從高斯分布中隨機採樣的噪聲,因此可以確定該項 KL 散度會很小。

剩余兩項的推導結果如下(過程省略,詳見 Lilian's Blog)
Screenshot 2024-12-13 at 14.43.24

β\beta 是固定的,那麼就關注 μ\mu 的形式:

μ~t(xt,x0)=αˉt(1αˉt1)1αˉtxt+αˉt1βt1αˉtx0\tilde{\mu}_t(x_t, x_0) = \frac{\sqrt{\bar{\alpha}_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} x_t + \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1 - \bar{\alpha}_t} x_0

正向過程生成的閉合形式 xt=αˉtx0+1αˉtϵx_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon 可以重寫為 x0x_0 形式:

x0=1αˉt(xt1αˉtϵ)x_0 = \frac{1}{\sqrt{\bar{\alpha}_t}} \left( x_t - \sqrt{1 - \bar{\alpha}_t} \epsilon \right)

將上述 x0x_0 的表達式代入預測均值公式 μ~t\tilde{\mu}_t

μ~t=αˉt(1αˉt1)1αˉtxt+αˉt1βt1αˉt1αˉt(xt1αˉtϵ)\tilde{\mu}_t = \frac{\sqrt{\bar{\alpha}_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} x_t + \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1 - \bar{\alpha}_t} \cdot \frac{1}{\sqrt{\bar{\alpha}_t}} \left( x_t - \sqrt{1 - \bar{\alpha}_t} \epsilon \right)

現在 μ\mu 不再依賴 x0x_0。繼續化簡,首先展開第二項:

αˉt1βt1αˉt1αˉt(xt1αˉtϵ)=αˉt1βtαˉt(1αˉt)xtαˉt1βt1αˉtαˉt(1αˉt)ϵ\frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1 - \bar{\alpha}_t} \cdot \frac{1}{\sqrt{\bar{\alpha}_t}} \left( x_t - \sqrt{1 - \bar{\alpha}_t} \epsilon \right) = \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{\sqrt{\bar{\alpha}_t} (1 - \bar{\alpha}_t)} x_t - \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t \sqrt{1 - \bar{\alpha}_t}}{\sqrt{\bar{\alpha}_t} (1 - \bar{\alpha}_t)} \epsilon

xtx_t 項進行合併:

μ~t=(αˉt(1αˉt1)1αˉt+αˉt1βtαˉt(1αˉt))xtαˉt1βt1αˉtαˉt(1αˉt)ϵ\tilde{\mu}_t = \left( \frac{\sqrt{\bar{\alpha}_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} + \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{\sqrt{\bar{\alpha}_t} (1 - \bar{\alpha}_t)} \right) x_t - \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t \sqrt{1 - \bar{\alpha}_t}}{\sqrt{\bar{\alpha}_t} (1 - \bar{\alpha}_t)} \epsilon

xtx_t 的系數進行進一步合併和化簡,最終得到:

μ~t=1αˉt(xtβt1αˉtϵ)\tilde{\mu}_t = \frac{1}{\sqrt{\bar{\alpha}_t}} \left( x_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon \right)

這表示我們基本上只是減去 xtx_t 生成的隨機縮放噪聲,這就是神經網絡要預測的東西。

代入後的損失函數 LtL_t 定義為一個均方誤差:

Lt=12σt21αˉt(xtβt1αˉtϵ)μθ(xt,t)2=12σt21αˉt(xtβt1αˉtϵ)1αˉt(xtβt1αˉtϵθ(xt,t))2=12σt2βtαˉt(1αˉt)(ϵϵθ(xt,t))2=βt22σt2αˉt(1αˉt)ϵϵθ(xt,t)2\begin{align*} L_t &= \frac{1}{2\sigma_t^2} \left\| \frac{1}{\sqrt{\bar{\alpha}_t}} \left( x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon \right) - \mu_\theta(x_t, t) \right\|^2 \\ &= \frac{1}{2\sigma_t^2} \left\| \frac{1}{\sqrt{\bar{\alpha}_t}} \left( x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon \right) - \frac{1}{\sqrt{\bar{\alpha}_t}} \left( x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(x_t, t) \right) \right\|^2 \\ &= \frac{1}{2\sigma_t^2} \left\| \frac{\beta_t}{\sqrt{\bar{\alpha}_t (1-\bar{\alpha}_t)}} (\epsilon - \epsilon_\theta(x_t, t)) \right\|^2 \\ &= \frac{\beta_t^2}{2\sigma_t^2 \bar{\alpha}_t (1-\bar{\alpha}_t)} \|\epsilon - \epsilon_\theta(x_t, t)\|^2 \end{align*}

最後的形式就是時間步 tt 的實際噪聲和神經網絡預測噪聲之間的均方誤差。研究人員發現忽略前面的縮放項會得到更好的採樣質量並且更容易實現。

βt22σt2αt(1α^t)ϵϵθ(xt,t)2ϵϵθ(xt,t)2\frac{\beta_t^2}{2 \sigma_t^2 \alpha_t (1 - \hat{\alpha}_t)} \left\| \epsilon - \epsilon_\theta(x_t, t) \right\|^2 \longrightarrow \left\| \epsilon - \epsilon_\theta(x_t, t) \right\|^2

回到原始公式

N(xt1;1αt(xtβt1αˉtϵθ(xt,t)),βt)\mathcal{N}\left(x_{t-1}; \frac{1}{\sqrt{\alpha_t}} \left(x_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta(x_t, t)\right), \beta_t\right)

作者決定在最後一步的採樣中,不再添加額外的隨機噪聲使生成過程更加穩定:

Screenshot 2024-12-13 at 15.51.11

最後的形式為:

Lsimple=Et,x0,ϵ[ϵϵθ(αˉtx0+1αˉtϵ,t)2]    Et,x0,ϵ[ϵϵθ(xt,t)2]\begin{align} L_{\text{simple}} &= \mathbb{E}_{t, \mathbf{x}_0, \boldsymbol{\epsilon}} \left[ \left\| \boldsymbol{\epsilon} - \boldsymbol{\epsilon}_\theta \left( \sqrt{\bar{\alpha}_t} \mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t} \boldsymbol{\epsilon}, t \right) \right\|^2 \right] \\ &\implies \mathbb{E}_{t, \mathbf{x}_0, \boldsymbol{\epsilon}} \left[ \left\| \boldsymbol{\epsilon} - \boldsymbol{\epsilon}_\theta \left( \mathbf{x}_t, t \right) \right\|^2 \right] \end{align}
  • Et,x0,ϵ\mathbb{E}_{t, \mathbf{x}_0,\boldsymbol{\epsilon}} 表示對時間步 tt、原始數據 x0\mathbf{x}_0 和噪聲 ϵ\boldsymbol{\epsilon} 取期望
  • ϵ\boldsymbol{\epsilon} 是實際添加的隨機噪聲
  • ϵθ\boldsymbol{\epsilon}_\theta 是神經網絡預測的噪聲
  • αˉtx0+1αˉtϵ\sqrt{\bar{\alpha}_t} \mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t} \boldsymbol{\epsilon} 是前向過程的閉式解,表示在時間步 tt 的噪聲數據,因此可以簡化為:
    • xt\mathbf{x}_t 直接表示時間步 tt 的噪聲數據,即 αˉtx0+1αˉtϵ\sqrt{\bar{\alpha}_t} \mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t} \boldsymbol{\epsilon}
    • 整個損失函數本質上是在衡量預測噪聲與實際噪聲之間的均方誤差

其中,時間步 tt 通常是從均勻分布中採樣的(即 tUniform(1,T)t∼Uniform(1,T),T 是總的時間步數)。這種選擇確保了在訓練過程中,每個時間步都有相同的概率被選擇,從而使模型在所有時間步上都能有效地學習去噪過程。

訓練#

image

首先我們從數據集中採樣一些圖像,然後採樣 tt 和來自正態分布的噪聲,然後通過梯度下降優化目標

採樣#

首先從正態分布中採樣 xtx_t,然後用前面展示過的公式通過重參數化來採樣 xt1x_{t-1}

image

注意這裡 t=1t=1 時是不增加噪聲的,根據公式

x0=1αt(x1βt1αˉtϵθ(x1,1))x_0 = \frac{1}{\sqrt{\alpha_t}} \left( x_1 - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta(x_1, 1) \right)

t=1t = 1 時,公式用於從 x1x_1 恢復到 x0x_0,這是去噪過程的最後一步。此時,我們希望盡可能準確地重建原始圖像。在最後一步不添加噪聲(即沒有 βtϵ\sqrt{\beta_t} \epsilon 項),可以避免在生成最終圖像時引入不必要的隨機性,從而保持圖像的清晰度和細節。

代碼實現#

推薦 知乎 Sunrise 的 MLP 簡化實現
後面有時間的話考慮做一下 Stable Diffusion 的代碼手撕,挖個坑...

參考資料#

載入中......
此文章數據所有權由區塊鏈加密技術和智能合約保障僅歸創作者所有。