之前写了一篇文章,由于数据集规模很大,在我电脑上运行训练Tokenizer会报内存不够。预训练和SFT训练,需要60~80个小时。因此,做了一个精简的版本,只保留10万条数据,1个小时以内可以走一遍大模型Tokenizer训练、预训练、SFT训练过程。

详细说明如下:

以下内容基于开源项目https://github.com/datawhalechina/happy-llm,并针对个人电脑进行了调整。大家可以通过happy-llm这个项目,学习大模型相关的原理,并一步一步,在自己的电脑上面预训练出一个非常迷你的“大模型”。
我的运行环境如下:
1、硬件:CPU i5-13490f,GPU 5090 32G,内存 64G DDR4,硬盘 1T M.2 SSD
2、操作系统:Ubuntu 24.04,NVIDIA Driver Version: 575.64.03      CUDA Version: 12.9
3、Python环境:conda 25.5.1,Python 3.13.5,pip依赖见最后附录1:

一、下载项目

git clone https://github.com/datawhalechina/happy-llm.git
cd happy-llm
mkdir dataset
cd dataset

二、下载数据集
首先,需要下载预训练数据集。在这里,我们使用两个开源的数据集,包含了大量的中文对话数据,可以用于训练对话生成模型。
1)出门问问序列猴子开源数据集:出门问问序列猴子通用文本数据集由来自网页、百科、博客、问答、开源代码、书籍、报刊、专利、教材、考题等多种公开可获取的数据进行汇总清洗之后而形成的大语言模型预训练语料。总量大概在 10B Token。
2)BelleGroup:350万条中文对话数据集,包含了人机对话、人人对话、人物对话等多种对话数据,可以用于训练对话生成模型。
数据集的下载全部从魔搭,这样就不用额外翻墙:

modelscope download --dataset ddzhu123/seq-monkey mobvoi_seq_monkey_general_open_corpus.jsonl.tar.bz2 --local_dir your_local_dir
tar -xvf your_local_dir/mobvoi_seq_monkey_general_open_corpus.jsonl.tar.bz2
mkdir BelleGroup
modelscope download --model fq980207/train_3.5M_CN  --local_dir ./BelleGroup


三、对数据集进行处理
由于数据集比较大,如果全部进行训练,对电脑的内存消耗非常大(T级),并且需要训练好几天。因此,我们只选择了10万条记录,用于进行训练。整个过程,在5090上不到1个小时可以跑完。当然数据集的减小,也会导致模型的效果下降:
使用 happy-llm/docs/chapter5/code/deal_dataset.py,将 data = f.readlines() 改为 data = f.readlines()[:100000]

cp ../docs/chapter5/code/deal_dataset.py ./
vi deal_dataset.py

修改后的deal_dataset.py程序如下:

import os
import json
from tqdm import tqdm
# 1 处理预训练数据
def split_text(text, chunk_size=512):
    """将文本按指定长度切分成块"""
    return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]

input_file = 'mobvoi_seq_monkey_general_open_corpus.jsonl'

with open('seq_monkey_datawhale.jsonl', 'a', encoding='utf-8') as pretrain:
    with open(input_file, 'r', encoding='utf-8') as f:
        data = f.readlines()[:100000]  # 只读取前100000行
        for line in tqdm(data, desc=f"Processing lines in {input_file}", leave=False):
            line = json.loads(line)
            text = line['text']
            chunks = split_text(text)
            for chunk in chunks:
                pretrain.write(json.dumps({'text': chunk}, ensure_ascii=False) + '\n')

# 2 处理SFT数据
def convert_message(data):
    """
    将原始数据转换为标准格式
    """
    message = [
        {"role": "system", "content": "你是一个AI助手"},
    ]
    for item in data:
        if item['from'] == 'human':
            message.append({'role': 'user', 'content': item['value']})
        elif item['from'] == 'assistant':
            message.append({'role': 'assistant', 'content': item['value']})
    return message

with open('BelleGroup_sft.jsonl', 'a', encoding='utf-8') as sft:
    with open('BelleGroup/train_3.5M_CN.json', 'r', encoding='utf-8') as f:
        data = f.readlines()[:100000]  # 只读取前100000行
        for item in tqdm(data, desc="Processing", unit="lines"):
            item = json.loads(item)
            message = convert_message(item['conversations'])
            sft.write(json.dumps(message, ensure_ascii=False) + '\n')

主要修改点:
1、在第一个处理预训练数据的部分,将 data = f.readlines() 改为 data = f.readlines()[:100000]
2、在第二个处理SFT数据的部分,同样将 data = f.readlines() 改为 data = f.readlines()[:100000]
这样修改后,程序只会处理每个输入文件的前100000行数据。

python deal_dataset.py

运行后,会生成seq_monkey_datawhale.jsonl和BelleGroup_sft.jsonl两个数据集文件,用于训练。


四、训练 Tokenizer
首先,我们需要为文本处理训练一个Tokenizer。Tokenizer的作用是将文本转换为数字序列,以便模型能够理解和处理。我们使用的数据集是已经下载的出门问问序列猴子开源数据集,这个数据集包含了大量的中文文本数据,可以用于训练Tokenizer。
由于上一步,我们已经把数据集缩小为前10万条,所以对训练所需的内存大大减少,32~64G内存就可以运行。如果你的内存不够,也可以通过扩大内存交换分区方式运行。例如,建立了一个50G的交换分区:

sudo fallocate -l 50G /swapfile
sudo chmod 600 /swapfile
sudo mkswap /swapfile
sudo swapon /swapfile
# 在另一个终端运行观察内存
watch -n 1 free -h

tokenizer训练文件如下:

使用happy-llm/docs/chapter5/code/train_tokenizer.py
将data_path = "your data path" 改为data_path = "./seq_monkey_datawhale.jsonl"

cp ../docs/chapter5/code/train_tokenizer.py ./
vi train_tokenizer.py
python train_tokenizer.py

大概训练个几分钟,就会在tokenizer_k目录生成对应的tokenizer.json、tokenizer_config.json、special_tokens_map.json。由于训练的数据规模较小,训练的tokenizer效果一般,也可以直接复制 happy-llm/docs/chapter5/code/tokenizer_k目录,已经训练好的tokenizer文件。


五、模型预训练
使用 happy-llm/docs/chapter5/code/ddp_pretrain.py 程序,命令如下:

cp ../docs/chapter5/code/ddp_pretrain.py ./
python ddp_pretrain.py --data_path seq_monkey_datawhale.jsonl --gpus 0 --batch_size 18

在5090下,batch_size 18,大概消耗28G左右显存,10万条记录25分钟左右就可以训练完毕。在base_model_215M就可以看到训练生成的基础模型文件:pretrain_1024_18_6144.pth、pretrain_1024_18_6144_step20000.pth、pretrain_1024_18_6144_step40000.pth。


六、SFT训练
使用 happy-llm/docs/chapter5/code/ddp_sft_full.py 程序,命令如下:

cp ../docs/chapter5/code/ddp_sft_full.py ./
python ddp_sft_full.py --data_path BelleGroup_sft.jsonl --gpus 0 --batch_size 18

在5090下,batch_size 18,大概消耗18G左右显存,10万条记录12分钟左右就可以训练完毕。在sft_model_215M就可以看到训练生成的基础模型文件:sft_dim1024_layers18_vocab_size6144.pth


七、模型推理测试
使用happy-llm/docs/chapter5/code/model_sample.py 程序,命令如下:

cp ../docs/chapter5/code/model_sample.py ./
python model_sample.py

就可以看到自己亲手训练的“大”模型的运行效果。由于数据量只选了10万条,回答的结果就很一般了。如果电脑内存大,而且时间充足,可以重复上面的步骤,在“三、对数据集进行处理”环节,取更多的数据量,最后训练的效果会更好。


附录1:
截至2025年7月,5090需要安装previous versions,选择cuda 12.9(我的Ubuntu24.04,针对5090的cuda就是12.9)。

pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129

我的Python的依赖包环境如下:

pip list
Package                  Version
------------------------ ------------------------
accelerate               1.9.0
aiohappyeyeballs         2.6.1
aiohttp                  3.12.15
aiosignal                1.4.0
anaconda-anon-usage      0.7.1
annotated-types          0.6.0
archspec                 0.2.3
attrs                    25.3.0
boltons                  24.1.0
boto3                    1.40.1
botocore                 1.40.1
Brotli                   1.0.9
certifi                  2025.4.26
cffi                     1.17.1
charset-normalizer       3.3.2
click                    8.2.1
conda                    25.5.1
conda-anaconda-telemetry 0.1.2
conda-anaconda-tos       0.2.0
conda-content-trust      0.2.0
conda-libmamba-solver    25.4.0
conda-package-handling   2.4.0
conda_package_streaming  0.11.0
contourpy                1.3.3
cryptography             45.0.3
cycler                   0.12.1
datasets                 4.0.0
dill                     0.3.8
distro                   1.9.0
einops                   0.8.1
filelock                 3.18.0
flash_attn               2.8.2
fonttools                4.59.0
frozendict               2.4.2
frozenlist               1.7.0
fsspec                   2025.3.0
hf-xet                   1.1.5
huggingface-hub          0.34.3
idna                     3.7
Jinja2                   3.1.6
jmespath                 1.0.1
joblib                   1.5.1
jsonlines                4.0.0
jsonpatch                1.33
jsonpointer              2.1
kiwisolver               1.4.8
libmambapy               2.0.5
markdown-it-py           2.2.0
MarkupSafe               2.1.5
marshmallow              4.0.0
matplotlib               3.10.5
mdurl                    0.1.0
menuinst                 2.2.0
modelscope               1.28.0
mpmath                   1.3.0
multidict                6.6.3
multiprocess             0.70.16
networkx                 3.5
ngrok                    1.5.0
ninja                    1.11.1.4
nltk                     3.9.1
numpy                    2.3.1
nvidia-cublas-cu12       12.8.4.1
nvidia-cuda-cupti-cu12   12.8.90
nvidia-cuda-nvrtc-cu12   12.8.93
nvidia-cuda-runtime-cu12 12.8.90
nvidia-cudnn-cu12        9.10.2.21
nvidia-cufft-cu12        11.3.3.83
nvidia-cufile-cu12       1.13.1.3
nvidia-curand-cu12       10.3.9.90
nvidia-cusolver-cu12     11.7.3.90
nvidia-cusparse-cu12     12.5.8.93
nvidia-cusparselt-cu12   0.7.1
nvidia-ml-py             12.575.51
nvidia-nccl-cu12         2.27.5
nvidia-nvjitlink-cu12    12.8.93
nvidia-nvshmem-cu12      3.3.9
nvidia-nvtx-cu12         12.8.90
packaging                24.2
pandas                   2.3.1
pillow                   11.3.0
pip                      25.1
platformdirs             4.3.7
pluggy                   1.5.0
prettytable              3.16.0
propcache                0.3.2
protobuf                 6.31.1
psutil                   7.0.0
pyarrow                  21.0.0
pycosat                  0.6.6
pycparser                2.21
pydantic                 2.10.3
pydantic_core            2.27.1
pyecharts                2.0.8
Pygments                 2.19.1
pynvml                   12.0.0
pyparsing                3.2.3
PySocks                  1.7.1
python-dateutil          2.9.0.post0
pytorch-triton           3.4.0+gitae848267
pytz                     2025.2
PyYAML                   6.0.2
regex                    2024.11.6
requests                 2.32.3
rich                     13.9.4
ruamel.yaml              0.18.10
ruamel.yaml.clib         0.2.12
s3transfer               0.13.1
safetensors              0.5.3
scikit-learn             1.7.1
scipy                    1.16.1
sentence-transformers    5.0.0
setuptools               78.1.1
simhash                  2.1.2
simplejson               3.20.1
six                      1.17.0
swankit                  0.2.4
swanlab                  0.6.8
sympy                    1.14.0
threadpoolctl            3.6.0
tiktoken                 0.9.0
tokenizers               0.21.4
torch                    2.9.0.dev20250717+cu128
torchaudio               2.8.0.dev20250717+cu128
torchvision              0.24.0.dev20250717+cu128
tqdm                     4.67.1
transformers             4.54.1
trl                      0.20.0
truststore               0.10.0
typing_extensions        4.12.2
tzdata                   2025.2
ujson                    5.10.0
urllib3                  2.3.0
wcwidth                  0.2.13
wheel                    0.45.1
wrapt                    1.17.2
xxhash                   3.5.0
yarl                     1.20.1
zstandard                0.23.0


 

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐