目前针对update_state方法的调用都是针对最终状态,如果针对某个过去的历史节点调用此方法,就会在对应的地方开启了一个新的分支。除此之外,我们还可以利用bulk_update_state/abulk_update_state方法实现批量状态更新。

1. 以历史节点名义修改状态

假设我们有如上的一个转账流程,它会先后执行的calculate_amounttransfer这两个步骤,前者根据输入的银行账号计算对应的转账金额,后者生成一个代表转账业务的Transaction对象。对于一个已经发生的转账流程,如果发现金额计算有误,我们可以在节点calculate_amount处开启一个新的分支将金额改过来,然后重新驱动后续的流程。

如下的程序模拟了此应用场景。我们创建了上述的calculate_amounttransfer节点,并由它们构建了一个Pregel对象。调用该对象的时候向输入通道from_accountto_account输入转账双方的银行账号,并利用通道 start 驱动calculate_amount。它将转账金额1,000,000写入 通道 amount ,并由此驱动transfer节点,后者创建一个代表转账业务的Transaction对象写入对应的Channel。

from langgraph.channels import LastValue
from typing import NamedTuple
from langgraph.pregel import Pregel, NodeBuilder
from langgraph.checkpoint.memory import InMemorySaver
import json

class Transaction(NamedTuple):
    from_account: str
    to_account: str
    amount: float

calculate_amount = (
    NodeBuilder()
    .subscribe_to("start", read=False)
    .read_from("from_account", "to_account")
    .do(lambda _: 1_000_000.00)
    .write_to("amount")
)

transfer = (
    NodeBuilder()
    .subscribe_to("amount")
    .read_from("from_account", "to_account")
    .do(
        lambda args: Transaction(
            args["from_account"], args["to_account"], args["amount"]
        )
    )
    .write_to("transaction")
)

app = Pregel(
    nodes={"calculate_amount": calculate_amount, "transfer": transfer},
    channels={
        "start": LastValue(None),
        "from_account": LastValue(str),
        "to_account": LastValue(str),
        "amount": LastValue(float),
        "transaction": LastValue(Transaction),
    },
    input_channels=["start", "from_account", "to_account"],
    output_channels=["from_account", "to_account", "amount", "transaction"],
    checkpointer=InMemorySaver(),
)
config = {"configurable": {"thread_id": "tx123"}}
input = {"start": None, "from_account": "Alice", "to_account": "Bob"}
result = app.invoke(input=input, config=config)
assert result["transaction"] == Transaction("Alice", "Bob", 1000000.00)

new_config = app.update_state(
    config=list(app.get_state_history(config))[-2].config,
    values=100.00,
)
result = app.invoke(input=None, config=new_config)
assert result["transaction"] == Transaction("Alice", "Bob", 100.00)
for state in app.get_state_history(config):
    print(f"step {state.metadata['step']}: ")
    print(json.dumps(state.values))

我们输入转账双方的账号(Alice和Bob代替)调用Pregel的invoke方法,并通过断言确保最终生成了我们希望的转账业务(Transaction(“Alice”, “Bob”, 1000000.00))。现在我们需要在转账流程进行到金额计算的地方开启新的分支来将修改转账金额(100),为此我们调用get_state_history方法得到整段历史,并提取倒数第二个快照的配置(转账金额是在倒数第二个Superstep由calculate_amount提供)。

我们将此配置作为参数调用update_state方法,并将新的金额设置为values参数。由于此Superstep只涉及单一Node,所以无需指定as_node参数。状态的更新并不会驱动后续流程的自动执行,我们还需要再次调用invoke方法并从这个地方开始执行,此时我们不需要再次提供原始的输入,只需将update_state方法返回的RunnableConfig配置作为参数。再次利用断言验证生成的转账业务具有新的金额后,我们调用get_state_history获取并输出新的历史。从如下的输出结果可以看出,最后的两个Checkpoint就是我们开辟的新分支。

step 2:
{"start": null, "from_account": "Alice", "to_account": "Bob", "amount": 100.0, "transaction": ["Alice", "Bob", 100.0]}
step 1:
{"start": null, "from_account": "Alice", "to_account": "Bob", "amount": 100.0}

step 1:
{"start": null, "from_account": "Alice", "to_account": "Bob", "amount": 1000000.0, "transaction": ["Alice", "Bob", 1000000.0]}
step 0:
{"start": null, "from_account": "Alice", "to_account": "Bob", "amount": 1000000.0}
step -1:
{"start": null, "from_account": "Alice", "to_account": "Bob"}

2. 修改原始输入

如果我们发现原始输入错误了,需要在流程后续的某个Node将输入改过来怎么办呢?其实很简单,将as_node设置成__input__就可以了。就以上面这个转账流程为例,假设在计算金额的时候发现提供的两个账号弄反了,我们可以采用如下的方式直接纠正过来。

new_config = app.update_state(
    config=list(app.get_state_history(config))[-2].config,
    values={"from_account": "Bob", "to_account": "Alice", },
    as_node="__input__"
)
result = app.invoke(input=None, config=new_config)
assert result["transaction"] == Transaction("Bob", "Alice", 1_000_000.00)
for state in app.get_state_history(config):
    metadata = state.metadata    
    step=metadata["step"]
    source=metadata["source"]
    print(f"step {step}\n(source: {source})\nvalues:{json.dumps(state.values)}\n")

由于我们是通过修改原始输入的方式开启的分支,重建的这个代表新分支起点的Checkpoint的Source将是input,我们可以从输出的历史看出这一点。

step 2
(source: loop)
values:{"start": null, "from_account": "Bob", "to_account": "Alice", "amount": 1000000.0, "transaction": ["Bob", "Alice", 1000000.0]}

step 1
(source: input)
values:{"start": null, "from_account": "Bob", "to_account": "Alice", "amount": 1000000.0}

step 1
(source: loop)
values:{"start": null, "from_account": "Alice", "to_account": "Bob", "amount": 1000000.0, "transaction": ["Alice", "Bob", 1000000.0]}

step 0
(source: loop)
values:{"start": null, "from_account": "Alice", "to_account": "Bob", "amount": 1000000.0}

step -1
(source: input)
values:{"start": null, "from_account": "Alice", "to_account": "Bob"}

3. 单纯Fork一个分支

我们还可以直接在不对状态作任何更新的前提下,直接拷贝指定Checkpoint的方式开启一个分支,此时只需要将as_node参数设置为__copy__就可以了,对于这个必需的values参数,我们指定为一个空的列表代表不对状态做任何更新。

new_config = app.update_state(
    config=list(app.get_state_history(config))[-2].config,
    values=[],
    as_node="__copy__"
)
result = app.invoke(input=None, config=new_config)
assert result["transaction"] == Transaction("Alice", "Bob", 1_000_000.00)
for state in app.get_state_history(config):
    metadata = state.metadata    
    step=metadata["step"]
    source=metadata["source"]
    print(f"step {step}\n(source: {source})\nvalues:{json.dumps(state.values)}\n")

使用__copy__作为as_node的参数值的意图很明确,那就是在此处fork一个新的分支,所以创建的Checkpoint的Source就是fork,输出的历史也体现了这一点。

step 2
(source: loop)
values:{"start": null, "from_account": "Alice", "to_account": "Bob", "amount": 1000000.0, "transaction": ["Alice", "Bob", 1000000.0]}

step 1
(source: fork)
values:{"start": null, "from_account": "Alice", "to_account": "Bob", "amount": 1000000.0}

step 1
(source: loop)
values:{"start": null, "from_account": "Alice", "to_account": "Bob", "amount": 1000000.0, "transaction": ["Alice", "Bob", 1000000.0]}

step 0
(source: loop)
values:{"start": null, "from_account": "Alice", "to_account": "Bob", "amount": 1000000.0}

step -1
(source: input)
values:{"start": null, "from_account": "Alice", "to_account": "Bob"}

4. 在拷贝的基础上同时修改状态

在拷贝的同时修改状态也是可以的,而且还可以同时针对多个Node修改对应的状态,此时values参数需要设置为两层序列。外层序列代表针对不同Node的状态更新,内层指定值和Node名称,具体的格式为[[value1,node1],[value2, node2],..](使用元组也可以)。比如下面的两段代码分别实现了针对输入的更新和金额的更新。为什么不放在一起了,因为输入(__input__)不支持多Node更新

new_config = app.update_state(
    config=list(app.get_state_history(config))[-2].config,
values=[ 
    [{"from_account": "Jason", "to_account": "Jyden", },"__input__"]
    ],
    as_node="__copy__"
)
result = app.invoke(input=None, config=new_config)
assert result["transaction"] == Transaction("Jason", "Jyden", 1_000_000.00)
new_config = app.update_state(
    config=list(app.get_state_history(config))[-2].config,
    values=[
        [9999,"calculate_amount"]
        ],
    as_node="__copy__"
)
result = app.invoke(input=None, config=new_config)
assert result["transaction"] == Transaction("Alice", "Bob", 9999)

5. 批量更新

如果涉及多Node状态更新,就需要将每个更新封装成StateUpdate对象,然后进一步组合成双层序列Sequence[Sequence[StateUpdate]],并将其作为参数调用bulk_update_state/abulk_update_state方法。方法将每个Sequence[StateUpdate]对象作为一批统一写入,并为它们创建一个Checkpoint。如果序列中包含多个StateUpdate对象,每个对象必须通过as_node字段将更新状态的 “名义Node” 确定下来。

如下这个演示程序中的Pregel由四个并行执行的节点foo、bar、baz和qux组成,它们会将自身的Node名称写入与之同名的Channel。在常规执行之后,我们调用了bulk_update_state方法,并将supersteps参数指定为一个包含两组StateUpdate列表的列表,分别以节点foo/bar和baz/qux的名义修改对应Channel的值。

from langgraph.channels import LastValue
from langgraph.pregel import Pregel, NodeBuilder
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.types import StateUpdate

def build_node(node_name: str):
    return (
        NodeBuilder()
        .subscribe_to("start", read=False)
        .do(lambda _: node_name)
        .write_to(node_name)
    )

nodes = {name: build_node(name) for name in ["foo", "bar", "baz", "qux"]}
app = Pregel(
    nodes=nodes,
    channels={
        "start": LastValue(None),
        "foo": LastValue(str),
        "bar": LastValue(str),
        "baz": LastValue(str),
        "qux": LastValue(str),
    },
    input_channels=["start"],
    output_channels=["foo", "bar", "baz", "qux"],
    checkpointer=InMemorySaver(),
)
config = {"configurable": {"thread_id": "tx123"}}
result = app.invoke(input={"start": None}, config=config)

new_config = app.bulk_update_state(
    config=config,
    supersteps=[
        [
            StateUpdate(as_node="foo", values="updated_foo"),
            StateUpdate(as_node="bar", values="updated_bar"),
        ],
        [
            StateUpdate(as_node="baz", values="updated_baz"),
            StateUpdate(as_node="qux", values="updated_qux"),
        ],
    ],
)

result = app.invoke(input=None, config=new_config)
assert result == {
    "foo": "updated_foo",
    "bar": "updated_bar",
    "baz": "updated_baz",
    "qux": "updated_qux",
}

for state in app.get_state_history(config):
    metadata = state.metadata
    step = metadata["step"]
    source = metadata["source"]
    print(f"step {step}\nsource: {source}\nvalues: {state.values}")
    print()

两个StateUpdate序列对应着两个Checkpoint的创建,具体体现在如下所示的输出结果中。前一个Checkpoint包含了针对通道foo和bar的更新,针对通道baz/qux的更新体现在后一个Checkpoint中。

step 2
source: update
values: {'start': None, 'foo': 'updated_foo', 'bar': 'updated_bar', 'baz': 'updated_baz', 'qux': 'updated_qux'}

step 1
source: update
values: {'start': None, 'foo': 'updated_foo', 'bar': 'updated_bar', 'baz': 'baz', 'qux': 'qux'}

step 0
source: loop
values: {'start': None, 'foo': 'foo', 'bar': 'bar', 'baz': 'baz', 'qux': 'qux'}

step -1
source: input
values: {'start': None}
Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐