diff --git a/py/.pysentry.toml b/py/.pysentry.toml index c536e38a1e..439f7ee2c0 100644 --- a/py/.pysentry.toml +++ b/py/.pysentry.toml @@ -1,3 +1,19 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + version = 1 [defaults] diff --git a/py/GEMINI.md b/py/GEMINI.md new file mode 100644 index 0000000000..2c7c26a6d1 --- /dev/null +++ b/py/GEMINI.md @@ -0,0 +1,94 @@ +# Python Development Guidelines + +## Code Quality & Linting +- **Run Linting**: Always run `./bin/lint` from the repo root (or `py/` directory semantics depending on the script) for all Python code changes. +- **Pass All Tests**: Ensure all unit tests pass (`uv run pytest .`). +- **Production Ready**: The objective is to produce production-grade code. +- **Shift Left**: Employ a "shift left" strategy—catch errors early. +- **Strict Typing**: Strict type checking is required. Do not use `Any` unless absolutely necessary and documented. +- **No Warning Suppression**: Avoid ignoring warnings from the type checker (`# type: ignore`) or other tools unless there is a compelling, documented reason. + +## Generated Files & Data Model +- **Do Not Edit typing.py**: `py/packages/genkit/src/genkit/core/typing.py` is an auto-generated file. **DO NOT MODIFY IT DIRECTLY.** +- **Generator/Sanitizer**: Any necessary changes to the core types must be applied to the generator script or the schema sanitizer. +- **Canonical Parity**: The data model MUST be identical to the JSON schema defined in the JavaScript (canonical) implementation. + +## API & Behavior Parity +- **JS Canonical**: The Python implementation MUST be identical in API structure and runtime behavior to the JavaScript (canonical) implementation. + +## Detailed Coding Guidelines + +### Target Environment +- **Python Version**: Target Python 3.12 or newer. +- **Environment**: Use `uv` for packaging and environment management. + +### Typing & Style +- **Syntax**: + - Use `|` for union types instead of `Union`. + - Use `| None` instead of `Optional`. + - Use lowercase `list`, `dict` for type hints (avoid `List`, `Dict`). + - Use modern generics (PEP 585, 695). + - Use the `type` keyword for type aliases. +- **Imports**: Import types like `Callable`, `Awaitable` from `collections.abc`, not `typing`. +- **Enums**: Use `StrEnum` instead of `(str, Enum)`. +- **Strictness**: Apply type hints strictly, including `-> None` for void functions. +- **Design**: + - Code against interfaces, not implementations. + - Use the adapter pattern for optional implementations. +- **Comments**: + - Use proper punctuation. + - Avoid comments explaining obvious code. + - Use `TODO: Fix this later.` format for stubs. + +### Docstrings +- **Format**: Write comprehensive Google-style docstrings for modules, classes, and functions. +- **Content**: + - **Explain Concepts**: Explain the terminology and concepts used in the code to someone unfamiliar with the code so that first timers can easily understand these ideas. + - **Visuals**: Prefer using tabular format and ascii diagrams in the docstrings to break down complex concepts or list terminology. +- **Required Sections**: + - **Overview**: One-liner description followed by rationale. + - **Key Operations**: Purpose of the component. + - **Args/Attributes**: Required for callables/classes. + - **Returns**: Required for callables. + - **Examples**: Required for user-facing API. + - **Caveats**: Known limitations or edge cases. + +### Formatting +- **Tool**: Format code using `ruff` (or `bin/fmt`). +- **Line Length**: Max 120 characters. +- **Strings**: Wrap long lines and strings appropriately. +- **Config**: Refer to `.editorconfig` or `pyproject.toml` for rules. + +### Testing +- **Framework**: Use `pytest` and `unittest`. +- **Scope**: Write comprehensive unit tests. +- **Documentation**: Add docstrings to test modules/functions explaining their scope. +- **Execution**: Run via `uv run pytest .`. +- **Porting**: Maintain 1:1 logic parity accurately if porting tests. Do not invent behavior. +- **Fixes**: Fix underlying code issues rather than special-casing tests. + +### Logging +- **Library**: Use `structlog` for structured logging. +- **Async**: Use `await logger.ainfo(...)` within coroutines. +- **Format**: Avoid f-strings for async logging; use structured key-values. + +### Licensing +Include the Apache 2.0 license header at the top of each file (update year as needed): + +```python +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +``` diff --git a/py/engdoc/contributing/coding_guidelines.md b/py/engdoc/contributing/coding_guidelines.md deleted file mode 100644 index b4adcf2348..0000000000 --- a/py/engdoc/contributing/coding_guidelines.md +++ /dev/null @@ -1,91 +0,0 @@ -## Python coding guidelines - -This is for both bots and humans. - -Target: Python >= 3.10 - -### Typing & Style - -- Use `|` for union types instead of `Union` and `| None` instead of `Optional`. -- Use lowercase `list`, `dict` for type hints (not `List`, `Dict`). -- Use modern generics (PEP 585, 695). -- Import types such as `Callable`, `Awaitable` that are deprecated in `typing` - from `collections.abc` instead. -- Apply type hints strictly, including `-> None` for functions returning nothing. -- Use the `type` keyword for type aliases. -- Use enum types like `StrEnum` instead of `(str, Enum)` for string-based enums. -- Code against interfaces, not implementations. -- Use the adapter pattern for optional implementations. -- Use proper punctuation in comments. -- Avoid comments explaining obvious code or actions. -- Add TODO comments such as `TODO: Fix this later.` when adding stub implementations. - -### Docstrings - -- Write comprehensive Google-style docstrings for modules, classes, and -functions. -- Include the following sections as needed: - - Overview (required as a one liner description with follow-up paragraphs as - rationale) - - Key operations/purpose - - Arguments/attributes (required for callables) - - Returns (required for callables) - - Examples (required for user-facing API) - - Caveats - -### Formatting - -- Format code using ruff (or `bin/fmt` or `scripts/fmt` if present). -- Max line length: 120 characters to make it easy to read code vertically. -- Refer to the `.editorconfig` or workspace-root `pyproject.toml` for - other formatting rules. -- Wrap long lines and strings appropriately. - -### Testing - -- Write comprehensive unit tests using `pytest` and `unittest`. -- Add docstrings to test modules, classes, and functions explaining their scope. -- Run tests via `uv run --directory ${PYTHON_WORKSPACE_DIR} pytest .` where - the `PYTHON_WORKSPACE_DIR` corresponds to the workspace directory. -- Fix underlying code issues rather than special-casing tests. -- If porting tests: Maintain 1:1 logic parity accurately; do not invent behavior. - -### Tooling & Environment - -- Use `uv` for packaging and environment management. -- Use `mypy` for static type checking. -- Target Python 3.12 or newer. Aim for PyPy compatibility (optional). - -### Logging - -- Use `structlog` for structured logging. -- Use `structlog`'s async API (`await logger.ainfo(...)`) within coroutines -- Avoid f-strings for async logging. - -### Porting - -- If porting from another language (e.g., JS or TypeScript), maintain 1:1 logic -parity in implementation and tests. - -### Licensing - -Include the following Apache 2.0 license header at the top of each file, -updating the year: - -```python -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-License-Identifier: Apache-2.0 -``` diff --git a/py/packages/genkit/src/genkit/ai/__init__.py b/py/packages/genkit/src/genkit/ai/__init__.py index 108a62a9aa..c4aa86f766 100644 --- a/py/packages/genkit/src/genkit/ai/__init__.py +++ b/py/packages/genkit/src/genkit/ai/__init__.py @@ -38,7 +38,7 @@ from genkit.core.plugin import Plugin from ._aio import Genkit -from ._registry import FlowWrapper, GenkitRegistry +from ._registry import FlowWrapper, GenkitRegistry, SimpleRetrieverOptions __all__ = [ ActionKind.__name__, @@ -50,6 +50,7 @@ ToolRunContext.__name__, tool_response.__name__, FlowWrapper.__name__, + SimpleRetrieverOptions.__name__, 'GENKIT_CLIENT_HEADER', 'GENKIT_VERSION', ] diff --git a/py/packages/genkit/src/genkit/ai/_aio.py b/py/packages/genkit/src/genkit/ai/_aio.py index 11185535d7..1a2f00a896 100644 --- a/py/packages/genkit/src/genkit/ai/_aio.py +++ b/py/packages/genkit/src/genkit/ai/_aio.py @@ -16,17 +16,27 @@ """User-facing asyncio API for Genkit. -To use Genkit in your application, construct an instance of the `Genkit` -class while customizing it with any plugins. +This module provides the primary entry point for using Genkit in an asynchronous +environment. The `Genkit` class coordinates plugins, registry, and execution +of AI actions like generation, embedding, and retrieval. + +Key features provided by the `Genkit` class: +- **Generation**: Interface for unified model interaction via `generate` and `generate_stream`. +- **Flow Control**: Execution of granular steps with tracing via `run`. +- **Dynamic Extensibility**: On-the-fly creation of tools via `dynamic_tool`. +- **Observability**: Specialized methods for managing trace context and flushing telemetry. """ import asyncio import uuid -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Awaitable, Callable from pathlib import Path -from typing import TypedDict, cast +from typing import Any, TypedDict, TypeVar, cast # noqa: F401 -from genkit.aio import Channel +from opentelemetry import trace as trace_api +from opentelemetry.sdk.trace import TracerProvider + +from genkit.aio import Channel, ensure_async from genkit.blocks.document import Document from genkit.blocks.embedding import EmbedderRef from genkit.blocks.evaluator import EvaluatorRef @@ -41,9 +51,10 @@ class while customizing it with any plugins. ) from genkit.blocks.prompt import PromptConfig, load_prompt_folder, to_generate_action_options from genkit.blocks.retriever import IndexerRef, IndexerRequest, RetrieverRef -from genkit.core.action import ActionRunContext +from genkit.core.action import Action, ActionRunContext from genkit.core.action.types import ActionKind from genkit.core.plugin import Plugin +from genkit.core.tracing import run_in_new_span from genkit.core.typing import ( BaseDataPoint, Embedding, @@ -51,6 +62,8 @@ class while customizing it with any plugins. EmbedResponse, EvalRequest, EvalResponse, + Operation, + SpanMetadata, ) from genkit.types import ( DocumentData, @@ -66,6 +79,8 @@ class while customizing it with any plugins. from ._base_async import GenkitBase from ._server import ServerSpec +T = TypeVar('T') + class OutputConfigDict(TypedDict, total=False): """TypedDict for output configuration when passed as a dict.""" @@ -686,3 +701,143 @@ async def evaluate( ) ) ).response + + @staticmethod + def current_context() -> dict[str, Any] | None: + """Retrieves the current execution context for the running action. + + This allows tools and other actions to access context data (like auth + or metadata) passed through the execution chain via ContextVars. + This provides parity with the JavaScript SDK's context handling. + + Returns: + The current context dictionary, or None if not running in an action. + """ + return ActionRunContext._current_context() + + def dynamic_tool( + self, + name: str, + fn: Callable[..., object], + description: str | None = None, + metadata: dict[str, object] | None = None, + ) -> Action: + """Creates an unregistered tool action. + + This is useful for creating tools that are passed directly to generate() + without being registered in the global registry. Dynamic tools behave exactly + like registered tools but offer more flexibility for runtime-defined logic. + + Args: + name: The unique name of the tool. + fn: The function that implements the tool logic. + description: Optional human-readable description of what the tool does. + metadata: Optional dictionary of metadata about the tool. + + Returns: + An Action instance of kind TOOL, configured for dynamic execution. + """ + tool_meta = metadata.copy() if metadata else {} + tool_meta['type'] = 'tool' + tool_meta['dynamic'] = True + return Action( + kind=ActionKind.TOOL, + name=name, + fn=fn, + description=description, + metadata=tool_meta, + ) + + async def flush_tracing(self) -> None: + """Flushes all registered trace processors. + + This ensures all pending spans are exported before the application + shuts down, preventing loss of telemetry data. + """ + provider = trace_api.get_tracer_provider() + if isinstance(provider, TracerProvider): + await ensure_async(provider.force_flush)() + + async def run( + self, + name: str, + func_or_input: object, + maybe_fn: Callable[..., T | Awaitable[T]] | None = None, + metadata: dict[str, Any] | None = None, + ) -> T: + """Runs a function as a discrete step within a trace. + + This method is used to create sub-spans (steps) within a flow or other action. + Each run step is recorded separately in the trace, making it easier to + debug and monitor the internal execution of complex flows. + + It supports two call signatures: + 1. `run(name, fn)`: Runs the provided function. + 2. `run(name, input, fn)`: Passes the input to the function and records it. + + Args: + name: The descriptive name of the span/step. + func_or_input: Either the function to execute, or input data to pass + to `maybe_fn`. + maybe_fn: An optional function to execute if `func_or_input` is + provided as input data. + metadata: Optional metadata to associate with the generated trace span. + + Returns: + The result of the function execution. + """ + fn: Callable[..., T | Awaitable[T]] + input_data: Any = None + has_input = False + + if maybe_fn: + fn = maybe_fn + input_data = func_or_input + has_input = True + elif callable(func_or_input): + fn = cast(Callable[..., T | Awaitable[T]], func_or_input) + else: + raise ValueError('A function must be provided to run.') + + span_metadata = SpanMetadata(name=name, metadata=metadata) + with run_in_new_span(span_metadata, labels={'genkit:type': 'flowStep'}) as span: + try: + if has_input: + span.set_input(input_data) + result = await ensure_async(fn)(input_data) + else: + result = await ensure_async(fn)() + + span.set_output(result) + return result + except Exception: + # We catch all exceptions here to ensure they are captured by + # the trace span context manager before being re-raised. + # The GenkitSpan wrapper (run_in_new_span) handles recording + # the exception details. + raise + + async def check_operation(self, operation: Operation) -> Operation: + """Checks the status of a long-running operation. + + This method resolves the action associated with the operation and executes + it to get an updated status. + + Args: + operation: The Operation object to check. + + Returns: + An updated Operation object. + + Raises: + ValueError: If the operation doesn't specify an action or if the + action cannot be resolved. + """ + if not operation.action: + raise ValueError('Operation must have an action specified to be checked.') + + action = await self.registry.resolve_action_by_key(operation.action) + if not action: + raise ValueError(f'Action "{operation.action}" not found.') + + return (await action.arun(operation)).response diff --git a/py/packages/genkit/src/genkit/ai/_registry.py b/py/packages/genkit/src/genkit/ai/_registry.py index e41ddc1f8e..8035479d00 100644 --- a/py/packages/genkit/src/genkit/ai/_registry.py +++ b/py/packages/genkit/src/genkit/ai/_registry.py @@ -41,10 +41,12 @@ import inspect import traceback import uuid -from collections.abc import AsyncIterator, Callable +from collections.abc import AsyncIterator, Awaitable, Callable from functools import wraps from typing import TYPE_CHECKING, Any, Generic, ParamSpec, TypeVar, cast +from genkit.aio import ensure_async + if TYPE_CHECKING: from genkit.blocks.prompt import ExecutablePrompt from genkit.blocks.resource import FlexibleResourceFn, ResourceOptions @@ -81,6 +83,7 @@ from genkit.core.tracing import run_in_new_span from genkit.core.typing import ( DocumentData, + DocumentPart, EvalFnResponse, EvalRequest, EvalResponse, @@ -89,6 +92,7 @@ Message, ModelInfo, Part, + RetrieverResponse, Score, SpanMetadata, ToolChoice, @@ -119,6 +123,77 @@ def get_func_description(func: Callable, description: str | None = None) -> str: return '' +R = TypeVar('R') + + +class SimpleRetrieverOptions(BaseModel, Generic[R]): + """Configuration options for `define_simple_retriever`. + + This class defines how items returned by a simple retriever handler are + mapped into Genkit `DocumentData` objects. + + Attributes: + name: The unique name of the retriever. + content: Specifies how to extract content from the returned items. + Can be a string key (for dict items) or a callable that transforms the item. + metadata: Specifies how to extract metadata from the returned items. + Can be a list of keys (for dict items) or a callable that transforms the item. + config_schema: Optional Pydantic schema or JSON schema for retriever configuration. + """ + + name: str + content: str | Callable[[R], str | list[DocumentPart]] | None = None + metadata: list[str] | Callable[[R], dict[str, Any]] | None = None + config_schema: type[BaseModel] | dict[str, Any] | None = None + + +def _item_to_document(item: R, options: SimpleRetrieverOptions[R]) -> DocumentData: + """Internal helper to convert a raw item to a Genkit DocumentData.""" + from genkit.blocks.document import Document + + if isinstance(item, (Document, DocumentData)): + return item + + if isinstance(item, str): + return Document.from_text(item) + + if callable(options.content): + transformed = options.content(item) + if isinstance(transformed, str): + return Document.from_text(transformed) + else: + # transformed is list[DocumentPart] + return DocumentData(content=cast(list[DocumentPart], transformed)) + + if isinstance(options.content, str) and isinstance(item, dict): + return Document.from_text(str(item[options.content])) + + if options.content is None and isinstance(item, str): + return Document.from_text(item) + + raise ValueError(f'Cannot convert item to document without content option. Item: {item}') + + +def _item_to_metadata(item: R, options: SimpleRetrieverOptions[R]) -> dict[str, Any] | None: + """Internal helper to extract metadata from a raw item for a Document.""" + if isinstance(item, str): + return None + + if isinstance(options.metadata, list) and isinstance(item, dict): + return {str(k): item[k] for k in options.metadata if k in item} + + if callable(options.metadata): + return options.metadata(item) + + if options.metadata is None and isinstance(item, dict): + out = cast(dict[str, Any], item.copy()) + if isinstance(options.content, str) and options.content in out: + del out[options.content] + return out + + return None + + class GenkitRegistry: """User-facing API for interacting with Genkit registry.""" @@ -357,6 +432,49 @@ def define_retriever( description=retriever_description, ) + def define_simple_retriever( + self, + options: SimpleRetrieverOptions[R] | str, + handler: Callable[[DocumentData, Any], list[R] | Awaitable[list[R]]], + description: str | None = None, + ) -> Action: + """Define a simple retriever action. + + A simple retriever makes it easy to map existing data into documents + that can be used for prompt augmentation. + + Args: + options: Configuration options for the retriever, or just the name. + handler: A function that queries a datastore and returns items + from which to extract documents. + description: Optional description for the retriever. + + Returns: + The registered Action for the retriever. + """ + if isinstance(options, str): + options = SimpleRetrieverOptions(name=options) + + from genkit.blocks.document import Document + + async def retriever_fn(query: Document, options_obj: Any) -> RetrieverResponse: # noqa: ANN401 + + items = await ensure_async(handler)(query, options_obj) + docs = [] + for item in items: + doc = _item_to_document(item, options) + if not isinstance(item, str): + doc.metadata = _item_to_metadata(item, options) + docs.append(doc) + return RetrieverResponse(documents=docs) + + return self.define_retriever( + name=options.name, + fn=retriever_fn, + config_schema=options.config_schema, + description=description, + ) + def define_indexer( self, name: str, diff --git a/py/packages/genkit/src/genkit/blocks/generate.py b/py/packages/genkit/src/genkit/blocks/generate.py index cff86d7f8b..bc48b3c149 100644 --- a/py/packages/genkit/src/genkit/blocks/generate.py +++ b/py/packages/genkit/src/genkit/blocks/generate.py @@ -131,6 +131,8 @@ async def generate_action( request = await action_to_generate_request(raw_request, tools, model) + logger.debug('generate request', model=model.name, request=dump_dict(request)) + prev_chunks: list[GenerateResponseChunk] = [] chunk_role: Role = cast(Role, Role.MODEL) @@ -282,6 +284,8 @@ def message_parser(msg: MessageWrapper) -> Any: # noqa: ANN401 message_parser=message_parser if formatter else None, ) + logger.debug('generate response', response=dump_dict(response)) + response.assert_valid() generated_msg = response.message diff --git a/py/packages/genkit/tests/genkit/ai/genkit_api_test.py b/py/packages/genkit/tests/genkit/ai/genkit_api_test.py new file mode 100644 index 0000000000..9b83dcc147 --- /dev/null +++ b/py/packages/genkit/tests/genkit/ai/genkit_api_test.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python3 +# +# Copyright 2025 Google LLC +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the Genkit extra API methods.""" + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from genkit.ai import Genkit +from genkit.core.action import Action +from genkit.core.action.types import ActionKind +from genkit.core.typing import DocumentPart, Operation +from genkit.types import DocumentData, RetrieverRequest, RetrieverResponse, TextPart + + +@pytest.mark.asyncio +async def test_genkit_run() -> None: + """Test Genkit.run method.""" + ai = Genkit() + + def sync_fn() -> str: + return 'hello' + + async def async_fn() -> str: + return 'world' + + res1 = await ai.run('test1', sync_fn) + assert res1 == 'hello' + + res2 = await ai.run('test2', async_fn) + assert res2 == 'world' + + # Test with metadata + res3 = await ai.run('test3', sync_fn, metadata={'foo': 'bar'}) + assert res3 == 'hello' + + # Test with input overload + async def multiply(x: int) -> int: + return x * 2 + + res4 = await ai.run('multiply', 10, multiply) + assert res4 == 20 + + +@pytest.mark.asyncio +async def test_genkit_dynamic_tool() -> None: + """Test Genkit.dynamic_tool method.""" + ai = Genkit() + + def my_tool(x: int) -> int: + return x + 1 + + tool = ai.dynamic_tool('my_tool', my_tool, description='increment x') + + assert isinstance(tool, Action) + assert tool.kind == ActionKind.TOOL + assert tool.name == 'my_tool' + assert tool.description == 'increment x' + assert tool.metadata.get('type') == 'tool' + assert tool.metadata.get('dynamic') is True + + # Execution + resp = await tool.arun(5) + assert resp.response == 6 + + +@pytest.mark.asyncio +async def test_genkit_check_operation() -> None: + """Test Genkit.check_operation method.""" + ai = Genkit() + + op = Operation(id='123', done=False, action='test_action') + + mock_action = AsyncMock() + mock_action.arun.return_value = MagicMock(response=Operation(id='123', done=True, output='result')) + + # Mock registry.resolve_action_by_key + ai.registry.resolve_action_by_key = AsyncMock(return_value=mock_action) # type: ignore[assignment] + + updated_op = await ai.check_operation(op) + + assert updated_op.done is True + assert updated_op.output == 'result' + ai.registry.resolve_action_by_key.assert_called_with('test_action') # type: ignore[attr-defined] + + +@pytest.mark.asyncio +async def test_genkit_check_operation_no_action() -> None: + """Test Genkit.check_operation method with no action.""" + ai = Genkit() + op = Operation(id='123', done=False) # action is None + + with pytest.raises(ValueError, match='Operation must have an action specified'): + await ai.check_operation(op) + + +@pytest.mark.asyncio +async def test_genkit_check_operation_not_found() -> None: + """Test Genkit.check_operation method with action not found.""" + ai = Genkit() + op = Operation(id='123', done=False, action='missing') + ai.registry.resolve_action_by_key = AsyncMock(return_value=None) # type: ignore[assignment] + + with pytest.raises(ValueError, match='Action "missing" not found'): + await ai.check_operation(op) + + +@pytest.mark.asyncio +async def test_define_simple_retriever_legacy() -> None: + """Test define_simple_retriever with legacy handler signature.""" + ai = Genkit() + + def simple_retriever(query: DocumentData, options: Any) -> list[DocumentData]: # noqa: ANN401 + # Returns list[DocumentData] directly + + text_part: DocumentPart = DocumentPart(root=TextPart(text='doc1')) + return [DocumentData(content=[text_part])] + + retriever_action = ai.define_simple_retriever('simple', simple_retriever) + + assert retriever_action.kind == ActionKind.RETRIEVER + + # Test execution + req = RetrieverRequest(query=DocumentData(content=[])) + resp_wrapper = await retriever_action.arun(req) + response = resp_wrapper.response + + assert isinstance(response, RetrieverResponse) + assert len(response.documents) == 1 + assert response.documents[0].content[0].root.text == 'doc1' + + +@pytest.mark.asyncio +async def test_define_simple_retriever_mapped() -> None: + """Test define_simple_retriever with mapping options.""" + ai = Genkit() + + def db_handler(query: DocumentData, options: Any) -> list[dict[str, Any]]: # noqa: ANN401 + return [ + {'id': '1', 'text': 'hello', 'extra': 'data'}, + {'id': '2', 'text': 'world', 'extra': 'more'}, + ] + + from genkit.ai._registry import SimpleRetrieverOptions + + options = SimpleRetrieverOptions(name='mapped', content='text', metadata=['extra']) + + retriever_action = ai.define_simple_retriever(options, db_handler) + + req = RetrieverRequest(query=DocumentData(content=[])) + resp_wrapper = await retriever_action.arun(req) + response = resp_wrapper.response + + assert len(response.documents) == 2 + assert response.documents[0].content[0].root.text == 'hello' + assert response.documents[0].metadata == {'extra': 'data'} + assert 'id' not in response.documents[0].metadata + + +@pytest.mark.asyncio +async def test_current_context() -> None: + """Test Genkit.current_context method.""" + from genkit.core.action._action import _action_context + + # current_context is a static method + assert Genkit.current_context() is None + + context = {'auth': {'uid': '123'}} + + # Simulate being inside an action run using ActionRunContext internal mechanism + token = _action_context.set(context) + try: + assert Genkit.current_context() == context + finally: + _action_context.reset(token) + + assert Genkit.current_context() is None + + +@pytest.mark.asyncio +async def test_flush_tracing() -> None: + """Test Genkit.flush_tracing method.""" + from opentelemetry import trace as trace_api + from opentelemetry.sdk.trace import TracerProvider + + ai = Genkit() + + mock_provider = MagicMock(spec=TracerProvider) + mock_provider.force_flush = MagicMock() + + # We can't easily mock the global provider if it's already set, + # but we can check if it calls force_flush if it is a TracerProvider. + + trace_api.get_tracer_provider() + trace_api.set_tracer_provider(mock_provider) + try: + await ai.flush_tracing() + mock_provider.force_flush.assert_called_once() + finally: + # Note: set_tracer_provider can only be called once in real OTel, + # but in tests we might be using a mock. + pass diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py index 978ec19afd..7f20f08e86 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py @@ -93,6 +93,8 @@ def __init__( credentials: Credentials | None = None, debug_config: DebugConfig | None = None, http_options: HttpOptions | HttpOptionsDict | None = None, + api_version: str | None = None, + base_url: str | None = None, ) -> None: """Initializes the GoogleAI plugin. @@ -106,6 +108,8 @@ def __init__( debug_config: Configuration for debugging the client. Defaults to None. http_options: HTTP options for configuring the client's network requests. Can be an instance of HttpOptions or a dictionary. Defaults to None. + api_version: The API version to use (e.g., 'v1beta'). Defaults to None. + base_url: The base URL for the API. Defaults to None. Raises: ValueError: If `api_key` is not provided and the 'GEMINI_API_KEY' @@ -122,7 +126,7 @@ def __init__( api_key=api_key, credentials=credentials, debug_config=debug_config, - http_options=_inject_attribution_headers(http_options), + http_options=_inject_attribution_headers(http_options, base_url, api_version), ) async def init(self) -> list[Action]: @@ -301,6 +305,8 @@ def __init__( debug_config: DebugConfig | None = None, http_options: HttpOptions | HttpOptionsDict | None = None, api_key: str | None = None, + api_version: str | None = None, + base_url: str | None = None, ) -> None: """Initializes the VertexAI plugin. @@ -316,6 +322,8 @@ def __init__( api_key: The API key for authenticating with the Google AI service. If not provided, it defaults to reading from the 'GEMINI_API_KEY' environment variable. + api_version: The API version to use. Defaults to None. + base_url: The base URL for the API. Defaults to None. """ project = project if project else os.getenv(const.GCLOUD_PROJECT) location = location if location else const.DEFAULT_REGION @@ -327,7 +335,7 @@ def __init__( project=project, location=location, debug_config=debug_config, - http_options=_inject_attribution_headers(http_options), + http_options=_inject_attribution_headers(http_options, base_url, api_version), ) async def init(self) -> list[Action]: @@ -492,7 +500,11 @@ async def list_actions(self) -> list[ActionMetadata]: return actions_list -def _inject_attribution_headers(http_options: HttpOptions | HttpOptionsDict | None = None) -> HttpOptions: +def _inject_attribution_headers( + http_options: HttpOptions | HttpOptionsDict | None = None, + base_url: str | None = None, + api_version: str | None = None, +) -> HttpOptions: """Adds genkit client info to the appropriate http headers.""" # Normalize to HttpOptions instance opts: HttpOptions @@ -504,6 +516,12 @@ def _inject_attribution_headers(http_options: HttpOptions | HttpOptionsDict | No # HttpOptionsDict or other dict-like - use model_validate for proper type conversion opts = HttpOptions.model_validate(http_options) + if base_url: + opts.base_url = base_url + + if api_version: + opts.api_version = api_version + if not opts.headers: opts.headers = {} diff --git a/py/plugins/google-genai/test/google_plugin_test.py b/py/plugins/google-genai/test/google_plugin_test.py index ed575dbb8a..92fe187553 100644 --- a/py/plugins/google-genai/test/google_plugin_test.py +++ b/py/plugins/google-genai/test/google_plugin_test.py @@ -406,10 +406,10 @@ def test_init_with_api_key(self, mock_genai_client: MagicMock) -> None: vertexai=True, api_key=api_key, credentials=None, - debug_config=None, - http_options=_inject_attribution_headers(), project='project', location='us-central1', + debug_config=None, + http_options=_inject_attribution_headers(), ) self.assertIsInstance(plugin, VertexAI) self.assertTrue(plugin._vertexai) @@ -425,10 +425,10 @@ def test_init_with_credentials(self, mock_genai_client: MagicMock) -> None: vertexai=True, api_key=None, credentials=mock_credentials, - debug_config=None, - http_options=_inject_attribution_headers(), project='project', location='us-central1', + debug_config=None, + http_options=_inject_attribution_headers(), ) self.assertIsInstance(plugin, VertexAI) self.assertTrue(plugin._vertexai) @@ -449,10 +449,10 @@ def test_init_with_all(self, mock_genai_client: MagicMock) -> None: vertexai=True, api_key=api_key, credentials=mock_credentials, - debug_config=None, - http_options=_inject_attribution_headers(), project='project', location='location', + debug_config=None, + http_options=_inject_attribution_headers(), ) self.assertIsInstance(plugin, VertexAI) self.assertTrue(plugin._vertexai) diff --git a/py/pyproject.toml b/py/pyproject.toml index a021b775f8..c0cea701cd 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -64,7 +64,13 @@ dev = [ "nox-uv>=0.2.2", ] -lint = ["pypdf", "strenum>=0.4.15", "ty>=0.0.1", "ruff>=0.9", "pysentry-rs>=0.3.14"] +lint = [ + "pypdf", + "strenum>=0.4.15", + "ty>=0.0.1", + "ruff>=0.9", + "pysentry-rs>=0.3.14", +] [tool.hatch.build.targets.wheel] diff --git a/py/samples/google-genai-hello/src/main.py b/py/samples/google-genai-hello/src/main.py index 7b9d322a29..265d7f71e3 100755 --- a/py/samples/google-genai-hello/src/main.py +++ b/py/samples/google-genai-hello/src/main.py @@ -206,6 +206,44 @@ async def say_hi(name: Annotated[str, Field(default='Alice')] = 'Alice') -> str: return resp.text +@ai.flow() +async def demo_dynamic_tools( + input_val: Annotated[str, Field(default='Dynamic tools demo')] = 'Dynamic tools demo', +) -> dict: + """Demonstrates advanced Genkit features: ai.run() and ai.dynamic_tool(). + + This flow shows how to: + 1. Use `ai.run()` to create sub-spans (steps) within a flow trace. + 2. Use `ai.dynamic_tool()` to create tools on-the-fly without registration. + + To test this in the Dev UI: + 1. Select 'demo_dynamic_tools' from the flows list. + 2. Run it with the default input or provide a custom string. + 3. Click 'View trace' to see the 'process_data_step' sub-span and tool execution. + """ + + # ai.run() allows you to wrap any function in a trace span, which is visible + # in the Dev UI. It supports an optional input argument as the second parameter. + def process_data(data: str) -> str: + return f'processed: {data}' + + run_result = await ai.run('process_data_step', input_val, process_data) + + # ai.dynamic_tool() creates a tool that isn't globally registered but can be + # used immediately or passed to generate() calls. + def multiplier_fn(x: int) -> int: + return x * 10 + + dynamic_multiplier = ai.dynamic_tool('dynamic_multiplier', multiplier_fn, description='Multiplies by 10') + tool_res = await dynamic_multiplier.arun(5) + + return { + 'step_result': run_result, + 'dynamic_tool_result': tool_res.response, + 'tool_metadata': dynamic_multiplier.metadata, + } + + @ai.flow() async def embed_docs( docs: list[str] | None = None,