作者:昇腾实战派 x 哒妮滋
关注公众号:AI模力圈

背景概述

WeMM 是 WeChatCV 推出的最新一代多模态大语言模型。WeMM 具备动态高分辨率图片下的中英双语对话能力,在多模态大语言模型的榜单中是百亿参数级别最强模型,整体测评结果(Avg Rank)位居第一梯队。本文记录了将WeMM多模态大模型适配到MindIE-LLM推理框架的完整过程,迁移过程中重点解决了模型结构分析、权重转换、Embedding融合和服务化对接等关键技术挑战。

模型仓库

视觉大语言模型

模型 日期 下载 摘要
WeMM-Chat-2K-CN 2024.06.27 🤗 HF link 🚀🚀支持2K分辨率的图片输入
WeMM-Chat-CN 2024.06.21 🤗 HF link 🚀🚀加强了中英双语对话能力
WeMM-1.2(VL) 2024.06.09 🤗 HF link 🚀🚀在多模态大语言模型全部基准测试中的百亿参数模型里排名第一

软硬件环境

软硬件 版本
服务器 Atlas 800I A2
MindIE T71
CANN T63
Pytorch 2.1.0

推理流程梳理

image-20250115100015995

迁移步骤

image-20250115102046736

分析模型结构

image-20250114180849093

【分析】:对比MindIE-LLM库中已有的模型结构可知,其中的InternLM-Xcomposer2(浦语·灵笔2)相似,都是用lora微调后的模型,且模型权重中都单独包含了lora权重,而且底层的LLM都为intermlm2。只有ViT部分用的是不同的模型,因此可以复用该模型的框架。

框架接入

WeMMRunner类,继承MultimodalPARunner,重写init_processorprecision_saveinfer方法

class WeMMRunner(MultimodalPARunner):
    def __init__(self, **kwargs):
        self.processor = None
        super().__init__(**kwargs)
        self.adapter_id = kwargs.get("lora_adapter_id", None)

    def init_processor(self):
        self.processor = self.model.tokenizer

    def precision_save(self, precision_inputs, **kwargs):
        all_generate_text_list = precision_inputs.all_generate_text_list
        image_file_list = precision_inputs.image_file_list
        image_answer_pairs = {}
        for image_file, generate_text in zip(image_file_list, all_generate_text_list):
            image_answer_pairs[image_file] = generate_text
        image_answer_pairs = dict(sorted(image_answer_pairs.items()))
        super().precision_save(precision_inputs, answer_pairs=image_answer_pairs)

    def infer(self, mm_inputs, batch_size, max_output_length, ignore_eos, **kwargs):
        input_texts = mm_inputs.input_texts
        image_path_list = mm_inputs.image_path
        if len(input_texts) != len(image_path_list):
            raise RuntimeError("input_text length must equal input_images length")
        if not ENV.profiling_enable:
            if self.max_batch_size > 0:
                max_iters = math.ceil(len(mm_inputs.image_path) / self.max_batch_size)
            else:
                raise RuntimeError("f{self.max_batch_size} max_batch_size should > 0, please check")
        return super().infer(mm_inputs, batch_size, max_output_length, ignore_eos, max_iters=max_iters)

为了后续在WeMMRunner类复用父类run和``prepare_request`方法,在主函数中预先处理与批处理相关的逻辑。

...
if len(texts) != image_length:
        texts.extend([texts[-1]] * (image_length - len(texts)))

remainder = image_length % args.max_batch_size
if remainder != 0:
    num_to_add = args.max_batch_size - remainder
    image_path.extend([image_path[-1]] * num_to_add)
    texts.extend([texts[-1]] * num_to_add)
...

权重导入

InternLM-Xcomposer2的权重命名结构如下图所示:

image-20250115105239951

WeMM-Chat-2k-CN的权重命名结构如下图所示:

image-20250115105429352

可以发现,两者权重命名风格不一样,WeMM的lora微调前的原始权重命名中多了一个original_linear,而且这个original_linear是插在名字中间,无法通过传入prefix来解决,为了不修改公共组件中的代码(遵循开闭原则),最终选择自写权重重命名脚本,将WeMM的权重中的original_linear去掉后重新保存。

# Copyright Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
import argparse
import os
from datetime import datetime, timezone
from pathlib import Path
import shutil

from typing import List
import torch
from safetensors.torch import save_file, load_file

from atb_llm.utils.hub import weight_files
from atb_llm.utils.log import logger
from atb_llm.utils import file_utils
from atb_llm.utils.convert import _remove_duplicate_names

MAX_TOKENIZER_FILE_SIZE = 1024 * 1024 * 1024
INFIX_WEIGHT_NAME = "original_linear."


def copy_remaining_files(model_dir, dest_dir):
    model_dir = file_utils.standardize_path(model_dir, check_link=False)
    file_utils.check_path_permission(model_dir)
    if os.path.exists(dest_dir):
        dest_dir = file_utils.standardize_path(dest_dir, check_link=False)
        file_utils.check_path_permission(dest_dir)
    else:
        os.makedirs(dest_dir, exist_ok=True)
        dest_dir = file_utils.standardize_path(dest_dir, check_link=False)

    suffix = '.safetensors'
    for filename in file_utils.safe_listdir(model_dir):
        if not filename.endswith(suffix):
            src_filepath = os.path.join(model_dir, filename)
            src_filepath = file_utils.standardize_path(src_filepath, check_link=False)
            file_utils.check_file_safety(src_filepath, 'r', max_file_size=MAX_TOKENIZER_FILE_SIZE)
            dest_filepath = os.path.join(dest_dir, filename)
            dest_filepath = file_utils.standardize_path(dest_filepath, check_link=False)
            file_utils.check_file_safety(dest_filepath, 'w', max_file_size=MAX_TOKENIZER_FILE_SIZE)
            shutil.copyfile(src_filepath, dest_filepath)


def rename_safetensor_file(src_file: Path, dst_file: Path, discard_names: List[str]):
    src_file = file_utils.standardize_path(str(src_file), check_link=False)
    file_utils.check_file_safety(src_file, 'r', is_check_file_size=False)

    loaded_state_dict = load_file(src_file)
    if "state_dict" in loaded_state_dict:
        loaded_state_dict = loaded_state_dict["state_dict"]
    to_remove_dict = _remove_duplicate_names(loaded_state_dict, discard_names=discard_names)

    metadata = {"format": "pt"}
    for kept_name, to_remove_list in to_remove_dict.items():
        for to_remove in to_remove_list:
            if to_remove not in metadata:
                metadata[to_remove] = kept_name
            del loaded_state_dict[to_remove]

    renamed_loaded_state_dict = {}
    for k, v in loaded_state_dict.items():
        if INFIX_WEIGHT_NAME in k:
            k = k.replace(INFIX_WEIGHT_NAME, "")
        renamed_loaded_state_dict[k] = v.contiguous()

    os.makedirs(os.path.dirname(dst_file), exist_ok=True)
    dst_file = file_utils.standardize_path(str(dst_file), check_link=False)
    file_utils.check_file_safety(dst_file, 'w', is_check_file_size=False)
    save_file(renamed_loaded_state_dict, dst_file, metadata=metadata)

    reloaded_state_dict = load_file(dst_file)
    for k, pt_tensor in loaded_state_dict.items():
        k = k.replace(INFIX_WEIGHT_NAME, "")
        sf_tensor = reloaded_state_dict[k]
        if not torch.equal(pt_tensor, sf_tensor):
            raise RuntimeError(f"The output tensors do not match for key {k}")


def rename_safetensor_files(src_files: List[Path], dst_files: List[Path], discard_names: List[str]):
    num_src_files = len(src_files)

    for i, (src_file, dst_file) in enumerate(zip(src_files, dst_files)):
        blacklisted_keywords = ["arguments", "args", "training"]
        if any(substring in src_file.name for substring in blacklisted_keywords):
            continue

        start_time = datetime.now(tz=timezone.utc)
        rename_safetensor_file(src_file, dst_file, discard_names)
        elapsed_time = datetime.now(tz=timezone.utc) - start_time
        try:
            logger.info(f"Rename: [{i + 1}/{num_src_files}] -- Took: {elapsed_time}")
        except ZeroDivisionError as e:
            raise ZeroDivisionError from e


def rename_weights(model_path, save_directory):
    local_src_files = weight_files(model_path, extension=".safetensors")
    local_dst_files = [
        Path(save_directory) / f"{s.stem}.safetensors"
        for s in local_src_files
    ]
    rename_safetensor_files(local_src_files, local_dst_files, discard_names=[])
    _ = weight_files(model_path)


def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--src_model_path', type=str, help="model and tokenizer path")
    parser.add_argument('--save_directory', type=str)
    return parser.parse_args()


if __name__ == '__main__':
    args = parse_arguments()
    model_path = args.src_model_path
    save_directory = args.save_directory

    input_model_path = file_utils.standardize_path(model_path, check_link=False)
    file_utils.check_path_permission(input_model_path)

    if not os.path.exists(save_directory):
        os.makedirs(save_directory, exist_ok=True)
    input_save_directory = file_utils.standardize_path(save_directory, check_link=False)
    file_utils.check_path_permission(input_save_directory)

    rename_weights(input_model_path, input_save_directory)
    copy_remaining_files(input_model_path, input_save_directory)

Embedding融合

WeMM模型作为多模态模型,最终需要将视觉ViT模型处理后的特征和文本Embedding按一定规则融合,最后输入LLM中推理得到最终输出。在flash_causal_wemm.py中的FlashWemmForCausalLM类下实现prepare_inputs_labels_for_multimodal函数。

    def prepare_inputs_labels_for_multimodal(
        self,
        llm,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.LongTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        clip_embeddings: Optional[torch.FloatTensor] = None
    ):
        if pixel_values is None and clip_embeddings is None:
            return {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "inputs_embeds": None,
                "labels": labels,
            }

        _labels = labels
        _attention_mask = attention_mask
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
        else:
            attention_mask = attention_mask.bool()
        if labels is None:
            labels = torch.full_like(input_ids, IGNORE_INDEX)

        input_ids = [
            cur_input_ids[cur_attention_mask] 
            for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
        ]
        labels = [
            cur_labels[cur_attention_mask] 
            for cur_labels, cur_attention_mask in zip(labels, attention_mask)
        ]

        new_inputs_embeds = []
        new_labels = []
        new_img_masks = []
        cur_image_idx = 0
        for batch_idx, cur_input_ids in enumerate(input_ids):
            num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
            if num_images == 0:
                cur_pixel_values = pixel_values[cur_image_idx] if pixel_values is not None else None
                cur_clip_emb = clip_embeddings[cur_image_idx] if clip_embeddings is not None else None
                cur_inputs_embeds_1 = llm.model.tok_embeddings(cur_input_ids)
                if cur_clip_emb is not None and cur_pixel_values is not None:
                    cur_inputs_embeds = torch.cat(
                        [cur_inputs_embeds_1, cur_pixel_values[0:0], cur_clip_emb[0:0]], dim=0
                    )
                elif cur_pixel_values is not None:
                    cur_inputs_embeds = torch.cat([cur_inputs_embeds_1, cur_pixel_values[0:0]], dim=0)
                elif cur_clip_emb is not None:
                    cur_inputs_embeds = torch.cat([cur_inputs_embeds_1, cur_clip_emb[0:0]], dim=0)
                else:
                    raise ValueError
                new_inputs_embeds.append(cur_inputs_embeds)
                new_labels.append(labels[batch_idx])
                new_img_masks.append(torch.zeros(cur_inputs_embeds.shape[0], device=cur_inputs_embeds.device).bool())
                cur_image_idx += 1
                continue

            image_token_indices = (
                [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
            )
            cur_input_ids_noim = []
            cur_labels = labels[batch_idx]
            cur_labels_noim = []
            for i in range(len(image_token_indices) - 1):
                cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
                cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
            split_sizes = [x.shape[0] for x in cur_labels_noim]
            cur_inputs_embeds = llm.model.tok_embeddings(torch.cat(cur_input_ids_noim))
            cur_inputs_embeds_no_im = torch.split(cur_inputs_embeds, split_sizes, dim=0)
            cur_new_inputs_embeds = []
            cur_new_labels = []
            cur_img_masks = []

            for i in range(num_images + 1):
                cur_new_inputs_embeds.append(cur_inputs_embeds_no_im[i])
                cur_new_labels.append(cur_labels_noim[i])
                cur_img_masks.append(
                    torch.zeros(cur_inputs_embeds_no_im[i].shape[0], device=cur_inputs_embeds_no_im[i].device).bool()
                )
                if i < num_images:
                    cur_pixel_values = pixel_values[cur_image_idx] if pixel_values is not None else None
                    cur_clip_emb = clip_embeddings[cur_image_idx] if clip_embeddings is not None else None

                    cur_image_idx += 1

                    # discrete token embeddings
                    if cur_pixel_values is not None:
                        cur_new_inputs_embeds.append(cur_pixel_values)
                        cur_img_masks.append(
                            torch.ones(cur_pixel_values.shape[0], device=cur_pixel_values.device).bool()
                        )
                        cur_new_labels.append(
                            torch.full(
                                (cur_pixel_values.shape[0],),
                                IGNORE_INDEX,
                                device=cur_labels.device,
                                dtype=cur_labels.dtype,
                            )
                        )

                    # clip embeddings
                    if cur_clip_emb is not None:
                        cur_new_inputs_embeds.append(cur_clip_emb)
                        cur_img_masks.append(torch.ones(cur_clip_emb.shape[0], device=cur_clip_emb.device).bool())
                        cur_new_labels.append(
                            torch.full(
                                (cur_clip_emb.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype
                            )
                        )

            cur_new_inputs_embeds = torch.cat(cur_new_inputs_embeds)
            cur_new_labels = torch.cat(cur_new_labels)
            cur_img_masks = torch.cat(cur_img_masks)

            new_inputs_embeds.append(cur_new_inputs_embeds)
            new_labels.append(cur_new_labels)
            new_img_masks.append(cur_img_masks)

        # Combine them
        max_len = max(x.shape[0] for x in new_inputs_embeds)
        batch_size = len(new_inputs_embeds)

        new_inputs_embeds_padded = []
        new_labels_padded = torch.full(
            (batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device
        )
        attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
        new_img_masks_padded = torch.zeros((batch_size, max_len), device=new_img_masks[0].device).bool()

        for i, (cur_new_embed, cur_new_labels, cur_new_img_masks) in enumerate(
            zip(new_inputs_embeds, new_labels, new_img_masks)
        ):
            cur_new_embed = cur_new_embed[:max_len]
            cur_new_labels = cur_new_labels[:max_len]
            cur_new_img_masks = cur_new_img_masks[:max_len]

            cur_len = cur_new_embed.shape[0]
            new_inputs_embeds_padded.append(
                torch.cat(
                    (
                        cur_new_embed,
                        torch.zeros(
                            (max_len - cur_len, cur_new_embed.shape[1]),
                            dtype=cur_new_embed.dtype,
                            device=cur_new_embed.device,
                        ),
                    ),
                    dim=0,
                )
            )
            if cur_len > 0:
                new_labels_padded[i, :cur_len] = cur_new_labels
                attention_mask[i, :cur_len] = True
                new_img_masks_padded[i, :cur_len] = cur_new_img_masks

        new_inputs_embeds = torch.stack(new_inputs_embeds_padded, dim=0)

        if _labels is None:
            new_labels = None
        else:
            new_labels = new_labels_padded

        if _attention_mask is None:
            attention_mask = None
        else:
            attention_mask = attention_mask.to(dtype=_attention_mask.dtype)

        prepared_data = {
            "input_ids": None,
            "attention_mask": attention_mask,
            "inputs_embeds": new_inputs_embeds,
            "labels": new_labels,
        }
        # if pixel_values is not None:
        prepared_data.update({"im_mask": new_img_masks_padded})
        return prepared_data

服务化对接

当前多模态模型优先适配到MindIE-LLM仓上的examples下(后续称为模型侧),多模态模型接入服务化的目的是为了能够通过URL请求后端(MindIE-LLM下的mindie_llm目录下的架构,后续称为自研推理后端)的模型端到端推理,这会涉及到另一个仓MindIE-Service下的组件MindIE-Server(后续称为服务侧)。当前主流的一些三方推理后端包括 Triton、TGI、VLLM和OpenAI,每个推理后端都有各自的URL请求格式。

【模型侧 VS 服务化】的推理流程对比(Llava为例)
模型侧推理流程

在模型侧,通常模型的运行是通过脚本执行run_pa.py文件,run_pa.py中通过PARunner类或其子类来控制推理流程:

  1. 包括processor的初始化;
  2. 通过FlashcausalXXX类中实现prepare_prefill_token一类的函数,调用ViT或者其他处理多媒体格式的模型,从而token_ids的生成;
  3. token_ids送入模型forward等操作。

因此模型侧可以在FlashcausalXXX类中实现一些prepare_prefill_token之类的操作供送入forward之前调用,还可以在forward入参中通过kwargs传入其它参数。

image-20250115112906901

服务化推理流程

与模型侧不同的是,服务侧只会调用推理后端的接口,因此并不会走run_pa.py,也不会调用PARunner类。从对接服务化流程角度简单来说,大概有几个步骤:

  1. 服务侧会使用多个线程初始化多个mies_tokenizer下的IbisTokenizer类;
  2. 服务侧收到一个请求之后,首先会调用IbisTokenizer类中的encode()接口获取input_ids,注意送入tokenize()的是单个Batch;
  3. 得到的input_ids在自研后端进行continuous batching之后,直接送入相应FlashcausalXXX类的forward

这里服务化处理输入获得input_ids的途径只有通过encode()接口,这个接口中实际上就是调用了模型侧Routertokenize()接口,得到input_ids之后就是组batch,之后就直接送入forward()了,这个流程的推动都是有自研后端来负责,作为模型侧适配我们无法感知和更改,因此就无法通过kwargs传入额外的参数。

image-20250115113111283

1.tokenize() 接口实现

image-20250115113327512image-20250115113342707

tokenize()函数中要做以下几件事:

  1. 对输入prompttokenizer分词后,得到token_ids;
  2. 遍历输入,加载并处理图片数据,计算图片经过image_preprocessorViT等处理后最终转换成的token的数量,进行padding
  3. 将处理好的pixel_value数据存入共享内存,注意:1)需要将数据转换为numpy.addray才能存入共享内存;2)dtype不做限制但是在后续读取数据时需要保持一致;
  4. 将共享内存name和存入数据的shape编码,涉及到的编码函数已提到公共代码中;
  5. 将编码好的nameshape嵌入input_ids中,返回一维的torch.Tensor(cpu)类型的input_ids

代码实现

# Copyright Huawei Technologies Co., Ltd. 2024-2028. All rights reserved.
import math
import os
from dataclasses import dataclass
from typing import Any

import numpy as np
import torch
from PIL import Image

from atb_llm.utils.log import logger
from atb_llm.utils.log.error_code import ErrorCode
from atb_llm.utils.shm_utils import encode_shm_name_to_int64, encode_shape_to_int64, create_shm
from atb_llm.utils.multimodal_utils import safe_open_image
from atb_llm.models.base.router import BaseRouter
from atb_llm.models.base.model_utils import safe_get_tokenizer_from_pretrained
from atb_llm.models.wemm.data_preprocess_wemm import recover_navit_subimages_with_pos_emb
from .image_processor_2k import Idefics2ImageProcessor

_IMAGE = "image"
_TEXT = "text"

DEFAULT_IMAGE_TOKEN = "<image>"
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_BEGIN_TOKEN = "<img>"
IMAGE_BEGIN_TOKEN_INDEX = -300
DEFAULT_IMAGE_END_TOKEN = "</img>"
IMAGE_END_TOKEN_INDEX = -400
EOS_TOKEN_ID = 92542


def process_shared_memory(pixel_values, shm_name_save_path, data_type):
    shm = create_shm(pixel_values.nbytes, shm_name_save_path)
    shared_array = np.ndarray(pixel_values.shape, dtype=data_type, buffer=shm.buf)
    shared_array[:] = pixel_values
    shm_name = encode_shm_name_to_int64(shm.name)
    shape_value = encode_shape_to_int64(pixel_values.shape)
    return shm_name, shape_value


@dataclass
class WemmRouter(BaseRouter):
    _image_processor: Any = None

    def __post_init__(self):
        super().__post_init__()
        self.tokenizer.eos_token_id = EOS_TOKEN_ID

    @property
    def image_processor(self):
        if not hasattr(self, "_image_processor"):
            self._image_processor = self.get_image_processor()
        elif self._image_processor is None:
            self._image_processor = self.get_image_processor()
        return self._image_processor

    def get_config(self):
        config_cls = self.get_config_cls()
        config = config_cls.from_dict(self.config_dict)
        super().check_config(config)
        return config

    def get_tokenizer(self):
        tokenizer = safe_get_tokenizer_from_pretrained(
            self.tokenizer_path,
            revision=self.revision,
            padding_side="left",
            trust_remote_code=self.trust_remote_code,
            use_fast=True,
        )
        return tokenizer

    def get_image_processor(self):
        return Idefics2ImageProcessor(self.config.image_processor)

    def tokenize(self, inputs, **kwargs):
        text = ""
        image_num = sum(1 for d in inputs if _IMAGE in d)
        shm_name_save_path = kwargs.get("shm_name_save_path", None)

        if image_num > 1:
            logger.error("Input image numbers can not be greater than 1!", ErrorCode.ATB_MODELS_PARAM_OUT_OF_RANGE)
            raise KeyError("Input image numbers can not be greater than 1!")

        for single_input in inputs:
            if single_input.get(_TEXT, None):
                text = single_input.get(_TEXT)
                continue
            if single_input.get(_IMAGE, None):
                image_path = single_input[_IMAGE]
                if shm_name_save_path is None:
                    shm_name_save_dir = os.path.dirname(os.path.dirname(image_path))
                    shm_name_save_path = os.path.join(shm_name_save_dir, "shm_name.txt")

                image_obj = None
                image_obj = safe_open_image(Image, image_path)
                if image_obj is None:
                    raise ValueError(f"Unrecognized image path input, only support local path,  got {image_path}")
                image_rgb = image_obj.convert("RGB")
                image_size = self.config.image_processor["size"]
                navit980_images = self.image_processor(
                    [[image_rgb]],
                    size=image_size,
                    return_tensors="pt",
                    do_image_splitting=self.config.do_image_splitting,
                )
                image_obj.close()

                dim = navit980_images["navit_pixel_values"].shape
                patch_size = self.config.vision_config.patch_size
                visual_dim = math.ceil(dim[2] / patch_size) * math.ceil(dim[3] / patch_size)
                clip_visual_outputs_fake = torch.ones(
                    (dim[0], visual_dim, self.config.vision_config.hidden_size), dtype=torch.float16
                )
                super_image_hidden_states, _, _ = recover_navit_subimages_with_pos_emb(
                    clip_visual_outputs_fake,
                    navit980_images["pixel_attention_mask"],
                    num_sub_images=-1,
                    visual_embedding_group=16,
                    pos_hidden_size=4096,
                    thumbnail_only=True,
                )
                img_token_num = math.ceil(super_image_hidden_states.shape[1] / 4) * math.ceil(
                    super_image_hidden_states.shape[2] / 4
                )
                values_shm_name, values_shape_value = process_shared_memory(
                    navit980_images["navit_pixel_values"], shm_name_save_path, np.float32
                )
                mask_shm_name, mask_shape_value = process_shared_memory(
                    navit980_images["pixel_attention_mask"], shm_name_save_path, np.bool8
                )

            else:
                logger.error(
                    "The input field currently only needs to support 'image' and 'text'.",
                    ErrorCode.ATB_MODELS_PARAM_OUT_OF_RANGE,
                )
                raise TypeError("The input field currently only needs to support 'image' and 'text'.")

        prompt = "<image>" + "\n" + text
        prompt = f"<|im_start|>user\n{prompt}<|im_end|><|im_start|>assistant\n"

        chunk_encode = []
        for idx, chunk in enumerate(prompt.split(DEFAULT_IMAGE_TOKEN)):
            if idx == 0:
                cur_encode = self.tokenizer.encode(chunk)
            else:
                cur_encode = self.tokenizer.encode(chunk, add_special_tokens=False)
            chunk_encode.append(cur_encode)

        if len(chunk_encode) != 2:
            raise ValueError("The length of chunk_encode should be 2")

        ids = []
        for idx, cur_chunk_encode in enumerate(chunk_encode):
            ids.extend(cur_chunk_encode)
            if idx != len(chunk_encode) - 1:
                ids.append(IMAGE_BEGIN_TOKEN_INDEX)
                ids.extend([IMAGE_TOKEN_INDEX] * img_token_num)
                ids.append(IMAGE_END_TOKEN_INDEX)

        input_ids = torch.tensor(ids)
        bos_pos = torch.where(torch.eq(input_ids, IMAGE_BEGIN_TOKEN_INDEX))[0]
        if input_ids.size(0) < bos_pos + 5:
            msg = "tokenize error, input_ids length is too short."
            logger.error(msg, ErrorCode.ATB_MODELS_PARAM_OUT_OF_RANGE)
            raise ValueError(msg)
        input_ids[bos_pos + 1] = values_shm_name
        input_ids[bos_pos + 2] = values_shape_value
        input_ids[bos_pos + 3] = mask_shm_name
        input_ids[bos_pos + 4] = mask_shape_value

        return input_ids

2.forward() 多媒体数据处理

forward()中,实际就是tokenize()的逆变换:

  1. 通过特殊的token找到图片数据的插入位置;
  2. 解码共享内存的name和数据的shape,从共享内存取出处理好的数据,需要注意get_pixel_value_from_shm()的第三个入参dtype是指存入共享内存时的类型,需要与存入时保持一致才能正确取出,这个函数返回tensor
  3. 将上一步取出的pixel_value送入ViT得到image embedding
  4. image embeddingtext embedding融合后送入LLM
    def prepare_prefill_token_service(self, ids):
        if not torch.any(torch.eq(ids, IMAGE_BEGIN_TOKEN_INDEX)):
            inputs_embeds = self.language_model.model.tok_embeddings(ids)
            self.im_mask = torch.zeros(inputs_embeds.shape[0], 1, device=inputs_embeds.device).to(torch.float16)
            return inputs_embeds

        bos_pos = torch.where(torch.eq(ids, IMAGE_BEGIN_TOKEN_INDEX))[0]
        eos_pos = torch.where(torch.eq(ids, IMAGE_END_TOKEN_INDEX))[0]
        values_shm_name = ids[bos_pos + 1]
        values_shape_value = ids[bos_pos + 2]
        mask_shm_name = ids[bos_pos + 3]
        mask_shape_value = ids[bos_pos + 4]

        navit980_images = {}
        navit980_images["navit_pixel_values"] = get_data_from_shm(
            values_shm_name, values_shape_value, np.float32, self.device
        )
        navit980_images["pixel_attention_mask"] = get_data_from_shm(
            mask_shm_name, mask_shape_value, np.bool8, self.device
        )
        merged_visual_embeddings = merge_visual_embed(
            navit980_images, self.vision_tower, self.downsampler, self.visual_source_spliter_emb
        )

        pixel_values = None
        in_ids = []
        in_ids.extend(ids[:bos_pos].cpu().numpy().tolist())
        in_ids.append(IMAGE_TOKEN_INDEX)
        in_ids.extend(ids[eos_pos + 1 :].cpu().numpy().tolist())
        in_ids = torch.tensor(in_ids).npu().unsqueeze(0)
        mm_inputs = self.prepare_inputs_labels_for_multimodal(
            llm=self.language_model,
            input_ids=in_ids,
            pixel_values=pixel_values,
            clip_embeddings=merged_visual_embeddings,
        )

        self.im_mask = mm_inputs.get("im_mask", None)
        if self.im_mask is not None:
            self.im_mask = self.im_mask.view(1, -1)
            self.im_mask = self.im_mask.squeeze(0).unsqueeze(-1).to(torch.float16)

        inputs_embeds = mm_inputs.get("inputs_embeds", None)
        inputs_embeds = inputs_embeds.view(inputs_embeds.shape[0] * inputs_embeds.shape[1], inputs_embeds.shape[2])
        return inputs_embeds

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        is_prefill: bool,
        kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
        block_tables: torch.Tensor,
        slots: torch.Tensor,
        input_lengths: torch.Tensor,
        max_seq_len: int,
        lm_head_indices: Optional[torch.Tensor] = None,
        **kwargs,
    ):
        self.language_model.adapter_manager = self.adapter_manager
        kwargs.update({"adapter_ids": ["wemm"]})
        if is_prefill and input_ids.dim() == 1:
            inputs_embeds = self.prepare_prefill_token_service(input_ids)
        else:
            inputs_embeds = input_ids

        return self.language_model.forward(
            inputs_embeds,
            position_ids,
            is_prefill,
            kv_cache,
            block_tables,
            slots,
            input_lengths,
            max_seq_len,
            lm_head_indices,
            im_mask=self.im_mask,
            **kwargs,
        )

迁移时遇到的问题

Token数量预计算问题

**问题描述:**在WemmRouter一侧,由于无法调用ViT等图片处理相关模型,比较难计算最终图片的Token数,所以想通过插入1个占位符,到FlashWemmForCausalLM这一侧再处理图片。但是在模型侧forward()函数的输入position_idskv_cacheblock_tables等参数都需要根据tokenize() 返回的input_ids的大小来计算分配。

image-20250115120655437

**解决办法:**在tokenize()函数中提前计算好图片的token数,再通过特殊token来占位,并且将存放图片Tensor的共享内存文件名称和数据shape两个参数嵌入到占位ids中,供服务侧分配kv_cache等资源和后续模型侧的数据读取。

encode()方法内格式转换问题

**问题描述:**在tokenize()方法中将计算得到的input_ids放到了npu上,但服务侧的IbisTokenizer类下的encode()方法收到Tensor格式的input_ids会把其转成list,导致了转换出错。

image-20250115120911288

image-20250115120917856

**解决办法:**将tokenize()方法的返回input_idsdevice设置成cpu,问题得到解决。

服务化参数传递限制

**问题3:**与纯模型推理不同,服务化推理时,服务侧只会调用推理后端的接口,因此不会走run_pa.py,也不会调用PARunner类,因此无法通过forward()接口的kwargs传入额外的参数,故而传入的lora adapteridNone,框架使用了默认id:“base”

image-20250115121048551

image-20250115121056015

**解决办法:**将forward()接口下判断是否是服务化的prefill阶段,如果是则添加参数{“adapter_ids”:[“wemm”]}

image-20250115121116365
如果你对多模态大模型、强化学习、昇腾 NPU 部署、模型性能优化感兴趣,欢迎持续关注【AI模力圈】。

我们会持续更新:

  1. 多模态模型结构拆解
  2. 强化学习算法原理与实践
  3. 昇腾 NPU 迁移部署与踩坑复盘
  4. 模型训练与推理性能优化

图解版、速读版内容也会同步更新到公众号 / 小红书。

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐