LangGraph学习笔记三(State)
一.基础笔记
1.TypeDict和BaseModel对比
2.State状态更新策略
(1).默认Reducer(覆盖更新)
class DefaultReducerState(TypedDict):
foo: int
bar: List[str]
def node_default_1(state: DefaultReducerState) -> dict:
print(state["foo"])
print(state["bar"])
return {"foo": 22}
def node_default_2(state: DefaultReducerState) -> dict:
print()
print(state["foo"])
print(state["bar"])
return {"bar": ["bye1","bye2","bye3"]}
def main():
print("1. 默认Reducer(覆盖更新)演示:\n")
builder = StateGraph(DefaultReducerState)
builder.add_node("node1", node_default_1)
builder.add_node("node2", node_default_2)
builder.add_edge(START, "node1")
builder.add_edge("node1", "node2")
builder.add_edge("node2", END)
graph = builder.compile()
result = graph.invoke(input={"foo": 1, "bar": ["hi"]})
#print(f"初始状态: {{'foo': 1, 'bar': ['hi']}}")
print(f"执行结果: {result}\n")
执行结果:
(2).add_messages Reducer(消息列表专用)
"""
LangGraph Reducer函数演示 - add_messages Reducer(消息列表专用)
"""
from typing import Annotated, List
from langchain_core.messages import HumanMessage, AIMessage
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
# 2. add_messages Reducer(消息列表专用)
class AddMessagesState(TypedDict):
"""
引入的 Annotated 类型,它允许给类型添加额外的元数据。
messages: Annotated[List, add_messages]
表示:
- messages 我的状态里只有一个字段叫 messages,类型是是 List列表类型,
- add_messages 这里的 add_messages 是一个函数,用于修改 messages 列表
每当节点返回对 messages 的“局部更新”时,
请用 add_messages 规约器把它合并到旧列表上(追加,而不是覆盖)
总结:
节点永远只 return 增量字典,不用手动把旧列表读出来再拼接。
add_messages 在后台帮你完成“追加”动作;如果换成默认 reducer,旧消息会被整份替换掉
"""
messages: Annotated[List, add_messages]
def chat_node_1(state: AddMessagesState) -> dict:
return {"messages": [("assistant", "Hello from node 1")]}
def chat_node_2(state: AddMessagesState) -> dict:
return {"messages": [("assistant", "Hello from node 2")]}
def run_demo():
print("2. add_messages Reducer(消息列表专用)演示:")
builder = StateGraph(AddMessagesState)
builder.add_node("chat1", chat_node_1)
builder.add_node("chat2", chat_node_2)
builder.add_edge(START, "chat1")
builder.add_edge(START, "chat2") # 并行执行
builder.add_edge("chat1", END)
builder.add_edge("chat2", END)
graph = builder.compile()
result = graph.invoke({"messages": [("user", "Hi there!")]})
print(f"初始状态: {{'messages': [('user', 'Hi there!')]}}")
print(f"执行结果: {result}\n")
print("*" * 60)
# 打印图的ascii可视化结构
print(graph.get_graph().print_ascii())
if __name__ == "__main__":
run_demo()
执行结果:
(3).operator.add Reducer(列表追加)
"""
LangGraph Reducer函数演示 - operator.add Reducer(列表追加)
"""
import operator
from typing import Annotated, List
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
# 3. operator.add Reducer(列表追加)
class ListAddState(TypedDict):
#data: Annotated[List[int], None] #默认覆盖
data: Annotated[List[int], operator.add] # (列表追加)
def producer_1(state: ListAddState) -> dict:
return {"data": [1, 2]}
def producer_2(state: ListAddState) -> dict:
return {"data": [3, 4]}
def run_demo():
builder = StateGraph(ListAddState)
# 注册节点
builder.add_node("producer1", producer_1)
builder.add_node("producer2", producer_2)
# 顺序执行边
builder.add_edge(START, "producer1")
builder.add_edge("producer1", "producer2")
builder.add_edge("producer2", END)
graph = builder.compile()
result = graph.invoke({"data": [0]})
print(f"初始状态: {{'data': [0]}}")
print(f"执行结果: {result}\n")
if __name__ == "__main__":
run_demo()
执行结果:
4.自定义乘法reducer规约器
def MyOperatorMul(current: float, update: float) -> float:
"""自定义乘法reducer,处理初始值为1.0"""
# 如果是第一次调用,current会是默认值0.0
if current == 0.0:
# 对于乘法,恒等元应该是1.0或者 return 1.0 * update
print(f"current:{current}")
print(f"update:{update}")
return 1.0 * update
return current * update
class MultiplyState(TypedDict):
factor: Annotated[float, MyOperatorMul]
def multiplier(state: MultiplyState) -> dict:
return {"factor": 2.0}
def run_demo():
print("使用自定义reducer解决乘法问题:")
builder = StateGraph(MultiplyState)
builder.add_node("multiplier", multiplier)
builder.add_edge(START, "multiplier")
builder.add_edge("multiplier", END)
graph = builder.compile()
result = graph.invoke({"factor": 5.0})
print(f"初始状态: {{'factor': 5.0}}")
print(f"执行结果: {result}") # 应该是 {'factor': 10.0}
print(f"解释: 5.0 * 2.0 = 10.0\n")
if __name__ == "__main__":
run_demo()
执行结果: