SDK封装(Software Development Kit封装)

将复杂系统(如大模型推理引擎)的能力,通过标准化接口、工具集、文档封装成易用的软件开发包(SDK),供上层业务系统直接调用

  • 隐藏底层细节(如网络通信、模型加载、参数调优)
  • 提供一致的API(如generateText()),适配不同推理引擎(vLLM/TGI/本地模型)
  • 集成熔断、重试、监控等企业级特性

gRPC(Google Remote Procedure Call)

高性能、开源的远程过程调用(RPC)框架,由Google开发,基于Protocol Buffers(Protobuf)​ 序列化协议,支持多语言、流式通信、负载均衡

  • 二进制传输:比JSON/XML更高效(体积小、解析快),延迟降低50%+
  • 多语言支持:自动生成Java/C++/Python等客户端/服务端代码
  • 流式通信:支持单向/双向流式调用(如实时对话逐字返回)
  • 内置服务治理:负载均衡、熔断、认证(TLS/mTLS)

用gRPC实现SDK封装

  • 模型推理服务:vLLM部署在GPU节点,暴露gRPC接口(支持流式输出)
  • Java SDK:封装gRPC客户端,提供generateStream()等方法,集成熔断(Resilience4j)、监控(Prometheus)
  • Java业务系统:调用SDK实现实时对话逻辑(如WebSocket推送结果给前端)
+-------------------+       +-----------------------+       +-----------------------+
|   Java业务系统     |       |   Model SDK (Java)     |       | 模型推理服务(vLLM)   |
| (Spring Boot)    | gRPC  | (gRPC Stub + 封装)   | gRPC  | (C++/Python后端)     |
| (实时对话服务)   |◄─────►| (连接池/熔断/监控)   |◄─────►| (GPU节点,K8s部署)   |
+-------------------+       +-----------------------+       +-----------------------+

①、定义gRPC服务接口(Protobuf)

syntax = "proto3";
package interface; //包名,避免命名冲突

//推理服务定义
service InferenceService{
	//流式生成(实时对话逐字返回)
	rpc GenerateStream(GenerateRequest) returns (stream GenerateStreamResponse);
	//同步生成(非流式,批量处理用)
	rpc GenerateSync(GenerateRequest) returns (GenerateResponse);
}

// 请求参数
message GenerateRequest {
  string prompt = 1;  // 输入文本(含对话历史)
  Parameters params = 2;  // 推理参数(温度、max_tokens等)
}

// 推理参数
message Parameters {
  float temperature = 1;  // 随机性(0~1,越低越确定)
  int32 max_new_tokens = 2;  // 最大生成token数
  float top_p = 3;  // 核采样(0~1,控制多样性)
  bool stream = 4;  // 是否流式输出(SDK层固定为true)
}

// 流式响应(逐块返回)
message GenerateStreamResponse {
  string text_chunk = 1;  // 文本片段(如一个词/字)
  bool is_final = 2;  // 是否为最终结果
  Usage usage = 3;  // Token用量(可选)
}

// 同步响应(完整结果)
message GenerateResponse {
  string full_text = 1;  // 完整生成文本
  Usage usage = 2;
}

// Token用量统计
message Usage {
  int32 input_tokens = 1;  // 输入token数
  int32 output_tokens = 2;  // 输出token数
}

②、生成gRPC代码(多语言)

用protoc编译器将.proto文件生成Java/C++/Python代码(客户端/服务端骨架)

生成Java客户端代码(需安装protoc和grpc-java插件):

  • InferenceServiceGrpc:gRPC服务基础类
  • GenerateRequest/GenerateStreamResponse:请求/响应消息类
protoc --java_out=./java --grpc-java_out=./java inference_service.proto

③、实现推理服务(vLLM+gRPC服务端)

部署vLLM推理引擎,并通过gRPC暴露服务(以Python为例,vLLM支持Python API)

服务端代码(vllm_grpc_server.py)

import grpc
from concurrent import futures
import inference_service_pb2
import inference_service_pb2_grpc
from vllm import LLM, SamplingParams

# 加载模型(Mistral-7B,量化加速)
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.1", quantization="awq")
sampling_params = SamplingParams(temperature=0.7, max_tokens=200)

class InferenceServiceImpl(inference_service_pb2_grpc.InferenceServiceServicer):
    def GenerateStream(self, request, context):
        # 调用vLLM生成文本(流式输出需vLLM支持,此处简化为分块返回)
        outputs = llm.generate([request.prompt], sampling_params)[0].outputs[0].text
        # 模拟流式分块(实际可按token/词拆分)
        chunks = [outputs[i:i+5] for i in range(0, len(outputs), 5)]  # 每5字符一块
        for chunk in chunks:
            yield inference_service_pb2.GenerateStreamResponse(
                text_chunk=chunk,
                is_final=(chunk == chunks[-1])
            )

# 启动gRPC服务(端口50051)
def serve():
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
    inference_service_pb2_grpc.add_InferenceServiceServicer_to_server(InferenceServiceImpl(), server)
    server.add_insecure_port('[::]:50051')
    server.start()
    server.wait_for_termination()

if __name__ == "__main__":
    serve()

④、开发Java SDK(封装gRPC客户端)

用生成的Java gRPC代码,封装为易用的SDK,集成连接池、熔断、监控

SDK核心代码(LLMSDK.java)

import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.stub.StreamObserver;
import inference_service.*;
import com.google.common.util.concurrent.ListenableFuture;
import io.github.resilience4j.circuitbreaker.CircuitBreaker;
import io.github.resilience4j.circuitbreaker.CircuitBreakerRegistry;
import javax.annotation.PreDestroy;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.function.Consumer;

public class LLMSDK {
    private final ManagedChannel channel;  // gRPC通道(连接池)
    private final InferenceServiceGrpc.InferenceServiceStub asyncStub;  // 异步调用Stub
    private final CircuitBreaker circuitBreaker;  // 熔断保护器
    private final ExecutorService executor = Executors.newCachedThreadPool();  // 线程池

    // 初始化SDK(从配置中心获取服务地址)
    public LLMSDK(String serviceAddress, CircuitBreakerRegistry cbRegistry) {
        this.channel = ManagedChannelBuilder.forTarget(serviceAddress)
                .usePlaintext()  // 生产环境用TLS
                .maxInboundMessageSize(10 * 1024 * 1024)  // 10MB最大消息
                .keepAliveTime(30, java.util.concurrent.TimeUnit.SECONDS)  // 保活
                .build();
        this.asyncStub = InferenceServiceGrpc.newStub(channel);
        this.circuitBreaker = cbRegistry.circuitBreaker("llm-inference");  // 熔断配置
    }

    // 流式调用(实时对话逐字返回)
    public void generateStream(String prompt, Parameters params, Consumer<String> chunkHandler, Runnable onComplete) {
        // 构建请求
        GenerateRequest request = GenerateRequest.newBuilder()
                .setPrompt(prompt)
                .setParams(params)
                .build();

        // 熔断保护下的gRPC调用
        Runnable grpcCall = () -> asyncStub.generateStream(request, new StreamObserver<GenerateStreamResponse>() {
            @Override
            public void onNext(GenerateStreamResponse response) {
                chunkHandler.accept(response.getTextChunk());  // 逐块回调
            }
            @Override
            public void onError(Throwable t) { /* 错误处理 */ }
            @Override
            public void onCompleted() { onComplete.run(); }  // 完成回调
        });

        // 执行熔断保护调用
        circuitBreaker.executeRunnable(grpcCall);
    }

    // 关闭资源
    @PreDestroy
    public void shutdown() {
        channel.shutdownNow();
        executor.shutdown();
    }

    // 推理参数类(对应Protobuf的Parameters)
    public static class Parameters {
        private float temperature = 0.7f;
        private int maxNewTokens = 200;
        // getter/setter省略
    }
}

⑤、业务系统集成SDK(实时对话)

业务系统调用SDK的generateStream()方法,通过WebSocket将流式结果推送给前端

业务服务代码(RealtimeDialogueService.java)

import org.springframework.stereotype.Service;
import org.springframework.web.socket.WebSocketSession;
import javax.annotation.Resource;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

@Service
public class RealtimeDialogueService {
    @Resource
    private LLMSDK llmSdk;  // 注入SDK

    // 会话ID→WebSocket连接映射
    private final Map<String, WebSocketSession> sessions = new ConcurrentHashMap<>();

    // 处理用户提问
    public void handleUserQuery(String sessionId, String userMessage, WebSocketSession session) {
        sessions.put(sessionId, session);  // 保存连接
        // 构建Prompt(含对话历史,简化示例)
        String prompt = "用户:" + userMessage + "\n助手:";
        // 配置推理参数
        LLMSDK.Parameters params = new LLMSDK.Parameters();
        params.setMaxNewTokens(200);
        // 调用SDK流式接口
        llmSdk.generateStream(
            prompt, 
            params,
            chunk -> sendToClient(sessionId, chunk),  // 逐块推送
            () -> sendToClient(sessionId, "[DONE]")   // 完成标记
        );
    }

    // 推送结果给前端(WebSocket)
    private void sendToClient(String sessionId, String chunk) {
        WebSocketSession session = sessions.get(sessionId);
        if (session != null && session.isOpen()) {
            session.sendMessage(new TextMessage(chunk));
        }
    }
}

JNI(Java Native Interface)

Java提供的跨语言调用接口,允许Java代码与C/C++等“原生代码”直接交互(如调用本地动态库.so/.dll)

  • 绕过JVM,直接调用底层优化库(如C++推理引擎)
  • 操作硬件资源(如GPU显存)、调用操作系统API
  • 需手动管理内存(Java GC与C++堆冲突)、线程安全、跨语言调试

当gRPC仍无法满足微秒级延迟(如高频交易风控),可用JNI直接调用C++推理引擎

①、C++推理引擎暴露Native方法

用C++实现推理逻辑(如调用TensorRT加速的模型),并暴露为动态链接库(.so/.dll)

#include <jni.h>
#include <string>
// 假设已集成TensorRT推理库
extern "C" {
    // JNI方法:生成文本(同步)
    JNIEXPORT jstring JNICALL Java_com_example_LLMSDK_generateSync(
        JNIEnv* env, jobject obj, jstring prompt_jstr, jfloat temperature) {
        // 1. 将Java字符串转为C++字符串
        const char* prompt_cstr = env->GetStringUTFChars(prompt_jstr, nullptr);
        std::string prompt(prompt_cstr);
        env->ReleaseStringUTFChars(prompt_jstr, prompt_cstr);

        // 2. 调用C++推理引擎(如TensorRT模型)
        std::string result = inference_engine.generate(prompt, temperature);

        // 3. 返回Java字符串
        return env->NewStringUTF(result.c_str());
    }
}

②、JavaSDK声明Native方法

用native关键字声明JNI方法,加载动态库

//Java SDK代码(LLMSDK.java)
public class LLMSDK {
    static {
        System.loadLibrary("inference_engine");  // 加载C++动态库(libinference_engine.so)
    }

    // 声明Native方法(同步调用)
    public native String generateSync(String prompt, float temperature);

    // 封装业务逻辑(如参数校验、异常处理)
    public String generate(String prompt) {
        return generateSync(prompt, 0.7f);  // 默认温度0.7
    }
}

③、C++编译,用g++编译为动态库(需指定JNI头文件路径)

Java调用:业务系统直接调用LLMSDK.generate(),JVM自动通过JNI桥接C++代码

g++ -shared -fPIC -o libinference_engine.so inference_engine.cpp -I${JAVA_HOME}/include -I${JAVA_HOME}/include/linux
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐