Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 152 additions & 55 deletions src/xai_sdk/aio/chat.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import asyncio
import datetime
import json
import sys
import warnings
from typing import AsyncIterator, Optional, Sequence, TypeVar
from typing import Any, Optional, Sequence, TypeVar

from opentelemetry.trace import SpanKind
from opentelemetry.trace import SpanKind, Status, StatusCode
from pydantic import BaseModel

from ..chat import BaseChat, BaseClient, Chunk, Response
Expand Down Expand Up @@ -81,6 +82,145 @@ async def delete_stored_completion(self, response_id: str) -> str:
T = TypeVar("T", bound=BaseModel)


class _AsyncChatStreamBase:
def __init__(self, chat: "Chat", n: int, span_name: str) -> None:
self._chat = chat
self._n = n
self._span_name = span_name
self._stream: Optional[Any] = None
self._stream_iterator: Optional[Any] = None
self._span: Optional[Any] = None
self._closed = False
self._first_chunk_received = False

@property
def grpc_call(self) -> Optional[Any]:
"""Returns the underlying gRPC unary-stream call, once iteration has started."""
return self._stream

def cancel(self) -> bool:
"""Cancels the underlying gRPC stream if it has started."""
self._closed = True
if self._stream is None:
self._finish_span()
return False

cancelled = self._stream.cancel()
self._finish_span()
return cancelled

async def close(self) -> None:
"""Closes the stream by cancelling the underlying gRPC call."""
self.cancel()

async def aclose(self) -> None:
"""Closes the stream using the standard async-generator method name."""
await self.close()

def _ensure_started(self) -> Any:
if self._closed:
raise StopAsyncIteration

if self._stream is None:
self._span = tracer.start_span(
name=f"{self._span_name} {self._chat._proto.model}",
kind=SpanKind.CLIENT,
attributes=self._chat._make_span_request_attributes(),
)
stream = self._chat._stub.GetCompletionChunk(self._chat._make_request(self._n))
self._stream = stream
self._stream_iterator = stream.__aiter__()

if self._stream_iterator is None:
raise StopAsyncIteration
return self._stream_iterator

async def _read_chunk(self, responses: Sequence[Response]) -> chat_pb2.GetChatCompletionChunk:
try:
stream_iterator = self._ensure_started()
return await stream_iterator.__anext__()
except StopAsyncIteration:
self._finish_span(responses)
raise
except BaseException:
self._finish_span(exc_info=sys.exc_info())
raise

def _mark_first_chunk_received(self) -> None:
if not self._first_chunk_received and self._span is not None:
self._span.set_attribute(
"gen_ai.completion.start_time",
datetime.datetime.now(datetime.timezone.utc).isoformat(),
)
self._first_chunk_received = True

def _finish_span(
self,
responses: Optional[Sequence[Response]] = None,
exc_info: Optional[tuple[Any, Any, Any]] = None,
) -> None:
if self._span is None:
return

if responses is not None:
self._span.set_attributes(self._chat._make_span_response_attributes(responses))

_, exc, _ = exc_info or (None, None, None)
if exc is not None:
self._span.record_exception(exc)
self._span.set_status(Status(StatusCode.ERROR, str(exc)))

self._span.end()
self._span = None


class AsyncChatStream(_AsyncChatStreamBase):
"""Cancelable async iterator for a single streamed chat completion."""

def __init__(self, chat: "Chat") -> None:
"""Creates a cancelable single-response chat stream."""
super().__init__(chat, n=1, span_name="chat.stream")
self._index = None if chat._uses_server_side_tools() else 0
self._response = Response(
chat_pb2.GetChatCompletionResponse(outputs=[chat_pb2.CompletionOutput()]), self._index
)

def __aiter__(self) -> "AsyncChatStream":
return self

async def __anext__(self) -> tuple[Response, Chunk]:
chunk = await self._read_chunk([self._response])
self._mark_first_chunk_received()

# Auto-detect if server added tools implicitly
self._index = self._chat._auto_detect_multi_output_mode(self._index, chunk.outputs)
self._response._index = self._index

self._response.process_chunk(chunk)
chunk_obj = Chunk(chunk, self._index)
return self._response, chunk_obj


class AsyncChatBatchStream(_AsyncChatStreamBase):
"""Cancelable async iterator for multiple streamed chat completions."""

def __init__(self, chat: "Chat", n: int) -> None:
"""Creates a cancelable multi-response chat stream."""
super().__init__(chat, n=n, span_name="chat.stream_batch")
proto = chat_pb2.GetChatCompletionResponse(outputs=[chat_pb2.CompletionOutput(index=i) for i in range(n)])
self._responses = [Response(proto, i) for i in range(n)]

def __aiter__(self) -> "AsyncChatBatchStream":
return self

async def __anext__(self) -> tuple[Sequence[Response], Sequence[Chunk]]:
chunk = await self._read_chunk(self._responses)
self._mark_first_chunk_received()

self._responses[0].process_chunk(chunk)
return self._responses, [Chunk(chunk, i) for i in range(self._n)]


class Chat(BaseChat):
"""Utility class for simplifying the interaction with Chat requests and responses."""

Expand Down Expand Up @@ -154,19 +294,20 @@ async def sample_batch(self, n: int) -> Sequence[Response]:
span.set_attributes(self._make_span_response_attributes(responses))
return responses

async def stream(self) -> AsyncIterator[tuple[Response, Chunk]]:
def stream(self) -> AsyncChatStream:
"""Asynchronously streams a single chat completion response.

This method streams the model's response in chunks, yielding each chunk as it is
received. It is suitable for real-time applications where partial responses
are displayed as they arrive, such as interactive chat interfaces.

Returns:
AsyncIterator[tuple[Response, Chunk]]: An async iterator yielding tuples, each
containing:
AsyncChatStream: A cancelable async iterator yielding tuples, each containing:
- `Response`: The accumulating response object, updated with each chunk.
- `Chunk`: A `Chunk` object containing the content and metadata of the
current chunk.
Call `stream.cancel()`, `await stream.close()`, or `await stream.aclose()`
to cancel the underlying gRPC stream from another task.

Example:
>>> chat = client.chat.create(model="grok-4.20-non-reasoning")
Expand All @@ -177,34 +318,9 @@ async def stream(self) -> AsyncIterator[tuple[Response, Chunk]]:
>>> print(response.content)
"Once upon a time..." (full accumulated response)
"""
first_chunk_received = False
with tracer.start_as_current_span(
name=f"chat.stream {self._proto.model}",
kind=SpanKind.CLIENT,
attributes=self._make_span_request_attributes(),
) as span:
index = None if self._uses_server_side_tools() else 0
response = Response(chat_pb2.GetChatCompletionResponse(outputs=[chat_pb2.CompletionOutput()]), index)
stream = self._stub.GetCompletionChunk(self._make_request(1))

async for chunk in stream:
if not first_chunk_received:
span.set_attribute(
"gen_ai.completion.start_time", datetime.datetime.now(datetime.timezone.utc).isoformat()
)
first_chunk_received = True

# Auto-detect if server added tools implicitly
index = self._auto_detect_multi_output_mode(index, chunk.outputs)
response._index = index
return AsyncChatStream(self)

response.process_chunk(chunk)
chunk_obj = Chunk(chunk, index)
yield response, chunk_obj

span.set_attributes(self._make_span_response_attributes([response]))

async def stream_batch(self, n: int) -> AsyncIterator[tuple[Sequence[Response], Sequence[Chunk]]]:
def stream_batch(self, n: int) -> AsyncChatBatchStream:
"""Asynchronously streams multiple chat completion responses.

This method streams `n` responses concurrently in a single request, yielding chunks
Expand All @@ -216,12 +332,13 @@ async def stream_batch(self, n: int) -> AsyncIterator[tuple[Sequence[Response],
n: The number of responses to generate.

Returns:
AsyncIterator[tuple[Sequence[Response], Sequence[Chunk]]]: An async iterator
yielding tuples, each containing:
AsyncChatBatchStream: A cancelable async iterator yielding tuples, each containing:
- `Sequence[Response]`: A sequence of `Response` objects, one for each of
the `n` responses, updated with each chunk.
- `Sequence[Chunk]`: A sequence of `Chunk` objects, one for each response,
containing the content and metadata of the current chunk.
Call `stream.cancel()`, `await stream.close()`, or `await stream.aclose()`
to cancel the underlying gRPC stream from another task.

Example:
>>> chat = client.chat.create(model="grok-4.20-non-reasoning")
Expand All @@ -240,27 +357,7 @@ async def stream_batch(self, n: int) -> AsyncIterator[tuple[Sequence[Response],
stacklevel=2,
)

proto = chat_pb2.GetChatCompletionResponse(outputs=[chat_pb2.CompletionOutput(index=i) for i in range(n)])
responses = [Response(proto, i) for i in range(n)]
first_chunk_received = False

with tracer.start_as_current_span(
name=f"chat.stream_batch {self._proto.model}",
kind=SpanKind.CLIENT,
attributes=self._make_span_request_attributes(),
) as span:
stream = self._stub.GetCompletionChunk(self._make_request(n))
async for chunk in stream:
if not first_chunk_received:
span.set_attribute(
"gen_ai.completion.start_time", datetime.datetime.now(datetime.timezone.utc).isoformat()
)
first_chunk_received = True

responses[0].process_chunk(chunk)
yield responses, [Chunk(chunk, i) for i in range(n)]

span.set_attributes(self._make_span_response_attributes(responses))
return AsyncChatBatchStream(self, n)

async def parse(self, shape: type[T]) -> tuple[Response, T]:
"""Asynchronously generates and parses a single chat completion response into a Pydantic model.
Expand Down
Loading