SGLang到Megatron的完整请求调度和Loss计算流程

1. 请求发起和数据生成阶段

RolloutManager协调SGLang引擎生成数据

  • RolloutManager.generate() 方法发起rollout请求 1
  • 调用 generate_rollout 函数通过SGLang生成样本 2
  • SGLang引擎执行推理,生成response并计算reward

数据收集和转换

  • 生成完成后调用 _convert_samples_to_train_data 转换格式 3
  • 提取tokens、response_lengths、rewards、rollout_log_probs等关键字段
  • 通过 _split_train_data_by_dp 按数据并行分区 4

2. 数据传递到Megatron

Ray远程调用传递数据

  • 主训练循环通过 ray.get(rollout_manager.generate.remote(rollout_id)) 获取数据 5
  • 调用 actor_model.async_train(rollout_id, rollout_data_ref) 传递给训练Actor 6

Megatron接收和处理数据

  • MegatronTrainRayActor.train() 方法接收训练请求 7
  • _get_rollout_data() 处理并移动数据到GPU 8
  • 创建数据迭代器准备训练 9

3. RL Loss计算详细过程

Advantage计算

  • 调用 compute_advantages_and_returns() 计算advantages和returns 10
  • 根据不同算法(GRPO、PPO等)使用不同的advantage estimator

Policy Loss计算

  • 计算PPO KL散度:ppo_kl = old_log_probs - log_probs 11
  • 调用 compute_policy_loss() 计算clipped surrogate loss 12
  • 核心公式:
    ratio = (-ppo_kl).exp()
    pg_losses1 = -ratio * advantages
    pg_losses2 = -ratio.clamp(1 - eps_clip, 1 + eps_clip_high) * advantages
    pg_losses = torch.maximum(pg_losses1, pg_losses2)
    

总Loss组合

  • 基础loss:loss = pg_loss - args.entropy_coef * entropy_loss 13
  • 可选添加KL loss:loss = loss + args.kl_loss_coef * kl_loss

4. 训练执行和权重更新

Megatron训练步骤

  • train_one_step() 执行pipeline并行训练 14
  • 前向传播计算logits和loss
  • 反向传播计算梯度
  • 优化器更新参数

权重同步到SGLang

  • 训练完成后调用 actor_model.update_weights() 15
  • 通过 UpdateWeightFromTensorUpdateWeightFromDistributed 同步权重 16

5. 完整流程图

Megatron MegatronTrainRayActor RayTrainGroup SGLang Engines RolloutManager train.py Megatron MegatronTrainRayActor RayTrainGroup SGLang Engines RolloutManager train.py generate.remote(rollout_id) 生成请求 返回samples + rewards 转换为train_data 返回rollout_data_ref async_train(rollout_id, data_ref) train.remote(rollout_id, data_ref) _get_rollout_data() 执行训练步骤 计算RL loss 更新参数 训练完成 update_weights() update_weights() 同步新权重

Notes

整个流程体现了slime框架的核心设计:通过Ray实现分布式协调,SGLang负责高效推理,Megatron负责稳定训练,两者通过精心设计的数据传递和权重同步机制无缝集成,实现高效的RL训练循环。 17

Wiki pages you might want to explore:

Citations

File: slime/ray/rollout.py (L93-108)

    def generate(self, rollout_id):
        monitor_started = self.args.use_fault_tolerance and self._health_monitor.start()
        start_time = time.time()
        try:
            data, metrics = self._get_rollout_data(rollout_id=rollout_id)
            self._save_debug_rollout_data(data, rollout_id=rollout_id, evaluation=False)
            _log_rollout_data(rollout_id, self.args, data, metrics, time.time() - start_time)
            data = self._convert_samples_to_train_data(data)
            return self._split_train_data_by_dp(data, self.train_parallel_config["dp_size"])
        finally:
            if monitor_started:
                self._health_monitor.stop()
                self.num_new_engines = init_rollout_engines(self.args, self.pg, self.all_rollout_engines)
            else:
                self.num_new_engines = 0

File: slime/ray/rollout.py (L215-275)

    def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sample]]):
        """
        Convert inference generated samples to training data.
        """
        raw_rewards, rewards = self._post_process_rewards(samples)

        assert len(raw_rewards) == len(samples)
        assert len(rewards) == len(samples)

        train_data = {
            "tokens": [sample.tokens for sample in samples],
            "response_lengths": [sample.response_length for sample in samples],
            # some reward model, e.g. remote rm, may return multiple rewards,
            # we could use key to select the reward.
            "rewards": rewards,
            "raw_reward": raw_rewards,
            "truncated": [1 if sample.status == Sample.Status.TRUNCATED else 0 for sample in samples],
            "sample_indices": [sample.index for sample in samples],
        }

        # loss mask
        # TODO: compress the loss mask
        loss_masks = []
        for sample in samples:
            # always instantiate loss_mask if not provided
            if sample.loss_mask is None:
                sample.loss_mask = [1] * sample.response_length

            assert (
                len(sample.loss_mask) == sample.response_length
            ), f"loss mask length {len(sample.loss_mask)} != response length {sample.response_length}"
            if sample.remove_sample:
                sample.loss_mask = [0] * sample.response_length
            loss_masks.append(sample.loss_mask)
        train_data["loss_masks"] = loss_masks

        # overwriting the raw reward
        if samples[0].metadata and "raw_reward" in samples[0].metadata:
            train_data["raw_reward"] = [sample.metadata["raw_reward"] for sample in samples]

        # For rollout buffer
        if samples[0].metadata and "round_number" in samples[0].metadata:
            train_data["round_number"] = [sample.metadata["round_number"] for sample in samples]

        # Add rollout log probabilities for off-policy correction
        if samples[0].rollout_log_probs is not None:
            train_data["rollout_log_probs"] = [sample.rollout_log_probs for sample in samples]

        if samples[0].rollout_routed_experts is not None:
            train_data["rollout_routed_experts"] = [sample.rollout_routed_experts for sample in samples]

        if samples[0].train_metadata is not None:
            train_data["metadata"] = [sample.train_metadata for sample in samples]

        if samples[0].multimodal_inputs is not None:
            train_data["multimodal_inputs"] = [sample.multimodal_inputs for sample in samples]

        if "teacher_log_probs" in samples[0].__dict__:
            train_data["teacher_log_probs"] = [sample.teacher_log_probs for sample in samples]

        return train_data

File: slime/ray/rollout.py (L280-328)

    def _split_train_data_by_dp(self, data, dp_size):
        """Split the train data by data parallel size."""
        rollout_data = {}

        if "prompt" in data:
            rollout_data["prompt"] = data["prompt"]

        total_lengths = [len(t) for t in data["tokens"]]
        data["total_lengths"] = total_lengths

        if self.args.balance_data:
            partitions = get_seqlen_balanced_partitions(total_lengths, dp_size, equal_size=True)
        else:
            partitions = [range(i, len(total_lengths), dp_size) for i in range(dp_size)]

        rollout_data_refs = []

        for i in range(dp_size):
            rollout_data = {}
            partition = partitions[i]
            rollout_data["partition"] = partition
            for key in [
                "tokens",
                "multimodal_inputs",
                "response_lengths",
                "rewards",
                "truncated",
                "loss_masks",
                "round_number",
                "sample_indices",
                "rollout_log_probs",
                "rollout_routed_experts",
                "prompt",
                "teacher_log_probs",
            ]:
                if key not in data:
                    continue
                val = [data[key][j] for j in partition]
                rollout_data[key] = val
            # keys that need to be splited at train side
            for key in [
                "raw_reward",
                "total_lengths",
            ]:
                if key not in data:
                    continue
                rollout_data[key] = data[key]
            rollout_data_refs.append(Box(ray.put(rollout_data)))
        return rollout_data_refs

File: slime/rollout/sglang_rollout.py (L330-419)

async def generate_rollout_async(
    args: Namespace, rollout_id: int, data_source: Callable[[int], list[list[Sample]]]
) -> tuple[RolloutFnTrainOutput, list[list[Sample]]]:
    """An example to implement the generate_rollout function for an rule based rm rollout generation.

    Args:
        args: the whole args
        rollout_id: int, the id of the rollout, used for deterministic data generation
        data_source: the data source to fetch

    Returns:
        tuple[RolloutFnTrainOutput, list[list[Sample]]]:
            - data: a list of groups of samples generated by the rollout, length equals `rollout_batch_size`
            - aborted_samples: any partial groups collected during abort when partial_rollout is enabled
    """
    assert args.rollout_global_dataset

    state = GenerateState(args)

    # instantiate data filters
    dynamic_filter = (
        load_function(args.dynamic_sampling_filter_path) if args.dynamic_sampling_filter_path is not None else None
    )

    metric_gatherer = _MetricGatherer()

    # target_data_size is the total number of valid samples to get
    target_data_size = args.rollout_batch_size

    data = []
    all_data = []
    do_print = True
    pbar = tqdm(total=target_data_size * args.n_samples_per_prompt, desc="Rollout generation")
    while len(data) < target_data_size:
        while state.remaining_batch_size < target_data_size:
            # get samples from the buffer and submit the generation requests.
            samples = data_source(args.over_sampling_batch_size)
            state.submit_generate_tasks(samples)

        # wait for the generation to finish
        done, state.pendings = await asyncio.wait(state.pendings, return_when=asyncio.FIRST_COMPLETED)
        for task in done:
            group: list[Sample] = task.result()

            if do_print:
                sample = group[0][0] if isinstance(group[0], list) else group[0]
                logger.info(
                    f"First rollout sample: {[str(sample.prompt) + sample.response]}, label: {sample.label}, reward: {sample.reward}",
                )
                do_print = False

            assert len(group) == args.n_samples_per_prompt
            all_data.append(group)
            dynamic_filter_output = _call_dynamic_filter(dynamic_filter, args, group)
            if not dynamic_filter_output.keep:
                metric_gatherer.on_dynamic_filter_drop(reason=dynamic_filter_output.reason)
                state.remaining_batch_size -= 1
                continue

            # add the samples to the data
            # NOTE: here we have not stored all the unused samples back to the data buffer.
            if len(data) < target_data_size:
                data.append(group)
                pbar.update(args.n_samples_per_prompt)

    pbar.close()
    sample = data[-1][0][0] if isinstance(data[-1][0], list) else data[-1][0]
    logger.info(
        f"Finish rollout: {[str(sample.prompt) + sample.response]}, label: {sample.label}, reward: {sample.reward}",
    )

    # there are still some unfinished requests, abort them
    aborted_samples = await abort(args, rollout_id)

    assert len(data) == args.rollout_batch_size, f"Got {len(data)} samples, expected {args.rollout_batch_size}"
    data = sorted(data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index)
    all_samples = sorted(data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index)

    # reset the global state to prevent effects on the next rollout or eval.
    state.reset()
    if args.rollout_sample_filter_path is not None:
        filter_func = load_function(args.rollout_sample_filter_path)
        filter_func(args, data)

    # There can be circumstances where users want to process all samples including filtered ones.
    if args.rollout_all_samples_process_path is not None:
        process_func = load_function(args.rollout_all_samples_process_path)
        process_func(args, all_samples, data_source)

    return RolloutFnTrainOutput(samples=data, metrics=metric_gatherer.collect()), aborted_samples

File: train.py (L68-68)

        rollout_data_ref = ray.get(rollout_manager.generate.remote(rollout_id))

File: train.py (L76-79)

                ray.get(actor_model.async_train(rollout_id, rollout_data_ref))
            ray.get(critic_train_handle)
        else:
            ray.get(actor_model.async_train(rollout_id, rollout_data_ref))

File: train.py (L91-91)

        actor_model.update_weights()

File: slime/backends/megatron_utils/actor.py (L130-137)

        update_weight_cls = UpdateWeightFromTensor if self.args.colocate else UpdateWeightFromDistributed
        self.weight_updater = update_weight_cls(
            self.args,
            self.model,
            weights_getter=lambda: self.weights_backuper.get("actor"),
            model_name=type(self.hf_config).__name__.lower() if self.args.model_name is None else self.args.model_name,
            quantization_config=getattr(self.hf_config, "quantization_config", None),
        )

File: slime/backends/megatron_utils/actor.py (L182-227)

    def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch:
        # Fetch data through ray on CPU, not sure if this will be performance bottleneck.
        # Both first pp stage and the last pp stage will receive the data.
        rollout_data = process_rollout_data(
            self.args,
            rollout_data_ref,
            mpu.get_data_parallel_rank(with_context_parallel=False),
            mpu.get_data_parallel_world_size(with_context_parallel=False),
        )
        # TODO: this is ugly, move to somewhere else?
        # move tokens to GPU in advance
        rollout_data["tokens"] = [
            torch.tensor(t, dtype=torch.long, device=torch.cuda.current_device()) for t in rollout_data["tokens"]
        ]
        rollout_data["loss_masks"] = [
            torch.tensor(t, dtype=torch.int, device=torch.cuda.current_device()) for t in rollout_data["loss_masks"]
        ]
        if "multimodal_inputs" in rollout_data:
            # Move multimodal inputs to GPU in advance
            rollout_data["multimodal_inputs"] = [
                (
                    {key: tensor.to(device=torch.cuda.current_device()) for key, tensor in mm_dict.items()}
                    if mm_dict is not None
                    else None
                )
                for mm_dict in rollout_data["multimodal_inputs"]
            ]
        if "rollout_log_probs" in rollout_data:
            rollout_data["rollout_log_probs"] = [
                torch.tensor(
                    slice_log_prob_with_cp(log_prob, total_length, response_length),
                    device=torch.cuda.current_device(),
                    dtype=torch.float32,
                )
                for log_prob, total_length, response_length in zip(
                    rollout_data["rollout_log_probs"],
                    rollout_data["total_lengths"],
                    rollout_data["response_lengths"],
                    strict=False,
                )
            ]
        if "rollout_routed_experts" in rollout_data:
            rollout_data["rollout_routed_experts"] = [
                torch.from_numpy(r) for r in rollout_data["rollout_routed_experts"]
            ]
        return rollout_data

File: slime/backends/megatron_utils/actor.py (L330-344)

    def train(self, rollout_id: int, rollout_data_ref: Box) -> None:
        if self.args.offload_train:
            self.wake_up()

        with timer("data_preprocess"):
            rollout_data = self._get_rollout_data(rollout_data_ref)
            if self.args.debug_rollout_only:
                log_rollout_data(rollout_id, self.args, rollout_data)
                return

        if self.role == "critic":
            return self.train_critic(rollout_id, rollout_data)
        else:
            return self.train_actor(rollout_id, rollout_data)

File: slime/backends/megatron_utils/actor.py (L375-375)

        data_iterator, num_microbatches = get_data_iterator(self.args, self.model, rollout_data)

File: slime/backends/megatron_utils/actor.py (L421-421)

                compute_advantages_and_returns(self.args, rollout_data)

File: slime/backends/fsdp_utils/actor.py (L624-624)

        ppo_kl = old_log_probs - log_probs

File: slime/backends/fsdp_utils/actor.py (L691-706)

        loss = pg_loss - self.args.entropy_coef * entropy_loss

        if self.args.use_kl_loss:
            ref_log_probs = torch.cat([batch["ref_log_probs"] for batch in unpacked_batches], dim=0)
            importance_ratio = None
            if self.args.use_unbiased_kl:
                importance_ratio = torch.exp(log_probs - old_log_probs)
            kl = compute_approx_kl(
                log_probs,
                ref_log_probs,
                kl_loss_type=self.args.kl_loss_type,
                importance_ratio=importance_ratio,
            )
            kl_loss = sum_of_sample_mean(kl, response_lengths, loss_masks)

            loss = loss + self.args.kl_loss_coef * kl_loss

File: slime/backends/megatron_utils/model.py (L293-339)

def train_one_step(
    args: Namespace,
    rollout_id: int,
    step_id: int,
    data_iterator: Sequence[DataIterator],
    model: Sequence[DDP],
    optimizer: MegatronOptimizer,
    opt_param_scheduler: OptimizerParamScheduler,
    num_microbatches: int,
) -> tuple[dict[str, float], float]:
    """Execute a single pipeline-parallel training step.

    Runs forward/backward over ``num_microbatches``, applies optimizer step and
    one scheduler step when gradients are valid.

    Args:
        args (Namespace): Runtime arguments.
        rollout_id (int): Rollout identifier.
        step_id (int): Step index within the current rollout.
        data_iterator (Sequence[DataIterator]): Iterable(s) yielding training batches.
        model (Sequence[DDP]): Sequence of DDP-wrapped model chunks.
        optimizer (MegatronOptimizer): Optimizer instance.
        opt_param_scheduler (OptimizerParamScheduler): LR/WD scheduler.
        num_microbatches (int): Number of microbatches to process.

    Returns:
        tuple[dict[str, float], float]: Reduced loss dictionary (last stage only)
        and gradient norm for logging.
    """
    args = get_args()

    # Set grad to zero.
    for model_chunk in model:
        model_chunk.zero_grad_buffer()
    optimizer.zero_grad()

    if args.custom_megatron_before_train_step_hook_path:
        from slime.utils.misc import load_function

        custom_before_train_step_hook = load_function(args.custom_megatron_before_train_step_hook_path)
        custom_before_train_step_hook(args, rollout_id, step_id, model, optimizer, opt_param_scheduler)

    def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_plan: bool = False) -> tuple[
        torch.Tensor,
        Callable[[torch.Tensor], tuple[torch.Tensor, int, dict[str, torch.Tensor | list[str]]]],
    ]:
        """Forward step used by Megatron's pipeline engine during training.

File: docs/en/blogs/introducing_slime.md (L37-45)

slime views the data sampling in RL differently. We manage all SGLang servers within slime with [sgl-router](https://github.com/sgl-project/sglang/tree/main/sgl-router) and provide an interface for the data generation component, **allowing users to inject custom logic and freely interact with SGLang servers**. Unleash their creativity.

![slime architecture](/imgs/arch.png)

With the sgl-router, users only need to send HTTP requests to a single endpoint. By exposing this endpoint, complex agent environments can directly interact with slime through an OpenAI-compatible API — no need to modify the environment, and training-deployment consistency is preserved.

Regarding training schemes, slime uses Ray for resource management, enabling **colocated** (same GPUs) or **decoupled** (separate GPUs) setups with a single flag (`--colocate`).

And with Ray's asynchronous execution via `.remote()`, slime naturally supports asynchronous training. Changing synchronization behavior is as simple as moving the `ray.get` operation. And to make experimenting with different strategies easy, we didn't wrap the code with trainer classes, but simply exposed the training loop in entrypoint  `train.py`.
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐