From 19cf2cdf7e62aaae118822aba48ff17ed32226c9 Mon Sep 17 00:00:00 2001 From: mukunda katta Date: Fri, 15 May 2026 07:55:12 -0700 Subject: [PATCH] Add streaming collection document uploads --- src/xai_sdk/aio/collections.py | 96 ++++++++++++++++++++++++++++++++- src/xai_sdk/sync/collections.py | 91 ++++++++++++++++++++++++++++++- tests/aio/collections_test.py | 23 ++++++++ tests/sync/collections_test.py | 22 ++++++++ 4 files changed, 228 insertions(+), 4 deletions(-) diff --git a/src/xai_sdk/aio/collections.py b/src/xai_sdk/aio/collections.py index ed7a3f2..f0c8816 100644 --- a/src/xai_sdk/aio/collections.py +++ b/src/xai_sdk/aio/collections.py @@ -1,7 +1,8 @@ import asyncio import datetime +import os import warnings -from typing import Optional, Sequence, Union +from typing import BinaryIO, Optional, Sequence, Union from opentelemetry.trace import SpanKind @@ -25,7 +26,12 @@ _hnsw_metric_to_pb, _order_to_pb, ) -from ..files import _async_chunk_file_data +from ..files import ( + ProgressCallback, + _async_chunk_file_data, + _async_chunk_file_from_fileobj, + _async_chunk_file_from_path, +) from ..poll_timer import PollTimer from ..proto import collections_pb2, documents_pb2, shared_pb2, types_pb2 from ..telemetry import get_tracer @@ -388,6 +394,92 @@ async def upload_document( collection_id, ) + async def upload_document_file( + self, + collection_id: str, + file: Union[str, BinaryIO], + *, + filename: Optional[str] = None, + fields: Optional[dict[str, str]] = None, + wait_for_indexing: bool = False, + poll_interval: Optional[datetime.timedelta] = None, + timeout: Optional[datetime.timedelta] = None, + on_progress: Optional[ProgressCallback] = None, + expires_after: Optional[Union[datetime.timedelta, int]] = None, + ) -> collections_pb2.DocumentMetadata: + """Streams a file to xAI and adds it to a collection. + + Args: + collection_id: The ID of the collection to upload the document to. + file: A path or binary file-like object to upload. + filename: Name to use for the uploaded file. Required when `file` is a + file-like object without a `.name` attribute. + fields: Additional metadata fields to store with the document. + wait_for_indexing: Whether to wait for the document to be indexed. + poll_interval: The interval to poll for when checking whether the document has been indexed. + timeout: The total time to wait for the document to be indexed before returning. + on_progress: Optional callback invoked after each chunk is uploaded. + expires_after: Optional time-to-live for the uploaded file. + + Returns: + The metadata for the uploaded document. + """ + if isinstance(file, str): + if not os.path.exists(file): + raise FileNotFoundError(f"File not found: {file}") + upload_chunks = _async_chunk_file_from_path( + file_path=file, + progress=on_progress, + expires_after=expires_after, + ) + elif hasattr(file, "read"): + if filename is None: + if hasattr(file, "name") and isinstance(file.name, str): + filename = os.path.basename(file.name) + else: + raise ValueError("filename is required when uploading a file-like object without a .name attribute") + upload_chunks = _async_chunk_file_from_fileobj( + file_obj=file, + filename=filename, + progress=on_progress, + expires_after=expires_after, + ) + else: + raise ValueError(f"Unsupported file type: {type(file)}") + + with tracer.start_as_current_span( + name="collections.upload_document_file", + kind=SpanKind.CLIENT, + attributes={ + "operation.name": "upload_document_file", + "provider.name": "xai", + }, + ) as span: + uploaded_file = await self._files_stub.UploadFile(upload_chunks) + span.set_attribute("file.id", uploaded_file.id) + span.set_attribute("file.name", uploaded_file.filename) + + await self._collections_stub.AddDocumentToCollection( + collections_pb2.AddDocumentToCollectionRequest( + collection_id=collection_id, + file_id=uploaded_file.id, + fields=fields, + ) + ) + + if wait_for_indexing: + return await self._wait_for_indexing( + collection_id, + uploaded_file.id, + poll_interval or DEFAULT_INDEXING_POLL_INTERVAL, + timeout or DEFAULT_INDEXING_TIMEOUT, + ) + + return await self.get_document( + uploaded_file.id, + collection_id, + ) + async def _wait_for_indexing( self, collection_id: str, diff --git a/src/xai_sdk/sync/collections.py b/src/xai_sdk/sync/collections.py index cfff119..d918366 100644 --- a/src/xai_sdk/sync/collections.py +++ b/src/xai_sdk/sync/collections.py @@ -1,7 +1,8 @@ import datetime +import os import time import warnings -from typing import Optional, Sequence, Union +from typing import BinaryIO, Optional, Sequence, Union from opentelemetry.trace import SpanKind @@ -25,7 +26,7 @@ _hnsw_metric_to_pb, _order_to_pb, ) -from ..files import _chunk_file_data +from ..files import ProgressCallback, _chunk_file_data, _chunk_file_from_fileobj, _chunk_file_from_path from ..poll_timer import PollTimer from ..proto import collections_pb2, documents_pb2, shared_pb2, types_pb2 from ..telemetry import get_tracer @@ -391,6 +392,92 @@ def upload_document( collection_id, ) + def upload_document_file( + self, + collection_id: str, + file: Union[str, BinaryIO], + *, + filename: Optional[str] = None, + fields: Optional[dict[str, str]] = None, + wait_for_indexing: bool = False, + poll_interval: Optional[datetime.timedelta] = None, + timeout: Optional[datetime.timedelta] = None, + on_progress: Optional[ProgressCallback] = None, + expires_after: Optional[Union[datetime.timedelta, int]] = None, + ) -> collections_pb2.DocumentMetadata: + """Streams a file to xAI and adds it to a collection. + + Args: + collection_id: The ID of the collection to upload the document to. + file: A path or binary file-like object to upload. + filename: Name to use for the uploaded file. Required when `file` is a + file-like object without a `.name` attribute. + fields: Additional metadata fields to store with the document. + wait_for_indexing: Whether to wait for the document to be indexed. + poll_interval: The interval to poll for when checking whether the document has been indexed. + timeout: The total time to wait for the document to be indexed before returning. + on_progress: Optional callback invoked after each chunk is uploaded. + expires_after: Optional time-to-live for the uploaded file. + + Returns: + The metadata for the uploaded document. + """ + if isinstance(file, str): + if not os.path.exists(file): + raise FileNotFoundError(f"File not found: {file}") + upload_chunks = _chunk_file_from_path( + file_path=file, + progress=on_progress, + expires_after=expires_after, + ) + elif hasattr(file, "read"): + if filename is None: + if hasattr(file, "name") and isinstance(file.name, str): + filename = os.path.basename(file.name) + else: + raise ValueError("filename is required when uploading a file-like object without a .name attribute") + upload_chunks = _chunk_file_from_fileobj( + file_obj=file, + filename=filename, + progress=on_progress, + expires_after=expires_after, + ) + else: + raise ValueError(f"Unsupported file type: {type(file)}") + + with tracer.start_as_current_span( + name="collections.upload_document_file", + kind=SpanKind.CLIENT, + attributes={ + "operation.name": "upload_document_file", + "provider.name": "xai", + }, + ) as span: + uploaded_file = self._files_stub.UploadFile(upload_chunks) + span.set_attribute("file.id", uploaded_file.id) + span.set_attribute("file.name", uploaded_file.filename) + + self._collections_stub.AddDocumentToCollection( + collections_pb2.AddDocumentToCollectionRequest( + collection_id=collection_id, + file_id=uploaded_file.id, + fields=fields, + ) + ) + + if wait_for_indexing: + return self._wait_for_indexing_to_complete( + collection_id, + uploaded_file.id, + poll_interval or DEFAULT_INDEXING_POLL_INTERVAL, + timeout or DEFAULT_INDEXING_TIMEOUT, + ) + + return self.get_document( + uploaded_file.id, + collection_id, + ) + def _wait_for_indexing_to_complete( self, collection_id: str, diff --git a/tests/aio/collections_test.py b/tests/aio/collections_test.py index 7fe5b0e..cada900 100644 --- a/tests/aio/collections_test.py +++ b/tests/aio/collections_test.py @@ -645,6 +645,29 @@ async def test_upload_document(client: AsyncClient): assert response.fields == fields +@pytest.mark.asyncio(loop_scope="session") +async def test_upload_document_file_from_path(client: AsyncClient, tmp_path): + collection_metadata = await client.collections.create(f"test-collection-{uuid.uuid4()}") + assert collection_metadata.collection_id is not None + + data = b"Hello from a streamed file!" + fields = {"source": "path"} + file_path = tmp_path / "streamed-document.txt" + file_path.write_bytes(data) + + document_metadata = await client.collections.upload_document_file( + collection_metadata.collection_id, + str(file_path), + fields=fields, + ) + + assert document_metadata.file_metadata.file_id is not None + assert document_metadata.file_metadata.name == file_path.name + assert document_metadata.file_metadata.size_bytes == len(data) + assert document_metadata.file_metadata.content_type == "text/plain" + assert document_metadata.fields == fields + + @pytest.mark.asyncio(loop_scope="session") async def test_add_existing_document_to_collection(client: AsyncClient): # Create a collection to add the document to. diff --git a/tests/sync/collections_test.py b/tests/sync/collections_test.py index f1a5612..7c8a77a 100644 --- a/tests/sync/collections_test.py +++ b/tests/sync/collections_test.py @@ -620,6 +620,28 @@ def test_upload_document(client: Client): assert response.fields == fields +def test_upload_document_file_from_path(client: Client, tmp_path): + collection_metadata = client.collections.create(f"test-collection-{uuid.uuid4()}") + assert collection_metadata.collection_id is not None + + data = b"Hello from a streamed file!" + fields = {"source": "path"} + file_path = tmp_path / "streamed-document.txt" + file_path.write_bytes(data) + + document_metadata = client.collections.upload_document_file( + collection_metadata.collection_id, + str(file_path), + fields=fields, + ) + + assert document_metadata.file_metadata.file_id is not None + assert document_metadata.file_metadata.name == file_path.name + assert document_metadata.file_metadata.size_bytes == len(data) + assert document_metadata.file_metadata.content_type == "text/plain" + assert document_metadata.fields == fields + + def test_upload_document_without_wait_for_indexing(client: Client): """Test uploading a document without waiting for indexing returns immediately with PROCESSED status.""" collection_metadata = client.collections.create(f"test-collection-{uuid.uuid4()}")