1. 背景

rollout 是强化学习中非常重要的环节。我们这里研究对于多模态VQA任务,rollout 过程中究竟发生了什么。

2. 重要packages 版本

verl: 0.2.0.dev0
# 注意:基于Qwen2.5-VL 跑多模态任务,vllm需要0.8.2或以上版本, 否则会报错
vllm:  0.8.2 

3. 当前关注的代码

# 注意:跑多模态任务, vllm_rollout 调用的出处如下,而非vllm_rollout.py !!!
# 这里的sp 意味着sequence parallelism
verl/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py

4. DataProto分析

# 核心方法 generate_sequences
    @torch.no_grad()
    def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
        '''
            DataProto包含:
                batch: TensorDict (input_ids, attention_mask, position_ids)。
                meta_info: 如eos_token_id、do_sample(是否采样)
        '''
        # rebuild vllm cache engine
        if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine:
            self.inference_engine.init_cache_engine() # 重建缓存引擎

        print('prompts: ', prompts)

我们打印出来看看prompts包含什么呢?从上述代码来看,输入到 generate_sequences 的prompts是一个DataProto 类,这个是verl团队自行定义的数据管理类。

DataProto(
	# batch=prompts.batch 
	batch=TensorDict(
		# 可见 当前 bz =406, max_prompt_seq = 8192
		fields={
				attention_mask: Tensor(shape=torch.Size([406, 8192]), device=cuda:0, dtype=torch.int64, is_shared=True),
				input_ids: Tensor(shape=torch.Size([406, 8192]), device=cuda:0, dtype=torch.int64, is_shared=True),
				position_ids: Tensor(shape=torch.Size([406, 8192]), device=cuda:0, dtype=torch.int64, is_shared=True)
			},
		
	# batch_size = prompts.batch_size
	batch_size=torch.Size([406]), device=cuda:0, is_shared=True),
	# non_tensor_batch = prompts.non_tensor_batch
	non_tensor_batch = {
		# raw_prompt_ids包含 bz个没有任何padding 的最原始的 prompt_ids
		# non_tensor_batch['raw_prompt_ids'] 是一个numpy array 数组,每个元素是一个List, 表示了一个prompt的原始分词结果
		'raw_prompt_ids': 
			array([list([151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 151652, 151655, 151653, 20848, 279, 2701, 6396, 15941, 11, 279, 2550, 315, 26107, 51, 3186, 3, 5973, 624, 2610, 34813, 1744, 911, 279, 32711, 1882, 438, 458, 5306, 1615, 76728, 323, 1221, 3410, 279, 1590, 4226, 13, 576, 32711, 1882, 27732, 7206, 43810, 2878, 366, 26865, 29, 690, 26865, 29, 9492, 13, 576, 1590, 4226, 27732, 7206, 304, 366, 9217, 29, 690, 9217, 29, 9492, 13, 151645, 198, 151644, 77091, 198]), list([151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 151652, 151655, 151653, 16246, 279, 5128, 26107, 88, 15159, 16, 51185, 18085, 38334, 14085, 323, 26107, 88, 15159, 17, 51185, 706, 35093, 14085, 304, 279, 80715, 16184, 1849, 438, 6839, 304, 279, 7071, 11, 279, 6291, 738, 315, 279, 31205, 26107, 18085, 38334, 59, 273, 80, 3226, 517, 3859, 35093, 14085, 374, 7436, 624, 2610, 34813, 1744, 911, 279, 32711, 1882, 438, 458, 5306, 1615, 76728, 323, 1221, 3410, 279, 1590, 4226, 13, 576, 32711, 1882, 27732, 7206, 43810, 2878, 366, 26865, 29, 690, 26865, 29, 9492, 13, 576, 1590, 4226, 27732, 7206, 304, 366, 9217, 29, 690, 9217, 29, 9492, 13, 151645, 198, 151644, 77091, 198]), ...], dtype=object),
		# multi_modal_data包含 bz个 Image.PIL 对象 
		# non_tensor_batch['multi_modal_data'] 是一个numpy array 数组,每个元素是一个Dict。每个Dict中的‘image’ key 对应 A list of Image.PIL object
		'multi_modal_data': array([{'image': [<PIL.Image.Image image mode=RGB size=544x501 at 0x71CAE71E62A0>]},  {'image': [<PIL.Image.Image image mode=RGB size=178x149 at 0x71CAE71E44D0>]}, ... ], dtype=object),
		# multi_modal_data包含 bz个 vision_ids 对象 
		# non_tensor_batch['multi_modal_inputs'] 是一个numpy array 数组,每个元素是一个Dict。每个Dict中 'pixel_values' 对应了image tokens, 'image_grid_thw'对应该图片的维度
		'multi_modal_inputs': array([{'pixel_values':  tensor([[ 1.9303,  1.9303,  1.9303,..., 2.1459,  2.1459,  2.1459],[ 1.9303,  1.9303,  1.9303,  ...,  2.1459,  2.1459,  2.1459]]),  'image_grid_thw': tensor([[ 1, 36, 38]])}, ....], dtype=object)
		}, 
	# meta_info = prompts.meta_info 
	meta_info={
		'eos_token_id': [151645, 151643], 
		'pad_token_id': 151643, 
		'recompute_log_prob': False, 
		'do_sample': False, 
		'validate': True
	}	

Comments:

  • prompts.batch[‘input_ids’] 可以得到统一padding 到max_prompt_length的 prompt ids
  • prompts.non_tensor_batch[‘non_tensor_batch’][‘raw_prompt_ids’] 可以得到没经过padding的prompt ids

5. 核心代码 generate_sequences 解析

	@torch.no_grad()
    def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
        '''
            DataProto包含:
                batch: TensorDict (input_ids, attention_mask, position_ids)。
                meta_info: 如eos_token_id、do_sample(是否采样)
        '''
        # rebuild vllm cache engine
        if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine:
            self.inference_engine.init_cache_engine() # 重建缓存引擎
		
		# 获得 left-padding (默认是左填充) 过的 prompt_ids,  维度 (406, 8192)
        idx = prompts.batch['input_ids']  
        # 获得左填充的 attention_mask, position_ids, 维度 (406, 8192)
        attention_mask = prompts.batch['attention_mask']
        position_ids = prompts.batch['position_ids'] 

        # 获得 eos_token_id, Qwen2.5VL对应  [151645, 151643(这个是pad token ID)]
        eos_token_id = prompts.meta_info['eos_token_id']
        batch_size = idx.size(0) # batch size 是406

        # 从prompts DataProto获取 non_tensor_batch 对应部分
        non_tensor_batch = prompts.non_tensor_batch
        # non_tensor_batch 当中没有包含 raw_prompt_ids的信息
        if 'raw_prompt_ids' not in non_tensor_batch:
        	# 则借用 _pre_process_inputs 去掉每个instance的left padding、
            non_tensor_batch['raw_prompt_ids'] = np.array(
                [_pre_process_inputs(self.pad_token_id, idx[i]) for i in range(batch_size)], dtype=object)

        # 检查 non_tensor_batch['raw_prompt_ids'] 是否为bz 个
        if batch_size != len(non_tensor_batch['raw_prompt_ids']):
            raise RuntimeError('vllm sharding manager is not work properly.')

        # 有多模态数据的情况
        if 'multi_modal_data' in non_tensor_batch: # 多模态数据处理
        	# 用来存放 每个instance的(prompt, image)相关信息
        	# vllm_input 会在之后传入self.inference_engine.generate
            vllm_inputs = [] 
            for raw_prompt_ids, multi_modal_data in zip(non_tensor_batch.pop('raw_prompt_ids'),
                                                        non_tensor_batch.pop('multi_modal_data')):
                # vllm_inputs 中每一个instance都是一个字典,'prompt_token_ids'对应原始prompt_ids, 'multi_modal_data'对应 Image.PIL 对象
                vllm_inputs.append({'prompt_token_ids': raw_prompt_ids, 'multi_modal_data': multi_modal_data})
        # 没有多模态数据的情况
        else:
            # non_tensor_batch.pop('raw_prompt_ids') 会直接得到List of raw_prompt_ids
            vllm_inputs = [{
                'prompt_token_ids': raw_prompt_ids
            } for raw_prompt_ids in non_tensor_batch.pop('raw_prompt_ids')]

        # 获取 do_sample 参数
        do_sample = prompts.meta_info.get('do_sample', True)
        # 获取 validate 参数, 默认为False
        is_validate = prompts.meta_info.get('validate', False)
        if not do_sample:
            kwargs = {
                'best_of': 1,
                'top_p': 1.0,
                'top_k': -1,
                'min_p': 0.0,
                'temperature': 0,
                'n': 1  # if greedy, only 1 response
            }
        elif is_validate:
            # TODO: try **
            kwargs = {
                'top_k': self.config.val_kwargs.top_k,
                'top_p': self.config.val_kwargs.top_p,
                'temperature': self.config.val_kwargs.temperature,
                'n': 1,  # if validate, already repeat in ray_trainer
            }

        # users can customize different sampling_params at different run
        # 采样参数动态调整
        with self.update_sampling_params(**kwargs): # 采用上下文管理器,临时修改采样参数(如验证时固定temperature=0)
            outputs = self.inference_engine.generate(
                prompts=vllm_inputs,  # 将刚刚打包好的训练数据传入 vllm
                sampling_params=self.sampling_params, # 传入sampling 参数
                use_tqdm=False)

            # TODO(sgm): disable logprob when recompute_log_prob is enable
            # if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length)

            response = []
            for output in outputs: # outputs为生成结果的列表,每个元素包含token_ids和logprobs。
                # 整个batch data 逐条生成的 token_ids 整理在 response当中
                for sample_id in range(len(output.outputs)):
                    response.append(output.outputs[sample_id].token_ids)

            # 对于生成的 response 统一填充到 self.config.response_length
            response = pad_2d_list_to_length(response, self.pad_token_id,
                                             max_length=self.config.response_length).to(idx.device)

            '''
                处理多样本生成(n > 1)时的输入数据扩展, 确保生成的多个样本能与原始输入数据正确对齐。
                self.sampling_params.n: 对于每个prompt, 进行3次独立重复响应
                为什么需要扩展?
                   vLLM的generate返回的多样本结果会平铺为(batch_size * n, ...),因此需要同步扩展输入数据以保持对齐。 
            '''
            if self.sampling_params.n > 1 and do_sample: # 若符合条件,则开始扩充
                # 将输入张量 idx/attention_mask/position_ids 沿第0维(batch维度)重复n次, 若原始idx形状为(batch_size=2, seq_len=5),n=3,则扩展后形状为(6, 5)。
                idx = _repeat_interleave(idx, self.sampling_params.n) 
                attention_mask = _repeat_interleave(attention_mask, self.sampling_params.n)
                position_ids = _repeat_interleave(position_ids, self.sampling_params.n)
                # 扩展后的batch大小 = 原始batch大小 × 样本数n
                batch_size = batch_size * self.sampling_params.n
                # 处理多模态输入: 对于 non_tensor_batch['multi_modal_inputs'] 进行n倍扩展,之后存入non_tensor_batch
                # 确保与文本输入一一对应。
                if 'multi_modal_inputs' in non_tensor_batch.keys():
                    non_tensor_batch['multi_modal_inputs'] = _repeat_interleave(non_tensor_batch['multi_modal_inputs'],
                                                                                self.sampling_params.n)
            # 拼接prompt和response
            seq = torch.cat([idx, response], dim=-1)
        
        # 构造attention_mask和position_ids
        response_length = response.size(1)  # 获取生成的response的token长度
        # 生成一个从 1 到 response_length 的递增序列,表示response部分每个token相对于prompt末尾的位置偏移。
        # 示例:若response_length=4,生成 [1, 2, 3, 4]。
        delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
        delta_position_id = delta_position_id.unsqueeze(0).expand(batch_size, -1)
        # 处理多维位置编码(如Qwen2VL-MRoPE)
        if position_ids.dim() == 3:  # qwen2vl mrope
            # 某些模型(如Qwen2VL)使用多维位置编码(如3D张量),需调整delta_position_id的形状以匹配。
            delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1)
        # TODO(sgm): fix position_ids on right_pad
        # prompt: left pad + response: right pad
        # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
        # position_ids:   [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
        
        '''
            获取response的位置编码。 
         prompt_end=position_ids[:, -1:] 获取prompt最后一个token的位置(如prompt长度为5,则值为4)。
         + delta_position_id: 将response的每个token位置设置为 prompt_end + 1, prompt_end + 2, ...。
        '''
        response_position_ids = position_ids[:, -1:] + delta_position_id
        position_ids = torch.cat([position_ids, response_position_ids], dim=-1) # 将prompt和response的position_ids拼接。
        
        '''
            生成一个与response形状相同的掩码,其中 1:有效token(EOS之前的部分), 0:填充token(EOS之后的部分)
            response: [token1, token2, EOS, pad, pad] → 掩码:[1, 1, 1, 0, 0]
        '''
        response_attention_mask = get_eos_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype)
        '''
             拼接完整 attention_mask
             Prompt部分: 通常是左填充(如[0, 0, 1, 1])。
             Response部分: 右填充(如[1, 1, 0, 0])。
             拼接后:[0, 0, 1, 1, 1, 1, 0, 0]
        '''
        attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)

        # all the tp ranks should contain the same data here. data in all ranks are valid
        #  封装输出数据
        batch = TensorDict(
            {
                'prompts': idx, # 原始prompt token IDs
                'responses': response, # 生成的response token IDs
                'input_ids': seq,   # 完整序列(prompt + response)here input_ids become the whole sentences
                # 'old_log_probs': log_probs, # we will recompute old log prob with actor
                'attention_mask': attention_mask, # 完整注意力掩码
                'position_ids': position_ids # 完整位置编码
            },
            batch_size=batch_size)

        # free vllm cache engine 
        if vllm_version in ('0.3.1', '0.4.2', '0.5.4', '0.6.3') and self.config.free_cache_engine:
            self.inference_engine.free_cache_engine()

        return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐