TensorRT LLM plugin背景

TensorRT LLM底层架构

  • TensorRT LLM Runtime代表着我们之前三个博客介绍的内容,Runtime下面就对接了TensorRT Engine,但是原生的TensorRT都是适配传统模型的算子
  • TensorRT为了更好的适配TensorRT LLM处理大模型,特别设计Tensorrt_llm_module模块,其中特别封装了 LLM 推理所需的核心计算逻辑(比如注意力机制、前馈网络(FFN)、层归一化等),是 TensorRT-LLM 框架中 “将 LLM 架构转化为可执行计算图” 的关键载体。
  • Tensorrt_llm_module中发挥解析LLM能力的是通过TensorRT Plugin的形式存在的,往下是高效发挥GPU并行计算能力TensorRT LLM Kernel ,也就是CUDA kernel。
  • 这篇博客就是想探究整个链路软件架构和代码具体实现

TensorRT Plugin

Plugin register

  • TensorRT Plugin作为中间模块发挥着承上启下重要的作用,在TensorRT LLM代码仓库中有Plugin文件夹存放plugin module
  • 里面有很多plugin模块,我选取在调试过程中使用的Plugin GPTAttentionPlugin(我选取的模型文件是Qwen)作为example给大家分析
  • 先说下TensorRT Plugin的注册流程
// TensorRT-LLM-main\cpp\tensorrt_llm\plugins\api\tllmPlugin.cpp
void initOnLoad()
{
    auto constexpr kLoadPlugins = "TRT_LLM_LOAD_PLUGINS";
    auto const loadPlugins = std::getenv(kLoadPlugins);
    if (loadPlugins && loadPlugins[0] == '1')
    {
        initTrtLlmPlugins(gLogger);
    }
}

bool initTrtLlmPlugins(void* logger, char const* libNamespace)
{
    if (pluginsInitialized)
    {
        return true;
    }

    if (logger)
    {
        gLogger = static_cast<nvinfer1::ILogger*>(logger);
    }
    setLoggerFinder(&gGlobalLoggerFinder);

    auto registry = getPluginRegistry();

    {
        std::int32_t nbCreators;
        auto creators = getPluginCreators(nbCreators);

        for (std::int32_t i = 0; i < nbCreators; ++i)
        {
            auto const creator = creators[i];
            creator->setPluginNamespace(libNamespace);
            registry->registerCreator(*creator, libNamespace);
            if (gLogger)
            {
                auto const msg = tc::fmtstr("Registered plugin creator %s version %s in namespace %s",
                    creator->getPluginName(), creator->getPluginVersion(), libNamespace);
                gLogger->log(nvinfer1::ILogger::Severity::kVERBOSE, msg.c_str());
            }
        }
    }

    {
        std::int32_t nbCreators;
        auto creators = getCreators(nbCreators);

        for (std::int32_t i = 0; i < nbCreators; ++i)
        {
            auto const creator = creators[i];
            registry->registerCreator(*creator, libNamespace);
        }
    }

    pluginsInitialized = true;
    return true;
}
  • 调用initOnLoad之后开始加载 Plugin,getPluginCreators函数会获取所有创建出来的PluginCreator,registerCreator把creator注册到 TensorRT 注册表,plugin creator注册之后才能使用。
  • 为什么creator很重要呢?因为具体的Plugin use Creator create instance。
  • Plugin instance是TensorRT自动化调用
// cpp\tensorrt_llm\plugins\gptAttentionPlugin\gptAttentionPlugin.cpp
//删除了很多参数
IPluginV2* GPTAttentionPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept
{
    PluginFieldParser p{fc->nbFields, fc->fields};
    try
    {
        auto* obj = new GPTAttentionPlugin(p.getScalar<int32_t>("layer_idx").value(),
            p.getScalar<int32_t>("num_heads").value(), p.getScalar<int32_t>("vision_start").value(),
            p.getScalar<int32_t>("vision_length").value(), p.getScalar<int32_t>("num_kv_heads").value(),
            static_cast<bool>(p.getScalar<int8_t>("remove_input_padding").value()),
            static_cast<AttentionMaskType>(p.getScalar<int32_t>("mask_type").value()),
            BlockSparseParams{p.getScalar<int32_t>("block_sparse_block_size").value(),
                static_cast<bool>(p.getScalar<int8_t>("block_sparse_homo_head_pattern").value()),
                p.getScalar<int32_t>("block_sparse_num_local_blocks").value(),
                p.getScalar<int32_t>("block_sparse_vertical_stride").value()},
        obj->setPluginNamespace(mNamespace.c_str());
        return obj;
    }
    catch (std::exception const& e)
    {
        caughtError(e);
    }
    return nullptr;
}
  • getCreators会获取EaglePrepareDrafterInputsPluginCreator、DoraPluginCreator两个特殊的PluginCreator
  • getPluginCreators函数是 TensorRT LLM 插件库的 “启动开关”—— 只有调用它,所有 LLM 专用优化插件(注意力、量化、Mamba、分布式等)才会被 TensorRT 识别
  • 所有plugin creator加载完之后会通过logger打印出来
    Plugin register

Plugin Create

  • TensorRT Plugin create 相当于是自定义Plugin,NVIDIA 给出一套非常标准的TensorRT Plugin添加流程,最核心是实现「Plugin 计算类」和「Plugin 工厂类(Creator)」,重写几个重要的接口。
// cpp\tensorrt_llm\plugins\gptAttentionPlugin\gptAttentionPlugin.h
class GPTAttentionPlugin : public GPTAttentionPluginCommon{
    int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
        void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
    // IPluginV2 Methods
    char const* getPluginType() const noexcept override;
    char const* getPluginVersion() const noexcept override;
    int getNbOutputs() const noexcept override;
    GPTAttentionPlugin* clone() const noexcept override;
}
class GPTAttentionPluginCreator : public GPTAttentionPluginCreatorCommon{
    char const* getPluginName() const noexcept override;
    char const* getPluginVersion() const noexcept override;
    nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;
    nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
}

//cpp\tensorrt_llm\plugins\gptAttentionCommon\gptAttentionCommon.h
class GPTAttentionPluginCommon : public BasePlugin{}
class GPTAttentionPluginCreatorCommon : public BaseCreator
{
public:
    GPTAttentionPluginCreatorCommon();
    nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;
    template <typename T>
    T* deserializePluginImpl(char const* name, void const* serialData, size_t serialLength) noexcept;
protected:
    std::vector<nvinfer1::PluginField> mPluginAttributes;
    nvinfer1::PluginFieldCollection mFC{};
};

//cpp\tensorrt_llm\plugins\common\plugin.h
class BasePlugin : public nvinfer1::IPluginV2DynamicExt
{
public:
    void setPluginNamespace(char const* libNamespace) noexcept override
    {
        mNamespace = libNamespace;
    }

    [[nodiscard]] char const* getPluginNamespace() const noexcept override
    {
        return mNamespace.c_str();
    }

protected:
    std::string mNamespace{api::kDefaultNamespace};
};

class BaseCreator : public nvinfer1::IPluginCreator
{
public:
    void setPluginNamespace(char const* libNamespace) noexcept override
    {
        mNamespace = libNamespace;
    }

    [[nodiscard]] char const* getPluginNamespace() const noexcept override
    {
        return mNamespace.c_str();
    }

protected:
    std::string mNamespace{api::kDefaultNamespace};
};

plugin class

  • 为了大家更加清楚GPTAttentionPlugin 继承关系这里直接给出类图
  • 现在NVIDIA已经升级到第三代plugin:IPluginV3、IPluginCreatorV3One,但是GPTAttention功能比较老还是用的第二代plugin
  • GPTAttentionPluginCreator在通过initTrtLlmPlugins registerCreator之后,在推理过程中会自动调用GPTAttentionPluginCreator::createPlugin new GPTAttentionPlugin,
  • 然后使用enqueue作为接口被调用,调用点是第三篇博客提到的:enqueueV3 interface

GPTAttentionPlugin调用到CUDA

  • GPTAttentionPlugin 是专为 GPT 类大模型(Qwen、LLaMA、GPT-3 等自回归模型) 设计的注意力层优化插件,核心作用是替代原生 TensorRT 算子,通过硬件加速、计算融合、内存优化等手段,实现低延迟、高吞吐、省显存的注意力计算。
  • Qwen 模型的注意力层与 GPT 类模型高度兼容,所以复用成熟的 GPTAttentionPlugin 来优化 QWEN 的注意力计算
  • 流程图:
    CUDA流程
  • enqueueV3被调用之后会调用GPTAttentionPlugin::enqueue interface,会根据量化的nvinfer1::DataType类型选择路径
  • enqueueImpl准备一下Requests Context。enqueueSome根据各种优化技术设置参数,都设置好之后
// GPTAttentionPlugin::enqueueSome
if (is_context) // context stage
{enqueueContext<T, KVCacheBuffer>(enqueue_params, stream);}
else // generation stage
{enqueueGeneration<T, KVCacheBuffer>(enqueue_params, stream);}
  • enqueueGeneration是在AttentionOp 负责具体的注意力层向前计算逻辑,operation确实写得很复杂。。。
  • 当使用了useKVCache的时候,会调用invokeShiftKCache对key缓存进行处理。实际的处理是在GPU 上运行shiftKCache CUDA函数执行
// cpp\tensorrt_llm\kernels\unfusedAttentionKernels.cu

template <typename T, typename KVCacheBuffer>
void invokeShiftKCache(KVCacheBuffer const& kvCacheBuffer, KVLinearBuffer const& shiftKCacheBuffer,
    const KvCacheDataType cache_type, int const sizePerHead, int const timestep, int const batch_beam,
    int const kv_head_num, int const beam_width, int const maxKCacheLen, int const sinkTokenLen,
    float const* kScaleQuantOrig, int const* sequence_lengths, int const* input_lengths, int const rotary_embedding_dim,
    float rotary_embedding_base, RotaryScalingType const rotary_scale_type, float rotary_embedding_scale,
    int const rotary_embedding_max_positions, PositionEmbeddingType const position_embedding_type, cudaStream_t stream)
{
    // Block handles K tile.
    int const token_num_in_k = (timestep <= maxKCacheLen) ? timestep : maxKCacheLen;
    int const vec_size = 16u / sizeof(T);
    dim3 block((sizePerHead / vec_size + 31) / 32 * 32);
    dim3 grid(token_num_in_k, kv_head_num, batch_beam);
    size_t smem_size = (position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX
                || position_embedding_type == PositionEmbeddingType::kLONG_ROPE
                || position_embedding_type == PositionEmbeddingType::kROPE_M
            ? 2 * rotary_embedding_dim * sizeof(T)
            : 0);

    if (cache_type == KvCacheDataType::INT8)
    {
        shiftKCache<T, int8_t, KVCacheBuffer><<<grid, block, smem_size, stream>>>(kvCacheBuffer, shiftKCacheBuffer,
            sizePerHead, timestep, beam_width, maxKCacheLen, sinkTokenLen, kScaleQuantOrig, sequence_lengths,
            input_lengths, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type, rotary_embedding_scale,
            rotary_embedding_max_positions, position_embedding_type);
    }
#ifdef ENABLE_FP8
    else if (cache_type == KvCacheDataType::FP8)
    {
        shiftKCache<T, __nv_fp8_e4m3, KVCacheBuffer><<<grid, block, smem_size, stream>>>(kvCacheBuffer,
            shiftKCacheBuffer, sizePerHead, timestep, beam_width, maxKCacheLen, sinkTokenLen, kScaleQuantOrig,
            sequence_lengths, input_lengths, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type,
            rotary_embedding_scale, rotary_embedding_max_positions, position_embedding_type);
    }
#endif // ENABLE_FP8
    else
    {
        shiftKCache<T, T, KVCacheBuffer><<<grid, block, smem_size, stream>>>(kvCacheBuffer, shiftKCacheBuffer,
            sizePerHead, timestep, beam_width, maxKCacheLen, sinkTokenLen, kScaleQuantOrig, sequence_lengths,
            input_lengths, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type, rotary_embedding_scale,
            rotary_embedding_max_positions, position_embedding_type);
    }
}

//global CUDA 函数实际执行在GPU
template <typename T, typename T_cache, typename KVCacheBuffer>
__global__ void shiftKCache(KVCacheBuffer kvCacheBuffer, KVLinearBuffer shiftKCacheBuffer, int const sizePerHead,
    int const timestep, int const beam_width, int const maxKCacheLen, int const sinkTokenLen,
    float const* kScaleQuantOrig, int const* sequence_lengths, int const* input_lengths, int const rotary_embedding_dim,
    float rotary_embedding_base, RotaryScalingType const rotary_scale_type, float rotary_embedding_scale,
    int const rotary_embedding_max_positions, PositionEmbeddingType const position_embedding_type)

TensorRT Plugin 创建

  • 从上面的代码可以看到createPlugin new GPTAttentionPlugin,那createPlugin是在什么时候调用的呢?
  • 是在trtllm-build的阶段,TensorRT会逐 Layer 构建网络时实时解析和适配判断,最终会覆盖 Qwen 模型的每一层,决定是否用自定义plugin 替代原生层。
// tensorrt_llm\builder.py
    def build_engine(self,
                     network: Network,
                     builder_config: BuilderConfig,
                     managed_weights: dict = None) -> trt.IHostMemory:
    with net_guard(network):
        # Prepare
        network.set_named_parameters(model.named_parameters())

        # Forward
        prepare_input_args = {
            "max_batch_size":
            build_config.max_batch_size,
            "max_input_len":
            build_config.max_input_len,
            "max_seq_len":
            build_config.max_seq_len,
            "use_cache":
            build_config.kv_cache_type != KVCacheType.DISABLED,
            "max_beam_width":
            build_config.max_beam_width,
            "max_num_tokens":
            build_config.max_num_tokens,
            "opt_num_tokens":
            build_config.opt_num_tokens,
            "prompt_embedding_table_size":
            build_config.max_prompt_embedding_table_size,
            "max_draft_len":
            build_config.max_draft_len,
            "speculative_decoding_draft_tokens_external":
            build_config.speculative_decoding_mode ==
            SpeculativeDecodingMode.DRAFT_TOKENS_EXTERNAL,
            "gather_context_logits":
            build_config.gather_context_logits,
            "lora_target_modules":
            build_config.lora_config.lora_target_modules
        }

        if model.config.architecture == "DecoderModel" or "mllama" in model.config.architecture.lower(
        ):
            prepare_input_args["max_seq_len"] = build_config.max_seq_len
            prepare_input_args[
                "max_decoder_input_len"] = build_config.max_input_len
            prepare_input_args[
                "max_encoder_input_len"] = build_config.max_encoder_input_len

        if model.config.architecture == "WhisperEncoder":

            prepare_input_args = {
                "max_batch_size": build_config.max_batch_size,
            }

        if build_config.speculative_decoding_mode == SpeculativeDecodingMode.EAGLE:
            prepare_input_args[
                "spec_decoding_is_generation_length_variable"] = True
            assert build_config.max_batch_size <= 512, "Max batch size > 512 is not supported for EAGLE"
            assert build_config.max_draft_len <= 256, "Max draft len > 256 is not supported for EAGLE"

        if build_config.speculative_decoding_mode == SpeculativeDecodingMode.LOOKAHEAD_DECODING:
            prepare_input_args[
                "spec_decoding_is_generation_length_variable"] = True
        if model.config.architecture == "Qwen2VLForConditionalGeneration" or model.config.architecture == "Qwen2VLModel":
            prepare_input_args[
                'mrope_rotary_cos_sin_size'] = model.config.max_position_embeddings * model.config.rotary_embedding_dim
        if build_config.speculative_decoding_mode == SpeculativeDecodingMode.EAGLE and not build_config.plugin_config.use_paged_context_fmha:
            logger.warning(
                "Paged Context FMHA is required for EAGLE. Turning it on")
            build_config.plugin_config.use_paged_context_fmha = True

        inputs = model.prepare_inputs(**prepare_input_args)
        model(**inputs)
  • 具体实现是在上面代码中,model(** inputs)执行模型的每一层都会被转换为 TensorRT 的层和自定义plugin实例再封装为 IPluginV2Layer 添加到 TensorRT 网络中
  • trtllm-build命令会调用到class Builder中build_engine function,build_serialized_network核心作用是将 TensorRT 的网络定义(计算图)结合构建配置,编译并优化为可序列化的 TensorRT 引擎二进制文件。
// tensorrt_llm\builder.py
    def build_engine(self,
                     network: Network,
                     builder_config: BuilderConfig,
                     managed_weights: dict = None) -> trt.IHostMemory:
	engine = self.trt_builder.build_serialized_network(
	    network.trt_network, builder_config.trt_builder_config)

  • 当build的过程使用GPTAttentionPlugin替换原生TensorRT之后,在实际运行TensorRT 引擎二进制文件过程中,如果creator注册之后就可以使用register Plugin去实现推理,提升大模型运行效率

Summary

  • 通过以上的流程分析,使用自定义TensorRT Plugin嵌入到TensorRT LLM框架中, 在build转换模型阶段create Plugin、推理框架运行加载模型阶段register creator,实际模型接收Request消息阶段使用Plugin 实际代码处理。
  • 这样高效合理的框架保证大模型可以使用自定义的Plugin、自定义的operation、自定义的CUDA算子,对每个大模型针对性优化的高效运行
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐