diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 25ea4cf89..621467a16 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -31,6 +31,7 @@ from client import AsyncLlamaStackClientHolder from configuration import configuration from constants import ( + INTERRUPTED_RESPONSE_MESSAGE, LLM_TOKEN_EVENT, LLM_TOOL_CALL_EVENT, LLM_TOOL_RESULT_EVENT, @@ -320,6 +321,110 @@ async def retrieve_response_generator( raise HTTPException(**error_response.model_dump()) from e +async def _persist_interrupted_turn( + context: ResponseGeneratorContext, + responses_params: ResponsesApiParams, + turn_summary: TurnSummary, +) -> None: + """Persist the user query and an interrupted response into the conversation. + + Called when a streaming request is cancelled so the exchange is not lost. + All errors are caught and logged to avoid masking the original + cancellation. + + Parameters: + context: The response generator context. + responses_params: The Responses API parameters. + turn_summary: TurnSummary with llm_response already set to the + interrupted message. + """ + try: + await append_turn_to_conversation( + context.client, + responses_params.conversation, + responses_params.input, + INTERRUPTED_RESPONSE_MESSAGE, + ) + except Exception: # pylint: disable=broad-except + logger.exception( + "Failed to append interrupted turn to conversation for request %s", + context.request_id, + ) + + try: + completed_at = datetime.datetime.now(datetime.UTC).strftime( + "%Y-%m-%dT%H:%M:%SZ" + ) + store_query_results( + user_id=context.user_id, + conversation_id=context.conversation_id, + model=responses_params.model, + completed_at=completed_at, + started_at=context.started_at, + summary=turn_summary, + query_request=context.query_request, + skip_userid_check=context.skip_userid_check, + topic_summary=None, + ) + except Exception: # pylint: disable=broad-except + logger.exception( + "Failed to store interrupted query results for request %s", + context.request_id, + ) + + +def _register_interrupt_callback( + context: ResponseGeneratorContext, + responses_params: ResponsesApiParams, + turn_summary: TurnSummary, +) -> list[bool]: + """Build an interrupt callback and register the stream for cancellation. + + The callback is scheduled as a **separate** asyncio task by + ``cancel_stream`` so it executes regardless of where the + ``CancelledError`` is raised in the ASGI stack. + + A mutable one-element list is used as a shared guard so the + callback and the in-generator ``CancelledError`` handler never + both persist the same turn. + + Parameters: + context: The response generator context. + responses_params: The Responses API parameters. + turn_summary: TurnSummary populated during streaming. + + Returns: + A mutable list ``[False]`` used as a persist-done guard; the + caller should check ``guard[0]`` before persisting and set + it to ``True`` afterwards. + """ + guard: list[bool] = [False] + + async def _on_interrupt() -> None: + if guard[0]: + return + guard[0] = True + turn_summary.llm_response = INTERRUPTED_RESPONSE_MESSAGE + await _persist_interrupted_turn(context, responses_params, turn_summary) + + current_task = asyncio.current_task() + if current_task is not None: + get_stream_interrupt_registry().register_stream( + request_id=context.request_id, + user_id=context.user_id, + task=current_task, + on_interrupt=_on_interrupt, + ) + else: + logger.warning( + "No current asyncio task for request %s; " + "stream interruption will not be available", + context.request_id, + ) + + return guard + + async def generate_response( generator: AsyncIterator[str], context: ResponseGeneratorContext, @@ -330,9 +435,9 @@ async def generate_response( Re-yields events from the generator, handles errors, and ensures persistence and token consumption after completion. When the - stream is interrupted via ``CancelledError``, all post-stream side - effects (token consumption, result storage) are skipped and the - request is deregistered from the interrupt registry. + stream is interrupted via ``CancelledError``, the user query and + an interrupted response are persisted to the conversation, but + token consumption is skipped (no usage data is available). Args: generator: The base generator to wrap @@ -343,20 +448,9 @@ async def generate_response( Yields: SSE-formatted strings from the wrapped generator """ - user_id = context.user_id - - current_task = asyncio.current_task() - if current_task is not None: - get_stream_interrupt_registry().register_stream( - request_id=context.request_id, - user_id=user_id, - task=current_task, - ) - else: - logger.warning( - "No current asyncio task for request %s; stream interruption will not be available", - context.request_id, - ) + persist_guard = _register_interrupt_callback( + context, responses_params, turn_summary + ) stream_completed = False try: @@ -390,6 +484,13 @@ async def generate_response( yield stream_http_error_event(error_response, context.query_request.media_type) except asyncio.CancelledError: logger.info("Streaming request %s interrupted by user", context.request_id) + current_task = asyncio.current_task() + if current_task is not None: + current_task.uncancel() + if not persist_guard[0]: + persist_guard[0] = True + turn_summary.llm_response = INTERRUPTED_RESPONSE_MESSAGE + await _persist_interrupted_turn(context, responses_params, turn_summary) yield stream_interrupted_event(context.request_id) finally: get_stream_interrupt_registry().deregister_stream(context.request_id) diff --git a/src/constants.py b/src/constants.py index fadaa064f..6ae17072e 100644 --- a/src/constants.py +++ b/src/constants.py @@ -6,6 +6,9 @@ UNABLE_TO_PROCESS_RESPONSE = "Unable to process this request" +# Response stored in the conversation when the user interrupts a streaming request +INTERRUPTED_RESPONSE_MESSAGE = "You interrupted this request." + # Supported attachment types ATTACHMENT_TYPES = frozenset( { diff --git a/src/utils/stream_interrupts.py b/src/utils/stream_interrupts.py index cf8cea180..95f9fde94 100644 --- a/src/utils/stream_interrupts.py +++ b/src/utils/stream_interrupts.py @@ -1,7 +1,9 @@ """In-memory registry for interrupting active streaming requests.""" import asyncio -from dataclasses import dataclass +from collections.abc import Callable, Coroutine +from typing import Any +from dataclasses import dataclass, field from enum import Enum from threading import Lock from log import get_logger @@ -17,10 +19,16 @@ class ActiveStream: Attributes: user_id: Owner of the streaming request. task: Asyncio task producing the stream response. + on_interrupt: Optional async callback invoked when the stream + is cancelled, scheduled as a separate task so it runs + regardless of where the ``CancelledError`` lands. """ user_id: str task: asyncio.Task[None] + on_interrupt: Callable[[], Coroutine[Any, Any, None]] | None = field( + default=None, repr=False + ) class CancelStreamResult(str, Enum): @@ -41,7 +49,11 @@ def __init__(self) -> None: self._lock = Lock() def register_stream( - self, request_id: str, user_id: str, task: asyncio.Task[None] + self, + request_id: str, + user_id: str, + task: asyncio.Task[None], + on_interrupt: Callable[[], Coroutine[Any, Any, None]] | None = None, ) -> None: """Register an active stream task for interrupt support. @@ -49,9 +61,13 @@ def register_stream( request_id: Unique streaming request identifier. user_id: User identifier that owns the stream. task: Asyncio task associated with the stream. + on_interrupt: Optional async callback to run when the stream + is cancelled, executed in a separate task. """ with self._lock: - self._streams[request_id] = ActiveStream(user_id=user_id, task=task) + self._streams[request_id] = ActiveStream( + user_id=user_id, task=task, on_interrupt=on_interrupt + ) def cancel_stream(self, request_id: str, user_id: str) -> CancelStreamResult: """Cancel an active stream owned by user. @@ -60,6 +76,11 @@ def cancel_stream(self, request_id: str, user_id: str) -> CancelStreamResult: lock so that a concurrent ``deregister_stream`` cannot remove the entry between the ownership check and the cancel call. + When an ``on_interrupt`` callback was registered, it is + scheduled as a **separate** asyncio task after the cancel so + persistence runs regardless of where the ``CancelledError`` + is raised (inside the generator or in Starlette's send). + Parameters: request_id: Unique streaming request identifier. user_id: User identifier attempting the interruption. @@ -67,6 +88,7 @@ def cancel_stream(self, request_id: str, user_id: str) -> CancelStreamResult: Returns: CancelStreamResult: Structured cancellation result. """ + on_interrupt = None with self._lock: stream = self._streams.get(request_id) if stream is None: @@ -81,7 +103,12 @@ def cancel_stream(self, request_id: str, user_id: str) -> CancelStreamResult: if stream.task.done(): return CancelStreamResult.ALREADY_DONE stream.task.cancel() - return CancelStreamResult.CANCELLED + on_interrupt = stream.on_interrupt + + if on_interrupt is not None: + asyncio.get_running_loop().create_task(on_interrupt()) + + return CancelStreamResult.CANCELLED def deregister_stream(self, request_id: str) -> None: """Remove stream task from registry once completed/cancelled. diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index f72794ebf..a819e1b8f 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -1358,12 +1358,12 @@ async def mock_generator() -> AsyncIterator[str]: assert any("error" in item for item in result) @pytest.mark.asyncio - async def test_generate_response_cancelled_skips_side_effects( + async def test_generate_response_cancelled_persists_interrupted_turn( self, mocker: MockerFixture, isolate_stream_interrupt_registry: Any, ) -> None: - """Test cancelled stream exits without quota consumption and persistence.""" + """Test cancelled stream persists user query with interrupted response.""" async def mock_generator() -> AsyncIterator[str]: yield "data: token\n\n" @@ -1377,9 +1377,12 @@ async def mock_generator() -> AsyncIterator[str]: ) # pyright: ignore[reportCallIssue] mock_context.started_at = "2024-01-01T00:00:00Z" mock_context.skip_userid_check = False + mock_context.client = mocker.AsyncMock(spec=AsyncLlamaStackClient) mock_responses_params = mocker.Mock(spec=ResponsesApiParams) mock_responses_params.model = "provider1/model1" + mock_responses_params.conversation = "conv_123" + mock_responses_params.input = "test" mock_turn_summary = TurnSummary() mock_turn_summary.token_usage = TokenCounter(input_tokens=10, output_tokens=5) @@ -1390,6 +1393,10 @@ async def mock_generator() -> AsyncIterator[str]: store_query_results_mock = mocker.patch( "app.endpoints.streaming_query.store_query_results" ) + append_turn_mock = mocker.patch( + "app.endpoints.streaming_query.append_turn_to_conversation", + new_callable=mocker.AsyncMock, + ) test_request_id = "123e4567-e89b-12d3-a456-426614174000" mock_context.request_id = test_request_id @@ -1407,11 +1414,183 @@ async def mock_generator() -> AsyncIterator[str]: assert any('"event": "interrupted"' in item for item in result) assert not any('"event": "end"' in item for item in result) consume_query_tokens_mock.assert_not_called() - store_query_results_mock.assert_not_called() + + append_turn_mock.assert_called_once_with( + mock_context.client, + "conv_123", + "test", + "You interrupted this request.", + ) + store_query_results_mock.assert_called_once() + call_kwargs = store_query_results_mock.call_args[1] + assert call_kwargs["user_id"] == "user_123" + assert call_kwargs["conversation_id"] == "conv_123" + assert call_kwargs["summary"].llm_response == "You interrupted this request." + assert call_kwargs["topic_summary"] is None + isolate_stream_interrupt_registry.deregister_stream.assert_called_once_with( test_request_id ) + @pytest.mark.asyncio + async def test_generate_response_cancelled_stores_results_when_append_fails( + self, + mocker: MockerFixture, + isolate_stream_interrupt_registry: Any, + ) -> None: + """Test store_query_results still runs when append_turn_to_conversation fails.""" + + async def mock_generator() -> AsyncIterator[str]: + yield "data: token\n\n" + raise asyncio.CancelledError() + + mock_context = mocker.Mock(spec=ResponseGeneratorContext) + mock_context.conversation_id = "conv_123" + mock_context.user_id = "user_123" + mock_context.query_request = QueryRequest( + query="test", media_type=MEDIA_TYPE_JSON + ) # pyright: ignore[reportCallIssue] + mock_context.started_at = "2024-01-01T00:00:00Z" + mock_context.skip_userid_check = False + mock_context.client = mocker.AsyncMock(spec=AsyncLlamaStackClient) + + mock_responses_params = mocker.Mock(spec=ResponsesApiParams) + mock_responses_params.model = "provider1/model1" + mock_responses_params.conversation = "conv_123" + mock_responses_params.input = "test" + + mock_turn_summary = TurnSummary() + + mocker.patch("app.endpoints.streaming_query.consume_query_tokens") + store_query_results_mock = mocker.patch( + "app.endpoints.streaming_query.store_query_results" + ) + mocker.patch( + "app.endpoints.streaming_query.append_turn_to_conversation", + new_callable=mocker.AsyncMock, + side_effect=RuntimeError("Llama Stack unavailable"), + ) + + test_request_id = "123e4567-e89b-12d3-a456-426614174000" + mock_context.request_id = test_request_id + + result = [] + async for item in generate_response( + mock_generator(), + mock_context, + mock_responses_params, + mock_turn_summary, + ): + result.append(item) + + assert any('"event": "interrupted"' in item for item in result) + store_query_results_mock.assert_called_once() + isolate_stream_interrupt_registry.deregister_stream.assert_called_once_with( + test_request_id + ) + + @pytest.mark.asyncio + async def test_generate_response_task_cancel_persists_results( + self, + mocker: MockerFixture, + isolate_stream_interrupt_registry: Any, + ) -> None: + """Test that real task.cancel() persists via CancelledError handler.""" + cancel_event = asyncio.Event() + + async def slow_generator() -> AsyncIterator[str]: + yield "data: token\n\n" + await cancel_event.wait() + yield "data: should not reach\n\n" + + mock_context = mocker.Mock(spec=ResponseGeneratorContext) + mock_context.conversation_id = "conv_123" + mock_context.user_id = "user_123" + mock_context.query_request = QueryRequest( + query="test", media_type=MEDIA_TYPE_JSON + ) # pyright: ignore[reportCallIssue] + mock_context.started_at = "2024-01-01T00:00:00Z" + mock_context.skip_userid_check = False + mock_context.client = mocker.AsyncMock(spec=AsyncLlamaStackClient) + + mock_responses_params = mocker.Mock(spec=ResponsesApiParams) + mock_responses_params.model = "provider1/model1" + mock_responses_params.conversation = "conv_123" + mock_responses_params.input = "test" + + mock_turn_summary = TurnSummary() + + mocker.patch("app.endpoints.streaming_query.consume_query_tokens") + store_query_results_mock = mocker.patch( + "app.endpoints.streaming_query.store_query_results" + ) + append_turn_mock = mocker.patch( + "app.endpoints.streaming_query.append_turn_to_conversation", + new_callable=mocker.AsyncMock, + ) + + test_request_id = "123e4567-e89b-12d3-a456-426614174000" + mock_context.request_id = test_request_id + + result: list[str] = [] + + async def consume_generator() -> None: + async for item in generate_response( + slow_generator(), + mock_context, + mock_responses_params, + mock_turn_summary, + ): + result.append(item) + + task = asyncio.create_task(consume_generator()) + await asyncio.sleep(0.05) + task.cancel() + await asyncio.sleep(0.05) + + assert any('"event": "interrupted"' in item for item in result) + append_turn_mock.assert_called_once() + store_query_results_mock.assert_called_once() + isolate_stream_interrupt_registry.deregister_stream.assert_called_once_with( + test_request_id + ) + + @pytest.mark.asyncio + async def test_cancel_stream_callback_persists_when_error_hits_outside_generator( + self, + ) -> None: + """Test on_interrupt callback runs via cancel_stream as a separate task.""" + registry = StreamInterruptRegistry() + test_request_id = "123e4567-e89b-12d3-a456-426614174099" + registry.deregister_stream(test_request_id) + + callback_ran = False + + async def mock_callback() -> None: + nonlocal callback_ran + callback_ran = True + + async def pending_stream() -> None: + await asyncio.sleep(10) + + task = asyncio.create_task(pending_stream()) + registry.register_stream( + test_request_id, "user_123", task, on_interrupt=mock_callback + ) + + result = registry.cancel_stream(test_request_id, "user_123") + assert result.value == "cancelled" + + # Let the scheduled callback task execute + await asyncio.sleep(0.01) + + assert callback_ran is True + + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + registry.deregister_stream(test_request_id) + class TestResponseGenerator: """Tests for response_generator function."""