Skip to content

Commit 0764852

Browse files
authored
feat: Maestro - add data sources support (#282)
fix: add data sources support * fix: typedDict import * fix: add missing param * fix: add test * fix: remove skip * fix: fix test
1 parent 1888bd7 commit 0764852

File tree

6 files changed

+69
-6
lines changed

6 files changed

+69
-6
lines changed

ai21/clients/common/maestro/run.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
DEFAULT_RUN_POLL_TIMEOUT,
1313
Requirement,
1414
Budget,
15+
OutputOptions,
1516
)
1617
from ai21.types import NOT_GIVEN, NotGiven
1718
from ai21.utils.typing import remove_not_given
@@ -30,6 +31,7 @@ def _create_body(
3031
context: Dict[str, Any] | NotGiven,
3132
requirements: List[Requirement] | NotGiven,
3233
budget: Budget | NotGiven,
34+
include: List[OutputOptions] | NotGiven,
3335
**kwargs,
3436
) -> dict:
3537
return remove_not_given(
@@ -41,6 +43,7 @@ def _create_body(
4143
"context": context,
4244
"requirements": requirements,
4345
"budget": budget,
46+
"include": include,
4447
**kwargs,
4548
}
4649
)
@@ -56,6 +59,7 @@ def create(
5659
context: Dict[str, Any] | NotGiven = NOT_GIVEN,
5760
requirements: List[Requirement] | NotGiven = NOT_GIVEN,
5861
budget: Budget | NotGiven = NOT_GIVEN,
62+
include: List[OutputOptions] | NotGiven = NOT_GIVEN,
5963
**kwargs,
6064
) -> RunResponse:
6165
pass
@@ -79,6 +83,7 @@ def create_and_poll(
7983
context: Dict[str, Any] | NotGiven = NOT_GIVEN,
8084
requirements: List[Requirement] | NotGiven = NOT_GIVEN,
8185
budget: Budget | NotGiven = NOT_GIVEN,
86+
include: List[OutputOptions] | NotGiven = NOT_GIVEN,
8287
poll_interval_sec: float = DEFAULT_RUN_POLL_INTERVAL,
8388
poll_timeout_sec: float = DEFAULT_RUN_POLL_TIMEOUT,
8489
**kwargs,

ai21/clients/studio/resources/maestro/run.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
DEFAULT_RUN_POLL_TIMEOUT,
1717
Requirement,
1818
Budget,
19+
OutputOptions,
1920
)
2021
from ai21.types import NotGiven, NOT_GIVEN
2122

@@ -31,6 +32,7 @@ def create(
3132
context: Dict[str, Any] | NotGiven = NOT_GIVEN,
3233
requirements: List[Requirement] | NotGiven = NOT_GIVEN,
3334
budget: Budget | NotGiven = NOT_GIVEN,
35+
include: List[OutputOptions] | NotGiven = NOT_GIVEN,
3436
**kwargs,
3537
) -> RunResponse:
3638
body = self._create_body(
@@ -41,6 +43,7 @@ def create(
4143
context=context,
4244
requirements=requirements,
4345
budget=budget,
46+
include=include,
4447
**kwargs,
4548
)
4649

@@ -76,6 +79,7 @@ def create_and_poll(
7679
context: Dict[str, Any] | NotGiven = NOT_GIVEN,
7780
requirements: List[Requirement] | NotGiven = NOT_GIVEN,
7881
budget: Budget | NotGiven = NOT_GIVEN,
82+
include: List[OutputOptions] | NotGiven = NOT_GIVEN,
7983
poll_interval_sec: float = DEFAULT_RUN_POLL_INTERVAL,
8084
poll_timeout_sec: float = DEFAULT_RUN_POLL_TIMEOUT,
8185
**kwargs,
@@ -88,6 +92,7 @@ def create_and_poll(
8892
context=context,
8993
requirements=requirements,
9094
budget=budget,
95+
include=include,
9196
**kwargs,
9297
)
9398

@@ -105,6 +110,7 @@ async def create(
105110
context: Dict[str, Any] | NotGiven = NOT_GIVEN,
106111
requirements: List[Requirement] | NotGiven = NOT_GIVEN,
107112
budget: Budget | NotGiven = NOT_GIVEN,
113+
include: List[OutputOptions] | NotGiven = NOT_GIVEN,
108114
**kwargs,
109115
) -> RunResponse:
110116
body = self._create_body(
@@ -115,6 +121,7 @@ async def create(
115121
context=context,
116122
requirements=requirements,
117123
budget=budget,
124+
include=include,
118125
**kwargs,
119126
)
120127

@@ -150,6 +157,7 @@ async def create_and_poll(
150157
context: Dict[str, Any] | NotGiven = NOT_GIVEN,
151158
requirements: List[Requirement] | NotGiven = NOT_GIVEN,
152159
budget: Budget | NotGiven = NOT_GIVEN,
160+
include: List[OutputOptions] | NotGiven = NOT_GIVEN,
153161
poll_interval_sec: float = DEFAULT_RUN_POLL_INTERVAL,
154162
poll_timeout_sec: float = DEFAULT_RUN_POLL_TIMEOUT,
155163
**kwargs,
@@ -162,6 +170,7 @@ async def create_and_poll(
162170
context=context,
163171
requirements=requirements,
164172
budget=budget,
173+
include=include,
165174
**kwargs,
166175
)
167176

ai21/models/__init__.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,16 @@
88
ConversationalRagSource,
99
)
1010
from ai21.models.responses.file_response import FileResponse
11-
from ai21.models.maestro.run import Requirement, Budget, Tool, ToolResources
12-
11+
from ai21.models.maestro.run import (
12+
Requirement,
13+
Budget,
14+
Tool,
15+
ToolResources,
16+
DataSources,
17+
FileSearchResult,
18+
WebSearchResult,
19+
OutputOptions,
20+
)
1321

1422
__all__ = [
1523
"ChatMessage",
@@ -26,4 +34,8 @@
2634
"Budget",
2735
"Tool",
2836
"ToolResources",
37+
"DataSources",
38+
"FileSearchResult",
39+
"WebSearchResult",
40+
"OutputOptions",
2941
]

ai21/models/maestro/run.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from typing import TypedDict, Literal, List, Optional, Any, Set, Dict, Type, Union
2-
1+
from typing import Literal, List, Optional, Any, Set, Dict, Type, Union
2+
from typing_extensions import TypedDict
33
from pydantic import BaseModel
44

55
from ai21.models.ai21_base_model import AI21BaseModel
@@ -8,6 +8,7 @@
88
Role = Literal["user", "assistant"]
99
RunStatus = Literal["completed", "failed", "in_progress", "requires_action"]
1010
ToolType = Literal["file_search", "web_search"]
11+
OutputOptions = Literal["data_sources"]
1112
PrimitiveTypes = Union[Type[str], Type[int], Type[float], Type[bool]]
1213
PrimitiveLists = Type[List[PrimitiveTypes]]
1314
OutputType = Union[Type[BaseModel], PrimitiveTypes, Dict[str, Any]]
@@ -40,12 +41,32 @@ class ToolResources(TypedDict, total=False):
4041
web_search: Optional[WebSearchToolResource]
4142

4243

43-
class Requirement(TypedDict):
44+
class Requirement(TypedDict, total=False):
4445
name: str
4546
description: str
4647

4748

49+
class FileSearchResult(TypedDict, total=False):
50+
text: Optional[str]
51+
file_id: str
52+
file_name: str
53+
score: float
54+
order: int
55+
56+
57+
class WebSearchResult(TypedDict, total=False):
58+
text: str
59+
url: str
60+
score: float
61+
62+
63+
class DataSources(TypedDict, total=False):
64+
file_search: Optional[List[FileSearchResult]]
65+
web_search: Optional[List[WebSearchResult]]
66+
67+
4868
class RunResponse(AI21BaseModel):
4969
id: str
5070
status: RunStatus
5171
result: Any
72+
data_sources: Optional[DataSources] = None

tests/integration_tests/clients/studio/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def _wait_for_file_to_process(client: AI21Client, file_id: str, timeout: float =
2020
return
2121

2222
elapsed_time = time.time() - start_time
23-
time.sleep(0.5)
23+
time.sleep(2)
2424

2525
raise TimeoutError(f"Timeout: {timeout} seconds passed. File processing not completed")
2626

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import pytest
2+
3+
from ai21 import AsyncAI21Client
4+
5+
6+
@pytest.mark.asyncio
7+
async def test_maestro__when_upload__should_return_data_sources(): # file_in_library: str):
8+
client = AsyncAI21Client()
9+
result = await client.beta.maestro.runs.create_and_poll(
10+
input="When did Einstein receive a Nobel Prize?", tools=[{"type": "file_search"}], include=["data_sources"]
11+
)
12+
assert result.status == "completed", "Expected 'completed' status"
13+
assert result.result, "Expected a non-empty answer"
14+
assert result.data_sources, "Expected data sources"
15+
assert len(result.data_sources["file_search"]) > 0, "Expected at least one file search data source"
16+
assert result.data_sources.get("web_search") is None, "Expected no web search data sources"

0 commit comments

Comments
 (0)