Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions libs/langgraph/langgraph/pregel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2834,21 +2834,6 @@ async def astream(
)
)

# set up custom stream mode
def stream_writer(c: Any) -> None:
aioloop.call_soon_threadsafe(
stream.put_nowait,
(
tuple(
get_config()[CONF][CONFIG_KEY_CHECKPOINT_NS].split(NS_SEP)[
:-1
]
),
"custom",
c,
),
)

if "custom" in stream_modes:

def stream_writer(c: Any) -> None:
Expand Down
78 changes: 78 additions & 0 deletions libs/langgraph/tests/test_get_stream_writer_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import asyncio
from typing import Annotated

import pytest
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode
from typing_extensions import TypedDict

from langgraph.config import get_stream_writer
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages


@tool
async def tool_a(query: str):
"""Tool A."""
writer = get_stream_writer()
writer({"type": "tool_a", "status": "start"})
await asyncio.sleep(0.1)
writer({"type": "tool_a", "status": "end"})
return "A"


@tool
async def tool_b(query: str):
"""Tool B."""
writer = get_stream_writer()
writer({"type": "tool_b", "status": "start"})
await asyncio.sleep(0.1)
writer({"type": "tool_b", "status": "end"})
return "B"


class State(TypedDict):
messages: Annotated[list[BaseMessage], add_messages]


@pytest.mark.asyncio
async def test_async_tool_node_streaming():
tools = [tool_a, tool_b]
tool_node = ToolNode(tools)

async def agent(state: State):
return {
"messages": [
AIMessage(
content="",
tool_calls=[
{"name": "tool_a", "args": {"query": "a"}, "id": "1"},
{"name": "tool_b", "args": {"query": "b"}, "id": "2"},
],
)
]
}

workflow = StateGraph(State)
workflow.add_node("agent", agent)
workflow.add_node("tools", tool_node)
workflow.add_edge(START, "agent")
workflow.add_edge("agent", "tools")
workflow.add_edge("tools", END)

app = workflow.compile()

chunks = []
async for chunk in app.astream(
{"messages": []},
stream_mode="custom",
):
chunks.append(chunk)

# Verify we got chunks from both tools
tool_a_chunks = [c for c in chunks if c.get("type") == "tool_a"]
tool_b_chunks = [c for c in chunks if c.get("type") == "tool_b"]

assert len(tool_a_chunks) == 2
assert len(tool_b_chunks) == 2
2 changes: 1 addition & 1 deletion libs/prebuilt/tests/test_react_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def get_weather(location: str) -> str:
for event in agent.stream(
{"messages": [("user", query)]}, config, stream_mode="values"
):
if "__interrupt__" not in event:
if "__interrupt__" not in event:
if messages := event.get("messages"):
message_types.append([m.type for m in messages])

Expand Down
Loading