🚀 使用gRPC实现Java客户端 + Python服务端的大模型调用

🎯 背景介绍

目前市面上主流的业务系统,大多基于Java生态系统构建服务端程序。但如果想要接入大模型能力,我们通常有哪些方案呢?

🔗 传统方案:HTTP通信

最简单的方式就是采用HTTP通信,使用Python实现服务端,直接定义一个POST接口来接收问题请求。然后利用已搭建好的大模型对接(这里以魔搭社区为例),由Python程序直接调用线上API接口,最后返回所需数据。

这种方式完全可行!👍 但今天我想尝试一种更高效的方式——使用gRPC来实现这个功能!

💡 为什么选择gRPC?

  1. 更快的传输速度 🚄
    gRPC相比HTTP请求的一大优势在于它采用二进制传输而非JSON格式,这使得传输报文体积大幅减小,提升了传输效率。

  2. 优秀的服务间调用体验 🤝
    gRPC原本就专为服务间程序调用而设计,就像后端接口可以直接调用本地函数一样调用其他系统的函数,性能更佳!

  3. 强大的流式数据支持 🌊
    gRPC天然支持流式数据传输,我们的应用端可以直观地看到每个Token的返回结果,实现实时展示。

🛠️ 实战开发

来,我们直接开干!💪

第一步:定义.proto文件

首先,我们需要定义gRPC的.proto文件,这个文件用于定义传输的数据结构和服务功能:

syntax = "proto3";

package aiservice;

// AI聊天服务
service AIChatService {
    // 简单请求-响应
    rpc AskQuestion(QuestionRequest) returns (QuestionResponse);

    // 服务端流式响应 - 用于AI流式生成回答
    rpc AskQuestionStream(QuestionRequest) returns (stream QuestionResponse);
}

// 问题请求
message QuestionRequest {
    string question = 1;
    optional bool stream = 2;  // 是否使用流式响应
}

// 回答响应
message QuestionResponse {
    string answer = 1;
    string question = 2;
    optional bool finished = 3;  // 流式响应是否结束
    optional int32 total_tokens = 4;  // 总token数
}

第二步:生成客户端和服务端代码

.proto文件定义好之后,我们就可以用它来生成Java和Python的代码了。

Java客户端代码
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.stub.StreamObserver;
import aiservice.AIChatServiceGrpc;
import aiservice.Aichat;
import java.util.concurrent.TimeUnit;
import java.util.Iterator;

public class JavaGrpcClient {

    private final ManagedChannel channel;
    private final AIChatServiceGrpc.AIChatServiceBlockingStub blockingStub;
    private final AIChatServiceGrpc.AIChatServiceStub asyncStub;

    public JavaGrpcClient(String host, int port) {
        // 创建gRPC通道
        this.channel = ManagedChannelBuilder.forAddress(host, port)
                .usePlaintext()  // 使用不加密的连接
                .build();
        // 创建阻塞式stub
        this.blockingStub = AIChatServiceGrpc.newBlockingStub(channel);
        // 创建异步stub用于流式通信
        this.asyncStub = AIChatServiceGrpc.newStub(channel);
    }

    public void shutdown() throws InterruptedException {
        channel.shutdown().awaitTermination(5, TimeUnit.SECONDS);
    }

    public String askQuestion(String question) {
        try {
            // 构建请求
            Aichat.QuestionRequest request = Aichat.QuestionRequest.newBuilder()
                    .setQuestion(question)
                    .build();

            // 发送请求并获取响应
            Aichat.QuestionResponse response = blockingStub.askQuestion(request);

            // 返回答案
            return response.getAnswer();

        } catch (Exception e) {
            System.err.println("RPC failed: " + e.getMessage());
            return "Error: " + e.getMessage();
        }
    }

    public void askQuestionStream(String question) {
        try {
            // 构建流式请求
            Aichat.QuestionRequest request = Aichat.QuestionRequest.newBuilder()
                    .setQuestion(question)
                    .setStream(true)
                    .build();

            System.out.println("问题: " + question);
            System.out.println("开始流式接收回答...");

            // 发送流式请求并接收响应
            Iterator<Aichat.QuestionResponse> responses = blockingStub.askQuestionStream(request);

            StringBuilder fullAnswer = new StringBuilder();
            int chunkCount = 0;

            while (responses.hasNext()) {
                Aichat.QuestionResponse response = responses.next();
                chunkCount++;

                if (response.getFinished()) {
                    System.out.println("\n=== 流式响应完成 ===");
                    System.out.println("总数据块数: " + chunkCount);
                    if (response.getTotalTokens() > 0) {
                        System.out.println("总token数: " + response.getTotalTokens());
                    }
                    break;
                }

                // 实时输出流式内容
                String chunk = response.getAnswer();
                System.out.print(chunk);
                System.out.flush(); // 确保立即输出
                fullAnswer.append(chunk);

                // 添加小延迟以便更好地观察流式效果
                try {
                    Thread.sleep(50);
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    break;
                }
            }

            System.out.println("\n\n完整回答长度: " + fullAnswer.length() + " 字符");

        } catch (Exception e) {
            System.err.println("Streaming RPC failed: " + e.getMessage());
        }
    }

    public void askQuestionStreamAsync(String question) {
        try {
            // 构建流式请求
            Aichat.QuestionRequest request = Aichat.QuestionRequest.newBuilder()
                    .setQuestion(question)
                    .setStream(true)
                    .build();

            System.out.println("问题: " + question);
            System.out.println("开始异步流式接收回答...");

            // 异步流式请求
            asyncStub.askQuestionStream(request, new StreamObserver<Aichat.QuestionResponse>() {
                private StringBuilder fullAnswer = new StringBuilder();
                private int chunkCount = 0;

                @Override
                public void onNext(Aichat.QuestionResponse response) {
                    chunkCount++;

                    if (response.getFinished()) {
                        System.out.println("\n=== 异步流式响应完成 ===");
                        System.out.println("总数据块数: " + chunkCount);
                        if (response.getTotalTokens() > 0) {
                            System.out.println("总token数: " + response.getTotalTokens());
                        }
                        return;
                    }

                    // 实时输出流式内容
                    String chunk = response.getAnswer();
                    System.out.print(chunk);
                    System.out.flush();
                    fullAnswer.append(chunk);
                }

                @Override
                public void onError(Throwable t) {
                    System.err.println("异步流式错误: " + t.getMessage());
                }

                @Override
                public void onCompleted() {
                    System.out.println("\n异步流式接收完成。完整回答长度: " + fullAnswer.length() + " 字符");
                }
            });

        } catch (Exception e) {
            System.err.println("Async streaming setup failed: " + e.getMessage());
        }
    }

    public static void main(String[] args) {
        JavaGrpcClient client = new JavaGrpcClient("localhost", 50051);
        try {
            // 测试普通请求
            System.out.println("=== 测试普通同步请求 ===");
            String question1 = "你好,请用一句话介绍你自己";
            String response1 = client.askQuestion(question1);
            System.out.println("回答: " + response1);
            System.out.println();

            // 测试流式请求
            System.out.println("=== 测试服务端流式请求 ===");
            String question2 = "请写一首关于春天的诗";
            client.askQuestionStream(question2);
            System.out.println();

            // 测试异步流式请求
            System.out.println("=== 测试异步流式请求 ===");
            String question3 = "请解释什么是机器学习";
            client.askQuestionStreamAsync(question3);

            // 等待异步流式完成
            Thread.sleep(10000);

        } catch (Exception e) {
            System.err.println("Error: " + e.getMessage());
        } finally {
            try {
                client.shutdown();
            } catch (InterruptedException e) {
                System.err.println("Error during shutdown: " + e.getMessage());
            }
        }
    }
}
Python服务端代码
import grpc
from concurrent import futures
import time
from openai import OpenAI
import json

# 导入生成的gRPC模块
import aichat_pb2
import aichat_pb2_grpc

class AIChatServiceImpl(aichat_pb2_grpc.AIChatServiceServicer):
    def __init__(self):
        # 初始化OpenAI客户端
        self.client = OpenAI(
            base_url='https://api-inference.modelscope.cn/v1',
            api_key='ms-7689bbda-xxxxxxxxxxxxxxxx', # ModelScope Token
        )

    def AskQuestion(self, request, context):
        """处理AI问答请求"""
        try:
            question = request.question
            print(f"Received question: {question}")

            # 调用大模型API
            response = self.client.chat.completions.create(
                model='ZhipuAI/GLM-4.6', # ModelScope Model-Id, required
                messages=[{
                    'role': 'user',
                    'content': [{
                        'type': 'text',
                        'text': question,
                    }],
                }],
                stream=False  # 改为False以便一次性返回结果
            )

            # 提取回答内容
            answer = response.choices[0].message.content
            print(f"Generated answer: {answer[:100]}...")  # 只打印前100个字符

            # 返回gRPC响应
            return aichat_pb2.QuestionResponse(
                answer=answer,
                question=question
            )

        except Exception as e:
            print(f"Error occurred: {str(e)}")
            # 通过gRPC错误状态返回错误信息
            context.set_code(grpc.StatusCode.INTERNAL)
            context.set_details(str(e))
            return aichat_pb2.QuestionResponse()

    def AskQuestionStream(self, request, context):
        """处理AI流式问答请求"""
        try:
            question = request.question
            print(f"Received streaming question: {question}")

            # 调用大模型API (流式)
            response = self.client.chat.completions.create(
                model='ZhipuAI/GLM-4.6', # ModelScope Model-Id, required
                messages=[{
                    'role': 'user',
                    'content': [{
                        'type': 'text',
                        'text': question,
                    }],
                }],
                stream=True  # 启用流式响应
            )

            full_answer = ""
            chunk_count = 0

            # 流式返回AI回答
            for chunk in response:
                if chunk.choices[0].delta.content:
                    content = chunk.choices[0].delta.content
                    full_answer += content
                    chunk_count += 1

                    # 发送流式响应
                    yield aichat_pb2.QuestionResponse(
                        answer=content,
                        question=question,
                        finished=False
                    )

                    print(f"Streamed chunk {chunk_count}: {content[:50]}...")

            # 发送完成信号
            yield aichat_pb2.QuestionResponse(
                answer="",
                question=question,
                finished=True,
                total_tokens=len(full_answer)
            )

            print(f"Streaming completed. Total chunks: {chunk_count}")

        except Exception as e:
            print(f"Streaming error occurred: {str(e)}")
            context.set_code(grpc.StatusCode.INTERNAL)
            context.set_details(str(e))
            yield aichat_pb2.QuestionResponse(
                answer=f"Error: {str(e)}",
                question=question,
                finished=True
            )

def serve():
    """启动gRPC服务器"""
    # 创建gRPC服务器
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))

    # 添加服务实现
    aichat_pb2_grpc.add_AIChatServiceServicer_to_server(AIChatServiceImpl(), server)

    # 监听端口
    port = '50051'
    server.add_insecure_port(f'[::]:{port}')

    print(f"Python gRPC server starting on port {port}")
    server.start()

    try:
        while True:
            time.sleep(86400)  # 一天
    except KeyboardInterrupt:
        server.stop(0)

if __name__ == '__main__':
    serve()

第三步:生成gRPC基础类

上面的是我们的业务层代码,主要负责调用功能,我们没有添加过多的其他功能。接下来是要用到的核心gRPC通信基础类,这两个类需要通过代码生成。

安装Python依赖
pip install grpcio grpcio-tools openai

然后使用命令生成Python代码:

python -m grpc_tools.protoc --python_out=. --grpc_python_out=. --proto_path=. aichat.proto

这样就可以生成我们在Python代码中需要用到的两个核心gRPC调用相关类。

生成Java代码

Java代码的生成需要使用Maven插件,因为我使用的是IDEA工程直接编译。这是pom的依赖配置:

<dependencies>
    <dependency>
        <groupId>io.grpc</groupId>
        <artifactId>grpc-netty</artifactId>
        <version>1.57.2</version>
    </dependency>
    <dependency>
        <groupId>io.grpc</groupId>
        <artifactId>grpc-protobuf</artifactId>
        <version>1.57.2</version>
    </dependency>
    <dependency>
        <groupId>io.grpc</groupId>
        <artifactId>grpc-stub</artifactId>
        <version>1.57.2</version>
    </dependency>
</dependencies>

并在pom中加入我们需要使用的编译插件:

<build>
    <extensions>
        <extension>
            <groupId>kr.motd.maven</groupId>
            <artifactId>os-maven-plugin</artifactId>
            <version>1.4.1.Final</version>
        </extension>
    </extensions>
    <plugins>
        <plugin>
            <groupId>org.xolstice.maven.plugins</groupId>
            <artifactId>protobuf-maven-plugin</artifactId>
            <version>0.5.0</version>
            <configuration>
                <protocArtifact>com.google.protobuf:protoc:3.21.7:exe:${os.detected.classifier}</protocArtifact>
                <pluginId>grpc-java</pluginId>
                <pluginArtifact>io.grpc:protoc-gen-grpc-java:1.57.2:exe:${os.detected.classifier}</pluginArtifact>
                <outputDirectory>${basedir}/src/main/java</outputDirectory>
                <clearOutputDirectory>false</clearOutputDirectory>
            </configuration>
            <executions>
                <execution>
                    <goals>
                        <goal>compile</goal>
                        <goal>compile-custom</goal>
                    </goals>
                </execution>
            </executions>
        </plugin>
    </plugins>
</build>

然后使用Maven的compile命令就可以生成我们需要的Java代码了:

  • 一个是 Aichat.java
  • 一个是 AIChatServiceGrpc.java

✅ 总结

通过gRPC实现Java客户端与Python服务端的通信,我们可以充分利用各自语言的优势:

  • Java客户端保证了企业级应用的稳定性和性能
  • Python服务端便于快速集成各种AI模型和算法库
  • gRPC提供了高效的通信机制和流式数据支持

这种架构模式非常适合需要整合AI能力的传统Java项目!🎉

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐