简短回答

是的,critic 和 actor_train 是同时进行训练的。在所有 ROLL 管道实现中,两者都使用 blocking=False 参数异步启动训练步骤,可以并行执行。 1 2


详细说明

训练执行模式

在 ROLL 框架的所有主要管道中,critic 和 actor_train 的训练都是并发执行的:

  1. 异步启动: 两者都使用 blocking=False 参数启动训练步骤 3
  2. 并行执行: 由于使用了 Ray 的分布式执行框架,这两个训练步骤可以在不同的 GPU 或节点上并行运行
  3. 统一收集: 训练指标稍后通过 DataProto.materialize_concat 统一收集 4

具体实现模式

所有管道都遵循相同的执行模式:

# 启动 critic 训练(异步)
critic_train_metrics_refs = self.critic.train_step(batch, blocking=False)

# 检查 critic 预热条件后启动 actor 训练(异步)
if self.pipeline_config.critic_warmup <= global_step:
    actor_train_metrics_refs = self.actor_train.train_step(batch, blocking=False)

# 收集两者的训练指标
critic_train_metrics = DataProto.materialize_concat(data_refs=critic_train_metrics_refs)
actor_train_metrics = DataProto.materialize_concat(data_refs=actor_train_metrics_refs)

管道一致性

这种并发训练模式在所有主要管道中保持一致:

  • AgenticPipeline: 1
  • RLVRPipeline: 2
  • RLVRMathVLMPipeline: 5
  • RLVRVLMPipeline: 6

设计优势

这种并发设计提供了以下优势:

  1. 提高资源利用率: critic 和 actor 可以在不同的 GPU 上同时训练
  2. 减少训练时间: 避免了串行执行的时间开销
  3. 保持同步: 通过统一的指标收集确保训练步骤的一致性

Notes

需要注意的是,虽然训练是并发启动的,但实际的并行度取决于可用的 GPU 资源和配置的并行策略(如数据并行、张量并行等)。critic 预热机制(critic_warmup)确保在训练初期 critic 先学习价值函数,然后再开始 actor 的策略更新。

Wiki pages you might want to explore:

Citations

File: roll/pipeline/agentic/agentic_pipeline.py (L280-291)

                            critic_train_metrics_refs: List[ray.ObjectRef] = self.critic.train_step(batch, blocking=False)

                        # implement critic warmup
                        if self.pipeline_config.critic_warmup <= global_step:
                            # update actor
                            actor_train_metrics_refs = self.actor_train.train_step(batch, blocking=False)
                            actor_train_metrics: DataProto = DataProto.materialize_concat(data_refs=actor_train_metrics_refs)
                            metrics.update(reduce_metrics(actor_train_metrics.meta_info.pop("metrics", {})))

                        if self.pipeline_config.adv_estimator == "gae":
                            critic_train_metrics = DataProto.materialize_concat(data_refs=critic_train_metrics_refs)
                            metrics.update(reduce_metrics(critic_train_metrics.meta_info.pop("metrics", {})))

File: roll/pipeline/rlvr/rlvr_pipeline.py (L703-726)

                        critic_train_metrics_refs: List[ray.ObjectRef] = self.critic.train_step(batch, blocking=False)

                    with actor_train_timer:
                        # implement critic warmup
                        if self.pipeline_config.critic_warmup <= global_step:
                            # update actor
                            if self.pipeline_config.actor_train.use_dynamic_batching_in_train:
                                batch, dynamic_batching_metrics = dynamic_batching_shard(
                                    batch,
                                    self.actor_train.dp_size,
                                    self.pipeline_config.actor_train.max_tokens_per_microbatch_in_train,
                                    self.pipeline_config.actor_train.sequence_length_round_in_train,
                                    "actor_train/train_step",
                                )
                                metrics_mgr.add_metrics(dynamic_batching_metrics)
                            actor_train_metrics_refs = self.actor_train.train_step(batch, blocking=False)
                            actor_train_metrics: DataProto = DataProto.materialize_concat(
                                data_refs=actor_train_metrics_refs
                            )
                            metrics_mgr.add_reduced_metrics(actor_train_metrics.meta_info.pop("metrics", {}))

                    if self.pipeline_config.adv_estimator == "gae":
                        critic_train_metrics = DataProto.materialize_concat(data_refs=critic_train_metrics_refs)
                        metrics_mgr.add_reduced_metrics(critic_train_metrics.meta_info.pop("metrics", {}))

File: roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py (L487-501)

                        critic_train_metrics_refs: List[ray.ObjectRef] = self.critic.train_step(batch, blocking=False)

                    with actor_train_timer:
                        # implement critic warmup
                        if not hasattr(self, "critic") or self.pipeline_config.critic_warmup <= global_step:
                            # update actor
                            actor_train_metrics_refs = self.actor_train.train_step(batch, blocking=False)
                            actor_train_metrics: DataProto = DataProto.materialize_concat(
                                data_refs=actor_train_metrics_refs
                            )
                            metrics.update(reduce_metrics(actor_train_metrics.meta_info.pop("metrics", {})))

                    if self.pipeline_config.adv_estimator == "gae":
                        critic_train_metrics = DataProto.materialize_concat(data_refs=critic_train_metrics_refs)
                        metrics.update(reduce_metrics(critic_train_metrics.meta_info.pop("metrics", {})))

File: roll/pipeline/rlvr/rlvr_vlm_pipeline.py (L664-678)

                        critic_train_metrics_refs: List[ray.ObjectRef] = self.critic.train_step(batch, blocking=False)

                    with actor_train_timer:
                        # implement critic warmup
                        if self.pipeline_config.critic_warmup <= global_step:
                            # update actor
                            actor_train_metrics_refs = self.actor_train.train_step(batch, blocking=False)
                            actor_train_metrics: DataProto = DataProto.materialize_concat(
                                data_refs=actor_train_metrics_refs
                            )
                            metrics_mgr.add_reduced_metrics(actor_train_metrics.meta_info.pop("metrics", {}))

                    if self.pipeline_config.adv_estimator == "gae":
                        critic_train_metrics = DataProto.materialize_concat(data_refs=critic_train_metrics_refs)
                        metrics_mgr.add_reduced_metrics(critic_train_metrics.meta_info.pop("metrics", {}))

DataProto.materialize_concat 是一个静态方法,用于从 Ray ObjectRef 获取数据并连接成单一的 DataProto 实例。 1

实现原理

1. 数据获取

函数首先将输入标准化为 ObjectRef 列表,然后使用 ray.get() 从 Ray 分布式框架中获取实际的 DataProto 对象: 2

# 获取对象从 Ray
data: List["DataProto"] = ray.get(data_refs, timeout=timeout)

2. 数据连接

获取到实际的 DataProto 对象后,调用 DataProto.concat() 静态方法进行连接: 3

DataProto.concat() 的实现包括: 4

  • 张量连接:使用 torch.cat() 连接所有张量数据 5
  • 非张量数据连接:使用 custom_np_concatenate() 连接 numpy 数组 6
  • 元信息聚合:对 meta_info 中的全局键(如 “metrics”)进行跨 rank 聚合 7

3. 全局键处理

对于 global_keys 指定的元信息字段,会进行特殊处理: 8

  • 如果值是字典,将每个子键跨 rank 连接
  • 如果值是标量,收集到列表中
  • 其他键保持 rank 0 的值

使用场景

在训练管道中,materialize_concat 用于收集异步执行的结果:

# 启动异步训练
train_metrics_refs = self.actor_train.train_step(batch, blocking=False)
# 等待并收集结果
train_metrics = DataProto.materialize_concat(data_refs=train_metrics_refs)

这种设计实现了"异步启动、同步收集"的模式,充分利用了分布式计算资源。 9

Notes

materialize_concat 支持超时配置,通过环境变量 roll_RPC_TIMEOUT 设置 10 。它还支持 ObjectRefWrap 类型,可以选择性地收集部分结果 11

Wiki pages you might want to explore:

Citations

File: roll/distributed/scheduler/protocol.py (L647-714)

    def concat(
            data: List["DataProto"],
            *,
            global_keys: Optional[Set[str]] = None,
    ) -> "DataProto":
        """
        Concatenate a list of DataProto objects.

        Parameters
        ----------
        data : List[DataProto]
            List of DataProto instances to be concatenated.
        global_keys : Set[str], optional
            Keys in `meta_info` that should be **aggregated across ranks**.
            - If the value is a dict, each sub-key is concatenated across ranks.
            - Otherwise, values are collected into a list.
            Keys not listed retain only the value from rank 0.

        Returns
        -------
        DataProto
            A new DataProto with concatenated tensors, non-tensor data,
            and processed meta information.
        """
        global_keys = global_keys if global_keys is not None else {"metrics"}

        # ---------- 1. Concatenate tensor / non-tensor batches ----------
        batch_lst = [d.batch for d in data if d.batch is not None]
        new_batch = torch.cat(batch_lst, dim=0) if batch_lst else None

        non_tensor_batch = list_of_dict_to_dict_of_list(
            [d.non_tensor_batch for d in data]
        )
        for k, v in non_tensor_batch.items():
            non_tensor_batch[k] = custom_np_concatenate(v)

        # ---------- 2. Aggregate meta information ----------
        merged_meta = dict(data[0].meta_info)  # start with rank-0 values

        for key in global_keys:
            if key not in merged_meta:
                continue

            values = [d.meta_info.get(key) for d in data]

            # Case 1: dict — aggregate each sub-key across ranks
            if isinstance(merged_meta[key], dict):
                sub_dict = list_of_dict_to_dict_of_list(values)
                for sub_key, sub_list in sub_dict.items():
                    try:
                        if np.isscalar(sub_list[0]):
                            sub_dict[sub_key] = np.array(sub_list).tolist()
                        else:
                            sub_dict[sub_key] = np.concatenate(sub_list, axis=0).tolist()
                    except Exception:
                        # fallback: keep as list
                        sub_dict[sub_key] = sub_list
                merged_meta[key] = sub_dict

            # Case 2: non-dict — collect into list
            else:
                merged_meta[key] = values

        return DataProto(
            batch=new_batch,
            non_tensor_batch=non_tensor_batch,
            meta_info=merged_meta,
        )

File: roll/distributed/scheduler/protocol.py (L817-857)

    def materialize_concat(
            data_refs: Union[List[ray.ObjectRef], ray.ObjectRef, List["ObjectRefWrap"]],
            *,
            global_keys: Optional[Set[str]] = None,
    ) -> "DataProto":
        """
        Fetch a collection of DataProto objects from Ray ObjectRef(s) and concatenate
        them into a single DataProto instance.

        Parameters
        ----------
        data_refs : Union[List[ray.ObjectRef], ray.ObjectRef, List[ObjectRefWrap]]
            Ray object references (or ObjectRefWrap) pointing to DataProto objects.
        global_keys : Optional[Set[str]], optional
            Keys in ``meta_info`` that should be aggregated across all ranks when
            concatenating.  If None, only rank-0 values are kept for all keys.

        Returns
        -------
        DataProto
            The concatenated DataProto instance.
        """
        # Normalize input to List[<reference>]
        if isinstance(data_refs, DataProto):
            data_refs = [data_refs]

        timeout = None
        if "roll_RPC_TIMEOUT" in os.environ:
            timeout = int(os.environ["roll_RPC_TIMEOUT"])

        # Fetch objects from Ray
        if isinstance(data_refs[0], ObjectRefWrap):
            data_refs: List[ObjectRefWrap]
            obj_refs = [ref.obj_ref for ref in data_refs]
            fetched = ray.get(obj_refs, timeout=timeout)
            data = [fetched[i] for i, ref in enumerate(data_refs) if ref.collected]
        else:
            data: List["DataProto"] = ray.get(data_refs, timeout=timeout)

        # Concatenate and apply global aggregation rules
        return DataProto.concat(data, global_keys=global_keys)

File: roll/pipeline/agentic/agentic_pipeline.py (L285-291)

                            actor_train_metrics_refs = self.actor_train.train_step(batch, blocking=False)
                            actor_train_metrics: DataProto = DataProto.materialize_concat(data_refs=actor_train_metrics_refs)
                            metrics.update(reduce_metrics(actor_train_metrics.meta_info.pop("metrics", {})))

                        if self.pipeline_config.adv_estimator == "gae":
                            critic_train_metrics = DataProto.materialize_concat(data_refs=critic_train_metrics_refs)
                            metrics.update(reduce_metrics(critic_train_metrics.meta_info.pop("metrics", {})))
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐