Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/xai_sdk/aio/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..proto import batch_pb2
from ..telemetry import get_tracer
from ..types import ImageGenerationModel
from ..types.chat import ServiceTier

tracer = get_tracer(__name__)

Expand All @@ -24,7 +25,7 @@
class Client(BaseClient):
"""Asynchronous client for interacting with the `Image` API."""

def prepare(

Check failure on line 28 in src/xai_sdk/aio/image.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.10)

Ruff (D417)

src/xai_sdk/aio/image.py:28:9: D417 Missing argument description in the docstring for `prepare`: `service_tier`

Check failure on line 28 in src/xai_sdk/aio/image.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.11)

Ruff (D417)

src/xai_sdk/aio/image.py:28:9: D417 Missing argument description in the docstring for `prepare`: `service_tier`

Check failure on line 28 in src/xai_sdk/aio/image.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.13)

Ruff (D417)

src/xai_sdk/aio/image.py:28:9: D417 Missing argument description in the docstring for `prepare`: `service_tier`

Check failure on line 28 in src/xai_sdk/aio/image.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.12)

Ruff (D417)

src/xai_sdk/aio/image.py:28:9: D417 Missing argument description in the docstring for `prepare`: `service_tier`
self,
prompt: str,
model: Union[ImageGenerationModel, str],
Expand All @@ -36,6 +37,7 @@
image_format: Optional[ImageFormat] = None,
aspect_ratio: Optional[ImageAspectRatio] = None,
resolution: Optional[ImageResolution] = None,
service_tier: Optional[ServiceTier] = None,
) -> batch_pb2.BatchRequest:
"""Prepares an image generation request for batch processing.

Expand Down Expand Up @@ -96,13 +98,14 @@
image_format=image_format,
aspect_ratio=aspect_ratio,
resolution=resolution,
service_tier=service_tier,
)
return batch_pb2.BatchRequest(
image_request=request,
batch_request_id=batch_request_id or "",
)

async def sample(

Check failure on line 108 in src/xai_sdk/aio/image.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.10)

Ruff (D417)

src/xai_sdk/aio/image.py:108:15: D417 Missing argument description in the docstring for `sample`: `service_tier`

Check failure on line 108 in src/xai_sdk/aio/image.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.11)

Ruff (D417)

src/xai_sdk/aio/image.py:108:15: D417 Missing argument description in the docstring for `sample`: `service_tier`

Check failure on line 108 in src/xai_sdk/aio/image.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.13)

Ruff (D417)

src/xai_sdk/aio/image.py:108:15: D417 Missing argument description in the docstring for `sample`: `service_tier`

Check failure on line 108 in src/xai_sdk/aio/image.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.12)

Ruff (D417)

src/xai_sdk/aio/image.py:108:15: D417 Missing argument description in the docstring for `sample`: `service_tier`
self,
prompt: str,
model: Union[ImageGenerationModel, str],
Expand All @@ -113,6 +116,7 @@
image_format: Optional[ImageFormat] = None,
aspect_ratio: Optional[ImageAspectRatio] = None,
resolution: Optional[ImageResolution] = None,
service_tier: Optional[ServiceTier] = None,
) -> "ImageResponse":
"""Samples a single image asynchronously based on the provided prompt.

Expand Down Expand Up @@ -163,6 +167,7 @@
image_format=image_format,
aspect_ratio=aspect_ratio,
resolution=resolution,
service_tier=service_tier,
)
with tracer.start_as_current_span(
name=f"image.sample {model}",
Expand All @@ -174,7 +179,7 @@
span.set_attributes(_make_span_response_attributes(request, [image_response]))
return image_response

async def sample_batch(

Check failure on line 182 in src/xai_sdk/aio/image.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.10)

Ruff (D417)

src/xai_sdk/aio/image.py:182:15: D417 Missing argument description in the docstring for `sample_batch`: `service_tier`

Check failure on line 182 in src/xai_sdk/aio/image.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.11)

Ruff (D417)

src/xai_sdk/aio/image.py:182:15: D417 Missing argument description in the docstring for `sample_batch`: `service_tier`

Check failure on line 182 in src/xai_sdk/aio/image.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.13)

Ruff (D417)

src/xai_sdk/aio/image.py:182:15: D417 Missing argument description in the docstring for `sample_batch`: `service_tier`

Check failure on line 182 in src/xai_sdk/aio/image.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.12)

Ruff (D417)

src/xai_sdk/aio/image.py:182:15: D417 Missing argument description in the docstring for `sample_batch`: `service_tier`
self,
prompt: str,
model: Union[ImageGenerationModel, str],
Expand All @@ -186,6 +191,7 @@
image_format: Optional[ImageFormat] = None,
aspect_ratio: Optional[ImageAspectRatio] = None,
resolution: Optional[ImageResolution] = None,
service_tier: Optional[ServiceTier] = None,
) -> Sequence["ImageResponse"]:
"""Samples a batch of images asynchronously based on the provided prompt.

Expand Down Expand Up @@ -238,6 +244,7 @@
image_format=image_format,
aspect_ratio=aspect_ratio,
resolution=resolution,
service_tier=service_tier,
)
with tracer.start_as_current_span(
name=f"image.sample_batch {model}",
Expand Down
13 changes: 13 additions & 0 deletions src/xai_sdk/aio/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ..proto import batch_pb2, deferred_pb2, video_pb2
from ..telemetry import get_tracer
from ..types import VideoGenerationModel
from ..types.chat import ServiceTier
from ..video import (
DEFAULT_VIDEO_POLL_INTERVAL,
DEFAULT_VIDEO_TIMEOUT,
Expand All @@ -31,7 +32,7 @@
class Client(BaseClient):
"""Asynchronous client for interacting with the `Video` API."""

def prepare(

Check failure on line 35 in src/xai_sdk/aio/video.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.10)

Ruff (D417)

src/xai_sdk/aio/video.py:35:9: D417 Missing argument description in the docstring for `prepare`: `service_tier`

Check failure on line 35 in src/xai_sdk/aio/video.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.11)

Ruff (D417)

src/xai_sdk/aio/video.py:35:9: D417 Missing argument description in the docstring for `prepare`: `service_tier`

Check failure on line 35 in src/xai_sdk/aio/video.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.13)

Ruff (D417)

src/xai_sdk/aio/video.py:35:9: D417 Missing argument description in the docstring for `prepare`: `service_tier`

Check failure on line 35 in src/xai_sdk/aio/video.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.12)

Ruff (D417)

src/xai_sdk/aio/video.py:35:9: D417 Missing argument description in the docstring for `prepare`: `service_tier`
self,
prompt: str,
model: Union[VideoGenerationModel, str],
Expand All @@ -43,6 +44,7 @@
aspect_ratio: Optional[VideoAspectRatio] = None,
resolution: Optional[VideoResolution] = None,
reference_image_urls: Optional[Sequence[str]] = None,
service_tier: Optional[ServiceTier] = None,
) -> batch_pb2.BatchRequest:
"""Prepares a video generation request for batch processing.

Expand Down Expand Up @@ -103,6 +105,7 @@
aspect_ratio=aspect_ratio,
resolution=resolution,
reference_image_urls=reference_image_urls,
service_tier=service_tier,
)

return batch_pb2.BatchRequest(
Expand All @@ -110,7 +113,7 @@
batch_request_id=batch_request_id or "",
)

def prepare_extension(

Check failure on line 116 in src/xai_sdk/aio/video.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.10)

Ruff (D417)

src/xai_sdk/aio/video.py:116:9: D417 Missing argument description in the docstring for `prepare_extension`: `service_tier`

Check failure on line 116 in src/xai_sdk/aio/video.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.11)

Ruff (D417)

src/xai_sdk/aio/video.py:116:9: D417 Missing argument description in the docstring for `prepare_extension`: `service_tier`

Check failure on line 116 in src/xai_sdk/aio/video.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.13)

Ruff (D417)

src/xai_sdk/aio/video.py:116:9: D417 Missing argument description in the docstring for `prepare_extension`: `service_tier`

Check failure on line 116 in src/xai_sdk/aio/video.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.12)

Ruff (D417)

src/xai_sdk/aio/video.py:116:9: D417 Missing argument description in the docstring for `prepare_extension`: `service_tier`
self,
prompt: str,
model: Union[VideoGenerationModel, str],
Expand All @@ -118,6 +121,7 @@
*,
batch_request_id: Optional[str] = None,
duration: Optional[int] = None,
service_tier: Optional[ServiceTier] = None,
) -> batch_pb2.BatchRequest:
"""Prepares a video extension request for batch processing.

Expand All @@ -141,6 +145,7 @@
model,
video_url,
duration=duration,
service_tier=service_tier,
)

return batch_pb2.BatchRequest(
Expand All @@ -159,6 +164,7 @@
aspect_ratio: Optional[VideoAspectRatio] = None,
resolution: Optional[VideoResolution] = None,
reference_image_urls: Optional[Sequence[str]] = None,
service_tier: Optional[ServiceTier] = None,
) -> deferred_pb2.StartDeferredResponse:
"""Starts a video generation request and returns a request_id for polling."""
request = _make_generate_request(
Expand All @@ -170,6 +176,7 @@
aspect_ratio=aspect_ratio,
resolution=resolution,
reference_image_urls=reference_image_urls,
service_tier=service_tier,
)

with tracer.start_as_current_span(
Expand All @@ -184,7 +191,7 @@
request = video_pb2.GetDeferredVideoRequest(request_id=request_id)
return await self._stub.GetDeferredVideo(request)

async def generate(

Check failure on line 194 in src/xai_sdk/aio/video.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.10)

Ruff (D417)

src/xai_sdk/aio/video.py:194:15: D417 Missing argument description in the docstring for `generate`: `service_tier`

Check failure on line 194 in src/xai_sdk/aio/video.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.11)

Ruff (D417)

src/xai_sdk/aio/video.py:194:15: D417 Missing argument description in the docstring for `generate`: `service_tier`

Check failure on line 194 in src/xai_sdk/aio/video.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.13)

Ruff (D417)

src/xai_sdk/aio/video.py:194:15: D417 Missing argument description in the docstring for `generate`: `service_tier`

Check failure on line 194 in src/xai_sdk/aio/video.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.12)

Ruff (D417)

src/xai_sdk/aio/video.py:194:15: D417 Missing argument description in the docstring for `generate`: `service_tier`
self,
prompt: str,
model: Union[VideoGenerationModel, str],
Expand All @@ -195,6 +202,7 @@
aspect_ratio: Optional[VideoAspectRatio] = None,
resolution: Optional[VideoResolution] = None,
reference_image_urls: Optional[Sequence[str]] = None,
service_tier: Optional[ServiceTier] = None,
timeout: Optional[datetime.timedelta] = None,
interval: Optional[datetime.timedelta] = None,
) -> VideoResponse:
Expand Down Expand Up @@ -295,6 +303,7 @@
aspect_ratio=aspect_ratio,
resolution=resolution,
reference_image_urls=reference_image_urls,
service_tier=service_tier,
)

with tracer.start_as_current_span(
Expand Down Expand Up @@ -332,13 +341,14 @@
)
await asyncio.sleep(timer.sleep_interval_or_raise())

async def extend_start(

Check failure on line 344 in src/xai_sdk/aio/video.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.10)

Ruff (D417)

src/xai_sdk/aio/video.py:344:15: D417 Missing argument description in the docstring for `extend_start`: `service_tier`

Check failure on line 344 in src/xai_sdk/aio/video.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.11)

Ruff (D417)

src/xai_sdk/aio/video.py:344:15: D417 Missing argument description in the docstring for `extend_start`: `service_tier`

Check failure on line 344 in src/xai_sdk/aio/video.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.13)

Ruff (D417)

src/xai_sdk/aio/video.py:344:15: D417 Missing argument description in the docstring for `extend_start`: `service_tier`

Check failure on line 344 in src/xai_sdk/aio/video.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.12)

Ruff (D417)

src/xai_sdk/aio/video.py:344:15: D417 Missing argument description in the docstring for `extend_start`: `service_tier`
self,
prompt: str,
model: Union[VideoGenerationModel, str],
video_url: str,
*,
duration: Optional[int] = None,
service_tier: Optional[ServiceTier] = None,
) -> deferred_pb2.StartDeferredResponse:
"""Starts a video extension request and returns a request_id for polling.

Expand All @@ -358,6 +368,7 @@
model,
video_url,
duration=duration,
service_tier=service_tier,
)

with tracer.start_as_current_span(
Expand All @@ -367,13 +378,14 @@
):
return await self._stub.ExtendVideo(request)

async def extend(

Check failure on line 381 in src/xai_sdk/aio/video.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.10)

Ruff (D417)

src/xai_sdk/aio/video.py:381:15: D417 Missing argument description in the docstring for `extend`: `service_tier`

Check failure on line 381 in src/xai_sdk/aio/video.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.11)

Ruff (D417)

src/xai_sdk/aio/video.py:381:15: D417 Missing argument description in the docstring for `extend`: `service_tier`

Check failure on line 381 in src/xai_sdk/aio/video.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.13)

Ruff (D417)

src/xai_sdk/aio/video.py:381:15: D417 Missing argument description in the docstring for `extend`: `service_tier`

Check failure on line 381 in src/xai_sdk/aio/video.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.12)

Ruff (D417)

src/xai_sdk/aio/video.py:381:15: D417 Missing argument description in the docstring for `extend`: `service_tier`
self,
prompt: str,
model: Union[VideoGenerationModel, str],
video_url: str,
*,
duration: Optional[int] = None,
service_tier: Optional[ServiceTier] = None,
timeout: Optional[datetime.timedelta] = None,
interval: Optional[datetime.timedelta] = None,
) -> VideoResponse:
Expand Down Expand Up @@ -441,6 +453,7 @@
model,
video_url,
duration=duration,
service_tier=service_tier,
)

with tracer.start_as_current_span(
Expand Down
32 changes: 32 additions & 0 deletions src/xai_sdk/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
IncludeOptionMap,
ReasoningEffort,
ResponseFormat,
ServiceTier,
ToolMode,
)

Expand All @@ -38,7 +39,7 @@
"""Creates a new client based on a gRPC channel."""
self._stub = chat_pb2_grpc.ChatStub(channel)

def create(

Check failure on line 42 in src/xai_sdk/chat.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.10)

Ruff (PLR0912)

src/xai_sdk/chat.py:42:9: PLR0912 Too many branches (14 > 12)

Check failure on line 42 in src/xai_sdk/chat.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.11)

Ruff (PLR0912)

src/xai_sdk/chat.py:42:9: PLR0912 Too many branches (14 > 12)

Check failure on line 42 in src/xai_sdk/chat.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.13)

Ruff (PLR0912)

src/xai_sdk/chat.py:42:9: PLR0912 Too many branches (14 > 12)

Check failure on line 42 in src/xai_sdk/chat.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.12)

Ruff (PLR0912)

src/xai_sdk/chat.py:42:9: PLR0912 Too many branches (14 > 12)
self,
model: Union[ChatModel, str],
*,
Expand Down Expand Up @@ -67,6 +68,7 @@
include: Optional[Sequence[Union[IncludeOption, "chat_pb2.IncludeOption"]]] = None,
agent_count: Optional[Union[AgentCount, "chat_pb2.AgentCount"]] = None,
batch_request_id: Optional[str] = None,
service_tier: Optional[Union[ServiceTier, "usage_pb2.ServiceTier"]] = None,
) -> T:
"""Creates a new chat conversation.

Expand Down Expand Up @@ -175,6 +177,9 @@
batch_request_id: An optional user-provided identifier for the batch request. **If provided, it must be
unique within the batch.**Used to identify the corresponding result when the response is returned to the
user.
service_tier: The processing tier for this request. Set to `"priority"` for higher
scheduling priority. Valid values: `"auto"` (default), `"default"`, `"priority"`.
The response includes a `service_tier` field indicating which tier was actually used.

Returns:
A new chat request bound to a client.
Expand Down Expand Up @@ -221,6 +226,12 @@
else:
agent_count_pb = agent_count

service_tier_pb: Optional[usage_pb2.ServiceTier] = None
if isinstance(service_tier, str):
service_tier_pb = _service_tier_to_proto(service_tier)
elif service_tier is not None:
service_tier_pb = service_tier

return self._make_chat(
conversation_id=conversation_id,
batch_request_id=batch_request_id,
Expand Down Expand Up @@ -248,6 +259,7 @@
max_turns=max_turns,
include=include_pb,
agent_count=agent_count_pb,
service_tier=service_tier_pb,
)

@abc.abstractmethod
Expand Down Expand Up @@ -922,6 +934,17 @@
raise ValueError(f"Invalid include option: {include_option}. Must be one of: {IncludeOptionMap.keys()}")


def _service_tier_to_proto(tier: ServiceTier) -> usage_pb2.ServiceTier:
"""Converts a `ServiceTier` literal to a proto."""
match tier:
case "priority":
return usage_pb2.ServiceTier.SERVICE_TIER_PRIORITY
case "default" | "auto":
return usage_pb2.ServiceTier.SERVICE_TIER_DEFAULT
case _:
raise ValueError(f"Invalid service tier: {tier}. Must be one of: 'auto', 'default', 'priority'.")


def _agent_count_to_proto(agent_count: int) -> chat_pb2.AgentCount:
"""Converts an `AgentCount` literal to a proto."""
if agent_count in AgentCountMap:
Expand Down Expand Up @@ -1286,6 +1309,15 @@
"""Returns the system fingerprint of this response."""
return self.proto.system_fingerprint

@property
def service_tier(self) -> str:
"""Returns the processing tier used for this request.

Returns ``"priority"`` if the request was served at the priority tier,
or ``"default"`` otherwise.
"""
return usage_pb2.ServiceTier.Name(self._proto.service_tier).removeprefix("SERVICE_TIER_").lower()

@property
def tool_calls(self) -> Sequence[chat_pb2.ToolCall]:
"""Returns the all tool calls of this response."""
Expand Down
6 changes: 6 additions & 0 deletions src/xai_sdk/image.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import base64
import warnings
from typing import Any, Optional, Sequence, Union

import grpc

from .cost import cost_usd_from_usage
from .meta import ProtoDecorator
from .proto import image_pb2, image_pb2_grpc, usage_pb2
from .types.chat import ServiceTier
from .telemetry import should_disable_sensitive_attributes
from .types import ImageAspectRatio, ImageFormat, ImageGenerationModel, ImageResolution

Check failure on line 12 in src/xai_sdk/image.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.10)

Ruff (I001)

src/xai_sdk/image.py:1:1: I001 Import block is un-sorted or un-formatted

Check failure on line 12 in src/xai_sdk/image.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.11)

Ruff (I001)

src/xai_sdk/image.py:1:1: I001 Import block is un-sorted or un-formatted

Check failure on line 12 in src/xai_sdk/image.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.13)

Ruff (I001)

src/xai_sdk/image.py:1:1: I001 Import block is un-sorted or un-formatted

Check failure on line 12 in src/xai_sdk/image.py

View workflow job for this annotation

GitHub Actions / ci-checks (3.12)

Ruff (I001)

src/xai_sdk/image.py:1:1: I001 Import block is un-sorted or un-formatted

_IMAGE_ASPECT_RATIO_MAP: dict[ImageAspectRatio, image_pb2.ImageAspectRatio] = {
"1:1": image_pb2.ImageAspectRatio.IMG_ASPECT_RATIO_1_1,
Expand Down Expand Up @@ -130,6 +131,7 @@
image_format: ImageFormat | None = None,
aspect_ratio: ImageAspectRatio | None = None,
resolution: ImageResolution | None = None,
service_tier: Union[ServiceTier, "usage_pb2.ServiceTier", None] = None,
) -> image_pb2.GenerateImageRequest:
if image_url is not None and image_urls is not None:
raise ValueError("Only one of image_url or image_urls can be set for a request.")
Expand Down Expand Up @@ -163,6 +165,10 @@
request.aspect_ratio = convert_image_aspect_ratio_to_pb(aspect_ratio)
if resolution is not None:
request.resolution = convert_image_resolution_to_pb(resolution)
if service_tier is not None:
from .chat import _service_tier_to_proto

request.service_tier = _service_tier_to_proto(service_tier) if isinstance(service_tier, str) else service_tier
return request


Expand Down
192 changes: 64 additions & 128 deletions src/xai_sdk/proto/v5/chat_pb2.py

Large diffs are not rendered by default.

Loading
Loading