问题背景

SAAChatController.chat() 方法中,当前使用流式输出获取大模型响应:

return chatService.chat(chatId, model, prompt)
    .doOnNext(chunk -> {
        resultContent.get().append(chunk);
    })
    .doOnComplete(() -> {
        billingService.taskCallbackWithToken(taskId, "chat", finalModel,
            BillingService.TaskStatus.SUCCESS, fullResult, null, null, token, ScenarioType.CHAT);
    });

当前问题:

  1. 只收集了文本内容,没有获取 token 使用量
  2. BillingServicerequest.setTokens(result) 传递的是文本内容而非 token 数量
  3. 无法准确按 token 数量计费

Spring AI Token 获取机制

1. Spring AI ChatResponse 结构

Spring AI 的 ChatResponse 对象包含以下信息:

ChatResponse
├── Result
│   ├── Output (AssistantMessage)
│   └── Metadata
│       ├── id
│       ├── model
│       ├── rateLimit
│       └── usage (重要!)
│           ├── promptTokens      // 输入 token 数量
│           ├── generationTokens  // 输出 token 数量
│           └── totalTokens       // 总 token 数量
└── Metadata
    └── ...

2. 获取 Token 的两种方式

方式一:非流式调用(call())
ChatResponse response = chatClient.prompt()
    .user(prompt)
    .call()
    .chatResponse();

// 获取 token 使用情况
Usage usage = response.getMetadata().getUsage();
Long promptTokens = usage.getPromptTokens();        // 输入 token
Long generationTokens = usage.getGenerationTokens(); // 输出 token
Long totalTokens = usage.getTotalTokens();          // 总 token
方式二:流式调用(stream())

问题: 流式调用 .stream().content() 只返回 Flux<String>,不包含元数据。

解决方案: 使用 .stream().chatResponse() 而不是 .stream().content()

Flux<ChatResponse> responseFlux = chatClient.prompt()
    .user(prompt)
    .stream()
    .chatResponse();  // 注意:这里是 chatResponse() 而不是 content()

AtomicReference<Usage> finalUsage = new AtomicReference<>();

return responseFlux
    .doOnNext(chatResponse -> {
        // 收集文本内容
        String content = chatResponse.getResult().getOutput().getContent();
        resultContent.get().append(content);

        // 保存最后一次的 usage(最后一个响应包含完整的 token 统计)
        if (chatResponse.getMetadata() != null &&
            chatResponse.getMetadata().getUsage() != null) {
            finalUsage.set(chatResponse.getMetadata().getUsage());
        }
    })
    .doOnComplete(() -> {
        Usage usage = finalUsage.get();
        Long totalTokens = usage != null ? usage.getTotalTokens() : null;

        // 进行计费回调,传递 token 数量
        billingService.taskCallbackWithToken(taskId, "chat", finalModel,
            BillingService.TaskStatus.SUCCESS, fullResult,
            totalTokens, null, token, ScenarioType.CHAT);
    });

完整实现方案

方案一:修改流式输出返回 ChatResponse(推荐)⭐⭐⭐

1. 修改 SAAChatService

修改前(SAAChatService.java:156):

return promptSpec.stream().content();

修改后:

// 返回完整的 ChatResponse 流,而不仅仅是 content
return promptSpec.stream().chatResponse()
    .map(chatResponse -> {
        // 提取文本内容用于返回给前端
        String content = chatResponse.getResult().getOutput().getContent();

        // 将 ChatResponse 的元数据存储到线程上下文或其他机制
        // (这里需要设计一个机制来传递 usage 信息)

        return content;
    });

问题: 这种方式仍然只能返回 Flux<String> 给前端,元数据会丢失。

2. 更好的方案:使用 Reactor Context 传递元数据

修改 SAAChatService.java:

public Flux<String> chat(String chatId, String model, String prompt) {
    // ... 现有代码 ...

    AtomicReference<Usage> usageRef = new AtomicReference<>();

    return promptSpec.stream().chatResponse()
        .doOnNext(chatResponse -> {
            // 保存 usage 信息到原子引用
            if (chatResponse.getMetadata() != null &&
                chatResponse.getMetadata().getUsage() != null) {
                usageRef.set(chatResponse.getMetadata().getUsage());
            }
        })
        .map(chatResponse -> chatResponse.getResult().getOutput().getContent())
        .contextWrite(ctx -> ctx.put("usageRef", usageRef));
}

修改 SAAChatController.java:

@PostMapping("/chat")
public Flux<String> chat(
        HttpServletResponse response,
        @Validated @RequestBody String prompt,
        @RequestHeader(value = "model", required = false) String model,
        @RequestHeader(value = "chatId", required = false) String chatId
) {
    // ... 现有代码 ...

    // 使用 AtomicReference 来收集内容和 token 使用量
    AtomicReference<StringBuilder> resultContent = new AtomicReference<>(new StringBuilder());
    AtomicReference<Long> totalTokens = new AtomicReference<>(null);
    AtomicReference<Long> promptTokens = new AtomicReference<>(null);
    AtomicReference<Long> generationTokens = new AtomicReference<>(null);

    response.setCharacterEncoding("UTF-8");

    String finalModel = model;
    return chatService.chatWithUsage(chatId, model, prompt)  // 新方法
        .doOnNext(chunk -> {
            // chunk 是包含 content 和 usage 的包装对象
            resultContent.get().append(chunk.getContent());
            if (chunk.getUsage() != null) {
                totalTokens.set(chunk.getUsage().getTotalTokens());
                promptTokens.set(chunk.getUsage().getPromptTokens());
                generationTokens.set(chunk.getUsage().getGenerationTokens());
            }
        })
        .map(chunk -> chunk.getContent())  // 只返回文本给前端
        .doOnComplete(() -> {
            String fullResult = resultContent.get().toString();

            // 传递 token 数量进行计费
            billingService.taskCallbackWithTokenUsage(
                taskId,
                "chat",
                finalModel,
                BillingService.TaskStatus.SUCCESS,
                fullResult,
                totalTokens.get(),
                promptTokens.get(),
                generationTokens.get(),
                null,
                null,
                token,
                BillingService.ScenarioType.CHAT
            );
        })
        .doOnError(error -> {
            logger.error(error.getMessage());
            String fullResult = resultContent.get().toString();
            billingService.taskCallbackWithTokenUsage(
                taskId,
                "chat",
                finalModel,
                BillingService.TaskStatus.FAILED,
                fullResult,
                totalTokens.get(),
                promptTokens.get(),
                generationTokens.get(),
                null,
                error.getMessage(),
                token,
                BillingService.ScenarioType.CHAT
            );
        });
}

方案二:创建 DTO 对象封装响应(最清晰)⭐⭐⭐⭐⭐

1. 创建 ChatChunk DTO

创建文件:src/main/java/com/alibaba/cloud/ai/application/entity/chat/ChatChunk.java

package com.alibaba.cloud.ai.application.entity.chat;

import org.springframework.ai.chat.metadata.Usage;

/**
 * 聊天响应块,包含内容和 token 使用情况
 */
public class ChatChunk {
    private String content;
    private Usage usage;

    public ChatChunk(String content, Usage usage) {
        this.content = content;
        this.usage = usage;
    }

    public String getContent() {
        return content;
    }

    public Usage getUsage() {
        return usage;
    }

    public Long getTotalTokens() {
        return usage != null ? usage.getTotalTokens() : null;
    }

    public Long getPromptTokens() {
        return usage != null ? usage.getPromptTokens() : null;
    }

    public Long getGenerationTokens() {
        return usage != null ? usage.getGenerationTokens() : null;
    }
}
2. 修改 SAAChatService 返回 ChatChunk

修改 SAAChatService.java:

/**
 * 聊天接口,返回包含 token 使用情况的流
 */
public Flux<ChatChunk> chatWithUsage(String chatId, String model, String prompt) {
    log.debug("chat model is: {}", model);
    log.info("Chat request - chatId: {}, model: {}, prompt length: {}",
        chatId, model, prompt.length());

    ChatClient selectedClient = selectChatClient(model);

    if (Objects.equals("deepseek-r1", model) || Objects.equals("deepseek-reasoner", model)) {
        selectedClient.prompt().advisors(reasoningContentAdvisor);
    }

    var promptSpec = selectedClient.prompt()
            .user(prompt)
            .advisors(memoryAdvisor -> {
                log.debug("Setting conversation ID for memory advisor: {}", chatId);
                memoryAdvisor.param(ChatMemory.CONVERSATION_ID, chatId);
            });

    if (isDeepSeekOfficialModel(model)) {
        var openAiOptions = OpenAiChatOptions.builder()
                .model(model)
                .temperature(0.8)
                .build();
        promptSpec.options(openAiOptions);
    } else {
        var dashScopeOptions = DashScopeChatOptions.builder()
                .withModel(model)
                .withTemperature(0.8)
                .withResponseFormat(DashScopeResponseFormat.builder()
                        .type(DashScopeResponseFormat.Type.TEXT)
                        .build()
                ).build();
        promptSpec.options(dashScopeOptions);
    }

    // 返回 ChatResponse 流,并转换为 ChatChunk
    return promptSpec.stream().chatResponse()
        .map(chatResponse -> {
            String content = chatResponse.getResult().getOutput().getContent();
            Usage usage = chatResponse.getMetadata() != null ?
                         chatResponse.getMetadata().getUsage() : null;
            return new ChatChunk(content, usage);
        });
}

/**
 * 兼容旧接口,只返回文本内容
 */
public Flux<String> chat(String chatId, String model, String prompt) {
    return chatWithUsage(chatId, model, prompt)
        .map(ChatChunk::getContent);
}
3. 修改 SAAChatController

修改 SAAChatController.java:

@PostMapping("/chat")
public Flux<String> chat(
        HttpServletResponse response,
        @Validated @RequestBody String prompt,
        @RequestHeader(value = "model", required = false) String model,
        @RequestHeader(value = "chatId", required = false) String chatId
) {
    // ... 验证模型代码保持不变 ...

    if (!billingService.judgeBalance(model, BillingService.ScenarioType.CHAT)) {
        return Flux.just("账户余额不足,请充值后再试");
    }

    String taskId = billingService.generateTaskId();
    String token = StpUtil.getTokenValue();

    // 收集结果和 token 使用情况
    AtomicReference<StringBuilder> resultContent = new AtomicReference<>(new StringBuilder());
    AtomicReference<Long> totalTokens = new AtomicReference<>(0L);
    AtomicReference<Long> promptTokens = new AtomicReference<>(0L);
    AtomicReference<Long> generationTokens = new AtomicReference<>(0L);

    response.setCharacterEncoding("UTF-8");

    String finalModel = model;
    return chatService.chatWithUsage(chatId, model, prompt)
        .doOnNext(chunk -> {
            // 收集文本内容
            resultContent.get().append(chunk.getContent());

            // 更新 token 使用情况(每个 chunk 都可能包含更新的 usage)
            if (chunk.getUsage() != null) {
                if (chunk.getTotalTokens() != null) {
                    totalTokens.set(chunk.getTotalTokens());
                }
                if (chunk.getPromptTokens() != null) {
                    promptTokens.set(chunk.getPromptTokens());
                }
                if (chunk.getGenerationTokens() != null) {
                    generationTokens.set(chunk.getGenerationTokens());
                }
            }
        })
        .map(ChatChunk::getContent)  // 只返回文本内容给前端
        .doOnComplete(() -> {
            String fullResult = resultContent.get().toString();

            logger.info("Chat completed - taskId: {}, model: {}, tokens: {} (prompt: {}, generation: {})",
                taskId, finalModel, totalTokens.get(), promptTokens.get(), generationTokens.get());

            // 使用真实的 token 数量进行计费
            billingService.taskCallbackWithTokenUsage(
                taskId,
                "chat",
                finalModel,
                BillingService.TaskStatus.SUCCESS,
                fullResult,
                totalTokens.get(),
                promptTokens.get(),
                generationTokens.get(),
                null,
                null,
                token,
                BillingService.ScenarioType.CHAT
            );
        })
        .doOnError(error -> {
            logger.error("Chat failed - taskId: {}, error: {}", taskId, error.getMessage());
            String fullResult = resultContent.get().toString();
            billingService.taskCallbackWithTokenUsage(
                taskId,
                "chat",
                finalModel,
                BillingService.TaskStatus.FAILED,
                fullResult,
                totalTokens.get(),
                promptTokens.get(),
                generationTokens.get(),
                null,
                error.getMessage(),
                token,
                BillingService.ScenarioType.CHAT
            );
        });
}
4. 修改 BillingService

修改 BillingService.java,添加新方法:

/**
 * 任务回调,传递详细的 token 使用情况
 *
 * @param taskId 任务订单号
 * @param title 文件名
 * @param modelName 使用的模型
 * @param status 任务状态
 * @param result 生成的文字内容
 * @param totalTokens 总 token 数量
 * @param promptTokens 输入 token 数量
 * @param generationTokens 输出 token 数量
 * @param fileUrl 文件路径
 * @param errorMessage 错误消息
 * @param token 用户token
 * @param scenario 场景类型
 * @return true表示回调成功
 */
public boolean taskCallbackWithTokenUsage(
        String taskId,
        String title,
        String modelName,
        Integer status,
        String result,
        Long totalTokens,
        Long promptTokens,
        Long generationTokens,
        String fileUrl,
        String errorMessage,
        String token,
        String scenario) {
    try {
        if (token == null || token.isEmpty()) {
            logger.error("用户token为空,无法进行任务回调");
            return false;
        }

        String chargeCode = determineChargeCode(modelName, scenario);

        // 构造回调请求参数
        TaskCallbackReq request = new TaskCallbackReq();
        request.setPlatformCode(platformCode);
        request.setPlatformId(platformId);
        request.setTaskType(TaskType.TOKEN);
        request.setTaskId(taskId);
        request.setTitle(title);
        request.setModelName(modelName);
        request.setStatus(status);
        request.setResult(result);

        // 设置详细的 token 使用情况
        request.setTotalTokens(totalTokens);
        request.setPromptTokens(promptTokens);
        request.setGenerationTokens(generationTokens);

        request.setFileUrl(fileUrl);
        request.setErrorMessage(errorMessage);
        request.setChargeType(TaskType.TOKEN);
        request.setType(ChargeType.PER_COUNT);

        // 根据 token 数量计算实际扣费数量
        request.setQuantity(totalTokens != null ? totalTokens : 1);
        request.setChargeCode(chargeCode);

        // 设置请求头
        HttpHeaders headers = new HttpHeaders();
        headers.setContentType(MediaType.APPLICATION_JSON);
        headers.set("Authorization", "Bearer " + token);

        HttpEntity<TaskCallbackReq> entity = new HttpEntity<>(request, headers);

        // 调用任务回调接口
        String url = billingServiceUrl + "/billing/callback";
        ResponseEntity<BillingResponse> response = restTemplate.exchange(
            url,
            HttpMethod.POST,
            entity,
            BillingResponse.class
        );

        if (response.getStatusCode() == HttpStatus.OK &&
            response.getBody() != null &&
            response.getBody().isSuccess()) {
            logger.info("任务回调成功: taskId={}, status={}, tokens={} (prompt={}, generation={}), 场景={}, 扣费编码={}",
                taskId, status, totalTokens, promptTokens, generationTokens, scenario, chargeCode);
            return true;
        } else {
            logger.error("任务回调接口返回异常状态: {}, taskId: {}, 响应: {}",
                response.getStatusCode(), taskId, response.getBody());
            return false;
        }

    } catch (Exception e) {
        logger.error("调用任务回调接口异常,taskId: {}, status: {}", taskId, status, e);
        return false;
    }
}

// 保留旧方法以兼容其他代码
@Deprecated
public boolean taskCallbackWithToken(String taskId, String title, String modelName,
                                    Integer status, String result, String fileUrl,
                                    String errorMessage, String token, String scenario) {
    // 调用新方法,token 数量设为 null
    return taskCallbackWithTokenUsage(taskId, title, modelName, status, result,
        null, null, null, fileUrl, errorMessage, token, scenario);
}
5. 修改 TaskCallbackReq 实体

修改文件:src/main/java/com/alibaba/cloud/ai/application/entity/req/TaskCallbackReq.java

package com.alibaba.cloud.ai.application.entity.req;

// ... 其他 imports ...

public class TaskCallbackReq {
    // ... 现有字段 ...

    /**
     * 总 token 数量
     */
    private Long totalTokens;

    /**
     * 输入 token 数量(提示词)
     */
    private Long promptTokens;

    /**
     * 输出 token 数量(生成内容)
     */
    private Long generationTokens;

    // Getters and Setters
    public Long getTotalTokens() {
        return totalTokens;
    }

    public void setTotalTokens(Long totalTokens) {
        this.totalTokens = totalTokens;
    }

    public Long getPromptTokens() {
        return promptTokens;
    }

    public void setPromptTokens(Long promptTokens) {
        this.promptTokens = promptTokens;
    }

    public Long getGenerationTokens() {
        return generationTokens;
    }

    public void setGenerationTokens(Long generationTokens) {
        this.generationTokens = generationTokens;
    }
}

方案对比

方案 优点 缺点 推荐度
方案一:Reactor Context 实现简单 上下文传递复杂 ⭐⭐
方案二:ChatChunk DTO 类型安全、清晰明了 需要新建类 ⭐⭐⭐⭐⭐

实施步骤

推荐使用方案二,步骤如下:

  1. 创建 ChatChunk DTO

    • 创建 ChatChunk.java 封装内容和 token 使用情况
  2. 修改 SAAChatService

    • 添加 chatWithUsage() 方法返回 Flux<ChatChunk>
    • 保留 chat() 方法调用 chatWithUsage() 以兼容旧代码
  3. 修改 SAAChatController

    • 使用 chatWithUsage() 获取流式响应
    • 收集 token 使用情况
    • 调用新的计费方法传递 token 数量
  4. 修改 BillingService

    • 添加 taskCallbackWithTokenUsage() 方法接收详细 token 信息
    • 保留旧方法以兼容其他代码
  5. 修改 TaskCallbackReq

    • 添加 totalTokenspromptTokensgenerationTokens 字段
  6. 同样处理 deepThinkingChat

    • 使用相同的方法修改深度思考接口

注意事项

  1. Token 更新时机

    • 流式输出中,token 统计可能在最后一个 chunk 才完整
    • 需要在每个 chunk 中检查并更新 token 使用情况
  2. 空值处理

    • 某些响应可能不包含 usage 信息
    • 需要进行 null 检查
  3. 日志记录

    • 记录详细的 token 使用情况便于调试和审计
    • 包括 prompt tokens、generation tokens 和 total tokens
  4. 兼容性

    • 保留旧方法以确保其他代码不受影响
    • 使用 @Deprecated 标注废弃方法
  5. 计费精度

    • 根据实际业务需求决定按哪个 token 数量计费
    • 通常使用 totalTokens

测试验证

1. 单元测试

@Test
void testChatWithTokenUsage() {
    List<ChatChunk> chunks = new ArrayList<>();

    chatService.chatWithUsage("test-chat-id", "qwen-plus", "你好")
        .doOnNext(chunks::add)
        .blockLast();

    // 验证收集到了 token 使用情况
    ChatChunk lastChunk = chunks.get(chunks.size() - 1);
    assertNotNull(lastChunk.getUsage());
    assertTrue(lastChunk.getTotalTokens() > 0);
}

2. 集成测试

curl -X POST http://localhost:8090/api/v1/chat \
  -H "Content-Type: application/json" \
  -H "model: qwen-plus" \
  -H "chatId: test-123" \
  -d "你好,请介绍一下你自己"

检查日志输出是否包含 token 使用情况:

Chat completed - taskId: task_xxx, model: qwen-plus, tokens: 156 (prompt: 12, generation: 144)
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐