【RL】ROLL critic与train actor 的关系。 critic需要warm up
需要注意的是,虽然训练是并发启动的,但实际的并行度取决于可用的 GPU 资源和配置的并行策略(如数据并行、张量并行等)。是的,critic 和 actor_train 是同时进行训练的。是一个静态方法,用于从 Ray ObjectRef 获取数据并连接成单一的 DataProto 实例。)确保在训练初期 critic 先学习价值函数,然后再开始 actor 的策略更新。这种设计实现了"异步启动、同
简短回答
是的,critic 和 actor_train 是同时进行训练的。在所有 ROLL 管道实现中,两者都使用 blocking=False 参数异步启动训练步骤,可以并行执行。 1 2
详细说明
训练执行模式
在 ROLL 框架的所有主要管道中,critic 和 actor_train 的训练都是并发执行的:
- 异步启动: 两者都使用
blocking=False参数启动训练步骤 3 - 并行执行: 由于使用了 Ray 的分布式执行框架,这两个训练步骤可以在不同的 GPU 或节点上并行运行
- 统一收集: 训练指标稍后通过
DataProto.materialize_concat统一收集 4
具体实现模式
所有管道都遵循相同的执行模式:
# 启动 critic 训练(异步)
critic_train_metrics_refs = self.critic.train_step(batch, blocking=False)
# 检查 critic 预热条件后启动 actor 训练(异步)
if self.pipeline_config.critic_warmup <= global_step:
actor_train_metrics_refs = self.actor_train.train_step(batch, blocking=False)
# 收集两者的训练指标
critic_train_metrics = DataProto.materialize_concat(data_refs=critic_train_metrics_refs)
actor_train_metrics = DataProto.materialize_concat(data_refs=actor_train_metrics_refs)
管道一致性
这种并发训练模式在所有主要管道中保持一致:
设计优势
这种并发设计提供了以下优势:
- 提高资源利用率: critic 和 actor 可以在不同的 GPU 上同时训练
- 减少训练时间: 避免了串行执行的时间开销
- 保持同步: 通过统一的指标收集确保训练步骤的一致性
Notes
需要注意的是,虽然训练是并发启动的,但实际的并行度取决于可用的 GPU 资源和配置的并行策略(如数据并行、张量并行等)。critic 预热机制(critic_warmup)确保在训练初期 critic 先学习价值函数,然后再开始 actor 的策略更新。
Wiki pages you might want to explore:
Citations
File: roll/pipeline/agentic/agentic_pipeline.py (L280-291)
critic_train_metrics_refs: List[ray.ObjectRef] = self.critic.train_step(batch, blocking=False)
# implement critic warmup
if self.pipeline_config.critic_warmup <= global_step:
# update actor
actor_train_metrics_refs = self.actor_train.train_step(batch, blocking=False)
actor_train_metrics: DataProto = DataProto.materialize_concat(data_refs=actor_train_metrics_refs)
metrics.update(reduce_metrics(actor_train_metrics.meta_info.pop("metrics", {})))
if self.pipeline_config.adv_estimator == "gae":
critic_train_metrics = DataProto.materialize_concat(data_refs=critic_train_metrics_refs)
metrics.update(reduce_metrics(critic_train_metrics.meta_info.pop("metrics", {})))
File: roll/pipeline/rlvr/rlvr_pipeline.py (L703-726)
critic_train_metrics_refs: List[ray.ObjectRef] = self.critic.train_step(batch, blocking=False)
with actor_train_timer:
# implement critic warmup
if self.pipeline_config.critic_warmup <= global_step:
# update actor
if self.pipeline_config.actor_train.use_dynamic_batching_in_train:
batch, dynamic_batching_metrics = dynamic_batching_shard(
batch,
self.actor_train.dp_size,
self.pipeline_config.actor_train.max_tokens_per_microbatch_in_train,
self.pipeline_config.actor_train.sequence_length_round_in_train,
"actor_train/train_step",
)
metrics_mgr.add_metrics(dynamic_batching_metrics)
actor_train_metrics_refs = self.actor_train.train_step(batch, blocking=False)
actor_train_metrics: DataProto = DataProto.materialize_concat(
data_refs=actor_train_metrics_refs
)
metrics_mgr.add_reduced_metrics(actor_train_metrics.meta_info.pop("metrics", {}))
if self.pipeline_config.adv_estimator == "gae":
critic_train_metrics = DataProto.materialize_concat(data_refs=critic_train_metrics_refs)
metrics_mgr.add_reduced_metrics(critic_train_metrics.meta_info.pop("metrics", {}))
File: roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py (L487-501)
critic_train_metrics_refs: List[ray.ObjectRef] = self.critic.train_step(batch, blocking=False)
with actor_train_timer:
# implement critic warmup
if not hasattr(self, "critic") or self.pipeline_config.critic_warmup <= global_step:
# update actor
actor_train_metrics_refs = self.actor_train.train_step(batch, blocking=False)
actor_train_metrics: DataProto = DataProto.materialize_concat(
data_refs=actor_train_metrics_refs
)
metrics.update(reduce_metrics(actor_train_metrics.meta_info.pop("metrics", {})))
if self.pipeline_config.adv_estimator == "gae":
critic_train_metrics = DataProto.materialize_concat(data_refs=critic_train_metrics_refs)
metrics.update(reduce_metrics(critic_train_metrics.meta_info.pop("metrics", {})))
File: roll/pipeline/rlvr/rlvr_vlm_pipeline.py (L664-678)
critic_train_metrics_refs: List[ray.ObjectRef] = self.critic.train_step(batch, blocking=False)
with actor_train_timer:
# implement critic warmup
if self.pipeline_config.critic_warmup <= global_step:
# update actor
actor_train_metrics_refs = self.actor_train.train_step(batch, blocking=False)
actor_train_metrics: DataProto = DataProto.materialize_concat(
data_refs=actor_train_metrics_refs
)
metrics_mgr.add_reduced_metrics(actor_train_metrics.meta_info.pop("metrics", {}))
if self.pipeline_config.adv_estimator == "gae":
critic_train_metrics = DataProto.materialize_concat(data_refs=critic_train_metrics_refs)
metrics_mgr.add_reduced_metrics(critic_train_metrics.meta_info.pop("metrics", {}))
DataProto.materialize_concat 是一个静态方法,用于从 Ray ObjectRef 获取数据并连接成单一的 DataProto 实例。 1
实现原理
1. 数据获取
函数首先将输入标准化为 ObjectRef 列表,然后使用 ray.get() 从 Ray 分布式框架中获取实际的 DataProto 对象: 2
# 获取对象从 Ray
data: List["DataProto"] = ray.get(data_refs, timeout=timeout)
2. 数据连接
获取到实际的 DataProto 对象后,调用 DataProto.concat() 静态方法进行连接: 3
DataProto.concat() 的实现包括: 4
- 张量连接:使用
torch.cat()连接所有张量数据 5 - 非张量数据连接:使用
custom_np_concatenate()连接 numpy 数组 6 - 元信息聚合:对
meta_info中的全局键(如 “metrics”)进行跨 rank 聚合 7
3. 全局键处理
对于 global_keys 指定的元信息字段,会进行特殊处理: 8
- 如果值是字典,将每个子键跨 rank 连接
- 如果值是标量,收集到列表中
- 其他键保持 rank 0 的值
使用场景
在训练管道中,materialize_concat 用于收集异步执行的结果:
# 启动异步训练
train_metrics_refs = self.actor_train.train_step(batch, blocking=False)
# 等待并收集结果
train_metrics = DataProto.materialize_concat(data_refs=train_metrics_refs)
这种设计实现了"异步启动、同步收集"的模式,充分利用了分布式计算资源。 9
Notes
materialize_concat 支持超时配置,通过环境变量 roll_RPC_TIMEOUT 设置 10 。它还支持 ObjectRefWrap 类型,可以选择性地收集部分结果 11 。
Wiki pages you might want to explore:
Citations
File: roll/distributed/scheduler/protocol.py (L647-714)
def concat(
data: List["DataProto"],
*,
global_keys: Optional[Set[str]] = None,
) -> "DataProto":
"""
Concatenate a list of DataProto objects.
Parameters
----------
data : List[DataProto]
List of DataProto instances to be concatenated.
global_keys : Set[str], optional
Keys in `meta_info` that should be **aggregated across ranks**.
- If the value is a dict, each sub-key is concatenated across ranks.
- Otherwise, values are collected into a list.
Keys not listed retain only the value from rank 0.
Returns
-------
DataProto
A new DataProto with concatenated tensors, non-tensor data,
and processed meta information.
"""
global_keys = global_keys if global_keys is not None else {"metrics"}
# ---------- 1. Concatenate tensor / non-tensor batches ----------
batch_lst = [d.batch for d in data if d.batch is not None]
new_batch = torch.cat(batch_lst, dim=0) if batch_lst else None
non_tensor_batch = list_of_dict_to_dict_of_list(
[d.non_tensor_batch for d in data]
)
for k, v in non_tensor_batch.items():
non_tensor_batch[k] = custom_np_concatenate(v)
# ---------- 2. Aggregate meta information ----------
merged_meta = dict(data[0].meta_info) # start with rank-0 values
for key in global_keys:
if key not in merged_meta:
continue
values = [d.meta_info.get(key) for d in data]
# Case 1: dict — aggregate each sub-key across ranks
if isinstance(merged_meta[key], dict):
sub_dict = list_of_dict_to_dict_of_list(values)
for sub_key, sub_list in sub_dict.items():
try:
if np.isscalar(sub_list[0]):
sub_dict[sub_key] = np.array(sub_list).tolist()
else:
sub_dict[sub_key] = np.concatenate(sub_list, axis=0).tolist()
except Exception:
# fallback: keep as list
sub_dict[sub_key] = sub_list
merged_meta[key] = sub_dict
# Case 2: non-dict — collect into list
else:
merged_meta[key] = values
return DataProto(
batch=new_batch,
non_tensor_batch=non_tensor_batch,
meta_info=merged_meta,
)
File: roll/distributed/scheduler/protocol.py (L817-857)
def materialize_concat(
data_refs: Union[List[ray.ObjectRef], ray.ObjectRef, List["ObjectRefWrap"]],
*,
global_keys: Optional[Set[str]] = None,
) -> "DataProto":
"""
Fetch a collection of DataProto objects from Ray ObjectRef(s) and concatenate
them into a single DataProto instance.
Parameters
----------
data_refs : Union[List[ray.ObjectRef], ray.ObjectRef, List[ObjectRefWrap]]
Ray object references (or ObjectRefWrap) pointing to DataProto objects.
global_keys : Optional[Set[str]], optional
Keys in ``meta_info`` that should be aggregated across all ranks when
concatenating. If None, only rank-0 values are kept for all keys.
Returns
-------
DataProto
The concatenated DataProto instance.
"""
# Normalize input to List[<reference>]
if isinstance(data_refs, DataProto):
data_refs = [data_refs]
timeout = None
if "roll_RPC_TIMEOUT" in os.environ:
timeout = int(os.environ["roll_RPC_TIMEOUT"])
# Fetch objects from Ray
if isinstance(data_refs[0], ObjectRefWrap):
data_refs: List[ObjectRefWrap]
obj_refs = [ref.obj_ref for ref in data_refs]
fetched = ray.get(obj_refs, timeout=timeout)
data = [fetched[i] for i, ref in enumerate(data_refs) if ref.collected]
else:
data: List["DataProto"] = ray.get(data_refs, timeout=timeout)
# Concatenate and apply global aggregation rules
return DataProto.concat(data, global_keys=global_keys)
File: roll/pipeline/agentic/agentic_pipeline.py (L285-291)
actor_train_metrics_refs = self.actor_train.train_step(batch, blocking=False)
actor_train_metrics: DataProto = DataProto.materialize_concat(data_refs=actor_train_metrics_refs)
metrics.update(reduce_metrics(actor_train_metrics.meta_info.pop("metrics", {})))
if self.pipeline_config.adv_estimator == "gae":
critic_train_metrics = DataProto.materialize_concat(data_refs=critic_train_metrics_refs)
metrics.update(reduce_metrics(critic_train_metrics.meta_info.pop("metrics", {})))
更多推荐



所有评论(0)