This article mainly organizes and elaborates on the teaching logic of the video, combining the explanatory content. If there are any errors, feel free to correct them in the comments!
Flow Matching: Let's Redefine Generative Models from First Principles#
Alright, let's talk about generative models
. The goal is simple, right
? We have a dataset, say, a bunch of cat images, which come from some crazy, high-dimensional probability distribution . We want to train a model that can spit out new cat images. The goal is simple, but the implementation might become... well, quite gnarly
.
You might have heard of Diffusion models
. The whole idea is to start with an image, slowly add noise over hundreds of steps, and then train a massive network to reverse this process step by step. The math behind it involves score functions
(), stochastic differential equations (SDEs)... it's a whole thing
. This method is indeed effective and surprisingly good, but as a computer scientist, I always wonder: can we achieve the same goal in a simpler, more direct way? Is there a way to hack
it?
Let's take a step back and start from scratch. First principles.
Core Problem#
We have two distributions:
- : a super simple noise distribution that we can easily sample from. Think of it as
x0 = torch.randn_like(image)
. - : the complex, unknown distribution of our real data (like cats!). We can sample from it by loading an image from the dataset.
We want to learn a mapping
that takes a sample from and transforms it into a sample from .
The Diffusion
approach defines a complex distribution path , allowing to slowly evolve into , and then learns how to reverse it. But this intermediate distribution is precisely where all the mathematical complexity comes from.
So, what's the simplest thing we can do?
A Naive, "High School Physics" Idea#
What if we just... draw a straight line?
Seriously. Let's pick a noise sample and a real cat image sample . What's the simplest path between them? Of course, it's linear interpolation.
Here, is our "time" parameter, ranging from to .
- When , we are at the noise .
- When , we are at the cat image .
- At any time between the two, we are in some blurred mixed state of both.
Alright, simple enough. Now, if we imagine a particle moving along this straight line from to over one second, what is its speed? Again, let's stick to high school physics and differentiate with respect to time :
Wait a second. Let that conclusion sink in for a moment.
For this simple straight-line path, the speed of our particle at any point in time is just that constant vector pointing from the start to the end. This is the simplest vector field
you can imagine.
This is the "Aha!" moment. What if this is all we need?
Building the Model#
We have a goal! We want to learn a vector field
. Let's define a neural network that takes any point and any time as input and outputs a vector, which is its prediction of the speed at that point.
How do we train it? Well, we want the network's output to match our simple target speed . The most straightforward way is to use Mean Squared Error
loss.
So, our entire training objective becomes:
Let's break down the training loop
. It's comically simple:
x1 = sample_from_dataset() # Grab a real cat image
x0 = torch.randn_like(x1) # Grab some noise
t = torch.rand(1) # Randomly pick a time
xt = (1 - t) * x0 + t * x1 # Interpolate to get our training input point
predicted_velocity = model(xt, t) # Let the model give a speed prediction
target_velocity = x1 - x0 # This is our ground truth!
loss = mse_loss(predicted_velocity, target_velocity)
loss.backward()
optimizer.step()
Boom
. That's it. This is the core of Conditional Flow Matching. We turned a perplexing probability distribution matching problem into a simple regression problem.
Why This is So Powerful: "Simulation-Free"#
Notice what we didn't do. We never had to mention that complex marginal distribution . We never had to define or estimate a score function
. We completely bypassed the entire SDE/PDE theoretical framework.
All we needed was the ability to extract point pairs and interpolate between them. That's why it's called simulation-free training. It's incredibly direct.
Generating Images (Inference)#
Alright, we have trained our network to be an excellent "GPS" from noise to data. How do we generate a new cat image?
We just follow its directions!
- Start from a random noise point:
x = torch.randn(...)
. - Start from time .
- Iterate for several steps:
a. Get direction from our "GPS":velocity = model(x, t)
.
b. Take a small step in that direction:x = x + velocity * dt
.
c. Update time:t = t + dt
. - After enough steps (for example, when reaches 1),
x
will become our brand new cat image.
This process is just solving an ordinary differential equation (ODE). It's basically the Euler method, which you might have learned in high school. Pretty cool, right?
Summary#
So, to recap, Flow Matching
gives us a whole new, simpler perspective on generative modeling
. We no longer think about probability densities and scores, but rather vector fields and flows. We defined a simple path from noise to data (like a straight line) and then trained a neural network to learn the velocity field that produces this path.
It turns out that this simple, intuitive idea is not just a hack
; it is theoretically sound and powers some of the latest state-of-the-art
models like SD3. It perfectly reminds us that sometimes, the most profound advancements come from finding a simpler abstraction
for a complex problem.
Simplicity wins.
The "Formal Derivation" of Flow Matching: Why Our Simple "Hack" Works#
Alright, earlier we derived the core of Flow Matching using a super simple and intuitive idea. We took a noise sample , a real data sample , drew a straight line between them (), and then said our neural network just needs to learn its speed... which is . The loss function almost popped out by itself. Done.
Honestly, for a practitioner, that's already 90% of what you need to know.
But if you're like me, there might be a little voice in your head saying, "Wait... this feels too easy. Our trick is based on a pair of independent sample points . Why should learning these independent lines allow our network to understand the flow of the entire high-dimensional probability distribution ? Is our simple trick a solid shortcut, or just a lucky, somewhat cute 'hack'?"
This is where the formal derivation in the paper comes into play. Its purpose is to prove that our simple, conditional
objective function can indeed magically optimize the grander, scarier marginal
objective.
Let's temporarily put on our mathematician hats and see how they bridge this gap.
The "Official", Purely Theoretical Problem: Marginal Flows#
The "real", theoretically pure problem is this: we have a series of probability distributions that gradually "morph" from noise to data . This continuous morphing process is governed by a vector field
. This is the velocity vector at time and position .
So, the "official" objective is to train our network to match this true marginal vector field . The corresponding loss function should be:
Seeing this formula, we should immediately realize: this is a disaster. It's completely intractable
. We can't sample from , and we have no idea what the target is. So, for now, this loss function is useless.
The Bridge of Communication: Connecting "Marginal" and "Conditional"#
So, researchers pulled a classic mathematical trick. They said, "Alright, the marginal field is a beast. But can we express it as an average of a bunch of simple, conditional
vector fields?"
A conditional vector field, which we call , refers to the velocity of a point given that we already know its final destination is the data point .
The paper proves (which is also its core theoretical insight) that the scary marginal field is actually the expectation of all the simple conditional fields, weighted by "the probability that a path starting from will pass through ":
This establishes a bridge. We connect an unknown thing () with a bunch of simpler things that we might be able to define ().
Our starting point is that "official", theoretically correct but directly unoptimizable marginal flow loss function. For any time step , its form is:
Here, is the marginal probability density at time , and is the true marginal vector field we want to learn. We can't obtain either of these, so this form is uncomputable. Our goal is to transform it into a computable form through mathematical manipulation.
Step 1: Expand the Squared Error Term
We use the algebraic identity to expand the loss function:
Notice that during optimization, we only care about the terms related to our model's parameters . The term is the squared length of the true vector field, which does not depend on , and thus can be treated as a constant term when calculating gradients. To minimize , we only need to minimize the remaining part:
Step 2: Rewrite the Expectation as an Integral
Using the definition of expectation , we rewrite the above as an integral form. We focus on the second part containing the unknown term , which is the cross term:
Step 3: Substitute the Bridge Formula Connecting "Marginal" and "Conditional"
The key here is a core equation that relates the intractable marginal term to the definable conditional terms. This equation is:
We substitute this equation into the cross term we are focusing on:
Step 4: Change the Order of Integration (Fubini-Tonelli Theorem)
Now we have a double integral. This expression looks more complex, but we can use the Fubini-Tonelli theorem to change the order of integration between and . This operation is valid and allows us to recombine the integrand:
Step 5: Rewrite the Integral Back into Expectation Form and Complete the Square
Carefully observe the part inside the brackets from step four: . This is precisely the expectation concerning the conditional probability distribution ! So we can write the inner integral as .
Now looking at the entire expression, it has transformed back into an integral concerning , so we can rewrite it as an expectation concerning :
This nested expectation can be combined into an expectation concerning the joint distribution:
Now we substitute this transformed cross term back into . Simultaneously, through similar transformations, the first term can also be rewritten: .
Thus we obtain:
To achieve a perfect square form, we add and subtract the same term :
The part inside the brackets forms a complete square. The last term we subtracted does not depend on the model parameters , so it can be ignored during optimization.
Final Result
We have successfully proven that minimizing the initially intractable marginal loss function is equivalent to minimizing the following tractable conditional flow matching objective function (Conditional Flow Matching Objective):
At this point, we have mathematically rigorously proven that by defining a simple conditional path (like linear interpolation) and its corresponding vector field, and optimizing this simple regression loss, we can achieve the grand goal of optimizing the true marginal flow.
This is a huge step! We have successfully eliminated the dependence on the marginal density . Our loss function now only depends on the density of the conditional path and the conditional vector field .
Back to Our Initial Simple Idea#
So, where are we now? This formal proof tells us that as long as we can define a conditional path and its corresponding vector field , we can use the above loss to train the model.
Now, we can finally bring back that initially "high school physics" level simple idea. We can freely define this conditional path. So, let's choose the simplest, most straightforward definition:
- Define the conditional path : Let the path be deterministic, a straight line. So, the probability distribution is 1 on the line and 0 elsewhere. (Whereas in Diffusion, the path starting from is random.)
- Define the conditional vector field : As we calculated earlier, the speed along this path is simply .
Note
In mathematics, this special distribution, where "everything is concentrated at one point and zero elsewhere," is called the Dirac function (Dirac delta function). So, when we choose a straight-line path, we are essentially choosing the Dirac function as our conditional probability distribution .
Now, substituting these two simple definitions into the elegantly derived objective function we just discussed. The expectation becomes "taking points on our straight line," and the target becomes our simple .
Thus, the moment of witnessing the miracle has arrived, and we finally obtain:
We have returned to that identical, beautifully simple loss function, which is the one we "guessed" from that naive first-principles derivation! This is the significance of the entire formal proof!
Summary#
Alright, quite cool. We just went through a lot of complex mathematical derivations—integrals, Fubini's theorem, the whole process—only to prove that our simple intuitive "hack" method was correct from the start. We have confirmed that learning that simple vector target along a straight-line path is indeed an effective way to train generative models.
From Theory to torch
: Coding Flow Matching#
Alright, we now understand the intuitive idea and have even seen the detailed formal proof. Ultimately, this is just a simple regression problem. But talk is cheap, so let's look at the code.
Amazingly, the PyTorch implementation is almost a 1:1 translation of our final simple formula. No hidden complexities, no intimidating math libraries. Just pure torch
.
Let's break down the most important parts of the script: the training loop and the sampling (inference) process.
Source code can be found in the implementation accompanying the video: https://github.com/dome272/Flow-Matching/blob/main/flow-matching.ipynb
Setup: Data and Model#
First, the script sets up a two-dimensional checkerboard pattern. This is our small "cat image dataset." These points are our real data, .
Then, it defines a simple MLP (multi-layer perceptron). This is our neural network, our "GPS," our vector field predictor . It's a standard network that takes a batch of coordinates x and a batch of time values t and outputs a predicted velocity vector for each point. There's nothing fancy here; the magic lies in what we make it do.
Training Loop: The Magic Happens Here#
This is the core part of the implementation. Let's revisit that final, elegant loss function from the blog:
Now, let's go through the training loop code line by line. This is where this formula comes into play.
data = torch.Tensor(sampled_points)
training_steps = 100_000
batch_size = 64
pbar = tqdm.tqdm(range(training_steps))
losses = []
for i in pbar:
# 1. Sample real data x1 and noise x0
x1 = data[torch.randint(data.size(0), (batch_size,))]
x0 = torch.randn_like(x1)
# 2. Define the target vector
target = x1 - x0
# 3. Sample random time t and create the interpolated input xt
t = torch.rand(x1.size(0))
xt = (1 - t[:, None]) * x0 + t[:, None] * x1
# 4. Get the model's prediction
pred = model(xt, t) # also add t here
# 5. Calculate the loss and other standard boilerplate
loss = ((target - pred)**2).mean()
loss.backward()
optim.step()
optim.zero_grad()
pbar.set_postfix(loss=loss.item())
losses.append(loss.item())
Let's map this directly to our formula:
-
x1 = ...
andx0 = ...
: Sampling from our data distribution and noise distribution provides the required and for the expectation . -
target = x1 - x0
: Right here. The core of the problem. This line computes the true vector field of our simple straight-line path. It is the target part of our loss function, . -
xt = (1 - t[:, None]) * x0 + t[:, None] * x1
: This is another key part. This creates the interpolated point along the path. It is the input to the model, . -
pred = model(xt, t)
: This is the forward pass, getting our network's prediction, . -
loss = ((target - pred)**2).mean()
: This is the final step. It calculates the mean squared error between target and pred. This is the part of our formula.
That's it! These five lines of crucial code are a direct line-by-line implementation of our elegant formula.
Sampling: Following the Flow 🗺️#
So we have trained our model. It is now a highly skilled "GPS" that knows the velocity field. How do we generate new checkerboard patterns? We start from an empty space (noise) and follow the directions all the way.
The basic mathematical principle is that we want to solve the ordinary differential equation (ODE):
The simplest way to solve this problem is the Euler method, which is achieved by taking small discrete steps.
Tip
Since is a complex neural network, we can't solve this problem with pen and paper. We must simulate it. The simplest way is to approximate the smooth, continuous flow with a series of small discrete straight steps.
According to the basic definition of derivatives, we know that over a very small time step , the change in position is approximately equal to the velocity multiplied by the time step: .
Therefore, to obtain our new position at time , we just need to add this small change to the current position. This gives us the update rule:
This method of "moving a little bit in the direction of the velocity" has a famous name: it's called the Euler method (or Euler update). This is the simplest and most basic way to numerically solve ordinary differential equations. As you can see, you could almost invent it yourself from first principles.
# The sampling loop from the script
xt = torch.randn(1000, 2) # Start with pure noise at t=0
steps = 1000
for i, t in enumerate(torch.linspace(0, 1, steps), start=1):
pred = model(xt, t.expand(xt.size(0))) # Get velocity prediction
# This is the Euler method step!
# dt = 1 / steps
xt = xt + (1 / steps) * pred
-
xt = torch.randn(...)
: We start from a cloud of random points, which is our initial noise. -
for t in torch.linspace(0, 1, steps)
: We simulate the flow from to over a series of discretesteps
. -
pred = model(xt, ...)
: At each step, we ask the model for the current velocity, . -
xt = xt + (1 / steps) * pred
: This is the Euler update. We take a small step in the direction predicted by the model. The step size is 1 / steps.
By repeating this simple update, the random points are gradually "pushed" by the learned vector field until they flow into the shape of the checkerboard data distribution.
The elegance of theory directly translates into clean, clear, and efficient code, which is delightful.
DiffusionFlow#
But wait… hold your horses, let's pause for a "thought bubble" moment.
Warning
We proved that this math holds under the assumption of taking a straight-line path between (the random noise vector) and (the random cat image). But... is a straight line really the best, most efficient path?
From the chaotic state of the Gaussian noise cloud to the finely complex manifold of the cat image, the "true" transformation process is likely a wild, winding, high-dimensional journey. And our assumption of taking a straight line... might be a bit too crude? Are we asking a single neural network to learn a vector field that magically applies to all these forced, unnatural linear interpolations? This might be why we still need quite a few sampling steps to generate high-quality images—the learned vector field has to continuously correct our overly simplified path assumption.
So, a clever "hacker" would naturally ask, "Can we simplify the problem we're trying to solve?"
Imagine... if we no longer forced a connection between two completely random points and , but could find a pair of "better" starting and ending points? For example, a pair of points that are somehow "naturally" related, where the path between them is already very close to a straight line, much simpler?
Where do such point pairs come from? Simple—we can use another generative model (like any ready-made DDPM) to help us generate them! We give it a noise vector , and after hundreds of steps of iteration, it outputs a nice image . This way, we obtain a point pair that represents the "true" path that a powerful model actually traversed.
Now we have a "teacher-student" training framework: the old, slow model provides us with these "pre-straightened" paths, and we use them to train the new, simpler flow matching model. This way, the learning task for the new model becomes much simpler.
This idea of using one model to construct a simpler learning problem for another model is incredibly powerful. Essentially, you are "distilling" a complex, winding path into a simpler, straighter one. In fact, teams like DeepMind have thought of exactly the same idea—this is the core concept behind their proposed Rectified Flow or DiffusionFlow: iteratively straightening the path until it is straight enough to almost jump from the starting point to the endpoint in one step.
This is a beautifully meta-level thought built on our initial simple "hack." It's worth savoring.