Stanford CS336 Assignment 1 BPE Tokenizer trainning on tinystories dataset
但是,我修改了之后并没有改善,而且从 profiling 看,还在调用 dict.get,改动没有生效?我打算暂时放过这里,把经历放在 pretokenize 的多进程上,这是assignment 的建议。当pair 频次相同时,为了打破平局(Tie-breaking),再按 pair 的字典序排序,对于字典里的每一个 Pair,都调用了一次。是最耗时的部分,约占总运行时间的 63% (460秒)
3. Tinystories 数据集上训练
3.1. 阅读分析题目要求
这个题目有两个要求
- a. vocab size 最大 10,000,确保将 speicial token
"<|endoftext|>"加入到 vocabulary。 资源要求:训练时长 ≤ 30 minutes (no GPUs), 占用内存 ≤ 30GB RAM
Tips:如果要在 2 分钟内完成训练,可以考虑多线程处理 pretokenize。 - b. “tokenizer 训练过程中哪一部分最耗时?”
按照作业要求,我将分三步来完成:
-
编写训练脚本:包含加载数据、训练、保存模型、统计时间和内存。
-
运行并分析性能 (Profiling):回答瓶颈在哪里。
-
检查结果:找出最长的 Token。
3.2 逐步实现
-
查看训练数据
训练前,可以用 head 命令查看下训练数据,确认数据与之前测试的类型基本一致。 -
代码关键实现分析
- 获取项目根路径
为了正确方便读写文件,工程中通常首先获取项目根路径。
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
- 获取进程运行时的内存
def get_memory_usage_mb():
"""Get current process memory usage in MB"""
process = psutil.Process(os.getpid())
return process.memory_info().rss / 1024 / 1024
- 将训练获得的 vocabulary 和 merges 规则写入磁盘进行持久化。
vocabulary 结构是 Dict,所以保存为 json格式,并且为了人类阅读,将bytesdecode为字符串保存在vocab.json中,无法解码的保留repr形式。
def save_vocab_and_merges(vocab, merges, output_dir="results"):
"""Save vocabulary and merges to disk"""
Path(output_dir).mkdir(exist_ok=True, parents=True)
# Save vocab.json
vocab_str = {}
for idx, token_bytes in vocab.items():
try:
vocab_str[idx] = token_bytes.decode('utf-8')
except UnicodeDecodeError:
vocab_str[idx] = str(token_bytes)
with open(f"{output_dir}/vocab.json", "w", encoding="utf-8") as f:
json.dump(vocab_str, f, ensure_ascii=False, indent=2)
# Save merges.txt
with open(f"{output_dir}/merges.txt", "w", encoding="utf-8") as f:
for p1, p2 in merges:
f.write(f"{p1.decode('utf-8', errors='ignore')} {p2.decode('utf-8', errors='ignore')}\n")
print(f"Artifacts saved to {output_dir}/")
- 定义训练 & 打印统计信息
def run_training(input_path, vocab_size, special_tokens, output_dir):
# record the initial memory usage before training
print(f"Initial Memory: {get_memory_usage_mb():.2f} MB")
# Initialize the BPE trainer
trainer = BPETrainer_Optimized()
# start training and record the time and memory usage
start_time = time.time()
print(f"Starting training on {input_path}...")
vocab, merges = trainer.train(
input_path=input_path,
vocab_size=vocab_size,
special_tokens=special_tokens
)
end_time = time.time()
duration = end_time - start_time
peak_memory = get_memory_usage_mb()
print("-" * 30)
print(f"Training Complete.")
print(f"Time Taken: {duration:.2f} seconds ({duration/60:.2f} minutes)")
print(f"Final Memory: {peak_memory:.2f} MB")
print("-" * 30)
save_vocab_and_merges(vocab, merges, output_dir)
# Output Statistics information
print("\n=== Statistics (Problem b) ===")
# 1. Longest token
longest_token_bytes = max(vocab.values(), key=len)
try:
longest_token_str = longest_token_bytes.decode('utf-8')
except:
longest_token_str = str(longest_token_bytes)
print(f"Longest Token: {longest_token_str!r}")
print(f"Length in bytes: {len(longest_token_bytes)}")
# 2. Most frequent token (approximate, based on merge priority if we tracked it,
# but here we can just say the last merged token was the most frequent *at that step*)
# The assignment asks for "most frequent token in the dataset"?
# Usually BPE doesn't keep full frequency counts of final vocab unless we re-tokenize.
# We will just print the last merge which represents the most frequent pair remaining.
print(f"Total Merges: {len(merges)}")
3.3 运行训练脚本
- 执行如下命令
uv run python scripts/train_bpe_tinystories.py \
--input_path data/TinyStoriesV2-GPT4-train.txt \
--vocab_size 10000 \
--profile
控制台可以看到如下输出
Enabling cProfile...
Initial Memory: 23.41 MB
Starting training on data/TinyStoriesV2-GPT4-train.txt...
Starting BPE training (Optimized)...
Merge 0/9743: (b' ', b't')
Merge 100/9743: (b'ri', b'end')
...
Merge 9600/9743: (b' pain', b'ful')
Merge 9700/9743: (b'St', b'ill')
------------------------------
Training Complete.
Time Taken: 731.09 seconds (12.18 minutes)
Final Memory: 96.74 MB
------------------------------
Artifacts saved to tinystories_tokenizer/
=== Statistics (Problem b) ===
Longest Token: ' accomplishment'
Length in bytes: 15
Total Merges: 9743
Profiling data saved to training.prof
Use 'snakeviz training.prof' to visualize.
3.4 检查确认
- [0,255] 为 256 个字符。
- 执行
grep -ne "<|endoftext|>" tinystories_tokenizer/vocab.json确认 "<|endoftext|>"加入到了 vocabulary。
输出258: "256": "<|endoftext|>",确认 "<|endoftext|>" 加入到 vocabulary, idx 为 256,符合预期。
文件不大也可以打开查看。 - [257:]: 训练完成得到的 vocabulary。
查看统计信息
由于 toml 文件中没有 snakeviz,我们需要单独安装 snakeviz,
- 安装 snakeviz (如果还没安装)
uv pip install snakeviz
- 启动可视化服务器
uv run snakeviz training.prof
然后浏览器打开


如果本地不方便,可以使用 python 内置的 pstats 查看文本报告
# 查看累计耗时(cumulative)排名前 10 的函数
uv run python -c "import pstats; p = pstats.Stats('training.prof'); p.sort_stats('cumulative').print_stats(10)"
耗时最长的前10
1368017311 function calls (1368017210 primitive calls) in 731.144 seconds
Ordered by: cumulative time
List reduced from 251 to 10 due to restriction <10>
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.029 0.029 731.144 731.144 /home/fdq/cources/cs336/assignment1-basics/scripts/train_bpe_tinystories.py:51(run_training)
1 0.372 0.372 731.066 731.066 /home/fdq/cources/cs336/assignment1-basics/cs336_basics/train_bpe_optimize.py:18(train)
1 369.361 369.361 460.691 460.691 /home/fdq/cources/cs336/assignment1-basics/cs336_basics/train_bpe_optimize.py:205(_pretokenize)
9746/9745 83.735 0.009 265.975 0.027 {built-in method builtins.max}
369218707 118.986 0.000 182.240 0.000 /home/fdq/cources/cs336/assignment1-basics/cs336_basics/train_bpe_optimize.py:52(<lambda>)
372023448 63.614 0.000 63.614 0.000 {method 'get' of 'dict' objects}
536592168 44.123 0.000 44.123 0.000 {method 'group' of '_regex.Match' objects}
2761194 2.553 0.000 39.444 0.000 /home/fdq/cources/cs336/assignment1-basics/.venv/lib/python3.11/site-packages/regex/regex.py:340(finditer)
2804690 8.909 0.000 36.517 0.000 /home/fdq/cources/cs336/assignment1-basics/.venv/lib/python3.11/site-packages/regex/regex.py:449(_compile)
2761196 3.285 0.000 12.991 0.000 /home/fdq/.local/share/uv/python/cpython-3.11.12-linux-x86_64-gnu/lib/python3.11/locale.py:679(getpreferredencoding)
根据 Profiling (性能分析) 结果揭示了程序的运行状况!这对于回答作业中的问题 (b) 以及后续优化至关重要。
1. 深度分析 Profiling 结果
根据 pstats 文本和 snakeviz 可视化页面,我们可以清晰地得出结论:
瓶颈一:Pre-tokenization (最耗时)
- 原因 1:这是单线程运行的。Python 的正则引擎虽然快,但要处理的是几 GB 的文本数据。一个 CPU 核心逐字逐句地扫描、匹配、计数,必然是慢的。
- 证据:
_pretokenize函数占据了 460.69 秒 (约占总时间的 63%)。 - 作业提示:这就是为什么作业 Hint 提到 "using multiprocessing during pretokenization"。
- 原因 2:正则匹配耗时
- 证据:finditer 和 _compile 以及 group 方法被调用了数百万次,占用了大量时间。我是在循环外编译的,
瓶颈二:寻找最佳 Pair (max 操作)
-
证据:
builtins.max耗时 265.97 秒。 -
细节分析:
-
调用了
max约 1 万次(每次合并一次)。 -
但关键在于那个
lambda函数被调用了 3.69 亿次! -
369218707 ... <lambda>和372023448 ... {method 'get' of 'dict' objects}。 -
罪魁祸首
找到这行代码
max(pair_counts, key=lambda x: (pair_counts[x], pair_strings.get(x, (b'', b''))))
当pair 频次相同时,为了打破平局(Tie-breaking),再按 pair 的字典序排序,对于字典里的每一个 Pair,都调用了一次 pair_strings.get(x)。这在 3.69 亿次调用下,带来了巨大的累积开销 (约 63秒花在 get 上,118秒花在 lambda 本身逻辑上)。
2. 回答作业问题 (b)
(b) Profile your code. What part of the tokenizer training process takes the most time?
中文: “性能分析显示,预分词 (pre-tokenization) 步骤 是最耗时的部分,约占总运行时间的 63% (460秒),这是由于单线程正则处理大量语料造成的。第二大的瓶颈是合并循环中的 max 操作 (265秒),主要是因为在处理平局的 lambda 函数中频繁进行字典查找 (pair_strings.get) 带来的开销。
” 英文回答: "The profiling results indicate that the pre-tokenization step is the most time-consuming part, accounting for approximately 63% of the total runtime (460s) due to the single-threaded regex processing of the large corpus. The second largest bottleneck is the max operation in the merge loop (265s), specifically caused by the overhead of dictionary lookups (pair_strings.get) inside the tie-breaking lambda function."
3. (可选但推荐) 冲击 "2分钟" 目标:代码优化
虽然现在的 12 分钟已经远低于作业要求的 30 分钟,但如果想体验一下极致优化的快感(以及满足 Hint 里的 2 分钟目标),需要做两件事:
第一步:优化 max 的 Tie-breaking (极速修正)
这个只需要修改一行代码,我就先做了,结果翻车了。。。。
在 max 里做这种复杂的 get 查找太慢了。我们可以利用 Python 元组比较的特性,直接比较 pair 本身(它是 int tuple,比如 (12, 34)),这比查字典找 bytes 快得多。
修改 train_bpe_optimize.py 中的这一行:
# 原来的代码 (慢)
# merge_pair = max(pair_counts, key=lambda x: (pair_counts[x], pair_strings.get(x, (b'', b''))))
# 修改后的代码 (快)
# 解释:如果频率一样,Python 会自动比较 key 本身 (id_1, id_2)。
# 这种比较是 C 语言层面的,极快,且结果是确定的。
merge_pair = max(pair_counts, key=lambda x: (pair_counts[x], x))
profiling 结果:
List reduced from 251 to 10 due to restriction <10>
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.029 0.029 713.010 713.010 /home/fdq/cources/cs336/assignment1-basics/scripts/train_bpe_tinystories.py:51(run_training)
1 0.355 0.355 712.921 712.921 /home/fdq/cources/cs336/assignment1-basics/cs336_basics/train_bpe_optimize.py:18(train)
1 364.844 364.844 452.898 452.898 /home/fdq/cources/cs336/assignment1-basics/cs336_basics/train_bpe_optimize.py:205(_pretokenize)
9746/9745 83.417 0.009 255.732 0.026 {built-in method builtins.max}
369218707 114.033 0.000 172.315 0.000 /home/fdq/cources/cs336/assignment1-basics/cs336_basics/train_bpe_optimize.py:52(<lambda>)
372023448 58.602 0.000 58.602 0.000 {method 'get' of 'dict' objects}
536592168 44.608 0.000 44.608 0.000 {method 'group' of '_regex.Match' objects}
2761194 2.350 0.000 35.984 0.000 /home/fdq/cources/cs336/assignment1-basics/.venv/lib/python3.11/site-packages/regex/regex.py:340(finditer)
2804690 8.114 0.000 33.361 0.000 /home/fdq/cources/cs336/assignment1-basics/.venv/lib/python3.11/site-packages/regex/regex.py:449(_compile)
2761196 3.020 0.000 11.858 0.000 /home/fdq/.local/share/uv/python/cpython-3.11.12-linux-x86_64-gnu/lib/python3.11/locale.py:679(getpreferredencoding)
但是,我修改了之后并没有改善,而且从 profiling 看,还在调用 dict.get,改动没有生效????我打算暂时放过这里,把经历放在 pretokenize 的多进程上,这是assignment 的建议。
接下来我把重点放在使用 multiprocessing 实现 pretokenize上。
=====================================================================
更新:
前面提到没有生效,是被动了编辑器,大概是复制 merge 函数后又删除,没有保存,总之,我看不到,却存在,不过,确实不能使用
merge_pair = max(pair_counts, key=lambda x: (pair_counts[x], x))
因为pdf 文档明确指出"lexicographically greater"比较的字节,并给出了例子:
>>> max([("A", "B"), ("A", "C"), ("B", "ZZ"), ("BA", "A")])
('BA', 'A')
s所以,执行官方测试会出现以下错误:
v run pytest tests/test_train_bpe.py -v 2>&1 | tail -20
> assert actual[key] == expected_data[key], (
f"Data for key '{key}' does not match snapshot for {test_name}"
)
E AssertionError: Data for key 'merges' does not match snapshot for test_train_bpe_special_tokens
E assert [(b'h', b'e')...', b'd'), ...] == [(b'h', b'e')...', b'd'), ...]
E
E At index 34 diff: (b' s', b'a') != (b'i', b'm')
E
E Full diff:
E [
E (
E b'h',...
E
E ...Full output truncated (3193 lines hidden), use '-vv' to show
tests/conftest.py:146: AssertionError
=========================== short test summary info ============================
FAILED tests/test_train_bpe.py::test_train_bpe - AssertionError: assert [(b' ...
FAILED tests/test_train_bpe.py::test_train_bpe_special_tokens - AssertionErro...
还是得使用
merge_pair = max(pair_counts, key=lambda x: (pair_counts[x], pair_strings.get(x, ('b', 'b'))))
另外,前面提到 special token str 正在编译调用几百万次,我也把编译从循环内提到循环外了,从监控看优化后的就不再有这个问题,我取了前10,至少说明不在消耗前10 中了:
ncalls tottime percall cumtime percall filename:lineno(function)
1 0.033 0.033 704.458 704.458 /home/fdq/cources/cs336/assignment1-basics/scripts/train_bpe_tinystories.py:51(run_training)
1 0.392 0.392 704.365 704.365 /home/fdq/cources/cs336/assignment1-basics/cs336_basics/train_bpe_optimize.py:18(train)
1 376.948 376.948 428.006 428.006 /home/fdq/cources/cs336/assignment1-basics/cs336_basics/train_bpe_optimize.py:159(_pretokenize)
9746/9745 83.887 0.009 271.807 0.028 {built-in method builtins.max}
369218707 120.925 0.000 187.920 0.000 /home/fdq/cources/cs336/assignment1-basics/cs336_basics/train_bpe_optimize.py:52(<lambda>)
369218760 66.994 0.000 66.994 0.000 {method 'get' of 'dict' objects}
536592168 43.463 0.000 43.463 0.000 {method 'group' of '_regex.Match' objects}
43496 0.958 0.000 4.115 0.000 /home/fdq/cources/cs336/assignment1-basics/cs336_basics/train_bpe_optimize.py:188(_chunk_documents_streaming)
277780 3.221 0.000 3.836 0.000 /home/fdq/cources/cs336/assignment1-basics/cs336_basics/train_bpe_optimize.py:104(_update_word_for_merge)
43495 2.032 0.000 3.033 0.000 {method 'read' of '_io.TextIOWrapper' objects}
更多推荐


所有评论(0)