Skip to content

Commit 776efe7

Browse files
authored
validate that workflow is not cleared until events are consumed (#186)
* validate that workflow is not cleared until events are consumed * version bump * move completion check to function located with status union
1 parent 63bb2e6 commit 776efe7

File tree

5 files changed

+53
-5
lines changed

5 files changed

+53
-5
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ dev = [
1818

1919
[project]
2020
name = "llama-index-workflows"
21-
version = "2.10.1"
21+
version = "2.10.2"
2222
description = "An event-driven, async-first, step-based way to control the execution flow of AI applications like Agents."
2323
readme = "README.md"
2424
license = "MIT"

src/workflows/protocol/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
Status = Literal["running", "completed", "failed", "cancelled"]
1212

1313

14+
def is_status_completed(status: Status) -> bool:
15+
return status in {"completed", "failed", "cancelled"}
16+
17+
1418
class HandlerData(BaseModel):
1519
handler_id: str
1620
workflow_name: str

src/workflows/server/server.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
WorkflowEventsListResponse,
4545
WorkflowGraphResponse,
4646
WorkflowSchemaResponse,
47+
is_status_completed,
4748
)
4849
from workflows.server.abstract_workflow_store import (
4950
AbstractWorkflowStore,
@@ -1125,9 +1126,11 @@ async def _post_event(self, request: Request) -> JSONResponse:
11251126

11261127
# Check if handler exists
11271128
wrapper = self._handlers.get(handler_id)
1129+
if wrapper is not None and is_status_completed(wrapper.status):
1130+
raise HTTPException(detail="Workflow already completed", status_code=409)
11281131
if wrapper is None:
11291132
handler_data = await self._load_handler(handler_id)
1130-
if handler_data.status in {"completed", "failed", "cancelled"}:
1133+
if is_status_completed(handler_data.status):
11311134
raise HTTPException(
11321135
detail="Workflow already completed", status_code=409
11331136
)
@@ -1407,6 +1410,7 @@ class _WorkflowHandler:
14071410
# Dependencies for persistence
14081411
_workflow_store: AbstractWorkflowStore
14091412
_persistence_backoff: list[float]
1413+
_on_finsh: Callable[[], Awaitable[None]] | None = None
14101414

14111415
async def persist(self) -> None:
14121416
"""Persist the current handler state immediately to the workflow store."""
@@ -1536,6 +1540,7 @@ def start_streaming(self, on_finish: Callable[[], Awaitable[None]]) -> None:
15361540
async def _stream_events(self, on_finish: Callable[[], Awaitable[None]]) -> None:
15371541
"""Internal method that streams events, updates status, and persists state."""
15381542
await self.checkpoint()
1543+
self._on_finish = on_finish
15391544
async for event in self.run_handler.stream_events(expose_internal=True):
15401545
if ( # Watch for a specific internal event that signals the step is complete
15411546
isinstance(event, StepStateChanged)
@@ -1561,7 +1566,6 @@ async def _stream_events(self, on_finish: Callable[[], Awaitable[None]]) -> None
15611566
logger.error(f"Workflow run {self.handler_id} failed! {e}", exc_info=True)
15621567

15631568
await self.checkpoint()
1564-
await on_finish()
15651569

15661570
async def acquire_events_stream(
15671571
self, timeout: float = 1
@@ -1608,6 +1612,8 @@ async def _iter_events(self, timeout: float = 1) -> AsyncGenerator[Event, None]:
16081612
queue_get_task.cancel()
16091613
break
16101614
finally:
1615+
if self._on_finish is not None:
1616+
await self._on_finish()
16111617
self.consumer_mutex.release()
16121618

16131619

tests/server/test_server_endpoints.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -970,7 +970,9 @@ async def test_post_event_context_not_available(
970970
) -> None:
971971
# Dumb test for code coverage. Inject a dummy handler with no context to trigger 500 path
972972
wrapper = SimpleNamespace(
973-
run_handler=SimpleNamespace(done=lambda: False, ctx=None), workflow_name="test"
973+
run_handler=SimpleNamespace(done=lambda: False, ctx=None),
974+
workflow_name="test",
975+
status="running",
974976
)
975977

976978
handler_id = "noctx-1"
@@ -1162,3 +1164,39 @@ async def _wait_done() -> Response:
11621164
assert isinstance(data, dict)
11631165
assert data.get("handler_id") == handler_id
11641166
assert data.get("result") is not None
1167+
1168+
1169+
@pytest.mark.asyncio
1170+
async def test_stream_events_after_completion_should_return_unconsumed_events(
1171+
client: AsyncClient,
1172+
) -> None:
1173+
# Start streaming workflow that emits 3 events and completes
1174+
start_resp = await client.post(
1175+
"/workflows/streaming/run-nowait", json={"kwargs": {"count": 3}}
1176+
)
1177+
assert start_resp.status_code == 200
1178+
handler_id = start_resp.json()["handler_id"]
1179+
1180+
# Wait for completion via results endpoint
1181+
async def _wait_done() -> Response:
1182+
r = await client.get(f"/handlers/{handler_id}")
1183+
if r.status_code == 200:
1184+
return r
1185+
raise AssertionError("not done")
1186+
1187+
await wait_for_passing(_wait_done)
1188+
1189+
# Now fetch events AFTER completion. Expect the unconsumed events to still be retrievable.
1190+
# Use NDJSON for easier parsing.
1191+
resp = await client.get(f"/events/{handler_id}?sse=false")
1192+
assert resp.status_code == 200
1193+
assert resp.headers["content-type"].startswith("application/x-ndjson")
1194+
1195+
# Collect NDJSON lines
1196+
lines: list[str] = []
1197+
async for line in resp.aiter_lines():
1198+
data = line.strip()
1199+
if data:
1200+
lines.append(data)
1201+
1202+
assert len(lines) == 4

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)