本文主要以該視頻的教學邏輯為主線,結合講解內容進行整理和闡述,如有錯誤歡迎在評論區糾正!
Flow Matching:讓我們從第一性原理重塑生成模型#
好了,我們來聊聊 generative models
。目標很簡單,right
?我們有一個數據集,比如說,一堆貓的圖片,它們來自某個瘋狂的、高維的概率分布 。我們想訓練一個模型,它能吐出新的貓咪圖片。目標簡單,但實現方法可能會變得…… 嗯,相當 gnarly
。
你可能聽說過 Diffusion models
。整個想法就是從一張圖片開始,通過幾百個步驟慢慢地加入噪聲,然後訓練一個巨大的網絡來一步步地逆轉這個過程。它背後的數學涉及到 score functions
()、隨機微分方程 (stochastic differential equations
, SDEs)……it's a whole thing
。這套方法確實有效,而且效果出奇地好,但作為一個計算機科學家,我總是會想:我們能用一種更簡單、更直接的方式達到同樣的目標嗎?有沒有辦法 hack
一下?
讓我們退一步,從零開始。第一性原理。
核心問題#
我們有兩個分布:
- :一個我們可以輕鬆採樣的、超級簡單的噪聲分布。可以想象成
x0 = torch.randn_like(image)
。 - :我們真實數據的、複雜的、未知的分布(比如貓!)。我們可以通過從數據集中加載一張圖片來從中採樣。
我們想要學習一個 mapping
,它能接收一個來自 的樣本,然後把它變成一個來自 的樣本。
Diffusion
的方式是定義一個複雜的分布路徑 ,讓 慢慢地演變成 ,然後再學習如何逆轉它。但這個中間分布 正是所有數學複雜性的來源。
那麼,我們能做的最簡單的事情是什麼呢?
一個天真的、“高中物理” 般的想法#
如果我們只是…… 畫一條直線呢?
說真的。讓我們挑選一個噪聲樣本 ,和一個真實的貓咪圖片樣本 。它們之間最簡單的路徑是什麼?當然是線性插值。
在這裡, 是我們的 “時間” 參數,從 變到 。
- 當 時,我們處於噪聲 。
- 當 時,我們處於貓咪圖片 。
- 在兩者之間的任何時間 ,我們都處於兩者的某種模糊混合狀態。
好了,很簡單。現在,如果我們想象一個粒子在一秒鐘內沿著這條直線路徑從 移動到 ,它的速度是多少?再次,讓我們堅持高中物理知識,只對時間 求導:
等一下。讓這個結論在你腦中停留一會兒。
對於這條簡單的直線路徑,我們的粒子在任何時間點的速度,都只是那個從起點指向終點的恆定向量。這是你能想象到的最簡單的 vector field
。
這就是那個 “Aha!” 時刻。如果這就是我們所需要的一切呢?
構建模型#
我們有了一个目标!我们想学习一个 vector field
。让我们定义一个神经网络 ,它接收任何点 和任何时间 作为输入,然后输出一个向量,也就是它对该点速度的预测。
我们如何训练它?嗯,我们希望网络的输出与我们的简单目标速度 相匹配。最直接的方法就是使用 Mean Squared Error
损失。
所以,我们整个的训练目标就变成了:
让我们来分解一下 training loop
。它简直是滑稽般的简单:
x1 = sample_from_dataset() # 抓一张真实的猫咪图片
x0 = torch.randn_like(x1) # 抓一些噪声
t = torch.rand(1) # 随机选一个时间
xt = (1 - t) * x0 + t * x1 # 插值得到我们的训练输入点
predicted_velocity = model(xt, t) # 让模型给出速度预测
target_velocity = x1 - x0 # 这就是我们的 ground truth!
loss = mse_loss(predicted_velocity, target_velocity)
loss.backward()
optimizer.step()
Boom
. 就是这样。这就是 Conditional Flow Matching 的核心。我们把一个令人费解的概率分布匹配问题,变成了一个简单的回归问题。
为什么这如此强大:“Simulation-Free”#
注意我们没有做什么。我们从未需要提及那个复杂的边缘分布 。我们从未需要定义或估计一个 score function
。我们完全绕过了整个 SDE/PDE 的理论体系。
我们所需要的只是能够抽取点对 并在它们之间进行插值的能力。这就是为什么它被称为 simulation-free 训练。它难以置信地直接。
生成图片 (Inference)#
好了,我们已经训练好我们的网络 ,让它成为了一个从噪声导航到数据的优秀 “GPS”。我们如何生成一张新的猫咪图片呢?
我们只需遵循它的指示!
- 从一个随机的噪声点开始:
x = torch.randn(...)
。 - 从时间 开始。
- 迭代若干步:
a. 从我们的 “GPS” 获取方向:velocity = model(x, t)
。
b. 朝那个方向迈出一小步:x = x + velocity * dt
。
c. 更新时间:t = t + dt
。 - 经过足够多的步骤(例如,当 到达 1 时),
x
就会成为我们全新的猫咪图片。
这个过程只是在求解一个常微分方程 (Ordinary Differential Equation
, ODE)。它基本上就是欧拉方法,你也可能在高中学过。Pretty cool, right?
总结#
所以,回顾一下,Flow Matching
给了我们一个全新的、更简单的视角来看待 generative modeling
。我们不再考虑概率密度和分数,而是考虑向量场和流。我们定义了一条从噪声到数据的简单路径(比如一条直线),然后训练一个神经网络来学习产生这条路径的速度场。
事实证明,这个简单的、直观的想法不仅仅是一个 hack
;它在理论上是健全的,并且为像 SD3 这样的一些最新的 state-of-the-art
模型提供了动力。它完美地提醒了我们,有时,最深刻的进步来自于为一个复杂问题找到一个更简单的 abstraction
。
Simplicity wins.
Flow Matching 的 “正规推导”:为啥我们那个简单的 “黑科技” 是可行的?#
好,前面我们用一个超级简单直观的想法推导出了 Flow Matching 的核心。我们取一个噪声样本 ,一个真实数据样本 ,在它们之间画一条直线(),然后说我们的神经网络 只需要学习它的速度就行了... 也就是 。损失函数几乎是自己蹦出来的。搞定。
老实说,对于一个实践者,这已经是你需要知道的 90% 的内容了。
但如果你和我一样,你脑海里可能会有个小声音在嘀咕:“等等... 这感觉也太轻松了吧。我们那个技巧是建立在一对独立的样本点 上的。凭什么通过学习这些独立的直线,就能让我们的网络理解整个高维概率分布 的流动呢?我们那个简单的技巧,到底是一个站得住脚的捷径,还是一个碰巧奏效的、有点可爱的‘黑科技’?”
这就是论文里那部分形式化推导的用武之地了。它的目的,就是为了证明我们那个简单的、conditional
(基于条件的)目标函数,确实能够魔法般地优化那个更宏大、更吓人的 marginal
(基于边缘分布的)目标。
让我们暂时戴上数学家的帽子,看看他们是如何填平这条鸿沟的。
那个 “官方的”、纯理论的问题:边缘流 (Marginal Flows)#
“真正的”、理论上纯粹的问题是这样的:我们有一系列概率分布 ,它从噪声 逐渐 “变形” 为数据 。这个连续的变形过程由一个 vector field
(向量场) 所支配。这个 就是在时间 、位置 处点的速度向量。
所以,“官方” 目标是训练我们的网络 来匹配这个真实的、边缘向量场 。对应的损失函数应该是:
看到这个公式,我们应该立刻意识到:这玩意儿简直是一场灾难。它完全是 intractable
(没法处理的)。我们没法从 中采样,也压根不知道目标 是什么。所以,暂时来看,这个损失函数毫无用处。
沟通的桥梁:连接 “边缘” 与 “条件”#
于是,研究者们使出了一招经典的数学招式。他们说:“好吧,边缘场 是个猛兽。但我们能不能把它表示成一大堆简单的、conditional
(条件的)向量场的平均值呢?”
一个条件向量场,我们称之为 ,指的是一个点 的速度,前提是我们已经知道它的最终目的地是数据点 。
论文证明了(这也是其核心的理论洞察),那个吓人的边缘场 ,其实就是所有简单的条件场的期望值,并由 “一条从 出发的路径会经过 的概率” 进行加权:
这就建立了一座桥梁。我们把一个未知的东西 () 和一堆我们或许可以定义出来的、更简单的东西 () 联系了起来。
我们的出发点是那个 “官方的”、理论上正确但无法直接优化的 边缘流损失函数 (Marginal Flow Matching Loss)。对于任意一个时间步 ,其形式如下:
这里的 是在时间 的边缘概率密度,而 是我们想要学习的真实边缘向量场。这两个我们都无法得到,所以这个形式是无法计算的。我们的目标是通过数学变换,把它变成一个可以计算的形式。
第一步:展开平方误差项
我们使用代数恒等式 来展开损失函数:
注意到,在优化过程中,我们只关心和我们模型的参数 相关的项。上式中的 是真实向量场的模长平方,它不依赖于 ,因此在求梯度时可以被看作一个常数项。为了最小化 ,我们只需要最小化剩余的部分即可:
第二步:将期望重写为积分
根据期望的定义 ,我们将上式重写为积分形式。我们重点关注包含未知项 的第二部分,即交叉项:
第三步:代入连接 “边缘” 与 “条件” 的桥梁公式
这里的关键在于一个核心等式,它将难以处理的边缘项 和可以定义的条件项联系起来。这个等式是:
我们将这个等式代入到我们重点关注的交叉项中:
第四步:交换积分次序 (Fubini-Tonelli 定理)
现在我们得到了一个双重积分。这个表达式看起来更复杂了,但我们可以利用 Fubini-Tonelli 定理交换 和 的积分次序。这个操作是合法的,它能让我们重新组合被积函数:
第五步:将积分重新变回期望形式并完成配方
仔细观察第四步中括号内的部分:。这正是关于条件概率分布 的期望!所以,我们可以将内部的积分写成 。
现在再看整个表达式,它又变成了关于 的积分,所以我们又可以把它写成关于 的期望:
这个嵌套的期望可以合并成一个关于联合分布的期望:
现在,我们将这个变换后的交叉项代回到 中。同时,通过类似的变换,第一项也可以被重写:。
于是我们得到:
为了得到一个完美的平方形式,我们对上式加上再减去同一个项 :
中括号里的部分正好构成了一个完全平方式。而减去的最后一项不依赖于模型参数 ,因此在优化时可以忽略。
最终结果
我们成功证明了,最小化最初那个无法处理的边缘损失函数,等价于最小化下面这个可以处理的条件流匹配目标函数 (Conditional Flow Matching Objective):
至此,我们便在数学上严格证明了,只需要定义一个简单的条件路径(如线性插值)和其对应的向量场,并优化这个简单的回归损失,就能达到优化真实边缘流的宏大目标。
这是巨大的一步!我们成功消除了对边缘密度 的依赖。我们的损失函数现在只依赖于条件路径的密度 和条件向量场 。
回到我们最初的简单想法#
那么,我们現在到哪一步了?這個形式化證明告訴我們,只要我們能定義一個條件路徑 和它對應的向量場 ,我們就能用上面的 損失來訓練模型。
現在,我們終於可以把最初那個 “高中物理” 級別的簡單想法給請回來了。我們可以自由地定義這個條件路徑。那么,我們就選一個最簡單、最不做作的定義:
- 定義條件路徑 :就讓路徑是確定性的,一條直線。所以,概率分布在 這條線上是 1,在其他任何地方都是 0。(而 Diffusion 中從 出發的路徑是隨機的)
- 定義條件向量場 :正如我們之前計算的,這條路徑的速度就是 。
Note
在數學中,這種 “全部集中於一點,別處皆為零” 的特殊分布,被稱為狄拉克 函數 (Dirac delta function)。所以,當我們選擇一條直線路徑時,我們其實是選擇了狄拉克函數作為我們的條件概率分布 。
現在,把這兩個簡單的定義代入到我們剛剛推導出的、那個看起來很高級的 目標函數中。期望 就變成了 “在我們的直線上取點 ”,目標 就變成了我們簡單的 。
於是,見證奇跡的時刻到了,我們最終得到了:
我們回到了那個完全相同、無比簡潔的損失函數,也就是我們從那個最天真的第一性原理推導中 “猜” 出來的那個!這就是整個形式化證明的意義所在!
總結#
好吧,相當酷。我們剛剛經歷了一大堆複雜的數學推導 —— 積分、富比尼定理,全套流程 —— 結果只是為了證明我們那個簡單直觀的 “取巧” 方法,從一開始就是正確的。我們已經確認:在一條直線路徑上學習那個簡單的向量目標 ,確實是訓練生成模型的一種有效方式。
從理論到 torch
:編碼流程匹配#
好了,我們已經了解了直觀的想法,甚至還看過了詳細的正式證明。歸根結底,這都是一個簡單的回歸問題。但空談無益,讓我們來看代碼。
令人驚嘆的是,PyTorch 的實現幾乎是我們最終簡單公式的 1:1 翻譯。沒有隱藏的複雜性,沒有令人害怕的數學庫。只是純粹的 torch
。
讓我們拆解一下腳本中最重要的部分:訓練循環和採樣(推理)過程。
源代码见视频配套的实现:https://github.com/dome272/Flow-Matching/blob/main/flow-matching.ipynb
設置:數據與模型#
首先,腳本設置了一個二維棋盤格圖案。這是我們的小型 “貓咪圖片數據集”。這些點是我們的真實數據,。
然後,它定義了一個簡單的 MLP(多層感知機)。這就是我們的神經網絡,我們的 “GPS”,我們的向量場預測器 。它是一個標準的網絡,接受一批坐標 x 和一批時間值 t,並為每個點輸出一個預測的速度向量。這裡沒有什麼花哨的,魔法不在於架構,而在於我們讓它做什麼。
訓練循環:魔法就在這裡發生#
這就是實現的核心部分。讓我們回顧一下博客中那最終、優美的損失函數:
現在,讓我們逐行查看訓練循環的代碼。這正是這個公式在起作用。
data = torch.Tensor(sampled_points)
training_steps = 100_000
batch_size = 64
pbar = tqdm.tqdm(range(training_steps))
losses = []
for i in pbar:
# 1. Sample real data x1 and noise x0
x1 = data[torch.randint(data.size(0), (batch_size,))]
x0 = torch.randn_like(x1)
# 2. Define the target vector
target = x1 - x0
# 3. Sample random time t and create the interpolated input xt
t = torch.rand(x1.size(0))
xt = (1 - t[:, None]) * x0 + t[:, None] * x1
# 4. Get the model's prediction
pred = model(xt, t) # also add t here
# 5. Calculate the loss and other standard boilerplate
loss = ((target - pred)**2).mean()
loss.backward()
optim.step()
optim.zero_grad()
pbar.set_postfix(loss=loss.item())
losses.append(loss.item())
讓我們將其直接映射到我們的公式上:
-
x1 = ...
andx0 = ...
:從我們的數據分布 和噪聲分布 中採樣,提供了期望 所需的 和 。 -
target = x1 - x0
:就在這裡。問題的核心。這一行計算我們簡單直線路徑的真實向量場。它是我們損失函數的目標部分,。 -
xt = (1 - t[:, None]) * x0 + t[:, None] * x1
:這是另一個關鍵部分。這是創建路徑上點 的線性插值。它是模型的輸入,。 -
pred = model(xt, t)
:這是前向傳播,獲取我們網絡的預測,。 -
loss = ((target - pred)**2).mean()
:是最後一步。它計算 target 與 pred 之間的均方誤差。這是我們公式的 部分。
就這些!這五行最重要的代碼是我們推導出的優雅公式的直接逐行實現。
採樣:Following the Flow 🗺️#
所以我們已經訓練好了我們的模型。它現在是一個具有高度技能的 “GPS”,能夠知道速度場。我們如何生成新的棋盤格圖案呢?我們從一片空曠的地方(噪聲)開始,按照指示一路前行。
其基本數學原理是我們希望求解常微分方程(ODE):
解決這個問題的最簡單方法是歐拉法,它就是通過逐步取小的離散步來實現。
Tip
由於 是一個複雜的神經網絡,我們無法用筆和紙解決這個問題。我們必須進行模擬。最簡單的方式是用一系列小的離散直線步驟來逼近平滑、連續的流動。
根據導數的基本定義,我們知道在一個很小的時間步長 內,位置的變化 大約等於速度乘以時間步長:。
因此,為了在時間 獲得我們新的位置,我們只需在當前位置上添加這個微小的變化。這為我們提供了更新規則:
這段 “沿著速度方向稍微移動一點” 的方法有一個著名的名字:它被稱為歐拉方法(或歐拉更新)。這是數值求解常微分方程最簡單、最基本的方法。正如你所看到的,你幾乎可以從第一原理自己發明出來。
# The sampling loop from the script
xt = torch.randn(1000, 2) # Start with pure noise at t=0
steps = 1000
for i, t in enumerate(torch.linspace(0, 1, steps), start=1):
pred = model(xt, t.expand(xt.size(0))) # Get velocity prediction
# This is the Euler method step!
# dt = 1 / steps
xt = xt + (1 / steps) * pred
-
xt = torch.randn(...)
:我們從一團隨機的點雲開始,即我們的初始。 -
for t in torch.linspace(0, 1, steps)
:我們在一系列離散的steps
中模擬從 到 的流動隨時間的變化。 -
pred = model(xt, ...)
:在每一步,我們向模型詢問當前速度,。 -
xt = xt + (1 / steps) * pred
:這是歐拉更新。我們將當前點 xt 朝模型 pred 預測的方向邁出微小的一步。步長 dt 為 1 /steps。
通過重複這個簡單的更新,隨機點雲逐漸被學到的向量場 “推動” 著,直到它們流入棋盤格數據分布的形狀。
理論的簡潔性直接轉化為簡潔、清晰且高效的代碼,令人感到美妙。
DiffusionFlow#
但等等…… 先別急著慶祝,讓我們暫停一下。來個 “思考氣泡” 時刻。
Warning
我們證明了這套數學在假設 (隨機噪聲向量)和 (隨機貓圖)之間走直線路徑的前提下是成立的。可是…… 直線真的就是最好、最高效的路徑嗎?
從高斯噪聲雲的混沌狀態,到貓圖像所處的那個精細複雜的流形,其 “真實” 的變換過程很可能是一段狂野、曲折、高維的旅程。而我們強行假設走直線…… 是不是有點太粗暴了?我們要求單個神經網絡 學習一個向量場,讓它神奇地適用於所有這些被強制設定的、不自然的線性插值。這或許正是為什麼我們仍然需要相當多的採樣步數才能生成高質量圖像的原因 —— 學到的向量場不得不持續糾正我們這個過於簡化的路徑假設。
於是,一個優秀的 “黑客” 接下來自然會問:“我們能不能讓要解決的問題變得更簡單?”
試想一下…… 如果我們不再強行連接兩個完全隨機的點 和 ,而是能找到一組 “更好” 的起點和終點呢?比如一對點 ,它們本身就以某種方式 “天然” 關聯,兩點之間的路徑已經非常接近直線,簡單得多?
這樣的點對從哪兒來?很簡單 —— 我們可以用另一個生成模型(比如隨便一個現成的 DDPM)來幫我們生成!我們給它一個噪聲向量 ,它經過幾百步的迭代,輸出一張不錯的圖像 。這樣,我們就得到了一个 點對,它代表了一個強大模型實際走過的 “真實” 路徑。
現在,我們就有了一個 “教師 - 學生”(teacher-student)的訓練框架:舊的、慢速的模型為我們提供這些 “預先拉直” 的路徑,而我們則用它們來訓練新的、簡單的流匹配(Flow Matching)模型。這樣一來,新模型的學習任務就變得簡單多了。
這種用一個模型為另一個模型構造更簡單學習問題的思路非常強大。本質上,你是在把一條複雜、彎曲的路徑 “蒸餾” 成一條更簡單、更直的路徑。事實上,DeepMind 等團隊也想到了完全相同的點子 —— 這正是他們提出的 Rectified Flow(校正流)或 DiffusionFlow 的核心思想:通過迭代不斷拉直路徑,直到它足夠直,幾乎一步就能從起點跳到終點。
這是一個建立在我們最初那個簡單 “取巧” 之上的、極具美感的元層級(meta-level)思想。值得細細品味。