# zero2_gpu_example.py

import os
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.optim import Adam

# --- 1. 模型定义 ---
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(10, 20)
        self.layer2 = nn.Linear(20, 5)

    def forward(self, x):
        return self.layer2(self.layer1(x))

# --- 2. 简化的 ZeRO-2 优化器 ---
class ZeRO_2_Optimizer:
    def __init__(self, model, optimizer_class, **optimizer_kwargs):
        if not dist.is_initialized():
            raise RuntimeError("Distributed environment is not initialized.")
        
        self.rank = dist.get_rank()
        self.world_size = dist.get_world_size()
        self.device = torch.device(f"cuda:{self.rank}")
        
        # 将模型所有参数扁平化为一个一维张量
        self.flat_params = torch.cat([p.data.flatten() for p in model.parameters()])
        self.param_partitions = list(self.flat_params.chunk(self.world_size))
        
        # 每个 rank 只为自己负责的参数分片创建优化器
        # 这个虚拟参数也必须在正确的设备上
        self.param_partition_for_rank = self.param_partitions[self.rank].detach().clone().to(self.device).requires_grad_(True)
        self.base_optimizer = optimizer_class([self.param_partition_for_rank], **optimizer_kwargs)
        
        self.original_model = model

        print(f"[Rank {self.rank}] ZeRO-2 Optimizer initialized. Responsible for {self.param_partition_for_rank.numel()} parameters on device {self.device}.")
        
    def zero_grad(self):
        # 清除模型中所有参数的梯度
        for p in self.original_model.parameters():
            if p.grad is not None:
                p.grad.detach_()
                p.grad.zero_()
        
    def step(self):
        # --- ZeRO-2 的核心流程 ---

        # 1. 将所有 GPU 上的完整梯度进行 Reduce-Scatter
        with torch.no_grad():
            full_grads = torch.cat([p.grad.flatten() for p in self.original_model.parameters()])
        
        # 准备接收梯度分片的缓冲区,确保在正确的设备上
        grad_partition = torch.zeros_like(self.param_partition_for_rank, device=self.device)
        
        print(f"\n[Rank {self.rank}] Step 1: Performing Reduce-Scatter on gradients...")
        print(f"  - Full gradient shape on this rank: {full_grads.shape}")
        
        dist.reduce_scatter(grad_partition, list(full_grads.chunk(self.world_size)), op=dist.ReduceOp.SUM)
        
        grad_partition.div_(self.world_size)
        
        print(f"  - Received gradient partition of shape: {grad_partition.shape} on device {grad_partition.device}")

        # 2. 本地更新
        self.param_partition_for_rank.grad = grad_partition
        
        print(f"[Rank {self.rank}] Step 2: Performing local optimizer step...")
        self.base_optimizer.step()
        print(f"  - Local parameter partition updated.")

        # 3. 将更新后的参数分片进行 All-Gather
        print(f"[Rank {self.rank}] Step 3: Performing All-Gather on updated parameters...")
        
        # 准备一个列表,用于收集所有 rank 的参数分片,确保缓冲区在正确设备上
        updated_partitions = [torch.zeros_like(p, device=self.device) for p in self.param_partitions]
        dist.all_gather(updated_partitions, self.param_partition_for_rank)
        
        updated_flat_params = torch.cat(updated_partitions)
        print(f"  - Gathered all partitions. Full updated params shape: {updated_flat_params.shape}")

        # 4. 将最新的参数更新回原始模型
        print(f"[Rank {self.rank}] Step 4: Updating the original model with new parameters.")
        offset = 0
        for p in self.original_model.parameters():
            numel = p.numel()
            p.data.copy_(updated_flat_params[offset:offset+numel].view_as(p.data))
            offset += numel
        print(f"  - Original model updated.\n")


# --- 3. 分布式训练函数 ---
def train_worker(rank, world_size):
    # 初始化分布式环境
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12345'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    
    # 绑定进程到 GPU
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")
    
    # 确保每个进程有不同的随机种子
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed_all(42 + rank)
    
    # 创建模型并移动到 GPU
    model = SimpleModel().to(device)
    
    # 创建 ZeRO-2 优化器
    zero2_optimizer = ZeRO_2_Optimizer(model, Adam, lr=0.01)

    # 模拟两次训练迭代
    for i in range(2):
        print(f"--- [Rank {rank}] Iteration {i+1} ---")
        zero2_optimizer.zero_grad()
        
        # 模拟输入和前向传播(输入数据也要在 GPU 上)
        inputs = torch.randn(8, 10).to(device)
        outputs = model(inputs)
        loss = outputs.sum()
        
        # 后向传播,计算本地梯度
        loss.backward()
        print(f"[Rank {rank}] Backward pass complete. Gradients computed locally.")
        
        # 执行优化器步骤(包含 ZeRO-2 的核心通信)
        zero2_optimizer.step()
        
        # 等待所有进程都完成 step
        dist.barrier()
        
        # 验证所有 rank 上的模型参数是否一致
        with torch.no_grad():
            local_params = torch.cat([p.data.flatten() for p in model.parameters()])
            # 创建一个缓冲区来接收 rank 0 的参数
            rank0_params_buffer = torch.zeros_like(local_params)
            
            if rank == 0:
                rank0_params_buffer.copy_(local_params)
            
            # 从 rank 0 广播参数到所有其他进程
            dist.broadcast(rank0_params_buffer, src=0)
            
            # 所有进程进行比较
            assert torch.allclose(local_params, rank0_params_buffer), \
                f"[Rank {rank}] Model parameters are NOT consistent across ranks!"
        
        print(f"[Rank {rank}] Verified that model parameters are consistent across all ranks after step.")
        print("-" * 40)

    dist.destroy_process_group()


# --- 4. 启动器 ---
def main():
    # 确保您至少有2个GPU可用
    world_size = torch.cuda.device_count()
    if world_size < 2:
        print("This example requires at least 2 GPUs.")
        return
        
    print(f"Spawning {world_size} processes for distributed training...")
    mp.spawn(train_worker,
             args=(world_size,),
             nprocs=world_size,
             join=True)

if __name__ == "__main__":
    main()
(python3.10) bash-4.4$ python testzero2.py 
/opt/conda/envs/python3.10/lib/python3.10/site-packages/torch/cuda/__init__.py:61: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
Spawning 2 processes for distributed training...
/opt/conda/envs/python3.10/lib/python3.10/site-packages/torch/cuda/__init__.py:61: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
/opt/conda/envs/python3.10/lib/python3.10/site-packages/torch/cuda/__init__.py:61: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
  import pynvml  # type: ignore[import]
[Rank 1] ZeRO-2 Optimizer initialized. Responsible for 162 parameters on device cuda:1.
--- [Rank 1] Iteration 1 ---
[Rank 0] ZeRO-2 Optimizer initialized. Responsible for 163 parameters on device cuda:0.
--- [Rank 0] Iteration 1 ---
[Rank 1] Backward pass complete. Gradients computed locally.

[Rank 1] Step 1: Performing Reduce-Scatter on gradients...
  - Full gradient shape on this rank: torch.Size([325])
notebook-76ec3aa0aa0d-worker-0:42641:42641 [1] NCCL INFO C4 stats mode none, reduce 1, send/recv 0.
[Rank 0] Backward pass complete. Gradients computed locally.

[Rank 0] Step 1: Performing Reduce-Scatter on gradients...
  - Full gradient shape on this rank: torch.Size([325])
notebook-76ec3aa0aa0d-worker-0:42640:42640 [0] NCCL INFO C4 stats mode none, reduce 1, send/recv 0.
notebook-76ec3aa0aa0d-worker-0:42640:42640 [0] NCCL INFO Bootstrap: Using eth0:33.203.57.184<0>
notebook-76ec3aa0aa0d-worker-0:42640:42640 [0] NCCL INFO ACCL_TUNING_LEVEL=2
notebook-76ec3aa0aa0d-worker-0:42640:42640 [0] NCCL INFO cudaDriverVersion 12040
notebook-76ec3aa0aa0d-worker-0:42640:42640 [0] NCCL INFO NCCL version 2.26.5.12-accl-n+cuda12.8, COMMIT_ID 2e6879b700b6cd1510bf9dc29a9db0132609d4cc, BUILD_TIME 2025-05-21 11:34:07
notebook-76ec3aa0aa0d-worker-0:42641:42641 [1] NCCL INFO cudaDriverVersion 12040
notebook-76ec3aa0aa0d-worker-0:42641:42641 [1] NCCL INFO Bootstrap: Using eth0:33.203.57.184<0>
notebook-76ec3aa0aa0d-worker-0:42641:42641 [1] NCCL INFO ACCL_TUNING_LEVEL=2
notebook-76ec3aa0aa0d-worker-0:42641:42641 [1] NCCL INFO NCCL version 2.26.5.12-accl-n+cuda12.8, COMMIT_ID 2e6879b700b6cd1510bf9dc29a9db0132609d4cc, BUILD_TIME 2025-05-21 11:34:07
notebook-76ec3aa0aa0d-worker-0:42640:42640 [0] NCCL INFO Comm config Blocking set to 1
notebook-76ec3aa0aa0d-worker-0:42640:42640 [0] NCCL INFO Comm config Traffic class set to 26215043
notebook-76ec3aa0aa0d-worker-0:42641:42641 [1] NCCL INFO Comm config Blocking set to 1
notebook-76ec3aa0aa0d-worker-0:42641:42641 [1] NCCL INFO Comm config Traffic class set to 32652
notebook-76ec3aa0aa0d-worker-0:42641:42835 [1] NCCL INFO NET/Plugin: Could not find: libnccl-net-none.so. Using internal net plugin.
notebook-76ec3aa0aa0d-worker-0:42641:42835 [1] NCCL INFO NET/IB : No device found.
notebook-76ec3aa0aa0d-worker-0:42641:42835 [1] NCCL INFO NET/IB : Using [RO]; OOB eth0:33.203.57.184<0>
notebook-76ec3aa0aa0d-worker-0:42641:42835 [1] NCCL INFO NET/Socket : Using [0]eth0:33.203.57.184<0>
notebook-76ec3aa0aa0d-worker-0:42641:42835 [1] NCCL INFO PROFILER/Plugin: Could not find: libnccl-profiler.so. 
notebook-76ec3aa0aa0d-worker-0:42641:42835 [1] NCCL INFO Using network Socket
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO NET/Plugin: Could not find: libnccl-net-none.so. Using internal net plugin.
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO NET/IB : No device found.
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO NET/IB : Using [RO]; OOB eth0:33.203.57.184<0>
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO NET/Socket : Using [0]eth0:33.203.57.184<0>
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO PROFILER/Plugin: Could not find: libnccl-profiler.so. 
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO Using network Socket
notebook-76ec3aa0aa0d-worker-0:42641:42835 [1] NCCL INFO ncclCommInitRankConfig comm 0xb7856b0 rank 1 nranks 2 cudaDev 1 nvmlDev 1 busId 17f000 commId 0xe52d77c7f68d320d commHash 16513987109356122637 - Init START
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO ncclCommInitRankConfig comm 0xb7853e0 rank 0 nranks 2 cudaDev 0 nvmlDev 0 busId 109000 commId 0xe52d77c7f68d320d commHash 16513987109356122637 - Init START
notebook-76ec3aa0aa0d-worker-0:42641:42835 [1] NCCL INFO RAS client listening socket at ::1<28028>
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO RAS client listening socket at ::1<28028>
notebook-76ec3aa0aa0d-worker-0:42641:42835 [1] NCCL INFO Bootstrap timings total 0.009070 (create 0.000026, send 0.000113, recv 0.008253, ring 0.000335, delay 0.000001)
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO Bootstrap timings total 0.000956 (create 0.000022, send 0.000124, recv 0.000179, ring 0.000350, delay 0.000001)
notebook-76ec3aa0aa0d-worker-0:42641:42835 [1] NCCL INFO gTaskid = b839cb21000063b5.
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO gTaskid = b839cb21000063b5.
notebook-76ec3aa0aa0d-worker-0:42641:42835 [1] NCCL INFO hpn_license_register error: Error occurred during zmq_msg_recv(): Resource temporarily unavailable.

notebook-76ec3aa0aa0d-worker-0:42641:42835 [1] NCCL INFO MNNVL busId 0x17f000 fabric UUID 0.0 cliqueId 0x0 state 3 healthMask 0x0
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO hpn_license_register error: Error occurred during zmq_msg_recv(): Resource temporarily unavailable.

notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO MNNVL busId 0x109000 fabric UUID 0.0 cliqueId 0x0 state 3 healthMask 0x0
notebook-76ec3aa0aa0d-worker-0:42641:42835 [1] NCCL INFO Setting affinity for GPU 1 to ffffffff,ffff0000,00000000,ffffffff,ffff0000,00000000
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO Setting affinity for GPU 0 to ffffffff,ffff0000,00000000,ffffffff,ffff0000,00000000
notebook-76ec3aa0aa0d-worker-0:42641:42835 [1] NCCL INFO comm 0xb7856b0 rank 1 nRanks 2 nNodes 1 localRanks 2 localRank 1 MNNVL 0
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO comm 0xb7853e0 rank 0 nRanks 2 nNodes 1 localRanks 2 localRank 0 MNNVL 0
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO Channel 00/16 : 0 1
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO Channel 01/16 : 0 1
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO Channel 02/16 : 0 1
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO Channel 03/16 : 0 1
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO Channel 04/16 : 0 1
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO Channel 05/16 : 0 1
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO Channel 06/16 : 0 1
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO Channel 07/16 : 0 1
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO Channel 08/16 : 0 1
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO Channel 09/16 : 0 1
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO Channel 10/16 : 0 1
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO Channel 11/16 : 0 1
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO Channel 12/16 : 0 1
notebook-76ec3aa0aa0d-worker-0:42641:42835 [1] NCCL INFO Trees [0] -1/-1/-1->1->0 [1] -1/-1/-1->1->0 [2] -1/-1/-1->1->0 [3] -1/-1/-1->1->0 [4] 0/-1/-1->1->-1 [5] 0/-1/-1->1->-1 [6] 0/-1/-1->1->-1 [7] 0/-1/-1->1->-1 [8] -1/-1/-1->1->0 [9] -1/-1/-1->1->0 [10] -1/-1/-1->1->0 [11] -1/-1/-1->1->0 [12] 0/-1/-1->1->-1 [13] 0/-1/-1->1->-1 [14] 0/-1/-1->1->-1 [15] 0/-1/-1->1->-1
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO Channel 13/16 : 0 1
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO Channel 14/16 : 0 1
notebook-76ec3aa0aa0d-worker-0:42641:42835 [1] NCCL INFO P2P Chunksize set to 524288
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO Channel 15/16 : 0 1
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1 [1] 1/-1/-1->0->-1 [2] 1/-1/-1->0->-1 [3] 1/-1/-1->0->-1 [4] -1/-1/-1->0->1 [5] -1/-1/-1->0->1 [6] -1/-1/-1->0->1 [7] -1/-1/-1->0->1 [8] 1/-1/-1->0->-1 [9] 1/-1/-1->0->-1 [10] 1/-1/-1->0->-1 [11] 1/-1/-1->0->-1 [12] -1/-1/-1->0->1 [13] -1/-1/-1->0->1 [14] -1/-1/-1->0->1 [15] -1/-1/-1->0->1
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO P2P Chunksize set to 524288
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO Check P2P Type intraNodeP2pSupport 1 directMode 0
notebook-76ec3aa0aa0d-worker-0:42641:42885 [1] NCCL INFO [Proxy Service] Device 1 CPU core 163
notebook-76ec3aa0aa0d-worker-0:42641:42886 [1] NCCL INFO [Proxy Service UDS] Device 1 CPU core 164
notebook-76ec3aa0aa0d-worker-0:42640:42887 [0] NCCL INFO [Proxy Service] Device 0 CPU core 78
notebook-76ec3aa0aa0d-worker-0:42640:42888 [0] NCCL INFO [Proxy Service UDS] Device 0 CPU core 79
notebook-76ec3aa0aa0d-worker-0:42641:42835 [1] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 512 | 512
notebook-76ec3aa0aa0d-worker-0:42641:42835 [1] NCCL INFO 16 coll channels, 16 collnet channels, 0 nvls channels, 16 p2p channels, 16 p2p channels per peer
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 512 | 512
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO 16 coll channels, 16 collnet channels, 0 nvls channels, 16 p2p channels, 16 p2p channels per peer
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO CC Off, workFifoBytes 1048576
notebook-76ec3aa0aa0d-worker-0:42641:42835 [1] NCCL INFO TUNER/Plugin: Could not find: libnccl-tuner.so. Using internal tuner plugin.
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO TUNER/Plugin: Could not find: libnccl-tuner.so. Using internal tuner plugin.
notebook-76ec3aa0aa0d-worker-0:42641:42835 [1] NCCL INFO ncclCommInitRankConfig comm 0xb7856b0 rank 1 nranks 2 cudaDev 1 nvmlDev 1 busId 17f000 commId 0xe52d77c7f68d320d commHash 16513987109356122637 - Init COMPLETE
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO ncclCommInitRankConfig comm 0xb7853e0 rank 0 nranks 2 cudaDev 0 nvmlDev 0 busId 109000 commId 0xe52d77c7f68d320d commHash 16513987109356122637 - Init COMPLETE
notebook-76ec3aa0aa0d-worker-0:42641:42835 [1] NCCL INFO Init timings - ncclCommInitRankConfig: rank 1 nranks 2 total 5.86 (kernels 0.33, alloc 0.24, bootstrap 0.01, allgathers 0.01, topo 0.24, graphs 0.00, connections 0.01, rest 5.02)
notebook-76ec3aa0aa0d-worker-0:42640:42834 [0] NCCL INFO Init timings - ncclCommInitRankConfig: rank 0 nranks 2 total 5.86 (kernels 0.35, alloc 0.23, bootstrap 0.00, allgathers 0.00, topo 0.24, graphs 0.00, connections 0.01, rest 5.03)
notebook-76ec3aa0aa0d-worker-0:42640:42889 [0] NCCL INFO Channel 00/0 : 0[0] -> 1[1] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42641:42890 [1] NCCL INFO Channel 00/0 : 1[1] -> 0[0] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42640:42889 [0] NCCL INFO Channel 01/0 : 0[0] -> 1[1] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42641:42890 [1] NCCL INFO Channel 01/0 : 1[1] -> 0[0] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42640:42889 [0] NCCL INFO Channel 02/0 : 0[0] -> 1[1] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42641:42890 [1] NCCL INFO Channel 02/0 : 1[1] -> 0[0] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42640:42889 [0] NCCL INFO Channel 03/0 : 0[0] -> 1[1] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42641:42890 [1] NCCL INFO Channel 03/0 : 1[1] -> 0[0] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42640:42889 [0] NCCL INFO Channel 04/0 : 0[0] -> 1[1] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42641:42890 [1] NCCL INFO Channel 04/0 : 1[1] -> 0[0] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42640:42889 [0] NCCL INFO Channel 05/0 : 0[0] -> 1[1] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42641:42890 [1] NCCL INFO Channel 05/0 : 1[1] -> 0[0] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42640:42889 [0] NCCL INFO Channel 06/0 : 0[0] -> 1[1] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42641:42890 [1] NCCL INFO Channel 06/0 : 1[1] -> 0[0] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42640:42889 [0] NCCL INFO Channel 07/0 : 0[0] -> 1[1] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42641:42890 [1] NCCL INFO Channel 07/0 : 1[1] -> 0[0] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42640:42889 [0] NCCL INFO Channel 08/0 : 0[0] -> 1[1] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42641:42890 [1] NCCL INFO Channel 08/0 : 1[1] -> 0[0] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42640:42889 [0] NCCL INFO Channel 09/0 : 0[0] -> 1[1] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42641:42890 [1] NCCL INFO Channel 09/0 : 1[1] -> 0[0] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42640:42889 [0] NCCL INFO Channel 10/0 : 0[0] -> 1[1] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42641:42890 [1] NCCL INFO Channel 10/0 : 1[1] -> 0[0] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42640:42889 [0] NCCL INFO Channel 11/0 : 0[0] -> 1[1] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42641:42890 [1] NCCL INFO Channel 11/0 : 1[1] -> 0[0] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42640:42889 [0] NCCL INFO Channel 12/0 : 0[0] -> 1[1] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42641:42890 [1] NCCL INFO Channel 12/0 : 1[1] -> 0[0] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42640:42889 [0] NCCL INFO Channel 13/0 : 0[0] -> 1[1] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42641:42890 [1] NCCL INFO Channel 13/0 : 1[1] -> 0[0] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42640:42889 [0] NCCL INFO Channel 14/0 : 0[0] -> 1[1] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42641:42890 [1] NCCL INFO Channel 14/0 : 1[1] -> 0[0] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42640:42889 [0] NCCL INFO Channel 15/0 : 0[0] -> 1[1] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42641:42890 [1] NCCL INFO Channel 15/0 : 1[1] -> 0[0] via P2P/CUMEM
notebook-76ec3aa0aa0d-worker-0:42641:42890 [1] NCCL INFO Connected all rings, use ring PXN 0 GDR 1
notebook-76ec3aa0aa0d-worker-0:42640:42889 [0] NCCL INFO Connected all rings, use ring PXN 0 GDR 1
  - Received gradient partition of shape: torch.Size([163]) on device cuda:0
  - Received gradient partition of shape: torch.Size([162]) on device cuda:1
[Rank 0] Step 2: Performing local optimizer step...
[Rank 1] Step 2: Performing local optimizer step...
  - Local parameter partition updated.
[Rank 1] Step 3: Performing All-Gather on updated parameters...
  - Local parameter partition updated.
[Rank 0] Step 3: Performing All-Gather on updated parameters...
  - Gathered all partitions. Full updated params shape: torch.Size([325])
[Rank 1] Step 4: Updating the original model with new parameters.
  - Original model updated.

  - Gathered all partitions. Full updated params shape: torch.Size([325])
[Rank 0] Step 4: Updating the original model with new parameters.
  - Original model updated.

[Rank 1] Verified that model parameters are consistent across all ranks after step.
----------------------------------------
--- [Rank 1] Iteration 2 ---
[Rank 0] Verified that model parameters are consistent across all ranks after step.
----------------------------------------
--- [Rank 0] Iteration 2 ---
[Rank 1] Backward pass complete. Gradients computed locally.

[Rank 1] Step 1: Performing Reduce-Scatter on gradients...
  - Full gradient shape on this rank: torch.Size([325])
  - Received gradient partition of shape: torch.Size([162]) on device cuda:1
[Rank 1] Step 2: Performing local optimizer step...
[Rank 0] Backward pass complete. Gradients computed locally.

[Rank 0] Step 1: Performing Reduce-Scatter on gradients...
  - Full gradient shape on this rank: torch.Size([325])
  - Received gradient partition of shape: torch.Size([163]) on device cuda:0
[Rank 0] Step 2: Performing local optimizer step...
  - Local parameter partition updated.
[Rank 1] Step 3: Performing All-Gather on updated parameters...
  - Gathered all partitions. Full updated params shape: torch.Size([325])
[Rank 1] Step 4: Updating the original model with new parameters.
  - Original model updated.

  - Local parameter partition updated.
[Rank 0] Step 3: Performing All-Gather on updated parameters...
  - Gathered all partitions. Full updated params shape: torch.Size([325])
[Rank 0] Step 4: Updating the original model with new parameters.
  - Original model updated.

[Rank 1] Verified that model parameters are consistent across all ranks after step.
[Rank 0] Verified that model parameters are consistent across all ranks after step.
----------------------------------------
----------------------------------------
notebook-76ec3aa0aa0d-worker-0:42641:42893 [1] NCCL INFO comm 0xb7856b0 rank 1 nranks 2 cudaDev 1 busId 17f000 - Destroy COMPLETE
notebook-76ec3aa0aa0d-worker-0:42640:42894 [0] NCCL INFO comm 0xb7853e0 rank 0 nranks 2 cudaDev 0 busId 109000 - Destroy COMPLETE

好的,我们来一步步详细地拆解您提供的这段 ZeRO-2 核心代码的执行流程。这段代码精准地展示了 ZeRO-2 优化器在 step() 函数内部是如何协同多个 GPU 完成一次参数更新的

我们将以 world_size=2 (两个 GPU,Rank 0 和 Rank 1) 为例,逐行解读。

# 假设 backward() 已完成,每个 GPU 上都有一个完整的、但基于本地数据的梯度
# full_grads = torch.cat([p.grad.flatten() for p in self.original_model.parameters()])
# 每个 GPU 上的 full_grads 形状都一样,但内容不同。

def step(self):
    # --- 1. 梯度同步与分区:Reduce-Scatter ---
    
    # 1a. 准备接收缓冲区
    grad_partition = torch.zeros_like(self.param_partition_for_rank, device=self.device)
    
    # 1b. 执行 Reduce-Scatter
    dist.reduce_scatter(grad_partition, list(full_grads.chunk(self.world_size)), op=dist.ReduceOp.SUM)
    
    # 1c. 平均梯度
    grad_partition.div_(self.world_size)
    
    # --- 2. 本地参数更新 ---
    
    # 2a. 将同步好的梯度分片赋给优化器
    self.param_partition_for_rank.grad = grad_partition
    
    # 2b. 执行基础优化器的 step
    self.base_optimizer.step()
    
    # --- 3. 参数同步:All-Gather ---
    
    # 3a. 准备接收缓冲区
    updated_partitions = [torch.zeros_like(p, device=self.device) for p in self.param_partitions]
    
    # 3b. 执行 All-Gather
    dist.all_gather(updated_partitions, self.param_partition_for_rank)
    
    # 3c. 拼接成完整参数
    updated_flat_params = torch.cat(updated_partitions)
    
    # --- 4. 更新原始模型 ---
    
    # 4a. 将最新的完整参数拷贝回模型
    offset = 0
    for p in self.original_model.parameters():
        numel = p.numel()
        p.data.copy_(updated_flat_params[offset:offset+numel].view_as(p.data))
        offset += numel

详细执行流程分解

阶段 0: 前提

在调用 step() 之前,loss.backward() 已经执行完毕。此时,在每个 GPU上:

  • self.original_model 中的每个参数都有一个 .grad 属性。
  • full_grads 是将这些 .grad 张量拼接起来形成的一个巨大的一维张量
  • Rank 0 上的 full_grads (我们称之为 G_full_0) 和 Rank 1 上的 full_grads (我们称之为 G_full_1) 内容是不同的
阶段 1: 梯度同步与分区 (Reduce-Scatter)

这是 ZeRO-2 最关键的一步,目的是计算出全局平均梯度,并且每个 GPU 只保留自己需要的那一部分。

  • 1a. grad_partition = torch.zeros_like(...)

    • Rank 0: 创建一个全零的张量 grad_partition_0,它的形状与 Rank 0 负责的参数分片 self.param_partition_for_rank 完全相同。这个张量是用来接收即将到来的梯度分片的。
    • Rank 1: 同样,创建一个形状匹配的、全零的张量 grad_partition_1
  • 1b. dist.reduce_scatter(...)

    • 这是一个阻塞式的集合通信操作,所有 rank 必须都到达这里才能继续。
    • 输入: list(full_grads.chunk(self.world_size))
      • Rank 0 上,full_grads (G_full_0) 被切分成 [G_full_0_part0, G_full_0_part1]
      • Rank 1 上,full_grads (G_full_1) 被切分成 [G_full_1_part0, G_full_1_part1]
      • 这个 input_list 实际上是一个分布式的列表,每个 rank 提供自己的那一份。
    • 后台操作:
      1. Reduce (归约/求和): NCCL 库在后台将所有 rank 对应部分的梯度相加。
        • G_sum_part0 = G_full_0_part0 + G_full_1_part0
        • G_sum_part1 = G_full_0_part1 + G_full_1_part1
      2. Scatter (散发): 通信库将求和后的结果分发出去。
        • G_sum_part0 发送给 Rank 0
        • G_sum_part1 发送给 Rank 1
    • 输出: grad_partition
      • 操作结束后,Rank 0grad_partition_0 张量中存储的内容就是 G_sum_part0
      • Rank 1grad_partition_1 张量中存储的内容就是 G_sum_part1
    • 结果: 每个 GPU 都获得了它所负责的那部分参数的全局总梯度
  • 1c. grad_partition.div_(self.world_size)

    • Rank 0: grad_partition_0 就地除以 2,得到平均梯度分片。
    • Rank 1: grad_partition_1 就地除以 2,得到平均梯度分片。

阶段 1 结束时: 每个 GPU 的显存中只保留了一小部分梯度(即梯度分片),并且这个分片是经过全局同步和平均的。原始的 full_grads 和参数的 .grad 属性所占用的内存可以被回收了。显存优化完成。


阶段 2: 本地参数更新

这个阶段不需要任何通信,完全在每个 GPU 内部独立进行。

  • 2a. self.param_partition_for_rank.grad = grad_partition

    • 我们将刚刚计算好的梯度分片,“嫁接”到优化器所管理的那个虚拟参数 self.param_partition_for_rank.grad 属性上。
    • 这就像是在告诉 Adam 优化器:“嘿,这就是你要用来更新参数的梯度!”
  • 2b. self.base_optimizer.step()

    • 调用基础优化器(如 Adam)的 step() 方法。
    • Rank 0: Adam 优化器会使用 grad_partition_0 来更新它管理的 self.param_partition_for_rank (即参数的前半部分)。它会更新自己的 momentum 和 variance 状态(这些状态从一开始就是分片的),并计算出新的参数值。
    • Rank 1: 同样,Adam 使用 grad_partition_1 来更新参数的后半部分。

阶段 2 结束时:

  • Rank 0 上的 self.param_partition_for_rank 存储了最新的参数前半部分
  • Rank 1 上的 self.param_partition_for_rank 存储了最新的参数后半部分
  • 此时,self.original_model 在两个 GPU 上的内容是不一致的!Rank 0 的模型只有前半部分是新的,Rank 1 的模型只有后半部分是新的。

阶段 3: 参数同步 (All-Gather)

为了让所有 GPU 在下一次前向传播时都拥有一个完整的、最新的模型,我们需要将更新后的参数分片重新组合起来。

  • 3a. updated_partitions = [torch.zeros_like(...) ...]

    • 每个 GPU 都创建一个列表 updated_partitions,其中包含两个全零的张量,形状分别对应参数的第一部分和第二部分。这个列表将作为接收所有分片的“容器”。
  • 3b. dist.all_gather(...)

    • 又一个阻塞式的集合通信操作。
    • 输入: self.param_partition_for_rank
      • Rank 0 将自己更新后的前半部分参数发送出去。
      • Rank 1 将自己更新后的后半部分参数发送出去。
    • 后台操作: NCCL 确保每个 GPU 都收到来自所有其他 GPU 的数据。
    • 输出: updated_partitions
      • 操作结束后,在每一个 GPU 上,updated_partitions 列表的内容都变成了:[最新的参数前半部分, 最新的参数后半部分]
  • 3c. updated_flat_params = torch.cat(updated_partitions)

    • 在每个 GPU 内部,将列表中的两个分片拼接起来,形成一个完整的、扁平化的、最新的参数张量 updated_flat_params

阶段 3 结束时: 每个 GPU 都拥有了一份完整的、更新后的模型参数。


阶段 4: 更新原始模型

最后一步,是将这个最新的参数“写回”到我们实际使用的模型 self.original_model 中。

  • 4a. for p in self.original_model.parameters(): ...
    • 这个循环遍历原始模型的每一个参数(如 layer1.weight, layer1.bias, layer2.weight等)。
    • p.data.copy_(...): 使用 copy_ 操作,从 updated_flat_params 中切出对应的一块,并将其内容拷贝到参数 p 中。view_as(p.data) 确保了形状的正确匹配。

阶段 4 结束时: 在所有 GPU 上,self.original_model 的所有参数都被更新到了最新状态,并且内容完全一致。系统现在准备好进行下一次的迭代了。


总结

这个 step 函数的执行流程,就像一个精心编排的舞蹈:

  1. 分散收集 (Reduce-Scatter): 将分布在各地的“信息”(梯度)进行汇总和精确分发,每个人只拿到自己的任务指令。
  2. 独立工作 (Local Update): 每个人根据自己的任务指令,独立完成自己的工作(更新参数分片)。
  3. 成果汇集 (All-Gather): 将每个人完成的工作成果收集起来,组合成一个完整的最终产品(完整的模型参数)。
  4. 最终部署 (Update Model): 将最终产品分发给每个人,确保大家都有一个最新的副本。

这个流程通过在关键步骤进行高效的分布式通信,巧妙地解决了大规模模型训练中的显存瓶颈问题。

每个 GPU 都创建一个列表 updated_partitions,其中包含两个全零的张量,形状分别对应参数的第一部分和第二部分。 为什么需要两个。 可能这与all_gather有关, 他的参数量相比allreduce 怎么样的

您提了两个非常棒的问题,它们都直指分布式通信操作的本质和效率。我们来逐一深入探讨。


问题一:为什么 updated_partitions 需要包含两个(或 world_size 个)张量?

直接答案:
因为 torch.distributed.all_gather 函数的接口设计就是这样的。它的第一个参数 tensor_list 是一个输出参数,它的作用是作为容器,接收来自所有进程的数据

  • tensor_list: 一个列表,其长度必须等于 world_size
  • tensor: 当前进程想要发送出去的张量。

详细解释:
让我们看看 all_gather 的函数签名和工作流程:
dist.all_gather(tensor_list, tensor)

  1. 准备阶段: 在调用函数之前,你必须在每个进程上都创建一个 tensor_list

    • 这个列表的长度必须是 world_size。对于我们的例子(world_size=2),它必须包含 2 个张量。
    • 列表中的每个张量都应该是一个预先分配好内存的“占位符”,其形状和数据类型必须与要发送的 tensor 完全相同。
    • updated_partitions = [torch.zeros_like(p, device=self.device) for p in self.param_partitions] 这行代码正是做了这件事。它创建了一个包含两个全零张量的列表,形状分别对应参数的第一和第二分片。
  2. 执行阶段:

    • Rank 0: 将它的 self.param_partition_for_rank (参数的前半部分) 发送出去。
    • Rank 1: 将它的 self.param_partition_for_rank (参数的后半部分) 发送出去。
    • 后台的 NCCL 库进行数据交换。
  3. 完成阶段:

    • all_gather 函数会将收集到的数据填充到你提供的 tensor_list 中。
    • Rank 0 上:
      • updated_partitions[0] 将被填充上来自 Rank 0 的数据。
      • updated_partitions[1] 将被填充上来自 Rank 1 的数据。
    • Rank 1 上:
      • updated_partitions[0] 将被填充上来自 Rank 0 的数据。
      • updated_partitions[1] 将被填充上来自 Rank 1 的数据。

结论: updated_partitions 列表就像一个信箱阵列。你有 world_size 个邻居(包括你自己),所以你需要准备 world_size 个信箱。all_gather 操作结束后,你的每个信箱里都会有一封来自对应邻居的信。

所以,这个列表的长度和内容是 all_gather 这个特定 API 的要求,它为接收来自所有进程的数据提供了预定义的、大小正确的存储空间。


问题二:all_gather 的通信量与 all_reduce 相比怎么样?

这是一个衡量分布式算法效率的关键问题。让我们用一个具体的例子来比较。

场景设置:

  • GPU 数量 (world_size): N
  • 每个 GPU 上的数据大小: S (例如,一个参数分片的大小)
1. all_gather 的通信量
  • 操作: 每个 GPU 将自己的大小为 S 的数据块发送给所有其他 N-1 个 GPU。同时,它也从其他 N-1 个 GPU 接收它们的数据块。
  • 每个 GPU 的发送量: (N-1) * S
  • 每个 GPU 的接收量: (N-1) * S
  • 总通信量 (简化模型): 大约是 N * (N-1) * S。这个模型假设是点对点通信,实际的 Ring-AllGather 算法会更高效。

一个更精确的 Ring-AllGather 算法模型:
在高效的 Ring-AllGather 实现中,数据沿着一个环形链路传递。

  • 每个 GPU 需要发送 (N-1)/N * N * S = (N-1) * S 的数据。
  • 每个 GPU 需要接收 (N-1) * S 的数据。
  • 通信分为 N-1 步。在每一步,每个 GPU 发送和接收大小为 S 的数据块。
  • 总通信时间: 2 * (N-1) * (α + S/β),其中 α 是网络延迟,β 是带宽。
  • 总通信数据量: (N-1) * N * S。每个GPU发送(N-1)*S的数据,接收(N-1)*S的数据。
2. all_reduce 的通信量

all_reduce 通常也用 Ring 算法实现。

  • 操作: 分为两个阶段:一个 Reduce-Scatter 和一个 All-Gather。
    1. Reduce-Scatter 阶段: 数据沿着环形链路传递并累加。经过 N-1 步,每个 GPU 都有了最终结果的一个分片。
    2. All-Gather 阶段: 与上面的 All-Gather 类似,每个 GPU 将自己持有的结果分片广播给其他所有 GPU。也需要 N-1 步。
  • 总通信时间: 大约是 2 * (N-1) * (α + S_full/β),其中 S_full = N*S 是聚合后的总数据大小。
  • 总通信数据量: 在每个步骤中,GPU发送和接收数据。经过 2*(N-1) 步,总的通信数据量大约是 2 * (N-1) * S_full = 2 * (N-1) * N * S

等等,这看起来 all_reduce 的通信量是 all_gather 的两倍?
这里需要注意 S 的定义。

让我们重新定义场景来做一个公平的比较。

  • 总数据大小: M
  • GPU 数量: N

All-Reduce:

  • 每个 GPU 一开始有一份大小为 M 的张量。
  • 目标是所有 GPU 都得到这 N 个张量reduce(例如求和)之后的大小为 M 的结果。
  • 通信量: 2 * (N-1)/N * M。在Ring算法中,每个GPU总共发送和接收的数据量是2*(N-1)个大小为M/N的数据块。

All-Gather:

  • 每个 GPU 一开始有一份大小为 M/N 的张量(一个分片)。
  • 目标是所有 GPU 都得到拼接后的大小为 M 的完整张量。
  • 通信量: (N-1)/N * M。在Ring算法中,每个GPU总共发送和接收的数据量是(N-1)个大小为M/N的数据块。

结论:
在理想的 Ring 算法实现中,对于同样大小的最终聚合数据 M
all_reduce 的通信量大约是 all_gather 的两倍。

这是因为 all_reduce 本质上就包含了一个 all_gather 的过程(在 reduce-scatter 之后的分发阶段),所以它的通信成本更高。

回到 ZeRO-2

在 ZeRO-2 的 step() 函数中:

  • reduce_scatter: 输入是 N 个大小为 M 的完整梯度,输出是 N 个大小为 M/N 的梯度分片。其通信量大约与 all_gather 相当,即 (N-1)/N * M
  • all_gather: 输入是 N 个大小为 M/N 的参数分片,输出是 N 个大小为 M 的完整参数。其通信量也是 (N-1)/N * M

所以,ZeRO-2 的一个 step 中,总通信量大约是 2 * (N-1)/N * M,这恰好与一次 all_reduce 的通信量相当。

ZeRO-2 的巧妙之处在于:它将一次昂贵的 all_reduce 操作,拆分成了两次通信量减半的操作 (reduce_scatterall_gather),并且在这两次通信之间插入了计算步骤(本地 optimizer.step())。这不仅在梯度同步阶段节省了显存,而且总的通信开销与标准的数据并行(只使用一次 all_reduce)保持在同一水平,没有引入额外的通信负担。

reduce_scatter: 输入是 N 个大小为 M 的完整梯度,输出是 N 个大小为 M/N 的梯度分片。其通信量大约与 all_gather 相当,即 (N-1)/N * M。计算方式,假如将其他N-1 个完整的发到0号,通信量为N-1 * M, 然后scatter 到N个其他的, 通信量为什么除以M,这样算感觉不需要啊

您这个问题提得太好了!您指出了一个非常容易混淆的地方:简单直观的通信模型 vs. 高效的 Ring 算法实现

您提到的计算方式是基于一个**中心化(Centralized)朴素(Naïve)的通信模型,而我给出的 (N-1)/N * M 公式是基于高效的环形(Ring)**算法。

让我们来详细拆解这两种模型,您就会明白为什么会有这个差异,以及为什么实际应用中我们考虑的是 Ring 算法。


1. 您描述的朴素/中心化通信模型

这种模型非常直观,容易理解。

场景:

  • N 个 GPU,每个 GPU 有一个大小为 M 的梯度张量。
  • 目标是计算 reduce_scatter

您的计算思路:

  1. Reduce (归约) 阶段:

    • 选定一个 GPU 作为“主节点”,比如 GPU 0。
    • 其他 N-1 个 GPU (GPU 1, 2, …, N-1) 将它们各自完整的、大小为 M 的张量发送给 GPU 0。
    • 通信量: (N-1) * M
    • GPU 0 收到所有数据后,将它们相加,得到一个大小为 M 的总和 G_sum
  2. Scatter (散发) 阶段:

    • GPU 0 将 G_sum 切分成 N 份,每份大小为 M/N
    • GPU 0 将第 i 份数据发送给 GPU i
    • 通信量: GPU 0 需要发送 N-1M/N 大小的数据给其他 GPU,总发送量是 (N-1) * M/N

总通信量: (N-1) * M + (N-1) * M/N = (N-1) * M * (1 + 1/N)

这个模型的巨大问题:

  • 性能瓶颈: 所有的通信压力都集中在 GPU 0 上。GPU 0 的网络带宽成为了整个系统的瓶颈。
  • 网络拥塞: 所有流量都涌向一个节点,容易造成网络拥塞。
  • 效率低下: 在 GPU 0 进行计算和分发时,其他 N-1 个 GPU 都在空闲等待。

结论: 这种朴素模型在理论上可行,但在实际的大规模分布式训练中绝对不会被使用,因为它效率太低,扩展性极差。


2. 高效的 Ring 算法模型(实际使用的模型)

实际的通信库(如 NCCL)使用 Ring 算法来实现 reduce_scatterall_gatherall_reduce,以避免中心化瓶颈,让所有 GPU 同时参与工作。

Ring-Reduce-Scatter 的工作原理:

想象一下 N 个 GPU 围成一个环。

  1. 数据分块: 首先,每个 GPU 将自己的本地数据(大小为 M)切分成 N块(chunks),每个块大小为 M/N。我们把这些块编号为 chunk_0, chunk_1, ..., chunk_{N-1}

  2. N-1 步迭代通信: 算法会进行 N-1 次迭代。在k(从 k=0 到 N-2):

    • 每个 GPU i 会将它当前持有chunk_{(i-k) mod N} 发送给它的下一个邻居 (i+1) mod N
    • 同时,它会从它的上一个邻居 (i-1) mod N 接收一个 chunk
    • 关键: 接收到新的 chunk 后,它会立即将这个新来的块与自己本地对应的块相加(Reduce操作)

我们来追踪一个块的旅程:

  • 以 GPU 0 上的 chunk_0 为例。它最终的目标是累加上所有 GPU 的 chunk_0
  • 在第 0 步,GPU 1 将它的 chunk_1 发给 GPU 2,GPU 0 将它的 chunk_0 发给 GPU 1。
  • … 经过 N-1 步后,chunk_0 已经绕环一周,并且在传递过程中,每个 GPU 都把它自己的 chunk_0 版本累加了上去。当它最终回到 GPU 0 时,它就包含了所有 chunk_0 的总和。

通信量分析:

  • 在每一轮迭代中,每个 GPU 发送一个块(大小 M/N)并接收一个块(大小 M/N)。
  • 总共有 N-1 轮迭代。
  • 每个 GPU 的总发送量: (N-1) * (M/N)
  • 每个 GPU 的总接收量: (N-1) * (M/N)
  • 整个系统的总通信量 (所有 GPU 发送的数据之和): N * (N-1) * (M/N) = (N-1) * M

回到您的问题:通信量为什么是 (N-1)/N * M

这个公式 (N-1)/N * M 描述的是单个 GPU 在整个 Ring 算法过程中总共发送(或接收)的数据量

  • M: 初始时每个 GPU 上的总数据大小。
  • M/N: 每个小块的大小。
  • N-1: 通信的轮数。
  • (N-1) * (M/N) = (N-1)/N * M

这个公式完全正确,它来自于对高效 Ring 算法的分析,而不是基于那个有瓶颈的中心化模型。

总结

对比项 朴素/中心化模型 Ring 算法模型 (实际使用)
工作方式 所有数据发给一个主节点,主节点处理后再分发。 所有 GPU 围成环,分块数据在环上传递并累加。
瓶颈 主节点成为严重瓶颈 没有中心瓶颈,所有 GPU 带宽被同时利用。
单个GPU总发送量 M (对于非主节点)
(N-1)M/N (对于主节点)
(N-1) * M/N
扩展性 好,非常适合大规模 GPU 集群。
适用场景 教学,易于理解。 工业级、高性能分布式计算。

所以,当我们在讨论分布式通信操作的通信量时,我们默认讨论的是像 Ring 算法这样的高效实现。您之前的计算方式是正确的,但它分析的是一个效率较低的模型。reduce_scatter 的通信量之所以比 all_reduce 小,正是因为它在 Ring 算法的框架下,巧妙地省去了一半的通信步骤。

p.numel() 是 PyTorch 张量(Tensor)的一个非常基础且常用的方法。

它的作用是:

返回张量 p 中元素的总数量。

numel 是 “number of elements” 的缩写。


详细解释和示例

假设 p 是一个 PyTorch 张量,它的形状是 (d_1, d_2, ..., d_n)。那么 p.numel() 的计算结果就等于:

d_1 * d_2 * ... * d_n

它会忽略张量的数据类型(dtype)和所在的设备(device),只计算元素的个数。

代码示例:

import torch

# 示例 1: 二维张量 (矩阵)
matrix = torch.randn(3, 4)  # 一个 3x4 的矩阵
print(f"矩阵的形状: {matrix.shape}")
print(f"矩阵的元素总数: {matrix.numel()}") # 输出将是 3 * 4 = 12

# 示例 2: 一维张量 (向量)
vector = torch.arange(10) # 包含 0 到 9 的向量
print(f"\n向量的形状: {vector.shape}")
print(f"向量的元素总数: {vector.numel()}") # 输出将是 10

# 示例 3: 四维张量 (常用于图像批次)
# (batch_size, channels, height, width)
image_batch = torch.zeros(16, 3, 224, 224) 
print(f"\n图像批次的形状: {image_batch.shape}")
print(f"图像批次的元素总数: {image_batch.numel()}") # 输出将是 16 * 3 * 224 * 224 = 2408448

# 示例 4: 标量 (零维张量)
scalar = torch.tensor(42)
print(f"\n标量的形状: {scalar.shape}")
print(f"标量的元素总数: {scalar.numel()}") # 输出将是 1

输出:

矩阵的形状: torch.Size([3, 4])
矩阵的元素总数: 12

向量的形状: torch.Size([10])
向量的元素总数: 10

图像批次的形状: torch.Size([16, 3, 224, 224])
图像批次的元素总数: 2408448

标量的形状: torch.Size([])
标量的元素总数: 1

在您提供的 ZeRO-2 代码中的作用

让我们回到您看到的这行代码:

# for p in self.original_model.parameters():
#     numel = p.numel()
#     p.data.copy_(updated_flat_params[offset:offset+numel].view_as(p.data))
#     offset += numel

这里的 p 是模型的一个参数,例如一个形状为 (20, 10) 的权重矩阵。

  1. numel = p.numel():

    • 这行代码计算出参数 p 中包含多少个元素。对于一个 (20, 10) 的权重矩阵,numel 的值就是 20 * 10 = 200
  2. updated_flat_params[offset:offset+numel]:

    • updated_flat_params 是一个巨大的、一维的、包含了所有最新参数的张量。
    • 这行代码的作用就是从这个大张量中,根据 offset(当前处理到的位置)和 numel(当前参数的大小),切出正好属于参数 p 的那一段数据。
    • 例如,如果 offset 是 0,numel 是 200,它就会切出 updated_flat_params 的前 200 个元素。
  3. .view_as(p.data):

    • 切出来的 updated_flat_params[offset:offset+numel] 仍然是一个一维的向量。
    • .view_as(p.data) 会将这个一维向量的形状**重塑(reshape)**成与原始参数 p 完全相同的形状。例如,将 (200,) 的向量重塑成 (20, 10) 的矩阵。
  4. p.data.copy_(...):

    • 将重塑好形状的最新参数数据,拷贝回原始模型参数 p 的存储空间中。
  5. offset += numel:

    • 更新 offset,使其指向下一个参数在大张量中的起始位置,为处理下一个参数做准备。

总结: 在这个上下文中,p.numel() 是一个至关重要的工具,它使得我们能够准确地知道需要从扁平化的参数大张量中切出多长的一段来更新当前正在处理的这个特定参数 p。它是在多维参数张量和一维扁平化张量之间建立映射的关键桥梁。

好的,我们来深入解析 train_worker 函数和 mp.spawn 是如何协同工作,从而启动一个完整的多进程分布式训练的。这个组合是 PyTorch 中实现多进程并行(无论是 DDP、ZeRO 还是其他并行策略)的标准和推荐方式。


1. mp.spawn:启动器和指挥官

首先,我们来看 main 函数中的 mp.spawn 调用:

# if __name__ == "__main__":
def main():
    world_size = torch.cuda.device_count() # 假设为 2
    
    print(f"Spawning {world_size} processes for distributed training...")
    
    mp.spawn(train_worker,           # 要执行的目标函数
             args=(world_size,),     # 传递给目标函数的额外参数
             nprocs=world_size,      # 要创建的进程数量
             join=True)              # 主进程等待所有子进程结束后再退出

mp.spawn 的作用就像一个任务分发中心。它会做以下几件事情:

  1. 创建进程 (Spawning Processes):

    • 它会创建 nprocs 个(这里是 2 个)全新的、独立的 Python 进程。我们称它们为 子进程
    • 这些子进程会继承当前脚本的环境,但拥有自己独立的内存空间。
  2. 分配角色 (Assigning Ranks):

    • mp.spawn 会自动为每个创建的子进程分配一个唯一的整数 ID,这个 ID 就是我们所说的 rank
    • 第一个子进程的 rank0
    • 第二个子进程的 rank1
    • …以此类推,直到 nprocs - 1
  3. 调用目标函数 (Invoking the Target Function):

    • mp.spawn 会在每个子进程中调用你指定的目标函数,即 train_worker
    • 它会将 rank 作为第一个参数自动传递给 train_worker
    • 它还会将 args 元组中的其他参数(这里是 world_size)传递给 train_worker
    • 所以,实际上发生的调用是:
      • 进程 0 中执行: train_worker(rank=0, world_size=2)
      • 进程 1 中执行: train_worker(rank=1, world_size=2)
  4. 等待和管理 (Joining):

    • 由于 join=True,启动了所有子进程后,主进程(即执行 mp.spawn 的那个进程)会暂停下来,等待所有子进程执行完毕。
    • 如果任何一个子进程因为错误而崩溃,mp.spawn 会捕获这个异常,并终止所有其他子进程,然后在主进程中抛出 ProcessRaisedException,这就是你之前看到的那个报错。
    • 当所有子进程都正常退出后,主进程才会继续执行(或者结束)。

2. train_worker:每个进程的独立工作空间

现在,我们进入 train_worker 函数的内部,看看在每个独立的进程中都发生了什么。以 Rank 1 的进程为例(train_worker(rank=1, world_size=2))。

def train_worker(rank, world_size):
    # rank = 1, world_size = 2
  1. os.environ['MASTER_ADDR'] = 'localhost'os.environ['MASTER_PORT'] = '12345':

    • 这一步是在告诉所有进程,它们应该去哪里集合。这就像约定了一个集结点地址
    • MASTER_ADDR: 主节点的地址,localhost 表示主节点就在这台机器上。
    • MASTER_PORT: 主节点监听的端口号。
    • 所有进程都必须使用相同的地址和端口,才能找到彼此。
  2. dist.init_process_group("nccl", rank=rank, world_size=world_size):

    • 这是整个分布式魔法的核心。每个进程都会调用这个函数来加入通信组
    • backend="nccl": 指定使用 NVIDIA 的 NCCL 库进行 GPU 间的高速通信。
    • rank=rank: 进程 1 在这里告诉大家:“我是 1 号成员”。
    • world_size=world_size: 进程 1 知道:“我们这个组总共有 2 个成员”。
    • 握手过程 (Handshake):
      • Rank 0 进程(主节点)会先到达,并在 localhost:12345 上建立一个监听服务。
      • Rank 1 进程到达后,会去连接 localhost:12345
      • 当所有 world_size 个进程都连接成功后,通信组就建立起来了。init_process_group 函数才会返回,所有进程才能继续往下执行。如果有一个进程没来,所有进程都会在这里无限期等待。
  3. torch.cuda.set_device(rank):

    • 这一步至关重要,它将当前进程与一个特定的 GPU 绑定
    • 在进程 1 中,torch.cuda.set_device(1) 的意思是:“从现在开始,我这个进程所有的 CUDA 操作(如创建张量、模型计算)都默认在 GPU 1 上进行”。
    • 这确保了进程 0 在 GPU 0 上工作,进程 1 在 GPU 1 上工作,实现了真正的并行,而不是所有进程都挤在一个 GPU 上。
  4. model = SimpleModel().to(device):

    • 每个进程都会创建一个独立的、完整的模型实例
    • .to(device) 将这个模型的所有参数和缓冲区移动到当前进程绑定的 GPU 上(对于进程 1,就是 GPU 1)。
  5. zero2_optimizer = ZeRO_2_Optimizer(...):

    • 每个进程都会创建一个独立的优化器实例
    • 在优化器内部,它会根据自己的 rankworld_size,决定自己负责模型参数的哪一个分片
      • 进程 0 的优化器只管理参数的前半部分的状态。
      • 进程 1 的优化器只管理参数的后半部分的状态。
  6. 训练循环 for i in range(2)::

    • inputs = torch.randn(...).to(device): 每个进程生成自己独立的、随机的一批数据,并放到自己的 GPU 上。这是数据并行(Data Parallelism)的体现。
    • loss.backward(): 每个进程在自己的 GPU 上,用自己的数据和完整的模型副本,计算出一份本地的、完整的梯度。此时,进程 0 和进程 1 的梯度是不同的。
    • zero2_optimizer.step(): 这是我们之前详细分析过的 ZeRO-2 核心步骤。
      • 两个进程通过 reduce_scatter 交换梯度信息,计算出全局同步的梯度分片。
      • 每个进程在本地更新它负责的那部分参数。
      • 两个进程再通过 all_gather 交换更新后的参数分片,最终在每个进程中都重建出一个完全一致的、最新的模型
    • dist.barrier(): 一个同步点,确保所有进程都完成了 step() 之后再继续。
    • 验证逻辑: 通过 broadcast 从 rank 0 发送其参数,其他 rank 接收并比较,以编程方式确认 all_gather 确实成功地使所有模型保持了一致。
  7. dist.destroy_process_group():

    • 训练结束后,每个进程调用此函数,优雅地退出通信组,释放所有分布式相关的资源。

总结:一个生动的比喻

想象一下你要组织一个大型建筑项目(训练一个大模型),你有两个施工队(world_size=2)。

  1. mp.spawn (项目经理):

    • 他招募了两个施工队,分别命名为“0号队”和“1号队” (rank=0, 1)。
    • 他给两个队发了相同的完整的建筑蓝图(模型代码),并告诉他们项目的总负责人地址 (MASTER_ADDR/PORT)。
    • 他对0号队说:“你去工地的东区开工 (device=0)”,对1号队说:“你去西区开工 (device=1)”。
    • 然后项目经理就在办公室等着,直到项目完成 (join=True)。
  2. train_worker (每个施工队的工作):

    • 集合 (init_process_group): 两个队都到达了项目负责人的地址,通过对讲机确认彼此都在,建立起通信。
    • 准备工具 (zero2_optimizer):
      • 0号队发现,根据项目分工,他们只负责维护地基和一楼的施工设备(优化器状态分片)。
      • 1号队只负责维护二楼和屋顶的施工设备。
    • 开始施工 (training loop):
      • 两个队都看着完整的蓝图(模型副本)进行工作。
      • 他们使用各自工区内的砖块和材料(输入数据),并计算出完成今天任务需要对整个建筑做哪些改动(本地梯度)。
      • 开会 (zero2_optimizer.step):
        • reduce_scatter: 两个队长通过对讲机沟通,0号队长说:“我们计算出地基和一楼要这样改…”,1号队长说:“我们计算出二楼和屋顶要那样改…”。他们汇总意见,最终0号队长只记下关于地基和一楼的最终方案(梯度分片),1号队长只记下关于二楼和屋顶的方案。
        • local update: 两个队根据自己记下的最终方案,在自己的工区进行施工(更新参数分片)。
        • all_gather: 施工完毕后,两个队长再次通过对讲机,互相告知对方自己工区的最新状态。这样,每个队的蓝图副本都被更新到了最新的、完全一致的状态。
    • 项目结束 (destroy_process_group): 所有工作完成后,各队收拾工具,关闭对讲机,解散。

通过这个流程,mp.spawn 负责“搭台”,而 train_worker 负责在每个独立的舞台上“唱戏”,并通过 torch.distributed 提供的通信原语进行高效的协同,最终完成了看似复杂但逻辑清晰的分布式训练任务。

代码中是不是没体现zero2 只保存自己的优化器参数那部分代码

您提出了一个绝佳的、非常深入的问题!您完全正确——在我提供的简化代码中,优化器状态(Optimizer State)的分区是**隐式(Implicitly)**发生的,而不是通过显式的切分代码来体现的。

这种隐式行为正是这段代码设计的巧妙之处。它依赖于 PyTorch 优化器的工作原理

让我们来详细解释这部分,并用一个独立的例子来证明它。


1. PyTorch 优化器如何管理状态?

一个标准的 PyTorch 优化器(如 Adam)是有状态的。它需要为它所管理的每一个参数存储额外的信息(比如 Adam 的一阶矩 momentum 和二阶矩 variance)。

这个状态是在优化器第一次执行 step() 时,按需创建的。关键在于:

优化器只会为在初始化时传递给它的那些参数创建和管理状态。

它内部有一个 state 字典,key 是参数对象,value 是该参数对应的状态信息。如果你不把某个参数传给优化器,optimizer.state 字典里就永远不会有这个参数的条目。


2. 回顾我们的 ZeRO-2 简化代码

现在,让我们来看一下 ZeRO_2_Optimizer 初始化时的这两行关键代码:

# 在 __init__ 方法中:

# 1. 我们创建了一个只属于当前 rank 的、扁平化的参数分片
self.param_partition_for_rank = self.param_partitions[self.rank].detach().clone().to(self.device).requires_grad_(True)

# 2. 我们用这个【分片】来初始化基础优化器
self.base_optimizer = optimizer_class([self.param_partition_for_rank], **optimizer_kwargs)

这里发生了什么?

  1. 我们没有将完整的 model.parameters() 传递给 Adam 优化器。
  2. 相反,我们创建了一个全新的、更小的、独立的张量 self.param_partition_for_rank。这个张量的大小只有完整模型参数的 1/world_size
  3. 我们把这个小得多的分片张量,作为唯一的参数,传递给了 Adam 优化器的构造函数。

结果就是

  • self.base_optimizer 从始至终只知道一个参数的存在,那就是 self.param_partition_for_rank
  • self.base_optimizer.step() 第一次被调用时,它会为这个唯一的参数创建状态。
  • 因为这个参数本身就是分片的(大小只有 1/N),所以为它创建的 momentumvariance 缓冲区的大小也自然只有 1/N

因此,通过控制传递给优化器构造函数的参数,我们巧妙地实现了优化器状态的分区存储。 每个 rank 上的优化器实例,天然地只为全局参数的一小部分分配和维护状态,从而节省了大量的显存。


3. 代码证明:“眼见为实”

为了让您更直观地理解,下面是一个独立的、极简的脚本,它清晰地展示了优化器状态是如何根据传入参数的大小来创建的。

import torch
import torch.nn as nn
from torch.optim import Adam

# 创建一个简单的模型
model = nn.Sequential(
    nn.Linear(10, 20, bias=False),  # 200 个参数
    nn.Linear(20, 5, bias=False)   # 100 个参数
)
total_params = sum(p.numel() for p in model.parameters())
print(f"模型总参数量: {total_params}") # 输出: 300

# --- 场景 1: 标准优化器,管理所有参数 ---
print("\n--- 场景 1: 标准优化器 ---")
optimizer_full = Adam(model.parameters(), lr=0.01)

# 模拟一次迭代来触发状态创建
(model(torch.randn(1, 10)).sum()).backward()
optimizer_full.step()

print("优化器状态字典中的 key 数量:", len(optimizer_full.state)) # 输出: 2 (对应两个 Linear 层的 weight)

# 检查第一个参数 (200个元素) 的状态大小
p1 = list(model.parameters())[0]
state_p1 = optimizer_full.state[p1]
momentum_shape_p1 = state_p1['exp_avg'].shape
variance_shape_p1 = state_p1['exp_avg_sq'].shape
print(f"第一个参数的状态 (momentum/variance) 形状: {momentum_shape_p1} / {variance_shape_p1}") # 输出: torch.Size([20, 10])


# --- 场景 2: ZeRO-2 模拟,只管理一个分片 ---
print("\n--- 场景 2: ZeRO-2 模拟优化器 ---")

# 1. 将所有参数扁平化并切分 (模拟 world_size=2)
flat_params = torch.cat([p.data.flatten() for p in model.parameters()])
partitions = flat_params.chunk(2)

# 2. Rank 0 的优化器只管理第一个分片
partition_0 = partitions[0].clone().requires_grad_(True)
print(f"Rank 0 的参数分片大小: {partition_0.numel()}") # 输出: 150

optimizer_partitioned = Adam([partition_0], lr=0.01) # 只把分片传进去

# 3. 模拟一次迭代
partition_0.grad = torch.randn_like(partition_0) # 模拟同步好的梯度分片
optimizer_partitioned.step()

print("分区优化器状态字典中的 key 数量:", len(optimizer_partitioned.state)) # 输出: 1
# 检查这个唯一的分片参数的状态大小
state_part0 = optimizer_partitioned.state[partition_0]
momentum_shape_part0 = state_part0['exp_avg'].shape
variance_shape_part0 = state_part0['exp_avg_sq'].shape
print(f"分片参数的状态 (momentum/variance) 形状: {momentum_shape_part0} / {variance_shape_part0}") # 输出: torch.Size([150])

print("\n比较:")
print(f"标准优化器需要存储状态的参数总数: {200 + 100 = 300}")
print(f"ZeRO-2 模拟优化器需要存储状态的参数总数: {150}")

运行以上代码,您会看到清晰的对比

  • 场景1中的标准优化器,为模型的全部 300 个参数创建了状态。
  • 场景2中模拟 ZeRO-2 的优化器,只被告知要管理一个大小为 150 的参数分片,因此它也只为这 150 个参数创建了状态。

这个例子无可辩驳地证明了,我们通过巧妙地构造传递给优化器的参数列表,隐式地、但却有效地实现了优化器状态的分区,这正是 ZeRO-1 和 ZeRO-2 节省显存的核心原理之一。

哪利用优化器更新时,不需要所有通信吗

您又一次问到了分布式训练中一个极其关键且容易混淆的点!

您的直觉是完全正确的:在优化器更新参数的那个瞬间,确实不需要通信。

这句话的意思是,optimizer.step() 这个操作本身,在 ZeRO-2 的设计下,是一个**纯本地计算(Local Computation)**的过程。

让我们来详细分解一下“为什么不需要通信”,以及这在整个 ZeRO-2 流程中处于什么位置。


ZeRO-2 step() 流程的再审视:分工明确

回想一下我们之前拆解的 ZeRO-2 的 step() 流程,它被精心设计成了**“通信-计算-通信”**的三明治结构。

ZeRO_2_Optimizer.step() {
    
    // --- Part 1: 通信阶段 ---
    // Reduce-Scatter: 同步并分发梯度
    dist.reduce_scatter(...) 
    
    // ---------------------------------
    //  ↑↑↑ 通信结束,计算开始 ↓↓↓
    // ---------------------------------
    
    // --- Part 2: 纯本地计算阶段 ---
    // 这个阶段【不涉及】任何 GPU 间的通信
    
    // 1. 将同步好的梯度分片“喂”给优化器
    self.param_partition_for_rank.grad = grad_partition
    
    // 2. 调用基础优化器,执行更新
    // 这是一个纯粹的数学计算过程:
    // new_param = old_param - lr * f(grad, momentum, variance)
    self.base_optimizer.step()
    
    // ---------------------------------
    //  ↑↑↑ 计算结束,通信开始 ↓↓↓
    // ---------------------------------

    // --- Part 3: 通信阶段 ---
    // All-Gather: 同步更新后的参数
    dist.all_gather(...)

    // ... 后续将参数写回模型
}

为什么在 base_optimizer.step() 时不需要通信?

关键在于,在执行 self.base_optimizer.step() 之前,每个 GPU 已经拥有了更新自己所负责的那部分参数所需的所有信息

让我们以 GPU 0 为例,在调用 base_optimizer.step() 的前一刻,它拥有:

  1. 旧的参数值: 它持有 self.param_partition_for_rank,这代表了模型参数的前半部分的当前值。
  2. 正确的梯度: 经过 reduce_scatter 之后,grad_partition 存储了模型参数前半部分的、经过全局同步和平均的梯度。
  3. 对应的优化器状态: self.base_optimizer 内部存储了只属于这前半部分参数的 momentumvariance

所以,当 base_optimizer.step() 被调用时,GPU 0 可以独立地、自给自足地完成以下计算:

  • 根据梯度更新 momentumvariance
  • 根据新的 momentumvariance 和梯度,计算出参数的更新量。
  • 将更新量应用到 self.param_partition_for_rank 上,得到新的参数值。

这个过程对于 GPU 1 也是完全一样的,它独立地更新参数的后半部分。

因为每个 GPU 在这个阶段只处理自己的“一亩三分地”,并且已经拿到了所有必要的“生产资料”(梯度和状态),所以它完全不需要和邻居(其他GPU)商量或交换任何信息。

这与标准数据并行(DDP)的对比

这个“本地更新”的特性,与标准 DDP 形成了鲜明对比。

  • 标准 DDP:
    1. loss.backward() 计算本地梯度。
    2. 通信: All-Reduce 操作被调用,将所有 GPU 的梯度进行求和平均。操作结束后,每个 GPU 都拥有一份完整的、同步好的梯度
    3. optimizer.step(): 每个 GPU 使用这份完整的梯度,独立地更新完整的模型参数。这个 step 本身也是本地计算。

区别在于

  • 在 DDP 中,optimizer.step() 更新的是全部参数
  • 在 ZeRO-2 中,optimizer.step() 更新的是部分参数(一个分片)

虽然两者在 step() 瞬间都是本地计算,但 ZeRO-2 的设计之所以更优越,是因为它通过将计算目标缩小(从完整参数到参数分片),从而使得为这个计算准备输入(梯度和状态)的成本大大降低

总结:一句话解释

在优化器更新时不需要通信,因为 ZeRO-2 的设计确保了所有必需的通信(梯度同步)都在更新步骤之前就已经完成了,并且每个 GPU 只需更新它在本地已拥有全部信息的参数子集。

这个设计思想是分布式系统优化的一个普遍原则:将通信和计算分离,并尽可能地将大的计算任务分解成可以在本地独立完成的小任务,从而减少同步开销和资源冗余。ZeRO-2 正是这一原则在深度学习训练领域的完美应用。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐