Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 94 additions & 2 deletions src/xai_sdk/aio/collections.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
import datetime
import os
import warnings
from typing import Optional, Sequence, Union
from typing import BinaryIO, Optional, Sequence, Union

from opentelemetry.trace import SpanKind

Expand All @@ -25,7 +26,12 @@
_hnsw_metric_to_pb,
_order_to_pb,
)
from ..files import _async_chunk_file_data
from ..files import (
ProgressCallback,
_async_chunk_file_data,
_async_chunk_file_from_fileobj,
_async_chunk_file_from_path,
)
from ..poll_timer import PollTimer
from ..proto import collections_pb2, documents_pb2, shared_pb2, types_pb2
from ..telemetry import get_tracer
Expand Down Expand Up @@ -388,6 +394,92 @@ async def upload_document(
collection_id,
)

async def upload_document_file(
self,
collection_id: str,
file: Union[str, BinaryIO],
*,
filename: Optional[str] = None,
fields: Optional[dict[str, str]] = None,
wait_for_indexing: bool = False,
poll_interval: Optional[datetime.timedelta] = None,
timeout: Optional[datetime.timedelta] = None,
on_progress: Optional[ProgressCallback] = None,
expires_after: Optional[Union[datetime.timedelta, int]] = None,
) -> collections_pb2.DocumentMetadata:
"""Streams a file to xAI and adds it to a collection.

Args:
collection_id: The ID of the collection to upload the document to.
file: A path or binary file-like object to upload.
filename: Name to use for the uploaded file. Required when `file` is a
file-like object without a `.name` attribute.
fields: Additional metadata fields to store with the document.
wait_for_indexing: Whether to wait for the document to be indexed.
poll_interval: The interval to poll for when checking whether the document has been indexed.
timeout: The total time to wait for the document to be indexed before returning.
on_progress: Optional callback invoked after each chunk is uploaded.
expires_after: Optional time-to-live for the uploaded file.

Returns:
The metadata for the uploaded document.
"""
if isinstance(file, str):
if not os.path.exists(file):
raise FileNotFoundError(f"File not found: {file}")
upload_chunks = _async_chunk_file_from_path(
file_path=file,
progress=on_progress,
expires_after=expires_after,
)
elif hasattr(file, "read"):
if filename is None:
if hasattr(file, "name") and isinstance(file.name, str):
filename = os.path.basename(file.name)
else:
raise ValueError("filename is required when uploading a file-like object without a .name attribute")
upload_chunks = _async_chunk_file_from_fileobj(
file_obj=file,
filename=filename,
progress=on_progress,
expires_after=expires_after,
)
else:
raise ValueError(f"Unsupported file type: {type(file)}")

with tracer.start_as_current_span(
name="collections.upload_document_file",
kind=SpanKind.CLIENT,
attributes={
"operation.name": "upload_document_file",
"provider.name": "xai",
},
) as span:
uploaded_file = await self._files_stub.UploadFile(upload_chunks)
span.set_attribute("file.id", uploaded_file.id)
span.set_attribute("file.name", uploaded_file.filename)

await self._collections_stub.AddDocumentToCollection(
collections_pb2.AddDocumentToCollectionRequest(
collection_id=collection_id,
file_id=uploaded_file.id,
fields=fields,
)
)

if wait_for_indexing:
return await self._wait_for_indexing(
collection_id,
uploaded_file.id,
poll_interval or DEFAULT_INDEXING_POLL_INTERVAL,
timeout or DEFAULT_INDEXING_TIMEOUT,
)

return await self.get_document(
uploaded_file.id,
collection_id,
)

async def _wait_for_indexing(
self,
collection_id: str,
Expand Down
91 changes: 89 additions & 2 deletions src/xai_sdk/sync/collections.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import datetime
import os
import time
import warnings
from typing import Optional, Sequence, Union
from typing import BinaryIO, Optional, Sequence, Union

from opentelemetry.trace import SpanKind

Expand All @@ -25,7 +26,7 @@
_hnsw_metric_to_pb,
_order_to_pb,
)
from ..files import _chunk_file_data
from ..files import ProgressCallback, _chunk_file_data, _chunk_file_from_fileobj, _chunk_file_from_path
from ..poll_timer import PollTimer
from ..proto import collections_pb2, documents_pb2, shared_pb2, types_pb2
from ..telemetry import get_tracer
Expand Down Expand Up @@ -391,6 +392,92 @@ def upload_document(
collection_id,
)

def upload_document_file(
self,
collection_id: str,
file: Union[str, BinaryIO],
*,
filename: Optional[str] = None,
fields: Optional[dict[str, str]] = None,
wait_for_indexing: bool = False,
poll_interval: Optional[datetime.timedelta] = None,
timeout: Optional[datetime.timedelta] = None,
on_progress: Optional[ProgressCallback] = None,
expires_after: Optional[Union[datetime.timedelta, int]] = None,
) -> collections_pb2.DocumentMetadata:
"""Streams a file to xAI and adds it to a collection.

Args:
collection_id: The ID of the collection to upload the document to.
file: A path or binary file-like object to upload.
filename: Name to use for the uploaded file. Required when `file` is a
file-like object without a `.name` attribute.
fields: Additional metadata fields to store with the document.
wait_for_indexing: Whether to wait for the document to be indexed.
poll_interval: The interval to poll for when checking whether the document has been indexed.
timeout: The total time to wait for the document to be indexed before returning.
on_progress: Optional callback invoked after each chunk is uploaded.
expires_after: Optional time-to-live for the uploaded file.

Returns:
The metadata for the uploaded document.
"""
if isinstance(file, str):
if not os.path.exists(file):
raise FileNotFoundError(f"File not found: {file}")
upload_chunks = _chunk_file_from_path(
file_path=file,
progress=on_progress,
expires_after=expires_after,
)
elif hasattr(file, "read"):
if filename is None:
if hasattr(file, "name") and isinstance(file.name, str):
filename = os.path.basename(file.name)
else:
raise ValueError("filename is required when uploading a file-like object without a .name attribute")
upload_chunks = _chunk_file_from_fileobj(
file_obj=file,
filename=filename,
progress=on_progress,
expires_after=expires_after,
)
else:
raise ValueError(f"Unsupported file type: {type(file)}")

with tracer.start_as_current_span(
name="collections.upload_document_file",
kind=SpanKind.CLIENT,
attributes={
"operation.name": "upload_document_file",
"provider.name": "xai",
},
) as span:
uploaded_file = self._files_stub.UploadFile(upload_chunks)
span.set_attribute("file.id", uploaded_file.id)
span.set_attribute("file.name", uploaded_file.filename)

self._collections_stub.AddDocumentToCollection(
collections_pb2.AddDocumentToCollectionRequest(
collection_id=collection_id,
file_id=uploaded_file.id,
fields=fields,
)
)

if wait_for_indexing:
return self._wait_for_indexing_to_complete(
collection_id,
uploaded_file.id,
poll_interval or DEFAULT_INDEXING_POLL_INTERVAL,
timeout or DEFAULT_INDEXING_TIMEOUT,
)

return self.get_document(
uploaded_file.id,
collection_id,
)

def _wait_for_indexing_to_complete(
self,
collection_id: str,
Expand Down
23 changes: 23 additions & 0 deletions tests/aio/collections_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,29 @@ async def test_upload_document(client: AsyncClient):
assert response.fields == fields


@pytest.mark.asyncio(loop_scope="session")
async def test_upload_document_file_from_path(client: AsyncClient, tmp_path):
collection_metadata = await client.collections.create(f"test-collection-{uuid.uuid4()}")
assert collection_metadata.collection_id is not None

data = b"Hello from a streamed file!"
fields = {"source": "path"}
file_path = tmp_path / "streamed-document.txt"
file_path.write_bytes(data)

document_metadata = await client.collections.upload_document_file(
collection_metadata.collection_id,
str(file_path),
fields=fields,
)

assert document_metadata.file_metadata.file_id is not None
assert document_metadata.file_metadata.name == file_path.name
assert document_metadata.file_metadata.size_bytes == len(data)
assert document_metadata.file_metadata.content_type == "text/plain"
assert document_metadata.fields == fields


@pytest.mark.asyncio(loop_scope="session")
async def test_add_existing_document_to_collection(client: AsyncClient):
# Create a collection to add the document to.
Expand Down
22 changes: 22 additions & 0 deletions tests/sync/collections_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,28 @@ def test_upload_document(client: Client):
assert response.fields == fields


def test_upload_document_file_from_path(client: Client, tmp_path):
collection_metadata = client.collections.create(f"test-collection-{uuid.uuid4()}")
assert collection_metadata.collection_id is not None

data = b"Hello from a streamed file!"
fields = {"source": "path"}
file_path = tmp_path / "streamed-document.txt"
file_path.write_bytes(data)

document_metadata = client.collections.upload_document_file(
collection_metadata.collection_id,
str(file_path),
fields=fields,
)

assert document_metadata.file_metadata.file_id is not None
assert document_metadata.file_metadata.name == file_path.name
assert document_metadata.file_metadata.size_bytes == len(data)
assert document_metadata.file_metadata.content_type == "text/plain"
assert document_metadata.fields == fields


def test_upload_document_without_wait_for_indexing(client: Client):
"""Test uploading a document without waiting for indexing returns immediately with PROCESSED status."""
collection_metadata = client.collections.create(f"test-collection-{uuid.uuid4()}")
Expand Down