背景介绍

Flow Matching(FM)由连续归一化流(CNFs)发展而来。首先简要介绍 CNF 的基本原理。


连续归一化流(CNFs)

1. 运动轨迹建模

运动轨迹通过常微分方程(ODE)建模:
dzdt=f(z(t),t) \frac{dz}{dt} = f(z(t), t) dtdz=f(z(t),t)
其中z(t)z(t)z(t) 表示系统的状态,f(z(t),t)f(z(t), t)f(z(t),t) 是向量场,通过神经网络进行参数化建模。

基于该轨迹,可以将服从初始分布 p0(x)p_0(x)p0(x) 的数据转换到目标分布 pT(z)p_T(z)pT(z)p0(x)p_0(x)p0(x) 是已知的初始分布(例如高斯分布)。pT(z)p_T(z)pT(z) 是目标分布,可能是未知的,需要通过数据样本近似。

2. 损失函数构建

1). 从目标分布中采样
从目标分布 pT(z)p_T(z)pT(z) 中采样得到样本 zzz。如果 pT(z)p_T(z)pT(z) 未知,可以通过数据样本近似。

2). 逆向ODE求解初始分布
通过逆向求解ODE,从目标分布 pT(z)p_T(z)pT(z) 中的样本 zzz 反推出初始分布 p0(x)p_0(x)p0(x) 中的样本点 xxx

3). 计算概率密度
假设 fffzzz 上是一致Lipschitz连续的,并且在 ttt 上是连续的,那么对数概率密度的变化遵循以下微分方程:
∂log⁡p(z(t))∂t=−tr(∂f∂z(t)) \frac{\partial \log p(z(t))}{\partial t} = -\mathrm{tr}\left( \frac{\partial f}{\partial z(t)} \right) tlogp(z(t))=tr(z(t)f)
其中 tr\mathrm{tr}tr 表示矩阵的迹,∂f∂z(t)\frac{\partial f}{\partial z(t)}z(t)ffffz(t)z(t)z(t) 的雅可比矩阵。

从时间 000TTT 积分,得到:
log⁡pT(z)=log⁡p0(x)−∫0Ttr(∂f∂z(s))ds \log p_T(z) = \log p_0(x) - \int_0^T \mathrm{tr}\left( \frac{\partial f}{\partial z(s)} \right) ds logpT(z)=logp0(x)0Ttr(z(s)f)ds

4). 最大化似然函数
为了拟合目标分布,最大化似然函数:
L=1N∑i=1Nlog⁡pT(zi) \mathcal{L} = \frac{1}{N} \sum_{i=1}^{N} \log p_T(z_i) L=N1i=1NlogpT(zi)


Flow Matching(FM)

1. 概述

与CNFs优化对数似然不同,Flow Matching直接优化向量场,并且不需要逆向求解ODE。Flow Matching 仍然使用常微分方程(ODE)建模运动轨迹,只是形式不同:
ddtϕt(x0)=ut(ϕt(x0)) \frac{d}{dt} \phi_t(x_0) = u_t(\phi_t(x_0)) dtdϕt(x0)=ut(ϕt(x0))
其中x0x_0x0 表示初始分布的采样点,ϕt(x0)\phi_t(x_0)ϕt(x0) 表示时刻 ttt 时初始点 x0x_0x0 达到的位置。

2. 向量场的参数化与损失函数

1). 向量场的参数化
使用神经网络参数化 vt(ϕt(x0))v_t(\phi_t(x_0))vt(ϕt(x0)) 来近似真实的向量场 ut(ϕt(x0))u_t(\phi_t(x_0))ut(ϕt(x0))

2). Flow Matching 损失函数
LFM(θ)=Et,pt(x)∥vt(x)−ut(x)∥2 \mathcal{L}_{\mathrm{FM}}(\theta) = \mathbb{E}_{t, p_t(x)} \| v_t(x) - u_t(x)\|^2 LFM(θ)=Et,pt(x)vt(x)ut(x)2
其中pt(x)p_t(x)pt(x) 是时刻 ttt 时的概率密度,xxx取自pt(x)p_t(x)pt(x)ut(x)u_t(x)ut(x) 是真实的向量场。

3). 条件 Flow Matching 损失函数
由于 pt(x)p_t(x)pt(x)ut(x)u_t(x)ut(x) 通常是未知的,无法直接优化损失函数。根据《FLOW MATCHING FOR GENERATIVE MODELING》定理2,优化条件损失函数与直接优化上述损失函数一致:
LCFM(θ)=Et,q(x1),pt(x∣x1)∥vt(x)−ut(x∣x1)∥2 \mathcal{L}_{\mathrm{CFM}}(\theta)=\mathbb{E}_{t,q(x_1),p_t(x|x_1)}\Vert v_t(x)-u_t(x|x_1)\Vert^2 LCFM(θ)=Et,q(x1),pt(xx1)vt(x)ut(xx1)2
其中pt(x∣x1)p_t(x|x_1)pt(xx1) 是条件概率密度,ut(x∣x1)u_t(x|x_1)ut(xx1) 是条件向量场。

3. 条件概率密度的选取

假设条件概率密度 pt(x∣x1)p_t(x|x_1)pt(xx1) 为高斯分布:
pt(x∣x1)=N(x ∣ μt(x1),σt(x1)2I) p_t(x|x_1)=\mathcal{N}(x\,|\,\mu_t(x_1),\sigma_t(x_1)^2I) pt(xx1)=N(xμt(x1),σt(x1)2I)

边界条件:

  • 初始时刻 t=0t=0t=0 时,p0(x∣x1)=p(x)p_0(x|x_1)=p(x)p0(xx1)=p(x)(初始分布)。
  • 终止时刻 t=1t=1t=1 时,p1(x∣x1)=N(x∣x1,σ2I)p_1(x|x_1)=\mathcal{N}(x|x_1,\sigma^2I)p1(xx1)=N(xx1,σ2I)(目标分布)。
4. 轨迹

由于生成特定概率路径的轨迹有无数种可能,此处选择简单的仿射变换作为轨迹:
ψt(x0)=σt(x1)x0+μt(x1) \psi_t(x_0)=\sigma_t(x_1)x_0+\mu_t(x_1) ψt(x0)=σt(x1)x0+μt(x1)
其中x0x_0x0 是初始点,服从标准正态分布 p(x)p(x)p(x)

5. 重参数化

利用重参数,将对pt(x∣x1)p_t(x|x_1)pt(xx1)的采样转变为对p(x0)p(x_0)p(x0)的采样,重写条件Flow Matching 损失函数:
LCFM(θ)=Et,q(x1),p(x0)∥vt(ψt(x0))−ddtψt(x0)∥2. \begin{align*}\mathcal{L}_{\mathrm{CFM}}(\theta) = \mathbb{E}_{t,q(x_1),p(x_0)} \Big\| v_t(\psi_t(x_0)) - \frac{d}{dt} \psi_t(x_0) \Big\|^2.\end{align*} LCFM(θ)=Et,q(x1),p(x0) vt(ψt(x0))dtdψt(x0) 2.

6. 向量场推导

条件向量场 ut(ψt(x0)∣x1)u_t(\psi_t(x_0)|x_1)ut(ψt(x0)x1) 是轨迹 ψt(x0)\psi_t(x_0)ψt(x0) 对时间 ttt 的导数:
ut(ψt(x0)∣x1)=ddtψt(x0) u_t(\psi_t(x_0)|x_1)=\frac{d}{dt} \psi_t(x_0) ut(ψt(x0)x1)=dtdψt(x0)
ψt(x0)\psi_t(x_0)ψt(x0) 求导,得到:
ψt′(x0)=σt′(x1)x0+μt′(x1) \psi_t'(x_0)=\sigma_t'(x_1)x_0+\mu_t'(x_1) ψt(x0)=σt(x1)x0+μt(x1)
从轨迹方程中解出 x0x_0x0
x0=ψt(x0)−μt(x1)σt(x1) x_0=\frac{\psi_t(x_0)-\mu_t(x_1)}{\sigma_t(x_1)} x0=σt(x1)ψt(x0)μt(x1)
x0x_0x0 代入导数表达式,得到条件向量场:
ut(ψt(x)∣x1)=σt′(x1)σt(x1)(ψt(x)−μt(x1))+μt′(x1) u_t(\psi_t(x)|x_1)=\frac{\sigma'_t(x_1)}{\sigma_t(x_1)}\left(\psi_t(x)-\mu_t(x_1)\right)+\mu'_t(x_1) ut(ψt(x)x1)=σt(x1)σt(x1)(ψt(x)μt(x1))+μt(x1)

举例

1. VP扩散路径

VP扩散路径(来自《SCORE-BASED GENERATIVE MODELING THROUGH STOCHASTIC DIFFERENTIAL EQUATIONS》中的公式29),此处为了区分,对变量名作了改变:
pτ(y∣y0)=N(y ∣ ατy0,(1−ατ2) I),ατ=e−12T(τ),T(τ)=∫0τβ(s)ds p_\tau(y|y_0)=\mathcal{N}(y\,|\,\alpha_{\tau}y_0,(1-\alpha_{\tau}^2)\,I),\alpha_{\tau}=e^{-\frac12T({\tau})},T({\tau})=\int_0^{\tau} \beta(s)ds pτ(yy0)=N(yατy0,(1ατ2)I),ατ=e21T(τ),T(τ)=0τβ(s)ds
上式中 y0y_0y0 代表数据点,本文中 x1x_1x1 代表数据点。上式的时间是从数据点到噪点,本文的时间是噪点到数据点,因此 t=1−τt=1-\taut=1τ,那么:
pt(x∣x1)=N(x ∣ α1−tx1,(1−α1−t2) I),αt=e−12T(t),T(t)=∫0tβ(s)ds p_t(x|x_1)=\mathcal{N}(x\,|\,\alpha_{1-t}x_1,(1-\alpha_{1-t}^2)\,I),\alpha_t=e^{-\frac12T(t)},T(t)=\int_0^t\beta(s)ds pt(xx1)=N(xα1tx1,(1α1t2)I),αt=e21T(t),T(t)=0tβ(s)ds
即:
μt(x1)=α1−tx1,σt(x1)=1−α1−t2 \mu_t(x_1)=\alpha_{1-t}x_1,\sigma_t(x_1)=\sqrt{1-\alpha_{1-t}^2} μt(x1)=α1tx1,σt(x1)=1α1t2
代入得条件向量场:
ut(x∣x1)=α1−t′1−α1−t2(α1−tx−x1)=−T′(1−t)2[e−T(1−t)x−e−12T(1−t)x11−e−T(1−t)] u_t(x|x_1)=\frac{\alpha'_{1-t}}{1-\alpha^2_{1-t}}\left(\alpha_{1-t}x-x_1\right)=-\frac{T'(1-t)}{2}\left[\frac{e^{-T(1-t)}x-e^{-\frac{1}{2}T(1-t)}x_1}{1-e^{-T(1-t)}}\right] ut(xx1)=1α1t2α1t(α1txx1)=2T(1t)[1eT(1t)eT(1t)xe21T(1t)x1]

2. 线性变化

更简单的例子是随时间线性变化:
μt(x)=tx1,σt(x)=1−(1−σmin⁡)t \mu_t(x)=tx_1,\sigma_t(x)=1-(1-\sigma_{\min})t μt(x)=tx1,σt(x)=1(1σmin)t
条件向量场:
ut(ψ(x)∣x1)=x1−(1−σmin⁡)ψ(x)1−(1−σmin⁡)t u_t(\psi(x)|x_1)=\frac{x_1-(1-\sigma_{\min})\psi(x)}{1-(1-\sigma_{\min})t} ut(ψ(x)x1)=1(1σmin)tx1(1σmin)ψ(x)

Flow Matching vs. Rectified Flow

Flow Matching 从概率论出发,通过条件期望将对全局边缘向量场的回归,严格等价地转化为对条件向量场的回归。
Rectified Flow 则从传输轨迹的几何结构出发,直接定义粒子级别的速度场,并通过迭代 rectification 使生成路径逐步逼近理想的 transport flow。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐