banner
Nagi-ovo

Nagi-ovo

Breezing homepage: nagi.fun
github

Ditching the SDEs: A Simpler Path with Flow Matching

This article mainly organizes and elaborates on the teaching logic of the video, combining the explanatory content. If there are any errors, feel free to correct them in the comments!

Flow Matching: Let's Redefine Generative Models from First Principles#

Alright, let's talk about generative models. The goal is simple, right? We have a dataset, say, a bunch of cat images, which come from some crazy, high-dimensional probability distribution p1(x1)p_1(x_1). We want to train a model that can spit out new cat images. The goal is simple, but the implementation might become... well, quite gnarly.

You might have heard of Diffusion models. The whole idea is to start with an image, slowly add noise over hundreds of steps, and then train a massive network to reverse this process step by step. The math behind it involves score functions (xlogpt(x)\nabla_x \log p_t(x)), stochastic differential equations (SDEs)... it's a whole thing. This method is indeed effective and surprisingly good, but as a computer scientist, I always wonder: can we achieve the same goal in a simpler, more direct way? Is there a way to hack it?

Let's take a step back and start from scratch. First principles.

Core Problem#

We have two distributions:

  1. p0(x0)p_0(x_0): a super simple noise distribution that we can easily sample from. Think of it as x0 = torch.randn_like(image).
  2. p1(x1)p_1(x_1): the complex, unknown distribution of our real data (like cats!). We can sample from it by loading an image from the dataset.

We want to learn a mapping that takes a sample from p0p_0 and transforms it into a sample from p1p_1.

The Diffusion approach defines a complex distribution path pt(x)p_t(x), allowing p1p_1 to slowly evolve into p0p_0, and then learns how to reverse it. But this intermediate distribution pt(x)p_t(x) is precisely where all the mathematical complexity comes from.

So, what's the simplest thing we can do?

A Naive, "High School Physics" Idea#

What if we just... draw a straight line?

Seriously. Let's pick a noise sample x0x_0 and a real cat image sample x1x_1. What's the simplest path between them? Of course, it's linear interpolation.

xt=(1t)x0+tx1x_t = (1-t)x_0 + t x_1

Here, tt is our "time" parameter, ranging from 00 to 11.

  • When t=0t=0, we are at the noise x0x_0.
  • When t=1t=1, we are at the cat image x1x_1.
  • At any time tt between the two, we are in some blurred mixed state of both.

Alright, simple enough. Now, if we imagine a particle moving along this straight line from x0x_0 to x1x_1 over one second, what is its speed? Again, let's stick to high school physics and differentiate with respect to time tt:

dxtdt=ddt((1t)x0+tx1)=x0+x1=x1x0\frac{d x_t}{dt} = \frac{d}{dt}((1-t)x_0 + t x_1) = -x_0 + x_1 = x_1 - x_0

Wait a second. Let that conclusion sink in for a moment.

Screenshot 2025-10-01 at 00.19.45

For this simple straight-line path, the speed of our particle at any point in time is just that constant vector pointing from the start to the end. This is the simplest vector field you can imagine.

This is the "Aha!" moment. What if this is all we need?

Building the Model#

We have a goal! We want to learn a vector field. Let's define a neural network vθ(x,t)v_\theta(x, t) that takes any point xx and any time tt as input and outputs a vector, which is its prediction of the speed at that point.

How do we train it? Well, we want the network's output to match our simple target speed x1x0x_1 - x_0. The most straightforward way is to use Mean Squared Error loss.

So, our entire training objective becomes:

L=Et,x0,x1[vθ((1t)x0+tx1,t)(x1x0)2]L = \mathbb{E}_{t, x_0, x_1} \left[ \| v_\theta((1-t)x_0 + t x_1, t) - (x_1 - x_0) \|^2 \right]

Let's break down the training loop. It's comically simple:

x1 = sample_from_dataset()        # Grab a real cat image
x0 = torch.randn_like(x1)         # Grab some noise
t = torch.rand(1)                 # Randomly pick a time
xt = (1 - t) * x0 + t * x1        # Interpolate to get our training input point
predicted_velocity = model(xt, t) # Let the model give a speed prediction
target_velocity = x1 - x0         # This is our ground truth!
loss = mse_loss(predicted_velocity, target_velocity)
loss.backward()
optimizer.step()

Boom. That's it. This is the core of Conditional Flow Matching. We turned a perplexing probability distribution matching problem into a simple regression problem.

Why This is So Powerful: "Simulation-Free"#

Notice what we didn't do. We never had to mention that complex marginal distribution pt(x)p_t(x). We never had to define or estimate a score function. We completely bypassed the entire SDE/PDE theoretical framework.

All we needed was the ability to extract point pairs (x0,x1)(x_0, x_1) and interpolate between them. That's why it's called simulation-free training. It's incredibly direct.

Generating Images (Inference)#

Alright, we have trained our network vθ(x,t)v_\theta(x, t) to be an excellent "GPS" from noise to data. How do we generate a new cat image?

We just follow its directions!

  1. Start from a random noise point: x = torch.randn(...).
  2. Start from time t=0t=0.
  3. Iterate for several steps:
    a. Get direction from our "GPS": velocity = model(x, t).
    b. Take a small step in that direction: x = x + velocity * dt.
    c. Update time: t = t + dt.
  4. After enough steps (for example, when tt reaches 1), x will become our brand new cat image.

This process is just solving an ordinary differential equation (ODE). It's basically the Euler method, which you might have learned in high school. Pretty cool, right?

Summary#

So, to recap, Flow Matching gives us a whole new, simpler perspective on generative modeling. We no longer think about probability densities and scores, but rather vector fields and flows. We defined a simple path from noise to data (like a straight line) and then trained a neural network to learn the velocity field that produces this path.

fm

It turns out that this simple, intuitive idea is not just a hack; it is theoretically sound and powers some of the latest state-of-the-art models like SD3. It perfectly reminds us that sometimes, the most profound advancements come from finding a simpler abstraction for a complex problem.

Simplicity wins.

The "Formal Derivation" of Flow Matching: Why Our Simple "Hack" Works#

Alright, earlier we derived the core of Flow Matching using a super simple and intuitive idea. We took a noise sample x0x_0, a real data sample x1x_1, drew a straight line between them (xt=(1t)x0+tx1x_t = (1-t)x_0 + t x_1), and then said our neural network vθv_\theta just needs to learn its speed... which is x1x0x_1 - x_0. The loss function almost popped out by itself. Done.

Honestly, for a practitioner, that's already 90% of what you need to know.

But if you're like me, there might be a little voice in your head saying, "Wait... this feels too easy. Our trick is based on a pair of independent sample points (x0,x1)(x_0, x_1). Why should learning these independent lines allow our network to understand the flow of the entire high-dimensional probability distribution pt(x)p_t(x)? Is our simple trick a solid shortcut, or just a lucky, somewhat cute 'hack'?"

This is where the formal derivation in the paper comes into play. Its purpose is to prove that our simple, conditional objective function can indeed magically optimize the grander, scarier marginal objective.

Let's temporarily put on our mathematician hats and see how they bridge this gap.

The "Official", Purely Theoretical Problem: Marginal Flows#

The "real", theoretically pure problem is this: we have a series of probability distributions pt(x)p_t(x) that gradually "morph" from noise p0(x)p_0(x) to data p1(x)p_1(x). This continuous morphing process is governed by a vector field ut(x)u_t(x). This ut(x)u_t(x) is the velocity vector at time tt and position xx.

So, the "official" objective is to train our network vθ(x,t)v_\theta(x, t) to match this true marginal vector field ut(x)u_t(x). The corresponding loss function should be:

Lmarginal=EtU(0,1),xpt(x)[vθ(x,t)ut(x)2]L_{marginal} = \mathbb{E}_{t \sim U(0,1), x \sim p_t(x)} \left[ \| v_\theta(x, t) - u_t(x) \|^2 \right]

Seeing this formula, we should immediately realize: this is a disaster. It's completely intractable. We can't sample from pt(x)p_t(x), and we have no idea what the target ut(x)u_t(x) is. So, for now, this loss function is useless.

The Bridge of Communication: Connecting "Marginal" and "Conditional"#

So, researchers pulled a classic mathematical trick. They said, "Alright, the marginal field ut(x)u_t(x) is a beast. But can we express it as an average of a bunch of simple, conditional vector fields?"

A conditional vector field, which we call ut(xx1)u_t(x|x_1), refers to the velocity of a point xx given that we already know its final destination is the data point x1x_1.

The paper proves (which is also its core theoretical insight) that the scary marginal field ut(x)u_t(x) is actually the expectation of all the simple conditional fields, weighted by "the probability that a path starting from x1x_1 will pass through xx":

ut(x)=Ex1p1(x1)[ut(xx1)(some probability term)]u_t(x) = \mathbb{E}_{x_1 \sim p_1(x_1)} [u_t(x|x_1) \cdot (\text{some probability term})]

This establishes a bridge. We connect an unknown thing (ut(x)u_t(x)) with a bunch of simpler things that we might be able to define (ut(xx1)u_t(x|x_1)).

Our starting point is that "official", theoretically correct but directly unoptimizable marginal flow loss function. For any time step tt, its form is:

Lt(vθ)=Expt(x)[vθ(x,t)ut(x)2]L_t(v_\theta) = \mathbb{E}_{x \sim p_t(x)} \left[ \| v_\theta(x, t) - u_t(x) \|^2 \right]

Here, pt(x)p_t(x) is the marginal probability density at time tt, and ut(x)u_t(x) is the true marginal vector field we want to learn. We can't obtain either of these, so this form is uncomputable. Our goal is to transform it into a computable form through mathematical manipulation.

Step 1: Expand the Squared Error Term

We use the algebraic identity AB2=A22AB+B2\|A - B\|^2 = \|A\|^2 - 2A \cdot B + \|B\|^2 to expand the loss function:

Lt(vθ)=Expt(x)[vθ(x,t)22vθ(x,t)ut(x)+ut(x)2]L_t(v_\theta) = \mathbb{E}_{x \sim p_t(x)} \left[ \|v_\theta(x,t)\|^2 - 2 v_\theta(x,t) \cdot u_t(x) + \|u_t(x)\|^2 \right]

Notice that during optimization, we only care about the terms related to our model's parameters θ\theta. The term ut(x)2\|u_t(x)\|^2 is the squared length of the true vector field, which does not depend on θ\theta, and thus can be treated as a constant term when calculating gradients. To minimize Lt(vθ)L_t(v_\theta), we only need to minimize the remaining part:

Lt(vθ)=Expt(x)[vθ(x,t)22vθ(x,t)ut(x)]L_t'(v_\theta) = \mathbb{E}_{x \sim p_t(x)} \left[ \|v_\theta(x,t)\|^2 - 2 v_\theta(x,t) \cdot u_t(x) \right]

Step 2: Rewrite the Expectation as an Integral

Using the definition of expectation Exp(x)[f(x)]=p(x)f(x)dx\mathbb{E}_{x \sim p(x)}[f(x)] = \int p(x)f(x)dx, we rewrite the above as an integral form. We focus on the second part containing the unknown term ut(x)u_t(x), which is the cross term:

Lt(vθ)=pt(x)vθ(x,t)2dx2pt(x)vθ(x,t)ut(x)dxL_t'(v_\theta) = \int p_t(x) \|v_\theta(x,t)\|^2 dx - 2 \int p_t(x) v_\theta(x,t) \cdot u_t(x) dx

Step 3: Substitute the Bridge Formula Connecting "Marginal" and "Conditional"

The key here is a core equation that relates the intractable marginal term pt(x)ut(x)p_t(x) u_t(x) to the definable conditional terms. This equation is:

pt(x)ut(x)=Ex1p1(x1)[pt(xx1)ut(xx1)]=p1(x1)pt(xx1)ut(xx1)dx1p_t(x) u_t(x) = \mathbb{E}_{x_1 \sim p_1(x_1)} [ p_t(x|x_1) u_t(x|x_1) ] = \int p_1(x_1) p_t(x|x_1) u_t(x|x_1) dx_1

We substitute this equation into the cross term we are focusing on:

2pt(x)vθ(x,t)ut(x)dx=2vθ(x,t)(p1(x1)pt(xx1)ut(xx1)dx1)dx-2 \int p_t(x) v_\theta(x,t) \cdot u_t(x) dx = -2 \int v_\theta(x,t) \cdot \left( \int p_1(x_1) p_t(x|x_1) u_t(x|x_1) dx_1 \right) dx

Step 4: Change the Order of Integration (Fubini-Tonelli Theorem)

Now we have a double integral. This expression looks more complex, but we can use the Fubini-Tonelli theorem to change the order of integration between dxdx and dx1dx_1. This operation is valid and allows us to recombine the integrand:

=2p1(x1)(pt(xx1)vθ(x,t)ut(xx1)dx)dx1= -2 \int p_1(x_1) \left( \int p_t(x|x_1) v_\theta(x,t) \cdot u_t(x|x_1) dx \right) dx_1

Step 5: Rewrite the Integral Back into Expectation Form and Complete the Square

Carefully observe the part inside the brackets from step four: pt(xx1)vθ(x,t)ut(xx1)dx\int p_t(x|x_1) v_\theta(x,t) \cdot u_t(x|x_1) dx. This is precisely the expectation concerning the conditional probability distribution pt(xx1)p_t(x|x_1)! So we can write the inner integral as Expt(x1)[]\mathbb{E}_{x \sim p_t(\cdot|x_1)}[\dots].

=2p1(x1)Expt(x1)[vθ(x,t)ut(xx1)]dx1= -2 \int p_1(x_1) \mathbb{E}_{x \sim p_t(\cdot|x_1)} \left[ v_\theta(x,t) \cdot u_t(x|x_1) \right] dx_1

Now looking at the entire expression, it has transformed back into an integral concerning p1(x1)p_1(x_1), so we can rewrite it as an expectation concerning x1x_1:

=2Ex1p1(x1)[Expt(x1)[vθ(x,t)ut(xx1)]]= -2 \mathbb{E}_{x_1 \sim p_1(x_1)} \left[ \mathbb{E}_{x \sim p_t(\cdot|x_1)} \left[ v_\theta(x,t) \cdot u_t(x|x_1) \right] \right]

This nested expectation can be combined into an expectation concerning the joint distribution:

Cross term=2Ex1p1,xpt(x1)[vθ(x,t)ut(xx1)]\text{Cross term} = -2 \mathbb{E}_{x_1 \sim p_1, x \sim p_t(\cdot|x_1)} \left[ v_\theta(x,t) \cdot u_t(x|x_1) \right]

Now we substitute this transformed cross term back into Lt(vθ)L_t'(v_\theta). Simultaneously, through similar transformations, the first term can also be rewritten: Expt(x)[vθ(x,t)2]=Ex1,xpt(x1)[vθ(x,t)2]\mathbb{E}_{x \sim p_t(x)} [ \|v_\theta(x,t)\|^2 ] = \mathbb{E}_{x_1, x \sim p_t(\cdot|x_1)} [ \|v_\theta(x,t)\|^2 ].
Thus we obtain:

Lt(vθ)=Ex1,xpt(x1)[vθ(x,t)22vθ(x,t)ut(xx1)]L_t'(v_\theta) = \mathbb{E}_{x_1, x \sim p_t(\cdot|x_1)} \left[ \|v_\theta(x,t)\|^2 - 2 v_\theta(x,t) \cdot u_t(x|x_1) \right]

To achieve a perfect square form, we add and subtract the same term Ex1,xpt(x1)[ut(xx1)2]\mathbb{E}_{x_1, x \sim p_t(\cdot|x_1)} \left[ \|u_t(x|x_1)\|^2 \right]:

Lt(vθ)=Ex1,x[vθ(x,t)22vθ(x,t)ut(xx1)+ut(xx1)2]Ex1,x[ut(xx1)2]L_t'(v_\theta) = \mathbb{E}_{x_1, x} \left[ \|v_\theta(x,t)\|^2 - 2 v_\theta(x,t) \cdot u_t(x|x_1) + \|u_t(x|x_1)\|^2 \right] - \mathbb{E}_{x_1, x} \left[ \|u_t(x|x_1)\|^2 \right]

The part inside the brackets forms a complete square. The last term we subtracted does not depend on the model parameters θ\theta, so it can be ignored during optimization.


Final Result

We have successfully proven that minimizing the initially intractable marginal loss function is equivalent to minimizing the following tractable conditional flow matching objective function (Conditional Flow Matching Objective):

LCFM(vθ)=Et,x1,xpt(x1)[vθ(x,t)ut(xx1)2]L_{CFM}(v_\theta) = \mathbb{E}_{t, x_1, x \sim p_t(\cdot|x_1)} \left[ \| v_\theta(x, t) - u_t(x|x_1) \|^2 \right]

At this point, we have mathematically rigorously proven that by defining a simple conditional path (like linear interpolation) and its corresponding vector field, and optimizing this simple regression loss, we can achieve the grand goal of optimizing the true marginal flow.

This is a huge step! We have successfully eliminated the dependence on the marginal density pt(x)p_t(x). Our loss function now only depends on the density of the conditional path pt(x1)p_t(\cdot|x_1) and the conditional vector field ut(xx1)u_t(x|x_1).

Back to Our Initial Simple Idea#

So, where are we now? This formal proof tells us that as long as we can define a conditional path pt(xx1)p_t(x|x_1) and its corresponding vector field ut(xx1)u_t(x|x_1), we can use the above LFML_{FM} loss to train the model.

Now, we can finally bring back that initially "high school physics" level simple idea. We can freely define this conditional path. So, let's choose the simplest, most straightforward definition:

  1. Define the conditional path pt(x0,x1)p_t(\cdot|x_0, x_1): Let the path be deterministic, a straight line. So, the probability distribution is 1 on the line xt=(1t)x0+tx1x_t = (1-t)x_0 + t x_1 and 0 elsewhere. (Whereas in Diffusion, the path starting from x0x_0 is random.)
  2. Define the conditional vector field ut(xtx0,x1)u_t(x_t|x_0, x_1): As we calculated earlier, the speed along this path is simply x1x0x_1 - x_0.

Note

In mathematics, this special distribution, where "everything is concentrated at one point and zero elsewhere," is called the Dirac δ\delta function (Dirac delta function). So, when we choose a straight-line path, we are essentially choosing the Dirac function as our conditional probability distribution pt(xx0,x1)p_t(x|x_0, x_1).

Now, substituting these two simple definitions into the elegantly derived LFML_{FM} objective function we just discussed. The expectation Expt(x1)\mathbb{E}_{x \sim p_t(\cdot|x_1)} becomes "taking points xtx_t on our straight line," and the target ut(xx1)u_t(x|x_1) becomes our simple x1x0x_1 - x_0.

Thus, the moment of witnessing the miracle has arrived, and we finally obtain:

L=Et,x0,x1[vθ((1t)x0+tx1,t)(x1x0)2]L = \mathbb{E}_{t, x_0, x_1} \left[ \| v_\theta((1-t)x_0 + t x_1, t) - (x_1 - x_0) \|^2 \right]

We have returned to that identical, beautifully simple loss function, which is the one we "guessed" from that naive first-principles derivation! This is the significance of the entire formal proof!

Summary#

Alright, quite cool. We just went through a lot of complex mathematical derivations—integrals, Fubini's theorem, the whole process—only to prove that our simple intuitive "hack" method was correct from the start. We have confirmed that learning that simple vector target x1x0x_1 - x_0 along a straight-line path is indeed an effective way to train generative models.

From Theory to torch: Coding Flow Matching#

Alright, we now understand the intuitive idea and have even seen the detailed formal proof. Ultimately, this is just a simple regression problem. But talk is cheap, so let's look at the code.

Amazingly, the PyTorch implementation is almost a 1:1 translation of our final simple formula. No hidden complexities, no intimidating math libraries. Just pure torch.

Let's break down the most important parts of the script: the training loop and the sampling (inference) process.

Source code can be found in the implementation accompanying the video: https://github.com/dome272/Flow-Matching/blob/main/flow-matching.ipynb

Setup: Data and Model#

First, the script sets up a two-dimensional checkerboard pattern. This is our small "cat image dataset." These points are our real data, x1x_1.

Then, it defines a simple MLP (multi-layer perceptron). This is our neural network, our "GPS," our vector field predictor vθ(x,t)v_\theta(x,t). It's a standard network that takes a batch of coordinates x and a batch of time values t and outputs a predicted velocity vector for each point. There's nothing fancy here; the magic lies in what we make it do.

Training Loop: The Magic Happens Here#

This is the core part of the implementation. Let's revisit that final, elegant loss function from the blog:

L=Et,x0,x1[vθ((1t)x0+tx1Input to Model,t)Prediction(x1x0)Target2]L = \mathbb{E}_{t,x_0,x_1} \left[ \left\| \underbrace{v_\theta \big( \overbrace{(1-t)x_0 + tx_1}^{\text{Input to Model}}, t \big)}_{\text{Prediction}} - \underbrace{(x_1 - x_0)}_{\text{Target}} \right\|^2 \right]

Now, let's go through the training loop code line by line. This is where this formula comes into play.

data = torch.Tensor(sampled_points)
training_steps = 100_000
batch_size = 64
pbar = tqdm.tqdm(range(training_steps))
losses = []

for i in pbar:
	# 1. Sample real data x1 and noise x0
    x1 = data[torch.randint(data.size(0), (batch_size,))]
    x0 = torch.randn_like(x1)
    
    # 2. Define the target vector
    target = x1 - x0
    
    # 3. Sample random time t and create the interpolated input xt
    t = torch.rand(x1.size(0))
    xt = (1 - t[:, None]) * x0 + t[:, None] * x1
    
    # 4. Get the model's prediction
    pred = model(xt, t)  # also add t here
    
    # 5. Calculate the loss and other standard boilerplate
    loss = ((target - pred)**2).mean()
    loss.backward()
    optim.step()
    optim.zero_grad()
    pbar.set_postfix(loss=loss.item())
    losses.append(loss.item())

Let's map this directly to our formula:

  • x1 = ... and x0 = ...: Sampling from our data distribution p1p_1 and noise distribution p0p_0 provides the required x1x_1 and x0x_0 for the expectation E\mathbb{E}.

  • target = x1 - x0: Right here. The core of the problem. This line computes the true vector field of our simple straight-line path. It is the target part of our loss function, (x1x0)\color{red}{(x_1 - x_0)}.

  • xt = (1 - t[:, None]) * x0 + t[:, None] * x1: This is another key part. This creates the interpolated point xtx_t along the path. It is the input to the model, (1t)x0+tx1\color{blue}{(1-t)x_0 + tx_1}.

  • pred = model(xt, t): This is the forward pass, getting our network's prediction, vθ(xt,t)v_\theta(x_t,t).

  • loss = ((target - pred)**2).mean(): This is the final step. It calculates the mean squared error between target and pred. This is the 2\|\cdot\|^2 part of our formula.

That's it! These five lines of crucial code are a direct line-by-line implementation of our elegant formula.

Sampling: Following the Flow 🗺️#

So we have trained our model. It is now a highly skilled "GPS" that knows the velocity field. How do we generate new checkerboard patterns? We start from an empty space (noise) and follow the directions all the way.

The basic mathematical principle is that we want to solve the ordinary differential equation (ODE):
dxtdt=vθ(xt,t)\frac{dx_t}{dt} = v_\theta(x_t, t)
The simplest way to solve this problem is the Euler method, which is achieved by taking small discrete steps.

Tip

Since vθv_\theta is a complex neural network, we can't solve this problem with pen and paper. We must simulate it. The simplest way is to approximate the smooth, continuous flow with a series of small discrete straight steps.

According to the basic definition of derivatives, we know that over a very small time step dtdt, the change in position dxtdx_t is approximately equal to the velocity multiplied by the time step: dxtvθ(xt,t)dtdx_t \approx v_\theta(x_t, t) \cdot dt.

Therefore, to obtain our new position at time t+dtt + dt, we just need to add this small change to the current position. This gives us the update rule:

xt+dt=xt+vθ(xt,t)dtx_{t+dt} = x_t + v_\theta(x_t, t) \cdot dt

This method of "moving a little bit in the direction of the velocity" has a famous name: it's called the Euler method (or Euler update). This is the simplest and most basic way to numerically solve ordinary differential equations. As you can see, you could almost invent it yourself from first principles.

# The sampling loop from the script
xt = torch.randn(1000, 2)  # Start with pure noise at t=0
steps = 1000
for i, t in enumerate(torch.linspace(0, 1, steps), start=1):
    pred = model(xt, t.expand(xt.size(0))) # Get velocity prediction
    
    # This is the Euler method step!
    # dt = 1 / steps
    xt = xt + (1 / steps) * pred
  • xt = torch.randn(...): We start from a cloud of random points, which is our initial noise.

  • for t in torch.linspace(0, 1, steps): We simulate the flow from t=0t=0 to t=1t=1 over a series of discrete steps.

  • pred = model(xt, ...): At each step, we ask the model for the current velocity, vθ(xt,t)v_\theta(x_t, t).

  • xt = xt + (1 / steps) * pred: This is the Euler update. We take a small step in the direction predicted by the model. The step size dtdt is 1 / steps.

By repeating this simple update, the random points are gradually "pushed" by the learned vector field until they flow into the shape of the checkerboard data distribution.

The elegance of theory directly translates into clean, clear, and efficient code, which is delightful.

DiffusionFlow#

But wait… hold your horses, let's pause for a "thought bubble" moment.

Warning

We proved that this math holds under the assumption of taking a straight-line path between x0x_0 (the random noise vector) and x1x_1 (the random cat image). But... is a straight line really the best, most efficient path?

From the chaotic state of the Gaussian noise cloud to the finely complex manifold of the cat image, the "true" transformation process is likely a wild, winding, high-dimensional journey. And our assumption of taking a straight line... might be a bit too crude? Are we asking a single neural network vθv_\theta to learn a vector field that magically applies to all these forced, unnatural linear interpolations? This might be why we still need quite a few sampling steps to generate high-quality images—the learned vector field has to continuously correct our overly simplified path assumption.

So, a clever "hacker" would naturally ask, "Can we simplify the problem we're trying to solve?"

Imagine... if we no longer forced a connection between two completely random points x0x_0 and x1x_1, but could find a pair of "better" starting and ending points? For example, a pair of points (z0,z1)(z_0, z_1) that are somehow "naturally" related, where the path between them is already very close to a straight line, much simpler?

Where do such point pairs come from? Simple—we can use another generative model (like any ready-made DDPM) to help us generate them! We give it a noise vector z0z_0, and after hundreds of steps of iteration, it outputs a nice image z1z_1. This way, we obtain a (z0,z1)(z_0, z_1) point pair that represents the "true" path that a powerful model actually traversed.

Now we have a "teacher-student" training framework: the old, slow model provides us with these "pre-straightened" paths, and we use them to train the new, simpler flow matching model. This way, the learning task for the new model becomes much simpler.

This idea of using one model to construct a simpler learning problem for another model is incredibly powerful. Essentially, you are "distilling" a complex, winding path into a simpler, straighter one. In fact, teams like DeepMind have thought of exactly the same idea—this is the core concept behind their proposed Rectified Flow or DiffusionFlow: iteratively straightening the path until it is straight enough to almost jump from the starting point to the endpoint in one step.

This is a beautifully meta-level thought built on our initial simple "hack." It's worth savoring.

References#

Loading...
Ownership of this post data is guaranteed by blockchain and smart contracts to the creator alone.