Skip to content

Commit babb1ab

Browse files
authored
Fix reconnect (#199)
* Fix reconnect * Bump version to 2.11.1
1 parent f233982 commit babb1ab

File tree

6 files changed

+186
-4
lines changed

6 files changed

+186
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ dev = [
1919

2020
[project]
2121
name = "llama-index-workflows"
22-
version = "2.11.0"
22+
version = "2.11.1"
2323
description = "An event-driven, async-first, step-based way to control the execution flow of AI applications like Agents."
2424
readme = "README.md"
2525
license = "MIT"

src/workflows/client/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ async def get_workflow_events(
228228
"include_internal": incl_inter,
229229
"acquire_timeout": lock_timeout,
230230
},
231+
headers={"Connection": "keep-alive"},
232+
timeout=None,
231233
) as response:
232234
# Handle different response codes
233235
if response.status_code == 404:

src/workflows/server/server.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,9 @@ async def _run_workflow(self, request: Request) -> JSONResponse:
522522
await wrapper.task
523523
except Exception:
524524
pass
525+
# explicitly close handlers from this synchronous api so they don't linger with events
526+
# that no-one is listening for
527+
await self._close_handler(wrapper)
525528

526529
return JSONResponse(
527530
wrapper.to_response_model().model_dump(), status_code=status
@@ -1424,7 +1427,7 @@ class _WorkflowHandler:
14241427
# Dependencies for persistence
14251428
_workflow_store: AbstractWorkflowStore
14261429
_persistence_backoff: list[float]
1427-
_on_finsh: Callable[[], Awaitable[None]] | None = None
1430+
_on_finish: Callable[[], Awaitable[None]] | None = None
14281431

14291432
def _as_persistent(self) -> PersistentHandler:
14301433
"""Persist the current handler state immediately to the workflow store."""
@@ -1641,7 +1644,8 @@ async def _iter_events(self, timeout: float = 1) -> AsyncGenerator[Event, None]:
16411644
queue_get_task.cancel()
16421645
break
16431646
finally:
1644-
if self._on_finish is not None:
1647+
if self._on_finish is not None and self.run_handler.done():
1648+
# clean up the resources if the stream has been consumed
16451649
await self._on_finish()
16461650
self.consumer_mutex.release()
16471651

tests/server/test_server_endpoints.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,3 +1238,17 @@ async def _wait_done() -> dict[str, Any]:
12381238
assert data["status"] == "completed"
12391239
assert seen_handler_id["handler_id"] is not None
12401240
assert seen_handler_id["handler_id"] == handler_id
1241+
1242+
1243+
@pytest.mark.asyncio
1244+
async def test_run_sync_removes_handler_even_with_unconsumed_events(
1245+
client: AsyncClient, server: WorkflowServer
1246+
) -> None:
1247+
# Run a streaming workflow synchronously; it emits user events but we don't consume them here.
1248+
resp = await client.post("/workflows/streaming/run", json={"kwargs": {"count": 2}})
1249+
assert resp.status_code == 200
1250+
data = resp.json()
1251+
assert data["status"] == "completed"
1252+
1253+
# The synchronous run path should clean up the handler from memory even if events remain
1254+
assert len(server._handlers) == 0
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# SPDX-License-Identifier: MIT
2+
# Copyright (c) 2025 LlamaIndex Inc.
3+
4+
from __future__ import annotations
5+
6+
import asyncio
7+
import contextlib
8+
import socket
9+
from contextlib import closing
10+
from typing import AsyncGenerator
11+
12+
import httpx
13+
import pytest
14+
import uvicorn
15+
16+
from workflows import Workflow
17+
from workflows.events import StopEvent
18+
from workflows.server import WorkflowServer
19+
from workflows.client.client import WorkflowClient
20+
from .conftest import ExternalEvent, RequestedExternalEvent
21+
from .util import wait_for_passing
22+
23+
24+
def _get_free_port() -> int:
25+
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
26+
s.bind(("127.0.0.1", 0))
27+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
28+
return int(s.getsockname()[1])
29+
30+
31+
@pytest.fixture
32+
async def live_server(
33+
simple_test_workflow: Workflow,
34+
streaming_workflow: Workflow,
35+
interactive_workflow: Workflow,
36+
error_workflow: Workflow,
37+
) -> AsyncGenerator[tuple[str, WorkflowServer], None]:
38+
port = _get_free_port()
39+
server = WorkflowServer()
40+
server.add_workflow("test", simple_test_workflow)
41+
server.add_workflow("streaming", streaming_workflow)
42+
server.add_workflow("interactive", interactive_workflow)
43+
server.add_workflow("error", error_workflow)
44+
45+
config = uvicorn.Config(
46+
server.app,
47+
host="127.0.0.1",
48+
port=port,
49+
log_level="error",
50+
loop="asyncio",
51+
)
52+
uv_server = uvicorn.Server(config)
53+
54+
# Start server in background task (lifespan will start workflows)
55+
task = asyncio.create_task(uv_server.serve())
56+
57+
# Wait until server responds on /health or timeout
58+
base_url = f"http://127.0.0.1:{port}"
59+
async with httpx.AsyncClient(base_url=base_url, timeout=1.0) as client:
60+
for _ in range(50): # ~0.5s max wait
61+
try:
62+
resp = await client.get("/health")
63+
if resp.status_code == 200:
64+
break
65+
except Exception:
66+
pass
67+
await asyncio.sleep(0.01)
68+
else:
69+
uv_server.should_exit = True
70+
await task
71+
raise RuntimeError("Live server did not start in time")
72+
73+
try:
74+
yield base_url, server
75+
finally:
76+
uv_server.should_exit = True
77+
try:
78+
await task
79+
finally:
80+
# Ensure graceful shutdown of workflow server
81+
await server.stop()
82+
83+
84+
@pytest.mark.asyncio
85+
async def test_streaming_over_real_http(
86+
live_server: tuple[str, WorkflowServer],
87+
) -> None:
88+
base_url, _server = live_server
89+
90+
client = WorkflowClient(base_url=base_url)
91+
92+
# 1) Start interactive workflow (no-wait)
93+
started = await client.run_workflow_nowait("interactive")
94+
handler_id = started.handler_id
95+
96+
# Stream until we see the RequestedExternalEvent, then respond and stop streaming
97+
saw_prompt = False
98+
async for ev in client.get_workflow_events(handler_id):
99+
event = ev.load_event()
100+
if isinstance(event, RequestedExternalEvent):
101+
saw_prompt = True
102+
sent = await client.send_event(handler_id, ExternalEvent(response="pong"))
103+
assert sent.status == "sent"
104+
break
105+
assert saw_prompt
106+
107+
data = await client.get_handler(handler_id)
108+
assert data.status == "completed"
109+
assert data.result is not None
110+
assert data.result.value.get("result") == "received: pong"
111+
112+
113+
@pytest.mark.asyncio
114+
async def test_reconnect_stream_and_send_event_succeeds(
115+
live_server: tuple[str, WorkflowServer],
116+
) -> None:
117+
base_url, _server = live_server
118+
119+
client = WorkflowClient(base_url=base_url)
120+
121+
# Start the interactive workflow (no-wait)
122+
started = await client.run_workflow_nowait("interactive")
123+
handler_id = started.handler_id
124+
125+
# 1) Connect and read until the first RequestedExternalEvent, then disconnect
126+
saw_prompt = False
127+
async for ev in client.get_workflow_events(handler_id):
128+
event = ev.load_event()
129+
if isinstance(event, RequestedExternalEvent):
130+
saw_prompt = True
131+
break # simulate client disconnect
132+
assert saw_prompt
133+
134+
# 2) Reconnect to stream; this should succeed and allow streaming again
135+
stop_seen = asyncio.Event()
136+
137+
async def _consume_again() -> None:
138+
async for ev in client.get_workflow_events(handler_id):
139+
event = ev.load_event()
140+
if isinstance(event, StopEvent):
141+
stop_seen.set()
142+
break
143+
144+
consumer_task = asyncio.create_task(_consume_again())
145+
try:
146+
# 3) After reconnect, post the human response; this should succeed
147+
sent = await client.send_event(handler_id, ExternalEvent(response="pong"))
148+
assert sent.status == "sent"
149+
150+
# 4) Wait for completion
151+
async def validate_result_response() -> None:
152+
data = await client.get_handler(handler_id)
153+
assert data.status == "completed"
154+
155+
await wait_for_passing(validate_result_response)
156+
157+
await asyncio.wait_for(stop_seen.wait(), timeout=2.0)
158+
finally:
159+
if not consumer_task.done():
160+
consumer_task.cancel()
161+
with contextlib.suppress(Exception):
162+
await consumer_task

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)