verl - vllm_rollout_spmd.py 详解
rollout 是强化学习中非常重要的环节。我们这里研究对于多模态VQA任务,rollout 过程中究竟发生了什么。
·
1. 背景
rollout 是强化学习中非常重要的环节。我们这里研究对于多模态VQA任务,rollout 过程中究竟发生了什么。
2. 重要packages 版本
verl: 0.2.0.dev0
# 注意:基于Qwen2.5-VL 跑多模态任务,vllm需要0.8.2或以上版本, 否则会报错
vllm: 0.8.2
3. 当前关注的代码
# 注意:跑多模态任务, vllm_rollout 调用的出处如下,而非vllm_rollout.py !!!
# 这里的sp 意味着sequence parallelism
verl/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py
4. DataProto分析
# 核心方法 generate_sequences
@torch.no_grad()
def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
'''
DataProto包含:
batch: TensorDict (input_ids, attention_mask, position_ids)。
meta_info: 如eos_token_id、do_sample(是否采样)
'''
# rebuild vllm cache engine
if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine:
self.inference_engine.init_cache_engine() # 重建缓存引擎
print('prompts: ', prompts)
我们打印出来看看prompts包含什么呢?从上述代码来看,输入到 generate_sequences 的prompts是一个DataProto 类,这个是verl团队自行定义的数据管理类。
DataProto(
# batch=prompts.batch
batch=TensorDict(
# 可见 当前 bz =406, max_prompt_seq = 8192
fields={
attention_mask: Tensor(shape=torch.Size([406, 8192]), device=cuda:0, dtype=torch.int64, is_shared=True),
input_ids: Tensor(shape=torch.Size([406, 8192]), device=cuda:0, dtype=torch.int64, is_shared=True),
position_ids: Tensor(shape=torch.Size([406, 8192]), device=cuda:0, dtype=torch.int64, is_shared=True)
},
# batch_size = prompts.batch_size
batch_size=torch.Size([406]), device=cuda:0, is_shared=True),
# non_tensor_batch = prompts.non_tensor_batch
non_tensor_batch = {
# raw_prompt_ids包含 bz个没有任何padding 的最原始的 prompt_ids
# non_tensor_batch['raw_prompt_ids'] 是一个numpy array 数组,每个元素是一个List, 表示了一个prompt的原始分词结果
'raw_prompt_ids':
array([list([151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 151652, 151655, 151653, 20848, 279, 2701, 6396, 15941, 11, 279, 2550, 315, 26107, 51, 3186, 3, 5973, 624, 2610, 34813, 1744, 911, 279, 32711, 1882, 438, 458, 5306, 1615, 76728, 323, 1221, 3410, 279, 1590, 4226, 13, 576, 32711, 1882, 27732, 7206, 43810, 2878, 366, 26865, 29, 690, 26865, 29, 9492, 13, 576, 1590, 4226, 27732, 7206, 304, 366, 9217, 29, 690, 9217, 29, 9492, 13, 151645, 198, 151644, 77091, 198]), list([151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 151652, 151655, 151653, 16246, 279, 5128, 26107, 88, 15159, 16, 51185, 18085, 38334, 14085, 323, 26107, 88, 15159, 17, 51185, 706, 35093, 14085, 304, 279, 80715, 16184, 1849, 438, 6839, 304, 279, 7071, 11, 279, 6291, 738, 315, 279, 31205, 26107, 18085, 38334, 59, 273, 80, 3226, 517, 3859, 35093, 14085, 374, 7436, 624, 2610, 34813, 1744, 911, 279, 32711, 1882, 438, 458, 5306, 1615, 76728, 323, 1221, 3410, 279, 1590, 4226, 13, 576, 32711, 1882, 27732, 7206, 43810, 2878, 366, 26865, 29, 690, 26865, 29, 9492, 13, 576, 1590, 4226, 27732, 7206, 304, 366, 9217, 29, 690, 9217, 29, 9492, 13, 151645, 198, 151644, 77091, 198]), ...], dtype=object),
# multi_modal_data包含 bz个 Image.PIL 对象
# non_tensor_batch['multi_modal_data'] 是一个numpy array 数组,每个元素是一个Dict。每个Dict中的‘image’ key 对应 A list of Image.PIL object
'multi_modal_data': array([{'image': [<PIL.Image.Image image mode=RGB size=544x501 at 0x71CAE71E62A0>]}, {'image': [<PIL.Image.Image image mode=RGB size=178x149 at 0x71CAE71E44D0>]}, ... ], dtype=object),
# multi_modal_data包含 bz个 vision_ids 对象
# non_tensor_batch['multi_modal_inputs'] 是一个numpy array 数组,每个元素是一个Dict。每个Dict中 'pixel_values' 对应了image tokens, 'image_grid_thw'对应该图片的维度
'multi_modal_inputs': array([{'pixel_values': tensor([[ 1.9303, 1.9303, 1.9303,..., 2.1459, 2.1459, 2.1459],[ 1.9303, 1.9303, 1.9303, ..., 2.1459, 2.1459, 2.1459]]), 'image_grid_thw': tensor([[ 1, 36, 38]])}, ....], dtype=object)
},
# meta_info = prompts.meta_info
meta_info={
'eos_token_id': [151645, 151643],
'pad_token_id': 151643,
'recompute_log_prob': False,
'do_sample': False,
'validate': True
}
Comments:
- prompts.batch[‘input_ids’] 可以得到统一padding 到max_prompt_length的 prompt ids
- prompts.non_tensor_batch[‘non_tensor_batch’][‘raw_prompt_ids’] 可以得到没经过padding的prompt ids
5. 核心代码 generate_sequences 解析
@torch.no_grad()
def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
'''
DataProto包含:
batch: TensorDict (input_ids, attention_mask, position_ids)。
meta_info: 如eos_token_id、do_sample(是否采样)
'''
# rebuild vllm cache engine
if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine:
self.inference_engine.init_cache_engine() # 重建缓存引擎
# 获得 left-padding (默认是左填充) 过的 prompt_ids, 维度 (406, 8192)
idx = prompts.batch['input_ids']
# 获得左填充的 attention_mask, position_ids, 维度 (406, 8192)
attention_mask = prompts.batch['attention_mask']
position_ids = prompts.batch['position_ids']
# 获得 eos_token_id, Qwen2.5VL对应 [151645, 151643(这个是pad token ID)]
eos_token_id = prompts.meta_info['eos_token_id']
batch_size = idx.size(0) # batch size 是406
# 从prompts DataProto获取 non_tensor_batch 对应部分
non_tensor_batch = prompts.non_tensor_batch
# non_tensor_batch 当中没有包含 raw_prompt_ids的信息
if 'raw_prompt_ids' not in non_tensor_batch:
# 则借用 _pre_process_inputs 去掉每个instance的left padding、
non_tensor_batch['raw_prompt_ids'] = np.array(
[_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object)
# 检查 non_tensor_batch['raw_prompt_ids'] 是否为bz 个
if batch_size != len(non_tensor_batch['raw_prompt_ids']):
raise RuntimeError('vllm sharding manager is not work properly.')
# 有多模态数据的情况
if 'multi_modal_data' in non_tensor_batch: # 多模态数据处理
# 用来存放 每个instance的(prompt, image)相关信息
# vllm_input 会在之后传入self.inference_engine.generate
vllm_inputs = []
for raw_prompt_ids, multi_modal_data in zip(non_tensor_batch.pop('raw_prompt_ids'),
non_tensor_batch.pop('multi_modal_data')):
# vllm_inputs 中每一个instance都是一个字典,'prompt_token_ids'对应原始prompt_ids, 'multi_modal_data'对应 Image.PIL 对象
vllm_inputs.append({'prompt_token_ids': raw_prompt_ids, 'multi_modal_data': multi_modal_data})
# 没有多模态数据的情况
else:
# non_tensor_batch.pop('raw_prompt_ids') 会直接得到List of raw_prompt_ids
vllm_inputs = [{
'prompt_token_ids': raw_prompt_ids
} for raw_prompt_ids in non_tensor_batch.pop('raw_prompt_ids')]
# 获取 do_sample 参数
do_sample = prompts.meta_info.get('do_sample', True)
# 获取 validate 参数, 默认为False
is_validate = prompts.meta_info.get('validate', False)
if not do_sample:
kwargs = {
'best_of': 1,
'top_p': 1.0,
'top_k': -1,
'min_p': 0.0,
'temperature': 0,
'n': 1 # if greedy, only 1 response
}
elif is_validate:
# TODO: try **
kwargs = {
'top_k': self.config.val_kwargs.top_k,
'top_p': self.config.val_kwargs.top_p,
'temperature': self.config.val_kwargs.temperature,
'n': 1, # if validate, already repeat in ray_trainer
}
# users can customize different sampling_params at different run
# 采样参数动态调整
with self.update_sampling_params(**kwargs): # 采用上下文管理器,临时修改采样参数(如验证时固定temperature=0)
outputs = self.inference_engine.generate(
prompts=vllm_inputs, # 将刚刚打包好的训练数据传入 vllm
sampling_params=self.sampling_params, # 传入sampling 参数
use_tqdm=False)
# TODO(sgm): disable logprob when recompute_log_prob is enable
# if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length)
response = []
for output in outputs: # outputs为生成结果的列表,每个元素包含token_ids和logprobs。
# 整个batch data 逐条生成的 token_ids 整理在 response当中
for sample_id in range(len(output.outputs)):
response.append(output.outputs[sample_id].token_ids)
# 对于生成的 response 统一填充到 self.config.response_length
response = pad_2d_list_to_length(response, self.pad_token_id,
max_length=self.config.response_length).to(idx.device)
'''
处理多样本生成(n > 1)时的输入数据扩展, 确保生成的多个样本能与原始输入数据正确对齐。
self.sampling_params.n: 对于每个prompt, 进行3次独立重复响应
为什么需要扩展?
vLLM的generate返回的多样本结果会平铺为(batch_size * n, ...),因此需要同步扩展输入数据以保持对齐。
'''
if self.sampling_params.n > 1 and do_sample: # 若符合条件,则开始扩充
# 将输入张量 idx/attention_mask/position_ids 沿第0维(batch维度)重复n次, 若原始idx形状为(batch_size=2, seq_len=5),n=3,则扩展后形状为(6, 5)。
idx = _repeat_interleave(idx, self.sampling_params.n)
attention_mask = _repeat_interleave(attention_mask, self.sampling_params.n)
position_ids = _repeat_interleave(position_ids, self.sampling_params.n)
# 扩展后的batch大小 = 原始batch大小 × 样本数n
batch_size = batch_size * self.sampling_params.n
# 处理多模态输入: 对于 non_tensor_batch['multi_modal_inputs'] 进行n倍扩展,之后存入non_tensor_batch
# 确保与文本输入一一对应。
if 'multi_modal_inputs' in non_tensor_batch.keys():
non_tensor_batch['multi_modal_inputs'] = _repeat_interleave(non_tensor_batch['multi_modal_inputs'],
self.sampling_params.n)
# 拼接prompt和response
seq = torch.cat([idx, response], dim=-1)
# 构造attention_mask和position_ids
response_length = response.size(1) # 获取生成的response的token长度
# 生成一个从 1 到 response_length 的递增序列,表示response部分每个token相对于prompt末尾的位置偏移。
# 示例:若response_length=4,生成 [1, 2, 3, 4]。
delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
delta_position_id = delta_position_id.unsqueeze(0).expand(batch_size, -1)
# 处理多维位置编码(如Qwen2VL-MRoPE)
if position_ids.dim() == 3: # qwen2vl mrope
# 某些模型(如Qwen2VL)使用多维位置编码(如3D张量),需调整delta_position_id的形状以匹配。
delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1)
# TODO(sgm): fix position_ids on right_pad
# prompt: left pad + response: right pad
# attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
'''
获取response的位置编码。
prompt_end=position_ids[:, -1:] 获取prompt最后一个token的位置(如prompt长度为5,则值为4)。
+ delta_position_id: 将response的每个token位置设置为 prompt_end + 1, prompt_end + 2, ...。
'''
response_position_ids = position_ids[:, -1:] + delta_position_id
position_ids = torch.cat([position_ids, response_position_ids], dim=-1) # 将prompt和response的position_ids拼接。
'''
生成一个与response形状相同的掩码,其中 1:有效token(EOS之前的部分), 0:填充token(EOS之后的部分)
response: [token1, token2, EOS, pad, pad] → 掩码:[1, 1, 1, 0, 0]
'''
response_attention_mask = get_eos_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
'''
拼接完整 attention_mask
Prompt部分: 通常是左填充(如[0, 0, 1, 1])。
Response部分: 右填充(如[1, 1, 0, 0])。
拼接后:[0, 0, 1, 1, 1, 1, 0, 0]
'''
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)
# all the tp ranks should contain the same data here. data in all ranks are valid
# 封装输出数据
batch = TensorDict(
{
'prompts': idx, # 原始prompt token IDs
'responses': response, # 生成的response token IDs
'input_ids': seq, # 完整序列(prompt + response)here input_ids become the whole sentences
# 'old_log_probs': log_probs, # we will recompute old log prob with actor
'attention_mask': attention_mask, # 完整注意力掩码
'position_ids': position_ids # 完整位置编码
},
batch_size=batch_size)
# free vllm cache engine
if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine:
self.inference_engine.free_cache_engine()
return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)
更多推荐
所有评论(0)