【RL】Slime训练流程
整个流程体现了slime框架的核心设计:通过Ray实现分布式协调,SGLang负责高效推理,Megatron负责稳定训练,两者通过精心设计的数据传递和权重同步机制无缝集成,实现高效的RL训练循环。返回samples + rewards。返回rollout_data_ref。转换为train_data。
SGLang到Megatron的完整请求调度和Loss计算流程
1. 请求发起和数据生成阶段
RolloutManager协调SGLang引擎生成数据:
RolloutManager.generate()方法发起rollout请求 1- 调用
generate_rollout函数通过SGLang生成样本 2 - SGLang引擎执行推理,生成response并计算reward
数据收集和转换:
- 生成完成后调用
_convert_samples_to_train_data转换格式 3 - 提取tokens、response_lengths、rewards、rollout_log_probs等关键字段
- 通过
_split_train_data_by_dp按数据并行分区 4
2. 数据传递到Megatron
Ray远程调用传递数据:
- 主训练循环通过
ray.get(rollout_manager.generate.remote(rollout_id))获取数据 5 - 调用
actor_model.async_train(rollout_id, rollout_data_ref)传递给训练Actor 6
Megatron接收和处理数据:
3. RL Loss计算详细过程
Advantage计算:
- 调用
compute_advantages_and_returns()计算advantages和returns 10 - 根据不同算法(GRPO、PPO等)使用不同的advantage estimator
Policy Loss计算:
- 计算PPO KL散度:
ppo_kl = old_log_probs - log_probs11 - 调用
compute_policy_loss()计算clipped surrogate loss 12 - 核心公式:
ratio = (-ppo_kl).exp() pg_losses1 = -ratio * advantages pg_losses2 = -ratio.clamp(1 - eps_clip, 1 + eps_clip_high) * advantages pg_losses = torch.maximum(pg_losses1, pg_losses2)
总Loss组合:
- 基础loss:
loss = pg_loss - args.entropy_coef * entropy_loss13 - 可选添加KL loss:
loss = loss + args.kl_loss_coef * kl_loss
4. 训练执行和权重更新
Megatron训练步骤:
train_one_step()执行pipeline并行训练 14- 前向传播计算logits和loss
- 反向传播计算梯度
- 优化器更新参数
权重同步到SGLang:
- 训练完成后调用
actor_model.update_weights()15 - 通过
UpdateWeightFromTensor或UpdateWeightFromDistributed同步权重 16
5. 完整流程图
Notes
整个流程体现了slime框架的核心设计:通过Ray实现分布式协调,SGLang负责高效推理,Megatron负责稳定训练,两者通过精心设计的数据传递和权重同步机制无缝集成,实现高效的RL训练循环。 17
Wiki pages you might want to explore:
Citations
File: slime/ray/rollout.py (L93-108)
def generate(self, rollout_id):
monitor_started = self.args.use_fault_tolerance and self._health_monitor.start()
start_time = time.time()
try:
data, metrics = self._get_rollout_data(rollout_id=rollout_id)
self._save_debug_rollout_data(data, rollout_id=rollout_id, evaluation=False)
_log_rollout_data(rollout_id, self.args, data, metrics, time.time() - start_time)
data = self._convert_samples_to_train_data(data)
return self._split_train_data_by_dp(data, self.train_parallel_config["dp_size"])
finally:
if monitor_started:
self._health_monitor.stop()
self.num_new_engines = init_rollout_engines(self.args, self.pg, self.all_rollout_engines)
else:
self.num_new_engines = 0
File: slime/ray/rollout.py (L215-275)
def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sample]]):
"""
Convert inference generated samples to training data.
"""
raw_rewards, rewards = self._post_process_rewards(samples)
assert len(raw_rewards) == len(samples)
assert len(rewards) == len(samples)
train_data = {
"tokens": [sample.tokens for sample in samples],
"response_lengths": [sample.response_length for sample in samples],
# some reward model, e.g. remote rm, may return multiple rewards,
# we could use key to select the reward.
"rewards": rewards,
"raw_reward": raw_rewards,
"truncated": [1 if sample.status == Sample.Status.TRUNCATED else 0 for sample in samples],
"sample_indices": [sample.index for sample in samples],
}
# loss mask
# TODO: compress the loss mask
loss_masks = []
for sample in samples:
# always instantiate loss_mask if not provided
if sample.loss_mask is None:
sample.loss_mask = [1] * sample.response_length
assert (
len(sample.loss_mask) == sample.response_length
), f"loss mask length {len(sample.loss_mask)} != response length {sample.response_length}"
if sample.remove_sample:
sample.loss_mask = [0] * sample.response_length
loss_masks.append(sample.loss_mask)
train_data["loss_masks"] = loss_masks
# overwriting the raw reward
if samples[0].metadata and "raw_reward" in samples[0].metadata:
train_data["raw_reward"] = [sample.metadata["raw_reward"] for sample in samples]
# For rollout buffer
if samples[0].metadata and "round_number" in samples[0].metadata:
train_data["round_number"] = [sample.metadata["round_number"] for sample in samples]
# Add rollout log probabilities for off-policy correction
if samples[0].rollout_log_probs is not None:
train_data["rollout_log_probs"] = [sample.rollout_log_probs for sample in samples]
if samples[0].rollout_routed_experts is not None:
train_data["rollout_routed_experts"] = [sample.rollout_routed_experts for sample in samples]
if samples[0].train_metadata is not None:
train_data["metadata"] = [sample.train_metadata for sample in samples]
if samples[0].multimodal_inputs is not None:
train_data["multimodal_inputs"] = [sample.multimodal_inputs for sample in samples]
if "teacher_log_probs" in samples[0].__dict__:
train_data["teacher_log_probs"] = [sample.teacher_log_probs for sample in samples]
return train_data
File: slime/ray/rollout.py (L280-328)
def _split_train_data_by_dp(self, data, dp_size):
"""Split the train data by data parallel size."""
rollout_data = {}
if "prompt" in data:
rollout_data["prompt"] = data["prompt"]
total_lengths = [len(t) for t in data["tokens"]]
data["total_lengths"] = total_lengths
if self.args.balance_data:
partitions = get_seqlen_balanced_partitions(total_lengths, dp_size, equal_size=True)
else:
partitions = [range(i, len(total_lengths), dp_size) for i in range(dp_size)]
rollout_data_refs = []
for i in range(dp_size):
rollout_data = {}
partition = partitions[i]
rollout_data["partition"] = partition
for key in [
"tokens",
"multimodal_inputs",
"response_lengths",
"rewards",
"truncated",
"loss_masks",
"round_number",
"sample_indices",
"rollout_log_probs",
"rollout_routed_experts",
"prompt",
"teacher_log_probs",
]:
if key not in data:
continue
val = [data[key][j] for j in partition]
rollout_data[key] = val
# keys that need to be splited at train side
for key in [
"raw_reward",
"total_lengths",
]:
if key not in data:
continue
rollout_data[key] = data[key]
rollout_data_refs.append(Box(ray.put(rollout_data)))
return rollout_data_refs
File: slime/rollout/sglang_rollout.py (L330-419)
async def generate_rollout_async(
args: Namespace, rollout_id: int, data_source: Callable[[int], list[list[Sample]]]
) -> tuple[RolloutFnTrainOutput, list[list[Sample]]]:
"""An example to implement the generate_rollout function for an rule based rm rollout generation.
Args:
args: the whole args
rollout_id: int, the id of the rollout, used for deterministic data generation
data_source: the data source to fetch
Returns:
tuple[RolloutFnTrainOutput, list[list[Sample]]]:
- data: a list of groups of samples generated by the rollout, length equals `rollout_batch_size`
- aborted_samples: any partial groups collected during abort when partial_rollout is enabled
"""
assert args.rollout_global_dataset
state = GenerateState(args)
# instantiate data filters
dynamic_filter = (
load_function(args.dynamic_sampling_filter_path) if args.dynamic_sampling_filter_path is not None else None
)
metric_gatherer = _MetricGatherer()
# target_data_size is the total number of valid samples to get
target_data_size = args.rollout_batch_size
data = []
all_data = []
do_print = True
pbar = tqdm(total=target_data_size * args.n_samples_per_prompt, desc="Rollout generation")
while len(data) < target_data_size:
while state.remaining_batch_size < target_data_size:
# get samples from the buffer and submit the generation requests.
samples = data_source(args.over_sampling_batch_size)
state.submit_generate_tasks(samples)
# wait for the generation to finish
done, state.pendings = await asyncio.wait(state.pendings, return_when=asyncio.FIRST_COMPLETED)
for task in done:
group: list[Sample] = task.result()
if do_print:
sample = group[0][0] if isinstance(group[0], list) else group[0]
logger.info(
f"First rollout sample: {[str(sample.prompt) + sample.response]}, label: {sample.label}, reward: {sample.reward}",
)
do_print = False
assert len(group) == args.n_samples_per_prompt
all_data.append(group)
dynamic_filter_output = _call_dynamic_filter(dynamic_filter, args, group)
if not dynamic_filter_output.keep:
metric_gatherer.on_dynamic_filter_drop(reason=dynamic_filter_output.reason)
state.remaining_batch_size -= 1
continue
# add the samples to the data
# NOTE: here we have not stored all the unused samples back to the data buffer.
if len(data) < target_data_size:
data.append(group)
pbar.update(args.n_samples_per_prompt)
pbar.close()
sample = data[-1][0][0] if isinstance(data[-1][0], list) else data[-1][0]
logger.info(
f"Finish rollout: {[str(sample.prompt) + sample.response]}, label: {sample.label}, reward: {sample.reward}",
)
# there are still some unfinished requests, abort them
aborted_samples = await abort(args, rollout_id)
assert len(data) == args.rollout_batch_size, f"Got {len(data)} samples, expected {args.rollout_batch_size}"
data = sorted(data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index)
all_samples = sorted(data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index)
# reset the global state to prevent effects on the next rollout or eval.
state.reset()
if args.rollout_sample_filter_path is not None:
filter_func = load_function(args.rollout_sample_filter_path)
filter_func(args, data)
# There can be circumstances where users want to process all samples including filtered ones.
if args.rollout_all_samples_process_path is not None:
process_func = load_function(args.rollout_all_samples_process_path)
process_func(args, all_samples, data_source)
return RolloutFnTrainOutput(samples=data, metrics=metric_gatherer.collect()), aborted_samples
File: train.py (L68-68)
rollout_data_ref = ray.get(rollout_manager.generate.remote(rollout_id))
File: train.py (L76-79)
ray.get(actor_model.async_train(rollout_id, rollout_data_ref))
ray.get(critic_train_handle)
else:
ray.get(actor_model.async_train(rollout_id, rollout_data_ref))
File: train.py (L91-91)
actor_model.update_weights()
File: slime/backends/megatron_utils/actor.py (L130-137)
update_weight_cls = UpdateWeightFromTensor if self.args.colocate else UpdateWeightFromDistributed
self.weight_updater = update_weight_cls(
self.args,
self.model,
weights_getter=lambda: self.weights_backuper.get("actor"),
model_name=type(self.hf_config).__name__.lower() if self.args.model_name is None else self.args.model_name,
quantization_config=getattr(self.hf_config, "quantization_config", None),
)
File: slime/backends/megatron_utils/actor.py (L182-227)
def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch:
# Fetch data through ray on CPU, not sure if this will be performance bottleneck.
# Both first pp stage and the last pp stage will receive the data.
rollout_data = process_rollout_data(
self.args,
rollout_data_ref,
mpu.get_data_parallel_rank(with_context_parallel=False),
mpu.get_data_parallel_world_size(with_context_parallel=False),
)
# TODO: this is ugly, move to somewhere else?
# move tokens to GPU in advance
rollout_data["tokens"] = [
torch.tensor(t, dtype=torch.long, device=torch.cuda.current_device()) for t in rollout_data["tokens"]
]
rollout_data["loss_masks"] = [
torch.tensor(t, dtype=torch.int, device=torch.cuda.current_device()) for t in rollout_data["loss_masks"]
]
if "multimodal_inputs" in rollout_data:
# Move multimodal inputs to GPU in advance
rollout_data["multimodal_inputs"] = [
(
{key: tensor.to(device=torch.cuda.current_device()) for key, tensor in mm_dict.items()}
if mm_dict is not None
else None
)
for mm_dict in rollout_data["multimodal_inputs"]
]
if "rollout_log_probs" in rollout_data:
rollout_data["rollout_log_probs"] = [
torch.tensor(
slice_log_prob_with_cp(log_prob, total_length, response_length),
device=torch.cuda.current_device(),
dtype=torch.float32,
)
for log_prob, total_length, response_length in zip(
rollout_data["rollout_log_probs"],
rollout_data["total_lengths"],
rollout_data["response_lengths"],
strict=False,
)
]
if "rollout_routed_experts" in rollout_data:
rollout_data["rollout_routed_experts"] = [
torch.from_numpy(r) for r in rollout_data["rollout_routed_experts"]
]
return rollout_data
File: slime/backends/megatron_utils/actor.py (L330-344)
def train(self, rollout_id: int, rollout_data_ref: Box) -> None:
if self.args.offload_train:
self.wake_up()
with timer("data_preprocess"):
rollout_data = self._get_rollout_data(rollout_data_ref)
if self.args.debug_rollout_only:
log_rollout_data(rollout_id, self.args, rollout_data)
return
if self.role == "critic":
return self.train_critic(rollout_id, rollout_data)
else:
return self.train_actor(rollout_id, rollout_data)
File: slime/backends/megatron_utils/actor.py (L375-375)
data_iterator, num_microbatches = get_data_iterator(self.args, self.model, rollout_data)
File: slime/backends/megatron_utils/actor.py (L421-421)
compute_advantages_and_returns(self.args, rollout_data)
File: slime/backends/fsdp_utils/actor.py (L624-624)
ppo_kl = old_log_probs - log_probs
File: slime/backends/fsdp_utils/actor.py (L691-706)
loss = pg_loss - self.args.entropy_coef * entropy_loss
if self.args.use_kl_loss:
ref_log_probs = torch.cat([batch["ref_log_probs"] for batch in unpacked_batches], dim=0)
importance_ratio = None
if self.args.use_unbiased_kl:
importance_ratio = torch.exp(log_probs - old_log_probs)
kl = compute_approx_kl(
log_probs,
ref_log_probs,
kl_loss_type=self.args.kl_loss_type,
importance_ratio=importance_ratio,
)
kl_loss = sum_of_sample_mean(kl, response_lengths, loss_masks)
loss = loss + self.args.kl_loss_coef * kl_loss
File: slime/backends/megatron_utils/model.py (L293-339)
def train_one_step(
args: Namespace,
rollout_id: int,
step_id: int,
data_iterator: Sequence[DataIterator],
model: Sequence[DDP],
optimizer: MegatronOptimizer,
opt_param_scheduler: OptimizerParamScheduler,
num_microbatches: int,
) -> tuple[dict[str, float], float]:
"""Execute a single pipeline-parallel training step.
Runs forward/backward over ``num_microbatches``, applies optimizer step and
one scheduler step when gradients are valid.
Args:
args (Namespace): Runtime arguments.
rollout_id (int): Rollout identifier.
step_id (int): Step index within the current rollout.
data_iterator (Sequence[DataIterator]): Iterable(s) yielding training batches.
model (Sequence[DDP]): Sequence of DDP-wrapped model chunks.
optimizer (MegatronOptimizer): Optimizer instance.
opt_param_scheduler (OptimizerParamScheduler): LR/WD scheduler.
num_microbatches (int): Number of microbatches to process.
Returns:
tuple[dict[str, float], float]: Reduced loss dictionary (last stage only)
and gradient norm for logging.
"""
args = get_args()
# Set grad to zero.
for model_chunk in model:
model_chunk.zero_grad_buffer()
optimizer.zero_grad()
if args.custom_megatron_before_train_step_hook_path:
from slime.utils.misc import load_function
custom_before_train_step_hook = load_function(args.custom_megatron_before_train_step_hook_path)
custom_before_train_step_hook(args, rollout_id, step_id, model, optimizer, opt_param_scheduler)
def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_plan: bool = False) -> tuple[
torch.Tensor,
Callable[[torch.Tensor], tuple[torch.Tensor, int, dict[str, torch.Tensor | list[str]]]],
]:
"""Forward step used by Megatron's pipeline engine during training.
File: docs/en/blogs/introducing_slime.md (L37-45)
slime views the data sampling in RL differently. We manage all SGLang servers within slime with [sgl-router](https://github.com/sgl-project/sglang/tree/main/sgl-router) and provide an interface for the data generation component, **allowing users to inject custom logic and freely interact with SGLang servers**. Unleash their creativity.

With the sgl-router, users only need to send HTTP requests to a single endpoint. By exposing this endpoint, complex agent environments can directly interact with slime through an OpenAI-compatible API — no need to modify the environment, and training-deployment consistency is preserved.
Regarding training schemes, slime uses Ray for resource management, enabling **colocated** (same GPUs) or **decoupled** (separate GPUs) setups with a single flag (`--colocate`).
And with Ray's asynchronous execution via `.remote()`, slime naturally supports asynchronous training. Changing synchronization behavior is as simple as moving the `ray.get` operation. And to make experimenting with different strategies easy, we didn't wrap the code with trainer classes, but simply exposed the training loop in entrypoint `train.py`.
更多推荐



所有评论(0)