构建AI智能体:八十七、KM与Chinchilla法则:AI模型发展的两种训练法则完全解析
摘要: 大模型训练中,如何在有限计算预算(C≈6ND)下最优分配模型参数量(N)与训练数据量(D)是关键挑战。KM扩展法则主张“模型优先”,认为增大N的收益高于D(α=0.076<β=0.103),推荐N∝C^0.73、D∝C^0.27。Chinchilla法则则通过实验发现大模型普遍训练不足,提出平衡策略(α=β≈0.38),推荐N∝D∝C^0.5,即在相同预算下减小模型规模并大幅增加数据量,可
一、前言
大模型的浪潮如火如荼,但做为个人开发者和小企业的我们,不知道大家有没有面临这样的困境:有限的算力预算如同杯水车薪,是该训练一个参数更多的聪明模型,还是用更多数据喂养一个见多识广的模型,往往训练一个大体量的模型,需要耗费大量的资金和时间,而作为普通用户的我们,如果想训练一个自己的模型,在我们固定的计算预算下,我们应该训练一个多大的模型参数量?并用多少数据?如何高效地分配计算资源成为模型训练的核心问题!
扩展法则就是为了科学地回答这个问题而生的,也正是破解这一难题,为我们提供了精细化的指导思路,它们是基于大量实验得出的经验规律,用于预测模型性能损失如何随参数量N和数据量D的变化而变化,它告诉我们,盲目堆砌参数可能只是在制造昂贵的傻瓜,而恰当的数据配比能让小预算发挥大效能。理解扩展法则,意味着能用1%的资源达成80%的效果,让资源有限的团队也能在AI赛道上精准发力。这不仅是技术选择,更是生存智慧,在有限的算力资源中,找到属于我们个人或小团队的制胜策略。今天我们重点围绕两个关键的扩展法则:KM扩展法则和Chinchilla扩展法则深度解析基础释义、核心思想以及数学原理,总结两者的差异和对模型训练的重要意义。

二、核心问题和概念
在深入分析之前,我们必须明确扩展法则要解决的核心问题:
在计算预算 C 固定的前提下,如何分配模型参数量N和训练数据Token量D,才能使模型的最终性能损失L最优。
这里有几个关键概念:
1. FLOPs
1.1 核心概念
FLOPs 是浮点运算次数,它就像是衡量计算机“做了多少脑力工作”的计数器,拆解开来理解:
- Floating Point: 浮点数
- 简单理解就是带小数点的数字,比如 3.14, -0.001, 2.71828。计算机在处理科学计算、图形图像、人工智能这些复杂任务时,主要就是和浮点数打交道。
- Operations: 运算
- 指最基本的数学计算,主要是 加法、减法、乘法、除法 这四种。
所以,1个 FLOP 就代表计算机执行了一次浮点数的加法、减法、乘法或除法,FLOPs(末尾的s代表复数)就是指总的浮点运算次数。
1.2 通俗的理解
想象我们要做一道数学题,要计算一个数学公式题:y = (3.2 × 1.5) + (2.1 ÷ 0.7) - 4.0,我们一步步算一下:
- 1. 先算 3.2 × 1.5 → 1次乘法(1 FLOP)
- 2. 再算 2.1 ÷ 0.7 → 1次除法(1 FLOP)
- 3. 然后算 (4.8) + (3.0) → 1次加法(1 FLOP)
- 4. 最后算 7.8 - 4.0 → 1次减法(1 FLOP)
完成这道题,计算机总共需要执行4次浮点运算,所以它的计算量就是4FLOPs。
1.3 对大模型的含义
在训练和运行AI模型时,绝大部分工作都是大规模的矩阵和向量运算,而这些运算最终都可以分解成海量的加法和乘法。
一个具体的例子:计算一个神经元的输出
假设一个神经元有3个输入 [x1, x2, x3],对应的权重是 [w1, w2, w3],还有一个偏置项 b。
它的输出是:y = (x1*w1 + x2*w2 + x3*w3) + b
我们来数一下FLOPs:
- x1*w1, x2*w2, x3*w3 → 3次乘法(3 FLOPs)
- (x1*w1 + x2*w2) → 1次加法(1 FLOP)
- (... + x3*w3) → 1次加法(1 FLOP)
- (... + b) → 1次加法(1 FLOP)
总共:6 FLOPs。
由此可以看出,一个大语言模型有数千亿个参数(权重和偏置),每处理一个token都需要进行数百万甚至数十亿次这样的计算,这个总的FLOPs数量就会变得极其庞大。
1.4 FLOPs的重要性
FLOPs是衡量计算成本、算法效率和硬件性能的一个核心指标。
- 衡量模型复杂度/训练成本:我们说训练一个庞大的模型需要~3.14 × 10^23 FLOPs,这个天文数字直观地告诉我们这个模型训练起来极其昂贵和耗时。它就是后面我们要进一步提到的 C ≈ 6 * N * D 公式具体计算出来的结果。
- 衡量硬件算力:我们常听到的A100、H100显卡,它们的算力就是用 FLOPS来衡量的。比如一块H100显卡的算力可达 ~4 PFLOPS,意思是它每秒可以执行4 × 10^15次浮点运算。
- 估算时间:有了总计算量(FLOPs)和硬件算力(FLOPS),我们就可以粗略估算任务时间:
- 时间 ≈ 总FLOPs / 硬件FLOPS
FLOPs就是完成一个计算任务,比如训练一个AI模型所需要完成的基础数学题的总数量,表示一个工作量单位,数量越大,意味着任务越复杂,需要的计算资源越多。它是我们理解和量化人工智能等领域巨大计算需求的基石。
2. 计算预算C ≈ 6 * N * D
计算预算通常以FLOPs衡量。对于自回归语言模型训练,一个广泛使用的近似是 C ≈ 6 * N * D,这个公式是理解模型训练成本的钥匙,它告诉我们,总计算量主要取决于模型有多大和学了多少数据。
为什么是 6 * N * D?
这是一个基于Transformer架构自回归语言模型训练的经验近似值。我们可以通过分析模型的前向传播和反向传播过程来理解它:
- 前向传播:
- 当模型处理一个输入token时,它需要执行矩阵乘法和激活函数等操作。
- 在Transformer中,处理一个token所需的浮点运算次数大致与模型总参数量 N 成正比。
- 因此,对单个token的前向传播计算量 ≈ 2 * N FLOPs。
- 因子 2 主要来自于:矩阵乘法(占大部分)和激活函数等其它操作。
- 反向传播:
- 训练模型时,除了前向传播,还需要反向传播来计算梯度,从而更新权重。
- 反向传播的计算量通常是前向传播的2倍。
- 因为它需要计算损失函数对于每一层输入的梯度(类似于前向传播),以及对于权重和偏置的梯度。
- 总计算量:
- 单个token的单次训练迭代总计算量 ≈ 前向传播(2N) + 反向传播(4N) = 6N FLOPs。
- 对于整个训练过程,模型会看到整个数据集 D 个token。
- 因此,总训练计算预算 C ≈ 6 * N * D FLOPs。
这是一个近似值,实际值可能因模型架构、序列长度、优化器类型等因素而在 ~2ND 到 ~10ND 之间变化,但 6ND 是一个被广泛接受和使用的可靠估算值,用于进行高阶的趋势分析和比较。
这个公式建立了一个预算约束,如果增大了模型规模N,但保持总预算C不变,那么必须相应地减少数据量D,反之亦然。这也是今天我们要谈论解决的核心问题:如何在固定的 C 下,最优地分配 N 和 D?
3. 性能L
L 是衡量模型好坏的指标,通常是模型在预留测试集上的交叉熵损失或困惑度,在语言建模中,它几乎总是通过交叉熵损失或其派生指标困惑度来定义,损失越低,模型能力越强。
3.1 交叉熵损失
核心思想:衡量模型预测的概率分布与真实的概率分布(一个one-hot向量,代表正确的下一个词)之间的距离。
计算公式(对于一个token):
- Cross-Entropy = - Σ (y_true * log(y_pred))
- 由于 y_true 是one-hot向量(只有正确词的位置为1,其他为0),所以简化为:
- Cross-Entropy = - log(y_pred_correct_word)
直观理解:模型对正确下一个词赋予的预测概率 y_pred_correct_word 越高,损失 -log(y_prob) 就越低。
- 如果 y_pred_correct_word = 1(完美预测),则 loss = -log(1) = 0。
- 如果 y_pred_correct_word = 0.5,则 loss = -log(0.5) ≈ 0.69。
- 如果 y_pred_correct_word = 0.1(预测很差),则 loss = -log(0.1) ≈ 2.30。
整个数据集的损失是所有这些单个token损失的平均值。
3.2 困惑度
困惑度是交叉熵损失的指数形式,因为它更直观。
计算公式:Perplexity = exp(Cross-Entropy_Loss)
直观理解:困惑度可以理解为“模型在预测下一个词时的平均不确定性程度”或者“平均分支因子”。
- 假设一个模型的困惑度为 10。这意味着,在预测每一个词时,模型感觉平均有10个等可能的选择。困惑度越低,模型越确定。
- 一个完美的模型的困惑度为 1(因为它总是100%确定下一个词是什么)。
- 如果模型只是随机猜测一个包含10,000个词的词汇表,其困惑度就是 10,000。
关系:由于 Perplexity = exp(L),最小化交叉熵损失 L 就等价于最小化困惑度。在扩展法则的研究中,通常直接使用交叉熵损失 L 作为优化目标,因为它数学性质更好(是加法性的)。
交叉熵损失和困惑度的详细说明可参考《信息论完全指南:从基础概念到在大模型中的实际应用》
4. 幂律与收益递减
这是扩展法则的灵魂,揭示了性能提升的基本规律。
3.1 幂律关系
扩展法则发现,损失 L 与模型规模 N 和数据规模 D 遵循幂律关系:
L ∝ 1 / N^α
L ∝ 1 / D^β
这意味着,L 与 N^α 和 D^β 成反比。将其与不可约损失 E 结合,就得到了我们之前看到的完整公式:
L(N, D) = E + A/N^α + B/D^β
3.2 收益递减效应
幂律中的指数 α 和 β(通常远小于1)是理解收益递减的关键。
让我们通过一个例子来理解:
- 假设 α = 0.5。
- 当 N 从 10^6(100万)增加到 10^7(1000万)时,N^α 从 (10^6)^0.5 = 1000 增加到 (10^7)^0.5 ≈ 3162。
- 性能提升(损失下降)的倍数是 3162 / 1000 ≈ 3.16 倍。
- 现在,再将 N 从 10^7 增加到 10^8(1亿),N^α 从 3162 增加到 (10^8)^0.5 = 10000。
- 性能提升的倍数是 10000 / 3162 ≈ 3.16 倍。
示例发现:
- 在两种情况下,模型规模都增加了10倍。
- 但性能提升的倍数却完全相同,都是约3.16倍,而不是随着规模变大而线性增加。
- 这意味着,每当我们想让性能提升固定的倍数,我们只需要将模型规模增加一个固定的比例,例如10倍。这就是收益递减,我们需要投入指数级增长的资源,才能获得线性增长的性能。
对扩展法则的实际意义:
- KM法则:α ≈ 0.076, β ≈ 0.103。指数非常小,意味着收益递减得非常非常慢。为了显著降低损失,我们需要极大地增加 N 或 D。同时,由于 α < β,增加 N 的收益衰减得比增加 D 更慢,因此KM推荐优先扩大 N。
- Chinchilla法则:α ≈ β ≈ 0.38。指数更大,意味着收益递减效应比KM认为的要严重得多。同时,α 和 β 相等,意味着增加模型和增加数据对性能的贡献是平衡的。因此,Chinchilla推荐同步扩大 N 和 D。
5. 总结
预算C、性能L和幂律这三个概念构成了一个完整的逻辑链:
- 有固定预算 C ≈ 6ND。
- 目标是最小化损失 L(或困惑度)。
- 通过实验发现,L 与 N 和 D 存在幂律关系,并伴有收益递减。
- KM和Chinchilla法则的核心争论在于幂律指数 α 和 β 的具体数值,这直接导致了截然不同的最优资源分配策略。
三、KM扩展法则
1. 基础理解
核心思想: 在计算预算充足的情况下,模型参数量 N 是影响性能的最关键因素。为了达到最佳性能,应优先扩大模型规模,同时按比例适当增加数据量。
一个简单的比喻:
好比我们在组建一个研究团队来解决一个复杂问题。
- 模型参数量就像是团队中博士的数量。博士越多,团队的集体智力、分析能力和解决复杂问题的潜力就越大。
- 训练数据量就像是提供给这个团队的研究资料和参考文献的数量。
- KM法则的发现是:相比于无止境地增加研究资料,优先增加博士的数量,对最终解决问题能力的提升效果更显著。
- 当然,资料不能太少,否则博士们会“巧妇难为无米之炊”。但KM法则认为,在资源和精力有限时,我们应该把重点放在招募更多、更聪明的博士上。
2. 数学公式
KM法则将测试损失 L 建模为 N 和 D 的幂律函数:
L(N, D) = E + (A / N^α) + (B / D^β)
其中:
- L(N, D): 这是我们最终想知道的模型损失。它取决于模型参数量 N 和训练数据Token量 D。
- E: 不可约损失,这是理论上的最低损失,由数据本身的内在复杂度和噪音决定,就像无论多么聪明的学生,也无法完美预测一次完全随机的硬币抛掷结果。E 是一个无法通过改进模型或增加数据来超越的极限。
- A / N^α: 模型容量损失,这部分损失是因为我们的模型不够大、不够复杂。
- N 是模型参数量。
- A 和 α 是常数,通过实验拟合得出,OpenAI发现 α ≈ 0.076。
- 直观理解:当 N 增大时,A / N^α 会变小,这意味着模型越大,因容量不足导致的损失就越小。α 很小,意味着我们需要极大地增加 N,才能让这项显著下降。
- B / D^β: 数据容量损失,这部分损失是因为我们的训练数据不够多。
- D 是训练数据Token量。
- B 和 β 是常数,通过实验拟合得出,OpenAI发现 β ≈ 0.103。
- 直观理解:当 D 增大时,B / D^β 会变小。这意味着数据越多,因数据不足导致的损失就越小。β 也很小,意味着数据也需要增加很多才能显著降低这项损失。
通过这个公式,如果我们知道了常数 E, A, B, α, β,我们就可以预测:一个拥有 N 参数、用 D 数据训练的模型,最终性能 L 大概会是多少,这为模型设计提供了很好的指导,由于 α 和 β 都很小,为了最小化损失,需要同时增大 N 和 D,但KM法则的实证结果表明,对 N 的投资回报率更高。
3. KM法则的决策指南
通过对上述公式的分析和实验验证,KM法则得出了几个改变AI研发方向的结论:
3.1 模型规模 N 的收益高于数据规模 D
- 因为实验中拟合出的 α (0.076) 略小于 β (0.103)。这意味着,N 的指数更小,其收益随规模增长而衰减的速度更慢。换句话说,扩大 N 的“后劲”更足。
- 在计算预算 C(约等于 6ND)固定的情况下,最优配置是大力倾斜于 N。论文给出的近似最优比例是:模型规模 N 应与 C 的约0.73次方成正比,而数据 D 仅与 C 的约0.27次方成正比。
3.2 性能平滑可预测
- 模型性能主要取决于 N 和 D,而对模型架构、训练方式等超参数的依赖相对较小。这意味着,你可以通过“缩放”来可靠地提升性能。
3.3 在计算最优边界上,模型应该“训练不足”
- 这是KM法则一个非常重要且反直觉的推论。它说:在固定的计算预算下,我们应该训练一个非常大的模型,但只在相对较少的数据上训练它,直到它还没有完全收敛(即“训练不足”)时就停止。
- 为什么?因为把同样的计算预算用来训练一个更大的模型,即使训练不充分,比用来更充分地训练一个较小的模型,最终性能更好。
4. 示例分析
KM法则的核心公式:L(N, D) = E + A/N^α + B/D^β
其中:
- L:模型损失(越低越好)
- N:模型参数量
- D:训练数据量(token数)
- E, A, B, α, β:通过实验拟合的常数
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
def km_scaling_law_log(N, D, E=2.0, A=3.0, B=3.0, alpha=0.3, beta=0.3):
"""
计算KM扩展法则预测的损失值 - 对数尺度版本
确保输出在合理范围内
参数:
N: 模型参数量 (单位: 十亿)
D: 训练数据量 (单位: 十亿token)
"""
# 使用对数尺度确保数值稳定
# 基础损失 + 模型项 + 数据项
L = E + (A / (np.log(N + 1) ** alpha)) + (B / (np.log(D + 1) ** beta))
return L
def safe_exp(x):
"""安全的指数函数,防止溢出"""
return np.exp(np.clip(x, -700, 700))
# 示例1: 单个模型预测
print("=== 示例1: 单个模型性能预测 ===")
N_example = 1.0 # 10亿参数
D_example = 5.0 # 50亿token
loss = km_scaling_law_log(N_example, D_example)
print(f"模型规模: {N_example}B 参数")
print(f"训练数据: {D_example}B token")
print(f"KM法则预测损失: {loss:.4f}")
print(f"对应的困惑度: {safe_exp(loss):.2f}\n")
# 示例2: 不同规模模型的对比
print("=== 示例2: 不同规模模型对比 ===")
model_sizes = [0.1, 0.5, 1.0, 5.0, 10.0, 50.0, 100.0] # 从1亿到1000亿参数
fixed_data = 10.0 # 固定100亿token数据
print(f"固定训练数据: {fixed_data}B token")
print("模型规模(B)\t预测损失\t困惑度")
print("-" * 55)
for size in model_sizes:
loss = km_scaling_law_log(size, fixed_data)
perplexity = safe_exp(loss)
print(f"{size:8.1f}\t{loss:.4f}\t\t{perplexity:.2f}")
# 示例3: 不同数据量的对比
print("\n=== 示例3: 不同数据量对比 ===")
data_sizes = [1.0, 5.0, 10.0, 50.0, 100.0, 500.0, 1000.0] # 从10亿到1万亿token
fixed_model = 1.0 # 固定10亿参数
print(f"固定模型规模: {fixed_model}B 参数")
print("数据量(B)\t预测损失\t困惑度")
print("-" * 55)
for data in data_sizes:
loss = km_scaling_law_log(fixed_model, data)
perplexity = safe_exp(loss)
print(f"{data:8.1f}\t{loss:.4f}\t\t{perplexity:.2f}")
# 示例4: 可视化分析
print("\n=== 示例4: 生成可视化图表 ===")
# 创建模型规模和数据的网格
N_range = np.logspace(-1, 2, 50) # 从0.1B到100B参数
D_range = np.logspace(0, 3, 50) # 从1B到1000B token
N_grid, D_grid = np.meshgrid(N_range, D_range)
L_grid = km_scaling_law_log(N_grid, D_grid)
# 创建可视化图表
fig = plt.figure(figsize=(16, 5))
# 子图1: 固定数据量,看模型规模的影响
ax1 = fig.add_subplot(131)
fixed_D = 10.0 # 固定10B token
losses_N = [km_scaling_law_log(N, fixed_D) for N in N_range]
ax1.semilogx(N_range, losses_N, 'b-', linewidth=3)
ax1.set_xlabel('模型参数量 (十亿)')
ax1.set_ylabel('预测损失')
ax1.set_title('模型规模对性能的影响\n(固定数据量)')
ax1.grid(True, alpha=0.3)
# 标记GPT-3规模的点
gpt3_N = 175
gpt3_loss = km_scaling_law_log(gpt3_N, fixed_D)
ax1.axvline(x=gpt3_N, color='red', linestyle='--', alpha=0.7)
ax1.plot(gpt3_N, gpt3_loss, 'ro', markersize=8)
ax1.annotate(f'GPT-3\n({gpt3_N}B)', (gpt3_N, gpt3_loss),
xytext=(10, 10), textcoords='offset points',
bbox=dict(boxstyle="round,pad=0.3", fc="yellow", alpha=0.7))
# 子图2: 固定模型规模,看数据量的影响
ax2 = fig.add_subplot(132)
fixed_N = 1.0 # 固定1B参数
losses_D = [km_scaling_law_log(fixed_N, D) for D in D_range]
ax2.semilogx(D_range, losses_D, 'r-', linewidth=3)
ax2.set_xlabel('训练数据量 (十亿token)')
ax2.set_ylabel('预测损失')
ax2.set_title('数据量对性能的影响\n(固定模型规模)')
ax2.grid(True, alpha=0.3)
# 子图3: 热力图展示N和D的共同影响
ax3 = fig.add_subplot(133)
contour = ax3.contourf(np.log10(N_grid), np.log10(D_grid), L_grid, levels=20, cmap='RdYlBu_r')
ax3.set_xlabel('log10(模型参数) (B)')
ax3.set_ylabel('log10(训练数据) (B)')
ax3.set_title('KM扩展法则热力图\n颜色表示损失值')
# 添加等值线
contour_lines = ax3.contour(np.log10(N_grid), np.log10(D_grid), L_grid,
levels=10, colors='black', alpha=0.5)
ax3.clabel(contour_lines, inline=True, fontsize=8)
plt.colorbar(contour, ax=ax3, label='预测损失')
plt.tight_layout()
plt.show()
# 示例5: 实际模型案例分析
print("\n=== 示例5: 实际模型性能预测 ===")
real_models = [
{"name": "GPT-3", "N": 175, "D": 300},
{"name": "LLaMA-2 7B", "N": 7, "D": 2000},
{"name": "LLaMA-2 70B", "N": 70, "D": 2000},
{"name": "PaLM", "N": 540, "D": 780},
{"name": "Chinchilla", "N": 70, "D": 1400},
]
print("模型名称\t\t参数(B)\t数据(B)\t预测损失\t困惑度")
print("-" * 70)
for model in real_models:
loss = km_scaling_law_log(model["N"], model["D"])
perplexity = safe_exp(loss)
print(f"{model['name']:12}\t{model['N']:4.0f}\t{model['D']:4.0f}\t{loss:.4f}\t\t{perplexity:.2f}")
# 示例6: 资源分配建议
print("\n=== 示例6: 资源分配策略 ===")
def analyze_resource_allocation(total_compute):
"""分析不同资源分配策略"""
print(f"\n在总计算量 {total_compute:.1e} FLOPs 下的策略分析:")
print("策略\t\t\t模型规模(B)\t数据量(B)\t预测损失")
print("-" * 65)
# 策略1: KM风格 (偏向大模型)
N_km = (total_compute / 6) ** 0.7 / 1e9
D_km = (total_compute / 6) ** 0.3 / 1e9
loss_km = km_scaling_law_log(N_km, D_km)
print(f"KM策略\t\t\t{N_km:6.1f}\t\t{D_km:6.1f}\t\t{loss_km:.4f}")
# 策略2: Chinchilla风格 (平衡)
N_chi = (total_compute / 6) ** 0.5 / 1e9
D_chi = (total_compute / 6) ** 0.5 / 1e9
loss_chi = km_scaling_law_log(N_chi, D_chi)
print(f"Chinchilla策略\t\t{N_chi:6.1f}\t\t{D_chi:6.1f}\t\t{loss_chi:.4f}")
# 策略3: 偏向大数据
N_data = (total_compute / 6) ** 0.3 / 1e9
D_data = (total_compute / 6) ** 0.7 / 1e9
loss_data = km_scaling_law_log(N_data, D_data)
print(f"数据优先策略\t\t{N_data:6.1f}\t\t{D_data:6.1f}\t\t{loss_data:.4f}")
analyze_resource_allocation(1e22) # 分析1e22 FLOPs预算
代码详细解释
4.1 核心函数
def km_scaling_law_log(N, D, E=2.0, A=3.0, B=3.0, alpha=0.3, beta=0.3):
"""
计算KM扩展法则预测的损失值 - 对数尺度版本
确保输出在合理范围内
参数:
N: 模型参数量 (单位: 十亿)
D: 训练数据量 (单位: 十亿token)
"""
# 使用对数尺度确保数值稳定
# 基础损失 + 模型项 + 数据项
L = E + (A / (np.log(N + 1) ** alpha)) + (B / (np.log(D + 1) ** beta))
return L
- 这个函数实现了KM法则的核心公式
- 默认参数基于OpenAI论文的近似值
- E=2.0 代表理论上的最小损失
- alpha=0.3 和 beta=0.3 是关键,决定了模型和数据的收益递减速度
4.2 单个预测示例
N_example = 1.0 # 10亿参数
D_example = 5.0 # 50亿token
loss = km_scaling_law_log(N_example, D_example)
这里我们预测一个10亿参数、用50亿token训练的模型的性能。
4.3 规模对比分析
model_sizes = [0.1, 0.5, 1.0, 5.0, 10.0, 50.0, 100.0] # 从1亿到1000亿参数
通过这个循环,我们可以看到模型规模从1亿参数增长到100亿参数时,性能如何变化。
4.4 输出结果
=== 示例1: 单个模型性能预测 ===
模型规模: 1.0B 参数
训练数据: 5.0B token
KM法则预测损失: 7.8672
对应的困惑度: 2610.12=== 示例2: 不同规模模型对比 ===
固定训练数据: 10.0B token
模型规模(B) 预测损失 困惑度
-------------------------------------------------------
0.1 10.3803 32219.56
0.5 8.2408 3792.44
1.0 7.6563 2114.01
5.0 6.8261 921.62
10.0 6.6153 746.45
50.0 6.2972 543.03
100.0 6.2038 494.62=== 示例3: 不同数据量对比 ===
固定模型规模: 1.0B 参数
数据量(B) 预测损失 困惑度
-------------------------------------------------------
1.0 8.6974 5987.08
5.0 7.8672 2610.12
10.0 7.6563 2114.01
50.0 7.3382 1537.90
100.0 7.2448 1400.80
500.0 7.0827 1191.19
1000.0 7.0286 1128.50=== 示例4: 生成可视化图表 ===
=== 示例5: 实际模型性能预测 ===
模型名称 参数(B) 数据(B) 预测损失 困惑度
----------------------------------------------------------------------
GPT-3 175 300 5.6117 273.60
LLaMA-2 7B 7 2000 6.0409 420.29
LLaMA-2 70B 70 2000 5.5744 263.58
PaLM 540 780 5.4262 227.27
Chinchilla 70 1400 5.5980 269.90=== 示例6: 资源分配策略 ===
在总计算量 1.0e+22 FLOPs 下的策略分析:
策略 模型规模(B) 数据量(B) 预测损失
-----------------------------------------------------------------
KM策略 716628.6 0.0 21.8804
Chinchilla策略 40.8 40.8 6.0413
数据优先策略 0.0 716628.6 21.8804

图例分析:
- 左图:蓝色曲线显示,随着模型参数增加,损失稳步下降,但下降速度逐渐变缓,收益递减
- 中图:红色曲线显示,数据量增加也能降低损失,但效果不如增大模型规模明显
- 右图:3D曲面显示损失如何同时受模型规模和数据量的影响
四、Chinchilla扩展法则
1. 基础理解
核心思想: 对于给定的计算预算 C,模型参数量 N 和数据Token量 D 应该成比例地增长。模型不是越大越好,而是需要与足够多的数据配对。许多现有的大模型是训练不足的,减小模型规模并大幅增加数据量,可以在相同计算成本下获得更优的性能。
这个思想可以分解为三个关键点:
1.1 挑战规模至上的观点
- 之前:KM法则认为模型规模是提升性能的最有效杠杆。
- Chinchilla发现:KM法则基于的实验范围有限(最大模型17B),当模型规模扩大到千亿参数时,其推荐的数据量远远不够,导致模型训练不足。模型虽然参数很多,但因为没有看到足够的数据,无法充分学习,其潜力未被完全挖掘。
1.2 揭示训练不足问题
- 比喻:这就像招募了一个智商200的天才(大模型),但只给他一本小学课本(少量数据)学习。他很快就能把课本倒背如流(训练损失很低),但并没有真正掌握广博的知识(泛化能力差)。而一个智商120的聪明人(中等模型),如果让他读完整个图书馆的藏书(海量数据),其解决实际问题的能力会远超那个天才。
- Chinchilla证明:像GPT-3 (175B)、Gopher (280B)这样的模型,如果将其参数量减半,但将训练数据增加至原来的4倍,训练出的模型性能反而会更强。
1.3 确立平衡分配原则
- Chinchilla的最终目标是为给定的计算预算 C,找到最优的 N 和 D 配比。
- 其核心结论是:计算预算应该在模型规模和训练数据之间几乎均等地分配。具体来说,模型参数量 N 和训练数据token量 D 应与计算预算 C 的平方根成正比。
2. 数学公式
2.1 公式说明
与KM法则类似,Chinchilla将测试损失 L 建模为模型参数量 N 和训练数据量 D 的函数:
L(N, D) = E + A/(N^α) + B/(D^β)
其中:
- L(N, D): 模型在未知数据上的损失(交叉熵损失)。L 越低,模型性能越好。
- E: 不可约损失,代表了数据分布本身的内在噪音和不确定性,是任何模型都无法超越的理论下限。
- A/(N^α): 模型容量项,代表了因模型参数有限而产生的近似误差。随着模型参数 N 增加,这项误差会减小。
- B/(D^β): 数据容量项,代表了因训练数据有限而产生的估计误差。随着训练数据 D 增加,这项误差会减小。
关键的Chinchilla参数值:
DeepMind通过实验拟合出的参数约为:
- α ≈ 0.38
- β ≈ 0.38
- A 和 B 是相应的缩放常数。
2.2 与KM法则的数学对比
| 特性 | KM 法则 | Chinchilla 法则 | 含义与影响 |
|---|---|---|---|
| 模型指数 α | ~0.076 |
~0.38 |
Chinchilla的α大了约5倍! 这意味着增加模型规模带来的性能收益衰减得快得多。模型规模的增长不再那么“划算”。 |
| 数据指数 β | ~0.103 |
~0.38 |
Chinchilla的β也大了约3.7倍! 这意味着增加数据量带来的性能收益同样衰减得很快,但其衰减速度现在与模型项持平。 |
| 指数关系 | α < β |
α ≈ β |
这是最根本的差异。 KM认为模型收益衰减更慢,故应优先扩大模型。Chinchilla发现两者衰减速度相同,故应平衡分配资源。 |
2.3 直观理解指数差异:
α 和 β 决定了“收益递减”的速度。
- KM的小指数:意味着需要把 N 或 D 扩大非常多,才能让 A/N^α 或 B/D^β 这一项显著减小。收益来得慢,但持续久。
- Chinchilla的大指数:意味着 N 或 D 的初始增长能带来显著的性能提升,但很快就会碰到“收益墙”,再投入资源的回报率急剧下降。
2.4 了解 N_op 和 D_op
2.4.1 N_op 和 D_op 是什么
- N_op: 最优模型参数量
- 在给定的计算预算 C 下,能够使模型性能达到最好的那个模型规模。
- 单位:通常是十亿参数
- D_op: 最优训练数据量
- 在给定的计算预算 C 下,能够使模型性能达到最好的那个训练数据量。
- 单位:通常是十亿token
2.4.2 符号 ∝ 的含义
∝ 表示"正比于",所以:
- N_op ∝ C^0.5 意思是:最优模型规模与计算预算的平方根成正比
- D_op ∝ C^0.5 意思是:最优数据量与计算预算的平方根成正比
2.4.3 直观理解:切蛋糕的比喻
想象我们有一块固定大小的蛋糕(计算预算 C),要分给两个人:
- 一个人叫"模型规模" (N)
- 一个人叫"训练数据" (D)
Chinchilla法则告诉我们:应该把蛋糕平均分给这两个人!
- N_op ∝ C^0.5 和 D_op ∝ C^0.5 就是这个平均分配规则的数学表达。
2.4.4 具体实例
场景1:小预算情况
假设计算预算 C = 1e21 FLOPs
- N_op ∝ (1e21)^0.5 ≈ 3.16e10 参数 ≈ 31.6亿参数
- D_op ∝ (1e21)^0.5 ≈ 3.16e10 token ≈ 31.6亿token
场景2:预算增加100倍
现在预算增加到 C = 1e23 FLOPs(增加了100倍)
- N_op ∝ (1e23)^0.5 ≈ 3.16e11 参数 ≈ 316亿参数
- D_op ∝ (1e23)^0.5 ≈ 3.16e11 token ≈ 316亿token
对比分析:
- 计算预算变化:预算增加100倍
- 模型规模变化:模型增加10倍
- 数据量变化:数据增加10倍
- 平方根效应:预算线性增长,规模平方根增长
2.4 计算最优分配公式
基于上述性能预测公式,Chinchilla推导出了在固定计算预算 C(其中 C ≈ 6 N D)下,如何分配 N 和 D 才能使损失 L 最小化。
核心发现:最优配置是让模型容量项和数据容量项对损失的贡献大致相等。
其推导出的最优比例是:
N_op ∝ C^a
D_op ∝ C^b
其中 a = β/(α+β), b = α/(α+β)
代入Chinchilla的 α=β=0.38:
- a = 0.38/(0.38+0.38) = 0.5
- b = 0.38/(0.38+0.38) = 0.5
因此,最优策略为:
N_op ∝ C^0.5
D_op ∝ C^0.5
具体经验性结论:
对于一个计算预算 C,Chinchilla推荐:
- 模型参数量 N 约等于 (C / 20)^0.5
- 训练数据Token量 D 约等于 (C / 20)^0.5
注意:这里的常数20是考虑了模型前向和反向传播的FLOPs估算后的一个经验值,与 C ≈ 6ND 的本质思想一致。
Chinchilla法则的数学公式告诉我们:
- 性能公式:L(N, D) = E + A/N^0.38 + B/D^0.38
- 模型和数据的贡献是对称且平衡的。
- 分配公式:N_op ∝ C^0.5, D_op ∝ C^0.5
- 计算预算应该均等地分配给模型复杂度和数据量。
3. 示例分析
在固定计算预算下,模型参数量(N)和训练数据量(D)应该平衡增长,而不是像KM法则那样偏向模型规模。
核心公式:L(N, D) = E + A/N^α + B/D^β
其中 α ≈ β ≈ 0.38,这与KM法则的 α=0.076, β=0.103 形成鲜明对比。
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
def chinchilla_scaling_law(N, D, E=1.69, A=406.4, B=410.7, alpha=0.38, beta=0.38):
"""
计算Chinchilla扩展法则预测的损失值
参数:
N: 模型参数量 (单位: 十亿)
D: 训练数据量 (单位: 十亿token)
E, A, B, alpha, beta: Chinchilla法则的经验参数
返回:
L: 预测的损失值
"""
# Chinchilla核心公式 - 注意指数alpha和beta都接近0.38
L = E + (A / (N ** alpha)) + (B / (D ** beta))
return L
def km_scaling_law(N, D, E=1.5, A=500, B=1000, alpha=0.076, beta=0.103):
"""
KM扩展法则用于对比
"""
L = E + (A / (N ** alpha)) + (B / (D ** beta))
return L
# 示例1: 单个模型预测对比
print("=== 示例1: Chinchilla vs KM 预测对比 ===")
N_example = 70 # 70亿参数
D_example = 1500 # 1.5万亿token
loss_chinchilla = chinchilla_scaling_law(N_example, D_example)
loss_km = km_scaling_law(N_example * 1000, D_example * 1000) # 转换为百万单位
print(f"模型规模: {N_example}B 参数")
print(f"训练数据: {D_example}B token")
print(f"Chinchilla预测损失: {loss_chinchilla:.4f}")
print(f"Chinchilla预测困惑度: {np.exp(loss_chinchilla):.2f}")
print(f"KM法则预测损失: {loss_km:.4f}")
print(f"KM法则预测困惑度: {np.exp(loss_km):.2f}\n")
# 示例2: 计算最优配置对比
print("=== 示例2: 最优配置计算对比 ===")
def find_optimal_allocation(compute_budget, law_type='chinchilla'):
"""
根据不同的扩展法则找到最优配置
假设计算预算 C ≈ 6 * N * D
"""
if law_type == 'chinchilla':
# Chinchilla: 平衡分配
alpha, beta = 0.38, 0.38
N_optimal = (compute_budget / 6) ** 0.5 # N ∝ C^0.5
D_optimal = (compute_budget / 6) ** 0.5 # D ∝ C^0.5
else: # KM法则
alpha, beta = 0.076, 0.103
optimal_N_ratio = alpha / (alpha + beta)
optimal_D_ratio = beta / (alpha + beta)
N_optimal = (compute_budget / 6) ** optimal_N_ratio # N ∝ C^0.74
D_optimal = (compute_budget / 6) ** optimal_D_ratio # D ∝ C^0.26
return N_optimal, D_optimal
# 测试不同计算预算下的最优配置
budgets = [1e21, 5e21, 1e22, 5e22] # 不同的计算预算
print("计算预算(FLOPs)\t法则类型\t\t最优参数(B)\t最优数据(B)\t参/数比例")
print("-" * 85)
for budget in budgets:
# Chinchilla最优配置
N_chi, D_chi = find_optimal_allocation(budget, 'chinchilla')
ratio_chi = N_chi / D_chi
# KM最优配置
N_km, D_km = find_optimal_allocation(budget, 'km')
ratio_km = N_km / D_km
print(f"{budget:.1e}\tChinchilla\t{N_chi/1e9:8.1f}\t\t{D_chi/1e9:8.1f}\t\t{ratio_chi:.3f}")
print(f"{budget:.1e}\tKM法则\t\t{N_km/1e9:8.1f}\t\t{D_km/1e9:8.1f}\t\t{ratio_km:.3f}")
print("-" * 85)
# 示例3: 训练不足分析
print("\n=== 示例3: 训练不足分析 ===")
def analyze_under_training(model_size_B, compute_budget):
"""
分析在固定计算预算下,不同数据量对性能的影响
"""
print(f"\n分析 {model_size_B}B 参数模型在 {compute_budget:.1e} FLOPs 预算下的表现:")
# Chinchilla推荐的数据量
N_chi_opt, D_chi_opt = find_optimal_allocation(compute_budget, 'chinchilla')
D_chi_for_model = compute_budget / (6 * model_size_B * 1e9)
# KM推荐的数据量
N_km_opt, D_km_opt = find_optimal_allocation(compute_budget, 'km')
D_km_for_model = compute_budget / (6 * model_size_B * 1e9)
# 计算不同数据量下的损失
data_ratios = [0.25, 0.5, 1.0, 2.0, 4.0] # 相对于Chinchilla推荐的数据量比例
print("数据比例\t实际数据(B)\tChinchilla损失\tKM损失\t\t训练状态")
print("-" * 75)
for ratio in data_ratios:
actual_data = D_chi_for_model * ratio / 1e9 # 转换为十亿单位
loss_chi = chinchilla_scaling_law(model_size_B, actual_data)
loss_km = km_scaling_law(model_size_B * 1000, actual_data * 1000)
status = "严重训练不足" if ratio < 0.5 else "训练不足" if ratio < 1.0 else "接近最优" if ratio <= 2.0 else "数据充足"
print(f"{ratio:4.2f}\t\t{actual_data:8.1f}\t\t{loss_chi:.4f}\t\t{loss_km:.4f}\t\t{status}")
analyze_under_training(70, 1e22) # 分析70B模型
# 示例4: 可视化对比
print("\n=== 示例4: 生成对比可视化图表 ===")
# 创建计算预算范围
compute_range = np.logspace(20, 24, 50) # 10^20 到 10^24 FLOPs
# 计算两种法则的最优配置
N_chi_optimal = []
D_chi_optimal = []
N_km_optimal = []
D_km_optimal = []
for C in compute_range:
N_chi, D_chi = find_optimal_allocation(C, 'chinchilla')
N_km, D_km = find_optimal_allocation(C, 'km')
N_chi_optimal.append(N_chi / 1e9) # 转换为十亿单位
D_chi_optimal.append(D_chi / 1e9)
N_km_optimal.append(N_km / 1e9)
D_km_optimal.append(D_km / 1e9)
# 创建可视化图表
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
# 子图1: 最优模型规模对比
ax1.loglog(compute_range, N_chi_optimal, 'r-', linewidth=3, label='Chinchilla最优')
ax1.loglog(compute_range, N_km_optimal, 'b--', linewidth=2, label='KM最优')
ax1.set_xlabel('计算预算 (FLOPs)')
ax1.set_ylabel('最优模型规模 (十亿参数)')
ax1.set_title('模型规模推荐对比\nChinchilla vs KM')
ax1.legend()
ax1.grid(True, alpha=0.3)
# 标记具体预算点示例
sample_budget = 1e22
N_chi_sample = (sample_budget / 6) ** 0.5 / 1e9
N_km_sample = (sample_budget / 6) ** (0.076/(0.076+0.103)) / 1e9
ax1.annotate(f'在{sample_budget:.0e} FLOPs:\nChinchilla: {N_chi_sample:.0f}B\nKM: {N_km_sample:.0f}B',
xy=(sample_budget, N_chi_sample), xytext=(1e21, 500),
arrowprops=dict(arrowstyle='->', color='red'),
bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="red", alpha=0.8))
# 子图2: 最优数据量对比
ax2.loglog(compute_range, D_chi_optimal, 'r-', linewidth=3, label='Chinchilla最优')
ax2.loglog(compute_range, D_km_optimal, 'b--', linewidth=2, label='KM最优')
ax2.set_xlabel('计算预算 (FLOPs)')
ax2.set_ylabel('最优训练数据量 (十亿token)')
ax2.set_title('训练数据量推荐对比\nChinchilla vs KM')
ax2.legend()
ax2.grid(True, alpha=0.3)
# 子图3: 参数-数据比例对比
ratio_chi = np.array(N_chi_optimal) / np.array(D_chi_optimal)
ratio_km = np.array(N_km_optimal) / np.array(D_km_optimal)
ax3.semilogx(compute_range, ratio_chi, 'r-', linewidth=3, label='Chinchilla比例')
ax3.semilogx(compute_range, ratio_km, 'b--', linewidth=2, label='KM比例')
ax3.set_xlabel('计算预算 (FLOPs)')
ax3.set_ylabel('参数/数据比例 (N/D)')
ax3.set_title('资源分配策略对比\n比例越高 = 越偏向模型规模')
ax3.legend()
ax3.grid(True, alpha=0.3)
# 子图4: 性能对比 - 固定计算预算下的损失
fixed_budget = 1e22
model_sizes = [7, 20, 70, 200] # 不同的模型规模 (十亿参数)
chinchilla_losses = []
km_losses = []
for size in model_sizes:
# 在固定预算下,计算对应的数据量
data_chi = fixed_budget / (6 * size * 1e9) / 1e9 # 十亿token单位
data_km = fixed_budget / (6 * size * 1e9) / 1e9 # 相同计算预算
loss_chi = chinchilla_scaling_law(size, data_chi)
loss_km = km_scaling_law(size * 1000, data_km * 1000)
chinchilla_losses.append(loss_chi)
km_losses.append(loss_km)
ax4.plot(model_sizes, chinchilla_losses, 'ro-', linewidth=2, label='Chinchilla预测')
ax4.plot(model_sizes, km_losses, 'bs--', linewidth=2, label='KM预测')
ax4.set_xlabel('模型规模 (十亿参数)')
ax4.set_ylabel('预测损失')
ax4.set_title(f'固定预算 {fixed_budget:.0e} FLOPs 下\n不同模型规模的性能对比')
ax4.legend()
ax4.grid(True, alpha=0.3)
# 标记最优配置
optimal_size_chi = (fixed_budget / 6) ** 0.5 / 1e9
optimal_data_chi = (fixed_budget / 6) ** 0.5 / 1e9
optimal_loss_chi = chinchilla_scaling_law(optimal_size_chi, optimal_data_chi)
ax4.axvline(x=optimal_size_chi, color='red', linestyle=':', alpha=0.7)
ax4.annotate(f'Chinchilla最优\n{optimal_size_chi:.0f}B模型',
xy=(optimal_size_chi, optimal_loss_chi),
xytext=(optimal_size_chi+30, optimal_loss_chi+0.1),
arrowprops=dict(arrowstyle='->', color='red'))
plt.tight_layout()
plt.show()
输出结果:
=== 示例1: Chinchilla vs KM 预测对比 ===
模型规模: 70B 参数
训练数据: 1500B token
Chinchilla预测损失: 108.0695
Chinchilla预测困惑度: 85899031069167667854303274236400488860482535424.00
KM法则预测损失: 446.7954
KM法则预测困惑度: 109845366723675280192034736636001868827496702856567587991204197607574163216094605052146384532230226621195552621200325104919263199725470312853558268331235709933397008939880728797698102035965018112.00
Chinchilla: 108, KM: 447,预测的损失值和现实偏差很大,对参数(A, B, E, α, β)需要重新校准。
=== 示例2: 最优配置计算对比 ===
计算预算(FLOPs) 法则类型 最优参数(B) 最优数据(B) 参/数比例
-------------------------------------------------------------------------------------
1.0e+21 Chinchilla 12.9 12.9 1.000
1.0e+21 KM法则 0.4 432.5 0.001
-------------------------------------------------------------------------------------
5.0e+21 Chinchilla 28.9 28.9 1.000
5.0e+21 KM法则 0.8 1092.0 0.001
-------------------------------------------------------------------------------------
1.0e+22 Chinchilla 40.8 40.8 1.000
1.0e+22 KM法则 1.0 1627.3 0.001
-------------------------------------------------------------------------------------
5.0e+22 Chinchilla 91.3 91.3 1.000
5.0e+22 KM法则 2.0 4108.2 0.000
-------------------------------------------------------------------------------------
Chinchilla法则(平衡策略):
- 模型与数据1:1平衡:参数/数据比例始终为 1.000
- 同步增长:在1e21 FLOPs时,推荐约13B模型配13B数据;在5e22 FLOPs时,推荐约91B模型配91B数据
- 实践意义:这是一种“中等模型 + 中等数据”的平衡发展路径
KM法则(极端偏向策略):
- 极度偏向数据:参数/数据比例仅为 0.001
- 模型极小,数据极大:在1e22 FLOPs时,推荐1B模型配1627B数据(差了1600倍!)
- 实践意义:这是一种“极小模型 + 海量数据”的极端策略
=== 示例3: 训练不足分析 ===
分析 70B 参数模型在 1.0e+22 FLOPs 预算下的表现:
数据比例 实际数据(B) Chinchilla损失 KM损失 训练状态
-------------------------------------------------------------------------------------------------------
0.25 6.0 291.0828 624.1760 严重训练不足
0.50 11.9 242.7980 596.0273 训练不足
1.00 23.8 205.6941 569.8181 接近最优
2.00 47.6 177.1822 545.4149 接近最优
4.00 95.2 155.2725 522.6933 数据充足
- 重复展示了固定模型规模时,数据量的关键作用。
- 数据量翻倍的效果:从6B数据→95B数据(16倍增长),Chinchilla损失从291→155(几乎减半)
- KM法则的严重误判:KM法则预测的损失始终在500+的高位,完全无法反映数据增加带来的收益

五、两者的可视化对比
import numpy as np
import matplotlib.pyplot as plt
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 1. 定义计算预算 C (以FLOPs为单位,使用对数等间距点)
# np.linspace(20, 24, 100) 生成一个从20到24的数组,包含100个等间距的点。
# 这个数组代表计算预算的对数值,范围从10^20到10^24 FLOPs,覆盖了从中等到大规模的训练预算。
log_C = np.linspace(20, 24, 100)
# 将对数坐标转换回线性坐标,得到具体的计算预算值C。
C = 10 ** log_C
# 2. 根据两种法则估算模型参数量 (N) 和训练数据量 (D)
# 注意:以下是非常简化的经验近似,用于演示两种法则在趋势上的根本差异。
# KM扩展法则风格 (倾向于更大的模型规模):
# 假设模型参数量 N 与计算预算 C 的 0.7 次方成正比。
# 假设训练数据量 D 与计算预算 C 的 0.3 次方成正比。
# 这里的比例常数 (1e8, 5e9) 是为了让曲线在图表中处于一个合适的视觉位置而任意设定的。
N_km = 1e8 * (C / 1e20) ** 0.7 # 基础参数1亿,按比例缩放
D_km = 5e9 * (C / 1e20) ** 0.3 # 基础数据50亿Token,按比例缩放
# Chinchilla扩展法则风格 (模型与数据平衡增长):
# 假设模型参数量 N 和训练数据量 D 均与计算预算 C 的 0.5 次方成正比。
# 这体现了其核心思想:对于固定的计算预算,应在N和D之间进行平衡分配。
N_chi = 5e8 * (C / 1e20) ** 0.5 # 基础参数5亿,按比例缩放
D_chi = 2e10 * (C / 1e20) ** 0.5 # 基础数据200亿Token,按比例缩放
# 3. 创建图表进行可视化
# plt.subplots(1, 2) 创建一個包含1行2列子图的图形窗口。
# figsize=(14, 5) 设置整个图形窗口的尺寸为宽14英寸、高5英寸。
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
# 图表1:模型参数量 (N) 对比
# 在第一个子图(ax1)上,用蓝色实线绘制KM法则的N,用红色虚线绘制Chinchilla法则的N。
ax1.loglog(C, N_km, 'b-', linewidth=2, label='KM法则 (模型规模优先)')
ax1.loglog(C, N_chi, 'r--', linewidth=2, label='Chinchilla法则 (平衡策略)')
# 设置坐标轴标签、标题和图例。
ax1.set_xlabel('计算预算 C (FLOPs)')
ax1.set_ylabel('模型参数量 (N)')
ax1.set_title('模型规模预测对比')
ax1.legend() # 显示图例
ax1.grid(True, which="both", ls="-", alpha=0.2) # 添加网格线,便于读数
# 图表2:训练数据量 (D) 对比
# 在第二个子图(ax2)上,用同样的线型和颜色绘制两种法则的D。
ax2.loglog(C, D_km, 'b-', linewidth=2, label='KM法则')
ax2.loglog(C, D_chi, 'r--', linewidth=2, label='Chinchilla法则')
ax2.set_xlabel('计算预算 C (FLOPs)')
ax2.set_ylabel('训练数据Token量 (D)')
ax2.set_title('训练数据量预测对比')
ax2.legend()
ax2.grid(True, which="both", ls="-", alpha=0.2)
# 自动调整子图参数,使之填充整个图像区域,避免重叠。
plt.tight_layout()
# 显示图形
plt.show()
输出结果:

图例分析:
左图:模型规模预测对比
- X轴是“计算预算 C (FLOPs)”,采用对数刻度。
- Y轴是“模型参数量 (N)”,采用对数刻度。
- 图中包含两条线:
- 一条蓝色实线(KM法则):这条线非常陡峭,意味着随着计算预算的增加,KM法则建议的模型参数量会急剧增长。
- 一条红色虚线(Chinchilla法则):这条线相比蓝线平缓得多,意味着在相同的计算预算下,Chinchilla法则推荐的模型规模远小于KM法则的推荐。
右图:训练数据量预测对比
- X轴同样是“计算预算 C (FLOPs)”。
- Y轴是“训练数据Token量 (D)”。
- 图中同样包含两条线:
- 一条蓝色实线(KM法则):这条线非常平缓,意味着KM法则认为数据量只需要随着算力缓慢增加。
- 一条红色虚线(Chinchilla法则):这条线非常陡峭,意味着Chinchilla法则建议的数据量需要随着算力迅猛增长,其推荐量远超KM法则。
图示结论:
- KM法则(蓝色) 是 “大模型,适量数据” 的策略。它把大部分新增的算力都投入到了扩大模型参数上。
- Chinchilla法则(红色) 是 “适中模型,海量数据” 的策略。它认为算力应该在模型和数据之间取得平衡,甚至更倾向于为模型配备远超以往认知的数据量。
六、总结
大模型扩展法则揭示了计算预算的最优分配原理,KM法则主张“规模至上”,认为应优先扩大模型参数,数据适量即可。而Chinchilla法则通过实验证明,许多大模型实际处于训练不足状态,提出模型与数据应平衡增长的效率优先原则。
Chinchilla法则完成了关键范式转移,通过系统实验证明:平衡分配计算预算至模型参数量与训练数据量,才能在固定成本下实现性能最优。其核心在于将资源分配从KM的7:3倾斜调整为1:1平衡。这一转变具有深远影响:数据价值被重新评估,模型开发从盲目追求参数量转向寻求最优配比。实践中,Chinchilla法则催生了LLaMA等"小模型、大数据"的高效架构,显著降低了AI应用门槛。
更多推荐

所有评论(0)