一、概述

为什么需要Graph

在这里插入图片描述

核心概念

在这里插入图片描述

二、快速入门

实现如下工作流:
开始节点→node1→node2→结束节点
用node2的值替换node1的值

pom.xml添加核心依赖

<dependency>
   <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-web</artifactId>
</dependency>

<dependency>
    <groupId>org.springframework.ai</groupId>
    <artifactId>spring-ai-starter-model-zhipuai</artifactId>
</dependency>

<dependency>
    <groupId>com.alibaba.cloud.ai</groupId>
    <artifactId>spring-ai-alibaba-graph-core</artifactId>
</dependency>

修改配置文件application.yaml

server:
  port: 8889

spring:
  application:
    name: agent-graph
  ai:
    zhipuai:
      api-key: ${ZHIPU_KEY} # 配置智谱大模型的API Key
      chat:
        options:
          model: glm-4-flash

创建状态图的配置类

在这里插入图片描述
GraphConfig.java

@Configuration
@Slf4j
public class GraphConfig {

    @Bean("quickStartGraph")
    public CompiledGraph quickStartGraph() throws GraphStateException {

        KeyStrategyFactory keyStrategyFactory = new KeyStrategyFactory() {
            @Override
            public Map<String, KeyStrategy> apply() {
            	// ReplaceStrategy为替换策略
                return Map.of("input1", new ReplaceStrategy(),
                        "input2", new ReplaceStrategy());
            }
        };
        // 定义状态图StateGraph
        StateGraph stateGraph = new StateGraph("quickStartGraph", keyStrategyFactory);

        // 添加节点
        // AsyncNodeAction.node_async为异步执行
        stateGraph.addNode("node1", AsyncNodeAction.node_async(new NodeAction() {
            @Override
            public Map<String, Object> apply(OverAllState state) throws Exception {
                log.info("node1 state: {}", state);
                return Map.of("input1", 1, "input2", 1);
            }
        }));
        stateGraph.addNode("node2", AsyncNodeAction.node_async(new NodeAction() {
            @Override
            public Map<String, Object> apply(OverAllState state) throws Exception {
                log.info("node2 state: {}", state);
                return Map.of("input1", 2, "input2", 2);
            }
        }));

        // 定义边
        stateGraph.addEdge(StateGraph.START, "node1");
        stateGraph.addEdge("node1", "node2");
        stateGraph.addEdge("node2", StateGraph.END);

        // 编译状态图
        return stateGraph.compile();

    }
}

创建一个Controller

GraphController.java

@RestController
@RequestMapping("/graph")
@Slf4j
public class GraphController {

    private final CompiledGraph compiledGraph;

    public GraphController(CompiledGraph compiledGraph) {
        this.compiledGraph = compiledGraph;
    }

    @GetMapping("/quickStartGraph")
    public String quickStartGraph() {
        Optional<OverAllState> overAllStateOptional = compiledGraph.call(Map.of());
        log.info("overAllStateOptional: {}", overAllStateOptional);
        return "OK";
    }
}

启动程序,查看效果

GET方式,http://localhost:8889/graph/quickStartGraph
在这里插入图片描述
发现input1和input2的值被成功替换为2

三、API详解

KeyStrategyFactory(键策略工厂)

在这里插入图片描述

NodeAction&AsyncNodeAction

在这里插入图片描述

stateGraph(状态图)

状态图的抽象,需要配置状态(通过KeyStrategyFactory ),节点,边。
配置好后通过compile方法编译成CompiledGraph后才可以供调用。

CompiledGraph(编译图)

CompiledGraph是StateGraph编译后的结果,CompiledGraph才能用了执行。
一般我们是把StateGraph定义好后调用其compile方法得到一个CompiledGraph放入Spring容器中然后在需要的时候从容器中注入然后再调用。
在这里插入图片描述

四、案例:开发一个英语学习小助手

需求

使用Graph开发一个英语学习小助手。
功能如下:输入一个单词,能基于这个单词造句,然后再对句子进行翻译,把造句的译文也返回。

思路分析

我们可以定义一个工作流,工作流中主要有两个节点:
SentenceConstructionNode 造句节点,拿输入的单词让LLM进行造句。
TranslationNode 翻译节点,能够把一个英文句子翻译成中文。最终把造句的结果和翻译的结果返回即可。

流程图

开始节点(输入一个单词)–>造句节点(根据给定的单词进行造句)–>翻译节点(对句子进行翻译)–>结束节点(输出造句和翻译的结果)

代码编写

定义SentenceConstructionNode造句节点

在这里插入图片描述

public class SentenceConstructionNode implements NodeAction {

    private final ChatClient chatClient;

    public SentenceConstructionNode(ChatClient.Builder builder) {
        this.chatClient = builder.build();
    }

    @Override
    public Map<String, Object> apply(OverAllState state) throws Exception {
        // 从stage中获取要造句的单词
        String word = state.value("word", "");

        // 定义提示词
        PromptTemplate promptTemplate = new PromptTemplate("你是一个英语造句专家,能够基于给定的单词进行造句。"+
                "要求只返回最终造好的句子,不要返回其他信息。给定的单词:{word}");
        promptTemplate.add("word", word);   // 替换占位符
        String prompt = promptTemplate.render();  // 渲染提示词

        // 模型调用
        String content = chatClient.prompt()
                .user(prompt)
                .call()
                .content();

        // 把句子存入stage
        return Map.of("sentence", content);
    }
}

定义TranslationNode翻译节点

public class TranslationNode implements NodeAction {

    private final ChatClient chatClient;

    public TranslationNode(ChatClient.Builder builder) {
        this.chatClient = builder.build();
    }

    @Override
    public Map<String, Object> apply(OverAllState state) throws Exception {
        // 从stage中获取要翻译的句子
        String sentence = state.value("sentence", "");

        // 定义提示词
        PromptTemplate promptTemplate = new PromptTemplate("你是一个英语翻译专家,能够把英文翻译成中文。" +
                "要求只返回翻译的中文结果,不要返回英文原句。要翻译的英文句子:{sentence}");
        promptTemplate.add("sentence", sentence);   // 替换占位符
        String prompt = promptTemplate.render();  // 渲染提示词

        // 模型调用
        String content = chatClient.prompt()
                .user(prompt)
                .call()
                .content();

        // 把翻译结果存入stage
        return Map.of("translation", content);
    }
}

定义状态图

config/GraphConfig.java,在quickStartGraph下面增加如下内容
在这里插入图片描述

@Bean("simpleGraph")
public CompiledGraph simpleGraph(ChatClient.Builder clientBuilder) throws GraphStateException {

    KeyStrategyFactory keyStrategyFactory = () -> {
            HashMap<String, KeyStrategy> keyStrategyHashMap = new HashMap<>();
            keyStrategyHashMap.put("word", new ReplaceStrategy());
            keyStrategyHashMap.put("sentence", new ReplaceStrategy());
            keyStrategyHashMap.put("translation", new ReplaceStrategy());
            return keyStrategyHashMap;
        };
    // 创建状态图
    StateGraph stateGraph = new StateGraph("simpleGraph", keyStrategyFactory);
    // 添加节点
    stateGraph.addNode("SentenceConstructionNode", AsyncNodeAction.node_async(new SentenceConstructionNode(clientBuilder)));
    stateGraph.addNode("TranslationNode", AsyncNodeAction.node_async(new TranslationNode(clientBuilder)));
    // 定义边
    stateGraph.addEdge(StateGraph.START, "SentenceConstructionNode");
    stateGraph.addEdge("SentenceConstructionNode", "TranslationNode");
    stateGraph.addEdge("TranslationNode", StateGraph.END);
    // 编译状态图,放入容器
    return stateGraph.compile();
}

新增API接口

在这里插入图片描述

@RestController
@RequestMapping("/graph")
@Slf4j
public class GraphController {

    private final CompiledGraph compiledGraph;
    private final CompiledGraph simpleGraph;

    public GraphController(@Qualifier("quickStartGraph") CompiledGraph compiledGraph,
                           @Qualifier("simpleGraph") CompiledGraph simpleGraph) {
        this.compiledGraph = compiledGraph;
        this.simpleGraph = simpleGraph;
    }

    @GetMapping("/quickStartGraph")
    public String quickStartGraph() {
        Optional<OverAllState> overAllStateOptional = compiledGraph.call(Map.of());
        log.info("overAllStateOptional: {}", overAllStateOptional);
        return "OK";
    }

    @GetMapping("/simpleGraph")
    public Map<String, Object> simpleGraph(@RequestParam("word") String word) {
        Optional<OverAllState> overAllStateOptional = simpleGraph.call(Map.of("word", word));
        Map<String, Object> data = overAllStateOptional.map(OverAllState::data).orElse(Map.of());
        return data;
    }
}

启动服务,访问接口

GET方法:http://localhost:8889/graph/simpleGraph?word=sky
在这里插入图片描述

五、条件边

在这里插入图片描述
在这里插入图片描述

代码结构

在这里插入图片描述

定义GenerateJokeNode生成笑话节点

public class GenerateJokeNode implements NodeAction {

    private final ChatClient chatClient;

    public GenerateJokeNode(ChatClient.Builder builder) {
        this.chatClient = builder.build();
    }

    @Override
    public Map<String, Object> apply(OverAllState state) throws Exception {
        // 从stage中获取笑话主题
        String topic = state.value("topic", "");

        // 定义提示词
        PromptTemplate promptTemplate = new PromptTemplate("你需要写一个关于指定主题的短笑话。要求返回的结果中只能包含笑话的内容" +
                "主题:{topic}");
        promptTemplate.add("topic", topic);   // 替换占位符
        String prompt = promptTemplate.render();  // 渲染提示词

        // 模型调用
        String content = chatClient.prompt()
                .user(prompt)
                .call()
                .content();

        // 把结果存入stage
        return Map.of("joke", content);
    }
}

定义EvaluateJokesNode评估笑话节点

public class EvaluateJokesNode implements NodeAction {

    private final ChatClient chatClient;

    public EvaluateJokesNode(ChatClient.Builder builder) {
        this.chatClient = builder.build();
    }

    @Override
    public Map<String, Object> apply(OverAllState state) throws Exception {
        // 从stage中获取待评估笑话
        String joke = state.value("joke", "");

        // 定义提示词
        PromptTemplate promptTemplate = new PromptTemplate("你是一个笑话评分专家,能够对笑话进行评分,基于效果的搞笑程度给出0到10分的打分。" +
                "0到3分是不够优秀,4到10分是优秀。要求结果只返回优秀或者不够优秀,不能输出其他内容。"+
                "要评分的笑话:{joke}");
        promptTemplate.add("joke", joke);   // 替换占位符
        String prompt = promptTemplate.render();  // 渲染提示词

        // 模型调用
        String content = chatClient.prompt()
                .user(prompt)
                .call()
                .content();

        // 把结果存入stage
        return Map.of("result", content.trim());
    }
}

定义EnhanceJokeQualityNode优化笑话节点

public class EnhanceJokeQualityNode implements NodeAction {

    private final ChatClient chatClient;

    public EnhanceJokeQualityNode(ChatClient.Builder builder) {
        this.chatClient = builder.build();
    }

    @Override
    public Map<String, Object> apply(OverAllState state) throws Exception {
        // 从stage中获取待评估笑话
        String joke = state.value("joke", "");

        // 定义提示词
        PromptTemplate promptTemplate = new PromptTemplate("你是一个笑话优化专家,你能够优化笑话,让它更加搞笑" +
                "要优化的话:{joke}");
        promptTemplate.add("joke", joke);   // 替换占位符
        String prompt = promptTemplate.render();  // 渲染提示词

        // 模型调用
        String content = chatClient.prompt()
                .user(prompt)
                .call()
                .content();

        // 把结果存入stage
        return Map.of("newJoke", content);
    }
}

在GraphConfig下面定义图

@Bean("conditionalGraph")
    public CompiledGraph conditionalGraph(ChatClient.Builder clientBuilder) throws GraphStateException {
        KeyStrategyFactory keyStrategyFactory =()-> Map.of("topic",new ReplaceStrategy());
        // 定义状态图StateGraph
        StateGraph stateGraph = new StateGraph("conditionalGraph", keyStrategyFactory);
        // 定义节点
        stateGraph.addNode("生成笑话", AsyncNodeAction.node_async(new GenerateJokeNode(clientBuilder)));
        stateGraph.addNode("评估笑话", AsyncNodeAction.node_async(new EvaluateJokesNode(clientBuilder)));
        stateGraph.addNode("优化笑话", AsyncNodeAction.node_async(new EnhanceJokeQualityNode(clientBuilder)));
        // 定义边
        stateGraph.addEdge(StateGraph.START, "生成笑话");
        stateGraph.addEdge("生成笑话", "评估笑话");
        stateGraph.addConditionalEdges("评估笑话", AsyncEdgeAction.edge_async(
                state -> state.value("result", "优秀")),
                Map.of("优秀",StateGraph.END,
                        "不够优秀", "优化笑话"));
        stateGraph.addEdge("优化笑话", StateGraph.END);

        return stateGraph.compile();
    }

在GraphController下创建接口

private final CompiledGraph compiledGraph;
    private final CompiledGraph simpleGraph;
    private final CompiledGraph conditionalGraph;

    public GraphController(@Qualifier("quickStartGraph") CompiledGraph compiledGraph,
                           @Qualifier("simpleGraph") CompiledGraph simpleGraph,
                           @Qualifier("conditionalGraph") CompiledGraph conditionalGraph) {
        this.compiledGraph = compiledGraph;
        this.simpleGraph = simpleGraph;
        this.conditionalGraph = conditionalGraph;
    }

    @GetMapping("/conditionalGraph")
    public Map<String, Object> conditionalGraph(@RequestParam("topic") String topic) {
        Optional<OverAllState> overAllStateOptional = conditionalGraph.call(Map.of("topic", topic));
        Map<String, Object> data = overAllStateOptional.map(OverAllState::data).orElse(Map.of());
        return data;
    }

验证效果

GET方式:http://localhost:8889/graph/conditionalGraph?topic=爱情
评估结果是优秀
在这里插入图片描述
直接输出结果
在这里插入图片描述
在断点处右击,选择“Evaluate Expression”
在这里插入图片描述
篡改评估结果为"不够优秀",回车后关闭
在这里插入图片描述
修改成功
在这里插入图片描述
就会走优化节点,生成新的笑话
在这里插入图片描述

六、循环边

在这里插入图片描述

新增LoopEvaluateJokesNode循环评分节点

@Slf4j
public class LoopEvaluateJokesNode implements NodeAction {

    private final ChatClient chatClient;
    private final Integer targetScore;
    private final Integer maxLoopCount;

    public LoopEvaluateJokesNode(ChatClient.Builder builder, Integer targetScore, Integer maxLoopCount) {
        this.chatClient = builder.build();
        this.targetScore = targetScore;
        this.maxLoopCount = maxLoopCount;
    }

    @Override
    public Map<String, Object> apply(OverAllState state) throws Exception {
        // 从stage中获取待评估笑话
        String joke = state.value("joke", "");
        // 循环次数
        Integer loopCount = state.value("loopCount", 0);

        // 定义提示词
        PromptTemplate promptTemplate = new PromptTemplate("你是一个笑话评分专家,能够对笑话进行评分,基于效果的搞笑程度给出0到10分的打分。" +
                "要求结果只返回最后的打分,打分必须是整数,不能输出其他内容。"+
                "要评分的笑话:{joke}");
        promptTemplate.add("joke", joke);   // 替换占位符
        String prompt = promptTemplate.render();  // 渲染提示词

        // 模型调用
        String content = chatClient.prompt()
                .user(prompt)
                .call()
                .content();
        // content转为整数
        Integer score = Integer.parseInt(content.trim());

        log.info("joke: {},score: {},循环次数: {}", joke, score, loopCount);
        // 根据分数判断是否继续循环,循环最多执行5次
        String result = "loop";
        if (score >= targetScore || loopCount >= maxLoopCount) {
            result = "break";
        }
        loopCount++;
        // 把结果存入stage
        return Map.of("result", result, "loopCount", loopCount);
    }
}

在GraphConfig下面定义图

@Bean("loopGraph")
    public CompiledGraph loopGraph(ChatClient.Builder clientBuilder) throws GraphStateException {
        KeyStrategyFactory keyStrategyFactory =()-> Map.of("topic",new ReplaceStrategy());
        // 定义状态图StateGraph
        StateGraph stateGraph = new StateGraph("loopGraph", keyStrategyFactory);
        // 定义节点
        stateGraph.addNode("生成笑话", AsyncNodeAction.node_async(new GenerateJokeNode(clientBuilder)));
        stateGraph.addNode("评估笑话", AsyncNodeAction.node_async(new LoopEvaluateJokesNode(clientBuilder, 8, 5)));
        // 定义边
        stateGraph.addEdge(StateGraph.START, "生成笑话");
        stateGraph.addEdge("生成笑话", "评估笑话");
        stateGraph.addConditionalEdges("评估笑话", AsyncEdgeAction.edge_async(
                        state -> state.value("result", "loop")),
                Map.of("loop","生成笑话",
                        "break", StateGraph.END));

        return stateGraph.compile();
    }

在GraphController下创建接口

@GetMapping("/loopGraph")
public Map<String, Object> loopGraph(@RequestParam("topic") String topic) {
    Optional<OverAllState> overAllStateOptional = loopGraph.call(Map.of("topic", topic));
    Map<String, Object> data = overAllStateOptional.map(OverAllState::data).orElse(Map.of());
    return data;
}

测试

GET方式:http://localhost:8889/graph/loopGraph?topic=爱情
当score为8时,退出循环,输出结果
在这里插入图片描述
在这里插入图片描述

七、状态存储

我们可以把图中的状态数据进行存储。默契情况下Graph会把状态存储到内存中。

在ConfigGraph中创建状态图

@Bean("saveGraph")
public CompiledGraph saveGraph(ChatClient.Builder clientBuilder) throws GraphStateException {
    KeyStrategyFactory keyStrategyFactory = () -> Map.of();

    // 定义状态图 stateGraph
    StateGraph stateGraph = new StateGraph("saveGraph", keyStrategyFactory);

    stateGraph.addNode("对话存储",AsyncNodeAction.node_async(new NodeAction() {
                @Override
                public Map<String, Object> apply(OverAllState state) throws Exception {
                    String msg = state.value("msg", "");
                    ArrayList<Object> historyMsg = state.value("historyMsg", new ArrayList<>());
                    historyMsg.add(msg);
                    return Map.of("historyMsg", historyMsg);
                }
            })
    );

    // 定义边
    stateGraph.addEdge(StateGraph.START, "对话存储");
    stateGraph.addEdge("对话存储", StateGraph.END);

    return stateGraph.compile();
}

在GraphController中创建接口

@GetMapping("/saveGraph")
// 通过conversationId来隔离不同请求者的数据
public Map<String, Object> saveGraph(@RequestParam("msg") String msg, @RequestParam("conversationId") String conversationId) {
    RunnableConfig runnableConfig = RunnableConfig.builder()
            .threadId(conversationId)
            .build();
    Optional<OverAllState> overAllStateOptional = saveGraph.call(Map.of("msg", msg), runnableConfig);
    Map<String, Object> data = overAllStateOptional.map(OverAllState::data).orElse(Map.of());
    return data;
}

测试

GET方式:http://localhost:8889/graph/saveGraph?msg=你好张三&conversationId=zs
第一次调用
在这里插入图片描述

第二次调用,发现前面的值存储了下来
在这里插入图片描述
修改会话ID,历史数据只有最新的一条数据
在这里插入图片描述

八、打印图

我们可以把定义好的状态图进行打印,更直观的看到当前图的情况

在图的下面添加如下代码:

// 添加PlantUML打印
GraphRepresentation representation = stateGraph.getGraph(GraphRepresentation.Type.PLANTUML, "stateGraph");
log.info("\n===打印UML Flow===");
log.info(representation.content());
log.info("====================\n");

在这里插入图片描述
启动服务,复制如下内容
在这里插入图片描述
打开网址:http://www.plantuml.com/plantuml/
粘贴内容,就可以看到图的效果了
在这里插入图片描述

九、资料

视频:https://www.bilibili.com/video/BV1eyWbzEEnw?spm_id_from=333.788.player.switch&vd_source=0467ab39cc5ec5940fee22a0e7797575&p=45

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐