一、前言

        大模型的浪潮如火如荼,但做为个人开发者和小企业的我们,不知道大家有没有面临这样的困境:有限的算力预算如同杯水车薪,是该训练一个参数更多的聪明模型,还是用更多数据喂养一个见多识广的模型,往往训练一个大体量的模型,需要耗费大量的资金和时间,而作为普通用户的我们,如果想训练一个自己的模型,在我们固定的计算预算下,我们应该训练一个多大的模型参数量?并用多少数据?如何高效地分配计算资源成为模型训练的核心问题!

        扩展法则就是为了科学地回答这个问题而生的,也正是破解这一难题,为我们提供了精细化的指导思路,它们是基于大量实验得出的经验规律,用于预测模型性能损失如何随参数量N和数据量D的变化而变化,它告诉我们,盲目堆砌参数可能只是在制造昂贵的傻瓜,而恰当的数据配比能让小预算发挥大效能。理解扩展法则,意味着能用1%的资源达成80%的效果,让资源有限的团队也能在AI赛道上精准发力。这不仅是技术选择,更是生存智慧,在有限的算力资源中,找到属于我们个人或小团队的制胜策略。今天我们重点围绕两个关键的扩展法则:KM扩展法则和Chinchilla扩展法则深度解析基础释义、核心思想以及数学原理,总结两者的差异和对模型训练的重要意义。

二、核心问题和概念

在深入分析之前,我们必须明确扩展法则要解决的核心问题:

在计算预算 C 固定的前提下,如何分配模型参数量N和训练数据Token量D,才能使模型的最终性能损失L最优。

这里有几个关键概念:

1. FLOPs

1.1 核心概念

FLOPs 是浮点运算次数,它就像是衡量计算机“做了多少脑力工作”的计数器,拆解开来理解:

  • Floating Point: 浮点数
    • 简单理解就是带小数点的数字,比如 3.14, -0.001, 2.71828。计算机在处理科学计算、图形图像、人工智能这些复杂任务时,主要就是和浮点数打交道。
  • Operations: 运算
    • 指最基本的数学计算,主要是 加法、减法、乘法、除法 这四种。

所以,1个 FLOP 就代表计算机执行了一次浮点数的加法、减法、乘法或除法,FLOPs(末尾的s代表复数)就是指总的浮点运算次数。

1.2 通俗的理解

想象我们要做一道数学题,要计算一个数学公式题:y = (3.2 × 1.5) + (2.1 ÷ 0.7) - 4.0,我们一步步算一下:

  • 1. 先算 3.2 × 1.5 → 1次乘法(1 FLOP)
  • 2. 再算 2.1 ÷ 0.7 → 1次除法(1 FLOP)
  • 3. 然后算 (4.8) + (3.0) → 1次加法(1 FLOP)
  • 4. 最后算 7.8 - 4.0 → 1次减法(1 FLOP)

完成这道题,计算机总共需要执行4次浮点运算,所以它的计算量就是4FLOPs。

1.3 对大模型的含义

在训练和运行AI模型时,绝大部分工作都是大规模的矩阵和向量运算,而这些运算最终都可以分解成海量的加法和乘法。

一个具体的例子:计算一个神经元的输出

假设一个神经元有3个输入 [x1, x2, x3],对应的权重是 [w1, w2, w3],还有一个偏置项 b。
它的输出是:y = (x1*w1 + x2*w2 + x3*w3) + b

我们来数一下FLOPs:

  • x1*w1, x2*w2, x3*w3 → 3次乘法(3 FLOPs)
  • (x1*w1 + x2*w2) → 1次加法(1 FLOP)
  • (... + x3*w3) → 1次加法(1 FLOP)
  • (... + b) → 1次加法(1 FLOP)

总共:6 FLOPs。

由此可以看出,一个大语言模型有数千亿个参数(权重和偏置),每处理一个token都需要进行数百万甚至数十亿次这样的计算,这个总的FLOPs数量就会变得极其庞大。

1.4 FLOPs的重要性

FLOPs是衡量计算成本、算法效率和硬件性能的一个核心指标。

  • 衡量模型复杂度/训练成本:我们说训练一个庞大的模型需要~3.14 × 10^23 FLOPs,这个天文数字直观地告诉我们这个模型训练起来极其昂贵和耗时。它就是后面我们要进一步提到的 C ≈ 6 * N * D 公式具体计算出来的结果。
  • 衡量硬件算力:我们常听到的A100、H100显卡,它们的算力就是用 FLOPS来衡量的。比如一块H100显卡的算力可达 ~4 PFLOPS,意思是它每秒可以执行4 × 10^15次浮点运算。
  • 估算时间:有了总计算量(FLOPs)和硬件算力(FLOPS),我们就可以粗略估算任务时间:
  • 时间 ≈ 总FLOPs / 硬件FLOPS

        FLOPs就是完成一个计算任务,比如训练一个AI模型所需要完成的基础数学题的总数量,表示一个工作量单位,数量越大,意味着任务越复杂,需要的计算资源越多。它是我们理解和量化人工智能等领域巨大计算需求的基石。

2. 计算预算C ≈ 6 * N * D

计算预算通常以FLOPs衡量。对于自回归语言模型训练,一个广泛使用的近似是 C ≈ 6 * N * D,这个公式是理解模型训练成本的钥匙,它告诉我们,总计算量主要取决于模型有多大和学了多少数据。

为什么是 6 * N * D?

这是一个基于Transformer架构自回归语言模型训练的经验近似值。我们可以通过分析模型的前向传播和反向传播过程来理解它:

  • 前向传播:
    • 当模型处理一个输入token时,它需要执行矩阵乘法和激活函数等操作。
    • 在Transformer中,处理一个token所需的浮点运算次数大致与模型总参数量 N 成正比。
    • 因此,对单个token的前向传播计算量 ≈ 2 * N FLOPs。
    • 因子 2 主要来自于:矩阵乘法(占大部分)和激活函数等其它操作。
  • 反向传播:
    • 训练模型时,除了前向传播,还需要反向传播来计算梯度,从而更新权重。
    • 反向传播的计算量通常是前向传播的2倍。
    • 因为它需要计算损失函数对于每一层输入的梯度(类似于前向传播),以及对于权重和偏置的梯度。
  • 总计算量:
    • 单个token的单次训练迭代总计算量 ≈ 前向传播(2N) + 反向传播(4N) = 6N FLOPs。
    • 对于整个训练过程,模型会看到整个数据集 D 个token。
    • 因此,总训练计算预算 C ≈ 6 * N * D FLOPs。

        这是一个近似值,实际值可能因模型架构、序列长度、优化器类型等因素而在 ~2ND 到 ~10ND 之间变化,但 6ND 是一个被广泛接受和使用的可靠估算值,用于进行高阶的趋势分析和比较。

        这个公式建立了一个预算约束,如果增大了模型规模N,但保持总预算C不变,那么必须相应地减少数据量D,反之亦然。这也是今天我们要谈论解决的核心问题:如何在固定的 C 下,最优地分配 N 和 D?

3. 性能L

L 是衡量模型好坏的指标,通常是模型在预留测试集上的交叉熵损失或困惑度,在语言建模中,它几乎总是通过交叉熵损失或其派生指标困惑度来定义,损失越低,模型能力越强。

3.1 交叉熵损失

核心思想:衡量模型预测的概率分布与真实的概率分布(一个one-hot向量,代表正确的下一个词)之间的距离。

计算公式(对于一个token):

  • Cross-Entropy = - Σ (y_true * log(y_pred))
  • 由于 y_true 是one-hot向量(只有正确词的位置为1,其他为0),所以简化为:
  • Cross-Entropy = - log(y_pred_correct_word)

直观理解:模型对正确下一个词赋予的预测概率 y_pred_correct_word 越高,损失 -log(y_prob) 就越低。

  • 如果 y_pred_correct_word = 1(完美预测),则 loss = -log(1) = 0。
  • 如果 y_pred_correct_word = 0.5,则 loss = -log(0.5) ≈ 0.69。
  • 如果 y_pred_correct_word = 0.1(预测很差),则 loss = -log(0.1) ≈ 2.30。

整个数据集的损失是所有这些单个token损失的平均值。

3.2 困惑度

困惑度是交叉熵损失的指数形式,因为它更直观。

计算公式:Perplexity = exp(Cross-Entropy_Loss)

直观理解:困惑度可以理解为“模型在预测下一个词时的平均不确定性程度”或者“平均分支因子”。

  • 假设一个模型的困惑度为 10。这意味着,在预测每一个词时,模型感觉平均有10个等可能的选择。困惑度越低,模型越确定。
  • 一个完美的模型的困惑度为 1(因为它总是100%确定下一个词是什么)。
  • 如果模型只是随机猜测一个包含10,000个词的词汇表,其困惑度就是 10,000。

关系:由于 Perplexity = exp(L),最小化交叉熵损失 L 就等价于最小化困惑度。在扩展法则的研究中,通常直接使用交叉熵损失 L 作为优化目标,因为它数学性质更好(是加法性的)。

交叉熵损失和困惑度的详细说明可参考《信息论完全指南:从基础概念到在大模型中的实际应用

4. 幂律与收益递减

这是扩展法则的灵魂,揭示了性能提升的基本规律。

3.1 幂律关系

扩展法则发现,损失 L 与模型规模 N 和数据规模 D 遵循幂律关系:
L ∝ 1 / N^α
L ∝ 1 / D^β

这意味着,L 与 N^α 和 D^β 成反比。将其与不可约损失 E 结合,就得到了我们之前看到的完整公式:
L(N, D) = E + A/N^α + B/D^β

3.2 收益递减效应

幂律中的指数 α 和 β(通常远小于1)是理解收益递减的关键。

让我们通过一个例子来理解:

  • 假设 α = 0.5。
  • 当 N 从 10^6(100万)增加到 10^7(1000万)时,N^α 从 (10^6)^0.5 = 1000 增加到 (10^7)^0.5 ≈ 3162。
  • 性能提升(损失下降)的倍数是 3162 / 1000 ≈ 3.16 倍。
  • 现在,再将 N 从 10^7 增加到 10^8(1亿),N^α 从 3162 增加到 (10^8)^0.5 = 10000。
  • 性能提升的倍数是 10000 / 3162 ≈ 3.16 倍。

示例发现:

  • 在两种情况下,模型规模都增加了10倍。
  • 但性能提升的倍数却完全相同,都是约3.16倍,而不是随着规模变大而线性增加。
  • 这意味着,每当我们想让性能提升固定的倍数,我们只需要将模型规模增加一个固定的比例,例如10倍。这就是收益递减,我们需要投入指数级增长的资源,才能获得线性增长的性能。

对扩展法则的实际意义:

  • KM法则:α ≈ 0.076, β ≈ 0.103。指数非常小,意味着收益递减得非常非常慢。为了显著降低损失,我们需要极大地增加 N 或 D。同时,由于 α < β,增加 N 的收益衰减得比增加 D 更慢,因此KM推荐优先扩大 N。
  • Chinchilla法则:α ≈ β ≈ 0.38。指数更大,意味着收益递减效应比KM认为的要严重得多。同时,α 和 β 相等,意味着增加模型和增加数据对性能的贡献是平衡的。因此,Chinchilla推荐同步扩大 N 和 D。

5. 总结

预算C、性能L和幂律这三个概念构成了一个完整的逻辑链:

  • 有固定预算 C ≈ 6ND。
  • 目标是最小化损失 L(或困惑度)。
  • 通过实验发现,L 与 N 和 D 存在幂律关系,并伴有收益递减。
  • KM和Chinchilla法则的核心争论在于幂律指数 α 和 β 的具体数值,这直接导致了截然不同的最优资源分配策略。

三、KM扩展法则

1. 基础理解

核心思想: 在计算预算充足的情况下,模型参数量 N 是影响性能的最关键因素。为了达到最佳性能,应优先扩大模型规模,同时按比例适当增加数据量。

一个简单的比喻:

好比我们在组建一个研究团队来解决一个复杂问题。

  • 模型参数量就像是团队中博士的数量。博士越多,团队的集体智力、分析能力和解决复杂问题的潜力就越大。
  • 训练数据量就像是提供给这个团队的研究资料和参考文献的数量。
  • KM法则的发现是:相比于无止境地增加研究资料,优先增加博士的数量,对最终解决问题能力的提升效果更显著。
  • 当然,资料不能太少,否则博士们会“巧妇难为无米之炊”。但KM法则认为,在资源和精力有限时,我们应该把重点放在招募更多、更聪明的博士上。

2. 数学公式

KM法则将测试损失 L 建模为 N 和 D 的幂律函数:

L(N, D) = E + (A / N^α) + (B / D^β)

其中:

  • L(N, D): 这是我们最终想知道的模型损失。它取决于模型参数量 N 和训练数据Token量 D。
  • E: 不可约损失,这是理论上的最低损失,由数据本身的内在复杂度和噪音决定,就像无论多么聪明的学生,也无法完美预测一次完全随机的硬币抛掷结果。E 是一个无法通过改进模型或增加数据来超越的极限。
  • A / N^α: 模型容量损失,这部分损失是因为我们的模型不够大、不够复杂。
    • N 是模型参数量。
    • A 和 α 是常数,通过实验拟合得出,OpenAI发现 α ≈ 0.076。
    • 直观理解:当 N 增大时,A / N^α 会变小,这意味着模型越大,因容量不足导致的损失就越小。α 很小,意味着我们需要极大地增加 N,才能让这项显著下降。
  • B / D^β: 数据容量损失,这部分损失是因为我们的训练数据不够多。
    • D 是训练数据Token量。
    • B 和 β 是常数,通过实验拟合得出,OpenAI发现 β ≈ 0.103。
    • 直观理解:当 D 增大时,B / D^β 会变小。这意味着数据越多,因数据不足导致的损失就越小。β 也很小,意味着数据也需要增加很多才能显著降低这项损失。

        通过这个公式,如果我们知道了常数 E, A, B, α, β,我们就可以预测:一个拥有 N 参数、用 D 数据训练的模型,最终性能 L 大概会是多少,这为模型设计提供了很好的指导,由于 α 和 β 都很小,为了最小化损失,需要同时增大 N 和 D,但KM法则的实证结果表明,对 N 的投资回报率更高。

3. KM法则的决策指南

通过对上述公式的分析和实验验证,KM法则得出了几个改变AI研发方向的结论:

3.1 模型规模 N 的收益高于数据规模 D

  • 因为实验中拟合出的 α (0.076) 略小于 β (0.103)。这意味着,N 的指数更小,其收益随规模增长而衰减的速度更慢。换句话说,扩大 N 的“后劲”更足。
  • 在计算预算 C(约等于 6ND)固定的情况下,最优配置是大力倾斜于 N。论文给出的近似最优比例是:模型规模 N 应与 C 的约0.73次方成正比,而数据 D 仅与 C 的约0.27次方成正比。

3.2 性能平滑可预测

  • 模型性能主要取决于 N 和 D,而对模型架构、训练方式等超参数的依赖相对较小。这意味着,你可以通过“缩放”来可靠地提升性能。

3.3 在计算最优边界上,模型应该“训练不足”

  • 这是KM法则一个非常重要且反直觉的推论。它说:在固定的计算预算下,我们应该训练一个非常大的模型,但只在相对较少的数据上训练它,直到它还没有完全收敛(即“训练不足”)时就停止。
  • 为什么?因为把同样的计算预算用来训练一个更大的模型,即使训练不充分,比用来更充分地训练一个较小的模型,最终性能更好。

4. 示例分析

KM法则的核心公式:L(N, D) = E + A/N^α + B/D^β

其中:

  • L:模型损失(越低越好)
  • N:模型参数量
  • D:训练数据量(token数)
  • E, A, B, α, β:通过实验拟合的常数
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

def km_scaling_law_log(N, D, E=2.0, A=3.0, B=3.0, alpha=0.3, beta=0.3):
    """
    计算KM扩展法则预测的损失值 - 对数尺度版本
    确保输出在合理范围内
    
    参数:
    N: 模型参数量 (单位: 十亿)
    D: 训练数据量 (单位: 十亿token)
    """
    # 使用对数尺度确保数值稳定
    # 基础损失 + 模型项 + 数据项
    L = E + (A / (np.log(N + 1) ** alpha)) + (B / (np.log(D + 1) ** beta))
    return L

def safe_exp(x):
    """安全的指数函数,防止溢出"""
    return np.exp(np.clip(x, -700, 700))

# 示例1: 单个模型预测
print("=== 示例1: 单个模型性能预测 ===")
N_example = 1.0  # 10亿参数
D_example = 5.0  # 50亿token

loss = km_scaling_law_log(N_example, D_example)
print(f"模型规模: {N_example}B 参数")
print(f"训练数据: {D_example}B token")
print(f"KM法则预测损失: {loss:.4f}")
print(f"对应的困惑度: {safe_exp(loss):.2f}\n")

# 示例2: 不同规模模型的对比
print("=== 示例2: 不同规模模型对比 ===")
model_sizes = [0.1, 0.5, 1.0, 5.0, 10.0, 50.0, 100.0]  # 从1亿到1000亿参数
fixed_data = 10.0  # 固定100亿token数据

print(f"固定训练数据: {fixed_data}B token")
print("模型规模(B)\t预测损失\t困惑度")
print("-" * 55)

for size in model_sizes:
    loss = km_scaling_law_log(size, fixed_data)
    perplexity = safe_exp(loss)
    print(f"{size:8.1f}\t{loss:.4f}\t\t{perplexity:.2f}")

# 示例3: 不同数据量的对比
print("\n=== 示例3: 不同数据量对比 ===")
data_sizes = [1.0, 5.0, 10.0, 50.0, 100.0, 500.0, 1000.0]  # 从10亿到1万亿token
fixed_model = 1.0  # 固定10亿参数

print(f"固定模型规模: {fixed_model}B 参数")
print("数据量(B)\t预测损失\t困惑度")
print("-" * 55)

for data in data_sizes:
    loss = km_scaling_law_log(fixed_model, data)
    perplexity = safe_exp(loss)
    print(f"{data:8.1f}\t{loss:.4f}\t\t{perplexity:.2f}")

# 示例4: 可视化分析
print("\n=== 示例4: 生成可视化图表 ===")

# 创建模型规模和数据的网格
N_range = np.logspace(-1, 2, 50)  # 从0.1B到100B参数
D_range = np.logspace(0, 3, 50)   # 从1B到1000B token

N_grid, D_grid = np.meshgrid(N_range, D_range)
L_grid = km_scaling_law_log(N_grid, D_grid)

# 创建可视化图表
fig = plt.figure(figsize=(16, 5))

# 子图1: 固定数据量,看模型规模的影响
ax1 = fig.add_subplot(131)
fixed_D = 10.0  # 固定10B token

losses_N = [km_scaling_law_log(N, fixed_D) for N in N_range]
ax1.semilogx(N_range, losses_N, 'b-', linewidth=3)
ax1.set_xlabel('模型参数量 (十亿)')
ax1.set_ylabel('预测损失')
ax1.set_title('模型规模对性能的影响\n(固定数据量)')
ax1.grid(True, alpha=0.3)

# 标记GPT-3规模的点
gpt3_N = 175
gpt3_loss = km_scaling_law_log(gpt3_N, fixed_D)
ax1.axvline(x=gpt3_N, color='red', linestyle='--', alpha=0.7)
ax1.plot(gpt3_N, gpt3_loss, 'ro', markersize=8)
ax1.annotate(f'GPT-3\n({gpt3_N}B)', (gpt3_N, gpt3_loss), 
            xytext=(10, 10), textcoords='offset points',
            bbox=dict(boxstyle="round,pad=0.3", fc="yellow", alpha=0.7))

# 子图2: 固定模型规模,看数据量的影响
ax2 = fig.add_subplot(132)
fixed_N = 1.0  # 固定1B参数

losses_D = [km_scaling_law_log(fixed_N, D) for D in D_range]
ax2.semilogx(D_range, losses_D, 'r-', linewidth=3)
ax2.set_xlabel('训练数据量 (十亿token)')
ax2.set_ylabel('预测损失')
ax2.set_title('数据量对性能的影响\n(固定模型规模)')
ax2.grid(True, alpha=0.3)

# 子图3: 热力图展示N和D的共同影响
ax3 = fig.add_subplot(133)
contour = ax3.contourf(np.log10(N_grid), np.log10(D_grid), L_grid, levels=20, cmap='RdYlBu_r')
ax3.set_xlabel('log10(模型参数) (B)')
ax3.set_ylabel('log10(训练数据) (B)')
ax3.set_title('KM扩展法则热力图\n颜色表示损失值')

# 添加等值线
contour_lines = ax3.contour(np.log10(N_grid), np.log10(D_grid), L_grid, 
                           levels=10, colors='black', alpha=0.5)
ax3.clabel(contour_lines, inline=True, fontsize=8)

plt.colorbar(contour, ax=ax3, label='预测损失')

plt.tight_layout()
plt.show()

# 示例5: 实际模型案例分析
print("\n=== 示例5: 实际模型性能预测 ===")

real_models = [
    {"name": "GPT-3", "N": 175, "D": 300},
    {"name": "LLaMA-2 7B", "N": 7, "D": 2000},
    {"name": "LLaMA-2 70B", "N": 70, "D": 2000},
    {"name": "PaLM", "N": 540, "D": 780},
    {"name": "Chinchilla", "N": 70, "D": 1400},
]

print("模型名称\t\t参数(B)\t数据(B)\t预测损失\t困惑度")
print("-" * 70)

for model in real_models:
    loss = km_scaling_law_log(model["N"], model["D"])
    perplexity = safe_exp(loss)
    print(f"{model['name']:12}\t{model['N']:4.0f}\t{model['D']:4.0f}\t{loss:.4f}\t\t{perplexity:.2f}")

# 示例6: 资源分配建议
print("\n=== 示例6: 资源分配策略 ===")

def analyze_resource_allocation(total_compute):
    """分析不同资源分配策略"""
    print(f"\n在总计算量 {total_compute:.1e} FLOPs 下的策略分析:")
    print("策略\t\t\t模型规模(B)\t数据量(B)\t预测损失")
    print("-" * 65)
    
    # 策略1: KM风格 (偏向大模型)
    N_km = (total_compute / 6) ** 0.7 / 1e9
    D_km = (total_compute / 6) ** 0.3 / 1e9
    loss_km = km_scaling_law_log(N_km, D_km)
    print(f"KM策略\t\t\t{N_km:6.1f}\t\t{D_km:6.1f}\t\t{loss_km:.4f}")
    
    # 策略2: Chinchilla风格 (平衡)
    N_chi = (total_compute / 6) ** 0.5 / 1e9
    D_chi = (total_compute / 6) ** 0.5 / 1e9
    loss_chi = km_scaling_law_log(N_chi, D_chi)
    print(f"Chinchilla策略\t\t{N_chi:6.1f}\t\t{D_chi:6.1f}\t\t{loss_chi:.4f}")
    
    # 策略3: 偏向大数据
    N_data = (total_compute / 6) ** 0.3 / 1e9
    D_data = (total_compute / 6) ** 0.7 / 1e9
    loss_data = km_scaling_law_log(N_data, D_data)
    print(f"数据优先策略\t\t{N_data:6.1f}\t\t{D_data:6.1f}\t\t{loss_data:.4f}")

analyze_resource_allocation(1e22)  # 分析1e22 FLOPs预算

代码详细解释

4.1 核心函数

def km_scaling_law_log(N, D, E=2.0, A=3.0, B=3.0, alpha=0.3, beta=0.3):
    """
    计算KM扩展法则预测的损失值 - 对数尺度版本
    确保输出在合理范围内
    
    参数:
    N: 模型参数量 (单位: 十亿)
    D: 训练数据量 (单位: 十亿token)
    """
    # 使用对数尺度确保数值稳定
    # 基础损失 + 模型项 + 数据项
    L = E + (A / (np.log(N + 1) ** alpha)) + (B / (np.log(D + 1) ** beta))
    return L
  • 这个函数实现了KM法则的核心公式
  • 默认参数基于OpenAI论文的近似值
  • E=2.0 代表理论上的最小损失
  • alpha=0.3 和 beta=0.3 是关键,决定了模型和数据的收益递减速度

4.2 单个预测示例

N_example = 1.0  # 10亿参数
D_example = 5.0  # 50亿token

loss = km_scaling_law_log(N_example, D_example)

这里我们预测一个10亿参数、用50亿token训练的模型的性能。

4.3 规模对比分析

model_sizes = [0.1, 0.5, 1.0, 5.0, 10.0, 50.0, 100.0]  # 从1亿到1000亿参数

通过这个循环,我们可以看到模型规模从1亿参数增长到100亿参数时,性能如何变化。

4.4 输出结果

=== 示例1: 单个模型性能预测 ===
模型规模: 1.0B 参数
训练数据: 5.0B token
KM法则预测损失: 7.8672
对应的困惑度: 2610.12

=== 示例2: 不同规模模型对比 ===
固定训练数据: 10.0B token
模型规模(B)     预测损失        困惑度
-------------------------------------------------------
     0.1        10.3803         32219.56
     0.5        8.2408          3792.44
     1.0        7.6563          2114.01
     5.0        6.8261          921.62
    10.0        6.6153          746.45
    50.0        6.2972          543.03
   100.0        6.2038          494.62

=== 示例3: 不同数据量对比 ===
固定模型规模: 1.0B 参数
数据量(B)       预测损失        困惑度
-------------------------------------------------------
     1.0        8.6974          5987.08
     5.0        7.8672          2610.12
    10.0        7.6563          2114.01
    50.0        7.3382          1537.90
   100.0        7.2448          1400.80
   500.0        7.0827          1191.19
  1000.0        7.0286          1128.50

=== 示例4: 生成可视化图表 ===

=== 示例5: 实际模型性能预测 ===
模型名称                参数(B) 数据(B) 预测损失        困惑度
----------------------------------------------------------------------
GPT-3            175     300    5.6117          273.60
LLaMA-2 7B         7    2000    6.0409          420.29
LLaMA-2 70B       70    2000    5.5744          263.58
PaLM             540     780    5.4262          227.27
Chinchilla        70    1400    5.5980          269.90

=== 示例6: 资源分配策略 ===

在总计算量 1.0e+22 FLOPs 下的策略分析:
策略                    模型规模(B)     数据量(B)       预测损失
-----------------------------------------------------------------
KM策略                  716628.6                   0.0          21.8804
Chinchilla策略            40.8            40.8          6.0413
数据优先策略               0.0          716628.6                21.8804

图例分析:

  • 左图:蓝色曲线显示,随着模型参数增加,损失稳步下降,但下降速度逐渐变缓,收益递减
  • 中图:红色曲线显示,数据量增加也能降低损失,但效果不如增大模型规模明显
  • 右图:3D曲面显示损失如何同时受模型规模和数据量的影响

四、Chinchilla扩展法则

1. 基础理解

核心思想: 对于给定的计算预算 C,模型参数量 N 和数据Token量 D 应该成比例地增长。模型不是越大越好,而是需要与足够多的数据配对。许多现有的大模型是训练不足的,减小模型规模并大幅增加数据量,可以在相同计算成本下获得更优的性能。

这个思想可以分解为三个关键点:

1.1 挑战规模至上的观点

  • 之前:KM法则认为模型规模是提升性能的最有效杠杆。
  • Chinchilla发现:KM法则基于的实验范围有限(最大模型17B),当模型规模扩大到千亿参数时,其推荐的数据量远远不够,导致模型训练不足。模型虽然参数很多,但因为没有看到足够的数据,无法充分学习,其潜力未被完全挖掘。

1.2 揭示训练不足问题

  • 比喻:这就像招募了一个智商200的天才(大模型),但只给他一本小学课本(少量数据)学习。他很快就能把课本倒背如流(训练损失很低),但并没有真正掌握广博的知识(泛化能力差)。而一个智商120的聪明人(中等模型),如果让他读完整个图书馆的藏书(海量数据),其解决实际问题的能力会远超那个天才。
  • Chinchilla证明:像GPT-3 (175B)、Gopher (280B)这样的模型,如果将其参数量减半,但将训练数据增加至原来的4倍,训练出的模型性能反而会更强。

1.3 确立平衡分配原则

  • Chinchilla的最终目标是为给定的计算预算 C,找到最优的 N 和 D 配比。
  • 其核心结论是:计算预算应该在模型规模和训练数据之间几乎均等地分配。具体来说,模型参数量 N 和训练数据token量 D 应与计算预算 C 的平方根成正比。

2. 数学公式

2.1 公式说明

与KM法则类似,Chinchilla将测试损失 L 建模为模型参数量 N 和训练数据量 D 的函数:

L(N, D) = E + A/(N^α) + B/(D^β)

其中:

  • L(N, D): 模型在未知数据上的损失(交叉熵损失)。L 越低,模型性能越好。
  • E: 不可约损失,代表了数据分布本身的内在噪音和不确定性,是任何模型都无法超越的理论下限。
  • A/(N^α): 模型容量项,代表了因模型参数有限而产生的近似误差。随着模型参数 N 增加,这项误差会减小。
  • B/(D^β): 数据容量项,代表了因训练数据有限而产生的估计误差。随着训练数据 D 增加,这项误差会减小。

关键的Chinchilla参数值:
DeepMind通过实验拟合出的参数约为:

  • α ≈ 0.38
  • β ≈ 0.38
  • A 和 B 是相应的缩放常数。

2.2 与KM法则的数学对比

特性 KM 法则 Chinchilla 法则 含义与影响
模型指数 α ~0.076 ~0.38 Chinchilla的α大了约5倍! 这意味着增加模型规模带来的性能收益衰减得快得多。模型规模的增长不再那么“划算”。
数据指数 β ~0.103 ~0.38 Chinchilla的β也大了约3.7倍! 这意味着增加数据量带来的性能收益同样衰减得很快,但其衰减速度现在与模型项持平。
指数关系 α < β α ≈ β 这是最根本的差异。 KM认为模型收益衰减更慢,故应优先扩大模型。Chinchilla发现两者衰减速度相同,故应平衡分配资源。

2.3 直观理解指数差异

α 和 β 决定了“收益递减”的速度。

  • KM的小指数:意味着需要把 N 或 D 扩大非常多,才能让 A/N^α 或 B/D^β 这一项显著减小。收益来得慢,但持续久。
  • Chinchilla的大指数:意味着 N 或 D 的初始增长能带来显著的性能提升,但很快就会碰到“收益墙”,再投入资源的回报率急剧下降。

2.4 了解 N_op 和 D_op

2.4.1 N_op 和 D_op 是什么

  • N_op: 最优模型参数量
    • 在给定的计算预算 C 下,能够使模型性能达到最好的那个模型规模。
    • 单位:通常是十亿参数
  • D_op: 最优训练数据量
    • 在给定的计算预算 C 下,能够使模型性能达到最好的那个训练数据量。
    • 单位:通常是十亿token

2.4.2 符号 ∝ 的含义

∝ 表示"正比于",所以:

  • N_op ∝ C^0.5 意思是:最优模型规模与计算预算的平方根成正比
  • D_op ∝ C^0.5 意思是:最优数据量与计算预算的平方根成正比

2.4.3 直观理解:切蛋糕的比喻

想象我们有一块固定大小的蛋糕(计算预算 C),要分给两个人:

  • 一个人叫"模型规模" (N)
  • 一个人叫"训练数据" (D)

Chinchilla法则告诉我们:应该把蛋糕平均分给这两个人!

  • N_op ∝ C^0.5 和 D_op ∝ C^0.5 就是这个平均分配规则的数学表达。

2.4.4 具体实例

场景1:小预算情况
假设计算预算 C = 1e21 FLOPs

  • N_op ∝ (1e21)^0.5 ≈ 3.16e10 参数 ≈ 31.6亿参数
  • D_op ∝ (1e21)^0.5 ≈ 3.16e10 token ≈ 31.6亿token

场景2:预算增加100倍
现在预算增加到 C = 1e23 FLOPs(增加了100倍)

  • N_op ∝ (1e23)^0.5 ≈ 3.16e11 参数 ≈ 316亿参数
  • D_op ∝ (1e23)^0.5 ≈ 3.16e11 token ≈ 316亿token

对比分析:

  • 计算预算变化:预算增加100倍    
  • 模型规模变化:模型增加10倍    
  • 数据量变化:数据增加10倍    
  • 平方根效应:预算线性增长,规模平方根增长

2.4 计算最优分配公式

基于上述性能预测公式,Chinchilla推导出了在固定计算预算 C(其中 C ≈ 6 N D)下,如何分配 N 和 D 才能使损失 L 最小化。

核心发现:最优配置是让模型容量项和数据容量项对损失的贡献大致相等。

其推导出的最优比例是:
N_op ∝ C^a
D_op ∝ C^b
其中 a = β/(α+β), b = α/(α+β)

代入Chinchilla的 α=β=0.38:

  • a = 0.38/(0.38+0.38) = 0.5
  • b = 0.38/(0.38+0.38) = 0.5

因此,最优策略为:
N_op ∝ C^0.5
D_op ∝ C^0.5

具体经验性结论:
对于一个计算预算 C,Chinchilla推荐:

  • 模型参数量 N 约等于 (C / 20)^0.5
  • 训练数据Token量 D 约等于 (C / 20)^0.5

注意:这里的常数20是考虑了模型前向和反向传播的FLOPs估算后的一个经验值,与 C ≈ 6ND 的本质思想一致。

Chinchilla法则的数学公式告诉我们:

  • 性能公式:L(N, D) = E + A/N^0.38 + B/D^0.38
    • 模型和数据的贡献是对称且平衡的。
  • 分配公式:N_op ∝ C^0.5, D_op ∝ C^0.5
    • 计算预算应该均等地分配给模型复杂度和数据量。

3. 示例分析

在固定计算预算下,模型参数量(N)和训练数据量(D)应该平衡增长,而不是像KM法则那样偏向模型规模。

核心公式:L(N, D) = E + A/N^α + B/D^β
其中 α ≈ β ≈ 0.38,这与KM法则的 α=0.076, β=0.103 形成鲜明对比。

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

def chinchilla_scaling_law(N, D, E=1.69, A=406.4, B=410.7, alpha=0.38, beta=0.38):
    """
    计算Chinchilla扩展法则预测的损失值
    
    参数:
    N: 模型参数量 (单位: 十亿)
    D: 训练数据量 (单位: 十亿token)
    E, A, B, alpha, beta: Chinchilla法则的经验参数
    
    返回:
    L: 预测的损失值
    """
    # Chinchilla核心公式 - 注意指数alpha和beta都接近0.38
    L = E + (A / (N ** alpha)) + (B / (D ** beta))
    return L

def km_scaling_law(N, D, E=1.5, A=500, B=1000, alpha=0.076, beta=0.103):
    """
    KM扩展法则用于对比
    """
    L = E + (A / (N ** alpha)) + (B / (D ** beta))
    return L

# 示例1: 单个模型预测对比
print("=== 示例1: Chinchilla vs KM 预测对比 ===")
N_example = 70  # 70亿参数
D_example = 1500  # 1.5万亿token

loss_chinchilla = chinchilla_scaling_law(N_example, D_example)
loss_km = km_scaling_law(N_example * 1000, D_example * 1000)  # 转换为百万单位

print(f"模型规模: {N_example}B 参数")
print(f"训练数据: {D_example}B token")
print(f"Chinchilla预测损失: {loss_chinchilla:.4f}")
print(f"Chinchilla预测困惑度: {np.exp(loss_chinchilla):.2f}")
print(f"KM法则预测损失: {loss_km:.4f}")
print(f"KM法则预测困惑度: {np.exp(loss_km):.2f}\n")

# 示例2: 计算最优配置对比
print("=== 示例2: 最优配置计算对比 ===")

def find_optimal_allocation(compute_budget, law_type='chinchilla'):
    """
    根据不同的扩展法则找到最优配置
    假设计算预算 C ≈ 6 * N * D
    """
    if law_type == 'chinchilla':
        # Chinchilla: 平衡分配
        alpha, beta = 0.38, 0.38
        N_optimal = (compute_budget / 6) ** 0.5  # N ∝ C^0.5
        D_optimal = (compute_budget / 6) ** 0.5  # D ∝ C^0.5
    else:  # KM法则
        alpha, beta = 0.076, 0.103
        optimal_N_ratio = alpha / (alpha + beta)
        optimal_D_ratio = beta / (alpha + beta)
        N_optimal = (compute_budget / 6) ** optimal_N_ratio  # N ∝ C^0.74
        D_optimal = (compute_budget / 6) ** optimal_D_ratio  # D ∝ C^0.26
    
    return N_optimal, D_optimal

# 测试不同计算预算下的最优配置
budgets = [1e21, 5e21, 1e22, 5e22]  # 不同的计算预算

print("计算预算(FLOPs)\t法则类型\t\t最优参数(B)\t最优数据(B)\t参/数比例")
print("-" * 85)

for budget in budgets:
    # Chinchilla最优配置
    N_chi, D_chi = find_optimal_allocation(budget, 'chinchilla')
    ratio_chi = N_chi / D_chi
    
    # KM最优配置  
    N_km, D_km = find_optimal_allocation(budget, 'km')
    ratio_km = N_km / D_km
    
    print(f"{budget:.1e}\tChinchilla\t{N_chi/1e9:8.1f}\t\t{D_chi/1e9:8.1f}\t\t{ratio_chi:.3f}")
    print(f"{budget:.1e}\tKM法则\t\t{N_km/1e9:8.1f}\t\t{D_km/1e9:8.1f}\t\t{ratio_km:.3f}")
    print("-" * 85)

# 示例3: 训练不足分析
print("\n=== 示例3: 训练不足分析 ===")

def analyze_under_training(model_size_B, compute_budget):
    """
    分析在固定计算预算下,不同数据量对性能的影响
    """
    print(f"\n分析 {model_size_B}B 参数模型在 {compute_budget:.1e} FLOPs 预算下的表现:")
    
    # Chinchilla推荐的数据量
    N_chi_opt, D_chi_opt = find_optimal_allocation(compute_budget, 'chinchilla')
    D_chi_for_model = compute_budget / (6 * model_size_B * 1e9)
    
    # KM推荐的数据量
    N_km_opt, D_km_opt = find_optimal_allocation(compute_budget, 'km') 
    D_km_for_model = compute_budget / (6 * model_size_B * 1e9)
    
    # 计算不同数据量下的损失
    data_ratios = [0.25, 0.5, 1.0, 2.0, 4.0]  # 相对于Chinchilla推荐的数据量比例
    
    print("数据比例\t实际数据(B)\tChinchilla损失\tKM损失\t\t训练状态")
    print("-" * 75)
    
    for ratio in data_ratios:
        actual_data = D_chi_for_model * ratio / 1e9  # 转换为十亿单位
        loss_chi = chinchilla_scaling_law(model_size_B, actual_data)
        loss_km = km_scaling_law(model_size_B * 1000, actual_data * 1000)
        
        status = "严重训练不足" if ratio < 0.5 else "训练不足" if ratio < 1.0 else "接近最优" if ratio <= 2.0 else "数据充足"
        
        print(f"{ratio:4.2f}\t\t{actual_data:8.1f}\t\t{loss_chi:.4f}\t\t{loss_km:.4f}\t\t{status}")

analyze_under_training(70, 1e22)  # 分析70B模型

# 示例4: 可视化对比
print("\n=== 示例4: 生成对比可视化图表 ===")

# 创建计算预算范围
compute_range = np.logspace(20, 24, 50)  # 10^20 到 10^24 FLOPs

# 计算两种法则的最优配置
N_chi_optimal = []
D_chi_optimal = []
N_km_optimal = [] 
D_km_optimal = []

for C in compute_range:
    N_chi, D_chi = find_optimal_allocation(C, 'chinchilla')
    N_km, D_km = find_optimal_allocation(C, 'km')
    
    N_chi_optimal.append(N_chi / 1e9)  # 转换为十亿单位
    D_chi_optimal.append(D_chi / 1e9)
    N_km_optimal.append(N_km / 1e9)
    D_km_optimal.append(D_km / 1e9)

# 创建可视化图表
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))

# 子图1: 最优模型规模对比
ax1.loglog(compute_range, N_chi_optimal, 'r-', linewidth=3, label='Chinchilla最优')
ax1.loglog(compute_range, N_km_optimal, 'b--', linewidth=2, label='KM最优')
ax1.set_xlabel('计算预算 (FLOPs)')
ax1.set_ylabel('最优模型规模 (十亿参数)')
ax1.set_title('模型规模推荐对比\nChinchilla vs KM')
ax1.legend()
ax1.grid(True, alpha=0.3)

# 标记具体预算点示例
sample_budget = 1e22
N_chi_sample = (sample_budget / 6) ** 0.5 / 1e9
N_km_sample = (sample_budget / 6) ** (0.076/(0.076+0.103)) / 1e9
ax1.annotate(f'在{sample_budget:.0e} FLOPs:\nChinchilla: {N_chi_sample:.0f}B\nKM: {N_km_sample:.0f}B', 
            xy=(sample_budget, N_chi_sample), xytext=(1e21, 500),
            arrowprops=dict(arrowstyle='->', color='red'),
            bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="red", alpha=0.8))

# 子图2: 最优数据量对比
ax2.loglog(compute_range, D_chi_optimal, 'r-', linewidth=3, label='Chinchilla最优')
ax2.loglog(compute_range, D_km_optimal, 'b--', linewidth=2, label='KM最优')
ax2.set_xlabel('计算预算 (FLOPs)')
ax2.set_ylabel('最优训练数据量 (十亿token)')
ax2.set_title('训练数据量推荐对比\nChinchilla vs KM')
ax2.legend()
ax2.grid(True, alpha=0.3)

# 子图3: 参数-数据比例对比
ratio_chi = np.array(N_chi_optimal) / np.array(D_chi_optimal)
ratio_km = np.array(N_km_optimal) / np.array(D_km_optimal)

ax3.semilogx(compute_range, ratio_chi, 'r-', linewidth=3, label='Chinchilla比例')
ax3.semilogx(compute_range, ratio_km, 'b--', linewidth=2, label='KM比例')
ax3.set_xlabel('计算预算 (FLOPs)')
ax3.set_ylabel('参数/数据比例 (N/D)')
ax3.set_title('资源分配策略对比\n比例越高 = 越偏向模型规模')
ax3.legend()
ax3.grid(True, alpha=0.3)

# 子图4: 性能对比 - 固定计算预算下的损失
fixed_budget = 1e22
model_sizes = [7, 20, 70, 200]  # 不同的模型规模 (十亿参数)

chinchilla_losses = []
km_losses = []

for size in model_sizes:
    # 在固定预算下,计算对应的数据量
    data_chi = fixed_budget / (6 * size * 1e9) / 1e9  # 十亿token单位
    data_km = fixed_budget / (6 * size * 1e9) / 1e9   # 相同计算预算
    
    loss_chi = chinchilla_scaling_law(size, data_chi)
    loss_km = km_scaling_law(size * 1000, data_km * 1000)
    
    chinchilla_losses.append(loss_chi)
    km_losses.append(loss_km)

ax4.plot(model_sizes, chinchilla_losses, 'ro-', linewidth=2, label='Chinchilla预测')
ax4.plot(model_sizes, km_losses, 'bs--', linewidth=2, label='KM预测')
ax4.set_xlabel('模型规模 (十亿参数)')
ax4.set_ylabel('预测损失')
ax4.set_title(f'固定预算 {fixed_budget:.0e} FLOPs 下\n不同模型规模的性能对比')
ax4.legend()
ax4.grid(True, alpha=0.3)

# 标记最优配置
optimal_size_chi = (fixed_budget / 6) ** 0.5 / 1e9
optimal_data_chi = (fixed_budget / 6) ** 0.5 / 1e9
optimal_loss_chi = chinchilla_scaling_law(optimal_size_chi, optimal_data_chi)

ax4.axvline(x=optimal_size_chi, color='red', linestyle=':', alpha=0.7)
ax4.annotate(f'Chinchilla最优\n{optimal_size_chi:.0f}B模型', 
            xy=(optimal_size_chi, optimal_loss_chi), 
            xytext=(optimal_size_chi+30, optimal_loss_chi+0.1),
            arrowprops=dict(arrowstyle='->', color='red'))

plt.tight_layout()
plt.show()

输出结果:

=== 示例1: Chinchilla vs KM 预测对比 ===
模型规模: 70B 参数
训练数据: 1500B token
Chinchilla预测损失: 108.0695
Chinchilla预测困惑度: 85899031069167667854303274236400488860482535424.00
KM法则预测损失: 446.7954
KM法则预测困惑度: 109845366723675280192034736636001868827496702856567587991204197607574163216094605052146384532230226621195552621200325104919263199725470312853558268331235709933397008939880728797698102035965018112.00

Chinchilla: 108, KM: 447,预测的损失值和现实偏差很大,对参数(A, B, E, α, β)需要重新校准。

=== 示例2: 最优配置计算对比 ===
计算预算(FLOPs) 法则类型                最优参数(B)     最优数据(B) 参/数比例
-------------------------------------------------------------------------------------
1.0e+21 Chinchilla          12.9                    12.9            1.000
1.0e+21 KM法则               0.4                   432.5            0.001
-------------------------------------------------------------------------------------
5.0e+21 Chinchilla          28.9                    28.9            1.000
5.0e+21 KM法则               0.8                  1092.0            0.001
-------------------------------------------------------------------------------------
1.0e+22 Chinchilla          40.8                    40.8            1.000
1.0e+22 KM法则               1.0                  1627.3            0.001
-------------------------------------------------------------------------------------
5.0e+22 Chinchilla          91.3                    91.3            1.000
5.0e+22 KM法则               2.0                  4108.2            0.000
-------------------------------------------------------------------------------------

Chinchilla法则(平衡策略):

  • 模型与数据1:1平衡:参数/数据比例始终为 1.000
  • 同步增长:在1e21 FLOPs时,推荐约13B模型配13B数据;在5e22 FLOPs时,推荐约91B模型配91B数据
  • 实践意义:这是一种“中等模型 + 中等数据”的平衡发展路径

KM法则(极端偏向策略):

  • 极度偏向数据:参数/数据比例仅为 0.001
  • 模型极小,数据极大:在1e22 FLOPs时,推荐1B模型配1627B数据(差了1600倍!)
  • 实践意义:这是一种“极小模型 + 海量数据”的极端策略

=== 示例3: 训练不足分析 ===

分析 70B 参数模型在 1.0e+22 FLOPs 预算下的表现:
数据比例        实际数据(B)     Chinchilla损失  KM损失          训练状态
-------------------------------------------------------------------------------------------------------
0.25                 6.0                291.0828             624.1760        严重训练不足
0.50                11.9                242.7980            596.0273        训练不足
1.00                23.8                205.6941            569.8181        接近最优
2.00                47.6                177.1822            545.4149        接近最优
4.00                95.2                155.2725            522.6933        数据充足

  • 重复展示了固定模型规模时,数据量的关键作用。
  • 数据量翻倍的效果:从6B数据→95B数据(16倍增长),Chinchilla损失从291→155(几乎减半)
  • KM法则的严重误判:KM法则预测的损失始终在500+的高位,完全无法反映数据增加带来的收益

五、两者的可视化对比

import numpy as np
import matplotlib.pyplot as plt

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

# 1. 定义计算预算 C (以FLOPs为单位,使用对数等间距点)
# np.linspace(20, 24, 100) 生成一个从20到24的数组,包含100个等间距的点。
# 这个数组代表计算预算的对数值,范围从10^20到10^24 FLOPs,覆盖了从中等到大规模的训练预算。
log_C = np.linspace(20, 24, 100)
# 将对数坐标转换回线性坐标,得到具体的计算预算值C。
C = 10 ** log_C

# 2. 根据两种法则估算模型参数量 (N) 和训练数据量 (D)
# 注意:以下是非常简化的经验近似,用于演示两种法则在趋势上的根本差异。

# KM扩展法则风格 (倾向于更大的模型规模):
# 假设模型参数量 N 与计算预算 C 的 0.7 次方成正比。
# 假设训练数据量 D 与计算预算 C 的 0.3 次方成正比。
# 这里的比例常数 (1e8, 5e9) 是为了让曲线在图表中处于一个合适的视觉位置而任意设定的。
N_km = 1e8 * (C / 1e20) ** 0.7  # 基础参数1亿,按比例缩放
D_km = 5e9 * (C / 1e20) ** 0.3  # 基础数据50亿Token,按比例缩放

# Chinchilla扩展法则风格 (模型与数据平衡增长):
# 假设模型参数量 N 和训练数据量 D 均与计算预算 C 的 0.5 次方成正比。
# 这体现了其核心思想:对于固定的计算预算,应在N和D之间进行平衡分配。
N_chi = 5e8 * (C / 1e20) ** 0.5 # 基础参数5亿,按比例缩放
D_chi = 2e10 * (C / 1e20) ** 0.5 # 基础数据200亿Token,按比例缩放

# 3. 创建图表进行可视化
# plt.subplots(1, 2) 创建一個包含1行2列子图的图形窗口。
# figsize=(14, 5) 设置整个图形窗口的尺寸为宽14英寸、高5英寸。
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# 图表1:模型参数量 (N) 对比
# 在第一个子图(ax1)上,用蓝色实线绘制KM法则的N,用红色虚线绘制Chinchilla法则的N。
ax1.loglog(C, N_km, 'b-', linewidth=2, label='KM法则 (模型规模优先)')
ax1.loglog(C, N_chi, 'r--', linewidth=2, label='Chinchilla法则 (平衡策略)')
# 设置坐标轴标签、标题和图例。
ax1.set_xlabel('计算预算 C (FLOPs)')
ax1.set_ylabel('模型参数量 (N)')
ax1.set_title('模型规模预测对比')
ax1.legend() # 显示图例
ax1.grid(True, which="both", ls="-", alpha=0.2) # 添加网格线,便于读数

# 图表2:训练数据量 (D) 对比
# 在第二个子图(ax2)上,用同样的线型和颜色绘制两种法则的D。
ax2.loglog(C, D_km, 'b-', linewidth=2, label='KM法则')
ax2.loglog(C, D_chi, 'r--', linewidth=2, label='Chinchilla法则')
ax2.set_xlabel('计算预算 C (FLOPs)')
ax2.set_ylabel('训练数据Token量 (D)')
ax2.set_title('训练数据量预测对比')
ax2.legend()
ax2.grid(True, which="both", ls="-", alpha=0.2)

# 自动调整子图参数,使之填充整个图像区域,避免重叠。
plt.tight_layout()
# 显示图形
plt.show()

输出结果:

图例分析:

左图:模型规模预测对比

  • X轴是“计算预算 C (FLOPs)”,采用对数刻度。
  • Y轴是“模型参数量 (N)”,采用对数刻度。
  • 图中包含两条线:
    • 一条蓝色实线(KM法则):这条线非常陡峭,意味着随着计算预算的增加,KM法则建议的模型参数量会急剧增长。
    • 一条红色虚线(Chinchilla法则):这条线相比蓝线平缓得多,意味着在相同的计算预算下,Chinchilla法则推荐的模型规模远小于KM法则的推荐。

右图:训练数据量预测对比

  • X轴同样是“计算预算 C (FLOPs)”。
  • Y轴是“训练数据Token量 (D)”。
  • 图中同样包含两条线:
    • 一条蓝色实线(KM法则):这条线非常平缓,意味着KM法则认为数据量只需要随着算力缓慢增加。
    • 一条红色虚线(Chinchilla法则):这条线非常陡峭,意味着Chinchilla法则建议的数据量需要随着算力迅猛增长,其推荐量远超KM法则。

图示结论:

  • KM法则(蓝色) 是 “大模型,适量数据” 的策略。它把大部分新增的算力都投入到了扩大模型参数上。
  • Chinchilla法则(红色) 是 “适中模型,海量数据” 的策略。它认为算力应该在模型和数据之间取得平衡,甚至更倾向于为模型配备远超以往认知的数据量。

六、总结

        大模型扩展法则揭示了计算预算的最优分配原理,KM法则主张“规模至上”,认为应优先扩大模型参数,数据适量即可。而Chinchilla法则通过实验证明,许多大模型实际处于训练不足状态,提出模型与数据应平衡增长的效率优先原则。

        Chinchilla法则完成了关键范式转移,通过系统实验证明:平衡分配计算预算至模型参数量与训练数据量,才能在固定成本下实现性能最优。其核心在于将资源分配从KM的7:3倾斜调整为1:1平衡。这一转变具有深远影响:数据价值被重新评估,模型开发从盲目追求参数量转向寻求最优配比。实践中,Chinchilla法则催生了LLaMA等"小模型、大数据"的高效架构,显著降低了AI应用门槛。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐