流匹配(Flow Matching)
Flow Matching(FM)由连续归一化流(CNFs)发展而来。首先简要介绍 CNF 的基本原理。与CNFs优化对数似然不同,Flow Matching直接优化向量场,并且不需要逆向求解ODE。ddtϕtx0utϕtx0dtdϕtx0utϕtx0))其中x0x_0x0表示初始分布的采样点,ϕtx0ϕtx0表示时刻ttt时初始点x0x_0x0达到的位置。
背景介绍
Flow Matching(FM)由连续归一化流(CNFs)发展而来。首先简要介绍 CNF 的基本原理。
连续归一化流(CNFs)
1. 运动轨迹建模
运动轨迹通过常微分方程(ODE)建模:
dzdt=f(z(t),t) \frac{dz}{dt} = f(z(t), t) dtdz=f(z(t),t)
其中z(t)z(t)z(t) 表示系统的状态,f(z(t),t)f(z(t), t)f(z(t),t) 是向量场,通过神经网络进行参数化建模。
基于该轨迹,可以将服从初始分布 p0(x)p_0(x)p0(x) 的数据转换到目标分布 pT(z)p_T(z)pT(z)。p0(x)p_0(x)p0(x) 是已知的初始分布(例如高斯分布)。pT(z)p_T(z)pT(z) 是目标分布,可能是未知的,需要通过数据样本近似。
2. 损失函数构建
1). 从目标分布中采样:
从目标分布 pT(z)p_T(z)pT(z) 中采样得到样本 zzz。如果 pT(z)p_T(z)pT(z) 未知,可以通过数据样本近似。
2). 逆向ODE求解初始分布:
通过逆向求解ODE,从目标分布 pT(z)p_T(z)pT(z) 中的样本 zzz 反推出初始分布 p0(x)p_0(x)p0(x) 中的样本点 xxx。
3). 计算概率密度:
假设 fff 在 zzz 上是一致Lipschitz连续的,并且在 ttt 上是连续的,那么对数概率密度的变化遵循以下微分方程:
∂logp(z(t))∂t=−tr(∂f∂z(t)) \frac{\partial \log p(z(t))}{\partial t} = -\mathrm{tr}\left( \frac{\partial f}{\partial z(t)} \right) ∂t∂logp(z(t))=−tr(∂z(t)∂f)
其中 tr\mathrm{tr}tr 表示矩阵的迹,∂f∂z(t)\frac{\partial f}{\partial z(t)}∂z(t)∂f 是 fff 对 z(t)z(t)z(t) 的雅可比矩阵。
从时间 000 到 TTT 积分,得到:
logpT(z)=logp0(x)−∫0Ttr(∂f∂z(s))ds \log p_T(z) = \log p_0(x) - \int_0^T \mathrm{tr}\left( \frac{\partial f}{\partial z(s)} \right) ds logpT(z)=logp0(x)−∫0Ttr(∂z(s)∂f)ds
4). 最大化似然函数:
为了拟合目标分布,最大化似然函数:
L=1N∑i=1NlogpT(zi) \mathcal{L} = \frac{1}{N} \sum_{i=1}^{N} \log p_T(z_i) L=N1i=1∑NlogpT(zi)
Flow Matching(FM)
1. 概述
与CNFs优化对数似然不同,Flow Matching直接优化向量场,并且不需要逆向求解ODE。Flow Matching 仍然使用常微分方程(ODE)建模运动轨迹,只是形式不同:
ddtϕt(x0)=ut(ϕt(x0)) \frac{d}{dt} \phi_t(x_0) = u_t(\phi_t(x_0)) dtdϕt(x0)=ut(ϕt(x0))
其中x0x_0x0 表示初始分布的采样点,ϕt(x0)\phi_t(x_0)ϕt(x0) 表示时刻 ttt 时初始点 x0x_0x0 达到的位置。
2. 向量场的参数化与损失函数
1). 向量场的参数化:
使用神经网络参数化 vt(ϕt(x0))v_t(\phi_t(x_0))vt(ϕt(x0)) 来近似真实的向量场 ut(ϕt(x0))u_t(\phi_t(x_0))ut(ϕt(x0))。
2). Flow Matching 损失函数:
LFM(θ)=Et,pt(x)∥vt(x)−ut(x)∥2 \mathcal{L}_{\mathrm{FM}}(\theta) = \mathbb{E}_{t, p_t(x)} \| v_t(x) - u_t(x)\|^2 LFM(θ)=Et,pt(x)∥vt(x)−ut(x)∥2
其中pt(x)p_t(x)pt(x) 是时刻 ttt 时的概率密度,xxx取自pt(x)p_t(x)pt(x),ut(x)u_t(x)ut(x) 是真实的向量场。
3). 条件 Flow Matching 损失函数:
由于 pt(x)p_t(x)pt(x) 和 ut(x)u_t(x)ut(x) 通常是未知的,无法直接优化损失函数。根据《FLOW MATCHING FOR GENERATIVE MODELING》定理2,优化条件损失函数与直接优化上述损失函数一致:
LCFM(θ)=Et,q(x1),pt(x∣x1)∥vt(x)−ut(x∣x1)∥2 \mathcal{L}_{\mathrm{CFM}}(\theta)=\mathbb{E}_{t,q(x_1),p_t(x|x_1)}\Vert v_t(x)-u_t(x|x_1)\Vert^2 LCFM(θ)=Et,q(x1),pt(x∣x1)∥vt(x)−ut(x∣x1)∥2
其中pt(x∣x1)p_t(x|x_1)pt(x∣x1) 是条件概率密度,ut(x∣x1)u_t(x|x_1)ut(x∣x1) 是条件向量场。
3. 条件概率密度的选取
假设条件概率密度 pt(x∣x1)p_t(x|x_1)pt(x∣x1) 为高斯分布:
pt(x∣x1)=N(x ∣ μt(x1),σt(x1)2I) p_t(x|x_1)=\mathcal{N}(x\,|\,\mu_t(x_1),\sigma_t(x_1)^2I) pt(x∣x1)=N(x∣μt(x1),σt(x1)2I)
边界条件:
- 初始时刻 t=0t=0t=0 时,p0(x∣x1)=p(x)p_0(x|x_1)=p(x)p0(x∣x1)=p(x)(初始分布)。
- 终止时刻 t=1t=1t=1 时,p1(x∣x1)=N(x∣x1,σ2I)p_1(x|x_1)=\mathcal{N}(x|x_1,\sigma^2I)p1(x∣x1)=N(x∣x1,σ2I)(目标分布)。
4. 轨迹
由于生成特定概率路径的轨迹有无数种可能,此处选择简单的仿射变换作为轨迹:
ψt(x0)=σt(x1)x0+μt(x1) \psi_t(x_0)=\sigma_t(x_1)x_0+\mu_t(x_1) ψt(x0)=σt(x1)x0+μt(x1)
其中x0x_0x0 是初始点,服从标准正态分布 p(x)p(x)p(x)。
5. 重参数化
利用重参数,将对pt(x∣x1)p_t(x|x_1)pt(x∣x1)的采样转变为对p(x0)p(x_0)p(x0)的采样,重写条件Flow Matching 损失函数:
LCFM(θ)=Et,q(x1),p(x0)∥vt(ψt(x0))−ddtψt(x0)∥2. \begin{align*}\mathcal{L}_{\mathrm{CFM}}(\theta) = \mathbb{E}_{t,q(x_1),p(x_0)} \Big\| v_t(\psi_t(x_0)) - \frac{d}{dt} \psi_t(x_0) \Big\|^2.\end{align*} LCFM(θ)=Et,q(x1),p(x0)
vt(ψt(x0))−dtdψt(x0)
2.
6. 向量场推导
条件向量场 ut(ψt(x0)∣x1)u_t(\psi_t(x_0)|x_1)ut(ψt(x0)∣x1) 是轨迹 ψt(x0)\psi_t(x_0)ψt(x0) 对时间 ttt 的导数:
ut(ψt(x0)∣x1)=ddtψt(x0) u_t(\psi_t(x_0)|x_1)=\frac{d}{dt} \psi_t(x_0) ut(ψt(x0)∣x1)=dtdψt(x0)
对 ψt(x0)\psi_t(x_0)ψt(x0) 求导,得到:
ψt′(x0)=σt′(x1)x0+μt′(x1) \psi_t'(x_0)=\sigma_t'(x_1)x_0+\mu_t'(x_1) ψt′(x0)=σt′(x1)x0+μt′(x1)
从轨迹方程中解出 x0x_0x0:
x0=ψt(x0)−μt(x1)σt(x1) x_0=\frac{\psi_t(x_0)-\mu_t(x_1)}{\sigma_t(x_1)} x0=σt(x1)ψt(x0)−μt(x1)
将 x0x_0x0 代入导数表达式,得到条件向量场:
ut(ψt(x)∣x1)=σt′(x1)σt(x1)(ψt(x)−μt(x1))+μt′(x1) u_t(\psi_t(x)|x_1)=\frac{\sigma'_t(x_1)}{\sigma_t(x_1)}\left(\psi_t(x)-\mu_t(x_1)\right)+\mu'_t(x_1) ut(ψt(x)∣x1)=σt(x1)σt′(x1)(ψt(x)−μt(x1))+μt′(x1)
举例
1. VP扩散路径
VP扩散路径(来自《SCORE-BASED GENERATIVE MODELING THROUGH STOCHASTIC DIFFERENTIAL EQUATIONS》中的公式29),此处为了区分,对变量名作了改变:
pτ(y∣y0)=N(y ∣ ατy0,(1−ατ2) I),ατ=e−12T(τ),T(τ)=∫0τβ(s)ds p_\tau(y|y_0)=\mathcal{N}(y\,|\,\alpha_{\tau}y_0,(1-\alpha_{\tau}^2)\,I),\alpha_{\tau}=e^{-\frac12T({\tau})},T({\tau})=\int_0^{\tau} \beta(s)ds pτ(y∣y0)=N(y∣ατy0,(1−ατ2)I),ατ=e−21T(τ),T(τ)=∫0τβ(s)ds
上式中 y0y_0y0 代表数据点,本文中 x1x_1x1 代表数据点。上式的时间是从数据点到噪点,本文的时间是噪点到数据点,因此 t=1−τt=1-\taut=1−τ,那么:
pt(x∣x1)=N(x ∣ α1−tx1,(1−α1−t2) I),αt=e−12T(t),T(t)=∫0tβ(s)ds p_t(x|x_1)=\mathcal{N}(x\,|\,\alpha_{1-t}x_1,(1-\alpha_{1-t}^2)\,I),\alpha_t=e^{-\frac12T(t)},T(t)=\int_0^t\beta(s)ds pt(x∣x1)=N(x∣α1−tx1,(1−α1−t2)I),αt=e−21T(t),T(t)=∫0tβ(s)ds
即:
μt(x1)=α1−tx1,σt(x1)=1−α1−t2 \mu_t(x_1)=\alpha_{1-t}x_1,\sigma_t(x_1)=\sqrt{1-\alpha_{1-t}^2} μt(x1)=α1−tx1,σt(x1)=1−α1−t2
代入得条件向量场:
ut(x∣x1)=α1−t′1−α1−t2(α1−tx−x1)=−T′(1−t)2[e−T(1−t)x−e−12T(1−t)x11−e−T(1−t)] u_t(x|x_1)=\frac{\alpha'_{1-t}}{1-\alpha^2_{1-t}}\left(\alpha_{1-t}x-x_1\right)=-\frac{T'(1-t)}{2}\left[\frac{e^{-T(1-t)}x-e^{-\frac{1}{2}T(1-t)}x_1}{1-e^{-T(1-t)}}\right] ut(x∣x1)=1−α1−t2α1−t′(α1−tx−x1)=−2T′(1−t)[1−e−T(1−t)e−T(1−t)x−e−21T(1−t)x1]
2. 线性变化
更简单的例子是随时间线性变化:
μt(x)=tx1,σt(x)=1−(1−σmin)t \mu_t(x)=tx_1,\sigma_t(x)=1-(1-\sigma_{\min})t μt(x)=tx1,σt(x)=1−(1−σmin)t
条件向量场:
ut(ψ(x)∣x1)=x1−(1−σmin)ψ(x)1−(1−σmin)t u_t(\psi(x)|x_1)=\frac{x_1-(1-\sigma_{\min})\psi(x)}{1-(1-\sigma_{\min})t} ut(ψ(x)∣x1)=1−(1−σmin)tx1−(1−σmin)ψ(x)
Flow Matching vs. Rectified Flow
Flow Matching 从概率论出发,通过条件期望将对全局边缘向量场的回归,严格等价地转化为对条件向量场的回归。
Rectified Flow 则从传输轨迹的几何结构出发,直接定义粒子级别的速度场,并通过迭代 rectification 使生成路径逐步逼近理想的 transport flow。
更多推荐



所有评论(0)