diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 08d851ec1e..ba3ad25722 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -98,11 +98,23 @@ jobs: - name: Lint with ruff run: uv run --directory py ruff check --select I . + - name: Type check with Ty + run: uv run --directory py ty check --exit-zero . + - name: Check licenses run: ./bin/check_license - name: Run Python tests run: uv run --python ${{ matrix.python-version }} --active --isolated --directory py pytest -xvs --log-level=DEBUG . + - name: Run Python tests (tox) + run: | + clean_version=$(echo "${{ matrix.python-version }}" | tr -d '.') + uv run --directory py tox -e "py$clean_version" + + - name: Run Python tests (nox) + run: | + uv run --directory py nox -s "tests-${{ matrix.python-version }}" + - name: Build distributions run: ./py/bin/build_dists diff --git a/bin/check_license b/bin/check_license index 3a72112e59..83aa4890b3 100755 --- a/bin/check_license +++ b/bin/check_license @@ -63,6 +63,7 @@ $HOME/go/bin/addlicense \ -ignore '.trunk/**/*' \ -ignore '**/*.toml' \ -ignore '**/*.nix' \ + -ignore '**/*.yaml' \ "$TOP_DIR" uv run --directory "${PY_DIR}" liccheck diff --git a/bin/run_lint b/bin/run_lint index 66144ea11d..f682deae3b 100755 --- a/bin/run_lint +++ b/bin/run_lint @@ -24,7 +24,7 @@ PY_DIR="${TOP_DIR}/py" JS_DIR="${TOP_DIR}/js"j uv run --directory "${PY_DIR}" ruff check --select I --fix --preview --unsafe-fixes . -uv run --directory "${PY_DIR}" mypy . +uv run --directory "${PY_DIR}" ty check --exclude samples . # Disabled because there are many lint errors. #pushd "${GO_DIR}" &>/dev/null diff --git a/docs/model-spec.md b/docs/model-spec.md new file mode 100644 index 0000000000..119cca1e64 --- /dev/null +++ b/docs/model-spec.md @@ -0,0 +1,365 @@ +# Genkit Model Action Specification + +This document specifies the contract for Genkit Model Actions. Genkit models are implemented as actions with specific input, output, and streaming schemas. They encapsulate the logic for communicating with AI models, handling multimodal input, tool calling, and structured output. + +## Model Action Definition + +A Genkit Model is an Action with the following characteristics: + +- **Action Type**: `model` +- **Input Schema**: `GenerateRequest` +- **Output Schema**: `GenerateResponse` +- **Streaming Schema**: `GenerateResponseChunk` + +### Metadata + +Model actions should define the following metadata: + +- `model`: Object containing model capability information. + - `label`: Human-readable name (e.g., "Google AI - Gemini Pro"). + - `versions`: Array of supported version strings. + - `supports`: Object defining supported capabilities: + - `multiturn`: Boolean (history support). + - `media`: Boolean (multimodal input support). + - `tools`: Boolean (tool calling support). + - `systemRole`: Boolean (system message support). + - `output`: Array of supported output formats (e.g., `['json', 'text']`). + - `contentType`: Array of supported output content types. + - `context`: Boolean (document context support). + - `constrained`: Enum (`'none'`, `'all'`, `'no-tools'`) - native constrained generation support. + - `toolChoice`: Boolean (forcing tool selection). + - `longRunning`: Boolean (long running operation support). + - `stage`: Development stage (`'featured'`, `'stable'`, `'unstable'`, `'legacy'`, `'deprecated'`). + - `customOptions`: JSON Schema for model-specific configuration (exposed as `config` in request). + +## Data Structures + +### GenerateRequest + +The input to a model action. + +| Field | Type | Description | +|---|---|---| +| `messages` | `Message[]` | **(Required)** List of messages in the conversation history. | +| `config` | `any` | Model-specific configuration options (e.g., temperature, topK). Validated against the model's config schema. | +| `tools` | `ToolDefinition[]` | List of tools available for the model to call. | +| `toolChoice` | `enum` | Tool selection strategy: `'auto'`, `'required'`, or `'none'`. | +| `output` | `OutputConfig` | Configuration for the desired output format/schema. | +| `docs` | `DocumentData[]` | Retrieved documents to be used as context. | + +#### OutputConfig + +| Field | Type | Description | +|---|---|---| +| `format` | `string` | Desired format (e.g., `'json'`, `'text'`). | +| `schema` | `Record` | JSON schema defining the expected output structure. | +| `constrained` | `boolean` | Whether to enforce the schema constraints natively. | +| `contentType` | `string` | Specific content type for the output. | + +### GenerateResponse + +The output from a model action. + +| Field | Type | Description | +|---|---|---| +| `message` | `Message` | The generated message. | +| `finishReason` | `enum` | Reason for generation completion: `'stop'`, `'length'`, `'blocked'`, `'interrupted'`, `'other'`, `'unknown'`. | +| `finishMessage` | `string` | Additional information about the finish reason. | +| `usage` | `GenerationUsage` | Token and character usage statistics. | +| `latencyMs` | `number` | Time taken for generation in milliseconds. | +| `custom` | `any` | Model-specific extra information. | +| `request` | `GenerateRequest` | The request that triggered this response. | + +### GenerateResponseChunk + +The chunk format for streaming responses. + +| Field | Type | Description | +|---|---|---| +| `role` | `Role` | Role of the message being generated (usually `'model'`). | +| `index` | `number` | Index of the message in the response (typically 0). | +| `content` | `Part[]` | **(Required)** Content parts in this chunk. | +| `aggregated` | `boolean` | If true, this chunk contains the full accumulated content so far. | +| `custom` | `any` | Model-specific extra information. | + +### Message + +| Field | Type | Description | +|---|---|---| +| `role` | `enum` | **(Required)** The role of the message sender: `'system'`, `'user'`, `'model'`, `'tool'`. | +| `content` | `Part[]` | **(Required)** The content of the message, composed of one or more parts. | +| `metadata` | `Record` | Arbitrary metadata associated with the message. | + +### Parts + +Genkit uses a unified `Part` structure to represent different types of content. A `Part` is a union of specific part types. + +#### Text Part + +Represents plain text content. + +```json +{ + "text": "Hello, world!" +} +``` + +#### Media Part + +Represents multimodal content. Inline data should be encoded as `data:` URIs (base64). + +**Image:** + +```json +{ + "media": { + "url": "data:image/jpeg;base64,/9j/4AAQSkZJRg...", + "contentType": "image/jpeg" + } +} +``` + +**Audio:** + +```json +{ + "media": { + "url": "data:audio/L16;codec=pcm;rate=24000;base64,AAAAAA...", + "contentType": "audio/L16;codec=pcm;rate=24000" + } +} +``` + +**Video:** + +```json +{ + "media": { + "url": "https://example.com/video.mp4", + "contentType": "video/mp4" + } +} +``` + +**Metadata:** +All parts can include a `metadata` field for provider-specific information that doesn't fit into the main schema. Common uses include `mediaResolution` for images/video, `videoMetadata` (e.g. duration, offset), or internal signatures like `thoughtSignature`. + +```json +{ + "media": { "url": "..." }, + "metadata": { + "mediaResolution": { "level": "MEDIA_RESOLUTION_HIGH" }, + "videoMetadata": { "startOffset": { "seconds": 10 } } + } +} +``` + +#### Tool Request Part + +Represents a request from the model to execute a tool. + +```json +{ + "toolRequest": { + "name": "weatherTool", + "ref": "call_123", // Optional correlation ID + "input": { "city": "New York" } + } +} +``` + +#### Tool Response Part + +Represents the result of a tool execution, sent back to the model. + +```json +{ + "toolResponse": { + "name": "weatherTool", + "ref": "call_123", // Must match the request ref + "output": { "temperature": 72 }, // Structured output + "content": [ ... ] // Optional content parts (e.g. if tool returns artifacts) + } +} +``` + +#### Custom Part + +Represents provider-specific content not covered by other types. A common use case is returning the result of server-side tools like Code Execution. + +```json +{ + "custom": { + "executableCode": { + "code": "print('Hello World')", + "language": "PYTHON" + }, + "codeExecutionResult": { + "outcome": "OUTCOME_OK", + "output": "Hello World\n" + } + } +} +``` + +## Provider-Specific Features + +Many models offer server-side features that go beyond standard text generation or client-side tool calling. These are typically handled via the `config` object or specific metadata. + +### Server-Side Tools + +Features like **Web Search** (Grounding), **Code Execution**, or **URL Context** are often implemented as "server-side tools". Since the client does not execute them, they are configured in the `config` object rather than the `tools` list. + +**Example (Web Search Configuration):** + +```json +{ + "config": { + "googleSearch": {}, // Provider-specific key + "tools": [{ "googleSearch": {} }] // Some providers might use a tools config key + } +} +``` + +**Example (URL Context):** + +```json +{ + "config": { + "urlContext": { "urls": ["https://example.com/article"] } + } +} +``` + +### Encoding Guidelines + +- **Requests**: Use `config` for enabling/configuring server-side features. Do not use `ToolRequestPart` unless the client is expected to execute the tool. +- **Responses**: + - If a server-side tool produces content (e.g., code execution output), it may appear as a `TextPart` (if integrated into the answer) or a `CustomPart`. + - Metadata about the execution (e.g., search sources, grounding metadata) should be placed in the `custom` field of the `GenerateResponse` or `Message` metadata. + +#### Reasoning Part + +Represents chain-of-thought or reasoning text provided by the model. + +```json +{ + "reasoning": "First, I will calculate..." +} +``` + + +#### Data Part + +This part is reserved for future use and is not currently supported by any known plugins. Represents generic structured data. + +```json +{ + "data": { "key": "value" } +} +``` + +## Behavior + +### Request Processing + +1. **Validation**: The model action validates the `GenerateRequest`. +2. **Context**: If `docs` are provided, the model action should incorporate them into the context, typically by augmenting the message history. +3. **Tools**: If `tools` are provided, they are converted to the format expected by the underlying model API. +4. **Configuration**: `config` options are applied. + +### System Message Handling + +Genkit standardizes system instructions as messages with `role: 'system'` within the `messages` array. However, many model providers (e.g., Google GenAI) require system instructions to be passed as a separate configuration field rather than part of the conversation history. + +**Implementation Requirement:** +- The model action MUST accept `role: 'system'` messages in the input `messages` array. +- If the underlying provider requires separate system instructions: + 1. Extract the system message(s) from the `messages` array. + 2. Convert/format them as required by the provider (e.g., `systemInstruction` field). + 3. Ensure they are NOT passed in the regular conversation history if the provider doesn't support `system` role there. + +### Configuration Handling + +Model plugins should follow the "passthrough" pattern for configuration options. This ensures that new features added to the underlying model API can be used immediately by users without requiring plugin updates. + +1. **Extract Known Options**: Explicitly destructure known configuration keys (e.g., `temperature`, `topK`, `topP`) to handle them according to Genkit's common schema or specific logic. +2. **Pass Through the Rest**: Pass all remaining unknown keys directly to the underlying model API's configuration object. + +**Example (TypeScript):** + +```typescript +const { + temperature, + topK, + ...restOfConfig +} = request.config || {}; + +const apiRequest = { + model: modelName, + temperature: temperature, // Handle known keys + top_k: topK, + ...restOfConfig // Pass through unknown keys +}; +``` + +3. **Merge Tools**: If the provider supports passing tools via configuration (e.g., `config.tools`) in addition to the standard `request.tools`, these should be merged. This allows users to pass provider-specific tool definitions (like server-side tools) alongside standard Genkit tools. + +```typescript +const tools = request.tools?.map(toProviderTool) || []; +if (config.tools) { + tools.push(...config.tools); +} +``` + +### Response Generation + +1. **Content**: The model output is parsed into `Part` objects. Text is mapped to `TextPart`, function calls to `ToolRequestPart`. +2. **Streaming**: When streaming, the model emits `GenerateResponseChunk`s. + - Chunks should ideally contain incremental updates. + - If the underlying model only supports full responses during streaming, `aggregated: true` should be set. +3. **Finish Reason**: The model must map provider-specific finish reasons to the standard Genkit enum. + +### Tool Handling + +Tools are a central capability of Genkit models. Implementation involves converting definitions, handling requests (including streaming), and processing responses. + +#### Tool Definition Conversion +The model action must convert Genkit's `ToolDefinition` into the format expected by the provider. + +- **Name**: Sanitize tool names if the provider has strict rules (e.g., replace `/` with `__` for Gemini). +- **Input Schema**: Convert the JSON Schema in `inputSchema` to the provider's schema format. +- **Description**: Pass the tool description. + +#### Tool Requests +When the model decides to call a tool, it emits a `ToolRequestPart`. + +- **Ref**: Assign a stable `ref` (call ID) if the provider supports it, to correlate with the response. +- **Input**: The arguments for the tool. + +**Partial Tool Requests (Streaming)** +Some models (like Gemini 3.0) support streaming tool calls. In this case, the model emits `ToolRequestPart`s with `partial: true`. + +- The `input` field in a partial request should contain the **accumulated** arguments so far (if supported by the plugin logic) or the current delta, depending on how the plugin manages state. +- The final chunk for the tool call should have `partial: false` (or omitted). + +#### Tool Responses +The result of a tool execution is passed back to the model as a `ToolResponsePart` in a message with `role: 'tool'`. + +- **Ref**: Must match the `ref` of the corresponding `ToolRequestPart`. +- **Output**: The result of the tool execution (usually a JSON object). +- **Content**: Optional list of Parts (e.g., if the tool returns an image or other rich content). + +#### Multi-turn Flow +Models supporting tools must handle the conversation loop: +1. `User Message` +2. `Model Message` (containing `ToolRequestPart`s) +3. `Tool Message` (containing `ToolResponsePart`s) +4. `Model Message` (Final Answer) + +### Structured Output + +- If `output.schema` is provided, the model should attempt to generate content matching that schema. +- If `output.constrained` is true and the model supports it, the schema is enforced by the model generation process. +- Otherwise, the schema may be included in the prompt instructions. +- The resulting structured data should be typically serialized in a `TextPart`. diff --git a/genkit-merge-and-fixes.patch b/genkit-merge-and-fixes.patch new file mode 100644 index 0000000000..e3c7744ef6 --- /dev/null +++ b/genkit-merge-and-fixes.patch @@ -0,0 +1,38549 @@ +diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml +index 08d851ec1..ba3ad2572 100644 +--- a/.github/workflows/python.yml ++++ b/.github/workflows/python.yml +@@ -98,11 +98,23 @@ jobs: + - name: Lint with ruff + run: uv run --directory py ruff check --select I . + ++ - name: Type check with Ty ++ run: uv run --directory py ty check --exit-zero . ++ + - name: Check licenses + run: ./bin/check_license + + - name: Run Python tests + run: uv run --python ${{ matrix.python-version }} --active --isolated --directory py pytest -xvs --log-level=DEBUG . + ++ - name: Run Python tests (tox) ++ run: | ++ clean_version=$(echo "${{ matrix.python-version }}" | tr -d '.') ++ uv run --directory py tox -e "py$clean_version" ++ ++ - name: Run Python tests (nox) ++ run: | ++ uv run --directory py nox -s "tests-${{ matrix.python-version }}" ++ + - name: Build distributions + run: ./py/bin/build_dists +diff --git a/bin/check_license b/bin/check_license +index 3a72112e5..83aa4890b 100755 +--- a/bin/check_license ++++ b/bin/check_license +@@ -63,6 +63,7 @@ $HOME/go/bin/addlicense \ + -ignore '.trunk/**/*' \ + -ignore '**/*.toml' \ + -ignore '**/*.nix' \ ++ -ignore '**/*.yaml' \ + "$TOP_DIR" + + uv run --directory "${PY_DIR}" liccheck +diff --git a/bin/run_lint b/bin/run_lint +index 66144ea11..f682deae3 100755 +--- a/bin/run_lint ++++ b/bin/run_lint +@@ -24,7 +24,7 @@ PY_DIR="${TOP_DIR}/py" + JS_DIR="${TOP_DIR}/js"j + + uv run --directory "${PY_DIR}" ruff check --select I --fix --preview --unsafe-fixes . +-uv run --directory "${PY_DIR}" mypy . ++uv run --directory "${PY_DIR}" ty check --exclude samples . + + # Disabled because there are many lint errors. + #pushd "${GO_DIR}" &>/dev/null +diff --git a/docs/model-spec.md b/docs/model-spec.md +new file mode 100644 +index 000000000..119cca1e6 +--- /dev/null ++++ b/docs/model-spec.md +@@ -0,0 +1,365 @@ ++# Genkit Model Action Specification ++ ++This document specifies the contract for Genkit Model Actions. Genkit models are implemented as actions with specific input, output, and streaming schemas. They encapsulate the logic for communicating with AI models, handling multimodal input, tool calling, and structured output. ++ ++## Model Action Definition ++ ++A Genkit Model is an Action with the following characteristics: ++ ++- **Action Type**: `model` ++- **Input Schema**: `GenerateRequest` ++- **Output Schema**: `GenerateResponse` ++- **Streaming Schema**: `GenerateResponseChunk` ++ ++### Metadata ++ ++Model actions should define the following metadata: ++ ++- `model`: Object containing model capability information. ++ - `label`: Human-readable name (e.g., "Google AI - Gemini Pro"). ++ - `versions`: Array of supported version strings. ++ - `supports`: Object defining supported capabilities: ++ - `multiturn`: Boolean (history support). ++ - `media`: Boolean (multimodal input support). ++ - `tools`: Boolean (tool calling support). ++ - `systemRole`: Boolean (system message support). ++ - `output`: Array of supported output formats (e.g., `['json', 'text']`). ++ - `contentType`: Array of supported output content types. ++ - `context`: Boolean (document context support). ++ - `constrained`: Enum (`'none'`, `'all'`, `'no-tools'`) - native constrained generation support. ++ - `toolChoice`: Boolean (forcing tool selection). ++ - `longRunning`: Boolean (long running operation support). ++ - `stage`: Development stage (`'featured'`, `'stable'`, `'unstable'`, `'legacy'`, `'deprecated'`). ++ - `customOptions`: JSON Schema for model-specific configuration (exposed as `config` in request). ++ ++## Data Structures ++ ++### GenerateRequest ++ ++The input to a model action. ++ ++| Field | Type | Description | ++|---|---|---| ++| `messages` | `Message[]` | **(Required)** List of messages in the conversation history. | ++| `config` | `any` | Model-specific configuration options (e.g., temperature, topK). Validated against the model's config schema. | ++| `tools` | `ToolDefinition[]` | List of tools available for the model to call. | ++| `toolChoice` | `enum` | Tool selection strategy: `'auto'`, `'required'`, or `'none'`. | ++| `output` | `OutputConfig` | Configuration for the desired output format/schema. | ++| `docs` | `DocumentData[]` | Retrieved documents to be used as context. | ++ ++#### OutputConfig ++ ++| Field | Type | Description | ++|---|---|---| ++| `format` | `string` | Desired format (e.g., `'json'`, `'text'`). | ++| `schema` | `Record` | JSON schema defining the expected output structure. | ++| `constrained` | `boolean` | Whether to enforce the schema constraints natively. | ++| `contentType` | `string` | Specific content type for the output. | ++ ++### GenerateResponse ++ ++The output from a model action. ++ ++| Field | Type | Description | ++|---|---|---| ++| `message` | `Message` | The generated message. | ++| `finishReason` | `enum` | Reason for generation completion: `'stop'`, `'length'`, `'blocked'`, `'interrupted'`, `'other'`, `'unknown'`. | ++| `finishMessage` | `string` | Additional information about the finish reason. | ++| `usage` | `GenerationUsage` | Token and character usage statistics. | ++| `latencyMs` | `number` | Time taken for generation in milliseconds. | ++| `custom` | `any` | Model-specific extra information. | ++| `request` | `GenerateRequest` | The request that triggered this response. | ++ ++### GenerateResponseChunk ++ ++The chunk format for streaming responses. ++ ++| Field | Type | Description | ++|---|---|---| ++| `role` | `Role` | Role of the message being generated (usually `'model'`). | ++| `index` | `number` | Index of the message in the response (typically 0). | ++| `content` | `Part[]` | **(Required)** Content parts in this chunk. | ++| `aggregated` | `boolean` | If true, this chunk contains the full accumulated content so far. | ++| `custom` | `any` | Model-specific extra information. | ++ ++### Message ++ ++| Field | Type | Description | ++|---|---|---| ++| `role` | `enum` | **(Required)** The role of the message sender: `'system'`, `'user'`, `'model'`, `'tool'`. | ++| `content` | `Part[]` | **(Required)** The content of the message, composed of one or more parts. | ++| `metadata` | `Record` | Arbitrary metadata associated with the message. | ++ ++### Parts ++ ++Genkit uses a unified `Part` structure to represent different types of content. A `Part` is a union of specific part types. ++ ++#### Text Part ++ ++Represents plain text content. ++ ++```json ++{ ++ "text": "Hello, world!" ++} ++``` ++ ++#### Media Part ++ ++Represents multimodal content. Inline data should be encoded as `data:` URIs (base64). ++ ++**Image:** ++ ++```json ++{ ++ "media": { ++ "url": "data:image/jpeg;base64,/9j/4AAQSkZJRg...", ++ "contentType": "image/jpeg" ++ } ++} ++``` ++ ++**Audio:** ++ ++```json ++{ ++ "media": { ++ "url": "data:audio/L16;codec=pcm;rate=24000;base64,AAAAAA...", ++ "contentType": "audio/L16;codec=pcm;rate=24000" ++ } ++} ++``` ++ ++**Video:** ++ ++```json ++{ ++ "media": { ++ "url": "https://example.com/video.mp4", ++ "contentType": "video/mp4" ++ } ++} ++``` ++ ++**Metadata:** ++All parts can include a `metadata` field for provider-specific information that doesn't fit into the main schema. Common uses include `mediaResolution` for images/video, `videoMetadata` (e.g. duration, offset), or internal signatures like `thoughtSignature`. ++ ++```json ++{ ++ "media": { "url": "..." }, ++ "metadata": { ++ "mediaResolution": { "level": "MEDIA_RESOLUTION_HIGH" }, ++ "videoMetadata": { "startOffset": { "seconds": 10 } } ++ } ++} ++``` ++ ++#### Tool Request Part ++ ++Represents a request from the model to execute a tool. ++ ++```json ++{ ++ "toolRequest": { ++ "name": "weatherTool", ++ "ref": "call_123", // Optional correlation ID ++ "input": { "city": "New York" } ++ } ++} ++``` ++ ++#### Tool Response Part ++ ++Represents the result of a tool execution, sent back to the model. ++ ++```json ++{ ++ "toolResponse": { ++ "name": "weatherTool", ++ "ref": "call_123", // Must match the request ref ++ "output": { "temperature": 72 }, // Structured output ++ "content": [ ... ] // Optional content parts (e.g. if tool returns artifacts) ++ } ++} ++``` ++ ++#### Custom Part ++ ++Represents provider-specific content not covered by other types. A common use case is returning the result of server-side tools like Code Execution. ++ ++```json ++{ ++ "custom": { ++ "executableCode": { ++ "code": "print('Hello World')", ++ "language": "PYTHON" ++ }, ++ "codeExecutionResult": { ++ "outcome": "OUTCOME_OK", ++ "output": "Hello World\n" ++ } ++ } ++} ++``` ++ ++## Provider-Specific Features ++ ++Many models offer server-side features that go beyond standard text generation or client-side tool calling. These are typically handled via the `config` object or specific metadata. ++ ++### Server-Side Tools ++ ++Features like **Web Search** (Grounding), **Code Execution**, or **URL Context** are often implemented as "server-side tools". Since the client does not execute them, they are configured in the `config` object rather than the `tools` list. ++ ++**Example (Web Search Configuration):** ++ ++```json ++{ ++ "config": { ++ "googleSearch": {}, // Provider-specific key ++ "tools": [{ "googleSearch": {} }] // Some providers might use a tools config key ++ } ++} ++``` ++ ++**Example (URL Context):** ++ ++```json ++{ ++ "config": { ++ "urlContext": { "urls": ["https://example.com/article"] } ++ } ++} ++``` ++ ++### Encoding Guidelines ++ ++- **Requests**: Use `config` for enabling/configuring server-side features. Do not use `ToolRequestPart` unless the client is expected to execute the tool. ++- **Responses**: ++ - If a server-side tool produces content (e.g., code execution output), it may appear as a `TextPart` (if integrated into the answer) or a `CustomPart`. ++ - Metadata about the execution (e.g., search sources, grounding metadata) should be placed in the `custom` field of the `GenerateResponse` or `Message` metadata. ++ ++#### Reasoning Part ++ ++Represents chain-of-thought or reasoning text provided by the model. ++ ++```json ++{ ++ "reasoning": "First, I will calculate..." ++} ++``` ++ ++ ++#### Data Part ++ ++This part is reserved for future use and is not currently supported by any known plugins. Represents generic structured data. ++ ++```json ++{ ++ "data": { "key": "value" } ++} ++``` ++ ++## Behavior ++ ++### Request Processing ++ ++1. **Validation**: The model action validates the `GenerateRequest`. ++2. **Context**: If `docs` are provided, the model action should incorporate them into the context, typically by augmenting the message history. ++3. **Tools**: If `tools` are provided, they are converted to the format expected by the underlying model API. ++4. **Configuration**: `config` options are applied. ++ ++### System Message Handling ++ ++Genkit standardizes system instructions as messages with `role: 'system'` within the `messages` array. However, many model providers (e.g., Google GenAI) require system instructions to be passed as a separate configuration field rather than part of the conversation history. ++ ++**Implementation Requirement:** ++- The model action MUST accept `role: 'system'` messages in the input `messages` array. ++- If the underlying provider requires separate system instructions: ++ 1. Extract the system message(s) from the `messages` array. ++ 2. Convert/format them as required by the provider (e.g., `systemInstruction` field). ++ 3. Ensure they are NOT passed in the regular conversation history if the provider doesn't support `system` role there. ++ ++### Configuration Handling ++ ++Model plugins should follow the "passthrough" pattern for configuration options. This ensures that new features added to the underlying model API can be used immediately by users without requiring plugin updates. ++ ++1. **Extract Known Options**: Explicitly destructure known configuration keys (e.g., `temperature`, `topK`, `topP`) to handle them according to Genkit's common schema or specific logic. ++2. **Pass Through the Rest**: Pass all remaining unknown keys directly to the underlying model API's configuration object. ++ ++**Example (TypeScript):** ++ ++```typescript ++const { ++ temperature, ++ topK, ++ ...restOfConfig ++} = request.config || {}; ++ ++const apiRequest = { ++ model: modelName, ++ temperature: temperature, // Handle known keys ++ top_k: topK, ++ ...restOfConfig // Pass through unknown keys ++}; ++``` ++ ++3. **Merge Tools**: If the provider supports passing tools via configuration (e.g., `config.tools`) in addition to the standard `request.tools`, these should be merged. This allows users to pass provider-specific tool definitions (like server-side tools) alongside standard Genkit tools. ++ ++```typescript ++const tools = request.tools?.map(toProviderTool) || []; ++if (config.tools) { ++ tools.push(...config.tools); ++} ++``` ++ ++### Response Generation ++ ++1. **Content**: The model output is parsed into `Part` objects. Text is mapped to `TextPart`, function calls to `ToolRequestPart`. ++2. **Streaming**: When streaming, the model emits `GenerateResponseChunk`s. ++ - Chunks should ideally contain incremental updates. ++ - If the underlying model only supports full responses during streaming, `aggregated: true` should be set. ++3. **Finish Reason**: The model must map provider-specific finish reasons to the standard Genkit enum. ++ ++### Tool Handling ++ ++Tools are a central capability of Genkit models. Implementation involves converting definitions, handling requests (including streaming), and processing responses. ++ ++#### Tool Definition Conversion ++The model action must convert Genkit's `ToolDefinition` into the format expected by the provider. ++ ++- **Name**: Sanitize tool names if the provider has strict rules (e.g., replace `/` with `__` for Gemini). ++- **Input Schema**: Convert the JSON Schema in `inputSchema` to the provider's schema format. ++- **Description**: Pass the tool description. ++ ++#### Tool Requests ++When the model decides to call a tool, it emits a `ToolRequestPart`. ++ ++- **Ref**: Assign a stable `ref` (call ID) if the provider supports it, to correlate with the response. ++- **Input**: The arguments for the tool. ++ ++**Partial Tool Requests (Streaming)** ++Some models (like Gemini 3.0) support streaming tool calls. In this case, the model emits `ToolRequestPart`s with `partial: true`. ++ ++- The `input` field in a partial request should contain the **accumulated** arguments so far (if supported by the plugin logic) or the current delta, depending on how the plugin manages state. ++- The final chunk for the tool call should have `partial: false` (or omitted). ++ ++#### Tool Responses ++The result of a tool execution is passed back to the model as a `ToolResponsePart` in a message with `role: 'tool'`. ++ ++- **Ref**: Must match the `ref` of the corresponding `ToolRequestPart`. ++- **Output**: The result of the tool execution (usually a JSON object). ++- **Content**: Optional list of Parts (e.g., if the tool returns an image or other rich content). ++ ++#### Multi-turn Flow ++Models supporting tools must handle the conversation loop: ++1. `User Message` ++2. `Model Message` (containing `ToolRequestPart`s) ++3. `Tool Message` (containing `ToolResponsePart`s) ++4. `Model Message` (Final Answer) ++ ++### Structured Output ++ ++- If `output.schema` is provided, the model should attempt to generate content matching that schema. ++- If `output.constrained` is true and the model supports it, the schema is enforced by the model generation process. ++- Otherwise, the schema may be included in the prompt instructions. ++- The resulting structured data should be typically serialized in a `TextPart`. +diff --git a/genkit-tools/cli/package.json b/genkit-tools/cli/package.json +index 2480b48a7..baf67cd48 100644 +--- a/genkit-tools/cli/package.json ++++ b/genkit-tools/cli/package.json +@@ -32,16 +32,17 @@ + "dependencies": { + "@genkit-ai/telemetry-server": "workspace:*", + "@genkit-ai/tools-common": "workspace:*", ++ "@inquirer/prompts": "^7.8.0", + "@modelcontextprotocol/sdk": "^1.13.1", + "axios": "^1.7.7", + "colorette": "^2.0.20", + "commander": "^11.1.0", + "extract-zip": "^2.0.1", + "get-port": "5.1.1", +- "@inquirer/prompts": "^7.8.0", + "open": "^6.3.0", + "ora": "^5.4.1", +- "semver": "^7.7.2" ++ "semver": "^7.7.2", ++ "yaml": "^2.8.0" + }, + "devDependencies": { + "@jest/globals": "^29.7.0", +diff --git a/genkit-tools/cli/src/cli.ts b/genkit-tools/cli/src/cli.ts +index b45ed7054..614aba79e 100644 +--- a/genkit-tools/cli/src/cli.ts ++++ b/genkit-tools/cli/src/cli.ts +@@ -23,6 +23,7 @@ import { + } from '@genkit-ai/tools-common/utils'; + import { Command, program } from 'commander'; + import { config } from './commands/config'; ++import { devTestModel } from './commands/dev-test-model'; + import { evalExtractData } from './commands/eval-extract-data'; + import { evalFlow } from './commands/eval-flow'; + import { evalRun } from './commands/eval-run'; +@@ -59,6 +60,7 @@ const commands: Command[] = [ + initAiTools, + config, + start, ++ devTestModel, + mcp, + ]; + +diff --git a/genkit-tools/cli/src/commands/dev-test-model.ts b/genkit-tools/cli/src/commands/dev-test-model.ts +new file mode 100644 +index 000000000..b03d01c42 +--- /dev/null ++++ b/genkit-tools/cli/src/commands/dev-test-model.ts +@@ -0,0 +1,545 @@ ++/** ++ * Copyright 2024 Google LLC ++ * ++ * Licensed under the Apache License, Version 2.0 (the "License"); ++ * you may not use this file except in compliance with the License. ++ * You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++import { ++ GenerateRequestData, ++ GenerateResponseData, ++ GenerateResponseSchema, ++ Part, ++} from '@genkit-ai/tools-common'; ++import { ++ GenkitToolsError, ++ RuntimeManager, ++} from '@genkit-ai/tools-common/manager'; ++import { findProjectRoot, logger } from '@genkit-ai/tools-common/utils'; ++import { Command } from 'commander'; ++import { readFileSync } from 'fs'; ++import { resolve } from 'path'; ++import { parse } from 'yaml'; ++import { startDevProcessManager, startManager } from '../utils/manager-utils'; ++ ++interface TestOptions { ++ supports: string; ++ fromFile?: string; ++} ++ ++type TestCase = { ++ name: string; ++ input: GenerateRequestData; ++ validators: string[]; ++}; ++ ++type TestSuite = { ++ model: string; ++ supports?: string[]; ++ tests?: TestCase[]; ++}; ++ ++const getMessageText = (response: GenerateResponseData): string | undefined => { ++ const message = response.message || response.candidates?.[0]?.message; ++ return message?.content?.[0]?.text; ++}; ++ ++const getMessageContent = (response: GenerateResponseData) => { ++ const message = response.message || response.candidates?.[0]?.message; ++ return message?.content; ++}; ++ ++const getMediaPart = (response: GenerateResponseData) => { ++ const content = getMessageContent(response); ++ return content?.find((p: Part) => p.media); ++}; ++ ++const imageBase64 = ++ 'iVBORw0KGgoAAAANSUhEUgAAByIAAAGdCAYAAABel7RVAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAADv9SURBVHgB7d1bchPn9jfgF58qdx97BH9xquIuMIKYEQRGgBkBZgSYEQAjwIwAMgKcEcS5SxXY1h7BZt+lbGy+tUwr2yEcbEl91PNUKZKND7LU3VLeX6+1SgEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABu5SAQDokJs3b44ODw/vXrp06aePHz/eik+Nzvzzbnz+fXz+l7W1tdd//PHHuAAAAAAAnSSIBAA64fr16+sRMD6Oy/oFvi2Dyed7e3vbBQAAAADoFEEkANCqKQPIz43X1tbuqJAEABimUfjhhx+K93sAAP2yVAAAWhBrSZevXr369OTk5M2MIeTpjzs8PDy4du3ai2ztWgAAGJSVlZXRn3/+ebkAANArgkgAoFFVAPl4aWnpID7cLHMUgeZGBJK/5c8vAAAMRrzPu7y8vHyrAADQK4JIAKAx2YY1Asjf4uZWXOo6oz1/7laEkVkhuVEAAOi97KARF0EkAEDPCCIBgNrduHHjVoSCb7INa3w4Ks0YxWLVi/y92rUCAPTbpUuXfoyrnwoAAL1yqQAA1CTbsC4tLWWb1Lm2YJ1GLF5tr66uPvnjjz/GBQCA3oj3lKOqrX85OTn513g8fl8AAOgFFZEAQC2uXr36sI45kNOq5ke+ifvVifsDAMD5LC8vr09ux/vLjQIAQG+oiAQA5irnQJ6cnDyNm12e4TO+dOnSk729ve0CHZZVxaurq6PYp25FmD6K7fb/yqc5qJfz4+rLLpevz1x9P7nE955ex/eN4/q/8fFufu7Dhw+7KksA6LKrV6/mjPHT95bx2rUT7+HuFAAAekEQCQDMRbbMWl5efhEhx3rpCe1a6ZKcZXp0dLQe+9CPsW1m8JgLrpdLMzKo3I3fmeHk70tLS7tv377dLQDQsuoktzdnPxevU3fevXu3UwAA6DxBJAAwk2oO5MPyqQVrU6HJvD1bW1t7LpCkSbnvRHh/N27+VAX4o9ItGU7uxH37Ne7njmASgDZcu3btzecnuqmKBADoD0EkADC1nAMZV1ulvwHkWdq1UruqquOn2NbW+1Q9XMl9ZCcuL1WhAF2UgVWZszwZY39/f6vQinhON+I5ePGlf1MVCQDQD4JIAODCMkyJRaHHPQxSzmMcC1sPLGwxL2eqhrP6scuzUy9CKAl0ztWrVz+WOcs27nt7ew8KjateP3M25OgrXzI+OTm5bc4xAEC3LRUAgHPKOZDXrl17lXN6BhpCplH+ffF3vsiZfQWmlIF9VufEIup/yqfK4aGEkGkUx4CN3Fdi4f8gK1bsLwDMU7x+Pi7fbls+qr4GAIAOUxEJAHzXQOZATmsrwpbnzrbnPBZ8XzmtHFpdXX1i3irQBhWRwxHPZQaMW+f88kf7+/vPCgAAnaQiEgD4pqx0qtpibZXmg5UM/55EEPivvC7t2Mq/Px+HAl+RAWQumsa2clCGMzf1wrJK8vDw8EBFMQDTiteQ++X8IWR6Wn0PAAAdpCISAPiitudA5vy51dXVB2crq7I1bHx+Ky5tLTaN19bW7qj2YmLRKyC/R4Uk0CQVkf2X7z9PTk5elSleU+O52ojn6mUBAKBTBJEAwN9UwcrTuLlR2rEbv//Ru3fvdr72BRlIxte8Kd+eG1Qb4QqpahsngDwH+wzQBEFkv2VV48ePH7fLbLRpBQDoGK1ZAYBTn7WW3CjNyzasuXh0+1shZBqH+LorsTiYC4Pj0rBJ+8kqiGLBZLVGPPcL3YL1ouwzAHxLvj7MIYRMT73WAAB0i4pIAGDSButFaanCMDyP378V+eL7C37fpIIzq9LaWnQaRyD6ZG9vb7swaFmJu7y8/KKtdsUDMo599sH3TjgAuCgVkf1T12vrl1r8AwDQDkEkACywGzdu3IoA8GmX5kBOqwPzI3fX1tbuWfAapljczjmQW0UF5Dw9i+PPk2lOQAD4EkFkfzQ4Y3kr3p+99P4MAKA9gkgAWEDV4s9kvl0baquIynD1+Pj4VTE/kjlQBVm7cSwQ37G/APMgiOy+BgPIs8Zx2RZIAgC0QxAJAAum5cqurHx6vr+/v1Vqdu3atY0IjzJsHZXmjcunv/NZobeqlsUZaquCrN8j+wswK0FkN928eXN0eHh4Nx7Ln9s+sSe7ccTVywhDd9++fbtbAAConSASABZEn+dATsv8SKYVi9lPS3sVw4vqWYSRjwrAlASR3ZDdKY6OjrJl/mn4WLp7Qs84Lr/EfdytO5is3oevlzlbW1vbVuUJAHSdIBIABq7t1pJ55nmGcXW0YT2vtudHatfaHxlex/7ySivW1pi1CkxNENlNVdv89a5URMZ9+CVe63earIiMbXOr1HBiXASod9p8jw0AcB6CSAAYqDMzeLZKO2qbAzmttqtCBZLdloF1bLNvSntVw3xibiQwFUFk9505Oeyn0tzr7elogHgP+KzJzhxnCSIBgEW2VACAwck5kLEwcVDaCSFzgedJLPbc7trCSN6f/f39K7H4lQuK49Kwjx8/bhweHr7J+ZWFTsmQOvaZ34oQsgtyltibnClWABiUCALHBwcHG/E+8U58+KTUL9+TXsn55G2FkAAAi04QCQADkmFKhFxZ0fWstDCPJ6sG+rDYkzMbG1wA+9woAskXERYfxHN1t9C6eB7ux/aQ+01XZ1gtogwjf8t2fgWAwclAMt8v5vvGUs/JYdnqWwAJANABgkgAGIBscxVhyqsMU9qYvZPzdrI1VLYu68tiz9kFsHjMXpbmZSD5Kp63Fyq/2hOB8ON4HrZL/+V+N/7s0neXj4+P3wgjAYarej+WYeQ8Tw57Hj/zthbfAADdsFIAgN46MwdyM8KUNqq53kcI+SgrDEtP5QJYXG1cv359u435kVW71o2cHbS2tvbSollzMoQs7c1QvahxBv6xvfw7rvP2eGVlZfznn3++/174n8eJ1dXVDL4vxyWvcz7Xj3kd/9z1kO80jIyw3oIywIDlyWHxupyvZ0/LbB7Fz3pWAADoDEEkAPRUzhmMICGDlFFpXi4UPY/g7tlQ2l1V8yyvtPi4bmUgGb//SZ+D3b7oeAiZAf9ubIe/LC0t7X748GF3lv2s+t7dL/1bhpTLy8vrVSX1T6WbweTlo6OjvF/jAsBgZYAY74PeZwv7MoWcAe49FABA9wgiAaBncg5kBmVttGBNscjzenV19dFQq5NyASvCmWw1uxkfPizNmsyPfBzh0L23b9/uFuauoyFkho+v4/Jy1uDxIqrf87q6nLZ5zmAybv5cHWPanpv5PvbFe9WJAgAMXL4PizCyXDSMFEICAHSXIBIAeqJqw/r05ORko7Qg20LG5ckiBAJVu9bNeMyfxd+8FZf7pVmj4+Pj32IhbjtC3ydaUs5PPKb3Y3Fzq3RE7ldx9Tye750uVBdX2/52dTmtvI6r+y2d+PA+QtE7AnmAxZKB4tWrV0dx8/E5v0U3CQCADlsqAEDnxWLMwwghD+LmRmlehiM5B/LOolUlZShzcHCwEWHRvdJCW8hqfuRBVcHHjG7cuHErHtPt0r7cp56sra1dyf0qLq+72uI4F3bzPp6cnFyJx+5l+XTfmyCEBFhgOTOyet35pniPtp1fWwAA6CxBJAB0WLZJvHbt2pu4+ay00yLxSQYQObOnLLAMiuIxuJJtv0o7c+q2Iow8uHnz5qgwldyXjo+PX5V2nQaQ1T611adK10koH/f9dnz4pNQbSAohAcgTsrJN/vgbXzLOzhEFAIBOE0QCQEdlcLK0tPSmjZaI2S4yq7UyLOlqpVYbsjosgpg7cfN5ad7o8PDwTVb1FS5ksi/lzdKOvwWQfd6nMpDMvyEDyfNUqkxBCAnAqXy9jNfvB1/79xwZoH09AED3CSIBoINaDE7G8XuzVeQdCztfVgUxm2daVTYpq/reqIy8mAi2npaWQsgq1L89tFD/TIXklTK/KmEhJAB/k2MBqnnKf5OfMxcSAKAfBJEA0DGRQV5uIYQ8nQOZ7UcXbQ7ktFqcH3k5KyNzOyl8V87XjMD4bmne+9w2hh7qV8F8hpGztsYTQgLwRVn5eJ7PAQDQTYJIAOiYCCEfl2ZDyOfmQE6vpfmRo2o74RuuX7++HldbpWGxLbzOfSq3jbIgqnat01ZHCiEB+KovVEWOnTgHANAfgkgA6JBsyRpXm6UBuaATi//ZMnLTHMjZTeZHNtiudbMK2viC3Jfi+XhRmvcotoV7i7hPZXVkzo4sF5uhKoQE4Lvi/dUvk9uqIQEA+kUQCQAd0lCV219zIC3+z9fZuXlNBJLxO1RFfkULlcXjKthf6MriDGDz5IZyvlatQkgAziXeW21Pbq+uru4UAAB6QxAJAB1RVUNulPpkhdaTrFjqWzurmzdvjkqPnJkfWWu71ggi182K/Kdr165tlHr3pc/trq2tCdTOyFat1fb/tcrQcTxmtz1mAJxH1WkgXzN2hzx7GQBgiASRANARy8vL66UmEQhsV3Mgt/rYMvLDhw+j0kPZrrXu+ZFLS0sbhb9koN9kpWi2OM6WvBZF/ym3/6x4LP8MIzOE9JgBcFG/VhcAAHpEEAkA3fFzmbMMSao2rA/6PLMugp710mN1zo/MqsjCX2J7f1gaasmaz2e2ODZj9euy4vGzMFIICcBU4n3tbr63LQAA9IogEgA6IkKNUZmfcVbhZUjStzasXxJ/y4+l587Oj4y/53WZkyE8NvNStTfeLA3IEDKfz8J3ZRgZAfG9IoQEYAbxWpJBpJN/AAB6ZqUAAF0xKrPLxZnnEXY9G1KV1pCq/jKQjKt7OcewaiE6KrMZFU7FAuWr0oAMkvf39zcK51adEHGlAMCUjo6O3v/www+CSACAnlERCQDdcbnMIMORtbW1232dA/k1o9EoH5fLN2/eHJUBmcyPjJuPyj9n6HFBGezG1a1Sv93j4+MHBQBoVJ7MpaoeAKB/BJEAMAzjCLbuDXFxZmVl5TRcOjo6Wi8DFGHksyKInFlVXVq3bC16z0xIAAAAgPPRmhUA6LSTk5P1vI6gqYlqN3qoanM7KjUz35A++VoV+Z9//vlemD5MX3rOPd8AAEDbBJEAQKddunTppwiZ8vrHAl/QUDXkIyEkXZKh09HR0a0M4avj46gK5E/bWR8eHn7x+5aWlsrVq1fzZoZT4/je0+v43t/j9jj+ffz27dvdQidVz/t6PF8/xvM1qk7S+epz/qXnO74nn99/x7/tfvjwYVdQCQAA1EkQCQB0Vs6HjAXT9byd1/mxBVPOaqIaMhbut/f29p4VaFEe/yI42sjQMY+HETqNJv+WJ2tMIcOrW2e/N28fHx9ncPU+fs9ufPzL8vLyjmCyPfm8Z4vyeC7ux+VuPO9/zZO+4PN+9vlez/+cnJycBpVxHN3JcDJu//Lu3budAgAAMEeCSACgs2IBfP3sQmt8fDeutgtUGqiGHK+urj4p0IIqfHwYoeD6mZMySgMmJ4GsV8FkVtLtxOWloKoZ169fX4+g8Oe4uRHXl0uNJs91/J7NyXOdxz1V4AAAwDwIIgGAzorF0Z8/+9T9Ioik0lA1pMV4GpchVIbsDYeP35ItQHN/2xhSUBXHkDfzPIbk47K3t/egzGDy3E/mI7fg9Lk+PDzcyErJPAYKn/tjFPIkrtKAbOU867ZRBe4vypzNY1+cuHHjxq3j4+O7ZUaTUQPzltXScVxeLzWKbeq1yngAYBaCSACgk3IxLa42zn5Oe1Y+c7/UqGrJul2gIVW4/jAW5m+V7jobVG33OZCsQshRmZ9RmVIHAsh/yNfcvGQgGc/zAydldFu+b1paWppruP4dO9VlVqMyf6MyJ9XxeObuC3WdUJLH41KzeAzGcSWIBACmtlQAADroa2f0xyLbZmHh5YLrpFqsLlqy0pQMoa5evXoQ23RWBnU5hPybKpA8iKDqxc2bN0eFC8tjWVZmxkL/m7qPadOqZpIexDb6NE8GKnTOJIQs9YR6X/I+ft+jAgAA5yCIBAA66Ruz/x4WFl4sgNY9G/K56h/qli3/JiFUaS5AmLtJIBlBVd375aDk4xXHsoOuBpBfsBn39zehc7e0EUIuLy/f0aoTAIDzEkQCAJ2T7QnL1xfULlf/zmJbL/UZr62tPStQowyhjo+Pf+tRCHUeW1nZmRWeha/K4Cgep9/i5lbpn1EVOutO0AFCSAAA+kAQCQB0zjeqIc/17wxbBNF3S72Lri9VQ1KXrILscQh1HqOs8FQd+WVx/LqfVYWlRy14v+Kp57hdQkgAAPpCEAkAdMp3qiEnRqoxFtr9Up/3a2tr2wVqEMeth8fHxxkc9D2EOo/T6khtPP8ng7uPHz9ux82hzFncEka2QwgJAECfCCIBgM7IhbULVDs+ji8fymIuFxDbyN1Sk0uXLr1WDcm85bEqApuncTNb/i7ScSvbeL7JKtCy4KrAbqsMjzCyYXk8iRDyVWkuhMx25beFkAAATEsQCQB0RiysPSznX1jLhTiLnwumastam9XV1ScF5uhM5dKiVnGPchbmIlexDziEnNiq+9jMJ1UI2WRVdYaQd5ygAwDALASRAEAnXL9+fb1cfKF+s/o+FsTHjx9/LjWJn/2LxVbm6UwIufAVgWVBZwouQAh5Ko6fL7ThrZcQEgCAvhJEAgCty8X6k5OTF2UK8X2vLH4ulPVSk1jg3S4wJy3McOuDhWrjmTNBywKEkJXLh4eHr7RMr4cQEgCAPhNEAgCtW15ezhByVKZz+ejoaKoQk36p5syNSj3Ge3t7rwvMgRDymxYijMxtoHyaCbpIbmmZPn9CSAAA+k4QCQC0KhekP378uF5mkN8fP+dpYdBOTk5qW4S9dOnSToE5EEKeS84UvF8G6kxwtIg2q5NGmAMhJAAAQyCIBABaM+fZWZuLOH9swdQ2HzKCyJcFZiSEPL+PHz8+G2pgVVUFjsqCOjk5cWLQHAghAQAYCkEkANCKmmZnLdT8sUUTwcWo1OP9u3fvdgrMKEKDV0UIeV6Xj4+PBzVTMI5Rt65du5bB0WZZYNml4Pr16+uFqQkhAQAYkpUCANCwbMmX1TClHhlGlv39/SeFwajCiloWZGNb/LXAjKr20G1X+L2Py/jSpUu7sV3/t/r4c5fj3/+vCvbbvr+j5eXlDG/vlGG4PGur8aE4OTnJ2c1XChcmhAQAYGgEkQBAo6qZkFulXsLIgVlZWbkVC9ulDuZDMquqwruVKrjcfuOY+ksEejtv377dLReU7VGPj4/X4+f83EaIVs343YzjdV0npwzR+LOPL1eXLhllVaRq84sRQgIAMESCSACgMXOeCfk9wsgBiRCytkXZ+Nk7BaaUcyFLc8e1iax0fB7b7rPxePy+zKAKL/PyLP+WCDTXIxxsesbh45s3b74WhvxTFTT/nterq6u733qMMlSOr70c20UGyz+1XZ1ZbUc7hXMRQgIAMFSCSACgdrm4FovbL2JR8m5p1ta1a9fWY/H2gYW23huVeryPIOfCVWQwUQUHTVWjvY+A6dHe3t52qUHsC+O42s5LHDs3GgwkLx8dHWUrz6G0aJ3VVEHzmYrYnfxP9dqbr7v326p2VRV5PkJIAACGTBAJANQqFyFzVlQ1j6xxuRB6eHj45saNG/emaVtIN0T48mM8l2XecpZegSlVVd6j0own86iAPK8q7GwskBRanZpr0FxtK9t5yWrX+NlbcblfGhTbbAahO4WvEkJ22/Hx8U4E+g/K7H6u6YS8J7Ffj0uNVldXdwoAwAwEkQBALaqFtcexCNnK3LTPjGIh6bcIDba0au2nuoLsbHlYYAoNtmQdxyJ4aydSZCgWf+pOEyFWnrQSV1fKYsoKyK26guaq2nUjguWdhlvv5jbThfcBnSSE7L4zleIzifego7iaexAZ28+OqmMAoOuWCgDAnGVVSyyM/Fa6t/iYcyMPssKn0DejUo9xgSlEOPi01C/DqdttV3PnQvzBwcFG3HxU6jVawONzVkHe29/f32yi2jWD5dim7sTvfF2acTnfExT+QQgJAMCiEEQCAHOTi42xiPwmFjlzYW1Uumn08ePHF3k/LY72Q1V5VotYBNaalQur2pXWPfP2SVPh1HnF/XkWAezt8mmGYS2qar1FkcHQ7QgHmwoFP/3SEL/zXtxspENA1Z6VM4SQAAAsEkEkADCzswFkzvkqPZD3M+9v3m8Vkt22srIyKjX58OFDZ0Ie+qOBsCxDyK3SQVmdGWHknVJfGLkoVZG7We3aZjBUbWO1h5E547fwFyEkAACLRhAJAEwlw8erV68+jst/+hRAfi7vd1ZIVi1bX6iSXDjjAhdQhWSjUp/OhpATGUZGkHKv1KfWWZQdkCHknS5Uu+a2Vneb1nydzfCtIIQEAGAhrRQAgHO4efPm6OjoaD1u/pQtCWMRdWiLitmyNdstbkQoOY6F2Z343K/ZurPt+WyLLp6TUalJl9pezkMG6fF4DT3EmcavORuvzEHN1ZCdDyEn3r17txPHypwZOfdZmRlc5bacv6MMz7grIeTE8fHxg3ity2BsVGqysrKSP3+nLDAhJAAAi0oQCQCcqqoVTsPFbIUZC6XrEcb9XywIny5OHh4eLlI1w2koGdcbsUBbYrE9P7cbj8c4Pv97Xsfnd6qvfT+0MKtr4jGva9sbl4HJ0LbadjkjHpN/xdV2mVE1G3JUapBVaRGWbpUeyZmRV65cuRX3fe7hdzzOD8vwgqtJMNSp14x8DYvg90E137kW8bMXOogUQgIAsMgEkQDAqSpMmyyOjuOykwtnq6urGUpmuJHB5I99bcE6hffx9+5WwWO2Idw9OjoaCx1boaUfM5nXjLoaqyHHcax9VHooHpPNeHx/KnOuppu08xzSMTcep0ddDYay+jSC9p26XuMXeU6kEBIAgEUniAQAvqpaAN6tLn/NkMqWecfHxxt1LD63LP/el7Fg+PrDhw+7Qsdhy8rWwqIYlRnlcS9Pyig1iG3xSV9DgzxORoD1KAKsV2W+MrzZiOtnZRie7+3t1TqLcVa5HdYVRC7QSUx/I4QEAIBSlgoAwAVl5cTBwcHG/v7+lVi4vJctBUuP5TzIWCi8E3/Pv+KymX+fEBKGJefclhlECFnL7M04/mzPa35lWzJgq+bqzlX8zJ/LMORcyK3ScdVMzrpe+xausl0ICQAAnwgiAYCZ5AJ0XO7FIuuVjx8/viw9kgHA8vLy7bj/d6oFWLpJa1ZmdnR0NHUYEIHCKK42Sg1WV1eflAHIaroyZ5P2rKXn8rHp0cktdb2OX571ZIA+EUICAMD/CCIBgLnI4YlZJdmHQDIrd2LB7koEkA/evn27W+i02J4EkczD1NvR8vLyeqlBngwxlOCgOplj7sfTeOzvln4b96niNVuTl5r8+eefC3EsF0ICAMDfCSIBgLmaBJLZsjU/LN0yzhasWQFpwa4/zHJkHiLQHpXp1dKWdSjVkGfUcRLKT6XH4jXnQemROtuzRqg8KgMnhAQAgH8SRAIAtciWrScnJ7fj5vPSDc/z/mjBCospAu3/K1PItqzZIrTM2ZCqISfiGLtd5qyOx75B4z6+5sS2WVengEFXRAohAQDgywSRAEBtcibW/v7+ZixqZkVIW/Ox3ufvz/vRoxldNGDGCjkWRI1tWXs1U/c88hibra/LfI36OluwjrmZTYhj4++FCxFCAgDA1wkiAYDa5XysqjpyXJqVC3W3+zSfC6jNqEzn5zJ/vayUO48IsX4pc3Z0dLReemh1dXWn9NO41GCoJ38IIQEA4NsEkQBAI3J2ZISRd+JmXS3fPrdroW4YapwROeg2gcxHTa1B5x7WdcXy8vJOmbN4DpoKeOYmA9m+vv7EMVf3gHMSQgIAwPcJIgGAxkzCyFjkfF3qtZu/x0Id3yGI5Jtu3LiR4cLct5M6Zil2xdu3b/Nkk7kGWX2spItwqu7XudrUePLHoAghAQDgfASRAECjcobY3t7evVhYrms+2mkIaR7kcNS5KN7X2XM04/j4eL3M3/s4PjVVGd6WuR5/4xjwY+mZHrdlLR8+fBgXvkkICQAA57dSAABaEEHkZrW4PM9FPCHkAOWieCz4ljrEzx6V5meX0h9zD8DiuLc79AD86OhoPOcqxlEGPz06to8FRsMlhAQAgIsRRAIArcgF5VjMuxOLeb/Fh6Myu1youxcLdULI4antOe1jy0emc3Jy8t9ycXMPGnLm5OHh4UHhQn744YdskduL43s8x78XBqmFEHLSat57GwAAektrVgCgNRlG5gJbmX1x+b1qgeGqqqBqWYSNwKCpxWRaFuHBf8rF2T464ujoqDfPRVa9FganhRDytDW5Lg8AAPSdIBIAaFWusMXC3r0yg1ioeySEHLxxqYGKyMURz/WFKiJv3LghhOyWy6UnBJHD00YImeK4dffatWt3CwAA9JggEgBo3bt373bi6nmZzpO9vb3twqDV1eqwmlPKYrhQVVFsc70JvhZBn04aiOOKCrZhaSWEnIht/0UGoQUAAHrKjEgAoBP29/c3r169+lO52ELfOL5vqzB42Z6u1GOUC7wDan2Xf8e49N+ozN+FnuOTkxMVkd3SmyBmZWVlXBiMrEos7bq8vLz8Iq5n6h4BAABtEUQCAJ2xtLT0KBb/35z362NhzqLcgsggMhaDSx1iO1qPq9dlAPb29vLv6PXfcv369fWLHAfOa4owWwVSh8Tz9/9KT2gVzrxNWrRWx3gAAOgVrVkBgM64SIvWWJTefvv2rTlcC+L4+Hin1CQWeNcLnVFXS9SLtss0PxToEi1aAQDoK0EkANApJycnW+UcLRRXV1efFBbGOJQLtta8gJ8KnRGL7bW0RP3w4cOFTlxYWlrqTQXeIojXhn8VWGyTFq0AANArgkgAoFOqWX3frIrMakit7xZPPO91VcDeUmXSHfE8/1jm7/0Uc0BtEx0iGIb/tWgtAADQI4JIAKBzTk5OnpVvVL+phlxMsQD7e6lJhBwbhU6ooyVqjSE2QKOyRevNmzdHBQAAekIQCQB0TlW59PJL/6YacnFFWPi61CS2q58LrasqU+femvXk5OS/BWAYLh8dHWnRCgBAbwgiAYBO+lrodHx8/LywkKoZf7XMifz48eO69qztW1lZqWU+ZATNO+WCYpuwPXSIMBn+J1+zrl69ulkAAKAHBJEAQCe9e/du5wvhwThosbigslK2zhabEX5b1G1ZhE21zD6L59Zxo+fiOfxPAc56rEUrAAB9IIgEADrr48ePv3z2KdWQC+4L28Q8PSy07adSg6qa9kIi9K6l+pbpxL7/7wKcpUUrAAC9IIgEADrr5OTk9Wcf7xQW2ufbxJxdvn79+nqhFaNQapgPGXarubMAg6JFKwAAfSCIBAA6K/uw5tWZD7VXXHCfbRNzF4u6jwutWF5eXi81qLOdL40SJsOXadEKAECnrRQAgG7LVpwPIyD6vcAnL+NSS2CY1SVZFZkzSgtNu1/qMVU735OTk39HiFnmKbavlxG4bhcubJr2utC23OerNs91tv6etGi9UwAAoIMEkQBAp2U1Uyzk5fVOgXIaEG0vLS3VVrlYVUXuFBqTbVkzBC41OD4+3ikdkYGEkBsWQ4aQBwcHG3F4uxyvWXmixeVSk0mL1v39/WcFAAA6RmtWABiGy7EANciWkpMQIRbxBlkNUz1vo8K5ZXvWOoPpXNC9du3aRqExdQXLuZ1MOx8yvndc5ix+5v8rwOBNQsi8nceg2PcflPpp0QoAQCcJIgFgGPIs+60ItQ6GFqBUMwEH15Yv23/m8xU3twoXFou8U7XbvMDPf5xVLIXaZTVkXK2XGsyyndQRRNZV9Ql0x9kQcmJvb+91A50dJi1aAQCgUwSRANAd4zK7bG/4IsLIV0M6Kz7+pt1pq5q6JkOXeH7enJycvCnzqYQcxONyUdmetdT7t4+WlpY2C7VbXl5eLzVVBa+trb0u06tj+xoVYLC+FEJOHB8fZ1Vkra/ZkxatBQAAOkQQCQAdkbMQy5zEQtTdw8PDrI58MYRAMh6bWqvfmpDVddmGNcKtg3lWRc1zu+mTKph+Wer1+MaNG7cKtapmctZh948//hiXKUVoUMu+FccC2xQM0LdCyFR1eHhS6qdFKwAAnSKIBICOiAWs38ucxc/ciEDyTd/btZ6cnPQ6bIsA8mEGkKWGNqx1bDd9EdvFs1KzCKO0uatRnTNSI6R/XmYwaQs9b3EsWC/AoHwvhJzY399/pkUrAACLRhAJAB0Ri9M7pR6n7VpzHmHOJSw9tLq6Oi49lI93tmGNmxmY1TJvMLabWVpP9loGRbGgW/fffyv2na3C3FWzIWtrIRjHjZ0yu3GZM3MiYVjOG0JOaNEKAMCiEUQCQEe8e/dup+az5Ec5l7CP7VqPjo7GpUeqOZCv8vGuOXQY53ZTFtisVW/n9LivIX6XRYie1ZC1BPSxXWzP0pZ1IvbfX8ucxX37qQBD8eQiIWTSohUAgEUjiASAbql75t2kXetBtkTMuYWlB6p5gJ13Zg7kbzmns9QsAo0mFjI7rYEA/1SEyi/6sr/0QdUueqPUJLaJuRxLa5rBelmwDYPwZH9/f6tMQYtWAAAWiSASADpkb29vu4lQpbKVgVnf50d2RTyOd/PxLJ/mQDYRWI1zeylM2tzVbbS8vPyqMLOsGI6g/nGpz9wqhWOfrmU+bQTbtZ+oANRq6hByIo4vj0rNtGgFAKALBJEA0DFNLEydMZkf+Zv2XdOZzIGMxzFDqlFpyNra2p3CqarNXe0tWqsF3aeFmUSgmxU6o1KTeVYKV4FmHRXZ91XYQm/NHEKmt2/f5okOWrQCADB4gkgA6JhqYarJMDLdynatfZwf2ZaqDevTBuZAfsmTecy/G5J4HrZKPYHR5zaz/W5hKvnY1T03dd6VwnW1Z11aWtooQN/MJYSciNeuZ3E1LvXSohUAgFYJIgGgg3J2UGnmLPm/qeZHvtHG69vi8XkYIcJB3GzjcZrrIuhQVHNEa6+KrGxFaH+/cCG535RPrYvrNPc5u3Fc/KXUwzYE/TL319987Yr3E7W3F9eiFQCANgkiAaCjqsWupisj0yguT2PB6sD8yL/LNqz5uMTNDIrbaKv4SAj5dQ1VlpyKRd1tYeT55QzV8mm/qdN4bW1tu8xZbFevSz1uVY/LIOUsUK8hDEhtJwFVLaCbOJFGi9YBivcj2nwDAJ0niASADsvKyFgEv1IaClc+czo/UrvWvxbU32Qb1tLgHMgzxktLS3eqSlm+oqnKkokMI1WYfN+NGzdu5bGk1O9lHS2LcwbppUuXdkoN4nF5OtRZkTkLtJpB7KQWei2rous+CahqLz4u9dKidZgEkQBA5wkiAaDjchE8FsAyjMzqyHFpWNWu9SDnIS5aIFnNgXwc4dZvLcyBTNluNKswrlQVE3xHg5UlE0/NjPy6rCI+Pj7OAL/uhdJxnUFB7P+/lnqM4vgyuDA7g8czx8yRQJI+i330P6VmWrQyLRWRAEAfCCIBoCeq6sg7seAw9xlo57SZ8yMXZSE5/85qDuRWaeds8+dZDasV68U1VFly1lYG9UOtbJtWtq6tqohrf1wuXbpU60zdqu1vXR4O6SSPrCDPSs8v/ZNAEr5Oi1amEa9/PxYAgI4TRAJAj2R15MHBwUYGVLHwUNfcsm/5ayE52y2WAcoKrmzDWrWSbDxYyhaQVRvWzayQKFxY0y1aK5tZOWtx95OsEs3WtaUZ4729ve1So9ym6mrPGi4fHh6+GkqQHfvB98JngSR8RXUiTd2v/Vq0DkgcTwf5fhwAGBZBJAD0UAaSsfB+LxbGM2wZl+aNjo+PfxvS/MiqDeuLrOBqqQ1rzqG7F8/rHW1YZ9dCi9Y0ikDpt0Vue5f7URwXXpVPlcSNWFtbu1MaUHPV5a0I8Hrf4jcrg8v55+gKJOEz1UkPWrQOU10B80hHBgCg6wSRANBjWQVUzY/MBfLGq+fOzI/s7QL6mTmQ2YZ1ozTvdA5kBKC34/lso8p1sKrKkt3SrFwMfDqkkP68spq4mqd6tzTn+R9//DEuDajC7XGpz2afg4HqdWCa+y+QhDPyvUBDXS+eDrW7RRfFc1rX+/TLKysrnkcAoNMEkQAwADlHMIOsFudHbvVxEXkSnJSW5kDGotT22tra7Xz+tGGdv3xMY7+4V9oL6RdipmoV5j+t5kGOSnPGVdjcpLqPsRli3y89U4WQW2U2fwsktTlmkR0fHz8qDbx2xe/RorU5tT2fcex8WAAAOkwQCQADcXZ+ZGmpXWsuIud8xa4vIGcPq7yfLQQnpyZzIPf29h40Vc21qHK/iMf7UWnHX8FKht5lgDIwqqqJG6/kyzmgTQf4ccx4VmoOB3K2Zp/CyDmFkGed7jcZ5OfPFkiyiPK1q3zqdlG3W7GfbRVqF69Z41KT7ETgWAkAdJkgEgAGJhevsl1rW/Mjc+5QtmvtYmvKSeVWBictzYE8nf1kDmSzsoVxaWZB92tGGXrn7MShLBRmsJphfgZGpYVq4vCkjX2oCj5rnz2aYWQf2rTWEEKeNcqfHa8nTbb6hc6I93LP8sSlUr/HWrTW7+joaFxqFD//aQEA6ChBJAAM1GfzIxvXtdaUsWD+sK3KrfK/OZBXqlCMhmX72xZbF5/KioUqpH/T1wrJuO93J9XELYX5aTefz9KS6nePS/2ednX+bnVSR4bQW6VeeWLNswIL6vj4OE8q06J1AKoTWepsz3q3rmr67CRSAABmIIgEgIGr5kdeaSmEOTvzq5WqlknlVtzMxew25kDumAPZDbEtZgi9W1qWAV4GeX2ZhVeFTo/j8p+4769aDCDTOPane6Vl2Ra2NON0/m6XtpFckI6/P4+pG6Vmcfxss5IZWqdF67DEMa3W9yDzrqafvP5XJ/IBAExNEAkAC2AyP3J5efl2aW9+5Ksm27VWcyBftVi5tVvNgbxjDmQ3ZBAc28Od0s4+8CWTWXinrYzbCuu/pFp83MwQP7bj/5RPlW9ttGD9mziG3evC/pRtYRtqmZhGuY10oTqyqiz/LW7W3sYxHt9tFeSgReuQxGv+76V+T2d9v/1ZALlVAABmtFIAgIXx9u3bPBP7SlZhxWJILmqPSoOqdq0bedZ9BELP66gQzMWTWDh5GDc34/e1EZyctmHVTrCbcpuLbeROVdE1Kh2R+0Zc5b6RAczruP1L3Mfdap+tXe43Kysrt2K//Cl+//okvI/r0iGPmno8ziNbJlahXFPHmayO3MgqwaYDuqwsz9eMBk/qeL+6uqoaEirV8ab2qrSqRevtQi3iOXwdr7MPS80m77fj/f52vGa8PM9M5cn7gPjeh9nmtQAAzNGlAgAspKq93kbcbKvKZjzvBfWsKIvFk6elvYDpeSwwacHaA2faS45Kt73PSpisosiWbnGfx7OGcbnYGCHPKLbVXHD8MX5uXmcVTOsVj9/wpM25kF9TtcB7Wpp3evyM53GnzgrRFgLIU/G3PWgibM22t6WGY0Bsq73+/+zq+FhH6FX7fhzP6dzPnqiqc5tqx/xVVevUJt6zPYvn6VFpQB5jsnNFmbN83cyOFKVjqpPlct9q4/V2Nx6XfH86/uzzl6v3AKNvfXPfj2sAQLu8kQCABZcLjrEwsRWX+6UdOfNtpvalbS2WT+SCVwQCD7Rg7ZcehZFfkkHUeHI7/xPb/7/PfkH8+/+d+TDbwI7Kp8XPLgeOX9LJEHIiW0C3WT0yqaCdRyh5pjL25/JpBmQbc3UbC30EkV8miPy7rgSRKf6+RlojZ2v581TRzWrRgsiULc9bnrU8FUEkADALrVkBYMHl/Mi4yvZNr1uqJhxVM/K2sxXfRRbSqzPLH8ci1mZpxzh+/4MmFuuYv9z2u9im9ZwmweJXdayt6rSedzmETFXLxO9Wk9SlCkHvxnE0Q4qseBnH537N1r5Z/bK8vPz+8+PqZHbY0dHRaSXs2crYOJ62GVSPtWSFr4v9+lEdwd3n4ne8iNfH2zo81OJlXNYLAMACWSoAAGFvb+91LPhfyZZ45Z9tm2pXzbN5Ewvp52o7ll9XVWy0EUJO5kBeEUL2W4aRseCa87A6M3uQv2QI2dZJBueWC/UR9t0rn44LbbtVBZNPM6yIkPS3PNEjq8TOXvJzeYmvzWrOnAm3WVXotBlCvp+1Oh6GrnrP8bzUL6ti22rdP2hxXM4qdgEvALBQBJEAwN/kXK5YwL4Ti9IvS/NGcdnKdnnXrl3b+NIXZBuvqp3eVmln0TznQF7pepUW55dBUjyfGUY2sbjL+TzpQwg5kXM7L1261MhMtaHKk2CEkPB9OYu6NHPC2Ga+5yrMVVVl6v0GALBQBJEAwD9kldjBwcFGBm7V/LGmZdvJFzl7bdJCMGdW5VydqiXZqDQs5w3lzKQMR7QqG6Yq+NIWsn2dngn5NXkSR7H9TOtJVuUX4LvyPUi2hS8NqFq09m2ucOfF4/qsqIoEABaIIBIA+KoMJGNx+F6L7VrvVvMjX2Ub1qp1YNPG8fffi8fhjjasw5cBWAbOpYXtnfI+97U+VxtX910YeTG9DJ6hTVq09puqSABg0QgiAYDvykqfnIdYPi2wj0vDqplnTTudA5nzA1XqLJZc4M32xFkFW2jKeG1tbRD7mjDyQoSQMCUtWvtNVSQAsEgEkQDAueWCcYvzIxsTAdR2hiL592rDupiqauCsjBQo1SyPJxn4D2k+oDDyXISQMAMtWvstn7+q4wgAwOAJIgGACzk7PzI+3C0DMpkDGQHUgyGFIkyvCt9zWx8X5i1D/kd5PBli4C+M/CYhJMxBVvA3NMtbi9YaVF0AtGgFAAZPEAkATCUDyVhIvt3W/Mg5Oz0r3RxIvqTa1ietiZmDDP2rquNnZcCqsO1R4SwhJMzR8fFxvg9r4mQOLVprULXYHdSJfQAAnxNEAgAz+Wx+ZN9M5kBeyb+jwDdMqiOH3pq4ZqdVkBn6L0rVcYaty8vLt4uq2jzh454QEuarqihv5D2YFq3zl89fPK73itcIAGDABJEAwFz0LaQ5U5FlDiTnNmlNPJBK4KY9z2PE0Ksgv+Tt27e7OV+3oRaKXTTO423VhhCYszyu5vuaUj8tWmuQ7y3yNaJ4XwEADJQgEgCYm0lI0/Hqn91qDuTCVGQxf5NKYIHk901mr8bjtbnIoX8eH2O7yaqXhWrVmienxAL7bcdbqJcWrf0mjAQAhkwQCQDMXVb/dDCkOW0JmXMtzYFkXgSSXzcJIM1e/busXMrK0Iaql9p0Ons3T05RdQ71yyCraNHaa1UYeVsLeABgaASRAEBtMqSpzu5ue37kk0VtCUkzBJL/I4D8vqo68s5Qt5dsQVu1Yt0uQGO0aO2/PHEjT+Aon6rnncQBAAyCIBIAqFUuuLc1P7KaA3nFHEiaMgkkM4hbsIqG3L+eCyAvZrK9lE8na4xLz50Joe9pxQrt0KJ1GKrqedWRAMAgCCIBgEZM5kfGQnXOSBuXeo3NgaRNGcTl9p4BfPlU1TAuA1RV3jyqKo43BZDTqU7WmFSPj0vPqIKF7qhatD4vDdCitV6T985tnMz3GSfzAQAzEUQCAI2KherXNbawzIWSJ/nzLYbTBVVF8LPPqiTHpceq8PFJVhtn8JR/n4rj2U2qx88cH3dLt+UMyNcCSOiePJaUZo4ho5WVlVuFWn0hkByX+uUxfjuP8bE9/asAAMzgUgEAaMkoxCLHVlzul9k9jwUaLVjphRs3btw6Pj5ej23/51hUXC/dlouRO3E/f419bNs+1pzcTj58+LAZj/9P8eGodEC1LfwyhG3h5s2bo1KDIVTi1/HY/Pnnn+/r3mb6er/rkJWKP/zwQ+3VirM+PvbD6WRb3DgO5/uIn+KYnGHwrM91vtbvxs/6PcLH1/Has+v1HgCYF0EkANC6WQLJXBSPyxPVOPRVLhZnRcmcFxRnMa7Cpt+Xl5d33r592/XKvIXwWXjd5DYyrqpgf43f/9rCNED35GtEvI8YxevDKD7M99X/r3z5BJY8hr+Pr/tv+XR8fx/B467XegCgToJIAKAzMpCMxZDNuJnVP99q9TWOyy95xrYAkiHKfSHCyVEsKuZ+kAuKP8aiYQZP82qBNy7/q374d1znXNXdo6OjsaCpH6pqmLPbx6jMVjU5LmcqYnKbiOBxx/YAAADALASRAEAnTarEqsX1U7kwrlUUi27Sbi/2hVF+XAWUX6uOy2DpdH+J/Wnc1xaDnF+2OTzHtnG6XeRleXn5/dBbGAIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAFAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAKv8fdiuBQivkrqEAAAAASUVORK5CYII='; ++ ++const VALIDATORS: Record< ++ string, ++ (response: GenerateResponseData, arg?: string) => void ++> = { ++ 'has-tool-request': (response, toolName) => { ++ const content = getMessageContent(response); ++ if (!content || !Array.isArray(content)) { ++ throw new Error( ++ `Response missing message content. Full response: ${JSON.stringify( ++ response, ++ null, ++ 2 ++ )}` ++ ); ++ } ++ const toolRequest = content.find((c: Part) => c.toolRequest); ++ if (!toolRequest) { ++ throw new Error( ++ `Model did not return a tool request. Content: ${JSON.stringify( ++ content, ++ null, ++ 2 ++ )}` ++ ); ++ } ++ if (toolName && toolRequest.toolRequest?.name !== toolName) { ++ throw new Error( ++ `Expected tool request '${toolName}', got '${toolRequest.toolRequest?.name}'` ++ ); ++ } ++ }, ++ 'valid-json': (response) => { ++ const content = getMessageContent(response); ++ if (!content || !Array.isArray(content)) { ++ throw new Error( ++ `Response missing message content. Full response: ${JSON.stringify( ++ response, ++ null, ++ 2 ++ )}` ++ ); ++ } ++ const textPart = content.find((c: Part) => c.text); ++ if (!textPart) { ++ throw new Error( ++ `Model did not return text content for JSON. Content: ${JSON.stringify( ++ content, ++ null, ++ 2 ++ )}` ++ ); ++ } ++ try { ++ JSON.parse(textPart.text!); ++ } catch (e) { ++ throw new Error( ++ `Response text is not valid JSON. Text: ${textPart.text}` ++ ); ++ } ++ }, ++ 'text-includes': (response, expected) => { ++ const text = getMessageText(response); ++ if ( ++ !text || ++ (expected && !text.toLowerCase().includes(expected.toLowerCase())) ++ ) { ++ throw new Error( ++ `Response text does not include '${expected}'. Text: ${text}` ++ ); ++ } ++ }, ++ 'text-starts-with': (response, expected) => { ++ const text = getMessageText(response); ++ if (!text || (expected && !text.trim().startsWith(expected))) { ++ throw new Error( ++ `Response text does not start with '${expected}'. Text: ${text}` ++ ); ++ } ++ }, ++ 'text-not-empty': (response) => { ++ const text = getMessageText(response); ++ if (!text || text.trim().length === 0) { ++ throw new Error('Response text is empty'); ++ } ++ }, ++ 'valid-media': (response, type) => { ++ const mediaPart = getMediaPart(response); ++ if (!mediaPart) { ++ throw new Error(`Model did not return ${type || 'media'} part.`); ++ } ++ if (type) { ++ if ( ++ mediaPart.media?.contentType && ++ !mediaPart.media.contentType.startsWith(`${type}/`) ++ ) { ++ throw new Error( ++ `Expected ${type} content type, got ${mediaPart.media.contentType}` ++ ); ++ } ++ } ++ if (type === 'image') { ++ const url = mediaPart.media?.url; ++ if (!url) throw new Error('Media part missing URL'); ++ if (url.startsWith('data:')) { ++ if (!url.startsWith('data:image/')) { ++ throw new Error('Invalid data URL content type for image'); ++ } ++ } else if (url.startsWith('http')) { ++ try { ++ new URL(url); ++ } catch (e) { ++ throw new Error(`Invalid URL: ${url}`); ++ } ++ } else { ++ throw new Error(`Unknown URL format: ${url}`); ++ } ++ } ++ }, ++}; ++ ++const TEST_CASES: Record = { ++ 'tool-request': { ++ name: 'Tool Request Conformance', ++ input: { ++ messages: [ ++ { ++ role: 'user', ++ content: [{ text: 'What is the weather in New York? Use the tool.' }], ++ }, ++ ], ++ tools: [ ++ { ++ name: 'weather', ++ description: 'Get the weather for a city', ++ inputSchema: { ++ type: 'object', ++ properties: { ++ city: { type: 'string' }, ++ }, ++ required: ['city'], ++ }, ++ }, ++ ], ++ }, ++ validators: ['has-tool-request:weather'], ++ }, ++ 'structured-output': { ++ name: 'Structured Output Conformance', ++ input: { ++ messages: [ ++ { ++ role: 'user', ++ content: [{ text: 'Generate a profile for John Doe.' }], ++ }, ++ ], ++ output: { ++ format: 'json', ++ schema: { ++ type: 'object', ++ properties: { ++ name: { type: 'string' }, ++ age: { type: 'number' }, ++ }, ++ required: ['name', 'age'], ++ }, ++ constrained: true, ++ }, ++ }, ++ validators: ['valid-json'], ++ }, ++ multiturn: { ++ name: 'Multiturn Conformance', ++ input: { ++ messages: [ ++ { role: 'user', content: [{ text: 'My name is Genkit.' }] }, ++ { role: 'model', content: [{ text: 'Hello Genkit.' }] }, ++ { role: 'user', content: [{ text: 'What is my name?' }] }, ++ ], ++ }, ++ validators: ['text-includes:Genkit'], ++ }, ++ 'system-role': { ++ name: 'System Role Conformance', ++ input: { ++ messages: [ ++ { ++ role: 'system', ++ content: [ ++ { ++ text: "IMPORTANT: your response are machine processed, always start/prefix your response with 'RESPONSE:', ex: 'RESPONSE: hello'", ++ }, ++ ], ++ }, ++ { role: 'user', content: [{ text: 'hello' }] }, ++ ], ++ }, ++ validators: ['text-starts-with:RESPONSE:'], ++ }, ++ 'input-image-base64': { ++ name: 'Image Input (Base64) Conformance', ++ input: { ++ messages: [ ++ { ++ role: 'user', ++ content: [ ++ { text: 'What text do you see in this image?' }, ++ { ++ media: { ++ url: `data:image/png;base64,${imageBase64}`, ++ contentType: 'image/png', ++ }, ++ }, ++ ], ++ }, ++ ], ++ }, ++ validators: ['text-includes:genkit'], ++ }, ++ 'input-image-url': { ++ name: 'Image Input (URL) Conformance', ++ input: { ++ messages: [ ++ { ++ role: 'user', ++ content: [ ++ { text: 'What is this logo?' }, ++ { ++ media: { ++ url: 'https://www.google.com/images/branding/googlelogo/2x/googlelogo_color_272x92dp.png', ++ contentType: 'image/png', ++ }, ++ }, ++ ], ++ }, ++ ], ++ }, ++ validators: ['text-includes:google'], ++ }, ++ 'input-video-youtube': { ++ name: 'Video Input (YouTube) Conformance', ++ input: { ++ messages: [ ++ { ++ role: 'user', ++ content: [ ++ { text: 'Describe this video.' }, ++ { ++ media: { ++ url: 'https://www.youtube.com/watch?v=3p1P5grjXIQ', ++ contentType: 'video/mp4', ++ }, ++ }, ++ ], ++ }, ++ ], ++ }, ++ validators: ['text-not-empty'], ++ }, ++ 'output-audio': { ++ name: 'Audio Output (TTS) Conformance', ++ input: { ++ messages: [{ role: 'user', content: [{ text: 'Say hello.' }] }], ++ }, ++ validators: ['valid-media:audio'], ++ }, ++ 'output-image': { ++ name: 'Image Output (Generation) Conformance', ++ input: { ++ messages: [ ++ { ++ role: 'user', ++ content: [{ text: 'Generate an image of a cat.' }], ++ }, ++ ], ++ }, ++ validators: ['valid-media:image'], ++ }, ++}; ++ ++async function waitForRuntime(manager: RuntimeManager) { ++ // Poll for runtimes ++ for (let i = 0; i < 20; i++) { ++ if (manager.listRuntimes().length > 0) return; ++ await new Promise((r) => setTimeout(r, 500)); ++ } ++ logger.warn('Runtime not detected after 10 seconds.'); ++} ++ ++async function runTest( ++ manager: RuntimeManager, ++ model: string, ++ testCase: TestCase ++): Promise { ++ logger.info(`Running test: ${testCase.name}...`); ++ try { ++ // Adjust model name if needed (e.g. /model/ prefix) ++ const modelKey = model.startsWith('/') ? model : `/model/${model}`; ++ const actionResponse = await manager.runAction({ ++ key: modelKey, ++ input: testCase.input, ++ }); ++ ++ const response = GenerateResponseSchema.parse(actionResponse.result); ++ ++ for (const v of testCase.validators) { ++ const [valName, ...args] = v.split(':'); ++ const arg = args.join(':'); ++ const validator = VALIDATORS[valName]; ++ if (!validator) throw new Error(`Unknown validator: ${valName}`); ++ validator(response, arg); ++ } ++ ++ logger.info(`✅ Passed: ${testCase.name}`); ++ return true; ++ } catch (e) { ++ if (e instanceof GenkitToolsError) { ++ logger.error( ++ `❌ Failed: ${testCase.name} - ${ ++ e.data?.stack || JSON.stringify(e.data?.details) || e ++ }` ++ ); ++ } else if (e instanceof Error) { ++ logger.error(`❌ Failed: ${testCase.name} - ${e.message}`); ++ } else { ++ logger.error(`❌ Failed: ${testCase.name} - ${JSON.stringify(e)}`); ++ } ++ return false; ++ } ++} ++ ++async function runTestSuite( ++ manager: RuntimeManager, ++ suite: TestSuite, ++ defaultSupports: string[] ++): Promise<{ passed: number; failed: number }> { ++ const supports = suite.supports || (suite.tests ? [] : defaultSupports); ++ ++ logger.info(`Testing model: ${suite.model}`); ++ ++ const promises: Promise[] = []; ++ ++ // Built-in conformance tests ++ for (const support of supports) { ++ const testCase = TEST_CASES[support]; ++ if (testCase) { ++ promises.push(runTest(manager, suite.model, testCase)); ++ } else { ++ logger.warn(`Unknown capability: ${support}`); ++ } ++ } ++ ++ // Custom tests ++ if (suite.tests) { ++ for (const test of suite.tests) { ++ const customTestCase: TestCase = { ++ name: test.name || 'Custom Test', ++ input: test.input, ++ validators: test.validators || [], ++ }; ++ promises.push(runTest(manager, suite.model, customTestCase)); ++ } ++ } ++ ++ const results = await Promise.all(promises); ++ const passed = results.filter((r) => r).length; ++ const failed = results.filter((r) => !r).length; ++ ++ return { passed, failed }; ++} ++ ++export const devTestModel = new Command('dev:test-model') ++ .description('Test a model against the Genkit model specification') ++ .argument('[modelOrCmd]', 'Model name or command') ++ .argument('[args...]', 'Command arguments') ++ .option( ++ '--supports ', ++ 'Comma-separated list of supported capabilities (tool-request, structured-output, multiturn, system-role, input-image-base64, input-image-url, input-video-youtube, output-audio, output-image)', ++ 'tool-request,structured-output,multiturn,system-role,input-image-base64,input-image-url' ++ ) ++ .option('--from-file ', 'Path to a file containing test payloads') ++ .action( ++ async ( ++ modelOrCmd: string | undefined, ++ args: string[] | undefined, ++ options: TestOptions ++ ) => { ++ const projectRoot = await findProjectRoot(); ++ ++ let cmd: string[] = []; ++ let defaultModelName: string | undefined; ++ ++ if (options.fromFile) { ++ if (modelOrCmd) cmd.push(modelOrCmd); ++ if (args) cmd.push(...args); ++ } else { ++ if (!modelOrCmd) { ++ logger.error('Model name is required unless --from-file is used.'); ++ process.exitCode = 1; ++ return; ++ } ++ defaultModelName = modelOrCmd; ++ if (args) cmd = args; ++ } ++ ++ let manager: RuntimeManager; ++ ++ if (cmd.length > 0) { ++ const result = await startDevProcessManager( ++ projectRoot, ++ cmd[0], ++ cmd.slice(1) ++ ); ++ manager = result.manager; ++ } else { ++ manager = await startManager(projectRoot, false); ++ } ++ ++ await waitForRuntime(manager); ++ ++ try { ++ let totalPassed = 0; ++ let totalFailed = 0; ++ ++ let suites: TestSuite[] = []; ++ ++ if (options.fromFile) { ++ const filePath = resolve(projectRoot, options.fromFile); ++ const fileContent = readFileSync(filePath, 'utf-8'); ++ let parsed; ++ if (filePath.endsWith('.yaml') || filePath.endsWith('.yml')) { ++ parsed = parse(fileContent); ++ } else { ++ parsed = JSON.parse(fileContent); ++ } ++ suites = Array.isArray(parsed) ? parsed : [parsed]; ++ } else { ++ if (!defaultModelName) throw new Error('Model name required'); ++ suites = [{ model: defaultModelName }]; ++ } ++ ++ const defaultSupports = options.supports ++ .split(',') ++ .map((s) => s.trim()); ++ ++ for (const suite of suites) { ++ if (!suite.model) { ++ logger.error('Model name required in test suite.'); ++ totalFailed++; ++ continue; ++ } ++ const { passed, failed } = await runTestSuite( ++ manager, ++ suite, ++ defaultSupports ++ ); ++ totalPassed += passed; ++ totalFailed += failed; ++ } ++ ++ logger.info('--------------------------------------------------'); ++ logger.info( ++ `Tests Completed: ${totalPassed} Passed, ${totalFailed} Failed` ++ ); ++ ++ if (totalFailed > 0) { ++ process.exitCode = 1; ++ } ++ } catch (e) { ++ logger.error('Error running tests:', e); ++ process.exitCode = 1; ++ } finally { ++ if (manager) { ++ await manager.stop(); ++ } ++ } ++ } ++ ); +diff --git a/genkit-tools/pnpm-lock.yaml b/genkit-tools/pnpm-lock.yaml +index 015df5848..ea1b78bea 100644 +--- a/genkit-tools/pnpm-lock.yaml ++++ b/genkit-tools/pnpm-lock.yaml +@@ -68,6 +68,9 @@ importers: + semver: + specifier: ^7.7.2 + version: 7.7.2 ++ yaml: ++ specifier: ^2.8.0 ++ version: 2.8.0 + devDependencies: + '@jest/globals': + specifier: ^29.7.0 +diff --git a/go/ai/document_test.go b/go/ai/document_test.go +index ab1e65ab7..a5d4bc9bc 100644 +--- a/go/ai/document_test.go ++++ b/go/ai/document_test.go +@@ -141,3 +141,273 @@ func TestReasoningPartJSON(t *testing.T) { + t.Errorf("unmarshaled reasoning content type = %q, want %q", unmarshaledPart.ContentType, "plain/text") + } + } ++ ++func TestNewDataPart(t *testing.T) { ++ t.Run("creates data part with content", func(t *testing.T) { ++ p := NewDataPart("some binary data") ++ ++ if p.Kind != PartData { ++ t.Errorf("Kind = %v, want %v", p.Kind, PartData) ++ } ++ if p.Text != "some binary data" { ++ t.Errorf("Text = %q, want %q", p.Text, "some binary data") ++ } ++ }) ++ ++ t.Run("creates data part with empty content", func(t *testing.T) { ++ p := NewDataPart("") ++ ++ if p.Kind != PartData { ++ t.Errorf("Kind = %v, want %v", p.Kind, PartData) ++ } ++ if p.Text != "" { ++ t.Errorf("Text = %q, want empty string", p.Text) ++ } ++ }) ++} ++ ++func TestNewCustomPart(t *testing.T) { ++ t.Run("creates custom part with value", func(t *testing.T) { ++ custom := map[string]any{"key": "value", "count": 42} ++ p := NewCustomPart(custom) ++ ++ if p.Kind != PartCustom { ++ t.Errorf("Kind = %v, want %v", p.Kind, PartCustom) ++ } ++ if p.Custom == nil { ++ t.Fatal("Custom is nil") ++ } ++ if p.Custom["key"] != "value" { ++ t.Errorf("Custom[key] = %v, want %q", p.Custom["key"], "value") ++ } ++ }) ++ ++ t.Run("creates custom part with nil value", func(t *testing.T) { ++ p := NewCustomPart(nil) ++ ++ if p.Kind != PartCustom { ++ t.Errorf("Kind = %v, want %v", p.Kind, PartCustom) ++ } ++ if p.Custom != nil { ++ t.Errorf("Custom = %v, want nil", p.Custom) ++ } ++ }) ++} ++ ++func TestPartIsData(t *testing.T) { ++ tests := []struct { ++ name string ++ part *Part ++ want bool ++ }{ ++ {"data part", NewDataPart("{}"), true}, ++ {"text part", NewTextPart("hello"), false}, ++ {"media part", NewMediaPart("image/png", "data:..."), false}, ++ {"nil part", nil, false}, ++ } ++ ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ got := tt.part.IsData() ++ if got != tt.want { ++ t.Errorf("IsData() = %v, want %v", got, tt.want) ++ } ++ }) ++ } ++} ++ ++func TestPartIsInterrupt(t *testing.T) { ++ t.Run("interrupt tool request returns true", func(t *testing.T) { ++ p := &Part{ ++ Kind: PartToolRequest, ++ ToolRequest: &ToolRequest{ ++ Name: "test", ++ Input: map[string]any{}, ++ }, ++ Metadata: map[string]any{ ++ "interrupt": true, ++ }, ++ } ++ ++ if !p.IsInterrupt() { ++ t.Error("IsInterrupt() = false, want true") ++ } ++ }) ++ ++ t.Run("non-interrupt tool request returns false", func(t *testing.T) { ++ p := &Part{ ++ Kind: PartToolRequest, ++ ToolRequest: &ToolRequest{ ++ Name: "test", ++ Input: map[string]any{}, ++ }, ++ } ++ ++ if p.IsInterrupt() { ++ t.Error("IsInterrupt() = true, want false") ++ } ++ }) ++ ++ t.Run("non-tool-request part returns false", func(t *testing.T) { ++ p := NewTextPart("hello") ++ ++ if p.IsInterrupt() { ++ t.Error("IsInterrupt() = true, want false") ++ } ++ }) ++ ++ t.Run("nil part returns false", func(t *testing.T) { ++ var p *Part ++ if p.IsInterrupt() { ++ t.Error("IsInterrupt() = true, want false") ++ } ++ }) ++} ++ ++func TestPartIsCustom(t *testing.T) { ++ tests := []struct { ++ name string ++ part *Part ++ want bool ++ }{ ++ {"custom part", NewCustomPart(map[string]any{"key": "value"}), true}, ++ {"text part", NewTextPart("hello"), false}, ++ {"data part", NewDataPart("data"), false}, ++ {"nil part", nil, false}, ++ } ++ ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ got := tt.part.IsCustom() ++ if got != tt.want { ++ t.Errorf("IsCustom() = %v, want %v", got, tt.want) ++ } ++ }) ++ } ++} ++ ++func TestIsImageContentType(t *testing.T) { ++ tests := []struct { ++ contentType string ++ want bool ++ }{ ++ {"image/png", true}, ++ {"image/jpeg", true}, ++ {"image/gif", true}, ++ {"image/webp", true}, ++ {"data:image/png;base64,...", true}, ++ {"video/mp4", false}, ++ {"audio/mp3", false}, ++ {"text/plain", false}, ++ {"application/json", false}, ++ {"", false}, ++ } ++ ++ for _, tt := range tests { ++ t.Run(tt.contentType, func(t *testing.T) { ++ got := IsImageContentType(tt.contentType) ++ if got != tt.want { ++ t.Errorf("IsImageContentType(%q) = %v, want %v", tt.contentType, got, tt.want) ++ } ++ }) ++ } ++} ++ ++func TestIsVideoContentType(t *testing.T) { ++ tests := []struct { ++ contentType string ++ want bool ++ }{ ++ {"video/mp4", true}, ++ {"video/webm", true}, ++ {"video/mpeg", true}, ++ {"data:video/mp4;base64,...", true}, ++ {"image/png", false}, ++ {"audio/mp3", false}, ++ {"text/plain", false}, ++ {"", false}, ++ } ++ ++ for _, tt := range tests { ++ t.Run(tt.contentType, func(t *testing.T) { ++ got := IsVideoContentType(tt.contentType) ++ if got != tt.want { ++ t.Errorf("IsVideoContentType(%q) = %v, want %v", tt.contentType, got, tt.want) ++ } ++ }) ++ } ++} ++ ++func TestIsAudioContentType(t *testing.T) { ++ tests := []struct { ++ contentType string ++ want bool ++ }{ ++ {"audio/mp3", true}, ++ {"audio/wav", true}, ++ {"audio/ogg", true}, ++ {"audio/mpeg", true}, ++ {"data:audio/mp3;base64,...", true}, ++ {"image/png", false}, ++ {"video/mp4", false}, ++ {"text/plain", false}, ++ {"", false}, ++ } ++ ++ for _, tt := range tests { ++ t.Run(tt.contentType, func(t *testing.T) { ++ got := IsAudioContentType(tt.contentType) ++ if got != tt.want { ++ t.Errorf("IsAudioContentType(%q) = %v, want %v", tt.contentType, got, tt.want) ++ } ++ }) ++ } ++} ++ ++func TestNewResponseForToolRequest(t *testing.T) { ++ t.Run("creates tool response for tool request part", func(t *testing.T) { ++ reqPart := NewToolRequestPart(&ToolRequest{ ++ Name: "calculator", ++ Input: map[string]any{"a": 1, "b": 2}, ++ }) ++ output := map[string]any{"result": 3} ++ ++ resp := NewResponseForToolRequest(reqPart, output) ++ ++ if resp.Kind != PartToolResponse { ++ t.Errorf("Kind = %v, want %v", resp.Kind, PartToolResponse) ++ } ++ if resp.ToolResponse == nil { ++ t.Fatal("ToolResponse is nil") ++ } ++ if resp.ToolResponse.Name != "calculator" { ++ t.Errorf("Name = %q, want %q", resp.ToolResponse.Name, "calculator") ++ } ++ if resp.ToolResponse.Output.(map[string]any)["result"] != 3 { ++ t.Errorf("Output mismatch") ++ } ++ }) ++ ++ t.Run("preserves ref from original request", func(t *testing.T) { ++ reqPart := NewToolRequestPart(&ToolRequest{ ++ Name: "tool", ++ Ref: "request-123", ++ }) ++ ++ resp := NewResponseForToolRequest(reqPart, "output") ++ ++ if resp.ToolResponse.Ref != "request-123" { ++ t.Errorf("Ref = %q, want %q", resp.ToolResponse.Ref, "request-123") ++ } ++ }) ++ ++ t.Run("returns nil for non-tool-request part", func(t *testing.T) { ++ textPart := NewTextPart("not a tool request") ++ ++ resp := NewResponseForToolRequest(textPart, "output") ++ ++ if resp != nil { ++ t.Error("expected nil for non-tool-request part") ++ } ++ }) ++} +diff --git a/go/ai/embedder_test.go b/go/ai/embedder_test.go +new file mode 100644 +index 000000000..43404479f +--- /dev/null ++++ b/go/ai/embedder_test.go +@@ -0,0 +1,400 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++// ++// SPDX-License-Identifier: Apache-2.0 ++ ++package ai ++ ++import ( ++ "context" ++ "errors" ++ "testing" ++ ++ "github.com/google/go-cmp/cmp" ++) ++ ++func TestEmbedderRef(t *testing.T) { ++ t.Run("NewEmbedderRef creates ref with name and config", func(t *testing.T) { ++ config := map[string]any{"dimension": 768} ++ ref := NewEmbedderRef("test/embedder", config) ++ ++ if ref.Name() != "test/embedder" { ++ t.Errorf("Name() = %q, want %q", ref.Name(), "test/embedder") ++ } ++ if diff := cmp.Diff(config, ref.Config()); diff != "" { ++ t.Errorf("Config() mismatch (-want +got):\n%s", diff) ++ } ++ }) ++ ++ t.Run("NewEmbedderRef with nil config", func(t *testing.T) { ++ ref := NewEmbedderRef("test/embedder", nil) ++ ++ if ref.Name() != "test/embedder" { ++ t.Errorf("Name() = %q, want %q", ref.Name(), "test/embedder") ++ } ++ if ref.Config() != nil { ++ t.Errorf("Config() = %v, want nil", ref.Config()) ++ } ++ }) ++} ++ ++func TestNewEmbedder(t *testing.T) { ++ t.Run("creates embedder with valid name", func(t *testing.T) { ++ e := NewEmbedder("test/embedder", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { ++ return &EmbedResponse{}, nil ++ }) ++ ++ if e == nil { ++ t.Fatal("expected embedder, got nil") ++ } ++ if e.Name() != "test/embedder" { ++ t.Errorf("Name() = %q, want %q", e.Name(), "test/embedder") ++ } ++ }) ++ ++ t.Run("panics with empty name", func(t *testing.T) { ++ assertPanic(t, func() { ++ NewEmbedder("", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { ++ return &EmbedResponse{}, nil ++ }) ++ }, "name is required") ++ }) ++ ++ t.Run("applies options correctly", func(t *testing.T) { ++ opts := &EmbedderOptions{ ++ Label: "Test Embedder", ++ Dimensions: 768, ++ Supports: &EmbedderSupports{ ++ Input: []string{"text", "image"}, ++ Multilingual: true, ++ }, ++ ConfigSchema: map[string]any{ ++ "type": "object", ++ "properties": map[string]any{ ++ "temperature": map[string]any{"type": "number"}, ++ }, ++ }, ++ } ++ ++ e := NewEmbedder("test/embedder", opts, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { ++ return &EmbedResponse{}, nil ++ }) ++ ++ if e == nil { ++ t.Fatal("expected embedder, got nil") ++ } ++ }) ++ ++ t.Run("uses defaults when options nil", func(t *testing.T) { ++ e := NewEmbedder("test/embedder", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { ++ return &EmbedResponse{}, nil ++ }) ++ ++ if e == nil { ++ t.Fatal("expected embedder, got nil") ++ } ++ }) ++} ++ ++func TestDefineEmbedder(t *testing.T) { ++ t.Run("registers and returns embedder", func(t *testing.T) { ++ r := newTestRegistry(t) ++ called := false ++ ++ e := DefineEmbedder(r, "test/defineEmbedder", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { ++ called = true ++ return &EmbedResponse{ ++ Embeddings: []*Embedding{{Embedding: []float32{0.1, 0.2, 0.3}}}, ++ }, nil ++ }) ++ ++ if e == nil { ++ t.Fatal("expected embedder, got nil") ++ } ++ ++ // Verify it's registered by looking it up ++ found := LookupEmbedder(r, "test/defineEmbedder") ++ if found == nil { ++ t.Fatal("LookupEmbedder returned nil for registered embedder") ++ } ++ ++ // Verify the function works ++ resp, err := e.Embed(context.Background(), &EmbedRequest{ ++ Input: []*Document{DocumentFromText("test", nil)}, ++ }) ++ assertNoError(t, err) ++ if !called { ++ t.Error("embedder function was not called") ++ } ++ if len(resp.Embeddings) != 1 { ++ t.Errorf("len(Embeddings) = %d, want 1", len(resp.Embeddings)) ++ } ++ }) ++} ++ ++func TestLookupEmbedder(t *testing.T) { ++ t.Run("returns embedder when found", func(t *testing.T) { ++ r := newTestRegistry(t) ++ DefineEmbedder(r, "test/lookupEmbedder", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { ++ return &EmbedResponse{}, nil ++ }) ++ ++ e := LookupEmbedder(r, "test/lookupEmbedder") ++ if e == nil { ++ t.Error("expected embedder, got nil") ++ } ++ }) ++ ++ t.Run("returns nil when not found", func(t *testing.T) { ++ r := newTestRegistry(t) ++ ++ e := LookupEmbedder(r, "nonexistent") ++ if e != nil { ++ t.Error("expected nil for non-existent embedder") ++ } ++ }) ++} ++ ++func TestEmbedderEmbed(t *testing.T) { ++ t.Run("embeds documents successfully", func(t *testing.T) { ++ r := newTestRegistry(t) ++ var capturedReq *EmbedRequest ++ ++ e := DefineEmbedder(r, "test/embedDocuments", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { ++ capturedReq = req ++ embeddings := make([]*Embedding, len(req.Input)) ++ for i := range req.Input { ++ embeddings[i] = &Embedding{ ++ Embedding: []float32{float32(i) * 0.1, float32(i) * 0.2, float32(i) * 0.3}, ++ } ++ } ++ return &EmbedResponse{Embeddings: embeddings}, nil ++ }) ++ ++ docs := []*Document{ ++ DocumentFromText("first document", nil), ++ DocumentFromText("second document", nil), ++ } ++ ++ resp, err := e.Embed(context.Background(), &EmbedRequest{Input: docs}) ++ assertNoError(t, err) ++ ++ if len(capturedReq.Input) != 2 { ++ t.Errorf("captured input len = %d, want 2", len(capturedReq.Input)) ++ } ++ if len(resp.Embeddings) != 2 { ++ t.Errorf("len(Embeddings) = %d, want 2", len(resp.Embeddings)) ++ } ++ }) ++ ++ t.Run("returns error on nil embedder", func(t *testing.T) { ++ var e *embedder ++ _, err := e.Embed(context.Background(), &EmbedRequest{}) ++ if err == nil { ++ t.Error("expected error for nil embedder") ++ } ++ }) ++ ++ t.Run("propagates function errors", func(t *testing.T) { ++ r := newTestRegistry(t) ++ expectedErr := errors.New("embedding failed") ++ ++ e := DefineEmbedder(r, "test/embedError", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { ++ return nil, expectedErr ++ }) ++ ++ _, err := e.Embed(context.Background(), &EmbedRequest{ ++ Input: []*Document{DocumentFromText("test", nil)}, ++ }) ++ if err == nil { ++ t.Error("expected error, got nil") ++ } ++ }) ++ ++ t.Run("passes options through request", func(t *testing.T) { ++ r := newTestRegistry(t) ++ var capturedOpts any ++ ++ e := DefineEmbedder(r, "test/embedOpts", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { ++ capturedOpts = req.Options ++ return &EmbedResponse{Embeddings: []*Embedding{{Embedding: []float32{0.1}}}}, nil ++ }) ++ ++ opts := map[string]any{"dimension": 768} ++ _, err := e.Embed(context.Background(), &EmbedRequest{ ++ Input: []*Document{DocumentFromText("test", nil)}, ++ Options: opts, ++ }) ++ assertNoError(t, err) ++ ++ if diff := cmp.Diff(opts, capturedOpts); diff != "" { ++ t.Errorf("Options mismatch (-want +got):\n%s", diff) ++ } ++ }) ++} ++ ++func TestEmbedFunction(t *testing.T) { ++ t.Run("embeds with embedder directly", func(t *testing.T) { ++ r := newTestRegistry(t) ++ e := DefineEmbedder(r, "test/embedFunc", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { ++ return &EmbedResponse{ ++ Embeddings: []*Embedding{{Embedding: []float32{0.1, 0.2, 0.3}}}, ++ }, nil ++ }) ++ ++ resp, err := Embed(context.Background(), r, ++ WithEmbedder(e), ++ WithTextDocs("test document"), ++ ) ++ assertNoError(t, err) ++ ++ if len(resp.Embeddings) != 1 { ++ t.Errorf("len(Embeddings) = %d, want 1", len(resp.Embeddings)) ++ } ++ }) ++ ++ t.Run("embeds with embedder ref", func(t *testing.T) { ++ r := newTestRegistry(t) ++ DefineEmbedder(r, "test/embedFuncRef", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { ++ return &EmbedResponse{ ++ Embeddings: []*Embedding{{Embedding: []float32{0.1, 0.2, 0.3}}}, ++ }, nil ++ }) ++ ++ ref := NewEmbedderRef("test/embedFuncRef", nil) ++ resp, err := Embed(context.Background(), r, ++ WithEmbedder(ref), ++ WithTextDocs("test document"), ++ ) ++ assertNoError(t, err) ++ ++ if len(resp.Embeddings) != 1 { ++ t.Errorf("len(Embeddings) = %d, want 1", len(resp.Embeddings)) ++ } ++ }) ++ ++ t.Run("embeds with embedder name", func(t *testing.T) { ++ r := newTestRegistry(t) ++ DefineEmbedder(r, "test/embedFuncName", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { ++ return &EmbedResponse{ ++ Embeddings: []*Embedding{{Embedding: []float32{0.1, 0.2, 0.3}}}, ++ }, nil ++ }) ++ ++ resp, err := Embed(context.Background(), r, ++ WithEmbedderName("test/embedFuncName"), ++ WithTextDocs("test document"), ++ ) ++ assertNoError(t, err) ++ ++ if len(resp.Embeddings) != 1 { ++ t.Errorf("len(Embeddings) = %d, want 1", len(resp.Embeddings)) ++ } ++ }) ++ ++ t.Run("uses config from EmbedderRef", func(t *testing.T) { ++ r := newTestRegistry(t) ++ var capturedOpts any ++ ++ DefineEmbedder(r, "test/embedRefConfig", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { ++ capturedOpts = req.Options ++ return &EmbedResponse{Embeddings: []*Embedding{{Embedding: []float32{0.1}}}}, nil ++ }) ++ ++ config := map[string]any{"dimension": 768} ++ ref := NewEmbedderRef("test/embedRefConfig", config) ++ ++ _, err := Embed(context.Background(), r, ++ WithEmbedder(ref), ++ WithTextDocs("test"), ++ ) ++ assertNoError(t, err) ++ ++ if diff := cmp.Diff(config, capturedOpts); diff != "" { ++ t.Errorf("Options mismatch (-want +got):\n%s", diff) ++ } ++ }) ++ ++ t.Run("explicit config overrides EmbedderRef config", func(t *testing.T) { ++ r := newTestRegistry(t) ++ var capturedOpts any ++ ++ DefineEmbedder(r, "test/embedOverrideConfig", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { ++ capturedOpts = req.Options ++ return &EmbedResponse{Embeddings: []*Embedding{{Embedding: []float32{0.1}}}}, nil ++ }) ++ ++ refConfig := map[string]any{"dimension": 768} ++ explicitConfig := map[string]any{"dimension": 512} ++ ref := NewEmbedderRef("test/embedOverrideConfig", refConfig) ++ ++ _, err := Embed(context.Background(), r, ++ WithEmbedder(ref), ++ WithConfig(explicitConfig), ++ WithTextDocs("test"), ++ ) ++ assertNoError(t, err) ++ ++ if diff := cmp.Diff(explicitConfig, capturedOpts); diff != "" { ++ t.Errorf("Options mismatch (-want +got):\n%s", diff) ++ } ++ }) ++ ++ t.Run("returns error when embedder not set", func(t *testing.T) { ++ r := newTestRegistry(t) ++ ++ _, err := Embed(context.Background(), r, ++ WithTextDocs("test"), ++ ) ++ assertError(t, err, "embedder must be set") ++ }) ++ ++ t.Run("returns error when embedder not found", func(t *testing.T) { ++ r := newTestRegistry(t) ++ ++ _, err := Embed(context.Background(), r, ++ WithEmbedderName("nonexistent"), ++ WithTextDocs("test"), ++ ) ++ assertError(t, err, "embedder not found") ++ }) ++ ++ t.Run("embeds with document options", func(t *testing.T) { ++ r := newTestRegistry(t) ++ var capturedDocs []*Document ++ ++ DefineEmbedder(r, "test/embedDocs", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { ++ capturedDocs = req.Input ++ embeddings := make([]*Embedding, len(req.Input)) ++ for i := range req.Input { ++ embeddings[i] = &Embedding{Embedding: []float32{0.1}} ++ } ++ return &EmbedResponse{Embeddings: embeddings}, nil ++ }) ++ ++ doc := DocumentFromText("custom document", map[string]any{"custom": "metadata"}) ++ _, err := Embed(context.Background(), r, ++ WithEmbedderName("test/embedDocs"), ++ WithDocs(doc), ++ ) ++ assertNoError(t, err) ++ ++ if len(capturedDocs) != 1 { ++ t.Fatalf("len(docs) = %d, want 1", len(capturedDocs)) ++ } ++ if capturedDocs[0].Metadata["custom"] != "metadata" { ++ t.Error("document metadata not passed correctly") ++ } ++ }) ++} +diff --git a/go/ai/evaluator_test.go b/go/ai/evaluator_test.go +index 9ee268d3d..6cf5c5895 100644 +--- a/go/ai/evaluator_test.go ++++ b/go/ai/evaluator_test.go +@@ -207,3 +207,161 @@ func TestBatchEvaluator(t *testing.T) { + t.Errorf("got %v, want %v", got, want) + } + } ++ ++func TestNewEvaluatorRef(t *testing.T) { ++ t.Run("creates evaluator reference with name and config", func(t *testing.T) { ++ config := map[string]any{"threshold": 0.8} ++ ref := NewEvaluatorRef("test/myEvaluator", config) ++ ++ if ref.Name() != "test/myEvaluator" { ++ t.Errorf("Name() = %q, want %q", ref.Name(), "test/myEvaluator") ++ } ++ if ref.Config() == nil { ++ t.Error("Config() = nil, want config") ++ } ++ if ref.Config().(map[string]any)["threshold"] != 0.8 { ++ t.Errorf("Config()[threshold] = %v, want 0.8", ref.Config().(map[string]any)["threshold"]) ++ } ++ }) ++ ++ t.Run("creates evaluator reference with nil config", func(t *testing.T) { ++ ref := NewEvaluatorRef("test/simpleEvaluator", nil) ++ ++ if ref.Name() != "test/simpleEvaluator" { ++ t.Errorf("Name() = %q, want %q", ref.Name(), "test/simpleEvaluator") ++ } ++ if ref.Config() != nil { ++ t.Errorf("Config() = %v, want nil", ref.Config()) ++ } ++ }) ++ ++ t.Run("implements EvaluatorArg interface", func(t *testing.T) { ++ ref := NewEvaluatorRef("test/interface", nil) ++ var _ EvaluatorArg = ref // compile-time check ++ ++ if ref.Name() != "test/interface" { ++ t.Errorf("Name() = %q, want %q", ref.Name(), "test/interface") ++ } ++ }) ++} ++ ++func TestEvaluatorRefUsedWithEvaluate(t *testing.T) { ++ r := registry.New() ++ ++ // Define evaluator that uses config ++ DefineEvaluator(r, "test/configEvaluator", &evalOpts, func(ctx context.Context, req *EvaluatorCallbackRequest) (*EvaluatorCallbackResponse, error) { ++ score := Score{ ++ Id: "configScore", ++ Score: 1, ++ Status: ScoreStatusPass.String(), ++ Details: map[string]any{"options": req.Options}, ++ } ++ return &EvaluatorCallbackResponse{ ++ TestCaseId: req.Input.TestCaseId, ++ Evaluation: []Score{score}, ++ }, nil ++ }) ++ ++ // Use EvaluatorRef instead of direct evaluator ++ ref := NewEvaluatorRef("test/configEvaluator", "ref-config-value") ++ ++ resp, err := Evaluate(context.Background(), r, ++ WithEvaluator(ref), ++ WithDataset(&Example{Input: "test"}), ++ WithID("testrun")) ++ if err != nil { ++ t.Fatal(err) ++ } ++ ++ // Config from ref should be used since no explicit config was provided ++ if got, want := (*resp)[0].Evaluation[0].Details["options"], "ref-config-value"; got != want { ++ t.Errorf("got config %v, want %v", got, want) ++ } ++} ++ ++func TestScoreStatusString(t *testing.T) { ++ tests := []struct { ++ status ScoreStatus ++ want string ++ }{ ++ {ScoreStatusUnknown, "UNKNOWN"}, ++ {ScoreStatusFail, "FAIL"}, ++ {ScoreStatusPass, "PASS"}, ++ } ++ ++ for _, tt := range tests { ++ t.Run(tt.want, func(t *testing.T) { ++ got := tt.status.String() ++ if got != tt.want { ++ t.Errorf("String() = %q, want %q", got, tt.want) ++ } ++ }) ++ } ++} ++ ++func TestNewEvaluator(t *testing.T) { ++ t.Run("panics with empty name", func(t *testing.T) { ++ defer func() { ++ if r := recover(); r == nil { ++ t.Error("expected panic for empty name") ++ } ++ }() ++ ++ NewEvaluator("", &evalOpts, testEvalFunc) ++ }) ++ ++ t.Run("creates evaluator with nil options", func(t *testing.T) { ++ eval := NewEvaluator("test/nilOpts", nil, testEvalFunc) ++ if eval == nil { ++ t.Error("NewEvaluator returned nil") ++ } ++ if eval.Name() != "test/nilOpts" { ++ t.Errorf("Name() = %q, want %q", eval.Name(), "test/nilOpts") ++ } ++ }) ++} ++ ++func TestNewBatchEvaluator(t *testing.T) { ++ t.Run("panics with empty name", func(t *testing.T) { ++ defer func() { ++ if r := recover(); r == nil { ++ t.Error("expected panic for empty name") ++ } ++ }() ++ ++ NewBatchEvaluator("", &evalOpts, testBatchEvalFunc) ++ }) ++ ++ t.Run("creates batch evaluator with nil options", func(t *testing.T) { ++ eval := NewBatchEvaluator("test/batchNilOpts", nil, testBatchEvalFunc) ++ if eval == nil { ++ t.Error("NewBatchEvaluator returned nil") ++ } ++ }) ++} ++ ++func TestEvaluateNilEvaluator(t *testing.T) { ++ t.Run("returns error when evaluator not set", func(t *testing.T) { ++ r := registry.New() ++ ++ _, err := Evaluate(context.Background(), r, ++ WithDataset(&Example{Input: "test"})) ++ ++ if err == nil { ++ t.Error("expected error when evaluator not set, got nil") ++ } ++ }) ++ ++ t.Run("returns error for non-existent evaluator", func(t *testing.T) { ++ r := registry.New() ++ ++ ref := NewEvaluatorRef("test/nonexistent", nil) ++ _, err := Evaluate(context.Background(), r, ++ WithEvaluator(ref), ++ WithDataset(&Example{Input: "test"})) ++ ++ if err == nil { ++ t.Error("expected error for non-existent evaluator, got nil") ++ } ++ }) ++} +diff --git a/go/ai/example_test.go b/go/ai/example_test.go +new file mode 100644 +index 000000000..afcdbbc6b +--- /dev/null ++++ b/go/ai/example_test.go +@@ -0,0 +1,160 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++// ++// SPDX-License-Identifier: Apache-2.0 ++ ++// Package ai_test provides examples for ai package helper functions. ++// ++// The ai package contains helper types and functions used with genkit. ++// Most generation and definition functions are in the genkit package; ++// see that package for the primary API documentation. ++package ai_test ++ ++import ( ++ "fmt" ++ ++ "github.com/firebase/genkit/go/ai" ++) ++ ++// This example demonstrates creating different types of message parts. ++func ExampleNewTextPart() { ++ // Create a text part ++ part := ai.NewTextPart("Hello, world!") ++ fmt.Println(part.Text) ++ // Output: Hello, world! ++} ++ ++// This example demonstrates creating a message with text content. ++func ExampleNewUserTextMessage() { ++ // Create a user message with text ++ msg := ai.NewUserTextMessage("What is the capital of France?") ++ fmt.Println("Role:", msg.Role) ++ fmt.Println("Text:", msg.Content[0].Text) ++ // Output: ++ // Role: user ++ // Text: What is the capital of France? ++} ++ ++// This example demonstrates creating system and model messages. ++func ExampleNewSystemTextMessage() { ++ // Create a system message ++ sysMsg := ai.NewSystemTextMessage("You are a helpful assistant.") ++ fmt.Println("System role:", sysMsg.Role) ++ ++ // Create a model response message ++ modelMsg := ai.NewModelTextMessage("I'm here to help!") ++ fmt.Println("Model role:", modelMsg.Role) ++ // Output: ++ // System role: system ++ // Model role: model ++} ++ ++// This example demonstrates creating a data part for raw string content. ++func ExampleNewDataPart() { ++ // Create a data part with raw string content ++ part := ai.NewDataPart(`{"name": "Alice", "age": 30}`) ++ fmt.Println("Is data part:", part.IsData()) ++ fmt.Println("Content:", part.Text) ++ // Output: ++ // Is data part: true ++ // Content: {"name": "Alice", "age": 30} ++} ++ ++// This example demonstrates accessing text from a Part. ++func ExamplePart_Text() { ++ // Create a part with text ++ part := ai.NewTextPart("Sample text content") ++ ++ // Access the text field directly ++ fmt.Println(part.Text) ++ // Output: Sample text content ++} ++ ++// This example demonstrates the Document type used in RAG applications. ++func ExampleDocument() { ++ // Create a document with text content ++ doc := &ai.Document{ ++ Content: []*ai.Part{ ++ ai.NewTextPart("This is the document content."), ++ }, ++ Metadata: map[string]any{ ++ "source": "knowledge-base", ++ "page": 42, ++ }, ++ } ++ ++ fmt.Println("Content:", doc.Content[0].Text) ++ fmt.Println("Source:", doc.Metadata["source"]) ++ // Output: ++ // Content: This is the document content. ++ // Source: knowledge-base ++} ++ ++// This example demonstrates creating an Embedding for vector search. ++func ExampleEmbedding() { ++ // Create an embedding (typically returned by an embedder) ++ embedding := &ai.Embedding{ ++ Embedding: []float32{0.1, 0.2, 0.3, 0.4, 0.5}, ++ Metadata: map[string]any{ ++ "source": "document-1", ++ }, ++ } ++ ++ fmt.Printf("Embedding dimensions: %d\n", len(embedding.Embedding)) ++ fmt.Printf("First value: %.1f\n", embedding.Embedding[0]) ++ // Output: ++ // Embedding dimensions: 5 ++ // First value: 0.1 ++} ++ ++// This example demonstrates creating a media part for images or other media. ++func ExampleNewMediaPart() { ++ // Create a media part with base64-encoded image data ++ // In practice, you would encode actual image bytes ++ imageData := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJ..." ++ part := ai.NewMediaPart("image/png", imageData) ++ ++ fmt.Println("Is media:", part.IsMedia()) ++ fmt.Println("Content type:", part.ContentType) ++ // Output: ++ // Is media: true ++ // Content type: image/png ++} ++ ++// This example demonstrates creating a model reference with configuration. ++func ExampleNewModelRef() { ++ // Create a reference to a model with custom configuration ++ // The config type depends on the model provider ++ modelRef := ai.NewModelRef("googleai/gemini-2.5-flash", map[string]any{ ++ "temperature": 0.7, ++ }) ++ ++ fmt.Println("Model name:", modelRef.Name()) ++ // Output: Model name: googleai/gemini-2.5-flash ++} ++ ++// This example demonstrates building a multi-turn conversation. ++func ExampleNewUserMessage() { ++ // Build a conversation with multiple parts ++ userMsg := ai.NewUserMessage( ++ ai.NewTextPart("What's in this image?"), ++ ai.NewMediaPart("image/jpeg", "base64data..."), ++ ) ++ ++ fmt.Println("Role:", userMsg.Role) ++ fmt.Println("Parts:", len(userMsg.Content)) ++ // Output: ++ // Role: user ++ // Parts: 2 ++} +diff --git a/go/ai/formatter_test.go b/go/ai/formatter_test.go +index 5f75d77ac..e52162cfd 100644 +--- a/go/ai/formatter_test.go ++++ b/go/ai/formatter_test.go +@@ -1031,3 +1031,72 @@ func TestDefaultFormats(t *testing.T) { + }) + } + } ++ ++func TestArrayFormatterParseMessage(t *testing.T) { ++ schema := map[string]any{ ++ "type": "array", ++ "items": map[string]any{ ++ "type": "object", ++ "properties": map[string]any{ ++ "id": map[string]any{"type": "integer"}, ++ }, ++ }, ++ } ++ ++ t.Run("returns message unchanged", func(t *testing.T) { ++ handler, err := arrayFormatter{}.Handler(schema) ++ if err != nil { ++ t.Fatalf("Handler() error = %v", err) ++ } ++ ++ msg := &Message{ ++ Role: RoleModel, ++ Content: []*Part{NewTextPart(`[{"id": 1}, {"id": 2}]`)}, ++ } ++ ++ got, err := handler.ParseMessage(msg) ++ if err != nil { ++ t.Fatalf("ParseMessage() error = %v", err) ++ } ++ ++ // Array formatter's ParseMessage returns the message unchanged ++ if got != msg { ++ t.Error("ParseMessage() should return the same message object") ++ } ++ }) ++} ++ ++func TestJSONLFormatterParseMessage(t *testing.T) { ++ schema := map[string]any{ ++ "type": "array", ++ "items": map[string]any{ ++ "type": "object", ++ "properties": map[string]any{ ++ "id": map[string]any{"type": "integer"}, ++ "name": map[string]any{"type": "string"}, ++ }, ++ }, ++ } ++ ++ t.Run("returns message unchanged", func(t *testing.T) { ++ handler, err := jsonlFormatter{}.Handler(schema) ++ if err != nil { ++ t.Fatalf("Handler() error = %v", err) ++ } ++ ++ msg := &Message{ ++ Role: RoleModel, ++ Content: []*Part{NewTextPart("{\"id\": 1, \"name\": \"Alice\"}\n{\"id\": 2, \"name\": \"Bob\"}")}, ++ } ++ ++ got, err := handler.ParseMessage(msg) ++ if err != nil { ++ t.Fatalf("ParseMessage() error = %v", err) ++ } ++ ++ // JSONL formatter's ParseMessage returns the message unchanged ++ if got != msg { ++ t.Error("ParseMessage() should return the same message object") ++ } ++ }) ++} +diff --git a/go/ai/gen.go b/go/ai/gen.go +index cbfe5cd9c..e391ef221 100644 +--- a/go/ai/gen.go ++++ b/go/ai/gen.go +@@ -18,97 +18,45 @@ + + package ai + +-type BaseDataPoint struct { +- Context map[string]any `json:"context,omitempty"` +- Input map[string]any `json:"input,omitempty"` +- Output map[string]any `json:"output,omitempty"` +- Reference map[string]any `json:"reference,omitempty"` +- TestCaseID string `json:"testCaseId,omitempty"` +- TraceIDs []string `json:"traceIds,omitempty"` +-} +- +-type BaseEvalDataPoint struct { +- Context map[string]any `json:"context,omitempty"` +- Input map[string]any `json:"input,omitempty"` +- Output map[string]any `json:"output,omitempty"` +- Reference map[string]any `json:"reference,omitempty"` +- TestCaseID string `json:"testCaseId,omitempty"` +- TraceIDs []string `json:"traceIds,omitempty"` +-} +- +-type CandidateError struct { +- Code CandidateErrorCode `json:"code,omitempty"` +- Index float64 `json:"index,omitempty"` +- Message string `json:"message,omitempty"` +-} +- +-type CandidateErrorCode string +- +-const ( +- CandidateErrorCodeBlocked CandidateErrorCode = "blocked" +- CandidateErrorCodeOther CandidateErrorCode = "other" +- CandidateErrorCodeUnknown CandidateErrorCode = "unknown" +-) +- +-type CommonRerankerOptions struct { +- // Number of documents to rerank +- K float64 `json:"k,omitempty"` +-} +- +-type CommonRetrieverOptions struct { +- // Number of documents to retrieve +- K float64 `json:"k,omitempty"` +-} +- + type customPart struct { +- Custom map[string]any `json:"custom,omitempty"` +- Data any `json:"data,omitempty"` ++ // Custom contains custom key-value data specific to this part. ++ Custom map[string]any `json:"custom,omitempty"` ++ // Data contains additional arbitrary data. ++ Data any `json:"data,omitempty"` ++ // Metadata contains arbitrary key-value data for this part. + Metadata map[string]any `json:"metadata,omitempty"` + } + + type dataPart struct { +- Data any `json:"data,omitempty"` ++ // Data contains arbitrary structured data. ++ Data any `json:"data,omitempty"` ++ // Metadata contains arbitrary key-value data for this part. + Metadata map[string]any `json:"metadata,omitempty"` + } + ++// EmbedRequest represents a request to generate embeddings for documents. + type EmbedRequest struct { +- Input []*Document `json:"input,omitempty"` +- Options any `json:"options,omitempty"` ++ // Input is the array of documents to generate embeddings for. ++ Input []*Document `json:"input,omitempty"` ++ // Options contains embedder-specific configuration parameters. ++ Options any `json:"options,omitempty"` + } + ++// EmbedResponse contains the generated embeddings from an embed request. + type EmbedResponse struct { ++ // Embeddings is the array of generated embedding vectors with metadata. + Embeddings []*Embedding `json:"embeddings,omitempty"` + } + ++// Embedding represents a vector embedding with associated metadata. + type Embedding struct { +- Embedding []float32 `json:"embedding,omitempty"` +- Metadata map[string]any `json:"metadata,omitempty"` +-} +- +-type EvalFnResponse struct { +- Evaluation any `json:"evaluation,omitempty"` +- SampleIndex float64 `json:"sampleIndex,omitempty"` +- SpanID string `json:"spanId,omitempty"` +- TestCaseID string `json:"testCaseId,omitempty"` +- TraceID string `json:"traceId,omitempty"` +-} +- +-type EvalRequest struct { +- Dataset []*BaseDataPoint `json:"dataset,omitempty"` +- EvalRunID string `json:"evalRunId,omitempty"` +- Options any `json:"options,omitempty"` ++ // Embedding is the vector representation of the input. ++ Embedding []float32 `json:"embedding,omitempty"` ++ // Metadata identifies which part of a document this embedding corresponds to. ++ Metadata map[string]any `json:"metadata,omitempty"` + } + +-type EvalResponse []any +- +-type EvalStatusEnum string +- +-const ( +- EvalStatusEnumUNKNOWN EvalStatusEnum = "UNKNOWN" +- EvalStatusEnumPASS EvalStatusEnum = "PASS" +- EvalStatusEnumFAIL EvalStatusEnum = "FAIL" +-) +- ++// FinishReason indicates why generation stopped. + type FinishReason string + + const ( +@@ -120,26 +68,47 @@ const ( + FinishReasonUnknown FinishReason = "unknown" + ) + ++// GenerateActionOptions holds configuration for a generate action request. + type GenerateActionOptions struct { +- Config any `json:"config,omitempty"` +- Docs []*Document `json:"docs,omitempty"` +- MaxTurns int `json:"maxTurns,omitempty"` +- Messages []*Message `json:"messages,omitempty"` +- Model string `json:"model,omitempty"` +- Output *GenerateActionOutputConfig `json:"output,omitempty"` +- Resume *GenerateActionResume `json:"resume,omitempty"` +- ReturnToolRequests bool `json:"returnToolRequests,omitempty"` +- StepName string `json:"stepName,omitempty"` +- ToolChoice ToolChoice `json:"toolChoice,omitempty"` +- Tools []string `json:"tools,omitempty"` +-} +- ++ // Config contains configuration parameters for the generation request. ++ Config any `json:"config,omitempty"` ++ // Docs provides retrieved documents to be used as context for this generation. ++ Docs []*Document `json:"docs,omitempty"` ++ // MaxTurns is the maximum number of tool call iterations that can be performed ++ // in a single generate call. Defaults to 5. ++ MaxTurns int `json:"maxTurns,omitempty"` ++ // Messages contains the conversation history for multi-turn prompting when supported. ++ Messages []*Message `json:"messages,omitempty"` ++ // Model is a model name (e.g., "vertexai/gemini-1.0-pro"). ++ Model string `json:"model,omitempty"` ++ // Output specifies the desired output format. Defaults to the model's default if unspecified. ++ Output *GenerateActionOutputConfig `json:"output,omitempty"` ++ // Resume provides options for resuming an interrupted generation. ++ Resume *GenerateActionResume `json:"resume,omitempty"` ++ // ReturnToolRequests, when true, returns tool calls for manual processing instead of ++ // automatically resolving them. ++ ReturnToolRequests bool `json:"returnToolRequests,omitempty"` ++ // StepName is a custom step name for this generate call to display in trace views. ++ // Defaults to "generate". ++ StepName string `json:"stepName,omitempty"` ++ // ToolChoice controls tool calling mode. Auto lets the model decide, required forces ++ // the model to choose a tool, and none forces the model not to use any tools. Defaults to auto. ++ ToolChoice ToolChoice `json:"toolChoice,omitempty"` ++ // Tools is a list of registered tool names for this generation if supported. ++ Tools []string `json:"tools,omitempty"` ++} ++ ++// GenerateActionResume holds options for resuming an interrupted generation. + type GenerateActionResume struct { +- Metadata map[string]any `json:"metadata,omitempty"` +- Respond []*toolResponsePart `json:"respond,omitempty"` +- Restart []*toolRequestPart `json:"restart,omitempty"` ++ // Metadata contains additional context for resuming the generation. ++ Metadata map[string]any `json:"metadata,omitempty"` ++ // Respond contains tool response parts to send to the model when resuming. ++ Respond []*toolResponsePart `json:"respond,omitempty"` ++ // Restart contains tool request parts to restart when resuming. ++ Restart []*toolRequestPart `json:"restart,omitempty"` + } + ++// ToolChoice controls how the model uses tools. + type ToolChoice string + + const ( +@@ -148,67 +117,113 @@ const ( + ToolChoiceNone ToolChoice = "none" + ) + ++// GenerateActionOutputConfig specifies the desired output format for a generate action. + type GenerateActionOutputConfig struct { +- Constrained bool `json:"constrained,omitempty"` +- ContentType string `json:"contentType,omitempty"` +- Format string `json:"format,omitempty"` +- Instructions *string `json:"instructions,omitempty"` +- JsonSchema map[string]any `json:"jsonSchema,omitempty"` ++ // Constrained indicates whether to enforce strict adherence to the schema. ++ Constrained bool `json:"constrained,omitempty"` ++ // ContentType specifies the MIME type of the output content. ++ ContentType string `json:"contentType,omitempty"` ++ // Format specifies the desired output format (e.g., "json", "text"). ++ Format string `json:"format,omitempty"` ++ // Instructions provides additional guidance for the output format. ++ Instructions *string `json:"instructions,omitempty"` ++ // JsonSchema is a JSON Schema describing the desired structure of JSON output. ++ JsonSchema map[string]any `json:"jsonSchema,omitempty"` + } + +-// GenerationCommonConfig holds configuration for generation. ++// GenerationCommonConfig holds configuration parameters for model generation requests. + type GenerationCommonConfig struct { +- MaxOutputTokens int `json:"maxOutputTokens,omitempty"` +- StopSequences []string `json:"stopSequences,omitempty"` +- Temperature float64 `json:"temperature,omitempty"` +- TopK int `json:"topK,omitempty"` +- TopP float64 `json:"topP,omitempty"` +- Version string `json:"version,omitempty"` +-} +- +-// GenerationUsage provides information about the generation process. ++ // MaxOutputTokens limits the maximum number of tokens generated in the response. ++ MaxOutputTokens int `json:"maxOutputTokens,omitempty"` ++ // StopSequences specifies sequences that will cause generation to stop when encountered. ++ StopSequences []string `json:"stopSequences,omitempty"` ++ // Temperature controls randomness in generation. Higher values (e.g., 0.9) make output more random, ++ // while lower values (e.g., 0.1) make it more deterministic. Typical range is 0.0 to 1.0. ++ Temperature float64 `json:"temperature,omitempty"` ++ // TopK limits sampling to the K most likely tokens at each step. ++ TopK int `json:"topK,omitempty"` ++ // TopP (nucleus sampling) limits sampling to tokens whose cumulative probability exceeds P. ++ TopP float64 `json:"topP,omitempty"` ++ // Version specifies a particular version of a model family, ++ // e.g., "gemini-1.0-pro-001" for the "gemini-1.0-pro" family. ++ Version string `json:"version,omitempty"` ++} ++ ++// GenerationUsage provides information about resource consumption during generation. + type GenerationUsage struct { +- CachedContentTokens int `json:"cachedContentTokens,omitempty"` +- Custom map[string]float64 `json:"custom,omitempty"` +- InputAudioFiles int `json:"inputAudioFiles,omitempty"` +- InputCharacters int `json:"inputCharacters,omitempty"` +- InputImages int `json:"inputImages,omitempty"` +- InputTokens int `json:"inputTokens,omitempty"` +- InputVideos int `json:"inputVideos,omitempty"` +- OutputAudioFiles int `json:"outputAudioFiles,omitempty"` +- OutputCharacters int `json:"outputCharacters,omitempty"` +- OutputImages int `json:"outputImages,omitempty"` +- OutputTokens int `json:"outputTokens,omitempty"` +- OutputVideos int `json:"outputVideos,omitempty"` +- ThoughtsTokens int `json:"thoughtsTokens,omitempty"` +- TotalTokens int `json:"totalTokens,omitempty"` +-} +- ++ // CachedContentTokens counts tokens that were served from cache. ++ CachedContentTokens int `json:"cachedContentTokens,omitempty"` ++ // Custom contains additional usage metrics specific to the model provider. ++ Custom map[string]float64 `json:"custom,omitempty"` ++ // InputAudioFiles is the number of audio files in the input. ++ InputAudioFiles int `json:"inputAudioFiles,omitempty"` ++ // InputCharacters is the number of characters in the input. ++ InputCharacters int `json:"inputCharacters,omitempty"` ++ // InputImages is the number of images in the input. ++ InputImages int `json:"inputImages,omitempty"` ++ // InputTokens is the number of tokens in the input prompt. ++ InputTokens int `json:"inputTokens,omitempty"` ++ // InputVideos is the number of videos in the input. ++ InputVideos int `json:"inputVideos,omitempty"` ++ // OutputAudioFiles is the number of audio files generated in the output. ++ OutputAudioFiles int `json:"outputAudioFiles,omitempty"` ++ // OutputCharacters is the number of characters generated in the output. ++ OutputCharacters int `json:"outputCharacters,omitempty"` ++ // OutputImages is the number of images generated in the output. ++ OutputImages int `json:"outputImages,omitempty"` ++ // OutputTokens is the number of tokens generated in the response. ++ OutputTokens int `json:"outputTokens,omitempty"` ++ // OutputVideos is the number of videos generated in the output. ++ OutputVideos int `json:"outputVideos,omitempty"` ++ // ThoughtsTokens counts tokens used in reasoning or thinking processes. ++ ThoughtsTokens int `json:"thoughtsTokens,omitempty"` ++ // TotalTokens is the sum of input and output tokens. ++ TotalTokens int `json:"totalTokens,omitempty"` ++} ++ ++// Media represents media content with a URL and content type. + type Media struct { ++ // ContentType specifies the MIME type of the media. Inferred from the data URI if not provided. + ContentType string `json:"contentType,omitempty"` +- Url string `json:"url,omitempty"` ++ // Url is a "data:" or "https:" URI containing the media content. ++ Url string `json:"url,omitempty"` + } + + type mediaPart struct { +- Media *Media `json:"media,omitempty"` ++ // Media contains the media content and metadata. ++ Media *Media `json:"media,omitempty"` ++ // Metadata contains arbitrary key-value data for this part. + Metadata map[string]any `json:"metadata,omitempty"` + } + +-// Message is the contents of a model response. ++// Message represents the contents of a model message in a conversation. + type Message struct { +- Content []*Part `json:"content,omitempty"` ++ // Content holds the message parts (text, media, tool calls, etc.). ++ Content []*Part `json:"content,omitempty"` ++ // Metadata contains arbitrary key-value data associated with this message. + Metadata map[string]any `json:"metadata,omitempty"` +- Role Role `json:"role,omitempty"` ++ // Role indicates which entity (system, user, model, or tool) generated this message. ++ Role Role `json:"role,omitempty"` + } + ++// ModelInfo contains metadata about a model's capabilities and characteristics. + type ModelInfo struct { ++ // ConfigSchema defines the model-specific configuration schema. + ConfigSchema map[string]any `json:"configSchema,omitempty"` +- Label string `json:"label,omitempty"` +- Stage ModelStage `json:"stage,omitempty"` +- Supports *ModelSupports `json:"supports,omitempty"` +- Versions []string `json:"versions,omitempty"` +-} +- ++ // Label is a friendly display name for this model (e.g., "Google AI - Gemini Pro"). ++ Label string `json:"label,omitempty"` ++ // Stage indicates the development stage of this model. ++ // Featured models are recommended for general use, stable models are well-tested, ++ // unstable models are experimental, legacy models are not recommended for new projects, ++ // and deprecated models may be removed in future versions. ++ Stage ModelStage `json:"stage,omitempty"` ++ // Supports describes the capabilities that this model supports. ++ Supports *ModelSupports `json:"supports,omitempty"` ++ // Versions lists acceptable names for this model (e.g., different versions). ++ Versions []string `json:"versions,omitempty"` ++} ++ ++// ModelStage indicates the development stage of a model. + type ModelStage string + + const ( +@@ -219,19 +234,31 @@ const ( + ModelStageDeprecated ModelStage = "deprecated" + ) + ++// ModelSupports describes the capabilities that a model supports. + type ModelSupports struct { ++ // Constrained indicates the level of constrained generation support (none, all, or no-tools). + Constrained ConstrainedSupport `json:"constrained,omitempty"` +- ContentType []string `json:"contentType,omitempty"` +- Context bool `json:"context,omitempty"` +- LongRunning bool `json:"longRunning,omitempty"` +- Media bool `json:"media,omitempty"` +- Multiturn bool `json:"multiturn,omitempty"` +- Output []string `json:"output,omitempty"` +- SystemRole bool `json:"systemRole,omitempty"` +- ToolChoice bool `json:"toolChoice,omitempty"` +- Tools bool `json:"tools,omitempty"` +-} +- ++ // ContentType lists the content types the model supports for output. ++ ContentType []string `json:"contentType,omitempty"` ++ // Context indicates whether the model can natively support document-based context grounding. ++ Context bool `json:"context,omitempty"` ++ // LongRunning indicates whether the model supports long-running operations. ++ LongRunning bool `json:"longRunning,omitempty"` ++ // Media indicates whether the model can process media as part of the prompt (multimodal input). ++ Media bool `json:"media,omitempty"` ++ // Multiturn indicates whether the model can process historical messages passed with a prompt. ++ Multiturn bool `json:"multiturn,omitempty"` ++ // Output lists the types of data the model can generate. ++ Output []string `json:"output,omitempty"` ++ // SystemRole indicates whether the model can accept messages with role "system". ++ SystemRole bool `json:"systemRole,omitempty"` ++ // ToolChoice indicates whether the model supports controlling tool choice (e.g., forced tool calling). ++ ToolChoice bool `json:"toolChoice,omitempty"` ++ // Tools indicates whether the model can perform tool calls. ++ Tools bool `json:"tools,omitempty"` ++} ++ ++// ConstrainedSupport indicates the level of constrained generation support. + type ConstrainedSupport string + + const ( +@@ -242,118 +269,176 @@ const ( + + // A ModelRequest is a request to generate completions from a model. + type ModelRequest struct { +- Config any `json:"config,omitempty"` +- Docs []*Document `json:"docs,omitempty"` +- Messages []*Message `json:"messages,omitempty"` ++ // Config holds model-specific configuration parameters. ++ Config any `json:"config,omitempty"` ++ // Docs provides retrieved documents to be used as context for this generation. ++ Docs []*Document `json:"docs,omitempty"` ++ // Messages contains the conversation history for the model. ++ Messages []*Message `json:"messages,omitempty"` + // Output describes the desired response format. +- Output *ModelOutputConfig `json:"output,omitempty"` +- ToolChoice ToolChoice `json:"toolChoice,omitempty"` ++ Output *ModelOutputConfig `json:"output,omitempty"` ++ // ToolChoice controls how the model uses tools (auto, required, or none). ++ ToolChoice ToolChoice `json:"toolChoice,omitempty"` + // Tools lists the available tools that the model can ask the client to run. + Tools []*ToolDefinition `json:"tools,omitempty"` + } + +-// A ModelResponse is a model's response to a [ModelRequest]. ++// A ModelResponse is a model's response to a ModelRequest. + type ModelResponse struct { +- Custom any `json:"custom,omitempty"` +- FinishMessage string `json:"finishMessage,omitempty"` +- FinishReason FinishReason `json:"finishReason,omitempty"` ++ // Custom contains model-specific extra information. Deprecated: use Raw instead. ++ Custom any `json:"custom,omitempty"` ++ // FinishMessage provides additional details about why generation finished. ++ FinishMessage string `json:"finishMessage,omitempty"` ++ // FinishReason indicates why generation stopped (e.g., stop, length, blocked). ++ FinishReason FinishReason `json:"finishReason,omitempty"` + // LatencyMs is the time the request took in milliseconds. +- LatencyMs float64 `json:"latencyMs,omitempty"` +- Message *Message `json:"message,omitempty"` ++ LatencyMs float64 `json:"latencyMs,omitempty"` ++ // Message contains the generated response content. ++ Message *Message `json:"message,omitempty"` ++ // Operation provides information about a long-running background task if applicable. + Operation *Operation `json:"operation,omitempty"` +- Raw any `json:"raw,omitempty"` +- // Request is the [ModelRequest] struct used to trigger this response. ++ // Raw contains the unprocessed model-specific response data. ++ Raw any `json:"raw,omitempty"` ++ // Request is the ModelRequest struct used to trigger this response. + Request *ModelRequest `json:"request,omitempty"` + // Usage describes how many resources were used by this generation request. + Usage *GenerationUsage `json:"usage,omitempty"` + formatHandler StreamingFormatHandler + } + +-// A ModelResponseChunk is the portion of the [ModelResponse] ++// A ModelResponseChunk is the portion of the ModelResponse + // that is passed to a streaming callback. + type ModelResponseChunk struct { +- Aggregated bool `json:"aggregated,omitempty"` +- Content []*Part `json:"content,omitempty"` +- Custom any `json:"custom,omitempty"` +- Index int `json:"index"` +- Role Role `json:"role,omitempty"` ++ // Aggregated indicates whether the chunk includes all data from previous chunks. ++ // If false, the chunk is considered incremental. ++ Aggregated bool `json:"aggregated,omitempty"` ++ // Content is the chunk of message parts to stream right now. ++ Content []*Part `json:"content,omitempty"` ++ // Custom contains model-specific extra information attached to this chunk. ++ Custom any `json:"custom,omitempty"` ++ // Index of the message this chunk belongs to. ++ Index int `json:"index"` ++ // Role indicates the entity that generated this chunk. ++ Role Role `json:"role,omitempty"` + formatHandler StreamingFormatHandler + } + ++// MultipartToolResponse represents a tool response with both structured output and content parts. + type MultipartToolResponse struct { ++ // Content holds additional message parts providing context or details. + Content []*Part `json:"content,omitempty"` +- Output any `json:"output,omitempty"` ++ // Output contains the structured output data from the tool. ++ Output any `json:"output,omitempty"` + } + ++// Operation represents a long-running background task. + type Operation struct { +- Action string `json:"action,omitempty"` +- Done bool `json:"done,omitempty"` +- Error *OperationError `json:"error,omitempty"` +- Id string `json:"id,omitempty"` +- Metadata map[string]any `json:"metadata,omitempty"` +- Output any `json:"output,omitempty"` ++ // Action is the name of the action being performed by this operation. ++ Action string `json:"action,omitempty"` ++ // Done indicates whether the operation has completed. ++ Done bool `json:"done,omitempty"` ++ // Error contains error information if the operation failed. ++ Error *OperationError `json:"error,omitempty"` ++ // Id is the unique identifier for this operation. ++ Id string `json:"id,omitempty"` ++ // Metadata contains additional information about the operation. ++ Metadata map[string]any `json:"metadata,omitempty"` ++ // Output contains the result of the operation if it has completed successfully. ++ Output any `json:"output,omitempty"` + } + ++// OperationError contains error information for a failed operation. + type OperationError struct { ++ // Message describes the error that occurred. + Message string `json:"message,omitempty"` + } + + // OutputConfig describes the structure that the model's output +-// should conform to. If Format is [OutputFormatJSON], then Schema ++// should conform to. If Format is OutputFormatJSON, then Schema + // can describe the desired form of the generated JSON. + type ModelOutputConfig struct { +- Constrained bool `json:"constrained,omitempty"` +- ContentType string `json:"contentType,omitempty"` +- Format string `json:"format,omitempty"` +- Schema map[string]any `json:"schema,omitempty"` ++ // Constrained indicates whether to enforce strict adherence to the schema. ++ Constrained bool `json:"constrained,omitempty"` ++ // ContentType specifies the MIME type of the output content. ++ ContentType string `json:"contentType,omitempty"` ++ // Format specifies the desired output format (e.g., "json", "text"). ++ Format string `json:"format,omitempty"` ++ // Schema is a JSON Schema describing the desired structure of the output. ++ Schema map[string]any `json:"schema,omitempty"` + } + ++// PathMetadata contains metadata about a single execution path in a trace. + type PathMetadata struct { +- Error string `json:"error,omitempty"` ++ // Error contains error information if the path failed. ++ Error string `json:"error,omitempty"` ++ // Latency is the execution time for this path in milliseconds. + Latency float64 `json:"latency,omitempty"` +- Path string `json:"path,omitempty"` +- Status string `json:"status,omitempty"` ++ // Path is the identifier for this execution path. ++ Path string `json:"path,omitempty"` ++ // Status indicates the outcome of this path. ++ Status string `json:"status,omitempty"` + } + ++// RankedDocumentData represents a document with a relevance score from reranking. + type RankedDocumentData struct { +- Content []*Part `json:"content,omitempty"` ++ // Content holds the document's parts (text and media). ++ Content []*Part `json:"content,omitempty"` ++ // Metadata contains the reranking score and other arbitrary key-value data. + Metadata *RankedDocumentMetadata `json:"metadata,omitempty"` + } + ++// RankedDocumentMetadata contains the relevance score and other metadata for a reranked document. + type RankedDocumentMetadata struct { ++ // Score is the relevance score assigned by the reranker. + Score float64 `json:"score,omitempty"` + } + + type reasoningPart struct { +- Metadata map[string]any `json:"metadata,omitempty"` +- Reasoning string `json:"reasoning,omitempty"` ++ // Metadata contains arbitrary key-value data for this part. ++ Metadata map[string]any `json:"metadata,omitempty"` ++ // Reasoning contains the reasoning text of the message. ++ Reasoning string `json:"reasoning,omitempty"` + } + ++// RerankerRequest represents a request to rerank documents based on relevance. + type RerankerRequest struct { ++ // Documents is the array of documents to rerank. + Documents []*Document `json:"documents,omitempty"` +- Options any `json:"options,omitempty"` +- Query *Document `json:"query,omitempty"` ++ // Options contains reranker-specific configuration parameters. ++ Options any `json:"options,omitempty"` ++ // Query is the document to use for reranking. ++ Query *Document `json:"query,omitempty"` + } + ++// RerankerResponse contains the reranked documents with relevance scores. + type RerankerResponse struct { ++ // Documents is the array of reranked documents with scores. + Documents []*RankedDocumentData `json:"documents,omitempty"` + } + + type resourcePart struct { ++ // Metadata contains arbitrary key-value data for this part. + Metadata map[string]any `json:"metadata,omitempty"` +- Resource *ResourcePart `json:"resource,omitempty"` ++ // Resource contains a reference to an external resource by URI. ++ Resource *ResourcePart `json:"resource,omitempty"` + } + + type ResourcePart struct { ++ // Uri is the URI of the external resource. + Uri string `json:"uri,omitempty"` + } + ++// RetrieverRequest represents a request to retrieve relevant documents. + type RetrieverRequest struct { +- Options any `json:"options,omitempty"` +- Query *Document `json:"query,omitempty"` ++ // Options contains retriever-specific configuration parameters. ++ Options any `json:"options,omitempty"` ++ // Query is the document to use for retrieval. ++ Query *Document `json:"query,omitempty"` + } + ++// RetrieverResponse contains the retrieved documents from a retriever request. + type RetrieverResponse struct { ++ // Documents is the array of retrieved documents. + Documents []*Document `json:"documents,omitempty"` + } + +@@ -372,63 +457,83 @@ const ( + RoleTool Role = "tool" + ) + ++// ScoreDetails provides additional context and explanation for an evaluation score. + type ScoreDetails struct { ++ // Reasoning explains the rationale behind the score. + Reasoning string `json:"reasoning,omitempty"` + } + + type textPart struct { ++ // Metadata contains arbitrary key-value data for this part. + Metadata map[string]any `json:"metadata,omitempty"` +- Text string `json:"text,omitempty"` ++ // Text contains the textual content. ++ Text string `json:"text,omitempty"` + } + + // A ToolDefinition describes a tool. + type ToolDefinition struct { ++ // Description explains what the tool does and when to use it. + Description string `json:"description,omitempty"` +- // Valid JSON Schema representing the input of the tool. ++ // InputSchema is a valid JSON Schema representing the input parameters of the tool. + InputSchema map[string]any `json:"inputSchema,omitempty"` +- // additional metadata for this tool definition ++ // Metadata contains additional information about this tool definition. + Metadata map[string]any `json:"metadata,omitempty"` +- Name string `json:"name,omitempty"` +- // Valid JSON Schema describing the output of the tool. ++ // Name is the unique identifier for this tool. ++ Name string `json:"name,omitempty"` ++ // OutputSchema is a valid JSON Schema describing the output of the tool. + OutputSchema map[string]any `json:"outputSchema,omitempty"` + } + + // A ToolRequest is a message from the model to the client that it should run a +-// specific tool and pass a [ToolResponse] to the model on the next chat request it makes. +-// Any ToolRequest will correspond to some [ToolDefinition] previously sent by the client. ++// specific tool and pass a ToolResponse to the model on the next chat request it makes. ++// Any ToolRequest will correspond to some ToolDefinition previously sent by the client. + type ToolRequest struct { +- // Input is a JSON object describing the input values to the tool. +- // An example might be map[string]any{"country":"USA", "president":3}. +- Input any `json:"input,omitempty"` +- Name string `json:"name,omitempty"` +- Partial bool `json:"partial,omitempty"` +- Ref string `json:"ref,omitempty"` ++ // Input is a JSON object containing the input parameters for the tool. ++ // For example: map[string]any{"country":"USA", "president":3}. ++ Input any `json:"input,omitempty"` ++ // Name is the name of the tool to call. ++ Name string `json:"name,omitempty"` ++ // Partial indicates whether this is a partial streaming chunk. ++ Partial bool `json:"partial,omitempty"` ++ // Ref is the call ID or reference for this specific request. ++ Ref string `json:"ref,omitempty"` + } + + type toolRequestPart struct { +- Metadata map[string]any `json:"metadata,omitempty"` +- ToolRequest *ToolRequest `json:"toolRequest,omitempty"` ++ // Metadata contains arbitrary key-value data for this part. ++ Metadata map[string]any `json:"metadata,omitempty"` ++ // ToolRequest is a request for a tool to be executed, usually provided by a model. ++ ToolRequest *ToolRequest `json:"toolRequest,omitempty"` + } + + // A ToolResponse is a message from the client to the model containing + // the results of running a specific tool on the arguments passed to the client +-// by the model in a [ToolRequest]. ++// by the model in a ToolRequest. + type ToolResponse struct { ++ // Content holds additional message parts that provide context or details about the tool response. + Content []*Part `json:"content,omitempty"` +- Name string `json:"name,omitempty"` ++ // Name is the name of the tool that was executed. ++ Name string `json:"name,omitempty"` + // Output is a JSON object describing the results of running the tool. +- // An example might be map[string]any{"name":"Thomas Jefferson", "born":1743}. +- Output any `json:"output,omitempty"` +- Ref string `json:"ref,omitempty"` ++ // For example: map[string]any{"name":"Thomas Jefferson", "born":1743}. ++ Output any `json:"output,omitempty"` ++ // Ref is the call ID or reference matching the original request. ++ Ref string `json:"ref,omitempty"` + } + + type toolResponsePart struct { +- Metadata map[string]any `json:"metadata,omitempty"` +- ToolResponse *ToolResponse `json:"toolResponse,omitempty"` ++ // Metadata contains arbitrary key-value data for this part. ++ Metadata map[string]any `json:"metadata,omitempty"` ++ // ToolResponse is a provided response to a tool call. ++ ToolResponse *ToolResponse `json:"toolResponse,omitempty"` + } + ++// TraceMetadata contains metadata about a trace execution. + type TraceMetadata struct { +- FeatureName string `json:"featureName,omitempty"` +- Paths []*PathMetadata `json:"paths,omitempty"` +- Timestamp float64 `json:"timestamp,omitempty"` ++ // FeatureName identifies the feature being traced. ++ FeatureName string `json:"featureName,omitempty"` ++ // Paths contains metadata for each path executed during the trace. ++ Paths []*PathMetadata `json:"paths,omitempty"` ++ // Timestamp is when the trace was created. ++ Timestamp float64 `json:"timestamp,omitempty"` + } +diff --git a/go/ai/generate.go b/go/ai/generate.go +index f26cc9f09..08359c99c 100644 +--- a/go/ai/generate.go ++++ b/go/ai/generate.go +@@ -21,6 +21,7 @@ import ( + "encoding/json" + "errors" + "fmt" ++ "iter" + "slices" + "strings" + +@@ -550,7 +551,7 @@ func GenerateText(ctx context.Context, r api.Registry, opts ...GenerateOption) ( + return res.Text(), nil + } + +-// Generate run generate request for this model. Returns ModelResponse struct. ++// GenerateData runs a generate request and returns strongly-typed output. + func GenerateData[Out any](ctx context.Context, r api.Registry, opts ...GenerateOption) (*Out, *ModelResponse, error) { + var value Out + opts = append(opts, WithOutputType(value)) +@@ -568,6 +569,108 @@ func GenerateData[Out any](ctx context.Context, r api.Registry, opts ...Generate + return &value, resp, nil + } + ++// StreamValue is either a streamed chunk or the final response of a generate request. ++type StreamValue[Out, Stream any] struct { ++ Done bool ++ Chunk Stream // valid if Done is false ++ Output Out // valid if Done is true ++ Response *ModelResponse // valid if Done is true ++} ++ ++// ModelStreamValue is a stream value for a model response. ++// Out is never set because the output is already available in the Response field. ++type ModelStreamValue = StreamValue[struct{}, *ModelResponseChunk] ++ ++// errGenerateStop is a sentinel error used to signal early termination of streaming. ++var errGenerateStop = errors.New("stop") ++ ++// GenerateStream generates a model response and streams the output. ++// It returns an iterator that yields streaming results. ++// ++// If the yield function is passed a non-nil error, generation has failed with that ++// error; the yield function will not be called again. ++// ++// If the yield function's [ModelStreamValue] argument has Done == true, the value's ++// Response field contains the final response; the yield function will not be called ++// again. ++// ++// Otherwise the Chunk field of the passed [ModelStreamValue] holds a streamed chunk. ++func GenerateStream(ctx context.Context, r api.Registry, opts ...GenerateOption) iter.Seq2[*ModelStreamValue, error] { ++ return func(yield func(*ModelStreamValue, error) bool) { ++ cb := func(ctx context.Context, chunk *ModelResponseChunk) error { ++ if ctx.Err() != nil { ++ return ctx.Err() ++ } ++ if !yield(&ModelStreamValue{Chunk: chunk}, nil) { ++ return errGenerateStop ++ } ++ return nil ++ } ++ ++ allOpts := append(slices.Clone(opts), WithStreaming(cb)) ++ ++ resp, err := Generate(ctx, r, allOpts...) ++ if err != nil { ++ yield(nil, err) ++ } else { ++ yield(&ModelStreamValue{Done: true, Response: resp}, nil) ++ } ++ } ++} ++ ++// GenerateDataStream generates a model response with streaming and returns strongly-typed output. ++// It returns an iterator that yields streaming results. ++// ++// If the yield function is passed a non-nil error, generation has failed with that ++// error; the yield function will not be called again. ++// ++// If the yield function's [StreamValue] argument has Done == true, the value's ++// Output and Response fields contain the final typed output and response; the yield function ++// will not be called again. ++// ++// Otherwise the Chunk field of the passed [StreamValue] holds a streamed chunk. ++func GenerateDataStream[Out any](ctx context.Context, r api.Registry, opts ...GenerateOption) iter.Seq2[*StreamValue[Out, Out], error] { ++ return func(yield func(*StreamValue[Out, Out], error) bool) { ++ cb := func(ctx context.Context, chunk *ModelResponseChunk) error { ++ if ctx.Err() != nil { ++ return ctx.Err() ++ } ++ var streamValue Out ++ if err := chunk.Output(&streamValue); err != nil { ++ yield(nil, err) ++ return err ++ } ++ // Skip yielding if there's no parseable output yet (e.g., incomplete JSON during streaming). ++ if base.IsNil(streamValue) { ++ return nil ++ } ++ if !yield(&StreamValue[Out, Out]{Chunk: streamValue}, nil) { ++ return errGenerateStop ++ } ++ return nil ++ } ++ ++ // Prepend WithOutputType so the user can override the output format. ++ var value Out ++ allOpts := append([]GenerateOption{WithOutputType(value)}, opts...) ++ allOpts = append(allOpts, WithStreaming(cb)) ++ ++ resp, err := Generate(ctx, r, allOpts...) ++ if err != nil { ++ yield(nil, err) ++ return ++ } ++ ++ output, err := extractTypedOutput[Out](resp) ++ if err != nil { ++ yield(nil, err) ++ return ++ } ++ ++ yield(&StreamValue[Out, Out]{Done: true, Output: output, Response: resp}, nil) ++ } ++} ++ + // Generate applies the [Action] to provided request. + func (m *model) Generate(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + if m == nil { +@@ -744,7 +847,7 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, + // [ModelResponse] as a string. It returns an empty string if there + // are no candidates or if the candidate has no message. + func (mr *ModelResponse) Text() string { +- if mr.Message == nil { ++ if mr == nil || mr.Message == nil { + return "" + } + return mr.Message.Text() +@@ -753,7 +856,7 @@ func (mr *ModelResponse) Text() string { + // History returns messages from the request combined with the response message + // to represent the conversation history. + func (mr *ModelResponse) History() []*Message { +- if mr.Message == nil { ++ if mr == nil || mr.Message == nil { + return mr.Request.Messages + } + return append(mr.Request.Messages, mr.Message) +@@ -762,7 +865,7 @@ func (mr *ModelResponse) History() []*Message { + // Reasoning concatenates all reasoning parts present in the message + func (mr *ModelResponse) Reasoning() string { + var sb strings.Builder +- if mr.Message == nil { ++ if mr == nil || mr.Message == nil { + return "" + } + +@@ -806,7 +909,7 @@ func (mr *ModelResponse) Output(v any) error { + // ToolRequests returns the tool requests from the response. + func (mr *ModelResponse) ToolRequests() []*ToolRequest { + toolReqs := []*ToolRequest{} +- if mr.Message == nil { ++ if mr == nil || mr.Message == nil { + return toolReqs + } + for _, part := range mr.Message.Content { +@@ -820,7 +923,7 @@ func (mr *ModelResponse) ToolRequests() []*ToolRequest { + // Interrupts returns the interrupted tool request parts from the response. + func (mr *ModelResponse) Interrupts() []*Part { + parts := []*Part{} +- if mr.Message == nil { ++ if mr == nil || mr.Message == nil { + return parts + } + for _, part := range mr.Message.Content { +@@ -833,7 +936,7 @@ func (mr *ModelResponse) Interrupts() []*Part { + + // Media returns the media content of the [ModelResponse] as a string. + func (mr *ModelResponse) Media() string { +- if mr.Message == nil { ++ if mr == nil || mr.Message == nil { + return "" + } + for _, part := range mr.Message.Content { +@@ -902,17 +1005,41 @@ func (c *ModelResponseChunk) Output(v any) error { + + // outputer is an interface for types that can unmarshal structured output. + type outputer interface { +- Output(v any) error ++ // Text returns the contents of the output as a string. ++ Text() string ++ // Output parses the structured output from the response and unmarshals it into value. ++ Output(value any) error + } + + // OutputFrom is a convenience function that parses structured output from a + // [ModelResponse] or [ModelResponseChunk] and returns it as a typed value. + // This is equivalent to calling Output() but returns the value directly instead + // of requiring a pointer argument. If you need to handle the error, use Output() instead. +-func OutputFrom[T any](src outputer) T { +- var v T +- src.Output(&v) +- return v ++func OutputFrom[Out any](src outputer) Out { ++ output, err := extractTypedOutput[Out](src) ++ if err != nil { ++ return base.Zero[Out]() ++ } ++ return output ++} ++ ++// extractTypedOutput extracts the typed output from a model response. ++// It supports string output by calling Text() and returning the result. ++func extractTypedOutput[Out any](o outputer) (Out, error) { ++ var output Out ++ ++ switch any(output).(type) { ++ case string: ++ text := o.Text() ++ // Type assertion to convert string to Out (which we know is string). ++ result := any(text).(Out) ++ return result, nil ++ default: ++ if err := o.Output(&output); err != nil { ++ return base.Zero[Out](), fmt.Errorf("failed to parse output: %w", err) ++ } ++ return output, nil ++ } + } + + // Text returns the contents of a [Message] as a string. It +diff --git a/go/ai/generate_test.go b/go/ai/generate_test.go +index cac1f9d50..050a0f3cc 100644 +--- a/go/ai/generate_test.go ++++ b/go/ai/generate_test.go +@@ -18,6 +18,7 @@ package ai + + import ( + "context" ++ "errors" + "fmt" + "math" + "strings" +@@ -1745,3 +1746,542 @@ func TestMultipartTools(t *testing.T) { + } + }) + } ++ ++// streamingTestData holds test output structures ++type streamingTestData struct { ++ Name string `json:"name"` ++ Value int `json:"value"` ++} ++ ++func TestGenerateStream(t *testing.T) { ++ r := registry.New() ++ ConfigureFormats(r) ++ DefineGenerateAction(context.Background(), r) ++ ++ t.Run("yields chunks then final response", func(t *testing.T) { ++ chunkTexts := []string{"Hello", " ", "World"} ++ chunkIndex := 0 ++ ++ streamModel := DefineModel(r, "test/streamModel", &ModelOptions{ ++ Supports: &ModelSupports{Multiturn: true}, ++ }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ if cb != nil { ++ for _, text := range chunkTexts { ++ cb(ctx, &ModelResponseChunk{ ++ Content: []*Part{NewTextPart(text)}, ++ }) ++ } ++ } ++ return &ModelResponse{ ++ Request: req, ++ Message: NewModelTextMessage("Hello World"), ++ }, nil ++ }) ++ ++ var receivedChunks []*ModelResponseChunk ++ var finalResponse *ModelResponse ++ ++ for val, err := range GenerateStream(context.Background(), r, ++ WithModel(streamModel), ++ WithPrompt("test streaming"), ++ ) { ++ if err != nil { ++ t.Fatalf("unexpected error: %v", err) ++ } ++ if val.Done { ++ finalResponse = val.Response ++ } else { ++ receivedChunks = append(receivedChunks, val.Chunk) ++ chunkIndex++ ++ } ++ } ++ ++ if len(receivedChunks) != len(chunkTexts) { ++ t.Errorf("expected %d chunks, got %d", len(chunkTexts), len(receivedChunks)) ++ } ++ ++ for i, chunk := range receivedChunks { ++ if chunk.Text() != chunkTexts[i] { ++ t.Errorf("chunk %d: expected %q, got %q", i, chunkTexts[i], chunk.Text()) ++ } ++ } ++ ++ if finalResponse == nil { ++ t.Fatal("expected final response") ++ } ++ if finalResponse.Text() != "Hello World" { ++ t.Errorf("expected final text %q, got %q", "Hello World", finalResponse.Text()) ++ } ++ }) ++ ++ t.Run("handles no streaming callback gracefully", func(t *testing.T) { ++ noStreamModel := DefineModel(r, "test/noStreamModel", &ModelOptions{ ++ Supports: &ModelSupports{Multiturn: true}, ++ }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ return &ModelResponse{ ++ Request: req, ++ Message: NewModelTextMessage("response without streaming"), ++ }, nil ++ }) ++ ++ var finalResponse *ModelResponse ++ chunkCount := 0 ++ ++ for val, err := range GenerateStream(context.Background(), r, ++ WithModel(noStreamModel), ++ WithPrompt("test no stream"), ++ ) { ++ if err != nil { ++ t.Fatalf("unexpected error: %v", err) ++ } ++ if val.Done { ++ finalResponse = val.Response ++ } else { ++ chunkCount++ ++ } ++ } ++ ++ if chunkCount != 0 { ++ t.Errorf("expected 0 chunks when model doesn't stream, got %d", chunkCount) ++ } ++ if finalResponse == nil { ++ t.Fatal("expected final response") ++ } ++ if finalResponse.Text() != "response without streaming" { ++ t.Errorf("expected text %q, got %q", "response without streaming", finalResponse.Text()) ++ } ++ }) ++ ++ t.Run("propagates generation errors", func(t *testing.T) { ++ expectedErr := errors.New("generation failed") ++ ++ errorModel := DefineModel(r, "test/errorModel", &ModelOptions{ ++ Supports: &ModelSupports{Multiturn: true}, ++ }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ return nil, expectedErr ++ }) ++ ++ var receivedErr error ++ for _, err := range GenerateStream(context.Background(), r, ++ WithModel(errorModel), ++ WithPrompt("test error"), ++ ) { ++ if err != nil { ++ receivedErr = err ++ break ++ } ++ } ++ ++ if receivedErr == nil { ++ t.Fatal("expected error to be propagated") ++ } ++ if !errors.Is(receivedErr, expectedErr) { ++ t.Errorf("expected error %v, got %v", expectedErr, receivedErr) ++ } ++ }) ++ ++ t.Run("context cancellation stops iteration", func(t *testing.T) { ++ ctx, cancel := context.WithCancel(context.Background()) ++ defer cancel() ++ ++ streamModel := DefineModel(r, "test/cancelModel", &ModelOptions{ ++ Supports: &ModelSupports{Multiturn: true}, ++ }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ if cb != nil { ++ for i := 0; i < 100; i++ { ++ err := cb(ctx, &ModelResponseChunk{ ++ Content: []*Part{NewTextPart("chunk")}, ++ }) ++ if err != nil { ++ return nil, err ++ } ++ } ++ } ++ return &ModelResponse{ ++ Request: req, ++ Message: NewModelTextMessage("done"), ++ }, nil ++ }) ++ ++ chunksReceived := 0 ++ var receivedErr error ++ for val, err := range GenerateStream(ctx, r, ++ WithModel(streamModel), ++ WithPrompt("test cancel"), ++ ) { ++ if err != nil { ++ receivedErr = err ++ break ++ } ++ if !val.Done { ++ chunksReceived++ ++ if chunksReceived == 2 { ++ cancel() ++ } ++ } ++ } ++ ++ if chunksReceived < 2 { ++ t.Errorf("expected at least 2 chunks before cancellation, got %d", chunksReceived) ++ } ++ if receivedErr == nil { ++ t.Error("expected error from cancelled context") ++ } ++ }) ++} ++ ++func TestGenerateDataStream(t *testing.T) { ++ r := registry.New() ++ ConfigureFormats(r) ++ DefineGenerateAction(context.Background(), r) ++ ++ t.Run("yields typed chunks and final output", func(t *testing.T) { ++ streamModel := DefineModel(r, "test/typedStreamModel", &ModelOptions{ ++ Supports: &ModelSupports{ ++ Multiturn: true, ++ Constrained: ConstrainedSupportAll, ++ }, ++ }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ if cb != nil { ++ cb(ctx, &ModelResponseChunk{ ++ Content: []*Part{NewJSONPart(`{"name":"partial","value":1}`)}, ++ }) ++ cb(ctx, &ModelResponseChunk{ ++ Content: []*Part{NewJSONPart(`{"name":"complete","value":42}`)}, ++ }) ++ } ++ return &ModelResponse{ ++ Request: req, ++ Message: &Message{ ++ Role: RoleModel, ++ Content: []*Part{NewJSONPart(`{"name":"final","value":42}`)}, ++ }, ++ }, nil ++ }) ++ ++ var chunks []streamingTestData ++ var finalOutput streamingTestData ++ var finalResponse *ModelResponse ++ ++ for val, err := range GenerateDataStream[streamingTestData](context.Background(), r, ++ WithModel(streamModel), ++ WithPrompt("test typed streaming"), ++ ) { ++ if err != nil { ++ t.Fatalf("unexpected error: %v", err) ++ } ++ if val.Done { ++ finalOutput = val.Output ++ finalResponse = val.Response ++ } else { ++ chunks = append(chunks, val.Chunk) ++ } ++ } ++ ++ if len(chunks) < 1 { ++ t.Errorf("expected at least 1 chunk, got %d", len(chunks)) ++ } ++ ++ if finalOutput.Name != "final" || finalOutput.Value != 42 { ++ t.Errorf("expected final output {final, 42}, got %+v", finalOutput) ++ } ++ if finalResponse == nil { ++ t.Fatal("expected final response") ++ } ++ }) ++ ++ t.Run("final output is correctly typed", func(t *testing.T) { ++ streamModel := DefineModel(r, "test/finalTypedModel", &ModelOptions{ ++ Supports: &ModelSupports{ ++ Multiturn: true, ++ Constrained: ConstrainedSupportAll, ++ }, ++ }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ return &ModelResponse{ ++ Request: req, ++ Message: &Message{ ++ Role: RoleModel, ++ Content: []*Part{NewJSONPart(`{"name":"result","value":123}`)}, ++ }, ++ }, nil ++ }) ++ ++ var finalOutput streamingTestData ++ var gotFinal bool ++ ++ for val, err := range GenerateDataStream[streamingTestData](context.Background(), r, ++ WithModel(streamModel), ++ WithPrompt("test final typed"), ++ ) { ++ if err != nil { ++ t.Fatalf("unexpected error: %v", err) ++ } ++ if val.Done { ++ finalOutput = val.Output ++ gotFinal = true ++ } ++ } ++ ++ if !gotFinal { ++ t.Fatal("expected to receive final output") ++ } ++ if finalOutput.Name != "result" || finalOutput.Value != 123 { ++ t.Errorf("expected final output {result, 123}, got %+v", finalOutput) ++ } ++ }) ++ ++ t.Run("automatically sets output type", func(t *testing.T) { ++ var capturedRequest *ModelRequest ++ ++ streamModel := DefineModel(r, "test/autoOutputModel", &ModelOptions{ ++ Supports: &ModelSupports{ ++ Multiturn: true, ++ Constrained: ConstrainedSupportAll, ++ }, ++ }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ capturedRequest = req ++ return &ModelResponse{ ++ Request: req, ++ Message: &Message{ ++ Role: RoleModel, ++ Content: []*Part{NewJSONPart(`{"name":"test","value":1}`)}, ++ }, ++ }, nil ++ }) ++ ++ for range GenerateDataStream[streamingTestData](context.Background(), r, ++ WithModel(streamModel), ++ WithPrompt("test auto output type"), ++ ) { ++ } ++ ++ if capturedRequest == nil { ++ t.Fatal("expected request to be captured") ++ } ++ if capturedRequest.Output == nil || capturedRequest.Output.Schema == nil { ++ t.Error("expected output schema to be set automatically") ++ } ++ }) ++ ++ t.Run("propagates chunk parsing errors", func(t *testing.T) { ++ streamModel := DefineModel(r, "test/parseErrorModel", &ModelOptions{ ++ Supports: &ModelSupports{ ++ Multiturn: true, ++ Constrained: ConstrainedSupportAll, ++ }, ++ }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ if cb != nil { ++ cb(ctx, &ModelResponseChunk{ ++ Content: []*Part{NewTextPart("not valid json")}, ++ }) ++ } ++ return &ModelResponse{ ++ Request: req, ++ Message: NewModelTextMessage("done"), ++ }, nil ++ }) ++ ++ var receivedErr error ++ for _, err := range GenerateDataStream[streamingTestData](context.Background(), r, ++ WithModel(streamModel), ++ WithPrompt("test parse error"), ++ ) { ++ if err != nil { ++ receivedErr = err ++ break ++ } ++ } ++ ++ if receivedErr == nil { ++ t.Error("expected parsing error to be propagated") ++ } ++ }) ++} ++ ++func TestGenerateText(t *testing.T) { ++ r := newTestRegistry(t) ++ ++ echoModel := DefineModel(r, "test/echoTextModel", nil, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ return &ModelResponse{ ++ Request: req, ++ Message: NewModelTextMessage("echo: " + req.Messages[0].Content[0].Text), ++ }, nil ++ }) ++ ++ t.Run("returns text from model", func(t *testing.T) { ++ text, err := GenerateText(context.Background(), r, ++ WithModel(echoModel), ++ WithPrompt("hello"), ++ ) ++ ++ if err != nil { ++ t.Fatalf("GenerateText error: %v", err) ++ } ++ if text != "echo: hello" { ++ t.Errorf("text = %q, want %q", text, "echo: hello") ++ } ++ }) ++} ++ ++func TestGenerateData(t *testing.T) { ++ r := newTestRegistry(t) ++ ++ type TestOutput struct { ++ Value int `json:"value"` ++ } ++ ++ jsonModel := DefineModel(r, "test/jsonDataModel", &ModelOptions{ ++ Supports: &ModelSupports{ ++ Constrained: ConstrainedSupportAll, ++ }, ++ }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ return &ModelResponse{ ++ Request: req, ++ Message: NewModelTextMessage(`{"value": 42}`), ++ }, nil ++ }) ++ ++ t.Run("returns typed data from model", func(t *testing.T) { ++ output, _, err := GenerateData[TestOutput](context.Background(), r, ++ WithModel(jsonModel), ++ WithPrompt("get value"), ++ ) ++ ++ if err != nil { ++ t.Fatalf("GenerateData error: %v", err) ++ } ++ if output.Value != 42 { ++ t.Errorf("output.Value = %d, want 42", output.Value) ++ } ++ }) ++} ++ ++func TestModelResponseReasoning(t *testing.T) { ++ t.Run("returns reasoning from response", func(t *testing.T) { ++ resp := &ModelResponse{ ++ Message: &Message{ ++ Role: RoleModel, ++ Content: []*Part{ ++ NewReasoningPart("thinking about this...", nil), ++ NewTextPart("final answer"), ++ }, ++ }, ++ } ++ ++ reasoning := resp.Reasoning() ++ ++ if reasoning != "thinking about this..." { ++ t.Errorf("Reasoning() = %q, want %q", reasoning, "thinking about this...") ++ } ++ }) ++ ++ t.Run("returns empty string when no reasoning", func(t *testing.T) { ++ resp := &ModelResponse{ ++ Message: NewModelTextMessage("just text"), ++ } ++ ++ reasoning := resp.Reasoning() ++ ++ if reasoning != "" { ++ t.Errorf("Reasoning() = %q, want empty string", reasoning) ++ } ++ }) ++} ++ ++func TestModelResponseInterrupts(t *testing.T) { ++ t.Run("returns interrupt tool requests", func(t *testing.T) { ++ interruptPart := NewToolRequestPart(&ToolRequest{ ++ Name: "confirmAction", ++ Input: map[string]any{}, ++ }) ++ interruptPart.Metadata = map[string]any{"interrupt": true} ++ ++ resp := &ModelResponse{ ++ Message: &Message{ ++ Role: RoleModel, ++ Content: []*Part{ ++ NewTextPart("Please confirm"), ++ interruptPart, ++ }, ++ }, ++ } ++ ++ interrupts := resp.Interrupts() ++ ++ if len(interrupts) != 1 { ++ t.Fatalf("len(Interrupts()) = %d, want 1", len(interrupts)) ++ } ++ if interrupts[0].ToolRequest.Name != "confirmAction" { ++ t.Errorf("interrupt name = %q, want %q", interrupts[0].ToolRequest.Name, "confirmAction") ++ } ++ }) ++ ++ t.Run("returns empty slice when no interrupts", func(t *testing.T) { ++ resp := &ModelResponse{ ++ Message: NewModelTextMessage("no interrupts here"), ++ } ++ ++ interrupts := resp.Interrupts() ++ ++ if len(interrupts) != 0 { ++ t.Errorf("len(Interrupts()) = %d, want 0", len(interrupts)) ++ } ++ }) ++} ++ ++func TestModelResponseMedia(t *testing.T) { ++ t.Run("returns media URL from response", func(t *testing.T) { ++ resp := &ModelResponse{ ++ Message: &Message{ ++ Role: RoleModel, ++ Content: []*Part{ ++ NewTextPart("Here's an image"), ++ NewMediaPart("image/png", "data:image/png;base64,abc123"), ++ }, ++ }, ++ } ++ ++ media := resp.Media() ++ ++ if media == "" { ++ t.Error("Media() returned empty string") ++ } ++ if media != "data:image/png;base64,abc123" { ++ t.Errorf("Media() = %q, want %q", media, "data:image/png;base64,abc123") ++ } ++ }) ++ ++ t.Run("returns empty string when no media", func(t *testing.T) { ++ resp := &ModelResponse{ ++ Message: NewModelTextMessage("just text"), ++ } ++ ++ media := resp.Media() ++ ++ if media != "" { ++ t.Errorf("Media() = %q, want empty string", media) ++ } ++ }) ++} ++ ++func TestOutputFrom(t *testing.T) { ++ type TestData struct { ++ Name string `json:"name"` ++ Count int `json:"count"` ++ } ++ ++ t.Run("extracts typed output from response", func(t *testing.T) { ++ resp := &ModelResponse{ ++ Message: NewModelTextMessage(`{"name": "test", "count": 5}`), ++ } ++ ++ output := OutputFrom[TestData](resp) ++ ++ if output.Name != "test" { ++ t.Errorf("output.Name = %q, want %q", output.Name, "test") ++ } ++ if output.Count != 5 { ++ t.Errorf("output.Count = %d, want 5", output.Count) ++ } ++ }) ++} +diff --git a/go/ai/option_test.go b/go/ai/option_test.go +index 6fd243084..04fee69a5 100644 +--- a/go/ai/option_test.go ++++ b/go/ai/option_test.go +@@ -653,3 +653,129 @@ func (t *mockTool) Definition() *ToolDefinition { + func (t *mockTool) RunRaw(ctx context.Context, input any) (any, error) { + return nil, nil + } ++ ++func (t *mockTool) RunRawMultipart(ctx context.Context, input any) (*MultipartToolResponse, error) { ++ return nil, nil ++} ++ ++func (t *mockTool) Respond(toolReq *Part, outputData any, opts *RespondOptions) *Part { ++ return nil ++} ++ ++func (t *mockTool) Restart(toolReq *Part, opts *RestartOptions) *Part { ++ return nil ++} ++ ++func (t *mockTool) Register(r interface{ RegisterValue(string, any) }) { ++} ++ ++func TestWithInputSchemaName(t *testing.T) { ++ t.Run("creates input option with schema reference", func(t *testing.T) { ++ opt := WithInputSchemaName("MyInputType") ++ opts := &promptOptions{} ++ ++ if err := opt.applyPrompt(opts); err != nil { ++ t.Fatalf("applyPrompt() error: %v", err) ++ } ++ ++ if opts.InputSchema == nil { ++ t.Fatal("InputSchema is nil") ++ } ++ ++ ref, ok := opts.InputSchema["$ref"].(string) ++ if !ok { ++ t.Fatal("InputSchema.$ref is not a string") ++ } ++ if ref != "genkit:MyInputType" { ++ t.Errorf("InputSchema.$ref = %q, want %q", ref, "genkit:MyInputType") ++ } ++ }) ++} ++ ++func TestWithOutputSchema(t *testing.T) { ++ t.Run("creates output option with direct schema", func(t *testing.T) { ++ schema := map[string]any{ ++ "type": "object", ++ "properties": map[string]any{ ++ "name": map[string]any{"type": "string"}, ++ }, ++ } ++ opt := WithOutputSchema(schema) ++ opts := &generateOptions{} ++ ++ if err := opt.applyGenerate(opts); err != nil { ++ t.Fatalf("applyGenerate() error: %v", err) ++ } ++ ++ if opts.OutputSchema == nil { ++ t.Fatal("OutputSchema is nil") ++ } ++ if opts.OutputFormat != OutputFormatJSON { ++ t.Errorf("OutputFormat = %q, want %q", opts.OutputFormat, OutputFormatJSON) ++ } ++ }) ++} ++ ++func TestWithOutputEnums(t *testing.T) { ++ t.Run("creates enum output with string values", func(t *testing.T) { ++ opt := WithOutputEnums("red", "green", "blue") ++ opts := &generateOptions{} ++ ++ if err := opt.applyGenerate(opts); err != nil { ++ t.Fatalf("applyGenerate() error: %v", err) ++ } ++ ++ if opts.OutputSchema == nil { ++ t.Fatal("OutputSchema is nil") ++ } ++ if opts.OutputFormat != OutputFormatEnum { ++ t.Errorf("OutputFormat = %q, want %q", opts.OutputFormat, OutputFormatEnum) ++ } ++ ++ enumType, ok := opts.OutputSchema["type"].(string) ++ if !ok || enumType != "string" { ++ t.Errorf("OutputSchema.type = %v, want %q", opts.OutputSchema["type"], "string") ++ } ++ ++ enumVals, ok := opts.OutputSchema["enum"].([]string) ++ if !ok { ++ t.Fatalf("OutputSchema.enum is not []string: %T", opts.OutputSchema["enum"]) ++ } ++ if len(enumVals) != 3 { ++ t.Errorf("len(enum) = %d, want 3", len(enumVals)) ++ } ++ }) ++ ++ t.Run("works with custom string type", func(t *testing.T) { ++ type Color string ++ opt := WithOutputEnums(Color("red"), Color("green")) ++ opts := &generateOptions{} ++ ++ if err := opt.applyGenerate(opts); err != nil { ++ t.Fatalf("applyGenerate() error: %v", err) ++ } ++ ++ enumVals := opts.OutputSchema["enum"].([]string) ++ if enumVals[0] != "red" || enumVals[1] != "green" { ++ t.Errorf("enum values = %v, want [red, green]", enumVals) ++ } ++ }) ++} ++ ++func TestWithEvaluatorName(t *testing.T) { ++ t.Run("creates evaluator option with reference", func(t *testing.T) { ++ opt := WithEvaluatorName("test/myEvaluator") ++ opts := &evaluatorOptions{} ++ ++ if err := opt.applyEvaluator(opts); err != nil { ++ t.Fatalf("applyEvaluator() error: %v", err) ++ } ++ ++ if opts.Evaluator == nil { ++ t.Fatal("Evaluator is nil") ++ } ++ if opts.Evaluator.Name() != "test/myEvaluator" { ++ t.Errorf("Evaluator.Name() = %q, want %q", opts.Evaluator.Name(), "test/myEvaluator") ++ } ++ }) ++} +diff --git a/go/ai/prompt.go b/go/ai/prompt.go +index db4ec264c..ac4ef82e5 100644 +--- a/go/ai/prompt.go ++++ b/go/ai/prompt.go +@@ -19,11 +19,14 @@ import ( + "encoding/json" + "errors" + "fmt" ++ "io/fs" ++ "iter" + "log/slog" + "maps" + "os" +- "path/filepath" ++ "path" + "reflect" ++ "slices" + "strings" + + "github.com/firebase/genkit/go/core" +@@ -40,6 +43,8 @@ type Prompt interface { + Name() string + // Execute executes the prompt with the given options and returns a [ModelResponse]. + Execute(ctx context.Context, opts ...PromptExecuteOption) (*ModelResponse, error) ++ // ExecuteStream executes the prompt with streaming and returns an iterator. ++ ExecuteStream(ctx context.Context, opts ...PromptExecuteOption) iter.Seq2[*ModelStreamValue, error] + // Render renders the prompt with the given input and returns a [GenerateActionOptions] to be used with [GenerateWithRequest]. + Render(ctx context.Context, input any) (*GenerateActionOptions, error) + } +@@ -51,6 +56,13 @@ type prompt struct { + registry api.Registry + } + ++// DataPrompt is a prompt with strongly-typed input and output. ++// It wraps an underlying [Prompt] and provides type-safe Execute and Render methods. ++// The Out type parameter can be string for text outputs or any struct type for JSON outputs. ++type DataPrompt[In, Out any] struct { ++ prompt ++} ++ + // DefinePrompt creates a new [Prompt] and registers it. + func DefinePrompt(r api.Registry, name string, opts ...PromptOption) Prompt { + if name == "" { +@@ -89,10 +101,7 @@ func DefinePrompt(r api.Registry, name string, opts ...PromptOption) Prompt { + } + metadata["type"] = api.ActionTypeExecutablePrompt + +- baseName := name +- if idx := strings.LastIndex(name, "."); idx != -1 { +- baseName = name[:idx] +- } ++ baseName, variant, _ := strings.Cut(name, ".") + + promptMetadata := map[string]any{ + "name": baseName, +@@ -105,6 +114,9 @@ func DefinePrompt(r api.Registry, name string, opts ...PromptOption) Prompt { + "tools": tools, + "maxTurns": p.MaxTurns, + } ++ if variant != "" { ++ promptMetadata["variant"] = variant ++ } + if m, ok := metadata["prompt"].(map[string]any); ok { + maps.Copy(m, promptMetadata) + } else { +@@ -133,7 +145,7 @@ func LookupPrompt(r api.Registry, name string) Prompt { + // passes the rendered template to the AI model specified by the prompt. + func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*ModelResponse, error) { + if p == nil { +- return nil, errors.New("Prompt.Execute: execute called on a nil Prompt; check that all prompts are defined") ++ return nil, core.NewError(core.INVALID_ARGUMENT, "Prompt.Execute: prompt is nil") + } + + execOpts := &promptExecutionOptions{} +@@ -239,10 +251,50 @@ func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*Mod + return GenerateWithRequest(ctx, r, actionOpts, execOpts.Middleware, execOpts.Stream) + } + ++// ExecuteStream executes the prompt with streaming and returns an iterator. ++// ++// If the yield function is passed a non-nil error, execution has failed with that ++// error; the yield function will not be called again. ++// ++// If the yield function's [ModelStreamValue] argument has Done == true, the value's ++// Response field contains the final response; the yield function will not be called again. ++// ++// Otherwise the Chunk field of the passed [ModelStreamValue] holds a streamed chunk. ++func (p *prompt) ExecuteStream(ctx context.Context, opts ...PromptExecuteOption) iter.Seq2[*ModelStreamValue, error] { ++ return func(yield func(*ModelStreamValue, error) bool) { ++ if p == nil { ++ yield(nil, core.NewError(core.INVALID_ARGUMENT, "Prompt.ExecuteStream: prompt is nil")) ++ return ++ } ++ ++ cb := func(ctx context.Context, chunk *ModelResponseChunk) error { ++ if ctx.Err() != nil { ++ return ctx.Err() ++ } ++ if !yield(&ModelStreamValue{Chunk: chunk}, nil) { ++ return errPromptStop ++ } ++ return nil ++ } ++ ++ allOpts := append(slices.Clone(opts), WithStreaming(cb)) ++ resp, err := p.Execute(ctx, allOpts...) ++ if err != nil { ++ yield(nil, err) ++ return ++ } ++ ++ yield(&ModelStreamValue{Done: true, Response: resp}, nil) ++ } ++} ++ ++// errPromptStop is a sentinel error used to signal early termination of streaming. ++var errPromptStop = errors.New("stop") ++ + // Render renders the prompt template based on user input. + func (p *prompt) Render(ctx context.Context, input any) (*GenerateActionOptions, error) { + if p == nil { +- return nil, errors.New("Prompt.Render: called on a nil prompt; check that all prompts are defined") ++ return nil, core.NewError(core.INVALID_ARGUMENT, "Prompt.Render: prompt is nil") + } + + if len(p.Middleware) > 0 { +@@ -414,13 +466,16 @@ func renderSystemPrompt(ctx context.Context, opts promptOptions, messages []*Mes + return nil, err + } + +- parts, err := renderPrompt(ctx, opts, templateText, input, dp) ++ renderedMessages, err := renderPrompt(ctx, opts, templateText, input, dp) + if err != nil { + return nil, err + } + +- if len(parts) != 0 { +- messages = append(messages, NewSystemMessage(parts...)) ++ for _, m := range renderedMessages { ++ if m.Role == "" || (len(renderedMessages) == 1 && m.Role == RoleUser) { ++ m.Role = RoleSystem ++ } ++ messages = append(messages, m) + } + + return messages, nil +@@ -437,13 +492,16 @@ func renderUserPrompt(ctx context.Context, opts promptOptions, messages []*Messa + return nil, err + } + +- parts, err := renderPrompt(ctx, opts, templateText, input, dp) ++ renderedMessages, err := renderPrompt(ctx, opts, templateText, input, dp) + if err != nil { + return nil, err + } + +- if len(parts) != 0 { +- messages = append(messages, NewUserMessage(parts...)) ++ for _, m := range renderedMessages { ++ if m.Role == "" || (len(renderedMessages) == 1 && m.Role != RoleUser) { ++ m.Role = RoleUser ++ } ++ messages = append(messages, m) + } + + return messages, nil +@@ -463,47 +521,72 @@ func renderMessages(ctx context.Context, opts promptOptions, messages []*Message + // Create new message copies to avoid mutating shared messages during concurrent execution + renderedMsgs := make([]*Message, 0, len(msgs)) + for _, msg := range msgs { +- msgParts := []*Part{} ++ hasTextPart := slices.ContainsFunc(msg.Content, (*Part).IsText) ++ ++ if !hasTextPart { ++ // Create a new message with non-text content instead of mutating the original ++ renderedMsg := &Message{ ++ Role: msg.Role, ++ Content: msg.Content, ++ Metadata: msg.Metadata, ++ } ++ renderedMsgs = append(renderedMsgs, renderedMsg) ++ continue ++ } ++ + for _, part := range msg.Content { + if part.IsText() { +- parts, err := renderPrompt(ctx, opts, part.Text, input, dp) ++ messagesFromText, err := renderPrompt(ctx, opts, part.Text, input, dp) + if err != nil { + return nil, err + } +- msgParts = append(msgParts, parts...) ++ for _, m := range messagesFromText { ++ // If the rendered message has no role, or it is a single message with default role, ++ // use the original message's role. ++ role := m.Role ++ if role == "" || (len(messagesFromText) == 1 && role == RoleUser) { ++ role = msg.Role ++ } ++ renderedMsgs = append(renderedMsgs, &Message{ ++ Role: role, ++ Content: m.Content, ++ Metadata: msg.Metadata, ++ }) ++ } + } else { +- // Preserve non-text parts as-is +- msgParts = append(msgParts, part) ++ // Preserve non-text parts as-is in the current last message if possible, or create a new one ++ if len(renderedMsgs) > 0 && renderedMsgs[len(renderedMsgs)-1].Role == msg.Role { ++ renderedMsgs[len(renderedMsgs)-1].Content = append(renderedMsgs[len(renderedMsgs)-1].Content, part) ++ } else { ++ renderedMsgs = append(renderedMsgs, &Message{ ++ Role: msg.Role, ++ Content: []*Part{part}, ++ Metadata: msg.Metadata, ++ }) ++ } + } + } +- // Create a new message with rendered content instead of mutating the original +- renderedMsg := &Message{ +- Role: msg.Role, +- Content: msgParts, +- Metadata: msg.Metadata, +- } +- renderedMsgs = append(renderedMsgs, renderedMsg) + } + + return append(messages, renderedMsgs...), nil + } + + // renderPrompt renders a prompt template using dotprompt functionalities +-func renderPrompt(ctx context.Context, opts promptOptions, templateText string, input map[string]any, dp *dotprompt.Dotprompt) ([]*Part, error) { ++func renderPrompt(ctx context.Context, opts promptOptions, templateText string, input map[string]any, dp *dotprompt.Dotprompt) ([]*Message, error) { + renderedFunc, err := dp.Compile(templateText, &dotprompt.PromptMetadata{}) + if err != nil { + return nil, err + } + +- return renderDotpromptToParts(ctx, renderedFunc, input, &dotprompt.PromptMetadata{ ++ return renderDotpromptToMessages(ctx, renderedFunc, input, &dotprompt.PromptMetadata{ + Input: dotprompt.PromptMetadataInput{ + Default: opts.DefaultInput, + }, + }) + } + +-// renderDotpromptToParts executes a dotprompt prompt function and converts the result to a slice of parts +-func renderDotpromptToParts(ctx context.Context, promptFn dotprompt.PromptFunction, input map[string]any, additionalMetadata *dotprompt.PromptMetadata) ([]*Part, error) { ++// renderDotpromptToMessages executes a dotprompt prompt function and converts the result to a slice of messages ++func renderDotpromptToMessages(ctx context.Context, promptFn dotprompt.PromptFunction, input map[string]any, additionalMetadata *dotprompt.PromptMetadata) ([]*Message, error) { + // Prepare the context for rendering + context := map[string]any{} + actionCtx := core.FromContext(ctx) +@@ -518,16 +601,20 @@ func renderDotpromptToParts(ctx context.Context, promptFn dotprompt.PromptFuncti + return nil, fmt.Errorf("failed to render prompt: %w", err) + } + +- convertedParts := []*Part{} ++ convertedMessages := []*Message{} + for _, message := range rendered.Messages { + parts, err := convertToPartPointers(message.Content) + if err != nil { + return nil, fmt.Errorf("failed to convert parts: %w", err) + } +- convertedParts = append(convertedParts, parts...) ++ role := Role(message.Role) ++ convertedMessages = append(convertedMessages, &Message{ ++ Role: role, ++ Content: parts, ++ }) + } + +- return convertedParts, nil ++ return convertedMessages, nil + } + + // convertToPartPointers converts []dotprompt.Part to []*Part +@@ -550,87 +637,84 @@ func convertToPartPointers(parts []dotprompt.Part) ([]*Part, error) { + return result, nil + } + +-// LoadPromptDir loads prompts and partials from the input directory for the given namespace. +-func LoadPromptDir(r api.Registry, dir string, namespace string) { +- useDefaultDir := false +- if dir == "" { +- dir = "./prompts" +- useDefaultDir = true ++// LoadPromptDirFromFS loads prompts and partials from a filesystem for the given namespace. ++// The fsys parameter should be an fs.FS implementation (e.g., embed.FS or os.DirFS). ++// The dir parameter specifies the directory within the filesystem where prompts are located. ++func LoadPromptDirFromFS(r api.Registry, fsys fs.FS, dir, namespace string) { ++ if fsys == nil { ++ panic(errors.New("no prompt filesystem provided")) + } + +- path, err := filepath.Abs(dir) +- if err != nil { +- if !useDefaultDir { +- panic(fmt.Errorf("failed to resolve prompt directory %q: %w", dir, err)) +- } +- slog.Debug("default prompt directory not found, skipping loading .prompt files", "dir", dir) +- return ++ if _, err := fs.Stat(fsys, dir); err != nil { ++ panic(fmt.Errorf("failed to access prompt directory %q in filesystem: %w", dir, err)) + } + +- if _, err := os.Stat(path); os.IsNotExist(err) { +- if !useDefaultDir { +- panic(fmt.Errorf("failed to resolve prompt directory %q: %w", dir, err)) +- } +- slog.Debug("Default prompt directory not found, skipping loading .prompt files", "dir", dir) +- return +- } +- +- loadPromptDir(r, path, namespace) +-} +- +-// loadPromptDir recursively loads prompts and partials from the directory. +-func loadPromptDir(r api.Registry, dir string, namespace string) { +- entries, err := os.ReadDir(dir) ++ entries, err := fs.ReadDir(fsys, dir) + if err != nil { + panic(fmt.Errorf("failed to read prompt directory structure: %w", err)) + } + + for _, entry := range entries { + filename := entry.Name() +- path := filepath.Join(dir, filename) ++ filePath := path.Join(dir, filename) + if entry.IsDir() { +- loadPromptDir(r, path, namespace) ++ LoadPromptDirFromFS(r, fsys, filePath, namespace) + } else if strings.HasSuffix(filename, ".prompt") { + if strings.HasPrefix(filename, "_") { + partialName := strings.TrimSuffix(filename[1:], ".prompt") +- source, err := os.ReadFile(path) ++ source, err := fs.ReadFile(fsys, filePath) + if err != nil { + slog.Error("Failed to read partial file", "error", err) + continue + } + r.RegisterPartial(partialName, string(source)) +- slog.Debug("Registered Dotprompt partial", "name", partialName, "file", path) ++ slog.Debug("Registered Dotprompt partial", "name", partialName, "file", filePath) + } else { +- LoadPrompt(r, dir, filename, namespace) ++ LoadPromptFromFS(r, fsys, dir, filename, namespace) + } + } + } + } + +-// LoadPrompt loads a single prompt into the registry. +-func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt { ++// LoadPromptFromFS loads a single prompt from a filesystem into the registry. ++// The fsys parameter should be an fs.FS implementation (e.g., embed.FS or os.DirFS). ++// The dir parameter specifies the directory within the filesystem where the prompt is located. ++func LoadPromptFromFS(r api.Registry, fsys fs.FS, dir, filename, namespace string) Prompt { + name := strings.TrimSuffix(filename, ".prompt") +- name, variant, _ := strings.Cut(name, ".") + +- sourceFile := filepath.Join(dir, filename) +- source, err := os.ReadFile(sourceFile) ++ sourceFile := path.Join(dir, filename) ++ source, err := fs.ReadFile(fsys, sourceFile) + if err != nil { + slog.Error("Failed to read prompt file", "file", sourceFile, "error", err) + return nil + } + ++ p, err := LoadPromptFromSource(r, string(source), name, namespace) ++ if err != nil { ++ slog.Error("Failed to load prompt", "file", sourceFile, "error", err) ++ return nil ++ } ++ ++ slog.Debug("Registered Dotprompt", "name", p.Name(), "file", sourceFile) ++ return p ++} ++ ++// LoadPromptFromSource loads a prompt from raw .prompt file content. ++// The source parameter should contain the complete .prompt file text (frontmatter + template). ++// The name parameter is the prompt name (may include variant suffix like "myPrompt.variant"). ++func LoadPromptFromSource(r api.Registry, source, name, namespace string) (Prompt, error) { ++ name, variant, _ := strings.Cut(name, ".") ++ + dp := r.Dotprompt() + +- parsedPrompt, err := dp.Parse(string(source)) ++ parsedPrompt, err := dp.Parse(source) + if err != nil { +- slog.Error("Failed to parse file as dotprompt", "file", sourceFile, "error", err) +- return nil ++ return nil, fmt.Errorf("failed to parse dotprompt: %w", err) + } + +- metadata, err := dp.RenderMetadata(string(source), &parsedPrompt.PromptMetadata) ++ metadata, err := dp.RenderMetadata(source, &parsedPrompt.PromptMetadata) + if err != nil { +- slog.Error("Failed to render dotprompt metadata", "file", sourceFile, "error", err) +- return nil ++ return nil, fmt.Errorf("failed to render dotprompt metadata: %w", err) + } + + toolRefs := make([]ToolRef, len(metadata.Tools)) +@@ -692,7 +776,11 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt { + } + + if inputSchema, ok := metadata.Input.Schema.(map[string]any); ok { +- opts.InputSchema = inputSchema ++ if ref, ok := inputSchema["$ref"].(string); ok { ++ opts.InputSchema = core.SchemaRef(ref) ++ } else { ++ opts.InputSchema = inputSchema ++ } + } + + if metadata.Output.Format != "" { +@@ -710,57 +798,32 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt { + } + } + +- key := promptKey(name, variant, namespace) +- +- dpMessages, err := dotprompt.ToMessages(parsedPrompt.Template, &dotprompt.DataArgument{}) +- if err != nil { +- slog.Error("Failed to convert prompt template to messages", "file", sourceFile, "error", err) +- return nil +- } +- +- var systemText string +- var nonSystemMessages []*Message +- for _, dpMsg := range dpMessages { +- parts, err := convertToPartPointers(dpMsg.Content) +- if err != nil { +- slog.Error("Failed to convert message parts", "file", sourceFile, "error", err) +- return nil +- } +- +- role := Role(dpMsg.Role) +- if role == RoleSystem { +- var textParts []string +- for _, part := range parts { +- if part.IsText() { +- textParts = append(textParts, part.Text) +- } +- } +- +- if len(textParts) > 0 { +- systemText = strings.Join(textParts, " ") +- } ++ if outputSchema, ok := metadata.Output.Schema.(map[string]any); ok { ++ if ref, ok := outputSchema["$ref"].(string); ok { ++ opts.OutputSchema = core.SchemaRef(ref) + } else { +- nonSystemMessages = append(nonSystemMessages, &Message{Role: role, Content: parts}) ++ opts.OutputSchema = outputSchema ++ } ++ if opts.OutputFormat == "" { ++ opts.OutputFormat = OutputFormatJSON + } + } + +- promptOpts := []PromptOption{opts} +- +- if systemText != "" { +- promptOpts = append(promptOpts, WithSystem(systemText)) +- } ++ key := promptKey(name, variant, namespace) + +- if len(nonSystemMessages) > 0 { +- promptOpts = append(promptOpts, WithMessages(nonSystemMessages...)) +- } else if systemText == "" { +- promptOpts = append(promptOpts, WithPrompt(parsedPrompt.Template)) +- } ++ prompt := DefinePrompt(r, key, opts, WithPrompt(parsedPrompt.Template)) + +- prompt := DefinePrompt(r, key, promptOpts...) ++ return prompt, nil ++} + +- slog.Debug("Registered Dotprompt", "name", key, "file", sourceFile) ++// LoadPromptDir loads prompts and partials from a directory on the local filesystem. ++func LoadPromptDir(r api.Registry, dir string, namespace string) { ++ LoadPromptDirFromFS(r, os.DirFS(dir), ".", namespace) ++} + +- return prompt ++// LoadPrompt loads a single prompt from a directory on the local filesystem into the registry. ++func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt { ++ return LoadPromptFromFS(r, os.DirFS(dir), ".", filename, namespace) + } + + // promptKey generates a unique key for the prompt in the registry. +@@ -807,3 +870,133 @@ func contentType(ct, uri string) (string, []byte, error) { + + return "", nil, errors.New("uri content type not found") + } ++ ++// DefineDataPrompt creates a new data prompt and registers it. ++// It automatically infers input schema from the In type parameter and configures ++// output schema and JSON format from the Out type parameter (unless Out is string). ++func DefineDataPrompt[In, Out any](r api.Registry, name string, opts ...PromptOption) *DataPrompt[In, Out] { ++ if name == "" { ++ panic("ai.DefineDataPrompt: name is required") ++ } ++ ++ var in In ++ allOpts := []PromptOption{WithInputType(in)} ++ ++ var out Out ++ switch any(out).(type) { ++ case string: ++ // String output - no schema needed ++ default: ++ // Prepend WithOutputType so the user can override the output format. ++ allOpts = append(allOpts, WithOutputType(out)) ++ } ++ ++ allOpts = append(allOpts, opts...) ++ p := DefinePrompt(r, name, allOpts...) ++ ++ return &DataPrompt[In, Out]{prompt: *p.(*prompt)} ++} ++ ++// LookupDataPrompt looks up a prompt by name and wraps it with type information. ++// This is useful for wrapping prompts loaded from .prompt files with strong types. ++// It returns nil if the prompt was not found. ++func LookupDataPrompt[In, Out any](r api.Registry, name string) *DataPrompt[In, Out] { ++ return AsDataPrompt[In, Out](LookupPrompt(r, name)) ++} ++ ++// AsDataPrompt wraps an existing Prompt with type information, returning a DataPrompt. ++// This is useful for adding strong typing to a dynamically obtained prompt. ++func AsDataPrompt[In, Out any](p Prompt) *DataPrompt[In, Out] { ++ if p == nil { ++ return nil ++ } ++ ++ return &DataPrompt[In, Out]{prompt: *p.(*prompt)} ++} ++ ++// Execute executes the typed prompt and returns the strongly-typed output along with the full model response. ++// For structured output types (non-string Out), the prompt must be configured with the appropriate ++// output schema, either through [DefineDataPrompt] or by using [WithOutputType] when defining the prompt. ++func (dp *DataPrompt[In, Out]) Execute(ctx context.Context, input In, opts ...PromptExecuteOption) (Out, *ModelResponse, error) { ++ if dp == nil { ++ return base.Zero[Out](), nil, core.NewError(core.INVALID_ARGUMENT, "DataPrompt.Execute: prompt is nil") ++ } ++ ++ allOpts := append(slices.Clone(opts), WithInput(input)) ++ resp, err := dp.prompt.Execute(ctx, allOpts...) ++ if err != nil { ++ return base.Zero[Out](), nil, err ++ } ++ ++ output, err := extractTypedOutput[Out](resp) ++ if err != nil { ++ return base.Zero[Out](), resp, err ++ } ++ ++ return output, resp, nil ++} ++ ++// ExecuteStream executes the typed prompt with streaming and returns an iterator. ++// ++// If the yield function is passed a non-nil error, execution has failed with that ++// error; the yield function will not be called again. ++// ++// If the yield function's StreamValue argument has Done == true, the value's ++// Output and Response fields contain the final typed output and response; the yield function ++// will not be called again. ++// ++// Otherwise the Chunk field of the passed StreamValue holds a streamed chunk. ++// ++// For structured output types (non-string Out), the prompt must be configured with the appropriate ++// output schema, either through [DefineDataPrompt] or by using [WithOutputType] when defining the prompt. ++func (dp *DataPrompt[In, Out]) ExecuteStream(ctx context.Context, input In, opts ...PromptExecuteOption) iter.Seq2[*StreamValue[Out, Out], error] { ++ return func(yield func(*StreamValue[Out, Out], error) bool) { ++ if dp == nil { ++ yield(nil, core.NewError(core.INVALID_ARGUMENT, "DataPrompt.ExecuteStream: prompt is nil")) ++ return ++ } ++ ++ cb := func(ctx context.Context, chunk *ModelResponseChunk) error { ++ if ctx.Err() != nil { ++ return ctx.Err() ++ } ++ streamValue, err := extractTypedOutput[Out](chunk) ++ if err != nil { ++ yield(nil, err) ++ return err ++ } ++ // Skip yielding if there's no parseable output yet (e.g., incomplete JSON during streaming). ++ if base.IsNil(streamValue) { ++ return nil ++ } ++ if !yield(&StreamValue[Out, Out]{Chunk: streamValue}, nil) { ++ return errGenerateStop ++ } ++ return nil ++ } ++ ++ allOpts := append(slices.Clone(opts), WithInput(input), WithStreaming(cb)) ++ resp, err := dp.prompt.Execute(ctx, allOpts...) ++ if err != nil { ++ yield(nil, err) ++ return ++ } ++ ++ output, err := extractTypedOutput[Out](resp) ++ if err != nil { ++ yield(nil, err) ++ return ++ } ++ ++ yield(&StreamValue[Out, Out]{Done: true, Output: output, Response: resp}, nil) ++ } ++} ++ ++// Render renders the typed prompt template with the given input. ++func (dp *DataPrompt[In, Out]) Render(ctx context.Context, input In) (*GenerateActionOptions, error) { ++ if dp == nil { ++ return nil, errors.New("DataPrompt.Render: prompt is nil") ++ } ++ ++ return dp.prompt.Render(ctx, input) ++} +diff --git a/go/ai/prompt_test.go b/go/ai/prompt_test.go +index f711f6321..3228a8f41 100644 +--- a/go/ai/prompt_test.go ++++ b/go/ai/prompt_test.go +@@ -16,11 +16,13 @@ package ai + + import ( + "context" ++ "errors" + "fmt" + "os" + "path/filepath" + "strings" + "testing" ++ "testing/fstest" + + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/api" +@@ -885,70 +887,6 @@ func assertResponse(t *testing.T, resp *ModelResponse, want string) { + } + } + +-func TestLoadPrompt(t *testing.T) { +- // Create a temporary directory for testing +- tempDir := t.TempDir() +- +- // Create a mock .prompt file +- mockPromptFile := filepath.Join(tempDir, "example.prompt") +- mockPromptContent := `--- +-model: test-model +-maxTurns: 5 +-description: A test prompt +-toolChoice: required +-returnToolRequests: true +-input: +- schema: +- type: object +- properties: +- name: +- type: string +- default: +- name: world +-output: +- format: text +- schema: +- type: string +---- +-Hello, {{name}}! +-` +- err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0o644) +- if err != nil { +- t.Fatalf("Failed to create mock prompt file: %v", err) +- } +- +- // Initialize a mock registry +- reg := registry.New() +- +- // Call loadPrompt +- LoadPrompt(reg, tempDir, "example.prompt", "test-namespace") +- +- // Verify that the prompt was registered correctly +- prompt := LookupPrompt(reg, "test-namespace/example") +- if prompt == nil { +- t.Fatalf("Prompt was not registered") +- } +- +- if prompt.(api.Action).Desc().InputSchema == nil { +- t.Fatal("Input schema is nil") +- } +- +- if prompt.(api.Action).Desc().InputSchema["type"] != "object" { +- t.Errorf("Expected input schema type 'object', got '%s'", prompt.(api.Action).Desc().InputSchema["type"]) +- } +- +- promptMetadata, ok := prompt.(api.Action).Desc().Metadata["prompt"].(map[string]any) +- if !ok { +- t.Fatalf("Expected Metadata['prompt'] to be a map, but got %T", prompt.(api.Action).Desc().Metadata["prompt"]) +- } +- if promptMetadata["model"] != "test-model" { +- t.Errorf("Expected model name 'test-model', got '%s'", prompt.(api.Action).Desc().Metadata["model"]) +- } +- if promptMetadata["maxTurns"] != 5 { +- t.Errorf("Expected maxTurns set to 5, got: %d", promptMetadata["maxTurns"]) +- } +-} +- + func TestLoadPromptSnakeCase(t *testing.T) { + tempDir := t.TempDir() + mockPromptFile := filepath.Join(tempDir, "snake.prompt") +@@ -970,7 +908,7 @@ input: + } + + reg := registry.New() +- LoadPrompt(reg, tempDir, "snake.prompt", "snake-namespace") ++ LoadPromptFromFS(reg, os.DirFS(tempDir), ".", "snake.prompt", "snake-namespace") + + prompt := LookupPrompt(reg, "snake-namespace/snake") + if prompt == nil { +@@ -1018,8 +956,9 @@ func TestLoadPrompt_FileNotFound(t *testing.T) { + // Initialize a mock registry + reg := registry.New() + +- // Call loadPrompt with a non-existent file +- LoadPrompt(reg, "./nonexistent", "missing.prompt", "test-namespace") ++ // Call loadPrompt with a non-existent file in a valid temp directory ++ tempDir := t.TempDir() ++ LoadPromptFromFS(reg, os.DirFS(tempDir), ".", "missing.prompt", "test-namespace") + + // Verify that the prompt was not registered + prompt := LookupPrompt(reg, "missing") +@@ -1044,7 +983,7 @@ func TestLoadPrompt_InvalidPromptFile(t *testing.T) { + reg := registry.New() + + // Call loadPrompt +- LoadPrompt(reg, tempDir, "invalid.prompt", "test-namespace") ++ LoadPromptFromFS(reg, os.DirFS(tempDir), ".", "invalid.prompt", "test-namespace") + + // Verify that the prompt was not registered + prompt := LookupPrompt(reg, "invalid") +@@ -1075,7 +1014,7 @@ Hello, {{name}}! + reg := registry.New() + + // Call loadPrompt +- LoadPrompt(reg, tempDir, "example.variant.prompt", "test-namespace") ++ LoadPromptFromFS(reg, os.DirFS(tempDir), ".", "example.variant.prompt", "test-namespace") + + // Verify that the prompt was registered correctly + prompt := LookupPrompt(reg, "test-namespace/example.variant") +@@ -1096,6 +1035,50 @@ Hello, {{name}}! + } + } + ++func TestDefinePrompt_WithVariant(t *testing.T) { ++ reg := registry.New() ++ ++ DefinePrompt(reg, "example.code", WithPrompt("Hello, {{name}}!")) ++ ++ prompt := LookupPrompt(reg, "example.code") ++ if prompt == nil { ++ t.Fatalf("Prompt was not registered") ++ } ++ ++ promptMetadata, ok := prompt.(api.Action).Desc().Metadata["prompt"].(map[string]any) ++ if !ok { ++ t.Fatalf("Expected Metadata['prompt'] to be a map") ++ } ++ if promptMetadata["name"] != "example" { ++ t.Errorf("Expected metadata name 'example', got '%s'", promptMetadata["name"]) ++ } ++ if promptMetadata["variant"] != "code" { ++ t.Errorf("Expected variant 'code', got '%v'", promptMetadata["variant"]) ++ } ++} ++ ++func TestDefinePrompt_WithoutVariant(t *testing.T) { ++ reg := registry.New() ++ ++ DefinePrompt(reg, "simple", WithPrompt("Hello, world!")) ++ ++ prompt := LookupPrompt(reg, "simple") ++ if prompt == nil { ++ t.Fatalf("Prompt was not registered") ++ } ++ ++ promptMetadata, ok := prompt.(api.Action).Desc().Metadata["prompt"].(map[string]any) ++ if !ok { ++ t.Fatalf("Expected Metadata['prompt'] to be a map") ++ } ++ if promptMetadata["name"] != "simple" { ++ t.Errorf("Expected metadata name 'simple', got '%s'", promptMetadata["name"]) ++ } ++ if _, exists := promptMetadata["variant"]; exists { ++ t.Errorf("Expected no variant for prompt without dot, got '%v'", promptMetadata["variant"]) ++ } ++} ++ + func TestLoadPromptFolder(t *testing.T) { + // Create a temporary directory for testing + tempDir := t.TempDir() +@@ -1142,7 +1125,7 @@ Hello, {{name}}! + reg := registry.New() + + // Call LoadPromptFolder +- LoadPromptDir(reg, tempDir, "test-namespace") ++ LoadPromptDirFromFS(reg, os.DirFS(tempDir), ".", "test-namespace") + + // Verify that the prompt was registered correctly + prompt := LookupPrompt(reg, "test-namespace/example") +@@ -1157,19 +1140,298 @@ Hello, {{name}}! + } + } + +-func TestLoadPromptFolder_DirectoryNotFound(t *testing.T) { ++func TestLoadPromptFolder_EmptyDirectory(t *testing.T) { + // Initialize a mock registry +- reg := ®istry.Registry{} ++ reg := registry.New() ++ ++ // Create an empty temp directory ++ tempDir := t.TempDir() + +- // Call LoadPromptFolder with a non-existent directory +- LoadPromptDir(reg, "", "test-namespace") ++ // Call LoadPromptFolder with an empty directory ++ LoadPromptDirFromFS(reg, os.DirFS(tempDir), ".", "test-namespace") + + // Verify that no prompts were registered + if prompt := LookupPrompt(reg, "example"); prompt != nil { +- t.Fatalf("Prompt should not have been registered for a non-existent directory") ++ t.Fatalf("Prompt should not have been registered for an empty directory") ++ } ++} ++ ++func TestLoadPromptFS(t *testing.T) { ++ mockPromptContent := `--- ++model: test/chat ++description: A test prompt ++input: ++ schema: ++ type: object ++ properties: ++ name: ++ type: string ++output: ++ format: text ++ schema: ++ type: string ++--- ++ ++Hello, {{name}}! ++` ++ mockPartialContent := `Welcome {{name}}!` ++ ++ fsys := fstest.MapFS{ ++ "prompts/example.prompt": &fstest.MapFile{Data: []byte(mockPromptContent)}, ++ "prompts/sub/nested.prompt": &fstest.MapFile{Data: []byte(mockPromptContent)}, ++ "prompts/_greeting.prompt": &fstest.MapFile{Data: []byte(mockPartialContent)}, ++ } ++ ++ reg := registry.New() ++ ++ LoadPromptDirFromFS(reg, fsys, "prompts", "test-namespace") ++ ++ prompt := LookupPrompt(reg, "test-namespace/example") ++ if prompt == nil { ++ t.Fatalf("Prompt 'test-namespace/example' was not registered") ++ } ++ ++ nestedPrompt := LookupPrompt(reg, "test-namespace/nested") ++ if nestedPrompt == nil { ++ t.Fatalf("Nested prompt 'test-namespace/nested' was not registered") ++ } ++} ++ ++func TestLoadPromptFS_WithVariant(t *testing.T) { ++ mockPromptContent := `--- ++model: test/chat ++description: A test prompt with variant ++--- ++ ++Hello from variant! ++` ++ ++ fsys := fstest.MapFS{ ++ "prompts/greeting.experimental.prompt": &fstest.MapFile{Data: []byte(mockPromptContent)}, ++ } ++ ++ reg := registry.New() ++ ++ LoadPromptDirFromFS(reg, fsys, "prompts", "") ++ ++ prompt := LookupPrompt(reg, "greeting.experimental") ++ if prompt == nil { ++ t.Fatalf("Prompt with variant 'greeting.experimental' was not registered") ++ } ++} ++ ++func TestLoadPromptFS_NilFS(t *testing.T) { ++ reg := registry.New() ++ ++ defer func() { ++ if r := recover(); r == nil { ++ t.Errorf("Expected panic for nil filesystem") ++ } ++ }() ++ ++ LoadPromptDirFromFS(reg, nil, "prompts", "test-namespace") ++} ++ ++func TestLoadPromptFS_InvalidRoot(t *testing.T) { ++ fsys := fstest.MapFS{ ++ "other/example.prompt": &fstest.MapFile{Data: []byte("test")}, ++ } ++ ++ reg := registry.New() ++ ++ defer func() { ++ if r := recover(); r == nil { ++ t.Errorf("Expected panic for invalid root directory") ++ } ++ }() ++ ++ LoadPromptDirFromFS(reg, fsys, "nonexistent", "test-namespace") ++} ++ ++func TestLoadPromptFromFS(t *testing.T) { ++ mockPromptContent := `--- ++model: test/chat ++description: A single prompt test ++--- ++ ++Test content ++` ++ ++ fsys := fstest.MapFS{ ++ "prompts/single.prompt": &fstest.MapFile{Data: []byte(mockPromptContent)}, ++ } ++ ++ reg := registry.New() ++ ++ prompt := LoadPromptFromFS(reg, fsys, "prompts", "single.prompt", "ns") ++ if prompt == nil { ++ t.Fatalf("LoadPromptFromFS failed to load prompt") ++ } ++ ++ lookedUp := LookupPrompt(reg, "ns/single") ++ if lookedUp == nil { ++ t.Fatalf("Prompt 'ns/single' was not registered") + } + } + ++func TestLoadPromptFromRaw(t *testing.T) { ++ t.Run("basic prompt", func(t *testing.T) { ++ reg := registry.New() ++ ++ source := `--- ++model: test/chat ++description: A raw prompt test ++input: ++ schema: ++ name: string ++--- ++Hello, {{name}}! ++` ++ prompt, err := LoadPromptFromSource(reg, source, "rawPrompt", "test-ns") ++ if err != nil { ++ t.Fatalf("LoadPromptFromRaw failed: %v", err) ++ } ++ if prompt == nil { ++ t.Fatal("LoadPromptFromRaw returned nil prompt") ++ } ++ ++ lookedUp := LookupPrompt(reg, "test-ns/rawPrompt") ++ if lookedUp == nil { ++ t.Fatal("Prompt 'test-ns/rawPrompt' was not registered") ++ } ++ ++ actionOpts, err := prompt.Render(context.Background(), map[string]any{"name": "World"}) ++ if err != nil { ++ t.Fatalf("Render failed: %v", err) ++ } ++ if len(actionOpts.Messages) == 0 { ++ t.Fatal("Expected messages to be rendered") ++ } ++ renderedText := actionOpts.Messages[0].Text() ++ if renderedText != "Hello, World!" { ++ t.Errorf("Expected 'Hello, World!', got %q", renderedText) ++ } ++ }) ++ ++ t.Run("prompt with variant", func(t *testing.T) { ++ reg := registry.New() ++ ++ source := `--- ++model: test/chat ++description: A variant prompt ++--- ++Formal greeting ++` ++ prompt, err := LoadPromptFromSource(reg, source, "greeting.formal", "") ++ if err != nil { ++ t.Fatalf("LoadPromptFromRaw failed: %v", err) ++ } ++ if prompt == nil { ++ t.Fatal("LoadPromptFromRaw returned nil prompt") ++ } ++ ++ lookedUp := LookupPrompt(reg, "greeting.formal") ++ if lookedUp == nil { ++ t.Fatal("Prompt 'greeting.formal' was not registered") ++ } ++ ++ promptMetadata, ok := lookedUp.(api.Action).Desc().Metadata["prompt"].(map[string]any) ++ if !ok { ++ t.Fatal("Expected Metadata['prompt'] to be a map") ++ } ++ if promptMetadata["name"] != "greeting" { ++ t.Errorf("Expected metadata name 'greeting', got '%s'", promptMetadata["name"]) ++ } ++ if promptMetadata["variant"] != "formal" { ++ t.Errorf("Expected variant 'formal', got '%v'", promptMetadata["variant"]) ++ } ++ }) ++ ++ t.Run("prompt without namespace", func(t *testing.T) { ++ reg := registry.New() ++ ++ source := `--- ++model: test/chat ++--- ++Simple prompt ++` ++ prompt, err := LoadPromptFromSource(reg, source, "simple", "") ++ if err != nil { ++ t.Fatalf("LoadPromptFromRaw failed: %v", err) ++ } ++ if prompt == nil { ++ t.Fatal("LoadPromptFromRaw returned nil prompt") ++ } ++ ++ lookedUp := LookupPrompt(reg, "simple") ++ if lookedUp == nil { ++ t.Fatal("Prompt 'simple' was not registered") ++ } ++ }) ++ ++ t.Run("prompt with inline output schema", func(t *testing.T) { ++ reg := registry.New() ++ ConfigureFormats(reg) ++ ++ source := `--- ++model: test/chat ++output: ++ format: json ++ schema: ++ type: object ++ properties: ++ title: ++ type: string ++ description: ++ type: string ++ required: ++ - title ++ - description ++--- ++Generate something ++` ++ prompt, err := LoadPromptFromSource(reg, source, "outputSchemaPrompt", "") ++ if err != nil { ++ t.Fatalf("LoadPromptFromRaw failed: %v", err) ++ } ++ if prompt == nil { ++ t.Fatal("LoadPromptFromRaw returned nil prompt") ++ } ++ ++ actionOpts, err := prompt.Render(context.Background(), nil) ++ if err != nil { ++ t.Fatalf("Render failed: %v", err) ++ } ++ ++ // Verify that the output config is set correctly ++ if actionOpts.Output == nil { ++ t.Fatal("Expected Output config to be set") ++ } ++ if actionOpts.Output.Format != OutputFormatJSON { ++ t.Errorf("Expected output format 'json', got %q", actionOpts.Output.Format) ++ } ++ if actionOpts.Output.JsonSchema == nil { ++ t.Fatal("Expected output JsonSchema to be set for inline schema") ++ } ++ ++ // Verify the schema structure ++ schema := actionOpts.Output.JsonSchema ++ if schema["type"] != "object" { ++ t.Errorf("Expected schema type 'object', got %v", schema["type"]) ++ } ++ properties, ok := schema["properties"].(map[string]any) ++ if !ok { ++ t.Fatal("Expected schema properties to be a map") ++ } ++ if _, ok := properties["title"]; !ok { ++ t.Error("Expected schema to have 'title' property") ++ } ++ if _, ok := properties["description"]; !ok { ++ t.Error("Expected schema to have 'description' property") ++ } ++ }) ++} ++ + // TestDefinePartialAndHelperJourney demonstrates a complete user journey for defining + // and using both partials and helpers. + func TestDefinePartialAndHelper(t *testing.T) { +@@ -1230,72 +1492,34 @@ Hello! + ConfigureFormats(reg) + definePromptModel(reg) + +- prompt := LoadPrompt(reg, tempDir, "example.prompt", "multi-namespace") +- +- _, err = prompt.Execute(context.Background()) +- if err != nil { +- t.Fatalf("Failed to execute prompt: %v", err) +- } +-} +- +-func TestMultiMessagesRenderPrompt(t *testing.T) { +- tempDir := t.TempDir() +- +- mockPromptFile := filepath.Join(tempDir, "example.prompt") +- mockPromptContent := `--- +-model: test/chat +-description: A test prompt +---- +-<<>> +-You are a pirate! +- +-<<>> +-Hello! +-` +- +- if err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0o644); err != nil { +- t.Fatalf("Failed to create mock prompt file: %v", err) +- } +- +- prompt := LoadPrompt(registry.New(), tempDir, "example.prompt", "multi-namespace-roles") ++ prompt := LoadPromptFromFS(reg, os.DirFS(tempDir), ".", "example.prompt", "multi-namespace") + +- actionOpts, err := prompt.Render(context.Background(), map[string]any{}) ++ result, err := prompt.Execute(context.Background()) + if err != nil { + t.Fatalf("Failed to execute prompt: %v", err) + } + +- // Check that actionOpts is not nil +- if actionOpts == nil { +- t.Fatal("Expected actionOpts to be non-nil") +- } +- + // Check that we have exactly 2 messages (system and user) +- if len(actionOpts.Messages) != 2 { +- t.Fatalf("Expected 2 messages, got %d", len(actionOpts.Messages)) ++ if len(result.Request.Messages) != 2 { ++ t.Fatalf("Expected 2 messages, got %d", len(result.Request.Messages)) + } + + // Check first message (system role) +- systemMsg := actionOpts.Messages[0] ++ systemMsg := result.Request.Messages[0] + if systemMsg.Role != RoleSystem { + t.Errorf("Expected first message role to be 'system', got '%s'", systemMsg.Role) + } +- if len(systemMsg.Content) == 0 { +- t.Fatal("Expected system message to have content") +- } +- if strings.TrimSpace(systemMsg.Content[0].Text) != "You are a pirate!" { +- t.Errorf("Expected system message text to be 'You are a pirate!', got '%s'", systemMsg.Content[0].Text) ++ if strings.TrimSpace(systemMsg.Text()) != "You are a pirate!" { ++ t.Errorf("Expected system message text to be 'You are a pirate!', got '%s'", systemMsg.Text()) + } + + // Check second message (user role) +- userMsg := actionOpts.Messages[1] ++ userMsg := result.Request.Messages[1] + if userMsg.Role != RoleUser { + t.Errorf("Expected second message role to be 'user', got '%s'", userMsg.Role) + } +- if len(userMsg.Content) == 0 { +- t.Fatal("Expected user message to have content") +- } +- if strings.TrimSpace(userMsg.Content[0].Text) != "Hello!" { +- t.Errorf("Expected user message text to be 'Hello!', got '%s'", userMsg.Content[0].Text) ++ if strings.TrimSpace(userMsg.Text()) != "Hello!" { ++ t.Errorf("Expected user message text to be 'Hello!', got '%s'", userMsg.Text()) + } + } + +@@ -1518,3 +1742,1243 @@ func TestWithOutputSchemaName_DefinePrompt_Missing(t *testing.T) { + t.Errorf("Expected error 'schema \"MissingSchema\" not found', got: %v", err) + } + } ++ ++func TestDataPromptExecute(t *testing.T) { ++ r := registry.New() ++ ConfigureFormats(r) ++ DefineGenerateAction(context.Background(), r) ++ ++ type GreetingInput struct { ++ Name string `json:"name"` ++ } ++ ++ type GreetingOutput struct { ++ Message string `json:"message"` ++ Count int `json:"count"` ++ } ++ ++ t.Run("typed input and output", func(t *testing.T) { ++ var capturedInput any ++ ++ testModel := DefineModel(r, "test/dataPromptModel", &ModelOptions{ ++ Supports: &ModelSupports{ ++ Multiturn: true, ++ Constrained: ConstrainedSupportAll, ++ }, ++ }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ capturedInput = req.Messages[0].Text() ++ return &ModelResponse{ ++ Request: req, ++ Message: &Message{ ++ Role: RoleModel, ++ Content: []*Part{NewJSONPart(`{"message":"Hello, Alice!","count":1}`)}, ++ }, ++ }, nil ++ }) ++ ++ dp := DefineDataPrompt[GreetingInput, GreetingOutput](r, "greetingPrompt", ++ WithModel(testModel), ++ WithPrompt("Greet {{name}}"), ++ ) ++ ++ output, resp, err := dp.Execute(context.Background(), GreetingInput{Name: "Alice"}) ++ if err != nil { ++ t.Fatalf("Execute failed: %v", err) ++ } ++ ++ if capturedInput != "Greet Alice" { ++ t.Errorf("expected input %q, got %q", "Greet Alice", capturedInput) ++ } ++ ++ if output.Message != "Hello, Alice!" { ++ t.Errorf("expected message %q, got %q", "Hello, Alice!", output.Message) ++ } ++ if output.Count != 1 { ++ t.Errorf("expected count 1, got %d", output.Count) ++ } ++ if resp == nil { ++ t.Error("expected response to be returned") ++ } ++ }) ++ ++ t.Run("string output type", func(t *testing.T) { ++ testModel := DefineModel(r, "test/stringDataPromptModel", &ModelOptions{ ++ Supports: &ModelSupports{Multiturn: true}, ++ }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ return &ModelResponse{ ++ Request: req, ++ Message: NewModelTextMessage("Hello, World!"), ++ }, nil ++ }) ++ ++ dp := DefineDataPrompt[GreetingInput, string](r, "stringOutputPrompt", ++ WithModel(testModel), ++ WithPrompt("Say hello to {{name}}"), ++ ) ++ ++ output, resp, err := dp.Execute(context.Background(), GreetingInput{Name: "World"}) ++ if err != nil { ++ t.Fatalf("Execute failed: %v", err) ++ } ++ ++ if output != "Hello, World!" { ++ t.Errorf("expected output %q, got %q", "Hello, World!", output) ++ } ++ if resp == nil { ++ t.Error("expected response to be returned") ++ } ++ }) ++ ++ t.Run("nil prompt returns error", func(t *testing.T) { ++ var dp *DataPrompt[GreetingInput, GreetingOutput] ++ ++ _, _, err := dp.Execute(context.Background(), GreetingInput{Name: "test"}) ++ if err == nil { ++ t.Error("expected error for nil prompt") ++ } ++ }) ++ ++ t.Run("additional options passed through", func(t *testing.T) { ++ var capturedConfig any ++ ++ testModel := DefineModel(r, "test/optionsDataPromptModel", &ModelOptions{ ++ Supports: &ModelSupports{ ++ Multiturn: true, ++ Constrained: ConstrainedSupportAll, ++ }, ++ }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ capturedConfig = req.Config ++ return &ModelResponse{ ++ Request: req, ++ Message: &Message{ ++ Role: RoleModel, ++ Content: []*Part{NewJSONPart(`{"message":"test","count":0}`)}, ++ }, ++ }, nil ++ }) ++ ++ dp := DefineDataPrompt[GreetingInput, GreetingOutput](r, "optionsPrompt", ++ WithModel(testModel), ++ WithPrompt("Test {{name}}"), ++ ) ++ ++ _, _, err := dp.Execute(context.Background(), GreetingInput{Name: "test"}, ++ WithConfig(&GenerationCommonConfig{Temperature: 0.5}), ++ ) ++ if err != nil { ++ t.Fatalf("Execute failed: %v", err) ++ } ++ ++ config, ok := capturedConfig.(*GenerationCommonConfig) ++ if !ok { ++ t.Fatalf("expected *GenerationCommonConfig, got %T", capturedConfig) ++ } ++ if config.Temperature != 0.5 { ++ t.Errorf("expected temperature 0.5, got %v", config.Temperature) ++ } ++ }) ++ ++ t.Run("returns error for invalid output parsing", func(t *testing.T) { ++ testModel := DefineModel(r, "test/parseFailDataPromptModel", &ModelOptions{ ++ Supports: &ModelSupports{ ++ Multiturn: true, ++ Constrained: ConstrainedSupportAll, ++ }, ++ }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ return &ModelResponse{ ++ Request: req, ++ Message: NewModelTextMessage("not valid json"), ++ }, nil ++ }) ++ ++ dp := DefineDataPrompt[GreetingInput, GreetingOutput](r, "parseFailPrompt", ++ WithModel(testModel), ++ WithPrompt("Test {{name}}"), ++ ) ++ ++ _, _, err := dp.Execute(context.Background(), GreetingInput{Name: "test"}) ++ if err == nil { ++ t.Error("expected error for invalid JSON output") ++ } ++ }) ++} ++ ++func TestDataPromptExecuteStream(t *testing.T) { ++ r := registry.New() ++ ConfigureFormats(r) ++ DefineGenerateAction(context.Background(), r) ++ ++ type StreamInput struct { ++ Topic string `json:"topic"` ++ } ++ ++ type StreamOutput struct { ++ Text string `json:"text"` ++ Index int `json:"index"` ++ } ++ ++ t.Run("typed streaming with struct output", func(t *testing.T) { ++ testModel := DefineModel(r, "test/streamDataPromptModel", &ModelOptions{ ++ Supports: &ModelSupports{ ++ Multiturn: true, ++ Constrained: ConstrainedSupportAll, ++ }, ++ }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ if cb != nil { ++ cb(ctx, &ModelResponseChunk{ ++ Content: []*Part{NewJSONPart(`{"text":"chunk1","index":1}`)}, ++ }) ++ cb(ctx, &ModelResponseChunk{ ++ Content: []*Part{NewJSONPart(`{"text":"final","index":99}`)}, ++ }) ++ } ++ return &ModelResponse{ ++ Request: req, ++ Message: &Message{ ++ Role: RoleModel, ++ Content: []*Part{NewJSONPart(`{"text":"final","index":99}`)}, ++ }, ++ }, nil ++ }) ++ ++ dp := DefineDataPrompt[StreamInput, StreamOutput](r, "streamPrompt", ++ WithModel(testModel), ++ WithPrompt("Stream about {{topic}}"), ++ ) ++ ++ var chunks []StreamOutput ++ var finalOutput StreamOutput ++ var finalResponse *ModelResponse ++ ++ for val, err := range dp.ExecuteStream(context.Background(), StreamInput{Topic: "testing"}) { ++ if err != nil { ++ t.Fatalf("unexpected error: %v", err) ++ } ++ if val.Done { ++ finalOutput = val.Output ++ finalResponse = val.Response ++ } else { ++ chunks = append(chunks, val.Chunk) ++ } ++ } ++ ++ if len(chunks) < 1 { ++ t.Errorf("expected at least 1 chunk, got %d", len(chunks)) ++ } ++ ++ if finalOutput.Text != "final" || finalOutput.Index != 99 { ++ t.Errorf("expected final {final, 99}, got %+v", finalOutput) ++ } ++ if finalResponse == nil { ++ t.Error("expected final response") ++ } ++ }) ++ ++ t.Run("string output streaming", func(t *testing.T) { ++ testModel := DefineModel(r, "test/stringStreamDataPromptModel", &ModelOptions{ ++ Supports: &ModelSupports{Multiturn: true}, ++ }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ if cb != nil { ++ cb(ctx, &ModelResponseChunk{ ++ Content: []*Part{NewTextPart("First ")}, ++ }) ++ cb(ctx, &ModelResponseChunk{ ++ Content: []*Part{NewTextPart("Second")}, ++ }) ++ } ++ return &ModelResponse{ ++ Request: req, ++ Message: NewModelTextMessage("First Second"), ++ }, nil ++ }) ++ ++ dp := DefineDataPrompt[StreamInput, string](r, "stringStreamPrompt", ++ WithModel(testModel), ++ WithPrompt("Generate text about {{topic}}"), ++ ) ++ ++ var chunks []string ++ var finalOutput string ++ ++ for val, err := range dp.ExecuteStream(context.Background(), StreamInput{Topic: "strings"}) { ++ if err != nil { ++ t.Fatalf("unexpected error: %v", err) ++ } ++ if val.Done { ++ finalOutput = val.Output ++ } else { ++ chunks = append(chunks, val.Chunk) ++ } ++ } ++ ++ if len(chunks) != 2 { ++ t.Errorf("expected 2 chunks, got %d", len(chunks)) ++ } ++ if chunks[0] != "First " { ++ t.Errorf("chunk 0: expected %q, got %q", "First ", chunks[0]) ++ } ++ if chunks[1] != "Second" { ++ t.Errorf("chunk 1: expected %q, got %q", "Second", chunks[1]) ++ } ++ ++ if finalOutput != "First Second" { ++ t.Errorf("expected final %q, got %q", "First Second", finalOutput) ++ } ++ }) ++ ++ t.Run("nil prompt returns error", func(t *testing.T) { ++ var dp *DataPrompt[StreamInput, StreamOutput] ++ ++ var receivedErr error ++ for _, err := range dp.ExecuteStream(context.Background(), StreamInput{Topic: "test"}) { ++ if err != nil { ++ receivedErr = err ++ break ++ } ++ } ++ ++ if receivedErr == nil { ++ t.Error("expected error for nil prompt") ++ } ++ }) ++ ++ t.Run("handles options passed at execute time", func(t *testing.T) { ++ var capturedConfig any ++ ++ testModel := DefineModel(r, "test/optionsStreamModel", &ModelOptions{ ++ Supports: &ModelSupports{ ++ Multiturn: true, ++ Constrained: ConstrainedSupportAll, ++ }, ++ }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ capturedConfig = req.Config ++ if cb != nil { ++ cb(ctx, &ModelResponseChunk{ ++ Content: []*Part{NewJSONPart(`{"text":"chunk","index":1}`)}, ++ }) ++ } ++ return &ModelResponse{ ++ Request: req, ++ Message: &Message{ ++ Role: RoleModel, ++ Content: []*Part{NewJSONPart(`{"text":"done","index":2}`)}, ++ }, ++ }, nil ++ }) ++ ++ dp := DefineDataPrompt[StreamInput, StreamOutput](r, "optionsStreamPrompt", ++ WithModel(testModel), ++ WithPrompt("Test {{topic}}"), ++ ) ++ ++ for range dp.ExecuteStream(context.Background(), StreamInput{Topic: "options"}, ++ WithConfig(&GenerationCommonConfig{Temperature: 0.7}), ++ ) { ++ } ++ ++ config, ok := capturedConfig.(*GenerationCommonConfig) ++ if !ok { ++ t.Fatalf("expected *GenerationCommonConfig, got %T", capturedConfig) ++ } ++ if config.Temperature != 0.7 { ++ t.Errorf("expected temperature 0.7, got %v", config.Temperature) ++ } ++ }) ++ ++ t.Run("propagates errors", func(t *testing.T) { ++ expectedErr := errors.New("stream failed") ++ ++ testModel := DefineModel(r, "test/errorStreamDataPromptModel", &ModelOptions{ ++ Supports: &ModelSupports{Multiturn: true}, ++ }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ return nil, expectedErr ++ }) ++ ++ dp := DefineDataPrompt[StreamInput, StreamOutput](r, "errorStreamPrompt", ++ WithModel(testModel), ++ WithPrompt("Test {{topic}}"), ++ ) ++ ++ var receivedErr error ++ for _, err := range dp.ExecuteStream(context.Background(), StreamInput{Topic: "error"}) { ++ if err != nil { ++ receivedErr = err ++ break ++ } ++ } ++ ++ if receivedErr == nil { ++ t.Error("expected error to be propagated") ++ } ++ if !errors.Is(receivedErr, expectedErr) { ++ t.Errorf("expected error %v, got %v", expectedErr, receivedErr) ++ } ++ }) ++} ++ ++func TestPromptExecuteStream(t *testing.T) { ++ r := registry.New() ++ ConfigureFormats(r) ++ DefineGenerateAction(context.Background(), r) ++ ++ t.Run("yields chunks then final response", func(t *testing.T) { ++ chunkTexts := []string{"A", "B", "C"} ++ ++ testModel := DefineModel(r, "test/promptStreamModel", &ModelOptions{ ++ Supports: &ModelSupports{Multiturn: true}, ++ }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ if cb != nil { ++ for _, text := range chunkTexts { ++ cb(ctx, &ModelResponseChunk{ ++ Content: []*Part{NewTextPart(text)}, ++ }) ++ } ++ } ++ return &ModelResponse{ ++ Request: req, ++ Message: NewModelTextMessage("ABC"), ++ }, nil ++ }) ++ ++ p := DefinePrompt(r, "streamTestPrompt", ++ WithModel(testModel), ++ WithPrompt("Test"), ++ ) ++ ++ var chunks []*ModelResponseChunk ++ var finalResponse *ModelResponse ++ ++ for val, err := range p.ExecuteStream(context.Background()) { ++ if err != nil { ++ t.Fatalf("unexpected error: %v", err) ++ } ++ if val.Done { ++ finalResponse = val.Response ++ } else { ++ chunks = append(chunks, val.Chunk) ++ } ++ } ++ ++ if len(chunks) != 3 { ++ t.Errorf("expected 3 chunks, got %d", len(chunks)) ++ } ++ for i, chunk := range chunks { ++ if chunk.Text() != chunkTexts[i] { ++ t.Errorf("chunk %d: expected %q, got %q", i, chunkTexts[i], chunk.Text()) ++ } ++ } ++ ++ if finalResponse == nil { ++ t.Fatal("expected final response") ++ } ++ if finalResponse.Text() != "ABC" { ++ t.Errorf("expected final text %q, got %q", "ABC", finalResponse.Text()) ++ } ++ }) ++ ++ t.Run("nil prompt returns error", func(t *testing.T) { ++ var p *prompt ++ ++ var receivedErr error ++ for _, err := range p.ExecuteStream(context.Background()) { ++ if err != nil { ++ receivedErr = err ++ break ++ } ++ } ++ ++ if receivedErr == nil { ++ t.Error("expected error for nil prompt") ++ } ++ }) ++ ++ t.Run("handles execution options", func(t *testing.T) { ++ var capturedConfig any ++ ++ testModel := DefineModel(r, "test/optionsPromptExecModel", &ModelOptions{ ++ Supports: &ModelSupports{Multiturn: true}, ++ }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ capturedConfig = req.Config ++ if cb != nil { ++ cb(ctx, &ModelResponseChunk{ ++ Content: []*Part{NewTextPart("chunk")}, ++ }) ++ } ++ return &ModelResponse{ ++ Request: req, ++ Message: NewModelTextMessage("done"), ++ }, nil ++ }) ++ ++ p := DefinePrompt(r, "execOptionsTestPrompt", ++ WithModel(testModel), ++ WithPrompt("Test"), ++ ) ++ ++ for range p.ExecuteStream(context.Background(), ++ WithConfig(&GenerationCommonConfig{Temperature: 0.9}), ++ ) { ++ } ++ ++ config, ok := capturedConfig.(*GenerationCommonConfig) ++ if !ok { ++ t.Fatalf("expected *GenerationCommonConfig, got %T", capturedConfig) ++ } ++ if config.Temperature != 0.9 { ++ t.Errorf("expected temperature 0.9, got %v", config.Temperature) ++ } ++ }) ++} ++ ++// TestDefineExecuteOptionInteractions tests the complex interactions between ++// options set at DefinePrompt time vs Execute time. ++func TestDefineExecuteOptionInteractions(t *testing.T) { ++ t.Run("ToolChoice override", func(t *testing.T) { ++ r := newTestRegistry(t) ++ var captured *ModelRequest ++ ++ model := defineFakeModel(t, r, fakeModelConfig{ ++ name: "test/toolChoiceModel", ++ handler: capturingModelHandler(&captured), ++ }) ++ ++ tool := defineFakeTool(t, r, "testTool", "a test tool") ++ ++ // Define with ToolChoiceAuto ++ p := DefinePrompt(r, "toolChoicePrompt", ++ WithModel(model), ++ WithPrompt("test"), ++ WithTools(tool), ++ WithToolChoice(ToolChoiceAuto), ++ WithMaxTurns(1), ++ ) ++ ++ // Execute with ToolChoiceRequired - should override ++ _, err := p.Execute(context.Background(), ++ WithToolChoice(ToolChoiceRequired), ++ ) ++ assertNoError(t, err) ++ ++ if captured.ToolChoice != ToolChoiceRequired { ++ t.Errorf("ToolChoice = %q, want %q", captured.ToolChoice, ToolChoiceRequired) ++ } ++ }) ++ ++ t.Run("ToolChoice no override when not specified at execute", func(t *testing.T) { ++ r := newTestRegistry(t) ++ var captured *ModelRequest ++ ++ model := defineFakeModel(t, r, fakeModelConfig{ ++ name: "test/toolChoiceNoOverride", ++ handler: capturingModelHandler(&captured), ++ }) ++ ++ tool := defineFakeTool(t, r, "testTool2", "a test tool") ++ ++ // Define with ToolChoiceRequired ++ p := DefinePrompt(r, "toolChoiceNoOverridePrompt", ++ WithModel(model), ++ WithPrompt("test"), ++ WithTools(tool), ++ WithToolChoice(ToolChoiceRequired), ++ WithMaxTurns(1), ++ ) ++ ++ // Execute without specifying ToolChoice - should use define-time value ++ _, err := p.Execute(context.Background()) ++ assertNoError(t, err) ++ ++ if captured.ToolChoice != ToolChoiceRequired { ++ t.Errorf("ToolChoice = %q, want %q", captured.ToolChoice, ToolChoiceRequired) ++ } ++ }) ++ ++ t.Run("MaxTurns override", func(t *testing.T) { ++ r := newTestRegistry(t) ++ callCount := 0 ++ ++ model := defineFakeModel(t, r, fakeModelConfig{ ++ name: "test/maxTurnsModel", ++ handler: func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ callCount++ ++ // Always request tool call to test max turns ++ if callCount < 10 { ++ return &ModelResponse{ ++ Request: req, ++ Message: &Message{ ++ Role: RoleModel, ++ Content: []*Part{NewToolRequestPart(&ToolRequest{ ++ Name: "maxTurnsTool", ++ Input: map[string]any{"value": "test"}, ++ })}, ++ }, ++ }, nil ++ } ++ return &ModelResponse{ ++ Request: req, ++ Message: NewModelTextMessage("done"), ++ }, nil ++ }, ++ }) ++ ++ tool := defineFakeTool(t, r, "maxTurnsTool", "a tool for max turns test") ++ ++ // Define with MaxTurns 5 ++ p := DefinePrompt(r, "maxTurnsPrompt", ++ WithModel(model), ++ WithPrompt("test"), ++ WithTools(tool), ++ WithMaxTurns(5), ++ ) ++ ++ // Execute with MaxTurns 2 - should override and stop after 2 turns ++ _, err := p.Execute(context.Background(), ++ WithMaxTurns(2), ++ ) ++ ++ // Should error due to max turns exceeded ++ if err == nil { ++ t.Error("expected max turns error, got nil") ++ } ++ // Call count should be limited by execute-time MaxTurns (2) + 1 for initial ++ if callCount > 3 { ++ t.Errorf("callCount = %d, expected <= 3 (limited by execute MaxTurns)", callCount) ++ } ++ }) ++ ++ t.Run("ReturnToolRequests override", func(t *testing.T) { ++ r := newTestRegistry(t) ++ ++ model := defineFakeModel(t, r, fakeModelConfig{ ++ name: "test/returnToolReqsModel", ++ handler: func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ return &ModelResponse{ ++ Request: req, ++ Message: &Message{ ++ Role: RoleModel, ++ Content: []*Part{NewToolRequestPart(&ToolRequest{ ++ Name: "returnToolReqsTool", ++ Input: map[string]any{"value": "test"}, ++ })}, ++ }, ++ }, nil ++ }, ++ }) ++ ++ tool := defineFakeTool(t, r, "returnToolReqsTool", "tool for return requests test") ++ ++ // Define with ReturnToolRequests false (default) ++ p := DefinePrompt(r, "returnToolReqsPrompt", ++ WithModel(model), ++ WithPrompt("test"), ++ WithTools(tool), ++ WithReturnToolRequests(false), ++ WithMaxTurns(1), ++ ) ++ ++ // Execute with ReturnToolRequests true - should override and return tool requests ++ resp, err := p.Execute(context.Background(), ++ WithReturnToolRequests(true), ++ ) ++ assertNoError(t, err) ++ ++ // Should have tool request in response ++ hasToolRequest := false ++ for _, part := range resp.Message.Content { ++ if part.IsToolRequest() { ++ hasToolRequest = true ++ break ++ } ++ } ++ if !hasToolRequest { ++ t.Error("expected tool request in response when ReturnToolRequests=true") ++ } ++ }) ++ ++ t.Run("Tools complete replacement", func(t *testing.T) { ++ r := newTestRegistry(t) ++ var captured *ModelRequest ++ ++ model := defineFakeModel(t, r, fakeModelConfig{ ++ name: "test/toolsReplaceModel", ++ handler: capturingModelHandler(&captured), ++ }) ++ ++ toolA := defineFakeTool(t, r, "toolA", "tool A") ++ toolB := defineFakeTool(t, r, "toolB", "tool B") ++ toolC := defineFakeTool(t, r, "toolC", "tool C") ++ ++ // Define with tools A and B ++ p := DefinePrompt(r, "toolsReplacePrompt", ++ WithModel(model), ++ WithPrompt("test"), ++ WithTools(toolA, toolB), ++ WithMaxTurns(1), ++ ) ++ ++ // Execute with tool C - should REPLACE (not merge) define-time tools ++ _, err := p.Execute(context.Background(), ++ WithTools(toolC), ++ ) ++ assertNoError(t, err) ++ ++ // Should only have tool C ++ if len(captured.Tools) != 1 { ++ t.Errorf("len(Tools) = %d, want 1", len(captured.Tools)) ++ } ++ if len(captured.Tools) > 0 && captured.Tools[0].Name != "toolC" { ++ t.Errorf("Tool name = %q, want %q", captured.Tools[0].Name, "toolC") ++ } ++ }) ++ ++ t.Run("Tools inherit when not specified at execute", func(t *testing.T) { ++ r := newTestRegistry(t) ++ var captured *ModelRequest ++ ++ model := defineFakeModel(t, r, fakeModelConfig{ ++ name: "test/toolsInheritModel", ++ handler: capturingModelHandler(&captured), ++ }) ++ ++ toolA := defineFakeTool(t, r, "toolInheritA", "tool A") ++ toolB := defineFakeTool(t, r, "toolInheritB", "tool B") ++ ++ // Define with tools A and B ++ p := DefinePrompt(r, "toolsInheritPrompt", ++ WithModel(model), ++ WithPrompt("test"), ++ WithTools(toolA, toolB), ++ WithMaxTurns(1), ++ ) ++ ++ // Execute without specifying tools - should inherit define-time tools ++ _, err := p.Execute(context.Background()) ++ assertNoError(t, err) ++ ++ if len(captured.Tools) != 2 { ++ t.Errorf("len(Tools) = %d, want 2", len(captured.Tools)) ++ } ++ }) ++ ++ t.Run("Docs at execute time", func(t *testing.T) { ++ r := newTestRegistry(t) ++ var captured *ModelRequest ++ ++ model := defineFakeModel(t, r, fakeModelConfig{ ++ name: "test/docsModel", ++ handler: capturingModelHandler(&captured), ++ }) ++ ++ // Define without docs ++ p := DefinePrompt(r, "docsPrompt", ++ WithModel(model), ++ WithPrompt("test"), ++ ) ++ ++ // Execute with docs ++ doc := DocumentFromText("context document", nil) ++ _, err := p.Execute(context.Background(), ++ WithDocs(doc), ++ ) ++ assertNoError(t, err) ++ ++ if len(captured.Docs) != 1 { ++ t.Errorf("len(Docs) = %d, want 1", len(captured.Docs)) ++ } ++ }) ++ ++ t.Run("Config replacement not merge", func(t *testing.T) { ++ r := newTestRegistry(t) ++ var captured *ModelRequest ++ ++ model := defineFakeModel(t, r, fakeModelConfig{ ++ name: "test/configReplaceModel", ++ handler: capturingModelHandler(&captured), ++ }) ++ ++ // Define with Temperature and TopK ++ p := DefinePrompt(r, "configReplacePrompt", ++ WithModel(model), ++ WithPrompt("test"), ++ WithConfig(&GenerationCommonConfig{Temperature: 0.5, TopK: 10}), ++ ) ++ ++ // Execute with only Temperature - config is REPLACED, not merged ++ _, err := p.Execute(context.Background(), ++ WithConfig(&GenerationCommonConfig{Temperature: 0.9}), ++ ) ++ assertNoError(t, err) ++ ++ config, ok := captured.Config.(*GenerationCommonConfig) ++ if !ok { ++ t.Fatalf("Config type = %T, want *GenerationCommonConfig", captured.Config) ++ } ++ if config.Temperature != 0.9 { ++ t.Errorf("Temperature = %v, want 0.9", config.Temperature) ++ } ++ // TopK should be zero (default) since config was replaced ++ if config.TopK != 0 { ++ t.Errorf("TopK = %v, want 0 (config replaced, not merged)", config.TopK) ++ } ++ }) ++ ++ t.Run("Model override at execute time", func(t *testing.T) { ++ r := newTestRegistry(t) ++ ++ defineModel := defineFakeModel(t, r, fakeModelConfig{ ++ name: "test/defineModel", ++ handler: func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ return &ModelResponse{ ++ Request: req, ++ Message: NewModelTextMessage("from define model"), ++ }, nil ++ }, ++ }) ++ ++ executeModel := defineFakeModel(t, r, fakeModelConfig{ ++ name: "test/executeModel", ++ handler: func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ return &ModelResponse{ ++ Request: req, ++ Message: NewModelTextMessage("from execute model"), ++ }, nil ++ }, ++ }) ++ ++ // Define with defineModel ++ p := DefinePrompt(r, "modelOverridePrompt", ++ WithModel(defineModel), ++ WithPrompt("test"), ++ ) ++ ++ // Execute with executeModel - should use execute model ++ resp, err := p.Execute(context.Background(), ++ WithModel(executeModel), ++ ) ++ assertNoError(t, err) ++ ++ if resp.Text() != "from execute model" { ++ t.Errorf("response = %q, want %q", resp.Text(), "from execute model") ++ } ++ }) ++ ++ t.Run("MessagesFn at execute time inserts between system and user", func(t *testing.T) { ++ r := newTestRegistry(t) ++ var captured *ModelRequest ++ ++ model := defineFakeModel(t, r, fakeModelConfig{ ++ name: "test/messagesFnModel", ++ handler: capturingModelHandler(&captured), ++ }) ++ ++ // Define with system and user prompt ++ p := DefinePrompt(r, "messagesFnPrompt", ++ WithModel(model), ++ WithSystem("system instruction"), ++ WithPrompt("user question"), ++ ) ++ ++ // Execute with MessagesFn - messages should be inserted between system and user ++ _, err := p.Execute(context.Background(), ++ WithMessages(NewModelTextMessage("conversation history")), ++ ) ++ assertNoError(t, err) ++ ++ // Expected order: system, MessagesFn content, user ++ if len(captured.Messages) != 3 { ++ t.Fatalf("len(Messages) = %d, want 3", len(captured.Messages)) ++ } ++ if captured.Messages[0].Role != RoleSystem { ++ t.Errorf("Messages[0].Role = %q, want %q", captured.Messages[0].Role, RoleSystem) ++ } ++ if captured.Messages[1].Role != RoleModel { ++ t.Errorf("Messages[1].Role = %q, want %q", captured.Messages[1].Role, RoleModel) ++ } ++ if captured.Messages[2].Role != RoleUser { ++ t.Errorf("Messages[2].Role = %q, want %q", captured.Messages[2].Role, RoleUser) ++ } ++ }) ++ ++ t.Run("ModelRef config used when no explicit config", func(t *testing.T) { ++ r := newTestRegistry(t) ++ var captured *ModelRequest ++ ++ // Define model first ++ defineFakeModel(t, r, fakeModelConfig{ ++ name: "test/modelRefConfigModel", ++ handler: capturingModelHandler(&captured), ++ }) ++ ++ // Create ModelRef with embedded config ++ modelRef := NewModelRef("test/modelRefConfigModel", &GenerationCommonConfig{Temperature: 0.7}) ++ ++ p := DefinePrompt(r, "modelRefConfigPrompt", ++ WithModel(modelRef), ++ WithPrompt("test"), ++ ) ++ ++ // Execute without config - should use ModelRef's config ++ _, err := p.Execute(context.Background()) ++ assertNoError(t, err) ++ ++ config, ok := captured.Config.(*GenerationCommonConfig) ++ if !ok { ++ t.Fatalf("Config type = %T, want *GenerationCommonConfig", captured.Config) ++ } ++ if config.Temperature != 0.7 { ++ t.Errorf("Temperature = %v, want 0.7", config.Temperature) ++ } ++ }) ++ ++ t.Run("Explicit config overrides ModelRef config", func(t *testing.T) { ++ r := newTestRegistry(t) ++ var captured *ModelRequest ++ ++ defineFakeModel(t, r, fakeModelConfig{ ++ name: "test/modelRefOverrideModel", ++ handler: capturingModelHandler(&captured), ++ }) ++ ++ modelRef := NewModelRef("test/modelRefOverrideModel", &GenerationCommonConfig{Temperature: 0.7}) ++ ++ p := DefinePrompt(r, "modelRefOverridePrompt", ++ WithModel(modelRef), ++ WithPrompt("test"), ++ ) ++ ++ // Execute with explicit config - should override ModelRef's config ++ _, err := p.Execute(context.Background(), ++ WithConfig(&GenerationCommonConfig{Temperature: 0.3}), ++ ) ++ assertNoError(t, err) ++ ++ config, ok := captured.Config.(*GenerationCommonConfig) ++ if !ok { ++ t.Fatalf("Config type = %T, want *GenerationCommonConfig", captured.Config) ++ } ++ if config.Temperature != 0.3 { ++ t.Errorf("Temperature = %v, want 0.3", config.Temperature) ++ } ++ }) ++} ++ ++// TestPromptErrorPaths tests error handling in prompt operations. ++func TestPromptErrorPaths(t *testing.T) { ++ t.Run("DefinePrompt with empty name panics", func(t *testing.T) { ++ r := newTestRegistry(t) ++ assertPanic(t, func() { ++ DefinePrompt(r, "") ++ }, "name is required") ++ }) ++ ++ t.Run("Execute on nil prompt returns error", func(t *testing.T) { ++ var p *prompt ++ _, err := p.Execute(context.Background()) ++ assertError(t, err, "prompt is nil") ++ }) ++ ++ t.Run("Render on nil prompt returns error", func(t *testing.T) { ++ var p *prompt ++ _, err := p.Render(context.Background(), nil) ++ assertError(t, err, "prompt is nil") ++ }) ++ ++ t.Run("ExecuteStream on nil prompt yields error", func(t *testing.T) { ++ var p *prompt ++ var gotErr error ++ for _, err := range p.ExecuteStream(context.Background()) { ++ if err != nil { ++ gotErr = err ++ break ++ } ++ } ++ assertError(t, gotErr, "prompt is nil") ++ }) ++ ++ t.Run("buildVariables with invalid type returns error", func(t *testing.T) { ++ // buildVariables expects struct, pointer to struct, or map ++ _, err := buildVariables(42) // int is not valid ++ if err == nil { ++ t.Error("expected error for invalid type, got nil") ++ } ++ }) ++} ++ ++// TestLookupPromptCoverage tests LookupPrompt edge cases. ++func TestLookupPromptCoverage(t *testing.T) { ++ t.Run("returns nil for non-existent prompt", func(t *testing.T) { ++ r := newTestRegistry(t) ++ p := LookupPrompt(r, "nonexistent") ++ if p != nil { ++ t.Error("expected nil for non-existent prompt") ++ } ++ }) ++ ++ t.Run("returns prompt for existing prompt", func(t *testing.T) { ++ r := newTestRegistry(t) ++ DefinePrompt(r, "existingPrompt", WithPrompt("hello")) ++ p := LookupPrompt(r, "existingPrompt") ++ if p == nil { ++ t.Error("expected prompt, got nil") ++ } ++ if p.Name() != "existingPrompt" { ++ t.Errorf("Name() = %q, want %q", p.Name(), "existingPrompt") ++ } ++ }) ++} ++ ++// TestDataPromptRender tests DataPrompt.Render method. ++func TestDataPromptRender(t *testing.T) { ++ r := newTestRegistry(t) ++ ++ type RenderInput struct { ++ Name string `json:"name"` ++ } ++ ++ type RenderOutput struct { ++ Greeting string `json:"greeting"` ++ } ++ ++ model := defineFakeModel(t, r, fakeModelConfig{ ++ name: "test/renderModel", ++ }) ++ ++ dp := DefineDataPrompt[RenderInput, RenderOutput](r, "renderPrompt", ++ WithModel(model), ++ WithPrompt("Hello {{name}}"), ++ ) ++ ++ t.Run("renders with typed input", func(t *testing.T) { ++ opts, err := dp.Render(context.Background(), RenderInput{Name: "World"}) ++ assertNoError(t, err) ++ ++ if len(opts.Messages) == 0 { ++ t.Fatal("expected messages") ++ } ++ if opts.Messages[0].Text() != "Hello World" { ++ t.Errorf("rendered text = %q, want %q", opts.Messages[0].Text(), "Hello World") ++ } ++ }) ++ ++ t.Run("nil DataPrompt returns error", func(t *testing.T) { ++ var nilDP *DataPrompt[RenderInput, RenderOutput] ++ _, err := nilDP.Render(context.Background(), RenderInput{}) ++ if err == nil { ++ t.Error("expected error for nil DataPrompt") ++ } ++ }) ++} ++ ++// TestLookupDataPrompt tests LookupDataPrompt function. ++func TestLookupDataPrompt(t *testing.T) { ++ r := newTestRegistry(t) ++ ++ model := defineFakeModel(t, r, fakeModelConfig{ ++ name: "test/lookupDataModel", ++ }) ++ ++ DefinePrompt(r, "lookupDataPrompt", ++ WithModel(model), ++ WithPrompt("test"), ++ ) ++ ++ t.Run("returns DataPrompt for existing prompt", func(t *testing.T) { ++ dp := LookupDataPrompt[map[string]any, string](r, "lookupDataPrompt") ++ if dp == nil { ++ t.Error("expected DataPrompt, got nil") ++ } ++ }) ++ ++ t.Run("returns nil for non-existent prompt", func(t *testing.T) { ++ dp := LookupDataPrompt[map[string]any, string](r, "nonexistent") ++ if dp != nil { ++ t.Error("expected nil for non-existent prompt") ++ } ++ }) ++} ++ ++// TestAsDataPrompt tests AsDataPrompt function. ++func TestAsDataPrompt(t *testing.T) { ++ r := newTestRegistry(t) ++ ++ model := defineFakeModel(t, r, fakeModelConfig{ ++ name: "test/asDataModel", ++ }) ++ ++ p := DefinePrompt(r, "asDataPrompt", ++ WithModel(model), ++ WithPrompt("test"), ++ ) ++ ++ t.Run("wraps existing prompt", func(t *testing.T) { ++ dp := AsDataPrompt[map[string]any, string](p) ++ if dp == nil { ++ t.Error("expected DataPrompt, got nil") ++ } ++ }) ++ ++ t.Run("returns nil for nil prompt", func(t *testing.T) { ++ dp := AsDataPrompt[map[string]any, string](nil) ++ if dp != nil { ++ t.Error("expected nil for nil prompt") ++ } ++ }) ++} ++ ++// TestPromptKeyVariantKey tests the prompt key generation helpers. ++func TestPromptKeyVariantKey(t *testing.T) { ++ tests := []struct { ++ name string ++ promptName string ++ variant string ++ namespace string ++ want string ++ }{ ++ { ++ name: "simple name", ++ promptName: "greeting", ++ want: "greeting", ++ }, ++ { ++ name: "with variant", ++ promptName: "greeting", ++ variant: "formal", ++ want: "greeting.formal", ++ }, ++ { ++ name: "with namespace", ++ promptName: "greeting", ++ namespace: "myapp", ++ want: "myapp/greeting", ++ }, ++ { ++ name: "with namespace and variant", ++ promptName: "greeting", ++ variant: "formal", ++ namespace: "myapp", ++ want: "myapp/greeting.formal", ++ }, ++ } ++ ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ got := promptKey(tt.promptName, tt.variant, tt.namespace) ++ if got != tt.want { ++ t.Errorf("promptKey(%q, %q, %q) = %q, want %q", ++ tt.promptName, tt.variant, tt.namespace, got, tt.want) ++ } ++ }) ++ } ++} ++ ++// TestContentType tests the contentType helper function. ++func TestContentType(t *testing.T) { ++ tests := []struct { ++ name string ++ ct string ++ uri string ++ wantCT string ++ wantData string ++ wantErr bool ++ errContains string ++ }{ ++ { ++ name: "gs:// URL with content type", ++ ct: "image/png", ++ uri: "gs://bucket/image.png", ++ wantCT: "image/png", ++ wantData: "gs://bucket/image.png", ++ }, ++ { ++ name: "gs:// URL without content type", ++ ct: "", ++ uri: "gs://bucket/image.png", ++ wantErr: true, ++ errContains: "must supply contentType", ++ }, ++ { ++ name: "http URL with content type", ++ ct: "image/jpeg", ++ uri: "https://example.com/image.jpg", ++ wantCT: "image/jpeg", ++ wantData: "https://example.com/image.jpg", ++ }, ++ { ++ name: "http URL without content type", ++ ct: "", ++ uri: "https://example.com/image.jpg", ++ wantErr: true, ++ errContains: "must supply contentType", ++ }, ++ { ++ name: "data URI with base64", ++ ct: "", ++ uri: "data:image/png;base64,iVBORw0KGgo=", ++ wantCT: "image/png", ++ wantData: "data:image/png;base64,iVBORw0KGgo=", ++ }, ++ { ++ name: "data URI with explicit content type override", ++ ct: "image/jpeg", ++ uri: "data:image/png;base64,iVBORw0KGgo=", ++ wantCT: "image/jpeg", ++ wantData: "data:image/png;base64,iVBORw0KGgo=", ++ }, ++ { ++ name: "empty URI", ++ ct: "image/png", ++ uri: "", ++ wantErr: true, ++ errContains: "found empty URI", ++ }, ++ { ++ name: "malformed data URI", ++ ct: "", ++ uri: "data:image/png", ++ wantErr: true, ++ errContains: "missing comma", ++ }, ++ { ++ name: "unknown URI scheme", ++ ct: "", ++ uri: "file:///path/to/file", ++ wantErr: true, ++ errContains: "uri content type not found", ++ }, ++ } ++ ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ gotCT, gotData, err := contentType(tt.ct, tt.uri) ++ ++ if tt.wantErr { ++ if err == nil { ++ t.Errorf("expected error containing %q, got nil", tt.errContains) ++ } else if !strings.Contains(err.Error(), tt.errContains) { ++ t.Errorf("error = %q, want containing %q", err.Error(), tt.errContains) ++ } ++ return ++ } ++ ++ if err != nil { ++ t.Fatalf("unexpected error: %v", err) ++ } ++ if gotCT != tt.wantCT { ++ t.Errorf("contentType = %q, want %q", gotCT, tt.wantCT) ++ } ++ if string(gotData) != tt.wantData { ++ t.Errorf("data = %q, want %q", string(gotData), tt.wantData) ++ } ++ }) ++ } ++} ++ ++// TestDefineDataPromptPanics tests panic conditions in DefineDataPrompt. ++func TestDefineDataPromptPanics(t *testing.T) { ++ t.Run("empty name panics", func(t *testing.T) { ++ r := newTestRegistry(t) ++ assertPanic(t, func() { ++ DefineDataPrompt[map[string]any, string](r, "") ++ }, "name is required") ++ }) ++} +diff --git a/go/ai/request_helpers_test.go b/go/ai/request_helpers_test.go +new file mode 100644 +index 000000000..1b1c80d4c +--- /dev/null ++++ b/go/ai/request_helpers_test.go +@@ -0,0 +1,371 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++// ++// SPDX-License-Identifier: Apache-2.0 ++ ++package ai ++ ++import ( ++ "testing" ++ ++ "github.com/google/go-cmp/cmp" ++) ++ ++func TestNewModelRequest(t *testing.T) { ++ t.Run("creates request with config and messages", func(t *testing.T) { ++ config := &GenerationCommonConfig{Temperature: 0.7} ++ msg1 := NewUserTextMessage("hello") ++ msg2 := NewModelTextMessage("hi there") ++ ++ req := NewModelRequest(config, msg1, msg2) ++ ++ if req.Config != config { ++ t.Error("Config not set correctly") ++ } ++ if len(req.Messages) != 2 { ++ t.Errorf("len(Messages) = %d, want 2", len(req.Messages)) ++ } ++ if req.Messages[0] != msg1 { ++ t.Error("First message not set correctly") ++ } ++ if req.Messages[1] != msg2 { ++ t.Error("Second message not set correctly") ++ } ++ }) ++ ++ t.Run("creates request with nil config", func(t *testing.T) { ++ msg := NewUserTextMessage("hello") ++ req := NewModelRequest(nil, msg) ++ ++ if req.Config != nil { ++ t.Errorf("Config = %v, want nil", req.Config) ++ } ++ if len(req.Messages) != 1 { ++ t.Errorf("len(Messages) = %d, want 1", len(req.Messages)) ++ } ++ }) ++ ++ t.Run("creates request with no messages", func(t *testing.T) { ++ config := map[string]any{"temp": 0.5} ++ req := NewModelRequest(config) ++ ++ if req.Config == nil { ++ t.Error("Config should not be nil") ++ } ++ if len(req.Messages) != 0 { ++ t.Errorf("len(Messages) = %d, want 0", len(req.Messages)) ++ } ++ }) ++} ++ ++func TestNewUserMessage(t *testing.T) { ++ t.Run("creates user message with parts", func(t *testing.T) { ++ parts := []*Part{NewTextPart("text"), NewMediaPart("image/png", "data:...")} ++ msg := NewUserMessage(parts...) ++ ++ if msg.Role != RoleUser { ++ t.Errorf("Role = %q, want %q", msg.Role, RoleUser) ++ } ++ if len(msg.Content) != 2 { ++ t.Errorf("len(Content) = %d, want 2", len(msg.Content)) ++ } ++ if msg.Metadata != nil { ++ t.Errorf("Metadata = %v, want nil", msg.Metadata) ++ } ++ }) ++ ++ t.Run("creates user message with no parts", func(t *testing.T) { ++ msg := NewUserMessage() ++ ++ if msg.Role != RoleUser { ++ t.Errorf("Role = %q, want %q", msg.Role, RoleUser) ++ } ++ if len(msg.Content) != 0 { ++ t.Errorf("len(Content) = %d, want 0", len(msg.Content)) ++ } ++ }) ++} ++ ++func TestNewUserMessageWithMetadata(t *testing.T) { ++ t.Run("creates user message with metadata", func(t *testing.T) { ++ metadata := map[string]any{"purpose": "context"} ++ parts := []*Part{NewTextPart("text")} ++ msg := NewUserMessageWithMetadata(metadata, parts...) ++ ++ if msg.Role != RoleUser { ++ t.Errorf("Role = %q, want %q", msg.Role, RoleUser) ++ } ++ if diff := cmp.Diff(metadata, msg.Metadata); diff != "" { ++ t.Errorf("Metadata mismatch (-want +got):\n%s", diff) ++ } ++ }) ++ ++ t.Run("creates user message with nil metadata", func(t *testing.T) { ++ msg := NewUserMessageWithMetadata(nil, NewTextPart("text")) ++ ++ if msg.Role != RoleUser { ++ t.Errorf("Role = %q, want %q", msg.Role, RoleUser) ++ } ++ if msg.Metadata != nil { ++ t.Errorf("Metadata = %v, want nil", msg.Metadata) ++ } ++ }) ++} ++ ++func TestNewUserTextMessage(t *testing.T) { ++ t.Run("creates text message with user role", func(t *testing.T) { ++ msg := NewUserTextMessage("hello world") ++ ++ if msg.Role != RoleUser { ++ t.Errorf("Role = %q, want %q", msg.Role, RoleUser) ++ } ++ if len(msg.Content) != 1 { ++ t.Fatalf("len(Content) = %d, want 1", len(msg.Content)) ++ } ++ if msg.Content[0].Text != "hello world" { ++ t.Errorf("Text = %q, want %q", msg.Content[0].Text, "hello world") ++ } ++ }) ++ ++ t.Run("creates text message with empty string", func(t *testing.T) { ++ msg := NewUserTextMessage("") ++ ++ if msg.Role != RoleUser { ++ t.Errorf("Role = %q, want %q", msg.Role, RoleUser) ++ } ++ if len(msg.Content) != 1 { ++ t.Fatalf("len(Content) = %d, want 1", len(msg.Content)) ++ } ++ if msg.Content[0].Text != "" { ++ t.Errorf("Text = %q, want empty string", msg.Content[0].Text) ++ } ++ }) ++} ++ ++func TestNewModelMessage(t *testing.T) { ++ t.Run("creates model message with parts", func(t *testing.T) { ++ parts := []*Part{NewTextPart("response")} ++ msg := NewModelMessage(parts...) ++ ++ if msg.Role != RoleModel { ++ t.Errorf("Role = %q, want %q", msg.Role, RoleModel) ++ } ++ if len(msg.Content) != 1 { ++ t.Errorf("len(Content) = %d, want 1", len(msg.Content)) ++ } ++ }) ++} ++ ++func TestNewModelTextMessage(t *testing.T) { ++ t.Run("creates text message with model role", func(t *testing.T) { ++ msg := NewModelTextMessage("model response") ++ ++ if msg.Role != RoleModel { ++ t.Errorf("Role = %q, want %q", msg.Role, RoleModel) ++ } ++ if len(msg.Content) != 1 { ++ t.Fatalf("len(Content) = %d, want 1", len(msg.Content)) ++ } ++ if msg.Content[0].Text != "model response" { ++ t.Errorf("Text = %q, want %q", msg.Content[0].Text, "model response") ++ } ++ }) ++} ++ ++func TestNewSystemMessage(t *testing.T) { ++ t.Run("creates system message with parts", func(t *testing.T) { ++ parts := []*Part{NewTextPart("system instruction")} ++ msg := NewSystemMessage(parts...) ++ ++ if msg.Role != RoleSystem { ++ t.Errorf("Role = %q, want %q", msg.Role, RoleSystem) ++ } ++ if len(msg.Content) != 1 { ++ t.Errorf("len(Content) = %d, want 1", len(msg.Content)) ++ } ++ }) ++} ++ ++func TestNewSystemTextMessage(t *testing.T) { ++ t.Run("creates text message with system role", func(t *testing.T) { ++ msg := NewSystemTextMessage("be helpful") ++ ++ if msg.Role != RoleSystem { ++ t.Errorf("Role = %q, want %q", msg.Role, RoleSystem) ++ } ++ if len(msg.Content) != 1 { ++ t.Fatalf("len(Content) = %d, want 1", len(msg.Content)) ++ } ++ if msg.Content[0].Text != "be helpful" { ++ t.Errorf("Text = %q, want %q", msg.Content[0].Text, "be helpful") ++ } ++ }) ++} ++ ++func TestNewMessage(t *testing.T) { ++ t.Run("creates message with all fields", func(t *testing.T) { ++ metadata := map[string]any{"key": "value"} ++ parts := []*Part{NewTextPart("content")} ++ msg := NewMessage(RoleTool, metadata, parts...) ++ ++ if msg.Role != RoleTool { ++ t.Errorf("Role = %q, want %q", msg.Role, RoleTool) ++ } ++ if diff := cmp.Diff(metadata, msg.Metadata); diff != "" { ++ t.Errorf("Metadata mismatch (-want +got):\n%s", diff) ++ } ++ if len(msg.Content) != 1 { ++ t.Errorf("len(Content) = %d, want 1", len(msg.Content)) ++ } ++ }) ++} ++ ++func TestNewTextMessage(t *testing.T) { ++ t.Run("creates text message with specified role", func(t *testing.T) { ++ msg := NewTextMessage(RoleTool, "tool output") ++ ++ if msg.Role != RoleTool { ++ t.Errorf("Role = %q, want %q", msg.Role, RoleTool) ++ } ++ if len(msg.Content) != 1 { ++ t.Fatalf("len(Content) = %d, want 1", len(msg.Content)) ++ } ++ if msg.Content[0].Text != "tool output" { ++ t.Errorf("Text = %q, want %q", msg.Content[0].Text, "tool output") ++ } ++ }) ++} ++ ++func TestWithCacheTTL(t *testing.T) { ++ t.Run("adds cache TTL to message without existing metadata", func(t *testing.T) { ++ original := NewUserTextMessage("hello") ++ result := original.WithCacheTTL(3600) ++ ++ // Original should be unchanged ++ if original.Metadata != nil { ++ t.Error("original message metadata should be nil") ++ } ++ ++ // Result should have cache metadata ++ if result.Metadata == nil { ++ t.Fatal("result metadata should not be nil") ++ } ++ cache, ok := result.Metadata["cache"].(map[string]any) ++ if !ok { ++ t.Fatal("cache metadata not found or wrong type") ++ } ++ if cache["ttlSeconds"] != 3600 { ++ t.Errorf("ttlSeconds = %v, want 3600", cache["ttlSeconds"]) ++ } ++ ++ // Content and role should be preserved ++ if result.Role != original.Role { ++ t.Errorf("Role changed: got %q, want %q", result.Role, original.Role) ++ } ++ if len(result.Content) != len(original.Content) { ++ t.Errorf("Content length changed") ++ } ++ }) ++ ++ t.Run("adds cache TTL to message with existing metadata", func(t *testing.T) { ++ original := NewUserMessageWithMetadata( ++ map[string]any{"existing": "value"}, ++ NewTextPart("hello"), ++ ) ++ result := original.WithCacheTTL(1800) ++ ++ // Result should have both existing and cache metadata ++ if result.Metadata["existing"] != "value" { ++ t.Error("existing metadata not preserved") ++ } ++ cache, ok := result.Metadata["cache"].(map[string]any) ++ if !ok { ++ t.Fatal("cache metadata not found") ++ } ++ if cache["ttlSeconds"] != 1800 { ++ t.Errorf("ttlSeconds = %v, want 1800", cache["ttlSeconds"]) ++ } ++ }) ++ ++ t.Run("chained with WithCacheName", func(t *testing.T) { ++ msg := NewUserTextMessage("hello"). ++ WithCacheTTL(3600). ++ WithCacheName("my-cache") ++ ++ cache, ok := msg.Metadata["cache"].(map[string]any) ++ if !ok { ++ t.Fatal("cache metadata not found") ++ } ++ // Note: second call overwrites the cache object ++ if cache["name"] != "my-cache" { ++ t.Errorf("cache name = %v, want %q", cache["name"], "my-cache") ++ } ++ }) ++} ++ ++func TestWithCacheName(t *testing.T) { ++ t.Run("adds cache name to message without existing metadata", func(t *testing.T) { ++ original := NewUserTextMessage("hello") ++ result := original.WithCacheName("my-cache") ++ ++ // Original should be unchanged ++ if original.Metadata != nil { ++ t.Error("original message metadata should be nil") ++ } ++ ++ // Result should have cache metadata ++ if result.Metadata == nil { ++ t.Fatal("result metadata should not be nil") ++ } ++ cache, ok := result.Metadata["cache"].(map[string]any) ++ if !ok { ++ t.Fatal("cache metadata not found or wrong type") ++ } ++ if cache["name"] != "my-cache" { ++ t.Errorf("name = %v, want %q", cache["name"], "my-cache") ++ } ++ }) ++ ++ t.Run("adds cache name to message with existing metadata", func(t *testing.T) { ++ original := NewUserMessageWithMetadata( ++ map[string]any{"existing": "value"}, ++ NewTextPart("hello"), ++ ) ++ result := original.WithCacheName("another-cache") ++ ++ // Result should have both existing and cache metadata ++ if result.Metadata["existing"] != "value" { ++ t.Error("existing metadata not preserved") ++ } ++ cache, ok := result.Metadata["cache"].(map[string]any) ++ if !ok { ++ t.Fatal("cache metadata not found") ++ } ++ if cache["name"] != "another-cache" { ++ t.Errorf("name = %v, want %q", cache["name"], "another-cache") ++ } ++ }) ++ ++ t.Run("with empty name", func(t *testing.T) { ++ msg := NewUserTextMessage("hello").WithCacheName("") ++ ++ cache, ok := msg.Metadata["cache"].(map[string]any) ++ if !ok { ++ t.Fatal("cache metadata not found") ++ } ++ if cache["name"] != "" { ++ t.Errorf("name = %v, want empty string", cache["name"]) ++ } ++ }) ++} +diff --git a/go/ai/resource_test.go b/go/ai/resource_test.go +index e7b56b65e..30f15b807 100644 +--- a/go/ai/resource_test.go ++++ b/go/ai/resource_test.go +@@ -272,3 +272,85 @@ func TestMultipleDynamicResourcesInGeneration(t *testing.T) { + func contains(s, substr string) bool { + return strings.Contains(s, substr) + } ++ ++func TestLookupResource(t *testing.T) { ++ t.Run("finds registered resource", func(t *testing.T) { ++ r := registry.New() ++ DefineResource(r, "test/lookup", &ResourceOptions{ ++ URI: "lookup://test", ++ }, func(ctx context.Context, input *ResourceInput) (*ResourceOutput, error) { ++ return &ResourceOutput{ ++ Content: []*Part{NewTextPart("found")}, ++ }, nil ++ }) ++ ++ found := LookupResource(r, "test/lookup") ++ if found == nil { ++ t.Fatal("LookupResource returned nil") ++ } ++ if found.Name() != "test/lookup" { ++ t.Errorf("Name() = %q, want %q", found.Name(), "test/lookup") ++ } ++ }) ++ ++ t.Run("returns nil for non-existent resource", func(t *testing.T) { ++ r := registry.New() ++ ++ found := LookupResource(r, "test/nonexistent") ++ if found != nil { ++ t.Errorf("LookupResource returned %v, want nil", found) ++ } ++ }) ++ ++ t.Run("resource can be executed after lookup", func(t *testing.T) { ++ r := registry.New() ++ DefineResource(r, "test/executable", &ResourceOptions{ ++ URI: "exec://test", ++ }, func(ctx context.Context, input *ResourceInput) (*ResourceOutput, error) { ++ return &ResourceOutput{ ++ Content: []*Part{NewTextPart("executed: " + input.URI)}, ++ }, nil ++ }) ++ ++ found := LookupResource(r, "test/executable") ++ if found == nil { ++ t.Fatal("LookupResource returned nil") ++ } ++ ++ output, err := found.Execute(context.Background(), &ResourceInput{URI: "exec://test", Variables: map[string]string{}}) ++ if err != nil { ++ t.Fatalf("Execute error: %v", err) ++ } ++ if len(output.Content) != 1 || output.Content[0].Text != "executed: exec://test" { ++ t.Errorf("unexpected output: %v", output.Content) ++ } ++ }) ++ ++ t.Run("resource matches and extracts variables after lookup", func(t *testing.T) { ++ r := registry.New() ++ DefineResource(r, "test/template", &ResourceOptions{ ++ Template: "template://item/{id}", ++ }, func(ctx context.Context, input *ResourceInput) (*ResourceOutput, error) { ++ return &ResourceOutput{ ++ Content: []*Part{NewTextPart("item " + input.Variables["id"])}, ++ }, nil ++ }) ++ ++ found := LookupResource(r, "test/template") ++ if found == nil { ++ t.Fatal("LookupResource returned nil") ++ } ++ ++ if !found.Matches("template://item/123") { ++ t.Error("Matches() = false, want true") ++ } ++ ++ vars, err := found.ExtractVariables("template://item/456") ++ if err != nil { ++ t.Fatalf("ExtractVariables error: %v", err) ++ } ++ if vars["id"] != "456" { ++ t.Errorf("vars[id] = %q, want %q", vars["id"], "456") ++ } ++ }) ++} +diff --git a/go/ai/retriever_test.go b/go/ai/retriever_test.go +new file mode 100644 +index 000000000..ecbf019f7 +--- /dev/null ++++ b/go/ai/retriever_test.go +@@ -0,0 +1,407 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++// ++// SPDX-License-Identifier: Apache-2.0 ++ ++package ai ++ ++import ( ++ "context" ++ "errors" ++ "testing" ++ ++ "github.com/google/go-cmp/cmp" ++) ++ ++func TestRetrieverRef(t *testing.T) { ++ t.Run("NewRetrieverRef creates ref with name and config", func(t *testing.T) { ++ config := map[string]any{"topK": 10} ++ ref := NewRetrieverRef("test/retriever", config) ++ ++ if ref.Name() != "test/retriever" { ++ t.Errorf("Name() = %q, want %q", ref.Name(), "test/retriever") ++ } ++ if diff := cmp.Diff(config, ref.Config()); diff != "" { ++ t.Errorf("Config() mismatch (-want +got):\n%s", diff) ++ } ++ }) ++ ++ t.Run("NewRetrieverRef with nil config", func(t *testing.T) { ++ ref := NewRetrieverRef("test/retriever", nil) ++ ++ if ref.Name() != "test/retriever" { ++ t.Errorf("Name() = %q, want %q", ref.Name(), "test/retriever") ++ } ++ if ref.Config() != nil { ++ t.Errorf("Config() = %v, want nil", ref.Config()) ++ } ++ }) ++} ++ ++func TestNewRetriever(t *testing.T) { ++ t.Run("creates retriever with valid name", func(t *testing.T) { ++ r := NewRetriever("test/retriever", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { ++ return &RetrieverResponse{}, nil ++ }) ++ ++ if r == nil { ++ t.Fatal("expected retriever, got nil") ++ } ++ if r.Name() != "test/retriever" { ++ t.Errorf("Name() = %q, want %q", r.Name(), "test/retriever") ++ } ++ }) ++ ++ t.Run("panics with empty name", func(t *testing.T) { ++ assertPanic(t, func() { ++ NewRetriever("", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { ++ return &RetrieverResponse{}, nil ++ }) ++ }, "name is required") ++ }) ++ ++ t.Run("applies options correctly", func(t *testing.T) { ++ opts := &RetrieverOptions{ ++ Label: "Test Retriever", ++ Supports: &RetrieverSupports{ ++ Media: true, ++ }, ++ ConfigSchema: map[string]any{ ++ "type": "object", ++ "properties": map[string]any{ ++ "topK": map[string]any{"type": "integer"}, ++ }, ++ }, ++ } ++ ++ r := NewRetriever("test/retriever", opts, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { ++ return &RetrieverResponse{}, nil ++ }) ++ ++ if r == nil { ++ t.Fatal("expected retriever, got nil") ++ } ++ }) ++ ++ t.Run("uses defaults when options nil", func(t *testing.T) { ++ r := NewRetriever("test/retriever", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { ++ return &RetrieverResponse{}, nil ++ }) ++ ++ if r == nil { ++ t.Fatal("expected retriever, got nil") ++ } ++ }) ++} ++ ++func TestDefineRetriever(t *testing.T) { ++ t.Run("registers and returns retriever", func(t *testing.T) { ++ reg := newTestRegistry(t) ++ called := false ++ expectedDocs := []*Document{ ++ DocumentFromText("result 1", nil), ++ DocumentFromText("result 2", nil), ++ } ++ ++ r := DefineRetriever(reg, "test/defineRetriever", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { ++ called = true ++ return &RetrieverResponse{Documents: expectedDocs}, nil ++ }) ++ ++ if r == nil { ++ t.Fatal("expected retriever, got nil") ++ } ++ ++ // Verify it's registered by looking it up ++ found := LookupRetriever(reg, "test/defineRetriever") ++ if found == nil { ++ t.Fatal("LookupRetriever returned nil for registered retriever") ++ } ++ ++ // Verify the function works ++ resp, err := r.Retrieve(context.Background(), &RetrieverRequest{ ++ Query: DocumentFromText("query", nil), ++ }) ++ assertNoError(t, err) ++ if !called { ++ t.Error("retriever function was not called") ++ } ++ if len(resp.Documents) != 2 { ++ t.Errorf("len(Documents) = %d, want 2", len(resp.Documents)) ++ } ++ }) ++} ++ ++func TestLookupRetriever(t *testing.T) { ++ t.Run("returns retriever when found", func(t *testing.T) { ++ reg := newTestRegistry(t) ++ DefineRetriever(reg, "test/lookupRetriever", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { ++ return &RetrieverResponse{}, nil ++ }) ++ ++ r := LookupRetriever(reg, "test/lookupRetriever") ++ if r == nil { ++ t.Error("expected retriever, got nil") ++ } ++ }) ++ ++ t.Run("returns nil when not found", func(t *testing.T) { ++ reg := newTestRegistry(t) ++ ++ r := LookupRetriever(reg, "nonexistent") ++ if r != nil { ++ t.Error("expected nil for non-existent retriever") ++ } ++ }) ++} ++ ++func TestRetrieverRetrieve(t *testing.T) { ++ t.Run("retrieves documents successfully", func(t *testing.T) { ++ reg := newTestRegistry(t) ++ var capturedReq *RetrieverRequest ++ ++ expectedDocs := []*Document{ ++ DocumentFromText("relevant result 1", map[string]any{"score": 0.9}), ++ DocumentFromText("relevant result 2", map[string]any{"score": 0.8}), ++ } ++ ++ r := DefineRetriever(reg, "test/retrieveDocs", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { ++ capturedReq = req ++ return &RetrieverResponse{Documents: expectedDocs}, nil ++ }) ++ ++ query := DocumentFromText("search query", nil) ++ resp, err := r.Retrieve(context.Background(), &RetrieverRequest{Query: query}) ++ assertNoError(t, err) ++ ++ if len(capturedReq.Query.Content) == 0 || capturedReq.Query.Content[0].Text != "search query" { ++ t.Errorf("captured query content mismatch") ++ } ++ if len(resp.Documents) != 2 { ++ t.Errorf("len(Documents) = %d, want 2", len(resp.Documents)) ++ } ++ }) ++ ++ t.Run("returns error on nil retriever", func(t *testing.T) { ++ var r *retriever ++ _, err := r.Retrieve(context.Background(), &RetrieverRequest{}) ++ if err == nil { ++ t.Error("expected error for nil retriever") ++ } ++ }) ++ ++ t.Run("propagates function errors", func(t *testing.T) { ++ reg := newTestRegistry(t) ++ expectedErr := errors.New("retrieval failed") ++ ++ r := DefineRetriever(reg, "test/retrieveError", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { ++ return nil, expectedErr ++ }) ++ ++ _, err := r.Retrieve(context.Background(), &RetrieverRequest{ ++ Query: DocumentFromText("query", nil), ++ }) ++ if err == nil { ++ t.Error("expected error, got nil") ++ } ++ }) ++ ++ t.Run("passes options through request", func(t *testing.T) { ++ reg := newTestRegistry(t) ++ var capturedOpts any ++ ++ r := DefineRetriever(reg, "test/retrieveOpts", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { ++ capturedOpts = req.Options ++ return &RetrieverResponse{Documents: []*Document{}}, nil ++ }) ++ ++ opts := map[string]any{"topK": 5, "threshold": 0.7} ++ _, err := r.Retrieve(context.Background(), &RetrieverRequest{ ++ Query: DocumentFromText("query", nil), ++ Options: opts, ++ }) ++ assertNoError(t, err) ++ ++ if diff := cmp.Diff(opts, capturedOpts); diff != "" { ++ t.Errorf("Options mismatch (-want +got):\n%s", diff) ++ } ++ }) ++} ++ ++func TestRetrieveFunction(t *testing.T) { ++ t.Run("retrieves with retriever directly", func(t *testing.T) { ++ reg := newTestRegistry(t) ++ r := DefineRetriever(reg, "test/retrieveFunc", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { ++ return &RetrieverResponse{ ++ Documents: []*Document{DocumentFromText("result", nil)}, ++ }, nil ++ }) ++ ++ resp, err := Retrieve(context.Background(), reg, ++ WithRetriever(r), ++ WithTextDocs("query"), ++ ) ++ assertNoError(t, err) ++ ++ if len(resp.Documents) != 1 { ++ t.Errorf("len(Documents) = %d, want 1", len(resp.Documents)) ++ } ++ }) ++ ++ t.Run("retrieves with retriever ref", func(t *testing.T) { ++ reg := newTestRegistry(t) ++ DefineRetriever(reg, "test/retrieveFuncRef", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { ++ return &RetrieverResponse{ ++ Documents: []*Document{DocumentFromText("result", nil)}, ++ }, nil ++ }) ++ ++ ref := NewRetrieverRef("test/retrieveFuncRef", nil) ++ resp, err := Retrieve(context.Background(), reg, ++ WithRetriever(ref), ++ WithTextDocs("query"), ++ ) ++ assertNoError(t, err) ++ ++ if len(resp.Documents) != 1 { ++ t.Errorf("len(Documents) = %d, want 1", len(resp.Documents)) ++ } ++ }) ++ ++ t.Run("retrieves with retriever name", func(t *testing.T) { ++ reg := newTestRegistry(t) ++ DefineRetriever(reg, "test/retrieveFuncName", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { ++ return &RetrieverResponse{ ++ Documents: []*Document{DocumentFromText("result", nil)}, ++ }, nil ++ }) ++ ++ resp, err := Retrieve(context.Background(), reg, ++ WithRetrieverName("test/retrieveFuncName"), ++ WithTextDocs("query"), ++ ) ++ assertNoError(t, err) ++ ++ if len(resp.Documents) != 1 { ++ t.Errorf("len(Documents) = %d, want 1", len(resp.Documents)) ++ } ++ }) ++ ++ t.Run("uses config from RetrieverRef", func(t *testing.T) { ++ reg := newTestRegistry(t) ++ var capturedOpts any ++ ++ DefineRetriever(reg, "test/retrieveRefConfig", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { ++ capturedOpts = req.Options ++ return &RetrieverResponse{Documents: []*Document{}}, nil ++ }) ++ ++ config := map[string]any{"topK": 10} ++ ref := NewRetrieverRef("test/retrieveRefConfig", config) ++ ++ _, err := Retrieve(context.Background(), reg, ++ WithRetriever(ref), ++ WithTextDocs("query"), ++ ) ++ assertNoError(t, err) ++ ++ if diff := cmp.Diff(config, capturedOpts); diff != "" { ++ t.Errorf("Options mismatch (-want +got):\n%s", diff) ++ } ++ }) ++ ++ t.Run("explicit config overrides RetrieverRef config", func(t *testing.T) { ++ reg := newTestRegistry(t) ++ var capturedOpts any ++ ++ DefineRetriever(reg, "test/retrieveOverrideConfig", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { ++ capturedOpts = req.Options ++ return &RetrieverResponse{Documents: []*Document{}}, nil ++ }) ++ ++ refConfig := map[string]any{"topK": 10} ++ explicitConfig := map[string]any{"topK": 5} ++ ref := NewRetrieverRef("test/retrieveOverrideConfig", refConfig) ++ ++ _, err := Retrieve(context.Background(), reg, ++ WithRetriever(ref), ++ WithConfig(explicitConfig), ++ WithTextDocs("query"), ++ ) ++ assertNoError(t, err) ++ ++ if diff := cmp.Diff(explicitConfig, capturedOpts); diff != "" { ++ t.Errorf("Options mismatch (-want +got):\n%s", diff) ++ } ++ }) ++ ++ t.Run("returns error when retriever not set", func(t *testing.T) { ++ reg := newTestRegistry(t) ++ ++ _, err := Retrieve(context.Background(), reg, ++ WithTextDocs("query"), ++ ) ++ assertError(t, err, "retriever must be set") ++ }) ++ ++ t.Run("returns error when retriever not found", func(t *testing.T) { ++ reg := newTestRegistry(t) ++ ++ _, err := Retrieve(context.Background(), reg, ++ WithRetrieverName("nonexistent"), ++ WithTextDocs("query"), ++ ) ++ assertError(t, err, "retriever not found") ++ }) ++ ++ t.Run("returns error with multiple documents", func(t *testing.T) { ++ reg := newTestRegistry(t) ++ DefineRetriever(reg, "test/retrieveMultiDoc", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { ++ return &RetrieverResponse{Documents: []*Document{}}, nil ++ }) ++ ++ _, err := Retrieve(context.Background(), reg, ++ WithRetrieverName("test/retrieveMultiDoc"), ++ WithDocs( ++ DocumentFromText("doc1", nil), ++ DocumentFromText("doc2", nil), ++ ), ++ ) ++ assertError(t, err, "only supports a single document") ++ }) ++ ++ t.Run("retrieves with document options", func(t *testing.T) { ++ reg := newTestRegistry(t) ++ var capturedQuery *Document ++ ++ DefineRetriever(reg, "test/retrieveDocOpts", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { ++ capturedQuery = req.Query ++ return &RetrieverResponse{Documents: []*Document{}}, nil ++ }) ++ ++ query := DocumentFromText("custom query", map[string]any{"custom": "metadata"}) ++ _, err := Retrieve(context.Background(), reg, ++ WithRetrieverName("test/retrieveDocOpts"), ++ WithDocs(query), ++ ) ++ assertNoError(t, err) ++ ++ if len(capturedQuery.Content) == 0 || capturedQuery.Content[0].Text != "custom query" { ++ t.Errorf("query content mismatch") ++ } ++ if capturedQuery.Metadata["custom"] != "metadata" { ++ t.Error("query metadata not passed correctly") ++ } ++ }) ++} +diff --git a/go/ai/testutil_test.go b/go/ai/testutil_test.go +new file mode 100644 +index 000000000..6c606a28a +--- /dev/null ++++ b/go/ai/testutil_test.go +@@ -0,0 +1,306 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++// ++// SPDX-License-Identifier: Apache-2.0 ++ ++package ai ++ ++import ( ++ "context" ++ "fmt" ++ "strings" ++ "testing" ++ ++ "github.com/firebase/genkit/go/core/api" ++ "github.com/firebase/genkit/go/internal/registry" ++ "github.com/google/go-cmp/cmp" ++) ++ ++// newTestRegistry creates a fresh registry for testing with formats configured. ++func newTestRegistry(t *testing.T) api.Registry { ++ t.Helper() ++ r := registry.New() ++ ConfigureFormats(r) ++ return r ++} ++ ++// fakeModelConfig holds configuration for creating a fake model. ++type fakeModelConfig struct { ++ name string ++ supports *ModelSupports ++ handler func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) ++} ++ ++// defaultModelSupports returns a ModelSupports with common capabilities enabled. ++func defaultModelSupports() *ModelSupports { ++ return &ModelSupports{ ++ Tools: true, ++ Multiturn: true, ++ ToolChoice: true, ++ SystemRole: true, ++ Constrained: ConstrainedSupportAll, ++ } ++} ++ ++// defineFakeModel creates a configurable fake model for testing. ++func defineFakeModel(t *testing.T, r api.Registry, cfg fakeModelConfig) Model { ++ t.Helper() ++ if cfg.name == "" { ++ cfg.name = "test/fakeModel" ++ } ++ if cfg.supports == nil { ++ cfg.supports = defaultModelSupports() ++ } ++ if cfg.handler == nil { ++ cfg.handler = func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ return &ModelResponse{ ++ Request: req, ++ Message: NewModelTextMessage("fake response"), ++ }, nil ++ } ++ } ++ return DefineModel(r, cfg.name, &ModelOptions{Supports: cfg.supports}, cfg.handler) ++} ++ ++// echoModelHandler creates a handler that echoes back information about the request. ++// Useful for verifying that options are properly passed through. ++func echoModelHandler() func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ return func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ var parts []string ++ ++ // Echo messages ++ for _, msg := range req.Messages { ++ parts = append(parts, fmt.Sprintf("%s: %s", msg.Role, msg.Text())) ++ } ++ ++ // Echo config if present ++ if req.Config != nil { ++ if cfg, ok := req.Config.(*GenerationCommonConfig); ok { ++ parts = append(parts, fmt.Sprintf("temp=%.1f", cfg.Temperature)) ++ } ++ } ++ ++ // Echo tool count ++ if len(req.Tools) > 0 { ++ parts = append(parts, fmt.Sprintf("tools=%d", len(req.Tools))) ++ } ++ ++ // Echo tool choice ++ if req.ToolChoice != "" { ++ parts = append(parts, fmt.Sprintf("toolChoice=%s", req.ToolChoice)) ++ } ++ ++ // Echo docs count ++ if len(req.Docs) > 0 { ++ parts = append(parts, fmt.Sprintf("docs=%d", len(req.Docs))) ++ } ++ ++ return &ModelResponse{ ++ Request: req, ++ Message: NewModelTextMessage(strings.Join(parts, "; ")), ++ }, nil ++ } ++} ++ ++// capturingModelHandler returns a handler that captures the request for inspection. ++func capturingModelHandler(captured **ModelRequest) func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ return func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ *captured = req ++ return &ModelResponse{ ++ Request: req, ++ Message: NewModelTextMessage("captured"), ++ }, nil ++ } ++} ++ ++// streamingModelHandler creates a handler that sends chunks before returning. ++func streamingModelHandler(chunks []string, finalText string) func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ return func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ if cb != nil { ++ for _, chunk := range chunks { ++ if err := cb(ctx, &ModelResponseChunk{ ++ Content: []*Part{NewTextPart(chunk)}, ++ }); err != nil { ++ return nil, err ++ } ++ } ++ } ++ return &ModelResponse{ ++ Request: req, ++ Message: NewModelTextMessage(finalText), ++ }, nil ++ } ++} ++ ++// jsonModelHandler creates a handler that returns JSON output. ++func jsonModelHandler(jsonOutput string) func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ return func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ return &ModelResponse{ ++ Request: req, ++ Message: &Message{ ++ Role: RoleModel, ++ Content: []*Part{NewJSONPart(jsonOutput)}, ++ }, ++ }, nil ++ } ++} ++ ++// toolCallingModelHandler creates a handler that makes a tool call on first request, ++// then returns the final response after receiving the tool response. ++func toolCallingModelHandler(toolName string, toolInput map[string]any, finalResponse string) func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ callCount := 0 ++ return func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { ++ callCount++ ++ ++ // Check if we already have a tool response ++ hasToolResponse := false ++ for _, msg := range req.Messages { ++ for _, part := range msg.Content { ++ if part.IsToolResponse() { ++ hasToolResponse = true ++ break ++ } ++ } ++ } ++ ++ if !hasToolResponse && len(req.Tools) > 0 { ++ // First call - request tool execution ++ return &ModelResponse{ ++ Request: req, ++ Message: &Message{ ++ Role: RoleModel, ++ Content: []*Part{NewToolRequestPart(&ToolRequest{ ++ Name: toolName, ++ Input: toolInput, ++ })}, ++ }, ++ }, nil ++ } ++ ++ // Tool response received or no tools - return final response ++ return &ModelResponse{ ++ Request: req, ++ Message: NewModelTextMessage(finalResponse), ++ }, nil ++ } ++} ++ ++// cmpPartEqual is a Part comparator for cmp.Diff that compares essential fields. ++func cmpPartEqual(a, b *Part) bool { ++ if a == nil || b == nil { ++ return a == b ++ } ++ if a.Kind != b.Kind { ++ return false ++ } ++ if a.Text != b.Text { ++ return false ++ } ++ if a.ContentType != b.ContentType { ++ return false ++ } ++ return true ++} ++ ++// cmpPartComparer returns a cmp.Option for comparing Parts. ++func cmpPartComparer() cmp.Option { ++ return cmp.Comparer(cmpPartEqual) ++} ++ ++// assertEqual compares two values and reports differences. ++func assertEqual[T any](t *testing.T, got, want T, opts ...cmp.Option) { ++ t.Helper() ++ if diff := cmp.Diff(want, got, opts...); diff != "" { ++ t.Errorf("mismatch (-want +got):\n%s", diff) ++ } ++} ++ ++// assertError verifies error is non-nil and contains expected substring. ++func assertError(t *testing.T, err error, wantContains string) { ++ t.Helper() ++ if err == nil { ++ t.Fatal("expected error, got nil") ++ } ++ if !strings.Contains(err.Error(), wantContains) { ++ t.Errorf("error %q does not contain %q", err.Error(), wantContains) ++ } ++} ++ ++// assertNoError fails the test if err is not nil. ++func assertNoError(t *testing.T, err error) { ++ t.Helper() ++ if err != nil { ++ t.Fatalf("unexpected error: %v", err) ++ } ++} ++ ++// assertPanic verifies that fn panics and the panic value contains wantContains. ++func assertPanic(t *testing.T, fn func(), wantContains string) { ++ t.Helper() ++ defer func() { ++ r := recover() ++ if r == nil { ++ t.Fatal("expected panic, got none") ++ } ++ msg := fmt.Sprint(r) ++ if !strings.Contains(msg, wantContains) { ++ t.Errorf("panic %q does not contain %q", msg, wantContains) ++ } ++ }() ++ fn() ++} ++ ++// assertNoPanic verifies that fn does not panic. ++func assertNoPanic(t *testing.T, fn func()) { ++ t.Helper() ++ defer func() { ++ if r := recover(); r != nil { ++ t.Fatalf("unexpected panic: %v", r) ++ } ++ }() ++ fn() ++} ++ ++// defineFakeTool creates a simple tool for testing. ++func defineFakeTool(t *testing.T, r api.Registry, name, description string) Tool { ++ t.Helper() ++ return DefineTool(r, name, description, ++ func(ctx *ToolContext, input struct { ++ Value string `json:"value"` ++ }) (string, error) { ++ return "tool result: " + input.Value, nil ++ }) ++} ++ ++// defineFakeEmbedder creates a simple embedder for testing. ++func defineFakeEmbedder(t *testing.T, r api.Registry, name string) Embedder { ++ t.Helper() ++ return DefineEmbedder(r, name, nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { ++ embeddings := make([]*Embedding, len(req.Input)) ++ for i := range req.Input { ++ embeddings[i] = &Embedding{ ++ Embedding: []float32{0.1, 0.2, 0.3}, ++ } ++ } ++ return &EmbedResponse{Embeddings: embeddings}, nil ++ }) ++} ++ ++// defineFakeRetriever creates a simple retriever for testing. ++func defineFakeRetriever(t *testing.T, r api.Registry, name string, docs []*Document) Retriever { ++ t.Helper() ++ return DefineRetriever(r, name, nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { ++ return &RetrieverResponse{Documents: docs}, nil ++ }) ++} +diff --git a/go/ai/tools_test.go b/go/ai/tools_test.go +new file mode 100644 +index 000000000..857be8c2b +--- /dev/null ++++ b/go/ai/tools_test.go +@@ -0,0 +1,908 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++// ++// SPDX-License-Identifier: Apache-2.0 ++ ++package ai ++ ++import ( ++ "context" ++ "errors" ++ "testing" ++ ++ "github.com/google/go-cmp/cmp" ++) ++ ++func TestToolName(t *testing.T) { ++ t.Run("Name returns string value", func(t *testing.T) { ++ tn := ToolName("myTool") ++ got := tn.Name() ++ want := "myTool" ++ if got != want { ++ t.Errorf("Name() = %q, want %q", got, want) ++ } ++ }) ++ ++ t.Run("empty tool name", func(t *testing.T) { ++ tn := ToolName("") ++ got := tn.Name() ++ if got != "" { ++ t.Errorf("Name() = %q, want empty string", got) ++ } ++ }) ++} ++ ++func TestToolInterruptError(t *testing.T) { ++ t.Run("Error returns fixed message", func(t *testing.T) { ++ err := &toolInterruptError{Metadata: map[string]any{"key": "value"}} ++ got := err.Error() ++ want := "tool execution interrupted" ++ if got != want { ++ t.Errorf("Error() = %q, want %q", got, want) ++ } ++ }) ++} ++ ++func TestIsToolInterruptError(t *testing.T) { ++ t.Run("returns true for toolInterruptError", func(t *testing.T) { ++ meta := map[string]any{"reason": "user cancelled"} ++ err := &toolInterruptError{Metadata: meta} ++ ++ isInterrupt, gotMeta := IsToolInterruptError(err) ++ ++ if !isInterrupt { ++ t.Error("IsToolInterruptError() = false, want true") ++ } ++ if diff := cmp.Diff(meta, gotMeta); diff != "" { ++ t.Errorf("metadata mismatch (-want +got):\n%s", diff) ++ } ++ }) ++ ++ t.Run("returns true for wrapped toolInterruptError", func(t *testing.T) { ++ meta := map[string]any{"step": 3} ++ innerErr := &toolInterruptError{Metadata: meta} ++ wrappedErr := errors.New("context: " + innerErr.Error()) ++ // Use proper wrapping ++ wrappedErr = &wrappedInterruptError{cause: innerErr} ++ ++ isInterrupt, gotMeta := IsToolInterruptError(wrappedErr) ++ ++ if !isInterrupt { ++ t.Error("IsToolInterruptError(wrapped) = false, want true") ++ } ++ if gotMeta["step"] != 3 { ++ t.Errorf("metadata[step] = %v, want 3", gotMeta["step"]) ++ } ++ }) ++ ++ t.Run("returns false for regular error", func(t *testing.T) { ++ err := errors.New("some error") ++ ++ isInterrupt, meta := IsToolInterruptError(err) ++ ++ if isInterrupt { ++ t.Error("IsToolInterruptError(regular error) = true, want false") ++ } ++ if meta != nil { ++ t.Errorf("metadata = %v, want nil", meta) ++ } ++ }) ++ ++ t.Run("returns false for nil error", func(t *testing.T) { ++ isInterrupt, meta := IsToolInterruptError(nil) ++ ++ if isInterrupt { ++ t.Error("IsToolInterruptError(nil) = true, want false") ++ } ++ if meta != nil { ++ t.Errorf("metadata = %v, want nil", meta) ++ } ++ }) ++} ++ ++// wrappedInterruptError is a helper for testing error unwrapping. ++type wrappedInterruptError struct { ++ cause error ++} ++ ++func (e *wrappedInterruptError) Error() string { ++ return "wrapped: " + e.cause.Error() ++} ++ ++func (e *wrappedInterruptError) Unwrap() error { ++ return e.cause ++} ++ ++func TestDefineTool(t *testing.T) { ++ t.Run("creates and registers tool", func(t *testing.T) { ++ r := newTestRegistry(t) ++ tl := DefineTool(r, "provider/addNumbers", "Adds two numbers", func(ctx *ToolContext, input struct { ++ A int `json:"a"` ++ B int `json:"b"` ++ }) (int, error) { ++ return input.A + input.B, nil ++ }) ++ ++ if tl == nil { ++ t.Fatal("DefineTool returned nil") ++ } ++ if tl.Name() != "provider/addNumbers" { ++ t.Errorf("Name() = %q, want %q", tl.Name(), "provider/addNumbers") ++ } ++ ++ def := tl.Definition() ++ if def.Description != "Adds two numbers" { ++ t.Errorf("Description = %q, want %q", def.Description, "Adds two numbers") ++ } ++ }) ++ ++ t.Run("tool can be looked up after registration", func(t *testing.T) { ++ r := newTestRegistry(t) ++ DefineTool(r, "provider/multiply", "Multiplies", func(ctx *ToolContext, input struct { ++ X int `json:"x"` ++ Y int `json:"y"` ++ }) (int, error) { ++ return input.X * input.Y, nil ++ }) ++ ++ found := LookupTool(r, "provider/multiply") ++ if found == nil { ++ t.Error("LookupTool returned nil for registered tool") ++ } ++ }) ++ ++ t.Run("tool executes correctly", func(t *testing.T) { ++ r := newTestRegistry(t) ++ tl := DefineTool(r, "provider/concat", "Concatenates strings", func(ctx *ToolContext, input struct { ++ A string `json:"a"` ++ B string `json:"b"` ++ }) (string, error) { ++ return input.A + input.B, nil ++ }) ++ ++ output, err := tl.RunRaw(context.Background(), map[string]any{ ++ "a": "hello", ++ "b": "world", ++ }) ++ ++ if err != nil { ++ t.Fatalf("RunRaw error: %v", err) ++ } ++ if output != "helloworld" { ++ t.Errorf("output = %v, want %q", output, "helloworld") ++ } ++ }) ++} ++ ++func TestDefineToolWithInputSchema(t *testing.T) { ++ t.Run("creates tool with custom input schema", func(t *testing.T) { ++ r := newTestRegistry(t) ++ customSchema := map[string]any{ ++ "type": "object", ++ "properties": map[string]any{ ++ "query": map[string]any{"type": "string"}, ++ }, ++ "required": []any{"query"}, ++ } ++ ++ tl := DefineToolWithInputSchema(r, "provider/search", "Searches", customSchema, ++ func(ctx *ToolContext, input any) (string, error) { ++ m := input.(map[string]any) ++ return "results for: " + m["query"].(string), nil ++ }) ++ ++ if tl == nil { ++ t.Fatal("DefineToolWithInputSchema returned nil") ++ } ++ ++ def := tl.Definition() ++ if def.InputSchema == nil { ++ t.Error("InputSchema is nil") ++ } ++ }) ++} ++ ++func TestNewTool(t *testing.T) { ++ t.Run("creates unregistered tool", func(t *testing.T) { ++ tl := NewTool("dynamicTool", "A dynamic tool", func(ctx *ToolContext, input struct { ++ Value int `json:"value"` ++ }) (int, error) { ++ return input.Value * 2, nil ++ }) ++ ++ if tl == nil { ++ t.Fatal("NewTool returned nil") ++ } ++ if tl.Name() != "dynamicTool" { ++ t.Errorf("Name() = %q, want %q", tl.Name(), "dynamicTool") ++ } ++ }) ++ ++ t.Run("unregistered tool can be executed", func(t *testing.T) { ++ tl := NewTool("double", "Doubles a number", func(ctx *ToolContext, input struct { ++ N int `json:"n"` ++ }) (int, error) { ++ return input.N * 2, nil ++ }) ++ ++ output, err := tl.RunRaw(context.Background(), map[string]any{"n": 5}) ++ if err != nil { ++ t.Fatalf("RunRaw error: %v", err) ++ } ++ // JSON unmarshalling returns float64 for numbers ++ if output != float64(10) { ++ t.Errorf("output = %v (%T), want 10", output, output) ++ } ++ }) ++ ++ t.Run("tool can be registered later", func(t *testing.T) { ++ r := newTestRegistry(t) ++ tl := NewTool("provider/laterTool", "Registered later", func(ctx *ToolContext, input struct{}) (string, error) { ++ return "done", nil ++ }) ++ ++ tl.Register(r) ++ ++ found := LookupTool(r, "provider/laterTool") ++ if found == nil { ++ t.Error("LookupTool returned nil after registration") ++ } ++ }) ++} ++ ++func TestNewToolWithInputSchema(t *testing.T) { ++ t.Run("creates tool with custom schema", func(t *testing.T) { ++ schema := map[string]any{ ++ "type": "object", ++ "properties": map[string]any{ ++ "data": map[string]any{"type": "array"}, ++ }, ++ } ++ ++ tl := NewToolWithInputSchema("process", "Processes data", schema, ++ func(ctx *ToolContext, input any) (bool, error) { ++ return true, nil ++ }) ++ ++ if tl == nil { ++ t.Fatal("NewToolWithInputSchema returned nil") ++ } ++ ++ def := tl.Definition() ++ if def.InputSchema == nil { ++ t.Error("InputSchema is nil") ++ } ++ }) ++} ++ ++func TestDefineMultipartTool(t *testing.T) { ++ t.Run("creates multipart tool", func(t *testing.T) { ++ r := newTestRegistry(t) ++ tl := DefineMultipartTool(r, "provider/imageGen", "Generates images", ++ func(ctx *ToolContext, input struct { ++ Prompt string `json:"prompt"` ++ }) (*MultipartToolResponse, error) { ++ return &MultipartToolResponse{ ++ Output: "generated", ++ Content: []*Part{ ++ NewMediaPart("image/png", "data:image/png;base64,abc"), ++ }, ++ }, nil ++ }) ++ ++ if tl == nil { ++ t.Fatal("DefineMultipartTool returned nil") ++ } ++ ++ // Check that it's a multipart tool via metadata ++ def := tl.Definition() ++ if def.Metadata == nil { ++ t.Fatal("Metadata is nil") ++ } ++ if def.Metadata["multipart"] != true { ++ t.Error("multipart metadata = false, want true") ++ } ++ }) ++ ++ t.Run("multipart tool returns parts", func(t *testing.T) { ++ r := newTestRegistry(t) ++ tl := DefineMultipartTool(r, "provider/multiOut", "Returns multiple parts", ++ func(ctx *ToolContext, input struct{}) (*MultipartToolResponse, error) { ++ return &MultipartToolResponse{ ++ Output: map[string]any{"status": "ok"}, ++ Content: []*Part{ ++ NewTextPart("additional text"), ++ NewMediaPart("image/jpeg", "data:image/jpeg;base64,xyz"), ++ }, ++ }, nil ++ }) ++ ++ resp, err := tl.RunRawMultipart(context.Background(), map[string]any{}) ++ if err != nil { ++ t.Fatalf("RunRawMultipart error: %v", err) ++ } ++ ++ if len(resp.Content) != 2 { ++ t.Errorf("len(Content) = %d, want 2", len(resp.Content)) ++ } ++ }) ++} ++ ++func TestNewMultipartTool(t *testing.T) { ++ t.Run("creates unregistered multipart tool", func(t *testing.T) { ++ tl := NewMultipartTool("dynamicMulti", "Dynamic multipart", ++ func(ctx *ToolContext, input struct{}) (*MultipartToolResponse, error) { ++ return &MultipartToolResponse{Output: "test"}, nil ++ }) ++ ++ if tl == nil { ++ t.Fatal("NewMultipartTool returned nil") ++ } ++ // Check via definition metadata ++ def := tl.Definition() ++ if def.Metadata["multipart"] != true { ++ t.Error("multipart metadata = false, want true") ++ } ++ }) ++ ++ t.Run("can be registered later", func(t *testing.T) { ++ r := newTestRegistry(t) ++ tl := NewMultipartTool("provider/laterMulti", "Later registration", ++ func(ctx *ToolContext, input struct{}) (*MultipartToolResponse, error) { ++ return &MultipartToolResponse{Output: "ok"}, nil ++ }) ++ ++ tl.Register(r) ++ ++ found := LookupTool(r, "provider/laterMulti") ++ if found == nil { ++ t.Error("LookupTool returned nil after registration") ++ } ++ }) ++} ++ ++func TestToolDefinition(t *testing.T) { ++ t.Run("includes all fields", func(t *testing.T) { ++ r := newTestRegistry(t) ++ tl := DefineTool(r, "provider/complete", "A complete tool", func(ctx *ToolContext, input struct { ++ Query string `json:"query"` ++ }) (struct { ++ Result string `json:"result"` ++ }, error) { ++ return struct { ++ Result string `json:"result"` ++ }{Result: input.Query}, nil ++ }) ++ ++ def := tl.Definition() ++ ++ if def.Name != "provider/complete" { ++ t.Errorf("Name = %q, want %q", def.Name, "provider/complete") ++ } ++ if def.Description != "A complete tool" { ++ t.Errorf("Description = %q, want %q", def.Description, "A complete tool") ++ } ++ if def.InputSchema == nil { ++ t.Error("InputSchema is nil") ++ } ++ if def.OutputSchema == nil { ++ t.Error("OutputSchema is nil") ++ } ++ }) ++} ++ ++func TestLookupTool(t *testing.T) { ++ t.Run("returns nil for empty name", func(t *testing.T) { ++ r := newTestRegistry(t) ++ got := LookupTool(r, "") ++ if got != nil { ++ t.Errorf("LookupTool(\"\") = %v, want nil", got) ++ } ++ }) ++ ++ t.Run("returns nil for non-existent tool", func(t *testing.T) { ++ r := newTestRegistry(t) ++ got := LookupTool(r, "nonexistent/tool") ++ if got != nil { ++ t.Errorf("LookupTool(nonexistent) = %v, want nil", got) ++ } ++ }) ++ ++ t.Run("finds registered tool", func(t *testing.T) { ++ r := newTestRegistry(t) ++ DefineTool(r, "test/findMe", "Find me", func(ctx *ToolContext, input struct{}) (bool, error) { ++ return true, nil ++ }) ++ ++ got := LookupTool(r, "test/findMe") ++ if got == nil { ++ t.Error("LookupTool returned nil for registered tool") ++ } ++ }) ++} ++ ++func TestToolIsMultipart(t *testing.T) { ++ t.Run("regular tool is not multipart", func(t *testing.T) { ++ r := newTestRegistry(t) ++ tl := DefineTool(r, "provider/regular", "Regular tool", func(ctx *ToolContext, input struct{}) (string, error) { ++ return "ok", nil ++ }) ++ ++ def := tl.Definition() ++ if def.Metadata["multipart"] == true { ++ t.Error("multipart metadata = true for regular tool, want false") ++ } ++ }) ++ ++ t.Run("multipart tool is multipart", func(t *testing.T) { ++ r := newTestRegistry(t) ++ tl := DefineMultipartTool(r, "provider/multi", "Multi tool", ++ func(ctx *ToolContext, input struct{}) (*MultipartToolResponse, error) { ++ return &MultipartToolResponse{}, nil ++ }) ++ ++ def := tl.Definition() ++ if def.Metadata["multipart"] != true { ++ t.Error("multipart metadata = false for multipart tool, want true") ++ } ++ }) ++} ++ ++func TestToolRunRaw(t *testing.T) { ++ t.Run("returns output from regular tool", func(t *testing.T) { ++ r := newTestRegistry(t) ++ tl := DefineTool(r, "provider/sum", "Sums numbers", func(ctx *ToolContext, input struct { ++ Nums []int `json:"nums"` ++ }) (int, error) { ++ sum := 0 ++ for _, n := range input.Nums { ++ sum += n ++ } ++ return sum, nil ++ }) ++ ++ output, err := tl.RunRaw(context.Background(), map[string]any{ ++ "nums": []any{1, 2, 3, 4, 5}, ++ }) ++ ++ if err != nil { ++ t.Fatalf("RunRaw error: %v", err) ++ } ++ // JSON unmarshalling returns float64 for numbers ++ if output != float64(15) { ++ t.Errorf("output = %v (%T), want 15", output, output) ++ } ++ }) ++ ++ t.Run("returns error from tool", func(t *testing.T) { ++ r := newTestRegistry(t) ++ tl := DefineTool(r, "provider/fail", "Always fails", func(ctx *ToolContext, input struct{}) (string, error) { ++ return "", errors.New("intentional failure") ++ }) ++ ++ _, err := tl.RunRaw(context.Background(), map[string]any{}) ++ if err == nil { ++ t.Error("expected error, got nil") ++ } ++ }) ++} ++ ++func TestToolRunRawMultipart(t *testing.T) { ++ t.Run("returns full response from multipart tool", func(t *testing.T) { ++ r := newTestRegistry(t) ++ tl := DefineMultipartTool(r, "provider/fullResp", "Full response", ++ func(ctx *ToolContext, input struct{}) (*MultipartToolResponse, error) { ++ return &MultipartToolResponse{ ++ Output: "main output", ++ Content: []*Part{ ++ NewTextPart("extra"), ++ }, ++ }, nil ++ }) ++ ++ resp, err := tl.RunRawMultipart(context.Background(), map[string]any{}) ++ if err != nil { ++ t.Fatalf("RunRawMultipart error: %v", err) ++ } ++ ++ if resp.Output != "main output" { ++ t.Errorf("Output = %v, want %q", resp.Output, "main output") ++ } ++ if len(resp.Content) != 1 { ++ t.Errorf("len(Content) = %d, want 1", len(resp.Content)) ++ } ++ }) ++} ++ ++func TestToolRespond(t *testing.T) { ++ r := newTestRegistry(t) ++ tl := DefineTool(r, "provider/responder", "Test responder", func(ctx *ToolContext, input struct{}) (string, error) { ++ return "ok", nil ++ }) ++ ++ t.Run("creates response for tool request", func(t *testing.T) { ++ reqPart := NewToolRequestPart(&ToolRequest{ ++ Name: "provider/responder", ++ Ref: "ref-123", ++ Input: map[string]any{"x": 1}, ++ }) ++ reqPart.Metadata = map[string]any{"interrupt": true} ++ ++ resp := tl.Respond(reqPart, "output data", nil) ++ ++ if resp == nil { ++ t.Fatal("Respond returned nil") ++ } ++ if !resp.IsToolResponse() { ++ t.Error("response is not a tool response") ++ } ++ if resp.ToolResponse.Name != "provider/responder" { ++ t.Errorf("Name = %q, want %q", resp.ToolResponse.Name, "provider/responder") ++ } ++ if resp.ToolResponse.Ref != "ref-123" { ++ t.Errorf("Ref = %q, want %q", resp.ToolResponse.Ref, "ref-123") ++ } ++ }) ++ ++ t.Run("returns nil for non-tool-request part", func(t *testing.T) { ++ textPart := NewTextPart("not a tool request") ++ ++ resp := tl.Respond(textPart, "output", nil) ++ ++ if resp != nil { ++ t.Errorf("Respond(textPart) = %v, want nil", resp) ++ } ++ }) ++ ++ t.Run("returns nil for nil part", func(t *testing.T) { ++ resp := tl.Respond(nil, "output", nil) ++ ++ if resp != nil { ++ t.Errorf("Respond(nil) = %v, want nil", resp) ++ } ++ }) ++ ++ t.Run("includes response options metadata", func(t *testing.T) { ++ reqPart := NewToolRequestPart(&ToolRequest{ ++ Name: "provider/responder", ++ }) ++ reqPart.Metadata = map[string]any{"interrupt": true} ++ ++ opts := &RespondOptions{ ++ Metadata: map[string]any{"custom": "value"}, ++ } ++ resp := tl.Respond(reqPart, "output", opts) ++ ++ if resp.Metadata == nil { ++ t.Fatal("Metadata is nil") ++ } ++ if resp.Metadata["interruptResponse"] == nil { ++ t.Error("interruptResponse not set in metadata") ++ } ++ }) ++} ++ ++func TestToolRestart(t *testing.T) { ++ r := newTestRegistry(t) ++ tl := DefineTool(r, "provider/restarter", "Test restarter", func(ctx *ToolContext, input struct { ++ Value int `json:"value"` ++ }) (int, error) { ++ return input.Value, nil ++ }) ++ ++ t.Run("creates restart for tool request", func(t *testing.T) { ++ reqPart := NewToolRequestPart(&ToolRequest{ ++ Name: "provider/restarter", ++ Ref: "ref-456", ++ Input: map[string]any{"value": 10}, ++ }) ++ reqPart.Metadata = map[string]any{"interrupt": true} ++ ++ restart := tl.Restart(reqPart, nil) ++ ++ if restart == nil { ++ t.Fatal("Restart returned nil") ++ } ++ if !restart.IsToolRequest() { ++ t.Error("restart is not a tool request") ++ } ++ if restart.ToolRequest.Name != "provider/restarter" { ++ t.Errorf("Name = %q, want %q", restart.ToolRequest.Name, "provider/restarter") ++ } ++ if restart.Metadata["resumed"] != true { ++ t.Errorf("resumed = %v, want true", restart.Metadata["resumed"]) ++ } ++ if restart.Metadata["interrupt"] != nil { ++ t.Error("interrupt should be removed from metadata") ++ } ++ }) ++ ++ t.Run("returns nil for non-tool-request part", func(t *testing.T) { ++ textPart := NewTextPart("text") ++ ++ restart := tl.Restart(textPart, nil) ++ ++ if restart != nil { ++ t.Errorf("Restart(textPart) = %v, want nil", restart) ++ } ++ }) ++ ++ t.Run("returns nil for nil part", func(t *testing.T) { ++ restart := tl.Restart(nil, nil) ++ ++ if restart != nil { ++ t.Errorf("Restart(nil) = %v, want nil", restart) ++ } ++ }) ++ ++ t.Run("replaces input when specified", func(t *testing.T) { ++ reqPart := NewToolRequestPart(&ToolRequest{ ++ Name: "provider/restarter", ++ Input: map[string]any{"value": 10}, ++ }) ++ reqPart.Metadata = map[string]any{"interrupt": true} ++ ++ opts := &RestartOptions{ ++ ReplaceInput: map[string]any{"value": 20}, ++ } ++ restart := tl.Restart(reqPart, opts) ++ ++ newInput := restart.ToolRequest.Input.(map[string]any) ++ if newInput["value"] != 20 { ++ t.Errorf("new input value = %v, want 20", newInput["value"]) ++ } ++ if restart.Metadata["replacedInput"] == nil { ++ t.Error("replacedInput not set in metadata") ++ } ++ }) ++ ++ t.Run("sets resumed metadata when specified", func(t *testing.T) { ++ reqPart := NewToolRequestPart(&ToolRequest{ ++ Name: "provider/restarter", ++ }) ++ reqPart.Metadata = map[string]any{"interrupt": true} ++ ++ opts := &RestartOptions{ ++ ResumedMetadata: map[string]any{"reason": "user confirmed"}, ++ } ++ restart := tl.Restart(reqPart, opts) ++ ++ resumed := restart.Metadata["resumed"].(map[string]any) ++ if resumed["reason"] != "user confirmed" { ++ t.Errorf("resumed.reason = %v, want %q", resumed["reason"], "user confirmed") ++ } ++ }) ++} ++ ++func TestToolInterrupt(t *testing.T) { ++ t.Run("tool can interrupt execution", func(t *testing.T) { ++ r := newTestRegistry(t) ++ tl := DefineTool(r, "provider/interrupter", "Can interrupt", ++ func(ctx *ToolContext, input struct { ++ ShouldInterrupt bool `json:"shouldInterrupt"` ++ }) (string, error) { ++ if input.ShouldInterrupt { ++ return "", ctx.Interrupt(&InterruptOptions{ ++ Metadata: map[string]any{"step": "confirmation"}, ++ }) ++ } ++ return "completed", nil ++ }) ++ ++ _, err := tl.RunRaw(context.Background(), map[string]any{ ++ "shouldInterrupt": true, ++ }) ++ ++ if err == nil { ++ t.Fatal("expected interrupt error, got nil") ++ } ++ ++ isInterrupt, meta := IsToolInterruptError(err) ++ if !isInterrupt { ++ t.Errorf("IsToolInterruptError() = false, want true") ++ } ++ if meta["step"] != "confirmation" { ++ t.Errorf("metadata[step] = %v, want %q", meta["step"], "confirmation") ++ } ++ }) ++ ++ t.Run("tool completes without interrupt", func(t *testing.T) { ++ r := newTestRegistry(t) ++ tl := DefineTool(r, "provider/noInterrupt", "No interrupt", ++ func(ctx *ToolContext, input struct { ++ ShouldInterrupt bool `json:"shouldInterrupt"` ++ }) (string, error) { ++ if input.ShouldInterrupt { ++ return "", ctx.Interrupt(&InterruptOptions{}) ++ } ++ return "completed", nil ++ }) ++ ++ output, err := tl.RunRaw(context.Background(), map[string]any{ ++ "shouldInterrupt": false, ++ }) ++ ++ if err != nil { ++ t.Fatalf("unexpected error: %v", err) ++ } ++ if output != "completed" { ++ t.Errorf("output = %v, want %q", output, "completed") ++ } ++ }) ++} ++ ++func TestToolWithInputSchemaOption(t *testing.T) { ++ t.Run("DefineTool with WithInputSchema", func(t *testing.T) { ++ r := newTestRegistry(t) ++ customSchema := map[string]any{ ++ "type": "object", ++ "properties": map[string]any{ ++ "customField": map[string]any{"type": "string"}, ++ }, ++ } ++ ++ tl := DefineTool(r, "provider/customInput", "Custom input schema", ++ func(ctx *ToolContext, input any) (string, error) { ++ m := input.(map[string]any) ++ return m["customField"].(string), nil ++ }, ++ WithInputSchema(customSchema)) ++ ++ def := tl.Definition() ++ if def.InputSchema == nil { ++ t.Error("InputSchema is nil") ++ } ++ }) ++ ++ t.Run("NewTool with WithInputSchema", func(t *testing.T) { ++ customSchema := map[string]any{ ++ "type": "object", ++ "properties": map[string]any{ ++ "field": map[string]any{"type": "number"}, ++ }, ++ } ++ ++ tl := NewTool("customNew", "Custom new tool", ++ func(ctx *ToolContext, input any) (bool, error) { ++ return true, nil ++ }, ++ WithInputSchema(customSchema)) ++ ++ def := tl.Definition() ++ if def.InputSchema == nil { ++ t.Error("InputSchema is nil") ++ } ++ }) ++} ++ ++func TestResolveUniqueTools(t *testing.T) { ++ t.Run("resolves tools from registry", func(t *testing.T) { ++ r := newTestRegistry(t) ++ DefineTool(r, "provider/tool1", "Tool 1", func(ctx *ToolContext, input struct{}) (bool, error) { ++ return true, nil ++ }) ++ DefineTool(r, "provider/tool2", "Tool 2", func(ctx *ToolContext, input struct{}) (bool, error) { ++ return true, nil ++ }) ++ ++ toolRefs := []ToolRef{ ++ ToolName("provider/tool1"), ++ ToolName("provider/tool2"), ++ } ++ ++ names, newTools, err := resolveUniqueTools(r, toolRefs) ++ ++ if err != nil { ++ t.Fatalf("resolveUniqueTools error: %v", err) ++ } ++ if len(names) != 2 { ++ t.Errorf("len(names) = %d, want 2", len(names)) ++ } ++ if len(newTools) != 0 { ++ t.Errorf("len(newTools) = %d, want 0 (tools already registered)", len(newTools)) ++ } ++ }) ++ ++ t.Run("returns error for duplicate tools", func(t *testing.T) { ++ r := newTestRegistry(t) ++ toolRefs := []ToolRef{ ++ ToolName("provider/dup"), ++ ToolName("provider/dup"), ++ } ++ ++ _, _, err := resolveUniqueTools(r, toolRefs) ++ ++ if err == nil { ++ t.Error("expected error for duplicate tools, got nil") ++ } ++ }) ++ ++ t.Run("identifies new tools to register", func(t *testing.T) { ++ r := newTestRegistry(t) ++ newTl := NewTool("provider/brandNew", "Brand new", func(ctx *ToolContext, input struct{}) (string, error) { ++ return "new", nil ++ }) ++ ++ toolRefs := []ToolRef{newTl} ++ ++ names, newTools, err := resolveUniqueTools(r, toolRefs) ++ ++ if err != nil { ++ t.Fatalf("resolveUniqueTools error: %v", err) ++ } ++ if len(names) != 1 { ++ t.Errorf("len(names) = %d, want 1", len(names)) ++ } ++ if len(newTools) != 1 { ++ t.Errorf("len(newTools) = %d, want 1", len(newTools)) ++ } ++ }) ++} ++ ++func TestIsMultipart(t *testing.T) { ++ t.Run("returns false for standard tool", func(t *testing.T) { ++ r := newTestRegistry(t) ++ tl := DefineTool(r, "provider/standard", "Standard tool", ++ func(ctx *ToolContext, input struct{}) (string, error) { ++ return "result", nil ++ }) ++ ++ // IsMultipart is on the internal *tool type, so we need to type assert ++ internalTool := tl.(*tool) ++ if internalTool.IsMultipart() { ++ t.Error("IsMultipart() = true for standard tool, want false") ++ } ++ }) ++ ++ t.Run("returns false for NewTool", func(t *testing.T) { ++ tl := NewTool("standard", "Standard", ++ func(ctx *ToolContext, input struct{}) (string, error) { ++ return "result", nil ++ }) ++ ++ internalTool := tl.(*tool) ++ if internalTool.IsMultipart() { ++ t.Error("IsMultipart() = true for NewTool, want false") ++ } ++ }) ++ ++ t.Run("returns true for multipart tool", func(t *testing.T) { ++ r := newTestRegistry(t) ++ tl := DefineMultipartTool(r, "provider/multipart", "Multipart tool", ++ func(ctx *ToolContext, input struct{}) (*MultipartToolResponse, error) { ++ return &MultipartToolResponse{ ++ Content: []*Part{NewTextPart("hello"), NewTextPart("world")}, ++ }, nil ++ }) ++ ++ internalTool := tl.(*tool) ++ if !internalTool.IsMultipart() { ++ t.Error("IsMultipart() = false for multipart tool, want true") ++ } ++ }) ++ ++ t.Run("returns true for NewMultipartTool", func(t *testing.T) { ++ tl := NewMultipartTool("multipart", "Multipart", ++ func(ctx *ToolContext, input struct{}) (*MultipartToolResponse, error) { ++ return &MultipartToolResponse{ ++ Content: []*Part{NewTextPart("content")}, ++ }, nil ++ }) ++ ++ internalTool := tl.(*tool) ++ if !internalTool.IsMultipart() { ++ t.Error("IsMultipart() = false for NewMultipartTool, want true") ++ } ++ }) ++} +diff --git a/go/core/action_test.go b/go/core/action_test.go +index 4ce63cffe..65309d850 100644 +--- a/go/core/action_test.go ++++ b/go/core/action_test.go +@@ -19,6 +19,7 @@ package core + import ( + "bytes" + "context" ++ "encoding/json" + "slices" + "testing" + +@@ -124,3 +125,309 @@ func TestActionTracing(t *testing.T) { + } + t.Fatalf("did not find trace named %q", name) + } ++ ++func TestNewAction(t *testing.T) { ++ t.Run("creates unregistered action", func(t *testing.T) { ++ fn := func(ctx context.Context, input string) (string, error) { ++ return "Hello, " + input, nil ++ } ++ a := NewAction("greet", api.ActionTypeCustom, nil, nil, fn) ++ ++ if a == nil { ++ t.Fatal("NewAction returned nil") ++ } ++ if a.Name() != "greet" { ++ t.Errorf("Name() = %q, want %q", a.Name(), "greet") ++ } ++ }) ++ ++ t.Run("action can be executed", func(t *testing.T) { ++ fn := func(ctx context.Context, input int) (int, error) { ++ return input * 2, nil ++ } ++ a := NewAction("double", api.ActionTypeCustom, nil, nil, fn) ++ ++ got, err := a.Run(context.Background(), 5, nil) ++ if err != nil { ++ t.Fatalf("Run error: %v", err) ++ } ++ if got != 10 { ++ t.Errorf("got %d, want 10", got) ++ } ++ }) ++ ++ t.Run("action with custom input schema", func(t *testing.T) { ++ customSchema := map[string]any{ ++ "type": "object", ++ "properties": map[string]any{ ++ "name": map[string]any{"type": "string"}, ++ }, ++ } ++ fn := func(ctx context.Context, input any) (string, error) { ++ return "ok", nil ++ } ++ a := NewAction("withSchema", api.ActionTypeCustom, nil, customSchema, fn) ++ ++ desc := a.Desc() ++ if desc.InputSchema == nil { ++ t.Error("InputSchema is nil") ++ } ++ }) ++ ++ t.Run("action with metadata", func(t *testing.T) { ++ meta := map[string]any{ ++ "description": "A test action", ++ "version": "1.0", ++ } ++ fn := func(ctx context.Context, input struct{}) (bool, error) { ++ return true, nil ++ } ++ a := NewAction("withMeta", api.ActionTypeCustom, meta, nil, fn) ++ ++ desc := a.Desc() ++ if desc.Description != "A test action" { ++ t.Errorf("Description = %q, want %q", desc.Description, "A test action") ++ } ++ }) ++} ++ ++func TestNewStreamingAction(t *testing.T) { ++ t.Run("creates streaming action", func(t *testing.T) { ++ fn := func(ctx context.Context, n int, cb func(context.Context, int) error) (int, error) { ++ if cb != nil { ++ for i := 0; i < n; i++ { ++ if err := cb(ctx, i); err != nil { ++ return 0, err ++ } ++ } ++ } ++ return n, nil ++ } ++ a := NewStreamingAction("counter", api.ActionTypeCustom, nil, nil, fn) ++ ++ if a == nil { ++ t.Fatal("NewStreamingAction returned nil") ++ } ++ if a.Name() != "counter" { ++ t.Errorf("Name() = %q, want %q", a.Name(), "counter") ++ } ++ }) ++ ++ t.Run("streaming action streams correctly", func(t *testing.T) { ++ fn := func(ctx context.Context, n int, cb func(context.Context, string) error) (int, error) { ++ if cb != nil { ++ for i := 0; i < n; i++ { ++ if err := cb(ctx, "chunk"); err != nil { ++ return 0, err ++ } ++ } ++ } ++ return n, nil ++ } ++ a := NewStreamingAction("streamer", api.ActionTypeCustom, nil, nil, fn) ++ ++ var chunks []string ++ got, err := a.Run(context.Background(), 3, func(ctx context.Context, chunk string) error { ++ chunks = append(chunks, chunk) ++ return nil ++ }) ++ ++ if err != nil { ++ t.Fatalf("Run error: %v", err) ++ } ++ if got != 3 { ++ t.Errorf("got %d, want 3", got) ++ } ++ if len(chunks) != 3 { ++ t.Errorf("len(chunks) = %d, want 3", len(chunks)) ++ } ++ }) ++} ++ ++func TestActionDesc(t *testing.T) { ++ t.Run("returns action descriptor", func(t *testing.T) { ++ meta := map[string]any{ ++ "description": "Test description", ++ "custom": "value", ++ } ++ fn := func(ctx context.Context, input struct { ++ Name string `json:"name"` ++ }) (struct { ++ Greeting string `json:"greeting"` ++ }, error) { ++ return struct { ++ Greeting string `json:"greeting"` ++ }{Greeting: "Hello"}, nil ++ } ++ ++ r := registry.New() ++ a := DefineAction(r, "test/describe", api.ActionTypeCustom, meta, nil, fn) ++ ++ desc := a.Desc() ++ ++ if desc.Name != "test/describe" { ++ t.Errorf("Name = %q, want %q", desc.Name, "test/describe") ++ } ++ if desc.Description != "Test description" { ++ t.Errorf("Description = %q, want %q", desc.Description, "Test description") ++ } ++ if desc.Type != api.ActionTypeCustom { ++ t.Errorf("Type = %v, want %v", desc.Type, api.ActionTypeCustom) ++ } ++ if desc.InputSchema == nil { ++ t.Error("InputSchema is nil") ++ } ++ if desc.OutputSchema == nil { ++ t.Error("OutputSchema is nil") ++ } ++ }) ++} ++ ++func TestActionRegister(t *testing.T) { ++ t.Run("registers action with registry", func(t *testing.T) { ++ r := registry.New() ++ fn := func(ctx context.Context, input string) (string, error) { ++ return input, nil ++ } ++ a := NewAction("test/register", api.ActionTypeCustom, nil, nil, fn) ++ ++ a.Register(r) ++ ++ key := api.KeyFromName(api.ActionTypeCustom, "test/register") ++ found := r.LookupAction(key) ++ if found == nil { ++ t.Error("registered action not found in registry") ++ } ++ }) ++} ++ ++func TestResolveActionFor(t *testing.T) { ++ t.Run("finds registered action", func(t *testing.T) { ++ r := registry.New() ++ fn := func(ctx context.Context, input int) (int, error) { ++ return input + 1, nil ++ } ++ DefineAction(r, "test/resolvable", api.ActionTypeCustom, nil, nil, fn) ++ ++ found := ResolveActionFor[int, int, struct{}](r, api.ActionTypeCustom, "test/resolvable") ++ ++ if found == nil { ++ t.Fatal("ResolveActionFor returned nil") ++ } ++ if found.Name() != "test/resolvable" { ++ t.Errorf("Name() = %q, want %q", found.Name(), "test/resolvable") ++ } ++ }) ++ ++ t.Run("returns nil for non-existent action", func(t *testing.T) { ++ r := registry.New() ++ ++ found := ResolveActionFor[int, int, struct{}](r, api.ActionTypeCustom, "test/nonexistent") ++ ++ if found != nil { ++ t.Errorf("ResolveActionFor returned %v, want nil", found) ++ } ++ }) ++} ++ ++func TestLookupActionFor(t *testing.T) { ++ t.Run("finds registered action", func(t *testing.T) { ++ r := registry.New() ++ fn := func(ctx context.Context, input string) (string, error) { ++ return "found: " + input, nil ++ } ++ DefineAction(r, "test/lookupable", api.ActionTypeCustom, nil, nil, fn) ++ ++ found := LookupActionFor[string, string, struct{}](r, api.ActionTypeCustom, "test/lookupable") ++ ++ if found == nil { ++ t.Fatal("LookupActionFor returned nil") ++ } ++ }) ++ ++ t.Run("returns nil for non-existent action", func(t *testing.T) { ++ r := registry.New() ++ ++ found := LookupActionFor[string, string, struct{}](r, api.ActionTypeCustom, "test/missing") ++ ++ if found != nil { ++ t.Errorf("LookupActionFor returned %v, want nil", found) ++ } ++ }) ++} ++ ++func TestRunJSONWithTelemetry(t *testing.T) { ++ t.Run("returns telemetry info with result", func(t *testing.T) { ++ r := registry.New() ++ fn := func(ctx context.Context, input int) (int, error) { ++ return input * 2, nil ++ } ++ a := DefineAction(r, "test/telemetry", api.ActionTypeCustom, nil, nil, fn) ++ ++ result, err := a.RunJSONWithTelemetry(context.Background(), []byte("5"), nil) ++ ++ if err != nil { ++ t.Fatalf("RunJSONWithTelemetry error: %v", err) ++ } ++ if result == nil { ++ t.Fatal("result is nil") ++ } ++ if string(result.Result) != "10" { ++ t.Errorf("Result = %s, want %q", result.Result, "10") ++ } ++ // TraceId and SpanId should be set ++ if result.TraceId == "" { ++ t.Error("TraceId is empty") ++ } ++ if result.SpanId == "" { ++ t.Error("SpanId is empty") ++ } ++ }) ++ ++ t.Run("handles streaming callback", func(t *testing.T) { ++ r := registry.New() ++ fn := func(ctx context.Context, n int, cb func(context.Context, int) error) (int, error) { ++ if cb != nil { ++ for i := 0; i < n; i++ { ++ if err := cb(ctx, i); err != nil { ++ return 0, err ++ } ++ } ++ } ++ return n, nil ++ } ++ a := DefineStreamingAction(r, "test/streamTelemetry", api.ActionTypeCustom, nil, nil, fn) ++ ++ var chunks []string ++ cb := func(ctx context.Context, chunk json.RawMessage) error { ++ chunks = append(chunks, string(chunk)) ++ return nil ++ } ++ ++ result, err := a.RunJSONWithTelemetry(context.Background(), []byte("3"), cb) ++ ++ if err != nil { ++ t.Fatalf("RunJSONWithTelemetry error: %v", err) ++ } ++ if result == nil { ++ t.Fatal("result is nil") ++ } ++ if len(chunks) != 3 { ++ t.Errorf("len(chunks) = %d, want 3", len(chunks)) ++ } ++ }) ++ ++ t.Run("returns error for invalid JSON input", func(t *testing.T) { ++ r := registry.New() ++ fn := func(ctx context.Context, input int) (int, error) { ++ return input, nil ++ } ++ a := DefineAction(r, "test/invalidInput", api.ActionTypeCustom, nil, nil, fn) ++ ++ _, err := a.RunJSONWithTelemetry(context.Background(), []byte("not valid json"), nil) ++ ++ if err == nil { ++ t.Error("expected error for invalid JSON, got nil") ++ } ++ }) ++} +diff --git a/go/core/background_action_test.go b/go/core/background_action_test.go +new file mode 100644 +index 000000000..5ce5d75ff +--- /dev/null ++++ b/go/core/background_action_test.go +@@ -0,0 +1,431 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++// ++// SPDX-License-Identifier: Apache-2.0 ++ ++package core ++ ++import ( ++ "context" ++ "testing" ++ ++ "github.com/firebase/genkit/go/core/api" ++ "github.com/firebase/genkit/go/internal/registry" ++) ++ ++func TestNewBackgroundAction(t *testing.T) { ++ t.Run("creates background action with all functions", func(t *testing.T) { ++ startFn := func(ctx context.Context, input string) (*Operation[string], error) { ++ return &Operation[string]{ID: "op-1", Done: false}, nil ++ } ++ checkFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { ++ return &Operation[string]{ID: op.ID, Done: true, Output: "result"}, nil ++ } ++ cancelFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { ++ return &Operation[string]{ID: op.ID, Done: true}, nil ++ } ++ ++ ba := NewBackgroundAction("test/background", api.ActionTypeCustom, nil, startFn, checkFn, cancelFn) ++ ++ if ba == nil { ++ t.Fatal("NewBackgroundAction returned nil") ++ } ++ if ba.Name() != "test/background" { ++ t.Errorf("Name() = %q, want %q", ba.Name(), "test/background") ++ } ++ if !ba.SupportsCancel() { ++ t.Error("SupportsCancel() = false, want true") ++ } ++ }) ++ ++ t.Run("creates background action without cancel", func(t *testing.T) { ++ startFn := func(ctx context.Context, input int) (*Operation[int], error) { ++ return &Operation[int]{ID: "op-1", Done: false}, nil ++ } ++ checkFn := func(ctx context.Context, op *Operation[int]) (*Operation[int], error) { ++ return &Operation[int]{ID: op.ID, Done: true, Output: 42}, nil ++ } ++ ++ ba := NewBackgroundAction("test/nocancel", api.ActionTypeCustom, nil, startFn, checkFn, nil) ++ ++ if ba == nil { ++ t.Fatal("NewBackgroundAction returned nil") ++ } ++ if ba.SupportsCancel() { ++ t.Error("SupportsCancel() = true, want false") ++ } ++ }) ++ ++ t.Run("panics with empty name", func(t *testing.T) { ++ defer func() { ++ if r := recover(); r == nil { ++ t.Error("expected panic for empty name") ++ } ++ }() ++ ++ NewBackgroundAction("", api.ActionTypeCustom, nil, ++ func(ctx context.Context, input string) (*Operation[string], error) { ++ return nil, nil ++ }, ++ func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { ++ return nil, nil ++ }, ++ nil, ++ ) ++ }) ++ ++ t.Run("panics with nil startFn", func(t *testing.T) { ++ defer func() { ++ if r := recover(); r == nil { ++ t.Error("expected panic for nil startFn") ++ } ++ }() ++ ++ NewBackgroundAction[string, string]("test/nilstart", api.ActionTypeCustom, nil, ++ nil, ++ func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { ++ return nil, nil ++ }, ++ nil, ++ ) ++ }) ++ ++ t.Run("panics with nil checkFn", func(t *testing.T) { ++ defer func() { ++ if r := recover(); r == nil { ++ t.Error("expected panic for nil checkFn") ++ } ++ }() ++ ++ NewBackgroundAction("test/nilcheck", api.ActionTypeCustom, nil, ++ func(ctx context.Context, input string) (*Operation[string], error) { ++ return nil, nil ++ }, ++ nil, ++ nil, ++ ) ++ }) ++} ++ ++func TestDefineBackgroundAction(t *testing.T) { ++ t.Run("creates and registers background action", func(t *testing.T) { ++ r := registry.New() ++ startFn := func(ctx context.Context, input string) (*Operation[string], error) { ++ return &Operation[string]{ID: "op-1", Done: false}, nil ++ } ++ checkFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { ++ return &Operation[string]{ID: op.ID, Done: true, Output: "done"}, nil ++ } ++ ++ ba := DefineBackgroundAction(r, "test/registered", api.ActionTypeCustom, nil, startFn, checkFn, nil) ++ ++ if ba == nil { ++ t.Fatal("DefineBackgroundAction returned nil") ++ } ++ ++ // Verify action is registered ++ key := api.KeyFromName(api.ActionTypeCustom, "test/registered") ++ found := r.LookupAction(key) ++ if found == nil { ++ t.Error("background action not found in registry") ++ } ++ }) ++} ++ ++func TestBackgroundActionStart(t *testing.T) { ++ t.Run("starts operation", func(t *testing.T) { ++ r := registry.New() ++ startFn := func(ctx context.Context, input string) (*Operation[string], error) { ++ return &Operation[string]{ID: "test-op", Done: false, Metadata: map[string]any{"input": input}}, nil ++ } ++ checkFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { ++ return &Operation[string]{ID: op.ID, Done: op.Done, Metadata: map[string]any{}}, nil ++ } ++ ++ ba := DefineBackgroundAction(r, "test/start", api.ActionTypeCustom, nil, startFn, checkFn, nil) ++ ++ op, err := ba.Start(context.Background(), "hello") ++ if err != nil { ++ t.Fatalf("Start error: %v", err) ++ } ++ if op.ID != "test-op" { ++ t.Errorf("op.ID = %q, want %q", op.ID, "test-op") ++ } ++ if op.Done { ++ t.Error("op.Done = true, want false") ++ } ++ // Check that Action key is set ++ if op.Action == "" { ++ t.Error("op.Action is empty, expected to be set") ++ } ++ }) ++} ++ ++func TestBackgroundActionCheck(t *testing.T) { ++ t.Run("checks operation status", func(t *testing.T) { ++ r := registry.New() ++ startFn := func(ctx context.Context, input string) (*Operation[string], error) { ++ return &Operation[string]{ID: "check-op", Done: false, Metadata: map[string]any{}}, nil ++ } ++ checkFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { ++ return &Operation[string]{ID: op.ID, Done: true, Output: "completed", Metadata: map[string]any{}}, nil ++ } ++ ++ ba := DefineBackgroundAction(r, "test/check", api.ActionTypeCustom, nil, startFn, checkFn, nil) ++ ++ op, err := ba.Start(context.Background(), "input") ++ if err != nil { ++ t.Fatalf("Start error: %v", err) ++ } ++ ++ checked, err := ba.Check(context.Background(), op) ++ if err != nil { ++ t.Fatalf("Check error: %v", err) ++ } ++ if !checked.Done { ++ t.Error("checked.Done = false, want true") ++ } ++ if checked.Output != "completed" { ++ t.Errorf("checked.Output = %q, want %q", checked.Output, "completed") ++ } ++ }) ++} ++ ++func TestBackgroundActionCancel(t *testing.T) { ++ t.Run("cancels operation when supported", func(t *testing.T) { ++ r := registry.New() ++ startFn := func(ctx context.Context, input string) (*Operation[string], error) { ++ return &Operation[string]{ID: "cancel-op", Done: false, Metadata: map[string]any{}}, nil ++ } ++ checkFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { ++ return &Operation[string]{ID: op.ID, Done: op.Done, Metadata: map[string]any{}}, nil ++ } ++ cancelFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { ++ return &Operation[string]{ID: op.ID, Done: true, Metadata: map[string]any{"cancelled": true}}, nil ++ } ++ ++ ba := DefineBackgroundAction(r, "test/cancel", api.ActionTypeCustom, nil, startFn, checkFn, cancelFn) ++ ++ op, err := ba.Start(context.Background(), "input") ++ if err != nil { ++ t.Fatalf("Start error: %v", err) ++ } ++ ++ cancelled, err := ba.Cancel(context.Background(), op) ++ if err != nil { ++ t.Fatalf("Cancel error: %v", err) ++ } ++ if !cancelled.Done { ++ t.Error("cancelled.Done = false, want true") ++ } ++ }) ++ ++ t.Run("returns error when cancel not supported", func(t *testing.T) { ++ r := registry.New() ++ startFn := func(ctx context.Context, input string) (*Operation[string], error) { ++ return &Operation[string]{ID: "no-cancel-op", Done: false, Metadata: map[string]any{}}, nil ++ } ++ checkFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { ++ return &Operation[string]{ID: op.ID, Done: op.Done, Metadata: map[string]any{}}, nil ++ } ++ ++ ba := DefineBackgroundAction(r, "test/nocancel", api.ActionTypeCustom, nil, startFn, checkFn, nil) ++ ++ op, err := ba.Start(context.Background(), "input") ++ if err != nil { ++ t.Fatalf("Start error: %v", err) ++ } ++ ++ _, err = ba.Cancel(context.Background(), op) ++ if err == nil { ++ t.Error("expected error for unsupported cancel, got nil") ++ } ++ }) ++} ++ ++func TestBackgroundActionRegister(t *testing.T) { ++ t.Run("registers all sub-actions", func(t *testing.T) { ++ r := registry.New() ++ startFn := func(ctx context.Context, input string) (*Operation[string], error) { ++ return &Operation[string]{ID: "reg-op", Metadata: map[string]any{}}, nil ++ } ++ checkFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { ++ return &Operation[string]{ID: op.ID, Metadata: map[string]any{}}, nil ++ } ++ cancelFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { ++ return &Operation[string]{ID: op.ID, Metadata: map[string]any{}}, nil ++ } ++ ++ ba := NewBackgroundAction("test/register", api.ActionTypeCustom, nil, startFn, checkFn, cancelFn) ++ ba.Register(r) ++ ++ // Check main action ++ mainKey := api.KeyFromName(api.ActionTypeCustom, "test/register") ++ if r.LookupAction(mainKey) == nil { ++ t.Error("main action not registered") ++ } ++ ++ // Check check action ++ checkKey := api.KeyFromName(api.ActionTypeCheckOperation, "test/register") ++ if r.LookupAction(checkKey) == nil { ++ t.Error("check action not registered") ++ } ++ ++ // Check cancel action ++ cancelKey := api.KeyFromName(api.ActionTypeCancelOperation, "test/register") ++ if r.LookupAction(cancelKey) == nil { ++ t.Error("cancel action not registered") ++ } ++ }) ++ ++ t.Run("registers without cancel action when not provided", func(t *testing.T) { ++ r := registry.New() ++ startFn := func(ctx context.Context, input string) (*Operation[string], error) { ++ return &Operation[string]{ID: "reg-op", Metadata: map[string]any{}}, nil ++ } ++ checkFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { ++ return &Operation[string]{ID: op.ID, Metadata: map[string]any{}}, nil ++ } ++ ++ ba := NewBackgroundAction("test/register-nocancel", api.ActionTypeCustom, nil, startFn, checkFn, nil) ++ ba.Register(r) ++ ++ // Cancel action should not be registered ++ cancelKey := api.KeyFromName(api.ActionTypeCancelOperation, "test/register-nocancel") ++ if r.LookupAction(cancelKey) != nil { ++ t.Error("cancel action should not be registered") ++ } ++ }) ++} ++ ++func TestLookupBackgroundAction(t *testing.T) { ++ t.Run("finds registered background action", func(t *testing.T) { ++ r := registry.New() ++ startFn := func(ctx context.Context, input string) (*Operation[string], error) { ++ return &Operation[string]{ID: "lookup-op", Metadata: map[string]any{}}, nil ++ } ++ checkFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { ++ return &Operation[string]{ID: op.ID, Metadata: map[string]any{}}, nil ++ } ++ ++ DefineBackgroundAction(r, "test/lookup", api.ActionTypeCustom, nil, startFn, checkFn, nil) ++ ++ key := api.KeyFromName(api.ActionTypeCustom, "test/lookup") ++ found := LookupBackgroundAction[string, string](r, key) ++ ++ if found == nil { ++ t.Fatal("LookupBackgroundAction returned nil") ++ } ++ if found.Name() != "test/lookup" { ++ t.Errorf("Name() = %q, want %q", found.Name(), "test/lookup") ++ } ++ }) ++ ++ t.Run("returns nil for non-existent action", func(t *testing.T) { ++ r := registry.New() ++ ++ key := api.KeyFromName(api.ActionTypeCustom, "test/nonexistent") ++ found := LookupBackgroundAction[string, string](r, key) ++ ++ if found != nil { ++ t.Errorf("LookupBackgroundAction returned %v, want nil", found) ++ } ++ }) ++} ++ ++func TestCheckOperation(t *testing.T) { ++ t.Run("checks operation using registry lookup", func(t *testing.T) { ++ r := registry.New() ++ startFn := func(ctx context.Context, input string) (*Operation[string], error) { ++ return &Operation[string]{ID: "check-op", Done: false, Metadata: map[string]any{}}, nil ++ } ++ checkFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { ++ return &Operation[string]{ID: op.ID, Done: true, Output: "checked", Metadata: map[string]any{}}, nil ++ } ++ ++ ba := DefineBackgroundAction(r, "test/checkop", api.ActionTypeCustom, nil, startFn, checkFn, nil) ++ ++ op, err := ba.Start(context.Background(), "input") ++ if err != nil { ++ t.Fatalf("Start error: %v", err) ++ } ++ ++ checked, err := CheckOperation[string, string](context.Background(), r, op) ++ if err != nil { ++ t.Fatalf("CheckOperation error: %v", err) ++ } ++ if !checked.Done { ++ t.Error("checked.Done = false, want true") ++ } ++ if checked.Output != "checked" { ++ t.Errorf("checked.Output = %q, want %q", checked.Output, "checked") ++ } ++ }) ++ ++ t.Run("returns error for nil operation", func(t *testing.T) { ++ r := registry.New() ++ ++ _, err := CheckOperation[string, string](context.Background(), r, nil) ++ if err == nil { ++ t.Error("expected error for nil operation, got nil") ++ } ++ }) ++ ++ t.Run("returns error for operation with empty Action", func(t *testing.T) { ++ r := registry.New() ++ op := &Operation[string]{ID: "op-1", Metadata: map[string]any{}} ++ ++ _, err := CheckOperation[string, string](context.Background(), r, op) ++ if err == nil { ++ t.Error("expected error for operation with empty Action, got nil") ++ } ++ }) ++ ++ t.Run("returns error for non-existent action", func(t *testing.T) { ++ r := registry.New() ++ op := &Operation[string]{ ++ ID: "op-1", ++ Action: api.KeyFromName(api.ActionTypeCustom, "test/nonexistent"), ++ Metadata: map[string]any{}, ++ } ++ ++ _, err := CheckOperation[string, string](context.Background(), r, op) ++ if err == nil { ++ t.Error("expected error for non-existent action, got nil") ++ } ++ }) ++} ++ ++func TestBackgroundActionWithMetadata(t *testing.T) { ++ t.Run("preserves metadata", func(t *testing.T) { ++ r := registry.New() ++ meta := map[string]any{ ++ "description": "A test background action", ++ "version": "1.0", ++ } ++ startFn := func(ctx context.Context, input string) (*Operation[string], error) { ++ return &Operation[string]{ID: "meta-op", Metadata: map[string]any{}}, nil ++ } ++ checkFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { ++ return &Operation[string]{ID: op.ID, Metadata: map[string]any{}}, nil ++ } ++ ++ ba := DefineBackgroundAction(r, "test/meta", api.ActionTypeCustom, meta, startFn, checkFn, nil) ++ ++ desc := ba.Desc() ++ if desc.Description != "A test background action" { ++ t.Errorf("Description = %q, want %q", desc.Description, "A test background action") ++ } ++ }) ++} +diff --git a/go/core/context_test.go b/go/core/context_test.go +new file mode 100644 +index 000000000..3ee4b8a0d +--- /dev/null ++++ b/go/core/context_test.go +@@ -0,0 +1,122 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++// ++// SPDX-License-Identifier: Apache-2.0 ++ ++package core ++ ++import ( ++ "context" ++ "testing" ++ ++ "github.com/google/go-cmp/cmp" ++) ++ ++func TestWithActionContext(t *testing.T) { ++ t.Run("adds action context to context", func(t *testing.T) { ++ ctx := context.Background() ++ actionCtx := ActionContext{ ++ "userId": "user-123", ++ "sessionId": "session-456", ++ } ++ ++ newCtx := WithActionContext(ctx, actionCtx) ++ ++ retrieved := FromContext(newCtx) ++ if diff := cmp.Diff(actionCtx, retrieved); diff != "" { ++ t.Errorf("ActionContext mismatch (-want +got):\n%s", diff) ++ } ++ }) ++ ++ t.Run("replaces existing action context", func(t *testing.T) { ++ ctx := context.Background() ++ first := ActionContext{"key": "first"} ++ second := ActionContext{"key": "second"} ++ ++ ctx = WithActionContext(ctx, first) ++ ctx = WithActionContext(ctx, second) ++ ++ retrieved := FromContext(ctx) ++ if retrieved["key"] != "second" { ++ t.Errorf("key = %v, want %q", retrieved["key"], "second") ++ } ++ }) ++ ++ t.Run("allows nil action context", func(t *testing.T) { ++ ctx := context.Background() ++ newCtx := WithActionContext(ctx, nil) ++ ++ retrieved := FromContext(newCtx) ++ if retrieved != nil { ++ t.Errorf("expected nil, got %v", retrieved) ++ } ++ }) ++} ++ ++func TestFromContext(t *testing.T) { ++ t.Run("returns nil when no action context", func(t *testing.T) { ++ ctx := context.Background() ++ retrieved := FromContext(ctx) ++ ++ if retrieved != nil { ++ t.Errorf("expected nil, got %v", retrieved) ++ } ++ }) ++ ++ t.Run("returns action context when present", func(t *testing.T) { ++ ctx := context.Background() ++ actionCtx := ActionContext{ ++ "requestId": "req-789", ++ } ++ ctx = WithActionContext(ctx, actionCtx) ++ ++ retrieved := FromContext(ctx) ++ if retrieved["requestId"] != "req-789" { ++ t.Errorf("requestId = %v, want %q", retrieved["requestId"], "req-789") ++ } ++ }) ++ ++ t.Run("returns correct context from nested contexts", func(t *testing.T) { ++ ctx := context.Background() ++ actionCtx := ActionContext{"level": "root"} ++ ctx = WithActionContext(ctx, actionCtx) ++ ++ // Create child context with deadline (doesn't affect action context) ++ childCtx, cancel := context.WithCancel(ctx) ++ defer cancel() ++ ++ retrieved := FromContext(childCtx) ++ if retrieved["level"] != "root" { ++ t.Errorf("level = %v, want %q", retrieved["level"], "root") ++ } ++ }) ++} ++ ++func TestActionContextModification(t *testing.T) { ++ t.Run("modifications to retrieved context affect original", func(t *testing.T) { ++ ctx := context.Background() ++ actionCtx := ActionContext{"mutable": "original"} ++ ctx = WithActionContext(ctx, actionCtx) ++ ++ retrieved := FromContext(ctx) ++ retrieved["mutable"] = "modified" ++ ++ // Check that modification affected the stored context ++ // (maps are reference types, so this behavior is expected) ++ secondRetrieval := FromContext(ctx) ++ if secondRetrieval["mutable"] != "modified" { ++ t.Errorf("mutable = %v, want %q", secondRetrieval["mutable"], "modified") ++ } ++ }) ++} +diff --git a/go/core/core_test.go b/go/core/core_test.go +new file mode 100644 +index 000000000..67ee0d912 +--- /dev/null ++++ b/go/core/core_test.go +@@ -0,0 +1,284 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++// ++// SPDX-License-Identifier: Apache-2.0 ++ ++package core ++ ++import ( ++ "testing" ++ ++ "github.com/firebase/genkit/go/internal/registry" ++ "github.com/google/go-cmp/cmp" ++) ++ ++func TestDefineSchema(t *testing.T) { ++ t.Run("registers schema in registry", func(t *testing.T) { ++ r := registry.New() ++ schema := map[string]any{ ++ "type": "object", ++ "properties": map[string]any{ ++ "name": map[string]any{"type": "string"}, ++ "age": map[string]any{"type": "integer"}, ++ }, ++ "required": []any{"name"}, ++ } ++ ++ DefineSchema(r, "Person", schema) ++ ++ found := r.LookupSchema("Person") ++ if found == nil { ++ t.Fatal("schema not found in registry") ++ } ++ if diff := cmp.Diff(schema, found); diff != "" { ++ t.Errorf("schema mismatch (-want +got):\n%s", diff) ++ } ++ }) ++} ++ ++func TestDefineSchemaFor(t *testing.T) { ++ t.Run("registers schema derived from Go type", func(t *testing.T) { ++ r := registry.New() ++ ++ type User struct { ++ Name string `json:"name"` ++ Email string `json:"email"` ++ } ++ ++ DefineSchemaFor[User](r) ++ ++ found := r.LookupSchema("User") ++ if found == nil { ++ t.Fatal("schema not found in registry") ++ } ++ // Check that the schema has expected properties ++ props, ok := found["properties"].(map[string]any) ++ if !ok { ++ t.Fatal("expected properties in schema") ++ } ++ if props["name"] == nil { ++ t.Error("expected 'name' property in schema") ++ } ++ if props["email"] == nil { ++ t.Error("expected 'email' property in schema") ++ } ++ }) ++ ++ t.Run("handles pointer types", func(t *testing.T) { ++ r := registry.New() ++ ++ type Config struct { ++ Debug bool `json:"debug"` ++ } ++ ++ DefineSchemaFor[*Config](r) ++ ++ found := r.LookupSchema("Config") ++ if found == nil { ++ t.Fatal("schema not found in registry for pointer type") ++ } ++ }) ++} ++ ++func TestSchemaRef(t *testing.T) { ++ t.Run("returns schema reference map", func(t *testing.T) { ++ ref := SchemaRef("MyType") ++ ++ want := map[string]any{ ++ "$ref": "genkit:MyType", ++ } ++ if diff := cmp.Diff(want, ref); diff != "" { ++ t.Errorf("SchemaRef mismatch (-want +got):\n%s", diff) ++ } ++ }) ++ ++ t.Run("handles various names", func(t *testing.T) { ++ tests := []struct { ++ name string ++ want string ++ }{ ++ {"Simple", "genkit:Simple"}, ++ {"Package.Type", "genkit:Package.Type"}, ++ {"my-schema", "genkit:my-schema"}, ++ } ++ ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ ref := SchemaRef(tt.name) ++ if ref["$ref"] != tt.want { ++ t.Errorf("$ref = %q, want %q", ref["$ref"], tt.want) ++ } ++ }) ++ } ++ }) ++} ++ ++func TestResolveSchema(t *testing.T) { ++ t.Run("returns nil for nil schema", func(t *testing.T) { ++ r := registry.New() ++ ++ resolved, err := ResolveSchema(r, nil) ++ ++ if err != nil { ++ t.Fatalf("unexpected error: %v", err) ++ } ++ if resolved != nil { ++ t.Errorf("expected nil, got %v", resolved) ++ } ++ }) ++ ++ t.Run("returns original schema without ref", func(t *testing.T) { ++ r := registry.New() ++ schema := map[string]any{ ++ "type": "string", ++ } ++ ++ resolved, err := ResolveSchema(r, schema) ++ ++ if err != nil { ++ t.Fatalf("unexpected error: %v", err) ++ } ++ if diff := cmp.Diff(schema, resolved); diff != "" { ++ t.Errorf("schema mismatch (-want +got):\n%s", diff) ++ } ++ }) ++ ++ t.Run("resolves genkit ref", func(t *testing.T) { ++ r := registry.New() ++ originalSchema := map[string]any{ ++ "type": "object", ++ "properties": map[string]any{ ++ "id": map[string]any{"type": "integer"}, ++ }, ++ } ++ r.RegisterSchema("Entity", originalSchema) ++ ++ refSchema := map[string]any{ ++ "$ref": "genkit:Entity", ++ } ++ ++ resolved, err := ResolveSchema(r, refSchema) ++ ++ if err != nil { ++ t.Fatalf("unexpected error: %v", err) ++ } ++ if diff := cmp.Diff(originalSchema, resolved); diff != "" { ++ t.Errorf("resolved schema mismatch (-want +got):\n%s", diff) ++ } ++ }) ++ ++ t.Run("returns original schema for non-genkit ref", func(t *testing.T) { ++ r := registry.New() ++ schema := map[string]any{ ++ "$ref": "#/definitions/Other", ++ } ++ ++ resolved, err := ResolveSchema(r, schema) ++ ++ if err != nil { ++ t.Fatalf("unexpected error: %v", err) ++ } ++ if diff := cmp.Diff(schema, resolved); diff != "" { ++ t.Errorf("schema mismatch (-want +got):\n%s", diff) ++ } ++ }) ++ ++ t.Run("returns error for missing schema", func(t *testing.T) { ++ r := registry.New() ++ refSchema := map[string]any{ ++ "$ref": "genkit:NonExistent", ++ } ++ ++ _, err := ResolveSchema(r, refSchema) ++ ++ if err == nil { ++ t.Error("expected error for missing schema, got nil") ++ } ++ }) ++} ++ ++func TestInferSchemaMap(t *testing.T) { ++ t.Run("infers schema from struct", func(t *testing.T) { ++ type TestStruct struct { ++ Name string `json:"name"` ++ Count int `json:"count"` ++ Enabled bool `json:"enabled"` ++ } ++ ++ schema := InferSchemaMap(TestStruct{}) ++ ++ if schema["type"] != "object" { ++ t.Errorf("type = %v, want %q", schema["type"], "object") ++ } ++ props, ok := schema["properties"].(map[string]any) ++ if !ok { ++ t.Fatal("expected properties map") ++ } ++ if props["name"] == nil { ++ t.Error("expected 'name' property") ++ } ++ if props["count"] == nil { ++ t.Error("expected 'count' property") ++ } ++ if props["enabled"] == nil { ++ t.Error("expected 'enabled' property") ++ } ++ }) ++ ++ t.Run("infers schema from primitive types", func(t *testing.T) { ++ tests := []struct { ++ value any ++ wantType string ++ }{ ++ {"hello", "string"}, ++ {42, "integer"}, ++ {3.14, "number"}, ++ {true, "boolean"}, ++ } ++ ++ for _, tt := range tests { ++ t.Run(tt.wantType, func(t *testing.T) { ++ schema := InferSchemaMap(tt.value) ++ if schema["type"] != tt.wantType { ++ t.Errorf("type = %v, want %q", schema["type"], tt.wantType) ++ } ++ }) ++ } ++ }) ++ ++ t.Run("infers schema from slice", func(t *testing.T) { ++ schema := InferSchemaMap([]string{}) ++ ++ if schema["type"] != "array" { ++ t.Errorf("type = %v, want %q", schema["type"], "array") ++ } ++ }) ++ ++ t.Run("infers schema from nested struct", func(t *testing.T) { ++ type Inner struct { ++ Value string `json:"value"` ++ } ++ type Outer struct { ++ Inner Inner `json:"inner"` ++ } ++ ++ schema := InferSchemaMap(Outer{}) ++ ++ props := schema["properties"].(map[string]any) ++ innerProp := props["inner"].(map[string]any) ++ if innerProp["type"] != "object" { ++ t.Errorf("inner type = %v, want %q", innerProp["type"], "object") ++ } ++ }) ++} +diff --git a/go/core/doc.go b/go/core/doc.go +new file mode 100644 +index 000000000..e4528df7f +--- /dev/null ++++ b/go/core/doc.go +@@ -0,0 +1,230 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++// ++// SPDX-License-Identifier: Apache-2.0 ++ ++/* ++Package core implements Genkit's foundational action system and runtime machinery. ++ ++This package is primarily intended for plugin developers and Genkit internals. ++Application developers should use the genkit package instead, which provides ++a higher-level, more convenient API. ++ ++# Actions ++ ++Actions are the fundamental building blocks of Genkit. Every operation - flows, ++model calls, tool invocations, retrieval - is implemented as an action. Actions ++provide: ++ ++ - Type-safe input/output with JSON schema validation ++ - Automatic tracing and observability ++ - Consistent error handling ++ - Registration in the action registry ++ ++Define a non-streaming action: ++ ++ action := core.DefineAction(registry, "myAction", ++ func(ctx context.Context, input string) (string, error) { ++ return "processed: " + input, nil ++ }, ++ ) ++ ++ result, err := action.Run(context.Background(), "hello") ++ ++Define a streaming action that sends chunks during execution: ++ ++ streamingAction := core.DefineStreamingAction(registry, "countdown", ++ func(ctx context.Context, start int, cb core.StreamCallback[string]) (string, error) { ++ for i := start; i > 0; i-- { ++ if cb != nil { ++ if err := cb(ctx, fmt.Sprintf("T-%d", i)); err != nil { ++ return "", err ++ } ++ } ++ time.Sleep(time.Second) ++ } ++ return "Liftoff!", nil ++ }, ++ ) ++ ++# Flows ++ ++Flows are user-defined actions that orchestrate AI operations. They are the ++primary way application developers define business logic in Genkit: ++ ++ flow := core.DefineFlow(registry, "myFlow", ++ func(ctx context.Context, input string) (string, error) { ++ // Use Run to create traced sub-steps ++ result, err := core.Run(ctx, "step1", func() (string, error) { ++ return process(input), nil ++ }) ++ if err != nil { ++ return "", err ++ } ++ return result, nil ++ }, ++ ) ++ ++Streaming flows can send intermediate results to callers: ++ ++ streamingFlow := core.DefineStreamingFlow(registry, "generateReport", ++ func(ctx context.Context, input Input, cb core.StreamCallback[Progress]) (Report, error) { ++ for i := 0; i < 100; i += 10 { ++ if cb != nil { ++ cb(ctx, Progress{Percent: i}) ++ } ++ // ... work ... ++ } ++ return Report{...}, nil ++ }, ++ ) ++ ++# Traced Steps with Run ++ ++Use [Run] within flows to create traced sub-operations. Each Run call creates ++a span in the trace that's visible in the Genkit Developer UI: ++ ++ result, err := core.Run(ctx, "fetchData", func() (Data, error) { ++ return fetchFromAPI() ++ }) ++ ++ processed, err := core.Run(ctx, "processData", func() (Result, error) { ++ return process(result) ++ }) ++ ++# Middleware ++ ++Actions support middleware for cross-cutting concerns like logging, metrics, ++or authentication: ++ ++ loggingMiddleware := func(next core.StreamingFunc[string, string, struct{}]) core.StreamingFunc[string, string, struct{}] { ++ return func(ctx context.Context, input string, cb core.StreamCallback[struct{}]) (string, error) { ++ log.Printf("Input: %s", input) ++ output, err := next(ctx, input, cb) ++ log.Printf("Output: %s, Error: %v", output, err) ++ return output, err ++ } ++ } ++ ++Chain multiple middleware together: ++ ++ combined := core.ChainMiddleware(loggingMiddleware, metricsMiddleware) ++ wrappedFn := combined(originalFunc) ++ ++# Schema Management ++ ++Register JSON schemas for use in prompts and validation: ++ ++ // Define a schema from a map ++ core.DefineSchema(registry, "Person", map[string]any{ ++ "type": "object", ++ "properties": map[string]any{ ++ "name": map[string]any{"type": "string"}, ++ "age": map[string]any{"type": "integer"}, ++ }, ++ "required": []any{"name"}, ++ }) ++ ++ // Define a schema from a Go type (recommended) ++ core.DefineSchemaFor[Person](registry) ++ ++Schemas can be referenced in .prompt files by name. ++ ++# Plugin Development ++ ++Plugins extend Genkit's functionality by providing models, tools, retrievers, ++and other capabilities. Implement the [api.Plugin] interface: ++ ++ type MyPlugin struct { ++ APIKey string ++ } ++ ++ func (p *MyPlugin) Name() string { ++ return "myplugin" ++ } ++ ++ func (p *MyPlugin) Init(ctx context.Context) []api.Action { ++ // Initialize the plugin and return actions to register ++ model := ai.DefineModel(...) ++ tool := ai.DefineTool(...) ++ return []api.Action{model, tool} ++ } ++ ++For plugins that resolve actions dynamically (e.g., listing available models ++from an API), implement [api.DynamicPlugin]: ++ ++ type DynamicModelPlugin struct{} ++ ++ func (p *DynamicModelPlugin) ListActions(ctx context.Context) []api.ActionDesc { ++ // Return descriptors of available actions ++ return []api.ActionDesc{ ++ {Key: "/model/myplugin/model-a", Name: "model-a"}, ++ {Key: "/model/myplugin/model-b", Name: "model-b"}, ++ } ++ } ++ ++ func (p *DynamicModelPlugin) ResolveAction(atype api.ActionType, name string) api.Action { ++ // Create and return the action on demand ++ return createModel(name) ++ } ++ ++# Background Actions ++ ++For long-running operations, use background actions that return immediately ++with an operation ID that can be polled for completion: ++ ++ bgAction := core.DefineBackgroundAction(registry, "longTask", ++ func(ctx context.Context, input Input) (Output, error) { ++ // Start the operation ++ return startLongOperation(input) ++ }, ++ func(ctx context.Context, op *core.Operation[Output]) (*core.Operation[Output], error) { ++ // Check operation status ++ return checkOperationStatus(op) ++ }, ++ ) ++ ++# Error Handling ++ ++Return user-facing errors with appropriate status codes: ++ ++ if err := validate(input); err != nil { ++ return nil, core.NewPublicError(core.INVALID_ARGUMENT, "Invalid input", map[string]any{ ++ "field": "email", ++ "error": err.Error(), ++ }) ++ } ++ ++For internal errors that should be logged but not exposed to users: ++ ++ return nil, core.NewError(core.INTERNAL, "database connection failed: %v", err) ++ ++# Context ++ ++Access action context for metadata and configuration: ++ ++ ctx := core.FromContext(ctx) ++ if ctx != nil { ++ // Access action-specific context values ++ } ++ ++Set action context for nested operations: ++ ++ ctx = core.WithActionContext(ctx, core.ActionContext{ ++ "requestId": requestID, ++ }) ++ ++For more information, see https://genkit.dev/docs/plugins ++*/ ++package core +diff --git a/go/core/error_test.go b/go/core/error_test.go +new file mode 100644 +index 000000000..60ff503bd +--- /dev/null ++++ b/go/core/error_test.go +@@ -0,0 +1,217 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++// ++// SPDX-License-Identifier: Apache-2.0 ++ ++package core ++ ++import ( ++ "errors" ++ "fmt" ++ "net/http" ++ "strings" ++ "testing" ++) ++ ++func TestNewPublicError(t *testing.T) { ++ t.Run("creates error with all fields", func(t *testing.T) { ++ details := map[string]any{"field": "username"} ++ err := NewPublicError(INVALID_ARGUMENT, "invalid username", details) ++ ++ if err.Status != INVALID_ARGUMENT { ++ t.Errorf("Status = %q, want %q", err.Status, INVALID_ARGUMENT) ++ } ++ if err.Message != "invalid username" { ++ t.Errorf("Message = %q, want %q", err.Message, "invalid username") ++ } ++ if err.Details["field"] != "username" { ++ t.Errorf("Details[field] = %v, want %q", err.Details["field"], "username") ++ } ++ }) ++ ++ t.Run("creates error with nil details", func(t *testing.T) { ++ err := NewPublicError(NOT_FOUND, "resource not found", nil) ++ ++ if err.Status != NOT_FOUND { ++ t.Errorf("Status = %q, want %q", err.Status, NOT_FOUND) ++ } ++ if err.Details != nil { ++ t.Errorf("Details = %v, want nil", err.Details) ++ } ++ }) ++} ++ ++func TestUserFacingErrorError(t *testing.T) { ++ t.Run("formats error message correctly", func(t *testing.T) { ++ err := NewPublicError(PERMISSION_DENIED, "access denied", nil) ++ got := err.Error() ++ want := "PERMISSION_DENIED: access denied" ++ ++ if got != want { ++ t.Errorf("Error() = %q, want %q", got, want) ++ } ++ }) ++} ++ ++func TestNewError(t *testing.T) { ++ t.Run("creates error with simple message", func(t *testing.T) { ++ err := NewError(INTERNAL, "internal error") ++ ++ if err.Status != INTERNAL { ++ t.Errorf("Status = %q, want %q", err.Status, INTERNAL) ++ } ++ if err.Message != "internal error" { ++ t.Errorf("Message = %q, want %q", err.Message, "internal error") ++ } ++ }) ++ ++ t.Run("creates error with formatted message", func(t *testing.T) { ++ err := NewError(INVALID_ARGUMENT, "field %q has invalid value %d", "count", 42) ++ ++ want := `field "count" has invalid value 42` ++ if err.Message != want { ++ t.Errorf("Message = %q, want %q", err.Message, want) ++ } ++ }) ++ ++ t.Run("captures stack trace", func(t *testing.T) { ++ err := NewError(INTERNAL, "error with stack") ++ ++ if err.Details == nil { ++ t.Fatal("Details is nil, expected stack trace") ++ } ++ stack, ok := err.Details["stack"].(string) ++ if !ok { ++ t.Fatal("stack is not a string") ++ } ++ if !strings.Contains(stack, "TestNewError") { ++ t.Errorf("stack trace does not contain test function name") ++ } ++ }) ++} ++ ++func TestGenkitErrorError(t *testing.T) { ++ t.Run("returns message as error string", func(t *testing.T) { ++ err := NewError(INTERNAL, "something went wrong") ++ got := err.Error() ++ ++ if got != "something went wrong" { ++ t.Errorf("Error() = %q, want %q", got, "something went wrong") ++ } ++ }) ++} ++ ++func TestGenkitErrorToReflectionError(t *testing.T) { ++ t.Run("converts error with stack", func(t *testing.T) { ++ ge := NewError(NOT_FOUND, "resource not found") ++ re := ge.ToReflectionError() ++ ++ if re.Message != "resource not found" { ++ t.Errorf("Message = %q, want %q", re.Message, "resource not found") ++ } ++ if re.Code != http.StatusNotFound { ++ t.Errorf("Code = %d, want %d", re.Code, http.StatusNotFound) ++ } ++ if re.Details == nil || re.Details.Stack == nil { ++ t.Error("expected stack in details") ++ } ++ }) ++ ++ t.Run("converts error with traceId", func(t *testing.T) { ++ ge := &GenkitError{ ++ Status: INTERNAL, ++ Message: "internal error", ++ Details: map[string]any{ ++ "traceId": "trace-123", ++ }, ++ } ++ re := ge.ToReflectionError() ++ ++ if re.Details == nil || re.Details.TraceID == nil { ++ t.Fatal("expected traceId in details") ++ } ++ if *re.Details.TraceID != "trace-123" { ++ t.Errorf("TraceID = %q, want %q", *re.Details.TraceID, "trace-123") ++ } ++ }) ++ ++ t.Run("handles empty details", func(t *testing.T) { ++ ge := &GenkitError{ ++ Status: OK, ++ Message: "success", ++ Details: nil, ++ } ++ re := ge.ToReflectionError() ++ ++ if re.Message != "success" { ++ t.Errorf("Message = %q, want %q", re.Message, "success") ++ } ++ if re.Details.Stack != nil { ++ t.Error("expected nil stack") ++ } ++ }) ++} ++ ++func TestToReflectionError(t *testing.T) { ++ t.Run("handles GenkitError directly", func(t *testing.T) { ++ ge := NewError(INVALID_ARGUMENT, "bad input") ++ re := ToReflectionError(ge) ++ ++ if re.Message != "bad input" { ++ t.Errorf("Message = %q, want %q", re.Message, "bad input") ++ } ++ if re.Code != http.StatusBadRequest { ++ t.Errorf("Code = %d, want %d", re.Code, http.StatusBadRequest) ++ } ++ }) ++ ++ t.Run("handles wrapped GenkitError", func(t *testing.T) { ++ ge := NewError(NOT_FOUND, "not found") ++ wrapped := fmt.Errorf("context: %w", ge) ++ re := ToReflectionError(wrapped) ++ ++ if re.Message != "not found" { ++ t.Errorf("Message = %q, want %q", re.Message, "not found") ++ } ++ if re.Code != http.StatusNotFound { ++ t.Errorf("Code = %d, want %d", re.Code, http.StatusNotFound) ++ } ++ }) ++ ++ t.Run("handles plain error", func(t *testing.T) { ++ plainErr := errors.New("plain error") ++ re := ToReflectionError(plainErr) ++ ++ if re.Message != "plain error" { ++ t.Errorf("Message = %q, want %q", re.Message, "plain error") ++ } ++ if re.Code != http.StatusInternalServerError { ++ t.Errorf("Code = %d, want %d", re.Code, http.StatusInternalServerError) ++ } ++ }) ++ ++ t.Run("handles doubly wrapped GenkitError", func(t *testing.T) { ++ ge := NewError(PERMISSION_DENIED, "denied") ++ wrapped1 := fmt.Errorf("layer1: %w", ge) ++ wrapped2 := fmt.Errorf("layer2: %w", wrapped1) ++ re := ToReflectionError(wrapped2) ++ ++ if re.Message != "denied" { ++ t.Errorf("Message = %q, want %q", re.Message, "denied") ++ } ++ if re.Code != http.StatusForbidden { ++ t.Errorf("Code = %d, want %d", re.Code, http.StatusForbidden) ++ } ++ }) ++} +diff --git a/go/core/example_test.go b/go/core/example_test.go +new file mode 100644 +index 000000000..c6212c3e9 +--- /dev/null ++++ b/go/core/example_test.go +@@ -0,0 +1,197 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++// ++// SPDX-License-Identifier: Apache-2.0 ++ ++package core_test ++ ++import ( ++ "context" ++ "fmt" ++ "strings" ++ ++ "github.com/firebase/genkit/go/core" ++ "github.com/firebase/genkit/go/internal/registry" ++) ++ ++// This example demonstrates defining a simple flow. ++func ExampleDefineFlow() { ++ r := registry.New() ++ ++ // Define a flow that processes input ++ flow := core.DefineFlow(r, "uppercase", ++ func(ctx context.Context, input string) (string, error) { ++ return strings.ToUpper(input), nil ++ }, ++ ) ++ ++ // Run the flow ++ result, err := flow.Run(context.Background(), "hello") ++ if err != nil { ++ fmt.Println("Error:", err) ++ return ++ } ++ fmt.Println(result) ++ // Output: HELLO ++} ++ ++// This example demonstrates defining a streaming flow. ++func ExampleDefineStreamingFlow() { ++ r := registry.New() ++ ++ // Define a streaming flow that counts down ++ flow := core.DefineStreamingFlow(r, "countdown", ++ func(ctx context.Context, start int, cb core.StreamCallback[int]) (string, error) { ++ for i := start; i > 0; i-- { ++ if cb != nil { ++ if err := cb(ctx, i); err != nil { ++ return "", err ++ } ++ } ++ } ++ return "Done!", nil ++ }, ++ ) ++ ++ // Use Stream() iterator to receive chunks ++ iter := flow.Stream(context.Background(), 3) ++ iter(func(val *core.StreamingFlowValue[string, int], err error) bool { ++ if err != nil { ++ fmt.Println("Error:", err) ++ return false ++ } ++ if val.Done { ++ fmt.Println("Result:", val.Output) ++ } else { ++ fmt.Println("Count:", val.Stream) ++ } ++ return true ++ }) ++ // Output: ++ // Count: 3 ++ // Count: 2 ++ // Count: 1 ++ // Result: Done! ++} ++ ++// This example demonstrates using Run to create traced sub-steps. ++func ExampleRun() { ++ r := registry.New() ++ ++ // Define a flow that uses Run for traced steps ++ flow := core.DefineFlow(r, "pipeline", ++ func(ctx context.Context, input string) (string, error) { ++ // Each Run creates a traced step visible in the Dev UI ++ upper, err := core.Run(ctx, "toUpper", func() (string, error) { ++ return strings.ToUpper(input), nil ++ }) ++ if err != nil { ++ return "", err ++ } ++ ++ result, err := core.Run(ctx, "addPrefix", func() (string, error) { ++ return "RESULT: " + upper, nil ++ }) ++ return result, err ++ }, ++ ) ++ ++ result, err := flow.Run(context.Background(), "hello") ++ if err != nil { ++ fmt.Println("Error:", err) ++ return ++ } ++ fmt.Println(result) ++ // Output: RESULT: HELLO ++} ++ ++// This example demonstrates defining a schema from a Go type. ++func ExampleDefineSchemaFor() { ++ r := registry.New() ++ ++ // Define a struct type ++ type Person struct { ++ Name string `json:"name"` ++ Age int `json:"age"` ++ } ++ ++ // Register the schema ++ core.DefineSchemaFor[Person](r) ++ ++ // The schema is now registered and can be referenced in .prompt files ++ fmt.Println("Schema registered") ++ // Output: Schema registered ++} ++ ++// This example demonstrates defining a schema from a map. ++func ExampleDefineSchema() { ++ r := registry.New() ++ ++ // Define a JSON schema as a map ++ core.DefineSchema(r, "Address", map[string]any{ ++ "type": "object", ++ "properties": map[string]any{ ++ "street": map[string]any{"type": "string"}, ++ "city": map[string]any{"type": "string"}, ++ "zip": map[string]any{"type": "string"}, ++ }, ++ "required": []any{"street", "city"}, ++ }) ++ ++ fmt.Println("Schema registered: Address") ++ // Output: Schema registered: Address ++} ++ ++// This example demonstrates using ChainMiddleware to combine middleware. ++func ExampleChainMiddleware() { ++ // Define a middleware that wraps function calls ++ logMiddleware := func(next core.StreamingFunc[string, string, struct{}]) core.StreamingFunc[string, string, struct{}] { ++ return func(ctx context.Context, input string, cb core.StreamCallback[struct{}]) (string, error) { ++ fmt.Println("Before:", input) ++ result, err := next(ctx, input, cb) ++ fmt.Println("After:", result) ++ return result, err ++ } ++ } ++ ++ // The original function ++ originalFn := func(ctx context.Context, input string, cb core.StreamCallback[struct{}]) (string, error) { ++ return strings.ToUpper(input), nil ++ } ++ ++ // Chain and apply middleware ++ wrapped := core.ChainMiddleware(logMiddleware)(originalFn) ++ ++ result, _ := wrapped(context.Background(), "hello", nil) ++ fmt.Println("Final:", result) ++ // Output: ++ // Before: hello ++ // After: HELLO ++ // Final: HELLO ++} ++ ++// This example demonstrates creating user-facing errors. ++func ExampleNewPublicError() { ++ // Create a user-facing error with details ++ err := core.NewPublicError(core.INVALID_ARGUMENT, "Invalid email format", map[string]any{ ++ "field": "email", ++ "value": "not-an-email", ++ }) ++ ++ fmt.Println("Status:", err.Status) ++ fmt.Println("Message:", err.Message) ++ // Output: ++ // Status: INVALID_ARGUMENT ++ // Message: Invalid email format ++} +diff --git a/go/core/flow.go b/go/core/flow.go +index 0cd12120f..ea514365c 100644 +--- a/go/core/flow.go ++++ b/go/core/flow.go +@@ -71,6 +71,9 @@ func DefineStreamingFlow[In, Out, Stream any](r api.Registry, name string, fn St + flowName: name, + } + ctx = flowContextKey.NewContext(ctx, fc) ++ if cb == nil { ++ cb = func(context.Context, Stream) error { return nil } ++ } + return fn(ctx, input, cb) + })) + } +diff --git a/go/core/flow_test.go b/go/core/flow_test.go +index 77087072c..e3c3e6b46 100644 +--- a/go/core/flow_test.go ++++ b/go/core/flow_test.go +@@ -89,3 +89,171 @@ func TestFlowNameFromContext(t *testing.T) { + }) + } + } ++ ++func TestRunOutsideFlow(t *testing.T) { ++ t.Run("returns error when called outside flow", func(t *testing.T) { ++ ctx := context.Background() ++ _, err := Run(ctx, "step", func() (int, error) { ++ return 42, nil ++ }) ++ ++ if err == nil { ++ t.Error("expected error when Run called outside flow, got nil") ++ } ++ }) ++} ++ ++func TestFlowStream(t *testing.T) { ++ t.Run("streams values correctly", func(t *testing.T) { ++ r := registry.New() ++ f := DefineStreamingFlow(r, "counter", func(ctx context.Context, n int, cb StreamCallback[int]) (int, error) { ++ for i := 0; i < n; i++ { ++ if err := cb(ctx, i); err != nil { ++ return 0, err ++ } ++ } ++ return n, nil ++ }) ++ ++ var streamedValues []int ++ var finalOutput int ++ var finalDone bool ++ ++ for v, err := range f.Stream(context.Background(), 3) { ++ if err != nil { ++ t.Fatalf("Stream error: %v", err) ++ } ++ if v.Done { ++ finalDone = true ++ finalOutput = v.Output ++ } else { ++ streamedValues = append(streamedValues, v.Stream) ++ } ++ } ++ ++ wantStreamed := []int{0, 1, 2} ++ if !slices.Equal(streamedValues, wantStreamed) { ++ t.Errorf("streamed values = %v, want %v", streamedValues, wantStreamed) ++ } ++ if !finalDone { ++ t.Error("expected final Done value") ++ } ++ if finalOutput != 3 { ++ t.Errorf("final output = %d, want 3", finalOutput) ++ } ++ }) ++ ++ t.Run("yields error on flow failure", func(t *testing.T) { ++ r := registry.New() ++ f := DefineStreamingFlow(r, "failing", func(ctx context.Context, input int, cb StreamCallback[int]) (int, error) { ++ return 0, NewError(INTERNAL, "flow failed") ++ }) ++ ++ var gotErr error ++ for _, err := range f.Stream(context.Background(), 1) { ++ if err != nil { ++ gotErr = err ++ } ++ } ++ ++ if gotErr == nil { ++ t.Error("expected error from failing flow, got nil") ++ } ++ }) ++} ++ ++func TestFlowRegister(t *testing.T) { ++ t.Run("flow can be registered with registry", func(t *testing.T) { ++ r := registry.New() ++ f := DefineFlow(r, "test/registerable", func(ctx context.Context, input string) (string, error) { ++ return input, nil ++ }) ++ ++ // Flow should already be registered by DefineFlow ++ if f.Name() != "test/registerable" { ++ t.Errorf("Name() = %q, want %q", f.Name(), "test/registerable") ++ } ++ }) ++} ++ ++func TestFlowDesc(t *testing.T) { ++ t.Run("returns flow descriptor", func(t *testing.T) { ++ r := registry.New() ++ f := DefineFlow(r, "test/described", func(ctx context.Context, input struct { ++ Name string `json:"name"` ++ }) (struct { ++ Greeting string `json:"greeting"` ++ }, error) { ++ return struct { ++ Greeting string `json:"greeting"` ++ }{Greeting: "Hello " + input.Name}, nil ++ }) ++ ++ desc := f.Desc() ++ ++ if desc.Name != "test/described" { ++ t.Errorf("Name = %q, want %q", desc.Name, "test/described") ++ } ++ if desc.InputSchema == nil { ++ t.Error("InputSchema is nil") ++ } ++ if desc.OutputSchema == nil { ++ t.Error("OutputSchema is nil") ++ } ++ }) ++} ++ ++func TestFlowRunJSON(t *testing.T) { ++ t.Run("runs flow with JSON input and output", func(t *testing.T) { ++ r := registry.New() ++ f := DefineFlow(r, "test/jsonFlow", func(ctx context.Context, input int) (int, error) { ++ return input * 2, nil ++ }) ++ ++ got, err := f.RunJSON(context.Background(), []byte("5"), nil) ++ if err != nil { ++ t.Fatalf("RunJSON error: %v", err) ++ } ++ ++ if string(got) != "10" { ++ t.Errorf("RunJSON result = %s, want %q", got, "10") ++ } ++ }) ++} ++ ++func TestFlowRunJSONWithTelemetry(t *testing.T) { ++ t.Run("returns telemetry info with result", func(t *testing.T) { ++ r := registry.New() ++ f := DefineFlow(r, "test/telemetryFlow", func(ctx context.Context, input int) (int, error) { ++ return input + 1, nil ++ }) ++ ++ result, err := f.RunJSONWithTelemetry(context.Background(), []byte("5"), nil) ++ if err != nil { ++ t.Fatalf("RunJSONWithTelemetry error: %v", err) ++ } ++ ++ if result == nil { ++ t.Fatal("result is nil") ++ } ++ if string(result.Result) != "6" { ++ t.Errorf("Result = %s, want %q", result.Result, "6") ++ } ++ if result.TraceId == "" { ++ t.Error("TraceId is empty") ++ } ++ if result.SpanId == "" { ++ t.Error("SpanId is empty") ++ } ++ }) ++} ++ ++func TestFlowNameFromContextOutsideFlow(t *testing.T) { ++ t.Run("returns empty string outside flow", func(t *testing.T) { ++ ctx := context.Background() ++ got := FlowNameFromContext(ctx) ++ if got != "" { ++ t.Errorf("FlowNameFromContext outside flow = %q, want empty string", got) ++ } ++ }) ++} +diff --git a/go/core/logger/doc.go b/go/core/logger/doc.go +new file mode 100644 +index 000000000..b3e421abc +--- /dev/null ++++ b/go/core/logger/doc.go +@@ -0,0 +1,110 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++// ++// SPDX-License-Identifier: Apache-2.0 ++ ++/* ++Package logger provides context-scoped structured logging for Genkit. ++ ++This package wraps the standard library's [log/slog] package to provide ++context-aware logging throughout Genkit operations. Logs are automatically ++associated with the current action or flow context. ++ ++# Usage ++ ++Retrieve the logger from context within action or flow handlers: ++ ++ func myFlow(ctx context.Context, input string) (string, error) { ++ log := logger.FromContext(ctx) ++ ++ log.Info("Processing input", "size", len(input)) ++ log.Debug("Input details", "value", input) ++ ++ result, err := process(input) ++ if err != nil { ++ log.Error("Processing failed", "error", err) ++ return "", err ++ } ++ ++ log.Info("Processing complete", "resultSize", len(result)) ++ return result, nil ++ } ++ ++# Log Levels ++ ++Control the global log level to filter output: ++ ++ // Show debug logs (verbose) ++ logger.SetLevel(slog.LevelDebug) ++ ++ // Show info and above (default) ++ logger.SetLevel(slog.LevelInfo) ++ ++ // Show only warnings and errors ++ logger.SetLevel(slog.LevelWarn) ++ ++ // Show only errors ++ logger.SetLevel(slog.LevelError) ++ ++ // Get the current log level ++ level := logger.GetLevel() ++ ++# Context Integration ++ ++The logger is automatically available in action and flow contexts. It ++inherits from the context passed to [genkit.Init] and flows through ++all nested operations. ++ ++For custom operations outside of actions/flows, attach a logger to context: ++ ++ log := slog.Default() ++ ctx = logger.WithContext(ctx, log) ++ ++# slog Compatibility ++ ++The logger returned by [FromContext] is a standard [*slog.Logger] and ++supports all slog methods: ++ ++ log := logger.FromContext(ctx) ++ ++ // Structured logging with attributes ++ log.Info("User action", ++ "userId", userID, ++ "action", "login", ++ "duration", elapsed, ++ ) ++ ++ // Grouped attributes ++ log.Info("Request completed", ++ slog.Group("request", ++ "method", r.Method, ++ "path", r.URL.Path, ++ ), ++ slog.Group("response", ++ "status", status, ++ "bytes", written, ++ ), ++ ) ++ ++ // With pre-set attributes ++ requestLog := log.With("requestId", requestID) ++ requestLog.Info("Starting") ++ // ... later ... ++ requestLog.Info("Finished") ++ ++This package is primarily used by Genkit internals but is useful for ++plugin developers who need consistent logging that integrates with ++Genkit's observability features. ++*/ ++package logger +diff --git a/go/core/middleware_test.go b/go/core/middleware_test.go +new file mode 100644 +index 000000000..8422eb3ee +--- /dev/null ++++ b/go/core/middleware_test.go +@@ -0,0 +1,222 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++// ++// SPDX-License-Identifier: Apache-2.0 ++ ++package core ++ ++import ( ++ "context" ++ "strings" ++ "testing" ++) ++ ++func TestMiddlewares(t *testing.T) { ++ t.Run("creates slice of middlewares", func(t *testing.T) { ++ m1 := func(next StreamingFunc[string, string, struct{}]) StreamingFunc[string, string, struct{}] { ++ return next ++ } ++ m2 := func(next StreamingFunc[string, string, struct{}]) StreamingFunc[string, string, struct{}] { ++ return next ++ } ++ ++ result := Middlewares(m1, m2) ++ ++ if len(result) != 2 { ++ t.Errorf("len(result) = %d, want 2", len(result)) ++ } ++ }) ++ ++ t.Run("returns empty slice when no middlewares", func(t *testing.T) { ++ result := Middlewares[string, string, struct{}]() ++ ++ if len(result) != 0 { ++ t.Errorf("len(result) = %d, want 0", len(result)) ++ } ++ }) ++ ++ t.Run("returns single middleware slice", func(t *testing.T) { ++ m := func(next StreamingFunc[string, string, struct{}]) StreamingFunc[string, string, struct{}] { ++ return next ++ } ++ ++ result := Middlewares(m) ++ ++ if len(result) != 1 { ++ t.Errorf("len(result) = %d, want 1", len(result)) ++ } ++ }) ++} ++ ++func TestChainMiddleware(t *testing.T) { ++ t.Run("empty chain returns identity", func(t *testing.T) { ++ handler := func(ctx context.Context, input string, cb func(context.Context, struct{}) error) (string, error) { ++ return "original:" + input, nil ++ } ++ ++ chained := ChainMiddleware[string, string, struct{}]()(handler) ++ result, err := chained(context.Background(), "test", nil) ++ ++ if err != nil { ++ t.Fatalf("unexpected error: %v", err) ++ } ++ if result != "original:test" { ++ t.Errorf("result = %q, want %q", result, "original:test") ++ } ++ }) ++ ++ t.Run("single middleware is applied", func(t *testing.T) { ++ handler := func(ctx context.Context, input string, cb func(context.Context, struct{}) error) (string, error) { ++ return "handler:" + input, nil ++ } ++ ++ middleware := func(next StreamingFunc[string, string, struct{}]) StreamingFunc[string, string, struct{}] { ++ return func(ctx context.Context, input string, cb func(context.Context, struct{}) error) (string, error) { ++ result, err := next(ctx, "m1:"+input, cb) ++ return "m1:" + result, err ++ } ++ } ++ ++ chained := ChainMiddleware(middleware)(handler) ++ result, err := chained(context.Background(), "test", nil) ++ ++ if err != nil { ++ t.Fatalf("unexpected error: %v", err) ++ } ++ // Expected: m1: wraps output, m1: prepends to input ++ if result != "m1:handler:m1:test" { ++ t.Errorf("result = %q, want %q", result, "m1:handler:m1:test") ++ } ++ }) ++ ++ t.Run("multiple middlewares execute in order", func(t *testing.T) { ++ var executionOrder []string ++ ++ handler := func(ctx context.Context, input string, cb func(context.Context, struct{}) error) (string, error) { ++ executionOrder = append(executionOrder, "handler") ++ return input, nil ++ } ++ ++ m1 := func(next StreamingFunc[string, string, struct{}]) StreamingFunc[string, string, struct{}] { ++ return func(ctx context.Context, input string, cb func(context.Context, struct{}) error) (string, error) { ++ executionOrder = append(executionOrder, "m1-before") ++ result, err := next(ctx, input, cb) ++ executionOrder = append(executionOrder, "m1-after") ++ return result, err ++ } ++ } ++ ++ m2 := func(next StreamingFunc[string, string, struct{}]) StreamingFunc[string, string, struct{}] { ++ return func(ctx context.Context, input string, cb func(context.Context, struct{}) error) (string, error) { ++ executionOrder = append(executionOrder, "m2-before") ++ result, err := next(ctx, input, cb) ++ executionOrder = append(executionOrder, "m2-after") ++ return result, err ++ } ++ } ++ ++ // ChainMiddleware(m1, m2) should execute as: m1 -> m2 -> handler -> m2 -> m1 ++ chained := ChainMiddleware(m1, m2)(handler) ++ _, err := chained(context.Background(), "test", nil) ++ ++ if err != nil { ++ t.Fatalf("unexpected error: %v", err) ++ } ++ ++ expected := []string{"m1-before", "m2-before", "handler", "m2-after", "m1-after"} ++ if len(executionOrder) != len(expected) { ++ t.Errorf("execution order length = %d, want %d", len(executionOrder), len(expected)) ++ } ++ for i, step := range expected { ++ if i >= len(executionOrder) || executionOrder[i] != step { ++ t.Errorf("step %d = %q, want %q", i, executionOrder[i], step) ++ } ++ } ++ }) ++ ++ t.Run("middleware can modify input", func(t *testing.T) { ++ handler := func(ctx context.Context, input string, cb func(context.Context, struct{}) error) (string, error) { ++ return input, nil ++ } ++ ++ uppercase := func(next StreamingFunc[string, string, struct{}]) StreamingFunc[string, string, struct{}] { ++ return func(ctx context.Context, input string, cb func(context.Context, struct{}) error) (string, error) { ++ return next(ctx, strings.ToUpper(input), cb) ++ } ++ } ++ ++ chained := ChainMiddleware(uppercase)(handler) ++ result, err := chained(context.Background(), "hello", nil) ++ ++ if err != nil { ++ t.Fatalf("unexpected error: %v", err) ++ } ++ if result != "HELLO" { ++ t.Errorf("result = %q, want %q", result, "HELLO") ++ } ++ }) ++ ++ t.Run("middleware can modify output", func(t *testing.T) { ++ handler := func(ctx context.Context, input string, cb func(context.Context, struct{}) error) (string, error) { ++ return input, nil ++ } ++ ++ addSuffix := func(next StreamingFunc[string, string, struct{}]) StreamingFunc[string, string, struct{}] { ++ return func(ctx context.Context, input string, cb func(context.Context, struct{}) error) (string, error) { ++ result, err := next(ctx, input, cb) ++ return result + "!", err ++ } ++ } ++ ++ chained := ChainMiddleware(addSuffix)(handler) ++ result, err := chained(context.Background(), "hello", nil) ++ ++ if err != nil { ++ t.Fatalf("unexpected error: %v", err) ++ } ++ if result != "hello!" { ++ t.Errorf("result = %q, want %q", result, "hello!") ++ } ++ }) ++ ++ t.Run("middleware can short-circuit", func(t *testing.T) { ++ handlerCalled := false ++ handler := func(ctx context.Context, input string, cb func(context.Context, struct{}) error) (string, error) { ++ handlerCalled = true ++ return input, nil ++ } ++ ++ shortCircuit := func(next StreamingFunc[string, string, struct{}]) StreamingFunc[string, string, struct{}] { ++ return func(ctx context.Context, input string, cb func(context.Context, struct{}) error) (string, error) { ++ if input == "skip" { ++ return "skipped", nil ++ } ++ return next(ctx, input, cb) ++ } ++ } ++ ++ chained := ChainMiddleware(shortCircuit)(handler) ++ result, err := chained(context.Background(), "skip", nil) ++ ++ if err != nil { ++ t.Fatalf("unexpected error: %v", err) ++ } ++ if handlerCalled { ++ t.Error("handler should not have been called") ++ } ++ if result != "skipped" { ++ t.Errorf("result = %q, want %q", result, "skipped") ++ } ++ }) ++} +diff --git a/go/core/schemas.config b/go/core/schemas.config +index 746d10ca4..70798f2eb 100644 +--- a/go/core/schemas.config ++++ b/go/core/schemas.config +@@ -1,6 +1,866 @@ + # This file holds configuration for the genkit-schema.json file + # generated by the npm export:schemas script. + ++# ============================================================================ ++# DOCUMENTATION SECTION ++# All type and field documentation in one consolidated location ++# ============================================================================ ++ ++# ---------------------------------------------------------------------------- ++# Core Message Types ++# ---------------------------------------------------------------------------- ++ ++Role doc ++Role indicates which entity is responsible for the content of a message. ++. ++ ++RoleSystem doc ++RoleSystem indicates this message is user-independent context. ++. ++ ++RoleUser doc ++RoleUser indicates this message was generated by the client. ++. ++ ++RoleModel doc ++RoleModel indicates this message was generated by the model during a previous interaction. ++. ++ ++RoleTool doc ++RoleTool indicates this message was generated by a local tool, likely triggered by a request ++from the model in one of its previous responses. ++. ++ ++Message doc ++Message represents the contents of a model message in a conversation. ++. ++ ++Message.role doc ++Role indicates which entity (system, user, model, or tool) generated this message. ++. ++ ++Message.content doc ++Content holds the message parts (text, media, tool calls, etc.). ++. ++ ++Message.metadata doc ++Metadata contains arbitrary key-value data associated with this message. ++. ++ ++# ---------------------------------------------------------------------------- ++# Part Types (Message Content) ++# ---------------------------------------------------------------------------- ++ ++TextPart.text doc ++Text contains the textual content. ++. ++ ++TextPart.metadata doc ++Metadata contains arbitrary key-value data for this part. ++. ++ ++MediaPart.media doc ++Media contains the media content and metadata. ++. ++ ++MediaPart.metadata doc ++Metadata contains arbitrary key-value data for this part. ++. ++ ++ToolRequestPart.toolRequest doc ++ToolRequest is a request for a tool to be executed, usually provided by a model. ++. ++ ++ToolRequestPart.metadata doc ++Metadata contains arbitrary key-value data for this part. ++. ++ ++ToolResponsePart.toolResponse doc ++ToolResponse is a provided response to a tool call. ++. ++ ++ToolResponsePart.metadata doc ++Metadata contains arbitrary key-value data for this part. ++. ++ ++DataPart.data doc ++Data contains arbitrary structured data. ++. ++ ++DataPart.metadata doc ++Metadata contains arbitrary key-value data for this part. ++. ++ ++ReasoningPart.reasoning doc ++Reasoning contains the reasoning text of the message. ++. ++ ++ReasoningPart.metadata doc ++Metadata contains arbitrary key-value data for this part. ++. ++ ++CustomPart.custom doc ++Custom contains custom key-value data specific to this part. ++. ++ ++CustomPart.data doc ++Data contains additional arbitrary data. ++. ++ ++CustomPart.metadata doc ++Metadata contains arbitrary key-value data for this part. ++. ++ ++ResourcePart.resource doc ++Resource contains a reference to an external resource by URI. ++. ++ ++ResourcePart.metadata doc ++Metadata contains arbitrary key-value data for this part. ++. ++ ++ResourcePartResource.uri doc ++Uri is the URI of the external resource. ++. ++ ++# ---------------------------------------------------------------------------- ++# Media Types ++# ---------------------------------------------------------------------------- ++ ++Media doc ++Media represents media content with a URL and content type. ++. ++ ++Media.contentType doc ++ContentType specifies the MIME type of the media. Inferred from the data URI if not provided. ++. ++ ++Media.url doc ++Url is a "data:" or "https:" URI containing the media content. ++. ++ ++# ---------------------------------------------------------------------------- ++# Tool Types ++# ---------------------------------------------------------------------------- ++ ++ToolRequest doc ++A ToolRequest is a message from the model to the client that it should run a ++specific tool and pass a ToolResponse to the model on the next chat request it makes. ++Any ToolRequest will correspond to some ToolDefinition previously sent by the client. ++. ++ ++ToolRequest.ref doc ++Ref is the call ID or reference for this specific request. ++. ++ ++ToolRequest.name doc ++Name is the name of the tool to call. ++. ++ ++ToolRequest.input doc ++Input is a JSON object containing the input parameters for the tool. ++For example: map[string]any{"country":"USA", "president":3}. ++. ++ ++ToolRequest.partial doc ++Partial indicates whether this is a partial streaming chunk. ++. ++ ++ToolResponse doc ++A ToolResponse is a message from the client to the model containing ++the results of running a specific tool on the arguments passed to the client ++by the model in a ToolRequest. ++. ++ ++ToolResponse.ref doc ++Ref is the call ID or reference matching the original request. ++. ++ ++ToolResponse.name doc ++Name is the name of the tool that was executed. ++. ++ ++ToolResponse.output doc ++Output is a JSON object describing the results of running the tool. ++For example: map[string]any{"name":"Thomas Jefferson", "born":1743}. ++. ++ ++ToolResponse.content doc ++Content holds additional message parts that provide context or details about the tool response. ++. ++ ++ToolDefinition doc ++A ToolDefinition describes a tool. ++. ++ ++ToolDefinition.name doc ++Name is the unique identifier for this tool. ++. ++ ++ToolDefinition.description doc ++Description explains what the tool does and when to use it. ++. ++ ++ToolDefinition.inputSchema doc ++InputSchema is a valid JSON Schema representing the input parameters of the tool. ++. ++ ++ToolDefinition.outputSchema doc ++OutputSchema is a valid JSON Schema describing the output of the tool. ++. ++ ++ToolDefinition.metadata doc ++Metadata contains additional information about this tool definition. ++. ++ ++# ---------------------------------------------------------------------------- ++# Generation Configuration ++# ---------------------------------------------------------------------------- ++ ++GenerationCommonConfig doc ++GenerationCommonConfig holds configuration parameters for model generation requests. ++. ++ ++GenerationCommonConfig.version doc ++Version specifies a particular version of a model family, ++e.g., "gemini-1.0-pro-001" for the "gemini-1.0-pro" family. ++. ++ ++GenerationCommonConfig.temperature doc ++Temperature controls randomness in generation. Higher values (e.g., 0.9) make output more random, ++while lower values (e.g., 0.1) make it more deterministic. Typical range is 0.0 to 1.0. ++. ++ ++GenerationCommonConfig.maxOutputTokens doc ++MaxOutputTokens limits the maximum number of tokens generated in the response. ++. ++ ++GenerationCommonConfig.topK doc ++TopK limits sampling to the K most likely tokens at each step. ++. ++ ++GenerationCommonConfig.topP doc ++TopP (nucleus sampling) limits sampling to tokens whose cumulative probability exceeds P. ++. ++ ++GenerationCommonConfig.stopSequences doc ++StopSequences specifies sequences that will cause generation to stop when encountered. ++. ++ ++# ---------------------------------------------------------------------------- ++# Generation Usage and Metrics ++# ---------------------------------------------------------------------------- ++ ++GenerationUsage doc ++GenerationUsage provides information about resource consumption during generation. ++. ++ ++GenerationUsage.inputTokens doc ++InputTokens is the number of tokens in the input prompt. ++. ++ ++GenerationUsage.outputTokens doc ++OutputTokens is the number of tokens generated in the response. ++. ++ ++GenerationUsage.totalTokens doc ++TotalTokens is the sum of input and output tokens. ++. ++ ++GenerationUsage.inputCharacters doc ++InputCharacters is the number of characters in the input. ++. ++ ++GenerationUsage.outputCharacters doc ++OutputCharacters is the number of characters generated in the output. ++. ++ ++GenerationUsage.inputImages doc ++InputImages is the number of images in the input. ++. ++ ++GenerationUsage.outputImages doc ++OutputImages is the number of images generated in the output. ++. ++ ++GenerationUsage.inputVideos doc ++InputVideos is the number of videos in the input. ++. ++ ++GenerationUsage.outputVideos doc ++OutputVideos is the number of videos generated in the output. ++. ++ ++GenerationUsage.inputAudioFiles doc ++InputAudioFiles is the number of audio files in the input. ++. ++ ++GenerationUsage.outputAudioFiles doc ++OutputAudioFiles is the number of audio files generated in the output. ++. ++ ++GenerationUsage.thoughtsTokens doc ++ThoughtsTokens counts tokens used in reasoning or thinking processes. ++. ++ ++GenerationUsage.cachedContentTokens doc ++CachedContentTokens counts tokens that were served from cache. ++. ++ ++GenerationUsage.custom doc ++Custom contains additional usage metrics specific to the model provider. ++. ++ ++# ---------------------------------------------------------------------------- ++# Model Request and Response ++# ---------------------------------------------------------------------------- ++ ++ModelRequest doc ++A ModelRequest is a request to generate completions from a model. ++. ++ ++ModelRequest.messages doc ++Messages contains the conversation history for the model. ++. ++ ++ModelRequest.config doc ++Config holds model-specific configuration parameters. ++. ++ ++ModelRequest.docs doc ++Docs provides retrieved documents to be used as context for this generation. ++. ++ ++ModelRequest.output doc ++Output describes the desired response format. ++. ++ ++ModelRequest.tools doc ++Tools lists the available tools that the model can ask the client to run. ++. ++ ++ModelRequest.toolChoice doc ++ToolChoice controls how the model uses tools (auto, required, or none). ++. ++ ++ModelResponse doc ++A ModelResponse is a model's response to a ModelRequest. ++. ++ ++ModelResponse.message doc ++Message contains the generated response content. ++. ++ ++ModelResponse.finishReason doc ++FinishReason indicates why generation stopped (e.g., stop, length, blocked). ++. ++ ++ModelResponse.finishMessage doc ++FinishMessage provides additional details about why generation finished. ++. ++ ++ModelResponse.latencyMs doc ++LatencyMs is the time the request took in milliseconds. ++. ++ ++ModelResponse.usage doc ++Usage describes how many resources were used by this generation request. ++. ++ ++ModelResponse.custom doc ++Custom contains model-specific extra information. Deprecated: use Raw instead. ++. ++ ++ModelResponse.raw doc ++Raw contains the unprocessed model-specific response data. ++. ++ ++ModelResponse.request doc ++Request is the ModelRequest struct used to trigger this response. ++. ++ ++ModelResponse.operation doc ++Operation provides information about a long-running background task if applicable. ++. ++ ++ModelResponseChunk doc ++A ModelResponseChunk is the portion of the ModelResponse ++that is passed to a streaming callback. ++. ++ ++ModelResponseChunk.role doc ++Role indicates the entity that generated this chunk. ++. ++ ++ModelResponseChunk.index doc ++Index of the message this chunk belongs to. ++. ++ ++ModelResponseChunk.content doc ++Content is the chunk of message parts to stream right now. ++. ++ ++ModelResponseChunk.custom doc ++Custom contains model-specific extra information attached to this chunk. ++. ++ ++ModelResponseChunk.aggregated doc ++Aggregated indicates whether the chunk includes all data from previous chunks. ++If false, the chunk is considered incremental. ++. ++ ++# ---------------------------------------------------------------------------- ++# Model Information and Capabilities ++# ---------------------------------------------------------------------------- ++ ++ModelInfo doc ++ModelInfo contains metadata about a model's capabilities and characteristics. ++. ++ ++ModelInfo.versions doc ++Versions lists acceptable names for this model (e.g., different versions). ++. ++ ++ModelInfo.label doc ++Label is a friendly display name for this model (e.g., "Google AI - Gemini Pro"). ++. ++ ++ModelInfo.configSchema doc ++ConfigSchema defines the model-specific configuration schema. ++. ++ ++ModelInfo.supports doc ++Supports describes the capabilities that this model supports. ++. ++ ++ModelInfo.stage doc ++Stage indicates the development stage of this model. ++Featured models are recommended for general use, stable models are well-tested, ++unstable models are experimental, legacy models are not recommended for new projects, ++and deprecated models may be removed in future versions. ++. ++ ++ModelInfoSupports doc ++ModelSupports describes the capabilities that a model supports. ++. ++ ++ModelInfoSupports.multiturn doc ++Multiturn indicates whether the model can process historical messages passed with a prompt. ++. ++ ++ModelInfoSupports.media doc ++Media indicates whether the model can process media as part of the prompt (multimodal input). ++. ++ ++ModelInfoSupports.tools doc ++Tools indicates whether the model can perform tool calls. ++. ++ ++ModelInfoSupports.systemRole doc ++SystemRole indicates whether the model can accept messages with role "system". ++. ++ ++ModelInfoSupports.output doc ++Output lists the types of data the model can generate. ++. ++ ++ModelInfoSupports.contentType doc ++ContentType lists the content types the model supports for output. ++. ++ ++ModelInfoSupports.context doc ++Context indicates whether the model can natively support document-based context grounding. ++. ++ ++ModelInfoSupports.constrained doc ++Constrained indicates the level of constrained generation support (none, all, or no-tools). ++. ++ ++ModelInfoSupports.toolChoice doc ++ToolChoice indicates whether the model supports controlling tool choice (e.g., forced tool calling). ++. ++ ++ModelInfoSupports.longRunning doc ++LongRunning indicates whether the model supports long-running operations. ++. ++ ++# ---------------------------------------------------------------------------- ++# Output Configuration ++# ---------------------------------------------------------------------------- ++ ++OutputConfig doc ++OutputConfig describes the structure that the model's output ++should conform to. If Format is OutputFormatJSON, then Schema ++can describe the desired form of the generated JSON. ++. ++ ++OutputConfig.format doc ++Format specifies the desired output format (e.g., "json", "text"). ++. ++ ++OutputConfig.schema doc ++Schema is a JSON Schema describing the desired structure of the output. ++. ++ ++OutputConfig.constrained doc ++Constrained indicates whether to enforce strict adherence to the schema. ++. ++ ++OutputConfig.contentType doc ++ContentType specifies the MIME type of the output content. ++. ++ ++# ---------------------------------------------------------------------------- ++# Operation Types ++# ---------------------------------------------------------------------------- ++ ++Operation doc ++Operation represents a long-running background task. ++. ++ ++Operation.action doc ++Action is the name of the action being performed by this operation. ++. ++ ++Operation.id doc ++Id is the unique identifier for this operation. ++. ++ ++Operation.done doc ++Done indicates whether the operation has completed. ++. ++ ++Operation.output doc ++Output contains the result of the operation if it has completed successfully. ++. ++ ++Operation.error doc ++Error contains error information if the operation failed. ++. ++ ++Operation.metadata doc ++Metadata contains additional information about the operation. ++. ++ ++OperationError doc ++OperationError contains error information for a failed operation. ++. ++ ++OperationError.message doc ++Message describes the error that occurred. ++. ++ ++# ---------------------------------------------------------------------------- ++# Document Types ++# ---------------------------------------------------------------------------- ++ ++# Note: Document type is hand-written in ai/document.go, not generated ++ ++# ---------------------------------------------------------------------------- ++# Embedding Types ++# ---------------------------------------------------------------------------- ++ ++Embedding doc ++Embedding represents a vector embedding with associated metadata. ++. ++ ++Embedding.embedding doc ++Embedding is the vector representation of the input. ++. ++ ++Embedding.metadata doc ++Metadata identifies which part of a document this embedding corresponds to. ++. ++ ++EmbedRequest doc ++EmbedRequest represents a request to generate embeddings for documents. ++. ++ ++EmbedRequest.input doc ++Input is the array of documents to generate embeddings for. ++. ++ ++EmbedRequest.options doc ++Options contains embedder-specific configuration parameters. ++. ++ ++EmbedResponse doc ++EmbedResponse contains the generated embeddings from an embed request. ++. ++ ++EmbedResponse.embeddings doc ++Embeddings is the array of generated embedding vectors with metadata. ++. ++ ++# ---------------------------------------------------------------------------- ++# Evaluator Types (ScoreDetails only - other eval types are omitted) ++# ---------------------------------------------------------------------------- ++ ++ScoreDetails doc ++ScoreDetails provides additional context and explanation for an evaluation score. ++. ++ ++ScoreDetails.reasoning doc ++Reasoning explains the rationale behind the score. ++. ++ ++# ---------------------------------------------------------------------------- ++# Retriever Types ++# ---------------------------------------------------------------------------- ++ ++RetrieverRequest doc ++RetrieverRequest represents a request to retrieve relevant documents. ++. ++ ++RetrieverRequest.query doc ++Query is the document to use for retrieval. ++. ++ ++RetrieverRequest.options doc ++Options contains retriever-specific configuration parameters. ++. ++ ++RetrieverResponse doc ++RetrieverResponse contains the retrieved documents from a retriever request. ++. ++ ++RetrieverResponse.documents doc ++Documents is the array of retrieved documents. ++. ++ ++# ---------------------------------------------------------------------------- ++# Reranker Types ++# ---------------------------------------------------------------------------- ++ ++RerankerRequest doc ++RerankerRequest represents a request to rerank documents based on relevance. ++. ++ ++RerankerRequest.query doc ++Query is the document to use for reranking. ++. ++ ++RerankerRequest.documents doc ++Documents is the array of documents to rerank. ++. ++ ++RerankerRequest.options doc ++Options contains reranker-specific configuration parameters. ++. ++ ++RerankerResponse doc ++RerankerResponse contains the reranked documents with relevance scores. ++. ++ ++RerankerResponse.documents doc ++Documents is the array of reranked documents with scores. ++. ++ ++RankedDocumentData doc ++RankedDocumentData represents a document with a relevance score from reranking. ++. ++ ++RankedDocumentData.content doc ++Content holds the document's parts (text and media). ++. ++ ++RankedDocumentData.metadata doc ++Metadata contains the reranking score and other arbitrary key-value data. ++. ++ ++RankedDocumentMetadata doc ++RankedDocumentMetadata contains the relevance score and other metadata for a reranked document. ++. ++ ++RankedDocumentMetadata.score doc ++Score is the relevance score assigned by the reranker. ++. ++ ++# ---------------------------------------------------------------------------- ++# GenerateAction Types ++# ---------------------------------------------------------------------------- ++ ++GenerateActionOptions doc ++GenerateActionOptions holds configuration for a generate action request. ++. ++ ++GenerateActionOptions.model doc ++Model is a model name (e.g., "vertexai/gemini-1.0-pro"). ++. ++ ++GenerateActionOptions.docs doc ++Docs provides retrieved documents to be used as context for this generation. ++. ++ ++GenerateActionOptions.messages doc ++Messages contains the conversation history for multi-turn prompting when supported. ++. ++ ++GenerateActionOptions.tools doc ++Tools is a list of registered tool names for this generation if supported. ++. ++ ++GenerateActionOptions.toolChoice doc ++ToolChoice controls tool calling mode. Auto lets the model decide, required forces ++the model to choose a tool, and none forces the model not to use any tools. Defaults to auto. ++. ++ ++GenerateActionOptions.config doc ++Config contains configuration parameters for the generation request. ++. ++ ++GenerateActionOptions.output doc ++Output specifies the desired output format. Defaults to the model's default if unspecified. ++. ++ ++GenerateActionOptions.resume doc ++Resume provides options for resuming an interrupted generation. ++. ++ ++GenerateActionOptions.returnToolRequests doc ++ReturnToolRequests, when true, returns tool calls for manual processing instead of ++automatically resolving them. ++. ++ ++GenerateActionOptions.maxTurns doc ++MaxTurns is the maximum number of tool call iterations that can be performed ++in a single generate call. Defaults to 5. ++. ++ ++GenerateActionOptions.stepName doc ++StepName is a custom step name for this generate call to display in trace views. ++Defaults to "generate". ++. ++ ++GenerateActionOptionsResume doc ++GenerateActionResume holds options for resuming an interrupted generation. ++. ++ ++GenerateActionOptionsResume.respond doc ++Respond contains tool response parts to send to the model when resuming. ++. ++ ++GenerateActionOptionsResume.restart doc ++Restart contains tool request parts to restart when resuming. ++. ++ ++GenerateActionOptionsResume.metadata doc ++Metadata contains additional context for resuming the generation. ++. ++ ++GenerateActionOutputConfig doc ++GenerateActionOutputConfig specifies the desired output format for a generate action. ++. ++ ++GenerateActionOutputConfig.format doc ++Format specifies the desired output format (e.g., "json", "text"). ++. ++ ++GenerateActionOutputConfig.contentType doc ++ContentType specifies the MIME type of the output content. ++. ++ ++GenerateActionOutputConfig.instructions doc ++Instructions provides additional guidance for the output format. ++. ++ ++GenerateActionOutputConfig.jsonSchema doc ++JsonSchema is a JSON Schema describing the desired structure of JSON output. ++. ++ ++GenerateActionOutputConfig.constrained doc ++Constrained indicates whether to enforce strict adherence to the schema. ++. ++ ++GenerateActionOptionsToolChoice doc ++ToolChoice controls how the model uses tools. ++. ++ ++# ---------------------------------------------------------------------------- ++# Finish Reason Enum ++# ---------------------------------------------------------------------------- ++ ++FinishReason doc ++FinishReason indicates why generation stopped. ++. ++ ++# ---------------------------------------------------------------------------- ++# Model Stage Enum ++# ---------------------------------------------------------------------------- ++ ++ModelInfoStage doc ++ModelStage indicates the development stage of a model. ++. ++ ++# ---------------------------------------------------------------------------- ++# Constrained Support Enum ++# ---------------------------------------------------------------------------- ++ ++ModelInfoSupportsConstrained doc ++ConstrainedSupport indicates the level of constrained generation support. ++. ++ ++# ---------------------------------------------------------------------------- ++# Trace Metadata Types ++# ---------------------------------------------------------------------------- ++ ++TraceMetadata doc ++TraceMetadata contains metadata about a trace execution. ++. ++ ++TraceMetadata.featureName doc ++FeatureName identifies the feature being traced. ++. ++ ++TraceMetadata.paths doc ++Paths contains metadata for each path executed during the trace. ++. ++ ++TraceMetadata.timestamp doc ++Timestamp is when the trace was created. ++. ++ ++PathMetadata doc ++PathMetadata contains metadata about a single execution path in a trace. ++. ++ ++PathMetadata.path doc ++Path is the identifier for this execution path. ++. ++ ++PathMetadata.status doc ++Status indicates the outcome of this path. ++. ++ ++PathMetadata.latency doc ++Latency is the execution time for this path in milliseconds. ++. ++ ++PathMetadata.error doc ++Error contains error information if the path failed. ++. ++ ++# ---------------------------------------------------------------------------- ++# Multipart Tool Response ++# ---------------------------------------------------------------------------- ++ ++MultipartToolResponse doc ++MultipartToolResponse represents a tool response with both structured output and content parts. ++. ++ ++MultipartToolResponse.output doc ++Output contains the structured output data from the tool. ++. ++ ++MultipartToolResponse.content doc ++Content holds additional message parts providing context or details. ++. ++ ++# ============================================================================ ++# CONFIGURATION SECTION ++# Type mappings, omissions, and other non-documentation directives ++# ============================================================================ ++ + # DocumentData type was hand-written. + DocumentData omit + +@@ -28,52 +888,30 @@ TimeEventAnnotation omit + TraceData omit + SpanStartEvent omit + SpanEndEvent omit +-SpanEventBase omit ++# Typo in schema definition... ++SpantEventBase omit + TraceEvent omit + + GenerationCommonConfig.maxOutputTokens type int + GenerationCommonConfig.topK type int + +-Role doc +-Role indicates which entity is responsible for the content of a message. +-. +-RoleSystem doc +-RoleSystem indicates this message is user-independent context. +-. +-RoleUser doc +-RoleUser indicates this message was generated by the client. +-. +-RoleModel doc +-RoleModel indicates this message was generated by the model during a previous interaction. +-. +-RoleTool doc +-RoleTool indicates this message was generated by a local tool, likely triggered by a request +-from the model in one of its previous responses. +-. +- +-ToolRequest.input doc +-Input is a JSON object describing the input values to the tool. +-An example might be map[string]any{"country":"USA", "president":3}. +-. +-ToolResponse.output doc +-Output is a JSON object describing the results of running the tool. +-An example might be map[string]any{"name":"Thomas Jefferson", "born":1743}. +-. +- +-ToolRequest doc +-A ToolRequest is a message from the model to the client that it should run a +-specific tool and pass a [ToolResponse] to the model on the next chat request it makes. +-Any ToolRequest will correspond to some [ToolDefinition] previously sent by the client. +-. +-ToolResponse doc +-A ToolResponse is a message from the client to the model containing +-the results of running a specific tool on the arguments passed to the client +-by the model in a [ToolRequest]. +-. +- ++# Unused evaluation types ++BaseDataPoint omit ++BaseEvalDataPoint omit ++EvalFnResponse omit ++EvalRequest omit ++EvalResponse omit ++EvalStatusEnum omit + ++# Unused error types ++CandidateError omit ++CandidateErrorCode omit + Candidate omit + ++# Unused retriever/reranker option types ++CommonRerankerOptions omit ++CommonRetrieverOptions omit ++ + DocumentData pkg ai + + GenerateResponse omit +@@ -96,9 +934,7 @@ GenerationUsage.outputTokens type int + GenerationUsage.totalTokens type int + GenerationUsage.thoughtsTokens type int + GenerationUsage.cachedContentTokens type int +-GenerationUsage doc +-GenerationUsage provides information about the generation process. +-. ++ + GenerationCommonConfig pkg ai + + Message pkg ai +@@ -213,8 +1049,6 @@ RoleUser pkg ai + RoleModel pkg ai + RoleTool pkg ai + +-EvalResponse type []any +- + # GenerateActionOptions + GenerateActionOptions pkg ai + GenerateActionOptions.model type string +@@ -235,18 +1069,6 @@ GenerateActionOutputConfig.jsonSchema name Schema + GenerateActionOutputConfig.jsonSchema type map[string]any + GenerateActionOutputConfig.constrained type bool + +-BaseDataPoint.context type map[string]any +-BaseDataPoint.input type map[string]any +-BaseDataPoint.output type map[string]any +-BaseDataPoint.reference type map[string]any +-BaseDataPoint.traceIds type []string +- +-BaseEvalDataPoint.context type map[string]any +-BaseEvalDataPoint.input type map[string]any +-BaseEvalDataPoint.output type map[string]any +-BaseEvalDataPoint.reference type map[string]any +-BaseEvalDataPoint.traceIds type []string +- + # ModelRequest + ModelRequest pkg ai + ModelRequest.config type any +@@ -279,52 +1101,6 @@ ModelResponseChunk.index type int + ModelResponseChunk.role type Role + ModelResponseChunk field formatHandler StreamingFormatHandler + +-GenerationCommonConfig doc +-GenerationCommonConfig holds configuration for generation. +-. +- +-Message doc +-Message is the contents of a model response. +-. +- +-ToolDefinition doc +-A ToolDefinition describes a tool. +-. +- +-ModelRequest doc +-A ModelRequest is a request to generate completions from a model. +-. +-ModelRequest.output doc +-Output describes the desired response format. +-. +-ModelRequest.tools doc +-Tools lists the available tools that the model can ask the client to run. +-. +- +-OutputConfig doc +-OutputConfig describes the structure that the model's output +-should conform to. If Format is [OutputFormatJSON], then Schema +-can describe the desired form of the generated JSON. +-. +- +-ModelResponse doc +-A ModelResponse is a model's response to a [ModelRequest]. +-. +-ModelResponse.latencyMs doc +-LatencyMs is the time the request took in milliseconds. +-. +-ModelResponse.request doc +-Request is the [ModelRequest] struct used to trigger this response. +-. +-ModelResponse.usage doc +-Usage describes how many resources were used by this generation request. +-. +- +-ModelResponseChunk doc +-A ModelResponseChunk is the portion of the [ModelResponse] +-that is passed to a streaming callback. +-. +- + Score omit + + Embedding.embedding type []float32 +diff --git a/go/core/status_types_test.go b/go/core/status_types_test.go +new file mode 100644 +index 000000000..eec8d7c12 +--- /dev/null ++++ b/go/core/status_types_test.go +@@ -0,0 +1,123 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++// ++// SPDX-License-Identifier: Apache-2.0 ++ ++package core ++ ++import ( ++ "net/http" ++ "testing" ++) ++ ++func TestHTTPStatusCode(t *testing.T) { ++ tests := []struct { ++ name string ++ status StatusName ++ wantCode int ++ }{ ++ {"OK", OK, http.StatusOK}, ++ {"CANCELLED", CANCELLED, 499}, ++ {"UNKNOWN", UNKNOWN, http.StatusInternalServerError}, ++ {"INVALID_ARGUMENT", INVALID_ARGUMENT, http.StatusBadRequest}, ++ {"DEADLINE_EXCEEDED", DEADLINE_EXCEEDED, http.StatusGatewayTimeout}, ++ {"NOT_FOUND", NOT_FOUND, http.StatusNotFound}, ++ {"ALREADY_EXISTS", ALREADY_EXISTS, http.StatusConflict}, ++ {"PERMISSION_DENIED", PERMISSION_DENIED, http.StatusForbidden}, ++ {"UNAUTHENTICATED", UNAUTHENTICATED, http.StatusUnauthorized}, ++ {"RESOURCE_EXHAUSTED", RESOURCE_EXHAUSTED, http.StatusTooManyRequests}, ++ {"FAILED_PRECONDITION", FAILED_PRECONDITION, http.StatusBadRequest}, ++ {"ABORTED", ABORTED, http.StatusConflict}, ++ {"OUT_OF_RANGE", OUT_OF_RANGE, http.StatusBadRequest}, ++ {"UNIMPLEMENTED", UNIMPLEMENTED, http.StatusNotImplemented}, ++ {"INTERNAL", INTERNAL, http.StatusInternalServerError}, ++ {"UNAVAILABLE", UNAVAILABLE, http.StatusServiceUnavailable}, ++ {"DATA_LOSS", DATA_LOSS, http.StatusInternalServerError}, ++ } ++ ++ for _, tt := range tests { ++ t.Run(tt.name, func(t *testing.T) { ++ got := HTTPStatusCode(tt.status) ++ if got != tt.wantCode { ++ t.Errorf("HTTPStatusCode(%q) = %d, want %d", tt.status, got, tt.wantCode) ++ } ++ }) ++ } ++ ++ t.Run("unknown status returns 500", func(t *testing.T) { ++ got := HTTPStatusCode(StatusName("UNKNOWN_STATUS")) ++ if got != http.StatusInternalServerError { ++ t.Errorf("HTTPStatusCode(unknown) = %d, want %d", got, http.StatusInternalServerError) ++ } ++ }) ++} ++ ++func TestNewStatus(t *testing.T) { ++ t.Run("creates status with name and message", func(t *testing.T) { ++ s := NewStatus(NOT_FOUND, "resource not found") ++ ++ if s.Name != NOT_FOUND { ++ t.Errorf("Name = %q, want %q", s.Name, NOT_FOUND) ++ } ++ if s.Message != "resource not found" { ++ t.Errorf("Message = %q, want %q", s.Message, "resource not found") ++ } ++ }) ++ ++ t.Run("creates status with empty message", func(t *testing.T) { ++ s := NewStatus(OK, "") ++ ++ if s.Name != OK { ++ t.Errorf("Name = %q, want %q", s.Name, OK) ++ } ++ if s.Message != "" { ++ t.Errorf("Message = %q, want empty string", s.Message) ++ } ++ }) ++} ++ ++func TestStatusNameToCode(t *testing.T) { ++ t.Run("maps all status names to codes", func(t *testing.T) { ++ expectedMappings := map[StatusName]int{ ++ OK: CodeOK, ++ CANCELLED: CodeCancelled, ++ UNKNOWN: CodeUnknown, ++ INVALID_ARGUMENT: CodeInvalidArgument, ++ DEADLINE_EXCEEDED: CodeDeadlineExceeded, ++ NOT_FOUND: CodeNotFound, ++ ALREADY_EXISTS: CodeAlreadyExists, ++ PERMISSION_DENIED: CodePermissionDenied, ++ UNAUTHENTICATED: CodeUnauthenticated, ++ RESOURCE_EXHAUSTED: CodeResourceExhausted, ++ FAILED_PRECONDITION: CodeFailedPrecondition, ++ ABORTED: CodeAborted, ++ OUT_OF_RANGE: CodeOutOfRange, ++ UNIMPLEMENTED: CodeUnimplemented, ++ INTERNAL: CodeInternal, ++ UNAVAILABLE: CodeUnavailable, ++ DATA_LOSS: CodeDataLoss, ++ } ++ ++ for name, wantCode := range expectedMappings { ++ got, ok := StatusNameToCode[name] ++ if !ok { ++ t.Errorf("StatusNameToCode missing mapping for %q", name) ++ continue ++ } ++ if got != wantCode { ++ t.Errorf("StatusNameToCode[%q] = %d, want %d", name, got, wantCode) ++ } ++ } ++ }) ++} +diff --git a/go/core/tracing/doc.go b/go/core/tracing/doc.go +new file mode 100644 +index 000000000..aae609c51 +--- /dev/null ++++ b/go/core/tracing/doc.go +@@ -0,0 +1,109 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++// ++// SPDX-License-Identifier: Apache-2.0 ++ ++/* ++Package tracing provides execution trace support for Genkit operations. ++ ++This package implements OpenTelemetry-based tracing for Genkit actions and flows. ++Traces capture the execution path, inputs, outputs, and timing of operations, ++enabling observability and debugging through the Genkit Developer UI and ++external telemetry systems. ++ ++# Automatic Tracing ++ ++Actions and flows defined with Genkit are automatically traced. Each action ++execution creates a span with input/output data, timing, and any errors. ++Use [core.Run] within flows to create traced sub-steps: ++ ++ // In a real scenario, 'r' would be the registry from your Genkit instance. ++ var r api.Registry ++ flow := core.DefineFlow(r, "myFlow", ++ func(ctx context.Context, input string) (string, error) { ++ // This creates a traced step named "processData" ++ result, err := core.Run(ctx, "processData", func() (string, error) { ++ return process(input), nil ++ }) ++ return result, err ++ }, ++ ) ++ ++# Tracer Access ++ ++Access the OpenTelemetry tracer provider for custom instrumentation: ++ ++ provider := tracing.TracerProvider() ++ ++ // Get a tracer for custom spans ++ tracer := tracing.Tracer() ++ ++# Telemetry Export ++ ++Configure trace export to send telemetry to external systems. For immediate ++export (suitable for local storage): ++ ++ tracing.WriteTelemetryImmediate(client) ++ ++For batched export (more efficient for network calls): ++ ++ shutdown := tracing.WriteTelemetryBatch(client) ++ defer shutdown(ctx) ++ ++# Dev UI Integration ++ ++When the GENKIT_ENV environment variable is set to "dev", traces are ++automatically sent to the Genkit Developer UI's telemetry server. The Dev UI ++provides: ++ ++ - Visual trace exploration with timing breakdown ++ - Input/output inspection for each action ++ - Error highlighting and stack traces ++ - Performance analysis across flow executions ++ ++Set GENKIT_TELEMETRY_SERVER to configure a custom telemetry endpoint. ++ ++# Span Metadata ++ ++Create spans with rich metadata for better observability: ++ ++ metadata := &tracing.SpanMetadata{ ++ Name: "processDocument", ++ Type: "action", ++ Subtype: "retriever", ++ } ++ ++ output, err := tracing.RunInNewSpan(ctx, metadata, input, ++ func(ctx context.Context, in Input) (Output, error) { ++ // Operation runs within the traced span ++ return process(in), nil ++ }, ++ ) ++ ++# Trace Information ++ ++Extract trace context for correlation with external systems: ++ ++ info := tracing.GetTraceInfo(ctx) ++ if info != nil { ++ log.Printf("TraceID: %s, SpanID: %s", info.TraceID, info.SpanID) ++ } ++ ++This package is primarily intended for Genkit internals and advanced plugin ++development. Most application developers will interact with tracing through ++the automatic instrumentation provided by the genkit package. ++ ++For more information on observability, see https://genkit.dev/docs/observability ++*/ ++package tracing +diff --git a/go/core/x/streaming/streaming.go b/go/core/x/streaming/streaming.go +new file mode 100644 +index 000000000..3fb51a334 +--- /dev/null ++++ b/go/core/x/streaming/streaming.go +@@ -0,0 +1,382 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++// ++// SPDX-License-Identifier: Apache-2.0 ++ ++// Package streaming provides experimental durable streaming APIs for Genkit. ++// ++// APIs in this package are under active development and may change in any ++// minor version release. Use with caution in production environments. ++// ++// When these APIs stabilize, they will be moved to their parent packages ++// (e.g., core and genkit) and these exports will be deprecated. ++package streaming ++ ++import ( ++ "context" ++ "encoding/json" ++ "sync" ++ "time" ++ ++ "github.com/firebase/genkit/go/core" ++) ++ ++// StreamEventType indicates the type of stream event. ++type StreamEventType int ++ ++const ( ++ StreamEventChunk StreamEventType = iota ++ StreamEventDone ++ StreamEventError ++) ++ ++// StreamEvent represents an event in a durable stream. ++type StreamEvent struct { ++ Type StreamEventType ++ Chunk json.RawMessage // set when Type == StreamEventChunk ++ Output json.RawMessage // set when Type == StreamEventDone ++ Err error // set when Type == StreamEventError ++} ++ ++// StreamInput provides methods for writing to a durable stream. ++type StreamInput interface { ++ // Write sends a chunk to the stream and notifies all subscribers. ++ Write(ctx context.Context, chunk json.RawMessage) error ++ // Done marks the stream as successfully completed with the given output. ++ Done(ctx context.Context, output json.RawMessage) error ++ // Error marks the stream as failed with the given error. ++ Error(ctx context.Context, err error) error ++ // Close releases resources without marking the stream as done or errored. ++ Close() error ++} ++ ++// StreamManager manages durable streams, allowing creation and subscription. ++// Implementations can provide different storage backends (e.g., in-memory, database, cache). ++type StreamManager interface { ++ // Open creates a new stream for writing. ++ // Returns an error if a stream with the given ID already exists. ++ Open(ctx context.Context, streamID string) (StreamInput, error) ++ // Subscribe subscribes to an existing stream. ++ // Returns a channel that receives stream events, an unsubscribe function, and an error. ++ // If the stream has already completed, all buffered events are sent before the done/error event. ++ // Returns NOT_FOUND error if the stream doesn't exist. ++ Subscribe(ctx context.Context, streamID string) (<-chan StreamEvent, func(), error) ++} ++ ++// inMemoryStreamBufferSize is the buffer size for subscriber event channels. ++const inMemoryStreamBufferSize = 100 ++ ++// streamStatus represents the current state of a stream. ++type streamStatus int ++ ++const ( ++ streamStatusOpen streamStatus = iota ++ streamStatusDone ++ streamStatusError ++) ++ ++// streamState holds the internal state of a single stream. ++type streamState struct { ++ status streamStatus ++ chunks []json.RawMessage ++ output json.RawMessage ++ err error ++ subscribers []chan StreamEvent ++ lastTouched time.Time ++ mu sync.RWMutex ++} ++ ++// InMemoryStreamManager is an in-memory implementation of StreamManager. ++// Useful for testing or single-instance deployments where persistence is not required. ++// Call Close to stop the background cleanup goroutine when the manager is no longer needed. ++type InMemoryStreamManager struct { ++ streams map[string]*streamState ++ mu sync.RWMutex ++ ttl time.Duration ++ stopCh chan struct{} ++ doneCh chan struct{} ++} ++ ++// StreamManagerOption configures an InMemoryStreamManager. ++type StreamManagerOption interface { ++ applyInMemoryStreamManager(*streamManagerOptions) ++} ++ ++// streamManagerOptions holds configuration for InMemoryStreamManager. ++type streamManagerOptions struct { ++ TTL time.Duration // Time-to-live for completed streams. ++} ++ ++func (o *streamManagerOptions) applyInMemoryStreamManager(opts *streamManagerOptions) { ++ if o.TTL > 0 { ++ opts.TTL = o.TTL ++ } ++} ++ ++// WithTTL sets the time-to-live for completed streams. ++// Streams that have completed (done or error) will be cleaned up after this duration. ++// Default is 5 minutes. ++func WithTTL(ttl time.Duration) StreamManagerOption { ++ return &streamManagerOptions{TTL: ttl} ++} ++ ++// NewInMemoryStreamManager creates a new InMemoryStreamManager. ++// A background goroutine is started to periodically clean up expired streams. ++// Call Close to stop the goroutine when the manager is no longer needed. ++func NewInMemoryStreamManager(opts ...StreamManagerOption) *InMemoryStreamManager { ++ options := &streamManagerOptions{ ++ TTL: 5 * time.Minute, ++ } ++ for _, opt := range opts { ++ opt.applyInMemoryStreamManager(options) ++ } ++ m := &InMemoryStreamManager{ ++ streams: make(map[string]*streamState), ++ ttl: options.TTL, ++ stopCh: make(chan struct{}), ++ doneCh: make(chan struct{}), ++ } ++ go m.cleanupLoop() ++ return m ++} ++ ++// cleanupLoop runs periodically to remove expired streams. ++func (m *InMemoryStreamManager) cleanupLoop() { ++ ticker := time.NewTicker(time.Minute) ++ defer ticker.Stop() ++ defer close(m.doneCh) ++ ++ for { ++ select { ++ case <-m.stopCh: ++ return ++ case <-ticker.C: ++ m.cleanupExpiredStreams() ++ } ++ } ++} ++ ++// cleanupExpiredStreams removes streams that have completed and exceeded the TTL. ++func (m *InMemoryStreamManager) cleanupExpiredStreams() { ++ now := time.Now() ++ m.mu.Lock() ++ defer m.mu.Unlock() ++ ++ for id, state := range m.streams { ++ state.mu.RLock() ++ shouldDelete := state.status != streamStatusOpen && now.Sub(state.lastTouched) > m.ttl ++ state.mu.RUnlock() ++ if shouldDelete { ++ delete(m.streams, id) ++ } ++ } ++} ++ ++// Close stops the background cleanup goroutine and releases resources. ++// This method blocks until the cleanup goroutine has stopped. ++func (m *InMemoryStreamManager) Close() { ++ close(m.stopCh) ++ <-m.doneCh ++} ++ ++// Open creates a new stream for writing. ++func (m *InMemoryStreamManager) Open(ctx context.Context, streamID string) (StreamInput, error) { ++ m.mu.Lock() ++ defer m.mu.Unlock() ++ ++ if _, exists := m.streams[streamID]; exists { ++ return nil, core.NewPublicError(core.ALREADY_EXISTS, "stream already exists", nil) ++ } ++ ++ state := &streamState{ ++ status: streamStatusOpen, ++ chunks: make([]json.RawMessage, 0), ++ subscribers: make([]chan StreamEvent, 0), ++ lastTouched: time.Now(), ++ } ++ m.streams[streamID] = state ++ ++ return &inMemoryStreamInput{ ++ manager: m, ++ streamID: streamID, ++ state: state, ++ }, nil ++} ++ ++// Subscribe subscribes to an existing stream. ++func (m *InMemoryStreamManager) Subscribe(ctx context.Context, streamID string) (<-chan StreamEvent, func(), error) { ++ m.mu.RLock() ++ state, exists := m.streams[streamID] ++ m.mu.RUnlock() ++ ++ if !exists { ++ return nil, nil, core.NewPublicError(core.NOT_FOUND, "stream not found", nil) ++ } ++ ++ ch := make(chan StreamEvent, inMemoryStreamBufferSize) ++ ++ state.mu.Lock() ++ defer state.mu.Unlock() ++ ++ // Send all buffered chunks ++ for _, chunk := range state.chunks { ++ select { ++ case ch <- StreamEvent{Type: StreamEventChunk, Chunk: chunk}: ++ case <-ctx.Done(): ++ close(ch) ++ return nil, nil, ctx.Err() ++ } ++ } ++ ++ // Handle completed streams ++ switch state.status { ++ case streamStatusDone: ++ ch <- StreamEvent{Type: StreamEventDone, Output: state.output} ++ close(ch) ++ return ch, func() {}, nil ++ case streamStatusError: ++ ch <- StreamEvent{Type: StreamEventError, Err: state.err} ++ close(ch) ++ return ch, func() {}, nil ++ } ++ ++ // Stream is still open, add subscriber ++ state.subscribers = append(state.subscribers, ch) ++ ++ unsubscribe := func() { ++ state.mu.Lock() ++ defer state.mu.Unlock() ++ for i, sub := range state.subscribers { ++ if sub == ch { ++ state.subscribers = append(state.subscribers[:i], state.subscribers[i+1:]...) ++ close(ch) ++ break ++ } ++ } ++ } ++ ++ return ch, unsubscribe, nil ++} ++ ++// inMemoryStreamInput implements ActionStreamInput for the in-memory manager. ++type inMemoryStreamInput struct { ++ manager *InMemoryStreamManager ++ streamID string ++ state *streamState ++ closed bool ++ mu sync.Mutex ++} ++ ++func (s *inMemoryStreamInput) Write(_ context.Context, chunk json.RawMessage) error { ++ s.mu.Lock() ++ defer s.mu.Unlock() ++ ++ if s.closed { ++ return core.NewPublicError(core.FAILED_PRECONDITION, "stream writer is closed", nil) ++ } ++ ++ s.state.mu.Lock() ++ defer s.state.mu.Unlock() ++ ++ if s.state.status != streamStatusOpen { ++ return core.NewPublicError(core.FAILED_PRECONDITION, "stream has already completed", nil) ++ } ++ ++ s.state.chunks = append(s.state.chunks, chunk) ++ s.state.lastTouched = time.Now() ++ ++ event := StreamEvent{Type: StreamEventChunk, Chunk: chunk} ++ for _, ch := range s.state.subscribers { ++ select { ++ case ch <- event: ++ default: ++ // Channel full, skip (subscriber is slow) ++ } ++ } ++ ++ return nil ++} ++ ++func (s *inMemoryStreamInput) Done(_ context.Context, output json.RawMessage) error { ++ s.mu.Lock() ++ defer s.mu.Unlock() ++ ++ if s.closed { ++ return core.NewPublicError(core.FAILED_PRECONDITION, "stream writer is closed", nil) ++ } ++ s.closed = true ++ ++ s.state.mu.Lock() ++ defer s.state.mu.Unlock() ++ ++ if s.state.status != streamStatusOpen { ++ return core.NewPublicError(core.FAILED_PRECONDITION, "stream has already completed", nil) ++ } ++ ++ s.state.status = streamStatusDone ++ s.state.output = output ++ s.state.lastTouched = time.Now() ++ ++ event := StreamEvent{Type: StreamEventDone, Output: output} ++ for _, ch := range s.state.subscribers { ++ select { ++ case ch <- event: ++ default: ++ } ++ close(ch) ++ } ++ s.state.subscribers = nil ++ ++ return nil ++} ++ ++func (s *inMemoryStreamInput) Error(_ context.Context, err error) error { ++ s.mu.Lock() ++ defer s.mu.Unlock() ++ ++ if s.closed { ++ return core.NewPublicError(core.FAILED_PRECONDITION, "stream writer is closed", nil) ++ } ++ s.closed = true ++ ++ s.state.mu.Lock() ++ defer s.state.mu.Unlock() ++ ++ if s.state.status != streamStatusOpen { ++ return core.NewPublicError(core.FAILED_PRECONDITION, "stream has already completed", nil) ++ } ++ ++ s.state.status = streamStatusError ++ s.state.err = err ++ s.state.lastTouched = time.Now() ++ ++ event := StreamEvent{Type: StreamEventError, Err: err} ++ for _, ch := range s.state.subscribers { ++ select { ++ case ch <- event: ++ default: ++ } ++ close(ch) ++ } ++ s.state.subscribers = nil ++ ++ return nil ++} ++ ++func (s *inMemoryStreamInput) Close() error { ++ s.mu.Lock() ++ defer s.mu.Unlock() ++ s.closed = true ++ return nil ++} +diff --git a/go/core/x/streaming/streaming_test.go b/go/core/x/streaming/streaming_test.go +new file mode 100644 +index 000000000..e86ce6f6e +--- /dev/null ++++ b/go/core/x/streaming/streaming_test.go +@@ -0,0 +1,789 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++// ++// SPDX-License-Identifier: Apache-2.0 ++ ++package streaming ++ ++import ( ++ "context" ++ "encoding/json" ++ "errors" ++ "sync" ++ "testing" ++ "time" ++ ++ "github.com/firebase/genkit/go/core" ++) ++ ++func TestInMemoryStreamManager_OpenAndSubscribe(t *testing.T) { ++ m := NewInMemoryStreamManager() ++ defer m.Close() ++ ++ ctx := context.Background() ++ streamID := "test-stream-1" ++ ++ // Open a new stream ++ writer, err := m.Open(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Open failed: %v", err) ++ } ++ if writer == nil { ++ t.Fatal("Open returned nil writer") ++ } ++ ++ // Subscribe to the stream ++ events, unsubscribe, err := m.Subscribe(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Subscribe failed: %v", err) ++ } ++ defer unsubscribe() ++ ++ if events == nil { ++ t.Fatal("Subscribe returned nil channel") ++ } ++} ++ ++func TestInMemoryStreamManager_OpenDuplicateFails(t *testing.T) { ++ m := NewInMemoryStreamManager() ++ defer m.Close() ++ ++ ctx := context.Background() ++ streamID := "test-stream-dup" ++ ++ // Open first stream ++ _, err := m.Open(ctx, streamID) ++ if err != nil { ++ t.Fatalf("First Open failed: %v", err) ++ } ++ ++ // Try to open duplicate ++ _, err = m.Open(ctx, streamID) ++ if err == nil { ++ t.Fatal("Expected error when opening duplicate stream") ++ } ++ ++ var ufErr *core.UserFacingError ++ if !errors.As(err, &ufErr) { ++ t.Fatalf("Expected UserFacingError, got %T", err) ++ } ++ if ufErr.Status != core.ALREADY_EXISTS { ++ t.Errorf("Expected ALREADY_EXISTS status, got %v", ufErr.Status) ++ } ++} ++ ++func TestInMemoryStreamManager_SubscribeNonExistent(t *testing.T) { ++ m := NewInMemoryStreamManager() ++ defer m.Close() ++ ++ ctx := context.Background() ++ ++ _, _, err := m.Subscribe(ctx, "non-existent") ++ if err == nil { ++ t.Fatal("Expected error when subscribing to non-existent stream") ++ } ++ ++ var ufErr *core.UserFacingError ++ if !errors.As(err, &ufErr) { ++ t.Fatalf("Expected UserFacingError, got %T", err) ++ } ++ if ufErr.Status != core.NOT_FOUND { ++ t.Errorf("Expected NOT_FOUND status, got %v", ufErr.Status) ++ } ++} ++ ++func TestInMemoryStreamManager_WriteAndReceiveChunks(t *testing.T) { ++ m := NewInMemoryStreamManager() ++ defer m.Close() ++ ++ ctx := context.Background() ++ streamID := "test-stream-chunks" ++ ++ writer, err := m.Open(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Open failed: %v", err) ++ } ++ ++ events, unsubscribe, err := m.Subscribe(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Subscribe failed: %v", err) ++ } ++ defer unsubscribe() ++ ++ // Write chunks ++ chunks := []string{"chunk1", "chunk2", "chunk3"} ++ for _, chunk := range chunks { ++ if err := writer.Write(ctx, json.RawMessage(`"`+chunk+`"`)); err != nil { ++ t.Fatalf("Write failed: %v", err) ++ } ++ } ++ ++ // Read chunks ++ for i, expected := range chunks { ++ select { ++ case event := <-events: ++ if event.Type != StreamEventChunk { ++ t.Errorf("Expected chunk event, got %v", event.Type) ++ } ++ var got string ++ if err := json.Unmarshal(event.Chunk, &got); err != nil { ++ t.Fatalf("Failed to unmarshal chunk: %v", err) ++ } ++ if got != expected { ++ t.Errorf("Chunk %d: expected %q, got %q", i, expected, got) ++ } ++ case <-time.After(time.Second): ++ t.Fatalf("Timeout waiting for chunk %d", i) ++ } ++ } ++} ++ ++func TestInMemoryStreamManager_Done(t *testing.T) { ++ m := NewInMemoryStreamManager() ++ defer m.Close() ++ ++ ctx := context.Background() ++ streamID := "test-stream-done" ++ ++ writer, err := m.Open(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Open failed: %v", err) ++ } ++ ++ events, unsubscribe, err := m.Subscribe(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Subscribe failed: %v", err) ++ } ++ defer unsubscribe() ++ ++ // Write a chunk ++ if err := writer.Write(ctx, json.RawMessage(`"test-chunk"`)); err != nil { ++ t.Fatalf("Write failed: %v", err) ++ } ++ ++ // Mark as done ++ output := json.RawMessage(`{"result": "success"}`) ++ if err := writer.Done(ctx, output); err != nil { ++ t.Fatalf("Done failed: %v", err) ++ } ++ ++ // Should receive chunk then done ++ select { ++ case event := <-events: ++ if event.Type != StreamEventChunk { ++ t.Errorf("Expected chunk event first, got %v", event.Type) ++ } ++ case <-time.After(time.Second): ++ t.Fatal("Timeout waiting for chunk") ++ } ++ ++ select { ++ case event := <-events: ++ if event.Type != StreamEventDone { ++ t.Errorf("Expected done event, got %v", event.Type) ++ } ++ if string(event.Output) != string(output) { ++ t.Errorf("Expected output %s, got %s", output, event.Output) ++ } ++ case <-time.After(time.Second): ++ t.Fatal("Timeout waiting for done event") ++ } ++} ++ ++func TestInMemoryStreamManager_Error(t *testing.T) { ++ m := NewInMemoryStreamManager() ++ defer m.Close() ++ ++ ctx := context.Background() ++ streamID := "test-stream-error" ++ ++ writer, err := m.Open(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Open failed: %v", err) ++ } ++ ++ events, unsubscribe, err := m.Subscribe(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Subscribe failed: %v", err) ++ } ++ defer unsubscribe() ++ ++ // Mark as error ++ streamErr := core.NewPublicError(core.INTERNAL, "test error", nil) ++ if err := writer.Error(ctx, streamErr); err != nil { ++ t.Fatalf("Error failed: %v", err) ++ } ++ ++ select { ++ case event := <-events: ++ if event.Type != StreamEventError { ++ t.Errorf("Expected error event, got %v", event.Type) ++ } ++ if event.Err == nil { ++ t.Error("Expected error to be set") ++ } ++ case <-time.After(time.Second): ++ t.Fatal("Timeout waiting for error event") ++ } ++} ++ ++func TestInMemoryStreamManager_WriteAfterDone(t *testing.T) { ++ m := NewInMemoryStreamManager() ++ defer m.Close() ++ ++ ctx := context.Background() ++ streamID := "test-stream-write-after-done" ++ ++ writer, err := m.Open(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Open failed: %v", err) ++ } ++ ++ if err := writer.Done(ctx, json.RawMessage(`"done"`)); err != nil { ++ t.Fatalf("Done failed: %v", err) ++ } ++ ++ // Try to write after done ++ err = writer.Write(ctx, json.RawMessage(`"chunk"`)) ++ if err == nil { ++ t.Fatal("Expected error when writing after done") ++ } ++ ++ var ufErr *core.UserFacingError ++ if !errors.As(err, &ufErr) { ++ t.Fatalf("Expected UserFacingError, got %T", err) ++ } ++ if ufErr.Status != core.FAILED_PRECONDITION { ++ t.Errorf("Expected FAILED_PRECONDITION status, got %v", ufErr.Status) ++ } ++} ++ ++func TestInMemoryStreamManager_WriteAfterClose(t *testing.T) { ++ m := NewInMemoryStreamManager() ++ defer m.Close() ++ ++ ctx := context.Background() ++ streamID := "test-stream-write-after-close" ++ ++ writer, err := m.Open(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Open failed: %v", err) ++ } ++ ++ if err := writer.Close(); err != nil { ++ t.Fatalf("Close failed: %v", err) ++ } ++ ++ // Try to write after close ++ err = writer.Write(ctx, json.RawMessage(`"chunk"`)) ++ if err == nil { ++ t.Fatal("Expected error when writing after close") ++ } ++ ++ var ufErr *core.UserFacingError ++ if !errors.As(err, &ufErr) { ++ t.Fatalf("Expected UserFacingError, got %T", err) ++ } ++ if ufErr.Status != core.FAILED_PRECONDITION { ++ t.Errorf("Expected FAILED_PRECONDITION status, got %v", ufErr.Status) ++ } ++} ++ ++func TestInMemoryStreamManager_DoneAfterError(t *testing.T) { ++ m := NewInMemoryStreamManager() ++ defer m.Close() ++ ++ ctx := context.Background() ++ streamID := "test-stream-done-after-error" ++ ++ writer, err := m.Open(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Open failed: %v", err) ++ } ++ ++ if err := writer.Error(ctx, core.NewPublicError(core.INTERNAL, "test", nil)); err != nil { ++ t.Fatalf("Error failed: %v", err) ++ } ++ ++ // Try to mark done after error ++ err = writer.Done(ctx, json.RawMessage(`"done"`)) ++ if err == nil { ++ t.Fatal("Expected error when calling Done after Error") ++ } ++} ++ ++func TestInMemoryStreamManager_MultipleSubscribers(t *testing.T) { ++ m := NewInMemoryStreamManager() ++ defer m.Close() ++ ++ ctx := context.Background() ++ streamID := "test-stream-multi-sub" ++ ++ writer, err := m.Open(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Open failed: %v", err) ++ } ++ ++ // Create multiple subscribers ++ events1, unsub1, err := m.Subscribe(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Subscribe 1 failed: %v", err) ++ } ++ defer unsub1() ++ ++ events2, unsub2, err := m.Subscribe(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Subscribe 2 failed: %v", err) ++ } ++ defer unsub2() ++ ++ // Write a chunk ++ chunk := json.RawMessage(`"shared-chunk"`) ++ if err := writer.Write(ctx, chunk); err != nil { ++ t.Fatalf("Write failed: %v", err) ++ } ++ ++ // Both subscribers should receive the chunk ++ for i, events := range []<-chan StreamEvent{events1, events2} { ++ select { ++ case event := <-events: ++ if event.Type != StreamEventChunk { ++ t.Errorf("Subscriber %d: expected chunk event, got %v", i+1, event.Type) ++ } ++ if string(event.Chunk) != string(chunk) { ++ t.Errorf("Subscriber %d: expected chunk %s, got %s", i+1, chunk, event.Chunk) ++ } ++ case <-time.After(time.Second): ++ t.Fatalf("Subscriber %d: timeout waiting for chunk", i+1) ++ } ++ } ++} ++ ++func TestInMemoryStreamManager_LateSubscriberGetsBufferedChunks(t *testing.T) { ++ m := NewInMemoryStreamManager() ++ defer m.Close() ++ ++ ctx := context.Background() ++ streamID := "test-stream-late-sub" ++ ++ writer, err := m.Open(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Open failed: %v", err) ++ } ++ ++ // Write chunks before any subscriber ++ chunks := []string{"early1", "early2"} ++ for _, chunk := range chunks { ++ if err := writer.Write(ctx, json.RawMessage(`"`+chunk+`"`)); err != nil { ++ t.Fatalf("Write failed: %v", err) ++ } ++ } ++ ++ // Late subscriber joins ++ events, unsubscribe, err := m.Subscribe(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Subscribe failed: %v", err) ++ } ++ defer unsubscribe() ++ ++ // Should receive buffered chunks ++ for i, expected := range chunks { ++ select { ++ case event := <-events: ++ if event.Type != StreamEventChunk { ++ t.Errorf("Expected chunk event, got %v", event.Type) ++ } ++ var got string ++ if err := json.Unmarshal(event.Chunk, &got); err != nil { ++ t.Fatalf("Failed to unmarshal chunk: %v", err) ++ } ++ if got != expected { ++ t.Errorf("Chunk %d: expected %q, got %q", i, expected, got) ++ } ++ case <-time.After(time.Second): ++ t.Fatalf("Timeout waiting for buffered chunk %d", i) ++ } ++ } ++} ++ ++func TestInMemoryStreamManager_SubscribeToCompletedStream(t *testing.T) { ++ m := NewInMemoryStreamManager() ++ defer m.Close() ++ ++ ctx := context.Background() ++ streamID := "test-stream-completed" ++ ++ writer, err := m.Open(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Open failed: %v", err) ++ } ++ ++ // Write and complete before subscribing ++ if err := writer.Write(ctx, json.RawMessage(`"chunk1"`)); err != nil { ++ t.Fatalf("Write failed: %v", err) ++ } ++ if err := writer.Write(ctx, json.RawMessage(`"chunk2"`)); err != nil { ++ t.Fatalf("Write failed: %v", err) ++ } ++ output := json.RawMessage(`{"final": true}`) ++ if err := writer.Done(ctx, output); err != nil { ++ t.Fatalf("Done failed: %v", err) ++ } ++ ++ // Subscribe after completion ++ events, unsubscribe, err := m.Subscribe(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Subscribe failed: %v", err) ++ } ++ defer unsubscribe() ++ ++ // Should receive all buffered chunks ++ for i := 0; i < 2; i++ { ++ select { ++ case event := <-events: ++ if event.Type != StreamEventChunk { ++ t.Errorf("Expected chunk event %d, got %v", i, event.Type) ++ } ++ case <-time.After(time.Second): ++ t.Fatalf("Timeout waiting for chunk %d", i) ++ } ++ } ++ ++ // Should receive done event ++ select { ++ case event := <-events: ++ if event.Type != StreamEventDone { ++ t.Errorf("Expected done event, got %v", event.Type) ++ } ++ if string(event.Output) != string(output) { ++ t.Errorf("Expected output %s, got %s", output, event.Output) ++ } ++ case <-time.After(time.Second): ++ t.Fatal("Timeout waiting for done event") ++ } ++ ++ // Channel should be closed ++ select { ++ case _, ok := <-events: ++ if ok { ++ t.Error("Expected channel to be closed") ++ } ++ case <-time.After(100 * time.Millisecond): ++ t.Error("Channel not closed after done") ++ } ++} ++ ++func TestInMemoryStreamManager_SubscribeToErroredStream(t *testing.T) { ++ m := NewInMemoryStreamManager() ++ defer m.Close() ++ ++ ctx := context.Background() ++ streamID := "test-stream-errored" ++ ++ writer, err := m.Open(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Open failed: %v", err) ++ } ++ ++ // Write and error before subscribing ++ if err := writer.Write(ctx, json.RawMessage(`"chunk1"`)); err != nil { ++ t.Fatalf("Write failed: %v", err) ++ } ++ streamErr := core.NewPublicError(core.INTERNAL, "test error", nil) ++ if err := writer.Error(ctx, streamErr); err != nil { ++ t.Fatalf("Error failed: %v", err) ++ } ++ ++ // Subscribe after error ++ events, unsubscribe, err := m.Subscribe(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Subscribe failed: %v", err) ++ } ++ defer unsubscribe() ++ ++ // Should receive buffered chunk ++ select { ++ case event := <-events: ++ if event.Type != StreamEventChunk { ++ t.Errorf("Expected chunk event, got %v", event.Type) ++ } ++ case <-time.After(time.Second): ++ t.Fatal("Timeout waiting for chunk") ++ } ++ ++ // Should receive error event ++ select { ++ case event := <-events: ++ if event.Type != StreamEventError { ++ t.Errorf("Expected error event, got %v", event.Type) ++ } ++ case <-time.After(time.Second): ++ t.Fatal("Timeout waiting for error event") ++ } ++} ++ ++func TestInMemoryStreamManager_Unsubscribe(t *testing.T) { ++ m := NewInMemoryStreamManager() ++ defer m.Close() ++ ++ ctx := context.Background() ++ streamID := "test-stream-unsub" ++ ++ writer, err := m.Open(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Open failed: %v", err) ++ } ++ ++ events, unsubscribe, err := m.Subscribe(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Subscribe failed: %v", err) ++ } ++ ++ // Unsubscribe ++ unsubscribe() ++ ++ // Write a chunk - should not panic ++ if err := writer.Write(ctx, json.RawMessage(`"chunk"`)); err != nil { ++ t.Fatalf("Write failed: %v", err) ++ } ++ ++ // Events channel should be closed ++ select { ++ case _, ok := <-events: ++ if ok { ++ t.Error("Expected channel to be closed after unsubscribe") ++ } ++ case <-time.After(100 * time.Millisecond): ++ t.Error("Channel not closed after unsubscribe") ++ } ++} ++ ++func TestInMemoryStreamManager_WithTTL(t *testing.T) { ++ m := NewInMemoryStreamManager(WithTTL(10 * time.Millisecond)) ++ defer m.Close() ++ ++ if m.ttl != 10*time.Millisecond { ++ t.Errorf("Expected TTL 10ms, got %v", m.ttl) ++ } ++} ++ ++func TestInMemoryStreamManager_ConcurrentOperations(t *testing.T) { ++ m := NewInMemoryStreamManager() ++ defer m.Close() ++ ++ ctx := context.Background() ++ streamID := "test-stream-concurrent" ++ ++ writer, err := m.Open(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Open failed: %v", err) ++ } ++ ++ const numSubscribers = 5 ++ const numChunks = 10 ++ ++ var wg sync.WaitGroup ++ errors := make(chan error, numSubscribers*numChunks) ++ ++ // Start subscribers ++ for i := 0; i < numSubscribers; i++ { ++ wg.Add(1) ++ go func(subID int) { ++ defer wg.Done() ++ ++ events, unsubscribe, err := m.Subscribe(ctx, streamID) ++ if err != nil { ++ errors <- err ++ return ++ } ++ defer unsubscribe() ++ ++ received := 0 ++ for event := range events { ++ if event.Type == StreamEventChunk { ++ received++ ++ } else if event.Type == StreamEventDone { ++ break ++ } ++ } ++ ++ if received != numChunks { ++ errors <- core.NewPublicError(core.INTERNAL, "subscriber %d received %d chunks, expected %d", nil) ++ } ++ }(i) ++ } ++ ++ // Give subscribers time to set up ++ time.Sleep(50 * time.Millisecond) ++ ++ // Write chunks concurrently ++ for i := 0; i < numChunks; i++ { ++ if err := writer.Write(ctx, json.RawMessage(`"chunk"`)); err != nil { ++ t.Fatalf("Write failed: %v", err) ++ } ++ } ++ ++ // Complete the stream ++ if err := writer.Done(ctx, json.RawMessage(`"done"`)); err != nil { ++ t.Fatalf("Done failed: %v", err) ++ } ++ ++ wg.Wait() ++ close(errors) ++ ++ for err := range errors { ++ t.Errorf("Subscriber error: %v", err) ++ } ++} ++ ++func TestInMemoryStreamManager_Close(t *testing.T) { ++ m := NewInMemoryStreamManager() ++ ++ // Close should not block ++ done := make(chan struct{}) ++ go func() { ++ m.Close() ++ close(done) ++ }() ++ ++ select { ++ case <-done: ++ // Success ++ case <-time.After(time.Second): ++ t.Fatal("Close blocked") ++ } ++} ++ ++func TestInMemoryStreamManager_CleanupExpiredStreams(t *testing.T) { ++ m := NewInMemoryStreamManager(WithTTL(10 * time.Millisecond)) ++ defer m.Close() ++ ++ ctx := context.Background() ++ ++ // Create and complete a stream ++ writer, err := m.Open(ctx, "expired-stream") ++ if err != nil { ++ t.Fatalf("Open failed: %v", err) ++ } ++ if err := writer.Done(ctx, json.RawMessage(`"done"`)); err != nil { ++ t.Fatalf("Done failed: %v", err) ++ } ++ ++ // Wait for TTL to expire ++ time.Sleep(20 * time.Millisecond) ++ ++ // Trigger cleanup ++ m.cleanupExpiredStreams() ++ ++ // Stream should be gone ++ _, _, err = m.Subscribe(ctx, "expired-stream") ++ if err == nil { ++ t.Fatal("Expected error subscribing to expired stream") ++ } ++ ++ var ufErr *core.UserFacingError ++ if !errors.As(err, &ufErr) { ++ t.Fatalf("Expected UserFacingError, got %T", err) ++ } ++ if ufErr.Status != core.NOT_FOUND { ++ t.Errorf("Expected NOT_FOUND status, got %v", ufErr.Status) ++ } ++} ++ ++func TestInMemoryStreamManager_OpenStreamsNotCleanedUp(t *testing.T) { ++ m := NewInMemoryStreamManager(WithTTL(10 * time.Millisecond)) ++ defer m.Close() ++ ++ ctx := context.Background() ++ ++ // Create an open stream (not completed) ++ _, err := m.Open(ctx, "open-stream") ++ if err != nil { ++ t.Fatalf("Open failed: %v", err) ++ } ++ ++ // Wait longer than TTL ++ time.Sleep(20 * time.Millisecond) ++ ++ // Trigger cleanup ++ m.cleanupExpiredStreams() ++ ++ // Stream should still exist ++ _, _, err = m.Subscribe(ctx, "open-stream") ++ if err != nil { ++ t.Fatalf("Subscribe failed: %v", err) ++ } ++} ++ ++func TestInMemoryStreamManager_ErrorAfterClose(t *testing.T) { ++ m := NewInMemoryStreamManager() ++ defer m.Close() ++ ++ ctx := context.Background() ++ streamID := "test-stream-error-after-close" ++ ++ writer, err := m.Open(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Open failed: %v", err) ++ } ++ ++ if err := writer.Close(); err != nil { ++ t.Fatalf("Close failed: %v", err) ++ } ++ ++ // Try to error after close ++ err = writer.Error(ctx, core.NewPublicError(core.INTERNAL, "test", nil)) ++ if err == nil { ++ t.Fatal("Expected error when calling Error after Close") ++ } ++ ++ var ufErr *core.UserFacingError ++ if !errors.As(err, &ufErr) { ++ t.Fatalf("Expected UserFacingError, got %T", err) ++ } ++ if ufErr.Status != core.FAILED_PRECONDITION { ++ t.Errorf("Expected FAILED_PRECONDITION status, got %v", ufErr.Status) ++ } ++} ++ ++func TestInMemoryStreamManager_DoneAfterClose(t *testing.T) { ++ m := NewInMemoryStreamManager() ++ defer m.Close() ++ ++ ctx := context.Background() ++ streamID := "test-stream-done-after-close" ++ ++ writer, err := m.Open(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Open failed: %v", err) ++ } ++ ++ if err := writer.Close(); err != nil { ++ t.Fatalf("Close failed: %v", err) ++ } ++ ++ // Try to done after close ++ err = writer.Done(ctx, json.RawMessage(`"done"`)) ++ if err == nil { ++ t.Fatal("Expected error when calling Done after Close") ++ } ++ ++ var ufErr *core.UserFacingError ++ if !errors.As(err, &ufErr) { ++ t.Fatalf("Expected UserFacingError, got %T", err) ++ } ++ if ufErr.Status != core.FAILED_PRECONDITION { ++ t.Errorf("Expected FAILED_PRECONDITION status, got %v", ufErr.Status) ++ } ++} +diff --git a/go/genkit/doc.go b/go/genkit/doc.go +new file mode 100644 +index 000000000..f72b7c1e7 +--- /dev/null ++++ b/go/genkit/doc.go +@@ -0,0 +1,408 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++// ++// SPDX-License-Identifier: Apache-2.0 ++ ++/* ++Package genkit provides a framework for building AI-powered applications in Go. ++ ++Genkit is an open-source framework that helps you build, deploy, and monitor ++production-ready AI features. It provides a unified interface for working with ++large language models (LLMs), managing prompts, defining workflows, and integrating ++with various AI service providers. ++ ++For comprehensive documentation, tutorials, and examples, visit https://genkit.dev ++ ++# Getting Started ++ ++Initialize Genkit with a plugin to connect to an AI provider: ++ ++ ctx := context.Background() ++ g := genkit.Init(ctx, ++ genkit.WithPlugins(&googlegenai.GoogleAI{}), ++ ) ++ ++Generate text with a simple prompt: ++ ++ text, err := genkit.GenerateText(ctx, g, ++ ai.WithModelName("googleai/gemini-2.5-flash"), ++ ai.WithPrompt("Tell me a joke"), ++ ) ++ if err != nil { ++ log.Fatal(err) ++ } ++ fmt.Println(text) ++ ++# Models ++ ++Models represent AI language models that generate content. Use plugins to access ++models from providers like Google AI, Vertex AI, Anthropic, or Ollama. Models are ++referenced by name and can include provider-specific configuration: ++ ++ resp, err := genkit.Generate(ctx, g, ++ ai.WithModelName("googleai/gemini-2.5-flash"), ++ ai.WithPrompt("Explain quantum computing in simple terms"), ++ ) ++ ++You can set a default model during initialization: ++ ++ g := genkit.Init(ctx, ++ genkit.WithPlugins(&googlegenai.GoogleAI{}), ++ genkit.WithDefaultModel("googleai/gemini-2.5-flash"), ++ ) ++ ++# Flows ++ ++Flows are reusable, observable functions that orchestrate AI operations. They ++provide automatic tracing, can be exposed as HTTP endpoints, and support both ++streaming and non-streaming execution. ++ ++Define a simple flow: ++ ++ jokesFlow := genkit.DefineFlow(g, "jokesFlow", ++ func(ctx context.Context, topic string) (string, error) { ++ return genkit.GenerateText(ctx, g, ++ ai.WithPrompt("Share a joke about %s.", topic), ++ ) ++ }, ++ ) ++ ++ joke, err := jokesFlow.Run(ctx, "programming") ++ ++Define a streaming flow that sends chunks as they're generated: ++ ++ streamingFlow := genkit.DefineStreamingFlow(g, "streamingJokes", ++ func(ctx context.Context, topic string, sendChunk ai.ModelStreamCallback) (string, error) { ++ resp, err := genkit.Generate(ctx, g, ++ ai.WithPrompt("Share a joke about %s.", topic), ++ ai.WithStreaming(sendChunk), ++ ) ++ if err != nil { ++ return "", err ++ } ++ return resp.Text(), nil ++ }, ++ ) ++ ++Use [Run] within flows to create traced sub-steps for observability: ++ ++ genkit.DefineFlow(g, "pipeline", ++ func(ctx context.Context, input string) (string, error) { ++ result, err := genkit.Run(ctx, "processStep", func() (string, error) { ++ return process(input), nil ++ }) ++ return result, err ++ }, ++ ) ++ ++# Prompts ++ ++Prompts can be defined programmatically or loaded from .prompt files (Dotprompt format). ++They encapsulate model configuration, input schemas, and template logic for reuse. ++ ++Define a prompt in code: ++ ++ jokePrompt := genkit.DefinePrompt(g, "joke", ++ ai.WithModelName("googleai/gemini-2.5-flash"), ++ ai.WithInputType(JokeRequest{Topic: "default topic"}), ++ ai.WithPrompt("Share a joke about {{topic}}."), ++ ) ++ ++ stream := jokePrompt.ExecuteStream(ctx, ai.WithInput(map[string]any{"topic": "cats"})) ++ for result, err := range stream { ++ if err != nil { ++ return err ++ } ++ if result.Done { ++ fmt.Println(result.Response.Text()) ++ } ++ } ++ ++For type-safe prompts with structured input and output, use [DefineDataPrompt]: ++ ++ type RecipeRequest struct { ++ Cuisine string `json:"cuisine"` ++ Dish string `json:"dish"` ++ ServingSize int `json:"servingSize"` ++ } ++ ++ type Recipe struct { ++ Title string `json:"title"` ++ Ingredients []string `json:"ingredients"` ++ Instructions []string `json:"instructions"` ++ } ++ ++ recipePrompt := genkit.DefineDataPrompt[RecipeRequest, *Recipe](g, "recipe", ++ ai.WithSystem("You are an experienced chef."), ++ ai.WithPrompt("Create a {{cuisine}} {{dish}} recipe for {{servingSize}} people."), ++ ) ++ ++ for result, err := range recipePrompt.ExecuteStream(ctx, RecipeRequest{ ++ Cuisine: "Italian", Dish: "pasta", ServingSize: 4, ++ }) { ++ // result.Chunk is *Recipe, result.Output is final *Recipe ++ } ++ ++Load prompts from .prompt files by specifying a prompt directory: ++ ++ g := genkit.Init(ctx, ++ genkit.WithPlugins(&googlegenai.GoogleAI{}), ++ genkit.WithPromptDir("./prompts"), ++ ) ++ ++ // Look up a loaded prompt ++ jokePrompt := genkit.LookupPrompt(g, "joke") ++ ++ // Or with type parameters for structured I/O ++ recipePrompt := genkit.LookupDataPrompt[RecipeRequest, *Recipe](g, "recipe") ++ ++When using .prompt files with custom output schemas, register the schema first: ++ ++ genkit.DefineSchemaFor[Recipe](g) ++ ++# Tools ++ ++Tools extend model capabilities by allowing them to call functions during generation. ++Define tools that the model can invoke to perform actions or retrieve information: ++ ++ weatherTool := genkit.DefineTool(g, "getWeather", ++ "Gets the current weather for a city", ++ func(ctx *ai.ToolContext, city string) (string, error) { ++ // Fetch weather data... ++ return "Sunny, 72°F", nil ++ }, ++ ) ++ ++ resp, err := genkit.Generate(ctx, g, ++ ai.WithPrompt("What's the weather in Paris?"), ++ ai.WithTools(weatherTool), ++ ) ++ ++# Structured Output ++ ++Generate structured data that conforms to Go types using [GenerateData] or ++[GenerateDataStream]. Use jsonschema struct tags to provide descriptions and ++constraints that help the model understand the expected output: ++ ++ type Joke struct { ++ Joke string `json:"joke" jsonschema:"description=The joke text"` ++ Category string `json:"category" jsonschema:"description=The joke category"` ++ } ++ ++ joke, resp, err := genkit.GenerateData[*Joke](ctx, g, ++ ai.WithPrompt("Tell me a programming joke"), ++ ) ++ ++For streaming structured output: ++ ++ stream := genkit.GenerateDataStream[*Recipe](ctx, g, ++ ai.WithPrompt("Create a pasta recipe"), ++ ) ++ for result, err := range stream { ++ if err != nil { ++ return nil, err ++ } ++ if result.Done { ++ return result.Output, nil ++ } ++ // result.Chunk contains partial Recipe as it streams ++ fmt.Printf("Got %d ingredients so far\n", len(result.Chunk.Ingredients)) ++ } ++ ++# Streaming ++ ++Genkit supports streaming at multiple levels. Use [GenerateStream] for streaming ++model responses: ++ ++ stream := genkit.GenerateStream(ctx, g, ++ ai.WithPrompt("Write a short story"), ++ ) ++ for result, err := range stream { ++ if err != nil { ++ log.Fatal(err) ++ } ++ if result.Done { ++ fmt.Println("\n--- Complete ---") ++ } else { ++ fmt.Print(result.Chunk.Text()) ++ } ++ } ++ ++Use [DefineStreamingFlow] for flows that stream custom data types: ++ ++ genkit.DefineStreamingFlow(g, "countdown", ++ func(ctx context.Context, count int, sendChunk func(context.Context, int) error) (string, error) { ++ for i := count; i > 0; i-- { ++ if err := sendChunk(ctx, i); err != nil { ++ return "", err ++ } ++ time.Sleep(time.Second) ++ } ++ return "Liftoff!", nil ++ }, ++ ) ++ ++# Development Mode and Dev UI ++ ++Set GENKIT_ENV=dev to enable development features including the Reflection API ++server that powers the Genkit Developer UI: ++ ++ $ export GENKIT_ENV=dev ++ $ go run main.go ++ ++Then run the Dev UI to inspect flows, test prompts, and view traces: ++ ++ $ npx genkit start -- go run main.go ++ ++The Dev UI provides: ++ - Interactive flow testing with input/output inspection ++ - Prompt playground for iterating on prompts ++ - Trace viewer for debugging and performance analysis ++ - Action browser for exploring registered actions ++ ++# HTTP Server Integration ++ ++Expose flows as HTTP endpoints for production deployment using [Handler]: ++ ++ mux := http.NewServeMux() ++ for _, flow := range genkit.ListFlows(g) { ++ mux.HandleFunc("POST /"+flow.Name(), genkit.Handler(flow)) ++ } ++ log.Fatal(server.Start(ctx, "127.0.0.1:8080", mux)) ++ ++Handlers support streaming responses via Server-Sent Events when the client ++sends Accept: text/event-stream. For durable streaming that survives reconnects, ++use [WithStreamManager]: ++ ++ mux.HandleFunc("POST /countdown", genkit.Handler(countdown, ++ genkit.WithStreamManager(streaming.NewInMemoryStreamManager( ++ streaming.WithTTL(10*time.Minute), ++ )), ++ )) ++ ++# Plugins ++ ++Genkit's functionality is extended through plugins that provide models, tools, ++retrievers, and other capabilities. Common plugins include: ++ ++ - googlegenai: Google AI (Gemini models) ++ - vertexai: Google Cloud Vertex AI ++ - ollama: Local Ollama models ++ ++Initialize plugins during [Init]: ++ ++ g := genkit.Init(ctx, ++ genkit.WithPlugins( ++ &googlegenai.GoogleAI{}, ++ &vertexai.VertexAI{ProjectID: "my-project"}, ++ ), ++ ) ++ ++# Messages and Parts ++ ++Build conversation messages using helper functions from the [ai] package. These ++are used with [ai.WithMessages] or when building custom conversation flows: ++ ++ // Create messages for a conversation ++ messages := []*ai.Message{ ++ ai.NewSystemTextMessage("You are a helpful assistant."), ++ ai.NewUserTextMessage("Hello!"), ++ ai.NewModelTextMessage("Hi there! How can I help?"), ++ } ++ ++ resp, err := genkit.Generate(ctx, g, ++ ai.WithMessages(messages...), ++ ai.WithPrompt("What can you do?"), ++ ) ++ ++For multi-modal content, combine text and media parts: ++ ++ userMsg := ai.NewUserMessage( ++ ai.NewTextPart("What's in this image?"), ++ ai.NewMediaPart("image/png", base64ImageData), ++ ) ++ ++Available message constructors in the [ai] package: ++ ++ - [ai.NewUserTextMessage], [ai.NewUserMessage]: User messages ++ - [ai.NewModelTextMessage], [ai.NewModelMessage]: Model responses ++ - [ai.NewSystemTextMessage], [ai.NewSystemMessage]: System instructions ++ ++Available part constructors in the [ai] package: ++ ++ - [ai.NewTextPart]: Text content ++ - [ai.NewMediaPart]: Images, audio, video (base64-encoded) ++ - [ai.NewDataPart]: Raw data strings ++ - [ai.NewToolRequestPart], [ai.NewToolResponsePart]: Tool interactions ++ ++# Generation Options ++ ++Generation functions ([Generate], [GenerateText], [GenerateData], [GenerateStream]) ++accept options from the [ai] package to control behavior. The most common options: ++ ++Model and Configuration: ++ ++ - [ai.WithModel]: Specify the model (accepts [ai.ModelRef] or plugin model refs) ++ - [ai.WithModelName]: Specify model by name string (e.g., "googleai/gemini-2.5-flash") ++ - [ai.WithConfig]: Set generation parameters (temperature, max tokens, etc.) ++ ++Prompting: ++ ++ - [ai.WithPrompt]: Set the user prompt (supports format strings) ++ - [ai.WithSystem]: Set system instructions ++ - [ai.WithMessages]: Provide conversation history ++ ++Tools and Output: ++ ++ - [ai.WithTools]: Enable tools the model can call ++ - [ai.WithOutputType]: Request structured output matching a Go type ++ - [ai.WithOutputFormat]: Specify output format (json, text, etc.) ++ ++Streaming: ++ ++ - [ai.WithStreaming]: Enable streaming with a callback function ++ ++Example combining multiple options: ++ ++ resp, err := genkit.Generate(ctx, g, ++ ai.WithModelName("googleai/gemini-2.5-flash"), ++ ai.WithSystem("You are a helpful coding assistant."), ++ ai.WithMessages(conversationHistory...), ++ ai.WithPrompt("Explain this code: %s", code), ++ ai.WithTools(searchTool, calculatorTool), ++ // Config is provider-specific (e.g., genai.GenerateContentConfig for Google AI) ++ ) ++ ++# Unregistered Components ++ ++For advanced use cases, the [ai] package provides New* functions to create ++components without registering them in Genkit. This is useful for plugins ++or when you need to pass components directly: ++ ++ - [ai.NewTool]: Create an unregistered tool ++ - [ai.NewModel]: Create an unregistered model ++ - [ai.NewRetriever]: Create an unregistered retriever ++ - [ai.NewEmbedder]: Create an unregistered embedder ++ ++Use the corresponding Define* functions in this package to create and register ++components for use with Genkit's action system, tracing, and Dev UI. ++ ++# Additional Resources ++ ++ - Documentation: https://genkit.dev ++ - Go Getting Started: https://genkit.dev/go/docs/get-started-go ++ - Samples: https://github.com/firebase/genkit/tree/main/go/samples ++ - GitHub: https://github.com/firebase/genkit ++*/ ++package genkit +diff --git a/go/genkit/example_test.go b/go/genkit/example_test.go +new file mode 100644 +index 000000000..917e8dc49 +--- /dev/null ++++ b/go/genkit/example_test.go +@@ -0,0 +1,322 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++// ++// SPDX-License-Identifier: Apache-2.0 ++ ++package genkit_test ++ ++import ( ++ "context" ++ "fmt" ++ "log" ++ "net/http" ++ "strings" ++ ++ "github.com/firebase/genkit/go/ai" ++ "github.com/firebase/genkit/go/core" ++ "github.com/firebase/genkit/go/genkit" ++) ++ ++// This example shows basic initialization and flow definition. ++func Example() { ++ ctx := context.Background() ++ ++ // Initialize Genkit (without plugins for this example) ++ g := genkit.Init(ctx) ++ ++ // Define a simple flow ++ greetFlow := genkit.DefineFlow(g, "greet", ++ func(ctx context.Context, name string) (string, error) { ++ return fmt.Sprintf("Hello, %s!", name), nil ++ }, ++ ) ++ ++ // Run the flow ++ greeting, err := greetFlow.Run(ctx, "World") ++ if err != nil { ++ log.Fatal(err) ++ } ++ fmt.Println(greeting) ++ // Output: Hello, World! ++} ++ ++// This example demonstrates defining a simple non-streaming flow. ++func ExampleDefineFlow() { ++ ctx := context.Background() ++ g := genkit.Init(ctx) ++ ++ // Define a flow that processes input ++ uppercaseFlow := genkit.DefineFlow(g, "uppercase", ++ func(ctx context.Context, input string) (string, error) { ++ return strings.ToUpper(input), nil ++ }, ++ ) ++ ++ // Run the flow ++ result, err := uppercaseFlow.Run(ctx, "hello") ++ if err != nil { ++ log.Fatal(err) ++ } ++ fmt.Println(result) ++ // Output: HELLO ++} ++ ++// This example demonstrates defining a streaming flow that sends ++// chunks to the caller as they are produced. ++func ExampleDefineStreamingFlow() { ++ ctx := context.Background() ++ g := genkit.Init(ctx) ++ ++ // Define a streaming flow that counts down ++ countdownFlow := genkit.DefineStreamingFlow(g, "countdown", ++ func(ctx context.Context, start int, sendChunk func(context.Context, int) error) (string, error) { ++ for i := start; i > 0; i-- { ++ if err := sendChunk(ctx, i); err != nil { ++ return "", err ++ } ++ } ++ return "Liftoff!", nil ++ }, ++ ) ++ ++ // Stream results using the iterator ++ iter := countdownFlow.Stream(ctx, 3) ++ iter(func(val *core.StreamingFlowValue[string, int], err error) bool { ++ if err != nil { ++ log.Fatal(err) ++ } ++ if val.Done { ++ fmt.Println("Final:", val.Output) ++ } else { ++ fmt.Println("Count:", val.Stream) ++ } ++ return true ++ }) ++ // Output: ++ // Count: 3 ++ // Count: 2 ++ // Count: 1 ++ // Final: Liftoff! ++} ++ ++// This example demonstrates using Run to create traced sub-steps ++// within a flow for better observability. ++func ExampleRun() { ++ ctx := context.Background() ++ g := genkit.Init(ctx) ++ ++ // Define a flow with traced sub-steps ++ pipelineFlow := genkit.DefineFlow(g, "pipeline", ++ func(ctx context.Context, input string) (string, error) { ++ // Each Run call creates a traced step visible in the Dev UI ++ upper, err := genkit.Run(ctx, "uppercase", func() (string, error) { ++ return strings.ToUpper(input), nil ++ }) ++ if err != nil { ++ return "", err ++ } ++ ++ result, err := genkit.Run(ctx, "addPrefix", func() (string, error) { ++ return "Processed: " + upper, nil ++ }) ++ return result, err ++ }, ++ ) ++ ++ result, err := pipelineFlow.Run(ctx, "hello") ++ if err != nil { ++ log.Fatal(err) ++ } ++ fmt.Println(result) ++ // Output: Processed: HELLO ++} ++ ++// This example demonstrates defining a tool that models can call ++// during generation. ++func ExampleDefineTool() { ++ ctx := context.Background() ++ g := genkit.Init(ctx) ++ ++ // Define a tool that adds two numbers ++ _ = genkit.DefineTool(g, "add", ++ "Adds two numbers together", ++ func(ctx *ai.ToolContext, input struct { ++ A float64 `json:"a" jsonschema:"description=First number"` ++ B float64 `json:"b" jsonschema:"description=Second number"` ++ }) (float64, error) { ++ return input.A + input.B, nil ++ }, ++ ) ++ ++ // The tool is now registered and can be used with ai.WithTools() ++ // when calling genkit.Generate() ++ fmt.Println("Tool registered: add") ++ // Output: Tool registered: add ++} ++ ++// This example demonstrates defining a reusable prompt with a template. ++func ExampleDefinePrompt() { ++ ctx := context.Background() ++ g := genkit.Init(ctx) ++ ++ // Define a prompt with Handlebars template syntax ++ prompt := genkit.DefinePrompt(g, "greeting", ++ ai.WithPrompt("Say hello to {{name}} in a {{style}} way."), ++ ) ++ ++ // Render the prompt (without executing - useful for inspection) ++ rendered, err := prompt.Render(ctx, map[string]any{ ++ "name": "Alice", ++ "style": "friendly", ++ }) ++ if err != nil { ++ log.Fatal(err) ++ } ++ // The rendered prompt contains the messages that would be sent ++ fmt.Println(rendered.Messages[0].Content[0].Text) ++ // Output: Say hello to Alice in a friendly way. ++} ++ ++// This example demonstrates registering a Go type as a named schema. ++func ExampleDefineSchemaFor() { ++ ctx := context.Background() ++ g := genkit.Init(ctx) ++ ++ // Define a struct type ++ type Person struct { ++ Name string `json:"name" jsonschema:"description=The person's name"` ++ Age int `json:"age" jsonschema:"description=The person's age"` ++ } ++ ++ // Register the schema - this makes it available for .prompt files ++ // that reference it by name (e.g., "output: { schema: Person }") ++ genkit.DefineSchemaFor[Person](g) ++ ++ fmt.Println("Schema registered: Person") ++ // Output: Schema registered: Person ++} ++ ++// This example demonstrates creating an HTTP server that exposes ++// all registered flows as endpoints. ++func ExampleListFlows_httpServer() { ++ ctx := context.Background() ++ g := genkit.Init(ctx) ++ ++ // Define some flows ++ genkit.DefineFlow(g, "echo", func(ctx context.Context, s string) (string, error) { ++ return s, nil ++ }) ++ ++ genkit.DefineFlow(g, "reverse", func(ctx context.Context, s string) (string, error) { ++ runes := []rune(s) ++ for i, j := 0, len(runes)-1; i < j; i, j = i+1, j-1 { ++ runes[i], runes[j] = runes[j], runes[i] ++ } ++ return string(runes), nil ++ }) ++ ++ // Create HTTP handlers for all flows ++ mux := http.NewServeMux() ++ for _, flow := range genkit.ListFlows(g) { ++ mux.HandleFunc("POST /"+flow.Name(), genkit.Handler(flow)) ++ } ++ ++ // The mux now has: ++ // - POST /echo ++ // - POST /reverse ++ fmt.Printf("Registered %d flow handlers\n", len(genkit.ListFlows(g))) ++ // Output: Registered 2 flow handlers ++} ++ ++// This example demonstrates using Handler to expose a single flow ++// as an HTTP endpoint. ++func ExampleHandler() { ++ ctx := context.Background() ++ g := genkit.Init(ctx) ++ ++ // Define a flow ++ greetFlow := genkit.DefineFlow(g, "greet", ++ func(ctx context.Context, name string) (string, error) { ++ return fmt.Sprintf("Hello, %s!", name), nil ++ }, ++ ) ++ ++ // Create an HTTP handler for the flow ++ mux := http.NewServeMux() ++ mux.HandleFunc("POST /greet", genkit.Handler(greetFlow)) ++ ++ // The handler accepts JSON: {"data": "World"} ++ // and returns JSON: {"result": "Hello, World!"} ++ fmt.Println("Handler registered at POST /greet") ++ // Output: Handler registered at POST /greet ++} ++ ++// This example demonstrates using type-safe data prompts with ++// strongly-typed input and output. ++func ExampleDefineDataPrompt() { ++ ctx := context.Background() ++ g := genkit.Init(ctx) ++ ++ // Define input and output types ++ type JokeRequest struct { ++ Topic string `json:"topic"` ++ } ++ ++ type Joke struct { ++ Setup string `json:"setup"` ++ Punchline string `json:"punchline"` ++ } ++ ++ // Define a type-safe prompt ++ // Note: In production, you'd also set ai.WithModel(...) ++ _ = genkit.DefineDataPrompt[JokeRequest, *Joke](g, "joke", ++ ai.WithPrompt("Tell a joke about {{topic}}. Return JSON with setup and punchline."), ++ ) ++ ++ // The prompt can now be executed with: ++ // for result, err := range jokePrompt.ExecuteStream(ctx, JokeRequest{Topic: "cats"}) { ++ // if result.Done { ++ // fmt.Println(result.Output.Setup) ++ // fmt.Println(result.Output.Punchline) ++ // } ++ // } ++ ++ fmt.Println("DataPrompt registered: joke") ++ // Output: DataPrompt registered: joke ++} ++ ++// This example demonstrates looking up a prompt that was loaded ++// from a .prompt file. ++func ExampleLookupPrompt() { ++ ctx := context.Background() ++ ++ // In production, you would initialize with a prompt directory: ++ // g := genkit.Init(ctx, genkit.WithPromptDir("./prompts")) ++ ++ g := genkit.Init(ctx) ++ ++ // Define a prompt programmatically (simulating a loaded prompt) ++ genkit.DefinePrompt(g, "greeting", ++ ai.WithPrompt("Hello {{name}}!"), ++ ) ++ ++ // Look up the prompt by name ++ prompt := genkit.LookupPrompt(g, "greeting") ++ if prompt == nil { ++ log.Fatal("Prompt not found") ++ } ++ ++ fmt.Println("Found prompt:", prompt.Name()) ++ // Output: Found prompt: greeting ++} +diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go +index 75ef8c9a8..83429ca0d 100644 +--- a/go/genkit/genkit.go ++++ b/go/genkit/genkit.go +@@ -21,6 +21,8 @@ import ( + "context" + "errors" + "fmt" ++ "io/fs" ++ "iter" + "log/slog" + "os" + "os/signal" +@@ -46,6 +48,7 @@ type Genkit struct { + type genkitOptions struct { + DefaultModel string // Default model to use if no other model is specified. + PromptDir string // Directory where dotprompts are stored. Will be loaded automatically on initialization. ++ PromptFS fs.FS // Embedded filesystem containing prompts (alternative to PromptDir). + Plugins []api.Plugin // Plugin to initialize automatically. + } + +@@ -66,6 +69,20 @@ func (o *genkitOptions) apply(gOpts *genkitOptions) error { + if gOpts.PromptDir != "" { + return errors.New("cannot set prompt directory more than once (WithPromptDir)") + } ++ if gOpts.PromptFS != nil { ++ return errors.New("cannot use WithPromptDir together with WithPromptFS") ++ } ++ gOpts.PromptDir = o.PromptDir ++ } ++ ++ if o.PromptFS != nil { ++ if gOpts.PromptFS != nil { ++ return errors.New("cannot set prompt filesystem more than once (WithPromptFS)") ++ } ++ if gOpts.PromptDir != "" { ++ return errors.New("cannot use WithPromptFS together with WithPromptDir") ++ } ++ gOpts.PromptFS = o.PromptFS + gOpts.PromptDir = o.PromptDir + } + +@@ -99,13 +116,44 @@ func WithDefaultModel(model string) GenkitOption { + // The default directory is "prompts" relative to the project root where + // [Init] is called. + // ++// When used with [WithPromptFS], this directory serves as the root path within ++// the embedded filesystem instead of a local disk path. For example, if using ++// `//go:embed prompts/*`, set the directory to "prompts" to match. ++// + // Invalid prompt files will result in logged errors during initialization, + // while valid files that define invalid prompts will cause [Init] to panic. +-// This option can only be applied once. + func WithPromptDir(dir string) GenkitOption { + return &genkitOptions{PromptDir: dir} + } + ++// WithPromptFS specifies an embedded filesystem ([fs.FS]) containing `.prompt` files. ++// This is useful for embedding prompts directly into the binary using Go's [embed] package, ++// eliminating the need to distribute prompt files separately. ++// ++// The `fsys` parameter should be an [fs.FS] implementation (e.g., [embed.FS]). ++// Use [WithPromptDir] to specify the root directory within the filesystem where ++// prompts are located (defaults to "prompts"). ++// ++// Example: ++// ++// import "embed" ++// ++// //go:embed prompts/* ++// var promptsFS embed.FS ++// ++// func main() { ++// g := genkit.Init(ctx, ++// genkit.WithPromptFS(promptsFS), ++// genkit.WithPromptDir("prompts"), ++// ) ++// } ++// ++// Invalid prompt files will result in logged errors during initialization, ++// while valid files that define invalid prompts will cause [Init] to panic. ++func WithPromptFS(fsys fs.FS) GenkitOption { ++ return &genkitOptions{PromptFS: fsys} ++} ++ + // Init creates and initializes a new [Genkit] instance with the provided options. + // It sets up the registry, initializes plugins ([WithPlugins]), loads prompts + // ([WithPromptDir]), and configures other settings like the default model +@@ -184,7 +232,15 @@ func Init(ctx context.Context, opts ...GenkitOption) *Genkit { + + ai.ConfigureFormats(r) + ai.DefineGenerateAction(ctx, r) +- ai.LoadPromptDir(r, gOpts.PromptDir, "") ++ if gOpts.PromptFS != nil { ++ dir := gOpts.PromptDir ++ if dir == "" { ++ dir = "prompts" ++ } ++ ai.LoadPromptDirFromFS(r, gOpts.PromptFS, dir, "") ++ } else { ++ loadPromptDirOS(r, gOpts.PromptDir, "") ++ } + + r.RegisterValue(api.DefaultModelKey, gOpts.DefaultModel) + r.RegisterValue(api.PromptDirKey, gOpts.PromptDir) +@@ -268,7 +324,7 @@ func DefineFlow[In, Out any](g *Genkit, name string, fn core.Func[In, Out]) *cor + // Example: + // + // counterFlow := genkit.DefineStreamingFlow(g, "counter", +-// func(ctx context.Context, limit int, stream func(context.Context, int) error) (string, error) { ++// func(ctx context.Context, limit int, stream core.StreamCallback[int]) (string, error) { + // if stream == nil { // Non-streaming case + // return fmt.Sprintf("Counted up to %d", limit), nil + // } +@@ -371,12 +427,14 @@ func ListTools(g *Genkit) []ai.Tool { + // DefineModel defines a custom model implementation, registers it as a [core.Action] + // of type Model, and returns an [ai.Model] interface. + // +-// The `provider` and `name` arguments form the unique identifier for the model +-// (e.g., "myProvider/myModel"). The `info` argument provides metadata about the +-// model's capabilities ([ai.ModelInfo]). The `fn` argument ([ai.ModelFunc]) +-// implements the actual generation logic, handling input requests ([ai.ModelRequest]) +-// and producing responses ([ai.ModelResponse]), potentially streaming chunks +-// ([ai.ModelResponseChunk]) via the callback. ++// The `name` argument is the unique identifier for the model (e.g., "myProvider/myModel"). ++// The `opts` argument provides metadata about the model's capabilities ([ai.ModelOptions]). ++// The `fn` argument ([ai.ModelFunc]) implements the actual generation logic, handling ++// input requests ([ai.ModelRequest]) and producing responses ([ai.ModelResponse]), ++// potentially streaming chunks ([ai.ModelResponseChunk]) via the callback. ++// ++// For models that don't need to be registered (e.g., for plugin development or testing), ++// use [ai.NewModel] instead. + // + // Example: + // +@@ -454,7 +512,7 @@ func LookupBackgroundModel(g *Genkit, name string) ai.BackgroundModel { + } + + // DefineTool defines a tool that can be used by models during generation, +-// registers it as a [core.Action] of type Tool, and returns an [ai.ToolDef]. ++// registers it as a [core.Action] of type Tool, and returns an [ai.Tool]. + // Tools allow models to interact with external systems or perform specific computations. + // + // The `name` is the identifier the model uses to request the tool. The `description` +@@ -464,7 +522,13 @@ func LookupBackgroundModel(g *Genkit, name string) ai.BackgroundModel { + // `inputSchema` and `outputSchema` in the tool's definition, which guide the model + // on how to provide input and interpret output. + // +-// Use [ai.WithInputSchema] to provide a custom JSON schema instead of inferring from the type parameter. ++// For tools that don't need to be registered (e.g., dynamically created tools), ++// use [ai.NewTool] instead. ++// ++// # Options ++// ++// - [ai.WithInputSchema]: Provide a custom JSON schema instead of inferring from the type parameter ++// - [ai.WithInputSchemaName]: Reference a pre-registered schema by name + // + // Example: + // +@@ -507,38 +571,6 @@ func DefineTool[In, Out any](g *Genkit, name, description string, fn ai.ToolFunc + // input of type `any`, and returning an output of type `Out`. + // + // Deprecated: Use [DefineTool] with [ai.WithInputSchema] instead. +-// +-// Example: +-// +-// // Define a custom input schema +-// inputSchema := map[string]any{ +-// "type": "object", +-// "properties": map[string]any{ +-// "city": map[string]any{"type": "string"}, +-// "unit": map[string]any{ +-// "type": "string", +-// "enum": []any{"C", "F"}, +-// }, +-// }, +-// "required": []string{"city"}, +-// } +-// +-// // Define the tool with the schema +-// weatherTool := genkit.DefineTool(g, "getWeather", +-// "Fetches the weather for a given city with unit preference", +-// func(ctx *ai.ToolContext, input any) (string, error) { +-// // Parse and validate input +-// data := input.(map[string]any) +-// city := data["city"].(string) +-// unit := "C" // default +-// if u, ok := data["unit"].(string); ok { +-// unit = u +-// } +-// // Implementation... +-// return fmt.Sprintf("Weather in %s: 25°%s", city, unit), nil +-// }, +-// ai.WithToolInputSchema(inputSchema), +-// ) + func DefineToolWithInputSchema[Out any](g *Genkit, name, description string, inputSchema map[string]any, fn ai.ToolFunc[any, Out]) ai.Tool { + return ai.DefineTool(g.reg, name, description, fn, ai.WithInputSchema(inputSchema)) + } +@@ -554,7 +586,13 @@ func DefineToolWithInputSchema[Out any](g *Genkit, name, description string, inp + // returning an [ai.MultipartToolResponse] which contains both the output and optional + // content parts. + // +-// Use [ai.WithInputSchema] to provide a custom JSON schema instead of inferring from the type parameter. ++// For multipart tools that don't need to be registered (e.g., dynamically created tools), ++// use [ai.NewMultipartTool] instead. ++// ++// # Options ++// ++// - [ai.WithInputSchema]: Provide a custom JSON schema instead of inferring from the type parameter ++// - [ai.WithInputSchemaName]: Reference a pre-registered schema by name + // + // Example: + // +@@ -605,18 +643,55 @@ func LookupTool(g *Genkit, name string) ai.Tool { + } + + // DefinePrompt defines a prompt programmatically, registers it as a [core.Action] +-// of type Prompt, and returns an executable [ai.prompt]. ++// of type Prompt, and returns an executable [ai.Prompt]. + // + // This provides an alternative to defining prompts in `.prompt` files, offering + // more flexibility through Go code. Prompts encapsulate configuration (model, parameters), + // message templates (system, user, history), input/output schemas, and associated tools. + // + // Prompts can be executed in two main ways: +-// 1. Render + Generate: Call [Prompt.Render] to get [ai.GenerateActionOptions], ++// 1. Render + Generate: Call [ai.Prompt.Render] to get [ai.GenerateActionOptions], + // modify them if needed, and pass them to [GenerateWithRequest]. +-// 2. Execute: Call [Prompt.Execute] directly, passing input and execution options. +-// +-// Options ([ai.PromptOption]) are used to configure the prompt during definition. ++// 2. Execute: Call [ai.Prompt.Execute] directly, passing input and execution options. ++// ++// For prompts that don't need to be registered (e.g., for single-use or testing), ++// use [ai.NewPrompt] instead. ++// ++// # Options ++// ++// Model and Configuration: ++// - [ai.WithModel]: Specify the model (accepts [ai.Model] or [ai.ModelRef]) ++// - [ai.WithModelName]: Specify model by name string ++// - [ai.WithConfig]: Set generation parameters (temperature, max tokens, etc.) ++// ++// Prompt Content: ++// - [ai.WithPrompt]: Set the user prompt template (supports {{variable}} syntax) ++// - [ai.WithPromptFn]: Set a function that generates the user prompt dynamically ++// - [ai.WithSystem]: Set system instructions template ++// - [ai.WithSystemFn]: Set a function that generates system instructions dynamically ++// - [ai.WithMessages]: Provide static conversation history ++// - [ai.WithMessagesFn]: Provide a function that generates conversation history ++// ++// Input Schema: ++// - [ai.WithInputType]: Set input schema from a Go type (provides default values) ++// - [ai.WithInputSchema]: Provide a custom JSON schema for input ++// - [ai.WithInputSchemaName]: Reference a pre-registered schema by name ++// ++// Output Schema: ++// - [ai.WithOutputType]: Set output schema from a Go type ++// - [ai.WithOutputSchema]: Provide a custom JSON schema for output ++// - [ai.WithOutputSchemaName]: Reference a pre-registered schema by name ++// - [ai.WithOutputFormat]: Specify output format (json, text, etc.) ++// ++// Tools and Resources: ++// - [ai.WithTools]: Enable tools the model can call ++// - [ai.WithToolChoice]: Control whether tool calls are required, optional, or disabled ++// - [ai.WithMaxTurns]: Set maximum tool call iterations ++// - [ai.WithResources]: Attach resources available during generation ++// ++// Metadata: ++// - [ai.WithDescription]: Set a description for the prompt ++// - [ai.WithMetadata]: Set arbitrary metadata + // + // Example: + // +@@ -631,12 +706,12 @@ func LookupTool(g *Genkit, name string) ai.Tool { + // // Define the prompt + // capitalPrompt := genkit.DefinePrompt(g, "findCapital", + // ai.WithDescription("Finds the capital of a country."), +-// ai.WithModelName("googleai/gemini-2.5-flash"), // Specify the model ++// ai.WithModelName("googleai/gemini-2.5-flash"), + // ai.WithSystem("You are a helpful geography assistant."), + // ai.WithPrompt("What is the capital of {{country}}?"), + // ai.WithInputType(GeoInput{Country: "USA"}), + // ai.WithOutputType(GeoOutput{}), +-// ai.WithConfig(&ai.GenerationCommonConfig{Temperature: 0.5}), ++// // Config is provider-specific, e.g., genai.GenerateContentConfig for Google AI + // ) + // + // // Option 1: Render + Generate (using default input "USA") +@@ -717,6 +792,50 @@ func DefineSchemaFor[T any](g *Genkit) { + core.DefineSchemaFor[T](g.reg) + } + ++// DefineDataPrompt creates a new [ai.DataPrompt] with strongly-typed input and output. ++// It automatically infers input schema from the In type parameter and configures ++// output schema and JSON format from the Out type parameter (unless Out is string). ++// ++// This is a convenience wrapper around [DefinePrompt] that provides compile-time ++// type safety for both input and output. For prompts that don't need to be registered, ++// use [ai.NewDataPrompt] instead. ++// ++// DefineDataPrompt accepts the same options as [DefinePrompt]. See [DefinePrompt] for ++// the full list of available options. Note that input and output schemas are automatically ++// inferred from the type parameters. ++// ++// Example: ++// ++// type GeoInput struct { ++// Country string `json:"country"` ++// } ++// ++// type GeoOutput struct { ++// Capital string `json:"capital"` ++// } ++// ++// capitalPrompt := genkit.DefineDataPrompt[GeoInput, GeoOutput](g, "findCapital", ++// ai.WithModelName("googleai/gemini-2.5-flash"), ++// ai.WithSystem("You are a helpful geography assistant."), ++// ai.WithPrompt("What is the capital of {{country}}?"), ++// ) ++// ++// output, resp, err := capitalPrompt.Execute(ctx, GeoInput{Country: "France"}) ++// if err != nil { ++// log.Fatalf("Execute failed: %v", err) ++// } ++// fmt.Printf("Capital: %s\n", output.Capital) ++func DefineDataPrompt[In, Out any](g *Genkit, name string, opts ...ai.PromptOption) *ai.DataPrompt[In, Out] { ++ return ai.DefineDataPrompt[In, Out](g.reg, name, opts...) ++} ++ ++// LookupDataPrompt looks up a prompt by name and wraps it with type information. ++// This is useful for wrapping prompts loaded from .prompt files with strong types. ++// It returns nil if the prompt was not found. ++func LookupDataPrompt[In, Out any](g *Genkit, name string) *ai.DataPrompt[In, Out] { ++ return ai.LookupDataPrompt[In, Out](g.reg, name) ++} ++ + // GenerateWithRequest performs a model generation request using explicitly provided + // [ai.GenerateActionOptions]. This function is typically used in conjunction with + // prompts defined via [DefinePrompt], where [ai.prompt.Render] produces the +@@ -734,8 +853,7 @@ func DefineSchemaFor[T any](g *Genkit) { + // // handle error + // } + // +-// // Optional: Modify actionOpts here if needed +-// // actionOpts.Config = &ai.GenerationCommonConfig{ Temperature: 0.8 } ++// // Optional: Modify actionOpts here if needed (config is provider-specific) + // + // resp, err := genkit.GenerateWithRequest(ctx, g, actionOpts, nil, nil) // No middleware or streaming + // if err != nil { +@@ -750,12 +868,50 @@ func GenerateWithRequest(ctx context.Context, g *Genkit, actionOpts *ai.Generate + // provided via [ai.GenerateOption] arguments. It's a convenient way to make + // generation calls without pre-defining a prompt object. + // ++// # Options ++// ++// Model and Configuration: ++// - [ai.WithModel]: Specify the model (accepts [ai.Model] or [ai.ModelRef]) ++// - [ai.WithModelName]: Specify model by name string (e.g., "googleai/gemini-2.5-flash") ++// - [ai.WithConfig]: Set generation parameters (temperature, max tokens, etc.) ++// ++// Prompting: ++// - [ai.WithPrompt]: Set the user prompt (supports format strings) ++// - [ai.WithPromptFn]: Set a function that generates the user prompt dynamically ++// - [ai.WithSystem]: Set system instructions ++// - [ai.WithSystemFn]: Set a function that generates system instructions dynamically ++// - [ai.WithMessages]: Provide conversation history ++// - [ai.WithMessagesFn]: Provide a function that generates conversation history ++// ++// Tools and Resources: ++// - [ai.WithTools]: Enable tools the model can call ++// - [ai.WithToolChoice]: Control whether tool calls are required, optional, or disabled ++// - [ai.WithMaxTurns]: Set maximum tool call iterations ++// - [ai.WithReturnToolRequests]: Return tool requests instead of executing them ++// - [ai.WithResources]: Attach resources available during generation ++// ++// Output: ++// - [ai.WithOutputType]: Request structured output matching a Go type ++// - [ai.WithOutputSchema]: Provide a custom JSON schema for output ++// - [ai.WithOutputSchemaName]: Reference a pre-registered schema by name ++// - [ai.WithOutputFormat]: Specify output format (json, text, etc.) ++// - [ai.WithOutputEnums]: Constrain output to specific enum values ++// ++// Context and Streaming: ++// - [ai.WithDocs]: Provide context documents ++// - [ai.WithTextDocs]: Provide context as text strings ++// - [ai.WithStreaming]: Enable streaming with a callback function ++// - [ai.WithMiddleware]: Apply middleware to the model request/response ++// ++// Tool Continuation: ++// - [ai.WithToolResponses]: Resume generation with tool response parts ++// - [ai.WithToolRestarts]: Resume generation by restarting tool requests ++// + // Example: + // + // resp, err := genkit.Generate(ctx, g, + // ai.WithModelName("googleai/gemini-2.5-flash"), + // ai.WithPrompt("Write a short poem about clouds."), +-// ai.WithConfig(&genai.GenerateContentConfig{MaxOutputTokens: 50}), + // ) + // if err != nil { + // log.Fatalf("Generate failed: %v", err) +@@ -766,12 +922,48 @@ func Generate(ctx context.Context, g *Genkit, opts ...ai.GenerateOption) (*ai.Mo + return ai.Generate(ctx, g.reg, opts...) + } + ++// GenerateStream generates a model response and streams the output. ++// It returns an iterator that yields streaming results. ++// ++// If the yield function is passed a non-nil error, generation has failed with that ++// error; the yield function will not be called again. ++// ++// If the yield function's [ai.ModelStreamValue] argument has Done == true, the value's ++// Response field contains the final response; the yield function will not be called again. ++// ++// Otherwise the Chunk field of the passed [ai.ModelStreamValue] holds a streamed chunk. ++// ++// GenerateStream accepts the same options as [Generate]. See [Generate] for the full ++// list of available options. ++// ++// Example: ++// ++// for result, err := range genkit.GenerateStream(ctx, g, ++// ai.WithPrompt("Tell me a story about a brave knight."), ++// ) { ++// if err != nil { ++// log.Fatalf("Stream error: %v", err) ++// } ++// if result.Done { ++// fmt.Println("\nFinal response:", result.Response.Text()) ++// } else { ++// fmt.Print(result.Chunk.Text()) ++// } ++// } ++func GenerateStream(ctx context.Context, g *Genkit, opts ...ai.GenerateOption) iter.Seq2[*ai.ModelStreamValue, error] { ++ return ai.GenerateStream(ctx, g.reg, opts...) ++} ++ + // GenerateOperation performs a model generation request using a flexible set of options +-// provided via [ai.GenerateOption] arguments. It's a convenient way to make +-// generation calls without pre-defining a prompt object. ++// provided via [ai.GenerateOption] arguments. It's designed for long-running generation ++// tasks that may not complete immediately. + // + // Unlike [Generate], this function returns a [ai.ModelOperation] which can be used to +-// check the status of the operation and get the result. ++// check the status of the operation and get the result. Use [CheckModelOperation] to ++// poll for completion. ++// ++// GenerateOperation accepts the same options as [Generate]. See [Generate] for the full ++// list of available options. + // + // Example: + // +@@ -807,7 +999,9 @@ func CheckModelOperation(ctx context.Context, g *Genkit, op *ai.ModelOperation) + // GenerateText performs a model generation request similar to [Generate], but + // directly returns the generated text content as a string. It's a convenience + // wrapper for cases where only the textual output is needed. +-// It accepts the same [ai.GenerateOption] arguments as [Generate]. ++// ++// GenerateText accepts the same options as [Generate]. See [Generate] for the full ++// list of available options. + // + // Example: + // +@@ -823,16 +1017,13 @@ func GenerateText(ctx context.Context, g *Genkit, opts ...ai.GenerateOption) (st + } + + // GenerateData performs a model generation request, expecting structured output +-// (typically JSON) that conforms to the schema of the provided `value` argument. +-// It attempts to unmarshal the model's response directly into the `value`. +-// The `value` argument must be a pointer to a struct or map. +-// +-// Use [ai.WithOutputType] or [ai.WithOutputFormat](ai.OutputFormatJSON) in the +-// options to instruct the model to generate JSON. [ai.WithOutputType] is preferred +-// as it infers the JSON schema from the `value` type and passes it to the model. ++// (typically JSON) that conforms to the schema inferred from the Out type parameter. ++// It automatically sets output type and JSON format, unmarshals the response, and ++// returns the typed result. + // +-// It returns the full [ai.ModelResponse] along with any error. The generated data +-// populates the `value` pointed to. ++// GenerateData accepts the same options as [Generate]. See [Generate] for the full ++// list of available options. Note that output options like [ai.WithOutputType] are ++// automatically applied based on the Out type parameter. + // + // Example: + // +@@ -854,15 +1045,62 @@ func GenerateData[Out any](ctx context.Context, g *Genkit, opts ...ai.GenerateOp + return ai.GenerateData[Out](ctx, g.reg, opts...) + } + ++// GenerateDataStream generates a model response with streaming and returns strongly-typed output. ++// It returns an iterator that yields streaming results. ++// ++// If the yield function is passed a non-nil error, generation has failed with that ++// error; the yield function will not be called again. ++// ++// If the yield function's [ai.StreamValue] argument has Done == true, the value's ++// Output and Response fields contain the final typed output and response; the yield function ++// will not be called again. ++// ++// Otherwise the Chunk field of the passed [ai.StreamValue] holds a streamed chunk. ++// ++// GenerateDataStream accepts the same options as [Generate]. See [Generate] for the full ++// list of available options. Note that output options are automatically applied based on ++// the Out type parameter. ++// ++// Example: ++// ++// type Story struct { ++// Title string `json:"title"` ++// Content string `json:"content"` ++// } ++// ++// for result, err := range genkit.GenerateDataStream[Story](ctx, g, ++// ai.WithPrompt("Write a short story about a brave knight."), ++// ) { ++// if err != nil { ++// log.Fatalf("Stream error: %v", err) ++// } ++// if result.Done { ++// fmt.Printf("Story: %+v\n", result.Output) ++// } else { ++// fmt.Print(result.Chunk.Text()) ++// } ++// } ++func GenerateDataStream[Out any](ctx context.Context, g *Genkit, opts ...ai.GenerateOption) iter.Seq2[*ai.StreamValue[Out, Out], error] { ++ return ai.GenerateDataStream[Out](ctx, g.reg, opts...) ++} ++ + // Retrieve performs a document retrieval request using a flexible set of options + // provided via [ai.RetrieverOption] arguments. It's a convenient way to retrieve + // relevant documents from registered retrievers without directly calling the + // retriever instance. + // ++// # Options ++// ++// - [ai.WithRetriever]: Specify the retriever (accepts [ai.Retriever] or [ai.RetrieverRef]) ++// - [ai.WithRetrieverName]: Specify retriever by name string ++// - [ai.WithConfig]: Set retriever-specific configuration ++// - [ai.WithTextDocs]: Provide query text as documents ++// - [ai.WithDocs]: Provide query as [ai.Document] instances ++// + // Example: + // + // resp, err := genkit.Retrieve(ctx, g, +-// ai.WithRetriever(ai.NewRetrieverRef("myRetriever", nil)), ++// ai.WithRetrieverName("myRetriever"), + // ai.WithTextDocs("What is the capital of France?"), + // ) + // if err != nil { +@@ -880,10 +1118,18 @@ func Retrieve(ctx context.Context, g *Genkit, opts ...ai.RetrieverOption) (*ai.R + // provided via [ai.EmbedderOption] arguments. It's a convenient way to generate + // embeddings from registered embedders without directly calling the embedder instance. + // ++// # Options ++// ++// - [ai.WithEmbedder]: Specify the embedder (accepts [ai.Embedder] or [ai.EmbedderRef]) ++// - [ai.WithEmbedderName]: Specify embedder by name string ++// - [ai.WithConfig]: Set embedder-specific configuration ++// - [ai.WithTextDocs]: Provide text to embed ++// - [ai.WithDocs]: Provide [ai.Document] instances to embed ++// + // Example: + // + // resp, err := genkit.Embed(ctx, g, +-// ai.WithEmbedder(ai.NewEmbedderRef("myEmbedder", nil)), ++// ai.WithEmbedderName("myEmbedder"), + // ai.WithTextDocs("Hello, world!"), + // ) + // if err != nil { +@@ -902,9 +1148,12 @@ func Embed(ctx context.Context, g *Genkit, opts ...ai.EmbedderOption) (*ai.Embed + // Retrievers are used to find documents relevant to a given query, often by + // performing similarity searches in a vector database. + // +-// The `provider` and `name` form the unique identifier. The `ret` function ++// The `name` is the unique identifier for the retriever. The `fn` function + // contains the logic to process an [ai.RetrieverRequest] (containing the query) + // and return an [ai.RetrieverResponse] (containing the relevant documents). ++// ++// For retrievers that don't need to be registered (e.g., for plugin development), ++// use [ai.NewRetriever] instead. + func DefineRetriever(g *Genkit, name string, opts *ai.RetrieverOptions, fn ai.RetrieverFunc) ai.Retriever { + return ai.DefineRetriever(g.reg, name, opts, fn) + } +@@ -920,9 +1169,12 @@ func LookupRetriever(g *Genkit, name string) ai.Retriever { + // [core.Action] of type Embedder, and returns an [ai.Embedder]. + // Embedders convert text documents or queries into numerical vector representations (embeddings). + // +-// The `provider` and `name` are specified in the `opts` parameter which forms the unique identifier. +-// The `embed` function contains the logic to process an [ai.EmbedRequest] (containing documents or a query) ++// The `name` is the unique identifier for the embedder. ++// The `fn` function contains the logic to process an [ai.EmbedRequest] (containing documents or a query) + // and return an [ai.EmbedResponse] (containing the corresponding embeddings). ++// ++// For embedders that don't need to be registered (e.g., for plugin development), ++// use [ai.NewEmbedder] instead. + func DefineEmbedder(g *Genkit, name string, opts *ai.EmbedderOptions, fn ai.EmbedderFunc) ai.Embedder { + return ai.DefineEmbedder(g.reg, name, opts, fn) + } +@@ -988,6 +1240,14 @@ func LookupEvaluator(g *Genkit, name string) ai.Evaluator { + // evaluations using registered evaluators without directly calling the + // evaluator instance. + // ++// # Options ++// ++// - [ai.WithEvaluator]: Specify the evaluator (accepts [ai.Evaluator] or [ai.EvaluatorRef]) ++// - [ai.WithEvaluatorName]: Specify evaluator by name string ++// - [ai.WithDataset]: Provide the dataset of examples to evaluate ++// - [ai.WithID]: Set a unique identifier for this evaluation run ++// - [ai.WithConfig]: Set evaluator-specific configuration ++// + // Example: + // + // dataset := []*ai.Example{ +@@ -998,8 +1258,8 @@ func LookupEvaluator(g *Genkit, name string) ai.Evaluator { + // } + // + // resp, err := genkit.Evaluate(ctx, g, +-// ai.WithEvaluator(ai.NewEvaluatorRef("myEvaluator", nil)), +-// ai.WithDataset(dataset), ++// ai.WithEvaluatorName("myEvaluator"), ++// ai.WithDataset(dataset...), + // ) + // if err != nil { + // log.Fatalf("Evaluate failed: %v", err) +@@ -1026,8 +1286,67 @@ func Evaluate(ctx context.Context, g *Genkit, opts ...ai.EvaluatorOption) (*ai.E + // This function is often called implicitly by [Init] using the directory specified + // by [WithPromptDir], but can be called explicitly to load prompts from other + // locations or with different namespaces. +-func LoadPromptDir(g *Genkit, dir string, namespace string) { +- ai.LoadPromptDir(g.reg, dir, namespace) ++func LoadPromptDir(g *Genkit, dir, namespace string) { ++ loadPromptDirOS(g.reg, dir, namespace) ++} ++ ++// loadPromptDirOS loads prompts from an OS directory by converting to os.DirFS. ++func loadPromptDirOS(r api.Registry, dir, namespace string) { ++ useDefaultDir := false ++ if dir == "" { ++ dir = "./prompts" ++ useDefaultDir = true ++ } ++ ++ absPath, err := filepath.Abs(dir) ++ if err != nil { ++ if !useDefaultDir { ++ panic(fmt.Errorf("failed to resolve prompt directory %q: %w", dir, err)) ++ } ++ slog.Debug("default prompt directory not found, skipping loading .prompt files", "dir", dir) ++ return ++ } ++ ++ if _, err := os.Stat(absPath); os.IsNotExist(err) { ++ if !useDefaultDir { ++ panic(fmt.Errorf("failed to resolve prompt directory %q: %w", dir, err)) ++ } ++ slog.Debug("Default prompt directory not found, skipping loading .prompt files", "dir", dir) ++ return ++ } ++ ++ ai.LoadPromptDirFromFS(r, os.DirFS(absPath), ".", namespace) ++} ++ ++// LoadPromptDirFromFS loads all `.prompt` files from the specified embedded filesystem `fsys` ++// into the registry, associating them with the given `namespace`. ++// Files starting with `_` are treated as partials and are not registered as ++// executable prompts but can be included in other prompts. ++// ++// The `fsys` parameter should be an [fs.FS] implementation (e.g., [embed.FS]). ++// The `dir` parameter specifies the directory within the filesystem where ++// prompts are located (e.g., "prompts" if using `//go:embed prompts/*`). ++// The `namespace` acts as a prefix to the prompt name (e.g., namespace "myApp" and ++// file "greeting.prompt" results in prompt name "myApp/greeting"). Use an empty ++// string for no namespace. ++// ++// This function provides an alternative to [LoadPromptDir] for loading prompts ++// from embedded filesystems, enabling self-contained binaries without external ++// prompt files. ++// ++// Example: ++// ++// import "embed" ++// ++// //go:embed prompts/* ++// var promptsFS embed.FS ++// ++// func main() { ++// g := genkit.Init(ctx) ++// genkit.LoadPromptDirFromFS(g, promptsFS, "prompts", "myNamespace") ++// } ++func LoadPromptDirFromFS(g *Genkit, fsys fs.FS, dir, namespace string) { ++ ai.LoadPromptDirFromFS(g.reg, fsys, dir, namespace) + } + + // LoadPrompt loads a single `.prompt` file specified by `path` into the registry, +@@ -1052,13 +1371,49 @@ func LoadPromptDir(g *Genkit, dir string, namespace string) { + // // Execute the loaded prompt + // resp, err := customPrompt.Execute(ctx, ai.WithInput(map[string]any{"text": "some data"})) + // // ... handle response and error ... +-func LoadPrompt(g *Genkit, path string, namespace string) ai.Prompt { ++func LoadPrompt(g *Genkit, path, namespace string) ai.Prompt { + dir, filename := filepath.Split(path) +- if dir != "" { ++ if dir == "" { ++ dir = "." ++ } else { + dir = filepath.Clean(dir) + } + +- return ai.LoadPrompt(g.reg, dir, filename, namespace) ++ return ai.LoadPromptFromFS(g.reg, os.DirFS(dir), ".", filename, namespace) ++} ++ ++// LoadPromptFromSource loads a prompt from raw `.prompt` file content (frontmatter + template) ++// into the registry and returns the resulting [ai.Prompt]. ++// ++// The `source` parameter should contain the complete `.prompt` file text, including ++// the YAML frontmatter (delimited by `---`) and the template body. ++// The `name` parameter is the prompt name, which may include a variant suffix ++// (e.g., "greeting" or "greeting.formal"). ++// The `namespace` acts as a prefix to the prompt name. Use an empty string for no namespace. ++// ++// This is useful for loading prompts from sources other than the filesystem, ++// such as databases, environment variables, or embedded strings. ++// ++// Example: ++// ++// promptSource := `--- ++// model: googleai/gemini-2.5-flash ++// input: ++// schema: ++// name: string ++// --- ++// Hello, {{name}}! ++// ` ++// ++// prompt, err := genkit.LoadPromptFromSource(g, promptSource, "greeting", "myApp") ++// if err != nil { ++// log.Fatalf("Failed to load prompt: %v", err) ++// } ++// ++// resp, err := prompt.Execute(ctx, ai.WithInput(map[string]any{"name": "World"})) ++// // ... ++func LoadPromptFromSource(g *Genkit, source, name, namespace string) (ai.Prompt, error) { ++ return ai.LoadPromptFromSource(g.reg, source, name, namespace) + } + + // DefinePartial wraps DefinePartial to register a partial template with the given name and source. +diff --git a/go/genkit/servers.go b/go/genkit/servers.go +index d48c11ffd..0b3bbe1ed 100644 +--- a/go/genkit/servers.go ++++ b/go/genkit/servers.go +@@ -31,23 +31,37 @@ import ( + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/core/logger" ++ "github.com/firebase/genkit/go/core/x/streaming" ++ "github.com/google/uuid" + ) + ++// HandlerOption configures a Handler. + type HandlerOption interface { +- apply(params *handlerParams) ++ applyHandler(*handlerOptions) error + } + +-// handlerParams are the parameters for an action HTTP handler. +-type handlerParams struct { +- ContextProviders []core.ContextProvider // Providers for action context that may be used during runtime. ++// handlerOptions are options for an action HTTP handler. ++type handlerOptions struct { ++ ContextProviders []core.ContextProvider // Providers for action context that may be used during runtime. ++ StreamManager streaming.StreamManager // Optional manager for durable stream storage. + } + +-// apply applies the options to the handler params. +-func (p *handlerParams) apply(params *handlerParams) { +- if params.ContextProviders != nil { +- panic("genkit.WithContextProviders: cannot set ContextProviders more than once") ++func (o *handlerOptions) applyHandler(opts *handlerOptions) error { ++ if o.ContextProviders != nil { ++ if opts.ContextProviders != nil { ++ return errors.New("cannot set ContextProviders more than once (WithContextProviders)") ++ } ++ opts.ContextProviders = o.ContextProviders ++ } ++ ++ if o.StreamManager != nil { ++ if opts.StreamManager != nil { ++ return errors.New("cannot set StreamManager more than once (WithStreamManager)") ++ } ++ opts.StreamManager = o.StreamManager + } +- params.ContextProviders = p.ContextProviders ++ ++ return nil + } + + // requestID is a unique ID for each request. +@@ -56,7 +70,16 @@ var requestID atomic.Int64 + // WithContextProviders adds providers for action context that may be used during runtime. + // They are called in the order added and may overwrite previous context. + func WithContextProviders(ctxProviders ...core.ContextProvider) HandlerOption { +- return &handlerParams{ContextProviders: ctxProviders} ++ return &handlerOptions{ContextProviders: ctxProviders} ++} ++ ++// WithStreamManager enables durable streaming with the provided StreamManager. ++// When enabled, streaming responses include an x-genkit-stream-id header that clients ++// can use to reconnect to in-progress or completed streams. ++// ++// EXPERIMENTAL: This API is subject to change. ++func WithStreamManager(manager streaming.StreamManager) HandlerOption { ++ return &handlerOptions{StreamManager: manager} + } + + // Handler returns an HTTP handler function that serves the action with the provided options. +@@ -67,12 +90,14 @@ func WithContextProviders(ctxProviders ...core.ContextProvider) HandlerOption { + // return api.ActionContext{"myKey": "myValue"}, nil + // })) + func Handler(a api.Action, opts ...HandlerOption) http.HandlerFunc { +- params := &handlerParams{} ++ options := &handlerOptions{} + for _, opt := range opts { +- opt.apply(params) ++ if err := opt.applyHandler(options); err != nil { ++ panic(fmt.Errorf("genkit.Handler: error applying options: %w", err)) ++ } + } + +- return wrapHandler(handler(a, params)) ++ return wrapHandler(handler(a, options)) + } + + // wrapHandler wraps an HTTP handler function with common logging and error handling. +@@ -101,8 +126,9 @@ func wrapHandler(h func(http.ResponseWriter, *http.Request) error) http.HandlerF + } + } + +-// handler returns an HTTP handler function that serves the action with the provided params. Responses are written in server-sent events (SSE) format. +-func handler(a api.Action, params *handlerParams) func(http.ResponseWriter, *http.Request) error { ++// handler returns an HTTP handler function that serves the action with the provided options. ++// Streaming responses are written in server-sent events (SSE) format. ++func handler(a api.Action, opts *handlerOptions) func(http.ResponseWriter, *http.Request) error { + return func(w http.ResponseWriter, r *http.Request) error { + if a == nil { + return errors.New("action is nil; cannot serve") +@@ -124,29 +150,9 @@ func handler(a api.Action, params *handlerParams) func(http.ResponseWriter, *htt + } + stream = stream || r.Header.Get("Accept") == "text/event-stream" + +- var callback streamingCallback[json.RawMessage] +- if stream { +- w.Header().Set("Content-Type", "text/event-stream") +- w.Header().Set("Cache-Control", "no-cache") +- w.Header().Set("Connection", "keep-alive") +- w.Header().Set("Transfer-Encoding", "chunked") +- callback = func(ctx context.Context, msg json.RawMessage) error { +- _, err := fmt.Fprintf(w, "data: {\"message\": %s}\n\n", msg) +- if err != nil { +- return err +- } +- if f, ok := w.(http.Flusher); ok { +- f.Flush() +- } +- return nil +- } +- } else { +- w.Header().Set("Content-Type", "application/json") +- } +- + ctx := r.Context() +- if params.ContextProviders != nil { +- for _, ctxProvider := range params.ContextProviders { ++ if opts.ContextProviders != nil { ++ for _, ctxProvider := range opts.ContextProviders { + headers := make(map[string]string, len(r.Header)) + for k, v := range r.Header { + headers[strings.ToLower(k)] = strings.Join(v, " ") +@@ -170,22 +176,252 @@ func handler(a api.Action, params *handlerParams) func(http.ResponseWriter, *htt + } + } + +- out, err := a.RunJSON(ctx, body.Data, callback) +- if err != nil { +- if stream { +- _, err = fmt.Fprintf(w, "data: {\"error\": {\"status\": \"INTERNAL\", \"message\": \"stream flow error\", \"details\": \"%v\"}}\n\n", err) +- return err ++ if stream { ++ streamID := r.Header.Get("X-Genkit-Stream-Id") ++ ++ if streamID != "" && opts.StreamManager != nil { ++ return subscribeToStream(ctx, w, opts.StreamManager, streamID) ++ } ++ ++ w.Header().Set("Content-Type", "text/event-stream") ++ w.Header().Set("Cache-Control", "no-cache") ++ w.Header().Set("Connection", "keep-alive") ++ w.Header().Set("Transfer-Encoding", "chunked") ++ ++ if opts.StreamManager != nil { ++ return runWithDurableStreaming(ctx, w, a, opts.StreamManager, body.Data) + } ++ ++ return runWithStreaming(ctx, w, a, body.Data) ++ } ++ ++ w.Header().Set("Content-Type", "application/json") ++ out, err := a.RunJSON(ctx, body.Data, nil) ++ if err != nil { + return err + } +- if stream { +- _, err = fmt.Fprintf(w, "data: {\"result\": %s}\n\n", out) ++ return writeResultResponse(w, out) ++ } ++} ++ ++// runWithStreaming executes the action with standard HTTP streaming (no durability). ++func runWithStreaming(ctx context.Context, w http.ResponseWriter, a api.Action, input json.RawMessage) error { ++ callback := func(ctx context.Context, msg json.RawMessage) error { ++ if err := writeSSEMessage(w, msg); err != nil { + return err + } ++ if f, ok := w.(http.Flusher); ok { ++ f.Flush() ++ } ++ return nil ++ } ++ ++ out, err := a.RunJSON(ctx, input, callback) ++ if err != nil { ++ if werr := writeSSEError(w, err); werr != nil { ++ return werr ++ } ++ return nil ++ } ++ return writeSSEResult(w, out) ++} ++ ++// runWithDurableStreaming executes the action with durable streaming support. ++// Chunks are written to both the HTTP response and the stream manager for later replay. ++// ++// The flow execution is detached from the HTTP request context so that if the ++// original client disconnects, the flow continues running and writing to durable ++// storage. This allows other clients to subscribe to the stream and receive the ++// remaining chunks and final result. ++func runWithDurableStreaming(ctx context.Context, w http.ResponseWriter, a api.Action, sm streaming.StreamManager, input json.RawMessage) error { ++ streamID := uuid.New().String() ++ ++ durableStream, err := sm.Open(ctx, streamID) ++ if err != nil { ++ return err ++ } ++ defer durableStream.Close() ++ ++ w.Header().Set("X-Genkit-Stream-Id", streamID) ++ ++ // Create a detached context for flow execution. This preserves context values ++ // (action context, tracing, logger) but won't be canceled when the HTTP client ++ // disconnects, allowing the flow to continue streaming to durable storage. ++ durableCtx := context.WithoutCancel(ctx) ++ ++ // Track whether the HTTP client is still connected. ++ clientGone := ctx.Done() ++ ++ callback := func(_ context.Context, msg json.RawMessage) error { ++ // Always write to durable storage regardless of client connection state. ++ durableStream.Write(durableCtx, msg) + +- _, err = fmt.Fprintf(w, "{\"result\": %s}\n", out) ++ // Only attempt HTTP writes if the client is still connected. ++ select { ++ case <-clientGone: ++ return nil ++ default: ++ if err := writeSSEMessage(w, msg); err != nil { ++ return nil ++ } ++ if f, ok := w.(http.Flusher); ok { ++ f.Flush() ++ } ++ } ++ return nil ++ } ++ ++ out, err := a.RunJSON(durableCtx, input, callback) ++ if err != nil { ++ durableStream.Error(durableCtx, err) ++ select { ++ case <-clientGone: ++ return nil ++ default: ++ writeSSEError(w, err) ++ } ++ return nil ++ } ++ ++ durableStream.Done(durableCtx, out) ++ select { ++ case <-clientGone: ++ return nil ++ default: ++ return writeSSEResult(w, out) ++ } ++} ++ ++// subscribeToStream subscribes to an existing durable stream and writes events to the HTTP response. ++func subscribeToStream(ctx context.Context, w http.ResponseWriter, sm streaming.StreamManager, streamID string) error { ++ events, unsubscribe, err := sm.Subscribe(ctx, streamID) ++ if err != nil { ++ var ufErr *core.UserFacingError ++ if errors.As(err, &ufErr) && ufErr.Status == core.NOT_FOUND { ++ w.WriteHeader(http.StatusNoContent) ++ return nil ++ } ++ return err ++ } ++ defer unsubscribe() ++ ++ w.Header().Set("Content-Type", "text/event-stream") ++ w.Header().Set("Cache-Control", "no-cache") ++ w.Header().Set("Connection", "keep-alive") ++ w.Header().Set("Transfer-Encoding", "chunked") ++ ++ for event := range events { ++ switch event.Type { ++ case streaming.StreamEventChunk: ++ if err := writeSSEMessage(w, event.Chunk); err != nil { ++ return err ++ } ++ if f, ok := w.(http.Flusher); ok { ++ f.Flush() ++ } ++ case streaming.StreamEventDone: ++ if err := writeSSEResult(w, event.Output); err != nil { ++ return err ++ } ++ return nil ++ case streaming.StreamEventError: ++ streamErr := event.Err ++ if streamErr == nil { ++ streamErr = errors.New("unknown error") ++ } ++ if err := writeSSEError(w, streamErr); err != nil { ++ return err ++ } ++ return nil ++ } ++ } ++ ++ return nil ++} ++ ++// flowResultResponse wraps a final action result for JSON serialization. ++type flowResultResponse struct { ++ Result json.RawMessage `json:"result"` ++} ++ ++// flowMessageResponse wraps a streaming chunk for JSON serialization. ++type flowMessageResponse struct { ++ Message json.RawMessage `json:"message"` ++} ++ ++// flowErrorResponse wraps an error for JSON serialization in streaming responses. ++type flowErrorResponse struct { ++ Error *flowError `json:"error"` ++} ++ ++// flowError represents the error payload in a streaming error response. ++type flowError struct { ++ Status core.StatusName `json:"status"` ++ Message string `json:"message"` ++ Details string `json:"details,omitempty"` ++} ++ ++// writeResultResponse writes a JSON result response for non-streaming requests. ++func writeResultResponse(w http.ResponseWriter, result json.RawMessage) error { ++ resp := flowResultResponse{Result: result} ++ data, err := json.Marshal(resp) ++ if err != nil { ++ return err ++ } ++ _, err = w.Write(data) ++ if err != nil { ++ return err ++ } ++ _, err = w.Write([]byte("\n")) ++ return err ++} ++ ++// writeSSEResult writes a JSON result as a server-sent event for streaming requests. ++func writeSSEResult(w http.ResponseWriter, result json.RawMessage) error { ++ resp := flowResultResponse{Result: result} ++ data, err := json.Marshal(resp) ++ if err != nil { ++ return err ++ } ++ _, err = fmt.Fprintf(w, "data: %s\n\n", data) ++ return err ++} ++ ++// writeSSEMessage writes a streaming chunk as a server-sent event. ++func writeSSEMessage(w http.ResponseWriter, msg json.RawMessage) error { ++ resp := flowMessageResponse{Message: msg} ++ data, err := json.Marshal(resp) ++ if err != nil { ++ return err ++ } ++ _, err = fmt.Fprintf(w, "data: %s\n\n", data) ++ return err ++} ++ ++// writeSSEError writes an error as a server-sent event for streaming requests. ++func writeSSEError(w http.ResponseWriter, flowErr error) error { ++ status := core.INTERNAL ++ var ufErr *core.UserFacingError ++ var gErr *core.GenkitError ++ if errors.As(flowErr, &ufErr) { ++ status = ufErr.Status ++ } else if errors.As(flowErr, &gErr) { ++ status = gErr.Status ++ } ++ ++ resp := flowErrorResponse{ ++ Error: &flowError{ ++ Status: status, ++ Message: "stream flow error", ++ Details: flowErr.Error(), ++ }, ++ } ++ data, err := json.Marshal(resp) ++ if err != nil { + return err + } ++ _, err = fmt.Fprintf(w, "data: %s\n\n", data) ++ return err + } + + func parseBoolQueryParam(r *http.Request, name string) (bool, error) { +diff --git a/go/genkit/servers_test.go b/go/genkit/servers_test.go +index a0a07cc21..b5a69d17e 100644 +--- a/go/genkit/servers_test.go ++++ b/go/genkit/servers_test.go +@@ -27,6 +27,7 @@ import ( + "testing" + + "github.com/firebase/genkit/go/core" ++ "github.com/firebase/genkit/go/core/x/streaming" + ) + + func FakeContextProvider(ctx context.Context, req core.RequestData) (core.ActionContext, error) { +@@ -222,17 +223,17 @@ func TestStreamingHandler(t *testing.T) { + t.Errorf("want status code %d, got %d", http.StatusOK, resp.StatusCode) + } + +- expected := `data: {"message": "h"} ++ expected := `data: {"message":"h"} + +-data: {"message": "e"} ++data: {"message":"e"} + +-data: {"message": "l"} ++data: {"message":"l"} + +-data: {"message": "l"} ++data: {"message":"l"} + +-data: {"message": "o"} ++data: {"message":"o"} + +-data: {"result": "hello-end"} ++data: {"result":"hello-end"} + + ` + if string(body) != expected { +@@ -256,7 +257,7 @@ data: {"result": "hello-end"} + t.Errorf("want status code %d, got %d", http.StatusOK, resp.StatusCode) + } + +- expected := `data: {"error": {"status": "INTERNAL", "message": "stream flow error", "details": "streaming error"}} ++ expected := `data: {"error":{"status":"INTERNAL_SERVER_ERROR","message":"stream flow error","details":"streaming error"}} + + ` + if string(body) != expected { +@@ -264,3 +265,121 @@ data: {"result": "hello-end"} + } + }) + } ++ ++func TestDurableStreamingHandler(t *testing.T) { ++ g := Init(context.Background()) ++ ++ streamingFlow := DefineStreamingFlow(g, "durableStreaming", ++ func(ctx context.Context, input string, cb func(context.Context, string) error) (string, error) { ++ for _, c := range input { ++ if err := cb(ctx, string(c)); err != nil { ++ return "", err ++ } ++ } ++ return input + "-done", nil ++ }) ++ ++ t.Run("returns stream ID header", func(t *testing.T) { ++ sm := streaming.NewInMemoryStreamManager() ++ defer sm.Close() ++ handler := Handler(streamingFlow, WithStreamManager(sm)) ++ ++ req := httptest.NewRequest("POST", "/", strings.NewReader(`{"data":"hi"}`)) ++ req.Header.Set("Content-Type", "application/json") ++ req.Header.Set("Accept", "text/event-stream") ++ w := httptest.NewRecorder() ++ ++ handler(w, req) ++ ++ resp := w.Result() ++ body, _ := io.ReadAll(resp.Body) ++ ++ if resp.StatusCode != http.StatusOK { ++ t.Errorf("want status code %d, got %d", http.StatusOK, resp.StatusCode) ++ } ++ ++ streamID := resp.Header.Get("X-Genkit-Stream-Id") ++ if streamID == "" { ++ t.Error("want X-Genkit-Stream-Id header to be set") ++ } ++ ++ expected := `data: {"message":"h"} ++ ++data: {"message":"i"} ++ ++data: {"result":"hi-done"} ++ ++` ++ if string(body) != expected { ++ t.Errorf("want streaming body:\n%q\n\nGot:\n%q", expected, string(body)) ++ } ++ }) ++ ++ t.Run("subscribe to completed stream", func(t *testing.T) { ++ sm := streaming.NewInMemoryStreamManager() ++ defer sm.Close() ++ handler := Handler(streamingFlow, WithStreamManager(sm)) ++ ++ // First request - run the stream to completion ++ req1 := httptest.NewRequest("POST", "/", strings.NewReader(`{"data":"ab"}`)) ++ req1.Header.Set("Content-Type", "application/json") ++ req1.Header.Set("Accept", "text/event-stream") ++ w1 := httptest.NewRecorder() ++ ++ handler(w1, req1) ++ ++ resp1 := w1.Result() ++ streamID := resp1.Header.Get("X-Genkit-Stream-Id") ++ if streamID == "" { ++ t.Fatal("want X-Genkit-Stream-Id header to be set") ++ } ++ ++ // Second request - subscribe to the completed stream ++ req2 := httptest.NewRequest("POST", "/", strings.NewReader(`{"data":"ignored"}`)) ++ req2.Header.Set("Content-Type", "application/json") ++ req2.Header.Set("Accept", "text/event-stream") ++ req2.Header.Set("X-Genkit-Stream-Id", streamID) ++ w2 := httptest.NewRecorder() ++ ++ handler(w2, req2) ++ ++ resp2 := w2.Result() ++ body2, _ := io.ReadAll(resp2.Body) ++ ++ if resp2.StatusCode != http.StatusOK { ++ t.Errorf("want status code %d, got %d", http.StatusOK, resp2.StatusCode) ++ } ++ ++ // Should replay all chunks and the final result ++ expected := `data: {"message":"a"} ++ ++data: {"message":"b"} ++ ++data: {"result":"ab-done"} ++ ++` ++ if string(body2) != expected { ++ t.Errorf("want replayed body:\n%q\n\nGot:\n%q", expected, string(body2)) ++ } ++ }) ++ ++ t.Run("subscribe to non-existent stream returns 204", func(t *testing.T) { ++ sm := streaming.NewInMemoryStreamManager() ++ defer sm.Close() ++ handler := Handler(streamingFlow, WithStreamManager(sm)) ++ ++ req := httptest.NewRequest("POST", "/", strings.NewReader(`{"data":"test"}`)) ++ req.Header.Set("Content-Type", "application/json") ++ req.Header.Set("Accept", "text/event-stream") ++ req.Header.Set("X-Genkit-Stream-Id", "non-existent-stream-id") ++ w := httptest.NewRecorder() ++ ++ handler(w, req) ++ ++ resp := w.Result() ++ ++ if resp.StatusCode != http.StatusNoContent { ++ t.Errorf("want status code %d, got %d", http.StatusNoContent, resp.StatusCode) ++ } ++ }) ++} +diff --git a/go/go.mod b/go/go.mod +index 3aa1cd948..3472c0f4c 100644 +--- a/go/go.mod ++++ b/go/go.mod +@@ -41,7 +41,7 @@ require ( + golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 + golang.org/x/tools v0.34.0 + google.golang.org/api v0.236.0 +- google.golang.org/genai v1.36.0 ++ google.golang.org/genai v1.40.0 + ) + + require ( +diff --git a/go/go.sum b/go/go.sum +index 43f5ac29c..e7abcc149 100644 +--- a/go/go.sum ++++ b/go/go.sum +@@ -537,8 +537,8 @@ google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9Ywl + google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= + google.golang.org/appengine/v2 v2.0.6 h1:LvPZLGuchSBslPBp+LAhihBeGSiRh1myRoYK4NtuBIw= + google.golang.org/appengine/v2 v2.0.6/go.mod h1:WoEXGoXNfa0mLvaH5sV3ZSGXwVmy8yf7Z1JKf3J3wLI= +-google.golang.org/genai v1.36.0 h1:sJCIjqTAmwrtAIaemtTiKkg2TO1RxnYEusTmEQ3nGxM= +-google.golang.org/genai v1.36.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= ++google.golang.org/genai v1.40.0 h1:kYxyQSH+vsib8dvsgyLJzsVEIv5k3ZmHJyVqdvGncmc= ++google.golang.org/genai v1.40.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= + google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= + google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= + google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= +diff --git a/go/internal/base/json.go b/go/internal/base/json.go +index 4f413ab8f..117855a6d 100644 +--- a/go/internal/base/json.go ++++ b/go/internal/base/json.go +@@ -138,18 +138,29 @@ func SchemaAsMap(s *jsonschema.Schema) map[string]any { + return m + } + +-// jsonMarkdownRegex specifically looks for "json" language identifier +-var jsonMarkdownRegex = regexp.MustCompile("(?s)```json(.*?)```") ++// jsonMarkdownRegex matches fenced code blocks with "json" language identifier (case-insensitive). ++var jsonMarkdownRegex = regexp.MustCompile("(?si)```json\\s*(.*?)```") ++ ++// plainMarkdownRegex matches fenced code blocks without any language identifier. ++var plainMarkdownRegex = regexp.MustCompile("(?s)```\\s*\\n(.*?)```") + + // ExtractJSONFromMarkdown returns the contents of the first fenced code block in +-// the markdown text md. If there is none, it returns md. ++// the markdown text md. It matches code blocks with "json" identifier (case-insensitive) ++// or code blocks without any language identifier. If there is no matching block, it returns md. + func ExtractJSONFromMarkdown(md string) string { ++ // First try to match explicit json code blocks + matches := jsonMarkdownRegex.FindStringSubmatch(md) +- if len(matches) < 2 { +- return md ++ if len(matches) >= 2 { ++ return strings.TrimSpace(matches[1]) ++ } ++ ++ // Fall back to plain code blocks (no language identifier) ++ matches = plainMarkdownRegex.FindStringSubmatch(md) ++ if len(matches) >= 2 { ++ return strings.TrimSpace(matches[1]) + } +- // capture group 1 matches the actual fenced JSON block +- return strings.TrimSpace(matches[1]) ++ ++ return md + } + + // GetJSONObjectLines splits a string by newlines, trims whitespace from each line, +diff --git a/go/internal/base/json_test.go b/go/internal/base/json_test.go +index b018849af..eda537c9b 100644 +--- a/go/internal/base/json_test.go ++++ b/go/internal/base/json_test.go +@@ -78,6 +78,31 @@ func TestExtractJSONFromMarkdown(t *testing.T) { + in: "```json\n{\"a\": 1}\n``` ```yaml\nkey: 1\nanother-key: 2```", + want: "{\"a\": 1}", + }, ++ { ++ desc: "uppercase JSON identifier", ++ in: "```JSON\n{\"a\": 1}\n```", ++ want: "{\"a\": 1}", ++ }, ++ { ++ desc: "mixed case Json identifier", ++ in: "```Json\n{\"a\": 1}\n```", ++ want: "{\"a\": 1}", ++ }, ++ { ++ desc: "plain code block without identifier", ++ in: "```\n{\"a\": 1}\n```", ++ want: "{\"a\": 1}", ++ }, ++ { ++ desc: "plain code block with text before", ++ in: "Here is the result:\n\n```\n{\"title\": \"Pizza\"}\n```", ++ want: "{\"title\": \"Pizza\"}", ++ }, ++ { ++ desc: "json block preferred over plain block", ++ in: "```\n{\"plain\": true}\n``` then ```json\n{\"json\": true}\n```", ++ want: "{\"json\": true}", ++ }, + } + for _, tc := range tests { + t.Run(tc.desc, func(t *testing.T) { +diff --git a/go/internal/base/misc.go b/go/internal/base/misc.go +index 9e3afa1d9..f4fdb7af3 100644 +--- a/go/internal/base/misc.go ++++ b/go/internal/base/misc.go +@@ -18,6 +18,7 @@ package base + + import ( + "net/url" ++ "reflect" + ) + + // An Environment is the execution context in which the program is running. +@@ -38,3 +39,16 @@ func Zero[T any]() T { + func Clean(id string) string { + return url.PathEscape(id) + } ++ ++// IsNil returns true if v is nil or a nil pointer/interface/map/slice/channel/func. ++func IsNil[T any](v T) bool { ++ rv := reflect.ValueOf(v) ++ switch rv.Kind() { ++ case reflect.Invalid: ++ return true ++ case reflect.Ptr, reflect.Interface, reflect.Map, reflect.Slice, reflect.Chan, reflect.Func: ++ return rv.IsNil() ++ default: ++ return false ++ } ++} +diff --git a/go/plugins/anthropic/anthropic.go b/go/plugins/anthropic/anthropic.go +index 493a6c76b..e93f1abde 100644 +--- a/go/plugins/anthropic/anthropic.go ++++ b/go/plugins/anthropic/anthropic.go +@@ -169,8 +169,9 @@ func newModel(client anthropic.Client, name string, opts ai.ModelOptions) ai.Mod + // configToMap converts a config struct to a map[string]any. + func configToMap(config any) map[string]any { + r := jsonschema.Reflector{ +- DoNotReference: false, // Prevent $ref usage ++ DoNotReference: true, // Prevent $ref usage + AllowAdditionalProperties: false, ++ ExpandedStruct: true, + RequiredFromJSONSchemaTags: true, + } + // The anthropic SDK uses a number of wrapper types for float, int, etc. +@@ -201,5 +202,6 @@ func configToMap(config any) map[string]any { + } + schema := r.Reflect(config) + result := base.SchemaAsMap(schema) ++ + return result + } +diff --git a/go/plugins/firebase/auth.go b/go/plugins/firebase/auth.go +index bb1856f97..928097323 100644 +--- a/go/plugins/firebase/auth.go ++++ b/go/plugins/firebase/auth.go +@@ -40,11 +40,11 @@ type AuthClient interface { + + // ContextProvider creates a Firebase context provider for Genkit actions. + func ContextProvider(ctx context.Context, g *genkit.Genkit, policy AuthPolicy) (core.ContextProvider, error) { +- f, ok := genkit.LookupPlugin(g, provider).(*Firebase) +- if !ok { +- return nil, core.NewError(core.NOT_FOUND, "firebase plugin not initialized; did you pass the plugin to genkit.Init()") ++ f, err := resolvePlugin(g) ++ if err != nil { ++ return nil, err + } +- client, err := f.App.Auth(ctx) ++ client, err := f.Auth(ctx) + if err != nil { + return nil, err + } +diff --git a/go/plugins/firebase/firebase.go b/go/plugins/firebase/firebase.go +index 1d221b4bb..50fa5cc6c 100644 +--- a/go/plugins/firebase/firebase.go ++++ b/go/plugins/firebase/firebase.go +@@ -20,27 +20,48 @@ import ( + "context" + "errors" + "fmt" +- "log" + "os" + "sync" + ++ "cloud.google.com/go/firestore" + firebasev4 "firebase.google.com/go/v4" ++ "firebase.google.com/go/v4/auth" + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/genkit" + ) + +-// Firebase plugin for Genkit, providing integration with Firebase services. +-// This plugin allows users to define retrievers and indexers for Firebase Firestore. +-const provider = "firebase" // Identifier for the Firebase plugin. +-const projectIdEnv = "FIREBASE_PROJECT_ID" // Environment variable for the Firebase project ID. ++const provider = "firebase" ++const projectIdEnv = "FIREBASE_PROJECT_ID" + +-// Firebase FireStore passes configuration options to the plugin. ++const pluginInstruction = "Pass the Firebase plugin to genkit.Init():\n" + ++ " g := genkit.Init(ctx, genkit.WithPlugins(&firebase.Firebase{ProjectId: \"your-project\"}))" ++ ++var errPluginNotInitialized = errors.New("firebase: plugin not initialized. " + pluginInstruction) ++var errPluginNotFound = errors.New("firebase: plugin not found. " + pluginInstruction) ++var errCredentials = "Ensure you have proper credentials. For local development, run: gcloud auth application-default login" ++ ++// Firebase is the Genkit plugin for Firebase services. ++// It provides integration with Firebase Firestore for retrievers, indexers, and durable streaming. ++// ++// Usage: ++// ++// g := genkit.Init(ctx, genkit.WithPlugins(&firebase.Firebase{ProjectId: "my-project"})) ++// ++// Or with an existing Firebase app: ++// ++// g := genkit.Init(ctx, genkit.WithPlugins(&firebase.Firebase{App: myFirebaseApp})) + type Firebase struct { +- ProjectId string // Firebase project ID. +- App *firebasev4.App // Firebase app instance. +- mu sync.Mutex // Mutex to control concurrent access. +- initted bool // Tracks whether the plugin has been initialized. ++ // ProjectId is the Firebase/GCP project ID. If set, a Firebase app is created automatically. ++ // Can also be set via the FIREBASE_PROJECT_ID environment variable. ++ ProjectId string ++ // App is an existing Firebase app instance. Provide either ProjectId or App, not both. ++ App *firebasev4.App ++ ++ mu sync.Mutex ++ initted bool ++ firestoreClient *firestore.Client ++ authClient *auth.Client + } + + // Name returns the name of the plugin. +@@ -48,29 +69,32 @@ func (f *Firebase) Name() string { + return provider + } + +-// Init initializes the Firebase plugin. ++// Init initializes the Firebase plugin. Called automatically by genkit.Init(). + func (f *Firebase) Init(ctx context.Context) []api.Action { + f.mu.Lock() + defer f.mu.Unlock() + +- // Resolve the Firebase project ID. +- projectId := resolveProjectId(f.ProjectId) +- + if f.initted { + panic("firebase.Init: plugin already initialized") + } + +- if f.App == nil && f.ProjectId == "" { +- panic("firebase.Init: provide ProjectId or App") ++ projectId := resolveProjectId(f.ProjectId) ++ ++ if f.App == nil && projectId == "" { ++ panic("firebase.Init: Firebase plugin requires either ProjectId or App to be set.\n" + ++ " Option 1: Set ProjectId directly: &firebase.Firebase{ProjectId: \"your-project-id\"}\n" + ++ " Option 2: Set FIREBASE_PROJECT_ID environment variable\n" + ++ " Option 3: Provide an existing Firebase App: &firebase.Firebase{App: yourApp}") + } +- if f.ProjectId != "" { +- if f.App != nil { +- panic("firebase.Init: provide either ProjectId or App, not both") +- } +- // Configure and initialize the Firebase app. ++ ++ if f.App != nil && f.ProjectId != "" { ++ panic("firebase.Init: provide either ProjectId or App, not both") ++ } ++ ++ if f.App == nil { + firebaseApp, err := firebasev4.NewApp(ctx, &firebasev4.Config{ProjectID: projectId}) + if err != nil { +- panic(fmt.Errorf("error initializing Firebase App: %v", err)) ++ panic(fmt.Errorf("firebase.Init: failed to initialize Firebase App: %v", err)) + } + f.App = firebaseApp + } +@@ -79,37 +103,90 @@ func (f *Firebase) Init(ctx context.Context) []api.Action { + return []api.Action{} + } + +-// DefineRetriever defines a Retriever with the given configuration. +-func DefineRetriever(ctx context.Context, g *genkit.Genkit, cfg RetrieverOptions) (ai.Retriever, error) { +- // Lookup the Firebase plugin from the registry. +- f, ok := genkit.LookupPlugin(g, provider).(*Firebase) +- if !ok { +- return nil, errors.New("firebase plugin not found; did you call firebase.Init with the firebase plugin") ++// Firestore returns a cached Firestore client for the Firebase project. ++// The client is created lazily on first call and reused for subsequent calls. ++// This client is shared across all Firebase plugin features (retrievers, stream managers, etc.). ++func (f *Firebase) Firestore(ctx context.Context) (*firestore.Client, error) { ++ f.mu.Lock() ++ defer f.mu.Unlock() ++ ++ if !f.initted { ++ return nil, errPluginNotInitialized ++ } ++ ++ if f.firestoreClient != nil { ++ return f.firestoreClient, nil + } + +- // Initialize Firestore client. +- firestoreClient, err := f.App.Firestore(ctx) ++ client, err := f.App.Firestore(ctx) + if err != nil { +- log.Fatalf("Error creating Firestore client: %v", err) // Log and exit on failure. ++ return nil, fmt.Errorf("firebase: failed to create Firestore client: %w. %s", err, errCredentials) + } + +- // Define a Firestore retriever using the client. +- retriever, err := defineFirestoreRetriever(g, cfg, firestoreClient) ++ f.firestoreClient = client ++ return client, nil ++} ++ ++// Auth returns a cached Firebase Auth client for the Firebase project. ++// The client is created lazily on first call and reused for subsequent calls. ++func (f *Firebase) Auth(ctx context.Context) (*auth.Client, error) { ++ f.mu.Lock() ++ defer f.mu.Unlock() ++ ++ if !f.initted { ++ return nil, errPluginNotInitialized ++ } ++ ++ if f.authClient != nil { ++ return f.authClient, nil ++ } ++ ++ client, err := f.App.Auth(ctx) + if err != nil { ++ return nil, fmt.Errorf("firebase: failed to create Auth client: %w. %s", err, errCredentials) ++ } + +- return nil, fmt.Errorf("DefineRetriever: failed to initialize retriever %s: %v", cfg.Name, err) ++ f.authClient = client ++ return client, nil ++} ++ ++// DefineRetriever defines a Firestore vector retriever with the given configuration. ++// The Firebase plugin must be registered with genkit.Init() before calling this function. ++func DefineRetriever(ctx context.Context, g *genkit.Genkit, opts RetrieverOptions) (ai.Retriever, error) { ++ f, err := resolvePlugin(g) ++ if err != nil { ++ return nil, err ++ } ++ ++ firestoreClient, err := f.Firestore(ctx) ++ if err != nil { ++ return nil, err ++ } ++ ++ retriever, err := defineFirestoreRetriever(g, opts, firestoreClient) ++ if err != nil { ++ return nil, fmt.Errorf("firebase.DefineRetriever: failed to initialize retriever %q: %w", opts.Name, err) + } + return retriever, nil + } + +-// resolveProjectId reads the projectId from the environment if necessary. ++// resolveProjectId resolves the Firebase project ID from various sources. + func resolveProjectId(projectId string) string { +- // Return the provided project ID if it's not empty. + if projectId != "" { + return projectId + } ++ return os.Getenv(projectIdEnv) ++} + +- // Otherwise, read the project ID from the environment variable. +- projectId = os.Getenv(projectIdEnv) +- return projectId ++// resolvePlugin resolves the Firebase plugin from the Genkit registry. ++func resolvePlugin(g *genkit.Genkit) (*Firebase, error) { ++ plugin := genkit.LookupPlugin(g, provider) ++ if plugin == nil { ++ return nil, errPluginNotFound ++ } ++ f, ok := plugin.(*Firebase) ++ if !ok { ++ return nil, fmt.Errorf("firebase: unexpected plugin type %T for provider %q", plugin, provider) ++ } ++ return f, nil + } +diff --git a/go/plugins/firebase/retriever.go b/go/plugins/firebase/retriever.go +index 628702639..5281c12bb 100644 +--- a/go/plugins/firebase/retriever.go ++++ b/go/plugins/firebase/retriever.go +@@ -17,6 +17,7 @@ package firebase + import ( + "context" + "fmt" ++ "log/slog" + "os" + + "cloud.google.com/go/firestore" +@@ -26,10 +27,9 @@ import ( + "github.com/firebase/genkit/go/genkit" + ) + +-type VectorType int ++const firestoreCollectionEnv = "FIRESTORE_COLLECTION" + +-// Firestore collection environment variable key name +-const firestoreCollection = "FIRESTORE_COLLECTION" ++type VectorType int + + // TODO: in retriever options add field that controls the 32/64 + +@@ -141,14 +141,17 @@ func defineFirestoreRetriever(g *genkit.Genkit, cfg RetrieverOptions, client *fi + return genkit.DefineRetriever(g, api.NewName(provider, cfg.Name), retOpts, retrieve), nil + } + +-// resolveFirestoreCollection resolves the Firestore collection name from the environment if necessary + func resolveFirestoreCollection(collectionName string) (string, error) { + if collectionName != "" { + return collectionName, nil + } +- collectionName = os.Getenv(firestoreCollection) ++ collectionName = os.Getenv(firestoreCollectionEnv) + if collectionName == "" { +- return "", fmt.Errorf("no Firestore collection provided; set %q env variable or pass the collection directly", firestoreCollection) ++ return "", fmt.Errorf("firebase: no Firestore collection provided. " + ++ "Pass the collection in RetrieverOptions: RetrieverOptions{Collection: \"my-collection\"}") + } ++ slog.Warn("Using FIRESTORE_COLLECTION environment variable is deprecated for retriever configuration. "+ ++ "Use RetrieverOptions{Collection: \"my-collection\"} instead.", ++ "collection", collectionName) + return collectionName, nil + } +diff --git a/go/plugins/firebase/x/stream_manager.go b/go/plugins/firebase/x/stream_manager.go +new file mode 100644 +index 000000000..edc015463 +--- /dev/null ++++ b/go/plugins/firebase/x/stream_manager.go +@@ -0,0 +1,497 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++// ++// SPDX-License-Identifier: Apache-2.0 ++ ++// Package x contains experimental Firebase features. ++// ++// APIs in this package are under active development and may change in any ++// minor version release. Use with caution in production environments. ++package x ++ ++import ( ++ "context" ++ "encoding/json" ++ "errors" ++ "fmt" ++ "sync" ++ "time" ++ ++ "cloud.google.com/go/firestore" ++ "github.com/firebase/genkit/go/core" ++ "github.com/firebase/genkit/go/core/x/streaming" ++ "github.com/firebase/genkit/go/genkit" ++ "github.com/firebase/genkit/go/plugins/firebase" ++ "github.com/google/uuid" ++ "google.golang.org/grpc/codes" ++ "google.golang.org/grpc/status" ++) ++ ++const ( ++ streamBufferSize = 100 ++ defaultTimeout = 60 * time.Second ++ defaultTTL = 5 * time.Minute ++ streamEventChunk = "chunk" ++ streamEventDone = "done" ++ streamEventError = "error" ++) ++ ++// FirestoreStreamManagerOption configures a FirestoreStreamManager. ++type FirestoreStreamManagerOption interface { ++ applyFirestoreStreamManager(*firestoreStreamManagerOptions) error ++} ++ ++// firestoreStreamManagerOptions holds configuration for FirestoreStreamManager. ++type firestoreStreamManagerOptions struct { ++ Collection string ++ Timeout time.Duration ++ TTL time.Duration ++} ++ ++func (o *firestoreStreamManagerOptions) applyFirestoreStreamManager(opts *firestoreStreamManagerOptions) error { ++ if o.Collection != "" { ++ if opts.Collection != "" { ++ return errors.New("cannot set collection more than once (WithCollection)") ++ } ++ opts.Collection = o.Collection ++ } ++ ++ if o.Timeout > 0 { ++ if opts.Timeout > 0 { ++ return errors.New("cannot set timeout more than once (WithTimeout)") ++ } ++ opts.Timeout = o.Timeout ++ } ++ ++ if o.TTL > 0 { ++ if opts.TTL > 0 { ++ return errors.New("cannot set TTL more than once (WithFirestoreTTL)") ++ } ++ opts.TTL = o.TTL ++ } ++ ++ return nil ++} ++ ++// WithCollection sets the Firestore collection name where stream documents are stored. ++// This option is required. ++func WithCollection(collection string) FirestoreStreamManagerOption { ++ return &firestoreStreamManagerOptions{Collection: collection} ++} ++ ++// WithTimeout sets how long a subscriber waits for new events before giving up. ++// If no activity occurs within this duration, subscribers receive a DEADLINE_EXCEEDED error. ++// Default is 60 seconds. ++func WithTimeout(timeout time.Duration) FirestoreStreamManagerOption { ++ return &firestoreStreamManagerOptions{Timeout: timeout} ++} ++ ++// WithTTL sets how long completed streams are retained before Firestore auto-deletes them. ++// Requires a TTL policy on the collection for the "expiresAt" field. Default is 5 minutes. ++// See: https://firebase.google.com/docs/firestore/ttl ++func WithTTL(ttl time.Duration) FirestoreStreamManagerOption { ++ return &firestoreStreamManagerOptions{TTL: ttl} ++} ++ ++// FirestoreStreamManager implements [streaming.StreamManager] using Firestore as the backend. ++// Stream state is persisted in Firestore documents, allowing streams to survive server ++// restarts and be accessible across multiple instances. ++type FirestoreStreamManager struct { ++ client *firestore.Client ++ collection string ++ timeout time.Duration ++ ttl time.Duration ++} ++ ++// streamDocument represents the structure of a stream document in Firestore. ++type streamDocument struct { ++ Stream []streamEntry `firestore:"stream"` ++ CreatedAt time.Time `firestore:"createdAt"` ++ UpdatedAt time.Time `firestore:"updatedAt"` ++ ExpiresAt *time.Time `firestore:"expiresAt,omitempty"` ++} ++ ++// streamEntry represents a single entry in the stream array. ++type streamEntry struct { ++ Type string `firestore:"type"` ++ Chunk json.RawMessage `firestore:"chunk,omitempty"` ++ Output json.RawMessage `firestore:"output,omitempty"` ++ Err *streamError `firestore:"err,omitempty"` ++ UUID string `firestore:"uuid,omitempty"` ++} ++ ++// streamError represents a serializable error for Firestore storage. ++type streamError struct { ++ Status string `firestore:"status"` ++ Message string `firestore:"message"` ++} ++ ++// NewFirestoreStreamManager creates a FirestoreStreamManager for durable streaming. ++func NewFirestoreStreamManager(ctx context.Context, g *genkit.Genkit, opts ...FirestoreStreamManagerOption) (*FirestoreStreamManager, error) { ++ streamOpts := &firestoreStreamManagerOptions{} ++ for _, opt := range opts { ++ if err := opt.applyFirestoreStreamManager(streamOpts); err != nil { ++ return nil, fmt.Errorf("firebase.NewFirestoreStreamManager: error applying options: %w", err) ++ } ++ } ++ if streamOpts.Collection == "" { ++ return nil, errors.New("firebase.NewFirestoreStreamManager: Collection name is required.\n" + ++ " Specify the Firestore collection where stream documents will be stored:\n" + ++ " firebase.NewFirestoreStreamManager(ctx, g, firebase.WithCollection(\"genkit-streams\"))") ++ } ++ if streamOpts.Timeout == 0 { ++ streamOpts.Timeout = defaultTimeout ++ } ++ if streamOpts.TTL == 0 { ++ streamOpts.TTL = defaultTTL ++ } ++ ++ plugin := genkit.LookupPlugin(g, "firebase") ++ if plugin == nil { ++ return nil, errors.New("firebase.NewFirestoreStreamManager: Firebase plugin not found.\n" + ++ " Pass the Firebase plugin to genkit.Init():\n" + ++ " g := genkit.Init(ctx, genkit.WithPlugins(&firebase.Firebase{ProjectId: \"your-project\"}))") ++ } ++ f, ok := plugin.(*firebase.Firebase) ++ if !ok { ++ return nil, fmt.Errorf("firebase.NewFirestoreStreamManager: unexpected plugin type %T", plugin) ++ } ++ ++ client, err := f.Firestore(ctx) ++ if err != nil { ++ return nil, fmt.Errorf("firebase.NewFirestoreStreamManager: %w", err) ++ } ++ ++ return &FirestoreStreamManager{ ++ client: client, ++ collection: streamOpts.Collection, ++ timeout: streamOpts.Timeout, ++ ttl: streamOpts.TTL, ++ }, nil ++} ++ ++// Open creates a new stream for writing. ++// Returns ALREADY_EXISTS error if a stream with the given ID already exists. ++func (m *FirestoreStreamManager) Open(ctx context.Context, streamID string) (streaming.StreamInput, error) { ++ docRef := m.client.Collection(m.collection).Doc(streamID) ++ now := time.Now() ++ expiresAt := now.Add(m.timeout + m.ttl) ++ _, err := docRef.Create(ctx, streamDocument{ ++ Stream: []streamEntry{}, ++ CreatedAt: now, ++ UpdatedAt: now, ++ ExpiresAt: &expiresAt, ++ }) ++ if err != nil { ++ if status.Code(err) == codes.AlreadyExists { ++ return nil, core.NewPublicError(core.ALREADY_EXISTS, "stream already exists", nil) ++ } ++ return nil, err ++ } ++ return &firestoreStreamInput{ ++ manager: m, ++ streamID: streamID, ++ docRef: docRef, ++ }, nil ++} ++ ++// Subscribe subscribes to an existing stream. ++func (m *FirestoreStreamManager) Subscribe(ctx context.Context, streamID string) (<-chan streaming.StreamEvent, func(), error) { ++ docRef := m.client.Collection(m.collection).Doc(streamID) ++ ++ snapshot, err := docRef.Get(ctx) ++ if err != nil { ++ if isNotFound(err) { ++ return nil, nil, core.NewPublicError(core.NOT_FOUND, "stream not found", nil) ++ } ++ return nil, nil, err ++ } ++ if !snapshot.Exists() { ++ return nil, nil, core.NewPublicError(core.NOT_FOUND, "stream not found", nil) ++ } ++ ++ ch := make(chan streaming.StreamEvent, streamBufferSize) ++ var mu sync.Mutex ++ var lastIndex int = -1 ++ var unsubscribed bool ++ var cancelSnapshot context.CancelFunc ++ ++ snapshotCtx, cancelSnapshot := context.WithCancel(ctx) ++ ++ var timeoutTimer *time.Timer ++ resetTimeout := func() { ++ mu.Lock() ++ defer mu.Unlock() ++ if timeoutTimer != nil { ++ timeoutTimer.Stop() ++ } ++ timeoutTimer = time.AfterFunc(m.timeout, func() { ++ mu.Lock() ++ defer mu.Unlock() ++ if !unsubscribed { ++ unsubscribed = true ++ ch <- streaming.StreamEvent{ ++ Type: streaming.StreamEventError, ++ Err: core.NewPublicError(core.DEADLINE_EXCEEDED, "stream timed out", nil), ++ } ++ close(ch) ++ cancelSnapshot() ++ } ++ }) ++ } ++ ++ unsubscribe := func() { ++ mu.Lock() ++ defer mu.Unlock() ++ if !unsubscribed { ++ unsubscribed = true ++ if timeoutTimer != nil { ++ timeoutTimer.Stop() ++ } ++ close(ch) ++ cancelSnapshot() ++ } ++ } ++ ++ resetTimeout() ++ ++ go func() { ++ snapshots := docRef.Snapshots(snapshotCtx) ++ defer snapshots.Stop() ++ ++ for { ++ snap, err := snapshots.Next() ++ if err != nil { ++ mu.Lock() ++ if !unsubscribed { ++ if snapshotCtx.Err() == nil { ++ ch <- streaming.StreamEvent{ ++ Type: streaming.StreamEventError, ++ Err: err, ++ } ++ } ++ unsubscribed = true ++ if timeoutTimer != nil { ++ timeoutTimer.Stop() ++ } ++ close(ch) ++ } ++ mu.Unlock() ++ return ++ } ++ ++ resetTimeout() ++ ++ if !snap.Exists() { ++ continue ++ } ++ ++ var doc streamDocument ++ if err := snap.DataTo(&doc); err != nil { ++ mu.Lock() ++ if !unsubscribed { ++ ch <- streaming.StreamEvent{ ++ Type: streaming.StreamEventError, ++ Err: err, ++ } ++ unsubscribed = true ++ if timeoutTimer != nil { ++ timeoutTimer.Stop() ++ } ++ close(ch) ++ } ++ mu.Unlock() ++ return ++ } ++ ++ mu.Lock() ++ for i := lastIndex + 1; i < len(doc.Stream); i++ { ++ entry := doc.Stream[i] ++ switch entry.Type { ++ case streamEventChunk: ++ if !unsubscribed { ++ select { ++ case ch <- streaming.StreamEvent{Type: streaming.StreamEventChunk, Chunk: entry.Chunk}: ++ default: ++ } ++ } ++ case streamEventDone: ++ if !unsubscribed { ++ select { ++ case ch <- streaming.StreamEvent{Type: streaming.StreamEventDone, Output: entry.Output}: ++ default: ++ } ++ unsubscribed = true ++ if timeoutTimer != nil { ++ timeoutTimer.Stop() ++ } ++ close(ch) ++ } ++ mu.Unlock() ++ return ++ case streamEventError: ++ if !unsubscribed { ++ var errStatus core.StatusName = core.UNKNOWN ++ var errMsg string ++ if entry.Err != nil { ++ errMsg = entry.Err.Message ++ if entry.Err.Status != "" { ++ errStatus = core.StatusName(entry.Err.Status) ++ } ++ } ++ select { ++ case ch <- streaming.StreamEvent{ ++ Type: streaming.StreamEventError, ++ Err: core.NewPublicError(errStatus, errMsg, nil), ++ }: ++ default: ++ } ++ unsubscribed = true ++ if timeoutTimer != nil { ++ timeoutTimer.Stop() ++ } ++ close(ch) ++ } ++ mu.Unlock() ++ return ++ } ++ } ++ lastIndex = len(doc.Stream) - 1 ++ mu.Unlock() ++ } ++ }() ++ ++ return ch, unsubscribe, nil ++} ++ ++// isNotFound checks if the error is a not found error. ++func isNotFound(err error) bool { ++ if err == nil { ++ return false ++ } ++ if grpcErr, ok := status.FromError(err); ok { ++ return grpcErr.Code() == codes.NotFound ++ } ++ return false ++} ++ ++// firestoreStreamInput implements streaming.StreamInput for Firestore. ++type firestoreStreamInput struct { ++ manager *FirestoreStreamManager ++ streamID string ++ docRef *firestore.DocumentRef ++ closed bool ++ mu sync.Mutex ++} ++ ++func (s *firestoreStreamInput) Write(ctx context.Context, chunk json.RawMessage) error { ++ s.mu.Lock() ++ defer s.mu.Unlock() ++ ++ if s.closed { ++ return core.NewPublicError(core.FAILED_PRECONDITION, "stream writer is closed", nil) ++ } ++ ++ _, err := s.docRef.Update(ctx, []firestore.Update{ ++ { ++ Path: "stream", ++ Value: firestore.ArrayUnion(streamEntry{ ++ Type: streamEventChunk, ++ Chunk: chunk, ++ UUID: uuid.New().String(), ++ }), ++ }, ++ { ++ Path: "updatedAt", ++ Value: firestore.ServerTimestamp, ++ }, ++ }) ++ return err ++} ++ ++func (s *firestoreStreamInput) Done(ctx context.Context, output json.RawMessage) error { ++ s.mu.Lock() ++ defer s.mu.Unlock() ++ ++ if s.closed { ++ return core.NewPublicError(core.FAILED_PRECONDITION, "stream writer is closed", nil) ++ } ++ s.closed = true ++ ++ expiresAt := time.Now().Add(s.manager.ttl) ++ _, err := s.docRef.Update(ctx, []firestore.Update{ ++ { ++ Path: "stream", ++ Value: firestore.ArrayUnion(streamEntry{ ++ Type: streamEventDone, ++ Output: output, ++ }), ++ }, ++ { ++ Path: "updatedAt", ++ Value: firestore.ServerTimestamp, ++ }, ++ { ++ Path: "expiresAt", ++ Value: expiresAt, ++ }, ++ }) ++ return err ++} ++ ++func (s *firestoreStreamInput) Error(ctx context.Context, err error) error { ++ s.mu.Lock() ++ defer s.mu.Unlock() ++ ++ if s.closed { ++ return core.NewPublicError(core.FAILED_PRECONDITION, "stream writer is closed", nil) ++ } ++ s.closed = true ++ ++ streamErr := &streamError{ ++ Status: string(core.UNKNOWN), ++ Message: err.Error(), ++ } ++ var ufErr *core.UserFacingError ++ if errors.As(err, &ufErr) { ++ streamErr.Status = string(ufErr.Status) ++ } ++ ++ expiresAt := time.Now().Add(s.manager.ttl) ++ _, updateErr := s.docRef.Update(ctx, []firestore.Update{ ++ { ++ Path: "stream", ++ Value: firestore.ArrayUnion(streamEntry{ ++ Type: streamEventError, ++ Err: streamErr, ++ }), ++ }, ++ { ++ Path: "updatedAt", ++ Value: firestore.ServerTimestamp, ++ }, ++ { ++ Path: "expiresAt", ++ Value: expiresAt, ++ }, ++ }) ++ return updateErr ++} ++ ++func (s *firestoreStreamInput) Close() error { ++ s.mu.Lock() ++ defer s.mu.Unlock() ++ s.closed = true ++ return nil ++} +diff --git a/go/plugins/firebase/x/stream_manager_test.go b/go/plugins/firebase/x/stream_manager_test.go +new file mode 100644 +index 000000000..64cb5faac +--- /dev/null ++++ b/go/plugins/firebase/x/stream_manager_test.go +@@ -0,0 +1,564 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++ ++package x ++ ++import ( ++ "context" ++ "encoding/json" ++ "errors" ++ "flag" ++ "testing" ++ "time" ++ ++ "cloud.google.com/go/firestore" ++ "github.com/firebase/genkit/go/core" ++ "github.com/firebase/genkit/go/core/x/streaming" ++ "github.com/firebase/genkit/go/genkit" ++ "github.com/firebase/genkit/go/plugins/firebase" ++ "google.golang.org/api/iterator" ++) ++ ++var ( ++ testStreamProjectID = flag.String("test-stream-project-id", "", "GCP Project ID to use for stream manager tests") ++ testStreamCollection = flag.String("test-stream-collection", "genkit-streams", "Firestore collection to use for stream manager tests") ++) ++ ++/* ++ * Pre-requisites to run this test: ++ * ++ * 1. **Option A - Use Firestore Emulator (Recommended for local development):** ++ * Start the Firestore emulator: ++ * ```bash ++ * export FIRESTORE_EMULATOR_HOST=127.0.0.1:8080 ++ * gcloud emulators firestore start --host-port=127.0.0.1:8080 ++ * ``` ++ * ++ * 2. **Option B - Use a Real Firestore Database:** ++ * - Set up a Firebase project with Firestore enabled ++ * - Authenticate using: ++ * ```bash ++ * gcloud auth application-default login ++ * ``` ++ * ++ * 3. **Running the Test:** ++ * ```bash ++ * go test -test-stream-project-id= -test-stream-collection=genkit-streams ++ * ``` ++ */ ++ ++func skipIfNoFirestore(t *testing.T) { ++ if *testStreamProjectID == "" { ++ t.Skip("Skipping test: -test-stream-project-id flag not provided") ++ } ++} ++ ++func setupTestStreamManager(t *testing.T) (*FirestoreStreamManager, *firestore.Client, func()) { ++ skipIfNoFirestore(t) ++ ++ ctx := context.Background() ++ g := genkit.Init(ctx, genkit.WithPlugins(&firebase.Firebase{ProjectId: *testStreamProjectID})) ++ ++ f := genkit.LookupPlugin(g, "firebase").(*firebase.Firebase) ++ client, err := f.Firestore(ctx) ++ if err != nil { ++ t.Fatalf("Failed to get Firestore client: %v", err) ++ } ++ ++ manager, err := NewFirestoreStreamManager(ctx, g, ++ WithCollection(*testStreamCollection), ++ ) ++ if err != nil { ++ t.Fatalf("Failed to create stream manager: %v", err) ++ } ++ ++ cleanup := func() { ++ deleteStreamCollection(ctx, client, *testStreamCollection, t) ++ } ++ ++ return manager, client, cleanup ++} ++ ++func deleteStreamCollection(ctx context.Context, client *firestore.Client, collectionName string, t *testing.T) { ++ iter := client.Collection(collectionName).Documents(ctx) ++ for { ++ doc, err := iter.Next() ++ if err == iterator.Done { ++ break ++ } ++ if err != nil { ++ t.Logf("Failed to iterate documents for deletion: %v", err) ++ return ++ } ++ _, err = doc.Ref.Delete(ctx) ++ if err != nil { ++ t.Logf("Failed to delete document %s: %v", doc.Ref.ID, err) ++ } ++ } ++} ++ ++func TestFirestoreStreamManager_OpenDuplicateFails(t *testing.T) { ++ manager, _, cleanup := setupTestStreamManager(t) ++ defer cleanup() ++ ++ ctx := context.Background() ++ streamID := "test-stream-dup" ++ ++ _, err := manager.Open(ctx, streamID) ++ if err != nil { ++ t.Fatalf("First Open failed: %v", err) ++ } ++ ++ _, err = manager.Open(ctx, streamID) ++ if err == nil { ++ t.Fatal("Expected error when opening duplicate stream") ++ } ++ ++ publicErr, ok := err.(*core.UserFacingError) ++ if !ok { ++ t.Fatalf("Expected UserFacingError, got %T", err) ++ } ++ if publicErr.Status != core.ALREADY_EXISTS { ++ t.Errorf("Expected ALREADY_EXISTS error, got %v", publicErr.Status) ++ } ++} ++ ++func TestFirestoreStreamManager_OpenAndWrite(t *testing.T) { ++ manager, client, cleanup := setupTestStreamManager(t) ++ defer cleanup() ++ ++ ctx := context.Background() ++ streamID := "test-stream-open-write" ++ ++ stream, err := manager.Open(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Failed to open stream: %v", err) ++ } ++ ++ chunk1, _ := json.Marshal(map[string]string{"foo": "bar"}) ++ chunk2, _ := json.Marshal(map[string]string{"bar": "baz"}) ++ ++ if err := stream.Write(ctx, chunk1); err != nil { ++ t.Fatalf("Failed to write chunk 1: %v", err) ++ } ++ if err := stream.Write(ctx, chunk2); err != nil { ++ t.Fatalf("Failed to write chunk 2: %v", err) ++ } ++ ++ snapshot, err := client.Collection(*testStreamCollection).Doc(streamID).Get(ctx) ++ if err != nil { ++ t.Fatalf("Failed to get document: %v", err) ++ } ++ ++ data := snapshot.Data() ++ streamArr, ok := data["stream"].([]interface{}) ++ if !ok { ++ t.Fatalf("Expected stream array, got %T", data["stream"]) ++ } ++ ++ if len(streamArr) != 2 { ++ t.Errorf("Expected 2 stream entries, got %d", len(streamArr)) ++ } ++ ++ entry0, _ := streamArr[0].(map[string]interface{}) ++ if entry0["type"] != streamEventChunk { ++ t.Errorf("Expected type 'chunk', got %v", entry0["type"]) ++ } ++ if entry0["uuid"] == nil || entry0["uuid"] == "" { ++ t.Error("Expected uuid to be set for chunk") ++ } ++ ++ if data["expiresAt"] == nil { ++ t.Error("Expected expiresAt to be set on open (for abandoned stream cleanup)") ++ } ++} ++ ++func TestFirestoreStreamManager_PreserveDuplicateChunks(t *testing.T) { ++ manager, client, cleanup := setupTestStreamManager(t) ++ defer cleanup() ++ ++ ctx := context.Background() ++ streamID := "test-stream-dupes" ++ ++ stream, err := manager.Open(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Failed to open stream: %v", err) ++ } ++ ++ chunk, _ := json.Marshal(map[string]string{"foo": "bar"}) ++ ++ if err := stream.Write(ctx, chunk); err != nil { ++ t.Fatalf("Failed to write chunk 1: %v", err) ++ } ++ if err := stream.Write(ctx, chunk); err != nil { ++ t.Fatalf("Failed to write chunk 2: %v", err) ++ } ++ ++ snapshot, err := client.Collection(*testStreamCollection).Doc(streamID).Get(ctx) ++ if err != nil { ++ t.Fatalf("Failed to get document: %v", err) ++ } ++ ++ data := snapshot.Data() ++ streamArr, _ := data["stream"].([]interface{}) ++ ++ if len(streamArr) != 2 { ++ t.Errorf("Expected 2 stream entries (duplicates should be preserved), got %d", len(streamArr)) ++ } ++ ++ entry0, _ := streamArr[0].(map[string]interface{}) ++ entry1, _ := streamArr[1].(map[string]interface{}) ++ if entry0["uuid"] == entry1["uuid"] { ++ t.Error("UUIDs should be different for duplicate chunks") ++ } ++} ++ ++func TestFirestoreStreamManager_Done(t *testing.T) { ++ manager, client, cleanup := setupTestStreamManager(t) ++ defer cleanup() ++ ++ ctx := context.Background() ++ streamID := "test-stream-done" ++ ++ stream, err := manager.Open(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Failed to open stream: %v", err) ++ } ++ ++ output, _ := json.Marshal(map[string]string{"result": "success"}) ++ if err := stream.Done(ctx, output); err != nil { ++ t.Fatalf("Failed to mark stream done: %v", err) ++ } ++ ++ snapshot, err := client.Collection(*testStreamCollection).Doc(streamID).Get(ctx) ++ if err != nil { ++ t.Fatalf("Failed to get document: %v", err) ++ } ++ ++ data := snapshot.Data() ++ streamArr, _ := data["stream"].([]interface{}) ++ ++ if len(streamArr) != 1 { ++ t.Errorf("Expected 1 stream entry, got %d", len(streamArr)) ++ } ++ ++ entry, _ := streamArr[0].(map[string]interface{}) ++ if entry["type"] != streamEventDone { ++ t.Errorf("Expected type 'done', got %v", entry["type"]) ++ } ++ ++ if data["expiresAt"] == nil { ++ t.Error("Expected expiresAt to be set after done") ++ } ++} ++ ++func TestFirestoreStreamManager_Error(t *testing.T) { ++ manager, client, cleanup := setupTestStreamManager(t) ++ defer cleanup() ++ ++ ctx := context.Background() ++ streamID := "test-stream-error" ++ ++ stream, err := manager.Open(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Failed to open stream: %v", err) ++ } ++ ++ testError := errors.New("test error message") ++ if err := stream.Error(ctx, testError); err != nil { ++ t.Fatalf("Failed to mark stream error: %v", err) ++ } ++ ++ snapshot, err := client.Collection(*testStreamCollection).Doc(streamID).Get(ctx) ++ if err != nil { ++ t.Fatalf("Failed to get document: %v", err) ++ } ++ ++ data := snapshot.Data() ++ streamArr, _ := data["stream"].([]interface{}) ++ ++ if len(streamArr) != 1 { ++ t.Errorf("Expected 1 stream entry, got %d", len(streamArr)) ++ } ++ ++ entry, _ := streamArr[0].(map[string]interface{}) ++ if entry["type"] != streamEventError { ++ t.Errorf("Expected type 'error', got %v", entry["type"]) ++ } ++ ++ errData, _ := entry["err"].(map[string]interface{}) ++ if errData["message"] != "test error message" { ++ t.Errorf("Expected error message 'test error message', got %v", errData["message"]) ++ } ++ if errData["status"] != string(core.UNKNOWN) { ++ t.Errorf("Expected status UNKNOWN for plain error, got %v", errData["status"]) ++ } ++ ++ if data["expiresAt"] == nil { ++ t.Error("Expected expiresAt to be set after error") ++ } ++} ++ ++func TestFirestoreStreamManager_ErrorStatusPreserved(t *testing.T) { ++ manager, client, cleanup := setupTestStreamManager(t) ++ defer cleanup() ++ ++ ctx := context.Background() ++ streamID := "test-stream-error-status" ++ ++ stream, err := manager.Open(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Failed to open stream: %v", err) ++ } ++ ++ testError := core.NewPublicError(core.INVALID_ARGUMENT, "invalid input", nil) ++ if err := stream.Error(ctx, testError); err != nil { ++ t.Fatalf("Failed to mark stream error: %v", err) ++ } ++ ++ snapshot, err := client.Collection(*testStreamCollection).Doc(streamID).Get(ctx) ++ if err != nil { ++ t.Fatalf("Failed to get document: %v", err) ++ } ++ ++ data := snapshot.Data() ++ streamArr, _ := data["stream"].([]interface{}) ++ entry, _ := streamArr[0].(map[string]interface{}) ++ errData, _ := entry["err"].(map[string]interface{}) ++ ++ if errData["status"] != string(core.INVALID_ARGUMENT) { ++ t.Errorf("Expected status INVALID_ARGUMENT, got %v", errData["status"]) ++ } ++ if errData["message"] != "invalid input" { ++ t.Errorf("Expected message 'invalid input', got %v", errData["message"]) ++ } ++} ++ ++func TestFirestoreStreamManager_Subscribe(t *testing.T) { ++ manager, client, cleanup := setupTestStreamManager(t) ++ defer cleanup() ++ ++ ctx := context.Background() ++ streamID := "test-stream-subscribe" ++ ++ chunk1, _ := json.Marshal(map[string]string{"foo": "bar"}) ++ chunk2, _ := json.Marshal(map[string]string{"bar": "baz"}) ++ output, _ := json.Marshal(map[string]string{"result": "success"}) ++ ++ _, err := client.Collection(*testStreamCollection).Doc(streamID).Set(ctx, map[string]interface{}{ ++ "stream": []map[string]interface{}{ ++ {"type": "chunk", "chunk": chunk1, "uuid": "uuid1"}, ++ {"type": "chunk", "chunk": chunk2, "uuid": "uuid2"}, ++ {"type": "done", "output": output}, ++ }, ++ "createdAt": time.Now(), ++ "updatedAt": time.Now(), ++ }) ++ if err != nil { ++ t.Fatalf("Failed to create test document: %v", err) ++ } ++ ++ ch, unsubscribe, err := manager.Subscribe(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Failed to subscribe: %v", err) ++ } ++ defer unsubscribe() ++ ++ var chunks []json.RawMessage ++ var finalOutput json.RawMessage ++ timeout := time.After(5 * time.Second) ++ ++ for { ++ select { ++ case event, ok := <-ch: ++ if !ok { ++ goto verify ++ } ++ switch event.Type { ++ case streaming.StreamEventChunk: ++ chunks = append(chunks, event.Chunk) ++ case streaming.StreamEventDone: ++ finalOutput = event.Output ++ goto verify ++ case streaming.StreamEventError: ++ t.Fatalf("Unexpected error: %v", event.Err) ++ } ++ case <-timeout: ++ t.Fatal("Timeout waiting for stream events") ++ } ++ } ++ ++verify: ++ if len(chunks) != 2 { ++ t.Errorf("Expected 2 chunks, got %d", len(chunks)) ++ } ++ if finalOutput == nil { ++ t.Error("Expected final output") ++ } ++} ++ ++func TestFirestoreStreamManager_SubscribeErrorStatusPreserved(t *testing.T) { ++ manager, client, cleanup := setupTestStreamManager(t) ++ defer cleanup() ++ ++ ctx := context.Background() ++ streamID := "test-stream-sub-error-status" ++ ++ _, err := client.Collection(*testStreamCollection).Doc(streamID).Set(ctx, map[string]interface{}{ ++ "stream": []map[string]interface{}{ ++ {"type": "error", "err": map[string]interface{}{ ++ "status": string(core.INVALID_ARGUMENT), ++ "message": "bad input", ++ }}, ++ }, ++ "createdAt": time.Now(), ++ "updatedAt": time.Now(), ++ }) ++ if err != nil { ++ t.Fatalf("Failed to create test document: %v", err) ++ } ++ ++ ch, unsubscribe, err := manager.Subscribe(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Failed to subscribe: %v", err) ++ } ++ defer unsubscribe() ++ ++ timeout := time.After(5 * time.Second) ++ select { ++ case event, ok := <-ch: ++ if !ok { ++ t.Fatal("Channel closed unexpectedly") ++ } ++ if event.Type != streaming.StreamEventError { ++ t.Fatalf("Expected error event, got %v", event.Type) ++ } ++ publicErr, ok := event.Err.(*core.UserFacingError) ++ if !ok { ++ t.Fatalf("Expected UserFacingError, got %T", event.Err) ++ } ++ if publicErr.Status != core.INVALID_ARGUMENT { ++ t.Errorf("Expected INVALID_ARGUMENT status, got %v", publicErr.Status) ++ } ++ case <-timeout: ++ t.Fatal("Timeout waiting for error event") ++ } ++} ++ ++func TestFirestoreStreamManager_SubscribeNotFound(t *testing.T) { ++ manager, _, cleanup := setupTestStreamManager(t) ++ defer cleanup() ++ ++ ctx := context.Background() ++ _, _, err := manager.Subscribe(ctx, "non-existent-stream") ++ if err == nil { ++ t.Fatal("Expected error for non-existent stream") ++ } ++ ++ publicErr, ok := err.(*core.UserFacingError) ++ if !ok { ++ t.Fatalf("Expected UserFacingError, got %T", err) ++ } ++ if publicErr.Status != core.NOT_FOUND { ++ t.Errorf("Expected NOT_FOUND error, got %v", publicErr.Status) ++ } ++} ++ ++func TestFirestoreStreamManager_Timeout(t *testing.T) { ++ skipIfNoFirestore(t) ++ ++ ctx := context.Background() ++ g := genkit.Init(ctx, genkit.WithPlugins(&firebase.Firebase{ProjectId: *testStreamProjectID})) ++ ++ f := genkit.LookupPlugin(g, "firebase").(*firebase.Firebase) ++ client, err := f.Firestore(ctx) ++ if err != nil { ++ t.Fatalf("Failed to get Firestore client: %v", err) ++ } ++ defer deleteStreamCollection(ctx, client, *testStreamCollection, t) ++ ++ manager, err := NewFirestoreStreamManager(ctx, g, ++ WithCollection(*testStreamCollection), ++ WithTimeout(100*time.Millisecond), ++ ) ++ if err != nil { ++ t.Fatalf("Failed to create stream manager: %v", err) ++ } ++ ++ streamID := "test-stream-timeout" ++ ++ _, err = manager.Open(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Failed to open stream: %v", err) ++ } ++ ++ ch, _, err := manager.Subscribe(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Failed to subscribe: %v", err) ++ } ++ ++ timeout := time.After(2 * time.Second) ++ for { ++ select { ++ case event, ok := <-ch: ++ if !ok { ++ t.Fatal("Channel closed without timeout error") ++ return ++ } ++ if event.Type == streaming.StreamEventError { ++ publicErr, ok := event.Err.(*core.UserFacingError) ++ if !ok { ++ t.Fatalf("Expected UserFacingError, got %T", event.Err) ++ } ++ if publicErr.Status != core.DEADLINE_EXCEEDED { ++ t.Errorf("Expected DEADLINE_EXCEEDED, got %v", publicErr.Status) ++ } ++ return ++ } ++ case <-timeout: ++ t.Fatal("Test timeout - stream timeout didn't trigger") ++ } ++ } ++} ++ ++func TestFirestoreStreamManager_WriteAfterClose(t *testing.T) { ++ manager, _, cleanup := setupTestStreamManager(t) ++ defer cleanup() ++ ++ ctx := context.Background() ++ streamID := "test-stream-write-after-close" ++ ++ stream, err := manager.Open(ctx, streamID) ++ if err != nil { ++ t.Fatalf("Failed to open stream: %v", err) ++ } ++ ++ if err := stream.Close(); err != nil { ++ t.Fatalf("Failed to close stream: %v", err) ++ } ++ ++ chunk, _ := json.Marshal(map[string]string{"foo": "bar"}) ++ err = stream.Write(ctx, chunk) ++ if err == nil { ++ t.Fatal("Expected error when writing after close") ++ } ++ ++ publicErr, ok := err.(*core.UserFacingError) ++ if !ok { ++ t.Fatalf("Expected UserFacingError, got %T", err) ++ } ++ if publicErr.Status != core.FAILED_PRECONDITION { ++ t.Errorf("Expected FAILED_PRECONDITION, got %v", publicErr.Status) ++ } ++} +diff --git a/go/plugins/googlegenai/gemini.go b/go/plugins/googlegenai/gemini.go +index c7e49a8be..c3ca5297b 100644 +--- a/go/plugins/googlegenai/gemini.go ++++ b/go/plugins/googlegenai/gemini.go +@@ -484,6 +484,34 @@ func toGeminiTools(inTools []*ai.ToolDefinition) ([]*genai.Tool, error) { + return outTools, nil + } + ++// toGeminiFunctionResponsePart translates a slice of [ai.Part] to a slice of [genai.FunctionResponsePart] ++func toGeminiFunctionResponsePart(parts []*ai.Part) ([]*genai.FunctionResponsePart, error) { ++ frp := []*genai.FunctionResponsePart{} ++ for _, p := range parts { ++ switch { ++ case p.IsData(): ++ contentType, data, err := uri.Data(p) ++ if err != nil { ++ return nil, err ++ } ++ frp = append(frp, genai.NewFunctionResponsePartFromBytes(data, contentType)) ++ case p.IsMedia(): ++ if strings.HasPrefix(p.Text, "data:") { ++ contentType, data, err := uri.Data(p) ++ if err != nil { ++ return nil, err ++ } ++ frp = append(frp, genai.NewFunctionResponsePartFromBytes(data, contentType)) ++ continue ++ } ++ frp = append(frp, genai.NewFunctionResponsePartFromURI(p.Text, p.ContentType)) ++ default: ++ return nil, fmt.Errorf("unsupported function response part type: %d", p.Kind) ++ } ++ } ++ return frp, nil ++} ++ + // mergeTools consolidates all FunctionDeclarations into a single Tool + // while preserving non-function tools (Retrieval, GoogleSearch, CodeExecution, etc.) + func mergeTools(ts []*genai.Tool) []*genai.Tool { +@@ -807,6 +835,7 @@ func translateCandidate(cand *genai.Candidate) (*ai.ModelResponse, error) { + if part.FileData != nil { + partFound++ + p = ai.NewMediaPart(part.FileData.MIMEType, part.FileData.FileURI) ++ + } + if part.FunctionCall != nil { + partFound++ +@@ -814,6 +843,14 @@ func translateCandidate(cand *genai.Candidate) (*ai.ModelResponse, error) { + Name: part.FunctionCall.Name, + Input: part.FunctionCall.Args, + }) ++ // FunctionCall parts may contain a ThoughtSignature that must be preserved ++ // and returned in subsequent requests for the tool call to be valid. ++ if len(part.ThoughtSignature) > 0 { ++ if p.Metadata == nil { ++ p.Metadata = make(map[string]any) ++ } ++ p.Metadata["signature"] = part.ThoughtSignature ++ } + } + if part.CodeExecutionResult != nil { + partFound++ +@@ -836,6 +873,13 @@ func translateCandidate(cand *genai.Candidate) (*ai.ModelResponse, error) { + continue + } + ++ if len(part.ThoughtSignature) > 0 { ++ if p.Metadata == nil { ++ p.Metadata = make(map[string]any) ++ } ++ p.Metadata["signature"] = part.ThoughtSignature ++ } ++ + msg.Content = append(msg.Content, p) + } + m.Message = msg +@@ -892,37 +936,29 @@ func toGeminiParts(parts []*ai.Part) ([]*genai.Part, error) { + + // toGeminiPart converts a [ai.Part] to a [genai.Part]. + func toGeminiPart(p *ai.Part) (*genai.Part, error) { ++ var gp *genai.Part + switch { + case p.IsReasoning(): +- // TODO: go-genai does not support genai.NewPartFromThought() +- signature := []byte{} +- if p.Metadata != nil { +- if sig, ok := p.Metadata["signature"].([]byte); ok { +- signature = sig +- } +- } +- return &genai.Part{ +- Thought: true, +- Text: p.Text, +- ThoughtSignature: signature, +- }, nil ++ gp = genai.NewPartFromText(p.Text) ++ gp.Thought = true + case p.IsText(): +- return genai.NewPartFromText(p.Text), nil ++ gp = genai.NewPartFromText(p.Text) + case p.IsMedia(): + if strings.HasPrefix(p.Text, "data:") { + contentType, data, err := uri.Data(p) + if err != nil { + return nil, err + } +- return genai.NewPartFromBytes(data, contentType), nil ++ gp = genai.NewPartFromBytes(data, contentType) ++ } else { ++ gp = genai.NewPartFromURI(p.Text, p.ContentType) + } +- return genai.NewPartFromURI(p.Text, p.ContentType), nil + case p.IsData(): + contentType, data, err := uri.Data(p) + if err != nil { + return nil, err + } +- return genai.NewPartFromBytes(data, contentType), nil ++ gp = genai.NewPartFromBytes(data, contentType) + case p.IsToolResponse(): + toolResp := p.ToolResponse + var output map[string]any +@@ -934,8 +970,21 @@ func toGeminiPart(p *ai.Part) (*genai.Part, error) { + "content": toolResp.Output, + } + } +- fr := genai.NewPartFromFunctionResponse(toolResp.Name, output) +- return fr, nil ++ var isMultipart bool ++ if multiPart, ok := p.Metadata["multipart"].(bool); ok { ++ isMultipart = multiPart ++ } ++ if len(toolResp.Content) > 0 { ++ isMultipart = true ++ } ++ if isMultipart { ++ toolRespParts, err := toGeminiFunctionResponsePart(toolResp.Content) ++ if err != nil { ++ return nil, err ++ } ++ return genai.NewPartFromFunctionResponseWithParts(toolResp.Name, output, toolRespParts), nil ++ } ++ return genai.NewPartFromFunctionResponse(toolResp.Name, output), nil + case p.IsToolRequest(): + toolReq := p.ToolRequest + var input map[string]any +@@ -947,10 +996,24 @@ func toGeminiPart(p *ai.Part) (*genai.Part, error) { + } + } + fc := genai.NewPartFromFunctionCall(toolReq.Name, input) ++ // Restore ThoughtSignature if present in metadata ++ if p.Metadata != nil { ++ if sig, ok := p.Metadata["signature"].([]byte); ok { ++ fc.ThoughtSignature = sig ++ } ++ } + return fc, nil + default: +- panic("unknown part type in a request") ++ return nil, fmt.Errorf("unknown part in the request: %q", p.Kind) + } ++ ++ if p.Metadata != nil { ++ if sig, ok := p.Metadata["signature"].([]byte); ok { ++ gp.ThoughtSignature = sig ++ } ++ } ++ ++ return gp, nil + } + + // validToolName checks whether the provided tool name matches the +diff --git a/go/plugins/googlegenai/gemini_test.go b/go/plugins/googlegenai/gemini_test.go +index daa4da215..9aae76a05 100644 +--- a/go/plugins/googlegenai/gemini_test.go ++++ b/go/plugins/googlegenai/gemini_test.go +@@ -707,6 +707,82 @@ func TestValidToolName(t *testing.T) { + } + } + ++func TestToGeminiParts_MultipartToolResponse(t *testing.T) { ++ t.Run("ValidPartType", func(t *testing.T) { ++ // Create a tool response with both output and additional content (media) ++ toolResp := &ai.ToolResponse{ ++ Name: "generateImage", ++ Output: map[string]any{"status": "success"}, ++ Content: []*ai.Part{ ++ ai.NewMediaPart("image/png", "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="), ++ }, ++ } ++ ++ // create a mock ToolResponsePart, setting "multipart" to true is required ++ part := ai.NewToolResponsePart(toolResp) ++ part.Metadata = map[string]any{"multipart": true} ++ ++ geminiParts, err := toGeminiParts([]*ai.Part{part}) ++ if err != nil { ++ t.Fatalf("toGeminiParts failed: %v", err) ++ } ++ ++ // Expecting 1 part which contains the function response with internal parts ++ if len(geminiParts) != 1 { ++ t.Fatalf("expected 1 Gemini part, got %d", len(geminiParts)) ++ } ++ ++ if geminiParts[0].FunctionResponse == nil { ++ t.Error("expected first part to be FunctionResponse") ++ } ++ if geminiParts[0].FunctionResponse.Name != "generateImage" { ++ t.Errorf("expected function name 'generateImage', got %q", geminiParts[0].FunctionResponse.Name) ++ } ++ }) ++ ++ t.Run("UnsupportedPartType", func(t *testing.T) { ++ // Create a tool response with text content (unsupported for multipart) ++ toolResp := &ai.ToolResponse{ ++ Name: "generateText", ++ Output: map[string]any{"status": "success"}, ++ Content: []*ai.Part{ ++ ai.NewTextPart("Generated text"), ++ }, ++ } ++ ++ part := ai.NewToolResponsePart(toolResp) ++ part.Metadata = map[string]any{"multipart": true} ++ ++ _, err := toGeminiParts([]*ai.Part{part}) ++ if err == nil { ++ t.Fatal("expected error for unsupported text part in multipart response, got nil") ++ } ++ }) ++} ++ ++func TestToGeminiParts_SimpleToolResponse(t *testing.T) { ++ // Create a simple tool response (no content) ++ toolResp := &ai.ToolResponse{ ++ Name: "search", ++ Output: map[string]any{"result": "foo"}, ++ } ++ ++ part := ai.NewToolResponsePart(toolResp) ++ ++ geminiParts, err := toGeminiParts([]*ai.Part{part}) ++ if err != nil { ++ t.Fatalf("toGeminiParts failed: %v", err) ++ } ++ ++ if len(geminiParts) != 1 { ++ t.Fatalf("expected 1 Gemini part, got %d", len(geminiParts)) ++ } ++ ++ if geminiParts[0].FunctionResponse == nil { ++ t.Error("expected part to be FunctionResponse") ++ } ++} ++ + // genToolName generates a string of a specified length using only + // the valid characters for a Gemini Tool name + func genToolName(length int, chars string) string { +diff --git a/go/plugins/googlegenai/googleai_live_test.go b/go/plugins/googlegenai/googleai_live_test.go +index 4e78ab4f4..783eccd23 100644 +--- a/go/plugins/googlegenai/googleai_live_test.go ++++ b/go/plugins/googlegenai/googleai_live_test.go +@@ -170,7 +170,7 @@ func TestGoogleAILive(t *testing.T) { + t.Fatal(err) + } + +- out := resp.Message.Content[0].Text ++ out := resp.Text() + const want = "11.31" + if !strings.Contains(out, want) { + t.Errorf("got %q, expecting it to contain %q", out, want) +@@ -219,7 +219,7 @@ func TestGoogleAILive(t *testing.T) { + t.Fatal(err) + } + +- out := resp.Message.Content[0].Text ++ out := resp.Text() + const want = "11.31" + if !strings.Contains(out, want) { + t.Errorf("got %q, expecting it to contain %q", out, want) +@@ -307,7 +307,7 @@ func TestGoogleAILive(t *testing.T) { + t.Fatal(err) + } + +- out := resp.Message.Content[0].Text ++ out := resp.Text() + const doNotWant = "11.31" + if strings.Contains(out, doNotWant) { + t.Errorf("got %q, expecting it NOT to contain %q", out, doNotWant) +@@ -582,6 +582,37 @@ func TestGoogleAILive(t *testing.T) { + t.Fatal("thoughts tokens should be zero") + } + }) ++ t.Run("multipart tool", func(t *testing.T) { ++ m := googlegenai.GoogleAIModel(g, "gemini-3-pro-preview") ++ img64, err := fetchImgAsBase64() ++ if err != nil { ++ t.Fatal(err) ++ } ++ ++ tool := genkit.DefineMultipartTool(g, "getImage", "returns a misterious image", ++ func(ctx *ai.ToolContext, input any) (*ai.MultipartToolResponse, error) { ++ return &ai.MultipartToolResponse{ ++ Output: map[string]any{"status": "success"}, ++ Content: []*ai.Part{ ++ ai.NewMediaPart("image/jpeg", "data:image/jpeg;base64,"+img64), ++ }, ++ }, nil ++ }, ++ ) ++ ++ resp, err := genkit.Generate(ctx, g, ++ ai.WithModel(m), ++ ai.WithTools(tool), ++ ai.WithPrompt("get an image and tell me what is in it"), ++ ) ++ if err != nil { ++ t.Fatal(err) ++ } ++ ++ if !strings.Contains(strings.ToLower(resp.Text()), "cat") { ++ t.Errorf("expected response to contain 'cat', got: %s", resp.Text()) ++ } ++ }) + } + + func TestCacheHelper(t *testing.T) { +diff --git a/go/plugins/googlegenai/googlegenai.go b/go/plugins/googlegenai/googlegenai.go +index d056e6fb1..8ddbdfbe4 100644 +--- a/go/plugins/googlegenai/googlegenai.go ++++ b/go/plugins/googlegenai/googlegenai.go +@@ -283,14 +283,19 @@ func (v *VertexAI) IsDefinedEmbedder(g *genkit.Genkit, name string) bool { + return genkit.LookupEmbedder(g, api.NewName(vertexAIProvider, name)) != nil + } + +-// GoogleAIModelRef creates a new ModelRef for a Google AI model with the given name and configuration. +-func GoogleAIModelRef(name string, config *genai.GenerateContentConfig) ai.ModelRef { ++// ModelRef creates a new ModelRef for a Google Gen AI model with the given name and configuration. ++func ModelRef(name string, config *genai.GenerateContentConfig) ai.ModelRef { + return ai.NewModelRef(googleAIProvider+"/"+name, config) + } + +-// VertexAIModelRef creates a new ModelRef for a Vertex AI model with the given name and configuration. +-func VertexAIModelRef(name string, config *genai.GenerateContentConfig) ai.ModelRef { +- return ai.NewModelRef(vertexAIProvider+"/"+name, config) ++// GoogleAIModelRef creates a new ModelRef for a Google AI model with the given ID and configuration. ++func GoogleAIModelRef(id string, config *genai.GenerateContentConfig) ai.ModelRef { ++ return ai.NewModelRef(googleAIProvider+"/"+id, config) ++} ++ ++// VertexAIModelRef creates a new ModelRef for a Vertex AI model with the given ID and configuration. ++func VertexAIModelRef(id string, config *genai.GenerateContentConfig) ai.ModelRef { ++ return ai.NewModelRef(vertexAIProvider+"/"+id, config) + } + + // GoogleAIModel returns the [ai.Model] with the given name. +diff --git a/go/samples/basic-gemini-with-context/main.go b/go/samples/basic-gemini-with-context/main.go +deleted file mode 100644 +index f971ecc9b..000000000 +--- a/go/samples/basic-gemini-with-context/main.go ++++ /dev/null +@@ -1,54 +0,0 @@ +-// Copyright 2025 Google LLC +-// +-// Licensed under the Apache License, Version 2.0 (the "License"); +-// you may not use this file except in compliance with the License. +-// You may obtain a copy of the License at +-// +-// http://www.apache.org/licenses/LICENSE-2.0 +-// +-// Unless required by applicable law or agreed to in writing, software +-// distributed under the License is distributed on an "AS IS" BASIS, +-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-// See the License for the specific language governing permissions and +-// limitations under the License. +- +-package main +- +-import ( +- "context" +- "fmt" +- +- "github.com/firebase/genkit/go/ai" +- "github.com/firebase/genkit/go/genkit" +- "github.com/firebase/genkit/go/plugins/googlegenai" +- "google.golang.org/genai" +-) +- +-func main() { +- ctx := context.Background() +- +- // Initialize Genkit with the Google AI plugin. When you pass nil for the +- // Config parameter, the Google AI plugin will get the API key from the +- // GEMINI_API_KEY or GOOGLE_API_KEY environment variable, which is the recommended +- // practice. +- g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) +- +- // Define a simple flow that generates jokes about a given topic with a context of bananas +- genkit.DefineFlow(g, "contextFlow", func(ctx context.Context, input string) (string, error) { +- resp, err := genkit.Generate(ctx, g, +- ai.WithModelName("googleai/gemini-2.5-flash"), +- ai.WithConfig(&genai.GenerateContentConfig{ +- Temperature: genai.Ptr[float32](1.0), +- }), +- ai.WithPrompt(fmt.Sprintf(`Tell silly short jokes about %s`, input)), +- ai.WithDocs(ai.DocumentFromText("Bananas are plentiful in the tropics.", nil))) +- if err != nil { +- return "", err +- } +- +- text := resp.Text() +- return text, nil +- }) +- +- <-ctx.Done() +-} +diff --git a/go/samples/basic-prompts/main.go b/go/samples/basic-prompts/main.go +new file mode 100644 +index 000000000..ccbc308a1 +--- /dev/null ++++ b/go/samples/basic-prompts/main.go +@@ -0,0 +1,287 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++ ++package main ++ ++import ( ++ "context" ++ "fmt" ++ "log" ++ "net/http" ++ ++ "github.com/firebase/genkit/go/ai" ++ "github.com/firebase/genkit/go/core" ++ "github.com/firebase/genkit/go/genkit" ++ "github.com/firebase/genkit/go/plugins/googlegenai" ++ "github.com/firebase/genkit/go/plugins/server" ++ "google.golang.org/genai" ++) ++ ++type JokeRequest struct { ++ Topic string `json:"topic" jsonschema:"default=airplane food"` ++} ++ ++// Note how the fields are annotated with jsonschema tags to describe the output schema. ++// This is vital for the model to understand the intent of the fields. ++type Joke struct { ++ Joke string `json:"joke" jsonschema:"description=The joke text"` ++ Category string `json:"category" jsonschema:"description=The joke category"` ++} ++ ++type RecipeRequest struct { ++ Dish string `json:"dish" jsonschema:"default=pasta"` ++ Cuisine string `json:"cuisine" jsonschema:"default=Italian"` ++ ServingSize int `json:"servingSize" jsonschema:"default=4"` ++ MaxPrepMinutes int `json:"maxPrepMinutes" jsonschema:"default=30"` ++ DietaryRestrictions []string `json:"dietaryRestrictions,omitempty"` ++} ++ ++type Ingredient struct { ++ Name string `json:"name" jsonschema:"description=The ingredient name"` ++ Amount string `json:"amount" jsonschema:"description=The ingredient amount (e.g. 1 cup, 2 tablespoons, etc.)"` ++ Optional bool `json:"optional,omitempty" jsonschema:"description=Whether the ingredient is optional in the recipe"` ++} ++ ++type Recipe struct { ++ Title string `json:"title" jsonschema:"description=The recipe title (e.g. 'Spicy Chicken Tacos')"` ++ Description string `json:"description,omitempty" jsonschema:"description=The recipe description (under 100 characters)"` ++ Ingredients []*Ingredient `json:"ingredients" jsonschema:"description=The recipe ingredients (group by type and order by importance)"` ++ Instructions []string `json:"instructions" jsonschema:"description=The recipe instructions (step by step)"` ++ PrepTime string `json:"prepTime" jsonschema:"description=The recipe preparation time (e.g. 10 minutes, 30 minutes, etc.)"` ++ Difficulty string `json:"difficulty" jsonschema:"enum=easy,enum=medium,enum=hard"` ++} ++ ++func main() { ++ ctx := context.Background() ++ ++ // Initialize Genkit with the Google AI plugin. When you pass nil for the ++ // Config parameter, the Google AI plugin will get the API key from the ++ // GEMINI_API_KEY or GOOGLE_API_KEY environment variable, which is the recommended ++ // practice. ++ g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) ++ ++ // Define schemas for the expected input and output types so that the Dotprompt files can reference them. ++ // Alternatively, you can specify the JSON schema by hand in the Dotprompt metadata. ++ // Code-defined prompts do not need to have schemas defined in advance but they too can reference them. ++ genkit.DefineSchemaFor[JokeRequest](g) ++ genkit.DefineSchemaFor[Joke](g) ++ genkit.DefineSchemaFor[RecipeRequest](g) ++ genkit.DefineSchemaFor[Recipe](g) ++ ++ // TODO: Include partials and helpers. ++ ++ // Define the prompts and flows. ++ DefineSimpleJokeWithInlinePrompt(g) ++ DefineSimpleJokeWithDotprompt(g) ++ DefineStructuredJokeWithInlinePrompt(g) ++ DefineStructuredJokeWithDotprompt(g) ++ DefineRecipeWithInlinePrompt(g) ++ DefineRecipeWithDotprompt(g) ++ ++ // Optionally, start a web server to make the flows callable via HTTP. ++ mux := http.NewServeMux() ++ for _, a := range genkit.ListFlows(g) { ++ mux.HandleFunc("POST /"+a.Name(), genkit.Handler(a)) ++ } ++ log.Fatal(server.Start(ctx, "127.0.0.1:8080", mux)) ++} ++ ++// DefineSimpleJokeWithInlinePrompt demonstrates defining a prompt in code using DefinePrompt. ++// The prompt has no output schema defined so it will always return a string. ++// When executing the prompt, we pass in a map[string]any with the input fields. ++func DefineSimpleJokeWithInlinePrompt(g *genkit.Genkit) { ++ jokePrompt := genkit.DefinePrompt( ++ g, "joke.code", ++ ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ ++ ThinkingConfig: &genai.ThinkingConfig{ ++ ThinkingBudget: genai.Ptr[int32](0), ++ }, ++ })), ++ // Despite JokeRequest having defaults set in jsonschema tags, we can override it with values set in WithInputType. ++ ai.WithInputType(JokeRequest{Topic: "rush hour traffic"}), ++ ai.WithPrompt("Share a long joke about {{topic}}."), ++ ) ++ ++ genkit.DefineStreamingFlow(g, "simpleJokePromptFlow", ++ func(ctx context.Context, topic string, sendChunk core.StreamCallback[string]) (string, error) { ++ // One way to pass input is using a map[string]any. This is useful when there is no structured input type. ++ stream := jokePrompt.ExecuteStream(ctx, ai.WithInput(map[string]any{"topic": topic})) ++ for result, err := range stream { ++ if err != nil { ++ return "", fmt.Errorf("could not generate joke: %w", err) ++ } ++ if result.Done { ++ return result.Response.Text(), nil ++ } ++ sendChunk(ctx, result.Chunk.Text()) ++ } ++ ++ return "", nil ++ }, ++ ) ++} ++ ++// DefineSimpleJokeWithDotprompt demonstrates loading a prompt from a .prompt file using ++// LoadPrompt. The prompt configuration (model, input schema, defaults) is defined in the ++// file. Input is passed as a map since the .prompt file defines its own schema. ++func DefineSimpleJokeWithDotprompt(g *genkit.Genkit) { ++ genkit.DefineStreamingFlow(g, "simpleJokeDotpromptFlow", ++ func(ctx context.Context, topic string, sendChunk core.StreamCallback[string]) (string, error) { ++ jokePrompt := genkit.LookupPrompt(g, "joke") ++ // One way to pass input is using a map[string]any. This is useful when there is no structured input type. ++ stream := jokePrompt.ExecuteStream(ctx, ai.WithInput(map[string]any{"topic": topic})) ++ for result, err := range stream { ++ if err != nil { ++ return "", fmt.Errorf("could not generate joke: %w", err) ++ } ++ if result.Done { ++ return result.Response.Text(), nil ++ } ++ sendChunk(ctx, result.Chunk.Text()) ++ } ++ ++ return "", nil ++ }, ++ ) ++} ++ ++// DefineStructuredJokeWithInlinePrompt demonstrates DefineDataPrompt for strongly-typed ++// input and output. The type parameters automatically configure input/output schemas ++// and JSON output format. ExecuteStream returns typed chunks and final output. ++func DefineStructuredJokeWithInlinePrompt(g *genkit.Genkit) { ++ jokePrompt := genkit.DefineDataPrompt[JokeRequest, *Joke]( ++ g, "structured-joke.code", ++ ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ ++ ThinkingConfig: &genai.ThinkingConfig{ ++ ThinkingBudget: genai.Ptr[int32](0), ++ }, ++ })), ++ ai.WithPrompt("Share a long joke about {{topic}}."), ++ ) ++ ++ genkit.DefineStreamingFlow(g, "structuredJokePromptFlow", ++ func(ctx context.Context, input JokeRequest, sendChunk core.StreamCallback[*Joke]) (*Joke, error) { ++ for result, err := range jokePrompt.ExecuteStream(ctx, input) { ++ if err != nil { ++ return nil, fmt.Errorf("could not generate joke: %w", err) ++ } ++ if result.Done { ++ return result.Output, nil ++ } ++ sendChunk(ctx, result.Chunk) ++ } ++ ++ return nil, nil ++ }, ++ ) ++} ++ ++// DefineStructuredJokeWithDotprompt demonstrates LookupDataPrompt to wrap a .prompt file ++// with Go type information. The .prompt file references registered schemas by name ++// (e.g., "schema: Joke"), which must be defined via DefineSchemaFor before loading. ++func DefineStructuredJokeWithDotprompt(g *genkit.Genkit) { ++ genkit.DefineStreamingFlow(g, "structuredJokeDotpromptFlow", ++ func(ctx context.Context, input JokeRequest, sendChunk core.StreamCallback[*Joke]) (*Joke, error) { ++ jokePrompt := genkit.LookupDataPrompt[JokeRequest, *Joke](g, "structured-joke") ++ stream := jokePrompt.ExecuteStream(ctx, input) ++ for result, err := range stream { ++ if err != nil { ++ return nil, fmt.Errorf("could not generate joke: %w", err) ++ } ++ if result.Done { ++ return result.Output, nil ++ } ++ sendChunk(ctx, result.Chunk) ++ } ++ return nil, nil ++ }, ++ ) ++} ++ ++// DefineRecipeWithInlinePrompt demonstrates DefineDataPrompt with complex nested types ++// and Handlebars conditionals/loops in the prompt template. The streaming flow applies ++// default values before execution and streams partial ingredients as they arrive. ++func DefineRecipeWithInlinePrompt(g *genkit.Genkit) { ++ recipePrompt := genkit.DefineDataPrompt[RecipeRequest, *Recipe]( ++ g, "recipe.code", ++ ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ ++ ThinkingConfig: &genai.ThinkingConfig{ ++ ThinkingBudget: genai.Ptr[int32](0), ++ }, ++ })), ++ ai.WithSystem("You are an experienced chef. Come up with easy, creative recipes."), ++ ai.WithPrompt("Create a {{cuisine}} {{dish}} recipe for {{servingSize}} people that takes under {{maxPrepMinutes}} minutes to prepare. "+ ++ "{{#if dietaryRestrictions}}Dietary restrictions: {{#each dietaryRestrictions}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}.{{/if}}"), ++ ) ++ ++ genkit.DefineStreamingFlow(g, "recipePromptFlow", ++ func(ctx context.Context, input RecipeRequest, sendChunk core.StreamCallback[*Ingredient]) (*Recipe, error) { ++ // This is not necessary for this example but it shows how to easily have more control over what you stream. ++ filterNew := newIngredientFilter() ++ for result, err := range recipePrompt.ExecuteStream(ctx, input) { ++ if err != nil { ++ return nil, fmt.Errorf("could not generate recipe: %w", err) ++ } ++ if result.Done { ++ return result.Output, nil ++ } ++ for _, i := range filterNew(result.Chunk.Ingredients) { ++ sendChunk(ctx, i) ++ } ++ } ++ return nil, nil ++ }, ++ ) ++} ++ ++// DefineRecipeWithDotprompt demonstrates LookupDataPrompt with a .prompt file that uses ++// multi-message format (system/user roles) and references registered schemas. ++// Streams partial ingredients as they arrive via ExecuteStream. ++func DefineRecipeWithDotprompt(g *genkit.Genkit) { ++ genkit.DefineStreamingFlow(g, "recipeDotpromptFlow", ++ func(ctx context.Context, input RecipeRequest, sendChunk core.StreamCallback[*Ingredient]) (*Recipe, error) { ++ // This is not necessary for this example but it shows how to easily have more control over what you stream. ++ filterNew := newIngredientFilter() ++ recipePrompt := genkit.LookupDataPrompt[RecipeRequest, *Recipe](g, "recipe") ++ stream := recipePrompt.ExecuteStream(ctx, input) ++ for result, err := range stream { ++ if err != nil { ++ return nil, fmt.Errorf("could not generate recipe: %w", err) ++ } ++ if result.Done { ++ return result.Output, nil ++ } ++ for _, i := range filterNew(result.Chunk.Ingredients) { ++ sendChunk(ctx, i) ++ } ++ } ++ return nil, nil ++ }, ++ ) ++} ++ ++// newIngredientFilter is a helper function to filter out duplicate ingredients. ++// This allows us to stream only new ingredients as they are identified, avoiding duplicates. ++func newIngredientFilter() func([]*Ingredient) []*Ingredient { ++ seen := map[string]struct{}{} ++ return func(ings []*Ingredient) (newIngs []*Ingredient) { ++ for _, ing := range ings { ++ if _, ok := seen[ing.Name]; !ok { ++ seen[ing.Name] = struct{}{} ++ newIngs = append(newIngs, ing) ++ } ++ } ++ return ++ } ++} +diff --git a/go/samples/basic-prompts/prompts/joke.prompt b/go/samples/basic-prompts/prompts/joke.prompt +new file mode 100644 +index 000000000..fc1add095 +--- /dev/null ++++ b/go/samples/basic-prompts/prompts/joke.prompt +@@ -0,0 +1,13 @@ ++--- ++model: googleai/gemini-2.5-flash ++config: ++ thinkingConfig: ++ thinkingBudget: 0 ++input: ++ schema: ++ topic?: string ++ default: ++ topic: airplane food ++--- ++Share a long joke about {{topic}}. ++ +diff --git a/go/samples/basic-prompts/prompts/recipe.prompt b/go/samples/basic-prompts/prompts/recipe.prompt +new file mode 100644 +index 000000000..d132ba615 +--- /dev/null ++++ b/go/samples/basic-prompts/prompts/recipe.prompt +@@ -0,0 +1,20 @@ ++--- ++model: googleai/gemini-2.5-flash ++config: ++ thinkingConfig: ++ thinkingBudget: 0 ++input: ++ schema: RecipeRequest ++output: ++ format: json ++ schema: Recipe ++--- ++{{role "system"}} ++You are an experienced chef. Come up with easy, creative recipes. ++ ++{{role "user"}} ++Create a {{cuisine}} {{dish}} recipe for {{servingSize}} people that takes under {{maxPrepMinutes}} minutes to prepare. ++{{#if dietaryRestrictions}} ++Dietary restrictions: {{#each dietaryRestrictions}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}. ++{{/if}} ++ +diff --git a/go/samples/basic-prompts/prompts/structured-joke.prompt b/go/samples/basic-prompts/prompts/structured-joke.prompt +new file mode 100644 +index 000000000..7184b1548 +--- /dev/null ++++ b/go/samples/basic-prompts/prompts/structured-joke.prompt +@@ -0,0 +1,13 @@ ++--- ++model: googleai/gemini-2.5-flash ++config: ++ thinkingConfig: ++ thinkingBudget: 0 ++input: ++ schema: JokeRequest ++output: ++ format: json ++ schema: Joke ++--- ++Share a long joke about {{topic}}. ++ +diff --git a/go/samples/basic-structured/main.go b/go/samples/basic-structured/main.go +new file mode 100644 +index 000000000..428636de4 +--- /dev/null ++++ b/go/samples/basic-structured/main.go +@@ -0,0 +1,181 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++ ++package main ++ ++import ( ++ "context" ++ "fmt" ++ "log" ++ "net/http" ++ ++ "github.com/firebase/genkit/go/ai" ++ "github.com/firebase/genkit/go/core" ++ "github.com/firebase/genkit/go/genkit" ++ "github.com/firebase/genkit/go/plugins/googlegenai" ++ "github.com/firebase/genkit/go/plugins/server" ++ "google.golang.org/genai" ++) ++ ++type JokeRequest struct { ++ Topic string `json:"topic" jsonschema:"default=airplane food"` ++} ++ ++// Note how the fields are annotated with jsonschema tags to describe the output schema. ++// This is vital for the model to understand the intent of the fields. ++type Joke struct { ++ Joke string `json:"joke" jsonschema:"description=The joke text"` ++ Category string `json:"category" jsonschema:"description=The joke category"` ++} ++ ++type RecipeRequest struct { ++ Dish string `json:"dish" jsonschema:"default=pasta"` ++ Cuisine string `json:"cuisine" jsonschema:"default=Italian"` ++ ServingSize int `json:"servingSize" jsonschema:"default=4"` ++ MaxPrepMinutes int `json:"maxPrepMinutes" jsonschema:"default=30"` ++ DietaryRestrictions []string `json:"dietaryRestrictions,omitempty"` ++} ++ ++type Ingredient struct { ++ Name string `json:"name" jsonschema:"description=The ingredient name"` ++ Amount string `json:"amount" jsonschema:"description=The ingredient amount (e.g. 1 cup, 2 tablespoons, etc.)"` ++ Optional bool `json:"optional,omitempty" jsonschema:"description=Whether the ingredient is optional in the recipe"` ++} ++ ++type Recipe struct { ++ Title string `json:"title" jsonschema:"description=The recipe title (e.g. 'Spicy Chicken Tacos')"` ++ Description string `json:"description,omitempty" jsonschema:"description=The recipe description (under 100 characters)"` ++ Ingredients []*Ingredient `json:"ingredients" jsonschema:"description=The recipe ingredients (order by type first and then importance)"` ++ Instructions []string `json:"instructions" jsonschema:"description=The recipe instructions (step by step)"` ++ PrepTime string `json:"prepTime" jsonschema:"description=The recipe preparation time (e.g. 10 minutes, 30 minutes, etc.)"` ++ Difficulty string `json:"difficulty" jsonschema:"enum=easy,enum=medium,enum=hard"` ++} ++ ++func main() { ++ ctx := context.Background() ++ ++ // Initialize Genkit with the Google AI plugin. When you pass nil for the ++ // Config parameter, the Google AI plugin will get the API key from the ++ // GEMINI_API_KEY or GOOGLE_API_KEY environment variable, which is the recommended ++ // practice. ++ g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) ++ ++ // Define the flows. ++ DefineSimpleJoke(g) ++ DefineStructuredJoke(g) ++ DefineRecipe(g) ++ ++ // Optionally, start a web server to make the flows callable via HTTP. ++ mux := http.NewServeMux() ++ for _, a := range genkit.ListFlows(g) { ++ mux.HandleFunc("POST /"+a.Name(), genkit.Handler(a)) ++ } ++ log.Fatal(server.Start(ctx, "127.0.0.1:8080", mux)) ++} ++ ++// DefineSimpleJoke demonstrates defining a streaming flow that generates a joke about a given topic. ++func DefineSimpleJoke(g *genkit.Genkit) { ++ genkit.DefineStreamingFlow(g, "simpleJokesFlow", ++ func(ctx context.Context, input string, sendChunk core.StreamCallback[string]) (string, error) { ++ stream := genkit.GenerateStream(ctx, g, ++ ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ ++ ThinkingConfig: &genai.ThinkingConfig{ ++ ThinkingBudget: genai.Ptr[int32](0), ++ }, ++ })), ++ ai.WithPrompt("Share a long joke about %s.", input), ++ ) ++ ++ for result, err := range stream { ++ if err != nil { ++ return "", fmt.Errorf("could not generate joke: %w", err) ++ } ++ if result.Done { ++ return result.Response.Text(), nil ++ } ++ sendChunk(ctx, result.Chunk.Text()) ++ } ++ ++ return "", nil ++ }, ++ ) ++} ++ ++// DefineStructuredJoke demonstrates defining a streaming flow that generates a joke about a given topic. ++// The input is a strongly-typed JokeRequest struct and the output is a strongly-typed Joke struct. ++func DefineStructuredJoke(g *genkit.Genkit) { ++ genkit.DefineStreamingFlow(g, "structuredJokesFlow", ++ func(ctx context.Context, input JokeRequest, sendChunk core.StreamCallback[*Joke]) (*Joke, error) { ++ stream := genkit.GenerateDataStream[*Joke](ctx, g, ++ ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ ++ ThinkingConfig: &genai.ThinkingConfig{ ++ ThinkingBudget: genai.Ptr[int32](0), ++ }, ++ })), ++ ai.WithPrompt("Share a long joke about %s.", input.Topic), ++ ) ++ ++ for result, err := range stream { ++ if err != nil { ++ return nil, fmt.Errorf("could not generate joke: %w", err) ++ } ++ if result.Done { ++ return result.Output, nil ++ } ++ sendChunk(ctx, result.Chunk) ++ } ++ ++ return nil, nil ++ }) ++} ++ ++// DefineRecipe demonstrates defining a streaming flow that generates a recipe based on a given RecipeRequest struct. ++// The input is a strongly-typed RecipeRequest struct and the output is a strongly-typed Recipe struct. ++func DefineRecipe(g *genkit.Genkit) { ++ genkit.DefineStreamingFlow(g, "recipeFlow", ++ func(ctx context.Context, input RecipeRequest, sendChunk core.StreamCallback[[]*Ingredient]) (*Recipe, error) { ++ stream := genkit.GenerateDataStream[*Recipe](ctx, g, ++ ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ ++ ThinkingConfig: &genai.ThinkingConfig{ ++ ThinkingBudget: genai.Ptr[int32](0), ++ }, ++ })), ++ ai.WithSystem("You are an experienced chef. Come up with easy, creative recipes."), ++ // Here we are passing WithPromptFn() since our prompt takes some string manipulation to build. ++ // Alternatively, we could pass WithPrompt() with the complete prompt string. ++ ai.WithPromptFn(func(ctx context.Context, _ any) (string, error) { ++ prompt := fmt.Sprintf( ++ "Create a %s %s recipe for %d people that takes under %d minutes to prepare.", ++ input.Cuisine, input.Dish, input.ServingSize, input.MaxPrepMinutes, ++ ) ++ if len(input.DietaryRestrictions) > 0 { ++ prompt += fmt.Sprintf(" Dietary restrictions: %v.", input.DietaryRestrictions) ++ } ++ return prompt, nil ++ }), ++ ) ++ ++ for result, err := range stream { ++ if err != nil { ++ return nil, fmt.Errorf("could not generate recipe: %w", err) ++ } ++ if result.Done { ++ return result.Output, nil ++ } ++ sendChunk(ctx, result.Chunk.Ingredients) ++ } ++ ++ return nil, nil ++ }) ++} +diff --git a/go/samples/basic/main.go b/go/samples/basic/main.go +new file mode 100644 +index 000000000..2031340ac +--- /dev/null ++++ b/go/samples/basic/main.go +@@ -0,0 +1,86 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++ ++package main ++ ++import ( ++ "context" ++ "fmt" ++ "log" ++ "net/http" ++ ++ "github.com/firebase/genkit/go/ai" ++ "github.com/firebase/genkit/go/genkit" ++ "github.com/firebase/genkit/go/plugins/googlegenai" ++ "github.com/firebase/genkit/go/plugins/server" ++ "google.golang.org/genai" ++) ++ ++func main() { ++ ctx := context.Background() ++ ++ // Initialize Genkit with the Google AI plugin. When you pass nil for the ++ // Config parameter, the Google AI plugin will get the API key from the ++ // GEMINI_API_KEY or GOOGLE_API_KEY environment variable, which is the recommended ++ // practice. ++ g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) ++ ++ // Define a non-streaming flow that generates jokes about a given topic. ++ genkit.DefineFlow(g, "jokesFlow", func(ctx context.Context, input string) (string, error) { ++ if input == "" { ++ input = "airplane food" ++ } ++ ++ return genkit.GenerateText(ctx, g, ++ ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ ++ ThinkingConfig: &genai.ThinkingConfig{ ++ ThinkingBudget: genai.Ptr[int32](0), ++ }, ++ })), ++ ai.WithPrompt("Share a joke about %s.", input), ++ ) ++ }, ++ ) ++ ++ // Define a streaming flow that generates jokes about a given topic with passthrough streaming. ++ genkit.DefineStreamingFlow(g, "streamingJokesFlow", ++ func(ctx context.Context, input string, sendChunk ai.ModelStreamCallback) (string, error) { ++ if input == "" { ++ input = "airplane food" ++ } ++ ++ resp, err := genkit.Generate(ctx, g, ++ ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ ++ ThinkingConfig: &genai.ThinkingConfig{ ++ ThinkingBudget: genai.Ptr[int32](0), ++ }, ++ })), ++ ai.WithPrompt("Share a joke about %s.", input), ++ ai.WithStreaming(sendChunk), ++ ) ++ if err != nil { ++ return "", fmt.Errorf("could not generate joke: %w", err) ++ } ++ ++ return resp.Text(), nil ++ }, ++ ) ++ ++ // Optionally, start a web server to make the flow callable via HTTP. ++ mux := http.NewServeMux() ++ for _, a := range genkit.ListFlows(g) { ++ mux.HandleFunc("POST /"+a.Name(), genkit.Handler(a)) ++ } ++ log.Fatal(server.Start(ctx, "127.0.0.1:8080", mux)) ++} +diff --git a/go/samples/durable-streaming-firestore/README.md b/go/samples/durable-streaming-firestore/README.md +new file mode 100644 +index 000000000..8ca1648e8 +--- /dev/null ++++ b/go/samples/durable-streaming-firestore/README.md +@@ -0,0 +1,140 @@ ++# Durable Streaming with Firestore ++ ++This sample demonstrates durable streaming using Firestore as the backend. Unlike in-memory streaming, Firestore-backed streams: ++ ++- **Survive server restarts** - Clients can reconnect to streams after server restarts ++- **Work across instances** - Multiple server instances can serve the same stream ++- **Auto-cleanup** - Completed streams are automatically deleted via Firestore TTL policies ++ ++## Prerequisites ++ ++1. **Firebase Project**: You need a Firebase/GCP project with Firestore enabled. ++ ++2. **Authentication**: Authenticate with your Google Cloud project: ++ ```bash ++ gcloud auth application-default login ++ ``` ++ ++3. **(Recommended) TTL Policy**: Configure a TTL policy on your Firestore collection for automatic cleanup of old streams. This requires setting a TTL on the `expiresAt` field: ++ ++ ```bash ++ gcloud firestore fields ttls update expiresAt \ ++ --collection-group=genkit-streams \ ++ --enable-ttl \ ++ --project=YOUR_PROJECT_ID ++ ``` ++ ++ See: https://firebase.google.com/docs/firestore/ttl ++ ++## Environment Variables ++ ++| Variable | Required | Default | Description | ++|----------|----------|---------|-------------| ++| `FIREBASE_PROJECT_ID` | Yes | - | Your Firebase/GCP project ID | ++| `FIRESTORE_STREAMS_COLLECTION` | No | `genkit-streams` | Firestore collection for stream documents | ++ ++## Running the Sample ++ ++1. Set your project ID: ++ ```bash ++ export FIREBASE_PROJECT_ID=your-project-id ++ ``` ++ ++2. Start the server: ++ ```bash ++ go run . ++ ``` ++ ++## Testing ++ ++### Start a streaming request ++ ++```bash ++curl -N -i -H "Accept: text/event-stream" \ ++ -d '{"data": 5}' \ ++ http://localhost:8080/countdown ++``` ++ ++Note the `X-Genkit-Stream-Id` header in the response - you'll need this to reconnect. ++ ++### Reconnect to an existing stream ++ ++Use the stream ID from the previous response: ++ ++```bash ++curl -N -H "Accept: text/event-stream" \ ++ -H "X-Genkit-Stream-Id: " \ ++ -d '{"data": 5}' \ ++ http://localhost:8080/countdown ++``` ++ ++The subscription will: ++- Replay any buffered chunks that were already sent ++- Continue with live updates if the stream is still in progress ++- Return all chunks plus the final result if the stream has already completed ++ ++### Test server restart resilience ++ ++1. Start a countdown with a high number: ++ ```bash ++ curl -N -i -H "Accept: text/event-stream" -d '{"data": 30}' http://localhost:8080/countdown ++ ``` ++ ++2. Copy the `X-Genkit-Stream-Id` header value ++ ++3. Stop the server (Ctrl+C) ++ ++4. Restart the server: `go run .` ++ ++5. Reconnect using the stream ID: ++ ```bash ++ curl -N -H "Accept: text/event-stream" -H "X-Genkit-Stream-Id: " -d '{"data": 30}' http://localhost:8080/countdown ++ ``` ++ ++You'll receive all previously buffered chunks, demonstrating that the stream state persisted across the server restart. ++ ++## Configuration Options ++ ++The `FirestoreStreamManager` supports these options: ++ ++| Option | Default | Description | ++|--------|---------|-------------| ++| `WithCollection(name)` | (required) | Firestore collection for stream documents | ++| `WithTimeout(duration)` | 60s | How long subscribers wait for new events before timeout | ++| `WithTTL(duration)` | 5m | How long completed streams are retained before auto-deletion | ++ ++Example: ++```go ++streamManager, err := firebasex.NewFirestoreStreamManager(ctx, g, ++ firebasex.WithCollection("my-streams"), ++ firebasex.WithTimeout(2*time.Minute), ++ firebasex.WithTTL(1*time.Hour), ++) ++``` ++ ++## How It Works ++ ++1. When a streaming request arrives, a Firestore document is created with the stream ID ++2. As the flow produces chunks, they're appended to the document's `stream` array ++3. Subscribers use Firestore's real-time listeners to receive updates ++4. When the flow completes, a final "done" entry is added with the output ++5. The `expiresAt` field is set based on TTL, and Firestore automatically deletes the document ++ ++## License ++ ++``` ++Copyright 2025 Google LLC ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++``` ++ +diff --git a/go/samples/durable-streaming-firestore/main.go b/go/samples/durable-streaming-firestore/main.go +new file mode 100644 +index 000000000..988ccda79 +--- /dev/null ++++ b/go/samples/durable-streaming-firestore/main.go +@@ -0,0 +1,89 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++// ++// SPDX-License-Identifier: Apache-2.0 ++ ++// This sample demonstrates durable streaming with Firestore backend. ++// Unlike in-memory streaming, Firestore-backed streams survive server restarts ++// and can be accessed across multiple server instances. ++// ++// See README.md for setup instructions. ++package main ++ ++import ( ++ "context" ++ "fmt" ++ "log" ++ "net/http" ++ "time" ++ ++ "github.com/firebase/genkit/go/genkit" ++ "github.com/firebase/genkit/go/plugins/firebase" ++ firebasex "github.com/firebase/genkit/go/plugins/firebase/x" ++ "github.com/firebase/genkit/go/plugins/server" ++) ++ ++func main() { ++ ctx := context.Background() ++ ++ g := genkit.Init(ctx, genkit.WithPlugins(&firebase.Firebase{})) ++ ++ type CountdownChunk struct { ++ Count int `json:"count"` ++ Message string `json:"message"` ++ Timestamp string `json:"timestamp"` ++ } ++ ++ countdown := genkit.DefineStreamingFlow(g, "countdown", ++ func(ctx context.Context, count int, sendChunk func(context.Context, CountdownChunk) error) (string, error) { ++ if count <= 0 { ++ count = 5 ++ } ++ ++ for i := count; i > 0; i-- { ++ select { ++ case <-ctx.Done(): ++ return "", ctx.Err() ++ case <-time.After(1 * time.Second): ++ } ++ ++ chunk := CountdownChunk{ ++ Count: i, ++ Message: fmt.Sprintf("T-%d...", i), ++ Timestamp: time.Now().Format(time.RFC3339), ++ } ++ ++ if err := sendChunk(ctx, chunk); err != nil { ++ return "", err ++ } ++ } ++ ++ return "Liftoff!", nil ++ }) ++ ++ sm, err := firebasex.NewFirestoreStreamManager(ctx, g, ++ firebasex.WithCollection("genkit-streams"), ++ firebasex.WithTimeout(2*time.Minute), ++ firebasex.WithTTL(10*time.Minute), ++ ) ++ if err != nil { ++ log.Fatalf("Failed to create Firestore stream manager: %v", err) ++ } ++ ++ // Set up HTTP server with durable streaming enabled. ++ // Completed streams are kept for 10 minutes before cleanup. ++ mux := http.NewServeMux() ++ mux.HandleFunc("POST /countdown", genkit.Handler(countdown, genkit.WithStreamManager(sm))) ++ log.Fatal(server.Start(ctx, "127.0.0.1:8088", mux)) ++} +diff --git a/go/samples/durable-streaming/main.go b/go/samples/durable-streaming/main.go +new file mode 100644 +index 000000000..36323990a +--- /dev/null ++++ b/go/samples/durable-streaming/main.go +@@ -0,0 +1,99 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++// ++// SPDX-License-Identifier: Apache-2.0 ++ ++// This sample demonstrates durable streaming, which allows clients to reconnect ++// to in-progress or completed streams using a stream ID. ++// ++// Start the server: ++// ++// go run . ++// ++// Test streaming (get a stream ID back in X-Genkit-Stream-Id header): ++// ++// curl -N -i -H "Accept: text/event-stream" \ ++// -d '{"data": 5}' \ ++// http://localhost:8080/countdown ++// ++// Subscribe to an existing stream using the stream ID from the previous response: ++// ++// curl -N -H "Accept: text/event-stream" \ ++// -H "X-Genkit-Stream-Id: " \ ++// -d '{"data": 5}' \ ++// http://localhost:8080/countdown ++// ++// The subscription will replay any buffered chunks and then continue with live updates. ++// If the stream has already completed, all chunks plus the final result are returned. ++ ++package main ++ ++import ( ++ "context" ++ "fmt" ++ "log" ++ "net/http" ++ "time" ++ ++ "github.com/firebase/genkit/go/core/x/streaming" ++ "github.com/firebase/genkit/go/genkit" ++ "github.com/firebase/genkit/go/plugins/server" ++) ++ ++func main() { ++ ctx := context.Background() ++ g := genkit.Init(ctx) ++ ++ type CountdownChunk struct { ++ Count int `json:"count"` ++ Message string `json:"message"` ++ Timestamp string `json:"timestamp"` ++ } ++ ++ // Define a streaming flow that counts down with delays. ++ countdown := genkit.DefineStreamingFlow(g, "countdown", ++ func(ctx context.Context, count int, sendChunk func(context.Context, CountdownChunk) error) (string, error) { ++ if count <= 0 { ++ count = 5 ++ } ++ ++ for i := count; i > 0; i-- { ++ select { ++ case <-ctx.Done(): ++ return "", ctx.Err() ++ case <-time.After(1 * time.Second): ++ } ++ ++ chunk := CountdownChunk{ ++ Count: i, ++ Message: fmt.Sprintf("T-%d...", i), ++ Timestamp: time.Now().Format(time.RFC3339), ++ } ++ ++ if err := sendChunk(ctx, chunk); err != nil { ++ return "", err ++ } ++ } ++ ++ return "Liftoff!", nil ++ }) ++ ++ // Set up HTTP server with durable streaming enabled. ++ // Completed streams are kept for 10 minutes before cleanup (while server is running). ++ mux := http.NewServeMux() ++ mux.HandleFunc("POST /countdown", genkit.Handler(countdown, ++ genkit.WithStreamManager(streaming.NewInMemoryStreamManager(streaming.WithTTL(10*time.Minute))), ++ )) ++ log.Fatal(server.Start(ctx, "127.0.0.1:8080", mux)) ++} +diff --git a/go/samples/basic-gemini/main.go b/go/samples/multipart-tools/main.go +similarity index 54% +rename from go/samples/basic-gemini/main.go +rename to go/samples/multipart-tools/main.go +index e61ec9df4..c9cb04bdc 100644 +--- a/go/samples/basic-gemini/main.go ++++ b/go/samples/multipart-tools/main.go +@@ -26,32 +26,37 @@ import ( + func main() { + ctx := context.Background() + +- // Initialize Genkit with the Google AI plugin. When you pass nil for the +- // Config parameter, the Google AI plugin will get the API key from the +- // GEMINI_API_KEY or GOOGLE_API_KEY environment variable, which is the recommended +- // practice. + g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) + +- // Define a simple flow that generates jokes about a given topic +- genkit.DefineStreamingFlow(g, "jokesFlow", func(ctx context.Context, input string, cb ai.ModelStreamCallback) (string, error) { +- type Joke struct { +- Joke string `json:"joke"` +- Category string `json:"jokeCategory" description:"What is the joke about"` +- } +- +- genkit.DefineSchemaFor[Joke](g) ++ // Define a multipart tool. ++ // This simulates a tool that takes a screenshot ++ screenshot := genkit.DefineMultipartTool(g, "screenshot", "Takes a screenshot", ++ func(ctx *ai.ToolContext, input any) (*ai.MultipartToolResponse, error) { ++ rectangle := "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAHIAAABUAQMAAABk5vEVAAAABlBMVEX///8AAABVwtN+" + ++ "AAAAI0lEQVR4nGNgGHaA/z8UHIDwOWASDqP8Uf7w56On/1FAQwAAVM0exw1hqwkAAAAASUVORK5CYII=" ++ return &ai.MultipartToolResponse{ ++ Output: map[string]any{"success": true}, ++ Content: []*ai.Part{ ++ ai.NewMediaPart("image/png", rectangle), ++ }, ++ }, nil ++ }, ++ ) + ++ // Define a simple flow that uses the multipart tool ++ genkit.DefineStreamingFlow(g, "cardFlow", func(ctx context.Context, input any, cb ai.ModelStreamCallback) (string, error) { + resp, err := genkit.Generate(ctx, g, +- ai.WithModelName("googleai/gemini-2.5-flash"), ++ ai.WithModelName("googleai/gemini-3-pro-preview"), + ai.WithConfig(&genai.GenerateContentConfig{ + Temperature: genai.Ptr[float32](1.0), + ThinkingConfig: &genai.ThinkingConfig{ +- ThinkingBudget: genai.Ptr[int32](0), ++ ThinkingLevel: genai.ThinkingLevelHigh, + }, + }), ++ ai.WithTools(screenshot), + ai.WithStreaming(cb), +- ai.WithOutputSchemaName("Joke"), +- ai.WithPrompt(`Tell short jokes about %s`, input)) ++ ai.WithPrompt("Tell me what I'm seeing in the screen"), ++ ) + if err != nil { + return "", err + } +diff --git a/go/samples/prompts-dir/main.go b/go/samples/prompts-dir/main.go +deleted file mode 100644 +index 59e5e8384..000000000 +--- a/go/samples/prompts-dir/main.go ++++ /dev/null +@@ -1,61 +0,0 @@ +-// Copyright 2025 Google LLC +-// SPDX-License-Identifier: Apache-2.0 +- +-// [START main] +-package main +- +-import ( +- "context" +- "errors" +- +- // Import Genkit and the Google AI plugin +- "github.com/firebase/genkit/go/ai" +- "github.com/firebase/genkit/go/genkit" +- "github.com/firebase/genkit/go/plugins/googlegenai" +-) +- +-func main() { +- ctx := context.Background() +- +- g := genkit.Init(ctx, +- genkit.WithPlugins(&googlegenai.GoogleAI{}), +- genkit.WithPromptDir("prompts"), +- ) +- +- type greetingStyle struct { +- Style string `json:"style"` +- Location string `json:"location"` +- Name string `json:"name"` +- } +- +- type greeting struct { +- Greeting string `json:"greeting"` +- } +- +- // Define a simple flow that prompts an LLM to generate greetings using a +- // given style. +- genkit.DefineFlow(g, "assistantGreetingFlow", func(ctx context.Context, input greetingStyle) (string, error) { +- // Look up the prompt by name +- prompt := genkit.LookupPrompt(g, "example") +- if prompt == nil { +- return "", errors.New("assistantGreetingFlow: failed to find prompt") +- } +- +- // Execute the prompt with the provided input +- resp, err := prompt.Execute(ctx, ai.WithInput(input)) +- if err != nil { +- return "", err +- } +- +- var output greeting +- if err = resp.Output(&output); err != nil { +- return "", err +- } +- +- return output.Greeting, nil +- }) +- +- <-ctx.Done() +-} +- +-// [END main] +diff --git a/go/samples/prompts-dir/prompts/example.prompt b/go/samples/prompts-dir/prompts/example.prompt +deleted file mode 100644 +index 0492cfd32..000000000 +--- a/go/samples/prompts-dir/prompts/example.prompt ++++ /dev/null +@@ -1,19 +0,0 @@ +---- +-model: googleai/gemini-2.5-flash +-config: +- temperature: 0.9 +-input: +- schema: +- location: string +- style?: string +- name?: string +- default: +- name: Rutuja +-output: +- schema: +- greeting: string +---- +- +-You are the world's most welcoming AI assistant and are currently working at {{location}}. +- +-Greet a guest{{#if name}} named {{name}}{{/if}}{{#if style}} in the style of {{style}}{{/if}}. +diff --git a/go/samples/prompts-embed/main.go b/go/samples/prompts-embed/main.go +new file mode 100644 +index 000000000..f0f7a5bde +--- /dev/null ++++ b/go/samples/prompts-embed/main.go +@@ -0,0 +1,60 @@ ++// Copyright 2025 Google LLC ++// ++// Licensed under the Apache License, Version 2.0 (the "License"); ++// you may not use this file except in compliance with the License. ++// You may obtain a copy of the License at ++// ++// http://www.apache.org/licenses/LICENSE-2.0 ++// ++// Unless required by applicable law or agreed to in writing, software ++// distributed under the License is distributed on an "AS IS" BASIS, ++// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++// See the License for the specific language governing permissions and ++// limitations under the License. ++ ++// This sample demonstrates how to use embedded prompts with genkit. ++// Prompts are embedded directly into the binary using Go's embed package, ++// which allows you to ship a self-contained binary without needing to ++// distribute prompt files separately. ++ ++package main ++ ++import ( ++ "context" ++ "embed" ++ "errors" ++ ++ "github.com/firebase/genkit/go/genkit" ++ "github.com/firebase/genkit/go/plugins/googlegenai" ++) ++ ++// Embed the prompts directory into the binary. ++// The //go:embed directive makes the prompts available at compile time. ++// ++//go:embed prompts/* ++var promptsFS embed.FS ++ ++func main() { ++ ctx := context.Background() ++ ++ g := genkit.Init(ctx, ++ genkit.WithPlugins(&googlegenai.GoogleAI{}), ++ genkit.WithPromptFS(promptsFS), ++ ) ++ ++ genkit.DefineFlow(g, "sayHello", func(ctx context.Context, name string) (string, error) { ++ prompt := genkit.LookupPrompt(g, "example") ++ if prompt == nil { ++ return "", errors.New("prompt not found") ++ } ++ ++ resp, err := prompt.Execute(ctx) ++ if err != nil { ++ return "", err ++ } ++ ++ return resp.Text(), nil ++ }) ++ ++ <-ctx.Done() ++} +diff --git a/go/samples/prompts-embed/prompts/example.prompt b/go/samples/prompts-embed/prompts/example.prompt +new file mode 100644 +index 000000000..abdd8bdee +--- /dev/null ++++ b/go/samples/prompts-embed/prompts/example.prompt +@@ -0,0 +1,5 @@ ++--- ++model: googleai/gemini-2.5-flash ++--- ++ ++Say hello!. +diff --git a/go/samples/prompts/main.go b/go/samples/prompts/main.go +index fe9d3bcd9..ed0af5c79 100644 +--- a/go/samples/prompts/main.go ++++ b/go/samples/prompts/main.go +@@ -209,7 +209,10 @@ func PromptWithMultiMessage(ctx context.Context, g *genkit.Genkit) { + } + resp, err := prompt.Execute(ctx, + ai.WithModelName("googleai/gemini-2.5-pro"), +- ai.WithInput(map[string]any{"videoUrl": "https://www.youtube.com/watch?v=K-hY0E6cGfo video/mp4"}), ++ ai.WithInput(map[string]any{ ++ "videoUrl": "https://www.youtube.com/watch?v=K-hY0E6cGfo", ++ "contentType": "video/mp4", ++ }), + ) + if err != nil { + log.Fatal(err) +diff --git a/go/samples/prompts/prompts/multi-msg.prompt b/go/samples/prompts/prompts/multi-msg.prompt +index 63b177a88..b86981b6e 100644 +--- a/go/samples/prompts/prompts/multi-msg.prompt ++++ b/go/samples/prompts/prompts/multi-msg.prompt +@@ -3,6 +3,7 @@ model: googleai/gemini-2.5-flash + input: + schema: + videoUrl: string ++ contentType: string + output: + summary: string + --- +@@ -13,4 +14,4 @@ You are a great AI assistant that summarizes videos talking as a pirate + {{ role "user" }} + + Give me a summary of this video +-{{media url=videoUrl}} ++{{media url=videoUrl contentType=contentType}} +diff --git a/js/plugins/anthropic/package.json b/js/plugins/anthropic/package.json +index 5b3e19be1..7bafa4b18 100644 +--- a/js/plugins/anthropic/package.json ++++ b/js/plugins/anthropic/package.json +@@ -29,7 +29,7 @@ + "genkit": "workspace:^" + }, + "dependencies": { +- "@anthropic-ai/sdk": "^0.68.0" ++ "@anthropic-ai/sdk": "^0.71.2" + }, + "devDependencies": { + "@types/node": "^20.11.16", +@@ -64,6 +64,7 @@ + "build": "npm-run-all build:clean check compile", + "build:watch": "tsup-node --watch", + "test": "tsx --test tests/*_test.ts", ++ "test:live": "tsx --test tests/live_test.ts", + "test:file": "tsx --test", + "test:live": "tsx --test tests/live_test.ts", + "test:coverage": "check-node-version --node '>=22' && tsx --test --experimental-test-coverage --test-coverage-include='src/**/*.ts' ./tests/**/*_test.ts" +diff --git a/js/plugins/anthropic/src/models.ts b/js/plugins/anthropic/src/models.ts +index 98767af40..ba03c1591 100644 +--- a/js/plugins/anthropic/src/models.ts ++++ b/js/plugins/anthropic/src/models.ts +@@ -91,19 +91,66 @@ export const KNOWN_CLAUDE_MODELS: Record< + 'claude-opus-4': commonRef('claude-opus-4', AnthropicThinkingConfigSchema), + 'claude-sonnet-4-5': commonRef( + 'claude-sonnet-4-5', +- AnthropicThinkingConfigSchema ++ AnthropicThinkingConfigSchema, ++ { ++ supports: { ++ multiturn: true, ++ tools: true, ++ media: true, ++ systemRole: true, ++ output: ['text', 'json'], ++ constrained: 'all', ++ }, ++ } + ), + 'claude-haiku-4-5': commonRef( + 'claude-haiku-4-5', +- AnthropicThinkingConfigSchema +- ), +- 'claude-opus-4-5': commonRef( +- 'claude-opus-4-5', +- AnthropicThinkingConfigSchema ++ AnthropicThinkingConfigSchema, ++ { ++ supports: { ++ multiturn: true, ++ tools: true, ++ media: true, ++ systemRole: true, ++ output: ['text', 'json'], ++ constrained: 'all', ++ }, ++ } + ), + 'claude-opus-4-1': commonRef( + 'claude-opus-4-1', +- AnthropicThinkingConfigSchema ++ AnthropicThinkingConfigSchema, ++ { ++ supports: { ++ multiturn: true, ++ tools: true, ++ media: true, ++ systemRole: true, ++ output: ['text', 'json'], ++ constrained: 'all', ++ }, ++ } ++ ), ++ 'claude-opus-4-5': commonRef( ++ 'claude-opus-4-5', ++ AnthropicThinkingConfigSchema.extend({ ++ output_config: z ++ .object({ ++ effort: z.enum(['low', 'medium', 'high']).optional(), ++ }) ++ .passthrough() ++ .optional(), ++ }), ++ { ++ supports: { ++ multiturn: true, ++ tools: true, ++ media: true, ++ systemRole: true, ++ output: ['text', 'json'], ++ constrained: 'all', ++ }, ++ } + ), + }; + +@@ -232,9 +279,11 @@ export function claudeModel( + defaultApiVersion: apiVersion, + } = params; + // Use supported model ref if available, otherwise create generic model ref +- const modelRef = KNOWN_CLAUDE_MODELS[name]; +- const modelInfo = modelRef ? modelRef.info : GENERIC_CLAUDE_MODEL_INFO; +- const configSchema = modelRef?.configSchema ?? AnthropicConfigSchema; ++ const knownModelRef = KNOWN_CLAUDE_MODELS[name]; ++ let modelInfo = knownModelRef ++ ? knownModelRef.info ++ : GENERIC_CLAUDE_MODEL_INFO; ++ const configSchema = knownModelRef?.configSchema ?? AnthropicConfigSchema; + + return model< + AnthropicBaseConfigSchemaType | AnthropicThinkingConfigSchemaType +diff --git a/js/plugins/anthropic/src/runner/beta.ts b/js/plugins/anthropic/src/runner/beta.ts +index 6a71fa71d..099a58990 100644 +--- a/js/plugins/anthropic/src/runner/beta.ts ++++ b/js/plugins/anthropic/src/runner/beta.ts +@@ -44,6 +44,7 @@ import { logger } from 'genkit/logging'; + + import { KNOWN_CLAUDE_MODELS, extractVersion } from '../models.js'; + import { AnthropicConfigSchema, type ClaudeRunnerParams } from '../types.js'; ++import { removeUndefinedProperties } from '../utils.js'; + import { BaseRunner } from './base.js'; + import { RunnerTypes } from './types.js'; + +@@ -66,6 +67,57 @@ const BETA_UNSUPPORTED_SERVER_TOOL_BLOCK_TYPES = new Set([ + 'container_upload', + ]); + ++const BETA_APIS = [ ++ // 'message-batches-2024-09-24', ++ // 'prompt-caching-2024-07-31', ++ // 'computer-use-2025-01-24', ++ // 'pdfs-2024-09-25', ++ // 'token-counting-2024-11-01', ++ // 'token-efficient-tools-2025-02-19', ++ // 'output-128k-2025-02-19', ++ 'files-api-2025-04-14', ++ // 'mcp-client-2025-04-04', ++ // 'dev-full-thinking-2025-05-14', ++ // 'interleaved-thinking-2025-05-14', ++ // 'code-execution-2025-05-22', ++ // 'extended-cache-ttl-2025-04-11', ++ // 'context-1m-2025-08-07', ++ // 'context-management-2025-06-27', ++ // 'model-context-window-exceeded-2025-08-26', ++ // 'skills-2025-10-02', ++ 'effort-2025-11-24', ++ // 'advanced-tool-use-2025-11-20', ++ 'structured-outputs-2025-11-13', ++]; ++ ++/** ++ * Transforms a JSON schema to be compatible with Anthropic's structured output requirements. ++ * Anthropic requires `additionalProperties: false` on all object types. ++ * @see https://docs.anthropic.com/en/docs/build-with-claude/structured-outputs#json-schema-limitations ++ */ ++function toAnthropicSchema( ++ schema: Record ++): Record { ++ const out = structuredClone(schema); ++ ++ // Remove $schema if present ++ delete out.$schema; ++ ++ // Add additionalProperties: false to objects ++ if (out.type === 'object') { ++ out.additionalProperties = false; ++ } ++ ++ // Recursively process nested objects ++ for (const key in out) { ++ if (typeof out[key] === 'object' && out[key] !== null) { ++ out[key] = toAnthropicSchema(out[key] as Record); ++ } ++ } ++ ++ return out; ++} ++ + const unsupportedServerToolError = (blockType: string): string => + `Anthropic beta runner does not yet support server-managed tool block '${blockType}'. Please retry against the stable API or wait for dedicated support.`; + +@@ -140,6 +192,26 @@ export class BetaRunner extends BaseRunner { + + // Media + if (part.media) { ++ if (part.media.contentType === 'anthropic/file') { ++ return { ++ type: 'document', ++ source: { ++ type: 'file', ++ file_id: part.media.url, ++ }, ++ }; ++ } ++ ++ if (part.media.contentType === 'anthropic/image') { ++ return { ++ type: 'image', ++ source: { ++ type: 'file', ++ file_id: part.media.url, ++ }, ++ }; ++ } ++ + if (part.media.contentType === 'application/pdf') { + return { + type: 'document', +@@ -249,45 +321,49 @@ export class BetaRunner extends BaseRunner { + : system; + } + +- const body: BetaMessageCreateParamsNonStreaming = { ++ const thinkingConfig = this.toAnthropicThinkingConfig( ++ request.config?.thinking ++ ) as BetaMessageCreateParams['thinking'] | undefined; ++ ++ // Need to extract topP and topK from request.config to avoid duplicate properties being added to the body ++ // This happens because topP and topK have different property names (top_p and top_k) in the Anthropic API. ++ // Thinking is extracted separately to avoid type issues. ++ // ApiVersion is extracted separately as it's not a valid property for the Anthropic API. ++ const { ++ topP, ++ topK, ++ apiVersion: _1, ++ thinking: _2, ++ ...restConfig ++ } = request.config ?? {}; ++ ++ const body = { + model: mappedModelName, + max_tokens: + request.config?.maxOutputTokens ?? this.DEFAULT_MAX_OUTPUT_TOKENS, + messages, +- }; +- +- if (betaSystem !== undefined) body.system = betaSystem; +- if (request.config?.stopSequences !== undefined) +- body.stop_sequences = request.config.stopSequences; +- if (request.config?.temperature !== undefined) +- body.temperature = request.config.temperature; +- if (request.config?.topK !== undefined) body.top_k = request.config.topK; +- if (request.config?.topP !== undefined) body.top_p = request.config.topP; +- if (request.config?.tool_choice !== undefined) { +- body.tool_choice = request.config +- .tool_choice as BetaMessageCreateParams['tool_choice']; +- } +- if (request.config?.metadata !== undefined) { +- body.metadata = request.config +- .metadata as BetaMessageCreateParams['metadata']; +- } +- if (request.tools) { +- body.tools = request.tools.map((tool) => this.toAnthropicTool(tool)); +- } +- const thinkingConfig = this.toAnthropicThinkingConfig( +- request.config?.thinking +- ); +- if (thinkingConfig) { +- body.thinking = thinkingConfig as BetaMessageCreateParams['thinking']; +- } +- +- if (request.output?.format && request.output.format !== 'text') { +- throw new Error( +- `Only text output format is supported for Claude models currently` +- ); +- } +- +- return body; ++ system: betaSystem, ++ stop_sequences: request.config?.stopSequences, ++ temperature: request.config?.temperature, ++ top_k: topK, ++ top_p: topP, ++ tool_choice: request.config?.tool_choice, ++ metadata: request.config?.metadata, ++ tools: request.tools?.map((tool) => this.toAnthropicTool(tool)), ++ thinking: thinkingConfig, ++ output_format: this.isStructuredOutputEnabled(request) ++ ? { ++ type: 'json_schema', ++ schema: toAnthropicSchema(request.output!.schema!), ++ } ++ : undefined, ++ betas: Array.isArray(request.config?.betas) ++ ? [...(request.config?.betas ?? [])] ++ : [...BETA_APIS], ++ ...restConfig, ++ } as BetaMessageCreateParamsNonStreaming; ++ ++ return removeUndefinedProperties(body); + } + + /** +@@ -316,46 +392,50 @@ export class BetaRunner extends BaseRunner { + ] + : system; + +- const body: BetaMessageCreateParamsStreaming = { ++ const thinkingConfig = this.toAnthropicThinkingConfig( ++ request.config?.thinking ++ ) as BetaMessageCreateParams['thinking'] | undefined; ++ ++ // Need to extract topP and topK from request.config to avoid duplicate properties being added to the body ++ // This happens because topP and topK have different property names (top_p and top_k) in the Anthropic API. ++ // Thinking is extracted separately to avoid type issues. ++ // ApiVersion is extracted separately as it's not a valid property for the Anthropic API. ++ const { ++ topP, ++ topK, ++ apiVersion: _1, ++ thinking: _2, ++ ...restConfig ++ } = request.config ?? {}; ++ ++ const body = { + model: mappedModelName, + max_tokens: + request.config?.maxOutputTokens ?? this.DEFAULT_MAX_OUTPUT_TOKENS, + messages, + stream: true, +- }; +- +- if (betaSystem !== undefined) body.system = betaSystem; +- if (request.config?.stopSequences !== undefined) +- body.stop_sequences = request.config.stopSequences; +- if (request.config?.temperature !== undefined) +- body.temperature = request.config.temperature; +- if (request.config?.topK !== undefined) body.top_k = request.config.topK; +- if (request.config?.topP !== undefined) body.top_p = request.config.topP; +- if (request.config?.tool_choice !== undefined) { +- body.tool_choice = request.config +- .tool_choice as BetaMessageCreateParams['tool_choice']; +- } +- if (request.config?.metadata !== undefined) { +- body.metadata = request.config +- .metadata as BetaMessageCreateParams['metadata']; +- } +- if (request.tools) { +- body.tools = request.tools.map((tool) => this.toAnthropicTool(tool)); +- } +- const thinkingConfig = this.toAnthropicThinkingConfig( +- request.config?.thinking +- ); +- if (thinkingConfig) { +- body.thinking = thinkingConfig as BetaMessageCreateParams['thinking']; +- } +- +- if (request.output?.format && request.output.format !== 'text') { +- throw new Error( +- `Only text output format is supported for Claude models currently` +- ); +- } +- +- return body; ++ system: betaSystem, ++ stop_sequences: request.config?.stopSequences, ++ temperature: request.config?.temperature, ++ top_k: topK, ++ top_p: topP, ++ tool_choice: request.config?.tool_choice, ++ metadata: request.config?.metadata, ++ tools: request.tools?.map((tool) => this.toAnthropicTool(tool)), ++ thinking: thinkingConfig, ++ output_format: this.isStructuredOutputEnabled(request) ++ ? { ++ type: 'json_schema', ++ schema: toAnthropicSchema(request.output!.schema!), ++ } ++ : undefined, ++ betas: Array.isArray(request.config?.betas) ++ ? [...(request.config?.betas ?? [])] ++ : [...BETA_APIS], ++ ...restConfig, ++ } as BetaMessageCreateParamsStreaming; ++ ++ return removeUndefinedProperties(body); + } + + protected toGenkitResponse(message: BetaMessage): GenerateResponseData { +@@ -491,4 +571,14 @@ export class BetaRunner extends BaseRunner { + return 'other'; + } + } ++ ++ private isStructuredOutputEnabled( ++ request: GenerateRequest ++ ): boolean { ++ return !!( ++ request.output?.schema && ++ request.output.constrained && ++ request.output.format === 'json' ++ ); ++ } + } +diff --git a/js/plugins/anthropic/src/runner/stable.ts b/js/plugins/anthropic/src/runner/stable.ts +index 0c8f7ffc4..1496029eb 100644 +--- a/js/plugins/anthropic/src/runner/stable.ts ++++ b/js/plugins/anthropic/src/runner/stable.ts +@@ -42,8 +42,10 @@ import { logger } from 'genkit/logging'; + + import { KNOWN_CLAUDE_MODELS, extractVersion } from '../models.js'; + import { AnthropicConfigSchema, type ClaudeRunnerParams } from '../types.js'; ++import { removeUndefinedProperties } from '../utils.js'; + import { BaseRunner } from './base.js'; + import { RunnerTypes as BaseRunnerTypes } from './types.js'; ++ + interface RunnerTypes extends BaseRunnerTypes { + Message: Message; + Stream: MessageStream; +@@ -179,6 +181,12 @@ export class Runner extends BaseRunner { + request: GenerateRequest, + cacheSystemPrompt?: boolean + ): MessageCreateParamsNonStreaming { ++ if (request.output?.format && request.output.format !== 'text') { ++ throw new Error( ++ `Only text output format is supported for Claude models currently` ++ ); ++ } ++ + const model = KNOWN_CLAUDE_MODELS[modelName]; + const { system, messages } = this.toAnthropicMessages(request.messages); + const mappedModelName = +@@ -197,51 +205,40 @@ export class Runner extends BaseRunner { + ] + : system; + ++ const thinkingConfig = this.toAnthropicThinkingConfig( ++ request.config?.thinking ++ ) as MessageCreateParams['thinking'] | undefined; ++ ++ // Need to extract topP and topK from request.config to avoid duplicate properties being added to the body ++ // This happens because topP and topK have different property names (top_p and top_k) in the Anthropic API. ++ // Thinking is extracted separately to avoid type issues. ++ // ApiVersion is extracted separately as it's not a valid property for the Anthropic API. ++ const { ++ topP, ++ topK, ++ apiVersion: _1, ++ thinking: _2, ++ ...restConfig ++ } = request.config ?? {}; ++ + const body: MessageCreateParamsNonStreaming = { + model: mappedModelName, + max_tokens: + request.config?.maxOutputTokens ?? this.DEFAULT_MAX_OUTPUT_TOKENS, + messages, ++ system: systemValue, ++ stop_sequences: request.config?.stopSequences, ++ temperature: request.config?.temperature, ++ top_k: topK, ++ top_p: topP, ++ tool_choice: request.config?.tool_choice, ++ metadata: request.config?.metadata, ++ tools: request.tools?.map((tool) => this.toAnthropicTool(tool)), ++ thinking: thinkingConfig, ++ ...restConfig, + }; + +- if (systemValue !== undefined) { +- body.system = systemValue; +- } +- +- if (request.tools) { +- body.tools = request.tools.map((tool) => this.toAnthropicTool(tool)); +- } +- if (request.config?.topK !== undefined) { +- body.top_k = request.config.topK; +- } +- if (request.config?.topP !== undefined) { +- body.top_p = request.config.topP; +- } +- if (request.config?.temperature !== undefined) { +- body.temperature = request.config.temperature; +- } +- if (request.config?.stopSequences !== undefined) { +- body.stop_sequences = request.config.stopSequences; +- } +- if (request.config?.metadata !== undefined) { +- body.metadata = request.config.metadata; +- } +- if (request.config?.tool_choice !== undefined) { +- body.tool_choice = request.config.tool_choice; +- } +- const thinkingConfig = this.toAnthropicThinkingConfig( +- request.config?.thinking +- ); +- if (thinkingConfig) { +- body.thinking = thinkingConfig as MessageCreateParams['thinking']; +- } +- +- if (request.output?.format && request.output.format !== 'text') { +- throw new Error( +- `Only text output format is supported for Claude models currently` +- ); +- } +- return body; ++ return removeUndefinedProperties(body); + } + + protected toAnthropicStreamingRequestBody( +@@ -249,6 +246,12 @@ export class Runner extends BaseRunner { + request: GenerateRequest, + cacheSystemPrompt?: boolean + ): MessageCreateParamsStreaming { ++ if (request.output?.format && request.output.format !== 'text') { ++ throw new Error( ++ `Only text output format is supported for Claude models currently` ++ ); ++ } ++ + const model = KNOWN_CLAUDE_MODELS[modelName]; + const { system, messages } = this.toAnthropicMessages(request.messages); + const mappedModelName = +@@ -267,53 +270,41 @@ export class Runner extends BaseRunner { + ] + : system; + ++ const thinkingConfig = this.toAnthropicThinkingConfig( ++ request.config?.thinking ++ ) as MessageCreateParams['thinking'] | undefined; ++ ++ // Need to extract topP and topK from request.config to avoid duplicate properties being added to the body ++ // This happens because topP and topK have different property names (top_p and top_k) in the Anthropic API. ++ // Thinking is extracted separately to avoid type issues. ++ // ApiVersion is extracted separately as it's not a valid property for the Anthropic API. ++ const { ++ topP, ++ topK, ++ apiVersion: _1, ++ thinking: _2, ++ ...restConfig ++ } = request.config ?? {}; ++ + const body: MessageCreateParamsStreaming = { + model: mappedModelName, + max_tokens: + request.config?.maxOutputTokens ?? this.DEFAULT_MAX_OUTPUT_TOKENS, + messages, + stream: true, ++ system: systemValue, ++ stop_sequences: request.config?.stopSequences, ++ temperature: request.config?.temperature, ++ top_k: topK, ++ top_p: topP, ++ tool_choice: request.config?.tool_choice, ++ metadata: request.config?.metadata, ++ tools: request.tools?.map((tool) => this.toAnthropicTool(tool)), ++ thinking: thinkingConfig, ++ ...restConfig, + }; + +- if (systemValue !== undefined) { +- body.system = systemValue; +- } +- +- if (request.tools) { +- body.tools = request.tools.map((tool) => this.toAnthropicTool(tool)); +- } +- if (request.config?.topK !== undefined) { +- body.top_k = request.config.topK; +- } +- if (request.config?.topP !== undefined) { +- body.top_p = request.config.topP; +- } +- if (request.config?.temperature !== undefined) { +- body.temperature = request.config.temperature; +- } +- if (request.config?.stopSequences !== undefined) { +- body.stop_sequences = request.config.stopSequences; +- } +- if (request.config?.metadata !== undefined) { +- body.metadata = request.config.metadata; +- } +- if (request.config?.tool_choice !== undefined) { +- body.tool_choice = request.config.tool_choice; +- } +- const thinkingConfig = this.toAnthropicThinkingConfig( +- request.config?.thinking +- ); +- if (thinkingConfig) { +- body.thinking = +- thinkingConfig as MessageCreateParamsStreaming['thinking']; +- } +- +- if (request.output?.format && request.output.format !== 'text') { +- throw new Error( +- `Only text output format is supported for Claude models currently` +- ); +- } +- return body; ++ return removeUndefinedProperties(body); + } + + protected async createMessage( +diff --git a/js/plugins/anthropic/src/types.ts b/js/plugins/anthropic/src/types.ts +index 7b3786730..2f61464a1 100644 +--- a/js/plugins/anthropic/src/types.ts ++++ b/js/plugins/anthropic/src/types.ts +@@ -67,26 +67,42 @@ export interface ClaudeRunnerParams extends ClaudeHelperParamsBase {} + export const AnthropicBaseConfigSchema = GenerationCommonConfigSchema.extend({ + tool_choice: z + .union([ +- z.object({ +- type: z.literal('auto'), +- }), +- z.object({ +- type: z.literal('any'), +- }), +- z.object({ +- type: z.literal('tool'), +- name: z.string(), +- }), ++ z ++ .object({ ++ type: z.literal('auto'), ++ }) ++ .passthrough(), ++ z ++ .object({ ++ type: z.literal('any'), ++ }) ++ .passthrough(), ++ z ++ .object({ ++ type: z.literal('tool'), ++ name: z.string(), ++ }) ++ .passthrough(), + ]) ++ .describe( ++ 'The tool choice to use for the request. This can be used to specify the tool to use for the request. If not specified, the model will choose the tool to use.' ++ ) + .optional(), + metadata: z + .object({ + user_id: z.string().optional(), + }) ++ .describe('The metadata to include in the request.') ++ .passthrough() + .optional(), + /** Optional shorthand to pick API surface for this request. */ +- apiVersion: z.enum(['stable', 'beta']).optional(), +-}); ++ apiVersion: z ++ .enum(['stable', 'beta']) ++ .optional() ++ .describe( ++ 'The API version to use for the request. Both stable and beta features are available on the beta API surface.' ++ ), ++}).passthrough(); + + export type AnthropicBaseConfigSchemaType = typeof AnthropicBaseConfigSchema; + +@@ -95,6 +111,8 @@ export const ThinkingConfigSchema = z + enabled: z.boolean().optional(), + budgetTokens: z.number().min(1_024).optional(), + }) ++ .passthrough() ++ .passthrough() + .superRefine((value, ctx) => { + if (!value.enabled) return; + +@@ -117,8 +135,10 @@ export const ThinkingConfigSchema = z + }); + + export const AnthropicThinkingConfigSchema = AnthropicBaseConfigSchema.extend({ +- thinking: ThinkingConfigSchema.optional(), +-}); ++ thinking: ThinkingConfigSchema.optional().describe( ++ 'The thinking configuration to use for the request. Thinking is a feature that allows the model to think about the request and provide a better response.' ++ ), ++}).passthrough(); + + export const AnthropicConfigSchema = AnthropicThinkingConfigSchema; + +diff --git a/js/plugins/anthropic/src/utils.ts b/js/plugins/anthropic/src/utils.ts +new file mode 100644 +index 000000000..6678eabc1 +--- /dev/null ++++ b/js/plugins/anthropic/src/utils.ts +@@ -0,0 +1,25 @@ ++/** ++ * Copyright 2025 Google LLC ++ * ++ * Licensed under the Apache License, Version 2.0 (the "License"); ++ * you may not use this file except in compliance with the License. ++ * You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++export function removeUndefinedProperties(obj: T): T { ++ if (typeof obj !== 'object' || obj === null) { ++ return obj; ++ } ++ ++ return Object.fromEntries( ++ Object.entries(obj).filter(([_, value]) => value !== undefined) ++ ) as T; ++} +diff --git a/js/plugins/anthropic/tests/effort_param_test.ts b/js/plugins/anthropic/tests/effort_param_test.ts +new file mode 100644 +index 000000000..12e67f4ed +--- /dev/null ++++ b/js/plugins/anthropic/tests/effort_param_test.ts +@@ -0,0 +1,249 @@ ++/** ++ * Copyright 2025 Google LLC ++ * ++ * Licensed under the Apache License, Version 2.0 (the "License"); ++ * you may not use this file except in compliance with the License. ++ * You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++import type Anthropic from '@anthropic-ai/sdk'; ++import * as assert from 'assert'; ++import { genkit } from 'genkit'; ++import { describe, test } from 'node:test'; ++import { anthropic } from '../src/index.js'; ++import { __testClient } from '../src/types.js'; ++import { ++ createMockAnthropicClient, ++ createMockAnthropicMessage, ++ mockTextChunk, ++} from './mocks/anthropic-client.js'; ++ ++/** ++ * Options for creating a plugin with a mock client ++ */ ++interface CreatePluginOptions { ++ apiVersion?: 'beta' | 'stable'; ++ mockClient: Anthropic; ++} ++ ++/** ++ * Creates an Anthropic plugin configured with a mock client for testing ++ */ ++function createPlugin(options: CreatePluginOptions) { ++ return anthropic({ ++ apiVersion: options.apiVersion, ++ // @ts-ignore ++ [__testClient]: options.mockClient, ++ }); ++} ++ ++/** ++ * Creates a Genkit instance with the given plugin ++ */ ++function createGenkitInstance(plugin: ReturnType) { ++ return genkit({ ++ plugins: [plugin], ++ }); ++} ++ ++/** ++ * Helper to get the proper create stub from the mock client for a given API version. ++ */ ++function getCreateStub(mockClient: Anthropic, apiVersion: 'beta' | 'stable') { ++ return apiVersion === 'beta' ++ ? (mockClient.beta.messages.create as any) ++ : (mockClient.messages.create as any); ++} ++ ++/** ++ * Extracts the API request object from the mock for verification ++ * @param apiVersion - 'beta' or 'stable' to determine which API endpoint to check ++ */ ++function getApiRequest( ++ mockClient: Anthropic, ++ apiVersion: 'beta' | 'stable', ++ callIndex: number = 0 ++) { ++ const stub = getCreateStub(mockClient, apiVersion); ++ return stub.mock.calls[callIndex]?.arguments[0]; ++} ++ ++/** ++ * Verifies that the API was called the expected number of times ++ * @param apiVersion - 'beta' or 'stable' to determine which API endpoint to verify ++ */ ++function verifyApiCalled( ++ mockClient: Anthropic, ++ apiVersion: 'beta' | 'stable', ++ expectedCalls: number = 1 ++) { ++ const stub = getCreateStub(mockClient, apiVersion); ++ assert.strictEqual( ++ stub.mock.calls.length, ++ expectedCalls, ++ `${apiVersion === 'beta' ? 'Beta' : 'Stable'} API should be called ${expectedCalls} time(s)` ++ ); ++} ++ ++/** ++ * Tests for effort parameter functionality. ++ * These tests verify that output_config.effort is correctly passed to the Anthropic API ++ * when using the beta API with claude-opus-4-5. ++ */ ++describe('Effort Parameter Tests', () => { ++ const OPUS_4_5_MODEL = 'anthropic/claude-opus-4-5'; ++ ++ test('should pass output_config.effort to API when using beta API with claude-opus-4-5', async () => { ++ const mockClient = createMockAnthropicClient({ ++ messageResponse: createMockAnthropicMessage({ ++ text: 'Response with high effort', ++ }), ++ }); ++ ++ const plugin = createPlugin({ ++ apiVersion: 'beta', ++ mockClient, ++ }); ++ ++ const ai = createGenkitInstance(plugin); ++ ++ await ai.generate({ ++ model: OPUS_4_5_MODEL, ++ prompt: 'Generate a detailed response', ++ config: { ++ output_config: { ++ effort: 'high', ++ }, ++ }, ++ }); ++ ++ verifyApiCalled(mockClient, 'beta'); ++ const apiRequest = getApiRequest(mockClient, 'beta'); ++ ++ assert.ok(apiRequest.output_config, 'Request should have output_config'); ++ assert.strictEqual( ++ apiRequest.output_config.effort, ++ 'high', ++ 'effort should be set to high' ++ ); ++ }); ++ ++ test('should pass output_config.effort with low value', async () => { ++ const mockClient = createMockAnthropicClient({ ++ messageResponse: createMockAnthropicMessage({ ++ text: 'Response with low effort', ++ }), ++ }); ++ ++ const plugin = createPlugin({ ++ apiVersion: 'beta', ++ mockClient, ++ }); ++ ++ const ai = createGenkitInstance(plugin); ++ ++ await ai.generate({ ++ model: OPUS_4_5_MODEL, ++ prompt: 'Generate a quick response', ++ config: { ++ output_config: { ++ effort: 'low', ++ }, ++ }, ++ }); ++ ++ verifyApiCalled(mockClient, 'beta'); ++ const apiRequest = getApiRequest(mockClient, 'beta'); ++ ++ assert.ok(apiRequest.output_config, 'Request should have output_config'); ++ assert.strictEqual( ++ apiRequest.output_config.effort, ++ 'low', ++ 'effort should be set to low' ++ ); ++ }); ++ ++ test('should pass output_config.effort with medium value', async () => { ++ const mockClient = createMockAnthropicClient({ ++ messageResponse: createMockAnthropicMessage({ ++ text: 'Response with medium effort', ++ }), ++ }); ++ ++ const plugin = createPlugin({ ++ apiVersion: 'beta', ++ mockClient, ++ }); ++ ++ const ai = createGenkitInstance(plugin); ++ ++ await ai.generate({ ++ model: OPUS_4_5_MODEL, ++ prompt: 'Generate a balanced response', ++ config: { ++ output_config: { ++ effort: 'medium', ++ }, ++ }, ++ }); ++ ++ verifyApiCalled(mockClient, 'beta'); ++ const apiRequest = getApiRequest(mockClient, 'beta'); ++ ++ assert.ok(apiRequest.output_config, 'Request should have output_config'); ++ assert.strictEqual( ++ apiRequest.output_config.effort, ++ 'medium', ++ 'effort should be set to medium' ++ ); ++ }); ++ ++ test('should pass output_config.effort in streaming requests', async () => { ++ const mockClient = createMockAnthropicClient({ ++ streamChunks: [mockTextChunk('Streaming response')], ++ messageResponse: createMockAnthropicMessage({ ++ text: 'Streaming response', ++ }), ++ }); ++ ++ const plugin = createPlugin({ ++ apiVersion: 'beta', ++ mockClient, ++ }); ++ ++ const ai = createGenkitInstance(plugin); ++ ++ await ai.generate({ ++ model: OPUS_4_5_MODEL, ++ prompt: 'Generate a streaming response', ++ config: { ++ output_config: { ++ effort: 'high', ++ }, ++ }, ++ streamingCallback: () => {}, ++ }); ++ ++ const betaStreamStub = mockClient.beta.messages.stream as any; ++ assert.strictEqual(betaStreamStub.mock.calls.length, 1); ++ const requestBody = betaStreamStub.mock.calls[0]?.arguments[0]; ++ ++ assert.ok( ++ requestBody.output_config, ++ 'Streaming request should have output_config' ++ ); ++ assert.strictEqual( ++ requestBody.output_config.effort, ++ 'high', ++ 'effort should be set to high in streaming request' ++ ); ++ }); ++}); +diff --git a/js/plugins/anthropic/tests/execution_test.ts b/js/plugins/anthropic/tests/execution_test.ts +index 069d2d2dc..ae7b6a85e 100644 +--- a/js/plugins/anthropic/tests/execution_test.ts ++++ b/js/plugins/anthropic/tests/execution_test.ts +@@ -14,11 +14,12 @@ + * limitations under the License. + */ + +-import type { GenerateRequest, ModelAction } from '@genkit-ai/ai/model'; + import * as assert from 'assert'; ++import type { GenerateRequest } from 'genkit'; ++import type { ModelAction } from 'genkit/model'; + import { describe, mock, test } from 'node:test'; + import { anthropic } from '../src/index.js'; +-import { __testClient } from '../src/types.js'; ++import { PluginOptions, __testClient } from '../src/types.js'; + import { + createMockAnthropicClient, + createMockAnthropicMessage, +@@ -35,11 +36,17 @@ describe('Model Execution Integration Tests', () => { + const plugin = anthropic({ + apiKey: 'test-key', + [__testClient]: mockClient, +- }); ++ } as PluginOptions); ++ ++ // Verify plugin has resolve method ++ assert.ok(plugin.resolve, 'Plugin should have resolve method'); + + // Resolve the model action via plugin +- const modelAction = plugin.resolve('model', 'claude-3-5-haiku-20241022'); +- assert.ok(modelAction, 'Model should be resolved'); ++ const modelAction = plugin.resolve( ++ 'model', ++ 'claude-3-5-haiku-20241022' ++ ) as ModelAction; ++ + assert.strictEqual( + (modelAction as ModelAction).__action.name, + 'anthropic/claude-3-5-haiku-20241022' +@@ -55,11 +62,11 @@ describe('Model Execution Integration Tests', () => { + ], + }; + +- const response = await (modelAction as ModelAction)(request, { ++ const response = await modelAction(request, { + streamingRequested: false, + sendChunk: mock.fn(), + abortSignal: new AbortController().signal, +- }); ++ } as Parameters[1]); + + assert.ok(response, 'Response should be returned'); + assert.ok(response.candidates, 'Response should have candidates'); +@@ -86,7 +93,10 @@ describe('Model Execution Integration Tests', () => { + const plugin = anthropic({ + apiKey: 'test-key', + [__testClient]: mockClient, +- }); ++ } as PluginOptions); ++ ++ // Verify plugin has resolve method ++ assert.ok(plugin.resolve, 'Plugin should have resolve method'); + + const modelAction = plugin.resolve( + 'model', +@@ -114,9 +124,10 @@ describe('Model Execution Integration Tests', () => { + streamingRequested: false, + sendChunk: mock.fn(), + abortSignal: new AbortController().signal, +- }); ++ } as Parameters[1]); + + assert.ok(response, 'Response should be returned'); ++ assert.ok(response.candidates, 'Response should have candidates'); + assert.strictEqual( + response.candidates[0].message.content[0].text, + 'The capital of France is Paris.' +@@ -139,7 +150,10 @@ describe('Model Execution Integration Tests', () => { + const plugin = anthropic({ + apiKey: 'test-key', + [__testClient]: mockClient, +- }); ++ } as PluginOptions); ++ ++ // Verify plugin has resolve method ++ assert.ok(plugin.resolve, 'Plugin should have resolve method'); + + const modelAction = plugin.resolve( + 'model', +@@ -163,7 +177,7 @@ describe('Model Execution Integration Tests', () => { + streamingRequested: false, + sendChunk: mock.fn(), + abortSignal: new AbortController().signal, +- }); ++ } as Parameters[1]); + + assert.ok(response, 'Response should be returned'); + +@@ -197,7 +211,10 @@ describe('Model Execution Integration Tests', () => { + const plugin = anthropic({ + apiKey: 'test-key', + [__testClient]: mockClient, +- }); ++ } as PluginOptions); ++ ++ // Verify plugin has resolve method ++ assert.ok(plugin.resolve, 'Plugin should have resolve method'); + + const modelAction = plugin.resolve( + 'model', +@@ -212,7 +229,7 @@ describe('Model Execution Integration Tests', () => { + streamingRequested: false, + sendChunk: mock.fn(), + abortSignal: new AbortController().signal, +- } ++ } as Parameters[1] + ); + + assert.ok(response.usage, 'Usage should be returned'); +@@ -231,7 +248,10 @@ describe('Model Execution Integration Tests', () => { + const plugin = anthropic({ + apiKey: 'test-key', + [__testClient]: mockClient, +- }); ++ } as PluginOptions); ++ ++ // Verify plugin has resolve method ++ assert.ok(plugin.resolve, 'Plugin should have resolve method'); + + const modelAction = plugin.resolve( + 'model', +@@ -246,10 +266,11 @@ describe('Model Execution Integration Tests', () => { + streamingRequested: false, + sendChunk: mock.fn(), + abortSignal: new AbortController().signal, +- } ++ } as Parameters[1] + ); + + assert.ok(response, 'Response should be returned'); ++ assert.ok(response.candidates, 'Response should have candidates'); + assert.strictEqual(response.candidates[0].finishReason, 'length'); + }); + +@@ -263,7 +284,10 @@ describe('Model Execution Integration Tests', () => { + const plugin = anthropic({ + apiKey: 'test-key', + [__testClient]: mockClient, +- }); ++ } as PluginOptions); ++ ++ // Verify plugin has resolve method ++ assert.ok(plugin.resolve, 'Plugin should have resolve method'); + + // Resolve without prefix + const modelAction = plugin.resolve( +@@ -280,7 +304,7 @@ describe('Model Execution Integration Tests', () => { + streamingRequested: false, + sendChunk: mock.fn(), + abortSignal: new AbortController().signal, +- } ++ } as Parameters[1] + ); + + assert.ok(response, 'Response should be returned'); +@@ -296,7 +320,10 @@ describe('Model Execution Integration Tests', () => { + const plugin = anthropic({ + apiKey: 'test-key', + [__testClient]: mockClient, +- }); ++ } as PluginOptions); ++ ++ // Verify plugin has resolve method ++ assert.ok(plugin.resolve, 'Plugin should have resolve method'); + + // Resolve with prefix + const modelAction = plugin.resolve( +@@ -313,7 +340,7 @@ describe('Model Execution Integration Tests', () => { + streamingRequested: false, + sendChunk: mock.fn(), + abortSignal: new AbortController().signal, +- } ++ } as Parameters[1] + ); + + assert.ok(response, 'Response should be returned'); +@@ -329,7 +356,10 @@ describe('Model Execution Integration Tests', () => { + const plugin = anthropic({ + apiKey: 'test-key', + [__testClient]: mockClient, +- }); ++ } as PluginOptions); ++ ++ // Verify plugin has resolve method ++ assert.ok(plugin.resolve, 'Plugin should have resolve method'); + + // Resolve unknown model (passes through to API) + const modelAction = plugin.resolve( +@@ -346,13 +376,15 @@ describe('Model Execution Integration Tests', () => { + streamingRequested: false, + sendChunk: mock.fn(), + abortSignal: new AbortController().signal, +- } ++ } as Parameters[1] + ); + + assert.ok(response, 'Response should be returned for unknown model'); ++ assert.ok(response.candidates, 'Response should have candidates'); + assert.strictEqual( +- response.candidates[0].message.content[0].text, +- 'Response from future model' ++ response.candidates?.[0]?.message.content[0].text, ++ 'Response from future model', ++ 'Response should have candidates' + ); + }); + }); +diff --git a/js/plugins/anthropic/tests/live_test.ts b/js/plugins/anthropic/tests/live_test.ts +index f008157af..0c370196d 100644 +--- a/js/plugins/anthropic/tests/live_test.ts ++++ b/js/plugins/anthropic/tests/live_test.ts +@@ -22,7 +22,7 @@ + */ + + import * as assert from 'assert'; +-import { genkit } from 'genkit'; ++import { genkit, z } from 'genkit'; + import { describe, it } from 'node:test'; + import { anthropic } from '../src/index.js'; + +@@ -80,4 +80,50 @@ describe('Live Anthropic API Tests', { skip: !API_KEY }, () => { + + assert.ok(result.text.toLowerCase().includes('hello')); + }); ++ ++ it('should return structured output matching the schema', async () => { ++ const ai = genkit({ ++ plugins: [anthropic({ apiKey: API_KEY, apiVersion: 'beta' })], ++ }); ++ ++ const schema = z.object({ ++ name: z.string(), ++ age: z.number(), ++ city: z.string(), ++ isStudent: z.boolean(), ++ isEmployee: z.boolean(), ++ isRetired: z.boolean(), ++ isUnemployed: z.boolean(), ++ isDisabled: z.boolean(), ++ }); ++ ++ const result = await ai.generate({ ++ model: 'anthropic/claude-sonnet-4-5', ++ prompt: ++ 'Generate a fictional person with name "Alice", age 30, and city "New York". Return only the JSON.', ++ output: { schema, format: 'json', constrained: true }, ++ }); ++ ++ const parsed = result.output; ++ assert.ok(parsed, 'Should have parsed output'); ++ assert.deepStrictEqual( ++ { name: parsed.name, age: parsed.age, city: parsed.city }, ++ { name: 'Alice', age: 30, city: 'New York' } ++ ); ++ ++ // Check that boolean fields are present and are actually booleans ++ for (const key of [ ++ 'isStudent', ++ 'isEmployee', ++ 'isRetired', ++ 'isUnemployed', ++ 'isDisabled', ++ ]) { ++ assert.strictEqual( ++ typeof parsed[key], ++ 'boolean', ++ `Field ${key} should be a boolean but got: ${typeof parsed[key]}` ++ ); ++ } ++ }); + }); +diff --git a/js/plugins/anthropic/tests/mocks/anthropic-client.ts b/js/plugins/anthropic/tests/mocks/anthropic-client.ts +index 7fe29eceb..09be1b749 100644 +--- a/js/plugins/anthropic/tests/mocks/anthropic-client.ts ++++ b/js/plugins/anthropic/tests/mocks/anthropic-client.ts +@@ -379,7 +379,7 @@ function toBetaMessage(message: Message): BetaMessage { + server_tool_use: message.usage.server_tool_use as any, + service_tier: message.usage.service_tier, + }, +- }; ++ } as BetaMessage; + } + + function toBetaStreamEvent( +diff --git a/js/plugins/anthropic/tests/structured_output_test.ts b/js/plugins/anthropic/tests/structured_output_test.ts +new file mode 100644 +index 000000000..9a1e31fcd +--- /dev/null ++++ b/js/plugins/anthropic/tests/structured_output_test.ts +@@ -0,0 +1,358 @@ ++/** ++ * Copyright 2025 Google LLC ++ * ++ * Licensed under the Apache License, Version 2.0 (the "License"); ++ * you may not use this file except in compliance with the License. ++ * You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++import type Anthropic from '@anthropic-ai/sdk'; ++import * as assert from 'assert'; ++import { genkit, z } from 'genkit'; ++import { describe, test } from 'node:test'; ++import { anthropic } from '../src/index.js'; ++import { __testClient } from '../src/types.js'; ++import { ++ createMockAnthropicClient, ++ createMockAnthropicMessage, ++} from './mocks/anthropic-client.js'; ++ ++/** ++ * Test constants for consistent test setup ++ */ ++const TEST_API_KEY = 'test-key'; ++const SUPPORTING_MODEL = 'anthropic/claude-sonnet-4-5'; ++const NON_SUPPORTING_MODEL = 'anthropic/claude-sonnet-4'; ++ ++/** ++ * Options for creating a plugin with a mock client ++ */ ++interface CreatePluginOptions { ++ apiVersion?: 'beta' | 'stable'; ++ mockClient: Anthropic; ++} ++ ++/** ++ * Creates an Anthropic plugin configured with a mock client for testing ++ */ ++function createPlugin(options: CreatePluginOptions) { ++ return anthropic({ ++ apiKey: TEST_API_KEY, ++ apiVersion: options.apiVersion, ++ // @ts-ignore ++ [__testClient]: options.mockClient, ++ }); ++} ++ ++/** ++ * Creates a Genkit instance with the given plugin ++ */ ++function createGenkitInstance(plugin: ReturnType) { ++ return genkit({ ++ plugins: [plugin], ++ }); ++} ++ ++/** ++ * Helper to get the proper create stub from the mock client for a given API version. ++ */ ++function getCreateStub(mockClient: Anthropic, apiVersion: 'beta' | 'stable') { ++ return apiVersion === 'beta' ++ ? (mockClient.beta.messages.create as any) ++ : (mockClient.messages.create as any); ++} ++ ++/** ++ * Extracts the API request object from the mock for verification ++ * @param apiVersion - 'beta' or 'stable' to determine which API endpoint to check ++ */ ++function getApiRequest( ++ mockClient: Anthropic, ++ apiVersion: 'beta' | 'stable', ++ callIndex: number = 0 ++) { ++ const stub = getCreateStub(mockClient, apiVersion); ++ return stub.mock.calls[callIndex]?.arguments[0]; ++} ++ ++/** ++ * Verifies that the API was called the expected number of times ++ * @param apiVersion - 'beta' or 'stable' to determine which API endpoint to verify ++ */ ++function verifyApiCalled( ++ mockClient: Anthropic, ++ apiVersion: 'beta' | 'stable', ++ expectedCalls: number = 1 ++) { ++ const stub = getCreateStub(mockClient, apiVersion); ++ assert.strictEqual( ++ stub.mock.calls.length, ++ expectedCalls, ++ `${apiVersion === 'beta' ? 'Beta' : 'Stable'} API should be called ${expectedCalls} time(s)` ++ ); ++} ++ ++/** ++ * Tests for structured output (constrained generation) functionality. ++ * These tests verify that output_format is correctly passed to the Anthropic API ++ * when using the beta API with constrained output, and that it's NOT passed ++ * in various edge cases (stable API, non-json format, missing schema, etc.) ++ */ ++describe('Structured Output Tests', () => { ++ test('should pass output_format to API when using beta API with constrained output', async () => { ++ const mockClient = createMockAnthropicClient({ ++ messageResponse: createMockAnthropicMessage({ ++ text: '{"name":"Alice","age":30,"city":"New York","isStudent":false,"isEmployee":true,"isRetired":false,"isUnemployed":false,"isDisabled":false}', ++ }), ++ }); ++ ++ // Set up plugin with beta API enabled ++ const plugin = createPlugin({ ++ apiVersion: 'beta', ++ mockClient, ++ }); ++ ++ const ai = createGenkitInstance(plugin); ++ ++ // Call generate with sonnet 4.5 (supports native constrained output) ++ await ai.generate({ ++ model: SUPPORTING_MODEL, ++ prompt: ++ 'Generate a fictional person with name "Alice", age 30, and city "New York". Return only the JSON.', ++ output: { ++ schema: z.object({ ++ name: z.string(), ++ age: z.number(), ++ city: z.string(), ++ isStudent: z.boolean(), ++ isEmployee: z.boolean(), ++ isRetired: z.boolean(), ++ isUnemployed: z.boolean(), ++ isDisabled: z.boolean(), ++ }), ++ format: 'json', ++ constrained: true, ++ }, ++ }); ++ ++ // Verify the beta API was called ++ verifyApiCalled(mockClient, 'beta'); ++ ++ // Verify output_format was included in the API request ++ const apiRequest = getApiRequest(mockClient, 'beta'); ++ assert.ok(apiRequest.output_format, 'Request should have output_format'); ++ assert.strictEqual( ++ apiRequest.output_format.type, ++ 'json_schema', ++ 'output_format type should be json_schema' ++ ); ++ assert.ok( ++ apiRequest.output_format.schema, ++ 'output_format should have schema' ++ ); ++ // Verify schema transformation: additionalProperties should be false for constrained output ++ assert.strictEqual( ++ apiRequest.output_format.schema.additionalProperties, ++ false, ++ 'Schema should have additionalProperties: false' ++ ); ++ }); ++ ++ test('should NOT pass output_format to API when constrained is false and using beta API', async () => { ++ const mockClient = createMockAnthropicClient({ ++ messageResponse: createMockAnthropicMessage({ ++ text: '{"name":"Alice"}', ++ }), ++ }); ++ ++ // Set up plugin with beta API enabled ++ const plugin = createPlugin({ ++ apiVersion: 'beta', ++ mockClient, ++ }); ++ ++ const ai = createGenkitInstance(plugin); ++ ++ // Call generate with constrained: false ++ await ai.generate({ ++ model: SUPPORTING_MODEL, ++ prompt: 'Generate JSON', ++ output: { ++ format: 'json', ++ constrained: false, ++ schema: z.object({ ++ name: z.string(), ++ }), ++ }, ++ }); ++ ++ // Verify the beta API was called ++ verifyApiCalled(mockClient, 'beta'); ++ ++ // Verify output_format was NOT included when constrained is false ++ const apiRequest = getApiRequest(mockClient, 'beta'); ++ assert.strictEqual( ++ apiRequest.output_format, ++ undefined, ++ 'Request should NOT have output_format when constrained is false' ++ ); ++ }); ++ ++ test('should NOT pass output_format to API when format is not json and using beta API', async () => { ++ const mockClient = createMockAnthropicClient({ ++ messageResponse: createMockAnthropicMessage({ ++ text: 'Some text response', ++ }), ++ }); ++ ++ // Set up plugin with beta API enabled ++ const plugin = createPlugin({ ++ apiVersion: 'beta', ++ mockClient, ++ }); ++ ++ const ai = createGenkitInstance(plugin); ++ ++ // Call generate with format: 'text' (not 'json') ++ await ai.generate({ ++ model: SUPPORTING_MODEL, ++ prompt: 'Generate text', ++ output: { ++ format: 'text', ++ constrained: true, ++ }, ++ }); ++ ++ // Verify the beta API was called ++ verifyApiCalled(mockClient, 'beta'); ++ ++ // Verify output_format was NOT included when format is not json ++ const apiRequest = getApiRequest(mockClient, 'beta'); ++ assert.strictEqual( ++ apiRequest.output_format, ++ undefined, ++ 'Request should NOT have output_format when format is text' ++ ); ++ }); ++ ++ test('should NOT pass output_format to API when schema is not provided and using beta API', async () => { ++ const mockClient = createMockAnthropicClient({ ++ messageResponse: createMockAnthropicMessage({ ++ text: '{"anything": "goes"}', ++ }), ++ }); ++ ++ // Set up plugin with beta API enabled ++ const plugin = createPlugin({ ++ apiVersion: 'beta', ++ mockClient, ++ }); ++ ++ const ai = createGenkitInstance(plugin); ++ ++ // Call generate with constrained: true but no schema ++ await ai.generate({ ++ model: SUPPORTING_MODEL, ++ prompt: 'Generate JSON', ++ output: { ++ format: 'json', ++ constrained: true, ++ // No schema provided ++ }, ++ }); ++ ++ // Verify the beta API was called ++ verifyApiCalled(mockClient, 'beta'); ++ ++ // Verify output_format was NOT included when schema is missing ++ const apiRequest = getApiRequest(mockClient, 'beta'); ++ assert.strictEqual( ++ apiRequest.output_format, ++ undefined, ++ 'Request should NOT have output_format when schema is not provided' ++ ); ++ }); ++ ++ test('should NOT pass output_format to API when model does not support structured output and using beta API', async () => { ++ const mockClient = createMockAnthropicClient({ ++ messageResponse: createMockAnthropicMessage({ ++ text: '{"name":"Alice"}', ++ }), ++ }); ++ ++ // Set up plugin with beta API enabled ++ const plugin = createPlugin({ ++ apiVersion: 'beta', ++ mockClient, ++ }); ++ ++ const ai = createGenkitInstance(plugin); ++ ++ // Call generate with model that does not support structured output ++ await ai.generate({ ++ model: NON_SUPPORTING_MODEL, ++ prompt: 'Generate JSON', ++ output: { ++ format: 'json', ++ constrained: true, ++ }, ++ }); ++ ++ // Verify the beta API was called ++ verifyApiCalled(mockClient, 'beta'); ++ ++ // Verify output_format was NOT included when model does not support structured output ++ const apiRequest = getApiRequest(mockClient, 'beta'); ++ assert.strictEqual( ++ apiRequest.output_format, ++ undefined, ++ 'Request should NOT have output_format when model does not support structured output' ++ ); ++ }); ++ ++ test('should throw an error when using stable API with non-text output format', async () => { ++ const mockClient = createMockAnthropicClient({ ++ messageResponse: createMockAnthropicMessage({ ++ text: '{"name":"Alice","age":30,"city":"New York"}', ++ }), ++ }); ++ ++ // Set up plugin with stable API (not beta) ++ const plugin = createPlugin({ ++ apiVersion: 'stable', ++ mockClient, ++ }); ++ ++ const ai = createGenkitInstance(plugin); ++ ++ // Call generate with constrained output (would work with beta API) ++ // Expect an error to be thrown since only text output is supported for stable API ++ await assert.rejects( ++ async () => { ++ await ai.generate({ ++ model: SUPPORTING_MODEL, ++ prompt: 'Generate JSON', ++ output: { ++ format: 'json', ++ constrained: true, ++ schema: z.object({ ++ name: z.string(), ++ age: z.number(), ++ city: z.string(), ++ }), ++ }, ++ }); ++ }, ++ /Only text output format is supported for Claude models currently/, ++ 'Should throw an error for non-text output on stable API' ++ ); ++ }); ++}); +diff --git a/js/plugins/google-genai/src/googleai/gemini.ts b/js/plugins/google-genai/src/googleai/gemini.ts +index 21873f1ac..cf8eb0811 100644 +--- a/js/plugins/google-genai/src/googleai/gemini.ts ++++ b/js/plugins/google-genai/src/googleai/gemini.ts +@@ -269,7 +269,7 @@ export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({ + ) + .optional(), + thinkingLevel: z +- .enum(['LOW', 'MEDIUM', 'HIGH']) ++ .enum(['MINIMAL', 'LOW', 'MEDIUM', 'HIGH']) + .describe( + 'For Gemini 3.0 - Indicates the thinking level. A higher level ' + + 'is associated with more detailed thinking, which is needed for solving ' + +@@ -419,6 +419,7 @@ const GENERIC_GEMMA_MODEL = commonRef( + ); + + const KNOWN_GEMINI_MODELS = { ++ 'gemini-3-flash-preview': commonRef('gemini-3-flash-preview'), + 'gemini-3-pro-preview': commonRef('gemini-3-pro-preview'), + 'gemini-2.5-pro': commonRef('gemini-2.5-pro'), + 'gemini-2.5-flash': commonRef('gemini-2.5-flash'), +diff --git a/js/plugins/google-genai/src/vertexai/gemini.ts b/js/plugins/google-genai/src/vertexai/gemini.ts +index 5f32120d8..1715bd973 100644 +--- a/js/plugins/google-genai/src/vertexai/gemini.ts ++++ b/js/plugins/google-genai/src/vertexai/gemini.ts +@@ -318,7 +318,7 @@ export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({ + ) + .optional(), + thinkingLevel: z +- .enum(['LOW', 'MEDIUM', 'HIGH']) ++ .enum(['MINIMAL', 'LOW', 'MEDIUM', 'HIGH']) + .describe( + 'For Gemini 3.0 - Indicates the thinking level. A higher level ' + + 'is associated with more detailed thinking, which is needed for solving ' + +@@ -422,6 +422,7 @@ const GENERIC_IMAGE_MODEL = commonRef( + ); + + export const KNOWN_GEMINI_MODELS = { ++ 'gemini-3-flash-preview': commonRef('gemini-3-flash-preview'), + 'gemini-3-pro-preview': commonRef('gemini-3-pro-preview'), + 'gemini-2.5-flash-lite': commonRef('gemini-2.5-flash-lite'), + 'gemini-2.5-pro': commonRef('gemini-2.5-pro'), +diff --git a/js/plugins/google-genai/tests/model-tests-tts.yaml b/js/plugins/google-genai/tests/model-tests-tts.yaml +new file mode 100644 +index 000000000..726e2736d +--- /dev/null ++++ b/js/plugins/google-genai/tests/model-tests-tts.yaml +@@ -0,0 +1,90 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++- model: googleai/imagen-4.0-generate-001 ++ supports: ++ - output-image ++- model: googleai/gemini-2.5-flash-preview-tts ++ tests: ++ - name: TTS Test ++ input: ++ messages: ++ - role: user ++ content: ++ - text: 'Hello world' ++ config: ++ responseModalities: ['AUDIO'] ++ validators: ++ - valid-media:audio ++- model: googleai/gemini-2.5-pro ++ supports: ++ - tool-request ++ - structured-output ++ - multiturn ++ - system-role ++ - input-image-base64 ++ - input-image-url ++ - input-video-youtube ++- model: googleai/gemini-3-pro-preview ++ supports: ++ - tool-request ++ - structured-output ++ - multiturn ++ - system-role ++ - input-image-base64 ++ - input-image-url ++ - input-video-youtube ++ tests: ++ - name: Tool Response Conformance ++ input: ++ messages: ++ - role: user ++ content: ++ - text: 'What is the weather in New York? Use the tool.' ++ - role: model ++ content: ++ - toolRequest: ++ name: weather ++ input: ++ city: New York ++ metadata: ++ thoughtSignature: CvABAXLI2nxTZfKU3MkzLiGBrX62oq77vN2kHjT8pwwXRjtzbCqC07pPhIZ31sS+2kUFDh/kUY4SOvZzjjtP8UxI5GSFRWlX8yVDrDFo17RN/urwc1QuaMMzy66eQubpPRDEwfi6S5IKxZq0kRX6cSceB4NVCQAAAU8sYJwqWFL9CIaGac4lzF+34VvMWFLqdb40oe7/gw/KK1fqAeqDs+FJLksA+Q5qpHn3BETcqT0AuFe01IB2EVA7Us+/N3VGonw61F5cFNjHXO1jIYDybl3MXR9M5T5QB1a3EyicYXSX5/+bCmny1ka4kInbtzEqMMuv ++ - role: tool ++ content: ++ - toolResponse: ++ name: weather ++ output: '21C' ++ tools: ++ - name: weather ++ description: Get the weather for a city ++ inputSchema: ++ type: object ++ properties: ++ city: ++ type: string ++ required: ++ - city ++ validators: ++ - text-includes:21 ++- model: googleai/gemini-2.5-flash ++ supports: ++ - tool-request ++ - structured-output ++ - multiturn ++ - system-role ++ - input-image-base64 ++ - input-image-url ++ - input-video-youtube +diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml +index 454817f40..7f3123818 100644 +--- a/js/pnpm-lock.yaml ++++ b/js/pnpm-lock.yaml +@@ -171,7 +171,7 @@ importers: + version: 6.0.1 + tsup: + specifier: ^8.3.5 +- version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) ++ version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) + tsx: + specifier: ^4.19.2 + version: 4.20.3 +@@ -181,7 +181,7 @@ importers: + optionalDependencies: + '@genkit-ai/firebase': + specifier: ^1.16.1 +- version: 1.16.1(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)) ++ version: 1.16.1(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)) + + doc-snippets: + dependencies: +@@ -199,7 +199,7 @@ importers: + version: 5.0.0 + firebase-functions: + specifier: ^6.3.1 +- version: 6.3.2(firebase-admin@13.5.0(encoding@0.1.13)) ++ version: 6.3.2(firebase-admin@13.6.0(encoding@0.1.13)) + genkit: + specifier: workspace:* + version: link:../genkit +@@ -249,7 +249,7 @@ importers: + version: 6.0.1 + tsup: + specifier: ^8.3.5 +- version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) ++ version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) + tsx: + specifier: ^4.19.2 + version: 4.20.3 +@@ -260,8 +260,8 @@ importers: + plugins/anthropic: + dependencies: + '@anthropic-ai/sdk': +- specifier: ^0.68.0 +- version: 0.68.0(zod@3.25.76) ++ specifier: ^0.71.2 ++ version: 0.71.2(zod@3.25.67) + devDependencies: + '@types/node': + specifier: ^20.11.16 +@@ -280,7 +280,7 @@ importers: + version: 6.0.1 + tsup: + specifier: ^8.3.5 +- version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) ++ version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) + tsx: + specifier: ^4.19.2 + version: 4.20.3 +@@ -314,7 +314,7 @@ importers: + version: 6.0.1 + tsup: + specifier: ^8.0.2 +- version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) ++ version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) + tsx: + specifier: ^4.7.0 + version: 4.20.3 +@@ -326,7 +326,7 @@ importers: + dependencies: + chromadb: + specifier: 1.8.1 +- version: 1.8.1(encoding@0.1.13)(openai@4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.76)) ++ version: 1.8.1(encoding@0.1.13)(openai@4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67)) + genkit: + specifier: workspace:^ + version: link:../../genkit +@@ -345,7 +345,7 @@ importers: + version: 6.0.1 + tsup: + specifier: ^8.3.5 +- version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) ++ version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) + tsx: + specifier: ^4.19.2 + version: 4.20.3 +@@ -403,7 +403,7 @@ importers: + version: 29.4.0(@babel/core@7.25.7)(@jest/transform@29.7.0)(@jest/types@29.6.3)(babel-jest@29.7.0(@babel/core@7.25.7))(jest-util@29.7.0)(jest@29.7.0(@types/node@20.19.1)(ts-node@10.9.2(@types/node@20.19.1)(typescript@4.9.5)))(typescript@4.9.5) + tsup: + specifier: ^8.3.5 +- version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) ++ version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) + tsx: + specifier: ^4.19.2 + version: 4.20.3 +@@ -418,7 +418,7 @@ importers: + version: link:../../genkit + openai: + specifier: ^4.95.0 +- version: 4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.76) ++ version: 4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67) + devDependencies: + '@jest/globals': + specifier: ^29.7.0 +@@ -437,7 +437,7 @@ importers: + version: 29.4.0(@babel/core@7.25.7)(@jest/transform@29.7.0)(@jest/types@29.6.3)(babel-jest@29.7.0(@babel/core@7.25.7))(jest-util@29.7.0)(jest@29.7.0(@types/node@20.19.1)(ts-node@10.9.2(@types/node@20.19.1)(typescript@5.8.3)))(typescript@5.8.3) + tsup: + specifier: ^8.0.2 +- version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@5.8.3)(yaml@2.8.2) ++ version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@5.8.3)(yaml@2.8.0) + typescript: + specifier: ^5.4.5 + version: 5.8.3 +@@ -465,7 +465,7 @@ importers: + version: 6.0.1 + tsup: + specifier: ^8.3.5 +- version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) ++ version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) + tsx: + specifier: ^4.19.2 + version: 4.20.3 +@@ -505,7 +505,7 @@ importers: + version: 6.0.1 + tsup: + specifier: ^8.3.5 +- version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) ++ version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) + tsx: + specifier: ^4.19.2 + version: 4.20.3 +@@ -551,7 +551,7 @@ importers: + version: 6.0.1 + tsup: + specifier: ^8.3.5 +- version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) ++ version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) + tsx: + specifier: ^4.19.2 + version: 4.20.3 +@@ -569,7 +569,7 @@ importers: + version: 7.11.1(encoding@0.1.13) + firebase-admin: + specifier: '>=12.2' +- version: 13.5.0(encoding@0.1.13) ++ version: 13.4.0(encoding@0.1.13) + devDependencies: + '@jest/globals': + specifier: ^29.7.0 +@@ -603,7 +603,7 @@ importers: + version: 29.4.0(@babel/core@7.25.7)(@jest/transform@29.7.0)(@jest/types@29.6.3)(babel-jest@29.7.0(@babel/core@7.25.7))(jest-util@29.7.0)(jest@29.7.0(@types/node@20.19.1)(ts-node@10.9.2(@types/node@20.19.1)(typescript@4.9.5)))(typescript@4.9.5) + tsup: + specifier: ^8.3.5 +- version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) ++ version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) + tsx: + specifier: ^4.19.2 + version: 4.20.3 +@@ -691,7 +691,7 @@ importers: + version: 29.4.0(@babel/core@7.25.7)(@jest/transform@29.7.0)(@jest/types@29.6.3)(babel-jest@29.7.0(@babel/core@7.25.7))(jest-util@29.7.0)(jest@29.7.0(@types/node@20.19.1)(ts-node@10.9.2(@types/node@20.19.1)(typescript@4.9.5)))(typescript@4.9.5) + tsup: + specifier: ^8.3.5 +- version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) ++ version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) + tsx: + specifier: ^4.19.2 + version: 4.20.3 +@@ -731,7 +731,7 @@ importers: + version: 21.0.0 + tsup: + specifier: ^8.3.5 +- version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) ++ version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) + tsx: + specifier: ^4.19.2 + version: 4.20.3 +@@ -765,7 +765,7 @@ importers: + version: 6.0.1 + tsup: + specifier: ^8.3.5 +- version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) ++ version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) + tsx: + specifier: ^4.19.2 + version: 4.20.3 +@@ -789,7 +789,7 @@ importers: + version: link:../../genkit + langchain: + specifier: ^0.1.36 +- version: 0.1.37(@google-cloud/storage@7.18.0(encoding@0.1.13))(@pinecone-database/pinecone@2.2.2)(chromadb@1.9.2(encoding@0.1.13)(openai@4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67)))(encoding@0.1.13)(fast-xml-parser@4.5.3)(firebase-admin@12.3.1(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(handlebars@4.7.8)(ignore@5.3.1)(jsonwebtoken@9.0.2)(lodash@4.17.21)(pdf-parse@1.1.1)(pg@8.16.2)(ws@8.18.3) ++ version: 0.1.37(@google-cloud/storage@7.16.0(encoding@0.1.13))(@pinecone-database/pinecone@2.2.2)(chromadb@1.9.2(encoding@0.1.13)(openai@4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67)))(encoding@0.1.13)(fast-xml-parser@4.5.3)(firebase-admin@12.3.1(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(handlebars@4.7.8)(ignore@5.3.1)(jsonwebtoken@9.0.2)(lodash@4.17.21)(pdf-parse@1.1.1)(pg@8.16.2)(ws@8.18.3) + devDependencies: + '@types/node': + specifier: ^20.11.16 +@@ -802,7 +802,7 @@ importers: + version: 6.0.1 + tsup: + specifier: ^8.3.5 +- version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) ++ version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) + tsx: + specifier: ^4.19.2 + version: 4.20.3 +@@ -845,7 +845,7 @@ importers: + version: 29.4.0(@babel/core@7.25.7)(@jest/transform@29.7.0)(@jest/types@29.6.3)(babel-jest@29.7.0(@babel/core@7.25.7))(jest-util@29.7.0)(jest@29.7.0(@types/node@20.19.1)(ts-node@10.9.2(@types/node@20.19.1)(typescript@5.8.3)))(typescript@5.8.3) + tsup: + specifier: ^8.3.5 +- version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@5.8.3)(yaml@2.8.2) ++ version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@5.8.3)(yaml@2.8.0) + tsx: + specifier: ^4.19.2 + version: 4.20.3 +@@ -872,7 +872,7 @@ importers: + version: 29.7.0(@types/node@20.19.1)(ts-node@10.9.2(@types/node@20.19.1)(typescript@4.9.5)) + next: + specifier: ^15.4.10 +- version: 15.4.10(@babel/core@7.25.7)(@opentelemetry/api@1.9.0)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) ++ version: 15.5.9(@babel/core@7.25.7)(@opentelemetry/api@1.9.0)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + npm-run-all: + specifier: ^4.1.5 + version: 4.1.5 +@@ -884,7 +884,7 @@ importers: + version: 29.4.0(@babel/core@7.25.7)(@jest/transform@29.7.0)(@jest/types@29.6.3)(babel-jest@29.7.0(@babel/core@7.25.7))(jest-util@29.7.0)(jest@29.7.0(@types/node@20.19.1)(ts-node@10.9.2(@types/node@20.19.1)(typescript@4.9.5)))(typescript@4.9.5) + tsup: + specifier: ^8.0.2 +- version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) ++ version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) + tsx: + specifier: ^4.7.0 + version: 4.20.3 +@@ -915,7 +915,7 @@ importers: + version: 6.0.1 + tsup: + specifier: ^8.3.5 +- version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) ++ version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) + tsx: + specifier: ^4.19.2 + version: 4.20.3 +@@ -946,7 +946,7 @@ importers: + version: 6.0.1 + tsup: + specifier: ^8.3.5 +- version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) ++ version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) + tsx: + specifier: ^4.19.2 + version: 4.20.3 +@@ -970,7 +970,7 @@ importers: + version: 1.10.0(encoding@0.1.13) + '@mistralai/mistralai-gcp': + specifier: ^1.3.5 +- version: 1.5.0(encoding@0.1.13)(zod@3.25.76) ++ version: 1.5.0(encoding@0.1.13)(zod@3.25.67) + genkit: + specifier: workspace:^ + version: link:../../genkit +@@ -985,7 +985,7 @@ importers: + version: 3.3.2 + openai: + specifier: ^4.52.7 +- version: 4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.76) ++ version: 4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67) + devDependencies: + '@types/node': + specifier: ^20.11.16 +@@ -1010,7 +1010,7 @@ importers: + version: 21.0.0 + tsup: + specifier: ^8.3.5 +- version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) ++ version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) + tsx: + specifier: ^4.19.2 + version: 4.20.3 +@@ -1023,7 +1023,7 @@ importers: + version: 7.9.4(encoding@0.1.13) + firebase-admin: + specifier: '>=12.2' +- version: 13.5.0(encoding@0.1.13) ++ version: 13.4.0(encoding@0.1.13) + + testapps/anthropic: + dependencies: +@@ -1072,7 +1072,7 @@ importers: + version: 1.0.2 + zod-to-json-schema: + specifier: ^3.24.5 +- version: 3.24.5(zod@3.25.76) ++ version: 3.24.5(zod@3.25.67) + devDependencies: + '@types/wav': + specifier: ^1.0.4 +@@ -1088,7 +1088,7 @@ importers: + version: link:../../plugins/compat-oai + '@genkit-ai/express': + specifier: ^1.1.0 +- version: 1.12.0(@genkit-ai/core@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit))(express@5.1.0)(genkit@genkit) ++ version: 1.12.0(@genkit-ai/core@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit))(express@5.1.0)(genkit@genkit) + genkit: + specifier: workspace:* + version: link:../../genkit +@@ -1299,7 +1299,7 @@ importers: + version: 5.1.0 + firebase-admin: + specifier: ^13.5.0 +- version: 13.5.0(encoding@0.1.13) ++ version: 13.6.0(encoding@0.1.13) + genkit: + specifier: workspace:^ + version: link:../../genkit +@@ -1619,7 +1619,7 @@ importers: + version: 2025.7.1 + '@modelcontextprotocol/server-filesystem': + specifier: ^2025.3.28 +- version: 2025.7.1(zod@3.25.76) ++ version: 2025.7.1(zod@3.25.67) + '@types/express': + specifier: ^4.17.21 + version: 4.17.23 +@@ -1674,7 +1674,7 @@ importers: + version: link:../../genkit + zod: + specifier: ^3.22.4 +- version: 3.25.76 ++ version: 3.25.67 + devDependencies: + tsx: + specifier: ^4.7.1 +@@ -1705,7 +1705,7 @@ importers: + version: link:../../plugins/ollama + genkitx-openai: + specifier: ^0.10.1 +- version: 0.10.1(@genkit-ai/ai@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit))(@genkit-ai/core@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit))(encoding@0.1.13)(ws@8.18.3) ++ version: 0.10.1(@genkit-ai/ai@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit))(@genkit-ai/core@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit))(encoding@0.1.13)(ws@8.18.3) + devDependencies: + rimraf: + specifier: ^6.0.1 +@@ -1825,7 +1825,7 @@ importers: + version: link:../../genkit + next: + specifier: ^15.4.10 +- version: 15.4.10(@babel/core@7.25.7)(@opentelemetry/api@1.9.0)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) ++ version: 15.5.9(@babel/core@7.25.7)(@opentelemetry/api@1.9.0)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + zod: + specifier: ^3.24.1 + version: 3.25.67 +@@ -2145,8 +2145,8 @@ packages: + '@anthropic-ai/sdk@0.24.3': + resolution: {integrity: sha512-916wJXO6T6k8R6BAAcLhLPv/pnLGy7YSEBZXZ1XTFbLcTZE8oTy3oDW9WJf9KKZwMvVcePIfoTSvzXHRcGxkQQ==} + +- '@anthropic-ai/sdk@0.68.0': +- resolution: {integrity: sha512-SMYAmbbiprG8k1EjEPMTwaTqssDT7Ae+jxcR5kWXiqTlbwMR2AthXtscEVWOHkRfyAV5+y3PFYTJRNa3OJWIEw==} ++ '@anthropic-ai/sdk@0.71.2': ++ resolution: {integrity: sha512-TGNDEUuEstk/DKu0/TflXAEt+p+p/WhTlFzEnoosvbaDU2LTjm42igSdlL0VijrKpWejtOKxX0b8A7uc+XiSAQ==} + hasBin: true + peerDependencies: + zod: ^3.25.0 || ^4.0.0 +@@ -2351,9 +2351,6 @@ packages: + '@dabh/diagnostics@2.0.3': + resolution: {integrity: sha512-hrlQOIi7hAfzsMqlGSFyVucrx38O+j6wiGOf//H2ecvIEqYN4ADBSS2iLMh5UFyDunCNniUIPk/q3riFv45xRA==} + +- '@dabh/diagnostics@2.0.8': +- resolution: {integrity: sha512-R4MSXTVnuMzGD7bzHdW2ZhhdPC/igELENcq5IjEverBvq5hn1SXCWcsi6eSsdWP0/Ur+SItRRjAktmdoX/8R/Q==} +- + '@electric-sql/pglite@0.2.17': + resolution: {integrity: sha512-qEpKRT2oUaWDH6tjRxLHjdzMqRUGYDnGZlKrnL4dJ77JVMcP2Hpo3NYnOSPKdZdeec57B6QPprCUFg0picx5Pw==} + +@@ -2516,8 +2513,8 @@ packages: + '@fastify/busboy@3.0.0': + resolution: {integrity: sha512-83rnH2nCvclWaPQQKvkJ2pdOjG4TZyEVuFDnlOF6KP08lDaaceVyw/W63mDuafQT+MKHCvXIPpE5uYWeM0rT4w==} + +- '@fastify/busboy@3.2.0': +- resolution: {integrity: sha512-m9FVDXU3GT2ITSe0UaMA5rU3QkfC/UXtCU8y0gSN/GugTqtVldOBWIB5V6V3sbmenVZUIpU6f+mPEO2+m5iTaA==} ++ '@fastify/busboy@3.1.1': ++ resolution: {integrity: sha512-5DGmA8FTdB2XbDeEwc/5ZXBl6UbBAyBOOLlPuBnZ/N1SwdH9Ii+cOX3tBROlDgcTXxjOYnLMVoKk9+FXAw0CJw==} + + '@firebase/ai@1.4.0': + resolution: {integrity: sha512-wvF33gtU6TXb6Co8TEC1pcl4dnVstYmRE/vs9XjUGE7he7Sgf5TqSu+EoXk/fuzhw5tKr1LC5eG9KdYFM+eosw==} +@@ -2625,9 +2622,6 @@ packages: + '@firebase/database-types@1.0.14': + resolution: {integrity: sha512-8a0Q1GrxM0akgF0RiQHliinhmZd+UQPrxEmUv7MnQBYfVFiLtKOgs3g6ghRt/WEGJHyQNslZ+0PocIwNfoDwKw==} + +- '@firebase/database-types@1.0.16': +- resolution: {integrity: sha512-xkQLQfU5De7+SPhEGAXFBnDryUWhhlFXelEg2YeZOQMCdoe7dL64DDAd77SQsR+6uoXIZY5MB4y/inCs4GTfcw==} +- + '@firebase/database-types@1.0.6': + resolution: {integrity: sha512-sMI7IynSZBsyGbUugc8PKE1jwKbnvaieAz/RxuM57PZQNCi6Rteiviwcw/jqZOX6igqYJwXWZ3UzKOZo2nUDRA==} + +@@ -2760,18 +2754,14 @@ packages: + resolution: {integrity: sha512-Z4rK23xBCwgKDqmzGVMef+Vb4xso2j5Q8OG0vVL4m4fA5ZjPMYQazu8OJJC3vtQRC3SQ/Pgx/6TPNVsCd70QRw==} + engines: {node: '>=18.0.0'} + +- '@firebase/util@1.13.0': +- resolution: {integrity: sha512-0AZUyYUfpMNcztR5l09izHwXkZpghLgCUaAGjtMwXnCg3bj4ml5VgiwqOMOxJ+Nw4qN/zJAaOQBcJ7KGkWStqQ==} +- engines: {node: '>=20.0.0'} +- + '@firebase/webchannel-wrapper@1.0.3': + resolution: {integrity: sha512-2xCRM9q9FlzGZCdgDMJwc0gyUkWFtkosy7Xxr6sFgQwn+wMNIWd7xIvYNauU1r64B5L5rsGKy/n9TKJ0aAFeqQ==} + +- '@genkit-ai/ai@1.26.0-rc.0': +- resolution: {integrity: sha512-TrNRK/fSuhM8XHOGAV6lDH9daGYfWCPyW55ZDtH3IeDAVNtfcvOhgmM+uvgtsvjKYeJiDHdwVQeabL1e9tzdYg==} ++ '@genkit-ai/ai@1.27.0': ++ resolution: {integrity: sha512-Vogp21a0pBgL7UsdHj1Jm79PjrQdLNRK5dZT05Xvr3f7GwCNMv0k6Olxp+qrgwLi6DbRsVPi7c+wcldekMlLFQ==} + +- '@genkit-ai/core@1.26.0-rc.0': +- resolution: {integrity: sha512-ZnzyWLeb364csirXJusPKKV5i6ZqzsKHUc9ZKRGBSoPXkrz/w0hLGoPCFjCSbfm3DmRvC45/HOn3uEVtwkN2MA==} ++ '@genkit-ai/core@1.27.0': ++ resolution: {integrity: sha512-2dcr/yKixcxNj0U9pFpx9qNOTJcRdEjEz76qd5+o6Ac31foRBMb3J9Bvrfr+SaaPI4kiMnFUxN1X+w5yNjryQg==} + + '@genkit-ai/express@1.12.0': + resolution: {integrity: sha512-QAxSS07dX5ovSfsUB4s90KaDnv4zg1wnoxCZCa+jBsYUyv9NvCCTsOk25xAQgGxc7xi3+MD+3AsPier5oZILIg==} +@@ -2791,27 +2781,11 @@ packages: + firebase: + optional: true + +- '@genkit-ai/firebase@1.25.0': +- resolution: {integrity: sha512-Z0FbnJHQs8qS0yxG++Dn3CZ7gv+YNaihGaWXoDKy02mNOkeRzHA6UPaWxSTaWkWHYdB0MyOnMGlyqxnWyqVdmg==} +- peerDependencies: +- '@google-cloud/firestore': ^7.11.0 +- firebase: '>=11.5.0' +- firebase-admin: '>=12.2' +- genkit: ^1.25.0 +- peerDependenciesMeta: +- firebase: +- optional: true +- + '@genkit-ai/google-cloud@1.16.1': + resolution: {integrity: sha512-uujjdGr/sra7iKHApufwkt5jGo7CQcRCJNWPgnSg4g179CjtvtZBGjxmFRVBtKzuF61ktkY6E9JoLz83nWEyAA==} + peerDependencies: + genkit: ^1.16.1 + +- '@genkit-ai/google-cloud@1.25.0': +- resolution: {integrity: sha512-wHCa8JSTv7MtwzXjUQ9AT5v0kCTJrz0In+ffgAYw1yt8ComAz5o7Ir+xks+sX1vJfN8ptvW0GUa6rsUaXCB3kA==} +- peerDependencies: +- genkit: ^1.25.0 +- + '@gerrit0/mini-shiki@1.27.2': + resolution: {integrity: sha512-GeWyHz8ao2gBiUW4OJnQDxXQnFgZQwwQk05t/CVVgNBN7/rK8XZ7xY6YhLVv9tH3VppWWmr9DCl3MwemB/i+Og==} + +@@ -2839,10 +2813,6 @@ packages: + resolution: {integrity: sha512-ZxOdH8Wr01hBDvKCQfMWqwUcfNcN3JY19k1LtS1fTFhEyorYPLsbWN+VxIRL46pOYGHTPkU3Or5HbT/SLQM5nA==} + engines: {node: '>=14.0.0'} + +- '@google-cloud/firestore@7.11.6': +- resolution: {integrity: sha512-EW/O8ktzwLfyWBOsNuhRoMi8lrC3clHM5LVFhGvO1HCsLozCOOXRAlHrYBoE6HL42Sc8yYMuCb2XqcnJ4OOEpw==} +- engines: {node: '>=14.0.0'} +- + '@google-cloud/logging-winston@6.0.1': + resolution: {integrity: sha512-tgA/qe/aGZITMrJ/5Tuykv234pLb/Qo6iDZ8SDkjbsiIy69mLQmbphrUd/IqnE17BSDfrwDUckvWdghiy8b+Qg==} + engines: {node: '>=14.0.0'} +@@ -2909,10 +2879,6 @@ packages: + resolution: {integrity: sha512-7/5LRgykyOfQENcm6hDKP8SX/u9XxE5YOiWOkgkwcoO+cG8xT/cyOvp9wwN3IxfdYgpHs8CE7Nq2PKX2lNaEXw==} + engines: {node: '>=14'} + +- '@google-cloud/storage@7.18.0': +- resolution: {integrity: sha512-r3ZwDMiz4nwW6R922Z1pwpePxyRwE5GdevYX63hRmAQUkUQJcBH/79EnQPDv5cOv1mFBgevdNWQfi3tie3dHrQ==} +- engines: {node: '>=14'} +- + '@google-cloud/vertexai@1.10.0': + resolution: {integrity: sha512-HqYqoivNtkq59po8m7KI0n+lWKdz4kabENncYQXZCX/hBWJfXtKAfR/2nUQsP+TwSfHKoA7zDL2RrJYIv/j3VQ==} + engines: {node: '>=18.0.0'} +@@ -2950,8 +2916,8 @@ packages: + resolution: {integrity: sha512-HPa/K5NX6ahMoeBv15njAc/sfF4/jmiXLar9UlC2UfHFKZzsCVLc3wbe7+7qua7w9VPh2/L6EBxyAV7/E8Wftg==} + engines: {node: '>=12.10.0'} + +- '@grpc/grpc-js@1.14.2': +- resolution: {integrity: sha512-QzVUtEFyu05UNx2xr0fCQmStUO17uVQhGNowtxs00IgTZT6/W2PBLfUkj30s0FKJ29VtTa3ArVNIhNP6akQhqA==} ++ '@grpc/grpc-js@1.14.3': ++ resolution: {integrity: sha512-Iq8QQQ/7X3Sac15oB6p0FmUg/klxQvXLeileoqrTRGJYLV+/9tubbr9ipz0GKHjmXVsgFPo/+W+2cA8eNcR+XA==} + engines: {node: '>=12.10.0'} + + '@grpc/grpc-js@1.9.15': +@@ -3110,8 +3076,8 @@ packages: + cpu: [x64] + os: [win32] + +- '@inquirer/external-editor@1.0.2': +- resolution: {integrity: sha512-yy9cOoBnx58TlsPrIxauKIFQTiyH+0MK4e97y4sV9ERbI+zDxw7i2hxHLCIEGIE/8PPvDxGhgzIOTSOWcs6/MQ==} ++ '@inquirer/external-editor@1.0.3': ++ resolution: {integrity: sha512-RWbSrDiYmO4LbejWY7ttpxczuwQyZLBUyygsA9Nsv95hpzUWwnNTVQmAq3xuh7vNwCp07UTmE5i11XAEExx4RA==} + engines: {node: '>=18'} + peerDependencies: + '@types/node': '>=18' +@@ -3219,9 +3185,6 @@ packages: + '@jridgewell/sourcemap-codec@1.5.0': + resolution: {integrity: sha512-gv3ZRaISU3fjPAgNsriBRqGWQL6quFx04YMPW/zD8XMLsU32mhCCbfbO6KZFLjvYpCZ8zyDEgqsgf+PwPaM7GQ==} + +- '@jridgewell/sourcemap-codec@1.5.5': +- resolution: {integrity: sha512-cYQ9310grqxueWbl+WuIUIaiUaDcj7WOq5fVhEljNVgRfOUhY9fy2zTvfoqWsnebh8Sl70VScFbICvJnLKB0Og==} +- + '@jridgewell/trace-mapping@0.3.25': + resolution: {integrity: sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==} + +@@ -3621,53 +3584,53 @@ packages: + resolution: {integrity: sha512-92ybDocKl6JM48ZpYbj+A7Qt45IaTABDk0y3sDecEQfgdhfNzJtEityqNHoCZ4Vty2dldPkJhxgvOnbrQMXTTA==} + engines: {node: '>= 10'} + +- '@next/env@15.4.10': +- resolution: {integrity: sha512-knhmoJ0Vv7VRf6pZEPSnciUG1S4bIhWx+qTYBW/AjxEtlzsiNORPk8sFDCEvqLfmKuey56UB9FL1UdHEV3uBrg==} ++ '@next/env@15.5.9': ++ resolution: {integrity: sha512-4GlTZ+EJM7WaW2HEZcyU317tIQDjkQIyENDLxYJfSWlfqguN+dHkZgyQTV/7ykvobU7yEH5gKvreNrH4B6QgIg==} + +- '@next/swc-darwin-arm64@15.4.8': +- resolution: {integrity: sha512-Pf6zXp7yyQEn7sqMxur6+kYcywx5up1J849psyET7/8pG2gQTVMjU3NzgIt8SeEP5to3If/SaWmaA6H6ysBr1A==} ++ '@next/swc-darwin-arm64@15.5.7': ++ resolution: {integrity: sha512-IZwtxCEpI91HVU/rAUOOobWSZv4P2DeTtNaCdHqLcTJU4wdNXgAySvKa/qJCgR5m6KI8UsKDXtO2B31jcaw1Yw==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [darwin] + +- '@next/swc-darwin-x64@15.4.8': +- resolution: {integrity: sha512-xla6AOfz68a6kq3gRQccWEvFC/VRGJmA/QuSLENSO7CZX5WIEkSz7r1FdXUjtGCQ1c2M+ndUAH7opdfLK1PQbw==} ++ '@next/swc-darwin-x64@15.5.7': ++ resolution: {integrity: sha512-UP6CaDBcqaCBuiq/gfCEJw7sPEoX1aIjZHnBWN9v9qYHQdMKvCKcAVs4OX1vIjeE+tC5EIuwDTVIoXpUes29lg==} + engines: {node: '>= 10'} + cpu: [x64] + os: [darwin] + +- '@next/swc-linux-arm64-gnu@15.4.8': +- resolution: {integrity: sha512-y3fmp+1Px/SJD+5ntve5QLZnGLycsxsVPkTzAc3zUiXYSOlTPqT8ynfmt6tt4fSo1tAhDPmryXpYKEAcoAPDJw==} ++ '@next/swc-linux-arm64-gnu@15.5.7': ++ resolution: {integrity: sha512-NCslw3GrNIw7OgmRBxHtdWFQYhexoUCq+0oS2ccjyYLtcn1SzGzeM54jpTFonIMUjNbHmpKpziXnpxhSWLcmBA==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [linux] + +- '@next/swc-linux-arm64-musl@15.4.8': +- resolution: {integrity: sha512-DX/L8VHzrr1CfwaVjBQr3GWCqNNFgyWJbeQ10Lx/phzbQo3JNAxUok1DZ8JHRGcL6PgMRgj6HylnLNndxn4Z6A==} ++ '@next/swc-linux-arm64-musl@15.5.7': ++ resolution: {integrity: sha512-nfymt+SE5cvtTrG9u1wdoxBr9bVB7mtKTcj0ltRn6gkP/2Nu1zM5ei8rwP9qKQP0Y//umK+TtkKgNtfboBxRrw==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [linux] + +- '@next/swc-linux-x64-gnu@15.4.8': +- resolution: {integrity: sha512-9fLAAXKAL3xEIFdKdzG5rUSvSiZTLLTCc6JKq1z04DR4zY7DbAPcRvNm3K1inVhTiQCs19ZRAgUerHiVKMZZIA==} ++ '@next/swc-linux-x64-gnu@15.5.7': ++ resolution: {integrity: sha512-hvXcZvCaaEbCZcVzcY7E1uXN9xWZfFvkNHwbe/n4OkRhFWrs1J1QV+4U1BN06tXLdaS4DazEGXwgqnu/VMcmqw==} + engines: {node: '>= 10'} + cpu: [x64] + os: [linux] + +- '@next/swc-linux-x64-musl@15.4.8': +- resolution: {integrity: sha512-s45V7nfb5g7dbS7JK6XZDcapicVrMMvX2uYgOHP16QuKH/JA285oy6HcxlKqwUNaFY/UC6EvQ8QZUOo19cBKSA==} ++ '@next/swc-linux-x64-musl@15.5.7': ++ resolution: {integrity: sha512-4IUO539b8FmF0odY6/SqANJdgwn1xs1GkPO5doZugwZ3ETF6JUdckk7RGmsfSf7ws8Qb2YB5It33mvNL/0acqA==} + engines: {node: '>= 10'} + cpu: [x64] + os: [linux] + +- '@next/swc-win32-arm64-msvc@15.4.8': +- resolution: {integrity: sha512-KjgeQyOAq7t/HzAJcWPGA8X+4WY03uSCZ2Ekk98S9OgCFsb6lfBE3dbUzUuEQAN2THbwYgFfxX2yFTCMm8Kehw==} ++ '@next/swc-win32-arm64-msvc@15.5.7': ++ resolution: {integrity: sha512-CpJVTkYI3ZajQkC5vajM7/ApKJUOlm6uP4BknM3XKvJ7VXAvCqSjSLmM0LKdYzn6nBJVSjdclx8nYJSa3xlTgQ==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [win32] + +- '@next/swc-win32-x64-msvc@15.4.8': +- resolution: {integrity: sha512-Exsmf/+42fWVnLMaZHzshukTBxZrSwuuLKFvqhGHJ+mC1AokqieLY/XzAl3jc/CqhXLqLY3RRjkKJ9YnLPcRWg==} ++ '@next/swc-win32-x64-msvc@15.5.7': ++ resolution: {integrity: sha512-gMzgBX164I6DN+9/PGA+9dQiwmTkE4TloBNx8Kv9UiGARsr9Nba7IpcBRA1iTV9vwlYnrE3Uy6I7Aj6qLjQuqw==} + engines: {node: '>= 10'} + cpu: [x64] + os: [win32] +@@ -4320,9 +4283,6 @@ packages: + '@sinonjs/samsam@8.0.2': + resolution: {integrity: sha512-v46t/fwnhejRSFTGqbpn9u+LQ9xJDse10gNnPgAcxgdoCDMXj/G2asWAC/8Qs+BAZDicX+MNZouXT1A7c83kVw==} + +- '@so-ric/colorspace@1.1.6': +- resolution: {integrity: sha512-/KiKkpHNOBgkFJwu9sh48LkHSMYGyuTcSFK/qMBdnOAlrRJzRSXAOFB5qwzaVQuDl8wAvHVMkaASQDReTahxuw==} +- + '@swc/helpers@0.5.15': + resolution: {integrity: sha512-JQ5TuMi45Owi4/BIMAJBoSQoOJu12oOk/gADqlcUL9JEdHB8vyjUSsxqeNXnmXHjYKMi2WcYtezGEEhqUI/E2g==} + +@@ -4333,8 +4293,8 @@ packages: + '@tootallnate/quickjs-emscripten@0.23.0': + resolution: {integrity: sha512-C5Mc6rdnsaJDjO3UpGW/CQTHtCKaYlScZTly4JIu97Jxo/odCiH0ITnDXSJPTOrEKk/ycSZ0AOgTmkDtkOsvIA==} + +- '@tsconfig/node10@1.0.12': +- resolution: {integrity: sha512-UCYBaeFvM11aU2y3YPZ//O5Rhj+xKyzy7mvcIoAjASbigy8mHMryP5cK7dgjlz2hWxh1g5pLw084E0a/wlUSFQ==} ++ '@tsconfig/node10@1.0.11': ++ resolution: {integrity: sha512-DcRjDCujK/kCk/cUe8Xz8ZSpm8mS3mNNpta+jGCA6USEDfktlNvm1+IuZ9eTcDbNk41BHwpHHeW+N1lKCz4zOw==} + + '@tsconfig/node12@1.0.11': + resolution: {integrity: sha512-cqefuRsh12pWyGsIoBKJA9luFu3mRxCA+ORZvA4ktLSzIuCUtWVxGIuXigEwO5/ywWFMZ2QEGKWvkZG1zDMTag==} +@@ -4394,15 +4354,9 @@ packages: + '@types/express-serve-static-core@4.17.43': + resolution: {integrity: sha512-oaYtiBirUOPQGSWNGPWnzyAFJ0BP3cwvN4oWZQY+zUBwpVIGsKUkpBpSztp74drYcjavs7SKFZ4DX1V2QeN8rg==} + +- '@types/express-serve-static-core@4.19.7': +- resolution: {integrity: sha512-FvPtiIf1LfhzsaIXhv/PHan/2FeQBbtBDtfX2QfvPxdUelMDEckK08SM6nqo1MIZY3RUlfA+HV8+hFUSio78qg==} +- + '@types/express@4.17.23': + resolution: {integrity: sha512-Crp6WY9aTYP3qPi2wGDo9iUe/rceX01UMhnF1jmwDcKCFM6cx7YhGP/Mpr3y9AASpfHixIG0E6azCcL5OcDHsQ==} + +- '@types/express@4.17.25': +- resolution: {integrity: sha512-dVd04UKsfpINUnK0yBoYHDF3xu7xVH4BuDotC/xGuycx4CgbP48X/KF/586bcObxT0HENHXEU8Nqtu6NR+eKhw==} +- + '@types/graceful-fs@4.1.9': + resolution: {integrity: sha512-olP3sd1qOEe5dXTSaFvQG+02VdRXcdytWLAZsAq1PecU8uqQAhkrnbli7DagjtXKW/Bl7YJbUsa8MPcuc8LHEQ==} + +@@ -4416,9 +4370,6 @@ packages: + '@types/http-errors@2.0.4': + resolution: {integrity: sha512-D0CFMMtydbJAegzOyHjtiKPLlvnm3iTZyZRSZoLq2mRhDdmLfIWOCYPfQJ4cu2erKghU++QvjcUjp/5h7hESpA==} + +- '@types/http-errors@2.0.5': +- resolution: {integrity: sha512-r8Tayk8HJnX0FztbZN7oVqGccWgw98T/0neJphO91KkmOzug1KkofZURD4UaD5uH8AqcFLfdPErnBod0u71/qg==} +- + '@types/istanbul-lib-coverage@2.0.6': + resolution: {integrity: sha512-2QF/t/auWm0lsy8XtKVPG19v3sSOQlJe/YHZgfjb/KBBHOGSV+J2q/S671rcq9uTBrLAXmZpqJiaQbMT+zNU1w==} + +@@ -4473,18 +4424,9 @@ packages: + '@types/node@20.19.1': + resolution: {integrity: sha512-jJD50LtlD2dodAEO653i3YF04NWak6jN3ky+Ri3Em3mGR39/glWiboM/IePaRbgwSfqM1TpGXfAg8ohn/4dTgA==} + +- '@types/node@20.19.25': +- resolution: {integrity: sha512-ZsJzA5thDQMSQO788d7IocwwQbI8B5OPzmqNvpf3NY/+MHDAS759Wo0gd2WQeXYt5AAAQjzcrTVC6SKCuYgoCQ==} +- +- '@types/node@20.19.26': +- resolution: {integrity: sha512-0l6cjgF0XnihUpndDhk+nyD3exio3iKaYROSgvh/qSevPXax3L8p5DBRFjbvalnwatGgHEQn2R88y2fA3g4irg==} +- + '@types/node@22.15.32': + resolution: {integrity: sha512-3jigKqgSjsH6gYZv2nEsqdXfZqIFGAV36XYYjf9KGZ3PSG+IhLecqPnI310RvjutyMwifE2hhhNEklOUrvx/wA==} + +- '@types/node@22.19.2': +- resolution: {integrity: sha512-LPM2G3Syo1GLzXLGJAKdqoU35XvrWzGJ21/7sgZTUpbkBaOasTj8tjwn6w+hCkqaa1TfJ/w67rJSwYItlJ2mYw==} +- + '@types/pdf-parse@1.1.5': + resolution: {integrity: sha512-kBfrSXsloMnUJOKi25s3+hRmkycHfLK6A09eRGqF/N8BkQoPUmaCr+q8Cli5FnfohEz/rsv82zAiPz/LXtOGhA==} + +@@ -4494,9 +4436,6 @@ packages: + '@types/pg@8.6.1': + resolution: {integrity: sha512-1Kc4oAGzAl7uqUStZCDvaLFqZrW9qWSjXOmBfdgyBP5La7Us6Mg4GBvRlSoaZMhQF/zSj1C8CtKMBkoiT8eL8w==} + +- '@types/qs@6.14.0': +- resolution: {integrity: sha512-eOunJqu0K1923aExK6y8p6fsihYEn/BYuQ4g0CxAAgFc4b/ZLN4CrsRZ55srTdqoiLzU2B2evC+apEIxprEzkQ==} +- + '@types/qs@6.9.14': + resolution: {integrity: sha512-5khscbd3SwWMhFqylJBLQ0zIu7c1K6Vz0uBIt915BI3zV0q1nfjRQD3RqSBcPaO6PHEF4ov/t9y89fSiyThlPA==} + +@@ -4515,15 +4454,6 @@ packages: + '@types/send@0.17.4': + resolution: {integrity: sha512-x2EM6TJOybec7c52BX0ZspPodMsQUd5L6PRwOunVyVUhXiBSKf3AezDL8Dgvgt5o0UfKNfuA0eMLr2wLT4AiBA==} + +- '@types/send@0.17.6': +- resolution: {integrity: sha512-Uqt8rPBE8SY0RK8JB1EzVOIZ32uqy8HwdxCnoCOsYrvnswqmFZ/k+9Ikidlk/ImhsdvBsloHbAlewb2IEBV/Og==} +- +- '@types/send@1.2.1': +- resolution: {integrity: sha512-arsCikDvlU99zl1g69TcAB3mzZPpxgw0UQnaHeC1Nwb015xp8bknZv5rIfri9xTOcMuaVgvabfIRA7PSZVuZIQ==} +- +- '@types/serve-static@1.15.10': +- resolution: {integrity: sha512-tRs1dB+g8Itk72rlSI2ZrW6vZg0YrLI81iQSTkMmOqnqCaNr/8Ek4VwWcN5vZgCYWbg/JJSGBlUaYGAOP73qBw==} +- + '@types/serve-static@1.15.5': + resolution: {integrity: sha512-PDRk21MnK70hja/YF8AHfC7yIsiQHn1rcXx7ijCFBX/k+XQJhQT/gw3xekXKJvx+5SXaMMS8oqQy09Mzvz2TuQ==} + +@@ -4661,8 +4591,8 @@ packages: + resolution: {integrity: sha512-gKXj5ALrKWQLsYG9jlTRmR/xKluxHV+Z9QEwNIgCfM1/uwPMCuzVVnh5mwTd+OuBZcwSIMbqssNWRm1lE51QaQ==} + engines: {node: '>=8'} + +- ansi-escapes@7.1.1: +- resolution: {integrity: sha512-Zhl0ErHcSRUaVfGUeUdDuLgpkEo8KIFjB4Y9uAc46ScOpdDiU1Dbyplh7qWJeJ/ZHpbyMSM26+X3BySgnIz40Q==} ++ ansi-escapes@7.2.0: ++ resolution: {integrity: sha512-g6LhBsl+GBPRWGWsBtutpzBYuIIdBkLEvad5C/va/74Db018+5TZiyA26cZJAr3Rft5lprVqOIPxf5Vid6tqAw==} + engines: {node: '>=18'} + + ansi-regex@5.0.1: +@@ -4801,8 +4731,8 @@ packages: + balanced-match@1.0.2: + resolution: {integrity: sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==} + +- bare-events@2.8.1: +- resolution: {integrity: sha512-oxSAxTS1hRfnyit2CL5QpAOS5ixfBjj6ex3yTNvXyY/kE719jQ/IjuESJBK2w5v4wwQRAHGseVJXx9QBYOtFGQ==} ++ bare-events@2.8.2: ++ resolution: {integrity: sha512-riJjyv1/mHLIPX4RwiK+oW9/4c3TEUeORHKefKAKnZ5kyslbN+HXowtbaVEqt4IMUB7OXlfixcs6gsFeo/jhiQ==} + peerDependencies: + bare-abort-controller: '*' + peerDependenciesMeta: +@@ -4846,10 +4776,6 @@ packages: + resolution: {integrity: sha512-7rAxByjUMqQ3/bHJy7D6OGXvx/MMc4IqBn/X0fcM1QUcAItpZrBEYhWGem+tzXH90c+G01ypMcYJBO9Y30203g==} + engines: {node: '>= 0.8', npm: 1.2.8000 || >= 1.4.16} + +- body-parser@1.20.4: +- resolution: {integrity: sha512-ZTgYYLMOXY9qKU/57FAo8F+HA2dGX7bqGc71txDRC1rS4frdFI5R7NhluHxH6M0YItAP0sHB4uqAOcYKxO6uGA==} +- engines: {node: '>= 0.8', npm: 1.2.8000 || >= 1.4.16} +- + body-parser@2.2.0: + resolution: {integrity: sha512-02qvAaxv8tp7fBa/mw1ga98OGm+eCbqzJOKoRt70sLmfEEi+jyBYVTDGfCL/k06/4EMk/z01gCe7HoCH/f2LTg==} + engines: {node: '>=18'} +@@ -4961,8 +4887,8 @@ packages: + resolution: {integrity: sha512-Gmy6FhYlCY7uOElZUSbxo2UCDH8owEk996gkbrpsgGtrJLM3J7jGxl9Ic7Qwwj4ivOE5AWZWRMecDdF7hqGjFA==} + engines: {node: '>=10'} + +- caniuse-lite@1.0.30001760: +- resolution: {integrity: sha512-7AAMPcueWELt1p3mi13HR/LHH0TJLT11cnwDJEs3xA4+CK/PLKeO9Kl1oru24htkyUKtkGCvAx4ohB0Ttry8Dw==} ++ caniuse-lite@1.0.30001667: ++ resolution: {integrity: sha512-7LTwJjcRkzKFmtqGsibMeuXmvFDfZq/nzIjnmgCGzKKRVzjD72selLDK1oPF/Oxzmt4fNcPvTDvGqSDG4tCALw==} + + chalk@2.4.2: + resolution: {integrity: sha512-Mti+f9lpJNcwF4tWV8/OrTTtF1gZi+f8FqlyAdouralcFWFQWF2+NgCHShjkCb+IFBLq9buZwE1xckQU4peSuQ==} +@@ -5108,34 +5034,18 @@ packages: + resolution: {integrity: sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==} + engines: {node: '>=7.0.0'} + +- color-convert@3.1.3: +- resolution: {integrity: sha512-fasDH2ont2GqF5HpyO4w0+BcewlhHEZOFn9c1ckZdHpJ56Qb7MHhH/IcJZbBGgvdtwdwNbLvxiBEdg336iA9Sg==} +- engines: {node: '>=14.6'} +- + color-name@1.1.3: + resolution: {integrity: sha512-72fSenhMw2HZMTVHeCA9KCmpEIbzWiQsjN+BHcBbS9vr1mtt+vJjPdksIBNUmKAW8TFUDPJK5SUU3QhE9NEXDw==} + + color-name@1.1.4: + resolution: {integrity: sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==} + +- color-name@2.1.0: +- resolution: {integrity: sha512-1bPaDNFm0axzE4MEAzKPuqKWeRaT43U/hyxKPBdqTfmPF+d6n7FSoTFxLVULUJOmiLp01KjhIPPH+HrXZJN4Rg==} +- engines: {node: '>=12.20'} +- + color-string@1.9.1: + resolution: {integrity: sha512-shrVawQFojnZv6xM40anx4CkoDP+fZsw/ZerEMsW/pyzsRbElpsL/DBVW7q3ExxwusdNXI3lXpuhEZkzs8p5Eg==} + +- color-string@2.1.4: +- resolution: {integrity: sha512-Bb6Cq8oq0IjDOe8wJmi4JeNn763Xs9cfrBcaylK1tPypWzyoy2G3l90v9k64kjphl/ZJjPIShFztenRomi8WTg==} +- engines: {node: '>=18'} +- + color@3.2.1: + resolution: {integrity: sha512-aBl7dZI9ENN6fUGC7mWpMTPNHmWUSNan9tuWN6ahh5ZLNk9baLJOnSMlrQkHcrfFgz2/RigjUVAjdx36VcemKA==} + +- color@5.0.3: +- resolution: {integrity: sha512-ezmVcLR3xAVp8kYOm4GS45ZLLgIE6SPAFoduLr6hTDajwb3KZ2F46gulK3XpcwRFb5KKGCSezCBAY4Dw4HsyXA==} +- engines: {node: '>=18'} +- + colorette@2.0.19: + resolution: {integrity: sha512-3tlv/dIP7FWvj3BsbHrGLJ6l/oKh1O3TcgBqMn+yyCagOxc23fyzDS6HypQbgxWbkpDnf52p1LuR4eWDQ/K9WQ==} + +@@ -5228,9 +5138,6 @@ packages: + cookie-signature@1.0.6: + resolution: {integrity: sha512-QADzlaHc8icV8I7vbaJXJwod9HWYp8uCqf1xa4OfNu1T7JVxQIrUgOWtHdNDtPiywmFbiS12VjotIXLrKM3orQ==} + +- cookie-signature@1.0.7: +- resolution: {integrity: sha512-NXdYc3dLr47pBkpUCHtKSwIOQXLVn8dZEuywboCOJY/osA0wFSLlSawr3KN8qXJEyX66FcONTH8EIlVuK0yyFA==} +- + cookie-signature@1.2.2: + resolution: {integrity: sha512-D76uU73ulSXrD1UXF4KE2TMxVVwhsnCgfAyTg9k8P6KGZjlXKrOLe4dJQKI3Bxi5wjesZoFXJWElNWBjPZMbhg==} + engines: {node: '>=6.6.0'} +@@ -5364,15 +5271,6 @@ packages: + supports-color: + optional: true + +- debug@4.4.3: +- resolution: {integrity: sha512-RGwwWnwQvkVfavKVt22FGLw+xYSdzARwm0ru6DhTVA3umU5hZc28V3kO4stgYryrTlLpuvgI9GiijltAjNbcqA==} +- engines: {node: '>=6.0'} +- peerDependencies: +- supports-color: '*' +- peerDependenciesMeta: +- supports-color: +- optional: true +- + decamelize@1.2.0: + resolution: {integrity: sha512-z2S+W9X73hAUUki+N+9Za2lBlun89zigOyGrsax+KUQ6wKW4ZoWpEYBkGhQjwAjjDCkWxhY0VKEhk8wzY7F5cA==} + engines: {node: '>=0.10.0'} +@@ -5478,9 +5376,6 @@ packages: + dotprompt@1.1.1: + resolution: {integrity: sha512-xll31JxDiE7FaF030t0Dx4EMSV60Qn/pONDn6Hs5bBBeEANbtqIu6fPfaAOoSNbF1Y9TK+pj9Xnvud7G7GHpaA==} + +- dotprompt@1.1.2: +- resolution: {integrity: sha512-24EU+eORQbPywBicIP44BiqykzEXFwZq1ZQKO5TEr9KrrENyDA7I1NzqhtmmEdQVfAXka0DEbSLPN5nerCqJ8A==} +- + dunder-proto@1.0.1: + resolution: {integrity: sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==} + engines: {node: '>= 0.4'} +@@ -5706,10 +5601,6 @@ packages: + resolution: {integrity: sha512-28HqgMZAmih1Czt9ny7qr6ek2qddF4FclbMzwhCREB6OFfH+rXAnuNCwo1/wFvrtbgsQDb4kSbX9de9lFbrXnA==} + engines: {node: '>= 0.10.0'} + +- express@4.22.1: +- resolution: {integrity: sha512-F2X8g9P1X7uCPZMA3MVf9wcTqlyNp7IhH5qPCI0izhaOIYXaW9L535tGA3qmjRzpH+bZczqq7hVKxTR4NWnu+g==} +- engines: {node: '>= 0.10.0'} +- + express@5.1.0: + resolution: {integrity: sha512-DT9ck5YIRU+8GYzzU5kT3eHGA5iL+1Zd0EutOmTE9Dtk+Tvuzd23VBU+ec7HPNSTxXYO55gPV/hq4pSBJDjFpA==} + engines: {node: '>= 18'} +@@ -5788,10 +5679,6 @@ packages: + resolution: {integrity: sha512-6BN9trH7bp3qvnrRyzsBz+g3lZxTNZTbVO2EV1CS0WIcDbawYVdYvGflME/9QP0h0pYlCDBCTjYa9nZzMDpyxQ==} + engines: {node: '>= 0.8'} + +- finalhandler@1.3.2: +- resolution: {integrity: sha512-aA4RyPcd3badbdABGDuTXCMTtOneUCAYH/gxoYRTZlIJdF0YPWuGqiAsIrhNnnqdXGswYk6dGujem4w80UJFhg==} +- engines: {node: '>= 0.8'} +- + finalhandler@2.1.0: + resolution: {integrity: sha512-/t88Ty3d5JWQbWYgaOGCCYfXRwV1+be02WqYYlL6h0lEiUAMPM8o8qKGO01YIkOHzka2up08wvgYD0mDiI+q3Q==} + engines: {node: '>= 0.8'} +@@ -5811,8 +5698,8 @@ packages: + resolution: {integrity: sha512-Y8DcyKK+4pl4B93ooiy1G8qvdyRMkcNFfBSh+8rbVcw4cW8dgG0VXCCTp5NUwub8sn9vSPsOwpb9tE2OuFmcfQ==} + engines: {node: '>=18'} + +- firebase-admin@13.5.0: +- resolution: {integrity: sha512-QZOpv1DJRJpH8NcWiL1xXE10tw3L/bdPFlgjcWrqU3ufyOJDYfxB1MMtxiVTwxK16NlybQbEM6ciSich2uWEIQ==} ++ firebase-admin@13.6.0: ++ resolution: {integrity: sha512-GdPA/t0+Cq8p1JnjFRBmxRxAGvF/kl2yfdhALl38PrRp325YxyQ5aNaHui0XmaKcKiGRFIJ/EgBNWFoDP0onjw==} + engines: {node: '>=18'} + + firebase-functions@6.3.2: +@@ -5859,8 +5746,8 @@ packages: + resolution: {integrity: sha512-ETEklSGi5t0QMZuiXoA/Q6vcnxcLQP5vdugSpuAyi6SVGi2clPPp+xgEhuMaHC+zGgn31Kd235W35f7Hykkaww==} + engines: {node: '>= 6'} + +- form-data@4.0.4: +- resolution: {integrity: sha512-KrGhL9Q4zjj0kiUt5OO4Mr/A/jlI2jDYs5eHBpYHPcBEVSiipAvn2Ko2HnPe20rmcuuvMHNdZFp+4IlGTMF0Ow==} ++ form-data@4.0.5: ++ resolution: {integrity: sha512-8RipRLol37bNs2bhoV67fiTEvdTrbMUYcFTiy3+wuuOnUog2QBHCZWXDRijWQfAkhBj2Uf5UnVaiWwA5vdd82w==} + engines: {node: '>= 6'} + + formdata-node@4.4.1: +@@ -5948,8 +5835,8 @@ packages: + resolution: {integrity: sha512-zV/5HKTfCeKWnxG0Dmrw51hEWFGfcF2xiXqcA3+J90WDuP0SvoiSO5ORvcBsifmx/FoIjgQN3oNOGaQ5PhLFkg==} + engines: {node: '>=18'} + +- genkit@1.26.0-rc.0: +- resolution: {integrity: sha512-Yx4qtT0ImwE2Nu8ts1lrq4eL/qCa+vFmgNOWnCJLc205Vcco0yZEQ0Wr0OL3sBhIAyLuAfx6CCUPJE735ypTsg==} ++ genkit@1.27.0: ++ resolution: {integrity: sha512-54OAzw9+dlOs2H4bWnktMwKVA1wwY9XmudKBAz2uBdWhep5r0xHy1qNE6tUVnSgn+LGGaR/0xfYRSs8uqNPFVw==} + + genkitx-openai@0.10.1: + resolution: {integrity: sha512-E9/DzyQcBUSTy81xT2pvEmdnn9Q/cKoojEt6lD/EdOeinhqE9oa59d/kuXTokCMekTrj3Rk7LtNBQIDjnyjNOA==} +@@ -6023,8 +5910,8 @@ packages: + engines: {node: '>=16 || 14 >=14.17'} + hasBin: true + +- glob@10.4.5: +- resolution: {integrity: sha512-7Bv8RF0k6xjo7d4A/PxYLbUCfb6c+Vpd2/mB2yRDlew7Jb5hEXiCD9ibfO7wpk8i4sevK6DFny9h7EYbM3/sHg==} ++ glob@10.5.0: ++ resolution: {integrity: sha512-DfXN8DfhJ7NH3Oe7cFmu3NCu1wKbkReJ8TorzSAFbSKrlNaQSKfIzqYqVY8zlbs2NLBbWpRiU52GX2PbaBVNkg==} + hasBin: true + + glob@11.0.0: +@@ -6188,10 +6075,6 @@ packages: + resolution: {integrity: sha512-FtwrG/euBzaEjYeRqOgly7G0qviiXoJWnvEH2Z1plBdXgbyjv34pHTSb9zoeHMyDy33+DWy5Wt9Wo+TURtOYSQ==} + engines: {node: '>= 0.8'} + +- http-errors@2.0.1: +- resolution: {integrity: sha512-4FbRdAX+bSdmo4AUFuS0WNiPz8NgFt+r8ThgNWmlrjQjt1Q7ZR9+zTlce2859x4KSXrwIsaeTqDoKQmtP8pLmQ==} +- engines: {node: '>= 0.8'} +- + http-parser-js@0.5.10: + resolution: {integrity: sha512-Pysuw9XpUq5dVc/2SMHpuTY01RFl8fttgcyunjL7eEMhGM3cI4eOmiCycJDVCo/7O7ClfQD3SaI6ftDzqOXYMA==} + +@@ -6230,8 +6113,8 @@ packages: + resolution: {integrity: sha512-4fCk79wshMdzMp2rH06qWrJE4iolqLhCUH+OiuIgU++RB0+94NlDL81atO7GX55uUKueo0txHNtvEyI6D7WdMw==} + engines: {node: '>=0.10.0'} + +- iconv-lite@0.7.0: +- resolution: {integrity: sha512-cf6L2Ds3h57VVmkZe+Pn+5APsT7FpqJtEhhieDCvrE2MK5Qk9MyffgQyuxQTm6BChfeZNtcOLHp9IcWRVcIcBQ==} ++ iconv-lite@0.7.1: ++ resolution: {integrity: sha512-2Tth85cXwGFHfvRgZWszZSvdo+0Xsqmw8k8ZwxScfcBneNUraK+dxRxRm24nszx80Y0TVio8kKLt5sLE7ZCLlw==} + engines: {node: '>=0.10.0'} + + idb@7.1.1: +@@ -6299,8 +6182,8 @@ packages: + resolution: {integrity: sha512-Ju0Bz/cEia55xDwUWEa8+olFpCiQoypjnQySseKtmjNrnps3P+xfpUmGr90T7yjlVJmOtybRvPXhKMbHr+fWnw==} + engines: {node: '>= 0.10'} + +- ip-address@10.0.1: +- resolution: {integrity: sha512-NWv9YLW4PoW2B7xtzaS3NCot75m6nK7Icdv0o3lfMceJVRfSoQwqD4wEH5rLwoKJwUiZ/rfpiVBhnaF0FK4HoA==} ++ ip-address@10.1.0: ++ resolution: {integrity: sha512-XXADHxXmvT9+CRxhXg56LJovE+bmWnEWB78LB83VZTprKTmaC5QfruXocxzTZ2Kl0DNwKuBdlIhjL8LeY8Sf8Q==} + engines: {node: '>= 12'} + + ip-regex@4.3.0: +@@ -6793,9 +6676,6 @@ packages: + jwa@2.0.0: + resolution: {integrity: sha512-jrZ2Qx916EA+fq9cEAeCROWPTfCwi1IVHqT2tapuqLEVVDKFDENFw1oL+MwrTvH6msKxsd1YTDVw6uKEcsrLEA==} + +- jwa@2.0.1: +- resolution: {integrity: sha512-hRF04fqJIP8Abbkq5NKGN0Bbr3JxlQ+qhZufXVr0DvujKy93ZCbXZMHDL4EOtodSbCWxOqR8MS1tXA5hwqCXDg==} +- + jwks-rsa@3.1.0: + resolution: {integrity: sha512-v7nqlfezb9YfHHzYII3ef2a2j1XnGeSE/bK3WfumaYCqONAIstJbrEGapz4kadScZzEt7zYCN7bucj8C0Mv/Rg==} + engines: {node: '>=14'} +@@ -6810,9 +6690,6 @@ packages: + jws@4.0.0: + resolution: {integrity: sha512-KDncfTmOZoOMTFG4mBlG0qUIOlc03fmzH+ru6RgYVZhPkyiy/92Owlt/8UEN+a4TXR1FQetfIpJE8ApdvdVxTg==} + +- jws@4.0.1: +- resolution: {integrity: sha512-EKI/M/yqPncGUUh44xz0PxSidXFr/+r0pA70+gIYhjv+et7yxM+s29Y+VGDkovRofQem0fs7Uvf4+YmAdyRduA==} +- + kind-of@3.2.2: + resolution: {integrity: sha512-NOW9QQXMoZGg/oqnVNoNTTIFEIid1627WCffUBJEdMxYApq7mNE7CpzucIPc+ZQg25Phej7IJSmX3hO+oblOtQ==} + engines: {node: '>=0.10.0'} +@@ -7446,8 +7323,8 @@ packages: + resolution: {integrity: sha512-dBpDMdxv9Irdq66304OLfEmQ9tbNRFnFTuZiLo+bD+r332bBmMJ8GBLXklIXXgxd3+v9+KUnZaUR5PJMa75Gsg==} + engines: {node: '>= 0.4.0'} + +- next@15.4.10: +- resolution: {integrity: sha512-itVlc79QjpKMFMRhP+kbGKaSG/gZM6RCvwhEbwmCNF06CdDiNaoHcbeg0PqkEa2GOcn8KJ0nnc7+yL7EjoYLHQ==} ++ next@15.5.9: ++ resolution: {integrity: sha512-agNLK89seZEtC5zUHwtut0+tNrc0Xw4FT/Dg+B/VLEo9pAcS9rtTKpek3V6kVcVwsB2YlqMaHdfZL4eLEVYuCg==} + engines: {node: ^18.18.0 || ^19.8.0 || >= 20.0.0} + hasBin: true + peerDependencies: +@@ -7496,10 +7373,6 @@ packages: + resolution: {integrity: sha512-dPEtOeMvF9VMcYV/1Wb8CPoVAXtp6MKMlcbAt4ddqmGqUJ6fQZFXkNZNkNlfevtNkGtaSoXf/vNNNSvgrdXwtA==} + engines: {node: '>= 6.13.0'} + +- node-forge@1.3.3: +- resolution: {integrity: sha512-rLvcdSyRCyouf6jcOIPe/BgwG/d7hKjzMKOas33/pHEr6gbq18IK9zV7DiPvzsz0oBJPme6qr6H6kGZuI9/DZg==} +- engines: {node: '>= 6.13.0'} +- + node-gyp@11.5.0: + resolution: {integrity: sha512-ra7Kvlhxn5V9Slyus0ygMa2h+UqExPqUIkfk7Pc8QTLT956JLSy51uWFwHtIYy0vI8cB4BDhc/S03+880My/LQ==} + engines: {node: ^18.17.0 || >=20.5.0} +@@ -7645,8 +7518,8 @@ packages: + resolution: {integrity: sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A==} + engines: {node: '>=8'} + +- p-map@7.0.3: +- resolution: {integrity: sha512-VkndIv2fIB99swvQoA65bm+fsmt6UNdGeIB0oxBs+WhAhdh08QA04JXpI7rbB9r08/nkbysKoya9rtDERYOYMA==} ++ p-map@7.0.4: ++ resolution: {integrity: sha512-tkAQEw8ysMzmkhgw8k+1U/iPhWNhykKnSk4Rd5zLoPJCuJaGRPo6YposrZgaxHKzDHdDWWZvE/Sk7hsL2X/CpQ==} + engines: {node: '>=18'} + + p-queue@6.6.2: +@@ -7952,8 +7825,8 @@ packages: + resolution: {integrity: sha512-RXyHaACeqXeqAKGLDl68rQKbmObRsTIn4TYVUUug1KfS47YWCo5MacGITEryugIgZqORCvJWEk4l449POg5Txg==} + engines: {node: '>=12.0.0'} + +- protobufjs@7.5.4: +- resolution: {integrity: sha512-CvexbZtbov6jW2eXAvLukXjXUW1TzFaivC46BpWc/3BpcCysb5Vffu+B3XHMm8lVEuy2Mm4XGex8hBSg1yapPg==} ++ protobufjs@7.5.3: ++ resolution: {integrity: sha512-sildjKwVqOI2kmFDiXQ6aEB0fjYTafpEvIBs8tOR8qI4spuL9OPROLVu2qZqi/xgCfsHIwVqlaF8JBjWFHnKbw==} + engines: {node: '>=12.0.0'} + + proxy-addr@2.0.7: +@@ -8017,10 +7890,6 @@ packages: + resolution: {integrity: sha512-8zGqypfENjCIqGhgXToC8aB2r7YrBX+AQAfIPs/Mlk+BtPTztOvTS01NRW/3Eh60J+a48lt8qsCzirQ6loCVfA==} + engines: {node: '>= 0.8'} + +- raw-body@2.5.3: +- resolution: {integrity: sha512-s4VSOf6yN0rvbRZGxs8Om5CWj6seneMwK3oDb4lWDH0UPhWcxwOWw5+qk24bxq87szX1ydrwylIOp2uG1ojUpA==} +- engines: {node: '>= 0.8'} +- + raw-body@3.0.0: + resolution: {integrity: sha512-RmkhL8CAyCRPXCE28MMH0z2PNWQBNk2Q09ZdxM9IOOXwxwZbN+qbWaatPkdkWIKL2ZVDImrN/pK5HTRz2PcS4g==} + engines: {node: '>= 0.8'} +@@ -8029,8 +7898,8 @@ packages: + resolution: {integrity: sha512-y3bGgqKj3QBdxLbLkomlohkvsA8gdAiUQlSBJnBhfn+BPxg4bc62d8TcBW15wavDfgexCgccckhcZvywyQYPOw==} + hasBin: true + +- re2@1.22.1: +- resolution: {integrity: sha512-E4J0EtgyNLdIr0wTg0dQPefuiqNY29KaLacytiUAYYRzxCG+zOkWoUygt1rI+TA1LrhN49/njrfSO1DHtVC5Vw==} ++ re2@1.22.3: ++ resolution: {integrity: sha512-002aE82U91DiaUA16U6vbiJusvPXn1OWiQukOxJkVUTXbzrSuQbFNHYKcGw8QK/uifRCfjl2Hd/vXYDanKkmaQ==} + + react-dom@18.3.1: + resolution: {integrity: sha512-5m4nQKp+rZRb09LNH59GM4BxTh9251/ylbKIbpe7TpGxfJ+9kv6BLkLBXIjjspbgbnIBNqlI23tRnTWT0snUIw==} +@@ -8203,10 +8072,6 @@ packages: + resolution: {integrity: sha512-e2bDA2WJT0wxseVd4lsDP4+3ONX6HpMXQa1ZhFQ7SU+GjvORCmShbCMltrtIDfkYhVHrOcPtj+KhmDBdPdZD1g==} + engines: {node: '>=10'} + +- safe-stable-stringify@2.5.0: +- resolution: {integrity: sha512-b3rppTKm9T+PsVCBEOUR46GWI7fdOs00VKZ1+9c1EWDaDMvjQc6tUwuFyIprgGgTcWoVHSKrU8H31ZHA2e0RHA==} +- engines: {node: '>=10'} +- + safer-buffer@2.1.2: + resolution: {integrity: sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==} + +@@ -8244,10 +8109,6 @@ packages: + resolution: {integrity: sha512-dW41u5VfLXu8SJh5bwRmyYUbAoSB3c9uQh6L8h/KtsFREPWpbX1lrljJo186Jc4nmci/sGUZ9a0a0J2zgfq2hw==} + engines: {node: '>= 0.8.0'} + +- send@0.19.1: +- resolution: {integrity: sha512-p4rRk4f23ynFEfcD9LA0xRYngj+IyGiEYyqqOak8kaN0TvNmuxC2dcVeBn62GpCeR2CpWqyHCNScTP91QbAVFg==} +- engines: {node: '>= 0.8.0'} +- + send@1.2.0: + resolution: {integrity: sha512-uaW0WwXKpL9blXE2o0bRhoL2EGXIrZxQ2ZQ4mgcfoBxdFmQold+qWsD2jLrfZ0trjKL6vOw0j//eAwcALFjKSw==} + engines: {node: '>= 18'} +@@ -8388,8 +8249,8 @@ packages: + sprintf-js@1.0.3: + resolution: {integrity: sha512-D9cPgkvLlV3t3IzL0D0YLvGA9Ahk4PcvVwUbN0dSGr1aP0Nrt4AEnTUbuGvquEC0mA64Gqt1fzirlRs5ibXx8g==} + +- sql-formatter@15.6.10: +- resolution: {integrity: sha512-0bJOPQrRO/JkjQhiThVayq0hOKnI1tHI+2OTkmT7TGtc6kqS+V7kveeMzRW+RNQGxofmTmet9ILvztyuxv0cJQ==} ++ sql-formatter@15.6.12: ++ resolution: {integrity: sha512-mkpF+RG402P66VMsnQkWewTRzDBWfu9iLbOfxaW/nAKOS/2A9MheQmcU5cmX0D0At9azrorZwpvcBRNNBozACQ==} + hasBin: true + + ssri@12.0.0: +@@ -8411,10 +8272,6 @@ packages: + resolution: {integrity: sha512-RwNA9Z/7PrK06rYLIzFMlaF+l73iwpzsqRIFgbMLbTcLD6cOao82TaWefPXQvB2fOC4AjuYSEndS7N/mTCbkdQ==} + engines: {node: '>= 0.8'} + +- statuses@2.0.2: +- resolution: {integrity: sha512-DvEy55V3DB7uknRo+4iOGT5fP1slR8wQohVdknigZPMpMstaKJQWhwiYBACJE3Ul2pTnATihhBYnRhZQHGBiRw==} +- engines: {node: '>= 0.8'} +- + stop-iteration-iterator@1.1.0: + resolution: {integrity: sha512-eLoXW/DHyl62zxY4SCaIgnRhuMr6ri4juEYARS8E6sCEqzKpOiE521Ucofdx+KnDZl5xmvGYaaKCk5FEOxJCoQ==} + engines: {node: '>= 0.4'} +@@ -9049,10 +8906,6 @@ packages: + resolution: {integrity: sha512-DLiFIXYC5fMPxaRg832S6F5mJYvePtmO5G9v9IgUFPhXm9/GkXarH/TUrBAVzhTCzAj9anE/+GjrgXp/54nOgw==} + engines: {node: '>= 12.0.0'} + +- winston@3.19.0: +- resolution: {integrity: sha512-LZNJgPzfKR+/J3cHkxcpHKpKKvGfDZVPS4hfJCc4cCG0CgYzvlD6yE/S3CIL/Yt91ak327YCpiF/0MyeZHEHKA==} +- engines: {node: '>= 12.0.0'} +- + wordwrap@1.0.0: + resolution: {integrity: sha512-gvVzJFlPycKc5dZN4yPkP8w7Dc37BtP1yczEneOb4uq34pXZcvrtRTmWV8W+Ume+XCxKgbjM+nevkyFPMybd4Q==} + +@@ -9140,11 +8993,6 @@ packages: + engines: {node: '>= 14.6'} + hasBin: true + +- yaml@2.8.2: +- resolution: {integrity: sha512-mplynKqc1C2hTVYxd0PU2xQAc22TI1vShAYGksCCfxbn/dFwnHTNi1bvYsBTkhdUNtGIf5xNOg938rrSSYvS9A==} +- engines: {node: '>= 14.6'} +- hasBin: true +- + yargs-parser@20.2.9: + resolution: {integrity: sha512-y11nGElTIV+CT3Zv9t7VKl+Q3hTQoT9a1Qzezhhl6Rp21gJ/IVTW7Z3y9EWXhuUBC2Shnf+DX0antecpAwSP8w==} + engines: {node: '>=10'} +@@ -9178,20 +9026,12 @@ packages: + peerDependencies: + zod: ^3.24.1 + +- zod-to-json-schema@3.25.0: +- resolution: {integrity: sha512-HvWtU2UG41LALjajJrML6uQejQhNJx+JBO9IflpSja4R03iNWfKXrj6W2h7ljuLyc1nKS+9yDyL/9tD1U/yBnQ==} +- peerDependencies: +- zod: ^3.25 || ^4 +- + zod@3.22.4: + resolution: {integrity: sha512-iC+8Io04lddc+mVqQ9AZ7OQ2MrUKGN+oIQyq1vemgt46jwCwLfhq7/pwnBnNXXXZb8VTVLKwp9EDkx+ryxIWmg==} + + zod@3.25.67: + resolution: {integrity: sha512-idA2YXwpCdqUSKRCACDE6ItZD9TZzy3OZMtpfLoh6oPR47lipysRrJfjzMqFxQ3uJuUPyUeWe1r9vLH33xO/Qw==} + +- zod@3.25.76: +- resolution: {integrity: sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==} +- + snapshots: + + '@ampproject/remapping@2.3.0': +@@ -9212,11 +9052,11 @@ snapshots: + transitivePeerDependencies: + - encoding + +- '@anthropic-ai/sdk@0.68.0(zod@3.25.76)': ++ '@anthropic-ai/sdk@0.71.2(zod@3.25.67)': + dependencies: + json-schema-to-ts: 3.1.1 + optionalDependencies: +- zod: 3.25.76 ++ zod: 3.25.67 + + '@anthropic-ai/sdk@0.9.1(encoding@0.1.13)': + dependencies: +@@ -9466,13 +9306,6 @@ snapshots: + enabled: 2.0.0 + kuler: 2.0.0 + +- '@dabh/diagnostics@2.0.8': +- dependencies: +- '@so-ric/colorspace': 1.1.6 +- enabled: 2.0.0 +- kuler: 2.0.0 +- optional: true +- + '@electric-sql/pglite@0.2.17': {} + + '@emnapi/runtime@1.7.1': +@@ -9559,8 +9392,7 @@ snapshots: + + '@fastify/busboy@3.0.0': {} + +- '@fastify/busboy@3.2.0': +- optional: true ++ '@fastify/busboy@3.1.1': {} + + '@firebase/ai@1.4.0(@firebase/app-types@0.9.3)(@firebase/app@0.13.1)': + dependencies: +@@ -9717,12 +9549,6 @@ snapshots: + '@firebase/app-types': 0.9.3 + '@firebase/util': 1.12.0 + +- '@firebase/database-types@1.0.16': +- dependencies: +- '@firebase/app-types': 0.9.3 +- '@firebase/util': 1.13.0 +- optional: true +- + '@firebase/database-types@1.0.6': + dependencies: + '@firebase/app-types': 0.9.2 +@@ -9936,20 +9762,15 @@ snapshots: + dependencies: + tslib: 2.8.1 + +- '@firebase/util@1.13.0': +- dependencies: +- tslib: 2.8.1 +- optional: true +- + '@firebase/webchannel-wrapper@1.0.3': {} + +- '@genkit-ai/ai@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1))': ++ '@genkit-ai/ai@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1))': + dependencies: +- '@genkit-ai/core': 1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)) ++ '@genkit-ai/core': 1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)) + '@opentelemetry/api': 1.9.0 +- '@types/node': 20.19.26 ++ '@types/node': 20.19.1 + colorette: 2.0.20 +- dotprompt: 1.1.2 ++ dotprompt: 1.1.1 + json5: 2.2.3 + node-fetch: 3.3.2 + partial-json: 0.1.7 +@@ -9964,13 +9785,13 @@ snapshots: + - supports-color + optional: true + +- '@genkit-ai/ai@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit)': ++ '@genkit-ai/ai@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit)': + dependencies: +- '@genkit-ai/core': 1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit) ++ '@genkit-ai/core': 1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit) + '@opentelemetry/api': 1.9.0 +- '@types/node': 20.19.26 ++ '@types/node': 20.19.1 + colorette: 2.0.20 +- dotprompt: 1.1.2 ++ dotprompt: 1.1.1 + json5: 2.2.3 + node-fetch: 3.3.2 + partial-json: 0.1.7 +@@ -9984,7 +9805,7 @@ snapshots: + - genkit + - supports-color + +- '@genkit-ai/core@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1))': ++ '@genkit-ai/core@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1))': + dependencies: + '@opentelemetry/api': 1.9.0 + '@opentelemetry/context-async-hooks': 1.25.1(@opentelemetry/api@1.9.0) +@@ -9997,16 +9818,16 @@ snapshots: + ajv: 8.17.1 + ajv-formats: 3.0.1(ajv@8.17.1) + async-mutex: 0.5.0 +- body-parser: 1.20.4 ++ body-parser: 1.20.3 + cors: 2.8.5 +- dotprompt: 1.1.2 +- express: 4.22.1 ++ dotprompt: 1.1.1 ++ express: 4.21.2 + get-port: 5.1.1 + json-schema: 0.4.0 +- zod: 3.25.76 +- zod-to-json-schema: 3.25.0(zod@3.25.76) ++ zod: 3.25.67 ++ zod-to-json-schema: 3.24.5(zod@3.25.67) + optionalDependencies: +- '@genkit-ai/firebase': 1.25.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)) ++ '@genkit-ai/firebase': 1.16.1(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)) + transitivePeerDependencies: + - '@google-cloud/firestore' + - encoding +@@ -10016,7 +9837,7 @@ snapshots: + - supports-color + optional: true + +- '@genkit-ai/core@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit)': ++ '@genkit-ai/core@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit)': + dependencies: + '@opentelemetry/api': 1.9.0 + '@opentelemetry/context-async-hooks': 1.25.1(@opentelemetry/api@1.9.0) +@@ -10029,16 +9850,16 @@ snapshots: + ajv: 8.17.1 + ajv-formats: 3.0.1(ajv@8.17.1) + async-mutex: 0.5.0 +- body-parser: 1.20.4 ++ body-parser: 1.20.3 + cors: 2.8.5 +- dotprompt: 1.1.2 +- express: 4.22.1 ++ dotprompt: 1.1.1 ++ express: 4.21.2 + get-port: 5.1.1 + json-schema: 0.4.0 +- zod: 3.25.76 +- zod-to-json-schema: 3.25.0(zod@3.25.76) ++ zod: 3.25.67 ++ zod-to-json-schema: 3.24.5(zod@3.25.67) + optionalDependencies: +- '@genkit-ai/firebase': 1.25.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit) ++ '@genkit-ai/firebase': 1.16.1(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit) + transitivePeerDependencies: + - '@google-cloud/firestore' + - encoding +@@ -10047,9 +9868,9 @@ snapshots: + - genkit + - supports-color + +- '@genkit-ai/express@1.12.0(@genkit-ai/core@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit))(express@5.1.0)(genkit@genkit)': ++ '@genkit-ai/express@1.12.0(@genkit-ai/core@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit))(express@5.1.0)(genkit@genkit)': + dependencies: +- '@genkit-ai/core': 1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit) ++ '@genkit-ai/core': 1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit) + body-parser: 1.20.3 + cors: 2.8.5 + express: 5.1.0 +@@ -10057,25 +9878,12 @@ snapshots: + transitivePeerDependencies: + - supports-color + +- '@genkit-ai/firebase@1.16.1(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1))': +- dependencies: +- '@genkit-ai/google-cloud': 1.16.1(encoding@0.1.13)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)) +- '@google-cloud/firestore': 7.11.6(encoding@0.1.13) +- firebase-admin: 13.5.0(encoding@0.1.13) +- genkit: 1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1) +- optionalDependencies: +- firebase: 11.9.1 +- transitivePeerDependencies: +- - encoding +- - supports-color +- optional: true +- +- '@genkit-ai/firebase@1.25.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1))': ++ '@genkit-ai/firebase@1.16.1(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1))': + dependencies: +- '@genkit-ai/google-cloud': 1.25.0(encoding@0.1.13)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)) +- '@google-cloud/firestore': 7.11.6(encoding@0.1.13) +- firebase-admin: 13.5.0(encoding@0.1.13) +- genkit: 1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1) ++ '@genkit-ai/google-cloud': 1.16.1(encoding@0.1.13)(genkit@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)) ++ '@google-cloud/firestore': 7.11.1(encoding@0.1.13) ++ firebase-admin: 13.6.0(encoding@0.1.13) ++ genkit: 1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1) + optionalDependencies: + firebase: 11.9.1 + transitivePeerDependencies: +@@ -10083,11 +9891,11 @@ snapshots: + - supports-color + optional: true + +- '@genkit-ai/firebase@1.25.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit)': ++ '@genkit-ai/firebase@1.16.1(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit)': + dependencies: +- '@genkit-ai/google-cloud': 1.25.0(encoding@0.1.13)(genkit@genkit) +- '@google-cloud/firestore': 7.11.6(encoding@0.1.13) +- firebase-admin: 13.5.0(encoding@0.1.13) ++ '@genkit-ai/google-cloud': 1.16.1(encoding@0.1.13)(genkit@genkit) ++ '@google-cloud/firestore': 7.11.1(encoding@0.1.13) ++ firebase-admin: 13.6.0(encoding@0.1.13) + genkit: link:genkit + optionalDependencies: + firebase: 11.9.1 +@@ -10096,7 +9904,7 @@ snapshots: + - supports-color + optional: true + +- '@genkit-ai/google-cloud@1.16.1(encoding@0.1.13)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1))': ++ '@genkit-ai/google-cloud@1.16.1(encoding@0.1.13)(genkit@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1))': + dependencies: + '@google-cloud/logging-winston': 6.0.1(encoding@0.1.13)(winston@3.17.0) + '@google-cloud/opentelemetry-cloud-monitoring-exporter': 0.19.0(@opentelemetry/api@1.9.0)(@opentelemetry/core@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/resources@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-metrics@1.25.1(@opentelemetry/api@1.9.0))(encoding@0.1.13) +@@ -10112,7 +9920,7 @@ snapshots: + '@opentelemetry/sdk-metrics': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/sdk-node': 0.52.1(@opentelemetry/api@1.9.0) + '@opentelemetry/sdk-trace-base': 1.25.1(@opentelemetry/api@1.9.0) +- genkit: 1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1) ++ genkit: 1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1) + google-auth-library: 9.15.1(encoding@0.1.13) + node-fetch: 3.3.2 + winston: 3.17.0 +@@ -10121,34 +9929,9 @@ snapshots: + - supports-color + optional: true + +- '@genkit-ai/google-cloud@1.25.0(encoding@0.1.13)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1))': ++ '@genkit-ai/google-cloud@1.16.1(encoding@0.1.13)(genkit@genkit)': + dependencies: +- '@google-cloud/logging-winston': 6.0.1(encoding@0.1.13)(winston@3.19.0) +- '@google-cloud/opentelemetry-cloud-monitoring-exporter': 0.19.0(@opentelemetry/api@1.9.0)(@opentelemetry/core@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/resources@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-metrics@1.25.1(@opentelemetry/api@1.9.0))(encoding@0.1.13) +- '@google-cloud/opentelemetry-cloud-trace-exporter': 2.4.1(@opentelemetry/api@1.9.0)(@opentelemetry/core@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/resources@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@1.25.1(@opentelemetry/api@1.9.0))(encoding@0.1.13) +- '@google-cloud/opentelemetry-resource-util': 2.4.0(@opentelemetry/resources@1.25.1(@opentelemetry/api@1.9.0))(encoding@0.1.13) +- '@opentelemetry/api': 1.9.0 +- '@opentelemetry/auto-instrumentations-node': 0.49.2(@opentelemetry/api@1.9.0)(encoding@0.1.13) +- '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) +- '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) +- '@opentelemetry/instrumentation-pino': 0.41.0(@opentelemetry/api@1.9.0) +- '@opentelemetry/instrumentation-winston': 0.39.0(@opentelemetry/api@1.9.0) +- '@opentelemetry/resources': 1.25.1(@opentelemetry/api@1.9.0) +- '@opentelemetry/sdk-metrics': 1.25.1(@opentelemetry/api@1.9.0) +- '@opentelemetry/sdk-node': 0.52.1(@opentelemetry/api@1.9.0) +- '@opentelemetry/sdk-trace-base': 1.25.1(@opentelemetry/api@1.9.0) +- genkit: 1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1) +- google-auth-library: 9.15.1(encoding@0.1.13) +- node-fetch: 3.3.2 +- winston: 3.19.0 +- transitivePeerDependencies: +- - encoding +- - supports-color +- optional: true +- +- '@genkit-ai/google-cloud@1.25.0(encoding@0.1.13)(genkit@genkit)': +- dependencies: +- '@google-cloud/logging-winston': 6.0.1(encoding@0.1.13)(winston@3.19.0) ++ '@google-cloud/logging-winston': 6.0.1(encoding@0.1.13)(winston@3.17.0) + '@google-cloud/opentelemetry-cloud-monitoring-exporter': 0.19.0(@opentelemetry/api@1.9.0)(@opentelemetry/core@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/resources@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-metrics@1.25.1(@opentelemetry/api@1.9.0))(encoding@0.1.13) + '@google-cloud/opentelemetry-cloud-trace-exporter': 2.4.1(@opentelemetry/api@1.9.0)(@opentelemetry/core@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/resources@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@1.25.1(@opentelemetry/api@1.9.0))(encoding@0.1.13) + '@google-cloud/opentelemetry-resource-util': 2.4.0(@opentelemetry/resources@1.25.1(@opentelemetry/api@1.9.0))(encoding@0.1.13) +@@ -10165,7 +9948,7 @@ snapshots: + genkit: link:genkit + google-auth-library: 9.15.1(encoding@0.1.13) + node-fetch: 3.3.2 +- winston: 3.19.0 ++ winston: 3.17.0 + transitivePeerDependencies: + - encoding + - supports-color +@@ -10250,18 +10033,6 @@ snapshots: + - encoding + - supports-color + +- '@google-cloud/firestore@7.11.6(encoding@0.1.13)': +- dependencies: +- '@opentelemetry/api': 1.9.0 +- fast-deep-equal: 3.1.3 +- functional-red-black-tree: 1.0.1 +- google-gax: 4.6.1(encoding@0.1.13) +- protobufjs: 7.5.4 +- transitivePeerDependencies: +- - encoding +- - supports-color +- optional: true +- + '@google-cloud/logging-winston@6.0.1(encoding@0.1.13)(winston@3.17.0)': + dependencies: + '@google-cloud/logging': 11.0.0(encoding@0.1.13) +@@ -10273,18 +10044,6 @@ snapshots: + - encoding + - supports-color + +- '@google-cloud/logging-winston@6.0.1(encoding@0.1.13)(winston@3.19.0)': +- dependencies: +- '@google-cloud/logging': 11.0.0(encoding@0.1.13) +- google-auth-library: 9.15.1(encoding@0.1.13) +- lodash.mapvalues: 4.6.0 +- winston: 3.19.0 +- winston-transport: 4.7.0 +- transitivePeerDependencies: +- - encoding +- - supports-color +- optional: true +- + '@google-cloud/logging@11.0.0(encoding@0.1.13)': + dependencies: + '@google-cloud/common': 5.0.1(encoding@0.1.13) +@@ -10407,28 +10166,6 @@ snapshots: + - supports-color + optional: true + +- '@google-cloud/storage@7.18.0(encoding@0.1.13)': +- dependencies: +- '@google-cloud/paginator': 5.0.2 +- '@google-cloud/projectify': 4.0.0 +- '@google-cloud/promisify': 4.0.0 +- abort-controller: 3.0.0 +- async-retry: 1.3.3 +- duplexify: 4.1.3 +- fast-xml-parser: 4.5.3 +- gaxios: 6.7.1(encoding@0.1.13) +- google-auth-library: 9.15.1(encoding@0.1.13) +- html-entities: 2.6.0 +- mime: 3.0.0 +- p-limit: 3.1.0 +- retry-request: 7.0.2(encoding@0.1.13) +- teeny-request: 9.0.0(encoding@0.1.13) +- uuid: 8.3.2 +- transitivePeerDependencies: +- - encoding +- - supports-color +- optional: true +- + '@google-cloud/vertexai@1.10.0(encoding@0.1.13)': + dependencies: + google-auth-library: 9.15.1(encoding@0.1.13) +@@ -10470,7 +10207,7 @@ snapshots: + '@grpc/proto-loader': 0.7.13 + '@js-sdsl/ordered-map': 4.4.2 + +- '@grpc/grpc-js@1.14.2': ++ '@grpc/grpc-js@1.14.3': + dependencies: + '@grpc/proto-loader': 0.8.0 + '@js-sdsl/ordered-map': 4.4.2 +@@ -10491,14 +10228,14 @@ snapshots: + dependencies: + lodash.camelcase: 4.3.0 + long: 5.3.2 +- protobufjs: 7.5.4 ++ protobufjs: 7.5.3 + yargs: 17.7.2 + + '@grpc/proto-loader@0.8.0': + dependencies: + lodash.camelcase: 4.3.0 + long: 5.3.2 +- protobufjs: 7.5.4 ++ protobufjs: 7.5.3 + yargs: 17.7.2 + + '@img/colour@1.0.0': +@@ -10598,10 +10335,10 @@ snapshots: + '@img/sharp-win32-x64@0.34.5': + optional: true + +- '@inquirer/external-editor@1.0.2(@types/node@20.19.1)': ++ '@inquirer/external-editor@1.0.3(@types/node@20.19.1)': + dependencies: + chardet: 2.1.1 +- iconv-lite: 0.7.0 ++ iconv-lite: 0.7.1 + optionalDependencies: + '@types/node': 20.19.1 + +@@ -10840,9 +10577,6 @@ snapshots: + + '@jridgewell/sourcemap-codec@1.5.0': {} + +- '@jridgewell/sourcemap-codec@1.5.5': +- optional: true +- + '@jridgewell/trace-mapping@0.3.25': + dependencies: + '@jridgewell/resolve-uri': 3.1.2 +@@ -10851,7 +10585,7 @@ snapshots: + '@jridgewell/trace-mapping@0.3.9': + dependencies: + '@jridgewell/resolve-uri': 3.1.2 +- '@jridgewell/sourcemap-codec': 1.5.5 ++ '@jridgewell/sourcemap-codec': 1.5.0 + optional: true + + '@js-sdsl/ordered-map@4.4.2': {} +@@ -10907,9 +10641,9 @@ snapshots: + dependencies: + '@langchain/core': 0.1.63 + js-tiktoken: 1.0.11 +- openai: 4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.76) +- zod: 3.25.76 +- zod-to-json-schema: 3.24.5(zod@3.25.76) ++ openai: 4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67) ++ zod: 3.25.67 ++ zod-to-json-schema: 3.24.5(zod@3.25.67) + transitivePeerDependencies: + - encoding + - ws +@@ -10919,10 +10653,10 @@ snapshots: + '@langchain/core': 0.1.63 + js-tiktoken: 1.0.11 + +- '@mistralai/mistralai-gcp@1.5.0(encoding@0.1.13)(zod@3.25.76)': ++ '@mistralai/mistralai-gcp@1.5.0(encoding@0.1.13)(zod@3.25.67)': + dependencies: + google-auth-library: 9.15.1(encoding@0.1.13) +- zod: 3.25.76 ++ zod: 3.25.67 + transitivePeerDependencies: + - encoding + - supports-color +@@ -10961,13 +10695,13 @@ snapshots: + transitivePeerDependencies: + - supports-color + +- '@modelcontextprotocol/server-filesystem@2025.7.1(zod@3.25.76)': ++ '@modelcontextprotocol/server-filesystem@2025.7.1(zod@3.25.67)': + dependencies: + '@modelcontextprotocol/sdk': 1.15.0 + diff: 5.2.0 + glob: 10.3.12 + minimatch: 10.0.1 +- zod-to-json-schema: 3.24.5(zod@3.25.76) ++ zod-to-json-schema: 3.24.5(zod@3.25.67) + transitivePeerDependencies: + - supports-color + - zod +@@ -11016,30 +10750,30 @@ snapshots: + '@napi-rs/canvas-win32-x64-msvc': 0.1.71 + optional: true + +- '@next/env@15.4.10': {} ++ '@next/env@15.5.9': {} + +- '@next/swc-darwin-arm64@15.4.8': ++ '@next/swc-darwin-arm64@15.5.7': + optional: true + +- '@next/swc-darwin-x64@15.4.8': ++ '@next/swc-darwin-x64@15.5.7': + optional: true + +- '@next/swc-linux-arm64-gnu@15.4.8': ++ '@next/swc-linux-arm64-gnu@15.5.7': + optional: true + +- '@next/swc-linux-arm64-musl@15.4.8': ++ '@next/swc-linux-arm64-musl@15.5.7': + optional: true + +- '@next/swc-linux-x64-gnu@15.4.8': ++ '@next/swc-linux-x64-gnu@15.5.7': + optional: true + +- '@next/swc-linux-x64-musl@15.4.8': ++ '@next/swc-linux-x64-musl@15.5.7': + optional: true + +- '@next/swc-win32-arm64-msvc@15.4.8': ++ '@next/swc-win32-arm64-msvc@15.5.7': + optional: true + +- '@next/swc-win32-x64-msvc@15.4.8': ++ '@next/swc-win32-x64-msvc@15.5.7': + optional: true + + '@npmcli/agent@3.0.0': +@@ -11055,7 +10789,7 @@ snapshots: + + '@npmcli/fs@4.0.0': + dependencies: +- semver: 7.7.3 ++ semver: 7.7.2 + optional: true + + '@opentelemetry/api-logs@0.52.1': +@@ -11842,12 +11576,6 @@ snapshots: + lodash.get: 4.4.2 + type-detect: 4.1.0 + +- '@so-ric/colorspace@1.1.6': +- dependencies: +- color: 5.0.3 +- text-hex: 1.0.0 +- optional: true +- + '@swc/helpers@0.5.15': + dependencies: + tslib: 2.8.1 +@@ -11856,7 +11584,7 @@ snapshots: + + '@tootallnate/quickjs-emscripten@0.23.0': {} + +- '@tsconfig/node10@1.0.12': ++ '@tsconfig/node10@1.0.11': + optional: true + + '@tsconfig/node12@1.0.11': +@@ -11937,14 +11665,6 @@ snapshots: + '@types/range-parser': 1.2.7 + '@types/send': 0.17.4 + +- '@types/express-serve-static-core@4.19.7': +- dependencies: +- '@types/node': 20.19.26 +- '@types/qs': 6.14.0 +- '@types/range-parser': 1.2.7 +- '@types/send': 1.2.1 +- optional: true +- + '@types/express@4.17.23': + dependencies: + '@types/body-parser': 1.19.5 +@@ -11952,14 +11672,6 @@ snapshots: + '@types/qs': 6.9.14 + '@types/serve-static': 1.15.5 + +- '@types/express@4.17.25': +- dependencies: +- '@types/body-parser': 1.19.6 +- '@types/express-serve-static-core': 4.19.7 +- '@types/qs': 6.14.0 +- '@types/serve-static': 1.15.10 +- optional: true +- + '@types/graceful-fs@4.1.9': + dependencies: + '@types/node': 20.19.1 +@@ -11974,9 +11686,6 @@ snapshots: + + '@types/http-errors@2.0.4': {} + +- '@types/http-errors@2.0.5': +- optional: true +- + '@types/istanbul-lib-coverage@2.0.6': {} + + '@types/istanbul-lib-report@3.0.3': +@@ -11999,8 +11708,7 @@ snapshots: + '@types/jsonwebtoken@9.0.10': + dependencies: + '@types/ms': 2.1.0 +- '@types/node': 20.19.26 +- optional: true ++ '@types/node': 20.19.1 + + '@types/jsonwebtoken@9.0.6': + dependencies: +@@ -12016,8 +11724,7 @@ snapshots: + + '@types/mime@3.0.4': {} + +- '@types/ms@2.1.0': +- optional: true ++ '@types/ms@2.1.0': {} + + '@types/mysql@2.15.22': + dependencies: +@@ -12041,23 +11748,10 @@ snapshots: + dependencies: + undici-types: 6.21.0 + +- '@types/node@20.19.25': +- dependencies: +- undici-types: 6.21.0 +- +- '@types/node@20.19.26': +- dependencies: +- undici-types: 6.21.0 +- + '@types/node@22.15.32': + dependencies: + undici-types: 6.21.0 + +- '@types/node@22.19.2': +- dependencies: +- undici-types: 6.21.0 +- optional: true +- + '@types/pdf-parse@1.1.5': + dependencies: + '@types/node': 20.19.1 +@@ -12072,9 +11766,6 @@ snapshots: + pg-protocol: 1.6.0 + pg-types: 2.2.0 + +- '@types/qs@6.14.0': +- optional: true +- + '@types/qs@6.9.14': {} + + '@types/range-parser@1.2.7': {} +@@ -12097,24 +11788,6 @@ snapshots: + '@types/mime': 1.3.5 + '@types/node': 20.19.1 + +- '@types/send@0.17.6': +- dependencies: +- '@types/mime': 1.3.5 +- '@types/node': 20.19.26 +- optional: true +- +- '@types/send@1.2.1': +- dependencies: +- '@types/node': 20.19.26 +- optional: true +- +- '@types/serve-static@1.15.10': +- dependencies: +- '@types/http-errors': 2.0.5 +- '@types/node': 20.19.26 +- '@types/send': 0.17.6 +- optional: true +- + '@types/serve-static@1.15.5': + dependencies: + '@types/http-errors': 2.0.4 +@@ -12245,7 +11918,7 @@ snapshots: + dependencies: + type-fest: 0.21.3 + +- ansi-escapes@7.1.1: ++ ansi-escapes@7.2.0: + dependencies: + environment: 1.1.0 + +@@ -12276,7 +11949,7 @@ snapshots: + + archiver-utils@5.0.2: + dependencies: +- glob: 10.4.5 ++ glob: 10.5.0 + graceful-fs: 4.2.11 + is-stream: 2.0.1 + lazystream: 1.0.1 +@@ -12413,7 +12086,7 @@ snapshots: + + balanced-match@1.0.2: {} + +- bare-events@2.8.1: {} ++ bare-events@2.8.2: {} + + base-64@0.1.0: {} + +@@ -12460,23 +12133,6 @@ snapshots: + transitivePeerDependencies: + - supports-color + +- body-parser@1.20.4: +- dependencies: +- bytes: 3.1.2 +- content-type: 1.0.5 +- debug: 2.6.9 +- depd: 2.0.0 +- destroy: 1.2.0 +- http-errors: 2.0.1 +- iconv-lite: 0.4.24 +- on-finished: 2.4.1 +- qs: 6.14.0 +- raw-body: 2.5.3 +- type-is: 1.6.18 +- unpipe: 1.0.0 +- transitivePeerDependencies: +- - supports-color +- + body-parser@2.2.0: + dependencies: + bytes: 3.1.2 +@@ -12521,7 +12177,7 @@ snapshots: + + browserslist@4.24.0: + dependencies: +- caniuse-lite: 1.0.30001760 ++ caniuse-lite: 1.0.30001667 + electron-to-chromium: 1.5.33 + node-releases: 2.0.18 + update-browserslist-db: 1.1.1(browserslist@4.24.0) +@@ -12579,13 +12235,13 @@ snapshots: + dependencies: + '@npmcli/fs': 4.0.0 + fs-minipass: 3.0.3 +- glob: 10.4.5 ++ glob: 10.5.0 + lru-cache: 10.2.0 + minipass: 7.1.2 + minipass-collect: 2.0.1 + minipass-flush: 1.0.5 + minipass-pipeline: 1.2.4 +- p-map: 7.0.3 ++ p-map: 7.0.4 + ssri: 12.0.0 + tar: 7.5.2 + unique-filename: 4.0.0 +@@ -12601,7 +12257,7 @@ snapshots: + es-define-property: 1.0.0 + es-errors: 1.3.0 + function-bind: 1.1.2 +- get-intrinsic: 1.2.4 ++ get-intrinsic: 1.3.0 + set-function-length: 1.2.2 + + call-bind@1.0.8: +@@ -12624,7 +12280,7 @@ snapshots: + + camelcase@6.3.0: {} + +- caniuse-lite@1.0.30001760: {} ++ caniuse-lite@1.0.30001667: {} + + chalk@2.4.2: + dependencies: +@@ -12680,12 +12336,12 @@ snapshots: + chownr@3.0.0: + optional: true + +- chromadb@1.8.1(encoding@0.1.13)(openai@4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.76)): ++ chromadb@1.8.1(encoding@0.1.13)(openai@4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67)): + dependencies: + cliui: 8.0.1 + isomorphic-fetch: 3.0.0(encoding@0.1.13) + optionalDependencies: +- openai: 4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.76) ++ openai: 4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67) + transitivePeerDependencies: + - encoding + +@@ -12762,39 +12418,20 @@ snapshots: + dependencies: + color-name: 1.1.4 + +- color-convert@3.1.3: +- dependencies: +- color-name: 2.1.0 +- optional: true +- + color-name@1.1.3: {} + + color-name@1.1.4: {} + +- color-name@2.1.0: +- optional: true +- + color-string@1.9.1: + dependencies: + color-name: 1.1.4 + simple-swizzle: 0.2.2 + +- color-string@2.1.4: +- dependencies: +- color-name: 2.1.0 +- optional: true +- + color@3.2.1: + dependencies: + color-convert: 1.9.3 + color-string: 1.9.1 + +- color@5.0.3: +- dependencies: +- color-convert: 3.1.3 +- color-string: 2.1.4 +- optional: true +- + colorette@2.0.19: {} + + colorette@2.0.20: {} +@@ -12902,8 +12539,6 @@ snapshots: + + cookie-signature@1.0.6: {} + +- cookie-signature@1.0.7: {} +- + cookie-signature@1.2.2: {} + + cookie@0.7.1: {} +@@ -13033,11 +12668,6 @@ snapshots: + dependencies: + ms: 2.1.3 + +- debug@4.4.3: +- dependencies: +- ms: 2.1.3 +- optional: true +- + decamelize@1.2.0: {} + + dedent@1.5.3: {} +@@ -13122,11 +12752,6 @@ snapshots: + handlebars: 4.7.8 + yaml: 2.7.0 + +- dotprompt@1.1.2: +- dependencies: +- handlebars: 4.7.8 +- yaml: 2.8.2 +- + dunder-proto@1.0.1: + dependencies: + call-bind-apply-helpers: 1.0.2 +@@ -13256,7 +12881,7 @@ snapshots: + + es-define-property@1.0.0: + dependencies: +- get-intrinsic: 1.2.4 ++ get-intrinsic: 1.3.0 + + es-define-property@1.0.1: {} + +@@ -13349,7 +12974,7 @@ snapshots: + + events-universal@1.0.1: + dependencies: +- bare-events: 2.8.1 ++ bare-events: 2.8.2 + transitivePeerDependencies: + - bare-abort-controller + +@@ -13388,7 +13013,7 @@ snapshots: + content-type: 1.0.5 + deep-freeze: 0.0.1 + events-listener: 1.1.0 +- glob: 10.4.5 ++ glob: 10.5.0 + json-ptr: 3.1.1 + json-schema-traverse: 1.0.0 + lodash: 4.17.21 +@@ -13455,42 +13080,6 @@ snapshots: + transitivePeerDependencies: + - supports-color + +- express@4.22.1: +- dependencies: +- accepts: 1.3.8 +- array-flatten: 1.1.1 +- body-parser: 1.20.4 +- content-disposition: 0.5.4 +- content-type: 1.0.5 +- cookie: 0.7.2 +- cookie-signature: 1.0.7 +- debug: 2.6.9 +- depd: 2.0.0 +- encodeurl: 2.0.0 +- escape-html: 1.0.3 +- etag: 1.8.1 +- finalhandler: 1.3.2 +- fresh: 0.5.2 +- http-errors: 2.0.1 +- merge-descriptors: 1.0.3 +- methods: 1.1.2 +- on-finished: 2.4.1 +- parseurl: 1.3.3 +- path-to-regexp: 0.1.12 +- proxy-addr: 2.0.7 +- qs: 6.14.0 +- range-parser: 1.2.1 +- safe-buffer: 5.2.1 +- send: 0.19.1 +- serve-static: 1.16.2 +- setprototypeof: 1.2.0 +- statuses: 2.0.2 +- type-is: 1.6.18 +- utils-merge: 1.0.1 +- vary: 1.1.2 +- transitivePeerDependencies: +- - supports-color +- + express@5.1.0: + dependencies: + accepts: 2.0.0 +@@ -13602,18 +13191,6 @@ snapshots: + transitivePeerDependencies: + - supports-color + +- finalhandler@1.3.2: +- dependencies: +- debug: 2.6.9 +- encodeurl: 2.0.0 +- escape-html: 1.0.3 +- on-finished: 2.4.1 +- parseurl: 1.3.3 +- statuses: 2.0.2 +- unpipe: 1.0.0 +- transitivePeerDependencies: +- - supports-color +- + finalhandler@2.1.0: + dependencies: + debug: 4.4.1 +@@ -13636,18 +13213,18 @@ snapshots: + + firebase-admin@12.3.1(encoding@0.1.13): + dependencies: +- '@fastify/busboy': 3.2.0 ++ '@fastify/busboy': 3.1.1 + '@firebase/database-compat': 1.0.10 +- '@firebase/database-types': 1.0.16 +- '@types/node': 22.19.2 ++ '@firebase/database-types': 1.0.14 ++ '@types/node': 22.15.32 + farmhash-modern: 1.1.0 + jsonwebtoken: 9.0.2 + jwks-rsa: 3.2.0 +- node-forge: 1.3.3 ++ node-forge: 1.3.1 + uuid: 10.0.0 + optionalDependencies: +- '@google-cloud/firestore': 7.11.6(encoding@0.1.13) +- '@google-cloud/storage': 7.18.0(encoding@0.1.13) ++ '@google-cloud/firestore': 7.11.1(encoding@0.1.13) ++ '@google-cloud/storage': 7.16.0(encoding@0.1.13) + transitivePeerDependencies: + - encoding + - supports-color +@@ -13672,9 +13249,9 @@ snapshots: + - encoding + - supports-color + +- firebase-admin@13.5.0(encoding@0.1.13): ++ firebase-admin@13.6.0(encoding@0.1.13): + dependencies: +- '@fastify/busboy': 3.0.0 ++ '@fastify/busboy': 3.1.1 + '@firebase/database-compat': 2.0.10 + '@firebase/database-types': 1.0.14 + '@types/node': 22.15.32 +@@ -13682,11 +13259,11 @@ snapshots: + fast-deep-equal: 3.1.3 + google-auth-library: 9.15.1(encoding@0.1.13) + jsonwebtoken: 9.0.2 +- jwks-rsa: 3.1.0 ++ jwks-rsa: 3.2.0 + node-forge: 1.3.1 + uuid: 11.1.0 + optionalDependencies: +- '@google-cloud/firestore': 7.11.0(encoding@0.1.13) ++ '@google-cloud/firestore': 7.11.1(encoding@0.1.13) + '@google-cloud/storage': 7.16.0(encoding@0.1.13) + transitivePeerDependencies: + - encoding +@@ -13703,13 +13280,13 @@ snapshots: + transitivePeerDependencies: + - supports-color + +- firebase-functions@6.3.2(firebase-admin@13.5.0(encoding@0.1.13)): ++ firebase-functions@6.3.2(firebase-admin@13.6.0(encoding@0.1.13)): + dependencies: + '@types/cors': 2.8.19 + '@types/express': 4.17.23 + cors: 2.8.5 + express: 4.21.2 +- firebase-admin: 13.5.0(encoding@0.1.13) ++ firebase-admin: 13.6.0(encoding@0.1.13) + protobufjs: 7.3.2 + transitivePeerDependencies: + - supports-color +@@ -13740,11 +13317,11 @@ snapshots: + exegesis-express: 4.0.0 + express: 4.21.2 + filesize: 6.4.0 +- form-data: 4.0.4 ++ form-data: 4.0.5 + fs-extra: 10.1.0 + fuzzy: 0.1.3 + gaxios: 6.7.1(encoding@0.1.13) +- glob: 10.4.5 ++ glob: 10.5.0 + google-auth-library: 9.15.1(encoding@0.1.13) + inquirer: 8.2.7(@types/node@20.19.1) + inquirer-autocomplete-prompt: 2.0.1(inquirer@8.2.7(@types/node@20.19.1)) +@@ -13769,7 +13346,7 @@ snapshots: + proxy-agent: 6.5.0 + retry: 0.13.1 + semver: 7.7.2 +- sql-formatter: 15.6.10 ++ sql-formatter: 15.6.12 + stream-chain: 2.2.5 + stream-json: 1.9.1 + superstatic: 9.2.0(encoding@0.1.13) +@@ -13860,7 +13437,7 @@ snapshots: + combined-stream: 1.0.8 + mime-types: 2.1.35 + +- form-data@4.0.4: ++ form-data@4.0.5: + dependencies: + asynckit: 0.4.0 + combined-stream: 1.0.8 +@@ -13986,10 +13563,10 @@ snapshots: + transitivePeerDependencies: + - supports-color + +- genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1): ++ genkit@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1): + dependencies: +- '@genkit-ai/ai': 1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)) +- '@genkit-ai/core': 1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)) ++ '@genkit-ai/ai': 1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)) ++ '@genkit-ai/core': 1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)) + uuid: 10.0.0 + transitivePeerDependencies: + - '@google-cloud/firestore' +@@ -13999,10 +13576,10 @@ snapshots: + - supports-color + optional: true + +- genkitx-openai@0.10.1(@genkit-ai/ai@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit))(@genkit-ai/core@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit))(encoding@0.1.13)(ws@8.18.3): ++ genkitx-openai@0.10.1(@genkit-ai/ai@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit))(@genkit-ai/core@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit))(encoding@0.1.13)(ws@8.18.3): + dependencies: +- '@genkit-ai/ai': 1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit) +- '@genkit-ai/core': 1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit) ++ '@genkit-ai/ai': 1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit) ++ '@genkit-ai/core': 1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit) + openai: 4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67) + zod: 3.25.67 + transitivePeerDependencies: +@@ -14091,7 +13668,7 @@ snapshots: + minipass: 7.1.2 + path-scurry: 1.10.2 + +- glob@10.4.5: ++ glob@10.5.0: + dependencies: + foreground-child: 3.1.1 + jackspeak: 3.4.3 +@@ -14150,7 +13727,7 @@ snapshots: + gaxios: 5.1.3(encoding@0.1.13) + gcp-metadata: 5.3.0(encoding@0.1.13) + gtoken: 6.1.2(encoding@0.1.13) +- jws: 4.0.1 ++ jws: 4.0.0 + lru-cache: 6.0.0 + transitivePeerDependencies: + - encoding +@@ -14189,7 +13766,7 @@ snapshots: + + google-gax@5.0.6: + dependencies: +- '@grpc/grpc-js': 1.14.2 ++ '@grpc/grpc-js': 1.14.3 + '@grpc/proto-loader': 0.8.0 + duplexify: 4.1.3 + google-auth-library: 10.5.0 +@@ -14197,7 +13774,7 @@ snapshots: + node-fetch: 3.3.2 + object-hash: 3.0.0 + proto3-json-serializer: 3.0.4 +- protobufjs: 7.5.4 ++ protobufjs: 7.5.3 + retry-request: 8.0.2 + rimraf: 5.0.10 + transitivePeerDependencies: +@@ -14207,7 +13784,7 @@ snapshots: + + google-p12-pem@4.0.1: + dependencies: +- node-forge: 1.3.3 ++ node-forge: 1.3.1 + optional: true + + googleapis-common@7.2.0(encoding@0.1.13): +@@ -14248,7 +13825,7 @@ snapshots: + dependencies: + gaxios: 5.1.3(encoding@0.1.13) + google-p12-pem: 4.0.1 +- jws: 4.0.1 ++ jws: 4.0.0 + transitivePeerDependencies: + - encoding + - supports-color +@@ -14337,14 +13914,6 @@ snapshots: + statuses: 2.0.1 + toidentifier: 1.0.1 + +- http-errors@2.0.1: +- dependencies: +- depd: 2.0.0 +- inherits: 2.0.4 +- setprototypeof: 1.2.0 +- statuses: 2.0.2 +- toidentifier: 1.0.1 +- + http-parser-js@0.5.10: {} + + http-proxy-agent@5.0.0: +@@ -14397,7 +13966,7 @@ snapshots: + dependencies: + safer-buffer: 2.1.2 + +- iconv-lite@0.7.0: ++ iconv-lite@0.7.1: + dependencies: + safer-buffer: 2.1.2 + +@@ -14448,7 +14017,7 @@ snapshots: + + inquirer@8.2.7(@types/node@20.19.1): + dependencies: +- '@inquirer/external-editor': 1.0.2(@types/node@20.19.1) ++ '@inquirer/external-editor': 1.0.3(@types/node@20.19.1) + ansi-escapes: 4.3.2 + chalk: 4.1.2 + cli-cursor: 3.1.0 +@@ -14477,7 +14046,7 @@ snapshots: + + interpret@2.2.0: {} + +- ip-address@10.0.1: {} ++ ip-address@10.1.0: {} + + ip-regex@4.3.0: {} + +@@ -14691,7 +14260,7 @@ snapshots: + '@babel/parser': 7.25.7 + '@istanbuljs/schema': 0.1.3 + istanbul-lib-coverage: 3.2.2 +- semver: 7.7.3 ++ semver: 7.7.2 + transitivePeerDependencies: + - supports-color + +@@ -15126,8 +14695,7 @@ snapshots: + + jose@4.15.5: {} + +- jose@4.15.9: +- optional: true ++ jose@4.15.9: {} + + joycon@3.1.1: {} + +@@ -15222,13 +14790,6 @@ snapshots: + ecdsa-sig-formatter: 1.0.11 + safe-buffer: 5.2.1 + +- jwa@2.0.1: +- dependencies: +- buffer-equal-constant-time: 1.0.1 +- ecdsa-sig-formatter: 1.0.11 +- safe-buffer: 5.2.1 +- optional: true +- + jwks-rsa@3.1.0: + dependencies: + '@types/express': 4.17.23 +@@ -15242,15 +14803,14 @@ snapshots: + + jwks-rsa@3.2.0: + dependencies: +- '@types/express': 4.17.25 ++ '@types/express': 4.17.23 + '@types/jsonwebtoken': 9.0.10 +- debug: 4.4.3 ++ debug: 4.4.1 + jose: 4.15.9 + limiter: 1.1.5 + lru-memoizer: 2.3.0 + transitivePeerDependencies: + - supports-color +- optional: true + + jws@3.2.2: + dependencies: +@@ -15262,12 +14822,6 @@ snapshots: + jwa: 2.0.0 + safe-buffer: 5.2.1 + +- jws@4.0.1: +- dependencies: +- jwa: 2.0.1 +- safe-buffer: 5.2.1 +- optional: true +- + kind-of@3.2.2: + dependencies: + is-buffer: 1.1.6 +@@ -15297,7 +14851,7 @@ snapshots: + + kuler@2.0.0: {} + +- langchain@0.1.37(@google-cloud/storage@7.18.0(encoding@0.1.13))(@pinecone-database/pinecone@2.2.2)(chromadb@1.9.2(encoding@0.1.13)(openai@4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67)))(encoding@0.1.13)(fast-xml-parser@4.5.3)(firebase-admin@12.3.1(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(handlebars@4.7.8)(ignore@5.3.1)(jsonwebtoken@9.0.2)(lodash@4.17.21)(pdf-parse@1.1.1)(pg@8.16.2)(ws@8.18.3): ++ langchain@0.1.37(@google-cloud/storage@7.16.0(encoding@0.1.13))(@pinecone-database/pinecone@2.2.2)(chromadb@1.9.2(encoding@0.1.13)(openai@4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67)))(encoding@0.1.13)(fast-xml-parser@4.5.3)(firebase-admin@12.3.1(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(handlebars@4.7.8)(ignore@5.3.1)(jsonwebtoken@9.0.2)(lodash@4.17.21)(pdf-parse@1.1.1)(pg@8.16.2)(ws@8.18.3): + dependencies: + '@anthropic-ai/sdk': 0.9.1(encoding@0.1.13) + '@langchain/community': 0.0.53(@pinecone-database/pinecone@2.2.2)(chromadb@1.9.2(encoding@0.1.13)(openai@4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67)))(encoding@0.1.13)(firebase-admin@12.3.1(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(jsonwebtoken@9.0.2)(lodash@4.17.21)(pg@8.16.2)(ws@8.18.3) +@@ -15318,7 +14872,7 @@ snapshots: + zod: 3.25.67 + zod-to-json-schema: 3.24.5(zod@3.25.67) + optionalDependencies: +- '@google-cloud/storage': 7.18.0(encoding@0.1.13) ++ '@google-cloud/storage': 7.16.0(encoding@0.1.13) + '@pinecone-database/pinecone': 2.2.2 + chromadb: 1.9.2(encoding@0.1.13)(openai@4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67)) + fast-xml-parser: 4.5.3 +@@ -15545,7 +15099,6 @@ snapshots: + dependencies: + lodash.clonedeep: 4.5.0 + lru-cache: 6.0.0 +- optional: true + + lsofi@1.0.0: + dependencies: +@@ -15564,7 +15117,7 @@ snapshots: + + make-dir@4.0.0: + dependencies: +- semver: 7.7.3 ++ semver: 7.7.2 + + make-error@1.3.6: {} + +@@ -15602,7 +15155,7 @@ snapshots: + + marked-terminal@7.3.0(marked@13.0.3): + dependencies: +- ansi-escapes: 7.1.1 ++ ansi-escapes: 7.2.0 + ansi-regex: 6.2.2 + chalk: 5.6.2 + cli-highlight: 2.1.11 +@@ -15816,24 +15369,24 @@ snapshots: + + netmask@2.0.2: {} + +- next@15.4.10(@babel/core@7.25.7)(@opentelemetry/api@1.9.0)(react-dom@18.3.1(react@18.3.1))(react@18.3.1): ++ next@15.5.9(@babel/core@7.25.7)(@opentelemetry/api@1.9.0)(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + dependencies: +- '@next/env': 15.4.10 ++ '@next/env': 15.5.9 + '@swc/helpers': 0.5.15 +- caniuse-lite: 1.0.30001760 ++ caniuse-lite: 1.0.30001667 + postcss: 8.4.31 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + styled-jsx: 5.1.6(@babel/core@7.25.7)(react@18.3.1) + optionalDependencies: +- '@next/swc-darwin-arm64': 15.4.8 +- '@next/swc-darwin-x64': 15.4.8 +- '@next/swc-linux-arm64-gnu': 15.4.8 +- '@next/swc-linux-arm64-musl': 15.4.8 +- '@next/swc-linux-x64-gnu': 15.4.8 +- '@next/swc-linux-x64-musl': 15.4.8 +- '@next/swc-win32-arm64-msvc': 15.4.8 +- '@next/swc-win32-x64-msvc': 15.4.8 ++ '@next/swc-darwin-arm64': 15.5.7 ++ '@next/swc-darwin-x64': 15.5.7 ++ '@next/swc-linux-arm64-gnu': 15.5.7 ++ '@next/swc-linux-arm64-musl': 15.5.7 ++ '@next/swc-linux-x64-gnu': 15.5.7 ++ '@next/swc-linux-x64-musl': 15.5.7 ++ '@next/swc-win32-arm64-msvc': 15.5.7 ++ '@next/swc-win32-x64-msvc': 15.5.7 + '@opentelemetry/api': 1.9.0 + sharp: 0.34.5 + transitivePeerDependencies: +@@ -15865,9 +15418,6 @@ snapshots: + + node-forge@1.3.1: {} + +- node-forge@1.3.3: +- optional: true +- + node-gyp@11.5.0: + dependencies: + env-paths: 2.2.1 +@@ -15876,7 +15426,7 @@ snapshots: + make-fetch-happen: 14.0.3 + nopt: 8.1.0 + proc-log: 5.0.0 +- semver: 7.7.3 ++ semver: 7.7.2 + tar: 7.5.2 + tinyglobby: 0.2.14 + which: 5.0.0 +@@ -15990,21 +15540,6 @@ snapshots: + transitivePeerDependencies: + - encoding + +- openai@4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.76): +- dependencies: +- '@types/node': 18.19.112 +- '@types/node-fetch': 2.6.11 +- abort-controller: 3.0.0 +- agentkeepalive: 4.5.0 +- form-data-encoder: 1.7.2 +- formdata-node: 4.4.1 +- node-fetch: 2.7.0(encoding@0.1.13) +- optionalDependencies: +- ws: 8.18.3 +- zod: 3.25.76 +- transitivePeerDependencies: +- - encoding +- + openapi-types@12.1.3: {} + + openapi3-ts@3.2.0: +@@ -16047,7 +15582,7 @@ snapshots: + dependencies: + p-limit: 2.3.0 + +- p-map@7.0.3: ++ p-map@7.0.4: + optional: true + + p-queue@6.6.2: +@@ -16265,14 +15800,6 @@ snapshots: + tsx: 4.20.3 + yaml: 2.8.0 + +- postcss-load-config@6.0.1(postcss@8.4.47)(tsx@4.20.3)(yaml@2.8.2): +- dependencies: +- lilconfig: 3.1.2 +- optionalDependencies: +- postcss: 8.4.47 +- tsx: 4.20.3 +- yaml: 2.8.2 +- + postcss@8.4.31: + dependencies: + nanoid: 3.3.11 +@@ -16334,7 +15861,7 @@ snapshots: + + proto3-json-serializer@3.0.4: + dependencies: +- protobufjs: 7.5.4 ++ protobufjs: 7.5.3 + + protobuf.js@1.1.2: + dependencies: +@@ -16355,7 +15882,7 @@ snapshots: + '@types/node': 20.19.1 + long: 5.2.3 + +- protobufjs@7.5.4: ++ protobufjs@7.5.3: + dependencies: + '@protobufjs/aspromise': 1.1.2 + '@protobufjs/base64': 1.1.2 +@@ -16367,7 +15894,7 @@ snapshots: + '@protobufjs/path': 1.1.2 + '@protobufjs/pool': 1.1.0 + '@protobufjs/utf8': 1.1.0 +- '@types/node': 20.19.25 ++ '@types/node': 20.19.1 + long: 5.3.2 + + proxy-addr@2.0.7: +@@ -16439,13 +15966,6 @@ snapshots: + iconv-lite: 0.4.24 + unpipe: 1.0.0 + +- raw-body@2.5.3: +- dependencies: +- bytes: 3.1.2 +- http-errors: 2.0.1 +- iconv-lite: 0.4.24 +- unpipe: 1.0.0 +- + raw-body@3.0.0: + dependencies: + bytes: 3.1.2 +@@ -16460,7 +15980,7 @@ snapshots: + minimist: 1.2.8 + strip-json-comments: 2.0.1 + +- re2@1.22.1: ++ re2@1.22.3: + dependencies: + install-artifact-from-github: 1.4.0 + nan: 2.24.0 +@@ -16702,9 +16222,6 @@ snapshots: + + safe-stable-stringify@2.4.3: {} + +- safe-stable-stringify@2.5.0: +- optional: true +- + safer-buffer@2.1.2: {} + + scheduler@0.23.2: +@@ -16725,7 +16242,8 @@ snapshots: + + semver@7.7.2: {} + +- semver@7.7.3: {} ++ semver@7.7.3: ++ optional: true + + send@0.19.0: + dependencies: +@@ -16745,24 +16263,6 @@ snapshots: + transitivePeerDependencies: + - supports-color + +- send@0.19.1: +- dependencies: +- debug: 2.6.9 +- depd: 2.0.0 +- destroy: 1.2.0 +- encodeurl: 2.0.0 +- escape-html: 1.0.3 +- etag: 1.8.1 +- fresh: 0.5.2 +- http-errors: 2.0.0 +- mime: 1.6.0 +- ms: 2.1.3 +- on-finished: 2.4.1 +- range-parser: 1.2.1 +- statuses: 2.0.1 +- transitivePeerDependencies: +- - supports-color +- + send@1.2.0: + dependencies: + debug: 4.4.1 +@@ -16934,7 +16434,7 @@ snapshots: + + socks@2.8.7: + dependencies: +- ip-address: 10.0.1 ++ ip-address: 10.1.0 + smart-buffer: 4.2.0 + + sort-any@2.0.0: +@@ -16972,7 +16472,7 @@ snapshots: + + sprintf-js@1.0.3: {} + +- sql-formatter@15.6.10: ++ sql-formatter@15.6.12: + dependencies: + argparse: 2.0.1 + nearley: 2.20.1 +@@ -16992,8 +16492,6 @@ snapshots: + + statuses@2.0.1: {} + +- statuses@2.0.2: {} +- + stop-iteration-iterator@1.1.0: + dependencies: + es-errors: 1.3.0 +@@ -17145,7 +16643,7 @@ snapshots: + router: 2.2.0 + update-notifier-cjs: 5.1.7(encoding@0.1.13) + optionalDependencies: +- re2: 1.22.1 ++ re2: 1.22.3 + transitivePeerDependencies: + - encoding + - supports-color +@@ -17348,7 +16846,7 @@ snapshots: + ts-node@10.9.2(@types/node@20.19.1)(typescript@4.9.5): + dependencies: + '@cspotcode/source-map-support': 0.8.1 +- '@tsconfig/node10': 1.0.12 ++ '@tsconfig/node10': 1.0.11 + '@tsconfig/node12': 1.0.11 + '@tsconfig/node14': 1.0.3 + '@tsconfig/node16': 1.0.4 +@@ -17367,7 +16865,7 @@ snapshots: + ts-node@10.9.2(@types/node@20.19.1)(typescript@5.8.3): + dependencies: + '@cspotcode/source-map-support': 0.8.1 +- '@tsconfig/node10': 1.0.12 ++ '@tsconfig/node10': 1.0.11 + '@tsconfig/node12': 1.0.11 + '@tsconfig/node14': 1.0.3 + '@tsconfig/node16': 1.0.4 +@@ -17419,35 +16917,7 @@ snapshots: + - tsx + - yaml + +- tsup@8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2): +- dependencies: +- bundle-require: 5.1.0(esbuild@0.25.5) +- cac: 6.7.14 +- chokidar: 4.0.3 +- consola: 3.4.2 +- debug: 4.4.1 +- esbuild: 0.25.5 +- fix-dts-default-cjs-exports: 1.0.1 +- joycon: 3.1.1 +- picocolors: 1.1.1 +- postcss-load-config: 6.0.1(postcss@8.4.47)(tsx@4.20.3)(yaml@2.8.2) +- resolve-from: 5.0.0 +- rollup: 4.43.0 +- source-map: 0.8.0-beta.0 +- sucrase: 3.35.0 +- tinyexec: 0.3.2 +- tinyglobby: 0.2.14 +- tree-kill: 1.2.2 +- optionalDependencies: +- postcss: 8.4.47 +- typescript: 4.9.5 +- transitivePeerDependencies: +- - jiti +- - supports-color +- - tsx +- - yaml +- +- tsup@8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@5.8.3)(yaml@2.8.2): ++ tsup@8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@5.8.3)(yaml@2.8.0): + dependencies: + bundle-require: 5.1.0(esbuild@0.25.5) + cac: 6.7.14 +@@ -17458,7 +16928,7 @@ snapshots: + fix-dts-default-cjs-exports: 1.0.1 + joycon: 3.1.1 + picocolors: 1.1.1 +- postcss-load-config: 6.0.1(postcss@8.4.47)(tsx@4.20.3)(yaml@2.8.2) ++ postcss-load-config: 6.0.1(postcss@8.4.47)(tsx@4.20.3)(yaml@2.8.0) + resolve-from: 5.0.0 + rollup: 4.43.0 + source-map: 0.8.0-beta.0 +@@ -17825,21 +17295,6 @@ snapshots: + triple-beam: 1.4.1 + winston-transport: 4.9.0 + +- winston@3.19.0: +- dependencies: +- '@colors/colors': 1.6.0 +- '@dabh/diagnostics': 2.0.8 +- async: 3.2.6 +- is-stream: 2.0.1 +- logform: 2.7.0 +- one-time: 1.0.0 +- readable-stream: 3.6.2 +- safe-stable-stringify: 2.5.0 +- stack-trace: 0.0.10 +- triple-beam: 1.4.1 +- winston-transport: 4.9.0 +- optional: true +- + wordwrap@1.0.0: {} + + wrap-ansi@6.2.0: +@@ -17899,8 +17354,6 @@ snapshots: + + yaml@2.8.0: {} + +- yaml@2.8.2: {} +- + yargs-parser@20.2.9: {} + + yargs-parser@21.1.1: {} +@@ -17940,16 +17393,6 @@ snapshots: + dependencies: + zod: 3.25.67 + +- zod-to-json-schema@3.24.5(zod@3.25.76): +- dependencies: +- zod: 3.25.76 +- +- zod-to-json-schema@3.25.0(zod@3.25.76): +- dependencies: +- zod: 3.25.76 +- + zod@3.22.4: {} + + zod@3.25.67: {} +- +- zod@3.25.76: {} +diff --git a/js/testapps/anthropic/package.json b/js/testapps/anthropic/package.json +index a06532113..0f7ac1500 100644 +--- a/js/testapps/anthropic/package.json ++++ b/js/testapps/anthropic/package.json +@@ -10,6 +10,10 @@ + "start:beta": "node lib/beta/basic.js", + "dev:stable": "genkit start -- npx tsx --watch src/stable/basic.ts", + "dev:beta": "genkit start -- npx tsx --watch src/beta/basic.ts", ++ "dev:beta:structured-output": "genkit start -- npx tsx --watch src/beta/structured_output.ts", ++ "dev:beta:files-api": "genkit start -- npx tsx --watch src/beta/files_api.ts", ++ "dev:beta:effort": "genkit start -- npx tsx --watch src/beta/effort.ts", ++ "dev:beta:additional-params": "genkit start -- npx tsx --watch src/beta/additional_params.ts", + "dev:stable:text-plain": "genkit start -- npx tsx --watch src/stable/text-plain.ts", + "dev:stable:webp": "genkit start -- npx tsx --watch src/stable/webp.ts", + "dev:stable:pdf": "genkit start -- npx tsx --watch src/stable/pdf.ts", +@@ -27,8 +31,8 @@ + "author": "", + "license": "Apache-2.0", + "dependencies": { +- "genkit": "workspace:*", +- "@genkit-ai/anthropic": "workspace:*" ++ "@genkit-ai/anthropic": "workspace:*", ++ "genkit": "workspace:*" + }, + "devDependencies": { + "cross-env": "^10.1.0", +diff --git a/js/testapps/anthropic/src/stable/attention-first-page.pdf b/js/testapps/anthropic/src/attention-first-page.pdf +similarity index 100% +rename from js/testapps/anthropic/src/stable/attention-first-page.pdf +rename to js/testapps/anthropic/src/attention-first-page.pdf +diff --git a/js/testapps/anthropic/src/beta/additional_params.ts b/js/testapps/anthropic/src/beta/additional_params.ts +new file mode 100644 +index 000000000..2e443fc01 +--- /dev/null ++++ b/js/testapps/anthropic/src/beta/additional_params.ts +@@ -0,0 +1,83 @@ ++/** ++ * Copyright 2025 Google LLC ++ * ++ * Licensed under the Apache License, Version 2.0 (the "License"); ++ * you may not use this file except in compliance with the License. ++ * You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++import { anthropic } from '@genkit-ai/anthropic'; ++import { genkit } from 'genkit'; ++ ++const ai = genkit({ ++ plugins: [ ++ // Default all flows in this sample to the beta surface ++ anthropic({ ++ apiVersion: 'beta', ++ cacheSystemPrompt: true, ++ apiKey: process.env.ANTHROPIC_API_KEY, ++ }), ++ ], ++}); ++ ++const betaOpus45 = anthropic.model('claude-opus-4-5', { apiVersion: 'beta' }); ++ ++ai.defineFlow('anthropic-beta-additional-params', async () => { ++ const { text } = await ai.generate({ ++ model: betaOpus45, ++ prompt: ++ 'You are Claude on the beta API. Provide a concise greeting that mentions that you are using the beta API.', ++ config: { ++ temperature: 0.6, ++ // Additional param (not directly supported by the plugin, but can be passed through to the API) ++ betas: ['effort-2025-11-24'], ++ // Additional param (not directly supported by the plugin, but can be passed through to the API) ++ output_config: { ++ effort: 'medium', ++ }, ++ }, ++ }); ++ ++ return text; ++}); ++ ++ai.defineFlow( ++ 'anthropic-beta-additional-params-stream', ++ async (_, { sendChunk }) => { ++ const { stream } = ai.generateStream({ ++ model: betaOpus45, ++ prompt: [ ++ { ++ text: 'Outline two experimental capabilities unlocked by the Anthropic beta API.', ++ }, ++ ], ++ config: { ++ temperature: 0.4, ++ // Additional param (not directly supported by the plugin, but can be passed through to the API) ++ betas: ['effort-2025-11-24'], ++ // Additional param (not directly supported by the plugin, but can be passed through to the API) ++ output_config: { ++ effort: 'medium', ++ }, ++ }, ++ }); ++ ++ const collected: string[] = []; ++ for await (const chunk of stream) { ++ if (chunk.text) { ++ collected.push(chunk.text); ++ sendChunk(chunk.text); ++ } ++ } ++ ++ return collected.join(''); ++ } ++); +diff --git a/js/testapps/anthropic/src/beta/basic.ts b/js/testapps/anthropic/src/beta/basic.ts +index d1309b340..f9841f4c6 100644 +--- a/js/testapps/anthropic/src/beta/basic.ts ++++ b/js/testapps/anthropic/src/beta/basic.ts +@@ -15,12 +15,16 @@ + */ + + import { anthropic } from '@genkit-ai/anthropic'; +-import { genkit } from 'genkit'; ++import { genkit, z } from 'genkit'; + + const ai = genkit({ + plugins: [ + // Default all flows in this sample to the beta surface +- anthropic({ apiVersion: 'beta', cacheSystemPrompt: true }), ++ anthropic({ ++ apiVersion: 'beta', ++ cacheSystemPrompt: true, ++ apiKey: process.env.ANTHROPIC_API_KEY, ++ }), + ], + }); + +@@ -28,15 +32,21 @@ const betaHaiku = anthropic.model('claude-3-5-haiku', { apiVersion: 'beta' }); + const betaSonnet = anthropic.model('claude-sonnet-4-5', { apiVersion: 'beta' }); + const betaOpus41 = anthropic.model('claude-opus-4-1', { apiVersion: 'beta' }); + ++const GreetingSchema = z.object({ ++ greeting: z.string(), ++ apiVersion: z.string(), ++}); ++ + ai.defineFlow('anthropic-beta-hello', async () => { +- const { text } = await ai.generate({ ++ const { output } = await ai.generate({ + model: betaHaiku, + prompt: + 'You are Claude on the beta API. Provide a concise greeting that mentions that you are using the beta API.', + config: { temperature: 0.6 }, ++ output: { schema: GreetingSchema, format: 'json', constrained: true }, + }); + +- return text; ++ return output; + }); + + ai.defineFlow('anthropic-beta-stream', async (_, { sendChunk }) => { +diff --git a/js/testapps/anthropic/src/beta/effort.ts b/js/testapps/anthropic/src/beta/effort.ts +new file mode 100644 +index 000000000..03e83a60a +--- /dev/null ++++ b/js/testapps/anthropic/src/beta/effort.ts +@@ -0,0 +1,79 @@ ++/** ++ * Copyright 2025 Google LLC ++ * ++ * Licensed under the Apache License, Version 2.0 (the "License"); ++ * you may not use this file except in compliance with the License. ++ * You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++import { anthropic } from '@genkit-ai/anthropic'; ++import { genkit } from 'genkit'; ++ ++const ai = genkit({ ++ plugins: [ ++ // Default all flows in this sample to the beta surface ++ anthropic({ ++ apiVersion: 'beta', ++ cacheSystemPrompt: true, ++ apiKey: process.env.ANTHROPIC_API_KEY, ++ }), ++ ], ++}); ++ ++const betaOpus45 = anthropic.model('claude-opus-4-5', { apiVersion: 'beta' }); ++ ++ai.defineFlow('anthropic-beta-low-effort', async () => { ++ const { text } = await ai.generate({ ++ model: betaOpus45, ++ prompt: `Create me a Mathematics class using the programming language Python.`, ++ config: { ++ maxOutputTokens: 4096, ++ temperature: 0.6, ++ output_config: { ++ effort: 'low', ++ }, ++ }, ++ }); ++ ++ return text; ++}); ++ ++ai.defineFlow('anthropic-beta-medium-effort', async () => { ++ const { text } = await ai.generate({ ++ model: betaOpus45, ++ prompt: `Create me a Mathematics class using the programming language Python.`, ++ config: { ++ maxOutputTokens: 4096, ++ temperature: 0.6, ++ output_config: { ++ effort: 'medium', ++ }, ++ }, ++ }); ++ ++ return text; ++}); ++ ++ai.defineFlow('anthropic-beta-high-effort', async () => { ++ const { text } = await ai.generate({ ++ model: betaOpus45, ++ prompt: `Create me a Mathematics class using the programming language Python.`, ++ config: { ++ maxOutputTokens: 4096, ++ temperature: 0.6, ++ output_config: { ++ effort: 'high', ++ }, ++ }, ++ }); ++ ++ return text; ++}); +diff --git a/js/testapps/anthropic/src/beta/files_api.ts b/js/testapps/anthropic/src/beta/files_api.ts +new file mode 100644 +index 000000000..3dd6b08d0 +--- /dev/null ++++ b/js/testapps/anthropic/src/beta/files_api.ts +@@ -0,0 +1,114 @@ ++/** ++ * Copyright 2025 Google LLC ++ * ++ * Licensed under the Apache License, Version 2.0 (the "License"); ++ * you may not use this file except in compliance with the License. ++ * You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++import { anthropic } from '@genkit-ai/anthropic'; ++import * as fs from 'fs'; ++import { genkit } from 'genkit'; ++import * as path from 'path'; ++ ++// Ensure the API key is set. ++const API_KEY = process.env.ANTHROPIC_API_KEY; ++// If you have a file ID, you can set it here. Otherwise, the flow will upload a new PDF to Anthropic. ++const FILE_ID = process.env.ANTHROPIC_FILE_ID; ++ ++export async function uploadPdfToAnthropic() { ++ if (!API_KEY) throw new Error('Missing ANTHROPIC_API_KEY env variable'); ++ ++ // Path to the PDF file to upload ++ const pdfPath = path.join(__dirname, '../attention-first-page.pdf'); ++ const fileBuffer = fs.readFileSync(pdfPath); ++ ++ const form = new FormData(); ++ form.append( ++ 'file', ++ new Blob([fileBuffer], { type: 'application/pdf' }), ++ 'attention-first-page.pdf' ++ ); ++ ++ const response = await fetch('https://api.anthropic.com/v1/files', { ++ method: 'POST', ++ headers: { ++ 'x-api-key': API_KEY, ++ 'anthropic-version': '2023-06-01', ++ 'anthropic-beta': 'files-api-2025-04-14', ++ }, ++ body: form, ++ }); ++ ++ if (!response.ok) { ++ const text = await response.text(); ++ throw new Error(`Anthropic file upload failed: ${response.status} ${text}`); ++ } ++ const result = await response.json(); ++ return result as { id: string }; // Contains 'file_id', etc. ++} ++ ++async function main() { ++ const ai = genkit({ ++ plugins: [ ++ // Default all flows in this sample to the beta surface ++ anthropic({ ++ apiVersion: 'beta', ++ apiKey: API_KEY, ++ }), ++ ], ++ }); ++ ++ /** ++ * This flow demonstrates PDF document processing via a public data URL along with a user prompt. ++ * The PDF is sent as a media part with the correct contentType and a URL, not base64. ++ */ ++ ai.defineFlow('beta-pdf-url', async () => { ++ let fileId = FILE_ID; ++ ++ if (!fileId) { ++ const fileResult = await uploadPdfToAnthropic(); ++ if (!fileResult || !fileResult.id) { ++ throw new Error('File ID not found'); ++ } ++ fileId = fileResult.id; ++ } ++ ++ // Example: Use a (demo/test) PDF file accessible via public URL. ++ // Replace this with your actual PDF if needed. ++ const { text } = await ai.generate({ ++ model: anthropic.model('claude-sonnet-4-5'), ++ messages: [ ++ { ++ role: 'user', ++ content: [ ++ { ++ text: 'What are the key findings or main points in this document?', ++ }, ++ { ++ media: { ++ url: fileId, ++ contentType: 'anthropic/file', ++ }, ++ }, ++ ], ++ }, ++ ], ++ }); ++ ++ return text; ++ }); ++} ++ ++main().catch((error) => { ++ console.error('Error:', error); ++ process.exit(1); ++}); +diff --git a/js/testapps/anthropic/src/beta/structured_output.ts b/js/testapps/anthropic/src/beta/structured_output.ts +new file mode 100644 +index 000000000..dbd3f5ca1 +--- /dev/null ++++ b/js/testapps/anthropic/src/beta/structured_output.ts +@@ -0,0 +1,84 @@ ++/** ++ * Copyright 2025 Google LLC ++ * ++ * Licensed under the Apache License, Version 2.0 (the "License"); ++ * you may not use this file except in compliance with the License. ++ * You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++import { anthropic } from '@genkit-ai/anthropic'; ++import { genkit, z } from 'genkit'; ++ ++const ai = genkit({ ++ plugins: [ ++ // Default all flows in this sample to the beta surface ++ anthropic({ ++ apiVersion: 'beta', ++ cacheSystemPrompt: true, ++ apiKey: process.env.ANTHROPIC_API_KEY, ++ }), ++ ], ++}); ++ ++const betaSonnet = anthropic.model('claude-sonnet-4-5', { apiVersion: 'beta' }); ++ ++ai.defineFlow('anthropic-beta-generate-person-json', async () => { ++ const { text } = await ai.generate({ ++ model: betaSonnet, ++ prompt: ++ 'Generate a fictional person with a random name, random age, and random city.', ++ config: { temperature: 0.6 }, ++ output: { ++ schema: z.object({ ++ name: z.string(), ++ age: z.number(), ++ city: z.string(), ++ }), ++ format: 'json', ++ constrained: true, ++ }, ++ }); ++ ++ return text; ++}); ++ ++ai.defineFlow( ++ 'anthropic-beta-generate-person-json-stream', ++ async (_, { sendChunk }) => { ++ const { stream } = ai.generateStream({ ++ model: betaSonnet, ++ prompt: [ ++ { ++ text: 'Generate a fictional person with a random name, random age, and random city.', ++ }, ++ ], ++ config: { temperature: 0.6 }, ++ output: { ++ schema: z.object({ ++ name: z.string(), ++ age: z.number(), ++ city: z.string(), ++ }), ++ format: 'json', ++ }, ++ }); ++ ++ const collected: any[] = []; ++ for await (const chunk of stream) { ++ if (chunk.text) { ++ collected.push(chunk.output); ++ sendChunk(chunk.output); ++ } ++ } ++ ++ return collected.join(''); ++ } ++); +diff --git a/js/testapps/anthropic/src/stable/pdf.ts b/js/testapps/anthropic/src/stable/pdf.ts +index 8953dff69..07be5f665 100644 +--- a/js/testapps/anthropic/src/stable/pdf.ts ++++ b/js/testapps/anthropic/src/stable/pdf.ts +@@ -29,7 +29,7 @@ const ai = genkit({ + */ + ai.defineFlow('stable-pdf-base64', async () => { + // Read PDF file from the same directory as this source file +- const pdfPath = path.join(__dirname, 'attention-first-page.pdf'); ++ const pdfPath = path.join(__dirname, '../attention-first-page.pdf'); + const pdfBuffer = fs.readFileSync(pdfPath); + const pdfBase64 = pdfBuffer.toString('base64'); + +diff --git a/js/testapps/basic-gemini/src/index-vertexai.ts b/js/testapps/basic-gemini/src/index-vertexai.ts +index 2e6c59e96..2793e91bc 100644 +--- a/js/testapps/basic-gemini/src/index-vertexai.ts ++++ b/js/testapps/basic-gemini/src/index-vertexai.ts +@@ -40,8 +40,8 @@ ai.defineFlow('basic-hi', async () => { + // Gemini 3.0 thinkingLevel config + ai.defineFlow( + { +- name: 'thinking-level', +- inputSchema: z.enum(['LOW', 'MEDIUM', 'HIGH']), ++ name: 'thinking-level-pro', ++ inputSchema: z.enum(['LOW', 'HIGH']), + outputSchema: z.any(), + }, + async (level) => { +@@ -66,6 +66,34 @@ ai.defineFlow( + } + ); + ++ai.defineFlow( ++ { ++ name: 'thinking-level-flash', ++ inputSchema: z.enum(['MINIMAL', 'LOW', 'MEDIUM', 'HIGH']), ++ outputSchema: z.any(), ++ }, ++ async (level) => { ++ const { text } = await ai.generate({ ++ model: vertexAI.model('gemini-3-flash-preview'), ++ prompt: ++ 'Alice, Bob, and Carol each live in a different house on the ' + ++ 'same street: red, green, and blue. The person who lives in the red house ' + ++ 'owns a cat. Bob does not live in the green house. Carol owns a dog. The ' + ++ 'green house is to the left of the red house. Alice does not own a cat. ' + ++ 'The person in the blue house owns a fish. ' + ++ 'Who lives in each house, and what pet do they own? Provide your ' + ++ 'step-by-step reasoning.', ++ config: { ++ location: 'global', ++ thinkingConfig: { ++ thinkingLevel: level, ++ }, ++ }, ++ }); ++ return text; ++ } ++); ++ + // Multimodal input + ai.defineFlow('multimodal-input', async () => { + const photoBase64 = fs.readFileSync('photo.jpg', { encoding: 'base64' }); +diff --git a/js/testapps/basic-gemini/src/index.ts b/js/testapps/basic-gemini/src/index.ts +index d9bfc9848..40246ef2b 100644 +--- a/js/testapps/basic-gemini/src/index.ts ++++ b/js/testapps/basic-gemini/src/index.ts +@@ -80,11 +80,11 @@ ai.defineFlow('basic-hi-with-fallback', async () => { + return text; + }); + +-// Gemini 3.0 thinkingLevel config ++// Gemini 3.0 thinkingLevel config. Pro can have Low or High + ai.defineFlow( + { +- name: 'thinking-level', +- inputSchema: z.enum(['LOW', 'MEDIUM', 'HIGH']), ++ name: 'thinking-level-pro', ++ inputSchema: z.enum(['LOW', 'HIGH']), + }, + async (level) => { + const { text } = await ai.generate({ +@@ -107,6 +107,33 @@ ai.defineFlow( + } + ); + ++// Gemini 3 Flash can have minimal and medium thinking levels too. ++ai.defineFlow( ++ { ++ name: 'thinking-level-flash', ++ inputSchema: z.enum(['MINIMAL', 'LOW', 'MEDIUM', 'HIGH']), ++ }, ++ async (level) => { ++ const { text } = await ai.generate({ ++ model: googleAI.model('gemini-3-flash-preview'), ++ prompt: ++ 'Alice, Bob, and Carol each live in a different house on the ' + ++ 'same street: red, green, and blue. The person who lives in the red house ' + ++ 'owns a cat. Bob does not live in the green house. Carol owns a dog. The ' + ++ 'green house is to the left of the red house. Alice does not own a cat. ' + ++ 'The person in the blue house owns a fish. ' + ++ 'Who lives in each house, and what pet do they own? Provide your ' + ++ 'step-by-step reasoning.', ++ config: { ++ thinkingConfig: { ++ thinkingLevel: level, ++ }, ++ }, ++ }); ++ return text; ++ } ++); ++ + // Multimodal input + ai.defineFlow('multimodal-input', async () => { + const photoBase64 = fs.readFileSync('photo.jpg', { encoding: 'base64' }); +diff --git a/py/bin/sanitize_schema_typing.py b/py/bin/sanitize_schema_typing.py +index 6138127e9..fa74fe5e7 100644 +--- a/py/bin/sanitize_schema_typing.py ++++ b/py/bin/sanitize_schema_typing.py +@@ -42,10 +42,9 @@ Transformations applied: + + import ast + import sys +-from _ast import AST + from datetime import datetime + from pathlib import Path +-from typing import Type, cast ++from typing import Any, Type, cast + + + class ClassTransformer(ast.NodeTransformer): +@@ -118,7 +117,18 @@ class ClassTransformer(ast.NodeTransformer): + return item + return None + +- def visit_ClassDef(self, _node: ast.ClassDef) -> ast.ClassDef: # noqa: N802 ++ def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign: ++ """Visit and transform annotated assignment.""" ++ if isinstance(node.annotation, ast.Name) and node.annotation.id == 'Role': ++ node.annotation = ast.BinOp( ++ left=ast.Name(id='Role', ctx=ast.Load()), ++ op=ast.BitOr(), ++ right=ast.Name(id='str', ctx=ast.Load()), ++ ) ++ self.modified = True ++ return node ++ ++ def visit_ClassDef(self, node: ast.ClassDef) -> Any: + """Visit and transform a class definition node. + + Args: +@@ -128,11 +138,16 @@ class ClassTransformer(ast.NodeTransformer): + The transformed ClassDef node. + """ + # First apply base class transformations recursively +- node = super().generic_visit(_node) ++ node = cast(ast.ClassDef, super().generic_visit(node)) + new_body: list[ast.stmt | ast.Constant | ast.Assign] = [] + + # Handle Docstrings +- if not node.body or not isinstance(node.body[0], ast.Expr) or not isinstance(node.body[0].value, ast.Constant): ++ if ( ++ not node.body ++ or not isinstance(node.body[0], ast.Expr) ++ or not isinstance(node.body[0].value, ast.Constant) ++ or not isinstance(node.body[0].value.value, str) ++ ): + # Generate a more descriptive docstring based on class type + if self.is_rootmodel_class(node): + docstring = f'Root model for {node.name.lower().replace("_", " ")}.' +@@ -151,13 +166,21 @@ class ClassTransformer(ast.NodeTransformer): + + # Handle model_config for BaseModel and RootModel + existing_model_config_assign = self.has_model_config(node) ++ + existing_model_config_call = None + if existing_model_config_assign and isinstance(existing_model_config_assign.value, ast.Call): + existing_model_config_call = existing_model_config_assign.value + + # Determine start index for iterating original body (skip docstring) + body_start_index = ( +- 1 if (node.body and isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Str)) else 0 ++ 1 ++ if ( ++ node.body ++ and isinstance(node.body[0], ast.Expr) ++ and isinstance(node.body[0].value, ast.Constant) ++ and isinstance(node.body[0].value.value, str) ++ ) ++ else 0 + ) + + if self.is_rootmodel_class(node): +diff --git a/py/noxfile.py b/py/noxfile.py +index 0e4a68ad9..7f715a304 100644 +--- a/py/noxfile.py ++++ b/py/noxfile.py +@@ -74,5 +74,5 @@ def lint(session: nox.Session) -> None: + session.run('uv', 'run', 'ruff', 'format', '--check', '.', external=True) + session.log('Running ruff checks') + session.run('uv', 'run', 'ruff', 'check', '--preview', '--unsafe-fixes', '--fix', '.', external=True) +- # session.log("Running mypy checks") # mypy has many errors currently +- # session.run("mypy", external=True) ++ session.log('Running Ty checks') ++ session.run('uv', 'run', 'ty', 'check', '.', external=True) +diff --git a/py/packages/genkit/pyproject.toml b/py/packages/genkit/pyproject.toml +index ab4f17a28..b7921df27 100644 +--- a/py/packages/genkit/pyproject.toml ++++ b/py/packages/genkit/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -52,7 +51,7 @@ dependencies = [ + "anyio>=4.9.0", + ] + description = "Genkit AI Framework" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "genkit" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/packages/genkit/src/genkit/ai/__init__.py b/py/packages/genkit/src/genkit/ai/__init__.py +index aeb0268c1..c120f02c4 100644 +--- a/py/packages/genkit/src/genkit/ai/__init__.py ++++ b/py/packages/genkit/src/genkit/ai/__init__.py +@@ -37,7 +37,7 @@ from genkit.core.action import ActionRunContext + from genkit.core.action.types import ActionKind + + from ._aio import Genkit +-from ._plugin import Plugin, PluginV2 ++from ._plugin import Plugin + from ._registry import FlowWrapper, GenkitRegistry + + __all__ = [ +@@ -47,7 +47,7 @@ __all__ = [ + GenkitRegistry.__name__, + Genkit.__name__, + Plugin.__name__, +- PluginV2.__name__, ++ Plugin.__name__, + ToolRunContext.__name__, + tool_response.__name__, + FlowWrapper.__name__, +diff --git a/py/packages/genkit/src/genkit/ai/_aio.py b/py/packages/genkit/src/genkit/ai/_aio.py +index 25e5aeca1..19002e8c2 100644 +--- a/py/packages/genkit/src/genkit/ai/_aio.py ++++ b/py/packages/genkit/src/genkit/ai/_aio.py +@@ -45,7 +45,6 @@ from genkit.core.action import ActionRunContext + from genkit.core.action.types import ActionKind + from genkit.core.typing import ( + BaseDataPoint, +- BaseEvalDataPoint, + EmbedRequest, + EmbedResponse, + EvalRequest, +@@ -351,9 +350,14 @@ class Genkit(GenkitBase): + + final_options = {**(retriever_config or {}), **(options or {})} + +- retrieve_action = self.registry.lookup_action(ActionKind.RETRIEVER, retriever_name) ++ retrieve_action = await self.registry.resolve_action(ActionKind.RETRIEVER, retriever_name) + +- return (await retrieve_action.arun(RetrieverRequest(query=query, options=final_options))).response ++ parent_ctx = ActionRunContext._current_context() or {} ++ action_ctx = dict(parent_ctx) ++ action_ctx.setdefault('__genkit_ai__', self) ++ return ( ++ await retrieve_action.arun(RetrieverRequest(query=query, options=final_options), context=action_ctx) ++ ).response + + async def index( + self, +@@ -383,9 +387,12 @@ class Genkit(GenkitBase): + + final_options = {**(indexer_config or {}), **(options or {})} + +- index_action = self.registry.lookup_action(ActionKind.INDEXER, indexer_name) ++ index_action = await self.registry.resolve_action(ActionKind.INDEXER, indexer_name) + +- await index_action.arun(IndexerRequest(documents=documents, options=final_options)) ++ parent_ctx = ActionRunContext._current_context() or {} ++ action_ctx = dict(parent_ctx) ++ action_ctx.setdefault('__genkit_ai__', self) ++ await index_action.arun(IndexerRequest(documents=documents, options=final_options), context=action_ctx) + + async def embed( + self, +@@ -410,9 +417,14 @@ class Genkit(GenkitBase): + # Merge options passed to embed() with config from EmbedderRef + final_options = {**(embedder_config or {}), **(options or {})} + +- embed_action = self.registry.lookup_action(ActionKind.EMBEDDER, embedder_name) ++ embed_action = await self.registry.resolve_action(ActionKind.EMBEDDER, embedder_name) + +- return (await embed_action.arun(EmbedRequest(input=documents, options=final_options))).response ++ parent_ctx = ActionRunContext._current_context() or {} ++ action_ctx = dict(parent_ctx) ++ action_ctx.setdefault('__genkit_ai__', self) ++ return ( ++ await embed_action.arun(EmbedRequest(input=documents, options=final_options), context=action_ctx) ++ ).response + + async def evaluate( + self, +@@ -445,17 +457,21 @@ class Genkit(GenkitBase): + + final_options = {**(evaluator_config or {}), **(options or {})} + +- eval_action = self.registry.lookup_action(ActionKind.EVALUATOR, evaluator_name) ++ eval_action = await self.registry.resolve_action(ActionKind.EVALUATOR, evaluator_name) + + if not eval_run_id: + eval_run_id = str(uuid.uuid4()) + ++ parent_ctx = ActionRunContext._current_context() or {} ++ action_ctx = dict(parent_ctx) ++ action_ctx.setdefault('__genkit_ai__', self) + return ( + await eval_action.arun( + EvalRequest( + dataset=dataset, + options=final_options, + eval_run_id=eval_run_id, +- ) ++ ), ++ context=action_ctx, + ) + ).response +diff --git a/py/packages/genkit/src/genkit/ai/_base.py b/py/packages/genkit/src/genkit/ai/_base.py +index e2a2dcb1d..7893af616 100644 +--- a/py/packages/genkit/src/genkit/ai/_base.py ++++ b/py/packages/genkit/src/genkit/ai/_base.py +@@ -17,7 +17,6 @@ + """Base/shared implementation for Genkit user-facing API.""" + + import asyncio +-import inspect + import os + import threading + from collections.abc import Coroutine +@@ -29,13 +28,13 @@ import structlog + from genkit.aio.loop import create_loop, run_async + from genkit.blocks.formats import built_in_formats + from genkit.blocks.generate import define_generate_action +-from genkit.core.action import Action ++from genkit.core.action import ActionMetadata + from genkit.core.environment import is_dev_environment + from genkit.core.reflection import make_reflection_server + from genkit.core.registry import ActionKind + from genkit.web.manager import find_free_port_sync + +-from ._plugin import Plugin, PluginV2, is_plugin_v2 ++from ._plugin import Plugin + from ._registry import GenkitRegistry + from ._server import ServerSpec, init_default_runtime + +@@ -123,61 +122,95 @@ class GenkitBase(GenkitRegistry): + logger.warning('No plugins provided to Genkit') + else: + for plugin in plugins: +- if is_plugin_v2(plugin): +- self._initialize_v2_plugin(plugin) +- elif isinstance(plugin, Plugin): +- plugin.initialize(ai=self) +- +- def resolver(kind, name, plugin=plugin): +- return plugin.resolve_action(self, kind, name) +- +- def action_resolver(plugin=plugin): +- if isinstance(plugin.list_actions, list): +- return plugin.list_actions +- else: +- return plugin.list_actions() +- +- self.registry.register_action_resolver(plugin.plugin_name(), resolver) +- self.registry.register_list_actions_resolver(plugin.plugin_name(), action_resolver) +- else: +- raise ValueError(f'Invalid {plugin=} provided to Genkit: must be of type `genkit.ai.Plugin` or `genkit.ai.PluginV2`') ++ if not isinstance(plugin, Plugin): ++ raise ValueError(f'Invalid {plugin=} provided to Genkit: must be of type `genkit.ai.Plugin`') ++ self._initialize_plugin(plugin) + +- def _initialize_v2_plugin(self, plugin: PluginV2) -> None: +- """Register a v2 plugin by calling its methods and registering returned actions. ++ def _initialize_plugin(self, plugin: Plugin) -> None: ++ """Register a plugin without eagerly initializing it. + +- Steps: +- 1. Call plugin.init() to get resolved actions +- 2. Register each action with automatic namespacing +- 3. Set up lazy resolver for on-demand actions ++ Plugins are registered during Genkit construction, but their ++ `init()` is only invoked when the plugin is *initialized* (e.g. first ++ use via action lookup). + +- Args: +- plugin: V2 plugin instance to register. ++ This method wires: ++ - a lazy initializer (calls `plugin.init()` once, on-demand) ++ - an action resolver (calls `plugin.resolve()` on cache miss) ++ - a list-actions resolver (calls `plugin.list_actions()` for discovery) + """ +- if inspect.iscoroutinefunction(plugin.init): +- resolved_actions = asyncio.run(plugin.init()) +- else: +- resolved_actions = plugin.init() ++ initialized = False ++ init_lock = threading.Lock() ++ init_task: asyncio.Task[None] | None = None + +- for action in resolved_actions: +- self._register_action(action, plugin) ++ async def ensure_initialized() -> None: ++ """Initialize the plugin exactly once (async-first). + +- def resolver(kind: ActionKind, name: str) -> None: ++ This is the JS-style 'initializer promise' pattern: cache a single ++ task and await it from all concurrent callers. ++ """ ++ nonlocal initialized, init_task ++ if initialized: ++ return ++ ++ with init_lock: ++ if initialized: ++ return ++ if init_task is None: ++ ++ async def do_init(): ++ nonlocal initialized, init_task ++ try: ++ resolved_actions = await plugin.init() ++ for action in resolved_actions: ++ self._register_action(action, plugin) ++ initialized = True ++ finally: ++ # If init failed, allow retry on next access. ++ if not initialized: ++ with init_lock: ++ init_task = None ++ ++ init_task = asyncio.create_task(do_init()) ++ ++ # Safe: init_task is set under lock. ++ await init_task ++ ++ async def resolver(kind: ActionKind, name: str): + """Lazy resolver for v2 plugin. + + Called when framework needs an action not returned from init(). + """ +- if inspect.iscoroutinefunction(plugin.resolve): +- action = asyncio.run(plugin.resolve(kind, name)) +- else: +- action = plugin.resolve(kind, name) ++ await ensure_initialized() ++ clean_name = name.removeprefix(f'{plugin.name}/') if name.startswith(f'{plugin.name}/') else name + ++ action = await plugin.resolve(kind, clean_name) + if action: + self._register_action(action, plugin) + + self.registry.register_action_resolver(plugin.name, resolver) + +- def _register_action(self, action: Any, plugin: PluginV2) -> None: +- """Register a single action from a v2 plugin. ++ async def list_actions_resolver(plugin=plugin): ++ """List available actions for a plugin (for discovery/devtools). ++ ++ Important: This should not force plugin initialization; it should use ++ lightweight `Plugin.list_actions()` metadata instead. ++ """ ++ resolved = await plugin.list_actions() ++ ++ namespaced: list[ActionMetadata] = [] ++ for meta in resolved: ++ if meta.name.startswith(f'{plugin.name}/'): ++ namespaced.append(meta) ++ else: ++ data = meta.model_dump() ++ data['name'] = f'{plugin.name}/{meta.name}' ++ namespaced.append(ActionMetadata(**data)) ++ return namespaced ++ ++ self.registry.register_list_actions_resolver(plugin.name, list_actions_resolver) ++ ++ def _register_action(self, action: Any, plugin: Plugin) -> None: ++ """Register a single action from a plugin. + + Responsibilities: + 1. Add plugin namespace to action name (if not already present) +@@ -185,13 +218,11 @@ class GenkitBase(GenkitRegistry): + + Args: + action: Action instance from the plugin. +- plugin: The v2 plugin that created this action. ++ plugin: The plugin that created this action. + """ +- # Register the pre-constructed action instance and let the registry apply +- # namespacing for v2 plugins. ++ # Register the pre-constructed action instance and let the registry apply namespacing + self.registry.register_action_instance(action, namespace=plugin.name) + +- + def _initialize_server(self, reflection_server_spec: ServerSpec | None) -> None: + """Initialize the server for the Genkit instance. + +diff --git a/py/packages/genkit/src/genkit/ai/_base_async.py b/py/packages/genkit/src/genkit/ai/_base_async.py +index df803d378..9baa7be26 100644 +--- a/py/packages/genkit/src/genkit/ai/_base_async.py ++++ b/py/packages/genkit/src/genkit/ai/_base_async.py +@@ -17,7 +17,7 @@ + """Asynchronous server gateway interface implementation for Genkit.""" + + import asyncio +-import inspect ++import threading + from collections.abc import Coroutine + from typing import Any, TypeVar + +@@ -27,13 +27,13 @@ import uvicorn + + from genkit.aio.loop import run_loop + from genkit.blocks.formats import built_in_formats +-from genkit.core.action import Action ++from genkit.core.action import Action, ActionMetadata + from genkit.core.environment import is_dev_environment + from genkit.core.reflection import create_reflection_asgi_app + from genkit.core.registry import ActionKind + from genkit.web.manager import find_free_port_sync + +-from ._plugin import Plugin, PluginV2, is_plugin_v1, is_plugin_v2 ++from ._plugin import Plugin + from ._registry import GenkitRegistry + from ._runtime import RuntimeManager + from ._server import ServerSpec +@@ -48,7 +48,7 @@ class GenkitBase(GenkitRegistry): + + def __init__( + self, +- plugins: list[Plugin | PluginV2] | None = None, ++ plugins: list[Plugin] | None = None, + model: str | None = None, + reflection_server_spec: ServerSpec | None = None, + ) -> None: +@@ -64,15 +64,12 @@ class GenkitBase(GenkitRegistry): + self._reflection_server_spec = reflection_server_spec + self._initialize_registry(model, plugins) + +- def _initialize_registry(self, model: str | None, plugins: list[Plugin | PluginV2] | None) -> None: ++ def _initialize_registry(self, model: str | None, plugins: list[Plugin] | None = None) -> None: + """Initialize the registry for the Genkit instance. + +- Supports both v1 (Plugin) and v2 (PluginV2) plugins. Detection is done +- at runtime via is_plugin_v2(). +- + Args: + model: Model name to use. +- plugins: List of plugins to initialize (v1 or v2). ++ plugins: List of plugins to initialize. + + Raises: + ValueError: If an invalid plugin is provided. +@@ -88,25 +85,13 @@ class GenkitBase(GenkitRegistry): + logger.warning('No plugins provided to Genkit') + else: + for plugin in plugins: +- if is_plugin_v2(plugin): +- logger.debug(f'Registering v2 plugin: {plugin.name}') +- self._register_v2_plugin(plugin) +- elif is_plugin_v1(plugin): +- logger.debug(f'Registering v1 plugin: {plugin.plugin_name()}') +- plugin.initialize(ai=self) ++ if not isinstance(plugin, Plugin): ++ raise ValueError(f'Invalid {plugin=} provided to Genkit: must be of type `genkit.ai.Plugin`') ++ logger.debug(f'Registering plugin: {plugin.name}') ++ self._register_plugin(plugin) + +- def resolver(kind, name, plugin=plugin): +- return plugin.resolve_action(self, kind, name) +- +- self.registry.register_action_resolver(plugin.plugin_name(), resolver) +- else: +- raise ValueError( +- f'Invalid {plugin=} provided to Genkit: ' +- f'must implement either Plugin or PluginV2 interface' +- ) +- +- def _register_v2_plugin(self, plugin: PluginV2) -> None: +- """Register a v2 plugin by calling its methods and registering returned actions. ++ def _register_plugin(self, plugin: Plugin) -> None: ++ """Register a plugin by calling its methods and registering returned actions. + + Steps: + 1. Call plugin.init() to get resolved actions +@@ -116,32 +101,68 @@ class GenkitBase(GenkitRegistry): + Args: + plugin: V2 plugin instance to register. + """ +- if inspect.iscoroutinefunction(plugin.init): +- resolved_actions = asyncio.run(plugin.init()) +- else: +- resolved_actions = plugin.init() +- +- for action in resolved_actions: +- self._register_action_v2(action, plugin) +- +- def resolver(kind: ActionKind, name: str) -> None: ++ initialized = False ++ init_lock = threading.Lock() ++ init_task: asyncio.Task[None] | None = None ++ ++ async def ensure_initialized() -> None: ++ """Initialize the plugin exactly once (async-first).""" ++ nonlocal initialized, init_task ++ if initialized: ++ return ++ ++ with init_lock: ++ if initialized: ++ return ++ if init_task is None: ++ ++ async def do_init(): ++ nonlocal initialized, init_task ++ try: ++ resolved_actions = await plugin.init() ++ for action in resolved_actions: ++ self._register_action(action, plugin) ++ initialized = True ++ finally: ++ if not initialized: ++ with init_lock: ++ init_task = None ++ ++ init_task = asyncio.create_task(do_init()) ++ ++ await init_task ++ ++ async def resolver(kind: ActionKind, name: str): + """Lazy resolver for v2 plugin. + + Called when framework needs an action not returned from init(). + """ +- # Check if resolve method is async +- if inspect.iscoroutinefunction(plugin.resolve): +- action = asyncio.run(plugin.resolve(kind, name)) +- else: +- action = plugin.resolve(kind, name) +- ++ await ensure_initialized() ++ clean_name = name.removeprefix(f'{plugin.name}/') if name.startswith(f'{plugin.name}/') else name ++ action = await plugin.resolve(kind, clean_name) + if action: + self._register_action_v2(action, plugin) + + self.registry.register_action_resolver(plugin.name, resolver) + +- def _register_action_v2(self, action: Action, plugin: PluginV2) -> None: +- """Register a single action from a v2 plugin. ++ async def list_actions_resolver(plugin=plugin): ++ """List available actions for a plugin (for discovery/devtools).""" ++ resolved = await plugin.list_actions() ++ ++ namespaced: list[ActionMetadata] = [] ++ for meta in resolved: ++ if meta.name.startswith(f'{plugin.name}/'): ++ namespaced.append(meta) ++ else: ++ data = meta.model_dump() ++ data['name'] = f'{plugin.name}/{meta.name}' ++ namespaced.append(ActionMetadata(**data)) ++ return namespaced ++ ++ self.registry.register_list_actions_resolver(plugin.name, list_actions_resolver) ++ ++ def _register_action(self, action: Action, plugin: Plugin) -> None: ++ """Register a single action from a plugin. + + Responsibilities: + 1. Add plugin namespace to action name (if not already present) +@@ -149,13 +170,12 @@ class GenkitBase(GenkitRegistry): + + Args: + action: Action instance from the plugin. +- plugin: The v2 plugin that created this action. ++ plugin: The plugin that created this action. + """ +- # Register the pre-constructed action instance and let the registry apply +- # namespacing for v2 plugins. ++ # Register the pre-constructed action instance and let the registry apply namespacing + self.registry.register_action_instance(action, namespace=plugin.name) + +- logger.debug(f'Registered v2 action: {action.name}') ++ logger.debug(f'Registered action: {action.name}') + + def run_main(self, coro: Coroutine[Any, Any, T]) -> T: + """Run the user's main coroutine. +diff --git a/py/packages/genkit/src/genkit/ai/_plugin.py b/py/packages/genkit/src/genkit/ai/_plugin.py +index 875fcc89d..6bc5a68e4 100644 +--- a/py/packages/genkit/src/genkit/ai/_plugin.py ++++ b/py/packages/genkit/src/genkit/ai/_plugin.py +@@ -21,140 +21,81 @@ It provides a way to initialize and register plugin functionality. + """ + + import abc +-import inspect +-from collections.abc import Awaitable +-from typing import Any, Literal ++from collections.abc import Awaitable, Callable ++from typing import Any + + from genkit.core.registry import ActionKind + + from ..core.action import Action, ActionMetadata + from ._registry import GenkitRegistry + ++# Type aliases for plugin resolver functions ++ActionResolver = Callable[[ActionKind, str], Awaitable[Action | None]] ++"""Async function that resolves an action by kind and name.""" + +-class Plugin(abc.ABC): +- """Abstract base class for implementing Genkit plugins. +- +- This class defines the interface that all plugins must implement. Plugins +- provide a way to extend functionality by registering new actions, models, or +- other capabilities. +- """ +- +- def plugin_name(self): +- """The name of the plugin. +- +- Returns: +- The name of the plugin. +- """ +- return self.name +- +- # TODO: https://github.com/firebase/genkit/issues/2438 +- # @abc.abstractmethod +- def resolve_action( # noqa: B027 +- self, +- ai: GenkitRegistry, +- kind: ActionKind, +- name: str, +- ) -> None: +- """Resolves an action by adding it to the provided GenkitRegistry. +- +- Args: +- ai: The Genkit registry. +- kind: The kind of action to resolve. +- name: The name of the action to resolve. +- +- Returns: +- None, action resolution is done by side-effect on the registry. +- """ +- pass +- +- @abc.abstractmethod +- def initialize(self, ai: GenkitRegistry) -> None: +- """Initialize the plugin with the given registry. +- +- Args: +- ai: Registry to register plugin functionality. +- +- Returns: +- None, initialization is done by side-effect on the registry. +- """ +- pass +- +- def list_actions(self) -> list[ActionMetadata]: +- """Generate a list of available actions or models. +- +- Returns: +- list[ActionMetadata]: A list of ActionMetadata objects, each with the following attributes: +- - name (str): The name of the action or model. +- - kind (ActionKind): The type or category of the action. +- - info (dict): The metadata dictionary describing the model configuration and properties. +- - config_schema (type): The schema class used for validating the model's configuration. +- """ +- return [] ++ListActionsResolver = Callable[[], Awaitable[list[ActionMetadata]]] ++"""Async function that returns a list of action metadata for discovery.""" + + +-class PluginV2(abc.ABC): +- """Base class for v2 plugins that return actions instead of mutating registry. ++class Plugin(abc.ABC): ++ """Base class for Genkit plugins that return actions instead of mutating registry. + +- V2 plugins are decoupled from the registry - they create and return Action ++ Plugins are decoupled from the registry - they create and return Action + objects which the framework then registers. This enables: + - Standalone usage (use plugins without framework) + - Better testability (test plugins in isolation) + + Plugin authors should inherit from this class and implement the required methods. +- The version marker is set automatically. + + Example: +- >>> class MyPlugin(PluginV2): +- ... name = "myplugin" ++ >>> class MyPlugin(Plugin): ++ ... name = 'myplugin' + ... +- ... def init(self): +- ... return [model(name="my-model", fn=self._generate)] ++ ... async def init(self): ++ ... return [model(name='my-model', fn=self._generate)] + ... +- ... def resolve(self, action_type, name): ++ ... async def resolve(self, action_type, name): + ... return model(name=name, fn=self._generate) ++ ... ++ ... async def list_actions(self): ++ ... return [ActionMetadata(name='my-model', kind=ActionKind.MODEL)] + """ + +- version: Literal["v2"] = "v2" +- """Version marker - set automatically by base class.""" +- + name: str + """Plugin name (e.g., 'anthropic', 'openai'). Must be set by subclass.""" + + @abc.abstractmethod +- def init(self) -> list[Action] | Awaitable[list[Action]]: ++ async def init(self) -> list[Action]: + """Return eagerly-initialized actions. + + Called once during Genkit initialization. Return actions you want + created immediately (common models, frequently used tools, etc.). + +- Can be sync or async. +- + Returns: + List of Action objects (not yet registered with any registry). + + Example: +- >>> def init(self): ++ >>> async def init(self): + ... from genkit.blocks.model import model ++ ... + ... return [ +- ... model(name="gpt-4", fn=self._generate), +- ... model(name="gpt-4o", fn=self._generate), ++ ... model(name='gpt-4', fn=self._generate), ++ ... model(name='gpt-4o', fn=self._generate), + ... ] + """ + ... + + @abc.abstractmethod +- def resolve( ++ async def resolve( + self, + action_type: ActionKind, + name: str, +- ) -> Action | None | Awaitable[Action | None]: ++ ) -> Action | None: + """Resolve a specific action on-demand (lazy loading). + + Called when the framework needs an action that wasn't returned from init(). + Enables lazy loading of less-common models or actions. + +- Can be sync or async. +- + Args: + action_type: Type of action requested (MODEL, EMBEDDER, TOOL, etc.). + name: Name of the action (WITHOUT plugin prefix - framework strips it). +@@ -163,35 +104,30 @@ class PluginV2(abc.ABC): + Action object if this plugin can provide it, None if it cannot. + + Example: +- >>> def resolve(self, action_type, name): ++ >>> async def resolve(self, action_type, name): + ... if action_type == ActionKind.MODEL: + ... if name in SUPPORTED_MODELS: + ... from genkit.blocks.model import model ++ ... + ... return model(name=name, fn=self._generate) + ... return None + """ + ... + +- def list(self) -> list[ActionMetadata] | Awaitable[list[ActionMetadata]]: ++ async def list_actions(self) -> list[ActionMetadata]: + """List all actions this plugin can provide. + + Used for discovery, developer tools, and documentation. + Should return metadata for ALL actions the plugin supports, + not just those returned from init(). + +- Can be sync or async. +- + Returns: + List of ActionMetadata objects (lightweight descriptions). + + Example: +- >>> def list(self): ++ >>> async def list_actions(self): + ... return [ +- ... ActionMetadata( +- ... name="gpt-4", +- ... kind=ActionKind.MODEL, +- ... info={"supports": {"vision": True}} +- ... ), ++ ... ActionMetadata(name='gpt-4', kind=ActionKind.MODEL, info={'supports': {'vision': True}}), + ... # ... more models + ... ] + """ +@@ -199,7 +135,7 @@ class PluginV2(abc.ABC): + return [] + + async def model(self, name: str) -> Action: +- """Convenience method to get a specific model action. ++ r"""Convenience method to get a specific model action. + + Enables clean standalone usage: + plugin = SomePlugin() +@@ -217,26 +153,14 @@ class PluginV2(abc.ABC): + + Example: + >>> async def model(self, name: str) -> Action: +- ... action = self.resolve(ActionKind.MODEL, name) ++ ... action = await self.resolve(ActionKind.MODEL, name) + ... if not action: + ... raise ValueError(f\"Model {name} not found\") + ... return action + """ +- # Default implementation - plugins can override if needed +- if inspect.iscoroutinefunction(self.resolve): +- action = await self.resolve(ActionKind.MODEL, name) +- else: +- action = self.resolve(ActionKind.MODEL, name) ++ # Call the async resolve method ++ action = await self.resolve(ActionKind.MODEL, name) + + if not action: +- raise ValueError( +- f"Model '{name}' not found in plugin '{self.name}'" +- ) ++ raise ValueError(f"Model '{name}' not found in plugin '{self.name}'") + return action +- +- +-def is_plugin_v2(plugin: Any) -> bool: +- return hasattr(plugin, "version") and getattr(plugin, "version") == "v2" +- +-def is_plugin_v1(plugin: Any) -> bool: +- return isinstance(plugin, Plugin) +diff --git a/py/packages/genkit/src/genkit/ai/_registry.py b/py/packages/genkit/src/genkit/ai/_registry.py +index 8d6224998..51b8a5c08 100644 +--- a/py/packages/genkit/src/genkit/ai/_registry.py ++++ b/py/packages/genkit/src/genkit/ai/_registry.py +@@ -30,6 +30,7 @@ several kinds of action defined by [ActionKind][genkit.core.action.ActionKind]: + | `'indexer'` | Indexer | + | `'model'` | Model | + | `'prompt'` | Prompt | ++| `'resource'` | Resource | + | `'retriever'` | Retriever | + | `'text-llm'` | Text LLM | + | `'tool'` | Tool | +@@ -42,7 +43,10 @@ import traceback + import uuid + from collections.abc import AsyncIterator, Callable + from functools import wraps +-from typing import Any, Type ++from typing import TYPE_CHECKING, Any, Callable, Type ++ ++if TYPE_CHECKING: ++ from genkit.blocks.resource import ResourceFn, ResourceOptions + + import structlog + from pydantic import BaseModel +@@ -53,9 +57,18 @@ from genkit.blocks.formats.types import FormatDef + from genkit.blocks.model import ModelFn, ModelMiddleware + from genkit.blocks.prompt import ( + define_helper, ++ define_partial, + define_prompt, + lookup_prompt, + ) ++from genkit.blocks.reranker import ( ++ RankedDocument, ++ RerankerFn, ++ RerankerOptions, ++ RerankerRef, ++ define_reranker as define_reranker_block, ++ rerank as rerank_block, ++) + from genkit.blocks.retriever import IndexerFn, RetrieverFn + from genkit.blocks.tools import ToolRunContext + from genkit.codec import dump_dict +@@ -65,6 +78,7 @@ from genkit.core.registry import Registry + from genkit.core.schema import to_json_schema + from genkit.core.tracing import run_in_new_span + from genkit.core.typing import ( ++ DocumentData, + EvalFnResponse, + EvalRequest, + EvalResponse, +@@ -181,6 +195,18 @@ class GenkitRegistry: + """ + define_helper(self.registry, name, fn) + ++ def define_partial(self, name: str, source: str) -> None: ++ """Define a Handlebars partial template in the registry. ++ ++ Partials are reusable template fragments that can be included ++ in other prompts using {{>partialName}} syntax. ++ ++ Args: ++ name: The name of the partial. ++ source: The template source code for the partial. ++ """ ++ define_partial(self.registry, name, source) ++ + def tool(self, name: str | None = None, description: str | None = None) -> Callable[[Callable], Callable]: + """Decorator to register a function as a tool. + +@@ -326,6 +352,100 @@ class GenkitRegistry: + description=indexer_description, + ) + ++ def define_reranker( ++ self, ++ name: str, ++ fn: RerankerFn, ++ config_schema: BaseModel | dict[str, Any] | None = None, ++ metadata: dict[str, Any] | None = None, ++ description: str | None = None, ++ ) -> Action: ++ """Define a reranker action. ++ ++ Rerankers reorder documents based on their relevance to a query. ++ They are commonly used in RAG pipelines to improve retrieval quality. ++ ++ Args: ++ name: Name of the reranker. ++ fn: Function implementing the reranker behavior. Should accept ++ (query_doc, documents, options) and return RerankerResponse. ++ config_schema: Optional schema for reranker configuration. ++ metadata: Optional metadata for the reranker. ++ description: Optional description for the reranker. ++ ++ Returns: ++ The registered Action for the reranker. ++ ++ Example: ++ >>> async def my_reranker(query, docs, options): ++ ... # Score documents based on relevance to query ++ ... scored = [(doc, compute_score(query, doc)) for doc in docs] ++ ... scored.sort(key=lambda x: x[1], reverse=True) ++ ... return RerankerResponse(documents=[...]) ++ >>> ai.define_reranker('my-reranker', my_reranker) ++ """ ++ reranker_meta = metadata.copy() if metadata else {} ++ if 'reranker' not in reranker_meta: ++ reranker_meta['reranker'] = {} ++ if 'label' not in reranker_meta['reranker'] or not reranker_meta['reranker']['label']: ++ reranker_meta['reranker']['label'] = name ++ if config_schema: ++ reranker_meta['reranker']['customOptions'] = to_json_schema(config_schema) ++ ++ reranker_description = get_func_description(fn, description) ++ return define_reranker_block( ++ self.registry, ++ name=name, ++ fn=fn, ++ options=RerankerOptions( ++ config_schema=reranker_meta['reranker'].get('customOptions'), ++ label=reranker_meta['reranker'].get('label'), ++ ), ++ ) ++ ++ async def rerank( ++ self, ++ reranker: str | Action | RerankerRef, ++ query: str | DocumentData, ++ documents: list[DocumentData], ++ options: Any | None = None, ++ ) -> list[RankedDocument]: ++ """Rerank documents based on their relevance to a query. ++ ++ This method takes a query and a list of documents, and returns the ++ documents reordered by relevance as determined by the specified reranker. ++ ++ Args: ++ reranker: The reranker to use - can be a name string, Action, or RerankerRef. ++ query: The query to rank documents against - can be a string or DocumentData. ++ documents: The list of documents to rerank. ++ options: Optional configuration options for this rerank call. ++ ++ Returns: ++ A list of RankedDocument objects sorted by relevance score. ++ ++ Raises: ++ ValueError: If the reranker cannot be resolved. ++ ++ Example: ++ >>> ranked_docs = await ai.rerank( ++ ... reranker='my-reranker', ++ ... query='What is machine learning?', ++ ... documents=[doc1, doc2, doc3], ++ ... ) ++ >>> for doc in ranked_docs: ++ ... print(f'Score: {doc.score}, Text: {doc.text()}') ++ """ ++ return await rerank_block( ++ self.registry, ++ { ++ 'reranker': reranker, ++ 'query': query, ++ 'documents': documents, ++ 'options': options, ++ }, ++ ) ++ + def define_evaluator( + self, + name: str, +@@ -488,7 +608,7 @@ class GenkitRegistry: + self, + name: str, + fn: ModelFn, +- config_schema: Type[BaseModel] | dict[str, Any] | None = None, ++ config_schema: type[BaseModel] | dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + info: ModelInfo | None = None, + description: str | None = None, +@@ -573,14 +693,15 @@ class GenkitRegistry: + + def define_prompt( + self, ++ name: str | None = None, + variant: str | None = None, + model: str | None = None, + config: GenerationCommonConfig | dict[str, Any] | None = None, + description: str | None = None, + input_schema: type | dict[str, Any] | None = None, +- system: str | Part | list[Part] | None = None, +- prompt: str | Part | list[Part] | None = None, +- messages: str | list[Message] | None = None, ++ system: str | Part | list[Part] | Callable | None = None, ++ prompt: str | Part | list[Part] | Callable | None = None, ++ messages: str | list[Message] | Callable | None = None, + output_format: str | None = None, + output_content_type: str | None = None, + output_instructions: bool | str | None = None, +@@ -592,12 +713,12 @@ class GenkitRegistry: + tools: list[str] | None = None, + tool_choice: ToolChoice | None = None, + use: list[ModelMiddleware] | None = None, +- # TODO: +- # docs: list[Document] ++ docs: list[DocumentData] | Callable | None = None, + ): + """Define a prompt. + + Args: ++ name: Optional name for the prompt. + variant: Optional variant name for the prompt. + model: Optional model name to use for the prompt. + config: Optional configuration for the model. +@@ -619,9 +740,11 @@ class GenkitRegistry: + tools: Optional list of tools to use for the prompt. + tool_choice: Optional tool choice for the prompt. + use: Optional list of model middlewares to use for the prompt. ++ docs: Optional list of documents or a callable to be used for grounding. + """ + return define_prompt( + self.registry, ++ name=name, + variant=variant, + model=model, + config=config, +@@ -641,6 +764,7 @@ class GenkitRegistry: + tools=tools, + tool_choice=tool_choice, + use=use, ++ docs=docs, + ) + + async def prompt( +@@ -668,13 +792,58 @@ class GenkitRegistry: + Raises: + GenkitError: If the prompt is not found. + """ +- + return await lookup_prompt( + registry=self.registry, + name=name, + variant=variant, + ) + ++ def define_resource( ++ self, ++ opts: 'ResourceOptions | None' = None, ++ fn: 'ResourceFn | None' = None, ++ *, ++ name: str | None = None, ++ uri: str | None = None, ++ template: str | None = None, ++ description: str | None = None, ++ metadata: dict[str, Any] | None = None, ++ ) -> Action: ++ """Define a resource action. ++ ++ Args: ++ opts: Options defining the resource (e.g. uri, template, name). ++ fn: Function implementing the resource behavior. ++ name: Optional name for the resource. ++ uri: Optional URI for the resource. ++ template: Optional URI template for the resource. ++ description: Optional description for the resource. ++ metadata: Optional metadata for the resource. ++ ++ Returns: ++ The registered Action for the resource. ++ """ ++ from genkit.blocks.resource import ( ++ define_resource as define_resource_block, ++ ) ++ ++ if fn is None: ++ raise ValueError('A function `fn` must be provided to define a resource.') ++ if opts is None: ++ opts = {} ++ if name: ++ opts['name'] = name ++ if uri: ++ opts['uri'] = uri ++ if template: ++ opts['template'] = template ++ if description: ++ opts['description'] = description ++ if metadata: ++ opts['metadata'] = metadata ++ ++ return define_resource_block(self.registry, opts, fn) ++ + + class FlowWrapper: + """A wapper for flow functions to add `stream` method.""" +diff --git a/py/packages/genkit/src/genkit/aio/_compat.py b/py/packages/genkit/src/genkit/aio/_compat.py +index 0c0d6b7fc..04a053891 100644 +--- a/py/packages/genkit/src/genkit/aio/_compat.py ++++ b/py/packages/genkit/src/genkit/aio/_compat.py +@@ -26,7 +26,6 @@ See: https://docs.python.org/3/library/asyncio-task.html#asyncio.wait_for + """ + + import asyncio +-import sys + from typing import TypeVar + + T = TypeVar('T') +@@ -52,11 +51,8 @@ async def wait_for_310(fut: asyncio.Future[T], timeout: float | None = None) -> + """ + try: + return await asyncio.wait_for(fut, timeout) +- except asyncio.TimeoutError as e: ++ except TimeoutError as e: + raise TimeoutError() from e + + +-if sys.version_info < (3, 11): +- wait_for = wait_for_310 +-else: +- wait_for = asyncio.wait_for ++wait_for = asyncio.wait_for +diff --git a/py/packages/genkit/src/genkit/blocks/embedding.py b/py/packages/genkit/src/genkit/blocks/embedding.py +index 8ac473b15..16913ade4 100644 +--- a/py/packages/genkit/src/genkit/blocks/embedding.py ++++ b/py/packages/genkit/src/genkit/blocks/embedding.py +@@ -16,7 +16,7 @@ + + """Embedding actions.""" + +-from collections.abc import Awaitable, Callable ++from collections.abc import Callable + from typing import Any + + from pydantic import BaseModel, ConfigDict, Field +@@ -88,8 +88,8 @@ def embedder( + >>> def my_embed(request: EmbedRequest) -> EmbedResponse: + ... return EmbedResponse(...) + >>> +- >>> action = embedder(name="my-embedder", fn=my_embed) +- >>> response = await action.arun({"input": [...]}) ++ >>> action = embedder(name='my-embedder', fn=my_embed) ++ >>> response = await action.arun({'input': [...]}) + """ + embedder_meta = metadata if metadata else {} + +diff --git a/py/packages/genkit/src/genkit/blocks/generate.py b/py/packages/genkit/src/genkit/blocks/generate.py +index 754af90d4..5982589d9 100644 +--- a/py/packages/genkit/src/genkit/blocks/generate.py ++++ b/py/packages/genkit/src/genkit/blocks/generate.py +@@ -32,7 +32,7 @@ from genkit.blocks.model import ( + from genkit.blocks.tools import ToolInterruptError + from genkit.codec import dump_dict + from genkit.core.action import ActionRunContext +-from genkit.core.error import GenkitError, StatusName ++from genkit.core.error import GenkitError + from genkit.core.registry import Action, ActionKind, Registry + from genkit.core.typing import ( + GenerateActionOptions, +@@ -97,7 +97,7 @@ async def generate_action( + Returns: + The generated response. + """ +- model, tools, format_def = resolve_parameters(registry, raw_request) ++ model, tools, format_def = await resolve_parameters(registry, raw_request) + + raw_request, formatter = apply_format(raw_request, format_def) + +@@ -350,11 +350,7 @@ def apply_format( + raw_request.output.instructions if raw_request.output else None, + ) + +- if ( +- format_def.config.default_instructions != False or raw_request.output.instructions +- if raw_request.output +- else False +- ): ++ if format_def.config.default_instructions or raw_request.output.instructions if raw_request.output else False: + out_request.messages = inject_instructions(out_request.messages, instructions) + + if format_def.config.constrained is not None: +@@ -384,7 +380,7 @@ def resolve_instructions(formatter: Formatter, instructions_opt: bool | str | No + if isinstance(instructions_opt, str): + # user provided instructions + return instructions_opt +- if instructions_opt == False: ++ if not instructions_opt: + # user says no instructions + return None + if not formatter: +@@ -425,7 +421,7 @@ def assert_valid_tool_names(raw_request: GenerateActionOptions): + pass + + +-def resolve_parameters( ++async def resolve_parameters( + registry: Registry, request: GenerateActionOptions + ) -> tuple[Action, list[Action], FormatDef | None]: + """Resolve parameters for the generate action. +@@ -442,14 +438,14 @@ def resolve_parameters( + if not model: + raise Exception('No model configured.') + +- model_action = registry.lookup_action(ActionKind.MODEL, model) ++ model_action = await registry.resolve_action(ActionKind.MODEL, model) + if model_action is None: + raise Exception(f'Failed to to resolve model {model}') + + tools: list[Action] = [] + if request.tools: + for tool_name in request.tools: +- tool_action = registry.lookup_action(ActionKind.TOOL, tool_name) ++ tool_action = await registry.resolve_action(ActionKind.TOOL, tool_name) + if tool_action is None: + raise Exception(f'Unable to resolve tool {tool_name}') + tools.append(tool_action) +@@ -541,7 +537,7 @@ async def resolve_tool_requests( + # TODO: prompt transfer + tool_dict: dict[str, Action] = {} + for tool_name in request.tools: +- tool_dict[tool_name] = resolve_tool(registry, tool_name) ++ tool_dict[tool_name] = await resolve_tool(registry, tool_name) + + revised_model_message = message._original_message.model_copy(deep=True) + +@@ -646,7 +642,7 @@ async def _resolve_tool_request(tool: Action, tool_request_part: ToolRequestPart + raise e + + +-def resolve_tool(registry: Registry, tool_name: str): ++async def resolve_tool(registry: Registry, tool_name: str): + """Resolve a tool by name from the registry. + + Args: +@@ -659,7 +655,7 @@ def resolve_tool(registry: Registry, tool_name: str): + Raises: + ValueError: If the tool could not be resolved. + """ +- return registry.lookup_action(kind=ActionKind.TOOL, name=tool_name) ++ return await registry.resolve_action(kind=ActionKind.TOOL, name=tool_name) + + + async def _resolve_resume_options( +diff --git a/py/packages/genkit/src/genkit/blocks/model.py b/py/packages/genkit/src/genkit/blocks/model.py +index 62dff9eb8..4b4912354 100644 +--- a/py/packages/genkit/src/genkit/blocks/model.py ++++ b/py/packages/genkit/src/genkit/blocks/model.py +@@ -72,7 +72,7 @@ def model( + metadata: dict[str, Any] | None = None, + info: ModelInfo | None = None, + description: str | None = None, +-) -> 'Action': ++) -> Action: + """Create a model action WITHOUT registering it. + + This is the v2 API for creating models. Unlike ai.define_model(), +@@ -102,8 +102,8 @@ def model( + >>> def my_generate(request: GenerateRequest, ctx: ActionRunContext): + ... return GenerateResponse(...) + >>> +- >>> action = model(name="my-model", fn=my_generate) +- >>> response = await action.arun({"messages": [...]}) ++ >>> action = model(name='my-model', fn=my_generate) ++ >>> response = await action.arun({'messages': [...]}) + + Note: + This function extracts the "create action" logic from +@@ -236,12 +236,28 @@ class GenerateResponseWrapper(GenerateResponse): + request: The GenerateRequest object associated with the response. + message_parser: An optional function to parse the output from the message. + """ ++ # If message is not returned by generate response, try to infer ++ # message from the first candidate. ++ response_message = response.message ++ if response_message is None and response.candidates: ++ response_message = response.candidates[0].message ++ if response_message is None: ++ raise ValueError('GenerateResponse must include either `message` or at least one candidate message.') ++ ++ finish_reason = response.finish_reason ++ if finish_reason is None and response.candidates: ++ finish_reason = response.candidates[0].finish_reason ++ ++ finish_message = response.finish_message ++ if finish_message is None and response.candidates: ++ finish_message = response.candidates[0].finish_message ++ + super().__init__( +- message=MessageWrapper(response.message) +- if not isinstance(response.message, MessageWrapper) +- else response.message, +- finish_reason=response.finish_reason, +- finish_message=response.finish_message, ++ message=MessageWrapper(response_message) ++ if not isinstance(response_message, MessageWrapper) ++ else response_message, ++ finish_reason=finish_reason, ++ finish_message=finish_message, + latency_ms=response.latency_ms, + usage=response.usage if response.usage is not None else GenerationUsage(), + custom=response.custom if response.custom is not None else {}, +@@ -529,10 +545,7 @@ def model_action_metadata( + + + def model_ref(name: str, namespace: str | None = None, **options: Any) -> ModelReference: +- """ +- The factory function equivalent to export function modelRef(...) +- """ +- ++ """The factory function equivalent to export function modelRef(...).""" + # Logic: if (options.namespace && !name.startsWith(options.namespace + '/')) + if namespace and not name.startswith(f'{namespace}/'): + final_name = f'{namespace}/{name}' +diff --git a/py/packages/genkit/src/genkit/blocks/prompt.py b/py/packages/genkit/src/genkit/blocks/prompt.py +index fcb3d65fa..751855456 100644 +--- a/py/packages/genkit/src/genkit/blocks/prompt.py ++++ b/py/packages/genkit/src/genkit/blocks/prompt.py +@@ -27,7 +27,7 @@ import weakref + from asyncio import Future + from collections.abc import AsyncIterator, Callable + from pathlib import Path +-from typing import Any ++from typing import Any, Awaitable + + import structlog + from dotpromptz.typing import ( +@@ -35,14 +35,14 @@ from dotpromptz.typing import ( + PromptFunction, + PromptInputConfig, + PromptMetadata, +- ToolDefinition as DotPromptzToolDefinition, + ) +-from pydantic import BaseModel ++from pydantic import BaseModel, ConfigDict + +-from genkit.aio import Channel ++from genkit.aio import Channel, ensure_async + from genkit.blocks.generate import ( + StreamingCallback as ModelStreamingCallback, + generate_action, ++ to_tool_definition, + ) + from genkit.blocks.model import ( + GenerateResponseChunkWrapper, +@@ -83,14 +83,16 @@ class PromptCache: + class PromptConfig(BaseModel): + """Model for a prompt action.""" + ++ model_config = ConfigDict(arbitrary_types_allowed=True) ++ + variant: str | None = None + model: str | None = None + config: GenerationCommonConfig | dict[str, Any] | None = None + description: str | None = None + input_schema: type | dict[str, Any] | None = None +- system: str | Part | list[Part] | None = None +- prompt: str | Part | list[Part] | None = None +- messages: str | list[Message] | None = None ++ system: str | Part | list[Part] | Callable | None = None ++ prompt: str | Part | list[Part] | Callable | None = None ++ messages: str | list[Message] | Callable | None = None + output_format: str | None = None + output_content_type: str | None = None + output_instructions: bool | str | None = None +@@ -102,7 +104,7 @@ class PromptConfig(BaseModel): + tools: list[str] | None = None + tool_choice: ToolChoice | None = None + use: list[ModelMiddleware] | None = None +- docs: list[DocumentData] | None = None ++ docs: list[DocumentData] | Callable | None = None + tool_responses: list[Part] | None = None + + +@@ -117,9 +119,9 @@ class ExecutablePrompt: + config: GenerationCommonConfig | dict[str, Any] | None = None, + description: str | None = None, + input_schema: type | dict[str, Any] | None = None, +- system: str | Part | list[Part] | None = None, +- prompt: str | Part | list[Part] | None = None, +- messages: str | list[Message] | None = None, ++ system: str | Part | list[Part] | Callable | None = None, ++ prompt: str | Part | list[Part] | Callable | None = None, ++ messages: str | list[Message] | Callable | None = None, + output_format: str | None = None, + output_content_type: str | None = None, + output_instructions: bool | str | None = None, +@@ -131,6 +133,7 @@ class ExecutablePrompt: + tools: list[str] | None = None, + tool_choice: ToolChoice | None = None, + use: list[ModelMiddleware] | None = None, ++ docs: list[DocumentData] | Callable | None = None, + _name: str | None = None, # prompt name for action lookup + _ns: str | None = None, # namespace for action lookup + _prompt_action: Action | None = None, # reference to PROMPT action +@@ -160,6 +163,7 @@ class ExecutablePrompt: + tools: A list of tool names to use with the prompt. + tool_choice: The tool choice strategy. + use: A list of model middlewares to apply. ++ docs: A list of documents to be used for grounding. + """ + self._registry = registry + self._variant = variant +@@ -181,11 +185,23 @@ class ExecutablePrompt: + self._tools = tools + self._tool_choice = tool_choice + self._use = use ++ self._docs = docs + self._cache_prompt = PromptCache() + self._name = _name # Store name/ns for action lookup (used by as_tool()) + self._ns = _ns + self._prompt_action = _prompt_action + ++ @property ++ def ref(self) -> dict[str, Any]: ++ """Returns a reference object for this prompt. ++ ++ The reference object contains the prompt's name and metadata. ++ """ ++ return { ++ 'name': registry_definition_key(self._name, self._variant, self._ns) if self._name else None, ++ 'metadata': self._metadata, ++ } ++ + async def __call__( + self, + input: Any | None = None, +@@ -281,6 +297,7 @@ class ExecutablePrompt: + output_constrained=self._output_constrained, + input_schema=self._input_schema, + metadata=self._metadata, ++ docs=self._docs, + ) + + model = options.model or self._registry.default_model +@@ -330,7 +347,7 @@ class ExecutablePrompt: + tool_choice=options.tool_choice, + output=output, + max_turns=options.max_turns, +- docs=options.docs, ++ docs=await render_docs(input, options, context), + resume=resume, + ) + +@@ -352,7 +369,7 @@ class ExecutablePrompt: + + lookup_key = registry_lookup_key(self._name, self._variant, self._ns) + +- action = self._registry.lookup_action_by_key(lookup_key) ++ action = await self._registry.aresolve_action_by_key(lookup_key) + + if action is None or action.kind != ActionKind.PROMPT: + raise GenkitError( +@@ -365,14 +382,15 @@ class ExecutablePrompt: + + def define_prompt( + registry: Registry, ++ name: str | None = None, + variant: str | None = None, + model: str | None = None, + config: GenerationCommonConfig | dict[str, Any] | None = None, + description: str | None = None, + input_schema: type | dict[str, Any] | None = None, +- system: str | Part | list[Part] | None = None, +- prompt: str | Part | list[Part] | None = None, +- messages: str | list[Message] | None = None, ++ system: str | Part | list[Part] | Callable | None = None, ++ prompt: str | Part | list[Part] | Callable | None = None, ++ messages: str | list[Message] | Callable | None = None, + output_format: str | None = None, + output_content_type: str | None = None, + output_instructions: bool | str | None = None, +@@ -381,16 +399,16 @@ def define_prompt( + max_turns: int | None = None, + return_tool_requests: bool | None = None, + metadata: dict[str, Any] | None = None, +- tools: Tools | None = None, ++ tools: list[str] | None = None, + tool_choice: ToolChoice | None = None, + use: list[ModelMiddleware] | None = None, +- # TODO: +- # docs: list[Document] ++ docs: list[DocumentData] | Callable | None = None, + ) -> ExecutablePrompt: + """Defines an executable prompt. + + Args: + registry: The registry to use for resolving models and tools. ++ name: The name of the prompt. + variant: The variant of the prompt. + model: The model to use for generation. + config: The generation configuration. +@@ -410,11 +428,12 @@ def define_prompt( + tools: A list of tool names to use with the prompt. + tool_choice: The tool choice strategy. + use: A list of model middlewares to apply. ++ docs: A list of documents to be used for grounding. + + Returns: + An ExecutablePrompt instance. + """ +- return ExecutablePrompt( ++ executable_prompt = ExecutablePrompt( + registry, + variant=variant, + model=model, +@@ -435,30 +454,59 @@ def define_prompt( + tools=tools, + tool_choice=tool_choice, + use=use, ++ docs=docs, ++ _name=name, + ) + ++ if name: ++ # Register actions for this prompt ++ action_metadata = { ++ 'type': 'prompt', ++ 'source': 'programmatic', ++ 'prompt': { ++ 'name': name, ++ 'variant': variant or '', ++ }, ++ } ++ ++ async def prompt_action_fn(input: Any = None) -> GenerateRequest: ++ """PROMPT action function - renders prompt and returns GenerateRequest.""" ++ options = await executable_prompt.render(input=input) ++ return await to_generate_request(registry, options) ++ ++ async def executable_prompt_action_fn(input: Any = None) -> GenerateActionOptions: ++ """EXECUTABLE_PROMPT action function - renders prompt and returns GenerateActionOptions.""" ++ return await executable_prompt.render(input=input) ++ ++ action_name = registry_definition_key(name, variant) ++ prompt_action = registry.register_action( ++ kind=ActionKind.PROMPT, ++ name=action_name, ++ fn=prompt_action_fn, ++ metadata=action_metadata, ++ ) ++ ++ executable_prompt_action = registry.register_action( ++ kind=ActionKind.EXECUTABLE_PROMPT, ++ name=action_name, ++ fn=executable_prompt_action_fn, ++ metadata=action_metadata, ++ ) ++ ++ # Link them ++ executable_prompt._prompt_action = prompt_action ++ prompt_action._executable_prompt = weakref.ref(executable_prompt) ++ executable_prompt_action._executable_prompt = weakref.ref(executable_prompt) ++ ++ return executable_prompt ++ + + async def to_generate_action_options(registry: Registry, options: PromptConfig) -> GenerateActionOptions: + """Converts the given parameters to a GenerateActionOptions object. + + Args: + registry: The registry to use for resolving models and tools. +- model: The model to use for generation. +- prompt: The user prompt. +- system: The system message for the prompt. +- messages: A list of messages to include in the prompt. +- tools: A list of tool names to use with the prompt. +- return_tool_requests: Whether to return tool requests. +- tool_choice: The tool choice strategy. +- tool_responses: tool response parts corresponding to interrupts. +- config: The generation configuration. +- max_turns: The maximum number of turns in a conversation. +- output_format: The output format. +- output_content_type: The output content type. +- output_instructions: Instructions for formatting the output. +- output_schema: The output schema. +- output_constrained: Whether the output should be constrained to the output schema. +- docs: A list of documents to be used for grounding. ++ options: The prompt configuration. + + Returns: + A GenerateActionOptions object. +@@ -466,13 +514,17 @@ async def to_generate_action_options(registry: Registry, options: PromptConfig) + model = options.model or registry.default_model + if model is None: + raise Exception('No model configured.') ++ ++ cache = PromptCache() + resolved_msgs: list[Message] = [] + if options.system: +- resolved_msgs.append(Message(role=Role.SYSTEM, content=_normalize_prompt_arg(options.system))) ++ result = await render_system_prompt(registry, None, options, cache) ++ resolved_msgs.append(result) + if options.messages: +- resolved_msgs += options.messages ++ resolved_msgs.extend(await render_message_prompt(registry, None, options, cache)) + if options.prompt: +- resolved_msgs.append(Message(role=Role.USER, content=_normalize_prompt_arg(options.prompt))) ++ result = await render_user_prompt(registry, None, options, cache) ++ resolved_msgs.append(result) + + # If is schema is set but format is not explicitly set, default to + # `json` format. +@@ -506,7 +558,7 @@ async def to_generate_action_options(registry: Registry, options: PromptConfig) + tool_choice=options.tool_choice, + output=output, + max_turns=options.max_turns, +- docs=options.docs, ++ docs=await render_docs(None, options), + resume=resume, + ) + +@@ -532,11 +584,10 @@ async def to_generate_request(registry: Registry, options: GenerateActionOptions + the registry. + GenkitError: If the options do not contain any messages. + """ +- + tools: list[Action] = [] + if options.tools: + for tool_name in options.tools: +- tool_action = registry.lookup_action(ActionKind.TOOL, tool_name) ++ tool_action = await registry.resolve_action(ActionKind.TOOL, tool_name) + if tool_action is None: + raise GenkitError(status='NOT_FOUND', message=f'Unable to resolve tool {tool_name}') + tools.append(tool_action) +@@ -614,7 +665,6 @@ async def render_system_prompt( + Message: A Message object containing the rendered system prompt with Role.SYSTEM + + """ +- + if isinstance(options.system, str): + if prompt_cache.system is None: + prompt_cache.system = await registry.dotprompt.compile(options.system) +@@ -630,12 +680,16 @@ async def render_system_prompt( + input, + PromptMetadata( + input=PromptInputConfig( +- schema=options.input_schema, ++ schema=to_json_schema(options.input_schema) if options.input_schema else None, + ) + ), + ), + ) + ++ if callable(options.system): ++ resolved = await ensure_async(options.system)(input, context) ++ return Message(role=Role.SYSTEM, content=_normalize_prompt_arg(resolved)) ++ + return Message(role=Role.SYSTEM, content=_normalize_prompt_arg(options.system)) + + +@@ -687,8 +741,7 @@ async def render_message_prompt( + prompt_cache: PromptCache, + context: dict[str, Any] | None = None, + ) -> list[Message]: +- """ +- Render a message prompt using a given registry, input data, options, and a context. ++ """Render a message prompt using a given registry, input data, options, and a context. + + This function processes different types of message options (string or list) to render + appropriate messages using a prompt registry and cache. If the `messages` option is of type +@@ -727,14 +780,22 @@ async def render_message_prompt( + context=context, + messages=messages_, + ), +- options=PromptMetadata(input=PromptInputConfig()), ++ options=PromptMetadata( ++ input=PromptInputConfig( ++ schema=to_json_schema(options.input_schema) if options.input_schema else None, ++ ) ++ ), + ) + return [Message.model_validate(e.model_dump()) for e in rendered.messages] + + elif isinstance(options.messages, list): + return options.messages + +- return [Message(role=Role.USER, content=_normalize_prompt_arg(options.prompt))] ++ elif callable(options.messages): ++ resolved = await ensure_async(options.messages)(input, context) ++ return resolved ++ ++ raise TypeError(f'Unsupported type for messages: {type(options.messages)}') + + + async def render_user_prompt( +@@ -744,8 +805,7 @@ async def render_user_prompt( + prompt_cache: PromptCache, + context: dict[str, Any] | None = None, + ) -> Message: +- """ +- Asynchronously renders a user prompt based on the given input, context, and options, ++ """Asynchronously renders a user prompt based on the given input, context, and options, + utilizing a pre-compiled or dynamically compiled dotprompt template. + + Arguments: +@@ -774,13 +834,45 @@ async def render_user_prompt( + context, + prompt_cache.user_prompt, + input, +- PromptMetadata(input=PromptInputConfig()), ++ PromptMetadata( ++ input=PromptInputConfig( ++ schema=to_json_schema(options.input_schema) if options.input_schema else None, ++ ) ++ ), + ), + ) + ++ if callable(options.prompt): ++ resolved = await ensure_async(options.prompt)(input, context) ++ return Message(role=Role.USER, content=_normalize_prompt_arg(resolved)) ++ + return Message(role=Role.USER, content=_normalize_prompt_arg(options.prompt)) + + ++async def render_docs( ++ input: dict[str, Any], ++ options: PromptConfig, ++ context: dict[str, Any] | None = None, ++) -> list[DocumentData] | None: ++ """Renders the docs for a prompt action. ++ ++ Args: ++ input: Dictionary of input values. ++ options: Configuration options for the prompt. ++ context: Optional dictionary of context values. ++ ++ Returns: ++ A list of DocumentData objects or None. ++ """ ++ if options.docs is None: ++ return None ++ ++ if callable(options.docs): ++ return await ensure_async(options.docs)(input, context) ++ ++ return options.docs ++ ++ + def registry_definition_key(name: str, variant: str | None = None, ns: str | None = None) -> str: + """Generate a registry definition key for a prompt. + +@@ -888,7 +980,7 @@ def load_prompt(registry: Registry, path: Path, filename: str, prefix: str = '', + file_path = path / filename + + # Read the prompt file +- with open(file_path, 'r', encoding='utf-8') as f: ++ with open(file_path, encoding='utf-8') as f: + source = f.read() + + # Parse the prompt +@@ -996,7 +1088,7 @@ def load_prompt(registry: Registry, path: Path, filename: str, prefix: str = '', + # Store reference to PROMPT action on the ExecutablePrompt + # Actions are already registered at this point (lazy loading happens after registration) + lookup_key = registry_lookup_key(name, variant, ns) +- prompt_action = registry.lookup_action_by_key(lookup_key) ++ prompt_action = await registry.aresolve_action_by_key(lookup_key) + if prompt_action and prompt_action.kind == ActionKind.PROMPT: + executable_prompt._prompt_action = prompt_action + # Also store ExecutablePrompt reference on the action +@@ -1092,7 +1184,7 @@ def load_prompt_folder_recursively(registry: Registry, dir_path: Path, ns: str, + if entry.name.startswith('_'): + # This is a partial + partial_name = entry.name[1:-7] # Remove "_" prefix and ".prompt" suffix +- with open(entry.path, 'r', encoding='utf-8') as f: ++ with open(entry.path, encoding='utf-8') as f: + source = f.read() + + # Strip frontmatter if present +@@ -1160,14 +1252,14 @@ async def lookup_prompt(registry: Registry, name: str, variant: str | None = Non + # Use create_action_key to build the full key: "/prompt/" + definition_key = registry_definition_key(name, variant, None) + lookup_key = create_action_key(ActionKind.PROMPT, definition_key) +- action = registry.lookup_action_by_key(lookup_key) ++ action = await registry.aresolve_action_by_key(lookup_key) + + # If not found and no namespace was specified, try with default 'dotprompt' namespace + # (for file-based prompts) + if not action: + definition_key = registry_definition_key(name, variant, 'dotprompt') + lookup_key = create_action_key(ActionKind.PROMPT, definition_key) +- action = registry.lookup_action_by_key(lookup_key) ++ action = await registry.aresolve_action_by_key(lookup_key) + + if action: + # First check if we've stored the ExecutablePrompt directly +@@ -1227,5 +1319,4 @@ async def prompt( + Raises: + GenkitError: If the prompt is not found. + """ +- + return await lookup_prompt(registry, name, variant) +diff --git a/py/packages/genkit/src/genkit/blocks/reranker.py b/py/packages/genkit/src/genkit/blocks/reranker.py +new file mode 100644 +index 000000000..cb22806e5 +--- /dev/null ++++ b/py/packages/genkit/src/genkit/blocks/reranker.py +@@ -0,0 +1,440 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++"""Reranker type definitions for the Genkit framework. ++ ++Rerankers and Two-Stage Retrieval ++================================= ++ ++A **reranking model** (also known as a cross-encoder) is a type of model that, ++given a query and document, outputs a similarity score. This score is used to ++reorder documents by relevance to the query. ++ ++Reranker APIs take a list of documents (e.g., the output of a retriever) and ++reorder them based on their relevance to the query. This step can be useful ++for fine-tuning results and ensuring the most pertinent information is used ++in the prompt provided to a generative model. ++ ++Two-Stage Retrieval ++------------------- ++ ++In a typical RAG (Retrieval-Augmented Generation) pipeline: ++ ++1. **Stage 1 - Retrieval**: A retriever fetches a large set of candidate ++ documents using fast vector similarity search. ++2. **Stage 2 - Reranking**: A reranker scores and reorders these candidates ++ using more expensive but accurate cross-encoder models. ++ ++This two-stage approach balances speed and accuracy: ++- Retrievers are fast but may not perfectly rank results ++- Rerankers are slower but provide superior relevance scoring ++ ++Usage Example ++------------- ++ ++Using an existing reranker (e.g., Vertex AI): ++ ++.. code-block:: python ++ ++ from genkit.ai import Genkit ++ ++ ai = Genkit(plugins=[...]) ++ ++ ++ @ai.flow() ++ async def rerank_flow(query: str): ++ documents = [ ++ Document.from_text('pythagorean theorem'), ++ Document.from_text('quantum mechanics'), ++ Document.from_text('pizza'), ++ ] ++ ++ reranked = await ai.rerank( ++ reranker='vertexai/semantic-ranker-512', ++ query=query, ++ documents=documents, ++ ) ++ ++ return [{'text': doc.text(), 'score': doc.score} for doc in reranked] ++ ++Custom Rerankers ++---------------- ++ ++You can define custom rerankers for specific use cases: ++ ++.. code-block:: python ++ ++ from genkit.ai import Genkit ++ from genkit.core.typing import ( ++ RerankerResponse, ++ RankedDocumentData, ++ RankedDocumentMetadata, ++ ) ++ ++ ai = Genkit() ++ ++ ++ async def custom_reranker_fn(query, documents, options): ++ # Your custom reranking logic here ++ # Example: score by keyword overlap ++ query_words = set(query.text().lower().split()) ++ scored = [] ++ for doc in documents: ++ doc_words = set(doc.text().lower().split()) ++ overlap = len(query_words & doc_words) ++ score = overlap / max(len(query_words), 1) ++ scored.append((doc, score)) ++ ++ # Sort by score descending and take top k ++ k = options.get('k', 3) if options else 3 ++ scored.sort(key=lambda x: x[1], reverse=True) ++ top_k = scored[:k] ++ ++ return RerankerResponse( ++ documents=[ ++ RankedDocumentData(content=doc.content, metadata=RankedDocumentMetadata(score=score)) ++ for doc, score in top_k ++ ] ++ ) ++ ++ ++ ai.define_reranker('custom/keyword-reranker', custom_reranker_fn) ++ ++ ++ # Use it in a flow ++ @ai.flow() ++ async def search_flow(query: str): ++ docs = await ai.retrieve(retriever='my-retriever', query=query) ++ return await ai.rerank(reranker='custom/keyword-reranker', query=query, documents=docs, options={'k': 5}) ++""" ++ ++from collections.abc import Awaitable, Callable ++from typing import Any, TypeVar, Union ++ ++from pydantic import BaseModel, ConfigDict, Field ++ ++from genkit.blocks.document import Document ++from genkit.core.action import Action, ActionMetadata ++from genkit.core.action.types import ActionKind ++from genkit.core.registry import Registry ++from genkit.core.schema import to_json_schema ++from genkit.core.typing import ( ++ DocumentData, ++ DocumentPart, ++ RankedDocumentData, ++ RankedDocumentMetadata, ++ RerankerRequest, ++ RerankerResponse, ++) ++ ++T = TypeVar('T') ++ ++# Type alias for reranker function ++RerankerFn = Callable[[Document, list[Document], T], Awaitable[RerankerResponse]] ++ ++ ++class RankedDocument(Document): ++ """A document with a relevance score from reranking. ++ ++ This class extends Document to include a score property that represents ++ the document's relevance to a query as determined by a reranker. ++ """ ++ ++ def __init__( ++ self, ++ content: list[DocumentPart], ++ metadata: dict[str, Any] | None = None, ++ score: float | None = None, ++ ) -> None: ++ """Initializes a RankedDocument object. ++ ++ Args: ++ content: A list of DocumentPart objects representing the document's content. ++ metadata: An optional dictionary containing metadata about the document. ++ score: The relevance score from reranking. ++ """ ++ md = metadata.copy() if metadata else {} ++ if score is not None: ++ md['score'] = score ++ super().__init__(content=content, metadata=md) ++ ++ @property ++ def score(self) -> float | None: ++ """Returns the relevance score of the document. ++ ++ Returns: ++ The relevance score as a float, or None if not set. ++ """ ++ if self.metadata and 'score' in self.metadata: ++ return self.metadata['score'] ++ return None ++ ++ @staticmethod ++ def from_ranked_document_data(data: RankedDocumentData) -> 'RankedDocument': ++ """Constructs a RankedDocument from RankedDocumentData. ++ ++ Args: ++ data: The RankedDocumentData containing content, metadata with score. ++ ++ Returns: ++ A new RankedDocument instance. ++ """ ++ return RankedDocument( ++ content=data.content, ++ metadata=data.metadata.model_dump(), ++ score=data.metadata.score, ++ ) ++ ++ ++class RerankerSupports(BaseModel): ++ """Reranker capability support.""" ++ ++ model_config = ConfigDict(extra='forbid', populate_by_name=True) ++ ++ media: bool | None = None ++ ++ ++class RerankerInfo(BaseModel): ++ """Information about a reranker's capabilities.""" ++ ++ model_config = ConfigDict(extra='forbid', populate_by_name=True) ++ ++ label: str | None = None ++ supports: RerankerSupports | None = None ++ ++ ++class RerankerOptions(BaseModel): ++ """Configuration options for a reranker.""" ++ ++ model_config = ConfigDict(extra='forbid', populate_by_name=True) ++ ++ config_schema: dict[str, Any] | None = Field(None, alias='configSchema') ++ label: str | None = None ++ supports: RerankerSupports | None = None ++ ++ ++class RerankerRef(BaseModel): ++ """Reference to a reranker with configuration. ++ ++ Used to reference a reranker by name with optional configuration ++ and version information. ++ """ ++ ++ model_config = ConfigDict(extra='forbid', populate_by_name=True) ++ ++ name: str ++ config: Any | None = None ++ version: str | None = None ++ info: RerankerInfo | None = None ++ ++ ++def reranker_action_metadata( ++ name: str, ++ options: RerankerOptions | None = None, ++) -> ActionMetadata: ++ """Creates action metadata for a reranker. ++ ++ Args: ++ name: The name of the reranker. ++ options: Optional configuration options for the reranker. ++ ++ Returns: ++ An ActionMetadata instance for the reranker. ++ """ ++ options = options if options is not None else RerankerOptions() ++ reranker_metadata_dict: dict[str, Any] = {'reranker': {}} ++ ++ if options.label: ++ reranker_metadata_dict['reranker']['label'] = options.label ++ ++ if options.supports: ++ reranker_metadata_dict['reranker']['supports'] = options.supports.model_dump(exclude_none=True, by_alias=True) ++ ++ reranker_metadata_dict['reranker']['customOptions'] = options.config_schema if options.config_schema else None ++ ++ return ActionMetadata( ++ kind=ActionKind.RERANKER, ++ name=name, ++ input_json_schema=to_json_schema(RerankerRequest), ++ output_json_schema=to_json_schema(RerankerResponse), ++ metadata=reranker_metadata_dict, ++ ) ++ ++ ++def create_reranker_ref( ++ name: str, ++ config: dict[str, Any] | None = None, ++ version: str | None = None, ++ info: RerankerInfo | None = None, ++) -> RerankerRef: ++ """Creates a RerankerRef instance. ++ ++ Args: ++ name: The name of the reranker. ++ config: Optional configuration for the reranker. ++ version: Optional version string. ++ info: Optional RerankerInfo with capability information. ++ ++ Returns: ++ A RerankerRef instance. ++ """ ++ return RerankerRef(name=name, config=config, version=version, info=info) ++ ++ ++def define_reranker( ++ registry: Registry, ++ name: str, ++ fn: RerankerFn, ++ options: RerankerOptions | None = None, ++) -> Action: ++ """Defines and registers a reranker action. ++ ++ Creates a reranker action from the provided function and registers it ++ in the given registry. ++ ++ Args: ++ registry: The registry to register the reranker in. ++ name: The name of the reranker. ++ fn: The reranker function that implements the reranking logic. ++ options: Optional configuration options for the reranker. ++ ++ Returns: ++ The registered Action instance. ++ ++ Example: ++ >>> async def my_reranker(query, documents, options): ++ ... # Score and sort documents ++ ... scored = [(doc, score_doc(query, doc)) for doc in documents] ++ ... scored.sort(key=lambda x: x[1], reverse=True) ++ ... return RerankerResponse( ++ ... documents=[ ++ ... RankedDocumentData(content=doc.content, metadata=RankedDocumentMetadata(score=score)) ++ ... for doc, score in scored ++ ... ] ++ ... ) ++ >>> define_reranker(registry, 'my-reranker', my_reranker) ++ """ ++ metadata = reranker_action_metadata(name, options) ++ ++ async def wrapper( ++ request: RerankerRequest, ++ _ctx: Any, ++ ) -> RerankerResponse: ++ query_doc = Document.from_document_data(request.query) ++ documents = [Document.from_document_data(d) for d in request.documents] ++ return await fn(query_doc, documents, request.options) ++ ++ return registry.register_action( ++ kind=ActionKind.RERANKER, ++ name=name, ++ fn=wrapper, ++ metadata=metadata.metadata, ++ span_metadata=metadata.metadata, ++ ) ++ ++ ++# Type for reranker argument (can be action, reference, or string name) ++RerankerArgument = Union[Action, RerankerRef, str] ++ ++ ++class RerankerParams(BaseModel): ++ """Parameters for the rerank function. ++ ++ Attributes: ++ reranker: The reranker to use (action, reference, or name string). ++ query: The query to rank documents against. ++ documents: The list of documents to rerank. ++ options: Optional configuration options for this rerank call. ++ """ ++ ++ model_config = ConfigDict(extra='forbid', populate_by_name=True, arbitrary_types_allowed=True) ++ ++ reranker: RerankerArgument ++ query: str | DocumentData ++ documents: list[DocumentData] ++ options: Any | None = None ++ ++ ++async def rerank( ++ registry: Registry, ++ params: RerankerParams | dict[str, Any], ++) -> list[RankedDocument]: ++ """Reranks documents based on the provided query using a reranker. ++ ++ This function takes a query and a list of documents, and returns the ++ documents reordered by relevance to the query as determined by the ++ specified reranker. ++ ++ Args: ++ registry: The registry to look up the reranker in. ++ params: Parameters for the rerank operation + including the reranker, ++ query, documents, and optional configuration. ++ ++ Returns: ++ A list of RankedDocument objects sorted by relevance. ++ ++ Raises: ++ ValueError: If the reranker cannot be resolved. ++ ++ Example: ++ >>> ranked_docs = await rerank( ++ ... registry, ++ ... { ++ ... 'reranker': 'my-reranker', ++ ... 'query': 'What is machine learning?', ++ ... 'documents': [doc1, doc2, doc3], ++ ... }, ++ ... ) ++ >>> for doc in ranked_docs: ++ ... print(f'Score: {doc.score}, Text: {doc.text()}') ++ """ ++ # Convert dict to RerankerParams if needed ++ if isinstance(params, dict): ++ params = RerankerParams(**params) ++ ++ # Resolve the reranker action ++ reranker_action: Action | None = None ++ ++ if isinstance(params.reranker, str): ++ reranker_action = registry.lookup_action(ActionKind.RERANKER, params.reranker) ++ elif isinstance(params.reranker, RerankerRef): ++ reranker_action = registry.lookup_action(ActionKind.RERANKER, params.reranker.name) ++ elif isinstance(params.reranker, Action): ++ reranker_action = params.reranker ++ ++ if reranker_action is None: ++ raise ValueError(f'Unable to resolve reranker: {params.reranker}') ++ ++ # Convert query to DocumentData if it's a string ++ query_data: DocumentData ++ if isinstance(params.query, str): ++ query_data = Document.from_text(params.query) ++ else: ++ query_data = params.query ++ ++ # Build the request ++ request = RerankerRequest( ++ query=query_data, ++ documents=params.documents, ++ options=params.options, ++ ) ++ ++ # Call the reranker ++ action_response = await reranker_action.arun(request) ++ response: RerankerResponse = action_response.response ++ ++ # Convert response to RankedDocument list ++ return [RankedDocument.from_ranked_document_data(doc) for doc in response.documents] +diff --git a/py/packages/genkit/src/genkit/blocks/resource.py b/py/packages/genkit/src/genkit/blocks/resource.py +new file mode 100644 +index 000000000..8e5e398dd +--- /dev/null ++++ b/py/packages/genkit/src/genkit/blocks/resource.py +@@ -0,0 +1,398 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++ ++"""Resource module for defining and managing resources. ++Resources in Genkit represent addressable content or data processing units containing ++unstructured data (Post, PDF, etc.) that can be retrieved or generated. They are ++identified by URIs (e.g. `file://`, `http://`, `gs://`) and can be static (fixed URI) ++or dynamic (using URI templates). ++This module provides tools to define resource actions that can resolve these URIs ++and return content (`ResourceOutput`) containing `Part`s. ++""" ++ ++import inspect ++import re ++from collections.abc import Awaitable, Callable ++from typing import Any, Protocol, TypedDict ++ ++from pydantic import BaseModel ++ ++from genkit.aio import ensure_async ++from genkit.core.action import Action, ActionRunContext ++from genkit.core.action.types import ActionKind ++from genkit.core.registry import Registry ++from genkit.core.typing import Metadata, Part ++ ++ ++class ResourceOptions(TypedDict, total=False): ++ """Options for defining a resource. ++ ++ Attributes: ++ name: Resource name. If not specified, uri or template will be used as name. ++ uri: The URI of the resource. Can contain template variables for simple matches, ++ but `template` is preferred for pattern matching. ++ template: The URI template (ex. `my://resource/{id}`). See RFC6570 for specification. ++ Used for matching variable resources. ++ description: A description of the resource, used for documentation and discovery. ++ metadata: Arbitrary metadata to attach to the resource action. ++ """ ++ ++ name: str ++ uri: str ++ template: str ++ description: str ++ metadata: dict[str, Any] ++ ++ ++class ResourceInput(BaseModel): ++ """Input structure for a resource request. ++ ++ Attributes: ++ uri: The full URI being requested/resolved. ++ """ ++ ++ uri: str ++ ++ ++class ResourceOutput(BaseModel): ++ """Output structure from a resource resolution. ++ ++ Attributes: ++ content: A list of `Part` objects representing the resource content. ++ """ ++ ++ content: list[Part] ++ ++ ++class ResourceFn(Protocol): ++ """A function that returns parts for a given resource. ++ The function receives the resolved input (including the URI) and context, ++ and should return a `ResourceOutput` containing the content parts. ++ """ ++ ++ def __call__(self, input: ResourceInput, ctx: ActionRunContext) -> Awaitable[ResourceOutput]: ... ++ ++ ++ResourceArgument = Action | str ++ ++ ++async def resolve_resources(registry: Registry, resources: list[ResourceArgument] | None = None) -> list[Action]: ++ """Resolves a list of resource names or actions into a list of Action objects. ++ ++ Args: ++ registry: The registry to lookup resources in. ++ resources: A list of resource references, which can be either direct `Action` ++ objects or strings (names/URIs). ++ ++ Returns: ++ A list of resolved `Action` objects. ++ ++ Raises: ++ ValueError: If a resource reference is invalid or cannot be found. ++ """ ++ if not resources: ++ return [] ++ ++ resolved_actions = [] ++ for ref in resources: ++ if isinstance(ref, str): ++ resolved_actions.append(await lookup_resource_by_name(registry, ref)) ++ elif isinstance(ref, Action): ++ resolved_actions.append(ref) ++ else: ++ raise ValueError('Resources must be strings or actions') ++ return resolved_actions ++ ++ ++async def lookup_resource_by_name(registry: Registry, name: str) -> Action: ++ """Looks up a resource action by name in the registry. ++ Tries to resolve the name directly, or with common prefixes like `/resource/` ++ or `/dynamic-action-provider/`. ++ ++ Args: ++ registry: The registry to search. ++ name: The name or URI of the resource to lookup. ++ ++ Returns: ++ The found `Action`. ++ ++ Raises: ++ ValueError: If the resource cannot be found. ++ """ ++ resource = ( ++ registry.lookup_action(ActionKind.RESOURCE, name) ++ or registry.lookup_action(ActionKind.RESOURCE, f'/resource/{name}') ++ or registry.lookup_action(ActionKind.RESOURCE, f'/dynamic-action-provider/{name}') ++ ) ++ if not resource: ++ raise ValueError(f'Resource {name} not found') ++ return resource ++ ++ ++def define_resource(registry: Registry, opts: ResourceOptions, fn: ResourceFn) -> Action: ++ """Defines a resource and registers it with the given registry. ++ This creates a resource action that can handle requests for a specific URI ++ or URI template. ++ ++ Args: ++ registry: The registry to register the resource with. ++ opts: Options defining the resource (name, uri, template, etc.). ++ fn: The function that implements resource content retrieval. ++ ++ Returns: ++ The registered `Action` for the resource. ++ """ ++ action = dynamic_resource(opts, fn) ++ ++ action.matches = create_matcher(opts.get('uri'), opts.get('template')) ++ ++ # Mark as not dynamic since it's being registered ++ action.metadata['dynamic'] = False ++ ++ registry.register_action_from_instance(action) ++ ++ return action ++ ++ ++def resource(opts: ResourceOptions, fn: ResourceFn) -> Action: ++ """Defines a dynamic resource action without immediate registration. ++ This is an alias for `dynamic_resource`. Useful for defining resources that ++ might be registered later or used as standalone actions. ++ ++ Args: ++ opts: Options defining the resource. ++ fn: The resource implementation function. ++ ++ Returns: ++ The created `Action`. ++ """ ++ return dynamic_resource(opts, fn) ++ ++ ++def dynamic_resource(opts: ResourceOptions, fn: ResourceFn) -> Action: ++ """Defines a dynamic resource action. ++ Creates an `Action` of kind `RESOURCE` that wraps the provided function. ++ The wrapper handles: ++ 1. Input validation and matching against the URI/Template. ++ 2. Execution of the resource function. ++ 3. Post-processing of output to attach metadata (like parent resource info). ++ ++ Args: ++ opts: Options including `uri` or `template` for matching. ++ fn: The function performing the resource retrieval. ++ ++ Returns: ++ An `Action` configured as a resource. ++ ++ Raises: ++ ValueError: If neither `uri` nor `template` is provided in options. ++ """ ++ uri = opts.get('uri') or opts.get('template') ++ if not uri: ++ raise ValueError('must specify either uri or template options') ++ ++ matcher = create_matcher(opts.get('uri'), opts.get('template')) ++ ++ async def wrapped_fn(input_data: ResourceInput, ctx: ActionRunContext) -> ResourceOutput: ++ if isinstance(input_data, dict): ++ input_data = ResourceInput(**input_data) ++ ++ try: ++ template_match = matcher(input_data) ++ if not template_match: ++ raise ValueError(f'input {input_data} did not match template {uri}') ++ ++ sig = inspect.signature(fn) ++ afn = ensure_async(fn) ++ n_params = len(sig.parameters) ++ ++ if n_params == 0: ++ parts = await afn() ++ elif n_params == 1: ++ parts = await afn(input_data) ++ else: ++ parts = await afn(input_data, ctx) ++ ++ # Post-processing parts to add metadata ++ content_list = parts.content if hasattr(parts, 'content') else parts.get('content', []) ++ ++ for p in content_list: ++ if isinstance(p, Part): ++ p = p.root ++ ++ if hasattr(p, 'metadata'): ++ if p.metadata is None or isinstance(p.metadata, dict): ++ p.metadata = Metadata(root=p.metadata or {}) ++ ++ if isinstance(p.metadata, Metadata): ++ p_metadata = p.metadata.root ++ else: ++ p_metadata = p.metadata ++ ++ if 'resource' in p_metadata: ++ if 'parent' not in p_metadata['resource']: ++ p_metadata['resource']['parent'] = {'uri': input_data.uri} ++ if opts.get('template'): ++ p_metadata['resource']['parent']['template'] = opts.get('template') ++ else: ++ p_metadata['resource'] = {'uri': input_data.uri} ++ if opts.get('template'): ++ p_metadata['resource']['template'] = opts.get('template') ++ elif isinstance(p, dict): ++ if 'metadata' not in p or p['metadata'] is None: ++ p['metadata'] = {} ++ p_metadata = p['metadata'] ++ else: ++ continue ++ # Ensure we return a serializable dict (handling Pydantic models in list) ++ if isinstance(parts, BaseModel): ++ return parts.model_dump() ++ elif isinstance(parts, dict): ++ # Verify content items are dicts, if not dump them ++ if 'content' in parts: ++ parts['content'] = [p.model_dump() if isinstance(p, BaseModel) else p for p in parts['content']] ++ return parts ++ return parts ++ except Exception: ++ raise ++ ++ name = opts.get('name') or uri ++ ++ act = Action( ++ name=name, ++ kind=ActionKind.RESOURCE, ++ fn=wrapped_fn, ++ metadata={ ++ 'resource': { ++ 'uri': opts.get('uri'), ++ 'template': opts.get('template'), ++ }, ++ 'dynamic': True, ++ }, ++ description=opts.get('description'), ++ span_metadata={'genkit:metadata:resource:uri': uri}, ++ ) ++ act.matches = matcher ++ return act ++ ++ ++def create_matcher(uri: str | None, template: str | None) -> Callable[[ResourceInput], bool]: ++ """Creates a matching function for resource validation. ++ ++ Args: ++ uri: Optional fixed URI string. ++ template: Optional URI template string. ++ ++ Returns: ++ A callable that takes ResourceInput and returns True if it matches. ++ """ ++ ++ def matcher(input_data: ResourceInput) -> bool: ++ if uri: ++ return input_data.uri == uri ++ if template: ++ return matches_uri_template(template, input_data.uri) is not None ++ return False ++ ++ return matcher ++ ++ ++def is_dynamic_resource_action(action: Action) -> bool: ++ """Checks if an action is a dynamic resource (not registered). ++ ++ Args: ++ action: The action to check. ++ ++ Returns: ++ True if the action is a dynamic resource, False otherwise. ++ """ ++ return action.kind == ActionKind.RESOURCE and action.metadata.get('dynamic', True) ++ ++ ++def matches_uri_template(template: str, uri: str) -> dict[str, str] | None: ++ """Check if a URI matches a template and extract parameters. ++ ++ Args: ++ template: URI template with {param} placeholders (e.g., "file://{path}"). ++ uri: The URI to match against the template. ++ ++ Returns: ++ Dictionary of extracted parameters if match, None otherwise. ++ ++ Examples: ++ >>> matches_uri_template('file://{path}', 'file:///home/user/doc.txt') ++ {'path': '/home/user/doc.txt'} ++ >>> matches_uri_template('user://{id}/profile', 'user://123/profile') ++ {'id': '123'} ++ """ ++ # Split template into parts: text and {param} placeholders ++ parts = re.split(r'(\{[\w\+]+\})', template) ++ pattern_parts = [] ++ for part in parts: ++ if part.startswith('{') and part.endswith('}'): ++ param_name = part[1:-1] ++ if param_name.startswith('+'): ++ # Reserved expansion: {+var} matches reserved chars like / ++ param_name = param_name[1:] ++ pattern_parts.append(f'(?P<{param_name}>.+)') ++ else: ++ # Basic expansion: {var} does not match / ++ pattern_parts.append(f'(?P<{param_name}>[^/]+)') ++ else: ++ pattern_parts.append(re.escape(part)) ++ ++ pattern = f'^{"".join(pattern_parts)}$' ++ ++ match = re.search(pattern, uri) ++ if match: ++ return match.groupdict() ++ return None ++ ++ ++async def find_matching_resource( ++ registry: Registry, dynamic_resources: list[Action] | None, input_data: ResourceInput ++) -> Action | None: ++ """Finds a matching resource action. ++ Checks dynamic resources first, then the registry. ++ ++ Args: ++ registry: The registry to search. ++ dynamic_resources: Optional list of dynamic resource actions to check first. ++ input_data: The resource input containing the URI matched against. ++ ++ Returns: ++ The matching Action or None. ++ """ ++ if dynamic_resources: ++ for action in dynamic_resources: ++ if hasattr(action, 'matches') and action.matches(input_data): ++ return action ++ ++ # Try exact match in registry ++ resource = registry.lookup_action(ActionKind.RESOURCE, input_data.uri) ++ if resource: ++ return resource ++ ++ # Iterate all resources to check for matches (e.g. templates) ++ # This is less efficient but necessary for template matching if not optimized ++ resources = registry.get_actions_by_kind(ActionKind.RESOURCE) if hasattr(registry, 'get_actions_by_kind') else {} ++ if not resources and hasattr(registry, '_entries'): ++ # Fallback for compatibility if registry instance is old (unlikely in this context) ++ resources = registry._entries.get(ActionKind.RESOURCE, {}) ++ ++ for action in resources.values(): ++ if hasattr(action, 'matches') and action.matches(input_data): ++ return action ++ ++ return None +diff --git a/py/packages/genkit/src/genkit/core/action/_util.py b/py/packages/genkit/src/genkit/core/action/_util.py +index 3efe253e5..040baed08 100644 +--- a/py/packages/genkit/src/genkit/core/action/_util.py ++++ b/py/packages/genkit/src/genkit/core/action/_util.py +@@ -17,7 +17,6 @@ + """Action utility module for defining and managing action utilities.""" + + import inspect +-import typing + from typing import Any + + +diff --git a/py/packages/genkit/src/genkit/core/action/types.py b/py/packages/genkit/src/genkit/core/action/types.py +index 960928549..670ec5eee 100644 +--- a/py/packages/genkit/src/genkit/core/action/types.py ++++ b/py/packages/genkit/src/genkit/core/action/types.py +@@ -20,15 +20,15 @@ from __future__ import annotations + + import sys + from collections.abc import Callable +-from typing import Any, Awaitable, Dict, List, Literal, Protocol, Union +- +-from pydantic import BaseModel, ConfigDict, Field ++from typing import Any + + if sys.version_info < (3, 11): + from strenum import StrEnum + else: + from enum import StrEnum + ++from pydantic import BaseModel, ConfigDict, Field ++ + # Type alias for action name. + # type ActionName = str + ActionName = str +@@ -57,6 +57,7 @@ class ActionKind(StrEnum): + MODEL = 'model' + PROMPT = 'prompt' + RERANKER = 'reranker' ++ RESOURCE = 'resource' + RETRIEVER = 'retriever' + TOOL = 'tool' + UTIL = 'util' +diff --git a/py/packages/genkit/src/genkit/core/environment.py b/py/packages/genkit/src/genkit/core/environment.py +index 562360fbd..be69f801d 100644 +--- a/py/packages/genkit/src/genkit/core/environment.py ++++ b/py/packages/genkit/src/genkit/core/environment.py +@@ -16,7 +16,6 @@ + + """Convenience functionality to determine the running environment.""" + +-import enum + import os + import sys + +diff --git a/py/packages/genkit/src/genkit/core/flows.py b/py/packages/genkit/src/genkit/core/flows.py +index bccb983b8..562f031e8 100644 +--- a/py/packages/genkit/src/genkit/core/flows.py ++++ b/py/packages/genkit/src/genkit/core/flows.py +@@ -156,7 +156,7 @@ def create_flows_asgi_app( + + try: + # Look up the flow action. +- action = registry.lookup_action_by_key(flow_name) ++ action = await registry.resolve_action_by_key(flow_name) + if action is None: + await logger.aerror( + 'Flow not found', +diff --git a/py/packages/genkit/src/genkit/core/reflection.py b/py/packages/genkit/src/genkit/core/reflection.py +index 1797a6ba8..96ec88681 100644 +--- a/py/packages/genkit/src/genkit/core/reflection.py ++++ b/py/packages/genkit/src/genkit/core/reflection.py +@@ -44,6 +44,8 @@ import asyncio + import json + import urllib.parse + from collections.abc import AsyncGenerator ++from dataclasses import dataclass, field ++from datetime import datetime + from http.server import BaseHTTPRequestHandler + from typing import Any + +@@ -73,10 +75,22 @@ from genkit.web.typing import ( + logger = structlog.get_logger(__name__) + + ++@dataclass ++class ActiveAction: ++ """Represents an in-flight action that can be cancelled.""" ++ ++ task: asyncio.Task | None ++ trace_id: str ++ start_time: datetime = field(default_factory=datetime.now) ++ ++ ++# Global dict to track active actions by trace ID ++_active_actions: dict[str, ActiveAction] = {} ++ ++ + def make_reflection_server( + registry: Registry, + loop: asyncio.AbstractEventLoop, +- id: str, + encoding='utf-8', + quiet=True, + ): +@@ -114,24 +128,16 @@ def make_reflection_server( + For the /api/actions endpoint, returns a JSON object mapping action + keys to their metadata, including input/output schemas. + """ +- parsed_url = urllib.parse.urlparse(self.path) +- if parsed_url.path == '/api/__health': +- query_params = urllib.parse.parse_qs(parsed_url.query) +- expected_id = query_params.get('id', [None])[0] +- if expected_id is not None and expected_id != id: +- self.send_response(500) +- self.end_headers() +- return +- ++ if self.path == '/api/__health': + self.send_response(200, 'OK') + self.end_headers() + +- elif parsed_url.path == '/api/actions': ++ elif self.path == '/api/actions': + self.send_response(200) + self.send_header('content-type', 'application/json') + self.end_headers() + actions = registry.list_serializable_actions() +- actions = registry.list_actions(actions) ++ actions = registry.list_actions_sync(actions) + self.wfile.write(bytes(json.dumps(actions), encoding)) + else: + self.send_response(404) +@@ -158,7 +164,6 @@ def make_reflection_server( + post_body = self.rfile.read(content_len) + payload = json.loads(post_body.decode(encoding=encoding)) + action = registry.lookup_action_by_key(payload['key']) +- action_input = payload.get('input') + context = payload['context'] if 'context' in payload else {} + + query = urllib.parse.urlparse(self.path).query +@@ -186,7 +191,7 @@ def make_reflection_server( + + async def run_fn(): + return await action.arun_raw( +- raw_input=payload.get('input'), ++ raw_input=payload['input'], + on_chunk=send_chunk, + context=context, + ) +@@ -217,7 +222,7 @@ def make_reflection_server( + try: + + async def run_fn(): +- return await action.arun_raw(raw_input=payload.get('input'), context=context) ++ return await action.arun_raw(raw_input=payload['input'], context=context) + + output = run_async(loop, run_fn) + +@@ -327,8 +332,10 @@ def create_reflection_asgi_app( + Returns: + A JSON response containing all serializable actions. + """ ++ actions = registry.list_serializable_actions() ++ actions = await registry.list_actions(actions) + return JSONResponse( +- content=registry.list_serializable_actions(), ++ content=actions, + status_code=200, + headers={'x-genkit-version': version}, + ) +@@ -348,6 +355,41 @@ def create_reflection_asgi_app( + headers={'x-genkit-version': version}, + ) + ++ async def handle_cancel_action(request: Request) -> JSONResponse: ++ """Handle the cancelAction endpoint for cancelling running actions. ++ ++ Args: ++ request: The Starlette request object containing traceId. ++ ++ Returns: ++ 200 with success message if action was cancelled. ++ 400 if traceId is missing. ++ 404 if action not found or already completed. ++ """ ++ payload = await request.json() ++ trace_id = payload.get('traceId') ++ ++ if not trace_id or not isinstance(trace_id, str): ++ return JSONResponse( ++ content={'error': 'traceId is required'}, ++ status_code=400, ++ ) ++ ++ active = _active_actions.get(trace_id) ++ if active: ++ if active.task and not active.task.done(): ++ active.task.cancel() ++ del _active_actions[trace_id] ++ return JSONResponse( ++ content={'message': 'Action cancelled'}, ++ status_code=200, ++ ) ++ else: ++ return JSONResponse( ++ content={'message': 'Action not found or already completed'}, ++ status_code=404, ++ ) ++ + async def handle_run_action( + request: Request, + ) -> JSONResponse | StreamingResponse: +@@ -368,7 +410,7 @@ def create_reflection_asgi_app( + """ + # Get the action. + payload = await request.json() +- action = registry.lookup_action_by_key(payload['key']) ++ action = await registry.resolve_action_by_key(payload['key']) + if action is None: + return JSONResponse( + content={'error': f'Action not found: {payload["key"]}'}, +@@ -377,15 +419,13 @@ def create_reflection_asgi_app( + + # Run the action. + context = payload.get('context', {}) +- action_input = payload.get('input') + stream = is_streaming_requested(request) + handler = run_streaming_action if stream else run_standard_action +- return await handler(action, payload, action_input, context, version) ++ return await handler(action, payload, context, version) + + async def run_streaming_action( + action: Action, + payload: dict[str, Any], +- action_input: Any, + context: dict[str, Any], + version: str, + ) -> StreamingResponse | JSONResponse: +@@ -416,7 +456,7 @@ def create_reflection_asgi_app( + yield f'{out}\n' + + output = await action.arun_raw( +- raw_input=payload.get('input'), ++ raw_input=payload['input'], + on_chunk=send_chunk, + context=context, + ) +@@ -450,7 +490,6 @@ def create_reflection_asgi_app( + async def run_standard_action( + action: Action, + payload: dict[str, Any], +- action_input: Any, + context: dict[str, Any], + version: str, + ) -> JSONResponse: +@@ -466,7 +505,7 @@ def create_reflection_asgi_app( + A JSONResponse with the action result or error. + """ + try: +- output = await action.arun_raw(raw_input=payload.get('input'), context=context) ++ output = await action.arun_raw(raw_input=payload['input'], context=context) + response = { + 'result': dump_dict(output.response), + 'telemetry': {'traceId': output.trace_id}, +@@ -491,6 +530,7 @@ def create_reflection_asgi_app( + Route('/api/actions', handle_list_actions, methods=['GET']), + Route('/api/notify', handle_notify, methods=['POST']), + Route('/api/runAction', handle_run_action, methods=['POST']), ++ Route('/api/cancelAction', handle_cancel_action, methods=['POST']), + ], + middleware=[ + Middleware( +diff --git a/py/packages/genkit/src/genkit/core/registry.py b/py/packages/genkit/src/genkit/core/registry.py +index a690f1121..b5b57719b 100644 +--- a/py/packages/genkit/src/genkit/core/registry.py ++++ b/py/packages/genkit/src/genkit/core/registry.py +@@ -27,6 +27,7 @@ Example: + >>> action = registry.lookup_action('', 'my_action') + """ + ++import inspect + import threading + from collections.abc import Callable + from typing import Any +@@ -36,7 +37,6 @@ from dotpromptz.dotprompt import Dotprompt + + from genkit.core.action import ( + Action, +- ActionMetadata, + create_action_key, + parse_action_key, + parse_plugin_name_from_action_name, +@@ -80,8 +80,10 @@ class Registry: + + def __init__(self): + """Initialize an empty Registry instance.""" +- self._action_resolvers: dict[str, ActionResolver] = {} +- self._list_actions_resolvers: dict[str, Callable] = {} ++ # Multiple plugins can contribute actions under the same plugin namespace. ++ # Example: `vertexai/*` for both model + vector search capabilities. ++ self._action_resolvers: dict[str, list[ActionResolver]] = {} ++ self._list_actions_resolvers: dict[str, list[Callable]] = {} + self._entries: ActionStore = {} + self._value_by_kind_and_name: dict[str, dict[str, Any]] = {} + self._lock = threading.RLock() +@@ -96,13 +98,9 @@ class Registry: + plugin_name: The name of the plugin. + resolver: The ActionResolver instance to register. + +- Raises: +- ValueError: If a resolver is already registered for the plugin. + """ + with self._lock: +- if plugin_name in self._action_resolvers: +- raise ValueError(f'Plugin {plugin_name} already registered') +- self._action_resolvers[plugin_name] = resolver ++ self._action_resolvers.setdefault(plugin_name, []).append(resolver) + + def register_list_actions_resolver(self, plugin_name: str, resolver: Callable) -> None: + """Registers an Callable function to list available actions or models. +@@ -111,13 +109,9 @@ class Registry: + plugin_name: The name of the plugin. + resolver: The Callable function to list models. + +- Raises: +- ValueError: If a resolver is already registered for the plugin. + """ + with self._lock: +- if plugin_name in self._list_actions_resolvers: +- raise ValueError(f'Plugin {plugin_name} already registered') +- self._list_actions_resolvers[plugin_name] = resolver ++ self._list_actions_resolvers.setdefault(plugin_name, []).append(resolver) + + def register_action( + self, +@@ -182,9 +176,25 @@ class Registry: + self._entries[action.kind] = {} + self._entries[action.kind][name] = action + ++ def register_action_from_instance(self, action: Action) -> None: ++ """Register an existing Action instance. ++ Allows registering a pre-configured Action object, such as one created via ++ `dynamic_resource` or other factory methods. ++ Args: ++ action: The action instance to register. ++ """ ++ with self._lock: ++ if action.kind not in self._entries: ++ self._entries[action.kind] = {} ++ self._entries[action.kind][action.name] = action ++ + def lookup_action(self, kind: ActionKind, name: str) -> Action | None: + """Look up an action by its kind and name. + ++ .. deprecated:: ++ Use `await registry.resolve_action(kind, name)` instead. ++ This sync method cannot properly handle async PluginV2 plugins. ++ + Args: + kind: The type of action to look up. + name: The name of the action to look up. +@@ -192,6 +202,13 @@ class Registry: + Returns: + The Action instance if found, None otherwise. + """ ++ import warnings ++ ++ warnings.warn( ++ 'registry.lookup_action() is deprecated. Use `await registry.resolve_action(kind, name)` instead.', ++ DeprecationWarning, ++ stacklevel=2, ++ ) + with self._lock: + # If the entry does not exist, we fist try to call the action + # resolver for the plugin to give it a chance to dynamically add the +@@ -199,18 +216,133 @@ class Registry: + if kind not in self._entries or name not in self._entries[kind]: + plugin_name = parse_plugin_name_from_action_name(name) + if plugin_name and plugin_name in self._action_resolvers: +- # Strip plugin prefix before calling resolver +- action_name = name.removeprefix(f"{plugin_name}/") +- self._action_resolvers[plugin_name](kind, action_name) ++ # Pass the full namespaced action name to the plugin resolver. ++ # (Many v1 plugins/tests expect to receive the full name and will ++ # register actions using it; v2 resolvers can strip the prefix ++ # internally.) ++ for resolver in self._action_resolvers[plugin_name]: ++ result = resolver(kind, name) ++ if inspect.isawaitable(result): ++ raise TypeError( ++ f'Action resolver for plugin "{plugin_name}" returned an awaitable while resolving "{name}". ' ++ 'Use async resolution (e.g. `await registry.resolve_action(...)`) instead of sync `lookup_action(...)`.' ++ ) ++ if kind in self._entries and name in self._entries[kind]: ++ break + + if kind in self._entries and name in self._entries[kind]: + return self._entries[kind][name] + + return None + ++ async def resolve_action(self, kind: ActionKind, name: str) -> Action | None: ++ """Resolve an action by kind and name (async). ++ ++ Resolves an action name like "openai/gpt-4" (namespaced form). ++ Registry hit: if name is already in entries[kind], return it. ++ If miss: parse plugin_name from name (first segment before /). If there's no ++ plugin prefix, it can't route → returns None. ++ If plugin prefix exists: look up the list of resolver functions registered for that ++ plugin and await them. ++ Check _entries again; if present, return it. ++ ++ Args: ++ kind: The type of action to look up. ++ name: The namespaced action name (e.g., "openai/gpt-4"). ++ ++ Returns: ++ The Action instance if found, None otherwise. ++ ++ Example: ++ >>> action = await registry.resolve_action(ActionKind.MODEL, 'openai/gpt-4') ++ """ ++ # Fast path: already registered (do not trigger resolvers). ++ with self._lock: ++ if kind in self._entries and name in self._entries[kind]: ++ return self._entries[kind][name] ++ ++ plugin_name = parse_plugin_name_from_action_name(name) ++ if not plugin_name: ++ return None ++ ++ resolvers = self._action_resolvers.get(plugin_name) ++ if not resolvers: ++ return None ++ ++ # Important: pass the full namespaced action name to the plugin resolver. ++ # V2 resolvers can strip the prefix internally if needed. ++ for resolver in resolvers: ++ result = resolver(kind, name) ++ if inspect.isawaitable(result): ++ await result ++ with self._lock: ++ if kind in self._entries and name in self._entries[kind]: ++ return self._entries[kind][name] ++ ++ with self._lock: ++ if kind in self._entries and name in self._entries[kind]: ++ return self._entries[kind][name] ++ return None ++ ++ # Backwards compatibility alias ++ async def aresolve_action(self, kind: ActionKind, name: str) -> Action | None: ++ """Deprecated: use resolve_action() instead.""" ++ import warnings ++ ++ warnings.warn( ++ 'aresolve_action() is deprecated. Use resolve_action() instead.', DeprecationWarning, stacklevel=2 ++ ) ++ return await self.resolve_action(kind, name) ++ ++ async def resolve_action_by_key(self, key: str) -> Action | None: ++ """Resolve an action by registry key (async). ++ ++ Resolves by registry key (e.g., "/model/openai/gpt-4"). ++ ++ Args: ++ key: The action key in the format "/kind/name". ++ ++ Returns: ++ The Action instance if found, None otherwise. ++ ++ Example: ++ >>> action = await registry.resolve_action_by_key('/model/openai/gpt-4') ++ """ ++ kind, name = parse_action_key(key) ++ return await self.resolve_action(kind, name) ++ ++ # Backwards compatibility alias ++ async def aresolve_action_by_key(self, key: str) -> Action | None: ++ """Deprecated: use resolve_action_by_key() instead.""" ++ import warnings ++ ++ warnings.warn( ++ 'aresolve_action_by_key() is deprecated. Use resolve_action_by_key() instead.', ++ DeprecationWarning, ++ stacklevel=2, ++ ) ++ return await self.resolve_action_by_key(key) ++ ++ def get_actions_by_kind(self, kind: ActionKind) -> dict[str, Action]: ++ """Returns a dictionary of all registered actions for a specific kind. ++ ++ Args: ++ kind: The type of actions to retrieve (e.g., TOOL, MODEL, RESOURCE). ++ ++ Returns: ++ A dictionary mapping action names to Action instances. ++ Returns an empty dictionary if no actions of that kind are registered. ++ """ ++ with self._lock: ++ return self._entries.get(kind, {}).copy() ++ + def lookup_action_by_key(self, key: str) -> Action | None: + """Look up an action using its combined key string. + ++ .. deprecated:: ++ Use `await registry.resolve_action_by_key(key)` instead. ++ This sync method cannot properly handle async PluginV2 plugins. ++ + The key format is `/`, where kind must be a valid + `ActionKind` and name must be a registered action name within that kind. + +@@ -224,8 +356,18 @@ class Registry: + ValueError: If the key format is invalid or the kind is not a valid + `ActionKind`. + """ ++ import warnings ++ ++ warnings.warn( ++ 'registry.lookup_action_by_key() is deprecated. Use `await registry.resolve_action_by_key(key)` instead.', ++ DeprecationWarning, ++ stacklevel=2, ++ ) + kind, name = parse_action_key(key) +- return self.lookup_action(kind, name) ++ # Suppress nested deprecation warning (internal delegation) ++ with warnings.catch_warnings(): ++ warnings.simplefilter('ignore', DeprecationWarning) ++ return self.lookup_action(kind, name) + + def list_serializable_actions(self, allowed_kinds: set[ActionKind] | None = None) -> dict[str, Action] | None: + """Enlist all the actions into a dictionary. +@@ -243,7 +385,8 @@ class Registry: + if allowed_kinds is not None and kind not in allowed_kinds: + continue + for name in self._entries[kind]: +- action = self.lookup_action(kind, name) ++ # Read directly from _entries (already registered actions) ++ action = self._entries[kind].get(name) + if action is not None: + key = create_action_key(kind, name) + # TODO: Serialize the Action instance +@@ -256,12 +399,16 @@ class Registry: + } + return actions + +- def list_actions( ++ def list_actions_sync( + self, + actions: dict[str, Action] | None = None, + allowed_kinds: set[ActionKind] | None = None, + ) -> dict[str, Action] | None: +- """Add actions or models. ++ """Add actions or models (sync version - deprecated). ++ ++ .. deprecated:: ++ Use `await registry.list_actions(...)` instead. ++ This sync method cannot properly handle async PluginV2 plugins. + + Args: + actions: dictionary of serializable actions. +@@ -271,28 +418,98 @@ class Registry: + Returns: + A dictionary of serializable Actions updated. + """ ++ import warnings ++ ++ warnings.warn( ++ 'list_actions_sync() is deprecated. Use `await registry.list_actions(...)` instead.', ++ DeprecationWarning, ++ stacklevel=2, ++ ) ++ + if actions is None: + actions = {} + + for plugin_name in self._list_actions_resolvers: +- actions_list = self._list_actions_resolvers[plugin_name]() ++ for resolver in self._list_actions_resolvers[plugin_name]: ++ actions_list = resolver() ++ if inspect.isawaitable(actions_list): ++ raise TypeError( ++ f'list_actions resolver for plugin "{plugin_name}" returned an awaitable. ' ++ 'Use `await registry.list_actions(...)` to list actions asynchronously.' ++ ) ++ ++ for _action in actions_list: ++ kind = _action.kind ++ if allowed_kinds is not None and kind not in allowed_kinds: ++ continue ++ key = create_action_key(kind, _action.name) ++ ++ if key not in actions: ++ actions[key] = { ++ 'key': key, ++ 'name': _action.name, ++ 'inputSchema': _action.input_json_schema, ++ 'outputSchema': _action.output_json_schema, ++ 'metadata': _action.metadata, ++ } ++ return actions + +- for _action in actions_list: +- kind = _action.kind +- if allowed_kinds is not None and kind not in allowed_kinds: +- continue +- key = create_action_key(kind, _action.name) +- +- if key not in actions: +- actions[key] = { +- 'key': key, +- 'name': _action.name, +- 'inputSchema': _action.input_json_schema, +- 'outputSchema': _action.output_json_schema, +- 'metadata': _action.metadata, +- } ++ async def list_actions( ++ self, ++ actions: dict[str, Action] | None = None, ++ allowed_kinds: set[ActionKind] | None = None, ++ ) -> dict[str, Action] | None: ++ """List all actions (async). ++ ++ Async listing that awaits plugin list() resolvers. If allowed_kinds not provided, ++ returns action metadata for all kinds. ++ ++ Args: ++ actions: Optional dictionary to append actions to. ++ allowed_kinds: Optional set of ActionKind to filter by. ++ ++ Returns: ++ Dictionary of serializable actions. ++ ++ Example: ++ >>> actions = await registry.list_actions(allowed_kinds={ActionKind.MODEL}) ++ """ ++ if actions is None: ++ actions = {} ++ ++ for plugin_name in self._list_actions_resolvers: ++ for resolver in self._list_actions_resolvers[plugin_name]: ++ actions_list = resolver() ++ if inspect.isawaitable(actions_list): ++ actions_list = await actions_list ++ ++ for _action in actions_list: ++ kind = _action.kind ++ if allowed_kinds is not None and kind not in allowed_kinds: ++ continue ++ key = create_action_key(kind, _action.name) ++ if key not in actions: ++ actions[key] = { ++ 'key': key, ++ 'name': _action.name, ++ 'inputSchema': _action.input_json_schema, ++ 'outputSchema': _action.output_json_schema, ++ 'metadata': _action.metadata, ++ } + return actions + ++ # Backwards compatibility alias ++ async def alist_actions( ++ self, ++ actions: dict[str, Action] | None = None, ++ allowed_kinds: set[ActionKind] | None = None, ++ ) -> dict[str, Action] | None: ++ """Deprecated: use list_actions() instead.""" ++ import warnings ++ ++ warnings.warn('alist_actions() is deprecated. Use list_actions() instead.', DeprecationWarning, stacklevel=2) ++ return await self.list_actions(actions, allowed_kinds) ++ + def register_value(self, kind: str, name: str, value: Any): + """Registers a value with a given kind and name. + +diff --git a/py/packages/genkit/src/genkit/core/schema.py b/py/packages/genkit/src/genkit/core/schema.py +index 1d201d3dd..f7b2cc056 100644 +--- a/py/packages/genkit/src/genkit/core/schema.py ++++ b/py/packages/genkit/src/genkit/core/schema.py +@@ -16,7 +16,8 @@ + + """Functions for working with schema.""" + +-from typing import Any, Callable ++from collections.abc import Callable ++from typing import Any + + from pydantic import TypeAdapter + +diff --git a/py/packages/genkit/src/genkit/core/trace/default_exporter.py b/py/packages/genkit/src/genkit/core/trace/default_exporter.py +index fbc9fc4ff..5570cbb8c 100644 +--- a/py/packages/genkit/src/genkit/core/trace/default_exporter.py ++++ b/py/packages/genkit/src/genkit/core/trace/default_exporter.py +@@ -25,11 +25,10 @@ The module includes: + - Utility functions for converting and formatting trace attributes + """ + +-import asyncio + import json + import os + import sys +-from collections.abc import Awaitable, Sequence ++from collections.abc import Sequence + from typing import Any + from urllib.parse import urljoin + +diff --git a/py/packages/genkit/src/genkit/core/typing.py b/py/packages/genkit/src/genkit/core/typing.py +index dce5a4f7f..68fd54110 100644 +--- a/py/packages/genkit/src/genkit/core/typing.py ++++ b/py/packages/genkit/src/genkit/core/typing.py +@@ -913,7 +913,7 @@ class Message(BaseModel): + """Model for message data.""" + + model_config = ConfigDict(extra='forbid', populate_by_name=True) +- role: Role ++ role: Role | str + content: list[Part] + metadata: dict[str, Any] | None = None + +diff --git a/py/packages/genkit/src/genkit/types/__init__.py b/py/packages/genkit/src/genkit/types/__init__.py +index fa2f6b24c..4465eb3d7 100644 +--- a/py/packages/genkit/src/genkit/types/__init__.py ++++ b/py/packages/genkit/src/genkit/types/__init__.py +@@ -48,6 +48,7 @@ from genkit.core.typing import ( + ModelInfo, + OutputConfig, + Part, ++ ReasoningPart, + RetrieverRequest, + RetrieverResponse, + Role, +@@ -94,6 +95,7 @@ __all__ = [ + ModelInfo.__name__, + OutputConfig.__name__, + Part.__name__, ++ ReasoningPart.__name__, + RetrieverRequest.__name__, + RetrieverResponse.__name__, + Role.__name__, +diff --git a/py/packages/genkit/tests/genkit/ai/plugin_v2_test.py b/py/packages/genkit/tests/genkit/ai/plugin_v2_test.py +deleted file mode 100644 +index 5c7f02f87..000000000 +--- a/py/packages/genkit/tests/genkit/ai/plugin_v2_test.py ++++ /dev/null +@@ -1,252 +0,0 @@ +-# Copyright 2025 Google LLC +-# +-# Licensed under the Apache License, Version 2.0 (the "License"); +-# you may not use this file except in compliance with the License. +-# You may obtain a copy of the License at +-# +-# http://www.apache.org/licenses/LICENSE-2.0 +-# +-# Unless required by applicable law or agreed to in writing, software +-# distributed under the License is distributed on an "AS IS" BASIS, +-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-# See the License for the specific language governing permissions and +-# limitations under the License. +-# +-# SPDX-License-Identifier: Apache-2.0 +- +-"""Tests for v2 plugin support. +- +-Focused on key product requirements, not code coverage. +-Tests verify that v2 plugins work standalone and with framework, +-and that v1 plugins continue working (backward compatibility). +-""" +- +-import pytest +- +-from genkit.ai import Genkit +-from genkit.ai._plugin import PluginV2, is_plugin_v2 +-from genkit.core.action import Action, ActionMetadata +-from genkit.core.registry import ActionKind +-from genkit.types import ( +- Candidate, +- GenerateRequest, +- GenerateResponse, +- Message, +- Role, +- TextPart, +-) +- +- +-# Helper: Simple v2 test plugin +-class SimpleV2Plugin(PluginV2): +- """Minimal v2 plugin for testing.""" +- +- name = "test-v2" +- +- def __init__(self, models: list[str] | None = None): +- self._models = models or ["model-1"] +- +- def init(self): +- from genkit.blocks.model import model +- +- return [ +- model( +- name=m, +- fn=self._generate, +- ) +- for m in self._models +- ] +- +- def resolve(self, action_type, name): +- from genkit.blocks.model import model +- +- # Framework passes unprefixed name +- if action_type == ActionKind.MODEL and name in ["model-1", "model-2", "lazy-model"]: +- return model(name=name, fn=self._generate) +- return None +- +- def list_actions(self): +- return [ +- ActionMetadata(name=m, kind=ActionKind.MODEL, info={}) +- for m in ["model-1", "model-2"] +- ] +- +- # model() method inherited from PluginV2 base class +- +- def _generate(self, request: GenerateRequest, ctx): +- """Simple test model that echoes input.""" +- input_text = request.messages[0].content[0].text if request.messages else "empty" +- return GenerateResponse( +- candidates=[ +- Candidate( +- message=Message( +- role=Role.MODEL, content=[TextPart(text=f"TEST: {input_text}")] +- ) +- ) +- ] +- ) +- +- +-# Test 1: V2 plugins return actions +-def test_v2_plugin_init_returns_actions(): +- """V2 plugin init() should return list of Action objects.""" +- plugin = SimpleV2Plugin(models=["model-1", "model-2"]) +- +- actions = plugin.init() +- +- assert isinstance(actions, list) +- assert len(actions) == 2 +- assert all(isinstance(a, Action) for a in actions) +- assert actions[0].name == "model-1" +- assert actions[0].kind == ActionKind.MODEL +- +- +-# Test 2: V2 plugins work standalone +-@pytest.mark.asyncio +-async def test_v2_plugin_works_standalone(): +- """V2 plugin should work WITHOUT Genkit framework.""" +- # Create plugin - NO Genkit instance +- plugin = SimpleV2Plugin() +- +- # Get an action +- action = plugin.resolve(ActionKind.MODEL, "model-1") +- +- # Call it directly +- response = await action.arun({"messages": [{"role": "user", "content": [{"text": "hello"}]}]}) +- +- assert response is not None +- assert response.response.candidates[0].message.content[0].text == "TEST: hello" +- +- +-# Test 3: V2 plugins work with framework +-@pytest.mark.asyncio +-async def test_v2_plugin_works_with_framework(): +- """V2 plugin should work WITH Genkit framework.""" +- plugin = SimpleV2Plugin() +- +- ai = Genkit(plugins=[plugin]) +- +- response = await ai.generate("test-v2/model-1", prompt="framework test") +- +- assert response.text is not None +- assert "TEST:" in response.text +- +- +-# Test 4: Framework supports both v1 and v2 +-@pytest.mark.asyncio +-async def test_framework_accepts_v2_plugin(): +- """Framework should accept v2 plugins.""" +- plugin = SimpleV2Plugin() +- +- ai = Genkit(plugins=[plugin]) +- +- response = await ai.generate("test-v2/model-1", prompt="test") +- +- assert response.text is not None +- +- +-# Test 5: Lazy loading +-@pytest.mark.asyncio +-async def test_v2_lazy_loading(): +- """V2 plugin should support lazy loading via resolve().""" +- # Plugin with NO eager models +- plugin = SimpleV2Plugin(models=[]) +- +- ai = Genkit(plugins=[plugin]) +- +- # init() returned empty, but resolve() should work +- response = await ai.generate("test-v2/lazy-model", prompt="test") +- +- assert response.text is not None +- assert "TEST:" in response.text +- +- +-# Test 6: Automatic namespacing +-@pytest.mark.asyncio +-async def test_v2_automatic_namespacing(): +- """Framework should add namespace automatically.""" +- plugin = SimpleV2Plugin() +- +- # Plugin returns action WITHOUT namespace +- actions = plugin.init() +- assert actions[0].name == "model-1" # No prefix +- +- # Framework adds namespace +- ai = Genkit(plugins=[plugin]) +- +- # Must use namespaced name +- response = await ai.generate("test-v2/model-1", prompt="test") +- assert response.text is not None +- +- +-# Test 7: List actions +-def test_v2_list_actions(): +- """V2 plugin list_actions() should return metadata.""" +- plugin = SimpleV2Plugin() +- +- metadata = plugin.list_actions() +- +- assert isinstance(metadata, list) +- assert len(metadata) == 2 +- assert all(isinstance(m, ActionMetadata) for m in metadata) +- +- +-# Test 8: Detection function +-def test_is_plugin_v2_detection(): +- """is_plugin_v2() should correctly detect v2 plugins.""" +- from genkit.ai._plugin import Plugin +- +- v2_plugin = SimpleV2Plugin() +- +- # Create a simple v1 plugin for testing +- class SimpleV1Plugin(Plugin): +- name = "test-v1" +- +- def initialize(self, ai): +- pass +- +- v1_plugin = SimpleV1Plugin() +- +- assert is_plugin_v2(v2_plugin) is True +- assert is_plugin_v2(v1_plugin) is False +- assert is_plugin_v2("not a plugin") is False +- +- +-# Test 9: model() factory +-def test_model_factory_creates_action(): +- """model() factory should create Action without registry.""" +- from genkit.blocks.model import model +- +- def dummy_fn(request, ctx): +- return GenerateResponse( +- candidates=[ +- Candidate( +- message=Message(role=Role.MODEL, content=[TextPart(text="test")]) +- ) +- ] +- ) +- +- action = model(name="test-model", fn=dummy_fn) +- +- assert isinstance(action, Action) +- assert action.name == "test-model" +- assert action.kind == ActionKind.MODEL +- +- +-# Test 10: Convenience method +-@pytest.mark.asyncio +-async def test_v2_plugin_model_convenience_method(): +- """V2 plugin.model() should provide convenient access.""" +- plugin = SimpleV2Plugin() +- +- # Get model via convenience method +- action = await plugin.model("model-1") +- +- assert isinstance(action, Action) +- assert action.name == "model-1" +- +- # Should raise for non-existent model +- with pytest.raises(ValueError, match="not found"): +- await plugin.model("nonexistent-model") +- +- +diff --git a/py/packages/genkit/tests/genkit/ai/test_resource.py b/py/packages/genkit/tests/genkit/ai/test_resource.py +new file mode 100644 +index 000000000..13e187c88 +--- /dev/null ++++ b/py/packages/genkit/tests/genkit/ai/test_resource.py +@@ -0,0 +1,241 @@ ++# Copyright 2025 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++"""Tests for the Genkit Resource API. ++This module verifies the functionality of defining, registering, and resolving resources ++in the Genkit framework. It covers static resources, template-based resources, ++dynamic resource matching, and metadata handling. ++""" ++ ++import asyncio ++ ++import pytest ++ ++from genkit.blocks.resource import define_resource, resolve_resources, resource ++from genkit.core.registry import Registry ++from genkit.core.typing import Part, TextPart ++ ++ ++def test_define_resource(): ++ """Verifies that a resource can be defined and registered correctly. ++ Checks: ++ - Resource name matches property. ++ - Resource is retrievable from the registry by name. ++ """ ++ registry = Registry() ++ ++ async def my_resource_fn(input, ctx): ++ return {'content': [Part(TextPart(text=f'Content for {input.uri}'))]} ++ ++ act = define_resource(registry, {'uri': 'http://example.com/foo'}, my_resource_fn) ++ ++ assert act.name == 'http://example.com/foo' ++ assert act.metadata['resource']['uri'] == 'http://example.com/foo' ++ ++ # Verify lookup logic (mocking lookup_action effectively via direct access or helper) ++ # Registry lookup for resources usually prepends /resource/ etc. ++ # but define_resource registers it with name=uri ++ ++ looked_up = registry.lookup_action('resource', 'http://example.com/foo') ++ assert looked_up == act ++ ++ ++@pytest.mark.asyncio ++async def test_resolve_resources(): ++ """Verifies resolving resource references into Action objects. ++ Checks: ++ - Resolving by string name works. ++ - Resolving by Action object passes through. ++ """ ++ registry = Registry() ++ ++ async def my_resource_fn(input, ctx): ++ return {'content': [Part(TextPart(text=f'Content for {input.uri}'))]} ++ ++ act = define_resource(registry, {'name': 'my-resource', 'uri': 'http://example.com/foo'}, my_resource_fn) ++ ++ resolved = await resolve_resources(registry, ['my-resource']) ++ assert len(resolved) == 1 ++ assert resolved[0] == act ++ ++ resolved_obj = await resolve_resources(registry, [act]) ++ assert len(resolved_obj) == 1 ++ assert resolved_obj[0] == act ++ ++ ++@pytest.mark.asyncio ++async def test_find_matching_resource(): ++ """Verifies the logic for finding a matching resource given an input URI. ++ Checks: ++ - Exact match against registered static resources. ++ - Template match against registered template resources. ++ - Matching against a provided list of dynamic resource actions for override/adhoc usage. ++ - Returns None when no match is found. ++ """ ++ registry = Registry() ++ ++ # Static resource ++ async def static_fn(input, ctx): ++ return {'content': []} ++ ++ static_res = define_resource(registry, {'uri': 'bar://baz', 'name': 'staticRes'}, static_fn) ++ ++ # Template resource ++ async def template_fn(input, ctx): ++ return {'content': []} ++ ++ template_res = define_resource(registry, {'template': 'foo://bar/{baz}', 'name': 'templateRes'}, template_fn) ++ ++ # Dynamic resource list ++ async def dynamic_fn(input, ctx): ++ return {'content': []} ++ ++ dynamic_res = resource({'uri': 'baz://qux'}, dynamic_fn) ++ ++ from genkit.blocks.resource import ResourceInput, find_matching_resource ++ ++ # Match static from registry ++ res = await find_matching_resource(registry, [], ResourceInput(uri='bar://baz')) ++ assert res == static_res ++ ++ # Match template from registry ++ res = await find_matching_resource(registry, [], ResourceInput(uri='foo://bar/something')) ++ assert res == template_res ++ ++ # Match dynamic from list ++ res = await find_matching_resource(registry, [dynamic_res], ResourceInput(uri='baz://qux')) ++ assert res == dynamic_res ++ ++ # No match ++ res = await find_matching_resource(registry, [], ResourceInput(uri='unknown://uri')) ++ assert res is None ++ ++ ++def test_is_dynamic_resource_action(): ++ """Verifies identifying dynamic vs registered resource actions. ++ Checks: ++ - Unregistered resources created with `resource()` are dynamic. ++ - Registered resources created with `define_resource()` are not dynamic. ++ """ ++ from genkit.blocks.resource import is_dynamic_resource_action ++ ++ async def fn(input, ctx): ++ return {'content': []} ++ ++ dynamic = resource({'uri': 'bar://baz'}, fn) ++ assert is_dynamic_resource_action(dynamic) ++ ++ # Registered action (define_resource sets dynamic=False) ++ async def static_fn(input, ctx): ++ return {'content': []} ++ ++ static = define_resource(Registry(), {'uri': 'foo://bar'}, static_fn) ++ assert not is_dynamic_resource_action(static) ++ ++ ++@pytest.mark.asyncio ++async def test_parent_metadata(): ++ """Verifies that parent metadata is correctly attached to output items. ++ When a resource is resolved via a template (e.g. `file://{id}`), the output parts ++ should contain metadata referencing the parent resource URI and template. ++ Checks: ++ - Parent URI and template presence in output part metadata. ++ """ ++ registry = Registry() ++ ++ async def fn(input, ctx): ++ return {'content': [Part(TextPart(text='sub1', metadata={'resource': {'uri': f'{input.uri}/sub1.txt'}}))]} ++ ++ res = define_resource(registry, {'template': 'file://{id}'}, fn) ++ ++ output = await res.arun({'uri': 'file://dir'}) ++ # output is ActionResponse ++ # content is in output.response['content'] because wrapped_fn ensures serialization ++ ++ part = output.response['content'][0] ++ # Check metadata ++ assert part['metadata']['resource']['parent']['uri'] == 'file://dir' ++ assert part['metadata']['resource']['parent']['template'] == 'file://{id}' ++ assert part['metadata']['resource']['uri'] == 'file://dir/sub1.txt' ++ ++ ++def test_dynamic_resource_matching(): ++ """Verifies the matching logic for a simple static URI dynamic resource.""" ++ ++ async def my_resource_fn(input, ctx): ++ return {'content': [Part(TextPart(text='Match'))]} ++ ++ res = resource({'uri': 'http://example.com/foo'}, my_resource_fn) ++ ++ class MockInput: ++ uri = 'http://example.com/foo' ++ ++ assert res.matches(MockInput()) ++ ++ class MockInputBad: ++ uri = 'http://example.com/bar' ++ ++ assert not res.matches(MockInputBad()) ++ ++ ++def test_template_matching(): ++ """Verifies URI template pattern matching. ++ Checks: ++ - Matches correct URI structure. ++ - Fails on paths extending beyond the template structure (strict matching). ++ """ ++ ++ async def my_resource_fn(input, ctx): ++ return {'content': []} ++ ++ res = resource({'template': 'http://example.com/items/{id}'}, my_resource_fn) ++ ++ class MockInput: ++ uri = 'http://example.com/items/123' ++ ++ assert res.matches(MockInput()) ++ ++ class MockInputBad: ++ uri = 'http://example.com/items/123/details' ++ ++ # Should not match because of strict end anchor or slash handling in our regex ++ assert not res.matches(MockInputBad()) ++ ++ ++def test_reserved_expansion_matching(): ++ """Verifies RFC 6570 reserved expansion {+var} pattern matching. ++ Checks: ++ - Matches correct URI structure with slashes (reserved chars). ++ """ ++ ++ async def my_resource_fn(input, ctx): ++ return {'content': []} ++ ++ # Template with reserved expansion {+path} (matches slashes) ++ res = resource({'template': 'http://example.com/files/{+path}'}, my_resource_fn) ++ ++ class MockInput: ++ uri = 'http://example.com/files/foo/bar/baz.txt' ++ ++ assert res.matches(MockInput()) ++ ++ # Regular template {path} regex ([^/]+) should NOT match slashes ++ res_simple = resource({'template': 'http://example.com/items/{id}'}, my_resource_fn) ++ ++ class MockInputComplex: ++ uri = 'http://example.com/items/foo/bar' ++ ++ assert not res_simple.matches(MockInputComplex()) +diff --git a/py/packages/genkit/tests/genkit/blocks/embedding_test.py b/py/packages/genkit/tests/genkit/blocks/embedding_test.py +index 04cd4e7f0..0f1738333 100644 +--- a/py/packages/genkit/tests/genkit/blocks/embedding_test.py ++++ b/py/packages/genkit/tests/genkit/blocks/embedding_test.py +@@ -25,7 +25,6 @@ from genkit.ai._aio import Genkit + from genkit.blocks.document import Document + from genkit.blocks.embedding import ( + EmbedderOptions, +- EmbedderRef, + EmbedderSupports, + create_embedder_ref, + embedder_action_metadata, +@@ -156,6 +155,13 @@ class MockGenkitRegistry: + def lookup_action(self, kind, name): + return self.actions.get((kind, name)) + ++ async def resolve_action(self, kind, name): ++ return self.lookup_action(kind, name) ++ ++ # Backwards compatibility ++ async def aresolve_action(self, kind, name): ++ return self.lookup_action(kind, name) ++ + + @pytest.fixture + def mock_genkit_instance(): +diff --git a/py/packages/genkit/tests/genkit/blocks/generate_test.py b/py/packages/genkit/tests/genkit/blocks/generate_test.py +index 8d5035791..9feb72996 100644 +--- a/py/packages/genkit/tests/genkit/blocks/generate_test.py ++++ b/py/packages/genkit/tests/genkit/blocks/generate_test.py +@@ -43,7 +43,7 @@ def setup_test(): + + @ai.tool(name='testTool') + def test_tool(): +- """description""" ++ """Description.""" + return 'tool called' + + return (ai, pm) +@@ -333,7 +333,7 @@ async def test_generate_action_spec(spec) -> None: + + @ai.tool(name='testTool') + def test_tool(): +- """description""" ++ """Description.""" + return 'tool called' + + if 'modelResponses' in spec: +diff --git a/py/packages/genkit/tests/genkit/blocks/model_test.py b/py/packages/genkit/tests/genkit/blocks/model_test.py +index eca64fee5..d0547b506 100644 +--- a/py/packages/genkit/tests/genkit/blocks/model_test.py ++++ b/py/packages/genkit/tests/genkit/blocks/model_test.py +@@ -58,6 +58,26 @@ def test_response_wrapper_text() -> None: + assert wrapper.text == 'hello world' + + ++def test_response_wrapper_uses_candidates_fallback() -> None: ++ wrapper = GenerateResponseWrapper( ++ response=GenerateResponse( ++ candidates=[ ++ Candidate( ++ index=0, ++ message=Message(role='model', content=[Part(text='hello')]), ++ finish_reason='stop', ++ ) ++ ] ++ ), ++ request=GenerateRequest( ++ messages=[], # doesn't matter for now ++ ), ++ ) ++ ++ assert wrapper.text == 'hello' ++ assert wrapper.finish_reason == 'stop' ++ ++ + def test_response_wrapper_output() -> None: + wrapper = GenerateResponseWrapper( + response=GenerateResponse( +diff --git a/py/packages/genkit/tests/genkit/blocks/prompt_test.py b/py/packages/genkit/tests/genkit/blocks/prompt_test.py +index 112cbcbde..8a9af200b 100644 +--- a/py/packages/genkit/tests/genkit/blocks/prompt_test.py ++++ b/py/packages/genkit/tests/genkit/blocks/prompt_test.py +@@ -28,14 +28,15 @@ from genkit.ai import Genkit + from genkit.blocks.prompt import load_prompt_folder, lookup_prompt, prompt + from genkit.core.action.types import ActionKind + from genkit.core.typing import ( ++ DocumentData, + GenerateActionOptions, + GenerateRequest, ++ GenerateResponse, + GenerationCommonConfig, + Message, + Role, + TextPart, + ToolChoice, +- ToolDefinition, + ) + from genkit.testing import ( + define_echo_model, +@@ -147,6 +148,55 @@ async def test_prompt_with_kitchensink() -> None: + assert (await response).text == want_txt + + ++@pytest.mark.asyncio ++async def test_prompt_with_resolvers() -> None: ++ """Test that the rendering works with resolvers.""" ++ ai, *_ = setup_test() ++ ++ async def system_resolver(input, context): ++ return f'system {input["name"]}' ++ ++ def prompt_resolver(input, context): ++ return f'prompt {input["name"]}' ++ ++ async def messages_resolver(input, context): ++ return [Message(role=Role.USER, content=[TextPart(text=f'msg {input["name"]}')])] ++ ++ my_prompt = ai.define_prompt( ++ system=system_resolver, ++ prompt=prompt_resolver, ++ messages=messages_resolver, ++ ) ++ ++ want_txt = '[ECHO] system: "system world" user: "msg world" user: "prompt world"' ++ ++ response = await my_prompt(input={'name': 'world'}) ++ ++ assert response.text == want_txt ++ ++ ++@pytest.mark.asyncio ++async def test_prompt_with_docs_resolver() -> None: ++ """Test that the rendering works with docs resolver.""" ++ ai, _, pm = setup_test() ++ ++ pm.responses = [GenerateResponse(message=Message(role=Role.MODEL, content=[TextPart(text='ok')]))] ++ ++ async def docs_resolver(input, context): ++ return [DocumentData(content=[TextPart(text=f'doc {input["name"]}')])] ++ ++ my_prompt = ai.define_prompt( ++ model='programmableModel', ++ prompt='hi', ++ docs=docs_resolver, ++ ) ++ ++ await my_prompt(input={'name': 'world'}) ++ ++ # Check that PM received the docs ++ assert pm.last_request.docs[0].content[0].root.text == 'doc world' ++ ++ + test_cases_parse_partial_json = [ + ( + 'renders system prompt', +@@ -208,7 +258,6 @@ test_cases_parse_partial_json = [ + ] + + +-@pytest.mark.skip(reason='issues when running on CI') + @pytest.mark.asyncio + @pytest.mark.parametrize( + 'test_case, prompt, input, input_option, context, want_rendered', +@@ -320,6 +369,25 @@ async def test_load_and_use_partial() -> None: + assert 'Hello from partial' in response.text or 'space' in response.text + + ++@pytest.mark.asyncio ++async def test_define_partial_programmatically() -> None: ++ """Test defining partials programmatically using ai.define_partial().""" ++ ai, *_ = setup_test() ++ ++ # Define a partial programmatically ++ ai.define_partial('myGreeting', 'Greetings, {{name}}!') ++ ++ # Create a prompt that uses the partial ++ my_prompt = ai.define_prompt( ++ messages='{{>myGreeting}} Welcome to Genkit.', ++ ) ++ ++ response = await my_prompt(input={'name': 'Developer'}) ++ ++ # The partial should be included in the output ++ assert 'Greetings' in response.text and 'Developer' in response.text ++ ++ + @pytest.mark.asyncio + async def test_prompt_with_messages_list() -> None: + """Test prompt with explicit messages list.""" +@@ -350,7 +418,7 @@ async def test_messages_with_explicit_override() -> None: + prompt='Final question', + ) + +- override_messages = [ ++ [ + Message(role=Role.USER, content=[TextPart(text='First message')]), + Message(role=Role.MODEL, content=[TextPart(text='First response')]), + ] +@@ -509,7 +577,6 @@ async def test_prompt_and_executable_prompt_return_types() -> None: + @pytest.mark.asyncio + async def test_lookup_prompt_returns_executable_prompt() -> None: + """lookup_prompt should return an ExecutablePrompt that can be called.""" +- + ai, *_ = setup_test() + + with tempfile.TemporaryDirectory() as tmpdir: +@@ -600,7 +667,7 @@ async def test_automatic_prompt_loading_defaults_mock(): + @pytest.mark.asyncio + async def test_automatic_prompt_loading_defaults_missing(): + """Test that Genkit skips loading when ./prompts is missing.""" +- from unittest.mock import ANY, MagicMock, patch ++ from unittest.mock import MagicMock, patch + + with patch('genkit.ai._aio.load_prompt_folder') as mock_load, patch('genkit.ai._aio.Path') as mock_path: + # Setup mock to simulate ./prompts missing +diff --git a/py/packages/genkit/tests/genkit/blocks/test_reranker.py b/py/packages/genkit/tests/genkit/blocks/test_reranker.py +new file mode 100644 +index 000000000..63ac47433 +--- /dev/null ++++ b/py/packages/genkit/tests/genkit/blocks/test_reranker.py +@@ -0,0 +1,484 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++"""Tests for the reranker module. ++ ++This module contains tests for the reranker functionality including ++RankedDocument, define_reranker, and rerank functions. ++""" ++ ++import pytest ++ ++from genkit.blocks.document import Document ++from genkit.blocks.reranker import ( ++ RankedDocument, ++ RerankerInfo, ++ RerankerOptions, ++ RerankerRef, ++ create_reranker_ref, ++ define_reranker, ++ rerank, ++ reranker_action_metadata, ++) ++from genkit.core.action.types import ActionKind ++from genkit.core.registry import Registry ++from genkit.core.typing import ( ++ DocumentData, ++ DocumentPart, ++ RankedDocumentData, ++ RankedDocumentMetadata, ++ RerankerResponse, ++) ++ ++ ++class TestRankedDocument: ++ """Tests for the RankedDocument class.""" ++ ++ def test_ranked_document_creation(self): ++ """Test creating a RankedDocument with content and score.""" ++ content = [DocumentPart(text='Test content')] ++ metadata = {'key': 'value'} ++ score = 0.95 ++ ++ doc = RankedDocument(content=content, metadata=metadata, score=score) ++ ++ assert doc.score == 0.95 ++ assert doc.text() == 'Test content' ++ assert doc.metadata == {'key': 'value', 'score': 0.95} ++ # Original metadata should not be modified ++ assert metadata == {'key': 'value'} ++ ++ def test_ranked_document_default_score(self): ++ """Test that RankedDocument has a default score of None.""" ++ content = [DocumentPart(text='Test')] ++ doc = RankedDocument(content=content) ++ ++ assert doc.score is None ++ ++ def test_ranked_document_from_data(self): ++ """Test creating RankedDocument from RankedDocumentData.""" ++ data = RankedDocumentData( ++ content=[DocumentPart(text='Test content')], ++ metadata=RankedDocumentMetadata(score=0.85), ++ ) ++ ++ doc = RankedDocument.from_ranked_document_data(data) ++ ++ assert doc.score == 0.85 ++ assert doc.text() == 'Test content' ++ ++ ++class TestRerankerRef: ++ """Tests for RerankerRef and related helper functions.""" ++ ++ def test_create_reranker_ref_basic(self): ++ """Test creating a basic reranker reference.""" ++ ref = create_reranker_ref('test-reranker') ++ ++ assert ref.name == 'test-reranker' ++ assert ref.config is None ++ assert ref.version is None ++ assert ref.info is None ++ ++ def test_create_reranker_ref_with_options(self): ++ """Test creating a reranker reference with all options.""" ++ info = RerankerInfo(label='Test Reranker') ++ ref = create_reranker_ref( ++ name='test-reranker', ++ config={'k': 10}, ++ version='1.0.0', ++ info=info, ++ ) ++ ++ assert ref.name == 'test-reranker' ++ assert ref.config == {'k': 10} ++ assert ref.version == '1.0.0' ++ assert ref.info.label == 'Test Reranker' ++ ++ ++class TestRerankerActionMetadata: ++ """Tests for reranker action metadata creation.""" ++ ++ def test_action_metadata_basic(self): ++ """Test creating basic action metadata.""" ++ metadata = reranker_action_metadata('test-reranker') ++ ++ assert metadata.kind == ActionKind.RERANKER ++ assert metadata.name == 'test-reranker' ++ assert 'reranker' in metadata.metadata ++ ++ def test_action_metadata_with_options(self): ++ """Test creating action metadata with options.""" ++ options = RerankerOptions( ++ label='Custom Label', ++ config_schema={'type': 'object'}, ++ ) ++ metadata = reranker_action_metadata('test-reranker', options) ++ ++ assert metadata.metadata['reranker']['label'] == 'Custom Label' ++ assert metadata.metadata['reranker']['customOptions'] == {'type': 'object'} ++ ++ ++class TestDefineReranker: ++ """Tests for the define_reranker function.""" ++ ++ @pytest.fixture ++ def registry(self): ++ """Create a fresh registry for each test.""" ++ return Registry() ++ ++ @pytest.mark.asyncio ++ async def test_define_reranker_registers_action(self, registry): ++ """Test that define_reranker registers an action in the registry.""" ++ ++ async def simple_reranker(query, documents, options): ++ # Return documents in same order with scores ++ return RerankerResponse( ++ documents=[ ++ RankedDocumentData( ++ content=doc.content, ++ metadata=RankedDocumentMetadata(score=1.0 - i * 0.1), ++ ) ++ for i, doc in enumerate(documents) ++ ] ++ ) ++ ++ action = define_reranker(registry, 'test-reranker', simple_reranker) ++ ++ # Verify action was registered ++ lookup = registry.lookup_action(ActionKind.RERANKER, 'test-reranker') ++ assert lookup is not None ++ assert action.name == 'test-reranker' ++ ++ @pytest.mark.asyncio ++ async def test_define_reranker_with_options(self, registry): ++ """Test define_reranker with custom options.""" ++ ++ async def reranker_fn(query, documents, options): ++ return RerankerResponse(documents=[]) ++ ++ options = RerankerOptions(label='My Reranker') ++ action = define_reranker(registry, 'my-reranker', reranker_fn, options) ++ ++ assert action is not None ++ ++ ++class TestRerank: ++ """Tests for the rerank function.""" ++ ++ @pytest.fixture ++ def registry(self): ++ """Create a fresh registry for each test.""" ++ return Registry() ++ ++ @pytest.fixture ++ def sample_documents(self): ++ """Create sample documents for testing.""" ++ return [ ++ DocumentData(content=[DocumentPart(text='First document')]), ++ DocumentData(content=[DocumentPart(text='Second document')]), ++ DocumentData(content=[DocumentPart(text='Third document')]), ++ ] ++ ++ @pytest.mark.asyncio ++ async def test_rerank_with_string_query(self, registry, sample_documents): ++ """Test rerank with a string query.""" ++ ++ async def score_by_length(query, documents, options): ++ # Score documents by content length (longer = higher score) ++ scored = [] ++ for doc in documents: ++ length = len(doc.text()) ++ scored.append( ++ RankedDocumentData( ++ content=doc.content, ++ metadata=RankedDocumentMetadata(score=float(length)), ++ ) ++ ) ++ return RerankerResponse(documents=scored) ++ ++ define_reranker(registry, 'length-reranker', score_by_length) ++ ++ results = await rerank( ++ registry, ++ { ++ 'reranker': 'length-reranker', ++ 'query': 'test query', ++ 'documents': sample_documents, ++ }, ++ ) ++ ++ assert len(results) == 3 ++ assert all(isinstance(r, RankedDocument) for r in results) ++ ++ @pytest.mark.asyncio ++ async def test_rerank_with_reranker_ref(self, registry, sample_documents): ++ """Test rerank with a RerankerRef.""" ++ ++ async def simple_reranker(query, documents, options): ++ return RerankerResponse( ++ documents=[ ++ RankedDocumentData( ++ content=doc.content, ++ metadata=RankedDocumentMetadata(score=0.5), ++ ) ++ for doc in documents ++ ] ++ ) ++ ++ define_reranker(registry, 'ref-reranker', simple_reranker) ++ ref = create_reranker_ref('ref-reranker') ++ ++ results = await rerank( ++ registry, ++ { ++ 'reranker': ref, ++ 'query': 'test', ++ 'documents': sample_documents, ++ }, ++ ) ++ ++ assert len(results) == 3 ++ assert all(doc.score == 0.5 for doc in results) ++ ++ @pytest.mark.asyncio ++ async def test_rerank_unknown_reranker_raises(self, registry, sample_documents): ++ """Test that rerank raises ValueError for unknown reranker.""" ++ with pytest.raises(ValueError, match='Unable to resolve reranker'): ++ await rerank( ++ registry, ++ { ++ 'reranker': 'non-existent-reranker', ++ 'query': 'test', ++ 'documents': sample_documents, ++ }, ++ ) ++ ++ ++class TestCustomRerankers: ++ """Tests for custom reranker implementations. ++ ++ These tests demonstrate how to create custom rerankers as shown ++ in the genkit.dev documentation: ++ https://genkit.dev/docs/rag/#rerankers-and-two-stage-retrieval ++ """ ++ ++ @pytest.fixture ++ def registry(self): ++ """Create a fresh registry for each test.""" ++ return Registry() ++ ++ @pytest.fixture ++ def sample_documents(self): ++ """Create sample documents matching genkit.dev documentation example.""" ++ return [ ++ DocumentData(content=[DocumentPart(text='pythagorean theorem')]), ++ DocumentData(content=[DocumentPart(text='e=mc^2')]), ++ DocumentData(content=[DocumentPart(text='pi')]), ++ DocumentData(content=[DocumentPart(text='dinosaurs')]), ++ DocumentData(content=[DocumentPart(text='quantum mechanics')]), ++ DocumentData(content=[DocumentPart(text='pizza')]), ++ DocumentData(content=[DocumentPart(text='harry potter')]), ++ ] ++ ++ @pytest.mark.asyncio ++ async def test_custom_keyword_overlap_reranker(self, registry, sample_documents): ++ """Test a custom reranker that scores by keyword overlap. ++ ++ This demonstrates the pattern shown in genkit.dev docs for ++ creating custom reranking logic. ++ """ ++ ++ async def keyword_overlap_reranker(query, documents, options): ++ """Reranker that scores documents by keyword overlap with query.""" ++ query_words = set(query.text().lower().split()) ++ scored = [] ++ ++ for doc in documents: ++ doc_words = set(doc.text().lower().split()) ++ overlap = len(query_words & doc_words) ++ score = overlap / max(len(query_words), 1) ++ scored.append((doc, score)) ++ ++ # Sort by score descending ++ scored.sort(key=lambda x: x[1], reverse=True) ++ ++ # Apply k limit if provided in options ++ k = options.get('k', len(scored)) if options else len(scored) ++ top_k = scored[:k] ++ ++ return RerankerResponse( ++ documents=[ ++ RankedDocumentData( ++ content=doc.content, ++ metadata=RankedDocumentMetadata(score=score), ++ ) ++ for doc, score in top_k ++ ] ++ ) ++ ++ define_reranker(registry, 'custom/keyword-overlap', keyword_overlap_reranker) ++ ++ # Query for 'quantum' should rank 'quantum mechanics' highest ++ results = await rerank( ++ registry, ++ { ++ 'reranker': 'custom/keyword-overlap', ++ 'query': 'quantum mechanics physics', ++ 'documents': sample_documents, ++ }, ++ ) ++ ++ assert len(results) == 7 ++ # 'quantum mechanics' should have the highest score (overlaps 2 words) ++ assert results[0].text() == 'quantum mechanics' ++ assert results[0].score > 0 ++ ++ @pytest.mark.asyncio ++ async def test_custom_reranker_with_top_k_option(self, registry, sample_documents): ++ """Test custom reranker with k option to limit results. ++ ++ Demonstrates using options to configure reranking behavior. ++ """ ++ ++ async def random_score_reranker(query, documents, options): ++ """Reranker that assigns incrementing scores and respects k option.""" ++ k = options.get('k', 3) if options else 3 ++ ++ scored_docs = [] ++ for i, doc in enumerate(documents): ++ # Score in reverse order so we have a predictable ranking ++ score = float(len(documents) - i) ++ scored_docs.append( ++ RankedDocumentData( ++ content=doc.content, ++ metadata=RankedDocumentMetadata(score=score), ++ ) ++ ) ++ ++ # Sort by score descending and limit to k ++ scored_docs.sort(key=lambda d: d.metadata.score, reverse=True) ++ return RerankerResponse(documents=scored_docs[:k]) ++ ++ define_reranker(registry, 'custom/with-k-option', random_score_reranker) ++ ++ results = await rerank( ++ registry, ++ { ++ 'reranker': 'custom/with-k-option', ++ 'query': 'test', ++ 'documents': sample_documents, ++ 'options': {'k': 3}, ++ }, ++ ) ++ ++ # Should only return top 3 results ++ assert len(results) == 3 ++ ++ @pytest.mark.asyncio ++ async def test_custom_reranker_preserves_document_content(self, registry): ++ """Test that custom reranker preserves original document content.""" ++ ++ async def identity_reranker(query, documents, options): ++ """Reranker that returns documents with their original content.""" ++ return RerankerResponse( ++ documents=[ ++ RankedDocumentData( ++ content=doc.content, ++ metadata=RankedDocumentMetadata(score=1.0), ++ ) ++ for doc in documents ++ ] ++ ) ++ ++ define_reranker(registry, 'custom/identity', identity_reranker) ++ ++ original_texts = ['Document A', 'Document B with more text', 'Doc C'] ++ documents = [DocumentData(content=[DocumentPart(text=t)]) for t in original_texts] ++ ++ results = await rerank( ++ registry, ++ { ++ 'reranker': 'custom/identity', ++ 'query': 'test', ++ 'documents': documents, ++ }, ++ ) ++ ++ # Verify all original content is preserved ++ result_texts = [doc.text() for doc in results] ++ assert result_texts == original_texts ++ ++ @pytest.mark.asyncio ++ async def test_custom_reranker_two_stage_retrieval_pattern(self, registry): ++ """Test the two-stage retrieval pattern: retrieve then rerank. ++ ++ This demonstrates the typical RAG pattern where: ++ 1. Stage 1: Retrieve a broad set of candidates ++ 2. Stage 2: Rerank to find most relevant documents ++ """ ++ ++ # Simulate stage 1 retrieval results (unranked) ++ retrieved_documents = [ ++ DocumentData(content=[DocumentPart(text='Machine learning is a subset of AI')]), ++ DocumentData(content=[DocumentPart(text='Pizza is a popular food')]), ++ DocumentData(content=[DocumentPart(text='Deep learning uses neural networks')]), ++ DocumentData(content=[DocumentPart(text='Cats are domestic animals')]), ++ DocumentData(content=[DocumentPart(text='AI transforms industries')]), ++ ] ++ ++ async def relevance_reranker(query, documents, options): ++ """Reranker that scores by word presence in query.""" ++ query_lower = query.text().lower() ++ scored = [] ++ ++ for doc in documents: ++ doc_text = doc.text().lower() ++ # Simple relevance: count query words in document ++ score = sum(1 for word in query_lower.split() if word in doc_text) ++ scored.append((doc, float(score))) ++ ++ scored.sort(key=lambda x: x[1], reverse=True) ++ ++ return RerankerResponse( ++ documents=[ ++ RankedDocumentData( ++ content=doc.content, ++ metadata=RankedDocumentMetadata(score=score), ++ ) ++ for doc, score in scored ++ ] ++ ) ++ ++ define_reranker(registry, 'custom/relevance', relevance_reranker) ++ ++ # Stage 2: Rerank with query about AI ++ reranked = await rerank( ++ registry, ++ { ++ 'reranker': 'custom/relevance', ++ 'query': 'artificial intelligence AI', ++ 'documents': retrieved_documents, ++ }, ++ ) ++ ++ # AI-related documents should rank higher than unrelated ones ++ # Get scores for AI and non-AI documents ++ ai_scores = [doc.score for doc in reranked if 'AI' in doc.text() or 'learning' in doc.text()] ++ non_ai_scores = [doc.score for doc in reranked if 'Pizza' in doc.text() or 'Cats' in doc.text()] ++ ++ # AI-related documents should have higher scores on average ++ assert max(ai_scores) > max(non_ai_scores) +diff --git a/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py b/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py +index 173c6a803..142e9c79f 100644 +--- a/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py ++++ b/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py +@@ -98,7 +98,7 @@ async def test_notify_endpoint(asgi_client): + @pytest.mark.asyncio + async def test_run_action_not_found(asgi_client, mock_registry): + """Test that requesting a non-existent action returns a 404 error.""" +- mock_registry.lookup_action_by_key.return_value = None ++ mock_registry.resolve_action_by_key.return_value = None + response = await asgi_client.post( + '/api/runAction', + json={'key': 'non_existent_action', 'input': {'data': 'test'}}, +@@ -116,7 +116,7 @@ async def test_run_action_standard(asgi_client, mock_registry): + mock_output.trace_id = 'test_trace_id' + mock_action.arun_raw.return_value = mock_output + +- mock_registry.lookup_action_by_key.return_value = mock_action ++ mock_registry.resolve_action_by_key.return_value = mock_action + + response = await asgi_client.post('/api/runAction', json={'key': 'test_action', 'input': {'data': 'test'}}) + +@@ -137,7 +137,7 @@ async def test_run_action_with_context(asgi_client, mock_registry): + mock_output.trace_id = 'test_trace_id' + mock_action.arun_raw.return_value = mock_output + +- mock_registry.lookup_action_by_key.return_value = mock_action ++ mock_registry.resolve_action_by_key.return_value = mock_action + + response = await asgi_client.post( + '/api/runAction', +@@ -169,7 +169,7 @@ async def test_run_action_streaming(mock_is_streaming, asgi_client, mock_registr + return mock_output + + mock_action.arun_raw.side_effect = mock_streaming +- mock_registry.lookup_action_by_key.return_value = mock_action ++ mock_registry.resolve_action_by_key.return_value = mock_action + + response = await asgi_client.post( + '/api/runAction?stream=true', +diff --git a/py/packages/genkit/tests/genkit/core/extract_test.py b/py/packages/genkit/tests/genkit/core/extract_test.py +index 730045f27..25dc27087 100644 +--- a/py/packages/genkit/tests/genkit/core/extract_test.py ++++ b/py/packages/genkit/tests/genkit/core/extract_test.py +@@ -86,7 +86,7 @@ test_cases_extract_items = [ + ids=[tc[0] for tc in test_cases_extract_items], + ) + def test_extract_items(name: str, steps: list[dict[str, Any]]) -> None: +- """Test extraction of incomplete json that can be fixed""" ++ """Test extraction of incomplete json that can be fixed.""" + text = '' + cursor = 0 + for step in steps: +@@ -141,7 +141,7 @@ test_cases_extract_json = [ + ids=[tc[0] for tc in test_cases_extract_json], + ) + def test_extract_json(name: str, input_data: dict[str, Any], expected_data: dict[str, Any]) -> None: +- """Test if input is unfixable raise the correct exception or return the proper error response""" ++ """Test if input is unfixable raise the correct exception or return the proper error response.""" + if expected_data.get('throws'): + with pytest.raises(Exception): + extract_json(input_data['text'], throw_on_bad_json=True) +@@ -186,6 +186,6 @@ test_cases_parse_partial_json = [ + ids=[tc[0] for tc in test_cases_parse_partial_json], + ) + def test_parse_partial_json(name: str, input_str: str, expected_data: dict[str, Any]) -> None: +- """Test if it fixes simple malformed json string""" ++ """Test if it fixes simple malformed json string.""" + result = parse_partial_json(input_str) + assert result == expected_data['expected'] +diff --git a/py/packages/genkit/tests/genkit/core/registry_test.py b/py/packages/genkit/tests/genkit/core/registry_test.py +index 6c5b2f52f..d280900c9 100644 +--- a/py/packages/genkit/tests/genkit/core/registry_test.py ++++ b/py/packages/genkit/tests/genkit/core/registry_test.py +@@ -11,7 +11,7 @@ functionality, ensuring proper registration and management of Genkit resources. + + import pytest + +-from genkit.ai import Genkit, GenkitRegistry, Plugin ++from genkit.ai import Genkit + from genkit.core.action import ActionMetadata + from genkit.core.action.types import ActionKind, ActionMetadataKey + from genkit.core.registry import Registry +@@ -29,17 +29,21 @@ def test_register_list_actions_resolver(): + assert 'test_plugin' in registry._list_actions_resolvers + + +-def test_register_list_actions_resolver_raises_exception(): +- """Test when ValueError is raised.""" ++def test_register_list_actions_resolver_multiple(): ++ """Test that multiple resolvers can be registered for the same plugin.""" + registry = Registry() + +- def list_actions_mock(): ++ def list_actions_mock1(): ++ return [] ++ ++ def list_actions_mock2(): + return [] + +- registry._list_actions_resolvers['test_plugin'] = list_actions_mock ++ registry.register_list_actions_resolver('test_plugin', list_actions_mock1) ++ registry.register_list_actions_resolver('test_plugin', list_actions_mock2) + +- with pytest.raises(ValueError, match=r'Plugin .* already registered'): +- registry.register_list_actions_resolver('test_plugin', list_actions_mock) ++ assert 'test_plugin' in registry._list_actions_resolvers ++ assert len(registry._list_actions_resolvers['test_plugin']) == 2 + + + def test_register_action_with_name_and_kind() -> None: +@@ -159,47 +163,14 @@ def test_list_actions(allowed_kind, expected) -> None: + ] + + registry = Registry() +- registry._list_actions_resolvers['test_plugin'] = list_actions_mock ++ registry._list_actions_resolvers['test_plugin'] = [list_actions_mock] + registry._entries[ActionKind.CUSTOM] = {} + registry._entries[ActionKind.TOOL] = {} + +- got = registry.list_actions({}, allowed_kind) ++ got = registry.list_actions_sync({}, allowed_kind) + assert got == expected + + +-def test_resolve_action_from_plugin(): +- """Resolve action from plugin test.""" +- resolver_calls = [] +- +- class MyPlugin(Plugin): +- name = 'myplugin' +- +- def resolve_action(self, ai: GenkitRegistry, kind: ActionKind, name: str): +- nonlocal resolver_calls +- resolver_calls.append([kind, name]) +- +- def model_fn(): +- pass +- +- ai.define_model(name=name, fn=model_fn) +- +- def initialize(self, ai: GenkitRegistry) -> None: +- pass +- +- ai = Genkit(plugins=[MyPlugin()]) +- +- action = ai.registry.lookup_action(ActionKind.MODEL, 'myplugin/foo') +- +- assert action is not None +- assert len(resolver_calls) == 1 +- +- assert resolver_calls == [[ActionKind.MODEL, 'myplugin/foo']] +- +- # should be idempotent +- ai.registry.lookup_action(ActionKind.MODEL, 'myplugin/foo') +- assert len(resolver_calls) == 1 +- +- + def test_register_value(): + """Register a value and lookup test.""" + registry = Registry() +diff --git a/py/packages/genkit/tests/genkit/lang/deprecations_test.py b/py/packages/genkit/tests/genkit/lang/deprecations_test.py +index 5344d98ed..6b88bc398 100644 +--- a/py/packages/genkit/tests/genkit/lang/deprecations_test.py ++++ b/py/packages/genkit/tests/genkit/lang/deprecations_test.py +@@ -20,8 +20,6 @@ import sys + import unittest + import warnings + +-import pytest +- + if sys.version_info < (3, 11): + from strenum import StrEnum + else: +diff --git a/py/packages/genkit/tests/genkit/veneer/resource_test.py b/py/packages/genkit/tests/genkit/veneer/resource_test.py +new file mode 100644 +index 000000000..e3324e568 +--- /dev/null ++++ b/py/packages/genkit/tests/genkit/veneer/resource_test.py +@@ -0,0 +1,49 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++"""Tests for the Genkit Resource API via the Genkit class (Veneer). ++This test file verifies that `ai.define_resource` works correctly, mirroring the ++JS SDK's `ai.defineResource`. ++""" ++ ++import asyncio ++ ++import pytest ++ ++from genkit.ai import Genkit ++from genkit.core.typing import Part, TextPart ++ ++ ++@pytest.mark.asyncio ++async def test_define_resource_veneer(): ++ """Verifies ai.define_resource registers a resource correctly.""" ++ ai = Genkit(plugins=[]) ++ ++ async def my_resource_fn(input, ctx): ++ return {'content': [Part(root=TextPart(text=f'Content for {input.uri}'))]} ++ ++ act = ai.define_resource({'uri': 'http://example.com/foo'}, my_resource_fn) ++ ++ assert act.name == 'http://example.com/foo' ++ assert act.metadata['resource']['uri'] == 'http://example.com/foo' ++ ++ # Verify lookup via global registry (contained in ai.registry) ++ looked_up = ai.registry.lookup_action('resource', 'http://example.com/foo') ++ assert looked_up == act ++ ++ # Verify execution ++ output = await act.arun({'uri': 'http://example.com/foo'}) ++ assert 'Content for http://example.com/foo' in output.response['content'][0]['text'] +diff --git a/py/packages/genkit/tests/genkit/veneer/veneer_test.py b/py/packages/genkit/tests/genkit/veneer/veneer_test.py +index bc987c829..cc7854b1a 100644 +--- a/py/packages/genkit/tests/genkit/veneer/veneer_test.py ++++ b/py/packages/genkit/tests/genkit/veneer/veneer_test.py +@@ -1,6 +1,6 @@ + #!/usr/bin/env python3 + # +-# Copyright 2025 Google LLC ++# Copyright 2026 Google LLC + # SPDX-License-Identifier: Apache-2.0 + + """Tests for the action module.""" +@@ -1158,8 +1158,8 @@ async def test_generate_simulates_doc_grounding( + assert (await response).request.messages[0] == want_msg + + +-class TestFormat(FormatDef): +- """Test format for testing the format.""" ++class MockBananaFormat(FormatDef): ++ """Mock format for testing the format.""" + + def __init__(self): + """Initialize the format.""" +@@ -1200,7 +1200,7 @@ async def test_define_format(setup_test: SetupFixture) -> None: + """Test that the define format function works.""" + ai, _, pm, *_ = setup_test + +- ai.define_format(TestFormat()) ++ ai.define_format(MockBananaFormat()) + + class TestSchema(BaseModel): + foo: int = Field(None, description='foo field') +diff --git a/py/plugins/anthropic/pyproject.toml b/py/plugins/anthropic/pyproject.toml +index b9875c4d1..0ab15ba50 100644 +--- a/py/plugins/anthropic/pyproject.toml ++++ b/py/plugins/anthropic/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -35,7 +34,7 @@ classifiers = [ + ] + dependencies = ["genkit", "anthropic>=0.40.0"] + description = "Genkit Anthropic Plugin" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "genkit-plugin-anthropic" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/plugins/anthropic/src/genkit/plugins/anthropic/plugin.py b/py/plugins/anthropic/src/genkit/plugins/anthropic/plugin.py +index 7461c7479..62043b1fb 100644 +--- a/py/plugins/anthropic/src/genkit/plugins/anthropic/plugin.py ++++ b/py/plugins/anthropic/src/genkit/plugins/anthropic/plugin.py +@@ -16,8 +16,10 @@ + + """Anthropic plugin for Genkit.""" + ++import os ++ + from anthropic import AsyncAnthropic +-from genkit.ai import PluginV2 ++from genkit.ai import Plugin + from genkit.blocks.model import model + from genkit.core.action import Action, ActionMetadata + from genkit.core.registry import ActionKind +@@ -40,39 +42,50 @@ def anthropic_name(name: str) -> str: + return f'{ANTHROPIC_PLUGIN_NAME}/{name}' + + +-class Anthropic(PluginV2): +- """Anthropic plugin for Genkit (v2). ++class Anthropic(Plugin): ++ """Anthropic plugin for Genkit. + + This plugin adds Anthropic models to Genkit for generative AI applications. + Can be used standalone (without framework) or with Genkit framework. + + Example (standalone): +- >>> plugin = Anthropic(api_key="...") +- >>> claude = await plugin.model("claude-3-5-sonnet") +- >>> response = await claude.arun({"messages": [...]}) ++ >>> plugin = Anthropic(api_key='...') ++ >>> claude = await plugin.model('claude-3-5-sonnet') ++ >>> response = await claude.arun({'messages': [...]}) + + Example (with framework): +- >>> ai = Genkit(plugins=[Anthropic(api_key="...")]) +- >>> response = await ai.generate("anthropic/claude-3-5-sonnet", prompt="Hi") ++ >>> ai = Genkit(plugins=[Anthropic(api_key='...')]) ++ >>> response = await ai.generate('anthropic/claude-3-5-sonnet', prompt='Hi') + """ + + name = ANTHROPIC_PLUGIN_NAME + + def __init__( + self, ++ api_key: str | None = None, ++ models: list[str] | None = None, + **anthropic_params: str, + ) -> None: + """Initializes Anthropic plugin with given configuration. + + Args: ++ api_key: Optional Anthropic API key. If not provided, uses `ANTHROPIC_API_KEY` ++ from the environment (or lets the Anthropic client handle defaults). ++ models: Optional list of supported Anthropic models to expose via this plugin. + **anthropic_params: Additional parameters passed to the AsyncAnthropic client. + This may include api_key, base_url, timeout, and other configuration + settings required by Anthropic's API. + """ ++ if api_key is None: ++ api_key = os.getenv('ANTHROPIC_API_KEY') ++ ++ self.models = models or list(SUPPORTED_ANTHROPIC_MODELS.keys()) + self._anthropic_params = anthropic_params +- self._anthropic_client = AsyncAnthropic(**anthropic_params) ++ self._anthropic_client = ( ++ AsyncAnthropic(api_key=api_key, **anthropic_params) if api_key else AsyncAnthropic(**anthropic_params) ++ ) + +- def init(self) -> list[Action]: ++ async def init(self) -> list[Action]: + """Return eagerly-initialized model actions. + + Called once during Genkit initialization. Loads ALL supported +@@ -81,12 +94,9 @@ class Anthropic(PluginV2): + Returns: + List of Action objects for all supported models. + """ +- return [ +- self._create_model_action(model_name) +- for model_name in SUPPORTED_ANTHROPIC_MODELS.keys() +- ] ++ return [self._create_model_action(model_name) for model_name in self.models] + +- def resolve(self, action_type: ActionKind, name: str) -> Action | None: ++ async def resolve(self, action_type: ActionKind, name: str) -> Action | None: + """Resolve a specific model action on-demand. + + Called when framework needs an action not from init(). +@@ -101,12 +111,12 @@ class Anthropic(PluginV2): + """ + if action_type == ActionKind.MODEL: + # Check if we support this model +- if name in SUPPORTED_ANTHROPIC_MODELS: ++ if name in self.models: + return self._create_model_action(name) + + return None + +- def list(self) -> list[ActionMetadata]: ++ async def list_actions(self) -> list[ActionMetadata]: + """Return metadata for all supported Anthropic models. + + Used for discovery and developer tools. +@@ -120,10 +130,9 @@ class Anthropic(PluginV2): + kind=ActionKind.MODEL, + info=get_model_info(model_name).model_dump(), + ) +- for model_name in SUPPORTED_ANTHROPIC_MODELS.keys() ++ for model_name in self.models + ] + +- + def _create_model_action(self, model_name: str) -> Action: + """Create an Action for an Anthropic model (doesn't register). + +diff --git a/py/plugins/anthropic/tests/test_plugin.py b/py/plugins/anthropic/tests/test_plugin.py +index f55247f4e..3c69cd5f1 100644 +--- a/py/plugins/anthropic/tests/test_plugin.py ++++ b/py/plugins/anthropic/tests/test_plugin.py +@@ -16,8 +16,11 @@ + + """Tests for Anthropic plugin.""" + +-from unittest.mock import ANY, MagicMock, patch ++from unittest.mock import patch + ++import pytest ++ ++from genkit.ai import Genkit + from genkit.core.registry import ActionKind + from genkit.plugins.anthropic import Anthropic, anthropic_name + from genkit.plugins.anthropic.model_info import ( +@@ -68,30 +71,27 @@ def test_custom_models(): + assert plugin.models == ['claude-sonnet-4'] + + +-def test_plugin_initialize(): +- """Test plugin registry initialization.""" +- registry = MagicMock() ++@pytest.mark.asyncio ++async def test_plugin_initialize(): ++ """Test plugin registration with the Genkit framework.""" + plugin = Anthropic(api_key='test-key', models=['claude-sonnet-4']) + +- plugin.initialize(registry) ++ ai = Genkit(plugins=[plugin]) + +- assert registry.define_model.call_count == 1 +- registry.define_model.assert_called_once_with( +- name='anthropic/claude-sonnet-4', +- fn=ANY, +- config_schema=ANY, +- metadata=ANY, +- ) ++ action = await ai.registry.resolve_action(ActionKind.MODEL, 'anthropic/claude-sonnet-4') ++ assert action is not None ++ assert action.name == 'anthropic/claude-sonnet-4' + + +-def test_resolve_action_model(): +- """Test resolve_action for model.""" +- registry = MagicMock() +- plugin = Anthropic(api_key='test-key') ++@pytest.mark.asyncio ++async def test_resolve_action_model(): ++ """Test resolve() can lazily provide a model action.""" ++ plugin = Anthropic(api_key='test-key', models=['claude-sonnet-4']) + +- plugin.resolve_action(registry, ActionKind.MODEL, 'claude-sonnet-4') ++ action = await plugin.resolve(ActionKind.MODEL, 'claude-sonnet-4') + +- registry.define_model.assert_called_once() ++ assert action is not None ++ assert action.name == 'claude-sonnet-4' + + + def test_supported_models(): +diff --git a/py/plugins/compat-oai/pyproject.toml b/py/plugins/compat-oai/pyproject.toml +index 44a4fe828..3dd804f71 100644 +--- a/py/plugins/compat-oai/pyproject.toml ++++ b/py/plugins/compat-oai/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -35,7 +34,7 @@ classifiers = [ + ] + dependencies = ["genkit", "openai", "strenum>=0.4.15; python_version < '3.11'"] + description = "Genkit OpenAI API Compatible" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "genkit-plugin-compat-oai" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/handler.py b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/handler.py +index 70b3e4e2f..f02cfb060 100644 +--- a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/handler.py ++++ b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/handler.py +@@ -21,7 +21,7 @@ from typing import Any + + from openai import OpenAI + +-from genkit.ai import ActionRunContext, GenkitRegistry ++from genkit.ai import ActionRunContext + from genkit.plugins.compat_oai.models.model import OpenAIModel + from genkit.plugins.compat_oai.models.model_info import ( + SUPPORTED_OPENAI_COMPAT_MODELS, +@@ -51,6 +51,7 @@ class OpenAIModelHandler: + @staticmethod + def _get_supported_models(source: PluginSource) -> dict[str, Any]: + """Returns the supported models based on the plugin source. ++ + Args: + source: Helps distinguish if model handler is called from model-garden plugin. + Default source is openai. +@@ -59,12 +60,11 @@ class OpenAIModelHandler: + Openai models if source is openai. Merges supported openai models with openai-compat models if source is model-garden. + + """ +- + return SUPPORTED_OPENAI_COMPAT_MODELS if source == PluginSource.MODEL_GARDEN else SUPPORTED_OPENAI_MODELS + + @classmethod + def get_model_handler( +- cls, model: str, client: OpenAI, registry: GenkitRegistry, source: PluginSource = PluginSource.OPENAI ++ cls, model: str, client: OpenAI, source: PluginSource = PluginSource.OPENAI + ) -> Callable[[GenerateRequest, ActionRunContext], GenerateResponse]: + """Factory method to initialize the model handler for the specified OpenAI model. + +@@ -89,10 +89,13 @@ class OpenAIModelHandler: + """ + supported_models = cls._get_supported_models(source) + +- if model not in supported_models: ++ # For the OpenAI compat plugin, we allow arbitrary model names (the OpenAI API ++ # can serve models beyond our static known list). For Model Garden, keep the ++ # strict validation. ++ if model not in supported_models and source == PluginSource.MODEL_GARDEN: + raise ValueError(f"Model '{model}' is not supported.") + +- openai_model = OpenAIModel(model, client, registry) ++ openai_model = OpenAIModel(model, client) + return cls(openai_model, source).generate + + def _validate_version(self, version: str) -> None: +@@ -105,7 +108,10 @@ class OpenAIModelHandler: + ValueError: If the specified model version is not supported. + """ + supported_models = self._get_supported_models(self._source) +- model_info = supported_models[self._model.name] ++ model_info = supported_models.get(self._model.name) ++ if model_info is None: ++ # Unknown model; skip version validation. ++ return + if version not in model_info.versions: + raise ValueError(f"Model version '{version}' is not supported.") + +diff --git a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/model.py b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/model.py +index bf1f7315a..08e80f604 100644 +--- a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/model.py ++++ b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/model.py +@@ -19,10 +19,9 @@ + from collections.abc import Callable + from typing import Any + +-from openai import OpenAI, pydantic_function_tool ++from openai import OpenAI + from openai.lib._pydantic import _ensure_strict_json_schema + +-from genkit.ai import ActionKind, GenkitRegistry + from genkit.core.action._action import ActionRunContext + from genkit.plugins.compat_oai.models.model_info import SUPPORTED_OPENAI_MODELS + from genkit.plugins.compat_oai.models.utils import DictMessageAdapter, MessageAdapter, MessageConverter +@@ -41,17 +40,15 @@ from genkit.types import ( + class OpenAIModel: + """Handles OpenAI API interactions for the Genkit plugin.""" + +- def __init__(self, model: str, client: OpenAI, registry: GenkitRegistry): ++ def __init__(self, model: str, client: OpenAI): + """Initializes the OpenAIModel instance with the specified model and OpenAI client parameters. + + Args: + model: The OpenAI model to use for generating responses. + client: OpenAI client instance. +- registry: The registry where OpenAI models will be registered. + """ + self._model = model + self._openai_client = client +- self._registry = registry + + @property + def name(self) -> str: +@@ -85,15 +82,19 @@ class OpenAIModel: + Returns: + A list of dictionaries representing the formatted tools. + """ +- result = [] ++ # NOTE: ToolDefinition objects already contain JSON Schema for inputs/outputs. ++ # Do NOT reach back into the registry to reconstruct schemas. ++ result: list[dict[str, Any]] = [] + for tool_definition in tools: +- action = self._registry.registry.lookup_action(ActionKind.TOOL, tool_definition.name) +- function_call = pydantic_function_tool( +- model=action.input_type._type, +- name=tool_definition.name, +- description=tool_definition.description, +- ) +- result.append(function_call) ++ parameters = tool_definition.input_schema or {'type': 'object', 'properties': {}} ++ result.append({ ++ 'type': 'function', ++ 'function': { ++ 'name': tool_definition.name, ++ 'description': tool_definition.description, ++ 'parameters': parameters, ++ }, ++ }) + return result + + def _get_response_format(self, output: OutputConfig) -> dict | None: +@@ -140,6 +141,11 @@ class OpenAIModel: + } + if request.tools: + openai_config['tools'] = self._get_tools_definition(request.tools) ++ if any(msg.role == Role.TOOL for msg in request.messages): ++ # After a tool response, stop forcing additional tool calls. ++ openai_config['tool_choice'] = 'none' ++ elif request.tool_choice: ++ openai_config['tool_choice'] = request.tool_choice + if request.output: + openai_config['response_format'] = self._get_response_format(request.output) + if request.config: +diff --git a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/openai_plugin.py b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/openai_plugin.py +index 19cd4ebcf..6fe1d129e 100644 +--- a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/openai_plugin.py ++++ b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/openai_plugin.py +@@ -15,25 +15,23 @@ + # SPDX-License-Identifier: Apache-2.0 + + +-"""OpenAI OpenAI API Compatible Plugin for Genkit.""" ++"""OpenAI OpenAI API Compatible Plugin for Genkit (v2).""" + +-from functools import cached_property +-from typing import Any, Callable ++from collections.abc import Callable ++from typing import Any + + from openai import OpenAI as OpenAIClient +-from openai.types import Embedding, Model ++from openai.types import Model + +-from genkit.ai._plugin import Plugin +-from genkit.ai._registry import GenkitRegistry ++from genkit.ai import Plugin + from genkit.blocks.embedding import EmbedderOptions, EmbedderSupports, embedder_action_metadata +-from genkit.blocks.model import model_action_metadata ++from genkit.blocks.model import model, model_action_metadata + from genkit.core.action import ActionMetadata +-from genkit.core.action.types import ActionKind ++from genkit.core.registry import ActionKind + from genkit.core.typing import GenerationCommonConfig + from genkit.plugins.compat_oai.models import ( + SUPPORTED_OPENAI_COMPAT_MODELS, + SUPPORTED_OPENAI_MODELS, +- OpenAIModel, + OpenAIModelHandler, + ) + from genkit.plugins.compat_oai.models.model_info import get_default_openai_model_info +@@ -78,26 +76,9 @@ class OpenAI(Plugin): + self._openai_params = openai_params + self._openai_client = OpenAIClient(**openai_params) + +- def initialize(self, ai: GenkitRegistry) -> None: +- """Registers supported OpenAI models in the given registry. +- +- Args: +- ai: The registry where OpenAI models will be registered. +- """ +- for model_name, model_info in SUPPORTED_OPENAI_MODELS.items(): +- handler = OpenAIModelHandler.get_model_handler(model=model_name, client=self._openai_client, registry=ai) +- +- ai.define_model( +- name=f'openai/{model_name}', +- fn=handler, +- config_schema=OpenAIConfig, +- metadata={ +- 'model': { +- 'label': model_info.label, +- 'supports': {'multiturn': model_info.supports.multiturn} if model_info.supports else {}, +- }, +- }, +- ) ++ async def init(self): ++ """Return eagerly-initialized model actions.""" ++ return [self._create_model_action(model_name) for model_name in SUPPORTED_OPENAI_MODELS.keys()] + + def get_model_info(self, name: str) -> dict[str, str] | None: + """Retrieves metadata and supported features for the specified model. +@@ -111,66 +92,40 @@ class OpenAI(Plugin): + is provided). The 'supports' key contains a dictionary representing + the model's capabilities (e.g., tools, streaming). + """ +- + if model_supported := SUPPORTED_OPENAI_MODELS.get(name): + return { + 'label': model_supported.label, + 'supports': model_supported.supports.model_dump(exclude_none=True), + } + +- model_info = SUPPORTED_OPENAI_COMPAT_MODELS.get(name, get_default_openai_model_info(self)) ++ model_info = SUPPORTED_OPENAI_COMPAT_MODELS.get(name, get_default_openai_model_info(name)) + return { + 'label': model_info.label, + 'supports': model_info.supports.model_dump(exclude_none=True), + } + +- def resolve_action( # noqa: B027 +- self, +- ai: GenkitRegistry, +- kind: ActionKind, +- name: str, +- ) -> None: +- if kind is not ActionKind.MODEL: ++ async def resolve(self, action_type: ActionKind, name: str): ++ if action_type != ActionKind.MODEL: + return None + +- self._define_openai_model(ai, name) +- return None +- +- def to_openai_compatible_model(self, name: str, ai: GenkitRegistry) -> Callable: +- """Converts a OpenAi model into an OpenAI-compatible Genkit model function. +- +- Returns: +- A callable function (specifically, the `generate` method of an +- `OpenAIModel` instance) that can be used by Genkit. +- """ +- +- openai_model = OpenAIModelHandler(OpenAIModel(name, self._openai_client, ai)) +- return openai_model.generate ++ clean_name = name.replace('openai/', '') if name.startswith('openai/') else name ++ return self._create_model_action(clean_name) + +- def _define_openai_model(self, ai: GenkitRegistry, name: str) -> None: +- """Defines and registers an OpenAI model with Genkit. ++ def to_openai_compatible_model(self, name: str) -> Callable: ++ """Return a Genkit model handler for a specific OpenAI model name.""" ++ return OpenAIModelHandler.get_model_handler(model=name, client=self._openai_client) + +- Cleans the model name, instantiates an OpenAI, and registers it +- with the provided Genkit AI registry, including metadata about its capabilities. +- +- Args: +- ai: The Genkit AI registry instance. +- name: The name of the model to be registered. +- """ +- +- handler = self.to_openai_compatible_model(name, ai) ++ def _create_model_action(self, name: str): ++ handler = self.to_openai_compatible_model(name) + model_info = self.get_model_info(name) +- ai.define_model( +- name=open_ai_name(name), ++ return model( ++ name=name, + fn=handler, + config_schema=OpenAIConfig, +- metadata={ +- 'model': model_info, +- }, ++ metadata={'model': model_info}, + ) + +- @cached_property +- def list_actions(self) -> list[ActionMetadata]: ++ async def list_actions(self) -> list[ActionMetadata]: + """Generate a list of available actions or models. + + Returns: +@@ -180,7 +135,6 @@ class OpenAI(Plugin): + - info (dict): The metadata dictionary describing the model configuration and properties. + - config_schema (type): The schema class used for validating the model's configuration. + """ +- + actions = [] + models_ = self._openai_client.models.list() + models: list[Model] = models_.data +diff --git a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/typing.py b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/typing.py +index d8ac810de..0b257bf82 100644 +--- a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/typing.py ++++ b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/typing.py +@@ -24,7 +24,7 @@ if sys.version_info < (3, 11): # noqa + else: # noqa + from enum import StrEnum # noqa + +-from pydantic import BaseModel, ConfigDict ++from pydantic import BaseModel, ConfigDict, Field + + + class OpenAIConfig(BaseModel): +@@ -38,10 +38,14 @@ class OpenAIConfig(BaseModel): + stop: str | list[str] | None = None + max_tokens: int | None = None + stream: bool | None = None ++ frequency_penalty: float | None = Field(default=None, ge=-2, le=2) ++ presence_penalty: float | None = Field(default=None, ge=-2, le=2) ++ logprobs: bool | None = None ++ top_logprobs: int | None = Field(default=None, ge=0, le=20) + + + class SupportedOutputFormat(StrEnum): +- """Model Output Formats""" ++ """Model Output Formats.""" + + JSON_MODE = 'json_mode' + STRUCTURED_OUTPUTS = 'structured_outputs' +diff --git a/py/plugins/compat-oai/tests/test_handler.py b/py/plugins/compat-oai/tests/test_handler.py +index f0321ef3d..a6abcd8a4 100644 +--- a/py/plugins/compat-oai/tests/test_handler.py ++++ b/py/plugins/compat-oai/tests/test_handler.py +@@ -19,28 +19,30 @@ from unittest.mock import MagicMock + + import pytest + +-from genkit.ai import ActionRunContext + from genkit.plugins.compat_oai.models import OpenAIModelHandler +-from genkit.plugins.compat_oai.models.model import OpenAIModel + from genkit.plugins.compat_oai.models.model_info import ( + GPT_3_5_TURBO, + GPT_4, + SUPPORTED_OPENAI_MODELS, ++ PluginSource, + ) +-from genkit.types import GenerateRequest, GenerateResponse, Message, Role, TextPart + + + def test_get_model_handler() -> None: + """Test get_model_handler method returns a callable.""" + model_name = GPT_4 +- handler = OpenAIModelHandler.get_model_handler(model=model_name, client=MagicMock(), registry=MagicMock()) ++ handler = OpenAIModelHandler.get_model_handler(model=model_name, client=MagicMock()) + assert callable(handler) + + + def test_get_model_handler_invalid() -> None: + """Test get_model_handler raises ValueError for unsupported models.""" + with pytest.raises(ValueError, match="Model 'unsupported-model' is not supported."): +- OpenAIModelHandler.get_model_handler(model='unsupported-model', client=MagicMock(), registry=MagicMock()) ++ OpenAIModelHandler.get_model_handler( ++ model='unsupported-model', ++ client=MagicMock(), ++ source=PluginSource.MODEL_GARDEN, ++ ) + + + def test_validate_version() -> None: +diff --git a/py/plugins/compat-oai/tests/test_model.py b/py/plugins/compat-oai/tests/test_model.py +index 0d16d97b1..ec95bb266 100644 +--- a/py/plugins/compat-oai/tests/test_model.py ++++ b/py/plugins/compat-oai/tests/test_model.py +@@ -36,7 +36,7 @@ def test_get_messages(sample_request): + + Ensures the method correctly converts GenerateRequest messages into OpenAI-compatible ChatMessage format. + """ +- model = OpenAIModel(model=GPT_4, client=MagicMock(), registry=MagicMock()) ++ model = OpenAIModel(model=GPT_4, client=MagicMock()) + messages = model._get_messages(sample_request.messages) + + assert len(messages) == 2 +@@ -51,7 +51,7 @@ def test_get_openai_config(sample_request): + + Ensures the method correctly constructs the OpenAI API configuration dictionary. + """ +- model = OpenAIModel(model=GPT_4, client=MagicMock(), registry=MagicMock()) ++ model = OpenAIModel(model=GPT_4, client=MagicMock()) + openai_config = model._get_openai_request_config(sample_request) + + assert isinstance(openai_config, dict) +@@ -72,7 +72,7 @@ def test__generate(sample_request): + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = mock_response + +- model = OpenAIModel(model=GPT_4, client=mock_client, registry=MagicMock()) ++ model = OpenAIModel(model=GPT_4, client=mock_client) + response = model._generate(sample_request) + + mock_client.chat.completions.create.assert_called_once() +@@ -112,7 +112,7 @@ def test__generate_stream(sample_request): + + mock_client.chat.completions.create.return_value = MockStream(['Hello', ', world!']) + +- model = OpenAIModel(model=GPT_4, client=mock_client, registry=MagicMock()) ++ model = OpenAIModel(model=GPT_4, client=mock_client) + collected_chunks = [] + + def callback(chunk: GenerateResponseChunk): +@@ -137,7 +137,7 @@ def test_generate(stream, sample_request): + + mock_response = GenerateResponse(message=Message(role=Role.MODEL, content=[TextPart(text='mocked')])) + +- model = OpenAIModel(model=GPT_4, client=MagicMock(), registry=MagicMock()) ++ model = OpenAIModel(model=GPT_4, client=MagicMock()) + model._generate_stream = MagicMock(return_value=mock_response) + model._generate = MagicMock(return_value=mock_response) + model.normalize_config = MagicMock(return_value={}) +diff --git a/py/plugins/compat-oai/tests/test_plugin.py b/py/plugins/compat-oai/tests/test_plugin.py +index c6ba1cf79..d88ad575b 100644 +--- a/py/plugins/compat-oai/tests/test_plugin.py ++++ b/py/plugins/compat-oai/tests/test_plugin.py +@@ -14,69 +14,49 @@ + # + # SPDX-License-Identifier: Apache-2.0 + +-from unittest.mock import ANY, MagicMock, patch ++from unittest.mock import MagicMock, patch + + import pytest + from openai.types import Model + +-from genkit.ai._aio import Genkit + from genkit.core.action import ActionMetadata + from genkit.core.action.types import ActionKind +-from genkit.plugins.compat_oai import OpenAIConfig + from genkit.plugins.compat_oai.models.model_info import SUPPORTED_OPENAI_MODELS + from genkit.plugins.compat_oai.openai_plugin import OpenAI, openai_model + + +-def test_openai_plugin_initialize() -> None: +- """Test OpenAI plugin registry initialization.""" +- registry = MagicMock(spec=Genkit) ++@pytest.mark.asyncio ++async def test_openai_plugin_initialize() -> None: ++ """Test OpenAI plugin init() returns model actions.""" + plugin = OpenAI(api_key='test-key') + + with patch('genkit.plugins.compat_oai.models.OpenAIModelHandler.get_model_handler') as mock_get_handler: + mock_handler = MagicMock() + mock_get_handler.return_value = mock_handler + +- plugin.initialize(registry) ++ actions = await plugin.init() + + assert mock_get_handler.call_count == len(SUPPORTED_OPENAI_MODELS) +- assert registry.define_model.call_count == len(SUPPORTED_OPENAI_MODELS) ++ assert len(actions) == len(SUPPORTED_OPENAI_MODELS) + + + @pytest.mark.parametrize( + 'kind, name', + [(ActionKind.MODEL, 'gpt-3.5-turbo')], + ) +-def test_openai_plugin_resolve_action(kind, name): +- """Unit Tests for resolve action method.""" ++@pytest.mark.asyncio ++async def test_openai_plugin_resolve_action(kind, name): ++ """Unit Tests for resolve method.""" + plugin = OpenAI(api_key='test-key') +- registry = MagicMock(spec=Genkit) +- plugin.resolve_action(registry, kind, name) +- +- model_info = SUPPORTED_OPENAI_MODELS[name] +- +- registry.define_model.assert_called_once_with( +- name=f'openai/{name}', +- fn=ANY, +- config_schema=OpenAIConfig, +- metadata={ +- 'model': { +- 'label': model_info.label, +- 'supports': { +- 'media': False, +- 'multiturn': True, +- 'output': [ +- 'json_mode', +- 'text', +- ], +- 'system_role': True, +- 'tools': True, +- }, +- }, +- }, +- ) +- +- +-def test_openai_plugin_list_actions() -> None: ++ action = await plugin.resolve(kind, name) ++ assert action is not None ++ assert action.kind == ActionKind.MODEL ++ assert action.name == name ++ assert action.metadata is not None ++ ++ ++@pytest.mark.asyncio ++async def test_openai_plugin_list_actions() -> None: + entries = [ + Model(id='gpt-4-0613', created=1686588896, object='model', owned_by='openai'), + Model(id='gpt-4', created=1687882411, object='model', owned_by='openai'), +@@ -94,9 +74,7 @@ def test_openai_plugin_list_actions() -> None: + + plugin._openai_client = mock_client + +- actions: list[ActionMetadata] = plugin.list_actions +- mock_client.models.list.assert_called_once() +- _ = plugin.list_actions ++ actions: list[ActionMetadata] = await plugin.list_actions() + mock_client.models.list.assert_called_once() + + assert len(actions) == len(entries) +@@ -108,14 +86,14 @@ def test_openai_plugin_list_actions() -> None: + 'kind, name', + [(ActionKind.MODEL, 'model_doesnt_exist')], + ) +-def test_openai_plugin_resolve_action_not_found(kind, name): +- """Unit Tests for resolve action method.""" +- ++@pytest.mark.asyncio ++async def test_openai_plugin_resolve_action_not_found(kind, name): ++ """Unknown models are still resolvable (compat plugin).""" + plugin = OpenAI(api_key='test-key') +- registry = MagicMock(spec=Genkit) +- plugin.resolve_action(registry, kind, name) +- +- registry.define_model.assert_called_once() ++ action = await plugin.resolve(kind, name) ++ assert action is not None ++ assert action.kind == ActionKind.MODEL ++ assert action.name == name + + + def test_openai_model_function() -> None: +diff --git a/py/plugins/compat-oai/tests/test_tool_calling.py b/py/plugins/compat-oai/tests/test_tool_calling.py +index 6aaf581df..288f4b3c4 100644 +--- a/py/plugins/compat-oai/tests/test_tool_calling.py ++++ b/py/plugins/compat-oai/tests/test_tool_calling.py +@@ -56,7 +56,7 @@ def test_generate_with_tool_calls_executes_tools(sample_request: GenerateRequest + second_response, + ] + +- model = OpenAIModel(model=GPT_4, client=mock_client, registry=MagicMock()) ++ model = OpenAIModel(model=GPT_4, client=mock_client) + + response = model._generate(sample_request) + +@@ -79,9 +79,7 @@ def test_generate_with_tool_calls_executes_tools(sample_request: GenerateRequest + + + def test_generate_stream_with_tool_calls(sample_request): +- """ +- Test generate_stream processes tool calls streamed in chunks correctly. +- """ ++ """Test generate_stream processes tool calls streamed in chunks correctly.""" + mock_client = MagicMock() + + class MockToolCall: +@@ -127,7 +125,7 @@ def test_generate_stream_with_tool_calls(sample_request): + + mock_client.chat.completions.create.return_value = MockStream() + +- model = OpenAIModel(model=GPT_4, client=mock_client, registry=MagicMock()) ++ model = OpenAIModel(model=GPT_4, client=mock_client) + collected_chunks = [] + + def callback(chunk: GenerateResponseChunk): +diff --git a/py/plugins/deepseek/pyproject.toml b/py/plugins/deepseek/pyproject.toml +new file mode 100644 +index 000000000..d4cbe0ef8 +--- /dev/null ++++ b/py/plugins/deepseek/pyproject.toml +@@ -0,0 +1,47 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++[project] ++authors = [{ name = "Google" }] ++classifiers = [ ++ "Development Status :: 3 - Alpha", ++ "Environment :: Console", ++ "Environment :: Web Environment", ++ "Intended Audience :: Developers", ++ "Operating System :: OS Independent", ++ "Programming Language :: Python", ++ "Programming Language :: Python :: 3 :: Only", ++ "Programming Language :: Python :: 3.10", ++ "Programming Language :: Python :: 3.11", ++ "Programming Language :: Python :: 3.12", ++ "Programming Language :: Python :: 3.13", ++ "Programming Language :: Python :: 3.14", ++ "Topic :: Scientific/Engineering :: Artificial Intelligence", ++ "Topic :: Software Development :: Libraries", ++] ++dependencies = ["genkit", "genkit-plugin-compat-oai", "openai>=1.0.0"] ++description = "Genkit DeepSeek Plugin" ++license = "Apache-2.0" ++name = "genkit-plugin-deepseek" ++requires-python = ">=3.10" ++version = "0.1.0" ++ ++[build-system] ++build-backend = "hatchling.build" ++requires = ["hatchling"] ++ ++[tool.hatch.build.targets.wheel] ++packages = ["src/genkit", "src/genkit/plugins"] +diff --git a/py/plugins/deepseek/src/genkit/plugins/deepseek/__init__.py b/py/plugins/deepseek/src/genkit/plugins/deepseek/__init__.py +new file mode 100644 +index 000000000..24021a619 +--- /dev/null ++++ b/py/plugins/deepseek/src/genkit/plugins/deepseek/__init__.py +@@ -0,0 +1,22 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++"""DeepSeek plugin for Genkit.""" ++ ++from .models import deepseek_name ++from .plugin import DeepSeek ++ ++__all__ = ['DeepSeek', 'deepseek_name'] +diff --git a/py/plugins/deepseek/src/genkit/plugins/deepseek/client.py b/py/plugins/deepseek/src/genkit/plugins/deepseek/client.py +new file mode 100644 +index 000000000..56ed84f24 +--- /dev/null ++++ b/py/plugins/deepseek/src/genkit/plugins/deepseek/client.py +@@ -0,0 +1,40 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++"""DeepSeek API client.""" ++ ++from openai import OpenAI as _OpenAI ++ ++ ++class DeepSeekClient: ++ """DeepSeek API client initialization.""" ++ ++ def __new__(cls, **deepseek_params) -> _OpenAI: ++ """Initialize the DeepSeek client. ++ ++ Args: ++ **deepseek_params: Client configuration parameters including: ++ - api_key: DeepSeek API key. ++ - base_url: API base URL (defaults to https://api.deepseek.com). ++ - Additional OpenAI client parameters. ++ ++ Returns: ++ Configured OpenAI client instance. ++ """ ++ api_key = deepseek_params.pop('api_key') ++ base_url = deepseek_params.pop('base_url', 'https://api.deepseek.com') ++ ++ return _OpenAI(api_key=api_key, base_url=base_url, **deepseek_params) +diff --git a/py/plugins/deepseek/src/genkit/plugins/deepseek/model_info.py b/py/plugins/deepseek/src/genkit/plugins/deepseek/model_info.py +new file mode 100644 +index 000000000..9601f58c6 +--- /dev/null ++++ b/py/plugins/deepseek/src/genkit/plugins/deepseek/model_info.py +@@ -0,0 +1,58 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++"""DeepSeek model information and metadata.""" ++ ++from genkit.types import ModelInfo, Supports ++ ++__all__ = ['SUPPORTED_DEEPSEEK_MODELS', 'get_default_model_info'] ++ ++# Model capabilities matching JS implementation ++_DEEPSEEK_SUPPORTS = Supports( ++ multiturn=True, ++ tools=True, ++ media=False, ++ system_role=True, ++ output=['text', 'json'], ++) ++ ++SUPPORTED_DEEPSEEK_MODELS: dict[str, ModelInfo] = { ++ 'deepseek-reasoner': ModelInfo( ++ label='DeepSeek - Reasoner', ++ versions=['deepseek-reasoner'], ++ supports=_DEEPSEEK_SUPPORTS, ++ ), ++ 'deepseek-chat': ModelInfo( ++ label='DeepSeek - Chat', ++ versions=['deepseek-chat'], ++ supports=_DEEPSEEK_SUPPORTS, ++ ), ++} ++ ++ ++def get_default_model_info(name: str) -> ModelInfo: ++ """Get default model information for unknown DeepSeek models. ++ ++ Args: ++ name: Model name. ++ ++ Returns: ++ Default ModelInfo with standard DeepSeek capabilities. ++ """ ++ return ModelInfo( ++ label=f'DeepSeek - {name}', ++ supports=_DEEPSEEK_SUPPORTS, ++ ) +diff --git a/py/plugins/deepseek/src/genkit/plugins/deepseek/models.py b/py/plugins/deepseek/src/genkit/plugins/deepseek/models.py +new file mode 100644 +index 000000000..65a5b76cb +--- /dev/null ++++ b/py/plugins/deepseek/src/genkit/plugins/deepseek/models.py +@@ -0,0 +1,124 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++"""DeepSeek model integration for Genkit.""" ++ ++from collections.abc import Callable ++from typing import Any ++ ++from genkit.ai import GenkitRegistry ++from genkit.plugins.compat_oai.models.model import OpenAIModel ++from genkit.plugins.compat_oai.typing import OpenAIConfig ++from genkit.plugins.deepseek.client import DeepSeekClient ++from genkit.plugins.deepseek.model_info import ( ++ SUPPORTED_DEEPSEEK_MODELS, ++ get_default_model_info, ++) ++ ++DEEPSEEK_PLUGIN_NAME = 'deepseek' ++ ++ ++def deepseek_name(name: str) -> str: ++ """Create a DeepSeek action name. ++ ++ Args: ++ name: Base name for the action. ++ ++ Returns: ++ The fully qualified DeepSeek action name. ++ """ ++ return f'{DEEPSEEK_PLUGIN_NAME}/{name}' ++ ++ ++class DeepSeekModel: ++ """Manages DeepSeek model integration for Genkit. ++ ++ This class provides integration with DeepSeek's OpenAI-compatible API, ++ allowing DeepSeek models to be exposed as Genkit models. It handles ++ client initialization, model information retrieval, and dynamic model ++ definition within the Genkit registry. ++ ++ Follows the Model Garden pattern for implementation consistency. ++ """ ++ ++ def __init__( ++ self, ++ model: str, ++ api_key: str, ++ registry: GenkitRegistry, ++ **deepseek_params, ++ ) -> None: ++ """Initialize the DeepSeek instance. ++ ++ Args: ++ model: The name of the specific DeepSeek model (e.g., 'deepseek-chat'). ++ api_key: DeepSeek API key for authentication. ++ registry: An instance of GenkitRegistry to register the model. ++ **deepseek_params: Additional parameters for the DeepSeek client. ++ """ ++ self.name = model ++ self.ai = registry ++ client_params = {'api_key': api_key, **deepseek_params} ++ self.client = DeepSeekClient(**client_params) ++ ++ def get_model_info(self) -> dict[str, Any] | None: ++ """Retrieve metadata and supported features for the specified model. ++ ++ This method looks up the model's information from a predefined list ++ of supported DeepSeek models or provides default information. ++ ++ Returns: ++ A dictionary containing the model's 'name' and 'supports' features. ++ The 'supports' key contains a dictionary representing the model's ++ capabilities (e.g., tools, streaming). ++ """ ++ model_info = SUPPORTED_DEEPSEEK_MODELS.get(self.name, get_default_model_info(self.name)) ++ return { ++ 'name': model_info.label, ++ 'supports': model_info.supports.model_dump(), ++ } ++ ++ def to_deepseek_model(self) -> Callable: ++ """Convert the DeepSeek model into a Genkit-compatible model function. ++ ++ This method wraps the underlying DeepSeek client and its generation ++ logic into a callable that adheres to the OpenAI model interface ++ expected by Genkit. ++ ++ Returns: ++ A callable function (the generate method of an OpenAIModel instance) ++ that can be used by Genkit. ++ """ ++ deepseek_model = OpenAIModel(self.name, self.client, self.ai) ++ return deepseek_model.generate ++ ++ def define_model(self) -> None: ++ """Define and register the DeepSeek model with the Genkit registry. ++ ++ This method orchestrates the retrieval of model metadata and the ++ creation of the generation function, then registers this model ++ within the Genkit framework using self.ai.define_model. ++ """ ++ model_info = self.get_model_info() ++ generate_fn = self.to_deepseek_model() ++ self.ai.define_model( ++ name=deepseek_name(self.name), ++ fn=generate_fn, ++ config_schema=OpenAIConfig, ++ metadata={ ++ 'model': model_info, ++ }, ++ ) +diff --git a/py/plugins/deepseek/src/genkit/plugins/deepseek/plugin.py b/py/plugins/deepseek/src/genkit/plugins/deepseek/plugin.py +new file mode 100644 +index 000000000..2943838c8 +--- /dev/null ++++ b/py/plugins/deepseek/src/genkit/plugins/deepseek/plugin.py +@@ -0,0 +1,140 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++"""DeepSeek Plugin for Genkit.""" ++ ++import os ++from functools import cached_property ++ ++from genkit.ai import GenkitRegistry, Plugin ++from genkit.blocks.model import model_action_metadata ++from genkit.core.action import ActionMetadata ++from genkit.core.action.types import ActionKind ++from genkit.core.error import GenkitError ++from genkit.plugins.compat_oai.typing import OpenAIConfig ++from genkit.plugins.deepseek.model_info import SUPPORTED_DEEPSEEK_MODELS ++from genkit.plugins.deepseek.models import DEEPSEEK_PLUGIN_NAME, DeepSeekModel, deepseek_name ++ ++ ++class DeepSeek(Plugin): ++ """DeepSeek plugin for Genkit. ++ ++ This plugin provides integration with DeepSeek's OpenAI-compatible API, ++ enabling the use of DeepSeek models within the Genkit framework. ++ """ ++ ++ name = DEEPSEEK_PLUGIN_NAME ++ ++ def __init__( ++ self, ++ api_key: str | None = None, ++ models: list[str] | None = None, ++ **deepseek_params, ++ ) -> None: ++ """Initialize the plugin and set up its configuration. ++ ++ Args: ++ api_key: The DeepSeek API key. If not provided, it attempts to load ++ from the DEEPSEEK_API_KEY environment variable. ++ models: An optional list of model names to register with the plugin. ++ If None, all supported models will be registered. ++ **deepseek_params: Additional parameters for the DeepSeek client. ++ ++ Raises: ++ GenkitError: If no API key is provided via parameter or environment. ++ """ ++ self.api_key = api_key if api_key is not None else os.getenv('DEEPSEEK_API_KEY') ++ ++ if not self.api_key: ++ raise GenkitError(message='Please provide api_key or set DEEPSEEK_API_KEY environment variable.') ++ ++ self.models = models ++ self.deepseek_params = deepseek_params ++ ++ def initialize(self, ai: GenkitRegistry) -> None: ++ """Initialize the plugin by registering specified models. ++ ++ Args: ++ ai: The Genkit registry where models will be registered. ++ """ ++ models = self.models ++ if models is None: ++ models = list(SUPPORTED_DEEPSEEK_MODELS.keys()) ++ ++ for model in models: ++ deepseek_model = DeepSeekModel( ++ model=model, ++ api_key=self.api_key, ++ registry=ai, ++ **self.deepseek_params, ++ ) ++ deepseek_model.define_model() ++ ++ def resolve_action( ++ self, ++ ai: GenkitRegistry, ++ kind: ActionKind, ++ name: str, ++ ) -> None: ++ """Resolve and register an action dynamically. ++ ++ Args: ++ ai: The Genkit registry. ++ kind: The kind of action to resolve. ++ name: The name of the action to resolve. ++ """ ++ if kind == ActionKind.MODEL: ++ self._resolve_model(ai=ai, name=name) ++ ++ def _resolve_model(self, ai: GenkitRegistry, name: str) -> None: ++ """Resolve and define a DeepSeek model within the Genkit registry. ++ ++ This internal method handles the logic for registering DeepSeek models ++ dynamically based on the provided name. It extracts a clean name, ++ instantiates the DeepSeek class, and registers it with the registry. ++ ++ Args: ++ ai: The Genkit AI registry instance to define the model in. ++ name: The name of the model to resolve. This name might include a ++ prefix indicating it's from the DeepSeek plugin. ++ """ ++ clean_name = name.replace(DEEPSEEK_PLUGIN_NAME + '/', '') if name.startswith(DEEPSEEK_PLUGIN_NAME) else name ++ ++ deepseek_model = DeepSeekModel( ++ model=clean_name, ++ api_key=self.api_key, ++ registry=ai, ++ **self.deepseek_params, ++ ) ++ deepseek_model.define_model() ++ ++ @cached_property ++ def list_actions(self) -> list[ActionMetadata]: ++ """Generate a list of available DeepSeek models. ++ ++ Returns: ++ list[ActionMetadata]: A list of ActionMetadata objects for each ++ supported DeepSeek model, including name, metadata, and config schema. ++ """ ++ actions_list = [] ++ for model, model_info in SUPPORTED_DEEPSEEK_MODELS.items(): ++ actions_list.append( ++ model_action_metadata( ++ name=deepseek_name(model), info=model_info.model_dump(), config_schema=OpenAIConfig ++ ) ++ ) ++ ++ return actions_list +diff --git a/py/plugins/deepseek/tests/test_deepseek_plugin.py b/py/plugins/deepseek/tests/test_deepseek_plugin.py +new file mode 100644 +index 000000000..150d1d23e +--- /dev/null ++++ b/py/plugins/deepseek/tests/test_deepseek_plugin.py +@@ -0,0 +1,185 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++"""Tests for DeepSeek plugin.""" ++ ++import os ++from unittest.mock import MagicMock, patch ++ ++import pytest ++ ++from genkit.core.error import GenkitError ++from genkit.core.registry import ActionKind ++from genkit.plugins.deepseek import DeepSeek, deepseek_name ++ ++ ++def test_deepseek_name(): ++ """Test name helper function.""" ++ assert deepseek_name('deepseek-chat') == 'deepseek/deepseek-chat' ++ assert deepseek_name('deepseek-reasoner') == 'deepseek/deepseek-reasoner' ++ ++ ++def test_plugin_initialization_with_api_key(): ++ """Test plugin initializes with API key.""" ++ plugin = DeepSeek(api_key='test-key') ++ assert plugin.name == 'deepseek' ++ assert plugin.api_key == 'test-key' ++ ++ ++def test_plugin_initialization_from_env(): ++ """Test plugin reads API key from environment.""" ++ with patch.dict(os.environ, {'DEEPSEEK_API_KEY': 'env-key'}): ++ plugin = DeepSeek() ++ assert plugin.api_key == 'env-key' ++ ++ ++def test_plugin_initialization_without_api_key(): ++ """Test plugin raises error without API key.""" ++ with patch.dict(os.environ, {}, clear=True): ++ with pytest.raises(GenkitError) as exc_info: ++ DeepSeek() ++ assert 'DEEPSEEK_API_KEY' in str(exc_info.value) ++ ++ ++@patch('genkit.plugins.deepseek.models.DeepSeekClient') ++def test_plugin_initialize(mock_client): ++ """Test plugin registers models during initialization.""" ++ plugin = DeepSeek(api_key='test-key', models=['deepseek-chat']) ++ mock_registry = MagicMock() ++ ++ plugin.initialize(mock_registry) ++ ++ # Should call define_model for the specified model ++ mock_registry.define_model.assert_called_once() ++ ++ ++@patch('genkit.plugins.deepseek.models.DeepSeekClient') ++def test_plugin_resolve_action(mock_client): ++ """Test plugin resolves models dynamically.""" ++ plugin = DeepSeek(api_key='test-key', models=[]) ++ mock_registry = MagicMock() ++ ++ plugin.resolve_action(mock_registry, ActionKind.MODEL, 'deepseek/deepseek-chat') ++ ++ # Should register the requested model ++ mock_registry.define_model.assert_called_once() ++ ++ ++def test_plugin_list_actions(): ++ """Test plugin lists available models.""" ++ plugin = DeepSeek(api_key='test-key') ++ actions = plugin.list_actions ++ ++ assert len(actions) == 2 ++ action_names = [action.name for action in actions] ++ assert 'deepseek/deepseek-reasoner' in action_names ++ assert 'deepseek/deepseek-chat' in action_names ++ ++ ++@patch('genkit.plugins.deepseek.models.DeepSeekClient') ++def test_plugin_with_custom_params(mock_client): ++ """Test plugin accepts custom parameters.""" ++ plugin = DeepSeek( ++ api_key='test-key', ++ models=['deepseek-chat'], ++ timeout=60, ++ max_retries=3, ++ ) ++ ++ assert plugin.deepseek_params['timeout'] == 60 ++ assert plugin.deepseek_params['max_retries'] == 3 ++ ++ ++@patch('genkit.plugins.deepseek.models.DeepSeekClient') ++def test_plugin_initialize_no_models(mock_client): ++ """Test plugin registers all supported models when models is None.""" ++ from genkit.plugins.deepseek.model_info import SUPPORTED_DEEPSEEK_MODELS ++ ++ plugin = DeepSeek(api_key='test-key') ++ mock_registry = MagicMock() ++ ++ # When models is None, all supported models should be registered ++ plugin.initialize(mock_registry) ++ ++ assert mock_registry.define_model.call_count == len(SUPPORTED_DEEPSEEK_MODELS) ++ ++ ++def test_plugin_resolve_action_non_model_kind(): ++ """Test resolve_action does nothing for non-MODEL kinds.""" ++ plugin = DeepSeek(api_key='test-key') ++ mock_registry = MagicMock() ++ ++ # Using PROMPT kind to test the case where kind != MODEL ++ plugin.resolve_action(mock_registry, ActionKind.PROMPT, 'some-prompt') ++ ++ # Should not attempt to register anything ++ mock_registry.define_model.assert_not_called() ++ ++ ++@patch('genkit.plugins.deepseek.models.DeepSeekClient') ++def test_plugin_resolve_action_without_prefix(mock_client): ++ """Test plugin resolves models without plugin prefix.""" ++ plugin = DeepSeek(api_key='test-key', models=[]) ++ mock_registry = MagicMock() ++ ++ # Pass name without 'deepseek/' prefix ++ plugin.resolve_action(mock_registry, ActionKind.MODEL, 'deepseek-chat') ++ ++ mock_registry.define_model.assert_called_once() ++ ++ ++@patch('genkit.plugins.deepseek.client.DeepSeekClient.__new__') ++def test_deepseek_client_initialization(mock_new): ++ """Test DeepSeekClient creates OpenAI client with correct params.""" ++ from genkit.plugins.deepseek.client import DeepSeekClient ++ ++ # Set up mock to return a fake client ++ mock_client_instance = MagicMock() ++ mock_new.return_value = mock_client_instance ++ ++ # Create a DeepSeekClient ++ result = DeepSeekClient(api_key='test-key', timeout=30) ++ ++ # Verify __new__ was called with correct parameters ++ mock_new.assert_called_once() ++ ++ ++def test_deepseek_client_with_custom_base_url(): ++ """Test DeepSeekClient accepts custom base_url.""" ++ from openai import OpenAI ++ ++ from genkit.plugins.deepseek.client import DeepSeekClient ++ ++ with patch.object(OpenAI, '__init__', return_value=None) as mock_init: ++ DeepSeekClient(api_key='test-key', base_url='https://custom.api.deepseek.com') ++ mock_init.assert_called_once_with( ++ api_key='test-key', ++ base_url='https://custom.api.deepseek.com', ++ ) ++ ++ ++def test_deepseek_client_default_base_url(): ++ """Test DeepSeekClient uses default base_url when not provided.""" ++ from openai import OpenAI ++ ++ from genkit.plugins.deepseek.client import DeepSeekClient ++ ++ with patch.object(OpenAI, '__init__', return_value=None) as mock_init: ++ DeepSeekClient(api_key='test-key') ++ mock_init.assert_called_once_with( ++ api_key='test-key', ++ base_url='https://api.deepseek.com', ++ ) +diff --git a/py/plugins/deepseek/tests/test_model_info.py b/py/plugins/deepseek/tests/test_model_info.py +new file mode 100644 +index 000000000..dd61b137b +--- /dev/null ++++ b/py/plugins/deepseek/tests/test_model_info.py +@@ -0,0 +1,55 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++"""Tests for DeepSeek model information.""" ++ ++import pytest ++ ++from genkit.plugins.deepseek.model_info import SUPPORTED_DEEPSEEK_MODELS, get_default_model_info ++ ++ ++def test_supported_models_exist(): ++ """Test that supported models are defined.""" ++ assert 'deepseek-reasoner' in SUPPORTED_DEEPSEEK_MODELS ++ assert 'deepseek-chat' in SUPPORTED_DEEPSEEK_MODELS ++ ++ ++def test_model_order(): ++ """Test models are in correct order (matching JS).""" ++ keys = list(SUPPORTED_DEEPSEEK_MODELS.keys()) ++ assert keys[0] == 'deepseek-reasoner' ++ assert keys[1] == 'deepseek-chat' ++ ++ ++def test_model_info_structure(): ++ """Test model info has required fields.""" ++ for model_name, model_info in SUPPORTED_DEEPSEEK_MODELS.items(): ++ assert model_info.label ++ assert model_info.supports ++ assert model_info.supports.multiturn is True ++ assert model_info.supports.tools is True ++ assert model_info.supports.media is False ++ assert model_info.supports.system_role is True ++ assert 'text' in model_info.supports.output ++ assert 'json' in model_info.supports.output ++ ++ ++def test_get_default_model_info(): ++ """Test getting default info for unknown models.""" ++ info = get_default_model_info('deepseek-future-model') ++ assert 'deepseek-future-model' in info.label ++ assert info.supports.multiturn is True ++ assert info.supports.tools is True +diff --git a/py/plugins/dev-local-vectorstore/pyproject.toml b/py/plugins/dev-local-vectorstore/pyproject.toml +index 7302053da..fe965b239 100644 +--- a/py/plugins/dev-local-vectorstore/pyproject.toml ++++ b/py/plugins/dev-local-vectorstore/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -40,7 +39,7 @@ dependencies = [ + "strenum>=0.4.15; python_version < '3.11'", + ] + description = "Genkit Local Vector Store Plugin" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "genkit-plugin-dev-local-vectorstore" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/plugins/dev-local-vectorstore/src/genkit/plugins/dev_local_vectorstore/indexer.py b/py/plugins/dev-local-vectorstore/src/genkit/plugins/dev_local_vectorstore/indexer.py +index 18c36482c..8f8379095 100644 +--- a/py/plugins/dev-local-vectorstore/src/genkit/plugins/dev_local_vectorstore/indexer.py ++++ b/py/plugins/dev-local-vectorstore/src/genkit/plugins/dev_local_vectorstore/indexer.py +@@ -22,7 +22,7 @@ from hashlib import md5 + from genkit.blocks.document import Document + from genkit.blocks.retriever import IndexerRequest + from genkit.codec import dump_json +-from genkit.types import DocumentData, Embedding ++from genkit.types import Embedding + + from .constant import DbValue + from .local_vector_store_api import ( +diff --git a/py/plugins/dev-local-vectorstore/src/genkit/plugins/dev_local_vectorstore/plugin_api.py b/py/plugins/dev-local-vectorstore/src/genkit/plugins/dev_local_vectorstore/plugin_api.py +index 21cea5831..22df23c63 100644 +--- a/py/plugins/dev-local-vectorstore/src/genkit/plugins/dev_local_vectorstore/plugin_api.py ++++ b/py/plugins/dev-local-vectorstore/src/genkit/plugins/dev_local_vectorstore/plugin_api.py +@@ -18,9 +18,16 @@ + + from typing import Any + +-from genkit.ai import GenkitRegistry, Plugin +-from genkit.core.action import Action +-from genkit.types import Docs ++from genkit.ai import Plugin ++from genkit.blocks.retriever import ( ++ IndexerOptions, ++ RetrieverOptions, ++ indexer_action_metadata, ++ retriever_action_metadata, ++) ++from genkit.core.action import Action, ActionMetadata ++from genkit.core.action.types import ActionKind ++from genkit.core.schema import to_json_schema + + from .indexer import ( + DevLocalVectorStoreIndexer, +@@ -44,62 +51,83 @@ class DevLocalVectorStore(Plugin): + self.embedder = embedder + self.embedder_options = embedder_options + +- def initialize(self, ai: GenkitRegistry) -> None: +- """Initialize the plugin by registering actions with the registry. +- +- This method registers the Local Vector Store actions with the provided +- registry, making them available for use in the Genkit framework. +- +- Args: +- ai: The registry to register actions with. +- +- Returns: +- None +- """ +- self._configure_dev_local_retriever(ai=ai) +- self._configure_dev_local_indexer(ai=ai) +- +- def _configure_dev_local_retriever(self, ai: GenkitRegistry) -> Action: +- """Registers Local Vector Store retriever for provided parameters. +- +- Args: +- ai: The registry to register retriever with. +- params: Parameters to register retriever with. +- +- Returns: +- registered Action instance +- """ +- retriever = DevLocalVectorStoreRetriever( +- ai=ai, +- index_name=self.index_name, +- embedder=self.embedder, +- embedder_options=self.embedder_options, +- ) +- +- return ai.define_retriever( +- name=self.index_name, +- config_schema=RetrieverOptionsSchema, +- fn=retriever.retrieve, +- ) +- +- def _configure_dev_local_indexer(self, ai: GenkitRegistry) -> Action: +- """Registers Local Vector Store indexer for provided parameters. +- +- Args: +- ai: The registry to register indexer with. +- params: Parameters to register indexer with. +- +- Returns: +- registered Action instance +- """ +- indexer = DevLocalVectorStoreIndexer( +- ai=ai, +- index_name=self.index_name, +- embedder=self.embedder, +- embedder_options=self.embedder_options, +- ) +- +- return ai.define_indexer( +- name=self.index_name, +- fn=indexer.index, +- ) ++ async def init(self) -> list[Action]: ++ return [ ++ self._create_retriever_action(), ++ self._create_indexer_action(), ++ ] ++ ++ async def resolve(self, action_type: ActionKind, name: str) -> Action | None: ++ if name != self.index_name: ++ return None ++ if action_type == ActionKind.RETRIEVER: ++ return self._create_retriever_action() ++ if action_type == ActionKind.INDEXER: ++ return self._create_indexer_action() ++ return None ++ ++ async def list_actions(self) -> list[ActionMetadata]: ++ return [ ++ retriever_action_metadata( ++ name=self.index_name, ++ options=RetrieverOptions( ++ label=self.index_name, ++ config_schema=to_json_schema(RetrieverOptionsSchema), ++ ), ++ ), ++ indexer_action_metadata( ++ name=self.index_name, ++ options=IndexerOptions( ++ label=self.index_name, ++ ), ++ ), ++ ] ++ ++ def _create_retriever_action(self) -> Action: ++ metadata: dict[str, Any] = { ++ 'retriever': { ++ 'label': self.index_name, ++ 'customOptions': to_json_schema(RetrieverOptionsSchema), ++ } ++ } ++ ++ async def retrieve(request, ctx): ++ ai = (ctx.context or {}).get('__genkit_ai__') ++ if ai is None: ++ raise ValueError( ++ 'DevLocalVectorStore retriever requires a Genkit instance in action context. ' ++ 'Use it via `await ai.retrieve(...)`.' ++ ) ++ retriever = DevLocalVectorStoreRetriever( ++ ai=ai, ++ index_name=self.index_name, ++ embedder=self.embedder, ++ embedder_options=self.embedder_options, ++ ) ++ return await retriever.retrieve(request, ctx) ++ ++ return Action(kind=ActionKind.RETRIEVER, name=self.index_name, fn=retrieve, metadata=metadata) ++ ++ def _create_indexer_action(self) -> Action: ++ metadata: dict[str, Any] = { ++ 'indexer': { ++ 'label': self.index_name, ++ } ++ } ++ ++ async def index(request, ctx): ++ ai = (ctx.context or {}).get('__genkit_ai__') ++ if ai is None: ++ raise ValueError( ++ 'DevLocalVectorStore indexer requires a Genkit instance in action context. ' ++ 'Use it via `await ai.index(...)`.' ++ ) ++ indexer = DevLocalVectorStoreIndexer( ++ ai=ai, ++ index_name=self.index_name, ++ embedder=self.embedder, ++ embedder_options=self.embedder_options, ++ ) ++ return await indexer.index(request) ++ ++ return Action(kind=ActionKind.INDEXER, name=self.index_name, fn=index, metadata=metadata) +diff --git a/py/plugins/dev-local-vectorstore/tests/test_dev_local_vectorstore_plugin_v2.py b/py/plugins/dev-local-vectorstore/tests/test_dev_local_vectorstore_plugin_v2.py +new file mode 100644 +index 000000000..609161cda +--- /dev/null ++++ b/py/plugins/dev-local-vectorstore/tests/test_dev_local_vectorstore_plugin_v2.py +@@ -0,0 +1,48 @@ ++# Copyright 2025 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++from __future__ import annotations ++ ++import pytest ++ ++from genkit.core.action.types import ActionKind ++from genkit.plugins.dev_local_vectorstore import DevLocalVectorStore ++ ++ ++@pytest.mark.asyncio ++async def test_init_returns_retriever_and_indexer_actions(): ++ plugin = DevLocalVectorStore( ++ name='films', ++ embedder='vertexai/text-embedding-004', ++ ) ++ ++ actions = await plugin.init() ++ ++ assert {a.kind for a in actions} == {ActionKind.RETRIEVER, ActionKind.INDEXER} ++ assert {a.name for a in actions} == {'films'} ++ ++ ++@pytest.mark.asyncio ++async def test_list_returns_action_metadata(): ++ plugin = DevLocalVectorStore( ++ name='films', ++ embedder='vertexai/text-embedding-004', ++ ) ++ ++ metas = await plugin.list_actions() ++ ++ assert {m.kind for m in metas} == {ActionKind.RETRIEVER, ActionKind.INDEXER} ++ assert {m.name for m in metas} == {'films'} +diff --git a/py/plugins/evaluators/pyproject.toml b/py/plugins/evaluators/pyproject.toml +index 257fad1fb..c7b9385d1 100644 +--- a/py/plugins/evaluators/pyproject.toml ++++ b/py/plugins/evaluators/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +@@ -38,7 +37,7 @@ dependencies = [ + "strenum>=0.4.15; python_version < '3.11'", + ] + description = "Genkit Evaluators Plugin for RAGAS" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "genkit-plugin-evaluators" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/plugins/evaluators/src/genkit/plugins/evaluators/plugin_api.py b/py/plugins/evaluators/src/genkit/plugins/evaluators/plugin_api.py +index e8da849b9..a9b07c37a 100644 +--- a/py/plugins/evaluators/src/genkit/plugins/evaluators/plugin_api.py ++++ b/py/plugins/evaluators/src/genkit/plugins/evaluators/plugin_api.py +@@ -22,9 +22,12 @@ from collections.abc import Callable + from typing import Any + + import jsonata +-from dotpromptz.typing import DataArgument + +-from genkit.ai import Genkit, Plugin ++from genkit.ai import Plugin ++from genkit.core.action import Action, ActionMetadata ++from genkit.core.action.types import ActionKind ++from genkit.core.schema import to_json_schema ++from genkit.core.typing import EvalRequest, EvalResponse + from genkit.plugins.evaluators.constant import ( + AnswerRelevancyResponseSchema, + GenkitMetricType, +@@ -80,25 +83,108 @@ class GenkitEvaluators(Plugin): + params = PluginOptions(root=params) + self.params = params + +- def initialize(self, ai: Genkit) -> None: +- """Initialize the plugin by registering actions with the registry.""" ++ async def init(self) -> list[Action]: ++ return [self._create_evaluator_action(param) for param in self.params.root] ++ ++ async def resolve(self, action_type: ActionKind, name: str) -> Action | None: ++ if action_type != ActionKind.EVALUATOR: ++ return None ++ for param in self.params.root: ++ metric_name, _, _ = self._metric_descriptor(param) ++ if name == metric_name: ++ return self._create_evaluator_action(param) ++ return None ++ ++ async def list_actions(self) -> list[ActionMetadata]: ++ metas: list[ActionMetadata] = [] + for param in self.params.root: +- self._configure_evaluator(ai=ai, param=param) ++ metric_name, display_name, definition = self._metric_descriptor(param) ++ metas.append( ++ ActionMetadata( ++ kind=ActionKind.EVALUATOR, ++ name=metric_name, ++ input_json_schema=to_json_schema(EvalRequest), ++ output_json_schema=to_json_schema(EvalResponse), ++ metadata={ ++ 'evaluator': { ++ 'label': metric_name, ++ 'displayName': display_name, ++ 'definition': definition, ++ 'isBilled': bool(param.judge), ++ } ++ }, ++ ) ++ ) ++ return metas + +- def _configure_evaluator(self, ai: Genkit, param: MetricConfig): +- """Validates and configures supported evaluators.""" ++ def _metric_descriptor(self, param: MetricConfig) -> tuple[str, str, str]: + metric_type = param.metric_type + match metric_type: + case GenkitMetricType.ANSWER_RELEVANCY: ++ return ( ++ str(metric_type).lower(), ++ 'Answer Relevancy', ++ 'Assesses how pertinent the generated answer is to the given prompt', ++ ) ++ case GenkitMetricType.FAITHFULNESS: ++ return ( ++ str(metric_type).lower(), ++ 'Faithfulness', ++ 'Measures the factual consistency of the generated answer against the given context', ++ ) ++ case GenkitMetricType.MALICIOUSNESS: ++ return ( ++ str(metric_type).lower(), ++ 'Maliciousness', ++ 'Measures whether the generated output intends to deceive, harm, or exploit', ++ ) ++ case GenkitMetricType.REGEX: ++ return ( ++ str(metric_type).lower(), ++ 'RegExp', ++ 'Tests output against the regexp provided as reference', ++ ) ++ case GenkitMetricType.DEEP_EQUAL: ++ return ( ++ str(metric_type).lower(), ++ 'Deep Equals', ++ 'Tests equality of output against the provided reference', ++ ) ++ case GenkitMetricType.JSONATA: ++ return ( ++ str(metric_type).lower(), ++ 'JSONata', ++ 'Tests JSONata expression (provided in reference) against output', ++ ) ++ case _: ++ raise ValueError(f'Unsupported metric type: {metric_type}') ++ ++ def _create_evaluator_action(self, param: MetricConfig) -> Action: ++ metric_name, display_name, definition = self._metric_descriptor(param) ++ metadata = { ++ 'evaluator': { ++ 'label': metric_name, ++ 'displayName': display_name, ++ 'definition': definition, ++ 'isBilled': bool(param.judge), ++ } ++ } ++ ++ metric_type = param.metric_type ++ ++ # Cache for prompts (loaded on first use) - scoped per-action to avoid cross-test coupling. ++ _faithfulness_prompts: dict[str, Any] = {} + +- async def _relevancy_eval(datapoint: BaseEvalDataPoint, options: Any | None): ++ async def eval_one(datapoint: BaseEvalDataPoint, options: Any | None, ai) -> EvalFnResponse: ++ match metric_type: ++ case GenkitMetricType.ANSWER_RELEVANCY: + assert datapoint.output is not None, 'output is required' + output_string = ( + datapoint.output if isinstance(datapoint.output, str) else json.dumps(datapoint.output) + ) + input_string = datapoint.input if isinstance(datapoint.input, str) else json.dumps(datapoint.input) + prompt_function = await load_prompt_file(_get_prompt_path('faithfulness_long_form.prompt')) +- context = ' '.join(json.dumps(e) for e in datapoint.context) ++ context = ' '.join(json.dumps(e) for e in (datapoint.context or [])) + prompt = await render_text( + prompt_function, {'input': input_string, 'output': output_string, 'context': context} + ) +@@ -106,24 +192,17 @@ class GenkitEvaluators(Plugin): + response = await ai.generate( + model=param.judge.name, + prompt=prompt, +- config=param.config, ++ config=param.judge_config, + output_schema=AnswerRelevancyResponseSchema, + ) +- # TODO: embedding comparison between the input and the result of the llm +- status = EvalStatusEnum.PASS_ if response.output else EvalStatusEnum.FAIL +- return fill_scores(datapoint, Score(score=score, status=status), param.status_override_fn) + +- ai.define_evaluator( +- name=evaluators_name(str(GenkitMetricType.ANSWER_RELEVANCY).lower()), +- display_name='Answer Relevancy', +- definition='Assesses how pertinent the generated answer is to the given prompt', +- fn=_relevancy_eval, +- ) +- case GenkitMetricType.FAITHFULNESS: +- # Cache for prompts (loaded on first use) +- _faithfulness_prompts = {} ++ out = response.output ++ answered = out.get('answered') if isinstance(out, dict) else (out.answered if out else False) ++ score = bool(answered) ++ status = EvalStatusEnum.PASS_ if score else EvalStatusEnum.FAIL ++ return fill_scores(datapoint, Score(score=score, status=status), param.status_override_fn) + +- async def _faithfulness_eval(datapoint: BaseEvalDataPoint, options: Any | None): ++ case GenkitMetricType.FAITHFULNESS: + assert datapoint.output is not None, 'output is required' + output_string = ( + datapoint.output if isinstance(datapoint.output, str) else json.dumps(datapoint.output) +@@ -131,7 +210,6 @@ class GenkitEvaluators(Plugin): + input_string = datapoint.input if isinstance(datapoint.input, str) else json.dumps(datapoint.input) + context_list = [(json.dumps(e) if not isinstance(e, str) else e) for e in (datapoint.context or [])] + +- # Lazy load and cache prompts + if 'longform' not in _faithfulness_prompts: + _faithfulness_prompts['longform'] = await load_prompt_file( + _get_prompt_path('faithfulness_long_form.prompt') +@@ -141,7 +219,6 @@ class GenkitEvaluators(Plugin): + _get_prompt_path('faithfulness_nli.prompt') + ) + +- # Step 1: Extract statements + prompt = await render_text( + _faithfulness_prompts['longform'], {'question': input_string, 'answer': output_string} + ) +@@ -159,7 +236,6 @@ class GenkitEvaluators(Plugin): + if not statements: + raise ValueError('No statements returned') + +- # Step 2: NLI Check + all_statements = '\n'.join([f'statement: {s}' for s in statements]) + all_context = '\n'.join(context_list) + prompt = await render_text( +@@ -174,68 +250,51 @@ class GenkitEvaluators(Plugin): + ) + + nli_output = nli_response.output +- if isinstance(nli_output, dict): +- responses = nli_output.get('responses', []) +- else: +- responses = nli_output.responses if nli_output else [] +- ++ responses = ( ++ nli_output.get('responses', []) ++ if isinstance(nli_output, dict) ++ else (nli_output.responses if nli_output else []) ++ ) + if not responses: + raise ValueError('Evaluator response empty') + +- # Handle both dict and object responses + faithful_count = sum( + 1 for r in responses if (r.get('verdict') if isinstance(r, dict) else r.verdict) + ) + score_val = faithful_count / len(responses) + reasoning = '; '.join([r.get('reason', '') if isinstance(r, dict) else r.reason for r in responses]) + status = EvalStatusEnum.PASS_ if score_val > 0.5 else EvalStatusEnum.FAIL +- + return fill_scores( + datapoint, + Score(score=score_val, status=status, details={'reasoning': reasoning}), + param.status_override_fn, + ) + +- ai.define_evaluator( +- name=evaluators_name(str(GenkitMetricType.FAITHFULNESS).lower()), +- display_name='Faithfulness', +- definition='Measures the factual consistency of the generated answer against the given context', +- fn=_faithfulness_eval, +- ) +- +- case GenkitMetricType.MALICIOUSNESS: +- +- async def _maliciousness_eval(datapoint: BaseEvalDataPoint, options: Any | None): ++ case GenkitMetricType.MALICIOUSNESS: + assert datapoint.output is not None, 'output is required' + output_string = ( + datapoint.output if isinstance(datapoint.output, str) else json.dumps(datapoint.output) + ) + input_string = datapoint.input if isinstance(datapoint.input, str) else json.dumps(datapoint.input) + prompt_function = await load_prompt_file(_get_prompt_path('maliciousness.prompt')) +- context = ' '.join(json.dumps(e) for e in datapoint.context) ++ context = ' '.join(json.dumps(e) for e in (datapoint.context or [])) + prompt = await render_text( + prompt_function, {'input': input_string, 'output': output_string, 'context': context} + ) + +- score = await ai.generate( ++ response = await ai.generate( + model=param.judge.name, + prompt=prompt, +- config=param.config, ++ config=param.judge_config, + output_schema=MaliciousnessResponseSchema, + ) ++ out = response.output ++ verdict = out.get('verdict') if isinstance(out, dict) else (out.verdict if out else False) ++ score = bool(verdict) + status = EvalStatusEnum.PASS_ if score else EvalStatusEnum.FAIL + return fill_scores(datapoint, Score(score=score, status=status), param.status_override_fn) + +- ai.define_evaluator( +- name=evaluators_name(str(GenkitMetricType.MALICIOUSNESS).lower()), +- display_name='Maliciousness', +- definition='Measures whether the generated output intends to deceive, harm, or exploit', +- fn=_maliciousness_eval, +- ) +- # +- case GenkitMetricType.REGEX: +- +- async def _regex_eval(datapoint: BaseEvalDataPoint, options: Any | None): ++ case GenkitMetricType.REGEX: + assert datapoint.output is not None, 'output is required' + assert datapoint.reference is not None, 'reference is required' + assert isinstance(datapoint.reference, str), 'reference must be of string (regex)' +@@ -243,39 +302,20 @@ class GenkitEvaluators(Plugin): + datapoint.output if isinstance(datapoint.output, str) else json.dumps(datapoint.output) + ) + pattern = re.compile(datapoint.reference) +- score = False if pattern.search(output_string) is None else True ++ score = pattern.search(output_string) is not None + status = EvalStatusEnum.PASS_ if score else EvalStatusEnum.FAIL + return fill_scores(datapoint, Score(score=score, status=status), param.status_override_fn) + +- ai.define_evaluator( +- name=evaluators_name(str(GenkitMetricType.REGEX).lower()), +- display_name='RegExp', +- definition='Tests output against the regexp provided as reference', +- fn=_regex_eval, +- ) +- +- case GenkitMetricType.DEEP_EQUAL: +- +- async def _deep_equal_eval(datapoint: BaseEvalDataPoint, options: Any | None): ++ case GenkitMetricType.DEEP_EQUAL: + assert datapoint.reference is not None, 'reference is required' + assert datapoint.output is not None, 'output is required' +- score = False +- if type(datapoint.output) is type(datapoint.reference): +- if datapoint.output == datapoint.reference: +- score = True ++ score = ( ++ type(datapoint.output) is type(datapoint.reference) and datapoint.output == datapoint.reference ++ ) + status = EvalStatusEnum.PASS_ if score else EvalStatusEnum.FAIL + return fill_scores(datapoint, Score(score=score, status=status), param.status_override_fn) + +- ai.define_evaluator( +- name=evaluators_name(str(GenkitMetricType.DEEP_EQUAL).lower()), +- display_name='Deep Equals', +- definition="""Tests equality of output against the provided reference""", +- fn=_deep_equal_eval, +- ) +- +- case GenkitMetricType.JSONATA: +- +- async def _jsonata_eval(datapoint: BaseEvalDataPoint, options: Any | None): ++ case GenkitMetricType.JSONATA: + assert datapoint.output is not None, 'output is required' + assert datapoint.reference is not None, 'reference is required' + assert isinstance(datapoint.reference, str), 'reference must be of string (jsonata)' +@@ -284,9 +324,33 @@ class GenkitEvaluators(Plugin): + status = EvalStatusEnum.PASS_ if bool(score) else EvalStatusEnum.FAIL + return fill_scores(datapoint, Score(score=score, status=status), param.status_override_fn) + +- ai.define_evaluator( +- name=evaluators_name(str(GenkitMetricType.JSONATA).lower()), +- display_name='JSONata', +- definition="""Tests JSONata expression (provided in reference) against output""", +- fn=_jsonata_eval, ++ case _: ++ raise ValueError(f'Unsupported metric type: {metric_type}') ++ ++ async def eval_stepper(req: EvalRequest, ctx): ++ ai = (ctx.context or {}).get('__genkit_ai__') ++ if ai is None: ++ raise ValueError( ++ 'GenkitEvaluators requires a Genkit instance in action context. Use `await ai.evaluate(...)`.' + ) ++ ++ responses: list[EvalFnResponse] = [] ++ for datapoint in req.dataset: ++ if datapoint.test_case_id is None: ++ # Keep behavior consistent with core evaluator runner. ++ datapoint.test_case_id = 'unknown' ++ try: ++ responses.append(await eval_one(datapoint, req.options, ai)) ++ except Exception as e: ++ responses.append( ++ EvalFnResponse( ++ test_case_id=datapoint.test_case_id, ++ evaluation=Score( ++ error=f'Evaluation of test case {datapoint.test_case_id} failed: \n{str(e)}', ++ status=EvalStatusEnum.FAIL, ++ ), ++ ) ++ ) ++ return EvalResponse(root=responses) ++ ++ return Action(kind=ActionKind.EVALUATOR, name=metric_name, fn=eval_stepper, metadata=metadata) +diff --git a/py/plugins/evaluators/src/genkit/plugins/metrics/helper.py b/py/plugins/evaluators/src/genkit/plugins/metrics/helper.py +index 43c60bf3d..2bd91b4de 100644 +--- a/py/plugins/evaluators/src/genkit/plugins/metrics/helper.py ++++ b/py/plugins/evaluators/src/genkit/plugins/metrics/helper.py +@@ -25,7 +25,7 @@ dp = Dotprompt() + + + async def load_prompt_file(path: str) -> PromptFunction: +- with open(path, 'r') as f: ++ with open(path) as f: + result = await dp.compile(f.read()) + + return result +diff --git a/py/plugins/evaluators/tests/test_evaluators_plugin_v2.py b/py/plugins/evaluators/tests/test_evaluators_plugin_v2.py +new file mode 100644 +index 000000000..7ab7f0f9e +--- /dev/null ++++ b/py/plugins/evaluators/tests/test_evaluators_plugin_v2.py +@@ -0,0 +1,53 @@ ++# Copyright 2025 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++from __future__ import annotations ++ ++import pytest ++ ++from genkit.core.action.types import ActionKind ++from genkit.plugins.evaluators.constant import GenkitMetricType, MetricConfig ++from genkit.plugins.evaluators.plugin_api import GenkitEvaluators ++ ++ ++@pytest.mark.asyncio ++async def test_init_returns_evaluator_actions(): ++ plugin = GenkitEvaluators( ++ params=[ ++ MetricConfig(metric_type=GenkitMetricType.REGEX), ++ MetricConfig(metric_type=GenkitMetricType.DEEP_EQUAL), ++ ] ++ ) ++ ++ actions = await plugin.init() ++ ++ assert {a.kind for a in actions} == {ActionKind.EVALUATOR} ++ assert {a.name for a in actions} == {str(GenkitMetricType.REGEX).lower(), str(GenkitMetricType.DEEP_EQUAL).lower()} ++ ++ ++@pytest.mark.asyncio ++async def test_list_returns_action_metadata(): ++ plugin = GenkitEvaluators( ++ params=[ ++ MetricConfig(metric_type=GenkitMetricType.REGEX), ++ ] ++ ) ++ ++ metas = await plugin.list_actions() ++ ++ assert len(metas) == 1 ++ assert metas[0].kind == ActionKind.EVALUATOR ++ assert metas[0].name == str(GenkitMetricType.REGEX).lower() +diff --git a/py/plugins/firebase/pyproject.toml b/py/plugins/firebase/pyproject.toml +index 05bb7c57f..fd747dbb3 100644 +--- a/py/plugins/firebase/pyproject.toml ++++ b/py/plugins/firebase/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -39,7 +38,7 @@ dependencies = [ + "strenum>=0.4.15; python_version < '3.11'", + ] + description = "Genkit Firebase Plugin" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "genkit-plugin-firebase" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/plugins/firebase/src/genkit/plugins/firebase/firestore.py b/py/plugins/firebase/src/genkit/plugins/firebase/firestore.py +index 8552fe4a6..074348d02 100644 +--- a/py/plugins/firebase/src/genkit/plugins/firebase/firestore.py ++++ b/py/plugins/firebase/src/genkit/plugins/firebase/firestore.py +@@ -20,7 +20,10 @@ from typing import Any + from google.cloud.firestore_v1 import DocumentSnapshot + from google.cloud.firestore_v1.base_vector_query import DistanceMeasure + +-from genkit.ai import GenkitRegistry, Plugin ++from genkit.ai import Plugin ++from genkit.blocks.retriever import RetrieverOptions, retriever_action_metadata ++from genkit.core.action import Action, ActionMetadata ++from genkit.core.action.types import ActionKind + from genkit.plugins.firebase.retriever import FirestoreRetriever + + from .constant import MetadataTransformFn +@@ -40,19 +43,9 @@ def firestore_action_name(name: str) -> str: + + + class FirestoreVectorStore(Plugin): +- """Firestore retriever plugin. ++ """Firestore retriever plugin (PluginV2).""" + +- Args: +- name: name if the retriever. +- collection: The name of the Firestore collection to query. +- vector_field: The name of the field containing the vector embeddings. +- content_field: The name of the field containing the document content, you wish to return. +- embedder: The embedder to use with this retriever. +- embedder_options: Optional configuration to pass to the embedder. +- distance_measure: The distance measure to use when comparing vectors. Defaults to 'COSINE'. +- firestore_client: The Firestore database instance from which to query. +- metadata_fields: Optional list of metadata fields to include. +- """ ++ name: str = 'firestore' + + def __init__( + self, +@@ -79,7 +72,7 @@ class FirestoreVectorStore(Plugin): + firestore_client: The Firestore database instance from which to query. + metadata_fields: Optional list of metadata fields to include. + """ +- self.name = name ++ self.store_name = name + self.firestore_client = firestore_client + self.collection = collection + self.vector_field = vector_field +@@ -89,31 +82,57 @@ class FirestoreVectorStore(Plugin): + self.distance_measure = distance_measure + self.metadata_fields = metadata_fields + +- def initialize(self, ai: GenkitRegistry) -> None: +- """Initialize firestore plugin. +- +- Register actions with the registry making them available for use in the Genkit framework. +- +- Args: +- ai: The registry to register actions with. +- +- Returns: +- None +- """ +- retriever = FirestoreRetriever( +- ai=ai, +- name=self.name, +- firestore_client=self.firestore_client, +- collection=self.collection, +- vector_field=self.vector_field, +- content_field=self.content_field, +- embedder=self.embedder, +- embedder_options=self.embedder_options, +- distance_measure=self.distance_measure, +- metadata_fields=self.metadata_fields, +- ) +- +- return ai.define_retriever( +- name=firestore_action_name(self.name), +- fn=retriever.retrieve, ++ async def init(self) -> list[Action]: ++ return [self._create_retriever_action()] ++ ++ async def resolve(self, action_type: ActionKind, name: str) -> Action | None: ++ if action_type != ActionKind.RETRIEVER: ++ return None ++ if name != self.store_name: ++ return None ++ return self._create_retriever_action() ++ ++ async def list_actions(self) -> list[ActionMetadata]: ++ return [ ++ retriever_action_metadata( ++ name=self.store_name, ++ options=RetrieverOptions( ++ label=self.store_name, ++ ), ++ ) ++ ] ++ ++ def _create_retriever_action(self) -> Action: ++ metadata: dict[str, Any] = { ++ 'retriever': { ++ 'label': self.store_name, ++ } ++ } ++ ++ async def retrieve(request, ctx): ++ ai = (ctx.context or {}).get('__genkit_ai__') ++ if ai is None: ++ raise ValueError( ++ 'FirestoreVectorStore retriever requires a Genkit instance in action context. ' ++ 'Use it via `await ai.retrieve(...)`.' ++ ) ++ retriever = FirestoreRetriever( ++ ai=ai, ++ name=self.store_name, ++ firestore_client=self.firestore_client, ++ collection=self.collection, ++ vector_field=self.vector_field, ++ content_field=self.content_field, ++ embedder=self.embedder, ++ embedder_options=self.embedder_options, ++ distance_measure=self.distance_measure, ++ metadata_fields=self.metadata_fields, ++ ) ++ return await retriever.retrieve(request, ctx) ++ ++ return Action( ++ kind=ActionKind.RETRIEVER, ++ name=self.store_name, ++ fn=retrieve, ++ metadata=metadata, + ) +diff --git a/py/plugins/firebase/src/genkit/plugins/firebase/tests/test_firestore_vectorstore_plugin.py b/py/plugins/firebase/src/genkit/plugins/firebase/tests/test_firestore_vectorstore_plugin.py +new file mode 100644 +index 000000000..df4345fce +--- /dev/null ++++ b/py/plugins/firebase/src/genkit/plugins/firebase/tests/test_firestore_vectorstore_plugin.py +@@ -0,0 +1,62 @@ ++# Copyright 2025 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++from __future__ import annotations ++ ++from unittest.mock import MagicMock ++ ++import pytest ++from google.cloud.firestore_v1.base_vector_query import DistanceMeasure ++ ++from genkit.core.action.types import ActionKind ++from genkit.plugins.firebase.firestore import FirestoreVectorStore ++ ++ ++@pytest.mark.asyncio ++async def test_init_returns_retriever_action(): ++ plugin = FirestoreVectorStore( ++ name='kb', ++ firestore_client=MagicMock(), ++ collection='docs', ++ vector_field='embedding', ++ content_field='text', ++ embedder='vertexai/text-embedding-004', ++ distance_measure=DistanceMeasure.COSINE, ++ ) ++ ++ actions = await plugin.init() ++ ++ assert len(actions) == 1 ++ assert actions[0].kind == ActionKind.RETRIEVER ++ assert actions[0].name == 'kb' ++ ++ ++@pytest.mark.asyncio ++async def test_list_returns_metadata(): ++ plugin = FirestoreVectorStore( ++ name='kb', ++ firestore_client=MagicMock(), ++ collection='docs', ++ vector_field='embedding', ++ content_field='text', ++ embedder='vertexai/text-embedding-004', ++ ) ++ ++ metas = await plugin.list_actions() ++ ++ assert len(metas) == 1 ++ assert metas[0].kind == ActionKind.RETRIEVER ++ assert metas[0].name == 'kb' +diff --git a/py/plugins/flask/pyproject.toml b/py/plugins/flask/pyproject.toml +index 43bb01a88..0b43ad401 100644 +--- a/py/plugins/flask/pyproject.toml ++++ b/py/plugins/flask/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -41,7 +40,7 @@ dependencies = [ + "flask", + ] + description = "Genkit Firebase Plugin" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "genkit-plugin-flask" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/plugins/flask/tests/flask_test.py b/py/plugins/flask/tests/flask_test.py +index 1308e9dee..e314644df 100644 +--- a/py/plugins/flask/tests/flask_test.py ++++ b/py/plugins/flask/tests/flask_test.py +@@ -65,7 +65,7 @@ def test_streaming(): + headers={'Authorization': 'Pavel', 'content-Type': 'application/json', 'accept': 'text/event-stream'}, + ) + +- assert response.is_streamed == True ++ assert response.is_streamed + + chunks = [] + for chunk in response.response: +diff --git a/py/plugins/google-cloud/pyproject.toml b/py/plugins/google-cloud/pyproject.toml +index 2a58f3e6f..43985cab8 100644 +--- a/py/plugins/google-cloud/pyproject.toml ++++ b/py/plugins/google-cloud/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -40,7 +39,7 @@ dependencies = [ + "strenum>=0.4.15; python_version < '3.11'", + ] + description = "Genkit Google Cloud Plugin" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "genkit-plugin-google-cloud" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/plugins/google-genai/pyproject.toml b/py/plugins/google-genai/pyproject.toml +index bc4df8a95..e98788f12 100644 +--- a/py/plugins/google-genai/pyproject.toml ++++ b/py/plugins/google-genai/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -41,7 +40,7 @@ dependencies = [ + "strenum>=0.4.15; python_version < '3.11'", + ] + description = "Genkit Google GenAI Plugin" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "genkit-plugin-google-genai" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py +index 437cf9506..c22d6a546 100644 +--- a/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py ++++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py +@@ -20,15 +20,14 @@ from functools import cached_property + from google import genai + from google.auth.credentials import Credentials + from google.genai.client import DebugConfig +-from google.genai.types import EmbedContentConfig, GenerateImagesConfigOrDict, HttpOptions, HttpOptionsDict ++from google.genai.types import HttpOptions, HttpOptionsDict + + import genkit.plugins.google_genai.constants as const +-from genkit.ai import GENKIT_CLIENT_HEADER, GenkitRegistry, Plugin +-from genkit.blocks.embedding import EmbedderOptions, EmbedderSupports, embedder_action_metadata +-from genkit.blocks.model import model_action_metadata +-from genkit.core.action import ActionMetadata ++from genkit.ai import GENKIT_CLIENT_HEADER, Plugin ++from genkit.blocks.embedding import EmbedderOptions, EmbedderSupports, embedder, embedder_action_metadata ++from genkit.blocks.model import model, model_action_metadata ++from genkit.core.action import Action, ActionMetadata + from genkit.core.registry import ActionKind +-from genkit.core.schema import to_json_schema + from genkit.plugins.google_genai.models.embedder import ( + Embedder, + GeminiEmbeddingModels, +@@ -37,7 +36,6 @@ from genkit.plugins.google_genai.models.embedder import ( + ) + from genkit.plugins.google_genai.models.gemini import ( + SUPPORTED_MODELS, +- GeminiConfigSchema, + GeminiModel, + GoogleAIGeminiVersion, + VertexAIGeminiVersion, +@@ -128,107 +126,65 @@ class GoogleAI(Plugin): + http_options=_inject_attribution_headers(http_options), + ) + +- def initialize(self, ai: GenkitRegistry) -> None: +- """Initialize the plugin by registering actions in the registry. ++ async def init(self) -> list[Action]: ++ actions: list[Action] = [] + +- Args: +- ai: the action registry. +- """ + for version in GoogleAIGeminiVersion: +- gemini_model = GeminiModel(version, self._client, ai) +- ai.define_model( +- name=googleai_name(version), +- fn=gemini_model.generate, +- metadata=gemini_model.metadata, +- # config_schema=GeminiConfigSchema, ++ gemini_model = GeminiModel(version, self._client) ++ actions.append( ++ model( ++ name=str(version), ++ fn=gemini_model.generate, ++ metadata=gemini_model.metadata, ++ # config_schema=GeminiConfigSchema, ++ ) + ) + + for version in GeminiEmbeddingModels: +- embedder = Embedder(version=version, client=self._client) ++ embedder_impl = Embedder(version=version, client=self._client) + embedder_info = default_embedder_info(version) +- ai.define_embedder( +- name=googleai_name(version), +- fn=embedder.generate, ++ actions.append( ++ embedder( ++ name=str(version), ++ fn=embedder_impl.generate, ++ options=EmbedderOptions( ++ label=embedder_info.get('label'), ++ dimensions=embedder_info.get('dimensions'), ++ supports=EmbedderSupports(**embedder_info['supports']) ++ if embedder_info.get('supports') ++ else None, ++ ), ++ ) ++ ) ++ ++ return actions ++ ++ async def resolve(self, action_type: ActionKind, name: str) -> Action | None: ++ if action_type == ActionKind.MODEL: ++ model_ref = google_model_info(name) ++ SUPPORTED_MODELS[name] = model_ref ++ gemini_model = GeminiModel(name, self._client) ++ return model( ++ name=name, ++ fn=gemini_model.generate, ++ metadata=gemini_model.metadata, ++ ) ++ if action_type == ActionKind.EMBEDDER: ++ embedder_impl = Embedder(version=name, client=self._client) ++ embedder_info = default_embedder_info(name) ++ return embedder( ++ name=name, ++ fn=embedder_impl.generate, + options=EmbedderOptions( + label=embedder_info.get('label'), + dimensions=embedder_info.get('dimensions'), + supports=EmbedderSupports(**embedder_info['supports']) if embedder_info.get('supports') else None, + ), + ) +- +- def resolve_action( +- self, +- ai: GenkitRegistry, +- kind: ActionKind, +- name: str, +- ) -> None: +- """Resolves and action. +- +- Args: +- ai: The Genkit registry. +- kind: The kind of action to resolve. +- name: The name of the action to resolve. +- """ +- if kind == ActionKind.MODEL: +- self._resolve_model(ai, name) +- elif kind == ActionKind.EMBEDDER: +- self._resolve_embedder(ai, name) +- +- def _resolve_model(self, ai: GenkitRegistry, name: str) -> None: +- """Resolves and defines a Google AI model within the Genkit registry. +- +- This internal method handles the logic for registering different types of +- Google AI models (e.g., Gemini text models) based on the provided name. +- It extracts a clean name, determines the model type, instantiates the +- appropriate model class, and registers it with the Genkit AI registry. +- +- Args: +- ai: The Genkit AI registry instance to define the model in. +- name: The name of the model to resolve. This name might include a +- prefix indicating it's from a specific plugin (e.g., 'googleai/gemini-pro'). +- """ +- _clean_name = name.replace(GOOGLEAI_PLUGIN_NAME + '/', '') if name.startswith(GOOGLEAI_PLUGIN_NAME) else name +- model_ref = google_model_info(_clean_name) +- +- SUPPORTED_MODELS[_clean_name] = model_ref +- +- gemini_model = GeminiModel(_clean_name, self._client, ai) +- +- ai.define_model( +- name=googleai_name(_clean_name), +- fn=gemini_model.generate, +- metadata=gemini_model.metadata, +- # config_schema=GeminiConfigSchema, +- ) +- +- def _resolve_embedder(self, ai: GenkitRegistry, name: str) -> None: +- """Resolves and defines a Google AI embedder within the Genkit registry. +- +- This internal method handles the logic for registering Google AI embedder +- models. It extracts a clean name, instantiates the embedder class, and +- registers it with the Genkit AI registry. +- +- Args: +- ai: The Genkit AI registry instance to define the embedder in. +- name: The name of the embedder to resolve. This name might include a +- prefix indicating it's from a specific plugin (e.g., 'googleai/embedding-001'). +- """ +- _clean_name = name.replace(GOOGLEAI_PLUGIN_NAME + '/', '') if name.startswith(GOOGLEAI_PLUGIN_NAME) else name +- embedder = Embedder(version=_clean_name, client=self._client) +- +- embedder_info = default_embedder_info(_clean_name) +- ai.define_embedder( +- name=googleai_name(_clean_name), +- fn=embedder.generate, +- options=EmbedderOptions( +- label=embedder_info.get('label'), +- dimensions=embedder_info.get('dimensions'), +- supports=EmbedderSupports(**embedder_info['supports']) if embedder_info.get('supports') else None, +- ), +- ) ++ return None + + @cached_property +- def list_actions(self) -> list[ActionMetadata]: ++ def _list_actions_cache(self) -> list[ActionMetadata]: + """Generate a list of available actions or models. + + Returns: +@@ -264,6 +220,9 @@ class GoogleAI(Plugin): + + return actions_list + ++ async def list_actions(self) -> list[ActionMetadata]: ++ return list(self._list_actions_cache) ++ + + class VertexAI(Plugin): + """Vertex AI plugin for Genkit. +@@ -315,125 +274,81 @@ class VertexAI(Plugin): + http_options=_inject_attribution_headers(http_options), + ) + +- def initialize(self, ai: GenkitRegistry) -> None: +- """Initialize the plugin by registering actions with the registry. +- +- This method registers the Vertex AI model actions with the provided +- registry, making them available for use in the Genkit framework. ++ async def init(self) -> list[Action]: ++ actions: list[Action] = [] + +- Args: +- ai: the action registry. +- """ + for version in VertexAIGeminiVersion: +- gemini_model = GeminiModel(version, self._client, ai) +- ai.define_model( +- name=vertexai_name(version), +- fn=gemini_model.generate, +- metadata=gemini_model.metadata, +- # config_schema=GeminiConfigSchema, ++ gemini_model = GeminiModel(version, self._client) ++ actions.append( ++ model( ++ name=str(version), ++ fn=gemini_model.generate, ++ metadata=gemini_model.metadata, ++ ) + ) + + for version in VertexEmbeddingModels: +- embedder = Embedder(version=version, client=self._client) ++ embedder_impl = Embedder(version=version, client=self._client) + embedder_info = default_embedder_info(version) +- ai.define_embedder( +- name=vertexai_name(version), +- fn=embedder.generate, +- options=EmbedderOptions( +- label=embedder_info.get('label'), +- dimensions=embedder_info.get('dimensions'), +- supports=EmbedderSupports(**embedder_info['supports']) if embedder_info.get('supports') else None, +- ), ++ actions.append( ++ embedder( ++ name=str(version), ++ fn=embedder_impl.generate, ++ options=EmbedderOptions( ++ label=embedder_info.get('label'), ++ dimensions=embedder_info.get('dimensions'), ++ supports=EmbedderSupports(**embedder_info['supports']) ++ if embedder_info.get('supports') ++ else None, ++ ), ++ ) + ) + + for version in ImagenVersion: + imagen_model = ImagenModel(version, self._client) +- ai.define_model( +- name=vertexai_name(version), +- fn=imagen_model.generate, +- metadata=imagen_model.metadata, ++ actions.append( ++ model( ++ name=str(version), ++ fn=imagen_model.generate, ++ metadata=imagen_model.metadata, ++ ) + ) + +- def resolve_action( +- self, +- ai: GenkitRegistry, +- kind: ActionKind, +- name: str, +- ) -> None: +- """Resolves and action. +- +- Args: +- ai: The Genkit registry. +- kind: The kind of action to resolve. +- name: The name of the action to resolve. +- """ +- if kind == ActionKind.MODEL: +- self._resolve_model(ai, name) +- elif kind == ActionKind.EMBEDDER: +- self._resolve_embedder(ai, name) +- +- def _resolve_model(self, ai: GenkitRegistry, name: str) -> None: +- """Resolves and defines a Vertex AI model within the Genkit registry. +- +- This internal method handles the logic for registering different types of +- Vertex AI models (e.g., Gemini text models, Imagen image models) based on +- the provided name. It extracts a clean name, determines the model type, +- instantiates the appropriate model class, and registers it with the Genkit +- AI registry. +- +- Args: +- ai: The Genkit AI registry instance to define the model in. +- name: The name of the model to resolve. This name might include a +- prefix indicating it's from a specific plugin (e.g., 'vertexai/gemini-pro'). +- """ +- _clean_name = name.replace(VERTEXAI_PLUGIN_NAME + '/', '') if name.startswith(VERTEXAI_PLUGIN_NAME) else name +- +- if _clean_name.lower().startswith('image'): +- model_ref = vertexai_image_model_info(_clean_name) +- model = ImagenModel(_clean_name, self._client) +- IMAGE_SUPPORTED_MODELS[_clean_name] = model_ref +- # config_schema = GenerateImagesConfigOrDict +- else: +- model_ref = google_model_info(_clean_name) +- model = GeminiModel(_clean_name, self._client, ai) +- SUPPORTED_MODELS[_clean_name] = model_ref +- # config_schema = GeminiConfigSchema +- +- ai.define_model( +- name=vertexai_name(_clean_name), +- fn=model.generate, +- metadata=model.metadata, +- # config_schema=config_schema, +- ) +- +- def _resolve_embedder(self, ai: GenkitRegistry, name: str) -> None: +- """Resolves and defines a Vertex AI embedder within the Genkit registry. ++ return actions ++ ++ async def resolve(self, action_type: ActionKind, name: str) -> Action | None: ++ if action_type == ActionKind.MODEL: ++ if name.lower().startswith('image'): ++ model_ref = vertexai_image_model_info(name) ++ model_impl = ImagenModel(name, self._client) ++ IMAGE_SUPPORTED_MODELS[name] = model_ref ++ else: ++ model_ref = google_model_info(name) ++ model_impl = GeminiModel(name, self._client) ++ SUPPORTED_MODELS[name] = model_ref ++ return model( ++ name=name, ++ fn=model_impl.generate, ++ metadata=model_impl.metadata, ++ ) + +- This internal method handles the logic for registering Google AI embedder +- models. It extracts a clean name, instantiates the embedder class, and +- registers it with the Genkit AI registry. ++ if action_type == ActionKind.EMBEDDER: ++ embedder_impl = Embedder(version=name, client=self._client) ++ embedder_info = default_embedder_info(name) ++ return embedder( ++ name=name, ++ fn=embedder_impl.generate, ++ options=EmbedderOptions( ++ label=embedder_info.get('label'), ++ dimensions=embedder_info.get('dimensions'), ++ supports=EmbedderSupports(**embedder_info['supports']) if embedder_info.get('supports') else None, ++ ), ++ ) + +- Args: +- ai: The Genkit AI registry instance to define the embedder in. +- name: The name of the embedder to resolve. This name might include a +- prefix indicating it's from a specific plugin (e.g., 'vertexai/embedding-001'). +- """ +- _clean_name = name.replace(VERTEXAI_PLUGIN_NAME + '/', '') if name.startswith(VERTEXAI_PLUGIN_NAME) else name +- embedder = Embedder(version=_clean_name, client=self._client) +- +- embedder_info = default_embedder_info(_clean_name) +- ai.define_embedder( +- name=vertexai_name(_clean_name), +- fn=embedder.generate, +- options=EmbedderOptions( +- label=embedder_info.get('label'), +- dimensions=embedder_info.get('dimensions'), +- supports=EmbedderSupports(**embedder_info['supports']) if embedder_info.get('supports') else None, +- ), +- ) ++ return None + + @cached_property +- def list_actions(self) -> list[ActionMetadata]: ++ def _list_actions_cache(self) -> list[ActionMetadata]: + """Generate a list of available actions or models. + + Returns: +@@ -469,6 +384,9 @@ class VertexAI(Plugin): + + return actions_list + ++ async def list_actions(self) -> list[ActionMetadata]: ++ return list(self._list_actions_cache) ++ + + def _inject_attribution_headers(http_options: HttpOptions | dict | None = None): + """Adds genkit client info to the appropriate http headers.""" +diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/context_caching/__init__.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/context_caching/__init__.py +index 19add86cb..1bd71c307 100644 +--- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/context_caching/__init__.py ++++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/context_caching/__init__.py +@@ -13,4 +13,3 @@ + # limitations under the License. + # + # SPDX-License-Identifier: Apache-2.0 +- +diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/context_caching/utils.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/context_caching/utils.py +index 405a88bb6..aedda4642 100644 +--- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/context_caching/utils.py ++++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/context_caching/utils.py +@@ -30,7 +30,7 @@ logger = structlog.getLogger(__name__) + + + def generate_cache_key(request: GenerateRequest) -> str: +- """Generates context cache key by hashing the given request instance ++ """Generates context cache key by hashing the given request instance. + + Args: + request: `GenerateRequest` instance to hash +@@ -42,7 +42,7 @@ def generate_cache_key(request: GenerateRequest) -> str: + + + def validate_context_cache_request(request: GenerateRequest, model_name: str) -> bool: +- """Verifies that the context cache request could be processed for the request ++ """Verifies that the context cache request could be processed for the request. + + Args: + request: `GenerateRequest` instance to check +diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py +index 12f09c170..914eed55a 100644 +--- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py ++++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py +@@ -149,9 +149,7 @@ from google import genai + from google.genai import types as genai_types # type: ignore + + from genkit.ai import ( +- ActionKind, + ActionRunContext, +- GenkitRegistry, + ) + from genkit.blocks.model import get_basic_usage_stats + from genkit.codec import dump_dict, dump_json +@@ -181,6 +179,25 @@ class GeminiConfigSchema(genai_types.GenerateContentConfig): + """Gemini Config Schema.""" + + code_execution: bool | None = None ++ response_modalities: list[str] | None = None ++ ++ ++class GeminiTtsConfigSchema(GeminiConfigSchema): ++ """Gemini TTS Config Schema.""" ++ ++ speech_config: dict[str, Any] | None = None ++ ++ ++class GeminiImageConfigSchema(GeminiConfigSchema): ++ """Gemini Image Config Schema.""" ++ ++ image_config: dict[str, Any] | None = None ++ ++ ++class GemmaConfigSchema(GeminiConfigSchema): ++ """Gemma Config Schema.""" ++ ++ temperature: float | None = None + + + GEMINI_1_5_PRO = ModelInfo( +@@ -341,6 +358,57 @@ GEMINI_2_5_FLASH_PREVIEW_04_17 = ModelInfo( + ), + ) + ++GENERIC_GEMINI_MODEL = ModelInfo( ++ label='Google AI - Gemini', ++ supports=Supports( ++ multiturn=True, ++ media=True, ++ tools=True, ++ tool_choice=True, ++ system_role=True, ++ constrained='no-tools', ++ output=['text', 'json'], ++ ), ++) ++ ++GENERIC_TTS_MODEL = ModelInfo( ++ label='Google AI - Gemini TTS', ++ supports=Supports( ++ multiturn=False, ++ media=False, ++ tools=False, ++ tool_choice=False, ++ system_role=False, ++ constrained='no-tools', ++ ), ++) ++ ++GENERIC_IMAGE_MODEL = ModelInfo( ++ label='Google AI - Gemini Image', ++ supports=Supports( ++ multiturn=True, ++ media=True, ++ tools=True, ++ tool_choice=True, ++ system_role=True, ++ constrained='no-tools', ++ output=['text'], ++ ), ++) ++ ++GENERIC_GEMMA_MODEL = ModelInfo( ++ label='Google AI - Gemma', ++ supports=Supports( ++ multiturn=True, ++ media=True, ++ tools=True, ++ tool_choice=True, ++ system_role=True, ++ constrained='no-tools', ++ output=['text', 'json'], ++ ), ++) ++ + + Deprecations = deprecated_enum_metafactory({ + 'GEMINI_1_0_PRO': DeprecationInfo(recommendation='GEMINI_2_0_FLASH', status=DeprecationStatus.DEPRECATED), +@@ -368,6 +436,21 @@ class VertexAIGeminiVersion(StrEnum, metaclass=Deprecations): + | `gemini-2.5-pro-exp-03-25` | Gemini 2.5 Pro Exp 03-25 | Supported | + | `gemini-2.5-pro-preview-03-25` | Gemini 2.5 Pro Preview 03-25 | Supported | + | `gemini-2.5-pro-preview-05-06` | Gemini 2.5 Pro Preview 05-06 | Supported | ++ | `gemini-3-flash-preview` | Gemini 3 Flash Preview | Supported | ++ | `gemini-3-pro-preview` | Gemini 3 Pro Preview | Supported | ++ | `gemini-2.5-pro` | Gemini 2.5 Pro | Supported | ++ | `gemini-2.5-flash` | Gemini 2.5 Flash | Supported | ++ | `gemini-2.5-flash-lite` | Gemini 2.5 Flash Lite | Supported | ++ | `gemini-2.5-flash-preview-tts` | Gemini 2.5 Flash Preview TTS | Supported | ++ | `gemini-2.5-pro-preview-tts` | Gemini 2.5 Pro Preview TTS | Supported | ++ | `gemini-3-pro-image-preview` | Gemini 3 Pro Image Preview | Supported | ++ | `gemini-2.5-flash-image-preview` | Gemini 2.5 Flash Image Preview | Supported | ++ | `gemini-2.5-flash-image` | Gemini 2.5 Flash Image | Supported | ++ | `gemma-3-12b-it` | Gemma 3 12B IT | Supported | ++ | `gemma-3-1b-it` | Gemma 3 1B IT | Supported | ++ | `gemma-3-27b-it` | Gemma 3 27B IT | Supported | ++ | `gemma-3-4b-it` | Gemma 3 4B IT | Supported | ++ | `gemma-3n-e4b-it` | Gemma 3n E4B IT | Supported | + """ + + GEMINI_1_5_FLASH = 'gemini-1.5-flash' +@@ -381,6 +464,21 @@ class VertexAIGeminiVersion(StrEnum, metaclass=Deprecations): + GEMINI_2_5_PRO_EXP_03_25 = 'gemini-2.5-pro-exp-03-25' + GEMINI_2_5_PRO_PREVIEW_03_25 = 'gemini-2.5-pro-preview-03-25' + GEMINI_2_5_PRO_PREVIEW_05_06 = 'gemini-2.5-pro-preview-05-06' ++ GEMINI_3_FLASH_PREVIEW = 'gemini-3-flash-preview' ++ GEMINI_3_PRO_PREVIEW = 'gemini-3-pro-preview' ++ GEMINI_2_5_PRO = 'gemini-2.5-pro' ++ GEMINI_2_5_FLASH = 'gemini-2.5-flash' ++ GEMINI_2_5_FLASH_LITE = 'gemini-2.5-flash-lite' ++ GEMINI_2_5_FLASH_PREVIEW_TTS = 'gemini-2.5-flash-preview-tts' ++ GEMINI_2_5_PRO_PREVIEW_TTS = 'gemini-2.5-pro-preview-tts' ++ GEMINI_3_PRO_IMAGE_PREVIEW = 'gemini-3-pro-image-preview' ++ GEMINI_2_5_FLASH_IMAGE_PREVIEW = 'gemini-2.5-flash-image-preview' ++ GEMINI_2_5_FLASH_IMAGE = 'gemini-2.5-flash-image' ++ GEMMA_3_12B_IT = 'gemma-3-12b-it' ++ GEMMA_3_1B_IT = 'gemma-3-1b-it' ++ GEMMA_3_27B_IT = 'gemma-3-27b-it' ++ GEMMA_3_4B_IT = 'gemma-3-4b-it' ++ GEMMA_3N_E4B_IT = 'gemma-3n-e4b-it' + + + class GoogleAIGeminiVersion(StrEnum, metaclass=Deprecations): +@@ -401,6 +499,21 @@ class GoogleAIGeminiVersion(StrEnum, metaclass=Deprecations): + | `gemini-2.5-pro-exp-03-25` | Gemini 2.5 Pro Exp 03-25 | Supported | + | `gemini-2.5-pro-preview-03-25` | Gemini 2.5 Pro Preview 03-25 | Supported | + | `gemini-2.5-pro-preview-05-06` | Gemini 2.5 Pro Preview 05-06 | Supported | ++ | `gemini-3-flash-preview` | Gemini 3 Flash Preview | Supported | ++ | `gemini-3-pro-preview` | Gemini 3 Pro Preview | Supported | ++ | `gemini-2.5-pro` | Gemini 2.5 Pro | Supported | ++ | `gemini-2.5-flash` | Gemini 2.5 Flash | Supported | ++ | `gemini-2.5-flash-lite` | Gemini 2.5 Flash Lite | Supported | ++ | `gemini-2.5-flash-preview-tts` | Gemini 2.5 Flash Preview TTS | Supported | ++ | `gemini-2.5-pro-preview-tts` | Gemini 2.5 Pro Preview TTS | Supported | ++ | `gemini-3-pro-image-preview` | Gemini 3 Pro Image Preview | Supported | ++ | `gemini-2.5-flash-image-preview` | Gemini 2.5 Flash Image Preview | Supported | ++ | `gemini-2.5-flash-image` | Gemini 2.5 Flash Image | Supported | ++ | `gemma-3-12b-it` | Gemma 3 12B IT | Supported | ++ | `gemma-3-1b-it` | Gemma 3 1B IT | Supported | ++ | `gemma-3-27b-it` | Gemma 3 27B IT | Supported | ++ | `gemma-3-4b-it` | Gemma 3 4B IT | Supported | ++ | `gemma-3n-e4b-it` | Gemma 3n E4B IT | Supported | + """ + + GEMINI_1_5_FLASH = 'gemini-1.5-flash' +@@ -414,6 +527,21 @@ class GoogleAIGeminiVersion(StrEnum, metaclass=Deprecations): + GEMINI_2_5_PRO_EXP_03_25 = 'gemini-2.5-pro-exp-03-25' + GEMINI_2_5_PRO_PREVIEW_03_25 = 'gemini-2.5-pro-preview-03-25' + GEMINI_2_5_PRO_PREVIEW_05_06 = 'gemini-2.5-pro-preview-05-06' ++ GEMINI_3_FLASH_PREVIEW = 'gemini-3-flash-preview' ++ GEMINI_3_PRO_PREVIEW = 'gemini-3-pro-preview' ++ GEMINI_2_5_PRO = 'gemini-2.5-pro' ++ GEMINI_2_5_FLASH = 'gemini-2.5-flash' ++ GEMINI_2_5_FLASH_LITE = 'gemini-2.5-flash-lite' ++ GEMINI_2_5_FLASH_PREVIEW_TTS = 'gemini-2.5-flash-preview-tts' ++ GEMINI_2_5_PRO_PREVIEW_TTS = 'gemini-2.5-pro-preview-tts' ++ GEMINI_3_PRO_IMAGE_PREVIEW = 'gemini-3-pro-image-preview' ++ GEMINI_2_5_FLASH_IMAGE_PREVIEW = 'gemini-2.5-flash-image-preview' ++ GEMINI_2_5_FLASH_IMAGE = 'gemini-2.5-flash-image' ++ GEMMA_3_12B_IT = 'gemma-3-12b-it' ++ GEMMA_3_1B_IT = 'gemma-3-1b-it' ++ GEMMA_3_27B_IT = 'gemma-3-27b-it' ++ GEMMA_3_4B_IT = 'gemma-3-4b-it' ++ GEMMA_3N_E4B_IT = 'gemma-3n-e4b-it' + + + SUPPORTED_MODELS = { +@@ -428,6 +556,21 @@ SUPPORTED_MODELS = { + GoogleAIGeminiVersion.GEMINI_2_5_PRO_EXP_03_25: GEMINI_2_5_PRO_EXP_03_25, + GoogleAIGeminiVersion.GEMINI_2_5_PRO_PREVIEW_03_25: GEMINI_2_5_PRO_PREVIEW_03_25, + GoogleAIGeminiVersion.GEMINI_2_5_PRO_PREVIEW_05_06: GEMINI_2_5_PRO_PREVIEW_05_06, ++ GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW: GENERIC_GEMINI_MODEL, ++ GoogleAIGeminiVersion.GEMINI_3_PRO_PREVIEW: GENERIC_GEMINI_MODEL, ++ GoogleAIGeminiVersion.GEMINI_2_5_PRO: GENERIC_GEMINI_MODEL, ++ GoogleAIGeminiVersion.GEMINI_2_5_FLASH: GENERIC_GEMINI_MODEL, ++ GoogleAIGeminiVersion.GEMINI_2_5_FLASH_LITE: GENERIC_GEMINI_MODEL, ++ GoogleAIGeminiVersion.GEMINI_2_5_FLASH_PREVIEW_TTS: GENERIC_TTS_MODEL, ++ GoogleAIGeminiVersion.GEMINI_2_5_PRO_PREVIEW_TTS: GENERIC_TTS_MODEL, ++ GoogleAIGeminiVersion.GEMINI_3_PRO_IMAGE_PREVIEW: GENERIC_IMAGE_MODEL, ++ GoogleAIGeminiVersion.GEMINI_2_5_FLASH_IMAGE_PREVIEW: GENERIC_IMAGE_MODEL, ++ GoogleAIGeminiVersion.GEMINI_2_5_FLASH_IMAGE: GENERIC_IMAGE_MODEL, ++ GoogleAIGeminiVersion.GEMMA_3_12B_IT: GENERIC_GEMMA_MODEL, ++ GoogleAIGeminiVersion.GEMMA_3_1B_IT: GENERIC_GEMMA_MODEL, ++ GoogleAIGeminiVersion.GEMMA_3_27B_IT: GENERIC_GEMMA_MODEL, ++ GoogleAIGeminiVersion.GEMMA_3_4B_IT: GENERIC_GEMMA_MODEL, ++ GoogleAIGeminiVersion.GEMMA_3N_E4B_IT: GENERIC_GEMMA_MODEL, + VertexAIGeminiVersion.GEMINI_1_5_FLASH: GEMINI_1_5_FLASH, + VertexAIGeminiVersion.GEMINI_1_5_FLASH_8B: GEMINI_1_5_FLASH_8B, + VertexAIGeminiVersion.GEMINI_1_5_PRO: GEMINI_1_5_PRO, +@@ -439,6 +582,21 @@ SUPPORTED_MODELS = { + VertexAIGeminiVersion.GEMINI_2_5_PRO_EXP_03_25: GEMINI_2_5_PRO_EXP_03_25, + VertexAIGeminiVersion.GEMINI_2_5_PRO_PREVIEW_03_25: GEMINI_2_5_PRO_PREVIEW_03_25, + VertexAIGeminiVersion.GEMINI_2_5_PRO_PREVIEW_05_06: GEMINI_2_5_PRO_PREVIEW_05_06, ++ VertexAIGeminiVersion.GEMINI_3_FLASH_PREVIEW: GENERIC_GEMINI_MODEL, ++ VertexAIGeminiVersion.GEMINI_3_PRO_PREVIEW: GENERIC_GEMINI_MODEL, ++ VertexAIGeminiVersion.GEMINI_2_5_PRO: GENERIC_GEMINI_MODEL, ++ VertexAIGeminiVersion.GEMINI_2_5_FLASH: GENERIC_GEMINI_MODEL, ++ VertexAIGeminiVersion.GEMINI_2_5_FLASH_LITE: GENERIC_GEMINI_MODEL, ++ VertexAIGeminiVersion.GEMINI_2_5_FLASH_PREVIEW_TTS: GENERIC_TTS_MODEL, ++ VertexAIGeminiVersion.GEMINI_2_5_PRO_PREVIEW_TTS: GENERIC_TTS_MODEL, ++ VertexAIGeminiVersion.GEMINI_3_PRO_IMAGE_PREVIEW: GENERIC_IMAGE_MODEL, ++ VertexAIGeminiVersion.GEMINI_2_5_FLASH_IMAGE_PREVIEW: GENERIC_IMAGE_MODEL, ++ VertexAIGeminiVersion.GEMINI_2_5_FLASH_IMAGE: GENERIC_IMAGE_MODEL, ++ VertexAIGeminiVersion.GEMMA_3_12B_IT: GENERIC_GEMMA_MODEL, ++ VertexAIGeminiVersion.GEMMA_3_1B_IT: GENERIC_GEMMA_MODEL, ++ VertexAIGeminiVersion.GEMMA_3_27B_IT: GENERIC_GEMMA_MODEL, ++ VertexAIGeminiVersion.GEMMA_3_4B_IT: GENERIC_GEMMA_MODEL, ++ VertexAIGeminiVersion.GEMMA_3N_E4B_IT: GENERIC_GEMMA_MODEL, + } + + +@@ -479,18 +637,15 @@ class GeminiModel: + self, + version: str | GoogleAIGeminiVersion | VertexAIGeminiVersion, + client: genai.Client, +- registry: GenkitRegistry, + ): + """Initialize Gemini model. + + Args: + version: Gemini version + client: Google AI client +- registry: Genkit registry + """ + self._version = version + self._client = client +- self._registry = registry + + def _get_tools(self, request: GenerateRequest) -> list[genai_types.Tool]: + """Generates VertexAI Gemini compatible tool definitions. +@@ -522,7 +677,7 @@ class GeminiModel: + name=tool.name, + description=tool.description, + parameters=params, +- response=tool.output_schema, ++ response=self._convert_schema_property(tool.output_schema) if tool.output_schema else None, + ) + return genai_types.Tool(function_declarations=[function]) + +@@ -588,32 +743,6 @@ class GeminiModel: + + return schema + +- def _call_tool(self, call: genai_types.FunctionCall) -> genai_types.Content: +- """Calls tool's function from the registry. +- +- Args: +- call: FunctionCall from Gemini response +- +- Returns: +- Gemini message content to add to the message +- """ +- tool_function = self._registry.registry.lookup_action(ActionKind.TOOL, call.name) +- if tool_function is None: +- raise LookupError(f'Tool {call.name} not found') +- +- args = tool_function.input_type.validate_python(call.args) +- tool_answer = tool_function.run(args) +- return genai_types.Content( +- parts=[ +- genai_types.Part.from_function_response( +- name=call.name, +- response={ +- 'content': tool_answer.response, +- }, +- ) +- ] +- ) +- + async def _retrieve_cached_content( + self, request: GenerateRequest, model_name: str, cache_config: dict, contents: list[genai_types.Content] + ) -> genai_types.CachedContent: +@@ -825,6 +954,8 @@ class GeminiModel: + cache = None + + for msg in request.messages: ++ if msg.role == Role.SYSTEM: ++ continue + content_parts: list[genai_types.Part] = [] + for p in msg.content: + content_parts.append(PartConverter.to_gemini(p)) +@@ -838,6 +969,9 @@ class GeminiModel: + contents=request_contents, + ) + ++ if not request_contents: ++ request_contents.append(genai_types.Content(parts=[genai_types.Part(text=' ')], role='user')) ++ + return request_contents, cache + + def _contents_from_response(self, response: genai_types.GenerateContentResponse) -> list: +@@ -853,8 +987,8 @@ class GeminiModel: + if response.candidates: + for candidate in response.candidates: + if candidate.content: +- for part in candidate.content.parts: +- content.append(PartConverter.from_gemini(part=part)) ++ for i, part in enumerate(candidate.content.parts): ++ content.append(PartConverter.from_gemini(part=part, ref=str(i))) + + return content + +@@ -915,7 +1049,6 @@ class GeminiModel: + for msg in system_messages: + for p in msg.content: + system_parts.append(PartConverter.to_gemini(p)) +- request.messages.remove(msg) + cfg.system_instruction = genai.types.Content(parts=system_parts) + + return cfg +diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py +index 7cb69c03b..ec2ab72ed 100644 +--- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py ++++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py +@@ -14,6 +14,7 @@ + # + # SPDX-License-Identifier: Apache-2.0 + import base64 ++from typing import Any + + from google import genai + +@@ -22,6 +23,7 @@ from genkit.types import ( + Media, + MediaPart, + Part, ++ ReasoningPart, + TextPart, + ToolRequest, + ToolRequestPart, +@@ -71,30 +73,43 @@ class PartConverter: + A `genai.types.Part` object representing the converted content. + """ + if isinstance(part.root, TextPart): +- return genai.types.Part(text=part.root.text) ++ return genai.types.Part(text=part.root.text or ' ') + if isinstance(part.root, ToolRequestPart): + return genai.types.Part( + function_call=genai.types.FunctionCall( +- id=part.root.tool_request.ref, +- name=part.root.tool_request.name, ++ # Gemini throws on '/' in tool name ++ name=part.root.tool_request.name.replace('/', '__'), + args=part.root.tool_request.input, +- ) ++ ), ++ thought_signature=cls._extract_thought_signature(part.root.metadata), ++ ) ++ if isinstance(part.root, ReasoningPart): ++ return genai.types.Part( ++ thought=True, ++ text=part.root.reasoning, ++ thought_signature=cls._extract_thought_signature(part.root.metadata), + ) + if isinstance(part.root, ToolResponsePart): + return genai.types.Part( + function_response=genai.types.FunctionResponse( + id=part.root.tool_response.ref, +- name=part.root.tool_response.name, +- response={'output': part.root.tool_response.output}, ++ name=part.root.tool_response.name.replace('/', '__'), ++ response=part.root.tool_response.output, + ) + ) + if isinstance(part.root, MediaPart): + url = part.root.media.url + if not url.startswith(cls.DATA): + raise ValueError(f'Unsupported media URL for inline_data: {url}') +- data = base64.b64decode(url.split(',', 1)[1]) ++ ++ # Extract mime type and data from data:mime_type;base64,data ++ metadata, data_str = url.split(',', 1) ++ mime_type = part.root.media.content_type or metadata.split(':', 1)[1].split(';', 1)[0] ++ data = base64.b64decode(data_str) ++ + return genai.types.Part( + inline_data=genai.types.Blob( ++ mime_type=mime_type, + data=data, + ) + ) +@@ -131,7 +146,7 @@ class PartConverter: + ) + + @classmethod +- def from_gemini(cls, part: genai.types.Part) -> Part: ++ def from_gemini(cls, part: genai.types.Part, ref: str | None = None) -> Part: + """Maps a Gemini Part back to a Genkit Part. + + This method inspects the type of the Gemini Part and converts it into +@@ -140,26 +155,41 @@ class PartConverter: + + Args: + part: The `genai.types.Part` object to convert. ++ ref: The tool call reference ID. + + Returns: + A Genkit `Part` object representing the converted content. + """ ++ if part.thought: ++ return Part( ++ root=ReasoningPart( ++ reasoning=part.text or '', ++ metadata=cls._encode_thought_signature(part.thought_signature), ++ ) ++ ) + if part.text: +- return Part(text=part.text) ++ return Part(root=TextPart(text=part.text)) + if part.function_call: + return Part( +- toolRequest=ToolRequest( +- ref=part.function_call.id, +- name=part.function_call.name, +- input=part.function_call.args, ++ root=ToolRequestPart( ++ tool_request=ToolRequest( ++ ref=ref or getattr(part.function_call, 'id', None), ++ # restore slashes ++ name=part.function_call.name.replace('__', '/'), ++ input=part.function_call.args, ++ ), ++ metadata=cls._encode_thought_signature(part.thought_signature), + ) + ) + if part.function_response: + return Part( +- toolResponse=ToolResponse( +- ref=part.function_call.id, +- name=part.function_response.name, +- output=part.function_response.response, ++ root=ToolResponsePart( ++ tool_response=ToolResponse( ++ ref=getattr(part.function_response, 'id', None), ++ # restore slashes ++ name=part.function_response.name.replace('__', '/'), ++ output=part.function_response.response, ++ ) + ) + ) + if part.inline_data: +@@ -188,3 +218,18 @@ class PartConverter: + } + } + ) ++ ++ @classmethod ++ def _extract_thought_signature(cls, metadata: Any) -> bytes | None: ++ """Extracts and decodes the thought signature from metadata.""" ++ thought_sig = metadata.root.get('thoughtSignature') if metadata else None ++ if isinstance(thought_sig, str): ++ return base64.b64decode(thought_sig) ++ return None ++ ++ @classmethod ++ def _encode_thought_signature(cls, thought_signature: bytes | None) -> dict[str, str] | None: ++ """Encodes the thought signature into metadata format.""" ++ if thought_signature: ++ return {'thoughtSignature': base64.b64encode(thought_signature).decode('utf-8')} ++ return None +diff --git a/py/plugins/google-genai/test/models/test_googlegenai_gemini.py b/py/plugins/google-genai/test/models/test_googlegenai_gemini.py +index 50c05b2b9..8f6aafe95 100644 +--- a/py/plugins/google-genai/test/models/test_googlegenai_gemini.py ++++ b/py/plugins/google-genai/test/models/test_googlegenai_gemini.py +@@ -16,7 +16,7 @@ + + import sys + import urllib.request +-from unittest.mock import ANY, AsyncMock, MagicMock, patch ++from unittest.mock import AsyncMock, MagicMock, patch + + if sys.version_info < (3, 11): # noqa + from strenum import StrEnum # noqa +diff --git a/py/plugins/google-genai/test/models/test_googlegenai_imagen.py b/py/plugins/google-genai/test/models/test_googlegenai_imagen.py +index c88438919..eda45a130 100644 +--- a/py/plugins/google-genai/test/models/test_googlegenai_imagen.py ++++ b/py/plugins/google-genai/test/models/test_googlegenai_imagen.py +@@ -14,7 +14,6 @@ + # + # SPDX-License-Identifier: Apache-2.0 + +-import base64 + import urllib.request + + import pytest +diff --git a/py/plugins/google-genai/test/test_google_plugin.py b/py/plugins/google-genai/test/test_google_plugin.py +index e589d86ee..bb4aa6a5d 100644 +--- a/py/plugins/google-genai/test/test_google_plugin.py ++++ b/py/plugins/google-genai/test/test_google_plugin.py +@@ -24,14 +24,13 @@ from unittest.mock import MagicMock, patch, ANY + + from google.auth.credentials import Credentials + from pydantic import BaseModel +-from google.genai.types import EmbedContentConfig, GenerateImagesConfigOrDict, HttpOptions ++from google.genai.types import HttpOptions + + import pytest + from genkit.ai import Genkit, GENKIT_CLIENT_HEADER + from genkit.blocks.embedding import embedder_action_metadata, EmbedderOptions, EmbedderSupports + from genkit.blocks.model import model_action_metadata + from genkit.core.registry import ActionKind +-from genkit.core.schema import to_json_schema + from genkit.plugins.google_genai import ( + GoogleAI, + VertexAI, +@@ -46,7 +45,6 @@ from genkit.plugins.google_genai.models.embedder import ( + ) + from genkit.plugins.google_genai.models.gemini import ( + DEFAULT_SUPPORTS_MODEL, +- GeminiConfigSchema, + SUPPORTED_MODELS, + GoogleAIGeminiVersion, + VertexAIGeminiVersion, +@@ -131,23 +129,24 @@ class TestGoogleAIInit(unittest.TestCase): + GoogleAI() + + +-def test_googleai_initialize(): +- """Unit tests for GoogleAI.initialize method.""" ++@pytest.mark.asyncio ++async def test_googleai_initialize(): ++ """Unit tests for GoogleAI.init method (V2).""" + api_key = 'test_api_key' + plugin = GoogleAI(api_key=api_key) +- ai_mock = MagicMock(spec=Genkit) + +- plugin.initialize(ai_mock) ++ actions = await plugin.init() + +- assert ai_mock.define_model.call_count == len(GoogleAIGeminiVersion) +- assert ai_mock.define_embedder.call_count == len(GeminiEmbeddingModels) ++ # Check we got actions for all models and embedders ++ model_actions = [a for a in actions if a.kind == ActionKind.MODEL] ++ embedder_actions = [a for a in actions if a.kind == ActionKind.EMBEDDER] + ++ assert len(model_actions) == len(GoogleAIGeminiVersion) ++ assert len(embedder_actions) == len(GeminiEmbeddingModels) ++ ++ # Check model names are correct + for version in GoogleAIGeminiVersion: +- ai_mock.define_model.assert_any_call( +- name=googleai_name(version), +- fn=ANY, +- metadata=ANY, +- ) ++ assert any(a.name == str(version) for a in model_actions) + + for version in GeminiEmbeddingModels: + ai_mock.define_embedder.assert_any_call( +diff --git a/py/plugins/google-genai/tests/test_google_genai_plugin_v2.py b/py/plugins/google-genai/tests/test_google_genai_plugin_v2.py +new file mode 100644 +index 000000000..66fe33f11 +--- /dev/null ++++ b/py/plugins/google-genai/tests/test_google_genai_plugin_v2.py +@@ -0,0 +1,56 @@ ++# Copyright 2025 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++from __future__ import annotations ++ ++from unittest.mock import MagicMock ++ ++import pytest ++ ++from genkit.core.action.types import ActionKind ++from genkit.plugins.google_genai.google import GoogleAI, VertexAI ++ ++ ++@pytest.mark.asyncio ++async def test_googleai_list_is_async(): ++ plugin = object.__new__(GoogleAI) ++ plugin._client = MagicMock() ++ plugin._client.models.list.return_value = [] ++ ++ metas = await plugin.list_actions() ++ assert isinstance(metas, list) ++ ++ ++@pytest.mark.asyncio ++async def test_vertexai_list_is_async(): ++ plugin = object.__new__(VertexAI) ++ plugin._client = MagicMock() ++ plugin._client.models.list.return_value = [] ++ ++ metas = await plugin.list_actions() ++ assert isinstance(metas, list) ++ ++ ++@pytest.mark.asyncio ++async def test_googleai_resolve_model_returns_action(): ++ plugin = object.__new__(GoogleAI) ++ plugin._client = MagicMock() ++ plugin._client.models.list.return_value = [] ++ ++ action = await plugin.resolve(ActionKind.MODEL, 'gemini-1.5-pro') ++ assert action is not None ++ assert action.kind == ActionKind.MODEL ++ assert action.name == 'gemini-1.5-pro' +diff --git a/py/plugins/mcp/README.md b/py/plugins/mcp/README.md +new file mode 100644 +index 000000000..1ad726219 +--- /dev/null ++++ b/py/plugins/mcp/README.md +@@ -0,0 +1,3 @@ ++# Genkit MCP Plugin ++ ++Integrate Model Context Protocol (MCP) with Genkit. +diff --git a/py/plugins/mcp/examples/client/simple_client.py b/py/plugins/mcp/examples/client/simple_client.py +new file mode 100644 +index 000000000..9512f12a5 +--- /dev/null ++++ b/py/plugins/mcp/examples/client/simple_client.py +@@ -0,0 +1,53 @@ ++# Copyright 2025 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++import asyncio ++ ++from genkit.ai import Genkit ++from genkit.plugins.mcp import McpServerConfig, create_mcp_client ++ ++try: ++ from genkit.plugins.google_genai import GoogleAI ++except ImportError: ++ GoogleAI = None ++ ++ ++# Simple client example connecting to 'everything' server using npx ++async def main(): ++ # Define the client plugin ++ everything_client = create_mcp_client( ++ name='everything', config=McpServerConfig(command='npx', args=['-y', '@modelcontextprotocol/server-everything']) ++ ) ++ ++ plugins = [everything_client] ++ if GoogleAI: ++ plugins.append(GoogleAI()) ++ ++ ai = Genkit(plugins=plugins) ++ ++ await everything_client.connect() ++ ++ print('Connected! Listing tools...') ++ ++ tools = await everything_client.list_tools() ++ for t in tools: ++ print(f'- {t.name}: {t.description}') ++ ++ await everything_client.close() ++ ++ ++if __name__ == '__main__': ++ asyncio.run(main()) +diff --git a/py/plugins/mcp/examples/server/prompts/port_code.prompt b/py/plugins/mcp/examples/server/prompts/port_code.prompt +new file mode 100644 +index 000000000..77e8501b3 +--- /dev/null ++++ b/py/plugins/mcp/examples/server/prompts/port_code.prompt +@@ -0,0 +1,13 @@ ++--- ++input: ++ schema: ++ code: string, the source code to port from one language to another ++ fromLang?: string, the original language of the source code (e.g. js, python) ++ toLang: string, the destination language of the source code (e.g. python, js) ++--- ++ ++You are assisting the user in translating code between two programming languages. Given the code below, translate it into {{toLang}}. ++ ++```{{#if fromLang}}{{fromLang}}{{/if}} ++{{code}} ++``` +diff --git a/py/plugins/mcp/examples/server/simple_server.py b/py/plugins/mcp/examples/server/simple_server.py +new file mode 100644 +index 000000000..2405c7429 +--- /dev/null ++++ b/py/plugins/mcp/examples/server/simple_server.py +@@ -0,0 +1,63 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++import asyncio ++ ++from pydantic import BaseModel, Field ++ ++from genkit.ai import Genkit ++from genkit.plugins.mcp import McpServerOptions, create_mcp_server ++ ++ ++# Define input model ++class AddInput(BaseModel): ++ a: int = Field(..., description='First number') ++ b: int = Field(..., description='Second number') ++ ++ ++import os ++ ++ ++def main(): ++ # Load prompts from the 'prompts' directory relative to this script ++ script_dir = os.path.dirname(os.path.abspath(__file__)) ++ prompts_dir = os.path.join(script_dir, 'prompts') ++ ++ ai = Genkit(prompt_dir=prompts_dir) ++ ++ @ai.tool(name='add', description='add two numbers together') ++ def add(input: AddInput): ++ return input.a + input.b ++ ++ # Genkit Python prompt definition (simplified) ++ # Note: In Python, prompts are typically loaded from files via prompt_dir ++ # This inline definition is for demonstration purposes ++ happy_prompt = ai.define_prompt( ++ input_schema={'action': str}, ++ prompt="If you're happy and you know it, {{action}}.", ++ ) ++ ++ # Create and start MCP server ++ # Note: create_mcp_server returns McpServer instance. ++ # In JS example: .start() is called. ++ server = create_mcp_server(ai, McpServerOptions(name='example_server', version='0.0.1')) ++ ++ print('Starting MCP server on stdio...') ++ asyncio.run(server.start()) ++ ++ ++if __name__ == '__main__': ++ main() +diff --git a/py/plugins/mcp/pyproject.toml b/py/plugins/mcp/pyproject.toml +new file mode 100644 +index 000000000..6ea44f68e +--- /dev/null ++++ b/py/plugins/mcp/pyproject.toml +@@ -0,0 +1,48 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++[project] ++authors = [{ name = "Google" }] ++classifiers = [ ++ "Development Status :: 3 - Alpha", ++ "Environment :: Console", ++ "Environment :: Web Environment", ++ "Intended Audience :: Developers", ++ "Operating System :: OS Independent", ++ "Programming Language :: Python", ++ "Programming Language :: Python :: 3 :: Only", ++ "Programming Language :: Python :: 3.10", ++ "Programming Language :: Python :: 3.11", ++ "Programming Language :: Python :: 3.12", ++ "Programming Language :: Python :: 3.13", ++ "Programming Language :: Python :: 3.14", ++ "Topic :: Scientific/Engineering :: Artificial Intelligence", ++ "Topic :: Software Development :: Libraries", ++] ++dependencies = ["genkit", "mcp"] ++description = "Genkit MCP Plugin" ++license = "Apache-2.0" ++name = "genkit-plugins-mcp" ++readme = "README.md" ++requires-python = ">=3.10" ++version = "0.1.0" ++ ++[build-system] ++build-backend = "hatchling.build" ++requires = ["hatchling"] ++ ++[tool.hatch.build.targets.wheel] ++packages = ["src"] +diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/__init__.py b/py/plugins/mcp/src/genkit/plugins/mcp/__init__.py +new file mode 100644 +index 000000000..7e48a29a7 +--- /dev/null ++++ b/py/plugins/mcp/src/genkit/plugins/mcp/__init__.py +@@ -0,0 +1,40 @@ ++""" ++Copyright 2026 Google LLC ++ ++Licensed under the Apache License, Version 2.0 (the "License"); ++you may not use this file except in compliance with the License. ++You may obtain a copy of the License at ++ ++ http://www.apache.org/licenses/LICENSE-2.0 ++ ++Unless required by applicable law or agreed to in writing, software ++distributed under the License is distributed on an "AS IS" BASIS, ++WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++See the License for the specific language governing permissions and ++limitations under the License. ++""" ++ ++from .client.client import ( ++ McpClient, ++ McpServerConfig, ++ create_mcp_client, ++) ++from .client.host import McpHost, create_mcp_host ++from .server import McpServer, McpServerOptions, create_mcp_server ++ ++ ++def package_name() -> str: ++ return 'genkit.plugins.mcp' ++ ++ ++__all__ = [ ++ 'McpClient', ++ 'McpHost', ++ 'McpServerConfig', ++ 'create_mcp_client', ++ 'create_mcp_host', ++ 'McpServer', ++ 'McpServerOptions', ++ 'create_mcp_server', ++ 'package_name', ++] +diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/client/__init__.py b/py/plugins/mcp/src/genkit/plugins/mcp/client/__init__.py +new file mode 100644 +index 000000000..fe4b8ffe1 +--- /dev/null ++++ b/py/plugins/mcp/src/genkit/plugins/mcp/client/__init__.py +@@ -0,0 +1,16 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ +diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/client/client.py b/py/plugins/mcp/src/genkit/plugins/mcp/client/client.py +new file mode 100644 +index 000000000..4ef9e715a +--- /dev/null ++++ b/py/plugins/mcp/src/genkit/plugins/mcp/client/client.py +@@ -0,0 +1,208 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++import asyncio ++import uuid ++from typing import Any, Callable, Dict, List, Optional, Union ++ ++import structlog ++from pydantic import BaseModel ++ ++from genkit.ai import Genkit ++from genkit.ai._plugin import Plugin ++from genkit.ai._registry import GenkitRegistry ++from genkit.core.action.types import ActionKind ++from mcp import ClientSession, StdioServerParameters ++from mcp.client.sse import sse_client ++from mcp.client.stdio import stdio_client ++from mcp.types import CallToolResult, Prompt, Resource, Tool ++ ++logger = structlog.get_logger(__name__) ++ ++ ++class McpServerConfig(BaseModel): ++ command: Optional[str] = None ++ args: Optional[List[str]] = None ++ env: Optional[Dict[str, str]] = None ++ url: Optional[str] = None ++ disabled: bool = False ++ ++ ++class McpClient(Plugin): ++ """Client for connecting to a single MCP server.""" ++ ++ def __init__(self, name: str, config: McpServerConfig, server_name: Optional[str] = None): ++ self.name = name ++ self.config = config ++ self.server_name = server_name or name ++ self.session: Optional[ClientSession] = None ++ self._exit_stack = None ++ self._session_context = None ++ self.ai: Optional[GenkitRegistry] = None ++ ++ def plugin_name(self) -> str: ++ return self.name ++ ++ def initialize(self, ai: GenkitRegistry) -> None: ++ self.ai = ai ++ ++ def resolve_action(self, ai: GenkitRegistry, kind: ActionKind, name: str) -> None: ++ # MCP tools are dynamic and currently registered upon connection/Discovery. ++ # This hook allows lazy resolution if we implement it. ++ pass ++ ++ async def connect(self): ++ """Connects to the MCP server.""" ++ if self.config.disabled: ++ logger.info(f'MCP server {self.server_name} is disabled.') ++ return ++ ++ try: ++ if self.config.command: ++ server_params = StdioServerParameters( ++ command=self.config.command, args=self.config.args or [], env=self.config.env ++ ) ++ # stdio_client returns (read, write) streams ++ stdio_context = stdio_client(server_params) ++ read, write = await stdio_context.__aenter__() ++ self._exit_stack = stdio_context ++ ++ # Create and initialize session ++ session_context = ClientSession(read, write) ++ self.session = await session_context.__aenter__() ++ self._session_context = session_context ++ ++ elif self.config.url: ++ # TODO: Verify SSE client usage in mcp python SDK ++ sse_context = sse_client(self.config.url) ++ read, write = await sse_context.__aenter__() ++ self._exit_stack = sse_context ++ ++ session_context = ClientSession(read, write) ++ self.session = await session_context.__aenter__() ++ self._session_context = session_context ++ ++ await self.session.initialize() ++ logger.info(f'Connected to MCP server: {self.server_name}') ++ ++ except Exception as e: ++ logger.error(f'Failed to connect to MCP server {self.server_name}: {e}') ++ self.config.disabled = True ++ # Clean up on error ++ await self.close() ++ raise e ++ ++ async def close(self): ++ """Closes the connection.""" ++ if hasattr(self, '_session_context') and self._session_context: ++ try: ++ await self._session_context.__aexit__(None, None, None) ++ except Exception as e: ++ logger.debug(f'Error closing session: {e}') ++ if self._exit_stack: ++ try: ++ await self._exit_stack.__aexit__(None, None, None) ++ except Exception as e: ++ logger.debug(f'Error closing transport: {e}') ++ ++ async def list_tools(self) -> List[Tool]: ++ if not self.session: ++ return [] ++ result = await self.session.list_tools() ++ return result.tools ++ ++ async def call_tool(self, tool_name: str, arguments: dict) -> Any: ++ if not self.session: ++ raise RuntimeError('MCP client is not connected') ++ result: CallToolResult = await self.session.call_tool(tool_name, arguments) ++ # Process result similarly to JS SDK ++ if result.isError: ++ raise RuntimeError(f'Tool execution failed: {result.content}') ++ ++ # Simple text extraction for now ++ texts = [c.text for c in result.content if c.type == 'text'] ++ return ''.join(texts) ++ ++ async def list_prompts(self) -> List[Prompt]: ++ if not self.session: ++ return [] ++ result = await self.session.list_prompts() ++ return result.prompts ++ ++ async def get_prompt(self, name: str, arguments: Optional[dict] = None) -> Any: ++ if not self.session: ++ raise RuntimeError('MCP client is not connected') ++ return await self.session.get_prompt(name, arguments) ++ ++ async def list_resources(self) -> List[Resource]: ++ if not self.session: ++ return [] ++ result = await self.session.list_resources() ++ return result.resources ++ ++ async def read_resource(self, uri: str) -> Any: ++ if not self.session: ++ raise RuntimeError('MCP client is not connected') ++ return await self.session.read_resource(uri) ++ ++ async def register_tools(self, ai: Optional[Genkit] = None): ++ """Registers all tools from connected client to Genkit.""" ++ registry = ai.registry if ai else (self.ai.registry if self.ai else None) ++ if not registry: ++ logger.warning('No Genkit registry available to register tools.') ++ return ++ ++ if not self.session: ++ return ++ ++ try: ++ tools = await self.list_tools() ++ for tool in tools: ++ # Create a wrapper function for the tool ++ # We need to capture tool and client in closure ++ async def tool_wrapper(args: Any = None, _tool_name=tool.name): ++ # args might be Pydantic model or dict. Genkit passes dict usually? ++ # TODO: Validate args against schema if needed ++ arguments = args ++ if hasattr(args, 'model_dump'): ++ arguments = args.model_dump() ++ return await self.call_tool(_tool_name, arguments or {}) ++ ++ # Use metadata to store MCP specific info ++ metadata = {'mcp': {'_meta': tool._meta}} if hasattr(tool, '_meta') else {} ++ ++ # Define the tool in Genkit registry ++ registry.register_action( ++ kind=ActionKind.TOOL, ++ name=f'{self.server_name}/{tool.name}', ++ fn=tool_wrapper, ++ description=tool.description, ++ metadata=metadata, ++ # TODO: json_schema conversion from tool.inputSchema ++ ) ++ logger.debug(f'Registered MCP tool: {self.server_name}/{tool.name}') ++ except Exception as e: ++ logger.error(f'Error registering tools for {self.server_name}: {e}') ++ ++ async def get_active_tools(self) -> List[Any]: ++ """Returns all active tools.""" ++ if not self.session: ++ return [] ++ return await self.list_tools() ++ ++ ++def create_mcp_client(config: McpServerConfig, name: str = 'mcp-client') -> McpClient: ++ return McpClient(name, config) +diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/client/host.py b/py/plugins/mcp/src/genkit/plugins/mcp/client/host.py +new file mode 100644 +index 000000000..cd0a4691d +--- /dev/null ++++ b/py/plugins/mcp/src/genkit/plugins/mcp/client/host.py +@@ -0,0 +1,64 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++from typing import Dict, List, Optional ++ ++from genkit.ai import Genkit ++ ++from .client import McpClient, McpServerConfig ++ ++ ++class McpHost: ++ """Host for managing multiple MCP clients.""" ++ ++ def __init__(self, clients: Dict[str, McpServerConfig]): ++ self.clients_config = clients ++ self.clients: Dict[str, McpClient] = {name: McpClient(name, config) for name, config in clients.items()} ++ ++ async def start(self): ++ """Starts all enabled MCP clients.""" ++ for client in self.clients.values(): ++ if not client.config.disabled: ++ await client.connect() ++ ++ async def close(self): ++ """Closes all MCP clients.""" ++ for client in self.clients.values(): ++ await client.close() ++ ++ async def register_tools(self, ai: Genkit): ++ """Registers all tools from connected clients to Genkit.""" ++ for client in self.clients.values(): ++ if client.session: ++ await client.register_tools(ai) ++ ++ async def enable(self, name: str): ++ """Enables and connects an MCP client.""" ++ if name in self.clients: ++ client = self.clients[name] ++ client.config.disabled = False ++ await client.connect() ++ ++ async def disable(self, name: str): ++ """Disables and closes an MCP client.""" ++ if name in self.clients: ++ client = self.clients[name] ++ client.config.disabled = True ++ await client.close() ++ ++ ++def create_mcp_host(configs: Dict[str, McpServerConfig]) -> McpHost: ++ return McpHost(configs) +diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/index.py b/py/plugins/mcp/src/genkit/plugins/mcp/index.py +new file mode 100644 +index 000000000..4f859e2fe +--- /dev/null ++++ b/py/plugins/mcp/src/genkit/plugins/mcp/index.py +@@ -0,0 +1,40 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++""" ++MCP Plugin Index ++ ++This module serves as the main entry point for the MCP plugin, ++similar to js/plugins/mcp/src/index.ts. ++ ++In Python, the actual exports are handled by the parent __init__.py, ++but this file exists for structural parity with the JS SDK. ++""" ++ ++from .client.client import McpClient, McpServerConfig, create_mcp_client ++from .client.host import McpHost, create_mcp_host ++from .server import McpServer, McpServerOptions, create_mcp_server ++ ++__all__ = [ ++ 'McpClient', ++ 'McpHost', ++ 'McpServerConfig', ++ 'create_mcp_client', ++ 'create_mcp_host', ++ 'McpServer', ++ 'McpServerOptions', ++ 'create_mcp_server', ++] +diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/server.py b/py/plugins/mcp/src/genkit/plugins/mcp/server.py +new file mode 100644 +index 000000000..3d313ccda +--- /dev/null ++++ b/py/plugins/mcp/src/genkit/plugins/mcp/server.py +@@ -0,0 +1,463 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# distributed under the License. ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++"""MCP Server implementation for exposing Genkit actions via Model Context Protocol.""" ++ ++import asyncio ++from typing import Any, Optional ++ ++import structlog ++from pydantic import BaseModel ++ ++from genkit.ai import Genkit ++from genkit.blocks.resource import matches_uri_template ++from genkit.core.action._key import parse_action_key ++from genkit.core.action.types import ActionKind ++from genkit.core.error import GenkitError ++from genkit.core.schema import to_json_schema ++from mcp.server import Server ++from mcp.server.stdio import stdio_server ++from mcp.types import ( ++ CallToolRequest, ++ CallToolResult, ++ GetPromptRequest, ++ GetPromptResult, ++ ListPromptsRequest, ++ ListPromptsResult, ++ ListResourcesRequest, ++ ListResourcesResult, ++ ListResourceTemplatesRequest, ++ ListResourceTemplatesResult, ++ ListToolsRequest, ++ ListToolsResult, ++ Prompt, ++ ReadResourceRequest, ++ ReadResourceResult, ++ Resource, ++ ResourceTemplate, ++ Tool, ++) ++ ++from .util import ( ++ to_mcp_prompt_arguments, ++ to_mcp_prompt_message, ++ to_mcp_resource_contents, ++ to_mcp_tool_result, ++) ++ ++logger = structlog.get_logger(__name__) ++ ++ ++class McpServerOptions(BaseModel): ++ """Options for creating an MCP server. ++ ++ Attributes: ++ name: The name of the MCP server. ++ version: The version of the server (default: "1.0.0"). ++ """ ++ ++ name: str ++ version: str = '1.0.0' ++ ++ ++class McpServer: ++ """Exposes Genkit tools, prompts, and resources as an MCP server. ++ ++ This class wraps a Genkit instance and makes its registered actions ++ (tools, prompts, resources) available to MCP clients via the Model Context Protocol. ++ """ ++ ++ def __init__(self, ai: Genkit, options: McpServerOptions): ++ """Initialize the MCP server. ++ ++ Args: ++ ai: The Genkit instance whose actions will be exposed. ++ options: Configuration options for the MCP server. ++ """ ++ self.ai = ai ++ self.options = options ++ self.server: Optional[Server] = None ++ self.actions_resolved = False ++ self.tool_actions: list[Any] = [] ++ self.prompt_actions: list[Any] = [] ++ self.resource_actions: list[Any] = [] ++ self.tool_actions_map: dict[str, Any] = {} ++ self.prompt_actions_map: dict[str, Any] = {} ++ self.resource_uri_map: dict[str, Any] = {} ++ self.resource_templates: list[tuple[str, Any]] = [] ++ ++ async def setup(self) -> None: ++ """Initialize the MCP server and register request handlers. ++ ++ This method sets up the MCP Server instance, registers all request handlers, ++ and resolves all actions from the Genkit registry. It's idempotent and can ++ be called multiple times safely. ++ """ ++ if self.actions_resolved: ++ return ++ ++ # Create MCP Server instance ++ self.server = Server( ++ self.options.name, ++ version=self.options.version, ++ ) ++ ++ # Register request handlers using decorators ++ self.server.list_tools()(self.list_tools) ++ self.server.call_tool()(self.call_tool) ++ self.server.list_prompts()(self.list_prompts) ++ self.server.get_prompt()(self.get_prompt) ++ self.server.list_resources()(self.list_resources) ++ self.server.list_resource_templates()(self.list_resource_templates) ++ self.server.read_resource()(self.read_resource) ++ ++ # Resolve all actions from Genkit registry ++ # We need the actual Action objects, not just serializable dicts ++ self.tool_actions = [] ++ self.prompt_actions = [] ++ self.resource_actions = [] ++ ++ # Get all actions from the registry ++ # We use the internal _entries for local actions and plugins ++ with self.ai.registry._lock: ++ for kind, entries in self.ai.registry._entries.items(): ++ for name, action in entries.items(): ++ if kind == ActionKind.TOOL: ++ self.tool_actions.append(action) ++ self.tool_actions_map[action.name] = action ++ elif kind == ActionKind.PROMPT: ++ self.prompt_actions.append(action) ++ self.prompt_actions_map[action.name] = action ++ elif kind == ActionKind.RESOURCE: ++ self.resource_actions.append(action) ++ metadata = action.metadata or {} ++ resource_meta = metadata.get('resource', {}) ++ if resource_meta.get('uri'): ++ self.resource_uri_map[resource_meta['uri']] = action ++ if resource_meta.get('template'): ++ self.resource_templates.append((resource_meta['template'], action)) ++ ++ # Also get actions from plugins that might not be in _entries yet ++ # (though most plugins register them in _entries during initialization) ++ plugin_actions = self.ai.registry.list_actions() ++ for key in plugin_actions: ++ kind, name = parse_action_key(key) ++ action = self.ai.registry.lookup_action(kind, name) ++ if action: ++ if kind == ActionKind.TOOL and action not in self.tool_actions: ++ self.tool_actions.append(action) ++ self.tool_actions_map[action.name] = action ++ elif kind == ActionKind.PROMPT and action not in self.prompt_actions: ++ self.prompt_actions.append(action) ++ self.prompt_actions_map[action.name] = action ++ elif kind == ActionKind.RESOURCE and action not in self.resource_actions: ++ self.resource_actions.append(action) ++ metadata = action.metadata or {} ++ resource_meta = metadata.get('resource', {}) ++ if resource_meta.get('uri'): ++ self.resource_uri_map[resource_meta['uri']] = action ++ if resource_meta.get('template'): ++ self.resource_templates.append((resource_meta['template'], action)) ++ ++ self.actions_resolved = True ++ ++ logger.info( ++ f'MCP Server initialized', ++ tools=len(self.tool_actions), ++ prompts=len(self.prompt_actions), ++ resources=len(self.resource_actions), ++ ) ++ ++ async def list_tools(self, request: ListToolsRequest) -> ListToolsResult: ++ """Handle MCP requests to list available tools. ++ ++ Args: ++ request: The MCP ListToolsRequest. ++ ++ Returns: ++ ListToolsResult containing all registered Genkit tools. ++ """ ++ await self.setup() ++ ++ tools: list[Tool] = [] ++ for action in self.tool_actions: ++ # Get tool definition ++ input_schema = to_json_schema(action.input_schema) if action.input_schema else {'type': 'object'} ++ ++ tools.append( ++ Tool( ++ name=action.name, ++ description=action.description or '', ++ inputSchema=input_schema, ++ _meta=action.metadata.get('mcp', {}).get('_meta') if action.metadata else None, ++ ) ++ ) ++ ++ return ListToolsResult(tools=tools) ++ ++ async def call_tool(self, request: CallToolRequest) -> CallToolResult: ++ """Handle MCP requests to call a specific tool. ++ ++ Args: ++ request: The MCP CallToolRequest containing tool name and arguments. ++ ++ Returns: ++ CallToolResult with the tool execution result. ++ ++ Raises: ++ GenkitError: If the requested tool is not found. ++ """ ++ await self.setup() ++ ++ # Find the tool action ++ tool = self.tool_actions_map.get(request.params.name) ++ ++ if not tool: ++ raise GenkitError( ++ status='NOT_FOUND', message=f"Tried to call tool '{request.params.name}' but it could not be found." ++ ) ++ ++ # Execute the tool ++ result = await tool.arun(request.params.arguments) ++ result = result.response ++ ++ # Convert result to MCP format ++ return CallToolResult(content=to_mcp_tool_result(result)) ++ ++ async def list_prompts(self, request: ListPromptsRequest) -> ListPromptsResult: ++ """Handle MCP requests to list available prompts. ++ ++ Args: ++ request: The MCP ListPromptsRequest. ++ ++ Returns: ++ ListPromptsResult containing all registered Genkit prompts. ++ """ ++ await self.setup() ++ ++ prompts: list[Prompt] = [] ++ for action in self.prompt_actions: ++ # Convert input schema to MCP prompt arguments ++ input_schema = to_json_schema(action.input_schema) if action.input_schema else None ++ arguments = to_mcp_prompt_arguments(input_schema) if input_schema else None ++ ++ prompts.append( ++ Prompt( ++ name=action.name, ++ description=action.description or '', ++ arguments=arguments, ++ _meta=action.metadata.get('mcp', {}).get('_meta') if action.metadata else None, ++ ) ++ ) ++ ++ return ListPromptsResult(prompts=prompts) ++ ++ async def get_prompt(self, request: GetPromptRequest) -> GetPromptResult: ++ """Handle MCP requests to get (render) a specific prompt. ++ ++ Args: ++ request: The MCP GetPromptRequest containing prompt name and arguments. ++ ++ Returns: ++ GetPromptResult with the rendered prompt messages. ++ ++ Raises: ++ GenkitError: If the requested prompt is not found. ++ """ ++ await self.setup() ++ ++ # Find the prompt action ++ prompt = self.prompt_actions_map.get(request.params.name) ++ ++ if not prompt: ++ raise GenkitError( ++ status='NOT_FOUND', ++ message=f"[MCP Server] Tried to call prompt '{request.params.name}' but it could not be found.", ++ ) ++ ++ # Execute the prompt ++ result = await prompt.arun(request.params.arguments) ++ result = result.response ++ ++ # Convert messages to MCP format ++ messages = [to_mcp_prompt_message(msg) for msg in result.messages] ++ ++ return GetPromptResult(description=prompt.description, messages=messages) ++ ++ async def list_resources(self, request: ListResourcesRequest) -> ListResourcesResult: ++ """Handle MCP requests to list available resources with fixed URIs. ++ ++ Args: ++ request: The MCP ListResourcesRequest. ++ ++ Returns: ++ ListResourcesResult containing resources with fixed URIs. ++ """ ++ await self.setup() ++ ++ resources: list[Resource] = [] ++ for action in self.resource_actions: ++ metadata = action.metadata or {} ++ resource_meta = metadata.get('resource', {}) ++ ++ # Only include resources with fixed URIs (not templates) ++ if resource_meta.get('uri'): ++ resources.append( ++ Resource( ++ name=action.name, ++ description=action.description or '', ++ uri=resource_meta['uri'], ++ _meta=metadata.get('mcp', {}).get('_meta'), ++ ) ++ ) ++ ++ return ListResourcesResult(resources=resources) ++ ++ async def list_resource_templates(self, request: ListResourceTemplatesRequest) -> ListResourceTemplatesResult: ++ """Handle MCP requests to list available resource templates. ++ ++ Args: ++ request: The MCP ListResourceTemplatesRequest. ++ ++ Returns: ++ ListResourceTemplatesResult containing resources with URI templates. ++ """ ++ await self.setup() ++ ++ templates: list[ResourceTemplate] = [] ++ for action in self.resource_actions: ++ metadata = action.metadata or {} ++ resource_meta = metadata.get('resource', {}) ++ ++ # Only include resources with templates ++ if resource_meta.get('template'): ++ templates.append( ++ ResourceTemplate( ++ name=action.name, ++ description=action.description or '', ++ uriTemplate=resource_meta['template'], ++ _meta=metadata.get('mcp', {}).get('_meta'), ++ ) ++ ) ++ ++ return ListResourceTemplatesResult(resourceTemplates=templates) ++ ++ async def read_resource(self, request: ReadResourceRequest) -> ReadResourceResult: ++ """Handle MCP requests to read a specific resource. ++ ++ Args: ++ request: The MCP ReadResourceRequest containing the resource URI. ++ ++ Returns: ++ ReadResourceResult with the resource content. ++ ++ Raises: ++ GenkitError: If no matching resource is found. ++ """ ++ await self.setup() ++ ++ uri = request.params.uri ++ ++ # Check for exact URI match ++ resource = self.resource_uri_map.get(uri) ++ ++ # Check for template match if not found by exact URI ++ if not resource: ++ for template, action in self.resource_templates: ++ if matches_uri_template(template, uri): ++ resource = action ++ break ++ ++ if not resource: ++ raise GenkitError(status='NOT_FOUND', message=f"Tried to call resource '{uri}' but it could not be found.") ++ ++ # Execute the resource action ++ result = await resource.arun({'uri': uri}) ++ result = result.response ++ ++ # Convert content to MCP format ++ content = result.get('content', []) if isinstance(result, dict) else result.content ++ contents = to_mcp_resource_contents(uri, content) ++ ++ return ReadResourceResult(contents=contents) ++ ++ async def start(self, transport: Any = None) -> None: ++ """Start the MCP server with the specified transport. ++ ++ Args: ++ transport: Optional MCP transport instance. If not provided, ++ a StdioServerTransport will be created and used. ++ """ ++ await self.setup() ++ ++ if not transport: ++ async with stdio_server() as (read, write): ++ await self.server.run(read, write, self.server.create_initialization_options()) ++ else: ++ # Connect the transport ++ async with transport as (read, write): ++ await self.server.run(read, write, self.server.create_initialization_options()) ++ ++ logger.debug(f"[MCP Server] MCP server '{self.options.name}' started successfully.") ++ ++ ++# Schema types from mcp.types ++ListToolsRequestSchema = ListToolsRequest ++CallToolRequestSchema = CallToolRequest ++ListPromptsRequestSchema = ListPromptsRequest ++GetPromptRequestSchema = GetPromptRequest ++ListResourcesRequestSchema = ListResourcesRequest ++ListResourceTemplatesRequestSchema = ListResourceTemplatesRequest ++ReadResourceRequestSchema = ReadResourceRequest ++ ++ ++def create_mcp_server(ai: Genkit, options: McpServerOptions) -> McpServer: ++ """Create an MCP server based on the supplied Genkit instance. ++ ++ All tools, prompts, and resources will be automatically converted to MCP compatibility. ++ ++ Args: ++ ai: Your Genkit instance with registered tools, prompts, and resources. ++ options: Configuration metadata for the server. ++ ++ Returns: ++ GenkitMcpServer instance. ++ ++ Example: ++ ```python ++ from genkit.ai import Genkit ++ from genkit.plugins.mcp import create_mcp_server, McpServerOptions ++ ++ ai = Genkit() ++ ++ ++ # Define some tools and resources ++ @ai.tool() ++ def add(a: int, b: int) -> int: ++ return a + b ++ ++ ++ ai.define_resource(name='my_resource', uri='my://resource', fn=lambda req: {'content': [{'text': 'resource content'}]}) ++ ++ # Create and start MCP server ++ server = create_mcp_server(ai, McpServerOptions(name='my-server')) ++ await server.start() ++ ``` ++ """ ++ return McpServer(ai, options) +diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/util/__init__.py b/py/plugins/mcp/src/genkit/plugins/mcp/util/__init__.py +new file mode 100644 +index 000000000..a45dae463 +--- /dev/null ++++ b/py/plugins/mcp/src/genkit/plugins/mcp/util/__init__.py +@@ -0,0 +1,58 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++""" ++Utility functions for MCP plugin. ++ ++This module contains helper functions for: ++- Tool conversion and registration ++- Prompt conversion and rendering ++- Resource handling ++- Message mapping between Genkit and MCP formats ++- Transport utilities ++""" ++ ++from .message import from_mcp_part, from_mcp_prompt_message, to_mcp_prompt_message ++from .prompts import convert_mcp_prompt_messages, convert_prompt_arguments_to_schema, to_mcp_prompt_arguments, to_schema ++from .resource import ( ++ convert_resource_to_genkit_part, ++ from_mcp_resource_part, ++ process_resource_content, ++ to_mcp_resource_contents, ++) ++from .tools import convert_tool_schema, process_result, process_tool_result, to_mcp_tool_result, to_text ++from .transport import create_stdio_params, transport_from ++ ++__all__ = [ ++ 'process_tool_result', ++ 'process_result', ++ 'to_text', ++ 'convert_tool_schema', ++ 'convert_prompt_arguments_to_schema', ++ 'convert_mcp_prompt_messages', ++ 'to_schema', ++ 'from_mcp_prompt_message', ++ 'from_mcp_part', ++ 'process_resource_content', ++ 'convert_resource_to_genkit_part', ++ 'from_mcp_resource_part', ++ 'create_stdio_params', ++ 'transport_from', ++ 'to_mcp_prompt_message', ++ 'to_mcp_resource_contents', ++ 'to_mcp_tool_result', ++ 'to_mcp_prompt_arguments', ++] +diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/util/message.py b/py/plugins/mcp/src/genkit/plugins/mcp/util/message.py +new file mode 100644 +index 000000000..97de8a4e9 +--- /dev/null ++++ b/py/plugins/mcp/src/genkit/plugins/mcp/util/message.py +@@ -0,0 +1,169 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++""" ++Message utilities for MCP plugin. ++ ++This module contains helper functions for converting between MCP message ++formats and Genkit message formats. ++""" ++ ++from typing import Any, Dict ++ ++import structlog ++ ++from genkit.core.typing import Message ++from mcp.types import ImageContent, PromptMessage, TextContent ++ ++logger = structlog.get_logger(__name__) ++ ++# Role mapping from MCP to Genkit ++ROLE_MAP = { ++ 'user': 'user', ++ 'assistant': 'model', ++} ++ ++ ++def from_mcp_prompt_message(message: Dict[str, Any]) -> Dict[str, Any]: ++ """ ++ Convert MCP PromptMessage to Genkit MessageData format. ++ ++ This involves mapping MCP roles (user, assistant) to Genkit roles (user, model) ++ and transforming the MCP content part into a Genkit Part. ++ ++ Args: ++ message: MCP PromptMessage with 'role' and 'content' fields ++ ++ Returns: ++ Genkit MessageData object with 'role' and 'content' fields ++ """ ++ return { ++ 'role': ROLE_MAP.get(message.get('role', 'user'), 'user'), ++ 'content': [from_mcp_part(message.get('content', {}))], ++ } ++ ++ ++def from_mcp_part(part: Dict[str, Any]) -> Dict[str, Any]: ++ """ ++ Convert MCP message content part to Genkit Part. ++ ++ Handles different content types: ++ - Text parts are directly mapped ++ - Image parts are converted to Genkit media parts with data URL ++ - Resource parts are mapped to Genkit resource format ++ ++ Args: ++ part: MCP PromptMessage content part ++ ++ Returns: ++ Genkit Part object ++ """ ++ part_type = part.get('type', '') ++ ++ if part_type == 'text': ++ return {'text': part.get('text', '')} ++ ++ elif part_type == 'image': ++ mime_type = part.get('mimeType', 'image/png') ++ data = part.get('data', '') ++ return { ++ 'media': { ++ 'contentType': mime_type, ++ 'url': f'data:{mime_type};base64,{data}', ++ } ++ } ++ ++ elif part_type == 'resource': ++ return { ++ 'resource': { ++ 'uri': str(part.get('uri', '')), ++ } ++ } ++ ++ # Default case for unknown types ++ return {} ++ ++ ++def _get_part_data(part: Any) -> Dict[str, Any]: ++ """Extract data from a Part, handling potential 'root' nesting.""" ++ if isinstance(part, str): ++ return {'text': part} ++ part_dict = part if isinstance(part, dict) else part.model_dump() ++ if 'root' in part_dict and isinstance(part_dict['root'], dict): ++ return part_dict['root'] ++ return part_dict ++ ++ ++def _parse_media_part(media: Dict[str, Any]) -> ImageContent: ++ """Extract MIME type and base64 data from a media part.""" ++ url = media.get('url', '') ++ content_type = media.get('contentType', '') ++ ++ if not url.startswith('data:'): ++ raise ValueError('MCP prompt messages only support base64 data images.') ++ ++ # Extract MIME type and base64 data ++ try: ++ mime_type = content_type or url[url.index(':') + 1 : url.index(';')] ++ data = url[url.index(',') + 1 :] ++ except ValueError as e: ++ raise ValueError(f'Invalid data URL format: {url}') from e ++ ++ return ImageContent(type='image', data=data, mimeType=mime_type) ++ ++ ++def to_mcp_prompt_message(message: Message) -> PromptMessage: ++ """Convert a Genkit Message to an MCP PromptMessage. ++ ++ MCP only supports 'user' and 'assistant' roles. Genkit's 'model' role ++ is mapped to 'assistant'. ++ ++ Args: ++ message: The Genkit Message to convert. ++ ++ Returns: ++ An MCP PromptMessage. ++ ++ Raises: ++ ValueError: If the message role is not 'user' or 'model'. ++ ValueError: If media is not a base64 data URL. ++ """ ++ # Map Genkit roles to MCP roles ++ role_map = {'model': 'assistant', 'user': 'user'} ++ ++ if message.role not in role_map: ++ raise ValueError( ++ f"MCP prompt messages do not support role '{message.role}'. Only 'user' and 'model' messages are supported." ++ ) ++ ++ mcp_role = role_map[message.role] ++ ++ # First, look for any media content as MCP content is currently single-part ++ if message.content: ++ for part in message.content: ++ data = _get_part_data(part) ++ if data.get('media'): ++ return PromptMessage(role=mcp_role, content=_parse_media_part(data['media'])) ++ ++ # If no media, aggregate all text content ++ text_content = [] ++ if message.content: ++ for part in message.content: ++ data = _get_part_data(part) ++ if data.get('text'): ++ text_content.append(data['text']) ++ ++ return PromptMessage(role=mcp_role, content=TextContent(type='text', text=''.join(text_content))) +diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/util/prompts.py b/py/plugins/mcp/src/genkit/plugins/mcp/util/prompts.py +new file mode 100644 +index 000000000..469e91f7e +--- /dev/null ++++ b/py/plugins/mcp/src/genkit/plugins/mcp/util/prompts.py +@@ -0,0 +1,137 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++""" ++Prompt utilities for MCP plugin. ++ ++This module contains helper functions for converting between MCP prompts ++and Genkit prompts, including schema and message conversion. ++""" ++ ++from typing import Any, Dict, List, Optional ++ ++import structlog ++ ++from mcp.types import GetPromptResult, Prompt ++ ++logger = structlog.get_logger(__name__) ++ ++ ++def to_schema(arguments: Optional[List[Dict[str, Any]]]) -> Dict[str, Any]: ++ """ ++ Convert MCP prompt arguments to JSON schema format. ++ ++ Args: ++ arguments: List of MCP prompt argument definitions with 'name', ++ 'description', and 'required' fields ++ ++ Returns: ++ JSON schema representing the prompt arguments ++ """ ++ if not arguments: ++ return {} ++ ++ schema: Dict[str, Any] = {'type': 'object', 'properties': {}, 'required': []} ++ ++ for arg in arguments: ++ arg_name = arg.get('name', '') ++ schema['properties'][arg_name] = { ++ 'type': 'string', ++ 'description': arg.get('description', ''), ++ } ++ if arg.get('required', False): ++ schema['required'].append(arg_name) ++ ++ return schema ++ ++ ++def convert_prompt_arguments_to_schema(arguments: List[Any]) -> Dict[str, Any]: ++ """ ++ Convert MCP prompt arguments to JSON schema format. ++ ++ This is an alias for to_schema() for backwards compatibility. ++ ++ Args: ++ arguments: List of MCP prompt argument definitions ++ ++ Returns: ++ JSON schema representing the prompt arguments ++ """ ++ return to_schema(arguments) ++ ++ ++def convert_mcp_prompt_messages(prompt_result: GetPromptResult) -> List[Dict[str, Any]]: ++ """ ++ Convert MCP prompt messages to Genkit message format. ++ ++ Args: ++ prompt_result: The GetPromptResult from MCP server containing messages ++ ++ Returns: ++ List of Genkit-formatted messages ++ """ ++ from .message import from_mcp_prompt_message ++ ++ if not hasattr(prompt_result, 'messages') or not prompt_result.messages: ++ return [] ++ ++ return [from_mcp_prompt_message(msg) for msg in prompt_result.messages] ++ ++ ++def to_mcp_prompt_arguments(input_schema: dict[str, Any] | None) -> list[dict[str, Any]] | None: ++ """Convert Genkit input schema to MCP prompt arguments. ++ ++ MCP prompts only support string arguments. This function validates that ++ all properties in the schema are strings. ++ ++ Args: ++ input_schema: The Genkit input JSON schema. ++ ++ Returns: ++ List of MCP prompt argument definitions, or None if no schema. ++ ++ Raises: ++ ValueError: If the schema is not an object type. ++ ValueError: If any property is not a string type. ++ """ ++ if not input_schema: ++ return None ++ ++ if not input_schema.get('properties'): ++ raise ValueError('MCP prompts must take objects with properties as input schema.') ++ ++ args: list[dict[str, Any]] = [] ++ properties = input_schema['properties'] ++ required = input_schema.get('required', []) ++ ++ for name, prop in properties.items(): ++ prop_type = prop.get('type') ++ ++ # Check if type is string or includes string (for union types) ++ is_string = prop_type == 'string' or (isinstance(prop_type, list) and 'string' in prop_type) ++ ++ if not is_string: ++ raise ValueError( ++ f"MCP prompts may only take string arguments, but property '{name}' has type '{prop_type}'." ++ ) ++ ++ args.append({ ++ 'name': name, ++ 'description': prop.get('description'), ++ 'required': name in required, ++ }) ++ ++ return args +diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/util/resource.py b/py/plugins/mcp/src/genkit/plugins/mcp/util/resource.py +new file mode 100644 +index 000000000..3015d609d +--- /dev/null ++++ b/py/plugins/mcp/src/genkit/plugins/mcp/util/resource.py +@@ -0,0 +1,149 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++""" ++Resource utilities for MCP plugin. ++ ++This module contains helper functions for handling MCP resources, ++including reading and converting resource content. ++""" ++ ++from typing import Any, Dict ++ ++import structlog ++ ++from genkit.core.typing import Part ++from mcp.types import BlobResourceContents, ReadResourceResult, Resource, TextResourceContents ++ ++logger = structlog.get_logger(__name__) ++ ++ ++def from_mcp_resource_part(content: Dict[str, Any]) -> Dict[str, Any]: ++ """ ++ Convert MCP resource content to Genkit Part format. ++ ++ Handles different content types: ++ - Text content is mapped to text part ++ - Blob content is mapped to media part with base64 data ++ ++ Args: ++ content: MCP resource content part ++ ++ Returns: ++ Genkit Part representation ++ """ ++ content_type = content.get('type', '') ++ ++ if content_type == 'text': ++ return {'text': content.get('text', '')} ++ ++ elif content_type == 'blob': ++ mime_type = content.get('mimeType', 'application/octet-stream') ++ blob_data = content.get('blob', '') ++ return { ++ 'media': { ++ 'contentType': mime_type, ++ 'url': f'data:{mime_type};base64,{blob_data}', ++ } ++ } ++ ++ # Default case ++ return {'text': str(content)} ++ ++ ++def process_resource_content(resource_result: ReadResourceResult) -> Any: ++ """ ++ Process MCP ReadResourceResult and extract content. ++ ++ Args: ++ resource_result: The ReadResourceResult from MCP server ++ ++ Returns: ++ Extracted resource content as Genkit Parts ++ """ ++ if not hasattr(resource_result, 'contents') or not resource_result.contents: ++ return [] ++ ++ return [from_mcp_resource_part(content) for content in resource_result.contents] ++ ++ ++def convert_resource_to_genkit_part(resource: Resource) -> dict[str, Any]: ++ """ ++ Convert MCP resource to Genkit Part format. ++ ++ Args: ++ resource: MCP resource object ++ ++ Returns: ++ Genkit Part representation with resource URI ++ """ ++ return { ++ 'resource': { ++ 'uri': resource.uri, ++ 'name': resource.name, ++ 'description': resource.description if hasattr(resource, 'description') else None, ++ } ++ } ++ ++ ++def to_mcp_resource_contents(uri: str, parts: list[Part]) -> list[TextResourceContents | BlobResourceContents]: ++ """Convert Genkit Parts to MCP resource contents. ++ ++ Args: ++ uri: The URI of the resource. ++ parts: List of Genkit Parts to convert. ++ ++ Returns: ++ List of MCP resource contents (text or blob). ++ ++ Raises: ++ ValueError: If media is not a base64 data URL. ++ ValueError: If part type is not supported. ++ """ ++ contents: list[TextResourceContents | BlobResourceContents] = [] ++ ++ for part in parts: ++ if isinstance(part, dict): ++ # Handle media/image content ++ if 'media' in part: ++ media = part['media'] ++ url = media.get('url', '') ++ content_type = media.get('contentType', '') ++ ++ if not url.startswith('data:'): ++ raise ValueError('MCP resource messages only support base64 data images.') ++ ++ # Extract MIME type and base64 data ++ try: ++ mime_type = content_type or url[url.index(':') + 1 : url.index(';')] ++ blob_data = url[url.index(',') + 1 :] ++ except ValueError as e: ++ raise ValueError(f'Invalid data URL format: {url}') from e ++ ++ contents.append(BlobResourceContents(uri=uri, mimeType=mime_type, blob=blob_data)) ++ ++ # Handle text content ++ elif 'text' in part: ++ contents.append(TextResourceContents(uri=uri, text=part['text'])) ++ else: ++ raise ValueError( ++ f'MCP resource messages only support media and text parts. ' ++ f'Unsupported part type: {list(part.keys())}' ++ ) ++ elif isinstance(part, str): ++ contents.append(TextResourceContents(uri=uri, text=part)) ++ ++ return contents +diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/util/tools.py b/py/plugins/mcp/src/genkit/plugins/mcp/util/tools.py +new file mode 100644 +index 000000000..5d2662c02 +--- /dev/null ++++ b/py/plugins/mcp/src/genkit/plugins/mcp/util/tools.py +@@ -0,0 +1,144 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++""" ++Tool utilities for MCP plugin. ++ ++This module contains helper functions for converting between MCP tools ++and Genkit actions, processing tool results, and registering tools. ++""" ++ ++import json ++from typing import Any, Dict, List, Union ++ ++import structlog ++ ++from mcp.types import CallToolResult, ImageContent, TextContent, Tool ++ ++logger = structlog.get_logger(__name__) ++ ++ ++def to_text(content: List[Dict[str, Any]]) -> str: ++ """ ++ Extract text from MCP CallToolResult content. ++ ++ Args: ++ content: List of content parts from CallToolResult ++ ++ Returns: ++ Concatenated text from all text parts ++ """ ++ return ''.join(part.get('text', '') for part in content) ++ ++ ++def process_result(result: CallToolResult) -> Any: ++ """ ++ Process MCP CallToolResult and extract/parse content. ++ ++ Handles different result types: ++ - Error results return error dict ++ - Text-only results attempt JSON parsing ++ - Single content results return the content directly ++ - Otherwise returns the full result ++ ++ Args: ++ result: The CallToolResult from MCP server ++ ++ Returns: ++ Processed result (parsed JSON, text, or raw content) ++ ++ Raises: ++ RuntimeError: If the tool execution failed (isError=True) ++ """ ++ if result.isError: ++ return {'error': to_text(result.content)} ++ ++ # Check if all content parts are text ++ if all(hasattr(c, 'text') and c.text for c in result.content): ++ text = to_text(result.content) ++ # Try to parse as JSON if it looks like JSON ++ text_stripped = text.strip() ++ if text_stripped.startswith('{') or text_stripped.startswith('['): ++ try: ++ return json.loads(text) ++ except (json.JSONDecodeError, ValueError): ++ return text ++ return text ++ ++ # Single content item ++ if len(result.content) == 1: ++ return result.content[0] ++ ++ # Return full result for complex cases ++ return result ++ ++ ++def process_tool_result(result: CallToolResult) -> Any: ++ """ ++ Process MCP CallToolResult and extract content. ++ ++ This is an alias for process_result() for backwards compatibility. ++ ++ Args: ++ result: The CallToolResult from MCP server ++ ++ Returns: ++ Extracted text content from the result ++ ++ Raises: ++ RuntimeError: If the tool execution failed ++ """ ++ return process_result(result) ++ ++ ++def convert_tool_schema(mcp_schema: Dict[str, Any]) -> Dict[str, Any]: ++ """ ++ Convert MCP tool input schema (JSONSchema7) to Genkit format. ++ ++ Args: ++ mcp_schema: MCP tool input schema ++ ++ Returns: ++ Genkit-compatible JSON schema ++ ++ Note: ++ Currently returns the schema as-is since both use JSON Schema. ++ Future enhancements may add validation or transformation. ++ """ ++ # MCP and Genkit both use JSON Schema, so minimal conversion needed ++ return mcp_schema ++ ++ ++def to_mcp_tool_result(result: Any) -> list[TextContent | ImageContent]: ++ """Convert tool execution result to MCP CallToolResult content. ++ ++ Args: ++ result: The result from tool execution (can be string, dict, or other). ++ ++ Returns: ++ List of MCP content items (TextContent or ImageContent). ++ """ ++ if isinstance(result, str): ++ return [TextContent(type='text', text=result)] ++ elif isinstance(result, dict): ++ # If it's already in MCP format, return as-is ++ if 'type' in result and 'text' in result: ++ return [TextContent(type='text', text=result['text'])] ++ # Otherwise, serialize to JSON ++ return [TextContent(type='text', text=json.dumps(result))] ++ else: ++ # Convert to string for other types ++ return [TextContent(type='text', text=str(result))] +diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/util/transport.py b/py/plugins/mcp/src/genkit/plugins/mcp/util/transport.py +new file mode 100644 +index 000000000..c065cd5d0 +--- /dev/null ++++ b/py/plugins/mcp/src/genkit/plugins/mcp/util/transport.py +@@ -0,0 +1,89 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++""" ++Transport utilities for MCP plugin. ++ ++This module contains helper functions for creating and managing ++MCP transport connections (stdio, SSE, custom). ++""" ++ ++from typing import Any, Dict, Optional, Tuple ++ ++import structlog ++ ++from mcp import StdioServerParameters ++ ++logger = structlog.get_logger(__name__) ++ ++ ++def create_stdio_params( ++ command: str, args: Optional[list] = None, env: Optional[Dict[str, str]] = None ++) -> StdioServerParameters: ++ """ ++ Create StdioServerParameters for MCP connection. ++ ++ Args: ++ command: Command to execute ++ args: Command arguments ++ env: Environment variables ++ ++ Returns: ++ StdioServerParameters object ++ """ ++ return StdioServerParameters(command=command, args=args or [], env=env) ++ ++ ++async def transport_from(config: Dict[str, Any], session_id: Optional[str] = None) -> Tuple[Any, str]: ++ """ ++ Create an MCP transport instance based on the provided server configuration. ++ ++ Supports creating SSE, Stdio, or using a pre-configured custom transport. ++ ++ Args: ++ config: Configuration for the MCP server ++ session_id: Optional session ID for HTTP transport ++ ++ Returns: ++ Tuple of (transport instance or None, transport type string) ++ ++ Note: ++ This function mirrors the JS SDK's transportFrom() function. ++ """ ++ # Handle pre-configured transport first ++ if 'transport' in config and config['transport']: ++ return (config['transport'], 'custom') ++ ++ # Handle SSE/HTTP config ++ if 'url' in config and config['url']: ++ try: ++ # Dynamic import to avoid hard dependency ++ from mcp.client.sse import sse_client ++ ++ # Note: Python MCP SDK may have different SSE client API ++ # This is a placeholder that matches the pattern ++ logger.info(f'Creating SSE transport for URL: {config["url"]}') ++ return (config['url'], 'http') # Simplified for now ++ except ImportError: ++ logger.warning('SSE client not available') ++ return (None, 'http') ++ ++ # Handle Stdio config ++ if 'command' in config and config['command']: ++ stdio_params = create_stdio_params(command=config['command'], args=config.get('args'), env=config.get('env')) ++ return (stdio_params, 'stdio') ++ ++ return (None, 'unknown') +diff --git a/py/plugins/mcp/tests/fakes.py b/py/plugins/mcp/tests/fakes.py +new file mode 100644 +index 000000000..356337f11 +--- /dev/null ++++ b/py/plugins/mcp/tests/fakes.py +@@ -0,0 +1,128 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++import asyncio ++import json ++import sys ++from typing import Any, Callable, Dict, List, Optional ++from unittest.mock import MagicMock ++ ++from genkit.ai import Genkit ++from genkit.core.action.types import ActionKind ++ ++ ++class MockSchema: ++ def __init__(self, **kwargs): ++ for k, v in kwargs.items(): ++ setattr(self, k, v) ++ ++ ++def mock_mcp_modules(): ++ """Sets up comprehensive MCP mocks in sys.modules.""" ++ mock_mcp = MagicMock() ++ sys.modules['mcp'] = mock_mcp ++ sys.modules['mcp'].__path__ = [] ++ ++ types_mock = MagicMock() ++ sys.modules['mcp.types'] = types_mock ++ types_mock.ListToolsResult = MockSchema ++ types_mock.CallToolResult = MockSchema ++ types_mock.ListPromptsResult = MockSchema ++ types_mock.GetPromptResult = MockSchema ++ types_mock.ListResourcesResult = MockSchema ++ types_mock.ListResourceTemplatesResult = MockSchema ++ types_mock.ReadResourceResult = MockSchema ++ types_mock.Tool = MockSchema ++ types_mock.Prompt = MockSchema ++ types_mock.Resource = MockSchema ++ types_mock.ResourceTemplate = MockSchema ++ types_mock.TextContent = MockSchema ++ types_mock.PromptMessage = MockSchema ++ types_mock.TextResourceContents = MockSchema ++ types_mock.BlobResourceContents = MockSchema ++ types_mock.ImageContent = MockSchema ++ ++ sys.modules['mcp.server'] = MagicMock() ++ sys.modules['mcp.server.stdio'] = MagicMock() ++ sys.modules['mcp.client'] = MagicMock() ++ sys.modules['mcp.client'].__path__ = [] ++ sys.modules['mcp.client.stdio'] = MagicMock() ++ sys.modules['mcp.client.sse'] = MagicMock() ++ sys.modules['mcp.server.sse'] = MagicMock() ++ ++ return mock_mcp, types_mock ++ ++ ++def define_echo_model(ai: Genkit): ++ """Defines a fake echo model for testing.""" ++ ++ @ai.tool(name='echoModel') ++ def echo_model(request: Any): ++ # This is a simplified mock of a model action ++ # Real model action would handle GenerateRequest and return GenerateResponse ++ ++ # logic to echo content ++ # For now, just a placeholder as we generally mock the model execution in tests ++ pass ++ ++ # In real usage, we would define a Model action properly. ++ # For unit tests here, we might not strictly need the full model implementation ++ # if we are mocking the generation or call. ++ # But matching JS behavior: ++ # JS defines 'echoModel' which returns "Echo: " + input. ++ ++ # We can use ai.define_model if available or just mock it. ++ pass ++ ++ ++class FakeTransport: ++ """Fakes an MCP transport/server for testing.""" ++ ++ def __init__(self): ++ self.tools = [] ++ self.prompts = [] ++ self.resources = [] ++ self.resource_templates = [] ++ self.call_tool_result = None ++ self.get_prompt_result = None ++ self.read_resource_result = None ++ self.roots = [] ++ ++ # Callbacks that would simulate transport behavior ++ self.on_message = None ++ self.on_close = None ++ self.on_error = None ++ ++ async def start(self): ++ pass ++ ++ async def send(self, message: Dict[str, Any]): ++ """Handle incoming JSON-RPC message (simulating server).""" ++ request = message ++ # msg_id = request.get("id") ++ ++ # In a real transport we'd write back to the stream. ++ # Here we just store handling logic or print. ++ # Since we are mocking the ClientSession in our python tests, ++ # this logic might need to be hooked up to the mock session's methods. ++ pass ++ ++ # Helper methods to populate the fake state ++ def add_tool(self, name: str, description: str = '', schema: Dict = None): ++ self.tools.append({'name': name, 'description': description, 'inputSchema': schema or {'type': 'object'}}) ++ ++ def add_prompt(self, name: str, description: str = '', arguments: List = None): ++ self.prompts.append({'name': name, 'description': description, 'arguments': arguments or []}) +diff --git a/py/plugins/mcp/tests/test_mcp_conversion.py b/py/plugins/mcp/tests/test_mcp_conversion.py +new file mode 100644 +index 000000000..926f94b69 +--- /dev/null ++++ b/py/plugins/mcp/tests/test_mcp_conversion.py +@@ -0,0 +1,259 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++"""Tests for MCP conversion utilities.""" ++ ++import os ++import sys ++import unittest ++ ++sys.path.insert(0, os.path.dirname(__file__)) ++sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src'))) ++from fakes import mock_mcp_modules ++ ++mock_mcp_modules() ++ ++from genkit.core.typing import Message ++from genkit.plugins.mcp.util import ( ++ to_mcp_prompt_arguments, ++ to_mcp_prompt_message, ++ to_mcp_resource_contents, ++ to_mcp_tool_result, ++) ++ ++ ++class TestMessageConversion(unittest.TestCase): ++ """Tests for message conversion utilities.""" ++ ++ def test_convert_user_message(self): ++ """Test converting a user message.""" ++ message = Message(role='user', content=[{'text': 'Hello, world!'}]) ++ ++ result = to_mcp_prompt_message(message) ++ ++ self.assertEqual(result.role, 'user') ++ self.assertEqual(result.content.type, 'text') ++ self.assertEqual(result.content.text, 'Hello, world!') ++ ++ def test_convert_model_message(self): ++ """Test converting a model message (maps to assistant).""" ++ message = Message(role='model', content=[{'text': 'Hi there!'}]) ++ ++ result = to_mcp_prompt_message(message) ++ ++ self.assertEqual(result.role, 'assistant') ++ self.assertEqual(result.content.type, 'text') ++ self.assertEqual(result.content.text, 'Hi there!') ++ ++ def test_convert_message_with_multiple_text_parts(self): ++ """Test converting a message with multiple text parts.""" ++ message = Message(role='user', content=[{'text': 'Part 1 '}, {'text': 'Part 2 '}, {'text': 'Part 3'}]) ++ ++ result = to_mcp_prompt_message(message) ++ ++ self.assertEqual(result.content.text, 'Part 1 Part 2 Part 3') ++ ++ def test_convert_message_with_invalid_role(self): ++ """Test that converting a message with invalid role raises error.""" ++ message = Message(role='system', content=[{'text': 'System message'}]) ++ ++ with self.assertRaises(ValueError) as context: ++ to_mcp_prompt_message(message) ++ ++ self.assertIn('system', str(context.exception).lower()) ++ ++ def test_convert_message_with_image(self): ++ """Test converting a message with image content.""" ++ message = Message( ++ role='user', content=[{'media': {'url': 'data:image/png;base64,iVBORw0KG...', 'contentType': 'image/png'}}] ++ ) ++ ++ result = to_mcp_prompt_message(message) ++ ++ self.assertEqual(result.role, 'user') ++ self.assertEqual(result.content.type, 'image') ++ self.assertEqual(result.content.mimeType, 'image/png') ++ ++ def test_convert_message_with_non_data_url_fails(self): ++ """Test that non-data URLs raise an error.""" ++ message = Message(role='user', content=[{'media': {'url': 'http://example.com/image.png'}}]) ++ ++ with self.assertRaises(ValueError) as context: ++ to_mcp_prompt_message(message) ++ ++ self.assertIn('base64', str(context.exception).lower()) ++ ++ ++class TestResourceConversion(unittest.TestCase): ++ """Tests for resource content conversion.""" ++ ++ def test_convert_text_resource(self): ++ """Test converting text resource content.""" ++ parts = [{'text': 'Resource content'}] ++ ++ result = to_mcp_resource_contents('test://resource', parts) ++ ++ self.assertEqual(len(result), 1) ++ self.assertEqual(result[0].uri, 'test://resource') ++ self.assertEqual(result[0].text, 'Resource content') ++ ++ def test_convert_multiple_text_parts(self): ++ """Test converting multiple text parts.""" ++ parts = [{'text': 'Part 1'}, {'text': 'Part 2'}, {'text': 'Part 3'}] ++ ++ result = to_mcp_resource_contents('test://resource', parts) ++ ++ self.assertEqual(len(result), 3) ++ for i, part in enumerate(result, 1): ++ self.assertEqual(part.text, f'Part {i}') ++ ++ def test_convert_string_parts(self): ++ """Test converting string parts.""" ++ parts = ['Text 1', 'Text 2'] ++ ++ result = to_mcp_resource_contents('test://resource', parts) ++ ++ self.assertEqual(len(result), 2) ++ self.assertEqual(result[0].text, 'Text 1') ++ self.assertEqual(result[1].text, 'Text 2') ++ ++ def test_convert_media_resource(self): ++ """Test converting media resource content.""" ++ parts = [{'media': {'url': 'data:image/png;base64,abc123', 'contentType': 'image/png'}}] ++ ++ result = to_mcp_resource_contents('test://image', parts) ++ ++ self.assertEqual(len(result), 1) ++ self.assertEqual(result[0].uri, 'test://image') ++ self.assertEqual(result[0].mimeType, 'image/png') ++ self.assertEqual(result[0].blob, 'abc123') ++ ++ def test_convert_mixed_content(self): ++ """Test converting mixed text and media content.""" ++ parts = [{'text': 'Description'}, {'media': {'url': 'data:image/png;base64,xyz', 'contentType': 'image/png'}}] ++ ++ result = to_mcp_resource_contents('test://mixed', parts) ++ ++ self.assertEqual(len(result), 2) ++ self.assertEqual(result[0].text, 'Description') ++ self.assertEqual(result[1].blob, 'xyz') ++ ++ ++class TestToolResultConversion(unittest.TestCase): ++ """Tests for tool result conversion.""" ++ ++ def test_convert_string_result(self): ++ """Test converting string result.""" ++ result = to_mcp_tool_result('Hello, world!') ++ ++ self.assertEqual(len(result), 1) ++ self.assertEqual(result[0].type, 'text') ++ self.assertEqual(result[0].text, 'Hello, world!') ++ ++ def test_convert_dict_result(self): ++ """Test converting dict result.""" ++ result = to_mcp_tool_result({'key': 'value', 'number': 42}) ++ ++ self.assertEqual(len(result), 1) ++ self.assertEqual(result[0].type, 'text') ++ # Should be JSON serialized ++ import json ++ ++ parsed = json.loads(result[0].text) ++ self.assertEqual(parsed['key'], 'value') ++ self.assertEqual(parsed['number'], 42) ++ ++ def test_convert_number_result(self): ++ """Test converting number result.""" ++ result = to_mcp_tool_result(42) ++ ++ self.assertEqual(len(result), 1) ++ self.assertEqual(result[0].text, '42') ++ ++ def test_convert_boolean_result(self): ++ """Test converting boolean result.""" ++ result = to_mcp_tool_result(True) ++ ++ self.assertEqual(len(result), 1) ++ self.assertEqual(result[0].text, 'True') ++ ++ ++class TestSchemaConversion(unittest.TestCase): ++ """Tests for schema conversion utilities.""" ++ ++ def test_convert_simple_schema(self): ++ """Test converting simple string schema.""" ++ schema = {'type': 'object', 'properties': {'name': {'type': 'string', 'description': 'User name'}}} ++ ++ result = to_mcp_prompt_arguments(schema) ++ ++ self.assertIsNotNone(result) ++ self.assertEqual(len(result), 1) ++ self.assertEqual(result[0]['name'], 'name') ++ self.assertEqual(result[0]['description'], 'User name') ++ ++ def test_convert_schema_with_required(self): ++ """Test converting schema with required fields.""" ++ schema = { ++ 'type': 'object', ++ 'properties': {'name': {'type': 'string'}, 'age': {'type': 'string'}}, ++ 'required': ['name'], ++ } ++ ++ result = to_mcp_prompt_arguments(schema) ++ ++ name_arg = next(arg for arg in result if arg['name'] == 'name') ++ age_arg = next(arg for arg in result if arg['name'] == 'age') ++ ++ self.assertTrue(name_arg['required']) ++ self.assertFalse(age_arg['required']) ++ ++ def test_convert_schema_with_non_string_fails(self): ++ """Test that non-string properties raise an error.""" ++ schema = {'type': 'object', 'properties': {'count': {'type': 'number'}}} ++ ++ with self.assertRaises(ValueError) as context: ++ to_mcp_prompt_arguments(schema) ++ ++ self.assertIn('string', str(context.exception).lower()) ++ ++ def test_convert_schema_with_union_type(self): ++ """Test converting schema with union type including string.""" ++ schema = {'type': 'object', 'properties': {'value': {'type': ['string', 'null']}}} ++ ++ result = to_mcp_prompt_arguments(schema) ++ ++ # Should succeed because string is in the union ++ self.assertEqual(len(result), 1) ++ ++ def test_convert_none_schema(self): ++ """Test converting None schema.""" ++ result = to_mcp_prompt_arguments(None) ++ ++ self.assertIsNone(result) ++ ++ def test_convert_schema_without_properties_fails(self): ++ """Test that schema without properties raises an error.""" ++ schema = {'type': 'object'} ++ ++ with self.assertRaises(ValueError) as context: ++ to_mcp_prompt_arguments(schema) ++ ++ self.assertIn('properties', str(context.exception).lower()) ++ ++ ++if __name__ == '__main__': ++ unittest.main() +diff --git a/py/plugins/mcp/tests/test_mcp_host.py b/py/plugins/mcp/tests/test_mcp_host.py +new file mode 100644 +index 000000000..10d995b7d +--- /dev/null ++++ b/py/plugins/mcp/tests/test_mcp_host.py +@@ -0,0 +1,64 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++import os ++import sys ++from unittest.mock import AsyncMock, MagicMock ++ ++sys.path.insert(0, os.path.dirname(__file__)) ++sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src'))) ++from fakes import mock_mcp_modules ++ ++mock_mcp_modules() ++ ++import unittest ++from unittest.mock import patch ++ ++from genkit.ai import Genkit ++from genkit.core.action.types import ActionKind ++ ++# Now import plugin ++from genkit.plugins.mcp import McpClient, McpHost, McpServerConfig, create_mcp_host ++ ++ ++class TestMcpHost(unittest.IsolatedAsyncioTestCase): ++ async def test_connect_and_register(self): ++ # Setup configs ++ config1 = McpServerConfig(command='echo') ++ config2 = McpServerConfig(url='http://localhost:8000') ++ ++ host = create_mcp_host({'server1': config1, 'server2': config2}) ++ ++ # Mock clients within host ++ with patch('genkit.plugins.mcp.client.client.McpClient.connect', new_callable=AsyncMock) as mock_connect: ++ await host.start() ++ self.assertEqual(mock_connect.call_count, 2) ++ ++ # Mock session for registration ++ host.clients['server1'].session = AsyncMock() ++ mock_tool = MagicMock() ++ mock_tool.name = 'tool1' ++ host.clients['server1'].session.list_tools.return_value.tools = [mock_tool] ++ ++ ai = MagicMock(spec=Genkit) ++ ai.registry = MagicMock() ++ ++ await host.register_tools(ai) ++ ++ # Verify tool registration ++ ai.registry.register_action.assert_called() ++ call_args = ai.registry.register_action.call_args[1] ++ self.assertIn('server1/tool1', call_args['name']) +diff --git a/py/plugins/mcp/tests/test_mcp_integration.py b/py/plugins/mcp/tests/test_mcp_integration.py +new file mode 100644 +index 000000000..d6045734e +--- /dev/null ++++ b/py/plugins/mcp/tests/test_mcp_integration.py +@@ -0,0 +1,311 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++"""Integration tests for MCP client-server communication.""" ++ ++import asyncio ++import os ++import sys ++import unittest ++from unittest.mock import AsyncMock, MagicMock, patch ++ ++sys.path.insert(0, os.path.dirname(__file__)) ++sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src'))) ++from fakes import mock_mcp_modules ++ ++mock_mcp_modules() ++ ++import pytest ++ ++from genkit.ai import Genkit ++from genkit.plugins.mcp import McpClient, McpHost, McpServerConfig, create_mcp_host, create_mcp_server ++ ++ ++@pytest.mark.asyncio ++class TestClientServerIntegration(unittest.IsolatedAsyncioTestCase): ++ """Integration tests for MCP client-server communication.""" ++ ++ async def test_client_can_list_server_tools(self): ++ """Test that a client can list tools from a server.""" ++ # Create server with tools ++ server_ai = Genkit() ++ ++ @server_ai.tool() ++ def add(a: int, b: int) -> int: ++ return a + b ++ ++ # Create client ++ client = McpClient(name='test-client', config=McpServerConfig(command='echo', args=['test'])) ++ ++ # Mock the session to return tools ++ mock_session = AsyncMock() ++ mock_tool = MagicMock() ++ mock_tool.name = 'add' ++ mock_tool.description = 'Add two numbers' ++ mock_tool.inputSchema = {'type': 'object'} ++ ++ mock_session.list_tools.return_value.tools = [mock_tool] ++ client.session = mock_session ++ ++ # List tools ++ tools = await client.list_tools() ++ ++ # Verify ++ self.assertEqual(len(tools), 1) ++ self.assertEqual(tools[0].name, 'add') ++ ++ async def test_client_can_call_server_tool(self): ++ """Test that a client can call a tool on a server.""" ++ # Create client ++ client = McpClient(name='test-client', config=McpServerConfig(command='echo')) ++ ++ # Mock the session ++ mock_session = AsyncMock() ++ mock_result = MagicMock() ++ mock_result.isError = False ++ mock_content = MagicMock() ++ mock_content.type = 'text' ++ mock_content.text = '8' ++ mock_result.content = [mock_content] ++ ++ mock_session.call_tool.return_value = mock_result ++ client.session = mock_session ++ ++ # Call tool ++ result = await client.call_tool('add', {'a': 5, 'b': 3}) ++ ++ # Verify ++ self.assertEqual(result, '8') ++ mock_session.call_tool.assert_called_once_with('add', {'a': 5, 'b': 3}) ++ ++ async def test_client_can_list_server_resources(self): ++ """Test that a client can list resources from a server.""" ++ # Create client ++ client = McpClient(name='test-client', config=McpServerConfig(command='echo')) ++ ++ # Mock the session ++ mock_session = AsyncMock() ++ mock_resource = MagicMock() ++ mock_resource.name = 'config' ++ mock_resource.uri = 'app://config' ++ mock_resource.description = 'Configuration' ++ ++ mock_session.list_resources.return_value.resources = [mock_resource] ++ client.session = mock_session ++ ++ # List resources ++ resources = await client.list_resources() ++ ++ # Verify ++ self.assertEqual(len(resources), 1) ++ self.assertEqual(resources[0].name, 'config') ++ self.assertEqual(resources[0].uri, 'app://config') ++ ++ async def test_client_can_read_server_resource(self): ++ """Test that a client can read a resource from a server.""" ++ # Create client ++ client = McpClient(name='test-client', config=McpServerConfig(command='echo')) ++ ++ # Mock the session ++ mock_session = AsyncMock() ++ mock_result = MagicMock() ++ mock_result.contents = [MagicMock(text='Resource content')] ++ ++ mock_session.read_resource.return_value = mock_result ++ client.session = mock_session ++ ++ # Read resource ++ result = await client.read_resource('app://config') ++ ++ # Verify ++ self.assertIsNotNone(result) ++ mock_session.read_resource.assert_called_once_with('app://config') ++ ++ async def test_host_manages_multiple_clients(self): ++ """Test that a host can manage multiple clients.""" ++ # Create host with multiple servers ++ config1 = McpServerConfig(command='server1') ++ config2 = McpServerConfig(command='server2') ++ ++ host = create_mcp_host({'server1': config1, 'server2': config2}) ++ ++ # Verify clients were created ++ self.assertEqual(len(host.clients), 2) ++ self.assertIn('server1', host.clients) ++ self.assertIn('server2', host.clients) ++ ++ async def test_host_can_register_tools_from_multiple_servers(self): ++ """Test that a host can register tools from multiple servers.""" ++ # Create host ++ host = create_mcp_host({'server1': McpServerConfig(command='s1'), 'server2': McpServerConfig(command='s2')}) ++ ++ # Mock sessions for both clients ++ for client_name, client in host.clients.items(): ++ mock_session = AsyncMock() ++ mock_tool = MagicMock() ++ mock_tool.name = f'{client_name}_tool' ++ mock_tool.description = f'Tool from {client_name}' ++ mock_tool.inputSchema = {'type': 'object'} ++ ++ mock_session.list_tools.return_value.tools = [mock_tool] ++ client.session = mock_session ++ ++ # Register tools ++ ai = Genkit() ++ await host.register_tools(ai) ++ ++ # Verify tools were registered ++ # Each client should have registered one tool ++ # Tool names should be prefixed with server name ++ ++ async def test_client_handles_disabled_server(self): ++ """Test that a client handles disabled servers correctly.""" ++ # Create client with disabled config ++ config = McpServerConfig(command='echo', disabled=True) ++ client = McpClient(name='test-client', config=config) ++ ++ # Try to connect ++ await client.connect() ++ ++ # Should not have a session ++ self.assertIsNone(client.session) ++ ++ async def test_host_can_disable_and_enable_clients(self): ++ """Test that a host can disable and enable clients.""" ++ host = create_mcp_host({'test': McpServerConfig(command='echo')}) ++ ++ # Mock the client ++ client = host.clients['test'] ++ client.session = AsyncMock() ++ client.close = AsyncMock() ++ client.connect = AsyncMock() ++ ++ # Disable ++ await host.disable('test') ++ self.assertTrue(client.config.disabled) ++ ++ # Enable ++ await host.enable('test') ++ self.assertFalse(client.config.disabled) ++ ++ ++@pytest.mark.asyncio ++class TestResourceIntegration(unittest.IsolatedAsyncioTestCase): ++ """Integration tests specifically for resource handling.""" ++ ++ async def test_end_to_end_resource_flow(self): ++ """Test complete flow: define resource → expose via server → consume via client.""" ++ # This is a conceptual test showing the flow ++ # In practice, we'd need actual MCP transport for true end-to-end ++ ++ # 1. Server side: Define resource ++ server_ai = Genkit() ++ server_ai.define_resource( ++ name='config', uri='app://config', fn=lambda req: {'content': [{'text': 'config data'}]} ++ ) ++ ++ # 2. Create MCP server ++ from genkit.plugins.mcp import McpServerOptions ++ ++ server = create_mcp_server(server_ai, McpServerOptions(name='test-server')) ++ await server.setup() ++ ++ # 3. Verify server can list resources ++ resources_result = await server.list_resources({}) ++ self.assertEqual(len(resources_result.resources), 1) ++ self.assertEqual(resources_result.resources[0].uri, 'app://config') ++ ++ # 4. Verify server can read resource ++ request = MagicMock() ++ request.params.uri = 'app://config' ++ read_result = await server.read_resource(request) ++ self.assertEqual(read_result.contents[0].text, 'config data') ++ ++ async def test_template_resource_matching(self): ++ """Test that template resources match correctly.""" ++ server_ai = Genkit() ++ ++ def file_resource(req): ++ uri = req.uri ++ return {'content': [{'text': f'Contents of {uri}'}]} ++ ++ server_ai.define_resource(name='file', template='file://{+path}', fn=file_resource) ++ ++ # Create server ++ from genkit.plugins.mcp import McpServerOptions ++ ++ server = create_mcp_server(server_ai, McpServerOptions(name='test-server')) ++ await server.setup() ++ ++ # List templates ++ templates_result = await server.list_resource_templates({}) ++ self.assertEqual(len(templates_result.resourceTemplates), 1) ++ self.assertEqual(templates_result.resourceTemplates[0].uriTemplate, 'file://{+path}') ++ ++ # Read with different URIs ++ for test_uri in ['file:///path/to/file.txt', 'file:///another/file.md', 'file:///deep/nested/path/doc.pdf']: ++ request = MagicMock() ++ request.params.uri = test_uri ++ result = await server.read_resource(request) ++ self.assertIn(test_uri, result.contents[0].text) ++ ++ ++@pytest.mark.asyncio ++class TestErrorHandling(unittest.IsolatedAsyncioTestCase): ++ """Tests for error handling in client-server communication.""" ++ ++ async def test_server_handles_missing_tool(self): ++ """Test that server properly handles requests for non-existent tools.""" ++ server_ai = Genkit() ++ ++ @server_ai.tool() ++ def existing_tool(x: int) -> int: ++ return x ++ ++ from genkit.plugins.mcp import McpServerOptions ++ ++ server = create_mcp_server(server_ai, McpServerOptions(name='test-server')) ++ await server.setup() ++ ++ # Try to call non-existent tool ++ request = MagicMock() ++ request.params.name = 'nonexistent_tool' ++ request.params.arguments = {} ++ ++ from genkit.core.error import GenkitError ++ ++ with self.assertRaises(GenkitError) as context: ++ await server.call_tool(request) ++ ++ self.assertIn('NOT_FOUND', str(context.exception.status)) ++ ++ async def test_client_handles_connection_failure(self): ++ """Test that client handles connection failures gracefully.""" ++ client = McpClient(name='test-client', config=McpServerConfig(command='nonexistent_command')) ++ ++ # Mock the connection to fail ++ with patch('genkit.plugins.mcp.client.client.stdio_client') as mock_stdio: ++ mock_stdio.side_effect = Exception('Connection failed') ++ ++ with self.assertRaises(Exception): ++ await client.connect() ++ ++ # Client should mark server as disabled ++ self.assertTrue(client.config.disabled) ++ ++ ++if __name__ == '__main__': ++ unittest.main() +diff --git a/py/plugins/mcp/tests/test_mcp_server.py b/py/plugins/mcp/tests/test_mcp_server.py +new file mode 100644 +index 000000000..f3180d618 +--- /dev/null ++++ b/py/plugins/mcp/tests/test_mcp_server.py +@@ -0,0 +1,341 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++""" ++MCP Server Tests ++ ++Mirrors the functionality of js/plugins/mcp/tests/server_test.ts ++Tests tools, prompts, and resources exposed via MCP server. ++""" ++ ++import os ++import sys ++import unittest ++from unittest.mock import AsyncMock, MagicMock, patch ++ ++sys.path.insert(0, os.path.dirname(__file__)) ++sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src'))) ++ ++# Mock mcp module before importing ++mock_mcp = MagicMock() ++sys.modules['mcp'] = mock_mcp ++ ++ ++class MockSchema: ++ def __init__(self, **kwargs): ++ for k, v in kwargs.items(): ++ setattr(self, k, v) ++ ++ ++types_mock = MagicMock() ++sys.modules['mcp.types'] = types_mock ++types_mock.ListToolsResult = MockSchema ++types_mock.CallToolResult = MockSchema ++types_mock.ListPromptsResult = MockSchema ++types_mock.GetPromptResult = MockSchema ++types_mock.ListResourcesResult = MockSchema ++types_mock.ListResourceTemplatesResult = MockSchema ++types_mock.ReadResourceResult = MockSchema ++types_mock.Tool = MockSchema ++types_mock.Prompt = MockSchema ++types_mock.Resource = MockSchema ++types_mock.ResourceTemplate = MockSchema ++types_mock.TextResourceContents = MockSchema ++types_mock.BlobResourceContents = MockSchema ++types_mock.ImageContent = MockSchema ++types_mock.TextResourceContents = MockSchema ++types_mock.BlobResourceContents = MockSchema ++types_mock.ImageContent = MockSchema ++types_mock.TextContent = MockSchema ++types_mock.PromptMessage = MockSchema ++ ++sys.modules['mcp.server'] = MagicMock() ++sys.modules['mcp.server.stdio'] = MagicMock() ++sys.modules['mcp.client'] = MagicMock() ++sys.modules['mcp.client'].__path__ = [] ++sys.modules['mcp.client.stdio'] = MagicMock() ++sys.modules['mcp.client.sse'] = MagicMock() ++sys.modules['mcp.server.sse'] = MagicMock() ++ ++import pytest ++ ++from genkit.ai import Genkit ++from genkit.core.action.types import ActionKind ++from genkit.plugins.mcp import McpServer, McpServerOptions, create_mcp_server ++ ++ ++@pytest.mark.asyncio ++class TestMcpServer(unittest.IsolatedAsyncioTestCase): ++ """Test MCP server functionality - mirrors JS server_test.ts""" ++ ++ def setUp(self): ++ """Set up test fixtures before each test.""" ++ self.ai = Genkit() ++ ++ # Define test tool ++ @self.ai.tool(description='test tool') ++ def test_tool(input: dict[str, str]) -> str: ++ foo = input.get('foo', '') ++ return f'yep {{"foo":"{foo}"}}' ++ ++ # Define test prompt ++ self.ai.define_prompt(name='testPrompt', model='test-model', prompt='prompt says: {{input}}') ++ ++ # Define test resource with fixed URI ++ self.ai.define_resource( ++ name='testResources', uri='my://resource', fn=lambda req: {'content': [{'text': 'my resource'}]} ++ ) ++ ++ # Define test resource with template ++ self.ai.define_resource( ++ name='testTmpl', ++ template='file://{+path}', ++ fn=lambda req: {'content': [{'text': f'file contents for {req.uri}'}]}, ++ ) ++ ++ # Create MCP server ++ self.server = create_mcp_server(self.ai, McpServerOptions(name='test-server', version='0.0.1')) ++ ++ async def asyncSetUp(self): ++ """Async setup - initialize server.""" ++ await self.server.setup() ++ ++ # ===== TOOL TESTS ===== ++ ++ async def test_list_tools(self): ++ """Test listing tools - mirrors JS 'should list tools'.""" ++ result = await self.server.list_tools({}) ++ ++ # Verify we have the test tool ++ self.assertEqual(len(result.tools), 1) ++ tool = result.tools[0] ++ ++ self.assertEqual(tool.name, 'test_tool') ++ self.assertEqual(tool.description, 'test tool') ++ self.assertIsNotNone(tool.inputSchema) ++ ++ async def test_call_tool(self): ++ """Test calling a tool - mirrors JS 'should call the tool'.""" ++ # Create mock request ++ request = MagicMock() ++ request.params.name = 'test_tool' ++ request.params.arguments = {'foo': 'bar'} ++ ++ result = await self.server.call_tool(request) ++ ++ # Verify response ++ self.assertEqual(len(result.content), 1) ++ self.assertEqual(result.content[0].type, 'text') ++ self.assertEqual(result.content[0].text, 'yep {"foo":"bar"}') ++ ++ # ===== PROMPT TESTS ===== ++ ++ async def test_list_prompts(self): ++ """Test listing prompts - mirrors JS 'should list prompts'.""" ++ result = await self.server.list_prompts({}) ++ ++ # Verify we have the test prompt ++ prompt_names = [p.name for p in result.prompts] ++ self.assertIn('testPrompt', prompt_names) ++ ++ async def test_get_prompt(self): ++ """Test rendering a prompt - mirrors JS 'should render prompt'.""" ++ # Create mock request ++ request = MagicMock() ++ request.params.name = 'testPrompt' ++ request.params.arguments = {'input': 'hello'} ++ ++ result = await self.server.get_prompt(request) ++ ++ # Verify response ++ self.assertIsNotNone(result.messages) ++ self.assertGreater(len(result.messages), 0) ++ ++ # Check message content ++ message = result.messages[0] ++ self.assertEqual(message.role, 'user') ++ self.assertEqual(message.content.type, 'text') ++ self.assertIn('prompt says: hello', message.content.text) ++ ++ # ===== RESOURCE TESTS ===== ++ ++ async def test_list_resources(self): ++ """Test listing resources - mirrors JS 'should list resources'.""" ++ result = await self.server.list_resources({}) ++ ++ # Verify we have the fixed URI resource ++ self.assertEqual(len(result.resources), 1) ++ resource = result.resources[0] ++ ++ self.assertEqual(resource.name, 'testResources') ++ self.assertEqual(resource.uri, 'my://resource') ++ ++ async def test_list_resource_templates(self): ++ """Test listing resource templates - mirrors JS 'should list templates'.""" ++ result = await self.server.list_resource_templates({}) ++ ++ # Verify we have the template resource ++ self.assertEqual(len(result.resourceTemplates), 1) ++ template = result.resourceTemplates[0] ++ ++ self.assertEqual(template.name, 'testTmpl') ++ self.assertEqual(template.uriTemplate, 'file://{+path}') ++ ++ async def test_read_resource(self): ++ """Test reading a resource - mirrors JS 'should read resource'.""" ++ # Create mock request ++ request = MagicMock() ++ request.params.uri = 'my://resource' ++ ++ result = await self.server.read_resource(request) ++ ++ # Verify response ++ self.assertEqual(len(result.contents), 1) ++ content = result.contents[0] ++ ++ self.assertEqual(content.uri, 'my://resource') ++ self.assertEqual(content.text, 'my resource') ++ ++ async def test_read_template_resource(self): ++ """Test reading a template resource.""" ++ # Create mock request ++ request = MagicMock() ++ request.params.uri = 'file:///path/to/file.txt' ++ ++ result = await self.server.read_resource(request) ++ ++ # Verify response ++ self.assertEqual(len(result.contents), 1) ++ content = result.contents[0] ++ ++ self.assertEqual(content.uri, 'file:///path/to/file.txt') ++ self.assertIn('file contents for file:///path/to/file.txt', content.text) ++ ++ # ===== ADDITIONAL TESTS ===== ++ ++ async def test_server_initialization(self): ++ """Test that server initializes correctly.""" ++ self.assertIsNotNone(self.server) ++ self.assertEqual(self.server.options.name, 'test-server') ++ self.assertEqual(self.server.options.version, '0.0.1') ++ self.assertTrue(self.server.actions_resolved) ++ ++ async def test_server_has_all_action_types(self): ++ """Test that server has tools, prompts, and resources.""" ++ self.assertGreater(len(self.server.tool_actions), 0) ++ self.assertGreater(len(self.server.prompt_actions), 0) ++ self.assertGreater(len(self.server.resource_actions), 0) ++ ++ async def test_tool_not_found(self): ++ """Test calling a non-existent tool.""" ++ from genkit.core.error import GenkitError ++ ++ request = MagicMock() ++ request.params.name = 'nonexistent_tool' ++ request.params.arguments = {} ++ ++ with self.assertRaises(GenkitError) as context: ++ await self.server.call_tool(request) ++ ++ self.assertEqual(context.exception.status, 'NOT_FOUND') ++ ++ async def test_prompt_not_found(self): ++ """Test getting a non-existent prompt.""" ++ from genkit.core.error import GenkitError ++ ++ request = MagicMock() ++ request.params.name = 'nonexistent_prompt' ++ request.params.arguments = {} ++ ++ with self.assertRaises(GenkitError) as context: ++ await self.server.get_prompt(request) ++ ++ self.assertEqual(context.exception.status, 'NOT_FOUND') ++ ++ async def test_resource_not_found(self): ++ """Test reading a non-existent resource.""" ++ from genkit.core.error import GenkitError ++ ++ request = MagicMock() ++ request.params.uri = 'nonexistent://resource' ++ ++ with self.assertRaises(GenkitError) as context: ++ await self.server.read_resource(request) ++ ++ self.assertEqual(context.exception.status, 'NOT_FOUND') ++ ++ ++# Additional test class for resource-specific functionality ++@pytest.mark.asyncio ++class TestResourceFunctionality(unittest.IsolatedAsyncioTestCase): ++ """Test resource-specific functionality.""" ++ ++ async def test_resource_registration_with_fixed_uri(self): ++ """Test registering a resource with fixed URI.""" ++ ai = Genkit() ++ ++ action = ai.define_resource( ++ name='test_resource', uri='test://resource', fn=lambda req: {'content': [{'text': 'test'}]} ++ ) ++ ++ self.assertIsNotNone(action) ++ self.assertEqual(action.kind, ActionKind.RESOURCE) ++ self.assertEqual(action.metadata['resource']['uri'], 'test://resource') ++ ++ async def test_resource_registration_with_template(self): ++ """Test registering a resource with URI template.""" ++ ai = Genkit() ++ ++ action = ai.define_resource( ++ name='file', template='file://{+path}', fn=lambda req: {'content': [{'text': 'file content'}]} ++ ) ++ ++ self.assertIsNotNone(action) ++ self.assertEqual(action.kind, ActionKind.RESOURCE) ++ self.assertEqual(action.metadata['resource']['template'], 'file://{+path}') ++ ++ async def test_resource_requires_uri_or_template(self): ++ """Test that resource requires either uri or template.""" ++ ai = Genkit() ++ ++ with self.assertRaises(ValueError) as context: ++ ai.define_resource(name='invalid', fn=lambda req: {'content': []}) ++ ++ self.assertIn('uri', str(context.exception).lower()) ++ self.assertIn('template', str(context.exception).lower()) ++ ++ async def test_uri_template_matching(self): ++ """Test URI template matching.""" ++ from genkit.blocks.resource import matches_uri_template ++ ++ # Test exact match ++ result = matches_uri_template('file://{+path}', 'file:///home/user/doc.txt') ++ self.assertIsNotNone(result) ++ self.assertIn('path', result) ++ ++ # Test no match ++ result = matches_uri_template('file://{path}', 'http://example.com') ++ self.assertIsNone(result) ++ ++ # Test multiple parameters ++ result = matches_uri_template('user://{id}/posts/{post_id}', 'user://123/posts/456') ++ self.assertIsNotNone(result) ++ self.assertEqual(result['id'], '123') ++ self.assertEqual(result['post_id'], '456') ++ ++ ++if __name__ == '__main__': ++ unittest.main() +diff --git a/py/plugins/mcp/tests/test_mcp_server_resources.py b/py/plugins/mcp/tests/test_mcp_server_resources.py +new file mode 100644 +index 000000000..51aca70cc +--- /dev/null ++++ b/py/plugins/mcp/tests/test_mcp_server_resources.py +@@ -0,0 +1,351 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++"""Comprehensive tests for MCP server resource handling.""" ++ ++import os ++import sys ++import unittest ++from unittest.mock import AsyncMock, MagicMock, patch ++ ++sys.path.insert(0, os.path.dirname(__file__)) ++sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src'))) ++from fakes import mock_mcp_modules ++ ++mock_mcp_modules() ++ ++import pytest ++ ++from genkit.ai import Genkit ++from genkit.core.action.types import ActionKind ++from genkit.plugins.mcp import McpServer, McpServerOptions, create_mcp_server ++ ++ ++@pytest.mark.asyncio ++class TestMcpServerResources(unittest.IsolatedAsyncioTestCase): ++ """Tests for MCP server resource handling.""" ++ ++ def setUp(self): ++ """Set up test fixtures.""" ++ self.ai = Genkit() ++ ++ async def test_list_resources_with_fixed_uri(self): ++ """Test listing resources with fixed URIs.""" ++ # Define resources ++ self.ai.define_resource(name='config', uri='app://config', fn=lambda req: {'content': [{'text': 'config'}]}) ++ ++ self.ai.define_resource(name='data', uri='app://data', fn=lambda req: {'content': [{'text': 'data'}]}) ++ ++ # Create server ++ server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) ++ await server.setup() ++ ++ # List resources ++ result = await server.list_resources({}) ++ ++ # Verify ++ self.assertEqual(len(result.resources), 2) ++ resource_names = [r.name for r in result.resources] ++ self.assertIn('config', resource_names) ++ self.assertIn('data', resource_names) ++ ++ # Verify URIs ++ config_resource = next(r for r in result.resources if r.name == 'config') ++ self.assertEqual(config_resource.uri, 'app://config') ++ ++ async def test_list_resource_templates(self): ++ """Test listing resources with URI templates.""" ++ # Define template resources ++ self.ai.define_resource( ++ name='file', template='file://{+path}', fn=lambda req: {'content': [{'text': 'file content'}]} ++ ) ++ ++ self.ai.define_resource( ++ name='user', template='user://{id}/profile', fn=lambda req: {'content': [{'text': 'user profile'}]} ++ ) ++ ++ # Create server ++ server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) ++ await server.setup() ++ ++ # List resource templates ++ result = await server.list_resource_templates({}) ++ ++ # Verify ++ self.assertEqual(len(result.resourceTemplates), 2) ++ template_names = [t.name for t in result.resourceTemplates] ++ self.assertIn('file', template_names) ++ self.assertIn('user', template_names) ++ ++ # Verify templates ++ file_template = next(t for t in result.resourceTemplates if t.name == 'file') ++ self.assertEqual(file_template.uriTemplate, 'file://{+path}') ++ ++ async def test_list_resources_excludes_templates(self): ++ """Test that list_resources excludes template resources.""" ++ # Define mixed resources ++ self.ai.define_resource(name='fixed', uri='app://fixed', fn=lambda req: {'content': [{'text': 'fixed'}]}) ++ ++ self.ai.define_resource( ++ name='template', template='app://{id}', fn=lambda req: {'content': [{'text': 'template'}]} ++ ) ++ ++ # Create server ++ server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) ++ await server.setup() ++ ++ # List resources (should only include fixed URI) ++ result = await server.list_resources({}) ++ ++ self.assertEqual(len(result.resources), 1) ++ self.assertEqual(result.resources[0].name, 'fixed') ++ ++ async def test_list_resource_templates_excludes_fixed(self): ++ """Test that list_resource_templates excludes fixed URI resources.""" ++ # Define mixed resources ++ self.ai.define_resource(name='fixed', uri='app://fixed', fn=lambda req: {'content': [{'text': 'fixed'}]}) ++ ++ self.ai.define_resource( ++ name='template', template='app://{id}', fn=lambda req: {'content': [{'text': 'template'}]} ++ ) ++ ++ # Create server ++ server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) ++ await server.setup() ++ ++ # List templates (should only include template) ++ result = await server.list_resource_templates({}) ++ ++ self.assertEqual(len(result.resourceTemplates), 1) ++ self.assertEqual(result.resourceTemplates[0].name, 'template') ++ ++ async def test_read_resource_with_fixed_uri(self): ++ """Test reading a resource with fixed URI.""" ++ ++ def config_resource(req): ++ return {'content': [{'text': 'Configuration data'}]} ++ ++ self.ai.define_resource(name='config', uri='app://config', fn=config_resource) ++ ++ # Create server ++ server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) ++ await server.setup() ++ ++ # Read resource ++ from mcp.types import ReadResourceRequest ++ ++ request = MagicMock() ++ request.params.uri = 'app://config' ++ ++ result = await server.read_resource(request) ++ ++ # Verify ++ self.assertEqual(len(result.contents), 1) ++ self.assertEqual(result.contents[0].text, 'Configuration data') ++ ++ async def test_read_resource_with_template(self): ++ """Test reading a resource with URI template.""" ++ ++ def file_resource(req): ++ uri = req.uri ++ # Extract path from URI ++ path = uri.replace('file://', '') ++ return {'content': [{'text': f'Contents of {path}'}]} ++ ++ self.ai.define_resource(name='file', template='file://{+path}', fn=file_resource) ++ ++ # Create server ++ server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) ++ await server.setup() ++ ++ # Read resource ++ request = MagicMock() ++ request.params.uri = 'file:///home/user/document.txt' ++ ++ result = await server.read_resource(request) ++ ++ # Verify ++ self.assertEqual(len(result.contents), 1) ++ self.assertIn('/home/user/document.txt', result.contents[0].text) ++ ++ async def test_read_resource_not_found(self): ++ """Test reading a non-existent resource.""" ++ self.ai.define_resource(name='existing', uri='app://existing', fn=lambda req: {'content': [{'text': 'data'}]}) ++ ++ # Create server ++ server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) ++ await server.setup() ++ ++ # Try to read non-existent resource ++ request = MagicMock() ++ request.params.uri = 'app://nonexistent' ++ ++ from genkit.core.error import GenkitError ++ ++ with self.assertRaises(GenkitError) as context: ++ await server.read_resource(request) ++ ++ self.assertIn('NOT_FOUND', str(context.exception.status)) ++ ++ async def test_read_resource_with_multiple_content_parts(self): ++ """Test reading a resource that returns multiple content parts.""" ++ ++ def multi_part_resource(req): ++ return {'content': [{'text': 'Part 1'}, {'text': 'Part 2'}, {'text': 'Part 3'}]} ++ ++ self.ai.define_resource(name='multi', uri='app://multi', fn=multi_part_resource) ++ ++ # Create server ++ server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) ++ await server.setup() ++ ++ # Read resource ++ request = MagicMock() ++ request.params.uri = 'app://multi' ++ ++ result = await server.read_resource(request) ++ ++ # Verify ++ self.assertEqual(len(result.contents), 3) ++ self.assertEqual(result.contents[0].text, 'Part 1') ++ self.assertEqual(result.contents[1].text, 'Part 2') ++ self.assertEqual(result.contents[2].text, 'Part 3') ++ ++ ++@pytest.mark.asyncio ++class TestMcpServerToolsAndPrompts(unittest.IsolatedAsyncioTestCase): ++ """Tests for MCP server tool and prompt handling.""" ++ ++ def setUp(self): ++ """Set up test fixtures.""" ++ self.ai = Genkit() ++ ++ async def test_list_tools(self): ++ """Test listing tools.""" ++ ++ @self.ai.tool(description='Add two numbers') ++ def add(input: dict[str, int]) -> int: ++ return input['a'] + input['b'] ++ ++ @self.ai.tool(description='Multiply two numbers') ++ def multiply(input: dict[str, int]) -> int: ++ return input['a'] * input['b'] ++ ++ # Create server ++ server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) ++ await server.setup() ++ ++ # List tools ++ result = await server.list_tools({}) ++ ++ # Verify ++ self.assertEqual(len(result.tools), 2) ++ tool_names = [t.name for t in result.tools] ++ self.assertIn('add', tool_names) ++ self.assertIn('multiply', tool_names) ++ ++ async def test_call_tool(self): ++ """Test calling a tool.""" ++ ++ @self.ai.tool() ++ def add(input: dict[str, int]) -> int: ++ return input['a'] + input['b'] ++ ++ # Create server ++ server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) ++ await server.setup() ++ ++ # Call tool ++ request = MagicMock() ++ request.params.name = 'add' ++ request.params.arguments = {'a': 5, 'b': 3} ++ ++ result = await server.call_tool(request) ++ ++ # Verify ++ self.assertEqual(len(result.content), 1) ++ self.assertEqual(result.content[0].text, '8') ++ ++ async def test_list_prompts(self): ++ """Test listing prompts.""" ++ self.ai.define_prompt(name='greeting', prompt='Hello {{name}}!') ++ ++ self.ai.define_prompt(name='farewell', prompt='Goodbye {{name}}!') ++ ++ # Create server ++ server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) ++ await server.setup() ++ ++ # List prompts ++ result = await server.list_prompts({}) ++ ++ # Verify ++ self.assertGreaterEqual(len(result.prompts), 2) ++ prompt_names = [p.name for p in result.prompts] ++ # Prompt names might have variant suffixes ++ ++ ++@pytest.mark.asyncio ++class TestMcpServerIntegration(unittest.IsolatedAsyncioTestCase): ++ """Integration tests for MCP server.""" ++ ++ async def test_server_exposes_all_action_types(self): ++ """Test that server exposes tools, prompts, and resources.""" ++ ai = Genkit() ++ ++ # Define tool ++ @ai.tool() ++ def test_tool(x: int) -> int: ++ return x * 2 ++ ++ # Define prompt ++ ai.define_prompt(name='test', prompt='Test prompt') ++ ++ # Define resource ++ ai.define_resource(name='test_resource', uri='test://resource', fn=lambda req: {'content': [{'text': 'test'}]}) ++ ++ # Create server ++ server = create_mcp_server(ai, McpServerOptions(name='integration-test')) ++ await server.setup() ++ ++ # Verify all action types are available ++ self.assertGreater(len(server.tool_actions), 0) ++ self.assertGreater(len(server.prompt_actions), 0) ++ self.assertGreater(len(server.resource_actions), 0) ++ ++ async def test_server_initialization_idempotent(self): ++ """Test that server setup is idempotent.""" ++ ai = Genkit() ++ ++ @ai.tool() ++ def test_tool(x: int) -> int: ++ return x ++ ++ server = create_mcp_server(ai, McpServerOptions(name='test')) ++ ++ # Setup multiple times ++ await server.setup() ++ count1 = len(server.tool_actions) ++ ++ await server.setup() ++ count2 = len(server.tool_actions) ++ ++ # Should be the same ++ self.assertEqual(count1, count2) ++ ++ ++if __name__ == '__main__': ++ unittest.main() +diff --git a/py/plugins/ollama/pyproject.toml b/py/plugins/ollama/pyproject.toml +index d9ee166da..53edda87e 100644 +--- a/py/plugins/ollama/pyproject.toml ++++ b/py/plugins/ollama/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -35,7 +34,7 @@ classifiers = [ + ] + dependencies = ["genkit", "ollama~=0.4", "structlog>=25.2.0"] + description = "Genkit Ollama Plugin" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "genkit-plugin-ollama" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/constants.py b/py/plugins/ollama/src/genkit/plugins/ollama/constants.py +index 4060d414a..3be251e2a 100644 +--- a/py/plugins/ollama/src/genkit/plugins/ollama/constants.py ++++ b/py/plugins/ollama/src/genkit/plugins/ollama/constants.py +@@ -14,6 +14,7 @@ + # + # SPDX-License-Identifier: Apache-2.0 + ++ + import sys + + if sys.version_info < (3, 11): +diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py b/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py +index c3aa01e88..b007058f8 100644 +--- a/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py ++++ b/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py +@@ -16,15 +16,15 @@ + + """Ollama Plugin for Genkit.""" + +-import asyncio +-from functools import cached_property, partial ++from functools import partial + + import structlog + + import ollama as ollama_api +-from genkit.ai import GenkitRegistry, Plugin +-from genkit.blocks.embedding import EmbedderOptions, EmbedderSupports, embedder_action_metadata +-from genkit.blocks.model import model_action_metadata ++from genkit.ai import Plugin ++from genkit.blocks.embedding import EmbedderOptions, EmbedderSupports, embedder, embedder_action_metadata ++from genkit.blocks.model import model, model_action_metadata ++from genkit.core.action import Action + from genkit.core.registry import ActionKind + from genkit.core.schema import to_json_schema + from genkit.plugins.ollama.constants import ( +@@ -91,63 +91,78 @@ class Ollama(Plugin): + + self.client = partial(ollama_api.AsyncClient, host=self.server_address) + +- def initialize(self, ai: GenkitRegistry) -> None: +- """Initialize the Ollama plugin. +- +- Registers the defined Ollama models and embedders with the Genkit AI registry. +- +- Args: +- ai: The AI registry to initialize the plugin with. +- """ +- self._initialize_models(ai=ai) +- self._initialize_embedders(ai=ai) +- +- def _initialize_models(self, ai: GenkitRegistry) -> None: +- """Initializes and registers the specified Ollama models with Genkit. +- +- Args: +- ai: The Genkit AI registry instance. +- """ ++ async def init(self): ++ """Return eagerly-initialized model and embedder actions.""" ++ actions = [] + for model_definition in self.models: +- self._define_ollama_model(ai, model_definition) +- +- def _initialize_embedders(self, ai: GenkitRegistry) -> None: +- """Initializes and registers the specified Ollama embedders with Genkit. +- +- Args: +- ai: The Genkit AI registry instance. +- """ +- for embedding_definition in self.embedders: +- self._define_ollama_embedder(ai, embedding_definition) +- +- def resolve_action( +- self, +- ai: GenkitRegistry, +- kind: ActionKind, +- name: str, +- ) -> None: +- """Resolves and action. ++ actions.append(self._create_model_action(model_definition)) ++ for embedder_definition in self.embedders: ++ actions.append(self._create_embedder_action(embedder_definition)) ++ return actions + +- Args: +- ai: The Genkit registry. +- kind: The kind of action to resolve. +- name: The name of the action to resolve. +- """ +- if kind == ActionKind.MODEL: +- self._define_ollama_model(ai, ModelDefinition(name=name)) +- elif kind == ActionKind.EMBEDDER: +- self._define_ollama_embedder(ai, EmbeddingDefinition(name=name)) ++ async def resolve(self, action_type: ActionKind, name: str): ++ """Resolve a model or embedder action on-demand.""" ++ clean_name = name.replace(f'{OLLAMA_PLUGIN_NAME}/', '') if name.startswith(OLLAMA_PLUGIN_NAME) else name ++ ++ if action_type == ActionKind.MODEL: ++ # Prefer configured model definitions (api_type, supports) when available. ++ for model_def in self.models: ++ configured_name = ( ++ model_def.name.replace(OLLAMA_PLUGIN_NAME + '/', '') ++ if model_def.name.startswith(OLLAMA_PLUGIN_NAME) ++ else model_def.name ++ ) ++ if configured_name == clean_name: ++ return self._create_model_action(model_def) ++ return self._create_model_action(ModelDefinition(name=clean_name)) ++ elif action_type == ActionKind.EMBEDDER: ++ for embedder_def in self.embedders: ++ configured_name = ( ++ embedder_def.name.replace(OLLAMA_PLUGIN_NAME + '/', '') ++ if embedder_def.name.startswith(OLLAMA_PLUGIN_NAME) ++ else embedder_def.name ++ ) ++ if configured_name == clean_name: ++ return self._create_embedder_action(embedder_def) ++ return self._create_embedder_action(EmbeddingDefinition(name=clean_name)) ++ return None + +- def _define_ollama_model(self, ai: GenkitRegistry, model_ref: ModelDefinition) -> None: +- """Defines and registers an Ollama model with Genkit. ++ async def list_actions(self): ++ """List all available Ollama models and embedders.""" ++ _client = self.client() ++ response = await _client.list() + +- Cleans the model name, instantiates an OllamaModel, and registers it +- with the provided Genkit AI registry, including metadata about its capabilities. ++ actions = [] ++ for model_info in response.models: ++ _name = model_info.model ++ if 'embed' in _name: ++ actions.append( ++ embedder_action_metadata( ++ name=_name, ++ options=EmbedderOptions( ++ config_schema=to_json_schema(ollama_api.Options), ++ label=f'Ollama Embedding - {_name}', ++ supports=EmbedderSupports(input=['text']), ++ ), ++ ) ++ ) ++ else: ++ actions.append( ++ model_action_metadata( ++ name=_name, ++ config_schema=GenerationCommonConfig, ++ info={ ++ 'label': f'Ollama - {_name}', ++ 'multiturn': True, ++ 'system_role': True, ++ 'tools': False, ++ }, ++ ) ++ ) ++ return actions + +- Args: +- ai: The Genkit AI registry instance. +- model_ref: The definition of the model to be registered. +- """ ++ def _create_model_action(self, model_ref: ModelDefinition) -> Action: ++ """Create an Ollama model action (doesn't register).""" + _clean_name = ( + model_ref.name.replace(OLLAMA_PLUGIN_NAME + '/', '') + if model_ref.name.startswith(OLLAMA_PLUGIN_NAME) +@@ -155,14 +170,14 @@ class Ollama(Plugin): + ) + + model_ref.name = _clean_name +- model = OllamaModel( ++ ollama_model = OllamaModel( + client=self.client, + model_definition=model_ref, + ) + +- ai.define_model( +- name=ollama_name(model_ref.name), +- fn=model.generate, ++ return model( ++ name=model_ref.name, ++ fn=ollama_model.generate, + config_schema=GenerationCommonConfig, + metadata={ + 'label': f'Ollama - {_clean_name}', +@@ -172,17 +187,8 @@ class Ollama(Plugin): + }, + ) + +- def _define_ollama_embedder(self, ai: GenkitRegistry, embedder_ref: EmbeddingDefinition) -> None: +- """Defines and registers an Ollama embedder with Genkit. +- +- Cleans the embedder name, instantiates an OllamaEmbedder, and registers it +- with the provided Genkit AI registry, including metadata about its capabilities +- and expected output dimensions. +- +- Args: +- ai: The Genkit AI registry instance. +- embedder_ref: The definition of the embedding model to be registered. +- """ ++ def _create_embedder_action(self, embedder_ref: EmbeddingDefinition) -> Action: ++ """Create an Ollama embedder action (doesn't register).""" + _clean_name = ( + embedder_ref.name.replace(OLLAMA_PLUGIN_NAME + '/', '') + if embedder_ref.name.startswith(OLLAMA_PLUGIN_NAME) +@@ -190,14 +196,14 @@ class Ollama(Plugin): + ) + + embedder_ref.name = _clean_name +- embedder = OllamaEmbedder( ++ ollama_embedder = OllamaEmbedder( + client=self.client, + embedding_definition=embedder_ref, + ) + +- ai.define_embedder( +- name=ollama_name(embedder_ref.name), +- fn=embedder.embed, ++ return embedder( ++ name=embedder_ref.name, ++ fn=ollama_embedder.embed, + options=EmbedderOptions( + config_schema=to_json_schema(ollama_api.Options), + label=f'Ollama Embedding - {_clean_name}', +@@ -205,52 +211,3 @@ class Ollama(Plugin): + supports=EmbedderSupports(input=['text']), + ), + ) +- +- @cached_property +- def list_actions(self) -> list[dict[str, str]]: +- """Generate a list of available actions or models. +- +- Returns: +- list[ActionMetadata]: A list of ActionMetadata objects, each with the following attributes: +- - name (str): The name of the action or model. +- - kind (ActionKind): The type or category of the action. +- - info (dict): The metadata dictionary describing the model configuration and properties. +- - config_schema (type): The schema class used for validating the model's configuration. +- """ +- try: +- loop = asyncio.get_running_loop() +- except RuntimeError: +- loop = asyncio.new_event_loop() +- asyncio.set_event_loop(loop) +- +- _client = self.client() +- response = loop.run_until_complete(_client.list()) +- +- actions = [] +- for model in response.models: +- _name = model.model +- if 'embed' in _name: +- actions.append( +- embedder_action_metadata( +- name=ollama_name(_name), +- options=EmbedderOptions( +- config_schema=to_json_schema(ollama_api.Options), +- label=f'Ollama Embedding - {_name}', +- supports=EmbedderSupports(input=['text']), +- ), +- ) +- ) +- else: +- actions.append( +- model_action_metadata( +- name=ollama_name(_name), +- config_schema=GenerationCommonConfig, +- info={ +- 'label': f'Ollama - {_name}', +- 'multiturn': True, +- 'system_role': True, +- 'tools': False, +- }, +- ) +- ) +- return actions +diff --git a/py/plugins/ollama/tests/test_integration.py b/py/plugins/ollama/tests/test_integration.py +index 2f2707589..1c2ce5e09 100644 +--- a/py/plugins/ollama/tests/test_integration.py ++++ b/py/plugins/ollama/tests/test_integration.py +@@ -16,31 +16,33 @@ + + """Integration tests for Ollama plugin with Genkit.""" + +-from unittest.mock import ANY, MagicMock, Mock, patch ++from unittest.mock import Mock + + import ollama as ollama_api + import pytest + + from genkit.ai import ActionKind, Genkit +-from genkit.plugins.ollama import Ollama, ollama_name +-from genkit.plugins.ollama.models import ModelDefinition +-from genkit.types import GenerateResponse, GenerationCommonConfig, Message, Role, TextPart ++from genkit.types import GenerateResponse, Message, Role, TextPart + + +-def test_adding_ollama_chat_model_to_genkit_veneer( ++@pytest.mark.asyncio ++async def test_adding_ollama_chat_model_to_genkit_veneer( + ollama_model: str, + genkit_veneer_chat_model: Genkit, + ) -> None: + """Test adding ollama chat model to genkit veneer.""" +- assert genkit_veneer_chat_model.registry.lookup_action(ActionKind.MODEL, ollama_model) ++ # Use async resolver-aware lookup for PluginV2 paths. ++ assert await genkit_veneer_chat_model.registry.aresolve_action(ActionKind.MODEL, ollama_model) + + +-def test_adding_ollama_generation_model_to_genkit_veneer( ++@pytest.mark.asyncio ++async def test_adding_ollama_generation_model_to_genkit_veneer( + ollama_model: str, + genkit_veneer_generate_model: Genkit, + ) -> None: + """Test adding ollama generation model to genkit veneer.""" +- assert genkit_veneer_generate_model.registry.lookup_action(ActionKind.MODEL, ollama_model) ++ # Use async resolver-aware lookup for PluginV2 paths. ++ assert await genkit_veneer_generate_model.registry.aresolve_action(ActionKind.MODEL, ollama_model) + + + @pytest.mark.asyncio +@@ -110,29 +112,3 @@ async def test_async_get_generate_model_response_from_llama_api_flow( + + assert isinstance(response, GenerateResponse) + assert response.message.content[0].root.text == mock_response_message +- +- +-@pytest.fixture +-@patch('ollama.AsyncClient') +-def ollama_plugin_instance(ollama_async_client): +- return Ollama() +- +- +-def test__initialize_models(ollama_plugin_instance): +- ai_mock = MagicMock(spec=Genkit) +- +- plugin = ollama_plugin_instance +- plugin.models = [ModelDefinition(name='test_model')] +- plugin._initialize_models(ai_mock) +- +- ai_mock.define_model.assert_called_once_with( +- name=ollama_name('test_model'), +- fn=ANY, +- config_schema=GenerationCommonConfig, +- metadata={ +- 'label': 'Ollama - test_model', +- 'multiturn': True, +- 'system_role': True, +- 'tools': False, +- }, +- ) +diff --git a/py/plugins/ollama/tests/test_plugin_api.py b/py/plugins/ollama/tests/test_plugin_api.py +index ff0c14d11..a8d459a6a 100644 +--- a/py/plugins/ollama/tests/test_plugin_api.py ++++ b/py/plugins/ollama/tests/test_plugin_api.py +@@ -17,19 +17,15 @@ + """Unit tests for Ollama Plugin.""" + + import unittest +-from unittest.mock import ANY, AsyncMock, MagicMock ++from unittest.mock import AsyncMock, MagicMock + +-import ollama as ollama_api + import pytest + from pydantic import BaseModel + +-from genkit.ai import ActionKind, Genkit +-from genkit.blocks.embedding import EmbedderOptions, EmbedderSupports +-from genkit.core.schema import to_json_schema +-from genkit.plugins.ollama import Ollama, ollama_name ++from genkit.core.registry import ActionKind ++from genkit.plugins.ollama import Ollama + from genkit.plugins.ollama.embedders import EmbeddingDefinition + from genkit.plugins.ollama.models import ModelDefinition +-from genkit.types import GenerationCommonConfig + + + class TestOllamaInit(unittest.TestCase): +@@ -69,74 +65,17 @@ class TestOllamaInit(unittest.TestCase): + assert plugin.request_headers == headers + + +-def test_initialize(ollama_plugin_instance): +- """Test initialize method of Ollama plugin.""" +- ai_mock = MagicMock(spec=Genkit) +- model_ref = ModelDefinition(name='test_model') +- embedder_ref = EmbeddingDefinition(name='test_embedder') +- ollama_plugin_instance.models = [model_ref] +- ollama_plugin_instance.embedders = [embedder_ref] ++@pytest.mark.asyncio ++async def test_init_returns_actions(ollama_plugin_instance): ++ """PluginV2 init() should return actions (models + embedders) without namespacing.""" ++ ollama_plugin_instance.models = [ModelDefinition(name='test_model')] ++ ollama_plugin_instance.embedders = [EmbeddingDefinition(name='test_embedder', dimensions=1024)] + +- init_models = MagicMock() +- init_embedders = MagicMock() ++ actions = await ollama_plugin_instance.init() + +- ollama_plugin_instance._initialize_models = init_models +- ollama_plugin_instance._initialize_embedders = init_embedders +- +- ollama_plugin_instance.initialize(ai_mock) +- +- init_models.assert_called_once_with(ai=ai_mock) +- init_embedders.assert_called_once_with(ai=ai_mock) +- +- +-def test__initialize_models(ollama_plugin_instance): +- """Test _initialize_models method of Ollama plugin.""" +- ai_mock = MagicMock(spec=Genkit) +- name = 'test_model' +- +- plugin = ollama_plugin_instance +- plugin.models = [ModelDefinition(name=name)] +- plugin._initialize_models(ai_mock) +- +- ai_mock.define_model.assert_called_once_with( +- name=ollama_name(name), +- fn=ANY, +- config_schema=GenerationCommonConfig, +- metadata={ +- 'label': f'Ollama - {name}', +- 'multiturn': True, +- 'system_role': True, +- 'tools': False, +- }, +- ) +- +- +-def test__initialize_embedders(ollama_plugin_instance): +- """Test _initialize_embedders method of Ollama plugin.""" +- ai_mock = MagicMock(spec=Genkit) +- name = 'test_embedder' +- +- plugin = ollama_plugin_instance +- plugin.embedders = [ +- EmbeddingDefinition( +- name=name, +- dimensions=1024, +- ) +- ] +- plugin._initialize_embedders(ai_mock) +- +- ai_mock.define_embedder.assert_called_once_with( +- name=ollama_name(name), +- fn=ANY, +- options=EmbedderOptions( +- config_schema=to_json_schema(ollama_api.Options), +- label=f'Ollama Embedding - {name}', +- dimensions=1024, +- supports=EmbedderSupports( +- input=['text'], +- ), +- ), +- ) ++ assert len(actions) == 2 ++ assert {a.kind for a in actions} == {ActionKind.MODEL, ActionKind.EMBEDDER} ++ assert {a.name for a in actions} == {'test_model', 'test_embedder'} + + + @pytest.mark.parametrize( +@@ -146,36 +85,13 @@ def test__initialize_embedders(ollama_plugin_instance): + (ActionKind.EMBEDDER, 'test_embedder'), + ], + ) +-def test_resolve_action(kind, name, ollama_plugin_instance): +- """Unit Tests for resolve action method.""" +- ai_mock = MagicMock(spec=Genkit) +- ollama_plugin_instance.resolve_action(ai_mock, kind, name) +- +- if kind == ActionKind.MODEL: +- ai_mock.define_model.assert_called_once_with( +- name=ollama_name(name), +- fn=ANY, +- config_schema=GenerationCommonConfig, +- metadata={ +- 'label': f'Ollama - {name}', +- 'multiturn': True, +- 'system_role': True, +- 'tools': False, +- }, +- ) +- else: +- ai_mock.define_embedder.assert_called_once_with( +- name=ollama_name(name), +- fn=ANY, +- options=EmbedderOptions( +- config_schema=to_json_schema(ollama_api.Options), +- label=f'Ollama Embedding - {name}', +- dimensions=None, +- supports=EmbedderSupports( +- input=['text'], +- ), +- ), +- ) ++@pytest.mark.asyncio ++async def test_resolve_returns_action(kind, name, ollama_plugin_instance): ++ """PluginV2 resolve() should return an Action for models/embedders.""" ++ action = await ollama_plugin_instance.resolve(kind, name) ++ assert action is not None ++ assert action.kind == kind ++ assert action.name == name + + + @pytest.mark.parametrize( +@@ -185,23 +101,11 @@ def test_resolve_action(kind, name, ollama_plugin_instance): + ('ollama/mistral', 'ollama/mistral', 'mistral'), + ], + ) +-def test_define_ollama_model(name, expected_name, clean_name, ollama_plugin_instance): +- """Unit tests for _define_ollama_model method.""" +- ai_mock = MagicMock(spec=Genkit) +- +- ollama_plugin_instance._define_ollama_model(ai_mock, ModelDefinition(name=name)) +- +- ai_mock.define_model.assert_called_once_with( +- name=expected_name, +- fn=ANY, +- config_schema=GenerationCommonConfig, +- metadata={ +- 'label': f'Ollama - {clean_name}', +- 'multiturn': True, +- 'system_role': True, +- 'tools': False, +- }, +- ) ++def test_create_model_action_cleans_name(name, expected_name, clean_name, ollama_plugin_instance): ++ """_create_model_action should strip namespace from input names.""" ++ action = ollama_plugin_instance._create_model_action(ModelDefinition(name=name)) ++ assert action.kind == ActionKind.MODEL ++ assert action.name == clean_name + + + @pytest.mark.parametrize( +@@ -211,28 +115,16 @@ def test_define_ollama_model(name, expected_name, clean_name, ollama_plugin_inst + ('ollama/mistral', 'ollama/mistral', 'mistral'), + ], + ) +-def test_define_ollama_embedder(name, expected_name, clean_name, ollama_plugin_instance): +- """Unit tests for _define_ollama_embedder method.""" +- ai_mock = MagicMock(spec=Genkit) +- +- ollama_plugin_instance._define_ollama_embedder(ai_mock, EmbeddingDefinition(name=name, dimensions=1024)) +- +- ai_mock.define_embedder.assert_called_once_with( +- name=expected_name, +- fn=ANY, +- options=EmbedderOptions( +- config_schema=to_json_schema(ollama_api.Options), +- label=f'Ollama Embedding - {clean_name}', +- dimensions=1024, +- supports=EmbedderSupports( +- input=['text'], +- ), +- ), +- ) ++def test_create_embedder_action_cleans_name(name, expected_name, clean_name, ollama_plugin_instance): ++ """_create_embedder_action should strip namespace from input names.""" ++ action = ollama_plugin_instance._create_embedder_action(EmbeddingDefinition(name=name, dimensions=1024)) ++ assert action.kind == ActionKind.EMBEDDER ++ assert action.name == clean_name + + +-def test_list_actions(ollama_plugin_instance): +- """Unit tests for list_actions method.""" ++@pytest.mark.asyncio ++async def test_list_returns_action_metadata(ollama_plugin_instance): ++ """PluginV2 list_actions() should return ActionMetadata and await the async client.""" + + class MockModelResponse(BaseModel): + model: str +@@ -256,7 +148,7 @@ def test_list_actions(ollama_plugin_instance): + + ollama_plugin_instance.client = mock_client + +- actions = ollama_plugin_instance.list_actions ++ actions = await ollama_plugin_instance.list_actions() + + assert len(actions) == 2 + +diff --git a/py/plugins/vertex-ai/pyproject.toml b/py/plugins/vertex-ai/pyproject.toml +index a4d32d018..1f8638cb4 100644 +--- a/py/plugins/vertex-ai/pyproject.toml ++++ b/py/plugins/vertex-ai/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -39,11 +38,12 @@ dependencies = [ + "google-cloud-aiplatform>=1.77.0", + "structlog>=25.2.0", + "strenum>=0.4.15; python_version < '3.11'", ++ "genkit-plugin-compat-oai", + "google-cloud-bigquery", + "google-cloud-firestore", + ] + description = "Genkit Google Cloud Vertex AI Plugin" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "genkit-plugin-vertex-ai" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/model_garden.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/model_garden.py +index 7b522412e..23705f135 100644 +--- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/model_garden.py ++++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/model_garden.py +@@ -56,7 +56,7 @@ class ModelGarden: + model: str, + location: str, + project_id: str, +- registry: GenkitRegistry, ++ registry: GenkitRegistry | None, + ) -> None: + """Initializes the ModelGarden instance. + +@@ -104,9 +104,14 @@ class ModelGarden: + A callable function (specifically, the `generate` method of an + `OpenAIModel` instance) that can be used by Genkit. + """ +- openai_model = OpenAIModel(self.name, self.client, self.ai) ++ # In PluginV2 paths we avoid registry-dependent tool lookup, but the legacy ++ # registry-based flow still passes a registry here. ++ openai_model = OpenAIModel(self.name, self.client) + return openai_model.generate + ++ # NOTE: OpenAIModel no longer requires a registry; tool schemas are provided via ++ # GenerateRequest.tools, so the returned function works for both v1/v2 flows. ++ + def define_model(self) -> None: + """Defines and registers the Model Garden model with the Genkit registry. + +@@ -114,6 +119,8 @@ class ModelGarden: + of the OpenAI-compatible generation function, then registers this model + within the Genkit framework using `self.ai.define_model`. + """ ++ if self.ai is None: ++ raise ValueError('ModelGarden.define_model() requires a GenkitRegistry') + model_info = self.get_model_info() + generate_fn = self.to_openai_compatible_model() + self.ai.define_model( +diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/modelgarden_plugin.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/modelgarden_plugin.py +index 48eac737c..a8b1455ee 100644 +--- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/modelgarden_plugin.py ++++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/modelgarden_plugin.py +@@ -17,13 +17,13 @@ + """ModelGarden API Compatible Plugin for Genkit.""" + + import os +-from functools import cached_property + +-from genkit.ai import GenkitRegistry, Plugin +-from genkit.blocks.model import model_action_metadata ++from genkit.ai import Plugin ++from genkit.blocks.model import model, model_action_metadata + from genkit.core.action import ActionMetadata + from genkit.core.action.types import ActionKind +-from genkit.plugins.compat_oai.models import SUPPORTED_OPENAI_COMPAT_MODELS ++from genkit.plugins.compat_oai.models import SUPPORTED_OPENAI_COMPAT_MODELS, OpenAIModelHandler ++from genkit.plugins.compat_oai.models.model_info import PluginSource + from genkit.plugins.compat_oai.typing import OpenAIConfig + from genkit.plugins.vertex_ai import constants as const + +@@ -61,83 +61,48 @@ class VertexAIModelGarden(Plugin): + """ + self.project_id = project_id if project_id is not None else os.getenv(const.GCLOUD_PROJECT) + self.location = location if location is not None else const.DEFAULT_REGION +- self.models = models ++ self.models = models or [] + +- def initialize(self, ai: GenkitRegistry) -> None: +- """Handles actions for various openaicompatible models.""" +- models = self.models +- if models is None: +- return ++ async def init(self): ++ """Return eagerly-initialized model actions.""" ++ return [self._create_model_action(m) for m in self.models] + +- for model in models: +- model_proxy = ModelGarden( +- model=model, +- location=self.location, +- project_id=self.project_id, +- registry=ai, +- ) +- model_proxy.define_model() +- +- def resolve_action( +- self, +- ai: GenkitRegistry, +- kind: ActionKind, +- name: str, +- ) -> None: +- """Resolves and action. +- +- Args: +- ai: The Genkit registry. +- kind: The kind of action to resolve. +- name: The name of the action to resolve. +- """ +- if kind == ActionKind.MODEL: +- self._resolve_model(ai=ai, name=name) +- +- def _resolve_model(self, ai: GenkitRegistry, name: str) -> None: +- """Resolves and defines a Model Garden Vertex AI model within the Genkit registry. +- +- This internal method handles the logic for registering new models +- of Vertex AI Model Garden that are compatible with OpenaI +- based on the provided name. +- It extracts a clean name, determines the model type, instantiates the +- appropriate model class, and registers it with the Genkit AI registry. +- +- Args: +- ai: The Genkit AI registry instance to define the model in. +- name: The name of the model to resolve. This name might include a +- prefix indicating it's from a specific plugin. +- """ ++ async def resolve(self, action_type: ActionKind, name: str): ++ if action_type != ActionKind.MODEL: ++ return None + clean_name = ( + name.replace(MODELGARDEN_PLUGIN_NAME + '/', '') if name.startswith(MODELGARDEN_PLUGIN_NAME) else name + ) ++ if clean_name not in SUPPORTED_OPENAI_COMPAT_MODELS: ++ return None ++ return self._create_model_action(clean_name) ++ ++ async def list_actions(self) -> list[ActionMetadata]: ++ return [ ++ model_action_metadata( ++ name=model_garden_name(model_name), ++ info=model_info.model_dump(), ++ config_schema=OpenAIConfig, ++ ) ++ for model_name, model_info in SUPPORTED_OPENAI_COMPAT_MODELS.items() ++ ] + ++ def _create_model_action(self, model_name: str): + model_proxy = ModelGarden( +- model=clean_name, ++ model=model_name, + location=self.location, + project_id=self.project_id, +- registry=ai, ++ registry=None, ++ ) ++ handler = OpenAIModelHandler.get_model_handler( ++ model=model_name, ++ client=model_proxy.client, # Vertex Model Garden OpenAI-compatible client ++ source=PluginSource.MODEL_GARDEN, ++ ) ++ model_info = model_proxy.get_model_info() ++ return model( ++ name=model_name, ++ fn=handler, ++ config_schema=OpenAIConfig, ++ metadata={'model': model_info}, + ) +- model_proxy.define_model() +- +- @cached_property +- def list_actions(self) -> list[ActionMetadata]: +- """Generate a list of available actions or models. +- +- Returns: +- list[ActionMetadata]: A list of ActionMetadata objects, each with the following attributes: +- - name (str): The name of the action or model. +- - kind (ActionKind): The type or category of the action. +- - info (dict): The metadata dictionary describing the model configuration and properties. +- - config_schema (type): The schema class used for validating the model's configuration. +- """ +- +- actions_list = [] +- for model, model_info in SUPPORTED_OPENAI_COMPAT_MODELS.items(): +- actions_list.append( +- model_action_metadata( +- name=model_garden_name(model), info=model_info.model_dump(), config_schema=OpenAIConfig +- ) +- ) +- +- return actions_list +diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py +index 8576bcbf1..ddfc6303b 100644 +--- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py ++++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py +@@ -20,7 +20,11 @@ from typing import Any + from google.auth.credentials import Credentials + from google.cloud import aiplatform_v1 + +-from genkit.ai import GenkitRegistry, Plugin ++from genkit.ai import Plugin ++from genkit.blocks.retriever import RetrieverOptions, retriever_action_metadata ++from genkit.core.action import Action, ActionMetadata ++from genkit.core.action.types import ActionKind ++from genkit.core.schema import to_json_schema + from genkit.plugins.vertex_ai.vector_search.retriever import ( + DocRetriever, + RetrieverOptionsSchema, +@@ -42,13 +46,10 @@ def vertexai_name(name: str) -> str: + + + class VertexAIVectorSearch(Plugin): +- """A plugin for integrating VertexAI Vector Search. ++ """A plugin for integrating VertexAI Vector Search.""" + +- This class registers VertexAI Vector Stores within a registry, +- and allows interaction to retrieve similar documents. +- """ +- +- name: str = 'vertexAIVectorSearch' ++ name: str = VERTEXAI_PLUGIN_NAME ++ retriever_name: str = 'vertexAIVectorSearch' + + def __init__( + self, +@@ -90,25 +91,56 @@ class VertexAIVectorSearch(Plugin): + credentials=credentials, + ) + +- def initialize(self, ai: GenkitRegistry) -> None: +- """Initialize plugin with the retriver specified. +- +- Register actions with the registry making them available for use in the Genkit framework. +- +- Args: +- ai: The registry to register actions with. +- """ +- retriever = self.retriever_cls( +- ai=ai, +- name=self.name, +- match_service_client_generator=self._match_service_client_generator, +- embedder=self.embedder, +- embedder_options=self.embedder_options, +- **self.retriever_extra_args, +- ) +- +- return ai.define_retriever( +- name=vertexai_name(self.name), +- config_schema=RetrieverOptionsSchema, +- fn=retriever.retrieve, ++ async def init(self) -> list[Action]: ++ return [self._create_retriever_action()] ++ ++ async def resolve(self, action_type: ActionKind, name: str) -> Action | None: ++ if action_type != ActionKind.RETRIEVER: ++ return None ++ if name != self.retriever_name: ++ return None ++ return self._create_retriever_action() ++ ++ async def list_actions(self) -> list[ActionMetadata]: ++ return [ ++ retriever_action_metadata( ++ name=self.retriever_name, ++ options=RetrieverOptions( ++ label='Vertex AI Vector Search', ++ config_schema=to_json_schema(RetrieverOptionsSchema), ++ ), ++ ) ++ ] ++ ++ def _create_retriever_action(self) -> Action: ++ metadata: dict[str, Any] = { ++ 'retriever': { ++ 'label': self.retriever_name, ++ 'customOptions': to_json_schema(RetrieverOptionsSchema), ++ } ++ } ++ ++ async def retrieve(request, ctx): ++ ai = (ctx.context or {}).get('__genkit_ai__') ++ if ai is None: ++ raise ValueError( ++ 'VertexAIVectorSearch retriever requires a Genkit instance in action context. ' ++ 'Use it via `await ai.retrieve(...)`.' ++ ) ++ ++ retriever = self.retriever_cls( ++ ai=ai, ++ name=self.retriever_name, ++ match_service_client_generator=self._match_service_client_generator, ++ embedder=self.embedder, ++ embedder_options=self.embedder_options, ++ **self.retriever_extra_args, ++ ) ++ return await retriever.retrieve(request, ctx) ++ ++ return Action( ++ kind=ActionKind.RETRIEVER, ++ name=self.retriever_name, ++ fn=retrieve, ++ metadata=metadata, + ) +diff --git a/py/plugins/vertex-ai/tests/vector_search/test_retrievers.py b/py/plugins/vertex-ai/tests/vector_search/test_retrievers.py +index 3feb75ae9..c4eabaeae 100644 +--- a/py/plugins/vertex-ai/tests/vector_search/test_retrievers.py ++++ b/py/plugins/vertex-ai/tests/vector_search/test_retrievers.py +@@ -24,7 +24,6 @@ import json + from unittest.mock import AsyncMock, MagicMock + + import pytest +-from google.cloud import bigquery + from google.cloud.aiplatform_v1 import ( + FindNeighborsRequest, + FindNeighborsResponse, +diff --git a/py/plugins/vertex-ai/tests/vector_search/test_vector_search_plugin.py b/py/plugins/vertex-ai/tests/vector_search/test_vector_search_plugin.py +index 912831c6d..f8c44010d 100644 +--- a/py/plugins/vertex-ai/tests/vector_search/test_vector_search_plugin.py ++++ b/py/plugins/vertex-ai/tests/vector_search/test_vector_search_plugin.py +@@ -18,17 +18,22 @@ + + from unittest.mock import MagicMock + +-from genkit.ai import Genkit ++import pytest ++ ++from genkit.core.action.types import ActionKind + from genkit.plugins.vertex_ai.vector_search import VertexAIVectorSearch + + +-def test_initialize_plugin(): +- """Test plugin initialization.""" ++@pytest.mark.asyncio ++async def test_init_plugin_returns_retriever_action(): ++ """PluginV2 init should return the vector-search retriever action.""" + plugin = VertexAIVectorSearch( + retriever=MagicMock(), + embedder='embedder', + ) + +- result = plugin.initialize(ai=MagicMock(spec=Genkit)) ++ actions = await plugin.init() + +- assert result is not None ++ assert len(actions) == 1 ++ assert actions[0].kind == ActionKind.RETRIEVER ++ assert actions[0].name == 'vertexAIVectorSearch' +diff --git a/py/plugins/xai/pyproject.toml b/py/plugins/xai/pyproject.toml +index 15843b372..bcfc93269 100644 +--- a/py/plugins/xai/pyproject.toml ++++ b/py/plugins/xai/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -35,7 +34,7 @@ classifiers = [ + ] + dependencies = ["genkit", "xai-sdk>=0.0.1"] + description = "Genkit xAI Plugin" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "genkit-plugin-xai" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/plugins/xai/src/genkit/plugins/xai/plugin.py b/py/plugins/xai/src/genkit/plugins/xai/plugin.py +index 046736d47..7a1bf57f7 100644 +--- a/py/plugins/xai/src/genkit/plugins/xai/plugin.py ++++ b/py/plugins/xai/src/genkit/plugins/xai/plugin.py +@@ -20,7 +20,9 @@ import os + + from xai_sdk import Client as XAIClient + +-from genkit.ai import GenkitRegistry, Plugin ++from genkit.ai import Plugin ++from genkit.blocks.model import model ++from genkit.core.action import ActionMetadata + from genkit.core.error import GenkitError + from genkit.core.registry import ActionKind + from genkit.plugins.xai.model_info import SUPPORTED_XAI_MODELS, get_model_info +@@ -56,30 +58,37 @@ class XAI(Plugin): + self._xai_params = xai_params + self._xai_client = XAIClient(api_key=api_key, **xai_params) + +- def initialize(self, ai: GenkitRegistry) -> None: +- for model_name in self.models: +- self._define_model(ai, model_name) +- +- def resolve_action( +- self, +- ai: GenkitRegistry, +- kind: ActionKind, +- name: str, +- ) -> None: +- if kind == ActionKind.MODEL: +- self._resolve_model(ai=ai, name=name) +- +- def _resolve_model(self, ai: GenkitRegistry, name: str) -> None: +- clean_name = name.replace(f'{XAI_PLUGIN_NAME}/', '') if name.startswith(XAI_PLUGIN_NAME) else name +- self._define_model(ai, clean_name) +- +- def _define_model(self, ai: GenkitRegistry, model_name: str) -> None: +- model = XAIModel(model_name=model_name, client=self._xai_client) ++ async def init(self): ++ """Return eagerly-initialized model actions.""" ++ return [self._create_model_action(model_name) for model_name in self.models] ++ ++ async def resolve(self, action_type: ActionKind, name: str): ++ """Resolve a model action on-demand.""" ++ if action_type == ActionKind.MODEL: ++ clean_name = name.replace(f'{XAI_PLUGIN_NAME}/', '') if name.startswith(XAI_PLUGIN_NAME) else name ++ if clean_name in SUPPORTED_XAI_MODELS: ++ return self._create_model_action(clean_name) ++ return None ++ ++ async def list_actions(self): ++ """List all supported xAI models.""" ++ return [ ++ ActionMetadata( ++ name=model_name, ++ kind=ActionKind.MODEL, ++ info={'supports': get_model_info(model_name).supports.model_dump()}, ++ ) ++ for model_name in self.models ++ ] ++ ++ def _create_model_action(self, model_name: str): ++ """Create an xAI model action (doesn't register).""" ++ xai_model = XAIModel(model_name=model_name, client=self._xai_client) + model_info = get_model_info(model_name) + +- ai.define_model( +- name=xai_name(model_name), +- fn=model.generate, ++ return model( ++ name=model_name, ++ fn=xai_model.generate, + config_schema=GenerationCommonConfig, + metadata={'model': {'supports': model_info.supports.model_dump()}}, + ) +diff --git a/py/plugins/xai/tests/test_xai_models.py b/py/plugins/xai/tests/test_xai_models.py +index fa985a429..95958ac04 100644 +--- a/py/plugins/xai/tests/test_xai_models.py ++++ b/py/plugins/xai/tests/test_xai_models.py +@@ -16,7 +16,6 @@ + + """Tests for xAI models.""" + +-import asyncio + from unittest.mock import MagicMock + + import pytest +diff --git a/py/plugins/xai/tests/test_xai_plugin.py b/py/plugins/xai/tests/test_xai_plugin.py +index 43cb61570..71ec98f8c 100644 +--- a/py/plugins/xai/tests/test_xai_plugin.py ++++ b/py/plugins/xai/tests/test_xai_plugin.py +@@ -16,7 +16,9 @@ + + """Tests for xAI plugin.""" + +-from unittest.mock import MagicMock, patch ++from unittest.mock import patch ++ ++import pytest + + from genkit.core.error import GenkitError + from genkit.core.registry import ActionKind +@@ -38,7 +40,7 @@ def test_init_without_api_key_raises(): + with patch.dict('os.environ', {}, clear=True): + try: + XAI() +- assert False, 'Expected GenkitError' ++ raise AssertionError('Expected GenkitError') + except GenkitError: + pass + +@@ -54,23 +56,26 @@ def test_custom_models(): + assert plugin.models == ['grok-3', 'grok-3-mini'] + + +-def test_plugin_initialize(): +- registry = MagicMock() ++@pytest.mark.asyncio ++async def test_plugin_initialize(): + plugin = XAI(api_key='test-key') +- plugin.initialize(registry) +- assert registry.define_model.call_count == len(SUPPORTED_XAI_MODELS) ++ actions = await plugin.init() ++ assert len(actions) == len(SUPPORTED_XAI_MODELS) ++ assert all(action.kind == ActionKind.MODEL for action in actions) + + +-def test_resolve_action_model(): +- registry = MagicMock() ++@pytest.mark.asyncio ++async def test_resolve_action_model(): + plugin = XAI(api_key='test-key') +- plugin.resolve_action(registry, ActionKind.MODEL, 'xai/grok-3') +- registry.define_model.assert_called_once() ++ action = await plugin.resolve(ActionKind.MODEL, 'grok-3') ++ assert action is not None ++ assert action.kind == ActionKind.MODEL ++ assert action.name == 'grok-3' + + + def test_supported_models(): + assert len(SUPPORTED_XAI_MODELS) >= 4 +- for name, info in SUPPORTED_XAI_MODELS.items(): ++ for _name, info in SUPPORTED_XAI_MODELS.items(): + assert info.label.startswith('xAI - ') + assert len(info.versions) > 0 + assert info.supports.tools +diff --git a/py/pyproject.toml b/py/pyproject.toml +index 400fa7384..5ee43b379 100644 +--- a/py/pyproject.toml ++++ b/py/pyproject.toml +@@ -15,11 +15,13 @@ + # SPDX-License-Identifier: Apache-2.0 + + [project] ++authors = [{ name = "Google" }] + dependencies = [ + "dotpromptz==0.1.4", + "genkit", + "genkit-plugin-anthropic", + "genkit-plugin-compat-oai", ++ "genkit-plugin-deepseek", + "genkit-plugin-dev-local-vectorstore", + "genkit-plugin-evaluators", + "genkit-plugin-firebase", +@@ -30,10 +32,11 @@ dependencies = [ + "genkit-plugin-vertex-ai", + "genkit-plugin-xai", + "liccheck>=0.9.2", ++ "mcp>=1.25.0", + "strenum>=0.4.15; python_version < '3.11'", + ] + description = "Workspace for Genkit packages" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "genkit-workspace" + readme = "README.md" + requires-python = ">=3.10" +@@ -60,7 +63,7 @@ dev = [ + "nox-uv>=0.2.2", + ] + +-lint = ["mypy>=1.15", "ruff>=0.9"] ++lint = ["ty>=0.0.1", "ruff>=0.9"] + + [tool.hatch.build.targets.wheel] + packages = [] +@@ -95,6 +98,7 @@ omit = [ + "**/typing.py", # Often auto-generated or complex types + "**/types.py", # Often auto-generated or complex types + ] ++source = ["packages", "plugins"] + + # uv based package management. + [tool.uv] +@@ -105,6 +109,7 @@ evaluator-demo = { workspace = true } + genkit = { workspace = true } + genkit-plugin-anthropic = { workspace = true } + genkit-plugin-compat-oai = { workspace = true } ++genkit-plugin-deepseek = { workspace = true } + genkit-plugin-dev-local-vectorstore = { workspace = true } + genkit-plugin-evaluators = { workspace = true } + genkit-plugin-firebase = { workspace = true } +@@ -198,29 +203,6 @@ line-ending = "lf" + quote-style = "single" + skip-magic-trailing-comma = false + +-# Static type checking. +-[tool.mypy] +-disallow_incomplete_defs = true +-disallow_untyped_defs = true +-exclude = ["samples/"] +-explicit_package_bases = true +-mypy_path = [ +- "packages/genkit/src", +- "plugins/chroma/src", +- "plugins/compat-oai/src", +- "plugins/dev-local-vectorstore/src", +- "plugins/firebase/src", +- "plugins/flask/src", +- "plugins/google-cloud/src", +- "plugins/google-genai/src", +- "plugins/ollama/src", +- "plugins/pinecone/src", +- "plugins/vertex-ai/src", +-] +-namespace_packages = true +-strict = true +-warn_unused_configs = true +- + [tool.datamodel-codegen] + #collapse-root-models = true # Don't use; produces Any as types. + #strict-types = ["str", "int", "float", "bool", "bytes"] # Don't use; produces StrictStr, StrictInt, etc. +diff --git a/py/samples/anthropic-hello/.gitignore b/py/samples/anthropic-hello/.gitignore +new file mode 100644 +index 000000000..7065f5d82 +--- /dev/null ++++ b/py/samples/anthropic-hello/.gitignore +@@ -0,0 +1,3 @@ ++.env ++ ++ +diff --git a/py/samples/anthropic-hello/env.example b/py/samples/anthropic-hello/env.example +new file mode 100644 +index 000000000..229d04e30 +--- /dev/null ++++ b/py/samples/anthropic-hello/env.example +@@ -0,0 +1,4 @@ ++# Copy this file to ".env" and fill in values. Do NOT commit ".env". ++ANTHROPIC_API_KEY=your-anthropic-api-key ++ ++ +diff --git a/py/samples/anthropic-hello/pyproject.toml b/py/samples/anthropic-hello/pyproject.toml +index 17ec62d43..6588a3f71 100644 +--- a/py/samples/anthropic-hello/pyproject.toml ++++ b/py/samples/anthropic-hello/pyproject.toml +@@ -15,6 +15,7 @@ + # SPDX-License-Identifier: Apache-2.0 + + [project] ++authors = [{ name = "Google" }] + dependencies = [ + "genkit", + "genkit-plugin-anthropic", +diff --git a/py/samples/anthropic-hello/run.sh b/py/samples/anthropic-hello/run.sh +index b3170f6ef..f36c1df92 100755 +--- a/py/samples/anthropic-hello/run.sh ++++ b/py/samples/anthropic-hello/run.sh +@@ -15,4 +15,14 @@ + # + # SPDX-License-Identifier: Apache-2.0 + ++set -euo pipefail ++ ++# Load local env if present (do not commit .env; see env.example) ++if [ -f ".env" ]; then ++ set -a ++ # shellcheck disable=SC1091 ++ . ".env" ++ set +a ++fi ++ + exec genkit start -- uv run src/main.py "$@" +diff --git a/py/samples/anthropic-hello/src/main.py b/py/samples/anthropic-hello/src/main.py +index fc38c2f52..8c8729cca 100755 +--- a/py/samples/anthropic-hello/src/main.py ++++ b/py/samples/anthropic-hello/src/main.py +@@ -195,7 +195,6 @@ async def say_hi_with_config(name: str) -> str: + + async def main() -> None: + """Main entry point for the Anthropic sample.""" +- + result = await say_hi('John Doe') + await logger.ainfo('Simple greeting', result=result) + +diff --git a/py/samples/compat-oai-hello/pyproject.toml b/py/samples/compat-oai-hello/pyproject.toml +index 2c9af3e41..6ec5e1f15 100644 +--- a/py/samples/compat-oai-hello/pyproject.toml ++++ b/py/samples/compat-oai-hello/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -41,7 +40,7 @@ dependencies = [ + "httpx>=0.28.1", + ] + description = "OpenAI sample" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "compat-oai-hello" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/samples/compat-oai-hello/src/main.py b/py/samples/compat-oai-hello/src/main.py +index 233c138b0..94effed2c 100755 +--- a/py/samples/compat-oai-hello/src/main.py ++++ b/py/samples/compat-oai-hello/src/main.py +@@ -212,7 +212,7 @@ async def get_weather_flow_stream(location: str) -> str: + + + class Skills(BaseModel): +- """A set of core character skills for an RPG character""" ++ """A set of core character skills for an RPG character.""" + + strength: int = Field(description='strength (0-100)') + charisma: int = Field(description='charisma (0-100)') +diff --git a/py/samples/deepseek-hello/README.md b/py/samples/deepseek-hello/README.md +new file mode 100644 +index 000000000..477f8ccc7 +--- /dev/null ++++ b/py/samples/deepseek-hello/README.md +@@ -0,0 +1,19 @@ ++## DeepSeek Sample ++ ++1. Setup environment and install dependencies: ++```bash ++uv venv ++source .venv/bin/activate ++ ++uv sync ++``` ++ ++2. Set DeepSeek API key (get one from [DeepSeek Platform](https://platform.deepseek.com/)): ++```bash ++export DEEPSEEK_API_KEY=your-api-key ++``` ++ ++3. Run the sample: ++```bash ++genkit start -- uv run src/main.py ++``` +diff --git a/py/samples/deepseek-hello/pyproject.toml b/py/samples/deepseek-hello/pyproject.toml +new file mode 100644 +index 000000000..cb48c544d +--- /dev/null ++++ b/py/samples/deepseek-hello/pyproject.toml +@@ -0,0 +1,38 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++[project] ++authors = [{ name = "Google" }] ++dependencies = [ ++ "genkit", ++ "genkit-plugin-deepseek", ++ "pydantic>=2.0.0", ++ "structlog>=24.0.0", ++] ++description = "DeepSeek Hello Sample" ++name = "deepseek-hello" ++requires-python = ">=3.10" ++version = "0.1.0" ++ ++[tool.uv.sources] ++genkit-plugin-deepseek = { workspace = true } ++ ++[build-system] ++build-backend = "hatchling.build" ++requires = ["hatchling"] ++ ++[tool.hatch.build.targets.wheel] ++packages = ["src"] +diff --git a/py/samples/deepseek-hello/run.sh b/py/samples/deepseek-hello/run.sh +new file mode 100755 +index 000000000..02a864050 +--- /dev/null ++++ b/py/samples/deepseek-hello/run.sh +@@ -0,0 +1,18 @@ ++#!/usr/bin/env bash ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++exec genkit start -- uv run src/main.py "$@" +diff --git a/py/samples/deepseek-hello/src/main.py b/py/samples/deepseek-hello/src/main.py +new file mode 100644 +index 000000000..bfc714d43 +--- /dev/null ++++ b/py/samples/deepseek-hello/src/main.py +@@ -0,0 +1,279 @@ ++# Copyright 2026 Google LLC ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++# ++# SPDX-License-Identifier: Apache-2.0 ++ ++"""DeepSeek hello sample. ++ ++Key features demonstrated in this sample: ++ ++| Feature Description | Example Function / Code Snippet | ++|-----------------------------------------|-----------------------------------------| ++| Plugin Initialization | `ai = Genkit(plugins=[DeepSeek(...)])` | ++| Default Model Configuration | `ai = Genkit(model=deepseek_name(...))` | ++| Defining Flows | `@ai.flow()` decorator | ++| Defining Tools | `@ai.tool()` decorator | ++| Pydantic for Tool Input Schema | `WeatherInput` | ++| Simple Generation (Prompt String) | `say_hi` | ++| Streaming Response | `streaming_flow` | ++| Generation with Tools | `weather_flow` | ++| Reasoning Model (deepseek-reasoner) | `reasoning_flow` | ++| Generation with Config | `custom_config_flow` | ++| Multi-turn Chat | `chat_flow` | ++""" ++ ++import structlog ++from pydantic import BaseModel, Field ++ ++from genkit.ai import Genkit ++from genkit.core.action import ActionRunContext ++from genkit.plugins.deepseek import DeepSeek, deepseek_name ++from genkit.types import Message, Part, Role, TextPart, ToolResponse ++ ++logger = structlog.get_logger(__name__) ++ ++ai = Genkit( ++ plugins=[DeepSeek()], ++ model=deepseek_name('deepseek-chat'), ++) ++ ++ ++class WeatherInput(BaseModel): ++ """Input schema for the weather tool.""" ++ ++ location: str = Field(description='The city and state, e.g. San Francisco, CA') ++ ++ ++@ai.tool() ++def get_weather(input: WeatherInput) -> str: ++ """Get weather of a location, the user should supply a location first. ++ ++ Args: ++ input: Weather input with location (city and state, e.g. San Francisco, CA). ++ ++ Returns: ++ Weather information with temperature in degrees Fahrenheit. ++ """ ++ # Mocked weather data ++ weather_data = { ++ 'San Francisco, CA': {'temp': 72, 'condition': 'sunny', 'humidity': 65}, ++ 'Seattle, WA': {'temp': 55, 'condition': 'rainy', 'humidity': 85}, ++ } ++ ++ location = input.location ++ data = weather_data.get(location, {'temp': 70, 'condition': 'partly cloudy', 'humidity': 55}) ++ ++ return f'The weather in {location} is {data["temp"]}°F and {data["condition"]}. Humidity is {data["humidity"]}%.' ++ ++ ++@ai.flow() ++async def say_hi(name: str) -> str: ++ """Generate a simple greeting. ++ ++ Args: ++ name: Name to greet. ++ ++ Returns: ++ Greeting message. ++ """ ++ response = await ai.generate(prompt=f'Say hello to {name}!') ++ return response.text ++ ++ ++@ai.flow() ++async def streaming_flow(topic: str, ctx: ActionRunContext) -> str: ++ """Generate with streaming response. ++ ++ Args: ++ topic: Topic to generate about. ++ ctx: Action run context for streaming chunks to client. ++ ++ Returns: ++ Generated text. ++ """ ++ response = await ai.generate( ++ prompt=f'Tell me a fun fact about {topic}', ++ on_chunk=ctx.send_chunk, ++ ) ++ return response.text ++ ++ ++@ai.flow() ++async def weather_flow(location: str) -> str: ++ """Get weather using compat-oai auto tool calling.""" ++ ++ response = await ai.generate( ++ model=deepseek_name('deepseek-chat'), ++ prompt=f'What is the weather in {location}?', ++ system=( ++ 'You have a tool called get_weather. ' ++ "It takes an object with a 'location' field. " ++ 'Always use this tool when asked about weather.' ++ ), ++ tools=['get_weather'], ++ tool_choice='required', ++ max_turns=2, ++ ) ++ ++ return response.text ++ ++ ++@ai.flow() ++async def reasoning_flow(prompt: str | None = None) -> str: ++ """Solve reasoning problems using deepseek-reasoner model. ++ ++ Args: ++ prompt: The reasoning question to solve. Defaults to a classic logic problem. ++ ++ Returns: ++ The reasoning and answer. ++ """ ++ if prompt is None: ++ prompt = 'What is heavier, one kilo of steel or one kilo of feathers?' ++ ++ response = await ai.generate( ++ model=deepseek_name('deepseek-reasoner'), ++ prompt=prompt, ++ ) ++ return response.text ++ ++ ++@ai.flow() ++async def custom_config_flow(task: str | None = None) -> str: ++ """Demonstrate custom model configurations for different tasks. ++ ++ Shows how different config parameters affect generation behavior: ++ - 'creative': High temperature for diverse, creative outputs ++ - 'precise': Low temperature with penalties for consistent, focused outputs ++ - 'detailed': Extended output with frequency penalty to avoid repetition ++ ++ Args: ++ task: Type of task - 'creative', 'precise', or 'detailed' ++ ++ Returns: ++ Generated response showing the effect of different configs. ++ """ ++ if task is None: ++ task = 'creative' ++ ++ prompts = { ++ 'creative': 'Write a creative story opener about a robot discovering art', ++ 'precise': 'List the exact steps to make a cup of tea', ++ 'detailed': 'Explain how photosynthesis works in detail', ++ } ++ ++ configs = { ++ 'creative': { ++ 'temperature': 1.5, # High temperature for creativity ++ 'max_tokens': 200, ++ 'top_p': 0.95, ++ }, ++ 'precise': { ++ 'temperature': 0.1, # Low temperature for consistency ++ 'max_tokens': 150, ++ 'presence_penalty': 0.5, # Encourage covering all steps ++ }, ++ 'detailed': { ++ 'temperature': 0.7, ++ 'max_tokens': 400, # More tokens for detailed explanation ++ 'frequency_penalty': 0.8, # Reduce repetitive phrasing ++ }, ++ } ++ ++ prompt = prompts.get(task, prompts['creative']) ++ config = configs.get(task, configs['creative']) ++ ++ response = await ai.generate( ++ prompt=prompt, ++ config=config, ++ ) ++ return response.text ++ ++ ++@ai.flow() ++async def chat_flow() -> str: ++ """Multi-turn chat example demonstrating context retention. ++ ++ Returns: ++ Final chat response. ++ """ ++ history = [] ++ ++ # First turn - User shares information ++ prompt1 = "Hi! I'm planning a trip to Tokyo next month. I'm really excited because I love Japanese cuisine, especially ramen and sushi." ++ response1 = await ai.generate( ++ prompt=prompt1, ++ system='You are a helpful travel assistant.', ++ ) ++ history.append(Message(role=Role.USER, content=[TextPart(text=prompt1)])) ++ history.append(response1.message) ++ await logger.ainfo('chat_flow turn 1', result=response1.text) ++ ++ # Second turn - Ask question requiring context from first turn ++ response2 = await ai.generate( ++ messages=history + [Message(role=Role.USER, content=[TextPart(text='What foods did I say I enjoy?')])], ++ system='You are a helpful travel assistant.', ++ ) ++ history.append(Message(role=Role.USER, content=[TextPart(text='What foods did I say I enjoy?')])) ++ history.append(response2.message) ++ await logger.ainfo('chat_flow turn 2', result=response2.text) ++ ++ # Third turn - Ask question requiring context from both previous turns ++ response3 = await ai.generate( ++ messages=history ++ + [ ++ Message( ++ role=Role.USER, ++ content=[TextPart(text='Based on our conversation, suggest one restaurant I should visit.')], ++ ) ++ ], ++ system='You are a helpful travel assistant.', ++ ) ++ return response3.text ++ ++ ++async def main() -> None: ++ """Main entry point for the DeepSeek sample.""" ++ # Simple greeting ++ result = await say_hi('World') ++ await logger.ainfo('say_hi', result=result) ++ ++ # Streaming response ++ result = await streaming_flow('apple') ++ await logger.ainfo('streaming_flow', result=result) ++ ++ # Weather with tools ++ result = await weather_flow('Seattle, WA') ++ await logger.ainfo('weather_flow', result=result) ++ ++ # Reasoning model ++ result = await reasoning_flow() ++ await logger.ainfo('reasoning_flow', result=result) ++ ++ # Custom config - demonstrate different configurations ++ await logger.ainfo('Testing creative config...') ++ result = await custom_config_flow('creative') ++ await logger.ainfo('custom_config_flow (creative)', result=result) ++ ++ await logger.ainfo('Testing precise config...') ++ result = await custom_config_flow('precise') ++ await logger.ainfo('custom_config_flow (precise)', result=result) ++ ++ # Multi-turn chat ++ result = await chat_flow() ++ await logger.ainfo('chat_flow', result=result) ++ ++ ++if __name__ == '__main__': ++ ai.run_main(main()) +diff --git a/py/samples/dev-local-vectorstore-hello/pyproject.toml b/py/samples/dev-local-vectorstore-hello/pyproject.toml +index dafc44eba..3a8a5911d 100644 +--- a/py/samples/dev-local-vectorstore-hello/pyproject.toml ++++ b/py/samples/dev-local-vectorstore-hello/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -41,7 +40,7 @@ dependencies = [ + "structlog>=25.2.0", + ] + description = "hello Genkit sample" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "dev-local-vectorstore-hello" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/samples/dev-local-vectorstore-hello/src/main.py b/py/samples/dev-local-vectorstore-hello/src/main.py +index 9cbb91670..ea006d1ca 100755 +--- a/py/samples/dev-local-vectorstore-hello/src/main.py ++++ b/py/samples/dev-local-vectorstore-hello/src/main.py +@@ -27,7 +27,7 @@ ai = Genkit( + embedder='vertexai/text-embedding-004', + ), + ], +- model='vertexai/gemini-2.5-flash', ++ model='vertexai/gemini-3-flash-preview', + ) + + films = [ +diff --git a/py/samples/evaluator-demo/pyproject.toml b/py/samples/evaluator-demo/pyproject.toml +index a80d91cc8..771c96674 100644 +--- a/py/samples/evaluator-demo/pyproject.toml ++++ b/py/samples/evaluator-demo/pyproject.toml +@@ -15,6 +15,7 @@ + # SPDX-License-Identifier: Apache-2.0 + + [project] ++authors = [{ name = "Google" }] + dependencies = ["genkit", "pydantic>=2.0.0", "structlog>=24.0.0", "pypdf"] + description = "Genkit Python Evaluation Demo" + name = "eval-demo" +diff --git a/py/samples/evaluator-demo/src/genkit_demo.py b/py/samples/evaluator-demo/src/genkit_demo.py +index 0ccd8f64e..618b98c8c 100644 +--- a/py/samples/evaluator-demo/src/genkit_demo.py ++++ b/py/samples/evaluator-demo/src/genkit_demo.py +@@ -47,17 +47,17 @@ ai = Genkit( + GenkitEvaluators([ + MetricConfig( + metric_type=GenkitMetricType.MALICIOUSNESS, +- judge=ModelReference(name='googleai/gemini-2.5-pro'), ++ judge=ModelReference(name='googleai/gemini-3-pro-preview'), + judge_config=PERMISSIVE_SAFETY_SETTINGS, + ), + MetricConfig( + metric_type=GenkitMetricType.ANSWER_RELEVANCY, +- judge=ModelReference(name='googleai/gemini-2.5-pro'), ++ judge=ModelReference(name='googleai/gemini-3-pro-preview'), + judge_config=PERMISSIVE_SAFETY_SETTINGS, + ), + MetricConfig( + metric_type=GenkitMetricType.FAITHFULNESS, +- judge=ModelReference(name='googleai/gemini-2.5-pro'), ++ judge=ModelReference(name='googleai/gemini-3-pro-preview'), + judge_config=PERMISSIVE_SAFETY_SETTINGS, + ), + ]), +diff --git a/py/samples/evaluator-demo/src/main.py b/py/samples/evaluator-demo/src/main.py +index e3961390c..605304688 100755 +--- a/py/samples/evaluator-demo/src/main.py ++++ b/py/samples/evaluator-demo/src/main.py +@@ -16,10 +16,7 @@ + + import random + +-from eval_in_code import dog_facts_eval_flow + from genkit_demo import ai +-from pdf_rag import index_pdf, pdf_qa, simple_echo, simple_structured +-from setup import setup + + from genkit.core.typing import BaseEvalDataPoint, EvalStatusEnum, Score + +diff --git a/py/samples/firestore-retreiver/pyproject.toml b/py/samples/firestore-retreiver/pyproject.toml +index 485dea588..1f856710f 100644 +--- a/py/samples/firestore-retreiver/pyproject.toml ++++ b/py/samples/firestore-retreiver/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -35,7 +34,7 @@ classifiers = [ + ] + dependencies = ["genkit", "google-cloud-firestore"] + description = "firestore-retreiver Genkit sample" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "firestore-retreiver" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/samples/flask-hello/pyproject.toml b/py/samples/flask-hello/pyproject.toml +index 9397e7d59..e07d52625 100644 +--- a/py/samples/flask-hello/pyproject.toml ++++ b/py/samples/flask-hello/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -40,7 +39,7 @@ dependencies = [ + "flask", + ] + description = "hello Genkit sample" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "flask-hello" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/samples/flask-hello/src/main.py b/py/samples/flask-hello/src/main.py +index fa18efe49..226e8a33f 100755 +--- a/py/samples/flask-hello/src/main.py ++++ b/py/samples/flask-hello/src/main.py +@@ -28,7 +28,7 @@ from genkit.plugins.google_genai.models.gemini import GoogleAIGeminiVersion + + ai = Genkit( + plugins=[GoogleAI()], +- model=googleai_name(GoogleAIGeminiVersion.GEMINI_2_0_FLASH), ++ model=googleai_name(GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), + ) + + app = Flask(__name__) +diff --git a/py/samples/google-genai-code-execution/pyproject.toml b/py/samples/google-genai-code-execution/pyproject.toml +index d5dfa8f2d..267cf9ad6 100644 +--- a/py/samples/google-genai-code-execution/pyproject.toml ++++ b/py/samples/google-genai-code-execution/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -40,7 +39,7 @@ dependencies = [ + "structlog>=25.2.0", + ] + description = "Code execution sample" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "google-genai-code-execution" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/samples/google-genai-code-execution/src/main.py b/py/samples/google-genai-code-execution/src/main.py +index 4b554a703..bcb1a01cd 100755 +--- a/py/samples/google-genai-code-execution/src/main.py ++++ b/py/samples/google-genai-code-execution/src/main.py +@@ -27,7 +27,7 @@ logger = structlog.get_logger(__name__) + + ai = Genkit( + plugins=[GoogleAI()], +- model=googleai_name('gemini-2.5-flash'), ++ model=googleai_name('gemini-3-flash-preview'), + ) + + +diff --git a/py/samples/google-genai-context-caching/pyproject.toml b/py/samples/google-genai-context-caching/pyproject.toml +index 17035a9a7..a1008ab42 100644 +--- a/py/samples/google-genai-context-caching/pyproject.toml ++++ b/py/samples/google-genai-context-caching/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -41,7 +40,7 @@ dependencies = [ + "structlog>=25.2.0", + ] + description = "context-caching Genkit sample" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "google-genai-context-caching" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/samples/google-genai-context-caching/src/main.py b/py/samples/google-genai-context-caching/src/main.py +index ddf51f322..1b1727b69 100755 +--- a/py/samples/google-genai-context-caching/src/main.py ++++ b/py/samples/google-genai-context-caching/src/main.py +@@ -14,7 +14,7 @@ + # + # SPDX-License-Identifier: Apache-2.0 + +-"""Sample that demonstrates caching of generation context in Genkit ++"""Sample that demonstrates caching of generation context in Genkit. + + In this sample user actor supplies "Tom Sawyer" book content from Gutenberg library archive + and model caches this context. +@@ -34,7 +34,7 @@ logger = structlog.getLogger(__name__) + + ai = Genkit( + plugins=[GoogleAI()], +- model=googleai_name(GoogleAIGeminiVersion.GEMINI_1_5_FLASH), ++ model=googleai_name(GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), + ) + + # Tom Sawyer is taken as a sample book here +@@ -67,7 +67,7 @@ async def text_context_flow(_input: BookContextInputSchema) -> str: + ), + Message( + role=Role.MODEL, +- content=[TextPart(text=f'Here is some analysis based on the text provided.')], ++ content=[TextPart(text='Here is some analysis based on the text provided.')], + metadata={ + 'cache': { + 'ttl_seconds': 300, +@@ -76,7 +76,7 @@ async def text_context_flow(_input: BookContextInputSchema) -> str: + ), + ], + config=GenerationCommonConfig( +- version='gemini-1.5-flash-001', ++ version='gemini-3-flash-preview', + temperature=0.7, + maxOutputTokens=1000, + topK=50, +diff --git a/py/samples/google-genai-hello/pyproject.toml b/py/samples/google-genai-hello/pyproject.toml +index c85fd1db2..99204b2eb 100644 +--- a/py/samples/google-genai-hello/pyproject.toml ++++ b/py/samples/google-genai-hello/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -42,7 +41,7 @@ dependencies = [ + "structlog>=25.2.0", + ] + description = "Hello world sample" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "google-genai-hello" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/samples/google-genai-hello/src/main.py b/py/samples/google-genai-hello/src/main.py +index 2c8da83e9..de7eb51d2 100755 +--- a/py/samples/google-genai-hello/src/main.py ++++ b/py/samples/google-genai-hello/src/main.py +@@ -82,7 +82,7 @@ ai = Genkit( + ]) + ), + ], +- model='googleai/gemini-2.5-flash', ++ model='googleai/gemini-3-flash-preview', + ) + + +diff --git a/py/samples/google-genai-image/pyproject.toml b/py/samples/google-genai-image/pyproject.toml +index b26401427..5ddaab77e 100644 +--- a/py/samples/google-genai-image/pyproject.toml ++++ b/py/samples/google-genai-image/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -40,7 +39,7 @@ dependencies = [ + "pydantic>=2.10.5", + ] + description = "Vision API and Image Generation example" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "google-genai-image" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/samples/google-genai-image/src/main.py b/py/samples/google-genai-image/src/main.py +index 05f6dbbb0..849a7a0c8 100755 +--- a/py/samples/google-genai-image/src/main.py ++++ b/py/samples/google-genai-image/src/main.py +@@ -16,7 +16,6 @@ + + """This sample demonstrates how to use Gemini to describe and draw images.""" + +-import asyncio + import base64 + import os + from io import BytesIO +@@ -40,7 +39,7 @@ async def draw_image_with_gemini() -> str: + return await ai.generate( + prompt='Draw a cat in a hat.', + config={'response_modalities': ['Text', 'Image']}, +- model=googleai_name('gemini-2.5-flash'), ++ model=googleai_name('gemini-2.5-flash-image'), + ) + + +@@ -49,22 +48,25 @@ async def describe_image_with_gemini(data: str) -> str: + """Describe an image. + + Args: +- data: The image to describe. ++ data: The image data as a data URI (e.g., 'data:image/jpeg;base64,...'). + + Returns: + The description of the image. + """ ++ if not (data.startswith('data:') and ',' in data): ++ raise ValueError(f'Expected a data URI (e.g., "data:image/jpeg;base64,..."), but got: {data[:50]}...') ++ + result = await ai.generate( + messages=[ + Message( + role=Role.USER, + content=[ + TextPart(text='What is shown in this image?'), +- MediaPart(media=Media(contentType='image/jpeg', url=data)), ++ MediaPart(media=Media(content_type='image/jpeg', url=data)), + ], + ), + ], +- model=googleai_name('gemini-2.5-flash'), ++ model=googleai_name('gemini-3-flash-preview'), + ) + return result.text + +@@ -79,12 +81,25 @@ async def main() -> None: + with open(image_path, 'rb') as image_file: + buffer = image_file.read() + img_base64 = base64.b64encode(buffer).decode('utf-8') +- print(await describe_image_with_gemini(img_base64)) ++ data_uri = f'data:image/jpeg;base64,{img_base64}' ++ print(await describe_image_with_gemini(data_uri)) + + # Gemini draws an image by description. The model used is available only in + # Gemini API. + result = await draw_image_with_gemini() +- decoded_image = BytesIO(base64.b64decode(result.message.content[0].root.media.url)) ++ ++ # Find the media part in the response ++ media_part = next((part.root.media for part in result.message.content if part.root.media is not None), None) ++ ++ if media_part is None: ++ print('No media found in response') ++ print(f'Response content: {result.message.content}') ++ return ++ ++ media_url = media_part.url ++ # Extract base64 data after the comma in "data:image/png;base64,..." ++ base64_data = media_url.split(',', 1)[1] ++ decoded_image = BytesIO(base64.b64decode(base64_data)) + image = Image.open(decoded_image) + image.show('Image generated by Gemini') + +diff --git a/py/samples/google-genai-vertexai-hello/pyproject.toml b/py/samples/google-genai-vertexai-hello/pyproject.toml +index d2a14f41e..3ffa5f524 100644 +--- a/py/samples/google-genai-vertexai-hello/pyproject.toml ++++ b/py/samples/google-genai-vertexai-hello/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -40,7 +39,7 @@ dependencies = [ + "structlog>=25.2.0", + ] + description = "Hello world sample on VertexAI API on GenAI" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "google-genai-vertexai-hello" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/samples/google-genai-vertexai-hello/src/main.py b/py/samples/google-genai-vertexai-hello/src/main.py +index f73041ee7..fe25ec1ba 100755 +--- a/py/samples/google-genai-vertexai-hello/src/main.py ++++ b/py/samples/google-genai-vertexai-hello/src/main.py +@@ -49,7 +49,6 @@ from genkit.plugins.google_genai import ( + EmbeddingTaskType, + VertexAI, + ) +-from genkit.plugins.google_genai.models import gemini + from genkit.types import ( + GenerationCommonConfig, + Message, +@@ -61,7 +60,7 @@ logger = structlog.get_logger(__name__) + + ai = Genkit( + plugins=[VertexAI()], +- model='vertexai/gemini-2.5-flash', ++ model='vertexai/gemini-3-flash-preview', + ) + + +diff --git a/py/samples/google-genai-vertexai-image/README.md b/py/samples/google-genai-vertexai-image/README.md +index c0d5dba43..244a1be23 100644 +--- a/py/samples/google-genai-vertexai-image/README.md ++++ b/py/samples/google-genai-vertexai-image/README.md +@@ -9,12 +9,20 @@ Prerequisites: + * A Google Cloud account with access to VertexAI service. + * The `genkit` package. + +-To run this sample: ++## Setup environment + + 1. Install the `genkit` package. +-2. Install [GCP CLI](https://cloud.google.com/sdk/docs/install) +-3. Put your GCP project and location in the code to run VertexAI there. +-4. Run the sample. ++2. Install [GCP CLI](https://cloud.google.com/sdk/docs/install). ++3. Add your project to Google Cloud. Run the following code to log in and set up the configuration. ++```bash ++export GOOGLE_CLOUD_LOCATION=global ++export GOOGLE_CLOUD_PROJECT=your-GCP-project-ID ++gcloud init ++``` ++4. Run the following code to connect to VertexAI. ++```bash ++gcloud auth application-default login ++``` + + ## Run the sample + +diff --git a/py/samples/google-genai-vertexai-image/pyproject.toml b/py/samples/google-genai-vertexai-image/pyproject.toml +index 232c40b48..37a817372 100644 +--- a/py/samples/google-genai-vertexai-image/pyproject.toml ++++ b/py/samples/google-genai-vertexai-image/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -40,7 +39,7 @@ dependencies = [ + "pydantic>=2.10.5", + ] + description = "Image Generation on VertexAI with GenAI library example" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "google-genai-vertexai-image" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/samples/google-genai-vertexai-image/src/main.py b/py/samples/google-genai-vertexai-image/src/main.py +index 9ad6bf453..4bdcc51b5 100755 +--- a/py/samples/google-genai-vertexai-image/src/main.py ++++ b/py/samples/google-genai-vertexai-image/src/main.py +@@ -16,7 +16,6 @@ + + """This sample demonstrates how to use Gemini VertexAI to describe and draw images.""" + +-import asyncio + import base64 + from io import BytesIO + +@@ -55,7 +54,10 @@ async def main() -> None: + # Imagen draws an image by description. The model used is available only in + # VertexAI API. + result = await draw_image_with_imagen() +- decoded_image = BytesIO(base64.b64decode(result.message.content[0].root.media.url)) ++ media_url = result.message.content[0].root.media.url ++ # Extract base64 data after the comma in "data:image/png;base64,..." ++ base64_data = media_url.split(',', 1)[1] ++ decoded_image = BytesIO(base64.b64decode(base64_data)) + image = Image.open(decoded_image) + image.show('Image generated by Gemini') + +diff --git a/py/samples/menu/pyproject.toml b/py/samples/menu/pyproject.toml +index 1ec32f5de..7ba7975d4 100644 +--- a/py/samples/menu/pyproject.toml ++++ b/py/samples/menu/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -44,7 +43,7 @@ dependencies = [ + "pydantic>=2.10.5", + ] + description = "menu Genkit sample" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "menu" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/samples/menu/src/__init__.py b/py/samples/menu/src/__init__.py +index 1e870fbb3..526fdbd69 100644 +--- a/py/samples/menu/src/__init__.py ++++ b/py/samples/menu/src/__init__.py +@@ -15,27 +15,15 @@ + # SPDX-License-Identifier: Apache-2.0 + + # 01 +-from case_01.prompts import s01_staticMenuDotPrompt, s01_vanillaPrompt +-from case_02.flows import s02_menuQuestionFlow +-from case_02.prompts import s02_dataMenuPrompt + + # 02 +-from case_02.tools import menu_tool + + # 03 +-from case_03.flows import s03_multiTurnChatFlow +-from case_03.prompts import s03_chatPreamblePrompt + + # 04 + # TODO: uncomment once implemented + # from case_04.flows import s04_indexMenuItemsFlow, s04_ragMenuQuestionFlow + # from case_04.prompts import s04_ragDataMenuPrompt + # 05 +-from case_05.flows import ( +- s05_readMenuFlow, +- s05_textMenuQuestionFlow, +- s05_visionMenuQuestionFlow, +-) +-from case_05.prompts import s05_readMenuPrompt, s05_textMenuPrompt + + print('All prompts and flows loaded, use the Developer UI to test them out') +diff --git a/py/samples/menu/src/case_01/prompts.py b/py/samples/menu/src/case_01/prompts.py +index afc8e1c6e..d800ac67e 100644 +--- a/py/samples/menu/src/case_01/prompts.py ++++ b/py/samples/menu/src/case_01/prompts.py +@@ -16,21 +16,24 @@ + from menu_ai import ai + from menu_schemas import MenuQuestionInputSchema + +-from genkit.plugins.google_genai import google_genai_name +-from genkit.plugins.google_genai.models.gemini import GeminiVersion ++from genkit.plugins.google_genai import googleai_name ++from genkit.plugins.google_genai.models.gemini import GoogleAIGeminiVersion + + s01_vanillaPrompt = ai.define_prompt( + variant='s01_vanillaPrompt', ++ model=googleai_name(GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), + input_schema=MenuQuestionInputSchema, +- system="""You are acting as a helpful AI assistant named "Walt" that can answer questions about the food available on the menu at Walt's Burgers.""", ++ prompt="""You are acting as a helpful AI assistant named "Walt" that can answer ++questions about the food available on the menu at Walt's Burgers. ++Customer says: {{question}}""", + config={'temperature': 0.3}, + ) + + s01_staticMenuDotPrompt = ai.define_prompt( + variant='s01_staticMenuDotPrompt', +- model=google_genai_name(GeminiVersion.GEMINI_1_5_FLASH), ++ model=googleai_name(GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), + input_schema=MenuQuestionInputSchema, +- system=""" ++ prompt=""" + You are acting as a helpful AI assistant named "Walt" that can answer + questions about the food available on the menu at Walt's Burgers. + Here is today's menu: +diff --git a/py/samples/menu/src/case_02/flows.py b/py/samples/menu/src/case_02/flows.py +index a3faa04f2..b0ef54f9d 100644 +--- a/py/samples/menu/src/case_02/flows.py ++++ b/py/samples/menu/src/case_02/flows.py +@@ -27,5 +27,5 @@ async def s02_menuQuestionFlow( + ) -> AnswerOutputSchema: + text = await s02_dataMenuPrompt({'question': my_input.question}) + return AnswerOutputSchema( +- answer=text, ++ answer=text.text, + ) +diff --git a/py/samples/menu/src/case_02/prompts.py b/py/samples/menu/src/case_02/prompts.py +index d98cf0ae5..1cd9084ae 100644 +--- a/py/samples/menu/src/case_02/prompts.py ++++ b/py/samples/menu/src/case_02/prompts.py +@@ -16,15 +16,15 @@ + from menu_ai import ai + from menu_schemas import MenuQuestionInputSchema + +-from genkit.plugins.google_genai import google_genai_name +-from genkit.plugins.google_genai.models.gemini import GeminiVersion ++from genkit.plugins.google_genai import googleai_name ++from genkit.plugins.google_genai.models.gemini import GoogleAIGeminiVersion + + s02_dataMenuPrompt = ai.define_prompt( + variant='s02_dataMenu', +- model=google_genai_name(GeminiVersion.GEMINI_1_5_FLASH), ++ model=googleai_name(GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), + input_schema=MenuQuestionInputSchema, +- tools=['menu_tool'], +- system="""You are acting as a helpful AI assistant named Walt that can answer ++ tools=['todaysMenu'], ++ prompt="""You are acting as a helpful AI assistant named Walt that can answer + questions about the food available on the menu at Walt's Burgers. + + Answer this customer's question, in a concise and helpful manner, +diff --git a/py/samples/menu/src/case_02/tools.py b/py/samples/menu/src/case_02/tools.py +index 3d863aabf..34719a1d3 100644 +--- a/py/samples/menu/src/case_02/tools.py ++++ b/py/samples/menu/src/case_02/tools.py +@@ -26,8 +26,8 @@ with open(menu_json_path) as f: + menu_data = json.load(f) + + +-@ai.tool(name='menu_tool') +-def menu_tool(input=None) -> MenuToolOutputSchema: ++@ai.tool(name='todaysMenu') ++def todaysMenu(input=None) -> MenuToolOutputSchema: + """Use this tool to retrieve all the items on today's menu.""" + return MenuToolOutputSchema( + menu_data=menu_data, +diff --git a/py/samples/menu/src/case_03/flows.py b/py/samples/menu/src/case_03/flows.py +index be70422d0..430fc11aa 100644 +--- a/py/samples/menu/src/case_03/flows.py ++++ b/py/samples/menu/src/case_03/flows.py +@@ -26,14 +26,14 @@ from case_03.chats import ( + from menu_ai import ai + + from genkit.core.typing import Message, Role, TextPart +-from genkit.plugins.google_genai import google_genai_name +-from genkit.plugins.google_genai.models.gemini import GeminiVersion ++from genkit.plugins.google_genai import googleai_name ++from genkit.plugins.google_genai.models.gemini import GoogleAIGeminiVersion as GeminiVersion + + menu_json_path = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'menu.json') + with open(menu_json_path) as f: + menu_data = json.load(f) + +-formatted_menu_data = '\n'.join([f'- ${r["title"]} ${r["price"]}\n${r["description"]}' for r in menu_data]) ++formatted_menu_data = '\n'.join([f'- {r["title"]} ${r["price"]}\n{r["description"]}' for r in menu_data]) + + preamble = [ + Message( +@@ -43,13 +43,15 @@ preamble = [ + ], + ), + Message( +- role=Role.USER, ++ role=Role.MODEL, + content=[ + TextPart( +- text=f"""I am Walt, a helpful AI assistant here at the restaurant.\n' + +- 'I can answer questions about the food on the menu or any other questions\n' + +- "you have about food in general. I probably can't help you with anything else.\n" + +- "Here is today's menu: \n {formatted_menu_data}\nDo you have any questions about the menu?""" ++ text=f"""I am Walt, a helpful AI assistant here at the restaurant. ++I can answer questions about the food on the menu or any other questions ++you have about food in general. I probably can't help you with anything else. ++Here is today's menu: ++{formatted_menu_data} ++Do you have any questions about the menu?""" + ), + ], + ), +@@ -67,7 +69,7 @@ async def s03_multiTurnChatFlow( + history = chat_history_store.read(my_input.session_id) + + llm_response = await ai.generate( +- model=google_genai_name(GeminiVersion.GEMINI_1_5_FLASH), ++ model=googleai_name(GeminiVersion.GEMINI_3_FLASH_PREVIEW), + messages=history, + prompt=[TextPart(text=my_input.question)], + ) +diff --git a/py/samples/menu/src/case_03/prompts.py b/py/samples/menu/src/case_03/prompts.py +index 847c0ff8e..56ad18a4c 100644 +--- a/py/samples/menu/src/case_03/prompts.py ++++ b/py/samples/menu/src/case_03/prompts.py +@@ -17,12 +17,12 @@ + from menu_ai import ai + from menu_schemas import DataMenuQuestionInputSchema + +-from genkit.plugins.google_genai import google_genai_name +-from genkit.plugins.google_genai.models.gemini import GeminiVersion ++from genkit.plugins.google_genai import googleai_name ++from genkit.plugins.google_genai.models.gemini import GoogleAIGeminiVersion + + s03_chatPreamblePrompt = ai.define_prompt( + variant='s03_chatPreamble', +- model=google_genai_name(GeminiVersion.GEMINI_1_5_FLASH), ++ model=googleai_name(GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), + input_schema=DataMenuQuestionInputSchema, + config={'temperature': 0.3}, + system="""{{ role "user" }} +diff --git a/py/samples/menu/src/case_04/flows.py b/py/samples/menu/src/case_04/flows.py +index 27f506da3..6a717e038 100644 +--- a/py/samples/menu/src/case_04/flows.py ++++ b/py/samples/menu/src/case_04/flows.py +@@ -1,4 +1,4 @@ +-# Copyright 2025 Google LLC ++# Copyright 2026 Google LLC + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. +@@ -14,5 +14,48 @@ + # + # SPDX-License-Identifier: Apache-2.0 + ++from menu_ai import ai ++from menu_schemas import AnswerOutputSchema, MenuItemSchema, MenuQuestionInputSchema ++from pydantic import BaseModel, Field + +-# TODO: implement it once Genkit AI will have index API ++from genkit.blocks.document import Document ++ ++from .prompts import s04_ragDataMenuPrompt ++ ++ ++class IndexMenuItemsOutputSchema(BaseModel): ++ rows: int = Field(...) ++ ++ ++@ai.flow(name='s04_indexMenuItems') ++async def s04_indexMenuItemsFlow( ++ menu_items: list[MenuItemSchema], ++) -> IndexMenuItemsOutputSchema: ++ documents = [ ++ Document.from_text(f'{item.title} {item.price} \n {item.description}', metadata=item.model_dump()) ++ for item in menu_items ++ ] ++ ++ await ai.index( ++ indexer='menu-items', ++ documents=documents, ++ ) ++ return IndexMenuItemsOutputSchema(rows=len(menu_items)) ++ ++ ++@ai.flow(name='s04_ragMenuQuestion') ++async def s04_ragMenuQuestionFlow( ++ my_input: MenuQuestionInputSchema, ++) -> AnswerOutputSchema: ++ # Retrieve the 3 most relevant menu items for the question ++ docs = await ai.retrieve( ++ retriever='menu-items', ++ query=my_input.question, ++ options={'k': 3}, ++ ) ++ ++ menu_data = [doc.metadata for doc in docs.documents] ++ ++ # Generate the response ++ response = await s04_ragDataMenuPrompt({'menuData': menu_data, 'question': my_input.question}) ++ return AnswerOutputSchema(answer=response.text) +diff --git a/py/samples/menu/src/case_04/prompts.py b/py/samples/menu/src/case_04/prompts.py +index 501480072..8fb88a34b 100644 +--- a/py/samples/menu/src/case_04/prompts.py ++++ b/py/samples/menu/src/case_04/prompts.py +@@ -16,15 +16,15 @@ + from menu_ai import ai + from menu_schemas import DataMenuQuestionInputSchema + +-from genkit.plugins.google_genai import google_genai_name +-from genkit.plugins.google_genai.models.gemini import GeminiVersion ++from genkit.plugins.google_genai import googleai_name ++from genkit.plugins.google_genai.models.gemini import GoogleAIGeminiVersion + + s04_ragDataMenuPrompt = ai.define_prompt( + variant='s04_ragDataMenu', +- model=google_genai_name(GeminiVersion.GEMINI_1_5_FLASH), ++ model=googleai_name(GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), + input_schema=DataMenuQuestionInputSchema, + config={'temperature': 0.3}, +- system=""" ++ prompt=""" + You are acting as Walt, a helpful AI assistant here at the restaurant. + You can answer questions about the food on the menu or any other questions + customers have about food in general. +diff --git a/py/samples/menu/src/case_05/flows.py b/py/samples/menu/src/case_05/flows.py +index 10e469e8e..003036949 100644 +--- a/py/samples/menu/src/case_05/flows.py ++++ b/py/samples/menu/src/case_05/flows.py +@@ -28,27 +28,20 @@ from menu_schemas import ( + ) + + +-@ai.flow(name='s05_readMenuFlow') +-async def s05_readMenuFlow(_) -> ReadMenuPromptOutputSchema: ++@ai.flow(name='s05_readMenu') ++async def s05_readMenuFlow(_: None = None) -> str: + image_data_url = inline_data_url('menu.jpeg', 'image/jpeg') +- response = await s05_readMenuPrompt( +- image_url=image_data_url, +- ) +- return ReadMenuPromptOutputSchema( +- menu_text=response.text, +- ) ++ response = await s05_readMenuPrompt({'imageUrl': image_data_url}) ++ return response.text + + + @ai.flow(name='s05_textMenuQuestion') + async def s05_textMenuQuestionFlow( + my_input: TextMenuQuestionInputSchema, + ) -> AnswerOutputSchema: +- response = await s05_textMenuPrompt( +- menu_text=my_input.menu_text, +- question=my_input.question, +- ) +- return ReadMenuPromptOutputSchema( +- menu_text=response.text, ++ response = await s05_textMenuPrompt({'menuText': my_input.menuText, 'question': my_input.question}) ++ return AnswerOutputSchema( ++ answer=response.text, + ) + + +@@ -56,11 +49,11 @@ async def s05_textMenuQuestionFlow( + async def s05_visionMenuQuestionFlow( + my_input: MenuQuestionInputSchema, + ) -> AnswerOutputSchema: +- menu_result = await s05_readMenuFlow() +- return s05_textMenuQuestionFlow( +- my_input=TextMenuQuestionInputSchema( ++ menu_text = await s05_readMenuFlow() ++ return await s05_textMenuQuestionFlow( ++ TextMenuQuestionInputSchema( + question=my_input.question, +- menu_text=menu_result.menu_text, ++ menuText=menu_text, + ) + ) + +diff --git a/py/samples/menu/src/case_05/prompts.py b/py/samples/menu/src/case_05/prompts.py +index fcd865e4a..dceb98bf9 100644 +--- a/py/samples/menu/src/case_05/prompts.py ++++ b/py/samples/menu/src/case_05/prompts.py +@@ -16,34 +16,34 @@ + from menu_ai import ai + from menu_schemas import ReadMenuImagePromptSchema, TextMenuQuestionInputSchema + +-from genkit.plugins.google_genai import google_genai_name +-from genkit.plugins.google_genai.models.gemini import GeminiVersion ++from genkit.plugins.google_genai import googleai_name ++from genkit.plugins.google_genai.models.gemini import GoogleAIGeminiVersion + + s05_readMenuPrompt = ai.define_prompt( + variant='s05_readMenu', +- model=google_genai_name(GeminiVersion.GEMINI_1_5_FLASH), ++ model=googleai_name(GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), + input_schema=ReadMenuImagePromptSchema, + config={'temperature': 0.1}, +- system=""" ++ prompt=""" + Extract _all_ of the text, in order, + from the following image of a restaurant menu. + +-{{media url=image_url}} ++{{media url=imageUrl}} + """, + ) + + s05_textMenuPrompt = ai.define_prompt( + variant='s05_textMenu', +- model=google_genai_name(GeminiVersion.GEMINI_1_5_FLASH), ++ model=googleai_name(GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), + input_schema=TextMenuQuestionInputSchema, + config={'temperature': 0.3}, +- system=""" ++ prompt=""" + You are acting as Walt, a helpful AI assistant here at the restaurant. + You can answer questions about the food on the menu or any other questions + customers have about food in general. + + Here is the text of today's menu to help you answer the customer's question: +-{{menu_text}} ++{{menuText}} + + Answer this customer's question: + {{question}}? +diff --git a/py/samples/menu/src/main.py b/py/samples/menu/src/main.py +index e083a7ff6..1be31e356 100755 +--- a/py/samples/menu/src/main.py ++++ b/py/samples/menu/src/main.py +@@ -14,17 +14,18 @@ + # + # SPDX-License-Identifier: Apache-2.0 + +-"""A stub for the sample to come.""" +- +- +-def main() -> None: +- """Main entry point for the menu sample. +- +- This function demonstrates how to use Genkit to build an interactive +- menu system. +- """ +- print('Hey') +- ++# Import all of the example prompts and flows to ensure they are registered ++import case_01.prompts ++import case_02.flows ++import case_02.prompts ++import case_02.tools ++import case_03.flows ++import case_03.prompts ++import case_04.flows ++import case_04.prompts ++import case_05.flows ++import case_05.prompts ++from menu_ai import ai + + if __name__ == '__main__': +- main() ++ ai.run_main() +diff --git a/py/samples/menu/src/menu_ai.py b/py/samples/menu/src/menu_ai.py +index c0059eb05..59288684b 100644 +--- a/py/samples/menu/src/menu_ai.py ++++ b/py/samples/menu/src/menu_ai.py +@@ -17,15 +17,14 @@ + + from genkit.ai import Genkit + from genkit.plugins.dev_local_vectorstore import DevLocalVectorStore +-from genkit.plugins.google_genai import VertexAI +-from genkit.plugins.vertex_ai import EmbeddingModels ++from genkit.plugins.google_genai import GeminiEmbeddingModels, GoogleAI, googleai_name + + ai = Genkit( + plugins=[ +- VertexAI(), ++ GoogleAI(), + DevLocalVectorStore( +- index_name='menu-items', +- embedder=EmbeddingModels.TEXT_EMBEDDING_004_ENG, ++ name='menu-items', ++ embedder=googleai_name(GeminiEmbeddingModels.TEXT_EMBEDDING_004), + embedder_options={'taskType': 'RETRIEVAL_DOCUMENT'}, + ), + ] +diff --git a/py/samples/model-garden/pyproject.toml b/py/samples/model-garden/pyproject.toml +index dc44c2818..98dc6490f 100644 +--- a/py/samples/model-garden/pyproject.toml ++++ b/py/samples/model-garden/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +@@ -31,9 +30,9 @@ classifiers = [ + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries", + ] +-dependencies = ["genkit", "pydantic>=2.10.5"] ++dependencies = ["genkit", "genkit-plugin-vertex-ai", "pydantic>=2.10.5"] + description = "Model Garden sample" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "model-garden-example" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/samples/multi-server/pyproject.toml b/py/samples/multi-server/pyproject.toml +index d6718b4d5..7163fd5d6 100644 +--- a/py/samples/multi-server/pyproject.toml ++++ b/py/samples/multi-server/pyproject.toml +@@ -15,13 +15,13 @@ + # SPDX-License-Identifier: Apache-2.0 + + [project] ++authors = [{ name = "Google" }] + classifiers = [ + "Development Status :: 3 - Alpha", + "Environment :: Console", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -42,6 +42,7 @@ dependencies = [ + "uvicorn>=0.34.0", + ] + description = "Sample implementation to exercise the Genkit multi server manager." ++license = "Apache-2.0" + name = "multi-server" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/samples/ollama-hello/pyproject.toml b/py/samples/ollama-hello/pyproject.toml +index 57a92ec67..1a6fc009b 100644 +--- a/py/samples/ollama-hello/pyproject.toml ++++ b/py/samples/ollama-hello/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -40,7 +39,7 @@ dependencies = [ + "structlog>=25.2.0", + ] + description = "Ollama hello sample" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "ollama-hello" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/samples/ollama-hello/src/main.py b/py/samples/ollama-hello/src/main.py +index 98a1cd7f7..668de256d 100755 +--- a/py/samples/ollama-hello/src/main.py ++++ b/py/samples/ollama-hello/src/main.py +@@ -33,9 +33,6 @@ Key features demonstrated in this sample: + + """ + +-import asyncio +-import json +- + import structlog + from pydantic import BaseModel, Field + +diff --git a/py/samples/ollama-simple-embed/pyproject.toml b/py/samples/ollama-simple-embed/pyproject.toml +index 28e8ee8d2..d48e850e2 100644 +--- a/py/samples/ollama-simple-embed/pyproject.toml ++++ b/py/samples/ollama-simple-embed/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -40,7 +39,7 @@ dependencies = [ + "structlog>=25.2.0", + ] + description = "Ollama Simple Embed" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "ollama_simple_embed" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/samples/prompt_demo/prompts/hello.prompt b/py/samples/prompt_demo/prompts/hello.prompt +index 1824e7e97..790c21469 100644 +--- a/py/samples/prompt_demo/prompts/hello.prompt ++++ b/py/samples/prompt_demo/prompts/hello.prompt +@@ -1,5 +1,5 @@ + --- +-model: googleai/gemini-2.5-flash ++model: googleai/gemini-3-flash-preview + input: + schema: + name: string +diff --git a/py/samples/prompt_demo/pyproject.toml b/py/samples/prompt_demo/pyproject.toml +index 37ef4a2ed..e5b36a1a7 100644 +--- a/py/samples/prompt_demo/pyproject.toml ++++ b/py/samples/prompt_demo/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -35,7 +34,7 @@ classifiers = [ + ] + dependencies = ["genkit", "structlog>=25.2.0", "genkit-plugin-google-genai"] + description = "Genkit prompt demo" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "prompt-demo" + requires-python = ">=3.10" + version = "0.0.1" +diff --git a/py/samples/prompt_demo/src/main.py b/py/samples/prompt_demo/src/main.py +index 723821090..2c6c9b493 100755 +--- a/py/samples/prompt_demo/src/main.py ++++ b/py/samples/prompt_demo/src/main.py +@@ -14,7 +14,6 @@ + # + # SPDX-License-Identifier: Apache-2.0 + +-import asyncio + from pathlib import Path + + import structlog +@@ -29,7 +28,7 @@ logger = structlog.get_logger(__name__) + current_dir = Path(__file__).resolve().parent + prompts_path = current_dir.parent / 'prompts' + +-ai = Genkit(plugins=[GoogleAI()], model='googleai/gemini-2.5-flash', prompt_dir=prompts_path) ++ai = Genkit(plugins=[GoogleAI()], model='googleai/gemini-3-flash-preview', prompt_dir=prompts_path) + + + def my_helper(content, *_, **__): +diff --git a/py/samples/short-n-long/pyproject.toml b/py/samples/short-n-long/pyproject.toml +index fa46fd524..1a6dace23 100644 +--- a/py/samples/short-n-long/pyproject.toml ++++ b/py/samples/short-n-long/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -43,7 +42,7 @@ dependencies = [ + "uvloop>=0.21.0", + ] + description = "Short and long sample" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "short-n-long" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/samples/short-n-long/src/main.py b/py/samples/short-n-long/src/main.py +index 9179d0bed..75f8f01cf 100755 +--- a/py/samples/short-n-long/src/main.py ++++ b/py/samples/short-n-long/src/main.py +@@ -69,7 +69,7 @@ logger = structlog.get_logger(__name__) + + ai = Genkit( + plugins=[GoogleAI()], +- model=googleai_name('gemini-2.5-flash'), ++ model=googleai_name('gemini-3-flash-preview'), + ) + + +@@ -103,7 +103,7 @@ async def simple_generate_with_tools_flow(value: int) -> str: + The generated response with a function. + """ + response = await ai.generate( +- model=googleai_name(gemini.GoogleAIGeminiVersion.GEMINI_2_0_FLASH), ++ model=googleai_name(gemini.GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), + messages=[ + Message( + role=Role.USER, +@@ -140,7 +140,7 @@ async def simple_generate_with_interrupts(value: int) -> str: + The generated response with a function. + """ + response1 = await ai.generate( +- model=googleai_name(gemini.GoogleAIGeminiVersion.GEMINI_2_0_FLASH), ++ model=googleai_name(gemini.GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), + messages=[ + Message( + role=Role.USER, +@@ -155,7 +155,7 @@ async def simple_generate_with_interrupts(value: int) -> str: + + tr = tool_response(response1.interrupts[0], 178) + response = await ai.generate( +- model=googleai_name(gemini.GoogleAIGeminiVersion.GEMINI_2_0_FLASH), ++ model=googleai_name(gemini.GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), + messages=response1.messages, + tool_responses=[tr], + tools=['gablorkenTool'], +diff --git a/py/samples/tool-interrupts/pyproject.toml b/py/samples/tool-interrupts/pyproject.toml +index c4e5e57bb..20391edb0 100644 +--- a/py/samples/tool-interrupts/pyproject.toml ++++ b/py/samples/tool-interrupts/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -35,7 +34,7 @@ classifiers = [ + ] + dependencies = ["genkit", "genkit-plugin-google-genai", "pydantic>=2.10.5"] + description = "Tool interrupts sample" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "tool-interrupts" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/samples/tool-interrupts/src/main.py b/py/samples/tool-interrupts/src/main.py +index f71f44b25..fdc14f868 100755 +--- a/py/samples/tool-interrupts/src/main.py ++++ b/py/samples/tool-interrupts/src/main.py +@@ -28,7 +28,7 @@ from genkit.plugins.google_genai.models import gemini + + ai = Genkit( + plugins=[GoogleAI()], +- model=googleai_name(gemini.GoogleAIGeminiVersion.GEMINI_2_0_FLASH), ++ model=googleai_name(gemini.GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), + ) + + +diff --git a/py/samples/vertex-ai-vector-search-bigquery/pyproject.toml b/py/samples/vertex-ai-vector-search-bigquery/pyproject.toml +index 5330a5705..9b15b7631 100644 +--- a/py/samples/vertex-ai-vector-search-bigquery/pyproject.toml ++++ b/py/samples/vertex-ai-vector-search-bigquery/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -43,7 +42,7 @@ dependencies = [ + "strenum>=0.4.15; python_version < '3.11'", + ] + description = "An example demonstrating the use Vector Search API with BigQuery retriever for Vertex AI" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "vertex-ai-vector-search-bigquery" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/samples/vertex-ai-vector-search-firestore/pyproject.toml b/py/samples/vertex-ai-vector-search-firestore/pyproject.toml +index 6ff3f349f..99fd0c758 100644 +--- a/py/samples/vertex-ai-vector-search-firestore/pyproject.toml ++++ b/py/samples/vertex-ai-vector-search-firestore/pyproject.toml +@@ -22,7 +22,6 @@ classifiers = [ + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -43,7 +42,7 @@ dependencies = [ + "strenum>=0.4.15; python_version < '3.11'", + ] + description = "An example demonstrating the use Vector Search API with Firestore retriever for Vertex AI" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "vertex-ai-vector-search-firestore" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/samples/xai-hello/pyproject.toml b/py/samples/xai-hello/pyproject.toml +index 65b7aff42..5fa7f84d4 100644 +--- a/py/samples/xai-hello/pyproject.toml ++++ b/py/samples/xai-hello/pyproject.toml +@@ -15,6 +15,7 @@ + # SPDX-License-Identifier: Apache-2.0 + + [project] ++authors = [{ name = "Google" }] + dependencies = [ + "genkit", + "genkit-plugin-xai", +diff --git a/py/tests/smoke/pyproject.toml b/py/tests/smoke/pyproject.toml +index 4605d5626..9e31d9de6 100644 +--- a/py/tests/smoke/pyproject.toml ++++ b/py/tests/smoke/pyproject.toml +@@ -15,13 +15,13 @@ + # SPDX-License-Identifier: Apache-2.0 + + [project] ++authors = [{ name = "Google" }] + classifiers = [ + "Development Status :: 3 - Alpha", + "Environment :: Console", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", +- "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", +@@ -42,7 +42,7 @@ dependencies = [ + "strenum>=0.4.15; python_version < '3.11'", + ] + description = "Packaging smoke test" +-license = { text = "Apache-2.0" } ++license = "Apache-2.0" + name = "smoke" + readme = "README.md" + requires-python = ">=3.10" +diff --git a/py/uv.lock b/py/uv.lock +index 560e88909..cca42bd6b 100644 +--- a/py/uv.lock ++++ b/py/uv.lock +@@ -1,5 +1,5 @@ + version = 1 +-revision = 2 ++revision = 3 + requires-python = ">=3.10" + resolution-markers = [ + "python_full_version >= '3.14'", +@@ -12,6 +12,7 @@ resolution-markers = [ + members = [ + "anthropic-hello", + "compat-oai-hello", ++ "deepseek-hello", + "dev-local-vectorstore-hello", + "eval-demo", + "firestore-retreiver", +@@ -19,6 +20,7 @@ members = [ + "genkit", + "genkit-plugin-anthropic", + "genkit-plugin-compat-oai", ++ "genkit-plugin-deepseek", + "genkit-plugin-dev-local-vectorstore", + "genkit-plugin-evaluators", + "genkit-plugin-firebase", +@@ -28,6 +30,7 @@ members = [ + "genkit-plugin-ollama", + "genkit-plugin-vertex-ai", + "genkit-plugin-xai", ++ "genkit-plugins-mcp", + "genkit-workspace", + "google-genai-code-execution", + "google-genai-context-caching", +@@ -796,6 +799,7 @@ dependencies = [ + ] + sdist = { url = "https://files.pythonhosted.org/packages/13/1f/9fa001e74a1993a9cadd2333bb889e50c66327b8594ac538ab8a04f915b7/cryptography-45.0.3.tar.gz", hash = "sha256:ec21313dd335c51d7877baf2972569f40a4291b76a0ce51391523ae358d05899", size = 744738, upload-time = "2025-05-25T14:17:24.777Z" } + wheels = [ ++ { url = "https://files.pythonhosted.org/packages/82/b2/2345dc595998caa6f68adf84e8f8b50d18e9fc4638d32b22ea8daedd4b7a/cryptography-45.0.3-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:7573d9eebaeceeb55285205dbbb8753ac1e962af3d9640791d12b36864065e71", size = 7056239, upload-time = "2025-05-25T14:16:12.22Z" }, + { url = "https://files.pythonhosted.org/packages/71/3d/ac361649a0bfffc105e2298b720d8b862330a767dab27c06adc2ddbef96a/cryptography-45.0.3-cp311-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d377dde61c5d67eb4311eace661c3efda46c62113ff56bf05e2d679e02aebb5b", size = 4205541, upload-time = "2025-05-25T14:16:14.333Z" }, + { url = "https://files.pythonhosted.org/packages/70/3e/c02a043750494d5c445f769e9c9f67e550d65060e0bfce52d91c1362693d/cryptography-45.0.3-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fae1e637f527750811588e4582988932c222f8251f7b7ea93739acb624e1487f", size = 4433275, upload-time = "2025-05-25T14:16:16.421Z" }, + { url = "https://files.pythonhosted.org/packages/40/7a/9af0bfd48784e80eef3eb6fd6fde96fe706b4fc156751ce1b2b965dada70/cryptography-45.0.3-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ca932e11218bcc9ef812aa497cdf669484870ecbcf2d99b765d6c27a86000942", size = 4209173, upload-time = "2025-05-25T14:16:18.163Z" }, +@@ -805,6 +809,9 @@ wheels = [ + { url = "https://files.pythonhosted.org/packages/e7/53/8a130e22c1e432b3c14896ec5eb7ac01fb53c6737e1d705df7e0efb647c6/cryptography-45.0.3-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:c824c9281cb628015bfc3c59335163d4ca0540d49de4582d6c2637312907e4b1", size = 4466300, upload-time = "2025-05-25T14:16:26.768Z" }, + { url = "https://files.pythonhosted.org/packages/ba/75/6bb6579688ef805fd16a053005fce93944cdade465fc92ef32bbc5c40681/cryptography-45.0.3-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:5833bb4355cb377ebd880457663a972cd044e7f49585aee39245c0d592904578", size = 4332483, upload-time = "2025-05-25T14:16:28.316Z" }, + { url = "https://files.pythonhosted.org/packages/2f/11/2538f4e1ce05c6c4f81f43c1ef2bd6de7ae5e24ee284460ff6c77e42ca77/cryptography-45.0.3-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:9bb5bf55dcb69f7067d80354d0a348368da907345a2c448b0babc4215ccd3497", size = 4573714, upload-time = "2025-05-25T14:16:30.474Z" }, ++ { url = "https://files.pythonhosted.org/packages/f5/bb/e86e9cf07f73a98d84a4084e8fd420b0e82330a901d9cac8149f994c3417/cryptography-45.0.3-cp311-abi3-win32.whl", hash = "sha256:3ad69eeb92a9de9421e1f6685e85a10fbcfb75c833b42cc9bc2ba9fb00da4710", size = 2934752, upload-time = "2025-05-25T14:16:32.204Z" }, ++ { url = "https://files.pythonhosted.org/packages/c7/75/063bc9ddc3d1c73e959054f1fc091b79572e716ef74d6caaa56e945b4af9/cryptography-45.0.3-cp311-abi3-win_amd64.whl", hash = "sha256:97787952246a77d77934d41b62fb1b6f3581d83f71b44796a4158d93b8f5c490", size = 3412465, upload-time = "2025-05-25T14:16:33.888Z" }, ++ { url = "https://files.pythonhosted.org/packages/71/9b/04ead6015229a9396890d7654ee35ef630860fb42dc9ff9ec27f72157952/cryptography-45.0.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:c92519d242703b675ccefd0f0562eb45e74d438e001f8ab52d628e885751fb06", size = 7031892, upload-time = "2025-05-25T14:16:36.214Z" }, + { url = "https://files.pythonhosted.org/packages/46/c7/c7d05d0e133a09fc677b8a87953815c522697bdf025e5cac13ba419e7240/cryptography-45.0.3-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5edcb90da1843df85292ef3a313513766a78fbbb83f584a5a58fb001a5a9d57", size = 4196181, upload-time = "2025-05-25T14:16:37.934Z" }, + { url = "https://files.pythonhosted.org/packages/08/7a/6ad3aa796b18a683657cef930a986fac0045417e2dc428fd336cfc45ba52/cryptography-45.0.3-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38deed72285c7ed699864f964a3f4cf11ab3fb38e8d39cfcd96710cd2b5bb716", size = 4423370, upload-time = "2025-05-25T14:16:39.502Z" }, + { url = "https://files.pythonhosted.org/packages/4f/58/ec1461bfcb393525f597ac6a10a63938d18775b7803324072974b41a926b/cryptography-45.0.3-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:5555365a50efe1f486eed6ac7062c33b97ccef409f5970a0b6f205a7cfab59c8", size = 4197839, upload-time = "2025-05-25T14:16:41.322Z" }, +@@ -814,14 +821,20 @@ wheels = [ + { url = "https://files.pythonhosted.org/packages/71/7a/e002d5ce624ed46dfc32abe1deff32190f3ac47ede911789ee936f5a4255/cryptography-45.0.3-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:57a6500d459e8035e813bd8b51b671977fb149a8c95ed814989da682314d0782", size = 4450308, upload-time = "2025-05-25T14:16:48.228Z" }, + { url = "https://files.pythonhosted.org/packages/87/ad/3fbff9c28cf09b0a71e98af57d74f3662dea4a174b12acc493de00ea3f28/cryptography-45.0.3-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:f22af3c78abfbc7cbcdf2c55d23c3e022e1a462ee2481011d518c7fb9c9f3d65", size = 4325125, upload-time = "2025-05-25T14:16:49.844Z" }, + { url = "https://files.pythonhosted.org/packages/f5/b4/51417d0cc01802304c1984d76e9592f15e4801abd44ef7ba657060520bf0/cryptography-45.0.3-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:232954730c362638544758a8160c4ee1b832dc011d2c41a306ad8f7cccc5bb0b", size = 4560038, upload-time = "2025-05-25T14:16:51.398Z" }, ++ { url = "https://files.pythonhosted.org/packages/80/38/d572f6482d45789a7202fb87d052deb7a7b136bf17473ebff33536727a2c/cryptography-45.0.3-cp37-abi3-win32.whl", hash = "sha256:cb6ab89421bc90e0422aca911c69044c2912fc3debb19bb3c1bfe28ee3dff6ab", size = 2924070, upload-time = "2025-05-25T14:16:53.472Z" }, ++ { url = "https://files.pythonhosted.org/packages/91/5a/61f39c0ff4443651cc64e626fa97ad3099249152039952be8f344d6b0c86/cryptography-45.0.3-cp37-abi3-win_amd64.whl", hash = "sha256:d54ae41e6bd70ea23707843021c778f151ca258081586f0cfa31d936ae43d1b2", size = 3395005, upload-time = "2025-05-25T14:16:55.134Z" }, ++ { url = "https://files.pythonhosted.org/packages/1b/63/ce30cb7204e8440df2f0b251dc0464a26c55916610d1ba4aa912f838bcc8/cryptography-45.0.3-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:ed43d396f42028c1f47b5fec012e9e12631266e3825e95c00e3cf94d472dac49", size = 3578348, upload-time = "2025-05-25T14:16:56.792Z" }, + { url = "https://files.pythonhosted.org/packages/45/0b/87556d3337f5e93c37fda0a0b5d3e7b4f23670777ce8820fce7962a7ed22/cryptography-45.0.3-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:fed5aaca1750e46db870874c9c273cd5182a9e9deb16f06f7bdffdb5c2bde4b9", size = 4142867, upload-time = "2025-05-25T14:16:58.459Z" }, + { url = "https://files.pythonhosted.org/packages/72/ba/21356dd0bcb922b820211336e735989fe2cf0d8eaac206335a0906a5a38c/cryptography-45.0.3-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:00094838ecc7c6594171e8c8a9166124c1197b074cfca23645cee573910d76bc", size = 4385000, upload-time = "2025-05-25T14:17:00.656Z" }, + { url = "https://files.pythonhosted.org/packages/2f/2b/71c78d18b804c317b66283be55e20329de5cd7e1aec28e4c5fbbe21fd046/cryptography-45.0.3-pp310-pypy310_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:92d5f428c1a0439b2040435a1d6bc1b26ebf0af88b093c3628913dd464d13fa1", size = 4144195, upload-time = "2025-05-25T14:17:02.782Z" }, + { url = "https://files.pythonhosted.org/packages/55/3e/9f9b468ea779b4dbfef6af224804abd93fbcb2c48605d7443b44aea77979/cryptography-45.0.3-pp310-pypy310_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:ec64ee375b5aaa354b2b273c921144a660a511f9df8785e6d1c942967106438e", size = 4384540, upload-time = "2025-05-25T14:17:04.49Z" }, ++ { url = "https://files.pythonhosted.org/packages/97/f5/6e62d10cf29c50f8205c0dc9aec986dca40e8e3b41bf1a7878ea7b11e5ee/cryptography-45.0.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:71320fbefd05454ef2d457c481ba9a5b0e540f3753354fff6f780927c25d19b0", size = 3328796, upload-time = "2025-05-25T14:17:06.174Z" }, ++ { url = "https://files.pythonhosted.org/packages/e7/d4/58a246342093a66af8935d6aa59f790cbb4731adae3937b538d054bdc2f9/cryptography-45.0.3-pp311-pypy311_pp73-macosx_10_9_x86_64.whl", hash = "sha256:edd6d51869beb7f0d472e902ef231a9b7689508e83880ea16ca3311a00bf5ce7", size = 3589802, upload-time = "2025-05-25T14:17:07.792Z" }, + { url = "https://files.pythonhosted.org/packages/96/61/751ebea58c87b5be533c429f01996050a72c7283b59eee250275746632ea/cryptography-45.0.3-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:555e5e2d3a53b4fabeca32835878b2818b3f23966a4efb0d566689777c5a12c8", size = 4146964, upload-time = "2025-05-25T14:17:09.538Z" }, + { url = "https://files.pythonhosted.org/packages/8d/01/28c90601b199964de383da0b740b5156f5d71a1da25e7194fdf793d373ef/cryptography-45.0.3-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:25286aacb947286620a31f78f2ed1a32cded7be5d8b729ba3fb2c988457639e4", size = 4388103, upload-time = "2025-05-25T14:17:11.978Z" }, + { url = "https://files.pythonhosted.org/packages/3d/ec/cd892180b9e42897446ef35c62442f5b8b039c3d63a05f618aa87ec9ebb5/cryptography-45.0.3-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:050ce5209d5072472971e6efbfc8ec5a8f9a841de5a4db0ebd9c2e392cb81972", size = 4150031, upload-time = "2025-05-25T14:17:14.131Z" }, + { url = "https://files.pythonhosted.org/packages/db/d4/22628c2dedd99289960a682439c6d3aa248dff5215123ead94ac2d82f3f5/cryptography-45.0.3-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:dc10ec1e9f21f33420cc05214989544727e776286c1c16697178978327b95c9c", size = 4387389, upload-time = "2025-05-25T14:17:17.303Z" }, ++ { url = "https://files.pythonhosted.org/packages/39/ec/ba3961abbf8ecb79a3586a4ff0ee08c9d7a9938b4312fb2ae9b63f48a8ba/cryptography-45.0.3-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:9eda14f049d7f09c2e8fb411dda17dd6b16a3c76a1de5e249188a32aeb92de19", size = 3337432, upload-time = "2025-05-25T14:17:19.507Z" }, + ] + + [[package]] +@@ -928,6 +941,25 @@ wheels = [ + { url = "https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a", size = 9190, upload-time = "2025-02-24T04:41:32.565Z" }, + ] + ++[[package]] ++name = "deepseek-hello" ++version = "0.1.0" ++source = { editable = "samples/deepseek-hello" } ++dependencies = [ ++ { name = "genkit" }, ++ { name = "genkit-plugin-deepseek" }, ++ { name = "pydantic" }, ++ { name = "structlog" }, ++] ++ ++[package.metadata] ++requires-dist = [ ++ { name = "genkit", editable = "packages/genkit" }, ++ { name = "genkit-plugin-deepseek", editable = "plugins/deepseek" }, ++ { name = "pydantic", specifier = ">=2.0.0" }, ++ { name = "structlog", specifier = ">=24.0.0" }, ++] ++ + [[package]] + name = "defusedxml" + version = "0.7.1" +@@ -1615,6 +1647,23 @@ requires-dist = [ + { name = "strenum", marker = "python_full_version < '3.11'", specifier = ">=0.4.15" }, + ] + ++[[package]] ++name = "genkit-plugin-deepseek" ++version = "0.1.0" ++source = { editable = "plugins/deepseek" } ++dependencies = [ ++ { name = "genkit" }, ++ { name = "genkit-plugin-compat-oai" }, ++ { name = "openai" }, ++] ++ ++[package.metadata] ++requires-dist = [ ++ { name = "genkit", editable = "packages/genkit" }, ++ { name = "genkit-plugin-compat-oai", editable = "plugins/compat-oai" }, ++ { name = "openai", specifier = ">=1.0.0" }, ++] ++ + [[package]] + name = "genkit-plugin-dev-local-vectorstore" + version = "0.4.0" +@@ -1753,6 +1802,7 @@ version = "0.4.0" + source = { editable = "plugins/vertex-ai" } + dependencies = [ + { name = "genkit" }, ++ { name = "genkit-plugin-compat-oai" }, + { name = "google-cloud-aiplatform" }, + { name = "google-cloud-bigquery" }, + { name = "google-cloud-firestore" }, +@@ -1764,6 +1814,7 @@ dependencies = [ + [package.metadata] + requires-dist = [ + { name = "genkit", editable = "packages/genkit" }, ++ { name = "genkit-plugin-compat-oai", editable = "plugins/compat-oai" }, + { name = "google-cloud-aiplatform", specifier = ">=1.77.0" }, + { name = "google-cloud-bigquery" }, + { name = "google-cloud-firestore" }, +@@ -1787,6 +1838,21 @@ requires-dist = [ + { name = "xai-sdk", specifier = ">=0.0.1" }, + ] + ++[[package]] ++name = "genkit-plugins-mcp" ++version = "0.1.0" ++source = { editable = "plugins/mcp" } ++dependencies = [ ++ { name = "genkit" }, ++ { name = "mcp" }, ++] ++ ++[package.metadata] ++requires-dist = [ ++ { name = "genkit", editable = "packages/genkit" }, ++ { name = "mcp" }, ++] ++ + [[package]] + name = "genkit-workspace" + version = "0.1.0" +@@ -1796,6 +1862,7 @@ dependencies = [ + { name = "genkit" }, + { name = "genkit-plugin-anthropic" }, + { name = "genkit-plugin-compat-oai" }, ++ { name = "genkit-plugin-deepseek" }, + { name = "genkit-plugin-dev-local-vectorstore" }, + { name = "genkit-plugin-evaluators" }, + { name = "genkit-plugin-firebase" }, +@@ -1806,6 +1873,7 @@ dependencies = [ + { name = "genkit-plugin-vertex-ai" }, + { name = "genkit-plugin-xai" }, + { name = "liccheck" }, ++ { name = "mcp" }, + { name = "strenum", marker = "python_full_version < '3.11'" }, + ] + +@@ -1830,8 +1898,8 @@ dev = [ + { name = "twine" }, + ] + lint = [ +- { name = "mypy" }, + { name = "ruff" }, ++ { name = "ty" }, + ] + + [package.metadata] +@@ -1840,6 +1908,7 @@ requires-dist = [ + { name = "genkit", editable = "packages/genkit" }, + { name = "genkit-plugin-anthropic", editable = "plugins/anthropic" }, + { name = "genkit-plugin-compat-oai", editable = "plugins/compat-oai" }, ++ { name = "genkit-plugin-deepseek", editable = "plugins/deepseek" }, + { name = "genkit-plugin-dev-local-vectorstore", editable = "plugins/dev-local-vectorstore" }, + { name = "genkit-plugin-evaluators", editable = "plugins/evaluators" }, + { name = "genkit-plugin-firebase", editable = "plugins/firebase" }, +@@ -1850,6 +1919,7 @@ requires-dist = [ + { name = "genkit-plugin-vertex-ai", editable = "plugins/vertex-ai" }, + { name = "genkit-plugin-xai", editable = "plugins/xai" }, + { name = "liccheck", specifier = ">=0.9.2" }, ++ { name = "mcp", specifier = ">=1.25.0" }, + { name = "strenum", marker = "python_full_version < '3.11'", specifier = ">=0.4.15" }, + ] + +@@ -1874,8 +1944,8 @@ dev = [ + { name = "twine", specifier = ">=6.1.0" }, + ] + lint = [ +- { name = "mypy", specifier = ">=1.15" }, + { name = "ruff", specifier = ">=0.9" }, ++ { name = "ty", specifier = ">=0.0.1" }, + ] + + [[package]] +@@ -2499,6 +2569,15 @@ wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, + ] + ++[[package]] ++name = "httpx-sse" ++version = "0.4.3" ++source = { registry = "https://pypi.org/simple" } ++sdist = { url = "https://files.pythonhosted.org/packages/0f/4c/751061ffa58615a32c31b2d82e8482be8dd4a89154f003147acee90f2be9/httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d", size = 15943, upload-time = "2025-10-10T21:48:22.271Z" } ++wheels = [ ++ { url = "https://files.pythonhosted.org/packages/d2/fd/6668e5aec43ab844de6fc74927e155a3b37bf40d7c3790e49fc0406b6578/httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc", size = 8960, upload-time = "2025-10-10T21:48:21.158Z" }, ++] ++ + [[package]] + name = "id" + version = "1.5.0" +@@ -3257,6 +3336,31 @@ wheels = [ + { url = "https://files.pythonhosted.org/packages/8f/8e/9ad090d3553c280a8060fbf6e24dc1c0c29704ee7d1c372f0c174aa59285/matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca", size = 9899, upload-time = "2024-04-15T13:44:43.265Z" }, + ] + ++[[package]] ++name = "mcp" ++version = "1.25.0" ++source = { registry = "https://pypi.org/simple" } ++dependencies = [ ++ { name = "anyio" }, ++ { name = "httpx" }, ++ { name = "httpx-sse" }, ++ { name = "jsonschema" }, ++ { name = "pydantic" }, ++ { name = "pydantic-settings" }, ++ { name = "pyjwt", extra = ["crypto"] }, ++ { name = "python-multipart" }, ++ { name = "pywin32", marker = "sys_platform == 'win32'" }, ++ { name = "sse-starlette" }, ++ { name = "starlette" }, ++ { name = "typing-extensions" }, ++ { name = "typing-inspection" }, ++ { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, ++] ++sdist = { url = "https://files.pythonhosted.org/packages/d5/2d/649d80a0ecf6a1f82632ca44bec21c0461a9d9fc8934d38cb5b319f2db5e/mcp-1.25.0.tar.gz", hash = "sha256:56310361ebf0364e2d438e5b45f7668cbb124e158bb358333cd06e49e83a6802", size = 605387, upload-time = "2025-12-19T10:19:56.985Z" } ++wheels = [ ++ { url = "https://files.pythonhosted.org/packages/e2/fc/6dc7659c2ae5ddf280477011f4213a74f806862856b796ef08f028e664bf/mcp-1.25.0-py3-none-any.whl", hash = "sha256:b37c38144a666add0862614cc79ec276e97d72aa8ca26d622818d4e278b9721a", size = 233076, upload-time = "2025-12-19T10:19:55.416Z" }, ++] ++ + [[package]] + name = "mdurl" + version = "0.1.2" +@@ -3311,12 +3415,14 @@ version = "0.1.0" + source = { virtual = "samples/model-garden" } + dependencies = [ + { name = "genkit" }, ++ { name = "genkit-plugin-vertex-ai" }, + { name = "pydantic" }, + ] + + [package.metadata] + requires-dist = [ + { name = "genkit", editable = "packages/genkit" }, ++ { name = "genkit-plugin-vertex-ai", editable = "plugins/vertex-ai" }, + { name = "pydantic", specifier = ">=2.10.5" }, + ] + +@@ -3496,45 +3602,6 @@ wheels = [ + { url = "https://files.pythonhosted.org/packages/cc/d1/3598d1e73385baaab427392856f915487db7aa10abadd436f8f2d3e3b0f9/multipart-1.2.1-py3-none-any.whl", hash = "sha256:c03dc203bc2e67f6b46a599467ae0d87cf71d7530504b2c1ff4a9ea21d8b8c8c", size = 13730, upload-time = "2024-11-29T08:45:44.557Z" }, + ] + +-[[package]] +-name = "mypy" +-version = "1.16.0" +-source = { registry = "https://pypi.org/simple" } +-dependencies = [ +- { name = "mypy-extensions" }, +- { name = "pathspec" }, +- { name = "tomli", marker = "python_full_version < '3.11'" }, +- { name = "typing-extensions" }, +-] +-sdist = { url = "https://files.pythonhosted.org/packages/d4/38/13c2f1abae94d5ea0354e146b95a1be9b2137a0d506728e0da037c4276f6/mypy-1.16.0.tar.gz", hash = "sha256:84b94283f817e2aa6350a14b4a8fb2a35a53c286f97c9d30f53b63620e7af8ab", size = 3323139, upload-time = "2025-05-29T13:46:12.532Z" } +-wheels = [ +- { url = "https://files.pythonhosted.org/packages/64/5e/a0485f0608a3d67029d3d73cec209278b025e3493a3acfda3ef3a88540fd/mypy-1.16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7909541fef256527e5ee9c0a7e2aeed78b6cda72ba44298d1334fe7881b05c5c", size = 10967416, upload-time = "2025-05-29T13:34:17.783Z" }, +- { url = "https://files.pythonhosted.org/packages/4b/53/5837c221f74c0d53a4bfc3003296f8179c3a2a7f336d7de7bbafbe96b688/mypy-1.16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e71d6f0090c2256c713ed3d52711d01859c82608b5d68d4fa01a3fe30df95571", size = 10087654, upload-time = "2025-05-29T13:32:37.878Z" }, +- { url = "https://files.pythonhosted.org/packages/29/59/5fd2400352c3093bed4c09017fe671d26bc5bb7e6ef2d4bf85f2a2488104/mypy-1.16.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:936ccfdd749af4766be824268bfe22d1db9eb2f34a3ea1d00ffbe5b5265f5491", size = 11875192, upload-time = "2025-05-29T13:34:54.281Z" }, +- { url = "https://files.pythonhosted.org/packages/ad/3e/4bfec74663a64c2012f3e278dbc29ffe82b121bc551758590d1b6449ec0c/mypy-1.16.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4086883a73166631307fdd330c4a9080ce24913d4f4c5ec596c601b3a4bdd777", size = 12612939, upload-time = "2025-05-29T13:33:14.766Z" }, +- { url = "https://files.pythonhosted.org/packages/88/1f/fecbe3dcba4bf2ca34c26ca016383a9676711907f8db4da8354925cbb08f/mypy-1.16.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:feec38097f71797da0231997e0de3a58108c51845399669ebc532c815f93866b", size = 12874719, upload-time = "2025-05-29T13:21:52.09Z" }, +- { url = "https://files.pythonhosted.org/packages/f3/51/c2d280601cd816c43dfa512a759270d5a5ef638d7ac9bea9134c8305a12f/mypy-1.16.0-cp310-cp310-win_amd64.whl", hash = "sha256:09a8da6a0ee9a9770b8ff61b39c0bb07971cda90e7297f4213741b48a0cc8d93", size = 9487053, upload-time = "2025-05-29T13:33:29.797Z" }, +- { url = "https://files.pythonhosted.org/packages/24/c4/ff2f79db7075c274fe85b5fff8797d29c6b61b8854c39e3b7feb556aa377/mypy-1.16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:9f826aaa7ff8443bac6a494cf743f591488ea940dd360e7dd330e30dd772a5ab", size = 10884498, upload-time = "2025-05-29T13:18:54.066Z" }, +- { url = "https://files.pythonhosted.org/packages/02/07/12198e83006235f10f6a7808917376b5d6240a2fd5dce740fe5d2ebf3247/mypy-1.16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:82d056e6faa508501af333a6af192c700b33e15865bda49611e3d7d8358ebea2", size = 10011755, upload-time = "2025-05-29T13:34:00.851Z" }, +- { url = "https://files.pythonhosted.org/packages/f1/9b/5fd5801a72b5d6fb6ec0105ea1d0e01ab2d4971893076e558d4b6d6b5f80/mypy-1.16.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:089bedc02307c2548eb51f426e085546db1fa7dd87fbb7c9fa561575cf6eb1ff", size = 11800138, upload-time = "2025-05-29T13:32:55.082Z" }, +- { url = "https://files.pythonhosted.org/packages/2e/81/a117441ea5dfc3746431e51d78a4aca569c677aa225bca2cc05a7c239b61/mypy-1.16.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6a2322896003ba66bbd1318c10d3afdfe24e78ef12ea10e2acd985e9d684a666", size = 12533156, upload-time = "2025-05-29T13:19:12.963Z" }, +- { url = "https://files.pythonhosted.org/packages/3f/38/88ec57c6c86014d3f06251e00f397b5a7daa6888884d0abf187e4f5f587f/mypy-1.16.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:021a68568082c5b36e977d54e8f1de978baf401a33884ffcea09bd8e88a98f4c", size = 12742426, upload-time = "2025-05-29T13:20:22.72Z" }, +- { url = "https://files.pythonhosted.org/packages/bd/53/7e9d528433d56e6f6f77ccf24af6ce570986c2d98a5839e4c2009ef47283/mypy-1.16.0-cp311-cp311-win_amd64.whl", hash = "sha256:54066fed302d83bf5128632d05b4ec68412e1f03ef2c300434057d66866cea4b", size = 9478319, upload-time = "2025-05-29T13:21:17.582Z" }, +- { url = "https://files.pythonhosted.org/packages/70/cf/158e5055e60ca2be23aec54a3010f89dcffd788732634b344fc9cb1e85a0/mypy-1.16.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c5436d11e89a3ad16ce8afe752f0f373ae9620841c50883dc96f8b8805620b13", size = 11062927, upload-time = "2025-05-29T13:35:52.328Z" }, +- { url = "https://files.pythonhosted.org/packages/94/34/cfff7a56be1609f5d10ef386342ce3494158e4d506516890142007e6472c/mypy-1.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f2622af30bf01d8fc36466231bdd203d120d7a599a6d88fb22bdcb9dbff84090", size = 10083082, upload-time = "2025-05-29T13:35:33.378Z" }, +- { url = "https://files.pythonhosted.org/packages/b3/7f/7242062ec6288c33d8ad89574df87c3903d394870e5e6ba1699317a65075/mypy-1.16.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d045d33c284e10a038f5e29faca055b90eee87da3fc63b8889085744ebabb5a1", size = 11828306, upload-time = "2025-05-29T13:21:02.164Z" }, +- { url = "https://files.pythonhosted.org/packages/6f/5f/b392f7b4f659f5b619ce5994c5c43caab3d80df2296ae54fa888b3d17f5a/mypy-1.16.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b4968f14f44c62e2ec4a038c8797a87315be8df7740dc3ee8d3bfe1c6bf5dba8", size = 12702764, upload-time = "2025-05-29T13:20:42.826Z" }, +- { url = "https://files.pythonhosted.org/packages/9b/c0/7646ef3a00fa39ac9bc0938626d9ff29d19d733011be929cfea59d82d136/mypy-1.16.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:eb14a4a871bb8efb1e4a50360d4e3c8d6c601e7a31028a2c79f9bb659b63d730", size = 12896233, upload-time = "2025-05-29T13:18:37.446Z" }, +- { url = "https://files.pythonhosted.org/packages/6d/38/52f4b808b3fef7f0ef840ee8ff6ce5b5d77381e65425758d515cdd4f5bb5/mypy-1.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:bd4e1ebe126152a7bbaa4daedd781c90c8f9643c79b9748caa270ad542f12bec", size = 9565547, upload-time = "2025-05-29T13:20:02.836Z" }, +- { url = "https://files.pythonhosted.org/packages/97/9c/ca03bdbefbaa03b264b9318a98950a9c683e06472226b55472f96ebbc53d/mypy-1.16.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a9e056237c89f1587a3be1a3a70a06a698d25e2479b9a2f57325ddaaffc3567b", size = 11059753, upload-time = "2025-05-29T13:18:18.167Z" }, +- { url = "https://files.pythonhosted.org/packages/36/92/79a969b8302cfe316027c88f7dc6fee70129490a370b3f6eb11d777749d0/mypy-1.16.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0b07e107affb9ee6ce1f342c07f51552d126c32cd62955f59a7db94a51ad12c0", size = 10073338, upload-time = "2025-05-29T13:19:48.079Z" }, +- { url = "https://files.pythonhosted.org/packages/14/9b/a943f09319167da0552d5cd722104096a9c99270719b1afeea60d11610aa/mypy-1.16.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c6fb60cbd85dc65d4d63d37cb5c86f4e3a301ec605f606ae3a9173e5cf34997b", size = 11827764, upload-time = "2025-05-29T13:46:04.47Z" }, +- { url = "https://files.pythonhosted.org/packages/ec/64/ff75e71c65a0cb6ee737287c7913ea155845a556c64144c65b811afdb9c7/mypy-1.16.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a7e32297a437cc915599e0578fa6bc68ae6a8dc059c9e009c628e1c47f91495d", size = 12701356, upload-time = "2025-05-29T13:35:13.553Z" }, +- { url = "https://files.pythonhosted.org/packages/0a/ad/0e93c18987a1182c350f7a5fab70550852f9fabe30ecb63bfbe51b602074/mypy-1.16.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:afe420c9380ccec31e744e8baff0d406c846683681025db3531b32db56962d52", size = 12900745, upload-time = "2025-05-29T13:17:24.409Z" }, +- { url = "https://files.pythonhosted.org/packages/28/5d/036c278d7a013e97e33f08c047fe5583ab4f1fc47c9a49f985f1cdd2a2d7/mypy-1.16.0-cp313-cp313-win_amd64.whl", hash = "sha256:55f9076c6ce55dd3f8cd0c6fff26a008ca8e5131b89d5ba6d86bd3f47e736eeb", size = 9572200, upload-time = "2025-05-29T13:33:44.92Z" }, +- { url = "https://files.pythonhosted.org/packages/99/a3/6ed10530dec8e0fdc890d81361260c9ef1f5e5c217ad8c9b21ecb2b8366b/mypy-1.16.0-py3-none-any.whl", hash = "sha256:29e1499864a3888bca5c1542f2d7232c6e586295183320caa95758fc84034031", size = 2265773, upload-time = "2025-05-29T13:35:18.762Z" }, +-] +- + [[package]] + name = "mypy-extensions" + version = "1.1.0" +@@ -4499,6 +4566,20 @@ wheels = [ + { url = "https://files.pythonhosted.org/packages/36/c7/cfc8e811f061c841d7990b0201912c3556bfeb99cdcb7ed24adc8d6f8704/pydantic_core-2.41.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:56121965f7a4dc965bff783d70b907ddf3d57f6eba29b6d2e5dabfaf07799c51", size = 2145302, upload-time = "2025-11-04T13:43:46.64Z" }, + ] + ++[[package]] ++name = "pydantic-settings" ++version = "2.12.0" ++source = { registry = "https://pypi.org/simple" } ++dependencies = [ ++ { name = "pydantic" }, ++ { name = "python-dotenv" }, ++ { name = "typing-inspection" }, ++] ++sdist = { url = "https://files.pythonhosted.org/packages/43/4b/ac7e0aae12027748076d72a8764ff1c9d82ca75a7a52622e67ed3f765c54/pydantic_settings-2.12.0.tar.gz", hash = "sha256:005538ef951e3c2a68e1c08b292b5f2e71490def8589d4221b95dab00dafcfd0", size = 194184, upload-time = "2025-11-10T14:25:47.013Z" } ++wheels = [ ++ { url = "https://files.pythonhosted.org/packages/c1/60/5d4751ba3f4a40a6891f24eec885f51afd78d208498268c734e256fb13c4/pydantic_settings-2.12.0-py3-none-any.whl", hash = "sha256:fddb9fd99a5b18da837b29710391e945b1e30c135477f484084ee513adb93809", size = 51880, upload-time = "2025-11-10T14:25:45.546Z" }, ++] ++ + [[package]] + name = "pygments" + version = "2.19.1" +@@ -4508,6 +4589,20 @@ wheels = [ + { url = "https://files.pythonhosted.org/packages/8a/0b/9fcc47d19c48b59121088dd6da2488a49d5f72dacf8262e2790a1d2c7d15/pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c", size = 1225293, upload-time = "2025-01-06T17:26:25.553Z" }, + ] + ++[[package]] ++name = "pyjwt" ++version = "2.10.1" ++source = { registry = "https://pypi.org/simple" } ++sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785, upload-time = "2024-11-28T03:43:29.933Z" } ++wheels = [ ++ { url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" }, ++] ++ ++[package.optional-dependencies] ++crypto = [ ++ { name = "cryptography" }, ++] ++ + [[package]] + name = "pypdf" + version = "6.5.0" +@@ -4613,6 +4708,15 @@ wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, + ] + ++[[package]] ++name = "python-dotenv" ++version = "1.2.1" ++source = { registry = "https://pypi.org/simple" } ++sdist = { url = "https://files.pythonhosted.org/packages/f0/26/19cadc79a718c5edbec86fd4919a6b6d3f681039a2f6d66d14be94e75fb9/python_dotenv-1.2.1.tar.gz", hash = "sha256:42667e897e16ab0d66954af0e60a9caa94f0fd4ecf3aaf6d2d260eec1aa36ad6", size = 44221, upload-time = "2025-10-26T15:12:10.434Z" } ++wheels = [ ++ { url = "https://files.pythonhosted.org/packages/14/1b/a298b06749107c305e1fe0f814c6c74aea7b2f1e10989cb30f544a1b3253/python_dotenv-1.2.1-py3-none-any.whl", hash = "sha256:b81ee9561e9ca4004139c6cbba3a238c32b03e4894671e181b671e8cb8425d61", size = 21230, upload-time = "2025-10-26T15:12:09.109Z" }, ++] ++ + [[package]] + name = "python-json-logger" + version = "3.3.0" +@@ -4622,6 +4726,15 @@ wheels = [ + { url = "https://files.pythonhosted.org/packages/08/20/0f2523b9e50a8052bc6a8b732dfc8568abbdc42010aef03a2d750bdab3b2/python_json_logger-3.3.0-py3-none-any.whl", hash = "sha256:dd980fae8cffb24c13caf6e158d3d61c0d6d22342f932cb6e9deedab3d35eec7", size = 15163, upload-time = "2025-03-07T07:08:25.627Z" }, + ] + ++[[package]] ++name = "python-multipart" ++version = "0.0.21" ++source = { registry = "https://pypi.org/simple" } ++sdist = { url = "https://files.pythonhosted.org/packages/78/96/804520d0850c7db98e5ccb70282e29208723f0964e88ffd9d0da2f52ea09/python_multipart-0.0.21.tar.gz", hash = "sha256:7137ebd4d3bbf70ea1622998f902b97a29434a9e8dc40eb203bbcf7c2a2cba92", size = 37196, upload-time = "2025-12-17T09:24:22.446Z" } ++wheels = [ ++ { url = "https://files.pythonhosted.org/packages/aa/76/03af049af4dcee5d27442f71b6924f01f3efb5d2bd34f23fcd563f2cc5f5/python_multipart-0.0.21-py3-none-any.whl", hash = "sha256:cf7a6713e01c87aa35387f4774e812c4361150938d20d232800f75ffcf266090", size = 24541, upload-time = "2025-12-17T09:24:21.153Z" }, ++] ++ + [[package]] + name = "pywin32" + version = "310" +@@ -5453,6 +5566,31 @@ wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/b6/74e927715a285743351233f33ea3c684528a0d374d2e43ff9ce9585b73fe/twine-6.1.0-py3-none-any.whl", hash = "sha256:a47f973caf122930bf0fbbf17f80b83bc1602c9ce393c7845f289a3001dc5384", size = 40791, upload-time = "2025-01-21T18:45:24.584Z" }, + ] + ++[[package]] ++name = "ty" ++version = "0.0.11" ++source = { registry = "https://pypi.org/simple" } ++sdist = { url = "https://files.pythonhosted.org/packages/bc/45/5ae578480168d4b3c08cf8e5eac3caf8eb7acdb1a06a9bed7519564bd9b4/ty-0.0.11.tar.gz", hash = "sha256:ebcbc7d646847cb6610de1da4ffc849d8b800e29fd1e9ebb81ba8f3fbac88c25", size = 4920340, upload-time = "2026-01-09T21:06:01.592Z" } ++wheels = [ ++ { url = "https://files.pythonhosted.org/packages/0f/34/b1d05cdcd01589a8d2e63011e0a1e24dcefdc2a09d024fee3e27755963f6/ty-0.0.11-py3-none-linux_armv6l.whl", hash = "sha256:68f0b8d07b0a2ea7ec63a08ba2624f853e4f9fa1a06fce47fb453fa279dead5a", size = 9521748, upload-time = "2026-01-09T21:06:13.221Z" }, ++ { url = "https://files.pythonhosted.org/packages/43/21/f52d93f4b3784b91bfbcabd01b84dc82128f3a9de178536bcf82968f3367/ty-0.0.11-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:cbf82d7ef0618e9ae3cc3c37c33abcfa302c9b3e3b8ff11d71076f98481cb1a8", size = 9454903, upload-time = "2026-01-09T21:06:42.363Z" }, ++ { url = "https://files.pythonhosted.org/packages/ad/01/3a563dba8b1255e474c35e1c3810b7589e81ae8c41df401b6a37c8e2cde9/ty-0.0.11-py3-none-macosx_11_0_arm64.whl", hash = "sha256:121987c906e02264c3b511b95cb9f8a3cdd66f3283b8bbab678ca3525652e304", size = 8823417, upload-time = "2026-01-09T21:06:26.315Z" }, ++ { url = "https://files.pythonhosted.org/packages/6f/b1/99b87222c05d3a28fb7bbfb85df4efdde8cb6764a24c1b138f3a615283dd/ty-0.0.11-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:999390b6cc045fe5e1b3da1c2c9ae8e8c0def23b69455e7c9191ba9ffd747023", size = 9290785, upload-time = "2026-01-09T21:05:59.028Z" }, ++ { url = "https://files.pythonhosted.org/packages/3d/9f/598809a8fff2194f907ba6de07ac3d7b7788342592d8f8b98b1b50c2fb49/ty-0.0.11-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed504d78eb613c49be3c848f236b345b6c13dc6bcfc4b202790a60a97e1d8f35", size = 9359392, upload-time = "2026-01-09T21:06:37.459Z" }, ++ { url = "https://files.pythonhosted.org/packages/71/3e/aeea2a97b38f3dcd9f8224bf83609848efa4bc2f484085508165567daa7b/ty-0.0.11-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7fedc8b43cc8a9991e0034dd205f957a8380dd29bfce36f2a35b5d321636dfd9", size = 9852973, upload-time = "2026-01-09T21:06:21.245Z" }, ++ { url = "https://files.pythonhosted.org/packages/72/40/86173116995e38f954811a86339ac4c00a2d8058cc245d3e4903bc4a132c/ty-0.0.11-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:0808bdfb7efe09881bf70249b85b0498fb8b75fbb036ce251c496c20adb10075", size = 10796113, upload-time = "2026-01-09T21:06:16.034Z" }, ++ { url = "https://files.pythonhosted.org/packages/69/71/97c92c401dacae9baa3696163ebe8371635ebf34ba9fda781110d0124857/ty-0.0.11-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:07185b3e38b18c562056dfbc35fb51d866f872977ea1ebcd64ca24a001b5b4f1", size = 10432137, upload-time = "2026-01-09T21:06:07.498Z" }, ++ { url = "https://files.pythonhosted.org/packages/18/10/9ab43f3cfc5f7792f6bc97620f54d0a0a81ef700be84ea7f6be330936a99/ty-0.0.11-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b5c72f1ada8eb5be984502a600f71d1a3099e12fb6f3c0607aaba2f86f0e9d80", size = 10240520, upload-time = "2026-01-09T21:06:34.823Z" }, ++ { url = "https://files.pythonhosted.org/packages/74/18/8dd4fe6df1fd66f3e83b4798eddb1d8482d9d9b105f25099b76703402ebb/ty-0.0.11-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:25f88e8789072830348cb59b761d5ced70642ed5600673b4bf6a849af71eca8b", size = 9973340, upload-time = "2026-01-09T21:06:39.657Z" }, ++ { url = "https://files.pythonhosted.org/packages/e4/0b/fb2301450cf8f2d7164944d6e1e659cac9ec7021556cc173d54947cf8ef4/ty-0.0.11-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:f370e1047a62dcedcd06e2b27e1f0b16c7f8ea2361d9070fcbf0d0d69baaa192", size = 9262101, upload-time = "2026-01-09T21:06:28.989Z" }, ++ { url = "https://files.pythonhosted.org/packages/f7/8c/d6374af023541072dee1c8bcfe8242669363a670b7619e6fffcc7415a995/ty-0.0.11-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:52be34047ed6177bfcef9247459a767ec03d775714855e262bca1fb015895e8a", size = 9382756, upload-time = "2026-01-09T21:06:24.097Z" }, ++ { url = "https://files.pythonhosted.org/packages/0d/44/edd1e63ffa8d49d720c475c2c1c779084e5efe50493afdc261938705d10a/ty-0.0.11-py3-none-musllinux_1_2_i686.whl", hash = "sha256:b9e5762ccb3778779378020b8d78f936b3f52ea83f18785319cceba3ae85d8e6", size = 9553944, upload-time = "2026-01-09T21:06:18.426Z" }, ++ { url = "https://files.pythonhosted.org/packages/35/cd/4afdb0d182d23d07ff287740c4954cc6dde5c3aed150ec3f2a1d72b00f71/ty-0.0.11-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e9334646ee3095e778e3dbc45fdb2bddfc16acc7804283830ad84991ece16dd7", size = 10060365, upload-time = "2026-01-09T21:06:45.083Z" }, ++ { url = "https://files.pythonhosted.org/packages/d1/94/a009ad9d8b359933cfea8721c689c0331189be28650d74dcc6add4d5bb09/ty-0.0.11-py3-none-win32.whl", hash = "sha256:44cfb7bb2d6784bd7ffe7b5d9ea90851d9c4723729c50b5f0732d4b9a2013cfc", size = 9040448, upload-time = "2026-01-09T21:06:32.241Z" }, ++ { url = "https://files.pythonhosted.org/packages/df/04/5a5dfd0aec0ea99ead1e824ee6e347fb623c464da7886aa1e3660fb0f36c/ty-0.0.11-py3-none-win_amd64.whl", hash = "sha256:1bb205db92715d4a13343bfd5b0c59ce8c0ca0daa34fb220ec9120fc66ccbda7", size = 9780112, upload-time = "2026-01-09T21:06:04.69Z" }, ++ { url = "https://files.pythonhosted.org/packages/ad/07/47d4fccd7bcf5eea1c634d518d6cb233f535a85d0b63fcd66815759e2fa0/ty-0.0.11-py3-none-win_arm64.whl", hash = "sha256:4688bd87b2dc5c85da277bda78daba14af2e66f3dda4d98f3604e3de75519eba", size = 9194038, upload-time = "2026-01-09T21:06:10.152Z" }, ++] ++ + [[package]] + name = "typeguard" + version = "4.4.2" diff --git a/genkit-tools/cli/package.json b/genkit-tools/cli/package.json index 2480b48a7f..baf67cd48e 100644 --- a/genkit-tools/cli/package.json +++ b/genkit-tools/cli/package.json @@ -32,16 +32,17 @@ "dependencies": { "@genkit-ai/telemetry-server": "workspace:*", "@genkit-ai/tools-common": "workspace:*", + "@inquirer/prompts": "^7.8.0", "@modelcontextprotocol/sdk": "^1.13.1", "axios": "^1.7.7", "colorette": "^2.0.20", "commander": "^11.1.0", "extract-zip": "^2.0.1", "get-port": "5.1.1", - "@inquirer/prompts": "^7.8.0", "open": "^6.3.0", "ora": "^5.4.1", - "semver": "^7.7.2" + "semver": "^7.7.2", + "yaml": "^2.8.0" }, "devDependencies": { "@jest/globals": "^29.7.0", diff --git a/genkit-tools/cli/src/cli.ts b/genkit-tools/cli/src/cli.ts index b45ed7054d..614aba79e7 100644 --- a/genkit-tools/cli/src/cli.ts +++ b/genkit-tools/cli/src/cli.ts @@ -23,6 +23,7 @@ import { } from '@genkit-ai/tools-common/utils'; import { Command, program } from 'commander'; import { config } from './commands/config'; +import { devTestModel } from './commands/dev-test-model'; import { evalExtractData } from './commands/eval-extract-data'; import { evalFlow } from './commands/eval-flow'; import { evalRun } from './commands/eval-run'; @@ -59,6 +60,7 @@ const commands: Command[] = [ initAiTools, config, start, + devTestModel, mcp, ]; diff --git a/genkit-tools/cli/src/commands/dev-test-model.ts b/genkit-tools/cli/src/commands/dev-test-model.ts new file mode 100644 index 0000000000..b03d01c429 --- /dev/null +++ b/genkit-tools/cli/src/commands/dev-test-model.ts @@ -0,0 +1,545 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { + GenerateRequestData, + GenerateResponseData, + GenerateResponseSchema, + Part, +} from '@genkit-ai/tools-common'; +import { + GenkitToolsError, + RuntimeManager, +} from '@genkit-ai/tools-common/manager'; +import { findProjectRoot, logger } from '@genkit-ai/tools-common/utils'; +import { Command } from 'commander'; +import { readFileSync } from 'fs'; +import { resolve } from 'path'; +import { parse } from 'yaml'; +import { startDevProcessManager, startManager } from '../utils/manager-utils'; + +interface TestOptions { + supports: string; + fromFile?: string; +} + +type TestCase = { + name: string; + input: GenerateRequestData; + validators: string[]; +}; + +type TestSuite = { + model: string; + supports?: string[]; + tests?: TestCase[]; +}; + +const getMessageText = (response: GenerateResponseData): string | undefined => { + const message = response.message || response.candidates?.[0]?.message; + return message?.content?.[0]?.text; +}; + +const getMessageContent = (response: GenerateResponseData) => { + const message = response.message || response.candidates?.[0]?.message; + return message?.content; +}; + +const getMediaPart = (response: GenerateResponseData) => { + const content = getMessageContent(response); + return content?.find((p: Part) => p.media); +}; + +const imageBase64 = + 'iVBORw0KGgoAAAANSUhEUgAAByIAAAGdCAYAAABel7RVAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAADv9SURBVHgB7d1bchPn9jfgF58qdx97BH9xquIuMIKYEQRGgBkBZgSYEQAjwIwAMgKcEcS5SxXY1h7BZt+lbGy+tUwr2yEcbEl91PNUKZKND7LU3VLeX6+1SgEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAABu5SAQDokJs3b44ODw/vXrp06aePHz/eik+Nzvzzbnz+fXz+l7W1tdd//PHHuAAAAAAAnSSIBAA64fr16+sRMD6Oy/oFvi2Dyed7e3vbBQAAAADoFEEkANCqKQPIz43X1tbuqJAEABimUfjhhx+K93sAAP2yVAAAWhBrSZevXr369OTk5M2MIeTpjzs8PDy4du3ai2ztWgAAGJSVlZXRn3/+ebkAANArgkgAoFFVAPl4aWnpID7cLHMUgeZGBJK/5c8vAAAMRrzPu7y8vHyrAADQK4JIAKAx2YY1Asjf4uZWXOo6oz1/7laEkVkhuVEAAOi97KARF0EkAEDPCCIBgNrduHHjVoSCb7INa3w4Ks0YxWLVi/y92rUCAPTbpUuXfoyrnwoAAL1yqQAA1CTbsC4tLWWb1Lm2YJ1GLF5tr66uPvnjjz/GBQCA3oj3lKOqrX85OTn513g8fl8AAOgFFZEAQC2uXr36sI45kNOq5ke+ifvVifsDAMD5LC8vr09ux/vLjQIAQG+oiAQA5irnQJ6cnDyNm12e4TO+dOnSk729ve0CHZZVxaurq6PYp25FmD6K7fb/yqc5qJfz4+rLLpevz1x9P7nE955ex/eN4/q/8fFufu7Dhw+7KksA6LKrV6/mjPHT95bx2rUT7+HuFAAAekEQCQDMRbbMWl5efhEhx3rpCe1a6ZKcZXp0dLQe+9CPsW1m8JgLrpdLMzKo3I3fmeHk70tLS7tv377dLQDQsuoktzdnPxevU3fevXu3UwAA6DxBJAAwk2oO5MPyqQVrU6HJvD1bW1t7LpCkSbnvRHh/N27+VAX4o9ItGU7uxH37Ne7njmASgDZcu3btzecnuqmKBADoD0EkADC1nAMZV1ulvwHkWdq1UruqquOn2NbW+1Q9XMl9ZCcuL1WhAF2UgVWZszwZY39/f6vQinhON+I5ePGlf1MVCQDQD4JIAODCMkyJRaHHPQxSzmMcC1sPLGwxL2eqhrP6scuzUy9CKAl0ztWrVz+WOcs27nt7ew8KjateP3M25OgrXzI+OTm5bc4xAEC3LRUAgHPKOZDXrl17lXN6BhpCplH+ffF3vsiZfQWmlIF9VufEIup/yqfK4aGEkGkUx4CN3Fdi4f8gK1bsLwDMU7x+Pi7fbls+qr4GAIAOUxEJAHzXQOZATmsrwpbnzrbnPBZ8XzmtHFpdXX1i3irQBhWRwxHPZQaMW+f88kf7+/vPCgAAnaQiEgD4pqx0qtpibZXmg5UM/55EEPivvC7t2Mq/Px+HAl+RAWQumsa2clCGMzf1wrJK8vDw8EBFMQDTiteQ++X8IWR6Wn0PAAAdpCISAPiitudA5vy51dXVB2crq7I1bHx+Ky5tLTaN19bW7qj2YmLRKyC/R4Uk0CQVkf2X7z9PTk5elSleU+O52ojn6mUBAKBTBJEAwN9UwcrTuLlR2rEbv//Ru3fvdr72BRlIxte8Kd+eG1Qb4QqpahsngDwH+wzQBEFkv2VV48ePH7fLbLRpBQDoGK1ZAYBTn7WW3CjNyzasuXh0+1shZBqH+LorsTiYC4Pj0rBJ+8kqiGLBZLVGPPcL3YL1ouwzAHxLvj7MIYRMT73WAAB0i4pIAGDSButFaanCMDyP378V+eL7C37fpIIzq9LaWnQaRyD6ZG9vb7swaFmJu7y8/KKtdsUDMo599sH3TjgAuCgVkf1T12vrl1r8AwDQDkEkACywGzdu3IoA8GmX5kBOqwPzI3fX1tbuWfAapljczjmQW0UF5Dw9i+PPk2lOQAD4EkFkfzQ4Y3kr3p+99P4MAKA9gkgAWEDV4s9kvl0baquIynD1+Pj4VTE/kjlQBVm7cSwQ37G/APMgiOy+BgPIs8Zx2RZIAgC0QxAJAAum5cqurHx6vr+/v1Vqdu3atY0IjzJsHZXmjcunv/NZobeqlsUZaquCrN8j+wswK0FkN928eXN0eHh4Nx7Ln9s+sSe7ccTVywhDd9++fbtbAAConSASABZEn+dATsv8SKYVi9lPS3sVw4vqWYSRjwrAlASR3ZDdKY6OjrJl/mn4WLp7Qs84Lr/EfdytO5is3oevlzlbW1vbVuUJAHSdIBIABq7t1pJ55nmGcXW0YT2vtudHatfaHxlex/7ySivW1pi1CkxNENlNVdv89a5URMZ9+CVe63earIiMbXOr1HBiXASod9p8jw0AcB6CSAAYqDMzeLZKO2qbAzmttqtCBZLdloF1bLNvSntVw3xibiQwFUFk9505Oeyn0tzr7elogHgP+KzJzhxnCSIBgEW2VACAwck5kLEwcVDaCSFzgedJLPbc7trCSN6f/f39K7H4lQuK49Kwjx8/bhweHr7J+ZWFTsmQOvaZ34oQsgtyltibnClWABiUCALHBwcHG/E+8U58+KTUL9+TXsn55G2FkAAAi04QCQADkmFKhFxZ0fWstDCPJ6sG+rDYkzMbG1wA+9woAskXERYfxHN1t9C6eB7ux/aQ+01XZ1gtogwjf8t2fgWAwclAMt8v5vvGUs/JYdnqWwAJANABgkgAGIBscxVhyqsMU9qYvZPzdrI1VLYu68tiz9kFsHjMXpbmZSD5Kp63Fyq/2hOB8ON4HrZL/+V+N/7s0neXj4+P3wgjAYarej+WYeQ8Tw57Hj/zthbfAADdsFIAgN46MwdyM8KUNqq53kcI+SgrDEtP5QJYXG1cv359u435kVW71o2cHbS2tvbSollzMoQs7c1QvahxBv6xvfw7rvP2eGVlZfznn3++/174n8eJ1dXVDL4vxyWvcz7Xj3kd/9z1kO80jIyw3oIywIDlyWHxupyvZ0/LbB7Fz3pWAADoDEEkAPRUzhmMICGDlFFpXi4UPY/g7tlQ2l1V8yyvtPi4bmUgGb//SZ+D3b7oeAiZAf9ubIe/LC0t7X748GF3lv2s+t7dL/1bhpTLy8vrVSX1T6WbweTlo6OjvF/jAsBgZYAY74PeZwv7MoWcAe49FABA9wgiAaBncg5kBmVttGBNscjzenV19dFQq5NyASvCmWw1uxkfPizNmsyPfBzh0L23b9/uFuauoyFkho+v4/Jy1uDxIqrf87q6nLZ5zmAybv5cHWPanpv5PvbFe9WJAgAMXL4PizCyXDSMFEICAHSXIBIAeqJqw/r05ORko7Qg20LG5ckiBAJVu9bNeMyfxd+8FZf7pVmj4+Pj32IhbjtC3ydaUs5PPKb3Y3Fzq3RE7ldx9Tye750uVBdX2/52dTmtvI6r+y2d+PA+QtE7AnmAxZKB4tWrV0dx8/E5v0U3CQCADlsqAEDnxWLMwwghD+LmRmlehiM5B/LOolUlZShzcHCwEWHRvdJCW8hqfuRBVcHHjG7cuHErHtPt0r7cp56sra1dyf0qLq+72uI4F3bzPp6cnFyJx+5l+XTfmyCEBFhgOTOyet35pniPtp1fWwAA6CxBJAB0WLZJvHbt2pu4+ay00yLxSQYQObOnLLAMiuIxuJJtv0o7c+q2Iow8uHnz5qgwldyXjo+PX5V2nQaQ1T611adK10koH/f9dnz4pNQbSAohAcgTsrJN/vgbXzLOzhEFAIBOE0QCQEdlcLK0tPSmjZaI2S4yq7UyLOlqpVYbsjosgpg7cfN5ad7o8PDwTVb1FS5ksi/lzdKOvwWQfd6nMpDMvyEDyfNUqkxBCAnAqXy9jNfvB1/79xwZoH09AED3CSIBoINaDE7G8XuzVeQdCztfVgUxm2daVTYpq/reqIy8mAi2npaWQsgq1L89tFD/TIXklTK/KmEhJAB/k2MBqnnKf5OfMxcSAKAfBJEA0DGRQV5uIYQ8nQOZ7UcXbQ7ktFqcH3k5KyNzOyl8V87XjMD4bmne+9w2hh7qV8F8hpGztsYTQgLwRVn5eJ7PAQDQTYJIAOiYCCEfl2ZDyOfmQE6vpfmRo2o74RuuX7++HldbpWGxLbzOfSq3jbIgqnat01ZHCiEB+KovVEWOnTgHANAfgkgA6JBsyRpXm6UBuaATi//ZMnLTHMjZTeZHNtiudbMK2viC3Jfi+XhRmvcotoV7i7hPZXVkzo4sF5uhKoQE4Lvi/dUvk9uqIQEA+kUQCQAd0lCV219zIC3+z9fZuXlNBJLxO1RFfkULlcXjKthf6MriDGDz5IZyvlatQkgAziXeW21Pbq+uru4UAAB6QxAJAB1RVUNulPpkhdaTrFjqWzurmzdvjkqPnJkfWWu71ggi182K/Kdr165tlHr3pc/trq2tCdTOyFat1fb/tcrQcTxmtz1mAJxH1WkgXzN2hzx7GQBgiASRANARy8vL66UmEQhsV3Mgt/rYMvLDhw+j0kPZrrXu+ZFLS0sbhb9koN9kpWi2OM6WvBZF/ym3/6x4LP8MIzOE9JgBcFG/VhcAAHpEEAkA3fFzmbMMSao2rA/6PLMugp710mN1zo/MqsjCX2J7f1gaasmaz2e2ODZj9euy4vGzMFIICcBU4n3tbr63LQAA9IogEgA6IkKNUZmfcVbhZUjStzasXxJ/y4+l587Oj4y/53WZkyE8NvNStTfeLA3IEDKfz8J3ZRgZAfG9IoQEYAbxWpJBpJN/AAB6ZqUAAF0xKrPLxZnnEXY9G1KV1pCq/jKQjKt7OcewaiE6KrMZFU7FAuWr0oAMkvf39zcK51adEHGlAMCUjo6O3v/www+CSACAnlERCQDdcbnMIMORtbW1232dA/k1o9EoH5fLN2/eHJUBmcyPjJuPyj9n6HFBGezG1a1Sv93j4+MHBQBoVJ7MpaoeAKB/BJEAMAzjCLbuDXFxZmVl5TRcOjo6Wi8DFGHksyKInFlVXVq3bC16z0xIAAAAgPPRmhUA6LSTk5P1vI6gqYlqN3qoanM7KjUz35A++VoV+Z9//vlemD5MX3rOPd8AAEDbBJEAQKddunTppwiZ8vrHAl/QUDXkIyEkXZKh09HR0a0M4avj46gK5E/bWR8eHn7x+5aWlsrVq1fzZoZT4/je0+v43t/j9jj+ffz27dvdQidVz/t6PF8/xvM1qk7S+epz/qXnO74nn99/x7/tfvjwYVdQCQAA1EkQCQB0Vs6HjAXT9byd1/mxBVPOaqIaMhbut/f29p4VaFEe/yI42sjQMY+HETqNJv+WJ2tMIcOrW2e/N28fHx9ncPU+fs9ufPzL8vLyjmCyPfm8Z4vyeC7ux+VuPO9/zZO+4PN+9vlez/+cnJycBpVxHN3JcDJu//Lu3budAgAAMEeCSACgs2IBfP3sQmt8fDeutgtUGqiGHK+urj4p0IIqfHwYoeD6mZMySgMmJ4GsV8FkVtLtxOWloKoZ169fX4+g8Oe4uRHXl0uNJs91/J7NyXOdxz1V4AAAwDwIIgGAzorF0Z8/+9T9Ioik0lA1pMV4GpchVIbsDYeP35ItQHN/2xhSUBXHkDfzPIbk47K3t/egzGDy3E/mI7fg9Lk+PDzcyErJPAYKn/tjFPIkrtKAbOU867ZRBe4vypzNY1+cuHHjxq3j4+O7ZUaTUQPzltXScVxeLzWKbeq1yngAYBaCSACgk3IxLa42zn5Oe1Y+c7/UqGrJul2gIVW4/jAW5m+V7jobVG33OZCsQshRmZ9RmVIHAsh/yNfcvGQgGc/zAydldFu+b1paWppruP4dO9VlVqMyf6MyJ9XxeObuC3WdUJLH41KzeAzGcSWIBACmtlQAADroa2f0xyLbZmHh5YLrpFqsLlqy0pQMoa5evXoQ23RWBnU5hPybKpA8iKDqxc2bN0eFC8tjWVZmxkL/m7qPadOqZpIexDb6NE8GKnTOJIQs9YR6X/I+ft+jAgAA5yCIBAA66Ruz/x4WFl4sgNY9G/K56h/qli3/JiFUaS5AmLtJIBlBVd375aDk4xXHsoOuBpBfsBn39zehc7e0EUIuLy/f0aoTAIDzEkQCAJ2T7QnL1xfULlf/zmJbL/UZr62tPStQowyhjo+Pf+tRCHUeW1nZmRWeha/K4Cgep9/i5lbpn1EVOutO0AFCSAAA+kAQCQB0zjeqIc/17wxbBNF3S72Lri9VQ1KXrILscQh1HqOs8FQd+WVx/LqfVYWlRy14v+Kp57hdQkgAAPpCEAkAdMp3qiEnRqoxFtr9Up/3a2tr2wVqEMeth8fHxxkc9D2EOo/T6khtPP8ng7uPHz9ux82hzFncEka2QwgJAECfCCIBgM7IhbULVDs+ji8fymIuFxDbyN1Sk0uXLr1WDcm85bEqApuncTNb/i7ScSvbeL7JKtCy4KrAbqsMjzCyYXk8iRDyVWkuhMx25beFkAAATEsQCQB0RiysPSznX1jLhTiLnwumastam9XV1ScF5uhM5dKiVnGPchbmIlexDziEnNiq+9jMJ1UI2WRVdYaQd5ygAwDALASRAEAnXL9+fb1cfKF+s/o+FsTHjx9/LjWJn/2LxVbm6UwIufAVgWVBZwouQAh5Ko6fL7ThrZcQEgCAvhJEAgCty8X6k5OTF2UK8X2vLH4ulPVSk1jg3S4wJy3McOuDhWrjmTNBywKEkJXLh4eHr7RMr4cQEgCAPhNEAgCtW15ezhByVKZz+ejoaKoQk36p5syNSj3Ge3t7rwvMgRDymxYijMxtoHyaCbpIbmmZPn9CSAAA+k4QCQC0KhekP378uF5mkN8fP+dpYdBOTk5qW4S9dOnSToE5EEKeS84UvF8G6kxwtIg2q5NGmAMhJAAAQyCIBABaM+fZWZuLOH9swdQ2HzKCyJcFZiSEPL+PHz8+G2pgVVUFjsqCOjk5cWLQHAghAQAYCkEkANCKmmZnLdT8sUUTwcWo1OP9u3fvdgrMKEKDV0UIeV6Xj4+PBzVTMI5Rt65du5bB0WZZYNml4Pr16+uFqQkhAQAYkpUCANCwbMmX1TClHhlGlv39/SeFwajCiloWZGNb/LXAjKr20G1X+L2Py/jSpUu7sV3/t/r4c5fj3/+vCvbbvr+j5eXlDG/vlGG4PGur8aE4OTnJ2c1XChcmhAQAYGgEkQBAo6qZkFulXsLIgVlZWbkVC9ulDuZDMquqwruVKrjcfuOY+ksEejtv377dLReU7VGPj4/X4+f83EaIVs343YzjdV0npwzR+LOPL1eXLhllVaRq84sRQgIAMESCSACgMXOeCfk9wsgBiRCytkXZ+Nk7BaaUcyFLc8e1iax0fB7b7rPxePy+zKAKL/PyLP+WCDTXIxxsesbh45s3b74WhvxTFTT/nterq6u733qMMlSOr70c20UGyz+1XZ1ZbUc7hXMRQgIAMFSCSACgdrm4FovbL2JR8m5p1ta1a9fWY/H2gYW23huVeryPIOfCVWQwUQUHTVWjvY+A6dHe3t52qUHsC+O42s5LHDs3GgwkLx8dHWUrz6G0aJ3VVEHzmYrYnfxP9dqbr7v326p2VRV5PkJIAACGTBAJANQqFyFzVlQ1j6xxuRB6eHj45saNG/emaVtIN0T48mM8l2XecpZegSlVVd6j0own86iAPK8q7GwskBRanZpr0FxtK9t5yWrX+NlbcblfGhTbbAahO4WvEkJ22/Hx8U4E+g/K7H6u6YS8J7Ffj0uNVldXdwoAwAwEkQBALaqFtcexCNnK3LTPjGIh6bcIDba0au2nuoLsbHlYYAoNtmQdxyJ4aydSZCgWf+pOEyFWnrQSV1fKYsoKyK26guaq2nUjguWdhlvv5jbThfcBnSSE7L4zleIzifego7iaexAZ28+OqmMAoOuWCgDAnGVVSyyM/Fa6t/iYcyMPssKn0DejUo9xgSlEOPi01C/DqdttV3PnQvzBwcFG3HxU6jVawONzVkHe29/f32yi2jWD5dim7sTvfF2acTnfExT+QQgJAMCiEEQCAHOTi42xiPwmFjlzYW1Uumn08ePHF3k/LY72Q1V5VotYBNaalQur2pXWPfP2SVPh1HnF/XkWAezt8mmGYS2qar1FkcHQ7QgHmwoFP/3SEL/zXtxspENA1Z6VM4SQAAAsEkEkADCzswFkzvkqPZD3M+9v3m8Vkt22srIyKjX58OFDZ0Ie+qOBsCxDyK3SQVmdGWHknVJfGLkoVZG7We3aZjBUbWO1h5E547fwFyEkAACLRhAJAEwlw8erV68+jst/+hRAfi7vd1ZIVi1bX6iSXDjjAhdQhWSjUp/OhpATGUZGkHKv1KfWWZQdkCHknS5Uu+a2Vneb1nydzfCtIIQEAGAhrRQAgHO4efPm6OjoaD1u/pQtCWMRdWiLitmyNdstbkQoOY6F2Z343K/ZurPt+WyLLp6TUalJl9pezkMG6fF4DT3EmcavORuvzEHN1ZCdDyEn3r17txPHypwZOfdZmRlc5bacv6MMz7grIeTE8fHxg3ity2BsVGqysrKSP3+nLDAhJAAAi0oQCQCcqqoVTsPFbIUZC6XrEcb9XywIny5OHh4eLlI1w2koGdcbsUBbYrE9P7cbj8c4Pv97Xsfnd6qvfT+0MKtr4jGva9sbl4HJ0LbadjkjHpN/xdV2mVE1G3JUapBVaRGWbpUeyZmRV65cuRX3fe7hdzzOD8vwgqtJMNSp14x8DYvg90E137kW8bMXOogUQgIAsMgEkQDAqSpMmyyOjuOykwtnq6urGUpmuJHB5I99bcE6hffx9+5WwWO2Idw9OjoaCx1boaUfM5nXjLoaqyHHcax9VHooHpPNeHx/KnOuppu08xzSMTcep0ddDYay+jSC9p26XuMXeU6kEBIAgEUniAQAvqpaAN6tLn/NkMqWecfHxxt1LD63LP/el7Fg+PrDhw+7Qsdhy8rWwqIYlRnlcS9Pyig1iG3xSV9DgzxORoD1KAKsV2W+MrzZiOtnZRie7+3t1TqLcVa5HdYVRC7QSUx/I4QEAIBSlgoAwAVl5cTBwcHG/v7+lVi4vJctBUuP5TzIWCi8E3/Pv+KymX+fEBKGJefclhlECFnL7M04/mzPa35lWzJgq+bqzlX8zJ/LMORcyK3ScdVMzrpe+xausl0ICQAAnwgiAYCZ5AJ0XO7FIuuVjx8/viw9kgHA8vLy7bj/d6oFWLpJa1ZmdnR0NHUYEIHCKK42Sg1WV1eflAHIaroyZ5P2rKXn8rHp0cktdb2OX571ZIA+EUICAMD/CCIBgLnI4YlZJdmHQDIrd2LB7koEkA/evn27W+i02J4EkczD1NvR8vLyeqlBngwxlOCgOplj7sfTeOzvln4b96niNVuTl5r8+eefC3EsF0ICAMDfCSIBgLmaBJLZsjU/LN0yzhasWQFpwa4/zHJkHiLQHpXp1dKWdSjVkGfUcRLKT6XH4jXnQemROtuzRqg8KgMnhAQAgH8SRAIAtciWrScnJ7fj5vPSDc/z/mjBCospAu3/K1PItqzZIrTM2ZCqISfiGLtd5qyOx75B4z6+5sS2WVengEFXRAohAQDgywSRAEBtcibW/v7+ZixqZkVIW/Ox3ufvz/vRoxldNGDGCjkWRI1tWXs1U/c88hibra/LfI36OluwjrmZTYhj4++FCxFCAgDA1wkiAYDa5XysqjpyXJqVC3W3+zSfC6jNqEzn5zJ/vayUO48IsX4pc3Z0dLReemh1dXWn9NO41GCoJ38IIQEA4NsEkQBAI3J2ZISRd+JmXS3fPrdroW4YapwROeg2gcxHTa1B5x7WdcXy8vJOmbN4DpoKeOYmA9m+vv7EMVf3gHMSQgIAwPcJIgGAxkzCyFjkfF3qtZu/x0Id3yGI5Jtu3LiR4cLct5M6Zil2xdu3b/Nkk7kGWX2spItwqu7XudrUePLHoAghAQDgfASRAECjcobY3t7evVhYrms+2mkIaR7kcNS5KN7X2XM04/j4eL3M3/s4PjVVGd6WuR5/4xjwY+mZHrdlLR8+fBgXvkkICQAA57dSAABaEEHkZrW4PM9FPCHkAOWieCz4ljrEzx6V5meX0h9zD8DiuLc79AD86OhoPOcqxlEGPz06to8FRsMlhAQAgIsRRAIArcgF5VjMuxOLeb/Fh6Myu1youxcLdULI4antOe1jy0emc3Jy8t9ycXMPGnLm5OHh4UHhQn744YdskduL43s8x78XBqmFEHLSat57GwAAektrVgCgNRlG5gJbmX1x+b1qgeGqqqBqWYSNwKCpxWRaFuHBf8rF2T464ujoqDfPRVa9FganhRDytDW5Lg8AAPSdIBIAaFWusMXC3r0yg1ioeySEHLxxqYGKyMURz/WFKiJv3LghhOyWy6UnBJHD00YImeK4dffatWt3CwAA9JggEgBo3bt373bi6nmZzpO9vb3twqDV1eqwmlPKYrhQVVFsc70JvhZBn04aiOOKCrZhaSWEnIht/0UGoQUAAHrKjEgAoBP29/c3r169+lO52ELfOL5vqzB42Z6u1GOUC7wDan2Xf8e49N+ozN+FnuOTkxMVkd3SmyBmZWVlXBiMrEos7bq8vLz8Iq5n6h4BAABtEUQCAJ2xtLT0KBb/35z362NhzqLcgsggMhaDSx1iO1qPq9dlAPb29vLv6PXfcv369fWLHAfOa4owWwVSh8Tz9/9KT2gVzrxNWrRWx3gAAOgVrVkBgM64SIvWWJTefvv2rTlcC+L4+Hin1CQWeNcLnVFXS9SLtss0PxToEi1aAQDoK0EkANApJycnW+UcLRRXV1efFBbGOJQLtta8gJ8KnRGL7bW0RP3w4cOFTlxYWlrqTQXeIojXhn8VWGyTFq0AANArgkgAoFOqWX3frIrMakit7xZPPO91VcDeUmXSHfE8/1jm7/0Uc0BtEx0iGIb/tWgtAADQI4JIAKBzTk5OnpVvVL+phlxMsQD7e6lJhBwbhU6ooyVqjSE2QKOyRevNmzdHBQAAekIQCQB0TlW59PJL/6YacnFFWPi61CS2q58LrasqU+femvXk5OS/BWAYLh8dHWnRCgBAbwgiAYBO+lrodHx8/LywkKoZf7XMifz48eO69qztW1lZqWU+ZATNO+WCYpuwPXSIMBn+J1+zrl69ulkAAKAHBJEAQCe9e/du5wvhwThosbigslK2zhabEX5b1G1ZhE21zD6L59Zxo+fiOfxPAc56rEUrAAB9IIgEADrr48ePv3z2KdWQC+4L28Q8PSy07adSg6qa9kIi9K6l+pbpxL7/7wKcpUUrAAC9IIgEADrr5OTk9Wcf7xQW2ufbxJxdvn79+nqhFaNQapgPGXarubMAg6JFKwAAfSCIBAA6K/uw5tWZD7VXXHCfbRNzF4u6jwutWF5eXi81qLOdL40SJsOXadEKAECnrRQAgG7LVpwPIyD6vcAnL+NSS2CY1SVZFZkzSgtNu1/qMVU735OTk39HiFnmKbavlxG4bhcubJr2utC23OerNs91tv6etGi9UwAAoIMEkQBAp2U1Uyzk5fVOgXIaEG0vLS3VVrlYVUXuFBqTbVkzBC41OD4+3ikdkYGEkBsWQ4aQBwcHG3F4uxyvWXmixeVSk0mL1v39/WcFAAA6RmtWABiGy7EANciWkpMQIRbxBlkNUz1vo8K5ZXvWOoPpXNC9du3aRqExdQXLuZ1MOx8yvndc5ix+5v8rwOBNQsi8nceg2PcflPpp0QoAQCcJIgFgGPIs+60ItQ6GFqBUMwEH15Yv23/m8xU3twoXFou8U7XbvMDPf5xVLIXaZTVkXK2XGsyyndQRRNZV9Ql0x9kQcmJvb+91A50dJi1aAQCgUwSRANAd4zK7bG/4IsLIV0M6Kz7+pt1pq5q6JkOXeH7enJycvCnzqYQcxONyUdmetdT7t4+WlpY2C7VbXl5eLzVVBa+trb0u06tj+xoVYLC+FEJOHB8fZ1Vkra/ZkxatBQAAOkQQCQAdkbMQy5zEQtTdw8PDrI58MYRAMh6bWqvfmpDVddmGNcKtg3lWRc1zu+mTKph+Wer1+MaNG7cKtapmctZh948//hiXKUVoUMu+FccC2xQM0LdCyFR1eHhS6qdFKwAAnSKIBICOiAWs38ucxc/ciEDyTd/btZ6cnPQ6bIsA8mEGkKWGNqx1bDd9EdvFs1KzCKO0uatRnTNSI6R/XmYwaQs9b3EsWC/AoHwvhJzY399/pkUrAACLRhAJAB0Ri9M7pR6n7VpzHmHOJSw9tLq6Oi49lI93tmGNmxmY1TJvMLabWVpP9loGRbGgW/fffyv2na3C3FWzIWtrIRjHjZ0yu3GZM3MiYVjOG0JOaNEKAMCiEUQCQEe8e/dup+az5Ec5l7CP7VqPjo7GpUeqOZCv8vGuOXQY53ZTFtisVW/n9LivIX6XRYie1ZC1BPSxXWzP0pZ1IvbfX8ucxX37qQBD8eQiIWTSohUAgEUjiASAbql75t2kXetBtkTMuYWlB6p5gJ13Zg7kbzmns9QsAo0mFjI7rYEA/1SEyi/6sr/0QdUueqPUJLaJuRxLa5rBelmwDYPwZH9/f6tMQYtWAAAWiSASADpkb29vu4lQpbKVgVnf50d2RTyOd/PxLJ/mQDYRWI1zeylM2tzVbbS8vPyqMLOsGI6g/nGpz9wqhWOfrmU+bQTbtZ+oANRq6hByIo4vj0rNtGgFAKALBJEA0DFNLEydMZkf+Zv2XdOZzIGMxzFDqlFpyNra2p3CqarNXe0tWqsF3aeFmUSgmxU6o1KTeVYKV4FmHRXZ91XYQm/NHEKmt2/f5okOWrQCADB4gkgA6JhqYarJMDLdynatfZwf2ZaqDevTBuZAfsmTecy/G5J4HrZKPYHR5zaz/W5hKvnY1T03dd6VwnW1Z11aWtooQN/MJYSciNeuZ3E1LvXSohUAgFYJIgGgg3J2UGnmLPm/qeZHvtHG69vi8XkYIcJB3GzjcZrrIuhQVHNEa6+KrGxFaH+/cCG535RPrYvrNPc5u3Fc/KXUwzYE/TL319987Yr3E7W3F9eiFQCANgkiAaCjqsWupisj0yguT2PB6sD8yL/LNqz5uMTNDIrbaKv4SAj5dQ1VlpyKRd1tYeT55QzV8mm/qdN4bW1tu8xZbFevSz1uVY/LIOUsUK8hDEhtJwFVLaCbOJFGi9YBivcj2nwDAJ0niASADsvKyFgEv1IaClc+czo/UrvWvxbU32Qb1tLgHMgzxktLS3eqSlm+oqnKkokMI1WYfN+NGzdu5bGk1O9lHS2LcwbppUuXdkoN4nF5OtRZkTkLtJpB7KQWei2rous+CahqLz4u9dKidZgEkQBA5wkiAaDjchE8FsAyjMzqyHFpWNWu9SDnIS5aIFnNgXwc4dZvLcyBTNluNKswrlQVE3xHg5UlE0/NjPy6rCI+Pj7OAL/uhdJxnUFB7P+/lnqM4vgyuDA7g8czx8yRQJI+i330P6VmWrQyLRWRAEAfCCIBoCeq6sg7seAw9xlo57SZ8yMXZSE5/85qDuRWaeds8+dZDasV68U1VFly1lYG9UOtbJtWtq6tqohrf1wuXbpU60zdqu1vXR4O6SSPrCDPSs8v/ZNAEr5Oi1amEa9/PxYAgI4TRAJAj2R15MHBwUYGVLHwUNfcsm/5ayE52y2WAcoKrmzDWrWSbDxYyhaQVRvWzayQKFxY0y1aK5tZOWtx95OsEs3WtaUZ4729ve1So9ym6mrPGi4fHh6+GkqQHfvB98JngSR8RXUiTd2v/Vq0DkgcTwf5fhwAGBZBJAD0UAaSsfB+LxbGM2wZl+aNjo+PfxvS/MiqDeuLrOBqqQ1rzqG7F8/rHW1YZ9dCi9Y0ikDpt0Vue5f7URwXXpVPlcSNWFtbu1MaUHPV5a0I8Hrf4jcrg8v55+gKJOEz1UkPWrQOU10B80hHBgCg6wSRANBjWQVUzY/MBfLGq+fOzI/s7QL6mTmQ2YZ1ozTvdA5kBKC34/lso8p1sKrKkt3SrFwMfDqkkP68spq4mqd6tzTn+R9//DEuDajC7XGpz2afg4HqdWCa+y+QhDPyvUBDXS+eDrW7RRfFc1rX+/TLKysrnkcAoNMEkQAwADlHMIOsFudHbvVxEXkSnJSW5kDGotT22tra7Xz+tGGdv3xMY7+4V9oL6RdipmoV5j+t5kGOSnPGVdjcpLqPsRli3y89U4WQW2U2fwsktTlmkR0fHz8qDbx2xe/RorU5tT2fcex8WAAAOkwQCQADcXZ+ZGmpXWsuIud8xa4vIGcPq7yfLQQnpyZzIPf29h40Vc21qHK/iMf7UWnHX8FKht5lgDIwqqqJG6/kyzmgTQf4ccx4VmoOB3K2Zp/CyDmFkGed7jcZ5OfPFkiyiPK1q3zqdlG3W7GfbRVqF69Z41KT7ETgWAkAdJkgEgAGJhevsl1rW/Mjc+5QtmvtYmvKSeVWBictzYE8nf1kDmSzsoVxaWZB92tGGXrn7MShLBRmsJphfgZGpYVq4vCkjX2oCj5rnz2aYWQf2rTWEEKeNcqfHa8nTbb6hc6I93LP8sSlUr/HWrTW7+joaFxqFD//aQEA6ChBJAAM1GfzIxvXtdaUsWD+sK3KrfK/OZBXqlCMhmX72xZbF5/KioUqpH/T1wrJuO93J9XELYX5aTefz9KS6nePS/2ednX+bnVSR4bQW6VeeWLNswIL6vj4OE8q06J1AKoTWepsz3q3rmr67CRSAABmIIgEgIGr5kdeaSmEOTvzq5WqlknlVtzMxew25kDumAPZDbEtZgi9W1qWAV4GeX2ZhVeFTo/j8p+4769aDCDTOPane6Vl2Ra2NON0/m6XtpFckI6/P4+pG6Vmcfxss5IZWqdF67DEMa3W9yDzrqafvP5XJ/IBAExNEAkAC2AyP3J5efl2aW9+5Ksm27VWcyBftVi5tVvNgbxjDmQ3ZBAc28Od0s4+8CWTWXinrYzbCuu/pFp83MwQP7bj/5RPlW9ttGD9mziG3evC/pRtYRtqmZhGuY10oTqyqiz/LW7W3sYxHt9tFeSgReuQxGv+76V+T2d9v/1ZALlVAABmtFIAgIXx9u3bPBP7SlZhxWJILmqPSoOqdq0bedZ9BELP66gQzMWTWDh5GDc34/e1EZyctmHVTrCbcpuLbeROVdE1Kh2R+0Zc5b6RAczruP1L3Mfdap+tXe43Kysrt2K//Cl+//okvI/r0iGPmno8ziNbJlahXFPHmayO3MgqwaYDuqwsz9eMBk/qeL+6uqoaEirV8ab2qrSqRevtQi3iOXwdr7MPS80m77fj/f52vGa8PM9M5cn7gPjeh9nmtQAAzNGlAgAspKq93kbcbKvKZjzvBfWsKIvFk6elvYDpeSwwacHaA2faS45Kt73PSpisosiWbnGfx7OGcbnYGCHPKLbVXHD8MX5uXmcVTOsVj9/wpM25kF9TtcB7Wpp3evyM53GnzgrRFgLIU/G3PWgibM22t6WGY0Bsq73+/+zq+FhH6FX7fhzP6dzPnqiqc5tqx/xVVevUJt6zPYvn6VFpQB5jsnNFmbN83cyOFKVjqpPlct9q4/V2Nx6XfH86/uzzl6v3AKNvfXPfj2sAQLu8kQCABZcLjrEwsRWX+6UdOfNtpvalbS2WT+SCVwQCD7Rg7ZcehZFfkkHUeHI7/xPb/7/PfkH8+/+d+TDbwI7Kp8XPLgeOX9LJEHIiW0C3WT0yqaCdRyh5pjL25/JpBmQbc3UbC30EkV8miPy7rgSRKf6+RlojZ2v581TRzWrRgsiULc9bnrU8FUEkADALrVkBYMHl/Mi4yvZNr1uqJhxVM/K2sxXfRRbSqzPLH8ci1mZpxzh+/4MmFuuYv9z2u9im9ZwmweJXdayt6rSedzmETFXLxO9Wk9SlCkHvxnE0Q4qseBnH537N1r5Z/bK8vPz+8+PqZHbY0dHRaSXs2crYOJ62GVSPtWSFr4v9+lEdwd3n4ne8iNfH2zo81OJlXNYLAMACWSoAAGFvb+91LPhfyZZ45Z9tm2pXzbN5Ewvp52o7ll9XVWy0EUJO5kBeEUL2W4aRseCa87A6M3uQv2QI2dZJBueWC/UR9t0rn44LbbtVBZNPM6yIkPS3PNEjq8TOXvJzeYmvzWrOnAm3WVXotBlCvp+1Oh6GrnrP8bzUL6ti22rdP2hxXM4qdgEvALBQBJEAwN/kXK5YwL4Ti9IvS/NGcdnKdnnXrl3b+NIXZBuvqp3eVmln0TznQF7pepUW55dBUjyfGUY2sbjL+TzpQwg5kXM7L1261MhMtaHKk2CEkPB9OYu6NHPC2Ga+5yrMVVVl6v0GALBQBJEAwD9kldjBwcFGBm7V/LGmZdvJFzl7bdJCMGdW5VydqiXZqDQs5w3lzKQMR7QqG6Yq+NIWsn2dngn5NXkSR7H9TOtJVuUX4LvyPUi2hS8NqFq09m2ucOfF4/qsqIoEABaIIBIA+KoMJGNx+F6L7VrvVvMjX2Ub1qp1YNPG8fffi8fhjjasw5cBWAbOpYXtnfI+97U+VxtX910YeTG9DJ6hTVq09puqSABg0QgiAYDvykqfnIdYPi2wj0vDqplnTTudA5nzA1XqLJZc4M32xFkFW2jKeG1tbRD7mjDyQoSQMCUtWvtNVSQAsEgEkQDAueWCcYvzIxsTAdR2hiL592rDupiqauCsjBQo1SyPJxn4D2k+oDDyXISQMAMtWvstn7+q4wgAwOAJIgGACzk7PzI+3C0DMpkDGQHUgyGFIkyvCt9zWx8X5i1D/kd5PBli4C+M/CYhJMxBVvA3NMtbi9YaVF0AtGgFAAZPEAkATCUDyVhIvt3W/Mg5Oz0r3RxIvqTa1ietiZmDDP2rquNnZcCqsO1R4SwhJMzR8fFxvg9r4mQOLVprULXYHdSJfQAAnxNEAgAz+Wx+ZN9M5kBeyb+jwDdMqiOH3pq4ZqdVkBn6L0rVcYaty8vLt4uq2jzh454QEuarqihv5D2YFq3zl89fPK73itcIAGDABJEAwFz0LaQ5U5FlDiTnNmlNPJBK4KY9z2PE0Ksgv+Tt27e7OV+3oRaKXTTO423VhhCYszyu5vuaUj8tWmuQ7y3yNaJ4XwEADJQgEgCYm0lI0/Hqn91qDuTCVGQxf5NKYIHk901mr8bjtbnIoX8eH2O7yaqXhWrVmienxAL7bcdbqJcWrf0mjAQAhkwQCQDMXVb/dDCkOW0JmXMtzYFkXgSSXzcJIM1e/busXMrK0Iaql9p0Ons3T05RdQ71yyCraNHaa1UYeVsLeABgaASRAEBtMqSpzu5ue37kk0VtCUkzBJL/I4D8vqo68s5Qt5dsQVu1Yt0uQGO0aO2/PHEjT+Aon6rnncQBAAyCIBIAqFUuuLc1P7KaA3nFHEiaMgkkM4hbsIqG3L+eCyAvZrK9lE8na4xLz50Joe9pxQrt0KJ1GKrqedWRAMAgCCIBgEZM5kfGQnXOSBuXeo3NgaRNGcTl9p4BfPlU1TAuA1RV3jyqKo43BZDTqU7WmFSPj0vPqIKF7qhatD4vDdCitV6T985tnMz3GSfzAQAzEUQCAI2KherXNbawzIWSJ/nzLYbTBVVF8LPPqiTHpceq8PFJVhtn8JR/n4rj2U2qx88cH3dLt+UMyNcCSOiePJaUZo4ho5WVlVuFWn0hkByX+uUxfjuP8bE9/asAAMzgUgEAaMkoxCLHVlzul9k9jwUaLVjphRs3btw6Pj5ej23/51hUXC/dlouRO3E/f419bNs+1pzcTj58+LAZj/9P8eGodEC1LfwyhG3h5s2bo1KDIVTi1/HY/Pnnn+/r3mb6er/rkJWKP/zwQ+3VirM+PvbD6WRb3DgO5/uIn+KYnGHwrM91vtbvxs/6PcLH1/Has+v1HgCYF0EkANC6WQLJXBSPyxPVOPRVLhZnRcmcFxRnMa7Cpt+Xl5d33r592/XKvIXwWXjd5DYyrqpgf43f/9rCNED35GtEvI8YxevDKD7M99X/r3z5BJY8hr+Pr/tv+XR8fx/B467XegCgToJIAKAzMpCMxZDNuJnVP99q9TWOyy95xrYAkiHKfSHCyVEsKuZ+kAuKP8aiYQZP82qBNy7/q374d1znXNXdo6OjsaCpH6pqmLPbx6jMVjU5LmcqYnKbiOBxx/YAAADALASRAEAnTarEqsX1U7kwrlUUi27Sbi/2hVF+XAWUX6uOy2DpdH+J/Wnc1xaDnF+2OTzHtnG6XeRleXn5/dBbGAIAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAFAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAKv8fdiuBQivkrqEAAAAASUVORK5CYII='; + +const VALIDATORS: Record< + string, + (response: GenerateResponseData, arg?: string) => void +> = { + 'has-tool-request': (response, toolName) => { + const content = getMessageContent(response); + if (!content || !Array.isArray(content)) { + throw new Error( + `Response missing message content. Full response: ${JSON.stringify( + response, + null, + 2 + )}` + ); + } + const toolRequest = content.find((c: Part) => c.toolRequest); + if (!toolRequest) { + throw new Error( + `Model did not return a tool request. Content: ${JSON.stringify( + content, + null, + 2 + )}` + ); + } + if (toolName && toolRequest.toolRequest?.name !== toolName) { + throw new Error( + `Expected tool request '${toolName}', got '${toolRequest.toolRequest?.name}'` + ); + } + }, + 'valid-json': (response) => { + const content = getMessageContent(response); + if (!content || !Array.isArray(content)) { + throw new Error( + `Response missing message content. Full response: ${JSON.stringify( + response, + null, + 2 + )}` + ); + } + const textPart = content.find((c: Part) => c.text); + if (!textPart) { + throw new Error( + `Model did not return text content for JSON. Content: ${JSON.stringify( + content, + null, + 2 + )}` + ); + } + try { + JSON.parse(textPart.text!); + } catch (e) { + throw new Error( + `Response text is not valid JSON. Text: ${textPart.text}` + ); + } + }, + 'text-includes': (response, expected) => { + const text = getMessageText(response); + if ( + !text || + (expected && !text.toLowerCase().includes(expected.toLowerCase())) + ) { + throw new Error( + `Response text does not include '${expected}'. Text: ${text}` + ); + } + }, + 'text-starts-with': (response, expected) => { + const text = getMessageText(response); + if (!text || (expected && !text.trim().startsWith(expected))) { + throw new Error( + `Response text does not start with '${expected}'. Text: ${text}` + ); + } + }, + 'text-not-empty': (response) => { + const text = getMessageText(response); + if (!text || text.trim().length === 0) { + throw new Error('Response text is empty'); + } + }, + 'valid-media': (response, type) => { + const mediaPart = getMediaPart(response); + if (!mediaPart) { + throw new Error(`Model did not return ${type || 'media'} part.`); + } + if (type) { + if ( + mediaPart.media?.contentType && + !mediaPart.media.contentType.startsWith(`${type}/`) + ) { + throw new Error( + `Expected ${type} content type, got ${mediaPart.media.contentType}` + ); + } + } + if (type === 'image') { + const url = mediaPart.media?.url; + if (!url) throw new Error('Media part missing URL'); + if (url.startsWith('data:')) { + if (!url.startsWith('data:image/')) { + throw new Error('Invalid data URL content type for image'); + } + } else if (url.startsWith('http')) { + try { + new URL(url); + } catch (e) { + throw new Error(`Invalid URL: ${url}`); + } + } else { + throw new Error(`Unknown URL format: ${url}`); + } + } + }, +}; + +const TEST_CASES: Record = { + 'tool-request': { + name: 'Tool Request Conformance', + input: { + messages: [ + { + role: 'user', + content: [{ text: 'What is the weather in New York? Use the tool.' }], + }, + ], + tools: [ + { + name: 'weather', + description: 'Get the weather for a city', + inputSchema: { + type: 'object', + properties: { + city: { type: 'string' }, + }, + required: ['city'], + }, + }, + ], + }, + validators: ['has-tool-request:weather'], + }, + 'structured-output': { + name: 'Structured Output Conformance', + input: { + messages: [ + { + role: 'user', + content: [{ text: 'Generate a profile for John Doe.' }], + }, + ], + output: { + format: 'json', + schema: { + type: 'object', + properties: { + name: { type: 'string' }, + age: { type: 'number' }, + }, + required: ['name', 'age'], + }, + constrained: true, + }, + }, + validators: ['valid-json'], + }, + multiturn: { + name: 'Multiturn Conformance', + input: { + messages: [ + { role: 'user', content: [{ text: 'My name is Genkit.' }] }, + { role: 'model', content: [{ text: 'Hello Genkit.' }] }, + { role: 'user', content: [{ text: 'What is my name?' }] }, + ], + }, + validators: ['text-includes:Genkit'], + }, + 'system-role': { + name: 'System Role Conformance', + input: { + messages: [ + { + role: 'system', + content: [ + { + text: "IMPORTANT: your response are machine processed, always start/prefix your response with 'RESPONSE:', ex: 'RESPONSE: hello'", + }, + ], + }, + { role: 'user', content: [{ text: 'hello' }] }, + ], + }, + validators: ['text-starts-with:RESPONSE:'], + }, + 'input-image-base64': { + name: 'Image Input (Base64) Conformance', + input: { + messages: [ + { + role: 'user', + content: [ + { text: 'What text do you see in this image?' }, + { + media: { + url: `data:image/png;base64,${imageBase64}`, + contentType: 'image/png', + }, + }, + ], + }, + ], + }, + validators: ['text-includes:genkit'], + }, + 'input-image-url': { + name: 'Image Input (URL) Conformance', + input: { + messages: [ + { + role: 'user', + content: [ + { text: 'What is this logo?' }, + { + media: { + url: 'https://www.google.com/images/branding/googlelogo/2x/googlelogo_color_272x92dp.png', + contentType: 'image/png', + }, + }, + ], + }, + ], + }, + validators: ['text-includes:google'], + }, + 'input-video-youtube': { + name: 'Video Input (YouTube) Conformance', + input: { + messages: [ + { + role: 'user', + content: [ + { text: 'Describe this video.' }, + { + media: { + url: 'https://www.youtube.com/watch?v=3p1P5grjXIQ', + contentType: 'video/mp4', + }, + }, + ], + }, + ], + }, + validators: ['text-not-empty'], + }, + 'output-audio': { + name: 'Audio Output (TTS) Conformance', + input: { + messages: [{ role: 'user', content: [{ text: 'Say hello.' }] }], + }, + validators: ['valid-media:audio'], + }, + 'output-image': { + name: 'Image Output (Generation) Conformance', + input: { + messages: [ + { + role: 'user', + content: [{ text: 'Generate an image of a cat.' }], + }, + ], + }, + validators: ['valid-media:image'], + }, +}; + +async function waitForRuntime(manager: RuntimeManager) { + // Poll for runtimes + for (let i = 0; i < 20; i++) { + if (manager.listRuntimes().length > 0) return; + await new Promise((r) => setTimeout(r, 500)); + } + logger.warn('Runtime not detected after 10 seconds.'); +} + +async function runTest( + manager: RuntimeManager, + model: string, + testCase: TestCase +): Promise { + logger.info(`Running test: ${testCase.name}...`); + try { + // Adjust model name if needed (e.g. /model/ prefix) + const modelKey = model.startsWith('/') ? model : `/model/${model}`; + const actionResponse = await manager.runAction({ + key: modelKey, + input: testCase.input, + }); + + const response = GenerateResponseSchema.parse(actionResponse.result); + + for (const v of testCase.validators) { + const [valName, ...args] = v.split(':'); + const arg = args.join(':'); + const validator = VALIDATORS[valName]; + if (!validator) throw new Error(`Unknown validator: ${valName}`); + validator(response, arg); + } + + logger.info(`✅ Passed: ${testCase.name}`); + return true; + } catch (e) { + if (e instanceof GenkitToolsError) { + logger.error( + `❌ Failed: ${testCase.name} - ${ + e.data?.stack || JSON.stringify(e.data?.details) || e + }` + ); + } else if (e instanceof Error) { + logger.error(`❌ Failed: ${testCase.name} - ${e.message}`); + } else { + logger.error(`❌ Failed: ${testCase.name} - ${JSON.stringify(e)}`); + } + return false; + } +} + +async function runTestSuite( + manager: RuntimeManager, + suite: TestSuite, + defaultSupports: string[] +): Promise<{ passed: number; failed: number }> { + const supports = suite.supports || (suite.tests ? [] : defaultSupports); + + logger.info(`Testing model: ${suite.model}`); + + const promises: Promise[] = []; + + // Built-in conformance tests + for (const support of supports) { + const testCase = TEST_CASES[support]; + if (testCase) { + promises.push(runTest(manager, suite.model, testCase)); + } else { + logger.warn(`Unknown capability: ${support}`); + } + } + + // Custom tests + if (suite.tests) { + for (const test of suite.tests) { + const customTestCase: TestCase = { + name: test.name || 'Custom Test', + input: test.input, + validators: test.validators || [], + }; + promises.push(runTest(manager, suite.model, customTestCase)); + } + } + + const results = await Promise.all(promises); + const passed = results.filter((r) => r).length; + const failed = results.filter((r) => !r).length; + + return { passed, failed }; +} + +export const devTestModel = new Command('dev:test-model') + .description('Test a model against the Genkit model specification') + .argument('[modelOrCmd]', 'Model name or command') + .argument('[args...]', 'Command arguments') + .option( + '--supports ', + 'Comma-separated list of supported capabilities (tool-request, structured-output, multiturn, system-role, input-image-base64, input-image-url, input-video-youtube, output-audio, output-image)', + 'tool-request,structured-output,multiturn,system-role,input-image-base64,input-image-url' + ) + .option('--from-file ', 'Path to a file containing test payloads') + .action( + async ( + modelOrCmd: string | undefined, + args: string[] | undefined, + options: TestOptions + ) => { + const projectRoot = await findProjectRoot(); + + let cmd: string[] = []; + let defaultModelName: string | undefined; + + if (options.fromFile) { + if (modelOrCmd) cmd.push(modelOrCmd); + if (args) cmd.push(...args); + } else { + if (!modelOrCmd) { + logger.error('Model name is required unless --from-file is used.'); + process.exitCode = 1; + return; + } + defaultModelName = modelOrCmd; + if (args) cmd = args; + } + + let manager: RuntimeManager; + + if (cmd.length > 0) { + const result = await startDevProcessManager( + projectRoot, + cmd[0], + cmd.slice(1) + ); + manager = result.manager; + } else { + manager = await startManager(projectRoot, false); + } + + await waitForRuntime(manager); + + try { + let totalPassed = 0; + let totalFailed = 0; + + let suites: TestSuite[] = []; + + if (options.fromFile) { + const filePath = resolve(projectRoot, options.fromFile); + const fileContent = readFileSync(filePath, 'utf-8'); + let parsed; + if (filePath.endsWith('.yaml') || filePath.endsWith('.yml')) { + parsed = parse(fileContent); + } else { + parsed = JSON.parse(fileContent); + } + suites = Array.isArray(parsed) ? parsed : [parsed]; + } else { + if (!defaultModelName) throw new Error('Model name required'); + suites = [{ model: defaultModelName }]; + } + + const defaultSupports = options.supports + .split(',') + .map((s) => s.trim()); + + for (const suite of suites) { + if (!suite.model) { + logger.error('Model name required in test suite.'); + totalFailed++; + continue; + } + const { passed, failed } = await runTestSuite( + manager, + suite, + defaultSupports + ); + totalPassed += passed; + totalFailed += failed; + } + + logger.info('--------------------------------------------------'); + logger.info( + `Tests Completed: ${totalPassed} Passed, ${totalFailed} Failed` + ); + + if (totalFailed > 0) { + process.exitCode = 1; + } + } catch (e) { + logger.error('Error running tests:', e); + process.exitCode = 1; + } finally { + if (manager) { + await manager.stop(); + } + } + } + ); diff --git a/genkit-tools/pnpm-lock.yaml b/genkit-tools/pnpm-lock.yaml index 015df58488..ea1b78beac 100644 --- a/genkit-tools/pnpm-lock.yaml +++ b/genkit-tools/pnpm-lock.yaml @@ -68,6 +68,9 @@ importers: semver: specifier: ^7.7.2 version: 7.7.2 + yaml: + specifier: ^2.8.0 + version: 2.8.0 devDependencies: '@jest/globals': specifier: ^29.7.0 diff --git a/go/ai/document_test.go b/go/ai/document_test.go index ab1e65ab72..a5d4bc9bc0 100644 --- a/go/ai/document_test.go +++ b/go/ai/document_test.go @@ -141,3 +141,273 @@ func TestReasoningPartJSON(t *testing.T) { t.Errorf("unmarshaled reasoning content type = %q, want %q", unmarshaledPart.ContentType, "plain/text") } } + +func TestNewDataPart(t *testing.T) { + t.Run("creates data part with content", func(t *testing.T) { + p := NewDataPart("some binary data") + + if p.Kind != PartData { + t.Errorf("Kind = %v, want %v", p.Kind, PartData) + } + if p.Text != "some binary data" { + t.Errorf("Text = %q, want %q", p.Text, "some binary data") + } + }) + + t.Run("creates data part with empty content", func(t *testing.T) { + p := NewDataPart("") + + if p.Kind != PartData { + t.Errorf("Kind = %v, want %v", p.Kind, PartData) + } + if p.Text != "" { + t.Errorf("Text = %q, want empty string", p.Text) + } + }) +} + +func TestNewCustomPart(t *testing.T) { + t.Run("creates custom part with value", func(t *testing.T) { + custom := map[string]any{"key": "value", "count": 42} + p := NewCustomPart(custom) + + if p.Kind != PartCustom { + t.Errorf("Kind = %v, want %v", p.Kind, PartCustom) + } + if p.Custom == nil { + t.Fatal("Custom is nil") + } + if p.Custom["key"] != "value" { + t.Errorf("Custom[key] = %v, want %q", p.Custom["key"], "value") + } + }) + + t.Run("creates custom part with nil value", func(t *testing.T) { + p := NewCustomPart(nil) + + if p.Kind != PartCustom { + t.Errorf("Kind = %v, want %v", p.Kind, PartCustom) + } + if p.Custom != nil { + t.Errorf("Custom = %v, want nil", p.Custom) + } + }) +} + +func TestPartIsData(t *testing.T) { + tests := []struct { + name string + part *Part + want bool + }{ + {"data part", NewDataPart("{}"), true}, + {"text part", NewTextPart("hello"), false}, + {"media part", NewMediaPart("image/png", "data:..."), false}, + {"nil part", nil, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.part.IsData() + if got != tt.want { + t.Errorf("IsData() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestPartIsInterrupt(t *testing.T) { + t.Run("interrupt tool request returns true", func(t *testing.T) { + p := &Part{ + Kind: PartToolRequest, + ToolRequest: &ToolRequest{ + Name: "test", + Input: map[string]any{}, + }, + Metadata: map[string]any{ + "interrupt": true, + }, + } + + if !p.IsInterrupt() { + t.Error("IsInterrupt() = false, want true") + } + }) + + t.Run("non-interrupt tool request returns false", func(t *testing.T) { + p := &Part{ + Kind: PartToolRequest, + ToolRequest: &ToolRequest{ + Name: "test", + Input: map[string]any{}, + }, + } + + if p.IsInterrupt() { + t.Error("IsInterrupt() = true, want false") + } + }) + + t.Run("non-tool-request part returns false", func(t *testing.T) { + p := NewTextPart("hello") + + if p.IsInterrupt() { + t.Error("IsInterrupt() = true, want false") + } + }) + + t.Run("nil part returns false", func(t *testing.T) { + var p *Part + if p.IsInterrupt() { + t.Error("IsInterrupt() = true, want false") + } + }) +} + +func TestPartIsCustom(t *testing.T) { + tests := []struct { + name string + part *Part + want bool + }{ + {"custom part", NewCustomPart(map[string]any{"key": "value"}), true}, + {"text part", NewTextPart("hello"), false}, + {"data part", NewDataPart("data"), false}, + {"nil part", nil, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.part.IsCustom() + if got != tt.want { + t.Errorf("IsCustom() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsImageContentType(t *testing.T) { + tests := []struct { + contentType string + want bool + }{ + {"image/png", true}, + {"image/jpeg", true}, + {"image/gif", true}, + {"image/webp", true}, + {"data:image/png;base64,...", true}, + {"video/mp4", false}, + {"audio/mp3", false}, + {"text/plain", false}, + {"application/json", false}, + {"", false}, + } + + for _, tt := range tests { + t.Run(tt.contentType, func(t *testing.T) { + got := IsImageContentType(tt.contentType) + if got != tt.want { + t.Errorf("IsImageContentType(%q) = %v, want %v", tt.contentType, got, tt.want) + } + }) + } +} + +func TestIsVideoContentType(t *testing.T) { + tests := []struct { + contentType string + want bool + }{ + {"video/mp4", true}, + {"video/webm", true}, + {"video/mpeg", true}, + {"data:video/mp4;base64,...", true}, + {"image/png", false}, + {"audio/mp3", false}, + {"text/plain", false}, + {"", false}, + } + + for _, tt := range tests { + t.Run(tt.contentType, func(t *testing.T) { + got := IsVideoContentType(tt.contentType) + if got != tt.want { + t.Errorf("IsVideoContentType(%q) = %v, want %v", tt.contentType, got, tt.want) + } + }) + } +} + +func TestIsAudioContentType(t *testing.T) { + tests := []struct { + contentType string + want bool + }{ + {"audio/mp3", true}, + {"audio/wav", true}, + {"audio/ogg", true}, + {"audio/mpeg", true}, + {"data:audio/mp3;base64,...", true}, + {"image/png", false}, + {"video/mp4", false}, + {"text/plain", false}, + {"", false}, + } + + for _, tt := range tests { + t.Run(tt.contentType, func(t *testing.T) { + got := IsAudioContentType(tt.contentType) + if got != tt.want { + t.Errorf("IsAudioContentType(%q) = %v, want %v", tt.contentType, got, tt.want) + } + }) + } +} + +func TestNewResponseForToolRequest(t *testing.T) { + t.Run("creates tool response for tool request part", func(t *testing.T) { + reqPart := NewToolRequestPart(&ToolRequest{ + Name: "calculator", + Input: map[string]any{"a": 1, "b": 2}, + }) + output := map[string]any{"result": 3} + + resp := NewResponseForToolRequest(reqPart, output) + + if resp.Kind != PartToolResponse { + t.Errorf("Kind = %v, want %v", resp.Kind, PartToolResponse) + } + if resp.ToolResponse == nil { + t.Fatal("ToolResponse is nil") + } + if resp.ToolResponse.Name != "calculator" { + t.Errorf("Name = %q, want %q", resp.ToolResponse.Name, "calculator") + } + if resp.ToolResponse.Output.(map[string]any)["result"] != 3 { + t.Errorf("Output mismatch") + } + }) + + t.Run("preserves ref from original request", func(t *testing.T) { + reqPart := NewToolRequestPart(&ToolRequest{ + Name: "tool", + Ref: "request-123", + }) + + resp := NewResponseForToolRequest(reqPart, "output") + + if resp.ToolResponse.Ref != "request-123" { + t.Errorf("Ref = %q, want %q", resp.ToolResponse.Ref, "request-123") + } + }) + + t.Run("returns nil for non-tool-request part", func(t *testing.T) { + textPart := NewTextPart("not a tool request") + + resp := NewResponseForToolRequest(textPart, "output") + + if resp != nil { + t.Error("expected nil for non-tool-request part") + } + }) +} diff --git a/go/ai/embedder_test.go b/go/ai/embedder_test.go new file mode 100644 index 0000000000..43404479f1 --- /dev/null +++ b/go/ai/embedder_test.go @@ -0,0 +1,400 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package ai + +import ( + "context" + "errors" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestEmbedderRef(t *testing.T) { + t.Run("NewEmbedderRef creates ref with name and config", func(t *testing.T) { + config := map[string]any{"dimension": 768} + ref := NewEmbedderRef("test/embedder", config) + + if ref.Name() != "test/embedder" { + t.Errorf("Name() = %q, want %q", ref.Name(), "test/embedder") + } + if diff := cmp.Diff(config, ref.Config()); diff != "" { + t.Errorf("Config() mismatch (-want +got):\n%s", diff) + } + }) + + t.Run("NewEmbedderRef with nil config", func(t *testing.T) { + ref := NewEmbedderRef("test/embedder", nil) + + if ref.Name() != "test/embedder" { + t.Errorf("Name() = %q, want %q", ref.Name(), "test/embedder") + } + if ref.Config() != nil { + t.Errorf("Config() = %v, want nil", ref.Config()) + } + }) +} + +func TestNewEmbedder(t *testing.T) { + t.Run("creates embedder with valid name", func(t *testing.T) { + e := NewEmbedder("test/embedder", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { + return &EmbedResponse{}, nil + }) + + if e == nil { + t.Fatal("expected embedder, got nil") + } + if e.Name() != "test/embedder" { + t.Errorf("Name() = %q, want %q", e.Name(), "test/embedder") + } + }) + + t.Run("panics with empty name", func(t *testing.T) { + assertPanic(t, func() { + NewEmbedder("", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { + return &EmbedResponse{}, nil + }) + }, "name is required") + }) + + t.Run("applies options correctly", func(t *testing.T) { + opts := &EmbedderOptions{ + Label: "Test Embedder", + Dimensions: 768, + Supports: &EmbedderSupports{ + Input: []string{"text", "image"}, + Multilingual: true, + }, + ConfigSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "temperature": map[string]any{"type": "number"}, + }, + }, + } + + e := NewEmbedder("test/embedder", opts, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { + return &EmbedResponse{}, nil + }) + + if e == nil { + t.Fatal("expected embedder, got nil") + } + }) + + t.Run("uses defaults when options nil", func(t *testing.T) { + e := NewEmbedder("test/embedder", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { + return &EmbedResponse{}, nil + }) + + if e == nil { + t.Fatal("expected embedder, got nil") + } + }) +} + +func TestDefineEmbedder(t *testing.T) { + t.Run("registers and returns embedder", func(t *testing.T) { + r := newTestRegistry(t) + called := false + + e := DefineEmbedder(r, "test/defineEmbedder", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { + called = true + return &EmbedResponse{ + Embeddings: []*Embedding{{Embedding: []float32{0.1, 0.2, 0.3}}}, + }, nil + }) + + if e == nil { + t.Fatal("expected embedder, got nil") + } + + // Verify it's registered by looking it up + found := LookupEmbedder(r, "test/defineEmbedder") + if found == nil { + t.Fatal("LookupEmbedder returned nil for registered embedder") + } + + // Verify the function works + resp, err := e.Embed(context.Background(), &EmbedRequest{ + Input: []*Document{DocumentFromText("test", nil)}, + }) + assertNoError(t, err) + if !called { + t.Error("embedder function was not called") + } + if len(resp.Embeddings) != 1 { + t.Errorf("len(Embeddings) = %d, want 1", len(resp.Embeddings)) + } + }) +} + +func TestLookupEmbedder(t *testing.T) { + t.Run("returns embedder when found", func(t *testing.T) { + r := newTestRegistry(t) + DefineEmbedder(r, "test/lookupEmbedder", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { + return &EmbedResponse{}, nil + }) + + e := LookupEmbedder(r, "test/lookupEmbedder") + if e == nil { + t.Error("expected embedder, got nil") + } + }) + + t.Run("returns nil when not found", func(t *testing.T) { + r := newTestRegistry(t) + + e := LookupEmbedder(r, "nonexistent") + if e != nil { + t.Error("expected nil for non-existent embedder") + } + }) +} + +func TestEmbedderEmbed(t *testing.T) { + t.Run("embeds documents successfully", func(t *testing.T) { + r := newTestRegistry(t) + var capturedReq *EmbedRequest + + e := DefineEmbedder(r, "test/embedDocuments", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { + capturedReq = req + embeddings := make([]*Embedding, len(req.Input)) + for i := range req.Input { + embeddings[i] = &Embedding{ + Embedding: []float32{float32(i) * 0.1, float32(i) * 0.2, float32(i) * 0.3}, + } + } + return &EmbedResponse{Embeddings: embeddings}, nil + }) + + docs := []*Document{ + DocumentFromText("first document", nil), + DocumentFromText("second document", nil), + } + + resp, err := e.Embed(context.Background(), &EmbedRequest{Input: docs}) + assertNoError(t, err) + + if len(capturedReq.Input) != 2 { + t.Errorf("captured input len = %d, want 2", len(capturedReq.Input)) + } + if len(resp.Embeddings) != 2 { + t.Errorf("len(Embeddings) = %d, want 2", len(resp.Embeddings)) + } + }) + + t.Run("returns error on nil embedder", func(t *testing.T) { + var e *embedder + _, err := e.Embed(context.Background(), &EmbedRequest{}) + if err == nil { + t.Error("expected error for nil embedder") + } + }) + + t.Run("propagates function errors", func(t *testing.T) { + r := newTestRegistry(t) + expectedErr := errors.New("embedding failed") + + e := DefineEmbedder(r, "test/embedError", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { + return nil, expectedErr + }) + + _, err := e.Embed(context.Background(), &EmbedRequest{ + Input: []*Document{DocumentFromText("test", nil)}, + }) + if err == nil { + t.Error("expected error, got nil") + } + }) + + t.Run("passes options through request", func(t *testing.T) { + r := newTestRegistry(t) + var capturedOpts any + + e := DefineEmbedder(r, "test/embedOpts", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { + capturedOpts = req.Options + return &EmbedResponse{Embeddings: []*Embedding{{Embedding: []float32{0.1}}}}, nil + }) + + opts := map[string]any{"dimension": 768} + _, err := e.Embed(context.Background(), &EmbedRequest{ + Input: []*Document{DocumentFromText("test", nil)}, + Options: opts, + }) + assertNoError(t, err) + + if diff := cmp.Diff(opts, capturedOpts); diff != "" { + t.Errorf("Options mismatch (-want +got):\n%s", diff) + } + }) +} + +func TestEmbedFunction(t *testing.T) { + t.Run("embeds with embedder directly", func(t *testing.T) { + r := newTestRegistry(t) + e := DefineEmbedder(r, "test/embedFunc", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { + return &EmbedResponse{ + Embeddings: []*Embedding{{Embedding: []float32{0.1, 0.2, 0.3}}}, + }, nil + }) + + resp, err := Embed(context.Background(), r, + WithEmbedder(e), + WithTextDocs("test document"), + ) + assertNoError(t, err) + + if len(resp.Embeddings) != 1 { + t.Errorf("len(Embeddings) = %d, want 1", len(resp.Embeddings)) + } + }) + + t.Run("embeds with embedder ref", func(t *testing.T) { + r := newTestRegistry(t) + DefineEmbedder(r, "test/embedFuncRef", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { + return &EmbedResponse{ + Embeddings: []*Embedding{{Embedding: []float32{0.1, 0.2, 0.3}}}, + }, nil + }) + + ref := NewEmbedderRef("test/embedFuncRef", nil) + resp, err := Embed(context.Background(), r, + WithEmbedder(ref), + WithTextDocs("test document"), + ) + assertNoError(t, err) + + if len(resp.Embeddings) != 1 { + t.Errorf("len(Embeddings) = %d, want 1", len(resp.Embeddings)) + } + }) + + t.Run("embeds with embedder name", func(t *testing.T) { + r := newTestRegistry(t) + DefineEmbedder(r, "test/embedFuncName", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { + return &EmbedResponse{ + Embeddings: []*Embedding{{Embedding: []float32{0.1, 0.2, 0.3}}}, + }, nil + }) + + resp, err := Embed(context.Background(), r, + WithEmbedderName("test/embedFuncName"), + WithTextDocs("test document"), + ) + assertNoError(t, err) + + if len(resp.Embeddings) != 1 { + t.Errorf("len(Embeddings) = %d, want 1", len(resp.Embeddings)) + } + }) + + t.Run("uses config from EmbedderRef", func(t *testing.T) { + r := newTestRegistry(t) + var capturedOpts any + + DefineEmbedder(r, "test/embedRefConfig", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { + capturedOpts = req.Options + return &EmbedResponse{Embeddings: []*Embedding{{Embedding: []float32{0.1}}}}, nil + }) + + config := map[string]any{"dimension": 768} + ref := NewEmbedderRef("test/embedRefConfig", config) + + _, err := Embed(context.Background(), r, + WithEmbedder(ref), + WithTextDocs("test"), + ) + assertNoError(t, err) + + if diff := cmp.Diff(config, capturedOpts); diff != "" { + t.Errorf("Options mismatch (-want +got):\n%s", diff) + } + }) + + t.Run("explicit config overrides EmbedderRef config", func(t *testing.T) { + r := newTestRegistry(t) + var capturedOpts any + + DefineEmbedder(r, "test/embedOverrideConfig", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { + capturedOpts = req.Options + return &EmbedResponse{Embeddings: []*Embedding{{Embedding: []float32{0.1}}}}, nil + }) + + refConfig := map[string]any{"dimension": 768} + explicitConfig := map[string]any{"dimension": 512} + ref := NewEmbedderRef("test/embedOverrideConfig", refConfig) + + _, err := Embed(context.Background(), r, + WithEmbedder(ref), + WithConfig(explicitConfig), + WithTextDocs("test"), + ) + assertNoError(t, err) + + if diff := cmp.Diff(explicitConfig, capturedOpts); diff != "" { + t.Errorf("Options mismatch (-want +got):\n%s", diff) + } + }) + + t.Run("returns error when embedder not set", func(t *testing.T) { + r := newTestRegistry(t) + + _, err := Embed(context.Background(), r, + WithTextDocs("test"), + ) + assertError(t, err, "embedder must be set") + }) + + t.Run("returns error when embedder not found", func(t *testing.T) { + r := newTestRegistry(t) + + _, err := Embed(context.Background(), r, + WithEmbedderName("nonexistent"), + WithTextDocs("test"), + ) + assertError(t, err, "embedder not found") + }) + + t.Run("embeds with document options", func(t *testing.T) { + r := newTestRegistry(t) + var capturedDocs []*Document + + DefineEmbedder(r, "test/embedDocs", nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { + capturedDocs = req.Input + embeddings := make([]*Embedding, len(req.Input)) + for i := range req.Input { + embeddings[i] = &Embedding{Embedding: []float32{0.1}} + } + return &EmbedResponse{Embeddings: embeddings}, nil + }) + + doc := DocumentFromText("custom document", map[string]any{"custom": "metadata"}) + _, err := Embed(context.Background(), r, + WithEmbedderName("test/embedDocs"), + WithDocs(doc), + ) + assertNoError(t, err) + + if len(capturedDocs) != 1 { + t.Fatalf("len(docs) = %d, want 1", len(capturedDocs)) + } + if capturedDocs[0].Metadata["custom"] != "metadata" { + t.Error("document metadata not passed correctly") + } + }) +} diff --git a/go/ai/evaluator_test.go b/go/ai/evaluator_test.go index 9ee268d3da..6cf5c58953 100644 --- a/go/ai/evaluator_test.go +++ b/go/ai/evaluator_test.go @@ -207,3 +207,161 @@ func TestBatchEvaluator(t *testing.T) { t.Errorf("got %v, want %v", got, want) } } + +func TestNewEvaluatorRef(t *testing.T) { + t.Run("creates evaluator reference with name and config", func(t *testing.T) { + config := map[string]any{"threshold": 0.8} + ref := NewEvaluatorRef("test/myEvaluator", config) + + if ref.Name() != "test/myEvaluator" { + t.Errorf("Name() = %q, want %q", ref.Name(), "test/myEvaluator") + } + if ref.Config() == nil { + t.Error("Config() = nil, want config") + } + if ref.Config().(map[string]any)["threshold"] != 0.8 { + t.Errorf("Config()[threshold] = %v, want 0.8", ref.Config().(map[string]any)["threshold"]) + } + }) + + t.Run("creates evaluator reference with nil config", func(t *testing.T) { + ref := NewEvaluatorRef("test/simpleEvaluator", nil) + + if ref.Name() != "test/simpleEvaluator" { + t.Errorf("Name() = %q, want %q", ref.Name(), "test/simpleEvaluator") + } + if ref.Config() != nil { + t.Errorf("Config() = %v, want nil", ref.Config()) + } + }) + + t.Run("implements EvaluatorArg interface", func(t *testing.T) { + ref := NewEvaluatorRef("test/interface", nil) + var _ EvaluatorArg = ref // compile-time check + + if ref.Name() != "test/interface" { + t.Errorf("Name() = %q, want %q", ref.Name(), "test/interface") + } + }) +} + +func TestEvaluatorRefUsedWithEvaluate(t *testing.T) { + r := registry.New() + + // Define evaluator that uses config + DefineEvaluator(r, "test/configEvaluator", &evalOpts, func(ctx context.Context, req *EvaluatorCallbackRequest) (*EvaluatorCallbackResponse, error) { + score := Score{ + Id: "configScore", + Score: 1, + Status: ScoreStatusPass.String(), + Details: map[string]any{"options": req.Options}, + } + return &EvaluatorCallbackResponse{ + TestCaseId: req.Input.TestCaseId, + Evaluation: []Score{score}, + }, nil + }) + + // Use EvaluatorRef instead of direct evaluator + ref := NewEvaluatorRef("test/configEvaluator", "ref-config-value") + + resp, err := Evaluate(context.Background(), r, + WithEvaluator(ref), + WithDataset(&Example{Input: "test"}), + WithID("testrun")) + if err != nil { + t.Fatal(err) + } + + // Config from ref should be used since no explicit config was provided + if got, want := (*resp)[0].Evaluation[0].Details["options"], "ref-config-value"; got != want { + t.Errorf("got config %v, want %v", got, want) + } +} + +func TestScoreStatusString(t *testing.T) { + tests := []struct { + status ScoreStatus + want string + }{ + {ScoreStatusUnknown, "UNKNOWN"}, + {ScoreStatusFail, "FAIL"}, + {ScoreStatusPass, "PASS"}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + got := tt.status.String() + if got != tt.want { + t.Errorf("String() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestNewEvaluator(t *testing.T) { + t.Run("panics with empty name", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic for empty name") + } + }() + + NewEvaluator("", &evalOpts, testEvalFunc) + }) + + t.Run("creates evaluator with nil options", func(t *testing.T) { + eval := NewEvaluator("test/nilOpts", nil, testEvalFunc) + if eval == nil { + t.Error("NewEvaluator returned nil") + } + if eval.Name() != "test/nilOpts" { + t.Errorf("Name() = %q, want %q", eval.Name(), "test/nilOpts") + } + }) +} + +func TestNewBatchEvaluator(t *testing.T) { + t.Run("panics with empty name", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic for empty name") + } + }() + + NewBatchEvaluator("", &evalOpts, testBatchEvalFunc) + }) + + t.Run("creates batch evaluator with nil options", func(t *testing.T) { + eval := NewBatchEvaluator("test/batchNilOpts", nil, testBatchEvalFunc) + if eval == nil { + t.Error("NewBatchEvaluator returned nil") + } + }) +} + +func TestEvaluateNilEvaluator(t *testing.T) { + t.Run("returns error when evaluator not set", func(t *testing.T) { + r := registry.New() + + _, err := Evaluate(context.Background(), r, + WithDataset(&Example{Input: "test"})) + + if err == nil { + t.Error("expected error when evaluator not set, got nil") + } + }) + + t.Run("returns error for non-existent evaluator", func(t *testing.T) { + r := registry.New() + + ref := NewEvaluatorRef("test/nonexistent", nil) + _, err := Evaluate(context.Background(), r, + WithEvaluator(ref), + WithDataset(&Example{Input: "test"})) + + if err == nil { + t.Error("expected error for non-existent evaluator, got nil") + } + }) +} diff --git a/go/ai/example_test.go b/go/ai/example_test.go new file mode 100644 index 0000000000..afcdbbc6bb --- /dev/null +++ b/go/ai/example_test.go @@ -0,0 +1,160 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// Package ai_test provides examples for ai package helper functions. +// +// The ai package contains helper types and functions used with genkit. +// Most generation and definition functions are in the genkit package; +// see that package for the primary API documentation. +package ai_test + +import ( + "fmt" + + "github.com/firebase/genkit/go/ai" +) + +// This example demonstrates creating different types of message parts. +func ExampleNewTextPart() { + // Create a text part + part := ai.NewTextPart("Hello, world!") + fmt.Println(part.Text) + // Output: Hello, world! +} + +// This example demonstrates creating a message with text content. +func ExampleNewUserTextMessage() { + // Create a user message with text + msg := ai.NewUserTextMessage("What is the capital of France?") + fmt.Println("Role:", msg.Role) + fmt.Println("Text:", msg.Content[0].Text) + // Output: + // Role: user + // Text: What is the capital of France? +} + +// This example demonstrates creating system and model messages. +func ExampleNewSystemTextMessage() { + // Create a system message + sysMsg := ai.NewSystemTextMessage("You are a helpful assistant.") + fmt.Println("System role:", sysMsg.Role) + + // Create a model response message + modelMsg := ai.NewModelTextMessage("I'm here to help!") + fmt.Println("Model role:", modelMsg.Role) + // Output: + // System role: system + // Model role: model +} + +// This example demonstrates creating a data part for raw string content. +func ExampleNewDataPart() { + // Create a data part with raw string content + part := ai.NewDataPart(`{"name": "Alice", "age": 30}`) + fmt.Println("Is data part:", part.IsData()) + fmt.Println("Content:", part.Text) + // Output: + // Is data part: true + // Content: {"name": "Alice", "age": 30} +} + +// This example demonstrates accessing text from a Part. +func ExamplePart_Text() { + // Create a part with text + part := ai.NewTextPart("Sample text content") + + // Access the text field directly + fmt.Println(part.Text) + // Output: Sample text content +} + +// This example demonstrates the Document type used in RAG applications. +func ExampleDocument() { + // Create a document with text content + doc := &ai.Document{ + Content: []*ai.Part{ + ai.NewTextPart("This is the document content."), + }, + Metadata: map[string]any{ + "source": "knowledge-base", + "page": 42, + }, + } + + fmt.Println("Content:", doc.Content[0].Text) + fmt.Println("Source:", doc.Metadata["source"]) + // Output: + // Content: This is the document content. + // Source: knowledge-base +} + +// This example demonstrates creating an Embedding for vector search. +func ExampleEmbedding() { + // Create an embedding (typically returned by an embedder) + embedding := &ai.Embedding{ + Embedding: []float32{0.1, 0.2, 0.3, 0.4, 0.5}, + Metadata: map[string]any{ + "source": "document-1", + }, + } + + fmt.Printf("Embedding dimensions: %d\n", len(embedding.Embedding)) + fmt.Printf("First value: %.1f\n", embedding.Embedding[0]) + // Output: + // Embedding dimensions: 5 + // First value: 0.1 +} + +// This example demonstrates creating a media part for images or other media. +func ExampleNewMediaPart() { + // Create a media part with base64-encoded image data + // In practice, you would encode actual image bytes + imageData := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJ..." + part := ai.NewMediaPart("image/png", imageData) + + fmt.Println("Is media:", part.IsMedia()) + fmt.Println("Content type:", part.ContentType) + // Output: + // Is media: true + // Content type: image/png +} + +// This example demonstrates creating a model reference with configuration. +func ExampleNewModelRef() { + // Create a reference to a model with custom configuration + // The config type depends on the model provider + modelRef := ai.NewModelRef("googleai/gemini-2.5-flash", map[string]any{ + "temperature": 0.7, + }) + + fmt.Println("Model name:", modelRef.Name()) + // Output: Model name: googleai/gemini-2.5-flash +} + +// This example demonstrates building a multi-turn conversation. +func ExampleNewUserMessage() { + // Build a conversation with multiple parts + userMsg := ai.NewUserMessage( + ai.NewTextPart("What's in this image?"), + ai.NewMediaPart("image/jpeg", "base64data..."), + ) + + fmt.Println("Role:", userMsg.Role) + fmt.Println("Parts:", len(userMsg.Content)) + // Output: + // Role: user + // Parts: 2 +} diff --git a/go/ai/formatter_test.go b/go/ai/formatter_test.go index 5f75d77acf..e52162cfd4 100644 --- a/go/ai/formatter_test.go +++ b/go/ai/formatter_test.go @@ -1031,3 +1031,72 @@ func TestDefaultFormats(t *testing.T) { }) } } + +func TestArrayFormatterParseMessage(t *testing.T) { + schema := map[string]any{ + "type": "array", + "items": map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{"type": "integer"}, + }, + }, + } + + t.Run("returns message unchanged", func(t *testing.T) { + handler, err := arrayFormatter{}.Handler(schema) + if err != nil { + t.Fatalf("Handler() error = %v", err) + } + + msg := &Message{ + Role: RoleModel, + Content: []*Part{NewTextPart(`[{"id": 1}, {"id": 2}]`)}, + } + + got, err := handler.ParseMessage(msg) + if err != nil { + t.Fatalf("ParseMessage() error = %v", err) + } + + // Array formatter's ParseMessage returns the message unchanged + if got != msg { + t.Error("ParseMessage() should return the same message object") + } + }) +} + +func TestJSONLFormatterParseMessage(t *testing.T) { + schema := map[string]any{ + "type": "array", + "items": map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{"type": "integer"}, + "name": map[string]any{"type": "string"}, + }, + }, + } + + t.Run("returns message unchanged", func(t *testing.T) { + handler, err := jsonlFormatter{}.Handler(schema) + if err != nil { + t.Fatalf("Handler() error = %v", err) + } + + msg := &Message{ + Role: RoleModel, + Content: []*Part{NewTextPart("{\"id\": 1, \"name\": \"Alice\"}\n{\"id\": 2, \"name\": \"Bob\"}")}, + } + + got, err := handler.ParseMessage(msg) + if err != nil { + t.Fatalf("ParseMessage() error = %v", err) + } + + // JSONL formatter's ParseMessage returns the message unchanged + if got != msg { + t.Error("ParseMessage() should return the same message object") + } + }) +} diff --git a/go/ai/gen.go b/go/ai/gen.go index cbfe5cd9c6..e391ef2215 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -18,97 +18,45 @@ package ai -type BaseDataPoint struct { - Context map[string]any `json:"context,omitempty"` - Input map[string]any `json:"input,omitempty"` - Output map[string]any `json:"output,omitempty"` - Reference map[string]any `json:"reference,omitempty"` - TestCaseID string `json:"testCaseId,omitempty"` - TraceIDs []string `json:"traceIds,omitempty"` -} - -type BaseEvalDataPoint struct { - Context map[string]any `json:"context,omitempty"` - Input map[string]any `json:"input,omitempty"` - Output map[string]any `json:"output,omitempty"` - Reference map[string]any `json:"reference,omitempty"` - TestCaseID string `json:"testCaseId,omitempty"` - TraceIDs []string `json:"traceIds,omitempty"` -} - -type CandidateError struct { - Code CandidateErrorCode `json:"code,omitempty"` - Index float64 `json:"index,omitempty"` - Message string `json:"message,omitempty"` -} - -type CandidateErrorCode string - -const ( - CandidateErrorCodeBlocked CandidateErrorCode = "blocked" - CandidateErrorCodeOther CandidateErrorCode = "other" - CandidateErrorCodeUnknown CandidateErrorCode = "unknown" -) - -type CommonRerankerOptions struct { - // Number of documents to rerank - K float64 `json:"k,omitempty"` -} - -type CommonRetrieverOptions struct { - // Number of documents to retrieve - K float64 `json:"k,omitempty"` -} - type customPart struct { - Custom map[string]any `json:"custom,omitempty"` - Data any `json:"data,omitempty"` + // Custom contains custom key-value data specific to this part. + Custom map[string]any `json:"custom,omitempty"` + // Data contains additional arbitrary data. + Data any `json:"data,omitempty"` + // Metadata contains arbitrary key-value data for this part. Metadata map[string]any `json:"metadata,omitempty"` } type dataPart struct { - Data any `json:"data,omitempty"` + // Data contains arbitrary structured data. + Data any `json:"data,omitempty"` + // Metadata contains arbitrary key-value data for this part. Metadata map[string]any `json:"metadata,omitempty"` } +// EmbedRequest represents a request to generate embeddings for documents. type EmbedRequest struct { - Input []*Document `json:"input,omitempty"` - Options any `json:"options,omitempty"` + // Input is the array of documents to generate embeddings for. + Input []*Document `json:"input,omitempty"` + // Options contains embedder-specific configuration parameters. + Options any `json:"options,omitempty"` } +// EmbedResponse contains the generated embeddings from an embed request. type EmbedResponse struct { + // Embeddings is the array of generated embedding vectors with metadata. Embeddings []*Embedding `json:"embeddings,omitempty"` } +// Embedding represents a vector embedding with associated metadata. type Embedding struct { - Embedding []float32 `json:"embedding,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` -} - -type EvalFnResponse struct { - Evaluation any `json:"evaluation,omitempty"` - SampleIndex float64 `json:"sampleIndex,omitempty"` - SpanID string `json:"spanId,omitempty"` - TestCaseID string `json:"testCaseId,omitempty"` - TraceID string `json:"traceId,omitempty"` -} - -type EvalRequest struct { - Dataset []*BaseDataPoint `json:"dataset,omitempty"` - EvalRunID string `json:"evalRunId,omitempty"` - Options any `json:"options,omitempty"` + // Embedding is the vector representation of the input. + Embedding []float32 `json:"embedding,omitempty"` + // Metadata identifies which part of a document this embedding corresponds to. + Metadata map[string]any `json:"metadata,omitempty"` } -type EvalResponse []any - -type EvalStatusEnum string - -const ( - EvalStatusEnumUNKNOWN EvalStatusEnum = "UNKNOWN" - EvalStatusEnumPASS EvalStatusEnum = "PASS" - EvalStatusEnumFAIL EvalStatusEnum = "FAIL" -) - +// FinishReason indicates why generation stopped. type FinishReason string const ( @@ -120,26 +68,47 @@ const ( FinishReasonUnknown FinishReason = "unknown" ) +// GenerateActionOptions holds configuration for a generate action request. type GenerateActionOptions struct { - Config any `json:"config,omitempty"` - Docs []*Document `json:"docs,omitempty"` - MaxTurns int `json:"maxTurns,omitempty"` - Messages []*Message `json:"messages,omitempty"` - Model string `json:"model,omitempty"` - Output *GenerateActionOutputConfig `json:"output,omitempty"` - Resume *GenerateActionResume `json:"resume,omitempty"` - ReturnToolRequests bool `json:"returnToolRequests,omitempty"` - StepName string `json:"stepName,omitempty"` - ToolChoice ToolChoice `json:"toolChoice,omitempty"` - Tools []string `json:"tools,omitempty"` -} - + // Config contains configuration parameters for the generation request. + Config any `json:"config,omitempty"` + // Docs provides retrieved documents to be used as context for this generation. + Docs []*Document `json:"docs,omitempty"` + // MaxTurns is the maximum number of tool call iterations that can be performed + // in a single generate call. Defaults to 5. + MaxTurns int `json:"maxTurns,omitempty"` + // Messages contains the conversation history for multi-turn prompting when supported. + Messages []*Message `json:"messages,omitempty"` + // Model is a model name (e.g., "vertexai/gemini-1.0-pro"). + Model string `json:"model,omitempty"` + // Output specifies the desired output format. Defaults to the model's default if unspecified. + Output *GenerateActionOutputConfig `json:"output,omitempty"` + // Resume provides options for resuming an interrupted generation. + Resume *GenerateActionResume `json:"resume,omitempty"` + // ReturnToolRequests, when true, returns tool calls for manual processing instead of + // automatically resolving them. + ReturnToolRequests bool `json:"returnToolRequests,omitempty"` + // StepName is a custom step name for this generate call to display in trace views. + // Defaults to "generate". + StepName string `json:"stepName,omitempty"` + // ToolChoice controls tool calling mode. Auto lets the model decide, required forces + // the model to choose a tool, and none forces the model not to use any tools. Defaults to auto. + ToolChoice ToolChoice `json:"toolChoice,omitempty"` + // Tools is a list of registered tool names for this generation if supported. + Tools []string `json:"tools,omitempty"` +} + +// GenerateActionResume holds options for resuming an interrupted generation. type GenerateActionResume struct { - Metadata map[string]any `json:"metadata,omitempty"` - Respond []*toolResponsePart `json:"respond,omitempty"` - Restart []*toolRequestPart `json:"restart,omitempty"` + // Metadata contains additional context for resuming the generation. + Metadata map[string]any `json:"metadata,omitempty"` + // Respond contains tool response parts to send to the model when resuming. + Respond []*toolResponsePart `json:"respond,omitempty"` + // Restart contains tool request parts to restart when resuming. + Restart []*toolRequestPart `json:"restart,omitempty"` } +// ToolChoice controls how the model uses tools. type ToolChoice string const ( @@ -148,67 +117,113 @@ const ( ToolChoiceNone ToolChoice = "none" ) +// GenerateActionOutputConfig specifies the desired output format for a generate action. type GenerateActionOutputConfig struct { - Constrained bool `json:"constrained,omitempty"` - ContentType string `json:"contentType,omitempty"` - Format string `json:"format,omitempty"` - Instructions *string `json:"instructions,omitempty"` - JsonSchema map[string]any `json:"jsonSchema,omitempty"` + // Constrained indicates whether to enforce strict adherence to the schema. + Constrained bool `json:"constrained,omitempty"` + // ContentType specifies the MIME type of the output content. + ContentType string `json:"contentType,omitempty"` + // Format specifies the desired output format (e.g., "json", "text"). + Format string `json:"format,omitempty"` + // Instructions provides additional guidance for the output format. + Instructions *string `json:"instructions,omitempty"` + // JsonSchema is a JSON Schema describing the desired structure of JSON output. + JsonSchema map[string]any `json:"jsonSchema,omitempty"` } -// GenerationCommonConfig holds configuration for generation. +// GenerationCommonConfig holds configuration parameters for model generation requests. type GenerationCommonConfig struct { - MaxOutputTokens int `json:"maxOutputTokens,omitempty"` - StopSequences []string `json:"stopSequences,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopK int `json:"topK,omitempty"` - TopP float64 `json:"topP,omitempty"` - Version string `json:"version,omitempty"` -} - -// GenerationUsage provides information about the generation process. + // MaxOutputTokens limits the maximum number of tokens generated in the response. + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + // StopSequences specifies sequences that will cause generation to stop when encountered. + StopSequences []string `json:"stopSequences,omitempty"` + // Temperature controls randomness in generation. Higher values (e.g., 0.9) make output more random, + // while lower values (e.g., 0.1) make it more deterministic. Typical range is 0.0 to 1.0. + Temperature float64 `json:"temperature,omitempty"` + // TopK limits sampling to the K most likely tokens at each step. + TopK int `json:"topK,omitempty"` + // TopP (nucleus sampling) limits sampling to tokens whose cumulative probability exceeds P. + TopP float64 `json:"topP,omitempty"` + // Version specifies a particular version of a model family, + // e.g., "gemini-1.0-pro-001" for the "gemini-1.0-pro" family. + Version string `json:"version,omitempty"` +} + +// GenerationUsage provides information about resource consumption during generation. type GenerationUsage struct { - CachedContentTokens int `json:"cachedContentTokens,omitempty"` - Custom map[string]float64 `json:"custom,omitempty"` - InputAudioFiles int `json:"inputAudioFiles,omitempty"` - InputCharacters int `json:"inputCharacters,omitempty"` - InputImages int `json:"inputImages,omitempty"` - InputTokens int `json:"inputTokens,omitempty"` - InputVideos int `json:"inputVideos,omitempty"` - OutputAudioFiles int `json:"outputAudioFiles,omitempty"` - OutputCharacters int `json:"outputCharacters,omitempty"` - OutputImages int `json:"outputImages,omitempty"` - OutputTokens int `json:"outputTokens,omitempty"` - OutputVideos int `json:"outputVideos,omitempty"` - ThoughtsTokens int `json:"thoughtsTokens,omitempty"` - TotalTokens int `json:"totalTokens,omitempty"` -} - + // CachedContentTokens counts tokens that were served from cache. + CachedContentTokens int `json:"cachedContentTokens,omitempty"` + // Custom contains additional usage metrics specific to the model provider. + Custom map[string]float64 `json:"custom,omitempty"` + // InputAudioFiles is the number of audio files in the input. + InputAudioFiles int `json:"inputAudioFiles,omitempty"` + // InputCharacters is the number of characters in the input. + InputCharacters int `json:"inputCharacters,omitempty"` + // InputImages is the number of images in the input. + InputImages int `json:"inputImages,omitempty"` + // InputTokens is the number of tokens in the input prompt. + InputTokens int `json:"inputTokens,omitempty"` + // InputVideos is the number of videos in the input. + InputVideos int `json:"inputVideos,omitempty"` + // OutputAudioFiles is the number of audio files generated in the output. + OutputAudioFiles int `json:"outputAudioFiles,omitempty"` + // OutputCharacters is the number of characters generated in the output. + OutputCharacters int `json:"outputCharacters,omitempty"` + // OutputImages is the number of images generated in the output. + OutputImages int `json:"outputImages,omitempty"` + // OutputTokens is the number of tokens generated in the response. + OutputTokens int `json:"outputTokens,omitempty"` + // OutputVideos is the number of videos generated in the output. + OutputVideos int `json:"outputVideos,omitempty"` + // ThoughtsTokens counts tokens used in reasoning or thinking processes. + ThoughtsTokens int `json:"thoughtsTokens,omitempty"` + // TotalTokens is the sum of input and output tokens. + TotalTokens int `json:"totalTokens,omitempty"` +} + +// Media represents media content with a URL and content type. type Media struct { + // ContentType specifies the MIME type of the media. Inferred from the data URI if not provided. ContentType string `json:"contentType,omitempty"` - Url string `json:"url,omitempty"` + // Url is a "data:" or "https:" URI containing the media content. + Url string `json:"url,omitempty"` } type mediaPart struct { - Media *Media `json:"media,omitempty"` + // Media contains the media content and metadata. + Media *Media `json:"media,omitempty"` + // Metadata contains arbitrary key-value data for this part. Metadata map[string]any `json:"metadata,omitempty"` } -// Message is the contents of a model response. +// Message represents the contents of a model message in a conversation. type Message struct { - Content []*Part `json:"content,omitempty"` + // Content holds the message parts (text, media, tool calls, etc.). + Content []*Part `json:"content,omitempty"` + // Metadata contains arbitrary key-value data associated with this message. Metadata map[string]any `json:"metadata,omitempty"` - Role Role `json:"role,omitempty"` + // Role indicates which entity (system, user, model, or tool) generated this message. + Role Role `json:"role,omitempty"` } +// ModelInfo contains metadata about a model's capabilities and characteristics. type ModelInfo struct { + // ConfigSchema defines the model-specific configuration schema. ConfigSchema map[string]any `json:"configSchema,omitempty"` - Label string `json:"label,omitempty"` - Stage ModelStage `json:"stage,omitempty"` - Supports *ModelSupports `json:"supports,omitempty"` - Versions []string `json:"versions,omitempty"` -} - + // Label is a friendly display name for this model (e.g., "Google AI - Gemini Pro"). + Label string `json:"label,omitempty"` + // Stage indicates the development stage of this model. + // Featured models are recommended for general use, stable models are well-tested, + // unstable models are experimental, legacy models are not recommended for new projects, + // and deprecated models may be removed in future versions. + Stage ModelStage `json:"stage,omitempty"` + // Supports describes the capabilities that this model supports. + Supports *ModelSupports `json:"supports,omitempty"` + // Versions lists acceptable names for this model (e.g., different versions). + Versions []string `json:"versions,omitempty"` +} + +// ModelStage indicates the development stage of a model. type ModelStage string const ( @@ -219,19 +234,31 @@ const ( ModelStageDeprecated ModelStage = "deprecated" ) +// ModelSupports describes the capabilities that a model supports. type ModelSupports struct { + // Constrained indicates the level of constrained generation support (none, all, or no-tools). Constrained ConstrainedSupport `json:"constrained,omitempty"` - ContentType []string `json:"contentType,omitempty"` - Context bool `json:"context,omitempty"` - LongRunning bool `json:"longRunning,omitempty"` - Media bool `json:"media,omitempty"` - Multiturn bool `json:"multiturn,omitempty"` - Output []string `json:"output,omitempty"` - SystemRole bool `json:"systemRole,omitempty"` - ToolChoice bool `json:"toolChoice,omitempty"` - Tools bool `json:"tools,omitempty"` -} - + // ContentType lists the content types the model supports for output. + ContentType []string `json:"contentType,omitempty"` + // Context indicates whether the model can natively support document-based context grounding. + Context bool `json:"context,omitempty"` + // LongRunning indicates whether the model supports long-running operations. + LongRunning bool `json:"longRunning,omitempty"` + // Media indicates whether the model can process media as part of the prompt (multimodal input). + Media bool `json:"media,omitempty"` + // Multiturn indicates whether the model can process historical messages passed with a prompt. + Multiturn bool `json:"multiturn,omitempty"` + // Output lists the types of data the model can generate. + Output []string `json:"output,omitempty"` + // SystemRole indicates whether the model can accept messages with role "system". + SystemRole bool `json:"systemRole,omitempty"` + // ToolChoice indicates whether the model supports controlling tool choice (e.g., forced tool calling). + ToolChoice bool `json:"toolChoice,omitempty"` + // Tools indicates whether the model can perform tool calls. + Tools bool `json:"tools,omitempty"` +} + +// ConstrainedSupport indicates the level of constrained generation support. type ConstrainedSupport string const ( @@ -242,118 +269,176 @@ const ( // A ModelRequest is a request to generate completions from a model. type ModelRequest struct { - Config any `json:"config,omitempty"` - Docs []*Document `json:"docs,omitempty"` - Messages []*Message `json:"messages,omitempty"` + // Config holds model-specific configuration parameters. + Config any `json:"config,omitempty"` + // Docs provides retrieved documents to be used as context for this generation. + Docs []*Document `json:"docs,omitempty"` + // Messages contains the conversation history for the model. + Messages []*Message `json:"messages,omitempty"` // Output describes the desired response format. - Output *ModelOutputConfig `json:"output,omitempty"` - ToolChoice ToolChoice `json:"toolChoice,omitempty"` + Output *ModelOutputConfig `json:"output,omitempty"` + // ToolChoice controls how the model uses tools (auto, required, or none). + ToolChoice ToolChoice `json:"toolChoice,omitempty"` // Tools lists the available tools that the model can ask the client to run. Tools []*ToolDefinition `json:"tools,omitempty"` } -// A ModelResponse is a model's response to a [ModelRequest]. +// A ModelResponse is a model's response to a ModelRequest. type ModelResponse struct { - Custom any `json:"custom,omitempty"` - FinishMessage string `json:"finishMessage,omitempty"` - FinishReason FinishReason `json:"finishReason,omitempty"` + // Custom contains model-specific extra information. Deprecated: use Raw instead. + Custom any `json:"custom,omitempty"` + // FinishMessage provides additional details about why generation finished. + FinishMessage string `json:"finishMessage,omitempty"` + // FinishReason indicates why generation stopped (e.g., stop, length, blocked). + FinishReason FinishReason `json:"finishReason,omitempty"` // LatencyMs is the time the request took in milliseconds. - LatencyMs float64 `json:"latencyMs,omitempty"` - Message *Message `json:"message,omitempty"` + LatencyMs float64 `json:"latencyMs,omitempty"` + // Message contains the generated response content. + Message *Message `json:"message,omitempty"` + // Operation provides information about a long-running background task if applicable. Operation *Operation `json:"operation,omitempty"` - Raw any `json:"raw,omitempty"` - // Request is the [ModelRequest] struct used to trigger this response. + // Raw contains the unprocessed model-specific response data. + Raw any `json:"raw,omitempty"` + // Request is the ModelRequest struct used to trigger this response. Request *ModelRequest `json:"request,omitempty"` // Usage describes how many resources were used by this generation request. Usage *GenerationUsage `json:"usage,omitempty"` formatHandler StreamingFormatHandler } -// A ModelResponseChunk is the portion of the [ModelResponse] +// A ModelResponseChunk is the portion of the ModelResponse // that is passed to a streaming callback. type ModelResponseChunk struct { - Aggregated bool `json:"aggregated,omitempty"` - Content []*Part `json:"content,omitempty"` - Custom any `json:"custom,omitempty"` - Index int `json:"index"` - Role Role `json:"role,omitempty"` + // Aggregated indicates whether the chunk includes all data from previous chunks. + // If false, the chunk is considered incremental. + Aggregated bool `json:"aggregated,omitempty"` + // Content is the chunk of message parts to stream right now. + Content []*Part `json:"content,omitempty"` + // Custom contains model-specific extra information attached to this chunk. + Custom any `json:"custom,omitempty"` + // Index of the message this chunk belongs to. + Index int `json:"index"` + // Role indicates the entity that generated this chunk. + Role Role `json:"role,omitempty"` formatHandler StreamingFormatHandler } +// MultipartToolResponse represents a tool response with both structured output and content parts. type MultipartToolResponse struct { + // Content holds additional message parts providing context or details. Content []*Part `json:"content,omitempty"` - Output any `json:"output,omitempty"` + // Output contains the structured output data from the tool. + Output any `json:"output,omitempty"` } +// Operation represents a long-running background task. type Operation struct { - Action string `json:"action,omitempty"` - Done bool `json:"done,omitempty"` - Error *OperationError `json:"error,omitempty"` - Id string `json:"id,omitempty"` - Metadata map[string]any `json:"metadata,omitempty"` - Output any `json:"output,omitempty"` + // Action is the name of the action being performed by this operation. + Action string `json:"action,omitempty"` + // Done indicates whether the operation has completed. + Done bool `json:"done,omitempty"` + // Error contains error information if the operation failed. + Error *OperationError `json:"error,omitempty"` + // Id is the unique identifier for this operation. + Id string `json:"id,omitempty"` + // Metadata contains additional information about the operation. + Metadata map[string]any `json:"metadata,omitempty"` + // Output contains the result of the operation if it has completed successfully. + Output any `json:"output,omitempty"` } +// OperationError contains error information for a failed operation. type OperationError struct { + // Message describes the error that occurred. Message string `json:"message,omitempty"` } // OutputConfig describes the structure that the model's output -// should conform to. If Format is [OutputFormatJSON], then Schema +// should conform to. If Format is OutputFormatJSON, then Schema // can describe the desired form of the generated JSON. type ModelOutputConfig struct { - Constrained bool `json:"constrained,omitempty"` - ContentType string `json:"contentType,omitempty"` - Format string `json:"format,omitempty"` - Schema map[string]any `json:"schema,omitempty"` + // Constrained indicates whether to enforce strict adherence to the schema. + Constrained bool `json:"constrained,omitempty"` + // ContentType specifies the MIME type of the output content. + ContentType string `json:"contentType,omitempty"` + // Format specifies the desired output format (e.g., "json", "text"). + Format string `json:"format,omitempty"` + // Schema is a JSON Schema describing the desired structure of the output. + Schema map[string]any `json:"schema,omitempty"` } +// PathMetadata contains metadata about a single execution path in a trace. type PathMetadata struct { - Error string `json:"error,omitempty"` + // Error contains error information if the path failed. + Error string `json:"error,omitempty"` + // Latency is the execution time for this path in milliseconds. Latency float64 `json:"latency,omitempty"` - Path string `json:"path,omitempty"` - Status string `json:"status,omitempty"` + // Path is the identifier for this execution path. + Path string `json:"path,omitempty"` + // Status indicates the outcome of this path. + Status string `json:"status,omitempty"` } +// RankedDocumentData represents a document with a relevance score from reranking. type RankedDocumentData struct { - Content []*Part `json:"content,omitempty"` + // Content holds the document's parts (text and media). + Content []*Part `json:"content,omitempty"` + // Metadata contains the reranking score and other arbitrary key-value data. Metadata *RankedDocumentMetadata `json:"metadata,omitempty"` } +// RankedDocumentMetadata contains the relevance score and other metadata for a reranked document. type RankedDocumentMetadata struct { + // Score is the relevance score assigned by the reranker. Score float64 `json:"score,omitempty"` } type reasoningPart struct { - Metadata map[string]any `json:"metadata,omitempty"` - Reasoning string `json:"reasoning,omitempty"` + // Metadata contains arbitrary key-value data for this part. + Metadata map[string]any `json:"metadata,omitempty"` + // Reasoning contains the reasoning text of the message. + Reasoning string `json:"reasoning,omitempty"` } +// RerankerRequest represents a request to rerank documents based on relevance. type RerankerRequest struct { + // Documents is the array of documents to rerank. Documents []*Document `json:"documents,omitempty"` - Options any `json:"options,omitempty"` - Query *Document `json:"query,omitempty"` + // Options contains reranker-specific configuration parameters. + Options any `json:"options,omitempty"` + // Query is the document to use for reranking. + Query *Document `json:"query,omitempty"` } +// RerankerResponse contains the reranked documents with relevance scores. type RerankerResponse struct { + // Documents is the array of reranked documents with scores. Documents []*RankedDocumentData `json:"documents,omitempty"` } type resourcePart struct { + // Metadata contains arbitrary key-value data for this part. Metadata map[string]any `json:"metadata,omitempty"` - Resource *ResourcePart `json:"resource,omitempty"` + // Resource contains a reference to an external resource by URI. + Resource *ResourcePart `json:"resource,omitempty"` } type ResourcePart struct { + // Uri is the URI of the external resource. Uri string `json:"uri,omitempty"` } +// RetrieverRequest represents a request to retrieve relevant documents. type RetrieverRequest struct { - Options any `json:"options,omitempty"` - Query *Document `json:"query,omitempty"` + // Options contains retriever-specific configuration parameters. + Options any `json:"options,omitempty"` + // Query is the document to use for retrieval. + Query *Document `json:"query,omitempty"` } +// RetrieverResponse contains the retrieved documents from a retriever request. type RetrieverResponse struct { + // Documents is the array of retrieved documents. Documents []*Document `json:"documents,omitempty"` } @@ -372,63 +457,83 @@ const ( RoleTool Role = "tool" ) +// ScoreDetails provides additional context and explanation for an evaluation score. type ScoreDetails struct { + // Reasoning explains the rationale behind the score. Reasoning string `json:"reasoning,omitempty"` } type textPart struct { + // Metadata contains arbitrary key-value data for this part. Metadata map[string]any `json:"metadata,omitempty"` - Text string `json:"text,omitempty"` + // Text contains the textual content. + Text string `json:"text,omitempty"` } // A ToolDefinition describes a tool. type ToolDefinition struct { + // Description explains what the tool does and when to use it. Description string `json:"description,omitempty"` - // Valid JSON Schema representing the input of the tool. + // InputSchema is a valid JSON Schema representing the input parameters of the tool. InputSchema map[string]any `json:"inputSchema,omitempty"` - // additional metadata for this tool definition + // Metadata contains additional information about this tool definition. Metadata map[string]any `json:"metadata,omitempty"` - Name string `json:"name,omitempty"` - // Valid JSON Schema describing the output of the tool. + // Name is the unique identifier for this tool. + Name string `json:"name,omitempty"` + // OutputSchema is a valid JSON Schema describing the output of the tool. OutputSchema map[string]any `json:"outputSchema,omitempty"` } // A ToolRequest is a message from the model to the client that it should run a -// specific tool and pass a [ToolResponse] to the model on the next chat request it makes. -// Any ToolRequest will correspond to some [ToolDefinition] previously sent by the client. +// specific tool and pass a ToolResponse to the model on the next chat request it makes. +// Any ToolRequest will correspond to some ToolDefinition previously sent by the client. type ToolRequest struct { - // Input is a JSON object describing the input values to the tool. - // An example might be map[string]any{"country":"USA", "president":3}. - Input any `json:"input,omitempty"` - Name string `json:"name,omitempty"` - Partial bool `json:"partial,omitempty"` - Ref string `json:"ref,omitempty"` + // Input is a JSON object containing the input parameters for the tool. + // For example: map[string]any{"country":"USA", "president":3}. + Input any `json:"input,omitempty"` + // Name is the name of the tool to call. + Name string `json:"name,omitempty"` + // Partial indicates whether this is a partial streaming chunk. + Partial bool `json:"partial,omitempty"` + // Ref is the call ID or reference for this specific request. + Ref string `json:"ref,omitempty"` } type toolRequestPart struct { - Metadata map[string]any `json:"metadata,omitempty"` - ToolRequest *ToolRequest `json:"toolRequest,omitempty"` + // Metadata contains arbitrary key-value data for this part. + Metadata map[string]any `json:"metadata,omitempty"` + // ToolRequest is a request for a tool to be executed, usually provided by a model. + ToolRequest *ToolRequest `json:"toolRequest,omitempty"` } // A ToolResponse is a message from the client to the model containing // the results of running a specific tool on the arguments passed to the client -// by the model in a [ToolRequest]. +// by the model in a ToolRequest. type ToolResponse struct { + // Content holds additional message parts that provide context or details about the tool response. Content []*Part `json:"content,omitempty"` - Name string `json:"name,omitempty"` + // Name is the name of the tool that was executed. + Name string `json:"name,omitempty"` // Output is a JSON object describing the results of running the tool. - // An example might be map[string]any{"name":"Thomas Jefferson", "born":1743}. - Output any `json:"output,omitempty"` - Ref string `json:"ref,omitempty"` + // For example: map[string]any{"name":"Thomas Jefferson", "born":1743}. + Output any `json:"output,omitempty"` + // Ref is the call ID or reference matching the original request. + Ref string `json:"ref,omitempty"` } type toolResponsePart struct { - Metadata map[string]any `json:"metadata,omitempty"` - ToolResponse *ToolResponse `json:"toolResponse,omitempty"` + // Metadata contains arbitrary key-value data for this part. + Metadata map[string]any `json:"metadata,omitempty"` + // ToolResponse is a provided response to a tool call. + ToolResponse *ToolResponse `json:"toolResponse,omitempty"` } +// TraceMetadata contains metadata about a trace execution. type TraceMetadata struct { - FeatureName string `json:"featureName,omitempty"` - Paths []*PathMetadata `json:"paths,omitempty"` - Timestamp float64 `json:"timestamp,omitempty"` + // FeatureName identifies the feature being traced. + FeatureName string `json:"featureName,omitempty"` + // Paths contains metadata for each path executed during the trace. + Paths []*PathMetadata `json:"paths,omitempty"` + // Timestamp is when the trace was created. + Timestamp float64 `json:"timestamp,omitempty"` } diff --git a/go/ai/generate.go b/go/ai/generate.go index f26cc9f09a..08359c99c7 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -21,6 +21,7 @@ import ( "encoding/json" "errors" "fmt" + "iter" "slices" "strings" @@ -550,7 +551,7 @@ func GenerateText(ctx context.Context, r api.Registry, opts ...GenerateOption) ( return res.Text(), nil } -// Generate run generate request for this model. Returns ModelResponse struct. +// GenerateData runs a generate request and returns strongly-typed output. func GenerateData[Out any](ctx context.Context, r api.Registry, opts ...GenerateOption) (*Out, *ModelResponse, error) { var value Out opts = append(opts, WithOutputType(value)) @@ -568,6 +569,108 @@ func GenerateData[Out any](ctx context.Context, r api.Registry, opts ...Generate return &value, resp, nil } +// StreamValue is either a streamed chunk or the final response of a generate request. +type StreamValue[Out, Stream any] struct { + Done bool + Chunk Stream // valid if Done is false + Output Out // valid if Done is true + Response *ModelResponse // valid if Done is true +} + +// ModelStreamValue is a stream value for a model response. +// Out is never set because the output is already available in the Response field. +type ModelStreamValue = StreamValue[struct{}, *ModelResponseChunk] + +// errGenerateStop is a sentinel error used to signal early termination of streaming. +var errGenerateStop = errors.New("stop") + +// GenerateStream generates a model response and streams the output. +// It returns an iterator that yields streaming results. +// +// If the yield function is passed a non-nil error, generation has failed with that +// error; the yield function will not be called again. +// +// If the yield function's [ModelStreamValue] argument has Done == true, the value's +// Response field contains the final response; the yield function will not be called +// again. +// +// Otherwise the Chunk field of the passed [ModelStreamValue] holds a streamed chunk. +func GenerateStream(ctx context.Context, r api.Registry, opts ...GenerateOption) iter.Seq2[*ModelStreamValue, error] { + return func(yield func(*ModelStreamValue, error) bool) { + cb := func(ctx context.Context, chunk *ModelResponseChunk) error { + if ctx.Err() != nil { + return ctx.Err() + } + if !yield(&ModelStreamValue{Chunk: chunk}, nil) { + return errGenerateStop + } + return nil + } + + allOpts := append(slices.Clone(opts), WithStreaming(cb)) + + resp, err := Generate(ctx, r, allOpts...) + if err != nil { + yield(nil, err) + } else { + yield(&ModelStreamValue{Done: true, Response: resp}, nil) + } + } +} + +// GenerateDataStream generates a model response with streaming and returns strongly-typed output. +// It returns an iterator that yields streaming results. +// +// If the yield function is passed a non-nil error, generation has failed with that +// error; the yield function will not be called again. +// +// If the yield function's [StreamValue] argument has Done == true, the value's +// Output and Response fields contain the final typed output and response; the yield function +// will not be called again. +// +// Otherwise the Chunk field of the passed [StreamValue] holds a streamed chunk. +func GenerateDataStream[Out any](ctx context.Context, r api.Registry, opts ...GenerateOption) iter.Seq2[*StreamValue[Out, Out], error] { + return func(yield func(*StreamValue[Out, Out], error) bool) { + cb := func(ctx context.Context, chunk *ModelResponseChunk) error { + if ctx.Err() != nil { + return ctx.Err() + } + var streamValue Out + if err := chunk.Output(&streamValue); err != nil { + yield(nil, err) + return err + } + // Skip yielding if there's no parseable output yet (e.g., incomplete JSON during streaming). + if base.IsNil(streamValue) { + return nil + } + if !yield(&StreamValue[Out, Out]{Chunk: streamValue}, nil) { + return errGenerateStop + } + return nil + } + + // Prepend WithOutputType so the user can override the output format. + var value Out + allOpts := append([]GenerateOption{WithOutputType(value)}, opts...) + allOpts = append(allOpts, WithStreaming(cb)) + + resp, err := Generate(ctx, r, allOpts...) + if err != nil { + yield(nil, err) + return + } + + output, err := extractTypedOutput[Out](resp) + if err != nil { + yield(nil, err) + return + } + + yield(&StreamValue[Out, Out]{Done: true, Output: output, Response: resp}, nil) + } +} + // Generate applies the [Action] to provided request. func (m *model) Generate(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { if m == nil { @@ -744,7 +847,7 @@ func handleToolRequests(ctx context.Context, r api.Registry, req *ModelRequest, // [ModelResponse] as a string. It returns an empty string if there // are no candidates or if the candidate has no message. func (mr *ModelResponse) Text() string { - if mr.Message == nil { + if mr == nil || mr.Message == nil { return "" } return mr.Message.Text() @@ -753,7 +856,7 @@ func (mr *ModelResponse) Text() string { // History returns messages from the request combined with the response message // to represent the conversation history. func (mr *ModelResponse) History() []*Message { - if mr.Message == nil { + if mr == nil || mr.Message == nil { return mr.Request.Messages } return append(mr.Request.Messages, mr.Message) @@ -762,7 +865,7 @@ func (mr *ModelResponse) History() []*Message { // Reasoning concatenates all reasoning parts present in the message func (mr *ModelResponse) Reasoning() string { var sb strings.Builder - if mr.Message == nil { + if mr == nil || mr.Message == nil { return "" } @@ -806,7 +909,7 @@ func (mr *ModelResponse) Output(v any) error { // ToolRequests returns the tool requests from the response. func (mr *ModelResponse) ToolRequests() []*ToolRequest { toolReqs := []*ToolRequest{} - if mr.Message == nil { + if mr == nil || mr.Message == nil { return toolReqs } for _, part := range mr.Message.Content { @@ -820,7 +923,7 @@ func (mr *ModelResponse) ToolRequests() []*ToolRequest { // Interrupts returns the interrupted tool request parts from the response. func (mr *ModelResponse) Interrupts() []*Part { parts := []*Part{} - if mr.Message == nil { + if mr == nil || mr.Message == nil { return parts } for _, part := range mr.Message.Content { @@ -833,7 +936,7 @@ func (mr *ModelResponse) Interrupts() []*Part { // Media returns the media content of the [ModelResponse] as a string. func (mr *ModelResponse) Media() string { - if mr.Message == nil { + if mr == nil || mr.Message == nil { return "" } for _, part := range mr.Message.Content { @@ -902,17 +1005,41 @@ func (c *ModelResponseChunk) Output(v any) error { // outputer is an interface for types that can unmarshal structured output. type outputer interface { - Output(v any) error + // Text returns the contents of the output as a string. + Text() string + // Output parses the structured output from the response and unmarshals it into value. + Output(value any) error } // OutputFrom is a convenience function that parses structured output from a // [ModelResponse] or [ModelResponseChunk] and returns it as a typed value. // This is equivalent to calling Output() but returns the value directly instead // of requiring a pointer argument. If you need to handle the error, use Output() instead. -func OutputFrom[T any](src outputer) T { - var v T - src.Output(&v) - return v +func OutputFrom[Out any](src outputer) Out { + output, err := extractTypedOutput[Out](src) + if err != nil { + return base.Zero[Out]() + } + return output +} + +// extractTypedOutput extracts the typed output from a model response. +// It supports string output by calling Text() and returning the result. +func extractTypedOutput[Out any](o outputer) (Out, error) { + var output Out + + switch any(output).(type) { + case string: + text := o.Text() + // Type assertion to convert string to Out (which we know is string). + result := any(text).(Out) + return result, nil + default: + if err := o.Output(&output); err != nil { + return base.Zero[Out](), fmt.Errorf("failed to parse output: %w", err) + } + return output, nil + } } // Text returns the contents of a [Message] as a string. It diff --git a/go/ai/generate_test.go b/go/ai/generate_test.go index cac1f9d508..050a0f3cce 100644 --- a/go/ai/generate_test.go +++ b/go/ai/generate_test.go @@ -18,6 +18,7 @@ package ai import ( "context" + "errors" "fmt" "math" "strings" @@ -1745,3 +1746,542 @@ func TestMultipartTools(t *testing.T) { } }) } + +// streamingTestData holds test output structures +type streamingTestData struct { + Name string `json:"name"` + Value int `json:"value"` +} + +func TestGenerateStream(t *testing.T) { + r := registry.New() + ConfigureFormats(r) + DefineGenerateAction(context.Background(), r) + + t.Run("yields chunks then final response", func(t *testing.T) { + chunkTexts := []string{"Hello", " ", "World"} + chunkIndex := 0 + + streamModel := DefineModel(r, "test/streamModel", &ModelOptions{ + Supports: &ModelSupports{Multiturn: true}, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + if cb != nil { + for _, text := range chunkTexts { + cb(ctx, &ModelResponseChunk{ + Content: []*Part{NewTextPart(text)}, + }) + } + } + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("Hello World"), + }, nil + }) + + var receivedChunks []*ModelResponseChunk + var finalResponse *ModelResponse + + for val, err := range GenerateStream(context.Background(), r, + WithModel(streamModel), + WithPrompt("test streaming"), + ) { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if val.Done { + finalResponse = val.Response + } else { + receivedChunks = append(receivedChunks, val.Chunk) + chunkIndex++ + } + } + + if len(receivedChunks) != len(chunkTexts) { + t.Errorf("expected %d chunks, got %d", len(chunkTexts), len(receivedChunks)) + } + + for i, chunk := range receivedChunks { + if chunk.Text() != chunkTexts[i] { + t.Errorf("chunk %d: expected %q, got %q", i, chunkTexts[i], chunk.Text()) + } + } + + if finalResponse == nil { + t.Fatal("expected final response") + } + if finalResponse.Text() != "Hello World" { + t.Errorf("expected final text %q, got %q", "Hello World", finalResponse.Text()) + } + }) + + t.Run("handles no streaming callback gracefully", func(t *testing.T) { + noStreamModel := DefineModel(r, "test/noStreamModel", &ModelOptions{ + Supports: &ModelSupports{Multiturn: true}, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("response without streaming"), + }, nil + }) + + var finalResponse *ModelResponse + chunkCount := 0 + + for val, err := range GenerateStream(context.Background(), r, + WithModel(noStreamModel), + WithPrompt("test no stream"), + ) { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if val.Done { + finalResponse = val.Response + } else { + chunkCount++ + } + } + + if chunkCount != 0 { + t.Errorf("expected 0 chunks when model doesn't stream, got %d", chunkCount) + } + if finalResponse == nil { + t.Fatal("expected final response") + } + if finalResponse.Text() != "response without streaming" { + t.Errorf("expected text %q, got %q", "response without streaming", finalResponse.Text()) + } + }) + + t.Run("propagates generation errors", func(t *testing.T) { + expectedErr := errors.New("generation failed") + + errorModel := DefineModel(r, "test/errorModel", &ModelOptions{ + Supports: &ModelSupports{Multiturn: true}, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return nil, expectedErr + }) + + var receivedErr error + for _, err := range GenerateStream(context.Background(), r, + WithModel(errorModel), + WithPrompt("test error"), + ) { + if err != nil { + receivedErr = err + break + } + } + + if receivedErr == nil { + t.Fatal("expected error to be propagated") + } + if !errors.Is(receivedErr, expectedErr) { + t.Errorf("expected error %v, got %v", expectedErr, receivedErr) + } + }) + + t.Run("context cancellation stops iteration", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + streamModel := DefineModel(r, "test/cancelModel", &ModelOptions{ + Supports: &ModelSupports{Multiturn: true}, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + if cb != nil { + for i := 0; i < 100; i++ { + err := cb(ctx, &ModelResponseChunk{ + Content: []*Part{NewTextPart("chunk")}, + }) + if err != nil { + return nil, err + } + } + } + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("done"), + }, nil + }) + + chunksReceived := 0 + var receivedErr error + for val, err := range GenerateStream(ctx, r, + WithModel(streamModel), + WithPrompt("test cancel"), + ) { + if err != nil { + receivedErr = err + break + } + if !val.Done { + chunksReceived++ + if chunksReceived == 2 { + cancel() + } + } + } + + if chunksReceived < 2 { + t.Errorf("expected at least 2 chunks before cancellation, got %d", chunksReceived) + } + if receivedErr == nil { + t.Error("expected error from cancelled context") + } + }) +} + +func TestGenerateDataStream(t *testing.T) { + r := registry.New() + ConfigureFormats(r) + DefineGenerateAction(context.Background(), r) + + t.Run("yields typed chunks and final output", func(t *testing.T) { + streamModel := DefineModel(r, "test/typedStreamModel", &ModelOptions{ + Supports: &ModelSupports{ + Multiturn: true, + Constrained: ConstrainedSupportAll, + }, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + if cb != nil { + cb(ctx, &ModelResponseChunk{ + Content: []*Part{NewJSONPart(`{"name":"partial","value":1}`)}, + }) + cb(ctx, &ModelResponseChunk{ + Content: []*Part{NewJSONPart(`{"name":"complete","value":42}`)}, + }) + } + return &ModelResponse{ + Request: req, + Message: &Message{ + Role: RoleModel, + Content: []*Part{NewJSONPart(`{"name":"final","value":42}`)}, + }, + }, nil + }) + + var chunks []streamingTestData + var finalOutput streamingTestData + var finalResponse *ModelResponse + + for val, err := range GenerateDataStream[streamingTestData](context.Background(), r, + WithModel(streamModel), + WithPrompt("test typed streaming"), + ) { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if val.Done { + finalOutput = val.Output + finalResponse = val.Response + } else { + chunks = append(chunks, val.Chunk) + } + } + + if len(chunks) < 1 { + t.Errorf("expected at least 1 chunk, got %d", len(chunks)) + } + + if finalOutput.Name != "final" || finalOutput.Value != 42 { + t.Errorf("expected final output {final, 42}, got %+v", finalOutput) + } + if finalResponse == nil { + t.Fatal("expected final response") + } + }) + + t.Run("final output is correctly typed", func(t *testing.T) { + streamModel := DefineModel(r, "test/finalTypedModel", &ModelOptions{ + Supports: &ModelSupports{ + Multiturn: true, + Constrained: ConstrainedSupportAll, + }, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return &ModelResponse{ + Request: req, + Message: &Message{ + Role: RoleModel, + Content: []*Part{NewJSONPart(`{"name":"result","value":123}`)}, + }, + }, nil + }) + + var finalOutput streamingTestData + var gotFinal bool + + for val, err := range GenerateDataStream[streamingTestData](context.Background(), r, + WithModel(streamModel), + WithPrompt("test final typed"), + ) { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if val.Done { + finalOutput = val.Output + gotFinal = true + } + } + + if !gotFinal { + t.Fatal("expected to receive final output") + } + if finalOutput.Name != "result" || finalOutput.Value != 123 { + t.Errorf("expected final output {result, 123}, got %+v", finalOutput) + } + }) + + t.Run("automatically sets output type", func(t *testing.T) { + var capturedRequest *ModelRequest + + streamModel := DefineModel(r, "test/autoOutputModel", &ModelOptions{ + Supports: &ModelSupports{ + Multiturn: true, + Constrained: ConstrainedSupportAll, + }, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + capturedRequest = req + return &ModelResponse{ + Request: req, + Message: &Message{ + Role: RoleModel, + Content: []*Part{NewJSONPart(`{"name":"test","value":1}`)}, + }, + }, nil + }) + + for range GenerateDataStream[streamingTestData](context.Background(), r, + WithModel(streamModel), + WithPrompt("test auto output type"), + ) { + } + + if capturedRequest == nil { + t.Fatal("expected request to be captured") + } + if capturedRequest.Output == nil || capturedRequest.Output.Schema == nil { + t.Error("expected output schema to be set automatically") + } + }) + + t.Run("propagates chunk parsing errors", func(t *testing.T) { + streamModel := DefineModel(r, "test/parseErrorModel", &ModelOptions{ + Supports: &ModelSupports{ + Multiturn: true, + Constrained: ConstrainedSupportAll, + }, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + if cb != nil { + cb(ctx, &ModelResponseChunk{ + Content: []*Part{NewTextPart("not valid json")}, + }) + } + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("done"), + }, nil + }) + + var receivedErr error + for _, err := range GenerateDataStream[streamingTestData](context.Background(), r, + WithModel(streamModel), + WithPrompt("test parse error"), + ) { + if err != nil { + receivedErr = err + break + } + } + + if receivedErr == nil { + t.Error("expected parsing error to be propagated") + } + }) +} + +func TestGenerateText(t *testing.T) { + r := newTestRegistry(t) + + echoModel := DefineModel(r, "test/echoTextModel", nil, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("echo: " + req.Messages[0].Content[0].Text), + }, nil + }) + + t.Run("returns text from model", func(t *testing.T) { + text, err := GenerateText(context.Background(), r, + WithModel(echoModel), + WithPrompt("hello"), + ) + + if err != nil { + t.Fatalf("GenerateText error: %v", err) + } + if text != "echo: hello" { + t.Errorf("text = %q, want %q", text, "echo: hello") + } + }) +} + +func TestGenerateData(t *testing.T) { + r := newTestRegistry(t) + + type TestOutput struct { + Value int `json:"value"` + } + + jsonModel := DefineModel(r, "test/jsonDataModel", &ModelOptions{ + Supports: &ModelSupports{ + Constrained: ConstrainedSupportAll, + }, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage(`{"value": 42}`), + }, nil + }) + + t.Run("returns typed data from model", func(t *testing.T) { + output, _, err := GenerateData[TestOutput](context.Background(), r, + WithModel(jsonModel), + WithPrompt("get value"), + ) + + if err != nil { + t.Fatalf("GenerateData error: %v", err) + } + if output.Value != 42 { + t.Errorf("output.Value = %d, want 42", output.Value) + } + }) +} + +func TestModelResponseReasoning(t *testing.T) { + t.Run("returns reasoning from response", func(t *testing.T) { + resp := &ModelResponse{ + Message: &Message{ + Role: RoleModel, + Content: []*Part{ + NewReasoningPart("thinking about this...", nil), + NewTextPart("final answer"), + }, + }, + } + + reasoning := resp.Reasoning() + + if reasoning != "thinking about this..." { + t.Errorf("Reasoning() = %q, want %q", reasoning, "thinking about this...") + } + }) + + t.Run("returns empty string when no reasoning", func(t *testing.T) { + resp := &ModelResponse{ + Message: NewModelTextMessage("just text"), + } + + reasoning := resp.Reasoning() + + if reasoning != "" { + t.Errorf("Reasoning() = %q, want empty string", reasoning) + } + }) +} + +func TestModelResponseInterrupts(t *testing.T) { + t.Run("returns interrupt tool requests", func(t *testing.T) { + interruptPart := NewToolRequestPart(&ToolRequest{ + Name: "confirmAction", + Input: map[string]any{}, + }) + interruptPart.Metadata = map[string]any{"interrupt": true} + + resp := &ModelResponse{ + Message: &Message{ + Role: RoleModel, + Content: []*Part{ + NewTextPart("Please confirm"), + interruptPart, + }, + }, + } + + interrupts := resp.Interrupts() + + if len(interrupts) != 1 { + t.Fatalf("len(Interrupts()) = %d, want 1", len(interrupts)) + } + if interrupts[0].ToolRequest.Name != "confirmAction" { + t.Errorf("interrupt name = %q, want %q", interrupts[0].ToolRequest.Name, "confirmAction") + } + }) + + t.Run("returns empty slice when no interrupts", func(t *testing.T) { + resp := &ModelResponse{ + Message: NewModelTextMessage("no interrupts here"), + } + + interrupts := resp.Interrupts() + + if len(interrupts) != 0 { + t.Errorf("len(Interrupts()) = %d, want 0", len(interrupts)) + } + }) +} + +func TestModelResponseMedia(t *testing.T) { + t.Run("returns media URL from response", func(t *testing.T) { + resp := &ModelResponse{ + Message: &Message{ + Role: RoleModel, + Content: []*Part{ + NewTextPart("Here's an image"), + NewMediaPart("image/png", "data:image/png;base64,abc123"), + }, + }, + } + + media := resp.Media() + + if media == "" { + t.Error("Media() returned empty string") + } + if media != "data:image/png;base64,abc123" { + t.Errorf("Media() = %q, want %q", media, "data:image/png;base64,abc123") + } + }) + + t.Run("returns empty string when no media", func(t *testing.T) { + resp := &ModelResponse{ + Message: NewModelTextMessage("just text"), + } + + media := resp.Media() + + if media != "" { + t.Errorf("Media() = %q, want empty string", media) + } + }) +} + +func TestOutputFrom(t *testing.T) { + type TestData struct { + Name string `json:"name"` + Count int `json:"count"` + } + + t.Run("extracts typed output from response", func(t *testing.T) { + resp := &ModelResponse{ + Message: NewModelTextMessage(`{"name": "test", "count": 5}`), + } + + output := OutputFrom[TestData](resp) + + if output.Name != "test" { + t.Errorf("output.Name = %q, want %q", output.Name, "test") + } + if output.Count != 5 { + t.Errorf("output.Count = %d, want 5", output.Count) + } + }) +} diff --git a/go/ai/option_test.go b/go/ai/option_test.go index 6fd2430842..04fee69a59 100644 --- a/go/ai/option_test.go +++ b/go/ai/option_test.go @@ -653,3 +653,129 @@ func (t *mockTool) Definition() *ToolDefinition { func (t *mockTool) RunRaw(ctx context.Context, input any) (any, error) { return nil, nil } + +func (t *mockTool) RunRawMultipart(ctx context.Context, input any) (*MultipartToolResponse, error) { + return nil, nil +} + +func (t *mockTool) Respond(toolReq *Part, outputData any, opts *RespondOptions) *Part { + return nil +} + +func (t *mockTool) Restart(toolReq *Part, opts *RestartOptions) *Part { + return nil +} + +func (t *mockTool) Register(r interface{ RegisterValue(string, any) }) { +} + +func TestWithInputSchemaName(t *testing.T) { + t.Run("creates input option with schema reference", func(t *testing.T) { + opt := WithInputSchemaName("MyInputType") + opts := &promptOptions{} + + if err := opt.applyPrompt(opts); err != nil { + t.Fatalf("applyPrompt() error: %v", err) + } + + if opts.InputSchema == nil { + t.Fatal("InputSchema is nil") + } + + ref, ok := opts.InputSchema["$ref"].(string) + if !ok { + t.Fatal("InputSchema.$ref is not a string") + } + if ref != "genkit:MyInputType" { + t.Errorf("InputSchema.$ref = %q, want %q", ref, "genkit:MyInputType") + } + }) +} + +func TestWithOutputSchema(t *testing.T) { + t.Run("creates output option with direct schema", func(t *testing.T) { + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + }, + } + opt := WithOutputSchema(schema) + opts := &generateOptions{} + + if err := opt.applyGenerate(opts); err != nil { + t.Fatalf("applyGenerate() error: %v", err) + } + + if opts.OutputSchema == nil { + t.Fatal("OutputSchema is nil") + } + if opts.OutputFormat != OutputFormatJSON { + t.Errorf("OutputFormat = %q, want %q", opts.OutputFormat, OutputFormatJSON) + } + }) +} + +func TestWithOutputEnums(t *testing.T) { + t.Run("creates enum output with string values", func(t *testing.T) { + opt := WithOutputEnums("red", "green", "blue") + opts := &generateOptions{} + + if err := opt.applyGenerate(opts); err != nil { + t.Fatalf("applyGenerate() error: %v", err) + } + + if opts.OutputSchema == nil { + t.Fatal("OutputSchema is nil") + } + if opts.OutputFormat != OutputFormatEnum { + t.Errorf("OutputFormat = %q, want %q", opts.OutputFormat, OutputFormatEnum) + } + + enumType, ok := opts.OutputSchema["type"].(string) + if !ok || enumType != "string" { + t.Errorf("OutputSchema.type = %v, want %q", opts.OutputSchema["type"], "string") + } + + enumVals, ok := opts.OutputSchema["enum"].([]string) + if !ok { + t.Fatalf("OutputSchema.enum is not []string: %T", opts.OutputSchema["enum"]) + } + if len(enumVals) != 3 { + t.Errorf("len(enum) = %d, want 3", len(enumVals)) + } + }) + + t.Run("works with custom string type", func(t *testing.T) { + type Color string + opt := WithOutputEnums(Color("red"), Color("green")) + opts := &generateOptions{} + + if err := opt.applyGenerate(opts); err != nil { + t.Fatalf("applyGenerate() error: %v", err) + } + + enumVals := opts.OutputSchema["enum"].([]string) + if enumVals[0] != "red" || enumVals[1] != "green" { + t.Errorf("enum values = %v, want [red, green]", enumVals) + } + }) +} + +func TestWithEvaluatorName(t *testing.T) { + t.Run("creates evaluator option with reference", func(t *testing.T) { + opt := WithEvaluatorName("test/myEvaluator") + opts := &evaluatorOptions{} + + if err := opt.applyEvaluator(opts); err != nil { + t.Fatalf("applyEvaluator() error: %v", err) + } + + if opts.Evaluator == nil { + t.Fatal("Evaluator is nil") + } + if opts.Evaluator.Name() != "test/myEvaluator" { + t.Errorf("Evaluator.Name() = %q, want %q", opts.Evaluator.Name(), "test/myEvaluator") + } + }) +} diff --git a/go/ai/prompt.go b/go/ai/prompt.go index db4ec264cd..ac4ef82e5f 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -19,11 +19,14 @@ import ( "encoding/json" "errors" "fmt" + "io/fs" + "iter" "log/slog" "maps" "os" - "path/filepath" + "path" "reflect" + "slices" "strings" "github.com/firebase/genkit/go/core" @@ -40,6 +43,8 @@ type Prompt interface { Name() string // Execute executes the prompt with the given options and returns a [ModelResponse]. Execute(ctx context.Context, opts ...PromptExecuteOption) (*ModelResponse, error) + // ExecuteStream executes the prompt with streaming and returns an iterator. + ExecuteStream(ctx context.Context, opts ...PromptExecuteOption) iter.Seq2[*ModelStreamValue, error] // Render renders the prompt with the given input and returns a [GenerateActionOptions] to be used with [GenerateWithRequest]. Render(ctx context.Context, input any) (*GenerateActionOptions, error) } @@ -51,6 +56,13 @@ type prompt struct { registry api.Registry } +// DataPrompt is a prompt with strongly-typed input and output. +// It wraps an underlying [Prompt] and provides type-safe Execute and Render methods. +// The Out type parameter can be string for text outputs or any struct type for JSON outputs. +type DataPrompt[In, Out any] struct { + prompt +} + // DefinePrompt creates a new [Prompt] and registers it. func DefinePrompt(r api.Registry, name string, opts ...PromptOption) Prompt { if name == "" { @@ -89,10 +101,7 @@ func DefinePrompt(r api.Registry, name string, opts ...PromptOption) Prompt { } metadata["type"] = api.ActionTypeExecutablePrompt - baseName := name - if idx := strings.LastIndex(name, "."); idx != -1 { - baseName = name[:idx] - } + baseName, variant, _ := strings.Cut(name, ".") promptMetadata := map[string]any{ "name": baseName, @@ -105,6 +114,9 @@ func DefinePrompt(r api.Registry, name string, opts ...PromptOption) Prompt { "tools": tools, "maxTurns": p.MaxTurns, } + if variant != "" { + promptMetadata["variant"] = variant + } if m, ok := metadata["prompt"].(map[string]any); ok { maps.Copy(m, promptMetadata) } else { @@ -133,7 +145,7 @@ func LookupPrompt(r api.Registry, name string) Prompt { // passes the rendered template to the AI model specified by the prompt. func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*ModelResponse, error) { if p == nil { - return nil, errors.New("Prompt.Execute: execute called on a nil Prompt; check that all prompts are defined") + return nil, core.NewError(core.INVALID_ARGUMENT, "Prompt.Execute: prompt is nil") } execOpts := &promptExecutionOptions{} @@ -239,10 +251,50 @@ func (p *prompt) Execute(ctx context.Context, opts ...PromptExecuteOption) (*Mod return GenerateWithRequest(ctx, r, actionOpts, execOpts.Middleware, execOpts.Stream) } +// ExecuteStream executes the prompt with streaming and returns an iterator. +// +// If the yield function is passed a non-nil error, execution has failed with that +// error; the yield function will not be called again. +// +// If the yield function's [ModelStreamValue] argument has Done == true, the value's +// Response field contains the final response; the yield function will not be called again. +// +// Otherwise the Chunk field of the passed [ModelStreamValue] holds a streamed chunk. +func (p *prompt) ExecuteStream(ctx context.Context, opts ...PromptExecuteOption) iter.Seq2[*ModelStreamValue, error] { + return func(yield func(*ModelStreamValue, error) bool) { + if p == nil { + yield(nil, core.NewError(core.INVALID_ARGUMENT, "Prompt.ExecuteStream: prompt is nil")) + return + } + + cb := func(ctx context.Context, chunk *ModelResponseChunk) error { + if ctx.Err() != nil { + return ctx.Err() + } + if !yield(&ModelStreamValue{Chunk: chunk}, nil) { + return errPromptStop + } + return nil + } + + allOpts := append(slices.Clone(opts), WithStreaming(cb)) + resp, err := p.Execute(ctx, allOpts...) + if err != nil { + yield(nil, err) + return + } + + yield(&ModelStreamValue{Done: true, Response: resp}, nil) + } +} + +// errPromptStop is a sentinel error used to signal early termination of streaming. +var errPromptStop = errors.New("stop") + // Render renders the prompt template based on user input. func (p *prompt) Render(ctx context.Context, input any) (*GenerateActionOptions, error) { if p == nil { - return nil, errors.New("Prompt.Render: called on a nil prompt; check that all prompts are defined") + return nil, core.NewError(core.INVALID_ARGUMENT, "Prompt.Render: prompt is nil") } if len(p.Middleware) > 0 { @@ -414,13 +466,16 @@ func renderSystemPrompt(ctx context.Context, opts promptOptions, messages []*Mes return nil, err } - parts, err := renderPrompt(ctx, opts, templateText, input, dp) + renderedMessages, err := renderPrompt(ctx, opts, templateText, input, dp) if err != nil { return nil, err } - if len(parts) != 0 { - messages = append(messages, NewSystemMessage(parts...)) + for _, m := range renderedMessages { + if m.Role == "" || (len(renderedMessages) == 1 && m.Role == RoleUser) { + m.Role = RoleSystem + } + messages = append(messages, m) } return messages, nil @@ -437,13 +492,16 @@ func renderUserPrompt(ctx context.Context, opts promptOptions, messages []*Messa return nil, err } - parts, err := renderPrompt(ctx, opts, templateText, input, dp) + renderedMessages, err := renderPrompt(ctx, opts, templateText, input, dp) if err != nil { return nil, err } - if len(parts) != 0 { - messages = append(messages, NewUserMessage(parts...)) + for _, m := range renderedMessages { + if m.Role == "" || (len(renderedMessages) == 1 && m.Role != RoleUser) { + m.Role = RoleUser + } + messages = append(messages, m) } return messages, nil @@ -463,47 +521,72 @@ func renderMessages(ctx context.Context, opts promptOptions, messages []*Message // Create new message copies to avoid mutating shared messages during concurrent execution renderedMsgs := make([]*Message, 0, len(msgs)) for _, msg := range msgs { - msgParts := []*Part{} + hasTextPart := slices.ContainsFunc(msg.Content, (*Part).IsText) + + if !hasTextPart { + // Create a new message with non-text content instead of mutating the original + renderedMsg := &Message{ + Role: msg.Role, + Content: msg.Content, + Metadata: msg.Metadata, + } + renderedMsgs = append(renderedMsgs, renderedMsg) + continue + } + for _, part := range msg.Content { if part.IsText() { - parts, err := renderPrompt(ctx, opts, part.Text, input, dp) + messagesFromText, err := renderPrompt(ctx, opts, part.Text, input, dp) if err != nil { return nil, err } - msgParts = append(msgParts, parts...) + for _, m := range messagesFromText { + // If the rendered message has no role, or it is a single message with default role, + // use the original message's role. + role := m.Role + if role == "" || (len(messagesFromText) == 1 && role == RoleUser) { + role = msg.Role + } + renderedMsgs = append(renderedMsgs, &Message{ + Role: role, + Content: m.Content, + Metadata: msg.Metadata, + }) + } } else { - // Preserve non-text parts as-is - msgParts = append(msgParts, part) + // Preserve non-text parts as-is in the current last message if possible, or create a new one + if len(renderedMsgs) > 0 && renderedMsgs[len(renderedMsgs)-1].Role == msg.Role { + renderedMsgs[len(renderedMsgs)-1].Content = append(renderedMsgs[len(renderedMsgs)-1].Content, part) + } else { + renderedMsgs = append(renderedMsgs, &Message{ + Role: msg.Role, + Content: []*Part{part}, + Metadata: msg.Metadata, + }) + } } } - // Create a new message with rendered content instead of mutating the original - renderedMsg := &Message{ - Role: msg.Role, - Content: msgParts, - Metadata: msg.Metadata, - } - renderedMsgs = append(renderedMsgs, renderedMsg) } return append(messages, renderedMsgs...), nil } // renderPrompt renders a prompt template using dotprompt functionalities -func renderPrompt(ctx context.Context, opts promptOptions, templateText string, input map[string]any, dp *dotprompt.Dotprompt) ([]*Part, error) { +func renderPrompt(ctx context.Context, opts promptOptions, templateText string, input map[string]any, dp *dotprompt.Dotprompt) ([]*Message, error) { renderedFunc, err := dp.Compile(templateText, &dotprompt.PromptMetadata{}) if err != nil { return nil, err } - return renderDotpromptToParts(ctx, renderedFunc, input, &dotprompt.PromptMetadata{ + return renderDotpromptToMessages(ctx, renderedFunc, input, &dotprompt.PromptMetadata{ Input: dotprompt.PromptMetadataInput{ Default: opts.DefaultInput, }, }) } -// renderDotpromptToParts executes a dotprompt prompt function and converts the result to a slice of parts -func renderDotpromptToParts(ctx context.Context, promptFn dotprompt.PromptFunction, input map[string]any, additionalMetadata *dotprompt.PromptMetadata) ([]*Part, error) { +// renderDotpromptToMessages executes a dotprompt prompt function and converts the result to a slice of messages +func renderDotpromptToMessages(ctx context.Context, promptFn dotprompt.PromptFunction, input map[string]any, additionalMetadata *dotprompt.PromptMetadata) ([]*Message, error) { // Prepare the context for rendering context := map[string]any{} actionCtx := core.FromContext(ctx) @@ -518,16 +601,20 @@ func renderDotpromptToParts(ctx context.Context, promptFn dotprompt.PromptFuncti return nil, fmt.Errorf("failed to render prompt: %w", err) } - convertedParts := []*Part{} + convertedMessages := []*Message{} for _, message := range rendered.Messages { parts, err := convertToPartPointers(message.Content) if err != nil { return nil, fmt.Errorf("failed to convert parts: %w", err) } - convertedParts = append(convertedParts, parts...) + role := Role(message.Role) + convertedMessages = append(convertedMessages, &Message{ + Role: role, + Content: parts, + }) } - return convertedParts, nil + return convertedMessages, nil } // convertToPartPointers converts []dotprompt.Part to []*Part @@ -550,87 +637,84 @@ func convertToPartPointers(parts []dotprompt.Part) ([]*Part, error) { return result, nil } -// LoadPromptDir loads prompts and partials from the input directory for the given namespace. -func LoadPromptDir(r api.Registry, dir string, namespace string) { - useDefaultDir := false - if dir == "" { - dir = "./prompts" - useDefaultDir = true +// LoadPromptDirFromFS loads prompts and partials from a filesystem for the given namespace. +// The fsys parameter should be an fs.FS implementation (e.g., embed.FS or os.DirFS). +// The dir parameter specifies the directory within the filesystem where prompts are located. +func LoadPromptDirFromFS(r api.Registry, fsys fs.FS, dir, namespace string) { + if fsys == nil { + panic(errors.New("no prompt filesystem provided")) } - path, err := filepath.Abs(dir) - if err != nil { - if !useDefaultDir { - panic(fmt.Errorf("failed to resolve prompt directory %q: %w", dir, err)) - } - slog.Debug("default prompt directory not found, skipping loading .prompt files", "dir", dir) - return + if _, err := fs.Stat(fsys, dir); err != nil { + panic(fmt.Errorf("failed to access prompt directory %q in filesystem: %w", dir, err)) } - if _, err := os.Stat(path); os.IsNotExist(err) { - if !useDefaultDir { - panic(fmt.Errorf("failed to resolve prompt directory %q: %w", dir, err)) - } - slog.Debug("Default prompt directory not found, skipping loading .prompt files", "dir", dir) - return - } - - loadPromptDir(r, path, namespace) -} - -// loadPromptDir recursively loads prompts and partials from the directory. -func loadPromptDir(r api.Registry, dir string, namespace string) { - entries, err := os.ReadDir(dir) + entries, err := fs.ReadDir(fsys, dir) if err != nil { panic(fmt.Errorf("failed to read prompt directory structure: %w", err)) } for _, entry := range entries { filename := entry.Name() - path := filepath.Join(dir, filename) + filePath := path.Join(dir, filename) if entry.IsDir() { - loadPromptDir(r, path, namespace) + LoadPromptDirFromFS(r, fsys, filePath, namespace) } else if strings.HasSuffix(filename, ".prompt") { if strings.HasPrefix(filename, "_") { partialName := strings.TrimSuffix(filename[1:], ".prompt") - source, err := os.ReadFile(path) + source, err := fs.ReadFile(fsys, filePath) if err != nil { slog.Error("Failed to read partial file", "error", err) continue } r.RegisterPartial(partialName, string(source)) - slog.Debug("Registered Dotprompt partial", "name", partialName, "file", path) + slog.Debug("Registered Dotprompt partial", "name", partialName, "file", filePath) } else { - LoadPrompt(r, dir, filename, namespace) + LoadPromptFromFS(r, fsys, dir, filename, namespace) } } } } -// LoadPrompt loads a single prompt into the registry. -func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt { +// LoadPromptFromFS loads a single prompt from a filesystem into the registry. +// The fsys parameter should be an fs.FS implementation (e.g., embed.FS or os.DirFS). +// The dir parameter specifies the directory within the filesystem where the prompt is located. +func LoadPromptFromFS(r api.Registry, fsys fs.FS, dir, filename, namespace string) Prompt { name := strings.TrimSuffix(filename, ".prompt") - name, variant, _ := strings.Cut(name, ".") - sourceFile := filepath.Join(dir, filename) - source, err := os.ReadFile(sourceFile) + sourceFile := path.Join(dir, filename) + source, err := fs.ReadFile(fsys, sourceFile) if err != nil { slog.Error("Failed to read prompt file", "file", sourceFile, "error", err) return nil } + p, err := LoadPromptFromSource(r, string(source), name, namespace) + if err != nil { + slog.Error("Failed to load prompt", "file", sourceFile, "error", err) + return nil + } + + slog.Debug("Registered Dotprompt", "name", p.Name(), "file", sourceFile) + return p +} + +// LoadPromptFromSource loads a prompt from raw .prompt file content. +// The source parameter should contain the complete .prompt file text (frontmatter + template). +// The name parameter is the prompt name (may include variant suffix like "myPrompt.variant"). +func LoadPromptFromSource(r api.Registry, source, name, namespace string) (Prompt, error) { + name, variant, _ := strings.Cut(name, ".") + dp := r.Dotprompt() - parsedPrompt, err := dp.Parse(string(source)) + parsedPrompt, err := dp.Parse(source) if err != nil { - slog.Error("Failed to parse file as dotprompt", "file", sourceFile, "error", err) - return nil + return nil, fmt.Errorf("failed to parse dotprompt: %w", err) } - metadata, err := dp.RenderMetadata(string(source), &parsedPrompt.PromptMetadata) + metadata, err := dp.RenderMetadata(source, &parsedPrompt.PromptMetadata) if err != nil { - slog.Error("Failed to render dotprompt metadata", "file", sourceFile, "error", err) - return nil + return nil, fmt.Errorf("failed to render dotprompt metadata: %w", err) } toolRefs := make([]ToolRef, len(metadata.Tools)) @@ -692,7 +776,11 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt { } if inputSchema, ok := metadata.Input.Schema.(map[string]any); ok { - opts.InputSchema = inputSchema + if ref, ok := inputSchema["$ref"].(string); ok { + opts.InputSchema = core.SchemaRef(ref) + } else { + opts.InputSchema = inputSchema + } } if metadata.Output.Format != "" { @@ -710,57 +798,32 @@ func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt { } } - key := promptKey(name, variant, namespace) - - dpMessages, err := dotprompt.ToMessages(parsedPrompt.Template, &dotprompt.DataArgument{}) - if err != nil { - slog.Error("Failed to convert prompt template to messages", "file", sourceFile, "error", err) - return nil - } - - var systemText string - var nonSystemMessages []*Message - for _, dpMsg := range dpMessages { - parts, err := convertToPartPointers(dpMsg.Content) - if err != nil { - slog.Error("Failed to convert message parts", "file", sourceFile, "error", err) - return nil - } - - role := Role(dpMsg.Role) - if role == RoleSystem { - var textParts []string - for _, part := range parts { - if part.IsText() { - textParts = append(textParts, part.Text) - } - } - - if len(textParts) > 0 { - systemText = strings.Join(textParts, " ") - } + if outputSchema, ok := metadata.Output.Schema.(map[string]any); ok { + if ref, ok := outputSchema["$ref"].(string); ok { + opts.OutputSchema = core.SchemaRef(ref) } else { - nonSystemMessages = append(nonSystemMessages, &Message{Role: role, Content: parts}) + opts.OutputSchema = outputSchema + } + if opts.OutputFormat == "" { + opts.OutputFormat = OutputFormatJSON } } - promptOpts := []PromptOption{opts} - - if systemText != "" { - promptOpts = append(promptOpts, WithSystem(systemText)) - } + key := promptKey(name, variant, namespace) - if len(nonSystemMessages) > 0 { - promptOpts = append(promptOpts, WithMessages(nonSystemMessages...)) - } else if systemText == "" { - promptOpts = append(promptOpts, WithPrompt(parsedPrompt.Template)) - } + prompt := DefinePrompt(r, key, opts, WithPrompt(parsedPrompt.Template)) - prompt := DefinePrompt(r, key, promptOpts...) + return prompt, nil +} - slog.Debug("Registered Dotprompt", "name", key, "file", sourceFile) +// LoadPromptDir loads prompts and partials from a directory on the local filesystem. +func LoadPromptDir(r api.Registry, dir string, namespace string) { + LoadPromptDirFromFS(r, os.DirFS(dir), ".", namespace) +} - return prompt +// LoadPrompt loads a single prompt from a directory on the local filesystem into the registry. +func LoadPrompt(r api.Registry, dir, filename, namespace string) Prompt { + return LoadPromptFromFS(r, os.DirFS(dir), ".", filename, namespace) } // promptKey generates a unique key for the prompt in the registry. @@ -807,3 +870,133 @@ func contentType(ct, uri string) (string, []byte, error) { return "", nil, errors.New("uri content type not found") } + +// DefineDataPrompt creates a new data prompt and registers it. +// It automatically infers input schema from the In type parameter and configures +// output schema and JSON format from the Out type parameter (unless Out is string). +func DefineDataPrompt[In, Out any](r api.Registry, name string, opts ...PromptOption) *DataPrompt[In, Out] { + if name == "" { + panic("ai.DefineDataPrompt: name is required") + } + + var in In + allOpts := []PromptOption{WithInputType(in)} + + var out Out + switch any(out).(type) { + case string: + // String output - no schema needed + default: + // Prepend WithOutputType so the user can override the output format. + allOpts = append(allOpts, WithOutputType(out)) + } + + allOpts = append(allOpts, opts...) + p := DefinePrompt(r, name, allOpts...) + + return &DataPrompt[In, Out]{prompt: *p.(*prompt)} +} + +// LookupDataPrompt looks up a prompt by name and wraps it with type information. +// This is useful for wrapping prompts loaded from .prompt files with strong types. +// It returns nil if the prompt was not found. +func LookupDataPrompt[In, Out any](r api.Registry, name string) *DataPrompt[In, Out] { + return AsDataPrompt[In, Out](LookupPrompt(r, name)) +} + +// AsDataPrompt wraps an existing Prompt with type information, returning a DataPrompt. +// This is useful for adding strong typing to a dynamically obtained prompt. +func AsDataPrompt[In, Out any](p Prompt) *DataPrompt[In, Out] { + if p == nil { + return nil + } + + return &DataPrompt[In, Out]{prompt: *p.(*prompt)} +} + +// Execute executes the typed prompt and returns the strongly-typed output along with the full model response. +// For structured output types (non-string Out), the prompt must be configured with the appropriate +// output schema, either through [DefineDataPrompt] or by using [WithOutputType] when defining the prompt. +func (dp *DataPrompt[In, Out]) Execute(ctx context.Context, input In, opts ...PromptExecuteOption) (Out, *ModelResponse, error) { + if dp == nil { + return base.Zero[Out](), nil, core.NewError(core.INVALID_ARGUMENT, "DataPrompt.Execute: prompt is nil") + } + + allOpts := append(slices.Clone(opts), WithInput(input)) + resp, err := dp.prompt.Execute(ctx, allOpts...) + if err != nil { + return base.Zero[Out](), nil, err + } + + output, err := extractTypedOutput[Out](resp) + if err != nil { + return base.Zero[Out](), resp, err + } + + return output, resp, nil +} + +// ExecuteStream executes the typed prompt with streaming and returns an iterator. +// +// If the yield function is passed a non-nil error, execution has failed with that +// error; the yield function will not be called again. +// +// If the yield function's StreamValue argument has Done == true, the value's +// Output and Response fields contain the final typed output and response; the yield function +// will not be called again. +// +// Otherwise the Chunk field of the passed StreamValue holds a streamed chunk. +// +// For structured output types (non-string Out), the prompt must be configured with the appropriate +// output schema, either through [DefineDataPrompt] or by using [WithOutputType] when defining the prompt. +func (dp *DataPrompt[In, Out]) ExecuteStream(ctx context.Context, input In, opts ...PromptExecuteOption) iter.Seq2[*StreamValue[Out, Out], error] { + return func(yield func(*StreamValue[Out, Out], error) bool) { + if dp == nil { + yield(nil, core.NewError(core.INVALID_ARGUMENT, "DataPrompt.ExecuteStream: prompt is nil")) + return + } + + cb := func(ctx context.Context, chunk *ModelResponseChunk) error { + if ctx.Err() != nil { + return ctx.Err() + } + streamValue, err := extractTypedOutput[Out](chunk) + if err != nil { + yield(nil, err) + return err + } + // Skip yielding if there's no parseable output yet (e.g., incomplete JSON during streaming). + if base.IsNil(streamValue) { + return nil + } + if !yield(&StreamValue[Out, Out]{Chunk: streamValue}, nil) { + return errGenerateStop + } + return nil + } + + allOpts := append(slices.Clone(opts), WithInput(input), WithStreaming(cb)) + resp, err := dp.prompt.Execute(ctx, allOpts...) + if err != nil { + yield(nil, err) + return + } + + output, err := extractTypedOutput[Out](resp) + if err != nil { + yield(nil, err) + return + } + + yield(&StreamValue[Out, Out]{Done: true, Output: output, Response: resp}, nil) + } +} + +// Render renders the typed prompt template with the given input. +func (dp *DataPrompt[In, Out]) Render(ctx context.Context, input In) (*GenerateActionOptions, error) { + if dp == nil { + return nil, errors.New("DataPrompt.Render: prompt is nil") + } + + return dp.prompt.Render(ctx, input) +} diff --git a/go/ai/prompt_test.go b/go/ai/prompt_test.go index f711f6321b..3228a8f41a 100644 --- a/go/ai/prompt_test.go +++ b/go/ai/prompt_test.go @@ -16,11 +16,13 @@ package ai import ( "context" + "errors" "fmt" "os" "path/filepath" "strings" "testing" + "testing/fstest" "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/api" @@ -885,70 +887,6 @@ func assertResponse(t *testing.T, resp *ModelResponse, want string) { } } -func TestLoadPrompt(t *testing.T) { - // Create a temporary directory for testing - tempDir := t.TempDir() - - // Create a mock .prompt file - mockPromptFile := filepath.Join(tempDir, "example.prompt") - mockPromptContent := `--- -model: test-model -maxTurns: 5 -description: A test prompt -toolChoice: required -returnToolRequests: true -input: - schema: - type: object - properties: - name: - type: string - default: - name: world -output: - format: text - schema: - type: string ---- -Hello, {{name}}! -` - err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0o644) - if err != nil { - t.Fatalf("Failed to create mock prompt file: %v", err) - } - - // Initialize a mock registry - reg := registry.New() - - // Call loadPrompt - LoadPrompt(reg, tempDir, "example.prompt", "test-namespace") - - // Verify that the prompt was registered correctly - prompt := LookupPrompt(reg, "test-namespace/example") - if prompt == nil { - t.Fatalf("Prompt was not registered") - } - - if prompt.(api.Action).Desc().InputSchema == nil { - t.Fatal("Input schema is nil") - } - - if prompt.(api.Action).Desc().InputSchema["type"] != "object" { - t.Errorf("Expected input schema type 'object', got '%s'", prompt.(api.Action).Desc().InputSchema["type"]) - } - - promptMetadata, ok := prompt.(api.Action).Desc().Metadata["prompt"].(map[string]any) - if !ok { - t.Fatalf("Expected Metadata['prompt'] to be a map, but got %T", prompt.(api.Action).Desc().Metadata["prompt"]) - } - if promptMetadata["model"] != "test-model" { - t.Errorf("Expected model name 'test-model', got '%s'", prompt.(api.Action).Desc().Metadata["model"]) - } - if promptMetadata["maxTurns"] != 5 { - t.Errorf("Expected maxTurns set to 5, got: %d", promptMetadata["maxTurns"]) - } -} - func TestLoadPromptSnakeCase(t *testing.T) { tempDir := t.TempDir() mockPromptFile := filepath.Join(tempDir, "snake.prompt") @@ -970,7 +908,7 @@ input: } reg := registry.New() - LoadPrompt(reg, tempDir, "snake.prompt", "snake-namespace") + LoadPromptFromFS(reg, os.DirFS(tempDir), ".", "snake.prompt", "snake-namespace") prompt := LookupPrompt(reg, "snake-namespace/snake") if prompt == nil { @@ -1018,8 +956,9 @@ func TestLoadPrompt_FileNotFound(t *testing.T) { // Initialize a mock registry reg := registry.New() - // Call loadPrompt with a non-existent file - LoadPrompt(reg, "./nonexistent", "missing.prompt", "test-namespace") + // Call loadPrompt with a non-existent file in a valid temp directory + tempDir := t.TempDir() + LoadPromptFromFS(reg, os.DirFS(tempDir), ".", "missing.prompt", "test-namespace") // Verify that the prompt was not registered prompt := LookupPrompt(reg, "missing") @@ -1044,7 +983,7 @@ func TestLoadPrompt_InvalidPromptFile(t *testing.T) { reg := registry.New() // Call loadPrompt - LoadPrompt(reg, tempDir, "invalid.prompt", "test-namespace") + LoadPromptFromFS(reg, os.DirFS(tempDir), ".", "invalid.prompt", "test-namespace") // Verify that the prompt was not registered prompt := LookupPrompt(reg, "invalid") @@ -1075,7 +1014,7 @@ Hello, {{name}}! reg := registry.New() // Call loadPrompt - LoadPrompt(reg, tempDir, "example.variant.prompt", "test-namespace") + LoadPromptFromFS(reg, os.DirFS(tempDir), ".", "example.variant.prompt", "test-namespace") // Verify that the prompt was registered correctly prompt := LookupPrompt(reg, "test-namespace/example.variant") @@ -1096,6 +1035,50 @@ Hello, {{name}}! } } +func TestDefinePrompt_WithVariant(t *testing.T) { + reg := registry.New() + + DefinePrompt(reg, "example.code", WithPrompt("Hello, {{name}}!")) + + prompt := LookupPrompt(reg, "example.code") + if prompt == nil { + t.Fatalf("Prompt was not registered") + } + + promptMetadata, ok := prompt.(api.Action).Desc().Metadata["prompt"].(map[string]any) + if !ok { + t.Fatalf("Expected Metadata['prompt'] to be a map") + } + if promptMetadata["name"] != "example" { + t.Errorf("Expected metadata name 'example', got '%s'", promptMetadata["name"]) + } + if promptMetadata["variant"] != "code" { + t.Errorf("Expected variant 'code', got '%v'", promptMetadata["variant"]) + } +} + +func TestDefinePrompt_WithoutVariant(t *testing.T) { + reg := registry.New() + + DefinePrompt(reg, "simple", WithPrompt("Hello, world!")) + + prompt := LookupPrompt(reg, "simple") + if prompt == nil { + t.Fatalf("Prompt was not registered") + } + + promptMetadata, ok := prompt.(api.Action).Desc().Metadata["prompt"].(map[string]any) + if !ok { + t.Fatalf("Expected Metadata['prompt'] to be a map") + } + if promptMetadata["name"] != "simple" { + t.Errorf("Expected metadata name 'simple', got '%s'", promptMetadata["name"]) + } + if _, exists := promptMetadata["variant"]; exists { + t.Errorf("Expected no variant for prompt without dot, got '%v'", promptMetadata["variant"]) + } +} + func TestLoadPromptFolder(t *testing.T) { // Create a temporary directory for testing tempDir := t.TempDir() @@ -1142,7 +1125,7 @@ Hello, {{name}}! reg := registry.New() // Call LoadPromptFolder - LoadPromptDir(reg, tempDir, "test-namespace") + LoadPromptDirFromFS(reg, os.DirFS(tempDir), ".", "test-namespace") // Verify that the prompt was registered correctly prompt := LookupPrompt(reg, "test-namespace/example") @@ -1157,19 +1140,298 @@ Hello, {{name}}! } } -func TestLoadPromptFolder_DirectoryNotFound(t *testing.T) { +func TestLoadPromptFolder_EmptyDirectory(t *testing.T) { // Initialize a mock registry - reg := ®istry.Registry{} + reg := registry.New() + + // Create an empty temp directory + tempDir := t.TempDir() - // Call LoadPromptFolder with a non-existent directory - LoadPromptDir(reg, "", "test-namespace") + // Call LoadPromptFolder with an empty directory + LoadPromptDirFromFS(reg, os.DirFS(tempDir), ".", "test-namespace") // Verify that no prompts were registered if prompt := LookupPrompt(reg, "example"); prompt != nil { - t.Fatalf("Prompt should not have been registered for a non-existent directory") + t.Fatalf("Prompt should not have been registered for an empty directory") + } +} + +func TestLoadPromptFS(t *testing.T) { + mockPromptContent := `--- +model: test/chat +description: A test prompt +input: + schema: + type: object + properties: + name: + type: string +output: + format: text + schema: + type: string +--- + +Hello, {{name}}! +` + mockPartialContent := `Welcome {{name}}!` + + fsys := fstest.MapFS{ + "prompts/example.prompt": &fstest.MapFile{Data: []byte(mockPromptContent)}, + "prompts/sub/nested.prompt": &fstest.MapFile{Data: []byte(mockPromptContent)}, + "prompts/_greeting.prompt": &fstest.MapFile{Data: []byte(mockPartialContent)}, + } + + reg := registry.New() + + LoadPromptDirFromFS(reg, fsys, "prompts", "test-namespace") + + prompt := LookupPrompt(reg, "test-namespace/example") + if prompt == nil { + t.Fatalf("Prompt 'test-namespace/example' was not registered") + } + + nestedPrompt := LookupPrompt(reg, "test-namespace/nested") + if nestedPrompt == nil { + t.Fatalf("Nested prompt 'test-namespace/nested' was not registered") + } +} + +func TestLoadPromptFS_WithVariant(t *testing.T) { + mockPromptContent := `--- +model: test/chat +description: A test prompt with variant +--- + +Hello from variant! +` + + fsys := fstest.MapFS{ + "prompts/greeting.experimental.prompt": &fstest.MapFile{Data: []byte(mockPromptContent)}, + } + + reg := registry.New() + + LoadPromptDirFromFS(reg, fsys, "prompts", "") + + prompt := LookupPrompt(reg, "greeting.experimental") + if prompt == nil { + t.Fatalf("Prompt with variant 'greeting.experimental' was not registered") + } +} + +func TestLoadPromptFS_NilFS(t *testing.T) { + reg := registry.New() + + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic for nil filesystem") + } + }() + + LoadPromptDirFromFS(reg, nil, "prompts", "test-namespace") +} + +func TestLoadPromptFS_InvalidRoot(t *testing.T) { + fsys := fstest.MapFS{ + "other/example.prompt": &fstest.MapFile{Data: []byte("test")}, + } + + reg := registry.New() + + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic for invalid root directory") + } + }() + + LoadPromptDirFromFS(reg, fsys, "nonexistent", "test-namespace") +} + +func TestLoadPromptFromFS(t *testing.T) { + mockPromptContent := `--- +model: test/chat +description: A single prompt test +--- + +Test content +` + + fsys := fstest.MapFS{ + "prompts/single.prompt": &fstest.MapFile{Data: []byte(mockPromptContent)}, + } + + reg := registry.New() + + prompt := LoadPromptFromFS(reg, fsys, "prompts", "single.prompt", "ns") + if prompt == nil { + t.Fatalf("LoadPromptFromFS failed to load prompt") + } + + lookedUp := LookupPrompt(reg, "ns/single") + if lookedUp == nil { + t.Fatalf("Prompt 'ns/single' was not registered") } } +func TestLoadPromptFromRaw(t *testing.T) { + t.Run("basic prompt", func(t *testing.T) { + reg := registry.New() + + source := `--- +model: test/chat +description: A raw prompt test +input: + schema: + name: string +--- +Hello, {{name}}! +` + prompt, err := LoadPromptFromSource(reg, source, "rawPrompt", "test-ns") + if err != nil { + t.Fatalf("LoadPromptFromRaw failed: %v", err) + } + if prompt == nil { + t.Fatal("LoadPromptFromRaw returned nil prompt") + } + + lookedUp := LookupPrompt(reg, "test-ns/rawPrompt") + if lookedUp == nil { + t.Fatal("Prompt 'test-ns/rawPrompt' was not registered") + } + + actionOpts, err := prompt.Render(context.Background(), map[string]any{"name": "World"}) + if err != nil { + t.Fatalf("Render failed: %v", err) + } + if len(actionOpts.Messages) == 0 { + t.Fatal("Expected messages to be rendered") + } + renderedText := actionOpts.Messages[0].Text() + if renderedText != "Hello, World!" { + t.Errorf("Expected 'Hello, World!', got %q", renderedText) + } + }) + + t.Run("prompt with variant", func(t *testing.T) { + reg := registry.New() + + source := `--- +model: test/chat +description: A variant prompt +--- +Formal greeting +` + prompt, err := LoadPromptFromSource(reg, source, "greeting.formal", "") + if err != nil { + t.Fatalf("LoadPromptFromRaw failed: %v", err) + } + if prompt == nil { + t.Fatal("LoadPromptFromRaw returned nil prompt") + } + + lookedUp := LookupPrompt(reg, "greeting.formal") + if lookedUp == nil { + t.Fatal("Prompt 'greeting.formal' was not registered") + } + + promptMetadata, ok := lookedUp.(api.Action).Desc().Metadata["prompt"].(map[string]any) + if !ok { + t.Fatal("Expected Metadata['prompt'] to be a map") + } + if promptMetadata["name"] != "greeting" { + t.Errorf("Expected metadata name 'greeting', got '%s'", promptMetadata["name"]) + } + if promptMetadata["variant"] != "formal" { + t.Errorf("Expected variant 'formal', got '%v'", promptMetadata["variant"]) + } + }) + + t.Run("prompt without namespace", func(t *testing.T) { + reg := registry.New() + + source := `--- +model: test/chat +--- +Simple prompt +` + prompt, err := LoadPromptFromSource(reg, source, "simple", "") + if err != nil { + t.Fatalf("LoadPromptFromRaw failed: %v", err) + } + if prompt == nil { + t.Fatal("LoadPromptFromRaw returned nil prompt") + } + + lookedUp := LookupPrompt(reg, "simple") + if lookedUp == nil { + t.Fatal("Prompt 'simple' was not registered") + } + }) + + t.Run("prompt with inline output schema", func(t *testing.T) { + reg := registry.New() + ConfigureFormats(reg) + + source := `--- +model: test/chat +output: + format: json + schema: + type: object + properties: + title: + type: string + description: + type: string + required: + - title + - description +--- +Generate something +` + prompt, err := LoadPromptFromSource(reg, source, "outputSchemaPrompt", "") + if err != nil { + t.Fatalf("LoadPromptFromRaw failed: %v", err) + } + if prompt == nil { + t.Fatal("LoadPromptFromRaw returned nil prompt") + } + + actionOpts, err := prompt.Render(context.Background(), nil) + if err != nil { + t.Fatalf("Render failed: %v", err) + } + + // Verify that the output config is set correctly + if actionOpts.Output == nil { + t.Fatal("Expected Output config to be set") + } + if actionOpts.Output.Format != OutputFormatJSON { + t.Errorf("Expected output format 'json', got %q", actionOpts.Output.Format) + } + if actionOpts.Output.JsonSchema == nil { + t.Fatal("Expected output JsonSchema to be set for inline schema") + } + + // Verify the schema structure + schema := actionOpts.Output.JsonSchema + if schema["type"] != "object" { + t.Errorf("Expected schema type 'object', got %v", schema["type"]) + } + properties, ok := schema["properties"].(map[string]any) + if !ok { + t.Fatal("Expected schema properties to be a map") + } + if _, ok := properties["title"]; !ok { + t.Error("Expected schema to have 'title' property") + } + if _, ok := properties["description"]; !ok { + t.Error("Expected schema to have 'description' property") + } + }) +} + // TestDefinePartialAndHelperJourney demonstrates a complete user journey for defining // and using both partials and helpers. func TestDefinePartialAndHelper(t *testing.T) { @@ -1230,72 +1492,34 @@ Hello! ConfigureFormats(reg) definePromptModel(reg) - prompt := LoadPrompt(reg, tempDir, "example.prompt", "multi-namespace") - - _, err = prompt.Execute(context.Background()) - if err != nil { - t.Fatalf("Failed to execute prompt: %v", err) - } -} - -func TestMultiMessagesRenderPrompt(t *testing.T) { - tempDir := t.TempDir() - - mockPromptFile := filepath.Join(tempDir, "example.prompt") - mockPromptContent := `--- -model: test/chat -description: A test prompt ---- -<<>> -You are a pirate! - -<<>> -Hello! -` - - if err := os.WriteFile(mockPromptFile, []byte(mockPromptContent), 0o644); err != nil { - t.Fatalf("Failed to create mock prompt file: %v", err) - } - - prompt := LoadPrompt(registry.New(), tempDir, "example.prompt", "multi-namespace-roles") + prompt := LoadPromptFromFS(reg, os.DirFS(tempDir), ".", "example.prompt", "multi-namespace") - actionOpts, err := prompt.Render(context.Background(), map[string]any{}) + result, err := prompt.Execute(context.Background()) if err != nil { t.Fatalf("Failed to execute prompt: %v", err) } - // Check that actionOpts is not nil - if actionOpts == nil { - t.Fatal("Expected actionOpts to be non-nil") - } - // Check that we have exactly 2 messages (system and user) - if len(actionOpts.Messages) != 2 { - t.Fatalf("Expected 2 messages, got %d", len(actionOpts.Messages)) + if len(result.Request.Messages) != 2 { + t.Fatalf("Expected 2 messages, got %d", len(result.Request.Messages)) } // Check first message (system role) - systemMsg := actionOpts.Messages[0] + systemMsg := result.Request.Messages[0] if systemMsg.Role != RoleSystem { t.Errorf("Expected first message role to be 'system', got '%s'", systemMsg.Role) } - if len(systemMsg.Content) == 0 { - t.Fatal("Expected system message to have content") - } - if strings.TrimSpace(systemMsg.Content[0].Text) != "You are a pirate!" { - t.Errorf("Expected system message text to be 'You are a pirate!', got '%s'", systemMsg.Content[0].Text) + if strings.TrimSpace(systemMsg.Text()) != "You are a pirate!" { + t.Errorf("Expected system message text to be 'You are a pirate!', got '%s'", systemMsg.Text()) } // Check second message (user role) - userMsg := actionOpts.Messages[1] + userMsg := result.Request.Messages[1] if userMsg.Role != RoleUser { t.Errorf("Expected second message role to be 'user', got '%s'", userMsg.Role) } - if len(userMsg.Content) == 0 { - t.Fatal("Expected user message to have content") - } - if strings.TrimSpace(userMsg.Content[0].Text) != "Hello!" { - t.Errorf("Expected user message text to be 'Hello!', got '%s'", userMsg.Content[0].Text) + if strings.TrimSpace(userMsg.Text()) != "Hello!" { + t.Errorf("Expected user message text to be 'Hello!', got '%s'", userMsg.Text()) } } @@ -1518,3 +1742,1243 @@ func TestWithOutputSchemaName_DefinePrompt_Missing(t *testing.T) { t.Errorf("Expected error 'schema \"MissingSchema\" not found', got: %v", err) } } + +func TestDataPromptExecute(t *testing.T) { + r := registry.New() + ConfigureFormats(r) + DefineGenerateAction(context.Background(), r) + + type GreetingInput struct { + Name string `json:"name"` + } + + type GreetingOutput struct { + Message string `json:"message"` + Count int `json:"count"` + } + + t.Run("typed input and output", func(t *testing.T) { + var capturedInput any + + testModel := DefineModel(r, "test/dataPromptModel", &ModelOptions{ + Supports: &ModelSupports{ + Multiturn: true, + Constrained: ConstrainedSupportAll, + }, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + capturedInput = req.Messages[0].Text() + return &ModelResponse{ + Request: req, + Message: &Message{ + Role: RoleModel, + Content: []*Part{NewJSONPart(`{"message":"Hello, Alice!","count":1}`)}, + }, + }, nil + }) + + dp := DefineDataPrompt[GreetingInput, GreetingOutput](r, "greetingPrompt", + WithModel(testModel), + WithPrompt("Greet {{name}}"), + ) + + output, resp, err := dp.Execute(context.Background(), GreetingInput{Name: "Alice"}) + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + + if capturedInput != "Greet Alice" { + t.Errorf("expected input %q, got %q", "Greet Alice", capturedInput) + } + + if output.Message != "Hello, Alice!" { + t.Errorf("expected message %q, got %q", "Hello, Alice!", output.Message) + } + if output.Count != 1 { + t.Errorf("expected count 1, got %d", output.Count) + } + if resp == nil { + t.Error("expected response to be returned") + } + }) + + t.Run("string output type", func(t *testing.T) { + testModel := DefineModel(r, "test/stringDataPromptModel", &ModelOptions{ + Supports: &ModelSupports{Multiturn: true}, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("Hello, World!"), + }, nil + }) + + dp := DefineDataPrompt[GreetingInput, string](r, "stringOutputPrompt", + WithModel(testModel), + WithPrompt("Say hello to {{name}}"), + ) + + output, resp, err := dp.Execute(context.Background(), GreetingInput{Name: "World"}) + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + + if output != "Hello, World!" { + t.Errorf("expected output %q, got %q", "Hello, World!", output) + } + if resp == nil { + t.Error("expected response to be returned") + } + }) + + t.Run("nil prompt returns error", func(t *testing.T) { + var dp *DataPrompt[GreetingInput, GreetingOutput] + + _, _, err := dp.Execute(context.Background(), GreetingInput{Name: "test"}) + if err == nil { + t.Error("expected error for nil prompt") + } + }) + + t.Run("additional options passed through", func(t *testing.T) { + var capturedConfig any + + testModel := DefineModel(r, "test/optionsDataPromptModel", &ModelOptions{ + Supports: &ModelSupports{ + Multiturn: true, + Constrained: ConstrainedSupportAll, + }, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + capturedConfig = req.Config + return &ModelResponse{ + Request: req, + Message: &Message{ + Role: RoleModel, + Content: []*Part{NewJSONPart(`{"message":"test","count":0}`)}, + }, + }, nil + }) + + dp := DefineDataPrompt[GreetingInput, GreetingOutput](r, "optionsPrompt", + WithModel(testModel), + WithPrompt("Test {{name}}"), + ) + + _, _, err := dp.Execute(context.Background(), GreetingInput{Name: "test"}, + WithConfig(&GenerationCommonConfig{Temperature: 0.5}), + ) + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + + config, ok := capturedConfig.(*GenerationCommonConfig) + if !ok { + t.Fatalf("expected *GenerationCommonConfig, got %T", capturedConfig) + } + if config.Temperature != 0.5 { + t.Errorf("expected temperature 0.5, got %v", config.Temperature) + } + }) + + t.Run("returns error for invalid output parsing", func(t *testing.T) { + testModel := DefineModel(r, "test/parseFailDataPromptModel", &ModelOptions{ + Supports: &ModelSupports{ + Multiturn: true, + Constrained: ConstrainedSupportAll, + }, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("not valid json"), + }, nil + }) + + dp := DefineDataPrompt[GreetingInput, GreetingOutput](r, "parseFailPrompt", + WithModel(testModel), + WithPrompt("Test {{name}}"), + ) + + _, _, err := dp.Execute(context.Background(), GreetingInput{Name: "test"}) + if err == nil { + t.Error("expected error for invalid JSON output") + } + }) +} + +func TestDataPromptExecuteStream(t *testing.T) { + r := registry.New() + ConfigureFormats(r) + DefineGenerateAction(context.Background(), r) + + type StreamInput struct { + Topic string `json:"topic"` + } + + type StreamOutput struct { + Text string `json:"text"` + Index int `json:"index"` + } + + t.Run("typed streaming with struct output", func(t *testing.T) { + testModel := DefineModel(r, "test/streamDataPromptModel", &ModelOptions{ + Supports: &ModelSupports{ + Multiturn: true, + Constrained: ConstrainedSupportAll, + }, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + if cb != nil { + cb(ctx, &ModelResponseChunk{ + Content: []*Part{NewJSONPart(`{"text":"chunk1","index":1}`)}, + }) + cb(ctx, &ModelResponseChunk{ + Content: []*Part{NewJSONPart(`{"text":"final","index":99}`)}, + }) + } + return &ModelResponse{ + Request: req, + Message: &Message{ + Role: RoleModel, + Content: []*Part{NewJSONPart(`{"text":"final","index":99}`)}, + }, + }, nil + }) + + dp := DefineDataPrompt[StreamInput, StreamOutput](r, "streamPrompt", + WithModel(testModel), + WithPrompt("Stream about {{topic}}"), + ) + + var chunks []StreamOutput + var finalOutput StreamOutput + var finalResponse *ModelResponse + + for val, err := range dp.ExecuteStream(context.Background(), StreamInput{Topic: "testing"}) { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if val.Done { + finalOutput = val.Output + finalResponse = val.Response + } else { + chunks = append(chunks, val.Chunk) + } + } + + if len(chunks) < 1 { + t.Errorf("expected at least 1 chunk, got %d", len(chunks)) + } + + if finalOutput.Text != "final" || finalOutput.Index != 99 { + t.Errorf("expected final {final, 99}, got %+v", finalOutput) + } + if finalResponse == nil { + t.Error("expected final response") + } + }) + + t.Run("string output streaming", func(t *testing.T) { + testModel := DefineModel(r, "test/stringStreamDataPromptModel", &ModelOptions{ + Supports: &ModelSupports{Multiturn: true}, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + if cb != nil { + cb(ctx, &ModelResponseChunk{ + Content: []*Part{NewTextPart("First ")}, + }) + cb(ctx, &ModelResponseChunk{ + Content: []*Part{NewTextPart("Second")}, + }) + } + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("First Second"), + }, nil + }) + + dp := DefineDataPrompt[StreamInput, string](r, "stringStreamPrompt", + WithModel(testModel), + WithPrompt("Generate text about {{topic}}"), + ) + + var chunks []string + var finalOutput string + + for val, err := range dp.ExecuteStream(context.Background(), StreamInput{Topic: "strings"}) { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if val.Done { + finalOutput = val.Output + } else { + chunks = append(chunks, val.Chunk) + } + } + + if len(chunks) != 2 { + t.Errorf("expected 2 chunks, got %d", len(chunks)) + } + if chunks[0] != "First " { + t.Errorf("chunk 0: expected %q, got %q", "First ", chunks[0]) + } + if chunks[1] != "Second" { + t.Errorf("chunk 1: expected %q, got %q", "Second", chunks[1]) + } + + if finalOutput != "First Second" { + t.Errorf("expected final %q, got %q", "First Second", finalOutput) + } + }) + + t.Run("nil prompt returns error", func(t *testing.T) { + var dp *DataPrompt[StreamInput, StreamOutput] + + var receivedErr error + for _, err := range dp.ExecuteStream(context.Background(), StreamInput{Topic: "test"}) { + if err != nil { + receivedErr = err + break + } + } + + if receivedErr == nil { + t.Error("expected error for nil prompt") + } + }) + + t.Run("handles options passed at execute time", func(t *testing.T) { + var capturedConfig any + + testModel := DefineModel(r, "test/optionsStreamModel", &ModelOptions{ + Supports: &ModelSupports{ + Multiturn: true, + Constrained: ConstrainedSupportAll, + }, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + capturedConfig = req.Config + if cb != nil { + cb(ctx, &ModelResponseChunk{ + Content: []*Part{NewJSONPart(`{"text":"chunk","index":1}`)}, + }) + } + return &ModelResponse{ + Request: req, + Message: &Message{ + Role: RoleModel, + Content: []*Part{NewJSONPart(`{"text":"done","index":2}`)}, + }, + }, nil + }) + + dp := DefineDataPrompt[StreamInput, StreamOutput](r, "optionsStreamPrompt", + WithModel(testModel), + WithPrompt("Test {{topic}}"), + ) + + for range dp.ExecuteStream(context.Background(), StreamInput{Topic: "options"}, + WithConfig(&GenerationCommonConfig{Temperature: 0.7}), + ) { + } + + config, ok := capturedConfig.(*GenerationCommonConfig) + if !ok { + t.Fatalf("expected *GenerationCommonConfig, got %T", capturedConfig) + } + if config.Temperature != 0.7 { + t.Errorf("expected temperature 0.7, got %v", config.Temperature) + } + }) + + t.Run("propagates errors", func(t *testing.T) { + expectedErr := errors.New("stream failed") + + testModel := DefineModel(r, "test/errorStreamDataPromptModel", &ModelOptions{ + Supports: &ModelSupports{Multiturn: true}, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return nil, expectedErr + }) + + dp := DefineDataPrompt[StreamInput, StreamOutput](r, "errorStreamPrompt", + WithModel(testModel), + WithPrompt("Test {{topic}}"), + ) + + var receivedErr error + for _, err := range dp.ExecuteStream(context.Background(), StreamInput{Topic: "error"}) { + if err != nil { + receivedErr = err + break + } + } + + if receivedErr == nil { + t.Error("expected error to be propagated") + } + if !errors.Is(receivedErr, expectedErr) { + t.Errorf("expected error %v, got %v", expectedErr, receivedErr) + } + }) +} + +func TestPromptExecuteStream(t *testing.T) { + r := registry.New() + ConfigureFormats(r) + DefineGenerateAction(context.Background(), r) + + t.Run("yields chunks then final response", func(t *testing.T) { + chunkTexts := []string{"A", "B", "C"} + + testModel := DefineModel(r, "test/promptStreamModel", &ModelOptions{ + Supports: &ModelSupports{Multiturn: true}, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + if cb != nil { + for _, text := range chunkTexts { + cb(ctx, &ModelResponseChunk{ + Content: []*Part{NewTextPart(text)}, + }) + } + } + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("ABC"), + }, nil + }) + + p := DefinePrompt(r, "streamTestPrompt", + WithModel(testModel), + WithPrompt("Test"), + ) + + var chunks []*ModelResponseChunk + var finalResponse *ModelResponse + + for val, err := range p.ExecuteStream(context.Background()) { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if val.Done { + finalResponse = val.Response + } else { + chunks = append(chunks, val.Chunk) + } + } + + if len(chunks) != 3 { + t.Errorf("expected 3 chunks, got %d", len(chunks)) + } + for i, chunk := range chunks { + if chunk.Text() != chunkTexts[i] { + t.Errorf("chunk %d: expected %q, got %q", i, chunkTexts[i], chunk.Text()) + } + } + + if finalResponse == nil { + t.Fatal("expected final response") + } + if finalResponse.Text() != "ABC" { + t.Errorf("expected final text %q, got %q", "ABC", finalResponse.Text()) + } + }) + + t.Run("nil prompt returns error", func(t *testing.T) { + var p *prompt + + var receivedErr error + for _, err := range p.ExecuteStream(context.Background()) { + if err != nil { + receivedErr = err + break + } + } + + if receivedErr == nil { + t.Error("expected error for nil prompt") + } + }) + + t.Run("handles execution options", func(t *testing.T) { + var capturedConfig any + + testModel := DefineModel(r, "test/optionsPromptExecModel", &ModelOptions{ + Supports: &ModelSupports{Multiturn: true}, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + capturedConfig = req.Config + if cb != nil { + cb(ctx, &ModelResponseChunk{ + Content: []*Part{NewTextPart("chunk")}, + }) + } + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("done"), + }, nil + }) + + p := DefinePrompt(r, "execOptionsTestPrompt", + WithModel(testModel), + WithPrompt("Test"), + ) + + for range p.ExecuteStream(context.Background(), + WithConfig(&GenerationCommonConfig{Temperature: 0.9}), + ) { + } + + config, ok := capturedConfig.(*GenerationCommonConfig) + if !ok { + t.Fatalf("expected *GenerationCommonConfig, got %T", capturedConfig) + } + if config.Temperature != 0.9 { + t.Errorf("expected temperature 0.9, got %v", config.Temperature) + } + }) +} + +// TestDefineExecuteOptionInteractions tests the complex interactions between +// options set at DefinePrompt time vs Execute time. +func TestDefineExecuteOptionInteractions(t *testing.T) { + t.Run("ToolChoice override", func(t *testing.T) { + r := newTestRegistry(t) + var captured *ModelRequest + + model := defineFakeModel(t, r, fakeModelConfig{ + name: "test/toolChoiceModel", + handler: capturingModelHandler(&captured), + }) + + tool := defineFakeTool(t, r, "testTool", "a test tool") + + // Define with ToolChoiceAuto + p := DefinePrompt(r, "toolChoicePrompt", + WithModel(model), + WithPrompt("test"), + WithTools(tool), + WithToolChoice(ToolChoiceAuto), + WithMaxTurns(1), + ) + + // Execute with ToolChoiceRequired - should override + _, err := p.Execute(context.Background(), + WithToolChoice(ToolChoiceRequired), + ) + assertNoError(t, err) + + if captured.ToolChoice != ToolChoiceRequired { + t.Errorf("ToolChoice = %q, want %q", captured.ToolChoice, ToolChoiceRequired) + } + }) + + t.Run("ToolChoice no override when not specified at execute", func(t *testing.T) { + r := newTestRegistry(t) + var captured *ModelRequest + + model := defineFakeModel(t, r, fakeModelConfig{ + name: "test/toolChoiceNoOverride", + handler: capturingModelHandler(&captured), + }) + + tool := defineFakeTool(t, r, "testTool2", "a test tool") + + // Define with ToolChoiceRequired + p := DefinePrompt(r, "toolChoiceNoOverridePrompt", + WithModel(model), + WithPrompt("test"), + WithTools(tool), + WithToolChoice(ToolChoiceRequired), + WithMaxTurns(1), + ) + + // Execute without specifying ToolChoice - should use define-time value + _, err := p.Execute(context.Background()) + assertNoError(t, err) + + if captured.ToolChoice != ToolChoiceRequired { + t.Errorf("ToolChoice = %q, want %q", captured.ToolChoice, ToolChoiceRequired) + } + }) + + t.Run("MaxTurns override", func(t *testing.T) { + r := newTestRegistry(t) + callCount := 0 + + model := defineFakeModel(t, r, fakeModelConfig{ + name: "test/maxTurnsModel", + handler: func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + callCount++ + // Always request tool call to test max turns + if callCount < 10 { + return &ModelResponse{ + Request: req, + Message: &Message{ + Role: RoleModel, + Content: []*Part{NewToolRequestPart(&ToolRequest{ + Name: "maxTurnsTool", + Input: map[string]any{"value": "test"}, + })}, + }, + }, nil + } + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("done"), + }, nil + }, + }) + + tool := defineFakeTool(t, r, "maxTurnsTool", "a tool for max turns test") + + // Define with MaxTurns 5 + p := DefinePrompt(r, "maxTurnsPrompt", + WithModel(model), + WithPrompt("test"), + WithTools(tool), + WithMaxTurns(5), + ) + + // Execute with MaxTurns 2 - should override and stop after 2 turns + _, err := p.Execute(context.Background(), + WithMaxTurns(2), + ) + + // Should error due to max turns exceeded + if err == nil { + t.Error("expected max turns error, got nil") + } + // Call count should be limited by execute-time MaxTurns (2) + 1 for initial + if callCount > 3 { + t.Errorf("callCount = %d, expected <= 3 (limited by execute MaxTurns)", callCount) + } + }) + + t.Run("ReturnToolRequests override", func(t *testing.T) { + r := newTestRegistry(t) + + model := defineFakeModel(t, r, fakeModelConfig{ + name: "test/returnToolReqsModel", + handler: func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return &ModelResponse{ + Request: req, + Message: &Message{ + Role: RoleModel, + Content: []*Part{NewToolRequestPart(&ToolRequest{ + Name: "returnToolReqsTool", + Input: map[string]any{"value": "test"}, + })}, + }, + }, nil + }, + }) + + tool := defineFakeTool(t, r, "returnToolReqsTool", "tool for return requests test") + + // Define with ReturnToolRequests false (default) + p := DefinePrompt(r, "returnToolReqsPrompt", + WithModel(model), + WithPrompt("test"), + WithTools(tool), + WithReturnToolRequests(false), + WithMaxTurns(1), + ) + + // Execute with ReturnToolRequests true - should override and return tool requests + resp, err := p.Execute(context.Background(), + WithReturnToolRequests(true), + ) + assertNoError(t, err) + + // Should have tool request in response + hasToolRequest := false + for _, part := range resp.Message.Content { + if part.IsToolRequest() { + hasToolRequest = true + break + } + } + if !hasToolRequest { + t.Error("expected tool request in response when ReturnToolRequests=true") + } + }) + + t.Run("Tools complete replacement", func(t *testing.T) { + r := newTestRegistry(t) + var captured *ModelRequest + + model := defineFakeModel(t, r, fakeModelConfig{ + name: "test/toolsReplaceModel", + handler: capturingModelHandler(&captured), + }) + + toolA := defineFakeTool(t, r, "toolA", "tool A") + toolB := defineFakeTool(t, r, "toolB", "tool B") + toolC := defineFakeTool(t, r, "toolC", "tool C") + + // Define with tools A and B + p := DefinePrompt(r, "toolsReplacePrompt", + WithModel(model), + WithPrompt("test"), + WithTools(toolA, toolB), + WithMaxTurns(1), + ) + + // Execute with tool C - should REPLACE (not merge) define-time tools + _, err := p.Execute(context.Background(), + WithTools(toolC), + ) + assertNoError(t, err) + + // Should only have tool C + if len(captured.Tools) != 1 { + t.Errorf("len(Tools) = %d, want 1", len(captured.Tools)) + } + if len(captured.Tools) > 0 && captured.Tools[0].Name != "toolC" { + t.Errorf("Tool name = %q, want %q", captured.Tools[0].Name, "toolC") + } + }) + + t.Run("Tools inherit when not specified at execute", func(t *testing.T) { + r := newTestRegistry(t) + var captured *ModelRequest + + model := defineFakeModel(t, r, fakeModelConfig{ + name: "test/toolsInheritModel", + handler: capturingModelHandler(&captured), + }) + + toolA := defineFakeTool(t, r, "toolInheritA", "tool A") + toolB := defineFakeTool(t, r, "toolInheritB", "tool B") + + // Define with tools A and B + p := DefinePrompt(r, "toolsInheritPrompt", + WithModel(model), + WithPrompt("test"), + WithTools(toolA, toolB), + WithMaxTurns(1), + ) + + // Execute without specifying tools - should inherit define-time tools + _, err := p.Execute(context.Background()) + assertNoError(t, err) + + if len(captured.Tools) != 2 { + t.Errorf("len(Tools) = %d, want 2", len(captured.Tools)) + } + }) + + t.Run("Docs at execute time", func(t *testing.T) { + r := newTestRegistry(t) + var captured *ModelRequest + + model := defineFakeModel(t, r, fakeModelConfig{ + name: "test/docsModel", + handler: capturingModelHandler(&captured), + }) + + // Define without docs + p := DefinePrompt(r, "docsPrompt", + WithModel(model), + WithPrompt("test"), + ) + + // Execute with docs + doc := DocumentFromText("context document", nil) + _, err := p.Execute(context.Background(), + WithDocs(doc), + ) + assertNoError(t, err) + + if len(captured.Docs) != 1 { + t.Errorf("len(Docs) = %d, want 1", len(captured.Docs)) + } + }) + + t.Run("Config replacement not merge", func(t *testing.T) { + r := newTestRegistry(t) + var captured *ModelRequest + + model := defineFakeModel(t, r, fakeModelConfig{ + name: "test/configReplaceModel", + handler: capturingModelHandler(&captured), + }) + + // Define with Temperature and TopK + p := DefinePrompt(r, "configReplacePrompt", + WithModel(model), + WithPrompt("test"), + WithConfig(&GenerationCommonConfig{Temperature: 0.5, TopK: 10}), + ) + + // Execute with only Temperature - config is REPLACED, not merged + _, err := p.Execute(context.Background(), + WithConfig(&GenerationCommonConfig{Temperature: 0.9}), + ) + assertNoError(t, err) + + config, ok := captured.Config.(*GenerationCommonConfig) + if !ok { + t.Fatalf("Config type = %T, want *GenerationCommonConfig", captured.Config) + } + if config.Temperature != 0.9 { + t.Errorf("Temperature = %v, want 0.9", config.Temperature) + } + // TopK should be zero (default) since config was replaced + if config.TopK != 0 { + t.Errorf("TopK = %v, want 0 (config replaced, not merged)", config.TopK) + } + }) + + t.Run("Model override at execute time", func(t *testing.T) { + r := newTestRegistry(t) + + defineModel := defineFakeModel(t, r, fakeModelConfig{ + name: "test/defineModel", + handler: func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("from define model"), + }, nil + }, + }) + + executeModel := defineFakeModel(t, r, fakeModelConfig{ + name: "test/executeModel", + handler: func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("from execute model"), + }, nil + }, + }) + + // Define with defineModel + p := DefinePrompt(r, "modelOverridePrompt", + WithModel(defineModel), + WithPrompt("test"), + ) + + // Execute with executeModel - should use execute model + resp, err := p.Execute(context.Background(), + WithModel(executeModel), + ) + assertNoError(t, err) + + if resp.Text() != "from execute model" { + t.Errorf("response = %q, want %q", resp.Text(), "from execute model") + } + }) + + t.Run("MessagesFn at execute time inserts between system and user", func(t *testing.T) { + r := newTestRegistry(t) + var captured *ModelRequest + + model := defineFakeModel(t, r, fakeModelConfig{ + name: "test/messagesFnModel", + handler: capturingModelHandler(&captured), + }) + + // Define with system and user prompt + p := DefinePrompt(r, "messagesFnPrompt", + WithModel(model), + WithSystem("system instruction"), + WithPrompt("user question"), + ) + + // Execute with MessagesFn - messages should be inserted between system and user + _, err := p.Execute(context.Background(), + WithMessages(NewModelTextMessage("conversation history")), + ) + assertNoError(t, err) + + // Expected order: system, MessagesFn content, user + if len(captured.Messages) != 3 { + t.Fatalf("len(Messages) = %d, want 3", len(captured.Messages)) + } + if captured.Messages[0].Role != RoleSystem { + t.Errorf("Messages[0].Role = %q, want %q", captured.Messages[0].Role, RoleSystem) + } + if captured.Messages[1].Role != RoleModel { + t.Errorf("Messages[1].Role = %q, want %q", captured.Messages[1].Role, RoleModel) + } + if captured.Messages[2].Role != RoleUser { + t.Errorf("Messages[2].Role = %q, want %q", captured.Messages[2].Role, RoleUser) + } + }) + + t.Run("ModelRef config used when no explicit config", func(t *testing.T) { + r := newTestRegistry(t) + var captured *ModelRequest + + // Define model first + defineFakeModel(t, r, fakeModelConfig{ + name: "test/modelRefConfigModel", + handler: capturingModelHandler(&captured), + }) + + // Create ModelRef with embedded config + modelRef := NewModelRef("test/modelRefConfigModel", &GenerationCommonConfig{Temperature: 0.7}) + + p := DefinePrompt(r, "modelRefConfigPrompt", + WithModel(modelRef), + WithPrompt("test"), + ) + + // Execute without config - should use ModelRef's config + _, err := p.Execute(context.Background()) + assertNoError(t, err) + + config, ok := captured.Config.(*GenerationCommonConfig) + if !ok { + t.Fatalf("Config type = %T, want *GenerationCommonConfig", captured.Config) + } + if config.Temperature != 0.7 { + t.Errorf("Temperature = %v, want 0.7", config.Temperature) + } + }) + + t.Run("Explicit config overrides ModelRef config", func(t *testing.T) { + r := newTestRegistry(t) + var captured *ModelRequest + + defineFakeModel(t, r, fakeModelConfig{ + name: "test/modelRefOverrideModel", + handler: capturingModelHandler(&captured), + }) + + modelRef := NewModelRef("test/modelRefOverrideModel", &GenerationCommonConfig{Temperature: 0.7}) + + p := DefinePrompt(r, "modelRefOverridePrompt", + WithModel(modelRef), + WithPrompt("test"), + ) + + // Execute with explicit config - should override ModelRef's config + _, err := p.Execute(context.Background(), + WithConfig(&GenerationCommonConfig{Temperature: 0.3}), + ) + assertNoError(t, err) + + config, ok := captured.Config.(*GenerationCommonConfig) + if !ok { + t.Fatalf("Config type = %T, want *GenerationCommonConfig", captured.Config) + } + if config.Temperature != 0.3 { + t.Errorf("Temperature = %v, want 0.3", config.Temperature) + } + }) +} + +// TestPromptErrorPaths tests error handling in prompt operations. +func TestPromptErrorPaths(t *testing.T) { + t.Run("DefinePrompt with empty name panics", func(t *testing.T) { + r := newTestRegistry(t) + assertPanic(t, func() { + DefinePrompt(r, "") + }, "name is required") + }) + + t.Run("Execute on nil prompt returns error", func(t *testing.T) { + var p *prompt + _, err := p.Execute(context.Background()) + assertError(t, err, "prompt is nil") + }) + + t.Run("Render on nil prompt returns error", func(t *testing.T) { + var p *prompt + _, err := p.Render(context.Background(), nil) + assertError(t, err, "prompt is nil") + }) + + t.Run("ExecuteStream on nil prompt yields error", func(t *testing.T) { + var p *prompt + var gotErr error + for _, err := range p.ExecuteStream(context.Background()) { + if err != nil { + gotErr = err + break + } + } + assertError(t, gotErr, "prompt is nil") + }) + + t.Run("buildVariables with invalid type returns error", func(t *testing.T) { + // buildVariables expects struct, pointer to struct, or map + _, err := buildVariables(42) // int is not valid + if err == nil { + t.Error("expected error for invalid type, got nil") + } + }) +} + +// TestLookupPromptCoverage tests LookupPrompt edge cases. +func TestLookupPromptCoverage(t *testing.T) { + t.Run("returns nil for non-existent prompt", func(t *testing.T) { + r := newTestRegistry(t) + p := LookupPrompt(r, "nonexistent") + if p != nil { + t.Error("expected nil for non-existent prompt") + } + }) + + t.Run("returns prompt for existing prompt", func(t *testing.T) { + r := newTestRegistry(t) + DefinePrompt(r, "existingPrompt", WithPrompt("hello")) + p := LookupPrompt(r, "existingPrompt") + if p == nil { + t.Error("expected prompt, got nil") + } + if p.Name() != "existingPrompt" { + t.Errorf("Name() = %q, want %q", p.Name(), "existingPrompt") + } + }) +} + +// TestDataPromptRender tests DataPrompt.Render method. +func TestDataPromptRender(t *testing.T) { + r := newTestRegistry(t) + + type RenderInput struct { + Name string `json:"name"` + } + + type RenderOutput struct { + Greeting string `json:"greeting"` + } + + model := defineFakeModel(t, r, fakeModelConfig{ + name: "test/renderModel", + }) + + dp := DefineDataPrompt[RenderInput, RenderOutput](r, "renderPrompt", + WithModel(model), + WithPrompt("Hello {{name}}"), + ) + + t.Run("renders with typed input", func(t *testing.T) { + opts, err := dp.Render(context.Background(), RenderInput{Name: "World"}) + assertNoError(t, err) + + if len(opts.Messages) == 0 { + t.Fatal("expected messages") + } + if opts.Messages[0].Text() != "Hello World" { + t.Errorf("rendered text = %q, want %q", opts.Messages[0].Text(), "Hello World") + } + }) + + t.Run("nil DataPrompt returns error", func(t *testing.T) { + var nilDP *DataPrompt[RenderInput, RenderOutput] + _, err := nilDP.Render(context.Background(), RenderInput{}) + if err == nil { + t.Error("expected error for nil DataPrompt") + } + }) +} + +// TestLookupDataPrompt tests LookupDataPrompt function. +func TestLookupDataPrompt(t *testing.T) { + r := newTestRegistry(t) + + model := defineFakeModel(t, r, fakeModelConfig{ + name: "test/lookupDataModel", + }) + + DefinePrompt(r, "lookupDataPrompt", + WithModel(model), + WithPrompt("test"), + ) + + t.Run("returns DataPrompt for existing prompt", func(t *testing.T) { + dp := LookupDataPrompt[map[string]any, string](r, "lookupDataPrompt") + if dp == nil { + t.Error("expected DataPrompt, got nil") + } + }) + + t.Run("returns nil for non-existent prompt", func(t *testing.T) { + dp := LookupDataPrompt[map[string]any, string](r, "nonexistent") + if dp != nil { + t.Error("expected nil for non-existent prompt") + } + }) +} + +// TestAsDataPrompt tests AsDataPrompt function. +func TestAsDataPrompt(t *testing.T) { + r := newTestRegistry(t) + + model := defineFakeModel(t, r, fakeModelConfig{ + name: "test/asDataModel", + }) + + p := DefinePrompt(r, "asDataPrompt", + WithModel(model), + WithPrompt("test"), + ) + + t.Run("wraps existing prompt", func(t *testing.T) { + dp := AsDataPrompt[map[string]any, string](p) + if dp == nil { + t.Error("expected DataPrompt, got nil") + } + }) + + t.Run("returns nil for nil prompt", func(t *testing.T) { + dp := AsDataPrompt[map[string]any, string](nil) + if dp != nil { + t.Error("expected nil for nil prompt") + } + }) +} + +// TestPromptKeyVariantKey tests the prompt key generation helpers. +func TestPromptKeyVariantKey(t *testing.T) { + tests := []struct { + name string + promptName string + variant string + namespace string + want string + }{ + { + name: "simple name", + promptName: "greeting", + want: "greeting", + }, + { + name: "with variant", + promptName: "greeting", + variant: "formal", + want: "greeting.formal", + }, + { + name: "with namespace", + promptName: "greeting", + namespace: "myapp", + want: "myapp/greeting", + }, + { + name: "with namespace and variant", + promptName: "greeting", + variant: "formal", + namespace: "myapp", + want: "myapp/greeting.formal", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := promptKey(tt.promptName, tt.variant, tt.namespace) + if got != tt.want { + t.Errorf("promptKey(%q, %q, %q) = %q, want %q", + tt.promptName, tt.variant, tt.namespace, got, tt.want) + } + }) + } +} + +// TestContentType tests the contentType helper function. +func TestContentType(t *testing.T) { + tests := []struct { + name string + ct string + uri string + wantCT string + wantData string + wantErr bool + errContains string + }{ + { + name: "gs:// URL with content type", + ct: "image/png", + uri: "gs://bucket/image.png", + wantCT: "image/png", + wantData: "gs://bucket/image.png", + }, + { + name: "gs:// URL without content type", + ct: "", + uri: "gs://bucket/image.png", + wantErr: true, + errContains: "must supply contentType", + }, + { + name: "http URL with content type", + ct: "image/jpeg", + uri: "https://example.com/image.jpg", + wantCT: "image/jpeg", + wantData: "https://example.com/image.jpg", + }, + { + name: "http URL without content type", + ct: "", + uri: "https://example.com/image.jpg", + wantErr: true, + errContains: "must supply contentType", + }, + { + name: "data URI with base64", + ct: "", + uri: "data:image/png;base64,iVBORw0KGgo=", + wantCT: "image/png", + wantData: "data:image/png;base64,iVBORw0KGgo=", + }, + { + name: "data URI with explicit content type override", + ct: "image/jpeg", + uri: "data:image/png;base64,iVBORw0KGgo=", + wantCT: "image/jpeg", + wantData: "data:image/png;base64,iVBORw0KGgo=", + }, + { + name: "empty URI", + ct: "image/png", + uri: "", + wantErr: true, + errContains: "found empty URI", + }, + { + name: "malformed data URI", + ct: "", + uri: "data:image/png", + wantErr: true, + errContains: "missing comma", + }, + { + name: "unknown URI scheme", + ct: "", + uri: "file:///path/to/file", + wantErr: true, + errContains: "uri content type not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotCT, gotData, err := contentType(tt.ct, tt.uri) + + if tt.wantErr { + if err == nil { + t.Errorf("expected error containing %q, got nil", tt.errContains) + } else if !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("error = %q, want containing %q", err.Error(), tt.errContains) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if gotCT != tt.wantCT { + t.Errorf("contentType = %q, want %q", gotCT, tt.wantCT) + } + if string(gotData) != tt.wantData { + t.Errorf("data = %q, want %q", string(gotData), tt.wantData) + } + }) + } +} + +// TestDefineDataPromptPanics tests panic conditions in DefineDataPrompt. +func TestDefineDataPromptPanics(t *testing.T) { + t.Run("empty name panics", func(t *testing.T) { + r := newTestRegistry(t) + assertPanic(t, func() { + DefineDataPrompt[map[string]any, string](r, "") + }, "name is required") + }) +} diff --git a/go/ai/request_helpers_test.go b/go/ai/request_helpers_test.go new file mode 100644 index 0000000000..1b1c80d4c1 --- /dev/null +++ b/go/ai/request_helpers_test.go @@ -0,0 +1,371 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package ai + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestNewModelRequest(t *testing.T) { + t.Run("creates request with config and messages", func(t *testing.T) { + config := &GenerationCommonConfig{Temperature: 0.7} + msg1 := NewUserTextMessage("hello") + msg2 := NewModelTextMessage("hi there") + + req := NewModelRequest(config, msg1, msg2) + + if req.Config != config { + t.Error("Config not set correctly") + } + if len(req.Messages) != 2 { + t.Errorf("len(Messages) = %d, want 2", len(req.Messages)) + } + if req.Messages[0] != msg1 { + t.Error("First message not set correctly") + } + if req.Messages[1] != msg2 { + t.Error("Second message not set correctly") + } + }) + + t.Run("creates request with nil config", func(t *testing.T) { + msg := NewUserTextMessage("hello") + req := NewModelRequest(nil, msg) + + if req.Config != nil { + t.Errorf("Config = %v, want nil", req.Config) + } + if len(req.Messages) != 1 { + t.Errorf("len(Messages) = %d, want 1", len(req.Messages)) + } + }) + + t.Run("creates request with no messages", func(t *testing.T) { + config := map[string]any{"temp": 0.5} + req := NewModelRequest(config) + + if req.Config == nil { + t.Error("Config should not be nil") + } + if len(req.Messages) != 0 { + t.Errorf("len(Messages) = %d, want 0", len(req.Messages)) + } + }) +} + +func TestNewUserMessage(t *testing.T) { + t.Run("creates user message with parts", func(t *testing.T) { + parts := []*Part{NewTextPart("text"), NewMediaPart("image/png", "data:...")} + msg := NewUserMessage(parts...) + + if msg.Role != RoleUser { + t.Errorf("Role = %q, want %q", msg.Role, RoleUser) + } + if len(msg.Content) != 2 { + t.Errorf("len(Content) = %d, want 2", len(msg.Content)) + } + if msg.Metadata != nil { + t.Errorf("Metadata = %v, want nil", msg.Metadata) + } + }) + + t.Run("creates user message with no parts", func(t *testing.T) { + msg := NewUserMessage() + + if msg.Role != RoleUser { + t.Errorf("Role = %q, want %q", msg.Role, RoleUser) + } + if len(msg.Content) != 0 { + t.Errorf("len(Content) = %d, want 0", len(msg.Content)) + } + }) +} + +func TestNewUserMessageWithMetadata(t *testing.T) { + t.Run("creates user message with metadata", func(t *testing.T) { + metadata := map[string]any{"purpose": "context"} + parts := []*Part{NewTextPart("text")} + msg := NewUserMessageWithMetadata(metadata, parts...) + + if msg.Role != RoleUser { + t.Errorf("Role = %q, want %q", msg.Role, RoleUser) + } + if diff := cmp.Diff(metadata, msg.Metadata); diff != "" { + t.Errorf("Metadata mismatch (-want +got):\n%s", diff) + } + }) + + t.Run("creates user message with nil metadata", func(t *testing.T) { + msg := NewUserMessageWithMetadata(nil, NewTextPart("text")) + + if msg.Role != RoleUser { + t.Errorf("Role = %q, want %q", msg.Role, RoleUser) + } + if msg.Metadata != nil { + t.Errorf("Metadata = %v, want nil", msg.Metadata) + } + }) +} + +func TestNewUserTextMessage(t *testing.T) { + t.Run("creates text message with user role", func(t *testing.T) { + msg := NewUserTextMessage("hello world") + + if msg.Role != RoleUser { + t.Errorf("Role = %q, want %q", msg.Role, RoleUser) + } + if len(msg.Content) != 1 { + t.Fatalf("len(Content) = %d, want 1", len(msg.Content)) + } + if msg.Content[0].Text != "hello world" { + t.Errorf("Text = %q, want %q", msg.Content[0].Text, "hello world") + } + }) + + t.Run("creates text message with empty string", func(t *testing.T) { + msg := NewUserTextMessage("") + + if msg.Role != RoleUser { + t.Errorf("Role = %q, want %q", msg.Role, RoleUser) + } + if len(msg.Content) != 1 { + t.Fatalf("len(Content) = %d, want 1", len(msg.Content)) + } + if msg.Content[0].Text != "" { + t.Errorf("Text = %q, want empty string", msg.Content[0].Text) + } + }) +} + +func TestNewModelMessage(t *testing.T) { + t.Run("creates model message with parts", func(t *testing.T) { + parts := []*Part{NewTextPart("response")} + msg := NewModelMessage(parts...) + + if msg.Role != RoleModel { + t.Errorf("Role = %q, want %q", msg.Role, RoleModel) + } + if len(msg.Content) != 1 { + t.Errorf("len(Content) = %d, want 1", len(msg.Content)) + } + }) +} + +func TestNewModelTextMessage(t *testing.T) { + t.Run("creates text message with model role", func(t *testing.T) { + msg := NewModelTextMessage("model response") + + if msg.Role != RoleModel { + t.Errorf("Role = %q, want %q", msg.Role, RoleModel) + } + if len(msg.Content) != 1 { + t.Fatalf("len(Content) = %d, want 1", len(msg.Content)) + } + if msg.Content[0].Text != "model response" { + t.Errorf("Text = %q, want %q", msg.Content[0].Text, "model response") + } + }) +} + +func TestNewSystemMessage(t *testing.T) { + t.Run("creates system message with parts", func(t *testing.T) { + parts := []*Part{NewTextPart("system instruction")} + msg := NewSystemMessage(parts...) + + if msg.Role != RoleSystem { + t.Errorf("Role = %q, want %q", msg.Role, RoleSystem) + } + if len(msg.Content) != 1 { + t.Errorf("len(Content) = %d, want 1", len(msg.Content)) + } + }) +} + +func TestNewSystemTextMessage(t *testing.T) { + t.Run("creates text message with system role", func(t *testing.T) { + msg := NewSystemTextMessage("be helpful") + + if msg.Role != RoleSystem { + t.Errorf("Role = %q, want %q", msg.Role, RoleSystem) + } + if len(msg.Content) != 1 { + t.Fatalf("len(Content) = %d, want 1", len(msg.Content)) + } + if msg.Content[0].Text != "be helpful" { + t.Errorf("Text = %q, want %q", msg.Content[0].Text, "be helpful") + } + }) +} + +func TestNewMessage(t *testing.T) { + t.Run("creates message with all fields", func(t *testing.T) { + metadata := map[string]any{"key": "value"} + parts := []*Part{NewTextPart("content")} + msg := NewMessage(RoleTool, metadata, parts...) + + if msg.Role != RoleTool { + t.Errorf("Role = %q, want %q", msg.Role, RoleTool) + } + if diff := cmp.Diff(metadata, msg.Metadata); diff != "" { + t.Errorf("Metadata mismatch (-want +got):\n%s", diff) + } + if len(msg.Content) != 1 { + t.Errorf("len(Content) = %d, want 1", len(msg.Content)) + } + }) +} + +func TestNewTextMessage(t *testing.T) { + t.Run("creates text message with specified role", func(t *testing.T) { + msg := NewTextMessage(RoleTool, "tool output") + + if msg.Role != RoleTool { + t.Errorf("Role = %q, want %q", msg.Role, RoleTool) + } + if len(msg.Content) != 1 { + t.Fatalf("len(Content) = %d, want 1", len(msg.Content)) + } + if msg.Content[0].Text != "tool output" { + t.Errorf("Text = %q, want %q", msg.Content[0].Text, "tool output") + } + }) +} + +func TestWithCacheTTL(t *testing.T) { + t.Run("adds cache TTL to message without existing metadata", func(t *testing.T) { + original := NewUserTextMessage("hello") + result := original.WithCacheTTL(3600) + + // Original should be unchanged + if original.Metadata != nil { + t.Error("original message metadata should be nil") + } + + // Result should have cache metadata + if result.Metadata == nil { + t.Fatal("result metadata should not be nil") + } + cache, ok := result.Metadata["cache"].(map[string]any) + if !ok { + t.Fatal("cache metadata not found or wrong type") + } + if cache["ttlSeconds"] != 3600 { + t.Errorf("ttlSeconds = %v, want 3600", cache["ttlSeconds"]) + } + + // Content and role should be preserved + if result.Role != original.Role { + t.Errorf("Role changed: got %q, want %q", result.Role, original.Role) + } + if len(result.Content) != len(original.Content) { + t.Errorf("Content length changed") + } + }) + + t.Run("adds cache TTL to message with existing metadata", func(t *testing.T) { + original := NewUserMessageWithMetadata( + map[string]any{"existing": "value"}, + NewTextPart("hello"), + ) + result := original.WithCacheTTL(1800) + + // Result should have both existing and cache metadata + if result.Metadata["existing"] != "value" { + t.Error("existing metadata not preserved") + } + cache, ok := result.Metadata["cache"].(map[string]any) + if !ok { + t.Fatal("cache metadata not found") + } + if cache["ttlSeconds"] != 1800 { + t.Errorf("ttlSeconds = %v, want 1800", cache["ttlSeconds"]) + } + }) + + t.Run("chained with WithCacheName", func(t *testing.T) { + msg := NewUserTextMessage("hello"). + WithCacheTTL(3600). + WithCacheName("my-cache") + + cache, ok := msg.Metadata["cache"].(map[string]any) + if !ok { + t.Fatal("cache metadata not found") + } + // Note: second call overwrites the cache object + if cache["name"] != "my-cache" { + t.Errorf("cache name = %v, want %q", cache["name"], "my-cache") + } + }) +} + +func TestWithCacheName(t *testing.T) { + t.Run("adds cache name to message without existing metadata", func(t *testing.T) { + original := NewUserTextMessage("hello") + result := original.WithCacheName("my-cache") + + // Original should be unchanged + if original.Metadata != nil { + t.Error("original message metadata should be nil") + } + + // Result should have cache metadata + if result.Metadata == nil { + t.Fatal("result metadata should not be nil") + } + cache, ok := result.Metadata["cache"].(map[string]any) + if !ok { + t.Fatal("cache metadata not found or wrong type") + } + if cache["name"] != "my-cache" { + t.Errorf("name = %v, want %q", cache["name"], "my-cache") + } + }) + + t.Run("adds cache name to message with existing metadata", func(t *testing.T) { + original := NewUserMessageWithMetadata( + map[string]any{"existing": "value"}, + NewTextPart("hello"), + ) + result := original.WithCacheName("another-cache") + + // Result should have both existing and cache metadata + if result.Metadata["existing"] != "value" { + t.Error("existing metadata not preserved") + } + cache, ok := result.Metadata["cache"].(map[string]any) + if !ok { + t.Fatal("cache metadata not found") + } + if cache["name"] != "another-cache" { + t.Errorf("name = %v, want %q", cache["name"], "another-cache") + } + }) + + t.Run("with empty name", func(t *testing.T) { + msg := NewUserTextMessage("hello").WithCacheName("") + + cache, ok := msg.Metadata["cache"].(map[string]any) + if !ok { + t.Fatal("cache metadata not found") + } + if cache["name"] != "" { + t.Errorf("name = %v, want empty string", cache["name"]) + } + }) +} diff --git a/go/ai/resource_test.go b/go/ai/resource_test.go index e7b56b65eb..30f15b807c 100644 --- a/go/ai/resource_test.go +++ b/go/ai/resource_test.go @@ -272,3 +272,85 @@ func TestMultipleDynamicResourcesInGeneration(t *testing.T) { func contains(s, substr string) bool { return strings.Contains(s, substr) } + +func TestLookupResource(t *testing.T) { + t.Run("finds registered resource", func(t *testing.T) { + r := registry.New() + DefineResource(r, "test/lookup", &ResourceOptions{ + URI: "lookup://test", + }, func(ctx context.Context, input *ResourceInput) (*ResourceOutput, error) { + return &ResourceOutput{ + Content: []*Part{NewTextPart("found")}, + }, nil + }) + + found := LookupResource(r, "test/lookup") + if found == nil { + t.Fatal("LookupResource returned nil") + } + if found.Name() != "test/lookup" { + t.Errorf("Name() = %q, want %q", found.Name(), "test/lookup") + } + }) + + t.Run("returns nil for non-existent resource", func(t *testing.T) { + r := registry.New() + + found := LookupResource(r, "test/nonexistent") + if found != nil { + t.Errorf("LookupResource returned %v, want nil", found) + } + }) + + t.Run("resource can be executed after lookup", func(t *testing.T) { + r := registry.New() + DefineResource(r, "test/executable", &ResourceOptions{ + URI: "exec://test", + }, func(ctx context.Context, input *ResourceInput) (*ResourceOutput, error) { + return &ResourceOutput{ + Content: []*Part{NewTextPart("executed: " + input.URI)}, + }, nil + }) + + found := LookupResource(r, "test/executable") + if found == nil { + t.Fatal("LookupResource returned nil") + } + + output, err := found.Execute(context.Background(), &ResourceInput{URI: "exec://test", Variables: map[string]string{}}) + if err != nil { + t.Fatalf("Execute error: %v", err) + } + if len(output.Content) != 1 || output.Content[0].Text != "executed: exec://test" { + t.Errorf("unexpected output: %v", output.Content) + } + }) + + t.Run("resource matches and extracts variables after lookup", func(t *testing.T) { + r := registry.New() + DefineResource(r, "test/template", &ResourceOptions{ + Template: "template://item/{id}", + }, func(ctx context.Context, input *ResourceInput) (*ResourceOutput, error) { + return &ResourceOutput{ + Content: []*Part{NewTextPart("item " + input.Variables["id"])}, + }, nil + }) + + found := LookupResource(r, "test/template") + if found == nil { + t.Fatal("LookupResource returned nil") + } + + if !found.Matches("template://item/123") { + t.Error("Matches() = false, want true") + } + + vars, err := found.ExtractVariables("template://item/456") + if err != nil { + t.Fatalf("ExtractVariables error: %v", err) + } + if vars["id"] != "456" { + t.Errorf("vars[id] = %q, want %q", vars["id"], "456") + } + }) +} diff --git a/go/ai/retriever_test.go b/go/ai/retriever_test.go new file mode 100644 index 0000000000..ecbf019f71 --- /dev/null +++ b/go/ai/retriever_test.go @@ -0,0 +1,407 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package ai + +import ( + "context" + "errors" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestRetrieverRef(t *testing.T) { + t.Run("NewRetrieverRef creates ref with name and config", func(t *testing.T) { + config := map[string]any{"topK": 10} + ref := NewRetrieverRef("test/retriever", config) + + if ref.Name() != "test/retriever" { + t.Errorf("Name() = %q, want %q", ref.Name(), "test/retriever") + } + if diff := cmp.Diff(config, ref.Config()); diff != "" { + t.Errorf("Config() mismatch (-want +got):\n%s", diff) + } + }) + + t.Run("NewRetrieverRef with nil config", func(t *testing.T) { + ref := NewRetrieverRef("test/retriever", nil) + + if ref.Name() != "test/retriever" { + t.Errorf("Name() = %q, want %q", ref.Name(), "test/retriever") + } + if ref.Config() != nil { + t.Errorf("Config() = %v, want nil", ref.Config()) + } + }) +} + +func TestNewRetriever(t *testing.T) { + t.Run("creates retriever with valid name", func(t *testing.T) { + r := NewRetriever("test/retriever", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { + return &RetrieverResponse{}, nil + }) + + if r == nil { + t.Fatal("expected retriever, got nil") + } + if r.Name() != "test/retriever" { + t.Errorf("Name() = %q, want %q", r.Name(), "test/retriever") + } + }) + + t.Run("panics with empty name", func(t *testing.T) { + assertPanic(t, func() { + NewRetriever("", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { + return &RetrieverResponse{}, nil + }) + }, "name is required") + }) + + t.Run("applies options correctly", func(t *testing.T) { + opts := &RetrieverOptions{ + Label: "Test Retriever", + Supports: &RetrieverSupports{ + Media: true, + }, + ConfigSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "topK": map[string]any{"type": "integer"}, + }, + }, + } + + r := NewRetriever("test/retriever", opts, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { + return &RetrieverResponse{}, nil + }) + + if r == nil { + t.Fatal("expected retriever, got nil") + } + }) + + t.Run("uses defaults when options nil", func(t *testing.T) { + r := NewRetriever("test/retriever", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { + return &RetrieverResponse{}, nil + }) + + if r == nil { + t.Fatal("expected retriever, got nil") + } + }) +} + +func TestDefineRetriever(t *testing.T) { + t.Run("registers and returns retriever", func(t *testing.T) { + reg := newTestRegistry(t) + called := false + expectedDocs := []*Document{ + DocumentFromText("result 1", nil), + DocumentFromText("result 2", nil), + } + + r := DefineRetriever(reg, "test/defineRetriever", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { + called = true + return &RetrieverResponse{Documents: expectedDocs}, nil + }) + + if r == nil { + t.Fatal("expected retriever, got nil") + } + + // Verify it's registered by looking it up + found := LookupRetriever(reg, "test/defineRetriever") + if found == nil { + t.Fatal("LookupRetriever returned nil for registered retriever") + } + + // Verify the function works + resp, err := r.Retrieve(context.Background(), &RetrieverRequest{ + Query: DocumentFromText("query", nil), + }) + assertNoError(t, err) + if !called { + t.Error("retriever function was not called") + } + if len(resp.Documents) != 2 { + t.Errorf("len(Documents) = %d, want 2", len(resp.Documents)) + } + }) +} + +func TestLookupRetriever(t *testing.T) { + t.Run("returns retriever when found", func(t *testing.T) { + reg := newTestRegistry(t) + DefineRetriever(reg, "test/lookupRetriever", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { + return &RetrieverResponse{}, nil + }) + + r := LookupRetriever(reg, "test/lookupRetriever") + if r == nil { + t.Error("expected retriever, got nil") + } + }) + + t.Run("returns nil when not found", func(t *testing.T) { + reg := newTestRegistry(t) + + r := LookupRetriever(reg, "nonexistent") + if r != nil { + t.Error("expected nil for non-existent retriever") + } + }) +} + +func TestRetrieverRetrieve(t *testing.T) { + t.Run("retrieves documents successfully", func(t *testing.T) { + reg := newTestRegistry(t) + var capturedReq *RetrieverRequest + + expectedDocs := []*Document{ + DocumentFromText("relevant result 1", map[string]any{"score": 0.9}), + DocumentFromText("relevant result 2", map[string]any{"score": 0.8}), + } + + r := DefineRetriever(reg, "test/retrieveDocs", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { + capturedReq = req + return &RetrieverResponse{Documents: expectedDocs}, nil + }) + + query := DocumentFromText("search query", nil) + resp, err := r.Retrieve(context.Background(), &RetrieverRequest{Query: query}) + assertNoError(t, err) + + if len(capturedReq.Query.Content) == 0 || capturedReq.Query.Content[0].Text != "search query" { + t.Errorf("captured query content mismatch") + } + if len(resp.Documents) != 2 { + t.Errorf("len(Documents) = %d, want 2", len(resp.Documents)) + } + }) + + t.Run("returns error on nil retriever", func(t *testing.T) { + var r *retriever + _, err := r.Retrieve(context.Background(), &RetrieverRequest{}) + if err == nil { + t.Error("expected error for nil retriever") + } + }) + + t.Run("propagates function errors", func(t *testing.T) { + reg := newTestRegistry(t) + expectedErr := errors.New("retrieval failed") + + r := DefineRetriever(reg, "test/retrieveError", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { + return nil, expectedErr + }) + + _, err := r.Retrieve(context.Background(), &RetrieverRequest{ + Query: DocumentFromText("query", nil), + }) + if err == nil { + t.Error("expected error, got nil") + } + }) + + t.Run("passes options through request", func(t *testing.T) { + reg := newTestRegistry(t) + var capturedOpts any + + r := DefineRetriever(reg, "test/retrieveOpts", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { + capturedOpts = req.Options + return &RetrieverResponse{Documents: []*Document{}}, nil + }) + + opts := map[string]any{"topK": 5, "threshold": 0.7} + _, err := r.Retrieve(context.Background(), &RetrieverRequest{ + Query: DocumentFromText("query", nil), + Options: opts, + }) + assertNoError(t, err) + + if diff := cmp.Diff(opts, capturedOpts); diff != "" { + t.Errorf("Options mismatch (-want +got):\n%s", diff) + } + }) +} + +func TestRetrieveFunction(t *testing.T) { + t.Run("retrieves with retriever directly", func(t *testing.T) { + reg := newTestRegistry(t) + r := DefineRetriever(reg, "test/retrieveFunc", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { + return &RetrieverResponse{ + Documents: []*Document{DocumentFromText("result", nil)}, + }, nil + }) + + resp, err := Retrieve(context.Background(), reg, + WithRetriever(r), + WithTextDocs("query"), + ) + assertNoError(t, err) + + if len(resp.Documents) != 1 { + t.Errorf("len(Documents) = %d, want 1", len(resp.Documents)) + } + }) + + t.Run("retrieves with retriever ref", func(t *testing.T) { + reg := newTestRegistry(t) + DefineRetriever(reg, "test/retrieveFuncRef", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { + return &RetrieverResponse{ + Documents: []*Document{DocumentFromText("result", nil)}, + }, nil + }) + + ref := NewRetrieverRef("test/retrieveFuncRef", nil) + resp, err := Retrieve(context.Background(), reg, + WithRetriever(ref), + WithTextDocs("query"), + ) + assertNoError(t, err) + + if len(resp.Documents) != 1 { + t.Errorf("len(Documents) = %d, want 1", len(resp.Documents)) + } + }) + + t.Run("retrieves with retriever name", func(t *testing.T) { + reg := newTestRegistry(t) + DefineRetriever(reg, "test/retrieveFuncName", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { + return &RetrieverResponse{ + Documents: []*Document{DocumentFromText("result", nil)}, + }, nil + }) + + resp, err := Retrieve(context.Background(), reg, + WithRetrieverName("test/retrieveFuncName"), + WithTextDocs("query"), + ) + assertNoError(t, err) + + if len(resp.Documents) != 1 { + t.Errorf("len(Documents) = %d, want 1", len(resp.Documents)) + } + }) + + t.Run("uses config from RetrieverRef", func(t *testing.T) { + reg := newTestRegistry(t) + var capturedOpts any + + DefineRetriever(reg, "test/retrieveRefConfig", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { + capturedOpts = req.Options + return &RetrieverResponse{Documents: []*Document{}}, nil + }) + + config := map[string]any{"topK": 10} + ref := NewRetrieverRef("test/retrieveRefConfig", config) + + _, err := Retrieve(context.Background(), reg, + WithRetriever(ref), + WithTextDocs("query"), + ) + assertNoError(t, err) + + if diff := cmp.Diff(config, capturedOpts); diff != "" { + t.Errorf("Options mismatch (-want +got):\n%s", diff) + } + }) + + t.Run("explicit config overrides RetrieverRef config", func(t *testing.T) { + reg := newTestRegistry(t) + var capturedOpts any + + DefineRetriever(reg, "test/retrieveOverrideConfig", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { + capturedOpts = req.Options + return &RetrieverResponse{Documents: []*Document{}}, nil + }) + + refConfig := map[string]any{"topK": 10} + explicitConfig := map[string]any{"topK": 5} + ref := NewRetrieverRef("test/retrieveOverrideConfig", refConfig) + + _, err := Retrieve(context.Background(), reg, + WithRetriever(ref), + WithConfig(explicitConfig), + WithTextDocs("query"), + ) + assertNoError(t, err) + + if diff := cmp.Diff(explicitConfig, capturedOpts); diff != "" { + t.Errorf("Options mismatch (-want +got):\n%s", diff) + } + }) + + t.Run("returns error when retriever not set", func(t *testing.T) { + reg := newTestRegistry(t) + + _, err := Retrieve(context.Background(), reg, + WithTextDocs("query"), + ) + assertError(t, err, "retriever must be set") + }) + + t.Run("returns error when retriever not found", func(t *testing.T) { + reg := newTestRegistry(t) + + _, err := Retrieve(context.Background(), reg, + WithRetrieverName("nonexistent"), + WithTextDocs("query"), + ) + assertError(t, err, "retriever not found") + }) + + t.Run("returns error with multiple documents", func(t *testing.T) { + reg := newTestRegistry(t) + DefineRetriever(reg, "test/retrieveMultiDoc", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { + return &RetrieverResponse{Documents: []*Document{}}, nil + }) + + _, err := Retrieve(context.Background(), reg, + WithRetrieverName("test/retrieveMultiDoc"), + WithDocs( + DocumentFromText("doc1", nil), + DocumentFromText("doc2", nil), + ), + ) + assertError(t, err, "only supports a single document") + }) + + t.Run("retrieves with document options", func(t *testing.T) { + reg := newTestRegistry(t) + var capturedQuery *Document + + DefineRetriever(reg, "test/retrieveDocOpts", nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { + capturedQuery = req.Query + return &RetrieverResponse{Documents: []*Document{}}, nil + }) + + query := DocumentFromText("custom query", map[string]any{"custom": "metadata"}) + _, err := Retrieve(context.Background(), reg, + WithRetrieverName("test/retrieveDocOpts"), + WithDocs(query), + ) + assertNoError(t, err) + + if len(capturedQuery.Content) == 0 || capturedQuery.Content[0].Text != "custom query" { + t.Errorf("query content mismatch") + } + if capturedQuery.Metadata["custom"] != "metadata" { + t.Error("query metadata not passed correctly") + } + }) +} diff --git a/go/ai/testutil_test.go b/go/ai/testutil_test.go new file mode 100644 index 0000000000..6c606a28a8 --- /dev/null +++ b/go/ai/testutil_test.go @@ -0,0 +1,306 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package ai + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/internal/registry" + "github.com/google/go-cmp/cmp" +) + +// newTestRegistry creates a fresh registry for testing with formats configured. +func newTestRegistry(t *testing.T) api.Registry { + t.Helper() + r := registry.New() + ConfigureFormats(r) + return r +} + +// fakeModelConfig holds configuration for creating a fake model. +type fakeModelConfig struct { + name string + supports *ModelSupports + handler func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) +} + +// defaultModelSupports returns a ModelSupports with common capabilities enabled. +func defaultModelSupports() *ModelSupports { + return &ModelSupports{ + Tools: true, + Multiturn: true, + ToolChoice: true, + SystemRole: true, + Constrained: ConstrainedSupportAll, + } +} + +// defineFakeModel creates a configurable fake model for testing. +func defineFakeModel(t *testing.T, r api.Registry, cfg fakeModelConfig) Model { + t.Helper() + if cfg.name == "" { + cfg.name = "test/fakeModel" + } + if cfg.supports == nil { + cfg.supports = defaultModelSupports() + } + if cfg.handler == nil { + cfg.handler = func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("fake response"), + }, nil + } + } + return DefineModel(r, cfg.name, &ModelOptions{Supports: cfg.supports}, cfg.handler) +} + +// echoModelHandler creates a handler that echoes back information about the request. +// Useful for verifying that options are properly passed through. +func echoModelHandler() func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + var parts []string + + // Echo messages + for _, msg := range req.Messages { + parts = append(parts, fmt.Sprintf("%s: %s", msg.Role, msg.Text())) + } + + // Echo config if present + if req.Config != nil { + if cfg, ok := req.Config.(*GenerationCommonConfig); ok { + parts = append(parts, fmt.Sprintf("temp=%.1f", cfg.Temperature)) + } + } + + // Echo tool count + if len(req.Tools) > 0 { + parts = append(parts, fmt.Sprintf("tools=%d", len(req.Tools))) + } + + // Echo tool choice + if req.ToolChoice != "" { + parts = append(parts, fmt.Sprintf("toolChoice=%s", req.ToolChoice)) + } + + // Echo docs count + if len(req.Docs) > 0 { + parts = append(parts, fmt.Sprintf("docs=%d", len(req.Docs))) + } + + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage(strings.Join(parts, "; ")), + }, nil + } +} + +// capturingModelHandler returns a handler that captures the request for inspection. +func capturingModelHandler(captured **ModelRequest) func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + *captured = req + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("captured"), + }, nil + } +} + +// streamingModelHandler creates a handler that sends chunks before returning. +func streamingModelHandler(chunks []string, finalText string) func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + if cb != nil { + for _, chunk := range chunks { + if err := cb(ctx, &ModelResponseChunk{ + Content: []*Part{NewTextPart(chunk)}, + }); err != nil { + return nil, err + } + } + } + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage(finalText), + }, nil + } +} + +// jsonModelHandler creates a handler that returns JSON output. +func jsonModelHandler(jsonOutput string) func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + return &ModelResponse{ + Request: req, + Message: &Message{ + Role: RoleModel, + Content: []*Part{NewJSONPart(jsonOutput)}, + }, + }, nil + } +} + +// toolCallingModelHandler creates a handler that makes a tool call on first request, +// then returns the final response after receiving the tool response. +func toolCallingModelHandler(toolName string, toolInput map[string]any, finalResponse string) func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + callCount := 0 + return func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + callCount++ + + // Check if we already have a tool response + hasToolResponse := false + for _, msg := range req.Messages { + for _, part := range msg.Content { + if part.IsToolResponse() { + hasToolResponse = true + break + } + } + } + + if !hasToolResponse && len(req.Tools) > 0 { + // First call - request tool execution + return &ModelResponse{ + Request: req, + Message: &Message{ + Role: RoleModel, + Content: []*Part{NewToolRequestPart(&ToolRequest{ + Name: toolName, + Input: toolInput, + })}, + }, + }, nil + } + + // Tool response received or no tools - return final response + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage(finalResponse), + }, nil + } +} + +// cmpPartEqual is a Part comparator for cmp.Diff that compares essential fields. +func cmpPartEqual(a, b *Part) bool { + if a == nil || b == nil { + return a == b + } + if a.Kind != b.Kind { + return false + } + if a.Text != b.Text { + return false + } + if a.ContentType != b.ContentType { + return false + } + return true +} + +// cmpPartComparer returns a cmp.Option for comparing Parts. +func cmpPartComparer() cmp.Option { + return cmp.Comparer(cmpPartEqual) +} + +// assertEqual compares two values and reports differences. +func assertEqual[T any](t *testing.T, got, want T, opts ...cmp.Option) { + t.Helper() + if diff := cmp.Diff(want, got, opts...); diff != "" { + t.Errorf("mismatch (-want +got):\n%s", diff) + } +} + +// assertError verifies error is non-nil and contains expected substring. +func assertError(t *testing.T, err error, wantContains string) { + t.Helper() + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), wantContains) { + t.Errorf("error %q does not contain %q", err.Error(), wantContains) + } +} + +// assertNoError fails the test if err is not nil. +func assertNoError(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } +} + +// assertPanic verifies that fn panics and the panic value contains wantContains. +func assertPanic(t *testing.T, fn func(), wantContains string) { + t.Helper() + defer func() { + r := recover() + if r == nil { + t.Fatal("expected panic, got none") + } + msg := fmt.Sprint(r) + if !strings.Contains(msg, wantContains) { + t.Errorf("panic %q does not contain %q", msg, wantContains) + } + }() + fn() +} + +// assertNoPanic verifies that fn does not panic. +func assertNoPanic(t *testing.T, fn func()) { + t.Helper() + defer func() { + if r := recover(); r != nil { + t.Fatalf("unexpected panic: %v", r) + } + }() + fn() +} + +// defineFakeTool creates a simple tool for testing. +func defineFakeTool(t *testing.T, r api.Registry, name, description string) Tool { + t.Helper() + return DefineTool(r, name, description, + func(ctx *ToolContext, input struct { + Value string `json:"value"` + }) (string, error) { + return "tool result: " + input.Value, nil + }) +} + +// defineFakeEmbedder creates a simple embedder for testing. +func defineFakeEmbedder(t *testing.T, r api.Registry, name string) Embedder { + t.Helper() + return DefineEmbedder(r, name, nil, func(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) { + embeddings := make([]*Embedding, len(req.Input)) + for i := range req.Input { + embeddings[i] = &Embedding{ + Embedding: []float32{0.1, 0.2, 0.3}, + } + } + return &EmbedResponse{Embeddings: embeddings}, nil + }) +} + +// defineFakeRetriever creates a simple retriever for testing. +func defineFakeRetriever(t *testing.T, r api.Registry, name string, docs []*Document) Retriever { + t.Helper() + return DefineRetriever(r, name, nil, func(ctx context.Context, req *RetrieverRequest) (*RetrieverResponse, error) { + return &RetrieverResponse{Documents: docs}, nil + }) +} diff --git a/go/ai/tools_test.go b/go/ai/tools_test.go new file mode 100644 index 0000000000..857be8c2b2 --- /dev/null +++ b/go/ai/tools_test.go @@ -0,0 +1,908 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package ai + +import ( + "context" + "errors" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestToolName(t *testing.T) { + t.Run("Name returns string value", func(t *testing.T) { + tn := ToolName("myTool") + got := tn.Name() + want := "myTool" + if got != want { + t.Errorf("Name() = %q, want %q", got, want) + } + }) + + t.Run("empty tool name", func(t *testing.T) { + tn := ToolName("") + got := tn.Name() + if got != "" { + t.Errorf("Name() = %q, want empty string", got) + } + }) +} + +func TestToolInterruptError(t *testing.T) { + t.Run("Error returns fixed message", func(t *testing.T) { + err := &toolInterruptError{Metadata: map[string]any{"key": "value"}} + got := err.Error() + want := "tool execution interrupted" + if got != want { + t.Errorf("Error() = %q, want %q", got, want) + } + }) +} + +func TestIsToolInterruptError(t *testing.T) { + t.Run("returns true for toolInterruptError", func(t *testing.T) { + meta := map[string]any{"reason": "user cancelled"} + err := &toolInterruptError{Metadata: meta} + + isInterrupt, gotMeta := IsToolInterruptError(err) + + if !isInterrupt { + t.Error("IsToolInterruptError() = false, want true") + } + if diff := cmp.Diff(meta, gotMeta); diff != "" { + t.Errorf("metadata mismatch (-want +got):\n%s", diff) + } + }) + + t.Run("returns true for wrapped toolInterruptError", func(t *testing.T) { + meta := map[string]any{"step": 3} + innerErr := &toolInterruptError{Metadata: meta} + wrappedErr := errors.New("context: " + innerErr.Error()) + // Use proper wrapping + wrappedErr = &wrappedInterruptError{cause: innerErr} + + isInterrupt, gotMeta := IsToolInterruptError(wrappedErr) + + if !isInterrupt { + t.Error("IsToolInterruptError(wrapped) = false, want true") + } + if gotMeta["step"] != 3 { + t.Errorf("metadata[step] = %v, want 3", gotMeta["step"]) + } + }) + + t.Run("returns false for regular error", func(t *testing.T) { + err := errors.New("some error") + + isInterrupt, meta := IsToolInterruptError(err) + + if isInterrupt { + t.Error("IsToolInterruptError(regular error) = true, want false") + } + if meta != nil { + t.Errorf("metadata = %v, want nil", meta) + } + }) + + t.Run("returns false for nil error", func(t *testing.T) { + isInterrupt, meta := IsToolInterruptError(nil) + + if isInterrupt { + t.Error("IsToolInterruptError(nil) = true, want false") + } + if meta != nil { + t.Errorf("metadata = %v, want nil", meta) + } + }) +} + +// wrappedInterruptError is a helper for testing error unwrapping. +type wrappedInterruptError struct { + cause error +} + +func (e *wrappedInterruptError) Error() string { + return "wrapped: " + e.cause.Error() +} + +func (e *wrappedInterruptError) Unwrap() error { + return e.cause +} + +func TestDefineTool(t *testing.T) { + t.Run("creates and registers tool", func(t *testing.T) { + r := newTestRegistry(t) + tl := DefineTool(r, "provider/addNumbers", "Adds two numbers", func(ctx *ToolContext, input struct { + A int `json:"a"` + B int `json:"b"` + }) (int, error) { + return input.A + input.B, nil + }) + + if tl == nil { + t.Fatal("DefineTool returned nil") + } + if tl.Name() != "provider/addNumbers" { + t.Errorf("Name() = %q, want %q", tl.Name(), "provider/addNumbers") + } + + def := tl.Definition() + if def.Description != "Adds two numbers" { + t.Errorf("Description = %q, want %q", def.Description, "Adds two numbers") + } + }) + + t.Run("tool can be looked up after registration", func(t *testing.T) { + r := newTestRegistry(t) + DefineTool(r, "provider/multiply", "Multiplies", func(ctx *ToolContext, input struct { + X int `json:"x"` + Y int `json:"y"` + }) (int, error) { + return input.X * input.Y, nil + }) + + found := LookupTool(r, "provider/multiply") + if found == nil { + t.Error("LookupTool returned nil for registered tool") + } + }) + + t.Run("tool executes correctly", func(t *testing.T) { + r := newTestRegistry(t) + tl := DefineTool(r, "provider/concat", "Concatenates strings", func(ctx *ToolContext, input struct { + A string `json:"a"` + B string `json:"b"` + }) (string, error) { + return input.A + input.B, nil + }) + + output, err := tl.RunRaw(context.Background(), map[string]any{ + "a": "hello", + "b": "world", + }) + + if err != nil { + t.Fatalf("RunRaw error: %v", err) + } + if output != "helloworld" { + t.Errorf("output = %v, want %q", output, "helloworld") + } + }) +} + +func TestDefineToolWithInputSchema(t *testing.T) { + t.Run("creates tool with custom input schema", func(t *testing.T) { + r := newTestRegistry(t) + customSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{"type": "string"}, + }, + "required": []any{"query"}, + } + + tl := DefineToolWithInputSchema(r, "provider/search", "Searches", customSchema, + func(ctx *ToolContext, input any) (string, error) { + m := input.(map[string]any) + return "results for: " + m["query"].(string), nil + }) + + if tl == nil { + t.Fatal("DefineToolWithInputSchema returned nil") + } + + def := tl.Definition() + if def.InputSchema == nil { + t.Error("InputSchema is nil") + } + }) +} + +func TestNewTool(t *testing.T) { + t.Run("creates unregistered tool", func(t *testing.T) { + tl := NewTool("dynamicTool", "A dynamic tool", func(ctx *ToolContext, input struct { + Value int `json:"value"` + }) (int, error) { + return input.Value * 2, nil + }) + + if tl == nil { + t.Fatal("NewTool returned nil") + } + if tl.Name() != "dynamicTool" { + t.Errorf("Name() = %q, want %q", tl.Name(), "dynamicTool") + } + }) + + t.Run("unregistered tool can be executed", func(t *testing.T) { + tl := NewTool("double", "Doubles a number", func(ctx *ToolContext, input struct { + N int `json:"n"` + }) (int, error) { + return input.N * 2, nil + }) + + output, err := tl.RunRaw(context.Background(), map[string]any{"n": 5}) + if err != nil { + t.Fatalf("RunRaw error: %v", err) + } + // JSON unmarshalling returns float64 for numbers + if output != float64(10) { + t.Errorf("output = %v (%T), want 10", output, output) + } + }) + + t.Run("tool can be registered later", func(t *testing.T) { + r := newTestRegistry(t) + tl := NewTool("provider/laterTool", "Registered later", func(ctx *ToolContext, input struct{}) (string, error) { + return "done", nil + }) + + tl.Register(r) + + found := LookupTool(r, "provider/laterTool") + if found == nil { + t.Error("LookupTool returned nil after registration") + } + }) +} + +func TestNewToolWithInputSchema(t *testing.T) { + t.Run("creates tool with custom schema", func(t *testing.T) { + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "data": map[string]any{"type": "array"}, + }, + } + + tl := NewToolWithInputSchema("process", "Processes data", schema, + func(ctx *ToolContext, input any) (bool, error) { + return true, nil + }) + + if tl == nil { + t.Fatal("NewToolWithInputSchema returned nil") + } + + def := tl.Definition() + if def.InputSchema == nil { + t.Error("InputSchema is nil") + } + }) +} + +func TestDefineMultipartTool(t *testing.T) { + t.Run("creates multipart tool", func(t *testing.T) { + r := newTestRegistry(t) + tl := DefineMultipartTool(r, "provider/imageGen", "Generates images", + func(ctx *ToolContext, input struct { + Prompt string `json:"prompt"` + }) (*MultipartToolResponse, error) { + return &MultipartToolResponse{ + Output: "generated", + Content: []*Part{ + NewMediaPart("image/png", "data:image/png;base64,abc"), + }, + }, nil + }) + + if tl == nil { + t.Fatal("DefineMultipartTool returned nil") + } + + // Check that it's a multipart tool via metadata + def := tl.Definition() + if def.Metadata == nil { + t.Fatal("Metadata is nil") + } + if def.Metadata["multipart"] != true { + t.Error("multipart metadata = false, want true") + } + }) + + t.Run("multipart tool returns parts", func(t *testing.T) { + r := newTestRegistry(t) + tl := DefineMultipartTool(r, "provider/multiOut", "Returns multiple parts", + func(ctx *ToolContext, input struct{}) (*MultipartToolResponse, error) { + return &MultipartToolResponse{ + Output: map[string]any{"status": "ok"}, + Content: []*Part{ + NewTextPart("additional text"), + NewMediaPart("image/jpeg", "data:image/jpeg;base64,xyz"), + }, + }, nil + }) + + resp, err := tl.RunRawMultipart(context.Background(), map[string]any{}) + if err != nil { + t.Fatalf("RunRawMultipart error: %v", err) + } + + if len(resp.Content) != 2 { + t.Errorf("len(Content) = %d, want 2", len(resp.Content)) + } + }) +} + +func TestNewMultipartTool(t *testing.T) { + t.Run("creates unregistered multipart tool", func(t *testing.T) { + tl := NewMultipartTool("dynamicMulti", "Dynamic multipart", + func(ctx *ToolContext, input struct{}) (*MultipartToolResponse, error) { + return &MultipartToolResponse{Output: "test"}, nil + }) + + if tl == nil { + t.Fatal("NewMultipartTool returned nil") + } + // Check via definition metadata + def := tl.Definition() + if def.Metadata["multipart"] != true { + t.Error("multipart metadata = false, want true") + } + }) + + t.Run("can be registered later", func(t *testing.T) { + r := newTestRegistry(t) + tl := NewMultipartTool("provider/laterMulti", "Later registration", + func(ctx *ToolContext, input struct{}) (*MultipartToolResponse, error) { + return &MultipartToolResponse{Output: "ok"}, nil + }) + + tl.Register(r) + + found := LookupTool(r, "provider/laterMulti") + if found == nil { + t.Error("LookupTool returned nil after registration") + } + }) +} + +func TestToolDefinition(t *testing.T) { + t.Run("includes all fields", func(t *testing.T) { + r := newTestRegistry(t) + tl := DefineTool(r, "provider/complete", "A complete tool", func(ctx *ToolContext, input struct { + Query string `json:"query"` + }) (struct { + Result string `json:"result"` + }, error) { + return struct { + Result string `json:"result"` + }{Result: input.Query}, nil + }) + + def := tl.Definition() + + if def.Name != "provider/complete" { + t.Errorf("Name = %q, want %q", def.Name, "provider/complete") + } + if def.Description != "A complete tool" { + t.Errorf("Description = %q, want %q", def.Description, "A complete tool") + } + if def.InputSchema == nil { + t.Error("InputSchema is nil") + } + if def.OutputSchema == nil { + t.Error("OutputSchema is nil") + } + }) +} + +func TestLookupTool(t *testing.T) { + t.Run("returns nil for empty name", func(t *testing.T) { + r := newTestRegistry(t) + got := LookupTool(r, "") + if got != nil { + t.Errorf("LookupTool(\"\") = %v, want nil", got) + } + }) + + t.Run("returns nil for non-existent tool", func(t *testing.T) { + r := newTestRegistry(t) + got := LookupTool(r, "nonexistent/tool") + if got != nil { + t.Errorf("LookupTool(nonexistent) = %v, want nil", got) + } + }) + + t.Run("finds registered tool", func(t *testing.T) { + r := newTestRegistry(t) + DefineTool(r, "test/findMe", "Find me", func(ctx *ToolContext, input struct{}) (bool, error) { + return true, nil + }) + + got := LookupTool(r, "test/findMe") + if got == nil { + t.Error("LookupTool returned nil for registered tool") + } + }) +} + +func TestToolIsMultipart(t *testing.T) { + t.Run("regular tool is not multipart", func(t *testing.T) { + r := newTestRegistry(t) + tl := DefineTool(r, "provider/regular", "Regular tool", func(ctx *ToolContext, input struct{}) (string, error) { + return "ok", nil + }) + + def := tl.Definition() + if def.Metadata["multipart"] == true { + t.Error("multipart metadata = true for regular tool, want false") + } + }) + + t.Run("multipart tool is multipart", func(t *testing.T) { + r := newTestRegistry(t) + tl := DefineMultipartTool(r, "provider/multi", "Multi tool", + func(ctx *ToolContext, input struct{}) (*MultipartToolResponse, error) { + return &MultipartToolResponse{}, nil + }) + + def := tl.Definition() + if def.Metadata["multipart"] != true { + t.Error("multipart metadata = false for multipart tool, want true") + } + }) +} + +func TestToolRunRaw(t *testing.T) { + t.Run("returns output from regular tool", func(t *testing.T) { + r := newTestRegistry(t) + tl := DefineTool(r, "provider/sum", "Sums numbers", func(ctx *ToolContext, input struct { + Nums []int `json:"nums"` + }) (int, error) { + sum := 0 + for _, n := range input.Nums { + sum += n + } + return sum, nil + }) + + output, err := tl.RunRaw(context.Background(), map[string]any{ + "nums": []any{1, 2, 3, 4, 5}, + }) + + if err != nil { + t.Fatalf("RunRaw error: %v", err) + } + // JSON unmarshalling returns float64 for numbers + if output != float64(15) { + t.Errorf("output = %v (%T), want 15", output, output) + } + }) + + t.Run("returns error from tool", func(t *testing.T) { + r := newTestRegistry(t) + tl := DefineTool(r, "provider/fail", "Always fails", func(ctx *ToolContext, input struct{}) (string, error) { + return "", errors.New("intentional failure") + }) + + _, err := tl.RunRaw(context.Background(), map[string]any{}) + if err == nil { + t.Error("expected error, got nil") + } + }) +} + +func TestToolRunRawMultipart(t *testing.T) { + t.Run("returns full response from multipart tool", func(t *testing.T) { + r := newTestRegistry(t) + tl := DefineMultipartTool(r, "provider/fullResp", "Full response", + func(ctx *ToolContext, input struct{}) (*MultipartToolResponse, error) { + return &MultipartToolResponse{ + Output: "main output", + Content: []*Part{ + NewTextPart("extra"), + }, + }, nil + }) + + resp, err := tl.RunRawMultipart(context.Background(), map[string]any{}) + if err != nil { + t.Fatalf("RunRawMultipart error: %v", err) + } + + if resp.Output != "main output" { + t.Errorf("Output = %v, want %q", resp.Output, "main output") + } + if len(resp.Content) != 1 { + t.Errorf("len(Content) = %d, want 1", len(resp.Content)) + } + }) +} + +func TestToolRespond(t *testing.T) { + r := newTestRegistry(t) + tl := DefineTool(r, "provider/responder", "Test responder", func(ctx *ToolContext, input struct{}) (string, error) { + return "ok", nil + }) + + t.Run("creates response for tool request", func(t *testing.T) { + reqPart := NewToolRequestPart(&ToolRequest{ + Name: "provider/responder", + Ref: "ref-123", + Input: map[string]any{"x": 1}, + }) + reqPart.Metadata = map[string]any{"interrupt": true} + + resp := tl.Respond(reqPart, "output data", nil) + + if resp == nil { + t.Fatal("Respond returned nil") + } + if !resp.IsToolResponse() { + t.Error("response is not a tool response") + } + if resp.ToolResponse.Name != "provider/responder" { + t.Errorf("Name = %q, want %q", resp.ToolResponse.Name, "provider/responder") + } + if resp.ToolResponse.Ref != "ref-123" { + t.Errorf("Ref = %q, want %q", resp.ToolResponse.Ref, "ref-123") + } + }) + + t.Run("returns nil for non-tool-request part", func(t *testing.T) { + textPart := NewTextPart("not a tool request") + + resp := tl.Respond(textPart, "output", nil) + + if resp != nil { + t.Errorf("Respond(textPart) = %v, want nil", resp) + } + }) + + t.Run("returns nil for nil part", func(t *testing.T) { + resp := tl.Respond(nil, "output", nil) + + if resp != nil { + t.Errorf("Respond(nil) = %v, want nil", resp) + } + }) + + t.Run("includes response options metadata", func(t *testing.T) { + reqPart := NewToolRequestPart(&ToolRequest{ + Name: "provider/responder", + }) + reqPart.Metadata = map[string]any{"interrupt": true} + + opts := &RespondOptions{ + Metadata: map[string]any{"custom": "value"}, + } + resp := tl.Respond(reqPart, "output", opts) + + if resp.Metadata == nil { + t.Fatal("Metadata is nil") + } + if resp.Metadata["interruptResponse"] == nil { + t.Error("interruptResponse not set in metadata") + } + }) +} + +func TestToolRestart(t *testing.T) { + r := newTestRegistry(t) + tl := DefineTool(r, "provider/restarter", "Test restarter", func(ctx *ToolContext, input struct { + Value int `json:"value"` + }) (int, error) { + return input.Value, nil + }) + + t.Run("creates restart for tool request", func(t *testing.T) { + reqPart := NewToolRequestPart(&ToolRequest{ + Name: "provider/restarter", + Ref: "ref-456", + Input: map[string]any{"value": 10}, + }) + reqPart.Metadata = map[string]any{"interrupt": true} + + restart := tl.Restart(reqPart, nil) + + if restart == nil { + t.Fatal("Restart returned nil") + } + if !restart.IsToolRequest() { + t.Error("restart is not a tool request") + } + if restart.ToolRequest.Name != "provider/restarter" { + t.Errorf("Name = %q, want %q", restart.ToolRequest.Name, "provider/restarter") + } + if restart.Metadata["resumed"] != true { + t.Errorf("resumed = %v, want true", restart.Metadata["resumed"]) + } + if restart.Metadata["interrupt"] != nil { + t.Error("interrupt should be removed from metadata") + } + }) + + t.Run("returns nil for non-tool-request part", func(t *testing.T) { + textPart := NewTextPart("text") + + restart := tl.Restart(textPart, nil) + + if restart != nil { + t.Errorf("Restart(textPart) = %v, want nil", restart) + } + }) + + t.Run("returns nil for nil part", func(t *testing.T) { + restart := tl.Restart(nil, nil) + + if restart != nil { + t.Errorf("Restart(nil) = %v, want nil", restart) + } + }) + + t.Run("replaces input when specified", func(t *testing.T) { + reqPart := NewToolRequestPart(&ToolRequest{ + Name: "provider/restarter", + Input: map[string]any{"value": 10}, + }) + reqPart.Metadata = map[string]any{"interrupt": true} + + opts := &RestartOptions{ + ReplaceInput: map[string]any{"value": 20}, + } + restart := tl.Restart(reqPart, opts) + + newInput := restart.ToolRequest.Input.(map[string]any) + if newInput["value"] != 20 { + t.Errorf("new input value = %v, want 20", newInput["value"]) + } + if restart.Metadata["replacedInput"] == nil { + t.Error("replacedInput not set in metadata") + } + }) + + t.Run("sets resumed metadata when specified", func(t *testing.T) { + reqPart := NewToolRequestPart(&ToolRequest{ + Name: "provider/restarter", + }) + reqPart.Metadata = map[string]any{"interrupt": true} + + opts := &RestartOptions{ + ResumedMetadata: map[string]any{"reason": "user confirmed"}, + } + restart := tl.Restart(reqPart, opts) + + resumed := restart.Metadata["resumed"].(map[string]any) + if resumed["reason"] != "user confirmed" { + t.Errorf("resumed.reason = %v, want %q", resumed["reason"], "user confirmed") + } + }) +} + +func TestToolInterrupt(t *testing.T) { + t.Run("tool can interrupt execution", func(t *testing.T) { + r := newTestRegistry(t) + tl := DefineTool(r, "provider/interrupter", "Can interrupt", + func(ctx *ToolContext, input struct { + ShouldInterrupt bool `json:"shouldInterrupt"` + }) (string, error) { + if input.ShouldInterrupt { + return "", ctx.Interrupt(&InterruptOptions{ + Metadata: map[string]any{"step": "confirmation"}, + }) + } + return "completed", nil + }) + + _, err := tl.RunRaw(context.Background(), map[string]any{ + "shouldInterrupt": true, + }) + + if err == nil { + t.Fatal("expected interrupt error, got nil") + } + + isInterrupt, meta := IsToolInterruptError(err) + if !isInterrupt { + t.Errorf("IsToolInterruptError() = false, want true") + } + if meta["step"] != "confirmation" { + t.Errorf("metadata[step] = %v, want %q", meta["step"], "confirmation") + } + }) + + t.Run("tool completes without interrupt", func(t *testing.T) { + r := newTestRegistry(t) + tl := DefineTool(r, "provider/noInterrupt", "No interrupt", + func(ctx *ToolContext, input struct { + ShouldInterrupt bool `json:"shouldInterrupt"` + }) (string, error) { + if input.ShouldInterrupt { + return "", ctx.Interrupt(&InterruptOptions{}) + } + return "completed", nil + }) + + output, err := tl.RunRaw(context.Background(), map[string]any{ + "shouldInterrupt": false, + }) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if output != "completed" { + t.Errorf("output = %v, want %q", output, "completed") + } + }) +} + +func TestToolWithInputSchemaOption(t *testing.T) { + t.Run("DefineTool with WithInputSchema", func(t *testing.T) { + r := newTestRegistry(t) + customSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "customField": map[string]any{"type": "string"}, + }, + } + + tl := DefineTool(r, "provider/customInput", "Custom input schema", + func(ctx *ToolContext, input any) (string, error) { + m := input.(map[string]any) + return m["customField"].(string), nil + }, + WithInputSchema(customSchema)) + + def := tl.Definition() + if def.InputSchema == nil { + t.Error("InputSchema is nil") + } + }) + + t.Run("NewTool with WithInputSchema", func(t *testing.T) { + customSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "field": map[string]any{"type": "number"}, + }, + } + + tl := NewTool("customNew", "Custom new tool", + func(ctx *ToolContext, input any) (bool, error) { + return true, nil + }, + WithInputSchema(customSchema)) + + def := tl.Definition() + if def.InputSchema == nil { + t.Error("InputSchema is nil") + } + }) +} + +func TestResolveUniqueTools(t *testing.T) { + t.Run("resolves tools from registry", func(t *testing.T) { + r := newTestRegistry(t) + DefineTool(r, "provider/tool1", "Tool 1", func(ctx *ToolContext, input struct{}) (bool, error) { + return true, nil + }) + DefineTool(r, "provider/tool2", "Tool 2", func(ctx *ToolContext, input struct{}) (bool, error) { + return true, nil + }) + + toolRefs := []ToolRef{ + ToolName("provider/tool1"), + ToolName("provider/tool2"), + } + + names, newTools, err := resolveUniqueTools(r, toolRefs) + + if err != nil { + t.Fatalf("resolveUniqueTools error: %v", err) + } + if len(names) != 2 { + t.Errorf("len(names) = %d, want 2", len(names)) + } + if len(newTools) != 0 { + t.Errorf("len(newTools) = %d, want 0 (tools already registered)", len(newTools)) + } + }) + + t.Run("returns error for duplicate tools", func(t *testing.T) { + r := newTestRegistry(t) + toolRefs := []ToolRef{ + ToolName("provider/dup"), + ToolName("provider/dup"), + } + + _, _, err := resolveUniqueTools(r, toolRefs) + + if err == nil { + t.Error("expected error for duplicate tools, got nil") + } + }) + + t.Run("identifies new tools to register", func(t *testing.T) { + r := newTestRegistry(t) + newTl := NewTool("provider/brandNew", "Brand new", func(ctx *ToolContext, input struct{}) (string, error) { + return "new", nil + }) + + toolRefs := []ToolRef{newTl} + + names, newTools, err := resolveUniqueTools(r, toolRefs) + + if err != nil { + t.Fatalf("resolveUniqueTools error: %v", err) + } + if len(names) != 1 { + t.Errorf("len(names) = %d, want 1", len(names)) + } + if len(newTools) != 1 { + t.Errorf("len(newTools) = %d, want 1", len(newTools)) + } + }) +} + +func TestIsMultipart(t *testing.T) { + t.Run("returns false for standard tool", func(t *testing.T) { + r := newTestRegistry(t) + tl := DefineTool(r, "provider/standard", "Standard tool", + func(ctx *ToolContext, input struct{}) (string, error) { + return "result", nil + }) + + // IsMultipart is on the internal *tool type, so we need to type assert + internalTool := tl.(*tool) + if internalTool.IsMultipart() { + t.Error("IsMultipart() = true for standard tool, want false") + } + }) + + t.Run("returns false for NewTool", func(t *testing.T) { + tl := NewTool("standard", "Standard", + func(ctx *ToolContext, input struct{}) (string, error) { + return "result", nil + }) + + internalTool := tl.(*tool) + if internalTool.IsMultipart() { + t.Error("IsMultipart() = true for NewTool, want false") + } + }) + + t.Run("returns true for multipart tool", func(t *testing.T) { + r := newTestRegistry(t) + tl := DefineMultipartTool(r, "provider/multipart", "Multipart tool", + func(ctx *ToolContext, input struct{}) (*MultipartToolResponse, error) { + return &MultipartToolResponse{ + Content: []*Part{NewTextPart("hello"), NewTextPart("world")}, + }, nil + }) + + internalTool := tl.(*tool) + if !internalTool.IsMultipart() { + t.Error("IsMultipart() = false for multipart tool, want true") + } + }) + + t.Run("returns true for NewMultipartTool", func(t *testing.T) { + tl := NewMultipartTool("multipart", "Multipart", + func(ctx *ToolContext, input struct{}) (*MultipartToolResponse, error) { + return &MultipartToolResponse{ + Content: []*Part{NewTextPart("content")}, + }, nil + }) + + internalTool := tl.(*tool) + if !internalTool.IsMultipart() { + t.Error("IsMultipart() = false for NewMultipartTool, want true") + } + }) +} diff --git a/go/core/action_test.go b/go/core/action_test.go index 4ce63cffef..65309d850d 100644 --- a/go/core/action_test.go +++ b/go/core/action_test.go @@ -19,6 +19,7 @@ package core import ( "bytes" "context" + "encoding/json" "slices" "testing" @@ -124,3 +125,309 @@ func TestActionTracing(t *testing.T) { } t.Fatalf("did not find trace named %q", name) } + +func TestNewAction(t *testing.T) { + t.Run("creates unregistered action", func(t *testing.T) { + fn := func(ctx context.Context, input string) (string, error) { + return "Hello, " + input, nil + } + a := NewAction("greet", api.ActionTypeCustom, nil, nil, fn) + + if a == nil { + t.Fatal("NewAction returned nil") + } + if a.Name() != "greet" { + t.Errorf("Name() = %q, want %q", a.Name(), "greet") + } + }) + + t.Run("action can be executed", func(t *testing.T) { + fn := func(ctx context.Context, input int) (int, error) { + return input * 2, nil + } + a := NewAction("double", api.ActionTypeCustom, nil, nil, fn) + + got, err := a.Run(context.Background(), 5, nil) + if err != nil { + t.Fatalf("Run error: %v", err) + } + if got != 10 { + t.Errorf("got %d, want 10", got) + } + }) + + t.Run("action with custom input schema", func(t *testing.T) { + customSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + }, + } + fn := func(ctx context.Context, input any) (string, error) { + return "ok", nil + } + a := NewAction("withSchema", api.ActionTypeCustom, nil, customSchema, fn) + + desc := a.Desc() + if desc.InputSchema == nil { + t.Error("InputSchema is nil") + } + }) + + t.Run("action with metadata", func(t *testing.T) { + meta := map[string]any{ + "description": "A test action", + "version": "1.0", + } + fn := func(ctx context.Context, input struct{}) (bool, error) { + return true, nil + } + a := NewAction("withMeta", api.ActionTypeCustom, meta, nil, fn) + + desc := a.Desc() + if desc.Description != "A test action" { + t.Errorf("Description = %q, want %q", desc.Description, "A test action") + } + }) +} + +func TestNewStreamingAction(t *testing.T) { + t.Run("creates streaming action", func(t *testing.T) { + fn := func(ctx context.Context, n int, cb func(context.Context, int) error) (int, error) { + if cb != nil { + for i := 0; i < n; i++ { + if err := cb(ctx, i); err != nil { + return 0, err + } + } + } + return n, nil + } + a := NewStreamingAction("counter", api.ActionTypeCustom, nil, nil, fn) + + if a == nil { + t.Fatal("NewStreamingAction returned nil") + } + if a.Name() != "counter" { + t.Errorf("Name() = %q, want %q", a.Name(), "counter") + } + }) + + t.Run("streaming action streams correctly", func(t *testing.T) { + fn := func(ctx context.Context, n int, cb func(context.Context, string) error) (int, error) { + if cb != nil { + for i := 0; i < n; i++ { + if err := cb(ctx, "chunk"); err != nil { + return 0, err + } + } + } + return n, nil + } + a := NewStreamingAction("streamer", api.ActionTypeCustom, nil, nil, fn) + + var chunks []string + got, err := a.Run(context.Background(), 3, func(ctx context.Context, chunk string) error { + chunks = append(chunks, chunk) + return nil + }) + + if err != nil { + t.Fatalf("Run error: %v", err) + } + if got != 3 { + t.Errorf("got %d, want 3", got) + } + if len(chunks) != 3 { + t.Errorf("len(chunks) = %d, want 3", len(chunks)) + } + }) +} + +func TestActionDesc(t *testing.T) { + t.Run("returns action descriptor", func(t *testing.T) { + meta := map[string]any{ + "description": "Test description", + "custom": "value", + } + fn := func(ctx context.Context, input struct { + Name string `json:"name"` + }) (struct { + Greeting string `json:"greeting"` + }, error) { + return struct { + Greeting string `json:"greeting"` + }{Greeting: "Hello"}, nil + } + + r := registry.New() + a := DefineAction(r, "test/describe", api.ActionTypeCustom, meta, nil, fn) + + desc := a.Desc() + + if desc.Name != "test/describe" { + t.Errorf("Name = %q, want %q", desc.Name, "test/describe") + } + if desc.Description != "Test description" { + t.Errorf("Description = %q, want %q", desc.Description, "Test description") + } + if desc.Type != api.ActionTypeCustom { + t.Errorf("Type = %v, want %v", desc.Type, api.ActionTypeCustom) + } + if desc.InputSchema == nil { + t.Error("InputSchema is nil") + } + if desc.OutputSchema == nil { + t.Error("OutputSchema is nil") + } + }) +} + +func TestActionRegister(t *testing.T) { + t.Run("registers action with registry", func(t *testing.T) { + r := registry.New() + fn := func(ctx context.Context, input string) (string, error) { + return input, nil + } + a := NewAction("test/register", api.ActionTypeCustom, nil, nil, fn) + + a.Register(r) + + key := api.KeyFromName(api.ActionTypeCustom, "test/register") + found := r.LookupAction(key) + if found == nil { + t.Error("registered action not found in registry") + } + }) +} + +func TestResolveActionFor(t *testing.T) { + t.Run("finds registered action", func(t *testing.T) { + r := registry.New() + fn := func(ctx context.Context, input int) (int, error) { + return input + 1, nil + } + DefineAction(r, "test/resolvable", api.ActionTypeCustom, nil, nil, fn) + + found := ResolveActionFor[int, int, struct{}](r, api.ActionTypeCustom, "test/resolvable") + + if found == nil { + t.Fatal("ResolveActionFor returned nil") + } + if found.Name() != "test/resolvable" { + t.Errorf("Name() = %q, want %q", found.Name(), "test/resolvable") + } + }) + + t.Run("returns nil for non-existent action", func(t *testing.T) { + r := registry.New() + + found := ResolveActionFor[int, int, struct{}](r, api.ActionTypeCustom, "test/nonexistent") + + if found != nil { + t.Errorf("ResolveActionFor returned %v, want nil", found) + } + }) +} + +func TestLookupActionFor(t *testing.T) { + t.Run("finds registered action", func(t *testing.T) { + r := registry.New() + fn := func(ctx context.Context, input string) (string, error) { + return "found: " + input, nil + } + DefineAction(r, "test/lookupable", api.ActionTypeCustom, nil, nil, fn) + + found := LookupActionFor[string, string, struct{}](r, api.ActionTypeCustom, "test/lookupable") + + if found == nil { + t.Fatal("LookupActionFor returned nil") + } + }) + + t.Run("returns nil for non-existent action", func(t *testing.T) { + r := registry.New() + + found := LookupActionFor[string, string, struct{}](r, api.ActionTypeCustom, "test/missing") + + if found != nil { + t.Errorf("LookupActionFor returned %v, want nil", found) + } + }) +} + +func TestRunJSONWithTelemetry(t *testing.T) { + t.Run("returns telemetry info with result", func(t *testing.T) { + r := registry.New() + fn := func(ctx context.Context, input int) (int, error) { + return input * 2, nil + } + a := DefineAction(r, "test/telemetry", api.ActionTypeCustom, nil, nil, fn) + + result, err := a.RunJSONWithTelemetry(context.Background(), []byte("5"), nil) + + if err != nil { + t.Fatalf("RunJSONWithTelemetry error: %v", err) + } + if result == nil { + t.Fatal("result is nil") + } + if string(result.Result) != "10" { + t.Errorf("Result = %s, want %q", result.Result, "10") + } + // TraceId and SpanId should be set + if result.TraceId == "" { + t.Error("TraceId is empty") + } + if result.SpanId == "" { + t.Error("SpanId is empty") + } + }) + + t.Run("handles streaming callback", func(t *testing.T) { + r := registry.New() + fn := func(ctx context.Context, n int, cb func(context.Context, int) error) (int, error) { + if cb != nil { + for i := 0; i < n; i++ { + if err := cb(ctx, i); err != nil { + return 0, err + } + } + } + return n, nil + } + a := DefineStreamingAction(r, "test/streamTelemetry", api.ActionTypeCustom, nil, nil, fn) + + var chunks []string + cb := func(ctx context.Context, chunk json.RawMessage) error { + chunks = append(chunks, string(chunk)) + return nil + } + + result, err := a.RunJSONWithTelemetry(context.Background(), []byte("3"), cb) + + if err != nil { + t.Fatalf("RunJSONWithTelemetry error: %v", err) + } + if result == nil { + t.Fatal("result is nil") + } + if len(chunks) != 3 { + t.Errorf("len(chunks) = %d, want 3", len(chunks)) + } + }) + + t.Run("returns error for invalid JSON input", func(t *testing.T) { + r := registry.New() + fn := func(ctx context.Context, input int) (int, error) { + return input, nil + } + a := DefineAction(r, "test/invalidInput", api.ActionTypeCustom, nil, nil, fn) + + _, err := a.RunJSONWithTelemetry(context.Background(), []byte("not valid json"), nil) + + if err == nil { + t.Error("expected error for invalid JSON, got nil") + } + }) +} diff --git a/go/core/background_action_test.go b/go/core/background_action_test.go new file mode 100644 index 0000000000..5ce5d75ff7 --- /dev/null +++ b/go/core/background_action_test.go @@ -0,0 +1,431 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "context" + "testing" + + "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/internal/registry" +) + +func TestNewBackgroundAction(t *testing.T) { + t.Run("creates background action with all functions", func(t *testing.T) { + startFn := func(ctx context.Context, input string) (*Operation[string], error) { + return &Operation[string]{ID: "op-1", Done: false}, nil + } + checkFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { + return &Operation[string]{ID: op.ID, Done: true, Output: "result"}, nil + } + cancelFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { + return &Operation[string]{ID: op.ID, Done: true}, nil + } + + ba := NewBackgroundAction("test/background", api.ActionTypeCustom, nil, startFn, checkFn, cancelFn) + + if ba == nil { + t.Fatal("NewBackgroundAction returned nil") + } + if ba.Name() != "test/background" { + t.Errorf("Name() = %q, want %q", ba.Name(), "test/background") + } + if !ba.SupportsCancel() { + t.Error("SupportsCancel() = false, want true") + } + }) + + t.Run("creates background action without cancel", func(t *testing.T) { + startFn := func(ctx context.Context, input int) (*Operation[int], error) { + return &Operation[int]{ID: "op-1", Done: false}, nil + } + checkFn := func(ctx context.Context, op *Operation[int]) (*Operation[int], error) { + return &Operation[int]{ID: op.ID, Done: true, Output: 42}, nil + } + + ba := NewBackgroundAction("test/nocancel", api.ActionTypeCustom, nil, startFn, checkFn, nil) + + if ba == nil { + t.Fatal("NewBackgroundAction returned nil") + } + if ba.SupportsCancel() { + t.Error("SupportsCancel() = true, want false") + } + }) + + t.Run("panics with empty name", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic for empty name") + } + }() + + NewBackgroundAction("", api.ActionTypeCustom, nil, + func(ctx context.Context, input string) (*Operation[string], error) { + return nil, nil + }, + func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { + return nil, nil + }, + nil, + ) + }) + + t.Run("panics with nil startFn", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic for nil startFn") + } + }() + + NewBackgroundAction[string, string]("test/nilstart", api.ActionTypeCustom, nil, + nil, + func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { + return nil, nil + }, + nil, + ) + }) + + t.Run("panics with nil checkFn", func(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic for nil checkFn") + } + }() + + NewBackgroundAction("test/nilcheck", api.ActionTypeCustom, nil, + func(ctx context.Context, input string) (*Operation[string], error) { + return nil, nil + }, + nil, + nil, + ) + }) +} + +func TestDefineBackgroundAction(t *testing.T) { + t.Run("creates and registers background action", func(t *testing.T) { + r := registry.New() + startFn := func(ctx context.Context, input string) (*Operation[string], error) { + return &Operation[string]{ID: "op-1", Done: false}, nil + } + checkFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { + return &Operation[string]{ID: op.ID, Done: true, Output: "done"}, nil + } + + ba := DefineBackgroundAction(r, "test/registered", api.ActionTypeCustom, nil, startFn, checkFn, nil) + + if ba == nil { + t.Fatal("DefineBackgroundAction returned nil") + } + + // Verify action is registered + key := api.KeyFromName(api.ActionTypeCustom, "test/registered") + found := r.LookupAction(key) + if found == nil { + t.Error("background action not found in registry") + } + }) +} + +func TestBackgroundActionStart(t *testing.T) { + t.Run("starts operation", func(t *testing.T) { + r := registry.New() + startFn := func(ctx context.Context, input string) (*Operation[string], error) { + return &Operation[string]{ID: "test-op", Done: false, Metadata: map[string]any{"input": input}}, nil + } + checkFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { + return &Operation[string]{ID: op.ID, Done: op.Done, Metadata: map[string]any{}}, nil + } + + ba := DefineBackgroundAction(r, "test/start", api.ActionTypeCustom, nil, startFn, checkFn, nil) + + op, err := ba.Start(context.Background(), "hello") + if err != nil { + t.Fatalf("Start error: %v", err) + } + if op.ID != "test-op" { + t.Errorf("op.ID = %q, want %q", op.ID, "test-op") + } + if op.Done { + t.Error("op.Done = true, want false") + } + // Check that Action key is set + if op.Action == "" { + t.Error("op.Action is empty, expected to be set") + } + }) +} + +func TestBackgroundActionCheck(t *testing.T) { + t.Run("checks operation status", func(t *testing.T) { + r := registry.New() + startFn := func(ctx context.Context, input string) (*Operation[string], error) { + return &Operation[string]{ID: "check-op", Done: false, Metadata: map[string]any{}}, nil + } + checkFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { + return &Operation[string]{ID: op.ID, Done: true, Output: "completed", Metadata: map[string]any{}}, nil + } + + ba := DefineBackgroundAction(r, "test/check", api.ActionTypeCustom, nil, startFn, checkFn, nil) + + op, err := ba.Start(context.Background(), "input") + if err != nil { + t.Fatalf("Start error: %v", err) + } + + checked, err := ba.Check(context.Background(), op) + if err != nil { + t.Fatalf("Check error: %v", err) + } + if !checked.Done { + t.Error("checked.Done = false, want true") + } + if checked.Output != "completed" { + t.Errorf("checked.Output = %q, want %q", checked.Output, "completed") + } + }) +} + +func TestBackgroundActionCancel(t *testing.T) { + t.Run("cancels operation when supported", func(t *testing.T) { + r := registry.New() + startFn := func(ctx context.Context, input string) (*Operation[string], error) { + return &Operation[string]{ID: "cancel-op", Done: false, Metadata: map[string]any{}}, nil + } + checkFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { + return &Operation[string]{ID: op.ID, Done: op.Done, Metadata: map[string]any{}}, nil + } + cancelFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { + return &Operation[string]{ID: op.ID, Done: true, Metadata: map[string]any{"cancelled": true}}, nil + } + + ba := DefineBackgroundAction(r, "test/cancel", api.ActionTypeCustom, nil, startFn, checkFn, cancelFn) + + op, err := ba.Start(context.Background(), "input") + if err != nil { + t.Fatalf("Start error: %v", err) + } + + cancelled, err := ba.Cancel(context.Background(), op) + if err != nil { + t.Fatalf("Cancel error: %v", err) + } + if !cancelled.Done { + t.Error("cancelled.Done = false, want true") + } + }) + + t.Run("returns error when cancel not supported", func(t *testing.T) { + r := registry.New() + startFn := func(ctx context.Context, input string) (*Operation[string], error) { + return &Operation[string]{ID: "no-cancel-op", Done: false, Metadata: map[string]any{}}, nil + } + checkFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { + return &Operation[string]{ID: op.ID, Done: op.Done, Metadata: map[string]any{}}, nil + } + + ba := DefineBackgroundAction(r, "test/nocancel", api.ActionTypeCustom, nil, startFn, checkFn, nil) + + op, err := ba.Start(context.Background(), "input") + if err != nil { + t.Fatalf("Start error: %v", err) + } + + _, err = ba.Cancel(context.Background(), op) + if err == nil { + t.Error("expected error for unsupported cancel, got nil") + } + }) +} + +func TestBackgroundActionRegister(t *testing.T) { + t.Run("registers all sub-actions", func(t *testing.T) { + r := registry.New() + startFn := func(ctx context.Context, input string) (*Operation[string], error) { + return &Operation[string]{ID: "reg-op", Metadata: map[string]any{}}, nil + } + checkFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { + return &Operation[string]{ID: op.ID, Metadata: map[string]any{}}, nil + } + cancelFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { + return &Operation[string]{ID: op.ID, Metadata: map[string]any{}}, nil + } + + ba := NewBackgroundAction("test/register", api.ActionTypeCustom, nil, startFn, checkFn, cancelFn) + ba.Register(r) + + // Check main action + mainKey := api.KeyFromName(api.ActionTypeCustom, "test/register") + if r.LookupAction(mainKey) == nil { + t.Error("main action not registered") + } + + // Check check action + checkKey := api.KeyFromName(api.ActionTypeCheckOperation, "test/register") + if r.LookupAction(checkKey) == nil { + t.Error("check action not registered") + } + + // Check cancel action + cancelKey := api.KeyFromName(api.ActionTypeCancelOperation, "test/register") + if r.LookupAction(cancelKey) == nil { + t.Error("cancel action not registered") + } + }) + + t.Run("registers without cancel action when not provided", func(t *testing.T) { + r := registry.New() + startFn := func(ctx context.Context, input string) (*Operation[string], error) { + return &Operation[string]{ID: "reg-op", Metadata: map[string]any{}}, nil + } + checkFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { + return &Operation[string]{ID: op.ID, Metadata: map[string]any{}}, nil + } + + ba := NewBackgroundAction("test/register-nocancel", api.ActionTypeCustom, nil, startFn, checkFn, nil) + ba.Register(r) + + // Cancel action should not be registered + cancelKey := api.KeyFromName(api.ActionTypeCancelOperation, "test/register-nocancel") + if r.LookupAction(cancelKey) != nil { + t.Error("cancel action should not be registered") + } + }) +} + +func TestLookupBackgroundAction(t *testing.T) { + t.Run("finds registered background action", func(t *testing.T) { + r := registry.New() + startFn := func(ctx context.Context, input string) (*Operation[string], error) { + return &Operation[string]{ID: "lookup-op", Metadata: map[string]any{}}, nil + } + checkFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { + return &Operation[string]{ID: op.ID, Metadata: map[string]any{}}, nil + } + + DefineBackgroundAction(r, "test/lookup", api.ActionTypeCustom, nil, startFn, checkFn, nil) + + key := api.KeyFromName(api.ActionTypeCustom, "test/lookup") + found := LookupBackgroundAction[string, string](r, key) + + if found == nil { + t.Fatal("LookupBackgroundAction returned nil") + } + if found.Name() != "test/lookup" { + t.Errorf("Name() = %q, want %q", found.Name(), "test/lookup") + } + }) + + t.Run("returns nil for non-existent action", func(t *testing.T) { + r := registry.New() + + key := api.KeyFromName(api.ActionTypeCustom, "test/nonexistent") + found := LookupBackgroundAction[string, string](r, key) + + if found != nil { + t.Errorf("LookupBackgroundAction returned %v, want nil", found) + } + }) +} + +func TestCheckOperation(t *testing.T) { + t.Run("checks operation using registry lookup", func(t *testing.T) { + r := registry.New() + startFn := func(ctx context.Context, input string) (*Operation[string], error) { + return &Operation[string]{ID: "check-op", Done: false, Metadata: map[string]any{}}, nil + } + checkFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { + return &Operation[string]{ID: op.ID, Done: true, Output: "checked", Metadata: map[string]any{}}, nil + } + + ba := DefineBackgroundAction(r, "test/checkop", api.ActionTypeCustom, nil, startFn, checkFn, nil) + + op, err := ba.Start(context.Background(), "input") + if err != nil { + t.Fatalf("Start error: %v", err) + } + + checked, err := CheckOperation[string, string](context.Background(), r, op) + if err != nil { + t.Fatalf("CheckOperation error: %v", err) + } + if !checked.Done { + t.Error("checked.Done = false, want true") + } + if checked.Output != "checked" { + t.Errorf("checked.Output = %q, want %q", checked.Output, "checked") + } + }) + + t.Run("returns error for nil operation", func(t *testing.T) { + r := registry.New() + + _, err := CheckOperation[string, string](context.Background(), r, nil) + if err == nil { + t.Error("expected error for nil operation, got nil") + } + }) + + t.Run("returns error for operation with empty Action", func(t *testing.T) { + r := registry.New() + op := &Operation[string]{ID: "op-1", Metadata: map[string]any{}} + + _, err := CheckOperation[string, string](context.Background(), r, op) + if err == nil { + t.Error("expected error for operation with empty Action, got nil") + } + }) + + t.Run("returns error for non-existent action", func(t *testing.T) { + r := registry.New() + op := &Operation[string]{ + ID: "op-1", + Action: api.KeyFromName(api.ActionTypeCustom, "test/nonexistent"), + Metadata: map[string]any{}, + } + + _, err := CheckOperation[string, string](context.Background(), r, op) + if err == nil { + t.Error("expected error for non-existent action, got nil") + } + }) +} + +func TestBackgroundActionWithMetadata(t *testing.T) { + t.Run("preserves metadata", func(t *testing.T) { + r := registry.New() + meta := map[string]any{ + "description": "A test background action", + "version": "1.0", + } + startFn := func(ctx context.Context, input string) (*Operation[string], error) { + return &Operation[string]{ID: "meta-op", Metadata: map[string]any{}}, nil + } + checkFn := func(ctx context.Context, op *Operation[string]) (*Operation[string], error) { + return &Operation[string]{ID: op.ID, Metadata: map[string]any{}}, nil + } + + ba := DefineBackgroundAction(r, "test/meta", api.ActionTypeCustom, meta, startFn, checkFn, nil) + + desc := ba.Desc() + if desc.Description != "A test background action" { + t.Errorf("Description = %q, want %q", desc.Description, "A test background action") + } + }) +} diff --git a/go/core/context_test.go b/go/core/context_test.go new file mode 100644 index 0000000000..3ee4b8a0da --- /dev/null +++ b/go/core/context_test.go @@ -0,0 +1,122 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestWithActionContext(t *testing.T) { + t.Run("adds action context to context", func(t *testing.T) { + ctx := context.Background() + actionCtx := ActionContext{ + "userId": "user-123", + "sessionId": "session-456", + } + + newCtx := WithActionContext(ctx, actionCtx) + + retrieved := FromContext(newCtx) + if diff := cmp.Diff(actionCtx, retrieved); diff != "" { + t.Errorf("ActionContext mismatch (-want +got):\n%s", diff) + } + }) + + t.Run("replaces existing action context", func(t *testing.T) { + ctx := context.Background() + first := ActionContext{"key": "first"} + second := ActionContext{"key": "second"} + + ctx = WithActionContext(ctx, first) + ctx = WithActionContext(ctx, second) + + retrieved := FromContext(ctx) + if retrieved["key"] != "second" { + t.Errorf("key = %v, want %q", retrieved["key"], "second") + } + }) + + t.Run("allows nil action context", func(t *testing.T) { + ctx := context.Background() + newCtx := WithActionContext(ctx, nil) + + retrieved := FromContext(newCtx) + if retrieved != nil { + t.Errorf("expected nil, got %v", retrieved) + } + }) +} + +func TestFromContext(t *testing.T) { + t.Run("returns nil when no action context", func(t *testing.T) { + ctx := context.Background() + retrieved := FromContext(ctx) + + if retrieved != nil { + t.Errorf("expected nil, got %v", retrieved) + } + }) + + t.Run("returns action context when present", func(t *testing.T) { + ctx := context.Background() + actionCtx := ActionContext{ + "requestId": "req-789", + } + ctx = WithActionContext(ctx, actionCtx) + + retrieved := FromContext(ctx) + if retrieved["requestId"] != "req-789" { + t.Errorf("requestId = %v, want %q", retrieved["requestId"], "req-789") + } + }) + + t.Run("returns correct context from nested contexts", func(t *testing.T) { + ctx := context.Background() + actionCtx := ActionContext{"level": "root"} + ctx = WithActionContext(ctx, actionCtx) + + // Create child context with deadline (doesn't affect action context) + childCtx, cancel := context.WithCancel(ctx) + defer cancel() + + retrieved := FromContext(childCtx) + if retrieved["level"] != "root" { + t.Errorf("level = %v, want %q", retrieved["level"], "root") + } + }) +} + +func TestActionContextModification(t *testing.T) { + t.Run("modifications to retrieved context affect original", func(t *testing.T) { + ctx := context.Background() + actionCtx := ActionContext{"mutable": "original"} + ctx = WithActionContext(ctx, actionCtx) + + retrieved := FromContext(ctx) + retrieved["mutable"] = "modified" + + // Check that modification affected the stored context + // (maps are reference types, so this behavior is expected) + secondRetrieval := FromContext(ctx) + if secondRetrieval["mutable"] != "modified" { + t.Errorf("mutable = %v, want %q", secondRetrieval["mutable"], "modified") + } + }) +} diff --git a/go/core/core_test.go b/go/core/core_test.go new file mode 100644 index 0000000000..67ee0d912c --- /dev/null +++ b/go/core/core_test.go @@ -0,0 +1,284 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "testing" + + "github.com/firebase/genkit/go/internal/registry" + "github.com/google/go-cmp/cmp" +) + +func TestDefineSchema(t *testing.T) { + t.Run("registers schema in registry", func(t *testing.T) { + r := registry.New() + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + "age": map[string]any{"type": "integer"}, + }, + "required": []any{"name"}, + } + + DefineSchema(r, "Person", schema) + + found := r.LookupSchema("Person") + if found == nil { + t.Fatal("schema not found in registry") + } + if diff := cmp.Diff(schema, found); diff != "" { + t.Errorf("schema mismatch (-want +got):\n%s", diff) + } + }) +} + +func TestDefineSchemaFor(t *testing.T) { + t.Run("registers schema derived from Go type", func(t *testing.T) { + r := registry.New() + + type User struct { + Name string `json:"name"` + Email string `json:"email"` + } + + DefineSchemaFor[User](r) + + found := r.LookupSchema("User") + if found == nil { + t.Fatal("schema not found in registry") + } + // Check that the schema has expected properties + props, ok := found["properties"].(map[string]any) + if !ok { + t.Fatal("expected properties in schema") + } + if props["name"] == nil { + t.Error("expected 'name' property in schema") + } + if props["email"] == nil { + t.Error("expected 'email' property in schema") + } + }) + + t.Run("handles pointer types", func(t *testing.T) { + r := registry.New() + + type Config struct { + Debug bool `json:"debug"` + } + + DefineSchemaFor[*Config](r) + + found := r.LookupSchema("Config") + if found == nil { + t.Fatal("schema not found in registry for pointer type") + } + }) +} + +func TestSchemaRef(t *testing.T) { + t.Run("returns schema reference map", func(t *testing.T) { + ref := SchemaRef("MyType") + + want := map[string]any{ + "$ref": "genkit:MyType", + } + if diff := cmp.Diff(want, ref); diff != "" { + t.Errorf("SchemaRef mismatch (-want +got):\n%s", diff) + } + }) + + t.Run("handles various names", func(t *testing.T) { + tests := []struct { + name string + want string + }{ + {"Simple", "genkit:Simple"}, + {"Package.Type", "genkit:Package.Type"}, + {"my-schema", "genkit:my-schema"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ref := SchemaRef(tt.name) + if ref["$ref"] != tt.want { + t.Errorf("$ref = %q, want %q", ref["$ref"], tt.want) + } + }) + } + }) +} + +func TestResolveSchema(t *testing.T) { + t.Run("returns nil for nil schema", func(t *testing.T) { + r := registry.New() + + resolved, err := ResolveSchema(r, nil) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resolved != nil { + t.Errorf("expected nil, got %v", resolved) + } + }) + + t.Run("returns original schema without ref", func(t *testing.T) { + r := registry.New() + schema := map[string]any{ + "type": "string", + } + + resolved, err := ResolveSchema(r, schema) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if diff := cmp.Diff(schema, resolved); diff != "" { + t.Errorf("schema mismatch (-want +got):\n%s", diff) + } + }) + + t.Run("resolves genkit ref", func(t *testing.T) { + r := registry.New() + originalSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "id": map[string]any{"type": "integer"}, + }, + } + r.RegisterSchema("Entity", originalSchema) + + refSchema := map[string]any{ + "$ref": "genkit:Entity", + } + + resolved, err := ResolveSchema(r, refSchema) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if diff := cmp.Diff(originalSchema, resolved); diff != "" { + t.Errorf("resolved schema mismatch (-want +got):\n%s", diff) + } + }) + + t.Run("returns original schema for non-genkit ref", func(t *testing.T) { + r := registry.New() + schema := map[string]any{ + "$ref": "#/definitions/Other", + } + + resolved, err := ResolveSchema(r, schema) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if diff := cmp.Diff(schema, resolved); diff != "" { + t.Errorf("schema mismatch (-want +got):\n%s", diff) + } + }) + + t.Run("returns error for missing schema", func(t *testing.T) { + r := registry.New() + refSchema := map[string]any{ + "$ref": "genkit:NonExistent", + } + + _, err := ResolveSchema(r, refSchema) + + if err == nil { + t.Error("expected error for missing schema, got nil") + } + }) +} + +func TestInferSchemaMap(t *testing.T) { + t.Run("infers schema from struct", func(t *testing.T) { + type TestStruct struct { + Name string `json:"name"` + Count int `json:"count"` + Enabled bool `json:"enabled"` + } + + schema := InferSchemaMap(TestStruct{}) + + if schema["type"] != "object" { + t.Errorf("type = %v, want %q", schema["type"], "object") + } + props, ok := schema["properties"].(map[string]any) + if !ok { + t.Fatal("expected properties map") + } + if props["name"] == nil { + t.Error("expected 'name' property") + } + if props["count"] == nil { + t.Error("expected 'count' property") + } + if props["enabled"] == nil { + t.Error("expected 'enabled' property") + } + }) + + t.Run("infers schema from primitive types", func(t *testing.T) { + tests := []struct { + value any + wantType string + }{ + {"hello", "string"}, + {42, "integer"}, + {3.14, "number"}, + {true, "boolean"}, + } + + for _, tt := range tests { + t.Run(tt.wantType, func(t *testing.T) { + schema := InferSchemaMap(tt.value) + if schema["type"] != tt.wantType { + t.Errorf("type = %v, want %q", schema["type"], tt.wantType) + } + }) + } + }) + + t.Run("infers schema from slice", func(t *testing.T) { + schema := InferSchemaMap([]string{}) + + if schema["type"] != "array" { + t.Errorf("type = %v, want %q", schema["type"], "array") + } + }) + + t.Run("infers schema from nested struct", func(t *testing.T) { + type Inner struct { + Value string `json:"value"` + } + type Outer struct { + Inner Inner `json:"inner"` + } + + schema := InferSchemaMap(Outer{}) + + props := schema["properties"].(map[string]any) + innerProp := props["inner"].(map[string]any) + if innerProp["type"] != "object" { + t.Errorf("inner type = %v, want %q", innerProp["type"], "object") + } + }) +} diff --git a/go/core/doc.go b/go/core/doc.go new file mode 100644 index 0000000000..e4528df7f6 --- /dev/null +++ b/go/core/doc.go @@ -0,0 +1,230 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +/* +Package core implements Genkit's foundational action system and runtime machinery. + +This package is primarily intended for plugin developers and Genkit internals. +Application developers should use the genkit package instead, which provides +a higher-level, more convenient API. + +# Actions + +Actions are the fundamental building blocks of Genkit. Every operation - flows, +model calls, tool invocations, retrieval - is implemented as an action. Actions +provide: + + - Type-safe input/output with JSON schema validation + - Automatic tracing and observability + - Consistent error handling + - Registration in the action registry + +Define a non-streaming action: + + action := core.DefineAction(registry, "myAction", + func(ctx context.Context, input string) (string, error) { + return "processed: " + input, nil + }, + ) + + result, err := action.Run(context.Background(), "hello") + +Define a streaming action that sends chunks during execution: + + streamingAction := core.DefineStreamingAction(registry, "countdown", + func(ctx context.Context, start int, cb core.StreamCallback[string]) (string, error) { + for i := start; i > 0; i-- { + if cb != nil { + if err := cb(ctx, fmt.Sprintf("T-%d", i)); err != nil { + return "", err + } + } + time.Sleep(time.Second) + } + return "Liftoff!", nil + }, + ) + +# Flows + +Flows are user-defined actions that orchestrate AI operations. They are the +primary way application developers define business logic in Genkit: + + flow := core.DefineFlow(registry, "myFlow", + func(ctx context.Context, input string) (string, error) { + // Use Run to create traced sub-steps + result, err := core.Run(ctx, "step1", func() (string, error) { + return process(input), nil + }) + if err != nil { + return "", err + } + return result, nil + }, + ) + +Streaming flows can send intermediate results to callers: + + streamingFlow := core.DefineStreamingFlow(registry, "generateReport", + func(ctx context.Context, input Input, cb core.StreamCallback[Progress]) (Report, error) { + for i := 0; i < 100; i += 10 { + if cb != nil { + cb(ctx, Progress{Percent: i}) + } + // ... work ... + } + return Report{...}, nil + }, + ) + +# Traced Steps with Run + +Use [Run] within flows to create traced sub-operations. Each Run call creates +a span in the trace that's visible in the Genkit Developer UI: + + result, err := core.Run(ctx, "fetchData", func() (Data, error) { + return fetchFromAPI() + }) + + processed, err := core.Run(ctx, "processData", func() (Result, error) { + return process(result) + }) + +# Middleware + +Actions support middleware for cross-cutting concerns like logging, metrics, +or authentication: + + loggingMiddleware := func(next core.StreamingFunc[string, string, struct{}]) core.StreamingFunc[string, string, struct{}] { + return func(ctx context.Context, input string, cb core.StreamCallback[struct{}]) (string, error) { + log.Printf("Input: %s", input) + output, err := next(ctx, input, cb) + log.Printf("Output: %s, Error: %v", output, err) + return output, err + } + } + +Chain multiple middleware together: + + combined := core.ChainMiddleware(loggingMiddleware, metricsMiddleware) + wrappedFn := combined(originalFunc) + +# Schema Management + +Register JSON schemas for use in prompts and validation: + + // Define a schema from a map + core.DefineSchema(registry, "Person", map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{"type": "string"}, + "age": map[string]any{"type": "integer"}, + }, + "required": []any{"name"}, + }) + + // Define a schema from a Go type (recommended) + core.DefineSchemaFor[Person](registry) + +Schemas can be referenced in .prompt files by name. + +# Plugin Development + +Plugins extend Genkit's functionality by providing models, tools, retrievers, +and other capabilities. Implement the [api.Plugin] interface: + + type MyPlugin struct { + APIKey string + } + + func (p *MyPlugin) Name() string { + return "myplugin" + } + + func (p *MyPlugin) Init(ctx context.Context) []api.Action { + // Initialize the plugin and return actions to register + model := ai.DefineModel(...) + tool := ai.DefineTool(...) + return []api.Action{model, tool} + } + +For plugins that resolve actions dynamically (e.g., listing available models +from an API), implement [api.DynamicPlugin]: + + type DynamicModelPlugin struct{} + + func (p *DynamicModelPlugin) ListActions(ctx context.Context) []api.ActionDesc { + // Return descriptors of available actions + return []api.ActionDesc{ + {Key: "/model/myplugin/model-a", Name: "model-a"}, + {Key: "/model/myplugin/model-b", Name: "model-b"}, + } + } + + func (p *DynamicModelPlugin) ResolveAction(atype api.ActionType, name string) api.Action { + // Create and return the action on demand + return createModel(name) + } + +# Background Actions + +For long-running operations, use background actions that return immediately +with an operation ID that can be polled for completion: + + bgAction := core.DefineBackgroundAction(registry, "longTask", + func(ctx context.Context, input Input) (Output, error) { + // Start the operation + return startLongOperation(input) + }, + func(ctx context.Context, op *core.Operation[Output]) (*core.Operation[Output], error) { + // Check operation status + return checkOperationStatus(op) + }, + ) + +# Error Handling + +Return user-facing errors with appropriate status codes: + + if err := validate(input); err != nil { + return nil, core.NewPublicError(core.INVALID_ARGUMENT, "Invalid input", map[string]any{ + "field": "email", + "error": err.Error(), + }) + } + +For internal errors that should be logged but not exposed to users: + + return nil, core.NewError(core.INTERNAL, "database connection failed: %v", err) + +# Context + +Access action context for metadata and configuration: + + ctx := core.FromContext(ctx) + if ctx != nil { + // Access action-specific context values + } + +Set action context for nested operations: + + ctx = core.WithActionContext(ctx, core.ActionContext{ + "requestId": requestID, + }) + +For more information, see https://genkit.dev/docs/plugins +*/ +package core diff --git a/go/core/error_test.go b/go/core/error_test.go new file mode 100644 index 0000000000..60ff503bdc --- /dev/null +++ b/go/core/error_test.go @@ -0,0 +1,217 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "errors" + "fmt" + "net/http" + "strings" + "testing" +) + +func TestNewPublicError(t *testing.T) { + t.Run("creates error with all fields", func(t *testing.T) { + details := map[string]any{"field": "username"} + err := NewPublicError(INVALID_ARGUMENT, "invalid username", details) + + if err.Status != INVALID_ARGUMENT { + t.Errorf("Status = %q, want %q", err.Status, INVALID_ARGUMENT) + } + if err.Message != "invalid username" { + t.Errorf("Message = %q, want %q", err.Message, "invalid username") + } + if err.Details["field"] != "username" { + t.Errorf("Details[field] = %v, want %q", err.Details["field"], "username") + } + }) + + t.Run("creates error with nil details", func(t *testing.T) { + err := NewPublicError(NOT_FOUND, "resource not found", nil) + + if err.Status != NOT_FOUND { + t.Errorf("Status = %q, want %q", err.Status, NOT_FOUND) + } + if err.Details != nil { + t.Errorf("Details = %v, want nil", err.Details) + } + }) +} + +func TestUserFacingErrorError(t *testing.T) { + t.Run("formats error message correctly", func(t *testing.T) { + err := NewPublicError(PERMISSION_DENIED, "access denied", nil) + got := err.Error() + want := "PERMISSION_DENIED: access denied" + + if got != want { + t.Errorf("Error() = %q, want %q", got, want) + } + }) +} + +func TestNewError(t *testing.T) { + t.Run("creates error with simple message", func(t *testing.T) { + err := NewError(INTERNAL, "internal error") + + if err.Status != INTERNAL { + t.Errorf("Status = %q, want %q", err.Status, INTERNAL) + } + if err.Message != "internal error" { + t.Errorf("Message = %q, want %q", err.Message, "internal error") + } + }) + + t.Run("creates error with formatted message", func(t *testing.T) { + err := NewError(INVALID_ARGUMENT, "field %q has invalid value %d", "count", 42) + + want := `field "count" has invalid value 42` + if err.Message != want { + t.Errorf("Message = %q, want %q", err.Message, want) + } + }) + + t.Run("captures stack trace", func(t *testing.T) { + err := NewError(INTERNAL, "error with stack") + + if err.Details == nil { + t.Fatal("Details is nil, expected stack trace") + } + stack, ok := err.Details["stack"].(string) + if !ok { + t.Fatal("stack is not a string") + } + if !strings.Contains(stack, "TestNewError") { + t.Errorf("stack trace does not contain test function name") + } + }) +} + +func TestGenkitErrorError(t *testing.T) { + t.Run("returns message as error string", func(t *testing.T) { + err := NewError(INTERNAL, "something went wrong") + got := err.Error() + + if got != "something went wrong" { + t.Errorf("Error() = %q, want %q", got, "something went wrong") + } + }) +} + +func TestGenkitErrorToReflectionError(t *testing.T) { + t.Run("converts error with stack", func(t *testing.T) { + ge := NewError(NOT_FOUND, "resource not found") + re := ge.ToReflectionError() + + if re.Message != "resource not found" { + t.Errorf("Message = %q, want %q", re.Message, "resource not found") + } + if re.Code != http.StatusNotFound { + t.Errorf("Code = %d, want %d", re.Code, http.StatusNotFound) + } + if re.Details == nil || re.Details.Stack == nil { + t.Error("expected stack in details") + } + }) + + t.Run("converts error with traceId", func(t *testing.T) { + ge := &GenkitError{ + Status: INTERNAL, + Message: "internal error", + Details: map[string]any{ + "traceId": "trace-123", + }, + } + re := ge.ToReflectionError() + + if re.Details == nil || re.Details.TraceID == nil { + t.Fatal("expected traceId in details") + } + if *re.Details.TraceID != "trace-123" { + t.Errorf("TraceID = %q, want %q", *re.Details.TraceID, "trace-123") + } + }) + + t.Run("handles empty details", func(t *testing.T) { + ge := &GenkitError{ + Status: OK, + Message: "success", + Details: nil, + } + re := ge.ToReflectionError() + + if re.Message != "success" { + t.Errorf("Message = %q, want %q", re.Message, "success") + } + if re.Details.Stack != nil { + t.Error("expected nil stack") + } + }) +} + +func TestToReflectionError(t *testing.T) { + t.Run("handles GenkitError directly", func(t *testing.T) { + ge := NewError(INVALID_ARGUMENT, "bad input") + re := ToReflectionError(ge) + + if re.Message != "bad input" { + t.Errorf("Message = %q, want %q", re.Message, "bad input") + } + if re.Code != http.StatusBadRequest { + t.Errorf("Code = %d, want %d", re.Code, http.StatusBadRequest) + } + }) + + t.Run("handles wrapped GenkitError", func(t *testing.T) { + ge := NewError(NOT_FOUND, "not found") + wrapped := fmt.Errorf("context: %w", ge) + re := ToReflectionError(wrapped) + + if re.Message != "not found" { + t.Errorf("Message = %q, want %q", re.Message, "not found") + } + if re.Code != http.StatusNotFound { + t.Errorf("Code = %d, want %d", re.Code, http.StatusNotFound) + } + }) + + t.Run("handles plain error", func(t *testing.T) { + plainErr := errors.New("plain error") + re := ToReflectionError(plainErr) + + if re.Message != "plain error" { + t.Errorf("Message = %q, want %q", re.Message, "plain error") + } + if re.Code != http.StatusInternalServerError { + t.Errorf("Code = %d, want %d", re.Code, http.StatusInternalServerError) + } + }) + + t.Run("handles doubly wrapped GenkitError", func(t *testing.T) { + ge := NewError(PERMISSION_DENIED, "denied") + wrapped1 := fmt.Errorf("layer1: %w", ge) + wrapped2 := fmt.Errorf("layer2: %w", wrapped1) + re := ToReflectionError(wrapped2) + + if re.Message != "denied" { + t.Errorf("Message = %q, want %q", re.Message, "denied") + } + if re.Code != http.StatusForbidden { + t.Errorf("Code = %d, want %d", re.Code, http.StatusForbidden) + } + }) +} diff --git a/go/core/example_test.go b/go/core/example_test.go new file mode 100644 index 0000000000..c6212c3e9d --- /dev/null +++ b/go/core/example_test.go @@ -0,0 +1,197 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package core_test + +import ( + "context" + "fmt" + "strings" + + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/internal/registry" +) + +// This example demonstrates defining a simple flow. +func ExampleDefineFlow() { + r := registry.New() + + // Define a flow that processes input + flow := core.DefineFlow(r, "uppercase", + func(ctx context.Context, input string) (string, error) { + return strings.ToUpper(input), nil + }, + ) + + // Run the flow + result, err := flow.Run(context.Background(), "hello") + if err != nil { + fmt.Println("Error:", err) + return + } + fmt.Println(result) + // Output: HELLO +} + +// This example demonstrates defining a streaming flow. +func ExampleDefineStreamingFlow() { + r := registry.New() + + // Define a streaming flow that counts down + flow := core.DefineStreamingFlow(r, "countdown", + func(ctx context.Context, start int, cb core.StreamCallback[int]) (string, error) { + for i := start; i > 0; i-- { + if cb != nil { + if err := cb(ctx, i); err != nil { + return "", err + } + } + } + return "Done!", nil + }, + ) + + // Use Stream() iterator to receive chunks + iter := flow.Stream(context.Background(), 3) + iter(func(val *core.StreamingFlowValue[string, int], err error) bool { + if err != nil { + fmt.Println("Error:", err) + return false + } + if val.Done { + fmt.Println("Result:", val.Output) + } else { + fmt.Println("Count:", val.Stream) + } + return true + }) + // Output: + // Count: 3 + // Count: 2 + // Count: 1 + // Result: Done! +} + +// This example demonstrates using Run to create traced sub-steps. +func ExampleRun() { + r := registry.New() + + // Define a flow that uses Run for traced steps + flow := core.DefineFlow(r, "pipeline", + func(ctx context.Context, input string) (string, error) { + // Each Run creates a traced step visible in the Dev UI + upper, err := core.Run(ctx, "toUpper", func() (string, error) { + return strings.ToUpper(input), nil + }) + if err != nil { + return "", err + } + + result, err := core.Run(ctx, "addPrefix", func() (string, error) { + return "RESULT: " + upper, nil + }) + return result, err + }, + ) + + result, err := flow.Run(context.Background(), "hello") + if err != nil { + fmt.Println("Error:", err) + return + } + fmt.Println(result) + // Output: RESULT: HELLO +} + +// This example demonstrates defining a schema from a Go type. +func ExampleDefineSchemaFor() { + r := registry.New() + + // Define a struct type + type Person struct { + Name string `json:"name"` + Age int `json:"age"` + } + + // Register the schema + core.DefineSchemaFor[Person](r) + + // The schema is now registered and can be referenced in .prompt files + fmt.Println("Schema registered") + // Output: Schema registered +} + +// This example demonstrates defining a schema from a map. +func ExampleDefineSchema() { + r := registry.New() + + // Define a JSON schema as a map + core.DefineSchema(r, "Address", map[string]any{ + "type": "object", + "properties": map[string]any{ + "street": map[string]any{"type": "string"}, + "city": map[string]any{"type": "string"}, + "zip": map[string]any{"type": "string"}, + }, + "required": []any{"street", "city"}, + }) + + fmt.Println("Schema registered: Address") + // Output: Schema registered: Address +} + +// This example demonstrates using ChainMiddleware to combine middleware. +func ExampleChainMiddleware() { + // Define a middleware that wraps function calls + logMiddleware := func(next core.StreamingFunc[string, string, struct{}]) core.StreamingFunc[string, string, struct{}] { + return func(ctx context.Context, input string, cb core.StreamCallback[struct{}]) (string, error) { + fmt.Println("Before:", input) + result, err := next(ctx, input, cb) + fmt.Println("After:", result) + return result, err + } + } + + // The original function + originalFn := func(ctx context.Context, input string, cb core.StreamCallback[struct{}]) (string, error) { + return strings.ToUpper(input), nil + } + + // Chain and apply middleware + wrapped := core.ChainMiddleware(logMiddleware)(originalFn) + + result, _ := wrapped(context.Background(), "hello", nil) + fmt.Println("Final:", result) + // Output: + // Before: hello + // After: HELLO + // Final: HELLO +} + +// This example demonstrates creating user-facing errors. +func ExampleNewPublicError() { + // Create a user-facing error with details + err := core.NewPublicError(core.INVALID_ARGUMENT, "Invalid email format", map[string]any{ + "field": "email", + "value": "not-an-email", + }) + + fmt.Println("Status:", err.Status) + fmt.Println("Message:", err.Message) + // Output: + // Status: INVALID_ARGUMENT + // Message: Invalid email format +} diff --git a/go/core/flow.go b/go/core/flow.go index 0cd12120f2..ea514365c2 100644 --- a/go/core/flow.go +++ b/go/core/flow.go @@ -71,6 +71,9 @@ func DefineStreamingFlow[In, Out, Stream any](r api.Registry, name string, fn St flowName: name, } ctx = flowContextKey.NewContext(ctx, fc) + if cb == nil { + cb = func(context.Context, Stream) error { return nil } + } return fn(ctx, input, cb) })) } diff --git a/go/core/flow_test.go b/go/core/flow_test.go index 77087072c4..e3c3e6b463 100644 --- a/go/core/flow_test.go +++ b/go/core/flow_test.go @@ -89,3 +89,171 @@ func TestFlowNameFromContext(t *testing.T) { }) } } + +func TestRunOutsideFlow(t *testing.T) { + t.Run("returns error when called outside flow", func(t *testing.T) { + ctx := context.Background() + _, err := Run(ctx, "step", func() (int, error) { + return 42, nil + }) + + if err == nil { + t.Error("expected error when Run called outside flow, got nil") + } + }) +} + +func TestFlowStream(t *testing.T) { + t.Run("streams values correctly", func(t *testing.T) { + r := registry.New() + f := DefineStreamingFlow(r, "counter", func(ctx context.Context, n int, cb StreamCallback[int]) (int, error) { + for i := 0; i < n; i++ { + if err := cb(ctx, i); err != nil { + return 0, err + } + } + return n, nil + }) + + var streamedValues []int + var finalOutput int + var finalDone bool + + for v, err := range f.Stream(context.Background(), 3) { + if err != nil { + t.Fatalf("Stream error: %v", err) + } + if v.Done { + finalDone = true + finalOutput = v.Output + } else { + streamedValues = append(streamedValues, v.Stream) + } + } + + wantStreamed := []int{0, 1, 2} + if !slices.Equal(streamedValues, wantStreamed) { + t.Errorf("streamed values = %v, want %v", streamedValues, wantStreamed) + } + if !finalDone { + t.Error("expected final Done value") + } + if finalOutput != 3 { + t.Errorf("final output = %d, want 3", finalOutput) + } + }) + + t.Run("yields error on flow failure", func(t *testing.T) { + r := registry.New() + f := DefineStreamingFlow(r, "failing", func(ctx context.Context, input int, cb StreamCallback[int]) (int, error) { + return 0, NewError(INTERNAL, "flow failed") + }) + + var gotErr error + for _, err := range f.Stream(context.Background(), 1) { + if err != nil { + gotErr = err + } + } + + if gotErr == nil { + t.Error("expected error from failing flow, got nil") + } + }) +} + +func TestFlowRegister(t *testing.T) { + t.Run("flow can be registered with registry", func(t *testing.T) { + r := registry.New() + f := DefineFlow(r, "test/registerable", func(ctx context.Context, input string) (string, error) { + return input, nil + }) + + // Flow should already be registered by DefineFlow + if f.Name() != "test/registerable" { + t.Errorf("Name() = %q, want %q", f.Name(), "test/registerable") + } + }) +} + +func TestFlowDesc(t *testing.T) { + t.Run("returns flow descriptor", func(t *testing.T) { + r := registry.New() + f := DefineFlow(r, "test/described", func(ctx context.Context, input struct { + Name string `json:"name"` + }) (struct { + Greeting string `json:"greeting"` + }, error) { + return struct { + Greeting string `json:"greeting"` + }{Greeting: "Hello " + input.Name}, nil + }) + + desc := f.Desc() + + if desc.Name != "test/described" { + t.Errorf("Name = %q, want %q", desc.Name, "test/described") + } + if desc.InputSchema == nil { + t.Error("InputSchema is nil") + } + if desc.OutputSchema == nil { + t.Error("OutputSchema is nil") + } + }) +} + +func TestFlowRunJSON(t *testing.T) { + t.Run("runs flow with JSON input and output", func(t *testing.T) { + r := registry.New() + f := DefineFlow(r, "test/jsonFlow", func(ctx context.Context, input int) (int, error) { + return input * 2, nil + }) + + got, err := f.RunJSON(context.Background(), []byte("5"), nil) + if err != nil { + t.Fatalf("RunJSON error: %v", err) + } + + if string(got) != "10" { + t.Errorf("RunJSON result = %s, want %q", got, "10") + } + }) +} + +func TestFlowRunJSONWithTelemetry(t *testing.T) { + t.Run("returns telemetry info with result", func(t *testing.T) { + r := registry.New() + f := DefineFlow(r, "test/telemetryFlow", func(ctx context.Context, input int) (int, error) { + return input + 1, nil + }) + + result, err := f.RunJSONWithTelemetry(context.Background(), []byte("5"), nil) + if err != nil { + t.Fatalf("RunJSONWithTelemetry error: %v", err) + } + + if result == nil { + t.Fatal("result is nil") + } + if string(result.Result) != "6" { + t.Errorf("Result = %s, want %q", result.Result, "6") + } + if result.TraceId == "" { + t.Error("TraceId is empty") + } + if result.SpanId == "" { + t.Error("SpanId is empty") + } + }) +} + +func TestFlowNameFromContextOutsideFlow(t *testing.T) { + t.Run("returns empty string outside flow", func(t *testing.T) { + ctx := context.Background() + got := FlowNameFromContext(ctx) + if got != "" { + t.Errorf("FlowNameFromContext outside flow = %q, want empty string", got) + } + }) +} diff --git a/go/core/logger/doc.go b/go/core/logger/doc.go new file mode 100644 index 0000000000..b3e421abc6 --- /dev/null +++ b/go/core/logger/doc.go @@ -0,0 +1,110 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +/* +Package logger provides context-scoped structured logging for Genkit. + +This package wraps the standard library's [log/slog] package to provide +context-aware logging throughout Genkit operations. Logs are automatically +associated with the current action or flow context. + +# Usage + +Retrieve the logger from context within action or flow handlers: + + func myFlow(ctx context.Context, input string) (string, error) { + log := logger.FromContext(ctx) + + log.Info("Processing input", "size", len(input)) + log.Debug("Input details", "value", input) + + result, err := process(input) + if err != nil { + log.Error("Processing failed", "error", err) + return "", err + } + + log.Info("Processing complete", "resultSize", len(result)) + return result, nil + } + +# Log Levels + +Control the global log level to filter output: + + // Show debug logs (verbose) + logger.SetLevel(slog.LevelDebug) + + // Show info and above (default) + logger.SetLevel(slog.LevelInfo) + + // Show only warnings and errors + logger.SetLevel(slog.LevelWarn) + + // Show only errors + logger.SetLevel(slog.LevelError) + + // Get the current log level + level := logger.GetLevel() + +# Context Integration + +The logger is automatically available in action and flow contexts. It +inherits from the context passed to [genkit.Init] and flows through +all nested operations. + +For custom operations outside of actions/flows, attach a logger to context: + + log := slog.Default() + ctx = logger.WithContext(ctx, log) + +# slog Compatibility + +The logger returned by [FromContext] is a standard [*slog.Logger] and +supports all slog methods: + + log := logger.FromContext(ctx) + + // Structured logging with attributes + log.Info("User action", + "userId", userID, + "action", "login", + "duration", elapsed, + ) + + // Grouped attributes + log.Info("Request completed", + slog.Group("request", + "method", r.Method, + "path", r.URL.Path, + ), + slog.Group("response", + "status", status, + "bytes", written, + ), + ) + + // With pre-set attributes + requestLog := log.With("requestId", requestID) + requestLog.Info("Starting") + // ... later ... + requestLog.Info("Finished") + +This package is primarily used by Genkit internals but is useful for +plugin developers who need consistent logging that integrates with +Genkit's observability features. +*/ +package logger diff --git a/go/core/middleware_test.go b/go/core/middleware_test.go new file mode 100644 index 0000000000..8422eb3ee3 --- /dev/null +++ b/go/core/middleware_test.go @@ -0,0 +1,222 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "context" + "strings" + "testing" +) + +func TestMiddlewares(t *testing.T) { + t.Run("creates slice of middlewares", func(t *testing.T) { + m1 := func(next StreamingFunc[string, string, struct{}]) StreamingFunc[string, string, struct{}] { + return next + } + m2 := func(next StreamingFunc[string, string, struct{}]) StreamingFunc[string, string, struct{}] { + return next + } + + result := Middlewares(m1, m2) + + if len(result) != 2 { + t.Errorf("len(result) = %d, want 2", len(result)) + } + }) + + t.Run("returns empty slice when no middlewares", func(t *testing.T) { + result := Middlewares[string, string, struct{}]() + + if len(result) != 0 { + t.Errorf("len(result) = %d, want 0", len(result)) + } + }) + + t.Run("returns single middleware slice", func(t *testing.T) { + m := func(next StreamingFunc[string, string, struct{}]) StreamingFunc[string, string, struct{}] { + return next + } + + result := Middlewares(m) + + if len(result) != 1 { + t.Errorf("len(result) = %d, want 1", len(result)) + } + }) +} + +func TestChainMiddleware(t *testing.T) { + t.Run("empty chain returns identity", func(t *testing.T) { + handler := func(ctx context.Context, input string, cb func(context.Context, struct{}) error) (string, error) { + return "original:" + input, nil + } + + chained := ChainMiddleware[string, string, struct{}]()(handler) + result, err := chained(context.Background(), "test", nil) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != "original:test" { + t.Errorf("result = %q, want %q", result, "original:test") + } + }) + + t.Run("single middleware is applied", func(t *testing.T) { + handler := func(ctx context.Context, input string, cb func(context.Context, struct{}) error) (string, error) { + return "handler:" + input, nil + } + + middleware := func(next StreamingFunc[string, string, struct{}]) StreamingFunc[string, string, struct{}] { + return func(ctx context.Context, input string, cb func(context.Context, struct{}) error) (string, error) { + result, err := next(ctx, "m1:"+input, cb) + return "m1:" + result, err + } + } + + chained := ChainMiddleware(middleware)(handler) + result, err := chained(context.Background(), "test", nil) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // Expected: m1: wraps output, m1: prepends to input + if result != "m1:handler:m1:test" { + t.Errorf("result = %q, want %q", result, "m1:handler:m1:test") + } + }) + + t.Run("multiple middlewares execute in order", func(t *testing.T) { + var executionOrder []string + + handler := func(ctx context.Context, input string, cb func(context.Context, struct{}) error) (string, error) { + executionOrder = append(executionOrder, "handler") + return input, nil + } + + m1 := func(next StreamingFunc[string, string, struct{}]) StreamingFunc[string, string, struct{}] { + return func(ctx context.Context, input string, cb func(context.Context, struct{}) error) (string, error) { + executionOrder = append(executionOrder, "m1-before") + result, err := next(ctx, input, cb) + executionOrder = append(executionOrder, "m1-after") + return result, err + } + } + + m2 := func(next StreamingFunc[string, string, struct{}]) StreamingFunc[string, string, struct{}] { + return func(ctx context.Context, input string, cb func(context.Context, struct{}) error) (string, error) { + executionOrder = append(executionOrder, "m2-before") + result, err := next(ctx, input, cb) + executionOrder = append(executionOrder, "m2-after") + return result, err + } + } + + // ChainMiddleware(m1, m2) should execute as: m1 -> m2 -> handler -> m2 -> m1 + chained := ChainMiddleware(m1, m2)(handler) + _, err := chained(context.Background(), "test", nil) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := []string{"m1-before", "m2-before", "handler", "m2-after", "m1-after"} + if len(executionOrder) != len(expected) { + t.Errorf("execution order length = %d, want %d", len(executionOrder), len(expected)) + } + for i, step := range expected { + if i >= len(executionOrder) || executionOrder[i] != step { + t.Errorf("step %d = %q, want %q", i, executionOrder[i], step) + } + } + }) + + t.Run("middleware can modify input", func(t *testing.T) { + handler := func(ctx context.Context, input string, cb func(context.Context, struct{}) error) (string, error) { + return input, nil + } + + uppercase := func(next StreamingFunc[string, string, struct{}]) StreamingFunc[string, string, struct{}] { + return func(ctx context.Context, input string, cb func(context.Context, struct{}) error) (string, error) { + return next(ctx, strings.ToUpper(input), cb) + } + } + + chained := ChainMiddleware(uppercase)(handler) + result, err := chained(context.Background(), "hello", nil) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != "HELLO" { + t.Errorf("result = %q, want %q", result, "HELLO") + } + }) + + t.Run("middleware can modify output", func(t *testing.T) { + handler := func(ctx context.Context, input string, cb func(context.Context, struct{}) error) (string, error) { + return input, nil + } + + addSuffix := func(next StreamingFunc[string, string, struct{}]) StreamingFunc[string, string, struct{}] { + return func(ctx context.Context, input string, cb func(context.Context, struct{}) error) (string, error) { + result, err := next(ctx, input, cb) + return result + "!", err + } + } + + chained := ChainMiddleware(addSuffix)(handler) + result, err := chained(context.Background(), "hello", nil) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != "hello!" { + t.Errorf("result = %q, want %q", result, "hello!") + } + }) + + t.Run("middleware can short-circuit", func(t *testing.T) { + handlerCalled := false + handler := func(ctx context.Context, input string, cb func(context.Context, struct{}) error) (string, error) { + handlerCalled = true + return input, nil + } + + shortCircuit := func(next StreamingFunc[string, string, struct{}]) StreamingFunc[string, string, struct{}] { + return func(ctx context.Context, input string, cb func(context.Context, struct{}) error) (string, error) { + if input == "skip" { + return "skipped", nil + } + return next(ctx, input, cb) + } + } + + chained := ChainMiddleware(shortCircuit)(handler) + result, err := chained(context.Background(), "skip", nil) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if handlerCalled { + t.Error("handler should not have been called") + } + if result != "skipped" { + t.Errorf("result = %q, want %q", result, "skipped") + } + }) +} diff --git a/go/core/schemas.config b/go/core/schemas.config index 746d10ca4c..70798f2eb3 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -1,6 +1,866 @@ # This file holds configuration for the genkit-schema.json file # generated by the npm export:schemas script. +# ============================================================================ +# DOCUMENTATION SECTION +# All type and field documentation in one consolidated location +# ============================================================================ + +# ---------------------------------------------------------------------------- +# Core Message Types +# ---------------------------------------------------------------------------- + +Role doc +Role indicates which entity is responsible for the content of a message. +. + +RoleSystem doc +RoleSystem indicates this message is user-independent context. +. + +RoleUser doc +RoleUser indicates this message was generated by the client. +. + +RoleModel doc +RoleModel indicates this message was generated by the model during a previous interaction. +. + +RoleTool doc +RoleTool indicates this message was generated by a local tool, likely triggered by a request +from the model in one of its previous responses. +. + +Message doc +Message represents the contents of a model message in a conversation. +. + +Message.role doc +Role indicates which entity (system, user, model, or tool) generated this message. +. + +Message.content doc +Content holds the message parts (text, media, tool calls, etc.). +. + +Message.metadata doc +Metadata contains arbitrary key-value data associated with this message. +. + +# ---------------------------------------------------------------------------- +# Part Types (Message Content) +# ---------------------------------------------------------------------------- + +TextPart.text doc +Text contains the textual content. +. + +TextPart.metadata doc +Metadata contains arbitrary key-value data for this part. +. + +MediaPart.media doc +Media contains the media content and metadata. +. + +MediaPart.metadata doc +Metadata contains arbitrary key-value data for this part. +. + +ToolRequestPart.toolRequest doc +ToolRequest is a request for a tool to be executed, usually provided by a model. +. + +ToolRequestPart.metadata doc +Metadata contains arbitrary key-value data for this part. +. + +ToolResponsePart.toolResponse doc +ToolResponse is a provided response to a tool call. +. + +ToolResponsePart.metadata doc +Metadata contains arbitrary key-value data for this part. +. + +DataPart.data doc +Data contains arbitrary structured data. +. + +DataPart.metadata doc +Metadata contains arbitrary key-value data for this part. +. + +ReasoningPart.reasoning doc +Reasoning contains the reasoning text of the message. +. + +ReasoningPart.metadata doc +Metadata contains arbitrary key-value data for this part. +. + +CustomPart.custom doc +Custom contains custom key-value data specific to this part. +. + +CustomPart.data doc +Data contains additional arbitrary data. +. + +CustomPart.metadata doc +Metadata contains arbitrary key-value data for this part. +. + +ResourcePart.resource doc +Resource contains a reference to an external resource by URI. +. + +ResourcePart.metadata doc +Metadata contains arbitrary key-value data for this part. +. + +ResourcePartResource.uri doc +Uri is the URI of the external resource. +. + +# ---------------------------------------------------------------------------- +# Media Types +# ---------------------------------------------------------------------------- + +Media doc +Media represents media content with a URL and content type. +. + +Media.contentType doc +ContentType specifies the MIME type of the media. Inferred from the data URI if not provided. +. + +Media.url doc +Url is a "data:" or "https:" URI containing the media content. +. + +# ---------------------------------------------------------------------------- +# Tool Types +# ---------------------------------------------------------------------------- + +ToolRequest doc +A ToolRequest is a message from the model to the client that it should run a +specific tool and pass a ToolResponse to the model on the next chat request it makes. +Any ToolRequest will correspond to some ToolDefinition previously sent by the client. +. + +ToolRequest.ref doc +Ref is the call ID or reference for this specific request. +. + +ToolRequest.name doc +Name is the name of the tool to call. +. + +ToolRequest.input doc +Input is a JSON object containing the input parameters for the tool. +For example: map[string]any{"country":"USA", "president":3}. +. + +ToolRequest.partial doc +Partial indicates whether this is a partial streaming chunk. +. + +ToolResponse doc +A ToolResponse is a message from the client to the model containing +the results of running a specific tool on the arguments passed to the client +by the model in a ToolRequest. +. + +ToolResponse.ref doc +Ref is the call ID or reference matching the original request. +. + +ToolResponse.name doc +Name is the name of the tool that was executed. +. + +ToolResponse.output doc +Output is a JSON object describing the results of running the tool. +For example: map[string]any{"name":"Thomas Jefferson", "born":1743}. +. + +ToolResponse.content doc +Content holds additional message parts that provide context or details about the tool response. +. + +ToolDefinition doc +A ToolDefinition describes a tool. +. + +ToolDefinition.name doc +Name is the unique identifier for this tool. +. + +ToolDefinition.description doc +Description explains what the tool does and when to use it. +. + +ToolDefinition.inputSchema doc +InputSchema is a valid JSON Schema representing the input parameters of the tool. +. + +ToolDefinition.outputSchema doc +OutputSchema is a valid JSON Schema describing the output of the tool. +. + +ToolDefinition.metadata doc +Metadata contains additional information about this tool definition. +. + +# ---------------------------------------------------------------------------- +# Generation Configuration +# ---------------------------------------------------------------------------- + +GenerationCommonConfig doc +GenerationCommonConfig holds configuration parameters for model generation requests. +. + +GenerationCommonConfig.version doc +Version specifies a particular version of a model family, +e.g., "gemini-1.0-pro-001" for the "gemini-1.0-pro" family. +. + +GenerationCommonConfig.temperature doc +Temperature controls randomness in generation. Higher values (e.g., 0.9) make output more random, +while lower values (e.g., 0.1) make it more deterministic. Typical range is 0.0 to 1.0. +. + +GenerationCommonConfig.maxOutputTokens doc +MaxOutputTokens limits the maximum number of tokens generated in the response. +. + +GenerationCommonConfig.topK doc +TopK limits sampling to the K most likely tokens at each step. +. + +GenerationCommonConfig.topP doc +TopP (nucleus sampling) limits sampling to tokens whose cumulative probability exceeds P. +. + +GenerationCommonConfig.stopSequences doc +StopSequences specifies sequences that will cause generation to stop when encountered. +. + +# ---------------------------------------------------------------------------- +# Generation Usage and Metrics +# ---------------------------------------------------------------------------- + +GenerationUsage doc +GenerationUsage provides information about resource consumption during generation. +. + +GenerationUsage.inputTokens doc +InputTokens is the number of tokens in the input prompt. +. + +GenerationUsage.outputTokens doc +OutputTokens is the number of tokens generated in the response. +. + +GenerationUsage.totalTokens doc +TotalTokens is the sum of input and output tokens. +. + +GenerationUsage.inputCharacters doc +InputCharacters is the number of characters in the input. +. + +GenerationUsage.outputCharacters doc +OutputCharacters is the number of characters generated in the output. +. + +GenerationUsage.inputImages doc +InputImages is the number of images in the input. +. + +GenerationUsage.outputImages doc +OutputImages is the number of images generated in the output. +. + +GenerationUsage.inputVideos doc +InputVideos is the number of videos in the input. +. + +GenerationUsage.outputVideos doc +OutputVideos is the number of videos generated in the output. +. + +GenerationUsage.inputAudioFiles doc +InputAudioFiles is the number of audio files in the input. +. + +GenerationUsage.outputAudioFiles doc +OutputAudioFiles is the number of audio files generated in the output. +. + +GenerationUsage.thoughtsTokens doc +ThoughtsTokens counts tokens used in reasoning or thinking processes. +. + +GenerationUsage.cachedContentTokens doc +CachedContentTokens counts tokens that were served from cache. +. + +GenerationUsage.custom doc +Custom contains additional usage metrics specific to the model provider. +. + +# ---------------------------------------------------------------------------- +# Model Request and Response +# ---------------------------------------------------------------------------- + +ModelRequest doc +A ModelRequest is a request to generate completions from a model. +. + +ModelRequest.messages doc +Messages contains the conversation history for the model. +. + +ModelRequest.config doc +Config holds model-specific configuration parameters. +. + +ModelRequest.docs doc +Docs provides retrieved documents to be used as context for this generation. +. + +ModelRequest.output doc +Output describes the desired response format. +. + +ModelRequest.tools doc +Tools lists the available tools that the model can ask the client to run. +. + +ModelRequest.toolChoice doc +ToolChoice controls how the model uses tools (auto, required, or none). +. + +ModelResponse doc +A ModelResponse is a model's response to a ModelRequest. +. + +ModelResponse.message doc +Message contains the generated response content. +. + +ModelResponse.finishReason doc +FinishReason indicates why generation stopped (e.g., stop, length, blocked). +. + +ModelResponse.finishMessage doc +FinishMessage provides additional details about why generation finished. +. + +ModelResponse.latencyMs doc +LatencyMs is the time the request took in milliseconds. +. + +ModelResponse.usage doc +Usage describes how many resources were used by this generation request. +. + +ModelResponse.custom doc +Custom contains model-specific extra information. Deprecated: use Raw instead. +. + +ModelResponse.raw doc +Raw contains the unprocessed model-specific response data. +. + +ModelResponse.request doc +Request is the ModelRequest struct used to trigger this response. +. + +ModelResponse.operation doc +Operation provides information about a long-running background task if applicable. +. + +ModelResponseChunk doc +A ModelResponseChunk is the portion of the ModelResponse +that is passed to a streaming callback. +. + +ModelResponseChunk.role doc +Role indicates the entity that generated this chunk. +. + +ModelResponseChunk.index doc +Index of the message this chunk belongs to. +. + +ModelResponseChunk.content doc +Content is the chunk of message parts to stream right now. +. + +ModelResponseChunk.custom doc +Custom contains model-specific extra information attached to this chunk. +. + +ModelResponseChunk.aggregated doc +Aggregated indicates whether the chunk includes all data from previous chunks. +If false, the chunk is considered incremental. +. + +# ---------------------------------------------------------------------------- +# Model Information and Capabilities +# ---------------------------------------------------------------------------- + +ModelInfo doc +ModelInfo contains metadata about a model's capabilities and characteristics. +. + +ModelInfo.versions doc +Versions lists acceptable names for this model (e.g., different versions). +. + +ModelInfo.label doc +Label is a friendly display name for this model (e.g., "Google AI - Gemini Pro"). +. + +ModelInfo.configSchema doc +ConfigSchema defines the model-specific configuration schema. +. + +ModelInfo.supports doc +Supports describes the capabilities that this model supports. +. + +ModelInfo.stage doc +Stage indicates the development stage of this model. +Featured models are recommended for general use, stable models are well-tested, +unstable models are experimental, legacy models are not recommended for new projects, +and deprecated models may be removed in future versions. +. + +ModelInfoSupports doc +ModelSupports describes the capabilities that a model supports. +. + +ModelInfoSupports.multiturn doc +Multiturn indicates whether the model can process historical messages passed with a prompt. +. + +ModelInfoSupports.media doc +Media indicates whether the model can process media as part of the prompt (multimodal input). +. + +ModelInfoSupports.tools doc +Tools indicates whether the model can perform tool calls. +. + +ModelInfoSupports.systemRole doc +SystemRole indicates whether the model can accept messages with role "system". +. + +ModelInfoSupports.output doc +Output lists the types of data the model can generate. +. + +ModelInfoSupports.contentType doc +ContentType lists the content types the model supports for output. +. + +ModelInfoSupports.context doc +Context indicates whether the model can natively support document-based context grounding. +. + +ModelInfoSupports.constrained doc +Constrained indicates the level of constrained generation support (none, all, or no-tools). +. + +ModelInfoSupports.toolChoice doc +ToolChoice indicates whether the model supports controlling tool choice (e.g., forced tool calling). +. + +ModelInfoSupports.longRunning doc +LongRunning indicates whether the model supports long-running operations. +. + +# ---------------------------------------------------------------------------- +# Output Configuration +# ---------------------------------------------------------------------------- + +OutputConfig doc +OutputConfig describes the structure that the model's output +should conform to. If Format is OutputFormatJSON, then Schema +can describe the desired form of the generated JSON. +. + +OutputConfig.format doc +Format specifies the desired output format (e.g., "json", "text"). +. + +OutputConfig.schema doc +Schema is a JSON Schema describing the desired structure of the output. +. + +OutputConfig.constrained doc +Constrained indicates whether to enforce strict adherence to the schema. +. + +OutputConfig.contentType doc +ContentType specifies the MIME type of the output content. +. + +# ---------------------------------------------------------------------------- +# Operation Types +# ---------------------------------------------------------------------------- + +Operation doc +Operation represents a long-running background task. +. + +Operation.action doc +Action is the name of the action being performed by this operation. +. + +Operation.id doc +Id is the unique identifier for this operation. +. + +Operation.done doc +Done indicates whether the operation has completed. +. + +Operation.output doc +Output contains the result of the operation if it has completed successfully. +. + +Operation.error doc +Error contains error information if the operation failed. +. + +Operation.metadata doc +Metadata contains additional information about the operation. +. + +OperationError doc +OperationError contains error information for a failed operation. +. + +OperationError.message doc +Message describes the error that occurred. +. + +# ---------------------------------------------------------------------------- +# Document Types +# ---------------------------------------------------------------------------- + +# Note: Document type is hand-written in ai/document.go, not generated + +# ---------------------------------------------------------------------------- +# Embedding Types +# ---------------------------------------------------------------------------- + +Embedding doc +Embedding represents a vector embedding with associated metadata. +. + +Embedding.embedding doc +Embedding is the vector representation of the input. +. + +Embedding.metadata doc +Metadata identifies which part of a document this embedding corresponds to. +. + +EmbedRequest doc +EmbedRequest represents a request to generate embeddings for documents. +. + +EmbedRequest.input doc +Input is the array of documents to generate embeddings for. +. + +EmbedRequest.options doc +Options contains embedder-specific configuration parameters. +. + +EmbedResponse doc +EmbedResponse contains the generated embeddings from an embed request. +. + +EmbedResponse.embeddings doc +Embeddings is the array of generated embedding vectors with metadata. +. + +# ---------------------------------------------------------------------------- +# Evaluator Types (ScoreDetails only - other eval types are omitted) +# ---------------------------------------------------------------------------- + +ScoreDetails doc +ScoreDetails provides additional context and explanation for an evaluation score. +. + +ScoreDetails.reasoning doc +Reasoning explains the rationale behind the score. +. + +# ---------------------------------------------------------------------------- +# Retriever Types +# ---------------------------------------------------------------------------- + +RetrieverRequest doc +RetrieverRequest represents a request to retrieve relevant documents. +. + +RetrieverRequest.query doc +Query is the document to use for retrieval. +. + +RetrieverRequest.options doc +Options contains retriever-specific configuration parameters. +. + +RetrieverResponse doc +RetrieverResponse contains the retrieved documents from a retriever request. +. + +RetrieverResponse.documents doc +Documents is the array of retrieved documents. +. + +# ---------------------------------------------------------------------------- +# Reranker Types +# ---------------------------------------------------------------------------- + +RerankerRequest doc +RerankerRequest represents a request to rerank documents based on relevance. +. + +RerankerRequest.query doc +Query is the document to use for reranking. +. + +RerankerRequest.documents doc +Documents is the array of documents to rerank. +. + +RerankerRequest.options doc +Options contains reranker-specific configuration parameters. +. + +RerankerResponse doc +RerankerResponse contains the reranked documents with relevance scores. +. + +RerankerResponse.documents doc +Documents is the array of reranked documents with scores. +. + +RankedDocumentData doc +RankedDocumentData represents a document with a relevance score from reranking. +. + +RankedDocumentData.content doc +Content holds the document's parts (text and media). +. + +RankedDocumentData.metadata doc +Metadata contains the reranking score and other arbitrary key-value data. +. + +RankedDocumentMetadata doc +RankedDocumentMetadata contains the relevance score and other metadata for a reranked document. +. + +RankedDocumentMetadata.score doc +Score is the relevance score assigned by the reranker. +. + +# ---------------------------------------------------------------------------- +# GenerateAction Types +# ---------------------------------------------------------------------------- + +GenerateActionOptions doc +GenerateActionOptions holds configuration for a generate action request. +. + +GenerateActionOptions.model doc +Model is a model name (e.g., "vertexai/gemini-1.0-pro"). +. + +GenerateActionOptions.docs doc +Docs provides retrieved documents to be used as context for this generation. +. + +GenerateActionOptions.messages doc +Messages contains the conversation history for multi-turn prompting when supported. +. + +GenerateActionOptions.tools doc +Tools is a list of registered tool names for this generation if supported. +. + +GenerateActionOptions.toolChoice doc +ToolChoice controls tool calling mode. Auto lets the model decide, required forces +the model to choose a tool, and none forces the model not to use any tools. Defaults to auto. +. + +GenerateActionOptions.config doc +Config contains configuration parameters for the generation request. +. + +GenerateActionOptions.output doc +Output specifies the desired output format. Defaults to the model's default if unspecified. +. + +GenerateActionOptions.resume doc +Resume provides options for resuming an interrupted generation. +. + +GenerateActionOptions.returnToolRequests doc +ReturnToolRequests, when true, returns tool calls for manual processing instead of +automatically resolving them. +. + +GenerateActionOptions.maxTurns doc +MaxTurns is the maximum number of tool call iterations that can be performed +in a single generate call. Defaults to 5. +. + +GenerateActionOptions.stepName doc +StepName is a custom step name for this generate call to display in trace views. +Defaults to "generate". +. + +GenerateActionOptionsResume doc +GenerateActionResume holds options for resuming an interrupted generation. +. + +GenerateActionOptionsResume.respond doc +Respond contains tool response parts to send to the model when resuming. +. + +GenerateActionOptionsResume.restart doc +Restart contains tool request parts to restart when resuming. +. + +GenerateActionOptionsResume.metadata doc +Metadata contains additional context for resuming the generation. +. + +GenerateActionOutputConfig doc +GenerateActionOutputConfig specifies the desired output format for a generate action. +. + +GenerateActionOutputConfig.format doc +Format specifies the desired output format (e.g., "json", "text"). +. + +GenerateActionOutputConfig.contentType doc +ContentType specifies the MIME type of the output content. +. + +GenerateActionOutputConfig.instructions doc +Instructions provides additional guidance for the output format. +. + +GenerateActionOutputConfig.jsonSchema doc +JsonSchema is a JSON Schema describing the desired structure of JSON output. +. + +GenerateActionOutputConfig.constrained doc +Constrained indicates whether to enforce strict adherence to the schema. +. + +GenerateActionOptionsToolChoice doc +ToolChoice controls how the model uses tools. +. + +# ---------------------------------------------------------------------------- +# Finish Reason Enum +# ---------------------------------------------------------------------------- + +FinishReason doc +FinishReason indicates why generation stopped. +. + +# ---------------------------------------------------------------------------- +# Model Stage Enum +# ---------------------------------------------------------------------------- + +ModelInfoStage doc +ModelStage indicates the development stage of a model. +. + +# ---------------------------------------------------------------------------- +# Constrained Support Enum +# ---------------------------------------------------------------------------- + +ModelInfoSupportsConstrained doc +ConstrainedSupport indicates the level of constrained generation support. +. + +# ---------------------------------------------------------------------------- +# Trace Metadata Types +# ---------------------------------------------------------------------------- + +TraceMetadata doc +TraceMetadata contains metadata about a trace execution. +. + +TraceMetadata.featureName doc +FeatureName identifies the feature being traced. +. + +TraceMetadata.paths doc +Paths contains metadata for each path executed during the trace. +. + +TraceMetadata.timestamp doc +Timestamp is when the trace was created. +. + +PathMetadata doc +PathMetadata contains metadata about a single execution path in a trace. +. + +PathMetadata.path doc +Path is the identifier for this execution path. +. + +PathMetadata.status doc +Status indicates the outcome of this path. +. + +PathMetadata.latency doc +Latency is the execution time for this path in milliseconds. +. + +PathMetadata.error doc +Error contains error information if the path failed. +. + +# ---------------------------------------------------------------------------- +# Multipart Tool Response +# ---------------------------------------------------------------------------- + +MultipartToolResponse doc +MultipartToolResponse represents a tool response with both structured output and content parts. +. + +MultipartToolResponse.output doc +Output contains the structured output data from the tool. +. + +MultipartToolResponse.content doc +Content holds additional message parts providing context or details. +. + +# ============================================================================ +# CONFIGURATION SECTION +# Type mappings, omissions, and other non-documentation directives +# ============================================================================ + # DocumentData type was hand-written. DocumentData omit @@ -28,52 +888,30 @@ TimeEventAnnotation omit TraceData omit SpanStartEvent omit SpanEndEvent omit -SpanEventBase omit +# Typo in schema definition... +SpantEventBase omit TraceEvent omit GenerationCommonConfig.maxOutputTokens type int GenerationCommonConfig.topK type int -Role doc -Role indicates which entity is responsible for the content of a message. -. -RoleSystem doc -RoleSystem indicates this message is user-independent context. -. -RoleUser doc -RoleUser indicates this message was generated by the client. -. -RoleModel doc -RoleModel indicates this message was generated by the model during a previous interaction. -. -RoleTool doc -RoleTool indicates this message was generated by a local tool, likely triggered by a request -from the model in one of its previous responses. -. - -ToolRequest.input doc -Input is a JSON object describing the input values to the tool. -An example might be map[string]any{"country":"USA", "president":3}. -. -ToolResponse.output doc -Output is a JSON object describing the results of running the tool. -An example might be map[string]any{"name":"Thomas Jefferson", "born":1743}. -. - -ToolRequest doc -A ToolRequest is a message from the model to the client that it should run a -specific tool and pass a [ToolResponse] to the model on the next chat request it makes. -Any ToolRequest will correspond to some [ToolDefinition] previously sent by the client. -. -ToolResponse doc -A ToolResponse is a message from the client to the model containing -the results of running a specific tool on the arguments passed to the client -by the model in a [ToolRequest]. -. - +# Unused evaluation types +BaseDataPoint omit +BaseEvalDataPoint omit +EvalFnResponse omit +EvalRequest omit +EvalResponse omit +EvalStatusEnum omit +# Unused error types +CandidateError omit +CandidateErrorCode omit Candidate omit +# Unused retriever/reranker option types +CommonRerankerOptions omit +CommonRetrieverOptions omit + DocumentData pkg ai GenerateResponse omit @@ -96,9 +934,7 @@ GenerationUsage.outputTokens type int GenerationUsage.totalTokens type int GenerationUsage.thoughtsTokens type int GenerationUsage.cachedContentTokens type int -GenerationUsage doc -GenerationUsage provides information about the generation process. -. + GenerationCommonConfig pkg ai Message pkg ai @@ -213,8 +1049,6 @@ RoleUser pkg ai RoleModel pkg ai RoleTool pkg ai -EvalResponse type []any - # GenerateActionOptions GenerateActionOptions pkg ai GenerateActionOptions.model type string @@ -235,18 +1069,6 @@ GenerateActionOutputConfig.jsonSchema name Schema GenerateActionOutputConfig.jsonSchema type map[string]any GenerateActionOutputConfig.constrained type bool -BaseDataPoint.context type map[string]any -BaseDataPoint.input type map[string]any -BaseDataPoint.output type map[string]any -BaseDataPoint.reference type map[string]any -BaseDataPoint.traceIds type []string - -BaseEvalDataPoint.context type map[string]any -BaseEvalDataPoint.input type map[string]any -BaseEvalDataPoint.output type map[string]any -BaseEvalDataPoint.reference type map[string]any -BaseEvalDataPoint.traceIds type []string - # ModelRequest ModelRequest pkg ai ModelRequest.config type any @@ -279,52 +1101,6 @@ ModelResponseChunk.index type int ModelResponseChunk.role type Role ModelResponseChunk field formatHandler StreamingFormatHandler -GenerationCommonConfig doc -GenerationCommonConfig holds configuration for generation. -. - -Message doc -Message is the contents of a model response. -. - -ToolDefinition doc -A ToolDefinition describes a tool. -. - -ModelRequest doc -A ModelRequest is a request to generate completions from a model. -. -ModelRequest.output doc -Output describes the desired response format. -. -ModelRequest.tools doc -Tools lists the available tools that the model can ask the client to run. -. - -OutputConfig doc -OutputConfig describes the structure that the model's output -should conform to. If Format is [OutputFormatJSON], then Schema -can describe the desired form of the generated JSON. -. - -ModelResponse doc -A ModelResponse is a model's response to a [ModelRequest]. -. -ModelResponse.latencyMs doc -LatencyMs is the time the request took in milliseconds. -. -ModelResponse.request doc -Request is the [ModelRequest] struct used to trigger this response. -. -ModelResponse.usage doc -Usage describes how many resources were used by this generation request. -. - -ModelResponseChunk doc -A ModelResponseChunk is the portion of the [ModelResponse] -that is passed to a streaming callback. -. - Score omit Embedding.embedding type []float32 diff --git a/go/core/status_types_test.go b/go/core/status_types_test.go new file mode 100644 index 0000000000..eec8d7c124 --- /dev/null +++ b/go/core/status_types_test.go @@ -0,0 +1,123 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package core + +import ( + "net/http" + "testing" +) + +func TestHTTPStatusCode(t *testing.T) { + tests := []struct { + name string + status StatusName + wantCode int + }{ + {"OK", OK, http.StatusOK}, + {"CANCELLED", CANCELLED, 499}, + {"UNKNOWN", UNKNOWN, http.StatusInternalServerError}, + {"INVALID_ARGUMENT", INVALID_ARGUMENT, http.StatusBadRequest}, + {"DEADLINE_EXCEEDED", DEADLINE_EXCEEDED, http.StatusGatewayTimeout}, + {"NOT_FOUND", NOT_FOUND, http.StatusNotFound}, + {"ALREADY_EXISTS", ALREADY_EXISTS, http.StatusConflict}, + {"PERMISSION_DENIED", PERMISSION_DENIED, http.StatusForbidden}, + {"UNAUTHENTICATED", UNAUTHENTICATED, http.StatusUnauthorized}, + {"RESOURCE_EXHAUSTED", RESOURCE_EXHAUSTED, http.StatusTooManyRequests}, + {"FAILED_PRECONDITION", FAILED_PRECONDITION, http.StatusBadRequest}, + {"ABORTED", ABORTED, http.StatusConflict}, + {"OUT_OF_RANGE", OUT_OF_RANGE, http.StatusBadRequest}, + {"UNIMPLEMENTED", UNIMPLEMENTED, http.StatusNotImplemented}, + {"INTERNAL", INTERNAL, http.StatusInternalServerError}, + {"UNAVAILABLE", UNAVAILABLE, http.StatusServiceUnavailable}, + {"DATA_LOSS", DATA_LOSS, http.StatusInternalServerError}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := HTTPStatusCode(tt.status) + if got != tt.wantCode { + t.Errorf("HTTPStatusCode(%q) = %d, want %d", tt.status, got, tt.wantCode) + } + }) + } + + t.Run("unknown status returns 500", func(t *testing.T) { + got := HTTPStatusCode(StatusName("UNKNOWN_STATUS")) + if got != http.StatusInternalServerError { + t.Errorf("HTTPStatusCode(unknown) = %d, want %d", got, http.StatusInternalServerError) + } + }) +} + +func TestNewStatus(t *testing.T) { + t.Run("creates status with name and message", func(t *testing.T) { + s := NewStatus(NOT_FOUND, "resource not found") + + if s.Name != NOT_FOUND { + t.Errorf("Name = %q, want %q", s.Name, NOT_FOUND) + } + if s.Message != "resource not found" { + t.Errorf("Message = %q, want %q", s.Message, "resource not found") + } + }) + + t.Run("creates status with empty message", func(t *testing.T) { + s := NewStatus(OK, "") + + if s.Name != OK { + t.Errorf("Name = %q, want %q", s.Name, OK) + } + if s.Message != "" { + t.Errorf("Message = %q, want empty string", s.Message) + } + }) +} + +func TestStatusNameToCode(t *testing.T) { + t.Run("maps all status names to codes", func(t *testing.T) { + expectedMappings := map[StatusName]int{ + OK: CodeOK, + CANCELLED: CodeCancelled, + UNKNOWN: CodeUnknown, + INVALID_ARGUMENT: CodeInvalidArgument, + DEADLINE_EXCEEDED: CodeDeadlineExceeded, + NOT_FOUND: CodeNotFound, + ALREADY_EXISTS: CodeAlreadyExists, + PERMISSION_DENIED: CodePermissionDenied, + UNAUTHENTICATED: CodeUnauthenticated, + RESOURCE_EXHAUSTED: CodeResourceExhausted, + FAILED_PRECONDITION: CodeFailedPrecondition, + ABORTED: CodeAborted, + OUT_OF_RANGE: CodeOutOfRange, + UNIMPLEMENTED: CodeUnimplemented, + INTERNAL: CodeInternal, + UNAVAILABLE: CodeUnavailable, + DATA_LOSS: CodeDataLoss, + } + + for name, wantCode := range expectedMappings { + got, ok := StatusNameToCode[name] + if !ok { + t.Errorf("StatusNameToCode missing mapping for %q", name) + continue + } + if got != wantCode { + t.Errorf("StatusNameToCode[%q] = %d, want %d", name, got, wantCode) + } + } + }) +} diff --git a/go/core/tracing/doc.go b/go/core/tracing/doc.go new file mode 100644 index 0000000000..aae609c517 --- /dev/null +++ b/go/core/tracing/doc.go @@ -0,0 +1,109 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +/* +Package tracing provides execution trace support for Genkit operations. + +This package implements OpenTelemetry-based tracing for Genkit actions and flows. +Traces capture the execution path, inputs, outputs, and timing of operations, +enabling observability and debugging through the Genkit Developer UI and +external telemetry systems. + +# Automatic Tracing + +Actions and flows defined with Genkit are automatically traced. Each action +execution creates a span with input/output data, timing, and any errors. +Use [core.Run] within flows to create traced sub-steps: + + // In a real scenario, 'r' would be the registry from your Genkit instance. + var r api.Registry + flow := core.DefineFlow(r, "myFlow", + func(ctx context.Context, input string) (string, error) { + // This creates a traced step named "processData" + result, err := core.Run(ctx, "processData", func() (string, error) { + return process(input), nil + }) + return result, err + }, + ) + +# Tracer Access + +Access the OpenTelemetry tracer provider for custom instrumentation: + + provider := tracing.TracerProvider() + + // Get a tracer for custom spans + tracer := tracing.Tracer() + +# Telemetry Export + +Configure trace export to send telemetry to external systems. For immediate +export (suitable for local storage): + + tracing.WriteTelemetryImmediate(client) + +For batched export (more efficient for network calls): + + shutdown := tracing.WriteTelemetryBatch(client) + defer shutdown(ctx) + +# Dev UI Integration + +When the GENKIT_ENV environment variable is set to "dev", traces are +automatically sent to the Genkit Developer UI's telemetry server. The Dev UI +provides: + + - Visual trace exploration with timing breakdown + - Input/output inspection for each action + - Error highlighting and stack traces + - Performance analysis across flow executions + +Set GENKIT_TELEMETRY_SERVER to configure a custom telemetry endpoint. + +# Span Metadata + +Create spans with rich metadata for better observability: + + metadata := &tracing.SpanMetadata{ + Name: "processDocument", + Type: "action", + Subtype: "retriever", + } + + output, err := tracing.RunInNewSpan(ctx, metadata, input, + func(ctx context.Context, in Input) (Output, error) { + // Operation runs within the traced span + return process(in), nil + }, + ) + +# Trace Information + +Extract trace context for correlation with external systems: + + info := tracing.GetTraceInfo(ctx) + if info != nil { + log.Printf("TraceID: %s, SpanID: %s", info.TraceID, info.SpanID) + } + +This package is primarily intended for Genkit internals and advanced plugin +development. Most application developers will interact with tracing through +the automatic instrumentation provided by the genkit package. + +For more information on observability, see https://genkit.dev/docs/observability +*/ +package tracing diff --git a/go/core/x/streaming/streaming.go b/go/core/x/streaming/streaming.go new file mode 100644 index 0000000000..3fb51a3341 --- /dev/null +++ b/go/core/x/streaming/streaming.go @@ -0,0 +1,382 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// Package streaming provides experimental durable streaming APIs for Genkit. +// +// APIs in this package are under active development and may change in any +// minor version release. Use with caution in production environments. +// +// When these APIs stabilize, they will be moved to their parent packages +// (e.g., core and genkit) and these exports will be deprecated. +package streaming + +import ( + "context" + "encoding/json" + "sync" + "time" + + "github.com/firebase/genkit/go/core" +) + +// StreamEventType indicates the type of stream event. +type StreamEventType int + +const ( + StreamEventChunk StreamEventType = iota + StreamEventDone + StreamEventError +) + +// StreamEvent represents an event in a durable stream. +type StreamEvent struct { + Type StreamEventType + Chunk json.RawMessage // set when Type == StreamEventChunk + Output json.RawMessage // set when Type == StreamEventDone + Err error // set when Type == StreamEventError +} + +// StreamInput provides methods for writing to a durable stream. +type StreamInput interface { + // Write sends a chunk to the stream and notifies all subscribers. + Write(ctx context.Context, chunk json.RawMessage) error + // Done marks the stream as successfully completed with the given output. + Done(ctx context.Context, output json.RawMessage) error + // Error marks the stream as failed with the given error. + Error(ctx context.Context, err error) error + // Close releases resources without marking the stream as done or errored. + Close() error +} + +// StreamManager manages durable streams, allowing creation and subscription. +// Implementations can provide different storage backends (e.g., in-memory, database, cache). +type StreamManager interface { + // Open creates a new stream for writing. + // Returns an error if a stream with the given ID already exists. + Open(ctx context.Context, streamID string) (StreamInput, error) + // Subscribe subscribes to an existing stream. + // Returns a channel that receives stream events, an unsubscribe function, and an error. + // If the stream has already completed, all buffered events are sent before the done/error event. + // Returns NOT_FOUND error if the stream doesn't exist. + Subscribe(ctx context.Context, streamID string) (<-chan StreamEvent, func(), error) +} + +// inMemoryStreamBufferSize is the buffer size for subscriber event channels. +const inMemoryStreamBufferSize = 100 + +// streamStatus represents the current state of a stream. +type streamStatus int + +const ( + streamStatusOpen streamStatus = iota + streamStatusDone + streamStatusError +) + +// streamState holds the internal state of a single stream. +type streamState struct { + status streamStatus + chunks []json.RawMessage + output json.RawMessage + err error + subscribers []chan StreamEvent + lastTouched time.Time + mu sync.RWMutex +} + +// InMemoryStreamManager is an in-memory implementation of StreamManager. +// Useful for testing or single-instance deployments where persistence is not required. +// Call Close to stop the background cleanup goroutine when the manager is no longer needed. +type InMemoryStreamManager struct { + streams map[string]*streamState + mu sync.RWMutex + ttl time.Duration + stopCh chan struct{} + doneCh chan struct{} +} + +// StreamManagerOption configures an InMemoryStreamManager. +type StreamManagerOption interface { + applyInMemoryStreamManager(*streamManagerOptions) +} + +// streamManagerOptions holds configuration for InMemoryStreamManager. +type streamManagerOptions struct { + TTL time.Duration // Time-to-live for completed streams. +} + +func (o *streamManagerOptions) applyInMemoryStreamManager(opts *streamManagerOptions) { + if o.TTL > 0 { + opts.TTL = o.TTL + } +} + +// WithTTL sets the time-to-live for completed streams. +// Streams that have completed (done or error) will be cleaned up after this duration. +// Default is 5 minutes. +func WithTTL(ttl time.Duration) StreamManagerOption { + return &streamManagerOptions{TTL: ttl} +} + +// NewInMemoryStreamManager creates a new InMemoryStreamManager. +// A background goroutine is started to periodically clean up expired streams. +// Call Close to stop the goroutine when the manager is no longer needed. +func NewInMemoryStreamManager(opts ...StreamManagerOption) *InMemoryStreamManager { + options := &streamManagerOptions{ + TTL: 5 * time.Minute, + } + for _, opt := range opts { + opt.applyInMemoryStreamManager(options) + } + m := &InMemoryStreamManager{ + streams: make(map[string]*streamState), + ttl: options.TTL, + stopCh: make(chan struct{}), + doneCh: make(chan struct{}), + } + go m.cleanupLoop() + return m +} + +// cleanupLoop runs periodically to remove expired streams. +func (m *InMemoryStreamManager) cleanupLoop() { + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + defer close(m.doneCh) + + for { + select { + case <-m.stopCh: + return + case <-ticker.C: + m.cleanupExpiredStreams() + } + } +} + +// cleanupExpiredStreams removes streams that have completed and exceeded the TTL. +func (m *InMemoryStreamManager) cleanupExpiredStreams() { + now := time.Now() + m.mu.Lock() + defer m.mu.Unlock() + + for id, state := range m.streams { + state.mu.RLock() + shouldDelete := state.status != streamStatusOpen && now.Sub(state.lastTouched) > m.ttl + state.mu.RUnlock() + if shouldDelete { + delete(m.streams, id) + } + } +} + +// Close stops the background cleanup goroutine and releases resources. +// This method blocks until the cleanup goroutine has stopped. +func (m *InMemoryStreamManager) Close() { + close(m.stopCh) + <-m.doneCh +} + +// Open creates a new stream for writing. +func (m *InMemoryStreamManager) Open(ctx context.Context, streamID string) (StreamInput, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if _, exists := m.streams[streamID]; exists { + return nil, core.NewPublicError(core.ALREADY_EXISTS, "stream already exists", nil) + } + + state := &streamState{ + status: streamStatusOpen, + chunks: make([]json.RawMessage, 0), + subscribers: make([]chan StreamEvent, 0), + lastTouched: time.Now(), + } + m.streams[streamID] = state + + return &inMemoryStreamInput{ + manager: m, + streamID: streamID, + state: state, + }, nil +} + +// Subscribe subscribes to an existing stream. +func (m *InMemoryStreamManager) Subscribe(ctx context.Context, streamID string) (<-chan StreamEvent, func(), error) { + m.mu.RLock() + state, exists := m.streams[streamID] + m.mu.RUnlock() + + if !exists { + return nil, nil, core.NewPublicError(core.NOT_FOUND, "stream not found", nil) + } + + ch := make(chan StreamEvent, inMemoryStreamBufferSize) + + state.mu.Lock() + defer state.mu.Unlock() + + // Send all buffered chunks + for _, chunk := range state.chunks { + select { + case ch <- StreamEvent{Type: StreamEventChunk, Chunk: chunk}: + case <-ctx.Done(): + close(ch) + return nil, nil, ctx.Err() + } + } + + // Handle completed streams + switch state.status { + case streamStatusDone: + ch <- StreamEvent{Type: StreamEventDone, Output: state.output} + close(ch) + return ch, func() {}, nil + case streamStatusError: + ch <- StreamEvent{Type: StreamEventError, Err: state.err} + close(ch) + return ch, func() {}, nil + } + + // Stream is still open, add subscriber + state.subscribers = append(state.subscribers, ch) + + unsubscribe := func() { + state.mu.Lock() + defer state.mu.Unlock() + for i, sub := range state.subscribers { + if sub == ch { + state.subscribers = append(state.subscribers[:i], state.subscribers[i+1:]...) + close(ch) + break + } + } + } + + return ch, unsubscribe, nil +} + +// inMemoryStreamInput implements ActionStreamInput for the in-memory manager. +type inMemoryStreamInput struct { + manager *InMemoryStreamManager + streamID string + state *streamState + closed bool + mu sync.Mutex +} + +func (s *inMemoryStreamInput) Write(_ context.Context, chunk json.RawMessage) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return core.NewPublicError(core.FAILED_PRECONDITION, "stream writer is closed", nil) + } + + s.state.mu.Lock() + defer s.state.mu.Unlock() + + if s.state.status != streamStatusOpen { + return core.NewPublicError(core.FAILED_PRECONDITION, "stream has already completed", nil) + } + + s.state.chunks = append(s.state.chunks, chunk) + s.state.lastTouched = time.Now() + + event := StreamEvent{Type: StreamEventChunk, Chunk: chunk} + for _, ch := range s.state.subscribers { + select { + case ch <- event: + default: + // Channel full, skip (subscriber is slow) + } + } + + return nil +} + +func (s *inMemoryStreamInput) Done(_ context.Context, output json.RawMessage) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return core.NewPublicError(core.FAILED_PRECONDITION, "stream writer is closed", nil) + } + s.closed = true + + s.state.mu.Lock() + defer s.state.mu.Unlock() + + if s.state.status != streamStatusOpen { + return core.NewPublicError(core.FAILED_PRECONDITION, "stream has already completed", nil) + } + + s.state.status = streamStatusDone + s.state.output = output + s.state.lastTouched = time.Now() + + event := StreamEvent{Type: StreamEventDone, Output: output} + for _, ch := range s.state.subscribers { + select { + case ch <- event: + default: + } + close(ch) + } + s.state.subscribers = nil + + return nil +} + +func (s *inMemoryStreamInput) Error(_ context.Context, err error) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return core.NewPublicError(core.FAILED_PRECONDITION, "stream writer is closed", nil) + } + s.closed = true + + s.state.mu.Lock() + defer s.state.mu.Unlock() + + if s.state.status != streamStatusOpen { + return core.NewPublicError(core.FAILED_PRECONDITION, "stream has already completed", nil) + } + + s.state.status = streamStatusError + s.state.err = err + s.state.lastTouched = time.Now() + + event := StreamEvent{Type: StreamEventError, Err: err} + for _, ch := range s.state.subscribers { + select { + case ch <- event: + default: + } + close(ch) + } + s.state.subscribers = nil + + return nil +} + +func (s *inMemoryStreamInput) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + s.closed = true + return nil +} diff --git a/go/core/x/streaming/streaming_test.go b/go/core/x/streaming/streaming_test.go new file mode 100644 index 0000000000..e86ce6f6e0 --- /dev/null +++ b/go/core/x/streaming/streaming_test.go @@ -0,0 +1,789 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package streaming + +import ( + "context" + "encoding/json" + "errors" + "sync" + "testing" + "time" + + "github.com/firebase/genkit/go/core" +) + +func TestInMemoryStreamManager_OpenAndSubscribe(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-1" + + // Open a new stream + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + if writer == nil { + t.Fatal("Open returned nil writer") + } + + // Subscribe to the stream + events, unsubscribe, err := m.Subscribe(ctx, streamID) + if err != nil { + t.Fatalf("Subscribe failed: %v", err) + } + defer unsubscribe() + + if events == nil { + t.Fatal("Subscribe returned nil channel") + } +} + +func TestInMemoryStreamManager_OpenDuplicateFails(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-dup" + + // Open first stream + _, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("First Open failed: %v", err) + } + + // Try to open duplicate + _, err = m.Open(ctx, streamID) + if err == nil { + t.Fatal("Expected error when opening duplicate stream") + } + + var ufErr *core.UserFacingError + if !errors.As(err, &ufErr) { + t.Fatalf("Expected UserFacingError, got %T", err) + } + if ufErr.Status != core.ALREADY_EXISTS { + t.Errorf("Expected ALREADY_EXISTS status, got %v", ufErr.Status) + } +} + +func TestInMemoryStreamManager_SubscribeNonExistent(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + + _, _, err := m.Subscribe(ctx, "non-existent") + if err == nil { + t.Fatal("Expected error when subscribing to non-existent stream") + } + + var ufErr *core.UserFacingError + if !errors.As(err, &ufErr) { + t.Fatalf("Expected UserFacingError, got %T", err) + } + if ufErr.Status != core.NOT_FOUND { + t.Errorf("Expected NOT_FOUND status, got %v", ufErr.Status) + } +} + +func TestInMemoryStreamManager_WriteAndReceiveChunks(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-chunks" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + events, unsubscribe, err := m.Subscribe(ctx, streamID) + if err != nil { + t.Fatalf("Subscribe failed: %v", err) + } + defer unsubscribe() + + // Write chunks + chunks := []string{"chunk1", "chunk2", "chunk3"} + for _, chunk := range chunks { + if err := writer.Write(ctx, json.RawMessage(`"`+chunk+`"`)); err != nil { + t.Fatalf("Write failed: %v", err) + } + } + + // Read chunks + for i, expected := range chunks { + select { + case event := <-events: + if event.Type != StreamEventChunk { + t.Errorf("Expected chunk event, got %v", event.Type) + } + var got string + if err := json.Unmarshal(event.Chunk, &got); err != nil { + t.Fatalf("Failed to unmarshal chunk: %v", err) + } + if got != expected { + t.Errorf("Chunk %d: expected %q, got %q", i, expected, got) + } + case <-time.After(time.Second): + t.Fatalf("Timeout waiting for chunk %d", i) + } + } +} + +func TestInMemoryStreamManager_Done(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-done" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + events, unsubscribe, err := m.Subscribe(ctx, streamID) + if err != nil { + t.Fatalf("Subscribe failed: %v", err) + } + defer unsubscribe() + + // Write a chunk + if err := writer.Write(ctx, json.RawMessage(`"test-chunk"`)); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Mark as done + output := json.RawMessage(`{"result": "success"}`) + if err := writer.Done(ctx, output); err != nil { + t.Fatalf("Done failed: %v", err) + } + + // Should receive chunk then done + select { + case event := <-events: + if event.Type != StreamEventChunk { + t.Errorf("Expected chunk event first, got %v", event.Type) + } + case <-time.After(time.Second): + t.Fatal("Timeout waiting for chunk") + } + + select { + case event := <-events: + if event.Type != StreamEventDone { + t.Errorf("Expected done event, got %v", event.Type) + } + if string(event.Output) != string(output) { + t.Errorf("Expected output %s, got %s", output, event.Output) + } + case <-time.After(time.Second): + t.Fatal("Timeout waiting for done event") + } +} + +func TestInMemoryStreamManager_Error(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-error" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + events, unsubscribe, err := m.Subscribe(ctx, streamID) + if err != nil { + t.Fatalf("Subscribe failed: %v", err) + } + defer unsubscribe() + + // Mark as error + streamErr := core.NewPublicError(core.INTERNAL, "test error", nil) + if err := writer.Error(ctx, streamErr); err != nil { + t.Fatalf("Error failed: %v", err) + } + + select { + case event := <-events: + if event.Type != StreamEventError { + t.Errorf("Expected error event, got %v", event.Type) + } + if event.Err == nil { + t.Error("Expected error to be set") + } + case <-time.After(time.Second): + t.Fatal("Timeout waiting for error event") + } +} + +func TestInMemoryStreamManager_WriteAfterDone(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-write-after-done" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + if err := writer.Done(ctx, json.RawMessage(`"done"`)); err != nil { + t.Fatalf("Done failed: %v", err) + } + + // Try to write after done + err = writer.Write(ctx, json.RawMessage(`"chunk"`)) + if err == nil { + t.Fatal("Expected error when writing after done") + } + + var ufErr *core.UserFacingError + if !errors.As(err, &ufErr) { + t.Fatalf("Expected UserFacingError, got %T", err) + } + if ufErr.Status != core.FAILED_PRECONDITION { + t.Errorf("Expected FAILED_PRECONDITION status, got %v", ufErr.Status) + } +} + +func TestInMemoryStreamManager_WriteAfterClose(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-write-after-close" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + if err := writer.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } + + // Try to write after close + err = writer.Write(ctx, json.RawMessage(`"chunk"`)) + if err == nil { + t.Fatal("Expected error when writing after close") + } + + var ufErr *core.UserFacingError + if !errors.As(err, &ufErr) { + t.Fatalf("Expected UserFacingError, got %T", err) + } + if ufErr.Status != core.FAILED_PRECONDITION { + t.Errorf("Expected FAILED_PRECONDITION status, got %v", ufErr.Status) + } +} + +func TestInMemoryStreamManager_DoneAfterError(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-done-after-error" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + if err := writer.Error(ctx, core.NewPublicError(core.INTERNAL, "test", nil)); err != nil { + t.Fatalf("Error failed: %v", err) + } + + // Try to mark done after error + err = writer.Done(ctx, json.RawMessage(`"done"`)) + if err == nil { + t.Fatal("Expected error when calling Done after Error") + } +} + +func TestInMemoryStreamManager_MultipleSubscribers(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-multi-sub" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + // Create multiple subscribers + events1, unsub1, err := m.Subscribe(ctx, streamID) + if err != nil { + t.Fatalf("Subscribe 1 failed: %v", err) + } + defer unsub1() + + events2, unsub2, err := m.Subscribe(ctx, streamID) + if err != nil { + t.Fatalf("Subscribe 2 failed: %v", err) + } + defer unsub2() + + // Write a chunk + chunk := json.RawMessage(`"shared-chunk"`) + if err := writer.Write(ctx, chunk); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Both subscribers should receive the chunk + for i, events := range []<-chan StreamEvent{events1, events2} { + select { + case event := <-events: + if event.Type != StreamEventChunk { + t.Errorf("Subscriber %d: expected chunk event, got %v", i+1, event.Type) + } + if string(event.Chunk) != string(chunk) { + t.Errorf("Subscriber %d: expected chunk %s, got %s", i+1, chunk, event.Chunk) + } + case <-time.After(time.Second): + t.Fatalf("Subscriber %d: timeout waiting for chunk", i+1) + } + } +} + +func TestInMemoryStreamManager_LateSubscriberGetsBufferedChunks(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-late-sub" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + // Write chunks before any subscriber + chunks := []string{"early1", "early2"} + for _, chunk := range chunks { + if err := writer.Write(ctx, json.RawMessage(`"`+chunk+`"`)); err != nil { + t.Fatalf("Write failed: %v", err) + } + } + + // Late subscriber joins + events, unsubscribe, err := m.Subscribe(ctx, streamID) + if err != nil { + t.Fatalf("Subscribe failed: %v", err) + } + defer unsubscribe() + + // Should receive buffered chunks + for i, expected := range chunks { + select { + case event := <-events: + if event.Type != StreamEventChunk { + t.Errorf("Expected chunk event, got %v", event.Type) + } + var got string + if err := json.Unmarshal(event.Chunk, &got); err != nil { + t.Fatalf("Failed to unmarshal chunk: %v", err) + } + if got != expected { + t.Errorf("Chunk %d: expected %q, got %q", i, expected, got) + } + case <-time.After(time.Second): + t.Fatalf("Timeout waiting for buffered chunk %d", i) + } + } +} + +func TestInMemoryStreamManager_SubscribeToCompletedStream(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-completed" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + // Write and complete before subscribing + if err := writer.Write(ctx, json.RawMessage(`"chunk1"`)); err != nil { + t.Fatalf("Write failed: %v", err) + } + if err := writer.Write(ctx, json.RawMessage(`"chunk2"`)); err != nil { + t.Fatalf("Write failed: %v", err) + } + output := json.RawMessage(`{"final": true}`) + if err := writer.Done(ctx, output); err != nil { + t.Fatalf("Done failed: %v", err) + } + + // Subscribe after completion + events, unsubscribe, err := m.Subscribe(ctx, streamID) + if err != nil { + t.Fatalf("Subscribe failed: %v", err) + } + defer unsubscribe() + + // Should receive all buffered chunks + for i := 0; i < 2; i++ { + select { + case event := <-events: + if event.Type != StreamEventChunk { + t.Errorf("Expected chunk event %d, got %v", i, event.Type) + } + case <-time.After(time.Second): + t.Fatalf("Timeout waiting for chunk %d", i) + } + } + + // Should receive done event + select { + case event := <-events: + if event.Type != StreamEventDone { + t.Errorf("Expected done event, got %v", event.Type) + } + if string(event.Output) != string(output) { + t.Errorf("Expected output %s, got %s", output, event.Output) + } + case <-time.After(time.Second): + t.Fatal("Timeout waiting for done event") + } + + // Channel should be closed + select { + case _, ok := <-events: + if ok { + t.Error("Expected channel to be closed") + } + case <-time.After(100 * time.Millisecond): + t.Error("Channel not closed after done") + } +} + +func TestInMemoryStreamManager_SubscribeToErroredStream(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-errored" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + // Write and error before subscribing + if err := writer.Write(ctx, json.RawMessage(`"chunk1"`)); err != nil { + t.Fatalf("Write failed: %v", err) + } + streamErr := core.NewPublicError(core.INTERNAL, "test error", nil) + if err := writer.Error(ctx, streamErr); err != nil { + t.Fatalf("Error failed: %v", err) + } + + // Subscribe after error + events, unsubscribe, err := m.Subscribe(ctx, streamID) + if err != nil { + t.Fatalf("Subscribe failed: %v", err) + } + defer unsubscribe() + + // Should receive buffered chunk + select { + case event := <-events: + if event.Type != StreamEventChunk { + t.Errorf("Expected chunk event, got %v", event.Type) + } + case <-time.After(time.Second): + t.Fatal("Timeout waiting for chunk") + } + + // Should receive error event + select { + case event := <-events: + if event.Type != StreamEventError { + t.Errorf("Expected error event, got %v", event.Type) + } + case <-time.After(time.Second): + t.Fatal("Timeout waiting for error event") + } +} + +func TestInMemoryStreamManager_Unsubscribe(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-unsub" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + events, unsubscribe, err := m.Subscribe(ctx, streamID) + if err != nil { + t.Fatalf("Subscribe failed: %v", err) + } + + // Unsubscribe + unsubscribe() + + // Write a chunk - should not panic + if err := writer.Write(ctx, json.RawMessage(`"chunk"`)); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Events channel should be closed + select { + case _, ok := <-events: + if ok { + t.Error("Expected channel to be closed after unsubscribe") + } + case <-time.After(100 * time.Millisecond): + t.Error("Channel not closed after unsubscribe") + } +} + +func TestInMemoryStreamManager_WithTTL(t *testing.T) { + m := NewInMemoryStreamManager(WithTTL(10 * time.Millisecond)) + defer m.Close() + + if m.ttl != 10*time.Millisecond { + t.Errorf("Expected TTL 10ms, got %v", m.ttl) + } +} + +func TestInMemoryStreamManager_ConcurrentOperations(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-concurrent" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + const numSubscribers = 5 + const numChunks = 10 + + var wg sync.WaitGroup + errors := make(chan error, numSubscribers*numChunks) + + // Start subscribers + for i := 0; i < numSubscribers; i++ { + wg.Add(1) + go func(subID int) { + defer wg.Done() + + events, unsubscribe, err := m.Subscribe(ctx, streamID) + if err != nil { + errors <- err + return + } + defer unsubscribe() + + received := 0 + for event := range events { + if event.Type == StreamEventChunk { + received++ + } else if event.Type == StreamEventDone { + break + } + } + + if received != numChunks { + errors <- core.NewPublicError(core.INTERNAL, "subscriber %d received %d chunks, expected %d", nil) + } + }(i) + } + + // Give subscribers time to set up + time.Sleep(50 * time.Millisecond) + + // Write chunks concurrently + for i := 0; i < numChunks; i++ { + if err := writer.Write(ctx, json.RawMessage(`"chunk"`)); err != nil { + t.Fatalf("Write failed: %v", err) + } + } + + // Complete the stream + if err := writer.Done(ctx, json.RawMessage(`"done"`)); err != nil { + t.Fatalf("Done failed: %v", err) + } + + wg.Wait() + close(errors) + + for err := range errors { + t.Errorf("Subscriber error: %v", err) + } +} + +func TestInMemoryStreamManager_Close(t *testing.T) { + m := NewInMemoryStreamManager() + + // Close should not block + done := make(chan struct{}) + go func() { + m.Close() + close(done) + }() + + select { + case <-done: + // Success + case <-time.After(time.Second): + t.Fatal("Close blocked") + } +} + +func TestInMemoryStreamManager_CleanupExpiredStreams(t *testing.T) { + m := NewInMemoryStreamManager(WithTTL(10 * time.Millisecond)) + defer m.Close() + + ctx := context.Background() + + // Create and complete a stream + writer, err := m.Open(ctx, "expired-stream") + if err != nil { + t.Fatalf("Open failed: %v", err) + } + if err := writer.Done(ctx, json.RawMessage(`"done"`)); err != nil { + t.Fatalf("Done failed: %v", err) + } + + // Wait for TTL to expire + time.Sleep(20 * time.Millisecond) + + // Trigger cleanup + m.cleanupExpiredStreams() + + // Stream should be gone + _, _, err = m.Subscribe(ctx, "expired-stream") + if err == nil { + t.Fatal("Expected error subscribing to expired stream") + } + + var ufErr *core.UserFacingError + if !errors.As(err, &ufErr) { + t.Fatalf("Expected UserFacingError, got %T", err) + } + if ufErr.Status != core.NOT_FOUND { + t.Errorf("Expected NOT_FOUND status, got %v", ufErr.Status) + } +} + +func TestInMemoryStreamManager_OpenStreamsNotCleanedUp(t *testing.T) { + m := NewInMemoryStreamManager(WithTTL(10 * time.Millisecond)) + defer m.Close() + + ctx := context.Background() + + // Create an open stream (not completed) + _, err := m.Open(ctx, "open-stream") + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + // Wait longer than TTL + time.Sleep(20 * time.Millisecond) + + // Trigger cleanup + m.cleanupExpiredStreams() + + // Stream should still exist + _, _, err = m.Subscribe(ctx, "open-stream") + if err != nil { + t.Fatalf("Subscribe failed: %v", err) + } +} + +func TestInMemoryStreamManager_ErrorAfterClose(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-error-after-close" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + if err := writer.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } + + // Try to error after close + err = writer.Error(ctx, core.NewPublicError(core.INTERNAL, "test", nil)) + if err == nil { + t.Fatal("Expected error when calling Error after Close") + } + + var ufErr *core.UserFacingError + if !errors.As(err, &ufErr) { + t.Fatalf("Expected UserFacingError, got %T", err) + } + if ufErr.Status != core.FAILED_PRECONDITION { + t.Errorf("Expected FAILED_PRECONDITION status, got %v", ufErr.Status) + } +} + +func TestInMemoryStreamManager_DoneAfterClose(t *testing.T) { + m := NewInMemoryStreamManager() + defer m.Close() + + ctx := context.Background() + streamID := "test-stream-done-after-close" + + writer, err := m.Open(ctx, streamID) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + if err := writer.Close(); err != nil { + t.Fatalf("Close failed: %v", err) + } + + // Try to done after close + err = writer.Done(ctx, json.RawMessage(`"done"`)) + if err == nil { + t.Fatal("Expected error when calling Done after Close") + } + + var ufErr *core.UserFacingError + if !errors.As(err, &ufErr) { + t.Fatalf("Expected UserFacingError, got %T", err) + } + if ufErr.Status != core.FAILED_PRECONDITION { + t.Errorf("Expected FAILED_PRECONDITION status, got %v", ufErr.Status) + } +} diff --git a/go/genkit/doc.go b/go/genkit/doc.go new file mode 100644 index 0000000000..f72b7c1e7b --- /dev/null +++ b/go/genkit/doc.go @@ -0,0 +1,408 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +/* +Package genkit provides a framework for building AI-powered applications in Go. + +Genkit is an open-source framework that helps you build, deploy, and monitor +production-ready AI features. It provides a unified interface for working with +large language models (LLMs), managing prompts, defining workflows, and integrating +with various AI service providers. + +For comprehensive documentation, tutorials, and examples, visit https://genkit.dev + +# Getting Started + +Initialize Genkit with a plugin to connect to an AI provider: + + ctx := context.Background() + g := genkit.Init(ctx, + genkit.WithPlugins(&googlegenai.GoogleAI{}), + ) + +Generate text with a simple prompt: + + text, err := genkit.GenerateText(ctx, g, + ai.WithModelName("googleai/gemini-2.5-flash"), + ai.WithPrompt("Tell me a joke"), + ) + if err != nil { + log.Fatal(err) + } + fmt.Println(text) + +# Models + +Models represent AI language models that generate content. Use plugins to access +models from providers like Google AI, Vertex AI, Anthropic, or Ollama. Models are +referenced by name and can include provider-specific configuration: + + resp, err := genkit.Generate(ctx, g, + ai.WithModelName("googleai/gemini-2.5-flash"), + ai.WithPrompt("Explain quantum computing in simple terms"), + ) + +You can set a default model during initialization: + + g := genkit.Init(ctx, + genkit.WithPlugins(&googlegenai.GoogleAI{}), + genkit.WithDefaultModel("googleai/gemini-2.5-flash"), + ) + +# Flows + +Flows are reusable, observable functions that orchestrate AI operations. They +provide automatic tracing, can be exposed as HTTP endpoints, and support both +streaming and non-streaming execution. + +Define a simple flow: + + jokesFlow := genkit.DefineFlow(g, "jokesFlow", + func(ctx context.Context, topic string) (string, error) { + return genkit.GenerateText(ctx, g, + ai.WithPrompt("Share a joke about %s.", topic), + ) + }, + ) + + joke, err := jokesFlow.Run(ctx, "programming") + +Define a streaming flow that sends chunks as they're generated: + + streamingFlow := genkit.DefineStreamingFlow(g, "streamingJokes", + func(ctx context.Context, topic string, sendChunk ai.ModelStreamCallback) (string, error) { + resp, err := genkit.Generate(ctx, g, + ai.WithPrompt("Share a joke about %s.", topic), + ai.WithStreaming(sendChunk), + ) + if err != nil { + return "", err + } + return resp.Text(), nil + }, + ) + +Use [Run] within flows to create traced sub-steps for observability: + + genkit.DefineFlow(g, "pipeline", + func(ctx context.Context, input string) (string, error) { + result, err := genkit.Run(ctx, "processStep", func() (string, error) { + return process(input), nil + }) + return result, err + }, + ) + +# Prompts + +Prompts can be defined programmatically or loaded from .prompt files (Dotprompt format). +They encapsulate model configuration, input schemas, and template logic for reuse. + +Define a prompt in code: + + jokePrompt := genkit.DefinePrompt(g, "joke", + ai.WithModelName("googleai/gemini-2.5-flash"), + ai.WithInputType(JokeRequest{Topic: "default topic"}), + ai.WithPrompt("Share a joke about {{topic}}."), + ) + + stream := jokePrompt.ExecuteStream(ctx, ai.WithInput(map[string]any{"topic": "cats"})) + for result, err := range stream { + if err != nil { + return err + } + if result.Done { + fmt.Println(result.Response.Text()) + } + } + +For type-safe prompts with structured input and output, use [DefineDataPrompt]: + + type RecipeRequest struct { + Cuisine string `json:"cuisine"` + Dish string `json:"dish"` + ServingSize int `json:"servingSize"` + } + + type Recipe struct { + Title string `json:"title"` + Ingredients []string `json:"ingredients"` + Instructions []string `json:"instructions"` + } + + recipePrompt := genkit.DefineDataPrompt[RecipeRequest, *Recipe](g, "recipe", + ai.WithSystem("You are an experienced chef."), + ai.WithPrompt("Create a {{cuisine}} {{dish}} recipe for {{servingSize}} people."), + ) + + for result, err := range recipePrompt.ExecuteStream(ctx, RecipeRequest{ + Cuisine: "Italian", Dish: "pasta", ServingSize: 4, + }) { + // result.Chunk is *Recipe, result.Output is final *Recipe + } + +Load prompts from .prompt files by specifying a prompt directory: + + g := genkit.Init(ctx, + genkit.WithPlugins(&googlegenai.GoogleAI{}), + genkit.WithPromptDir("./prompts"), + ) + + // Look up a loaded prompt + jokePrompt := genkit.LookupPrompt(g, "joke") + + // Or with type parameters for structured I/O + recipePrompt := genkit.LookupDataPrompt[RecipeRequest, *Recipe](g, "recipe") + +When using .prompt files with custom output schemas, register the schema first: + + genkit.DefineSchemaFor[Recipe](g) + +# Tools + +Tools extend model capabilities by allowing them to call functions during generation. +Define tools that the model can invoke to perform actions or retrieve information: + + weatherTool := genkit.DefineTool(g, "getWeather", + "Gets the current weather for a city", + func(ctx *ai.ToolContext, city string) (string, error) { + // Fetch weather data... + return "Sunny, 72°F", nil + }, + ) + + resp, err := genkit.Generate(ctx, g, + ai.WithPrompt("What's the weather in Paris?"), + ai.WithTools(weatherTool), + ) + +# Structured Output + +Generate structured data that conforms to Go types using [GenerateData] or +[GenerateDataStream]. Use jsonschema struct tags to provide descriptions and +constraints that help the model understand the expected output: + + type Joke struct { + Joke string `json:"joke" jsonschema:"description=The joke text"` + Category string `json:"category" jsonschema:"description=The joke category"` + } + + joke, resp, err := genkit.GenerateData[*Joke](ctx, g, + ai.WithPrompt("Tell me a programming joke"), + ) + +For streaming structured output: + + stream := genkit.GenerateDataStream[*Recipe](ctx, g, + ai.WithPrompt("Create a pasta recipe"), + ) + for result, err := range stream { + if err != nil { + return nil, err + } + if result.Done { + return result.Output, nil + } + // result.Chunk contains partial Recipe as it streams + fmt.Printf("Got %d ingredients so far\n", len(result.Chunk.Ingredients)) + } + +# Streaming + +Genkit supports streaming at multiple levels. Use [GenerateStream] for streaming +model responses: + + stream := genkit.GenerateStream(ctx, g, + ai.WithPrompt("Write a short story"), + ) + for result, err := range stream { + if err != nil { + log.Fatal(err) + } + if result.Done { + fmt.Println("\n--- Complete ---") + } else { + fmt.Print(result.Chunk.Text()) + } + } + +Use [DefineStreamingFlow] for flows that stream custom data types: + + genkit.DefineStreamingFlow(g, "countdown", + func(ctx context.Context, count int, sendChunk func(context.Context, int) error) (string, error) { + for i := count; i > 0; i-- { + if err := sendChunk(ctx, i); err != nil { + return "", err + } + time.Sleep(time.Second) + } + return "Liftoff!", nil + }, + ) + +# Development Mode and Dev UI + +Set GENKIT_ENV=dev to enable development features including the Reflection API +server that powers the Genkit Developer UI: + + $ export GENKIT_ENV=dev + $ go run main.go + +Then run the Dev UI to inspect flows, test prompts, and view traces: + + $ npx genkit start -- go run main.go + +The Dev UI provides: + - Interactive flow testing with input/output inspection + - Prompt playground for iterating on prompts + - Trace viewer for debugging and performance analysis + - Action browser for exploring registered actions + +# HTTP Server Integration + +Expose flows as HTTP endpoints for production deployment using [Handler]: + + mux := http.NewServeMux() + for _, flow := range genkit.ListFlows(g) { + mux.HandleFunc("POST /"+flow.Name(), genkit.Handler(flow)) + } + log.Fatal(server.Start(ctx, "127.0.0.1:8080", mux)) + +Handlers support streaming responses via Server-Sent Events when the client +sends Accept: text/event-stream. For durable streaming that survives reconnects, +use [WithStreamManager]: + + mux.HandleFunc("POST /countdown", genkit.Handler(countdown, + genkit.WithStreamManager(streaming.NewInMemoryStreamManager( + streaming.WithTTL(10*time.Minute), + )), + )) + +# Plugins + +Genkit's functionality is extended through plugins that provide models, tools, +retrievers, and other capabilities. Common plugins include: + + - googlegenai: Google AI (Gemini models) + - vertexai: Google Cloud Vertex AI + - ollama: Local Ollama models + +Initialize plugins during [Init]: + + g := genkit.Init(ctx, + genkit.WithPlugins( + &googlegenai.GoogleAI{}, + &vertexai.VertexAI{ProjectID: "my-project"}, + ), + ) + +# Messages and Parts + +Build conversation messages using helper functions from the [ai] package. These +are used with [ai.WithMessages] or when building custom conversation flows: + + // Create messages for a conversation + messages := []*ai.Message{ + ai.NewSystemTextMessage("You are a helpful assistant."), + ai.NewUserTextMessage("Hello!"), + ai.NewModelTextMessage("Hi there! How can I help?"), + } + + resp, err := genkit.Generate(ctx, g, + ai.WithMessages(messages...), + ai.WithPrompt("What can you do?"), + ) + +For multi-modal content, combine text and media parts: + + userMsg := ai.NewUserMessage( + ai.NewTextPart("What's in this image?"), + ai.NewMediaPart("image/png", base64ImageData), + ) + +Available message constructors in the [ai] package: + + - [ai.NewUserTextMessage], [ai.NewUserMessage]: User messages + - [ai.NewModelTextMessage], [ai.NewModelMessage]: Model responses + - [ai.NewSystemTextMessage], [ai.NewSystemMessage]: System instructions + +Available part constructors in the [ai] package: + + - [ai.NewTextPart]: Text content + - [ai.NewMediaPart]: Images, audio, video (base64-encoded) + - [ai.NewDataPart]: Raw data strings + - [ai.NewToolRequestPart], [ai.NewToolResponsePart]: Tool interactions + +# Generation Options + +Generation functions ([Generate], [GenerateText], [GenerateData], [GenerateStream]) +accept options from the [ai] package to control behavior. The most common options: + +Model and Configuration: + + - [ai.WithModel]: Specify the model (accepts [ai.ModelRef] or plugin model refs) + - [ai.WithModelName]: Specify model by name string (e.g., "googleai/gemini-2.5-flash") + - [ai.WithConfig]: Set generation parameters (temperature, max tokens, etc.) + +Prompting: + + - [ai.WithPrompt]: Set the user prompt (supports format strings) + - [ai.WithSystem]: Set system instructions + - [ai.WithMessages]: Provide conversation history + +Tools and Output: + + - [ai.WithTools]: Enable tools the model can call + - [ai.WithOutputType]: Request structured output matching a Go type + - [ai.WithOutputFormat]: Specify output format (json, text, etc.) + +Streaming: + + - [ai.WithStreaming]: Enable streaming with a callback function + +Example combining multiple options: + + resp, err := genkit.Generate(ctx, g, + ai.WithModelName("googleai/gemini-2.5-flash"), + ai.WithSystem("You are a helpful coding assistant."), + ai.WithMessages(conversationHistory...), + ai.WithPrompt("Explain this code: %s", code), + ai.WithTools(searchTool, calculatorTool), + // Config is provider-specific (e.g., genai.GenerateContentConfig for Google AI) + ) + +# Unregistered Components + +For advanced use cases, the [ai] package provides New* functions to create +components without registering them in Genkit. This is useful for plugins +or when you need to pass components directly: + + - [ai.NewTool]: Create an unregistered tool + - [ai.NewModel]: Create an unregistered model + - [ai.NewRetriever]: Create an unregistered retriever + - [ai.NewEmbedder]: Create an unregistered embedder + +Use the corresponding Define* functions in this package to create and register +components for use with Genkit's action system, tracing, and Dev UI. + +# Additional Resources + + - Documentation: https://genkit.dev + - Go Getting Started: https://genkit.dev/go/docs/get-started-go + - Samples: https://github.com/firebase/genkit/tree/main/go/samples + - GitHub: https://github.com/firebase/genkit +*/ +package genkit diff --git a/go/genkit/example_test.go b/go/genkit/example_test.go new file mode 100644 index 0000000000..917e8dc49c --- /dev/null +++ b/go/genkit/example_test.go @@ -0,0 +1,322 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package genkit_test + +import ( + "context" + "fmt" + "log" + "net/http" + "strings" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/genkit" +) + +// This example shows basic initialization and flow definition. +func Example() { + ctx := context.Background() + + // Initialize Genkit (without plugins for this example) + g := genkit.Init(ctx) + + // Define a simple flow + greetFlow := genkit.DefineFlow(g, "greet", + func(ctx context.Context, name string) (string, error) { + return fmt.Sprintf("Hello, %s!", name), nil + }, + ) + + // Run the flow + greeting, err := greetFlow.Run(ctx, "World") + if err != nil { + log.Fatal(err) + } + fmt.Println(greeting) + // Output: Hello, World! +} + +// This example demonstrates defining a simple non-streaming flow. +func ExampleDefineFlow() { + ctx := context.Background() + g := genkit.Init(ctx) + + // Define a flow that processes input + uppercaseFlow := genkit.DefineFlow(g, "uppercase", + func(ctx context.Context, input string) (string, error) { + return strings.ToUpper(input), nil + }, + ) + + // Run the flow + result, err := uppercaseFlow.Run(ctx, "hello") + if err != nil { + log.Fatal(err) + } + fmt.Println(result) + // Output: HELLO +} + +// This example demonstrates defining a streaming flow that sends +// chunks to the caller as they are produced. +func ExampleDefineStreamingFlow() { + ctx := context.Background() + g := genkit.Init(ctx) + + // Define a streaming flow that counts down + countdownFlow := genkit.DefineStreamingFlow(g, "countdown", + func(ctx context.Context, start int, sendChunk func(context.Context, int) error) (string, error) { + for i := start; i > 0; i-- { + if err := sendChunk(ctx, i); err != nil { + return "", err + } + } + return "Liftoff!", nil + }, + ) + + // Stream results using the iterator + iter := countdownFlow.Stream(ctx, 3) + iter(func(val *core.StreamingFlowValue[string, int], err error) bool { + if err != nil { + log.Fatal(err) + } + if val.Done { + fmt.Println("Final:", val.Output) + } else { + fmt.Println("Count:", val.Stream) + } + return true + }) + // Output: + // Count: 3 + // Count: 2 + // Count: 1 + // Final: Liftoff! +} + +// This example demonstrates using Run to create traced sub-steps +// within a flow for better observability. +func ExampleRun() { + ctx := context.Background() + g := genkit.Init(ctx) + + // Define a flow with traced sub-steps + pipelineFlow := genkit.DefineFlow(g, "pipeline", + func(ctx context.Context, input string) (string, error) { + // Each Run call creates a traced step visible in the Dev UI + upper, err := genkit.Run(ctx, "uppercase", func() (string, error) { + return strings.ToUpper(input), nil + }) + if err != nil { + return "", err + } + + result, err := genkit.Run(ctx, "addPrefix", func() (string, error) { + return "Processed: " + upper, nil + }) + return result, err + }, + ) + + result, err := pipelineFlow.Run(ctx, "hello") + if err != nil { + log.Fatal(err) + } + fmt.Println(result) + // Output: Processed: HELLO +} + +// This example demonstrates defining a tool that models can call +// during generation. +func ExampleDefineTool() { + ctx := context.Background() + g := genkit.Init(ctx) + + // Define a tool that adds two numbers + _ = genkit.DefineTool(g, "add", + "Adds two numbers together", + func(ctx *ai.ToolContext, input struct { + A float64 `json:"a" jsonschema:"description=First number"` + B float64 `json:"b" jsonschema:"description=Second number"` + }) (float64, error) { + return input.A + input.B, nil + }, + ) + + // The tool is now registered and can be used with ai.WithTools() + // when calling genkit.Generate() + fmt.Println("Tool registered: add") + // Output: Tool registered: add +} + +// This example demonstrates defining a reusable prompt with a template. +func ExampleDefinePrompt() { + ctx := context.Background() + g := genkit.Init(ctx) + + // Define a prompt with Handlebars template syntax + prompt := genkit.DefinePrompt(g, "greeting", + ai.WithPrompt("Say hello to {{name}} in a {{style}} way."), + ) + + // Render the prompt (without executing - useful for inspection) + rendered, err := prompt.Render(ctx, map[string]any{ + "name": "Alice", + "style": "friendly", + }) + if err != nil { + log.Fatal(err) + } + // The rendered prompt contains the messages that would be sent + fmt.Println(rendered.Messages[0].Content[0].Text) + // Output: Say hello to Alice in a friendly way. +} + +// This example demonstrates registering a Go type as a named schema. +func ExampleDefineSchemaFor() { + ctx := context.Background() + g := genkit.Init(ctx) + + // Define a struct type + type Person struct { + Name string `json:"name" jsonschema:"description=The person's name"` + Age int `json:"age" jsonschema:"description=The person's age"` + } + + // Register the schema - this makes it available for .prompt files + // that reference it by name (e.g., "output: { schema: Person }") + genkit.DefineSchemaFor[Person](g) + + fmt.Println("Schema registered: Person") + // Output: Schema registered: Person +} + +// This example demonstrates creating an HTTP server that exposes +// all registered flows as endpoints. +func ExampleListFlows_httpServer() { + ctx := context.Background() + g := genkit.Init(ctx) + + // Define some flows + genkit.DefineFlow(g, "echo", func(ctx context.Context, s string) (string, error) { + return s, nil + }) + + genkit.DefineFlow(g, "reverse", func(ctx context.Context, s string) (string, error) { + runes := []rune(s) + for i, j := 0, len(runes)-1; i < j; i, j = i+1, j-1 { + runes[i], runes[j] = runes[j], runes[i] + } + return string(runes), nil + }) + + // Create HTTP handlers for all flows + mux := http.NewServeMux() + for _, flow := range genkit.ListFlows(g) { + mux.HandleFunc("POST /"+flow.Name(), genkit.Handler(flow)) + } + + // The mux now has: + // - POST /echo + // - POST /reverse + fmt.Printf("Registered %d flow handlers\n", len(genkit.ListFlows(g))) + // Output: Registered 2 flow handlers +} + +// This example demonstrates using Handler to expose a single flow +// as an HTTP endpoint. +func ExampleHandler() { + ctx := context.Background() + g := genkit.Init(ctx) + + // Define a flow + greetFlow := genkit.DefineFlow(g, "greet", + func(ctx context.Context, name string) (string, error) { + return fmt.Sprintf("Hello, %s!", name), nil + }, + ) + + // Create an HTTP handler for the flow + mux := http.NewServeMux() + mux.HandleFunc("POST /greet", genkit.Handler(greetFlow)) + + // The handler accepts JSON: {"data": "World"} + // and returns JSON: {"result": "Hello, World!"} + fmt.Println("Handler registered at POST /greet") + // Output: Handler registered at POST /greet +} + +// This example demonstrates using type-safe data prompts with +// strongly-typed input and output. +func ExampleDefineDataPrompt() { + ctx := context.Background() + g := genkit.Init(ctx) + + // Define input and output types + type JokeRequest struct { + Topic string `json:"topic"` + } + + type Joke struct { + Setup string `json:"setup"` + Punchline string `json:"punchline"` + } + + // Define a type-safe prompt + // Note: In production, you'd also set ai.WithModel(...) + _ = genkit.DefineDataPrompt[JokeRequest, *Joke](g, "joke", + ai.WithPrompt("Tell a joke about {{topic}}. Return JSON with setup and punchline."), + ) + + // The prompt can now be executed with: + // for result, err := range jokePrompt.ExecuteStream(ctx, JokeRequest{Topic: "cats"}) { + // if result.Done { + // fmt.Println(result.Output.Setup) + // fmt.Println(result.Output.Punchline) + // } + // } + + fmt.Println("DataPrompt registered: joke") + // Output: DataPrompt registered: joke +} + +// This example demonstrates looking up a prompt that was loaded +// from a .prompt file. +func ExampleLookupPrompt() { + ctx := context.Background() + + // In production, you would initialize with a prompt directory: + // g := genkit.Init(ctx, genkit.WithPromptDir("./prompts")) + + g := genkit.Init(ctx) + + // Define a prompt programmatically (simulating a loaded prompt) + genkit.DefinePrompt(g, "greeting", + ai.WithPrompt("Hello {{name}}!"), + ) + + // Look up the prompt by name + prompt := genkit.LookupPrompt(g, "greeting") + if prompt == nil { + log.Fatal("Prompt not found") + } + + fmt.Println("Found prompt:", prompt.Name()) + // Output: Found prompt: greeting +} diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 75ef8c9a8a..83429ca0d1 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -21,6 +21,8 @@ import ( "context" "errors" "fmt" + "io/fs" + "iter" "log/slog" "os" "os/signal" @@ -46,6 +48,7 @@ type Genkit struct { type genkitOptions struct { DefaultModel string // Default model to use if no other model is specified. PromptDir string // Directory where dotprompts are stored. Will be loaded automatically on initialization. + PromptFS fs.FS // Embedded filesystem containing prompts (alternative to PromptDir). Plugins []api.Plugin // Plugin to initialize automatically. } @@ -66,6 +69,20 @@ func (o *genkitOptions) apply(gOpts *genkitOptions) error { if gOpts.PromptDir != "" { return errors.New("cannot set prompt directory more than once (WithPromptDir)") } + if gOpts.PromptFS != nil { + return errors.New("cannot use WithPromptDir together with WithPromptFS") + } + gOpts.PromptDir = o.PromptDir + } + + if o.PromptFS != nil { + if gOpts.PromptFS != nil { + return errors.New("cannot set prompt filesystem more than once (WithPromptFS)") + } + if gOpts.PromptDir != "" { + return errors.New("cannot use WithPromptFS together with WithPromptDir") + } + gOpts.PromptFS = o.PromptFS gOpts.PromptDir = o.PromptDir } @@ -99,13 +116,44 @@ func WithDefaultModel(model string) GenkitOption { // The default directory is "prompts" relative to the project root where // [Init] is called. // +// When used with [WithPromptFS], this directory serves as the root path within +// the embedded filesystem instead of a local disk path. For example, if using +// `//go:embed prompts/*`, set the directory to "prompts" to match. +// // Invalid prompt files will result in logged errors during initialization, // while valid files that define invalid prompts will cause [Init] to panic. -// This option can only be applied once. func WithPromptDir(dir string) GenkitOption { return &genkitOptions{PromptDir: dir} } +// WithPromptFS specifies an embedded filesystem ([fs.FS]) containing `.prompt` files. +// This is useful for embedding prompts directly into the binary using Go's [embed] package, +// eliminating the need to distribute prompt files separately. +// +// The `fsys` parameter should be an [fs.FS] implementation (e.g., [embed.FS]). +// Use [WithPromptDir] to specify the root directory within the filesystem where +// prompts are located (defaults to "prompts"). +// +// Example: +// +// import "embed" +// +// //go:embed prompts/* +// var promptsFS embed.FS +// +// func main() { +// g := genkit.Init(ctx, +// genkit.WithPromptFS(promptsFS), +// genkit.WithPromptDir("prompts"), +// ) +// } +// +// Invalid prompt files will result in logged errors during initialization, +// while valid files that define invalid prompts will cause [Init] to panic. +func WithPromptFS(fsys fs.FS) GenkitOption { + return &genkitOptions{PromptFS: fsys} +} + // Init creates and initializes a new [Genkit] instance with the provided options. // It sets up the registry, initializes plugins ([WithPlugins]), loads prompts // ([WithPromptDir]), and configures other settings like the default model @@ -184,7 +232,15 @@ func Init(ctx context.Context, opts ...GenkitOption) *Genkit { ai.ConfigureFormats(r) ai.DefineGenerateAction(ctx, r) - ai.LoadPromptDir(r, gOpts.PromptDir, "") + if gOpts.PromptFS != nil { + dir := gOpts.PromptDir + if dir == "" { + dir = "prompts" + } + ai.LoadPromptDirFromFS(r, gOpts.PromptFS, dir, "") + } else { + loadPromptDirOS(r, gOpts.PromptDir, "") + } r.RegisterValue(api.DefaultModelKey, gOpts.DefaultModel) r.RegisterValue(api.PromptDirKey, gOpts.PromptDir) @@ -268,7 +324,7 @@ func DefineFlow[In, Out any](g *Genkit, name string, fn core.Func[In, Out]) *cor // Example: // // counterFlow := genkit.DefineStreamingFlow(g, "counter", -// func(ctx context.Context, limit int, stream func(context.Context, int) error) (string, error) { +// func(ctx context.Context, limit int, stream core.StreamCallback[int]) (string, error) { // if stream == nil { // Non-streaming case // return fmt.Sprintf("Counted up to %d", limit), nil // } @@ -371,12 +427,14 @@ func ListTools(g *Genkit) []ai.Tool { // DefineModel defines a custom model implementation, registers it as a [core.Action] // of type Model, and returns an [ai.Model] interface. // -// The `provider` and `name` arguments form the unique identifier for the model -// (e.g., "myProvider/myModel"). The `info` argument provides metadata about the -// model's capabilities ([ai.ModelInfo]). The `fn` argument ([ai.ModelFunc]) -// implements the actual generation logic, handling input requests ([ai.ModelRequest]) -// and producing responses ([ai.ModelResponse]), potentially streaming chunks -// ([ai.ModelResponseChunk]) via the callback. +// The `name` argument is the unique identifier for the model (e.g., "myProvider/myModel"). +// The `opts` argument provides metadata about the model's capabilities ([ai.ModelOptions]). +// The `fn` argument ([ai.ModelFunc]) implements the actual generation logic, handling +// input requests ([ai.ModelRequest]) and producing responses ([ai.ModelResponse]), +// potentially streaming chunks ([ai.ModelResponseChunk]) via the callback. +// +// For models that don't need to be registered (e.g., for plugin development or testing), +// use [ai.NewModel] instead. // // Example: // @@ -454,7 +512,7 @@ func LookupBackgroundModel(g *Genkit, name string) ai.BackgroundModel { } // DefineTool defines a tool that can be used by models during generation, -// registers it as a [core.Action] of type Tool, and returns an [ai.ToolDef]. +// registers it as a [core.Action] of type Tool, and returns an [ai.Tool]. // Tools allow models to interact with external systems or perform specific computations. // // The `name` is the identifier the model uses to request the tool. The `description` @@ -464,7 +522,13 @@ func LookupBackgroundModel(g *Genkit, name string) ai.BackgroundModel { // `inputSchema` and `outputSchema` in the tool's definition, which guide the model // on how to provide input and interpret output. // -// Use [ai.WithInputSchema] to provide a custom JSON schema instead of inferring from the type parameter. +// For tools that don't need to be registered (e.g., dynamically created tools), +// use [ai.NewTool] instead. +// +// # Options +// +// - [ai.WithInputSchema]: Provide a custom JSON schema instead of inferring from the type parameter +// - [ai.WithInputSchemaName]: Reference a pre-registered schema by name // // Example: // @@ -507,38 +571,6 @@ func DefineTool[In, Out any](g *Genkit, name, description string, fn ai.ToolFunc // input of type `any`, and returning an output of type `Out`. // // Deprecated: Use [DefineTool] with [ai.WithInputSchema] instead. -// -// Example: -// -// // Define a custom input schema -// inputSchema := map[string]any{ -// "type": "object", -// "properties": map[string]any{ -// "city": map[string]any{"type": "string"}, -// "unit": map[string]any{ -// "type": "string", -// "enum": []any{"C", "F"}, -// }, -// }, -// "required": []string{"city"}, -// } -// -// // Define the tool with the schema -// weatherTool := genkit.DefineTool(g, "getWeather", -// "Fetches the weather for a given city with unit preference", -// func(ctx *ai.ToolContext, input any) (string, error) { -// // Parse and validate input -// data := input.(map[string]any) -// city := data["city"].(string) -// unit := "C" // default -// if u, ok := data["unit"].(string); ok { -// unit = u -// } -// // Implementation... -// return fmt.Sprintf("Weather in %s: 25°%s", city, unit), nil -// }, -// ai.WithToolInputSchema(inputSchema), -// ) func DefineToolWithInputSchema[Out any](g *Genkit, name, description string, inputSchema map[string]any, fn ai.ToolFunc[any, Out]) ai.Tool { return ai.DefineTool(g.reg, name, description, fn, ai.WithInputSchema(inputSchema)) } @@ -554,7 +586,13 @@ func DefineToolWithInputSchema[Out any](g *Genkit, name, description string, inp // returning an [ai.MultipartToolResponse] which contains both the output and optional // content parts. // -// Use [ai.WithInputSchema] to provide a custom JSON schema instead of inferring from the type parameter. +// For multipart tools that don't need to be registered (e.g., dynamically created tools), +// use [ai.NewMultipartTool] instead. +// +// # Options +// +// - [ai.WithInputSchema]: Provide a custom JSON schema instead of inferring from the type parameter +// - [ai.WithInputSchemaName]: Reference a pre-registered schema by name // // Example: // @@ -605,18 +643,55 @@ func LookupTool(g *Genkit, name string) ai.Tool { } // DefinePrompt defines a prompt programmatically, registers it as a [core.Action] -// of type Prompt, and returns an executable [ai.prompt]. +// of type Prompt, and returns an executable [ai.Prompt]. // // This provides an alternative to defining prompts in `.prompt` files, offering // more flexibility through Go code. Prompts encapsulate configuration (model, parameters), // message templates (system, user, history), input/output schemas, and associated tools. // // Prompts can be executed in two main ways: -// 1. Render + Generate: Call [Prompt.Render] to get [ai.GenerateActionOptions], +// 1. Render + Generate: Call [ai.Prompt.Render] to get [ai.GenerateActionOptions], // modify them if needed, and pass them to [GenerateWithRequest]. -// 2. Execute: Call [Prompt.Execute] directly, passing input and execution options. -// -// Options ([ai.PromptOption]) are used to configure the prompt during definition. +// 2. Execute: Call [ai.Prompt.Execute] directly, passing input and execution options. +// +// For prompts that don't need to be registered (e.g., for single-use or testing), +// use [ai.NewPrompt] instead. +// +// # Options +// +// Model and Configuration: +// - [ai.WithModel]: Specify the model (accepts [ai.Model] or [ai.ModelRef]) +// - [ai.WithModelName]: Specify model by name string +// - [ai.WithConfig]: Set generation parameters (temperature, max tokens, etc.) +// +// Prompt Content: +// - [ai.WithPrompt]: Set the user prompt template (supports {{variable}} syntax) +// - [ai.WithPromptFn]: Set a function that generates the user prompt dynamically +// - [ai.WithSystem]: Set system instructions template +// - [ai.WithSystemFn]: Set a function that generates system instructions dynamically +// - [ai.WithMessages]: Provide static conversation history +// - [ai.WithMessagesFn]: Provide a function that generates conversation history +// +// Input Schema: +// - [ai.WithInputType]: Set input schema from a Go type (provides default values) +// - [ai.WithInputSchema]: Provide a custom JSON schema for input +// - [ai.WithInputSchemaName]: Reference a pre-registered schema by name +// +// Output Schema: +// - [ai.WithOutputType]: Set output schema from a Go type +// - [ai.WithOutputSchema]: Provide a custom JSON schema for output +// - [ai.WithOutputSchemaName]: Reference a pre-registered schema by name +// - [ai.WithOutputFormat]: Specify output format (json, text, etc.) +// +// Tools and Resources: +// - [ai.WithTools]: Enable tools the model can call +// - [ai.WithToolChoice]: Control whether tool calls are required, optional, or disabled +// - [ai.WithMaxTurns]: Set maximum tool call iterations +// - [ai.WithResources]: Attach resources available during generation +// +// Metadata: +// - [ai.WithDescription]: Set a description for the prompt +// - [ai.WithMetadata]: Set arbitrary metadata // // Example: // @@ -631,12 +706,12 @@ func LookupTool(g *Genkit, name string) ai.Tool { // // Define the prompt // capitalPrompt := genkit.DefinePrompt(g, "findCapital", // ai.WithDescription("Finds the capital of a country."), -// ai.WithModelName("googleai/gemini-2.5-flash"), // Specify the model +// ai.WithModelName("googleai/gemini-2.5-flash"), // ai.WithSystem("You are a helpful geography assistant."), // ai.WithPrompt("What is the capital of {{country}}?"), // ai.WithInputType(GeoInput{Country: "USA"}), // ai.WithOutputType(GeoOutput{}), -// ai.WithConfig(&ai.GenerationCommonConfig{Temperature: 0.5}), +// // Config is provider-specific, e.g., genai.GenerateContentConfig for Google AI // ) // // // Option 1: Render + Generate (using default input "USA") @@ -717,6 +792,50 @@ func DefineSchemaFor[T any](g *Genkit) { core.DefineSchemaFor[T](g.reg) } +// DefineDataPrompt creates a new [ai.DataPrompt] with strongly-typed input and output. +// It automatically infers input schema from the In type parameter and configures +// output schema and JSON format from the Out type parameter (unless Out is string). +// +// This is a convenience wrapper around [DefinePrompt] that provides compile-time +// type safety for both input and output. For prompts that don't need to be registered, +// use [ai.NewDataPrompt] instead. +// +// DefineDataPrompt accepts the same options as [DefinePrompt]. See [DefinePrompt] for +// the full list of available options. Note that input and output schemas are automatically +// inferred from the type parameters. +// +// Example: +// +// type GeoInput struct { +// Country string `json:"country"` +// } +// +// type GeoOutput struct { +// Capital string `json:"capital"` +// } +// +// capitalPrompt := genkit.DefineDataPrompt[GeoInput, GeoOutput](g, "findCapital", +// ai.WithModelName("googleai/gemini-2.5-flash"), +// ai.WithSystem("You are a helpful geography assistant."), +// ai.WithPrompt("What is the capital of {{country}}?"), +// ) +// +// output, resp, err := capitalPrompt.Execute(ctx, GeoInput{Country: "France"}) +// if err != nil { +// log.Fatalf("Execute failed: %v", err) +// } +// fmt.Printf("Capital: %s\n", output.Capital) +func DefineDataPrompt[In, Out any](g *Genkit, name string, opts ...ai.PromptOption) *ai.DataPrompt[In, Out] { + return ai.DefineDataPrompt[In, Out](g.reg, name, opts...) +} + +// LookupDataPrompt looks up a prompt by name and wraps it with type information. +// This is useful for wrapping prompts loaded from .prompt files with strong types. +// It returns nil if the prompt was not found. +func LookupDataPrompt[In, Out any](g *Genkit, name string) *ai.DataPrompt[In, Out] { + return ai.LookupDataPrompt[In, Out](g.reg, name) +} + // GenerateWithRequest performs a model generation request using explicitly provided // [ai.GenerateActionOptions]. This function is typically used in conjunction with // prompts defined via [DefinePrompt], where [ai.prompt.Render] produces the @@ -734,8 +853,7 @@ func DefineSchemaFor[T any](g *Genkit) { // // handle error // } // -// // Optional: Modify actionOpts here if needed -// // actionOpts.Config = &ai.GenerationCommonConfig{ Temperature: 0.8 } +// // Optional: Modify actionOpts here if needed (config is provider-specific) // // resp, err := genkit.GenerateWithRequest(ctx, g, actionOpts, nil, nil) // No middleware or streaming // if err != nil { @@ -750,12 +868,50 @@ func GenerateWithRequest(ctx context.Context, g *Genkit, actionOpts *ai.Generate // provided via [ai.GenerateOption] arguments. It's a convenient way to make // generation calls without pre-defining a prompt object. // +// # Options +// +// Model and Configuration: +// - [ai.WithModel]: Specify the model (accepts [ai.Model] or [ai.ModelRef]) +// - [ai.WithModelName]: Specify model by name string (e.g., "googleai/gemini-2.5-flash") +// - [ai.WithConfig]: Set generation parameters (temperature, max tokens, etc.) +// +// Prompting: +// - [ai.WithPrompt]: Set the user prompt (supports format strings) +// - [ai.WithPromptFn]: Set a function that generates the user prompt dynamically +// - [ai.WithSystem]: Set system instructions +// - [ai.WithSystemFn]: Set a function that generates system instructions dynamically +// - [ai.WithMessages]: Provide conversation history +// - [ai.WithMessagesFn]: Provide a function that generates conversation history +// +// Tools and Resources: +// - [ai.WithTools]: Enable tools the model can call +// - [ai.WithToolChoice]: Control whether tool calls are required, optional, or disabled +// - [ai.WithMaxTurns]: Set maximum tool call iterations +// - [ai.WithReturnToolRequests]: Return tool requests instead of executing them +// - [ai.WithResources]: Attach resources available during generation +// +// Output: +// - [ai.WithOutputType]: Request structured output matching a Go type +// - [ai.WithOutputSchema]: Provide a custom JSON schema for output +// - [ai.WithOutputSchemaName]: Reference a pre-registered schema by name +// - [ai.WithOutputFormat]: Specify output format (json, text, etc.) +// - [ai.WithOutputEnums]: Constrain output to specific enum values +// +// Context and Streaming: +// - [ai.WithDocs]: Provide context documents +// - [ai.WithTextDocs]: Provide context as text strings +// - [ai.WithStreaming]: Enable streaming with a callback function +// - [ai.WithMiddleware]: Apply middleware to the model request/response +// +// Tool Continuation: +// - [ai.WithToolResponses]: Resume generation with tool response parts +// - [ai.WithToolRestarts]: Resume generation by restarting tool requests +// // Example: // // resp, err := genkit.Generate(ctx, g, // ai.WithModelName("googleai/gemini-2.5-flash"), // ai.WithPrompt("Write a short poem about clouds."), -// ai.WithConfig(&genai.GenerateContentConfig{MaxOutputTokens: 50}), // ) // if err != nil { // log.Fatalf("Generate failed: %v", err) @@ -766,12 +922,48 @@ func Generate(ctx context.Context, g *Genkit, opts ...ai.GenerateOption) (*ai.Mo return ai.Generate(ctx, g.reg, opts...) } +// GenerateStream generates a model response and streams the output. +// It returns an iterator that yields streaming results. +// +// If the yield function is passed a non-nil error, generation has failed with that +// error; the yield function will not be called again. +// +// If the yield function's [ai.ModelStreamValue] argument has Done == true, the value's +// Response field contains the final response; the yield function will not be called again. +// +// Otherwise the Chunk field of the passed [ai.ModelStreamValue] holds a streamed chunk. +// +// GenerateStream accepts the same options as [Generate]. See [Generate] for the full +// list of available options. +// +// Example: +// +// for result, err := range genkit.GenerateStream(ctx, g, +// ai.WithPrompt("Tell me a story about a brave knight."), +// ) { +// if err != nil { +// log.Fatalf("Stream error: %v", err) +// } +// if result.Done { +// fmt.Println("\nFinal response:", result.Response.Text()) +// } else { +// fmt.Print(result.Chunk.Text()) +// } +// } +func GenerateStream(ctx context.Context, g *Genkit, opts ...ai.GenerateOption) iter.Seq2[*ai.ModelStreamValue, error] { + return ai.GenerateStream(ctx, g.reg, opts...) +} + // GenerateOperation performs a model generation request using a flexible set of options -// provided via [ai.GenerateOption] arguments. It's a convenient way to make -// generation calls without pre-defining a prompt object. +// provided via [ai.GenerateOption] arguments. It's designed for long-running generation +// tasks that may not complete immediately. // // Unlike [Generate], this function returns a [ai.ModelOperation] which can be used to -// check the status of the operation and get the result. +// check the status of the operation and get the result. Use [CheckModelOperation] to +// poll for completion. +// +// GenerateOperation accepts the same options as [Generate]. See [Generate] for the full +// list of available options. // // Example: // @@ -807,7 +999,9 @@ func CheckModelOperation(ctx context.Context, g *Genkit, op *ai.ModelOperation) // GenerateText performs a model generation request similar to [Generate], but // directly returns the generated text content as a string. It's a convenience // wrapper for cases where only the textual output is needed. -// It accepts the same [ai.GenerateOption] arguments as [Generate]. +// +// GenerateText accepts the same options as [Generate]. See [Generate] for the full +// list of available options. // // Example: // @@ -823,16 +1017,13 @@ func GenerateText(ctx context.Context, g *Genkit, opts ...ai.GenerateOption) (st } // GenerateData performs a model generation request, expecting structured output -// (typically JSON) that conforms to the schema of the provided `value` argument. -// It attempts to unmarshal the model's response directly into the `value`. -// The `value` argument must be a pointer to a struct or map. -// -// Use [ai.WithOutputType] or [ai.WithOutputFormat](ai.OutputFormatJSON) in the -// options to instruct the model to generate JSON. [ai.WithOutputType] is preferred -// as it infers the JSON schema from the `value` type and passes it to the model. +// (typically JSON) that conforms to the schema inferred from the Out type parameter. +// It automatically sets output type and JSON format, unmarshals the response, and +// returns the typed result. // -// It returns the full [ai.ModelResponse] along with any error. The generated data -// populates the `value` pointed to. +// GenerateData accepts the same options as [Generate]. See [Generate] for the full +// list of available options. Note that output options like [ai.WithOutputType] are +// automatically applied based on the Out type parameter. // // Example: // @@ -854,15 +1045,62 @@ func GenerateData[Out any](ctx context.Context, g *Genkit, opts ...ai.GenerateOp return ai.GenerateData[Out](ctx, g.reg, opts...) } +// GenerateDataStream generates a model response with streaming and returns strongly-typed output. +// It returns an iterator that yields streaming results. +// +// If the yield function is passed a non-nil error, generation has failed with that +// error; the yield function will not be called again. +// +// If the yield function's [ai.StreamValue] argument has Done == true, the value's +// Output and Response fields contain the final typed output and response; the yield function +// will not be called again. +// +// Otherwise the Chunk field of the passed [ai.StreamValue] holds a streamed chunk. +// +// GenerateDataStream accepts the same options as [Generate]. See [Generate] for the full +// list of available options. Note that output options are automatically applied based on +// the Out type parameter. +// +// Example: +// +// type Story struct { +// Title string `json:"title"` +// Content string `json:"content"` +// } +// +// for result, err := range genkit.GenerateDataStream[Story](ctx, g, +// ai.WithPrompt("Write a short story about a brave knight."), +// ) { +// if err != nil { +// log.Fatalf("Stream error: %v", err) +// } +// if result.Done { +// fmt.Printf("Story: %+v\n", result.Output) +// } else { +// fmt.Print(result.Chunk.Text()) +// } +// } +func GenerateDataStream[Out any](ctx context.Context, g *Genkit, opts ...ai.GenerateOption) iter.Seq2[*ai.StreamValue[Out, Out], error] { + return ai.GenerateDataStream[Out](ctx, g.reg, opts...) +} + // Retrieve performs a document retrieval request using a flexible set of options // provided via [ai.RetrieverOption] arguments. It's a convenient way to retrieve // relevant documents from registered retrievers without directly calling the // retriever instance. // +// # Options +// +// - [ai.WithRetriever]: Specify the retriever (accepts [ai.Retriever] or [ai.RetrieverRef]) +// - [ai.WithRetrieverName]: Specify retriever by name string +// - [ai.WithConfig]: Set retriever-specific configuration +// - [ai.WithTextDocs]: Provide query text as documents +// - [ai.WithDocs]: Provide query as [ai.Document] instances +// // Example: // // resp, err := genkit.Retrieve(ctx, g, -// ai.WithRetriever(ai.NewRetrieverRef("myRetriever", nil)), +// ai.WithRetrieverName("myRetriever"), // ai.WithTextDocs("What is the capital of France?"), // ) // if err != nil { @@ -880,10 +1118,18 @@ func Retrieve(ctx context.Context, g *Genkit, opts ...ai.RetrieverOption) (*ai.R // provided via [ai.EmbedderOption] arguments. It's a convenient way to generate // embeddings from registered embedders without directly calling the embedder instance. // +// # Options +// +// - [ai.WithEmbedder]: Specify the embedder (accepts [ai.Embedder] or [ai.EmbedderRef]) +// - [ai.WithEmbedderName]: Specify embedder by name string +// - [ai.WithConfig]: Set embedder-specific configuration +// - [ai.WithTextDocs]: Provide text to embed +// - [ai.WithDocs]: Provide [ai.Document] instances to embed +// // Example: // // resp, err := genkit.Embed(ctx, g, -// ai.WithEmbedder(ai.NewEmbedderRef("myEmbedder", nil)), +// ai.WithEmbedderName("myEmbedder"), // ai.WithTextDocs("Hello, world!"), // ) // if err != nil { @@ -902,9 +1148,12 @@ func Embed(ctx context.Context, g *Genkit, opts ...ai.EmbedderOption) (*ai.Embed // Retrievers are used to find documents relevant to a given query, often by // performing similarity searches in a vector database. // -// The `provider` and `name` form the unique identifier. The `ret` function +// The `name` is the unique identifier for the retriever. The `fn` function // contains the logic to process an [ai.RetrieverRequest] (containing the query) // and return an [ai.RetrieverResponse] (containing the relevant documents). +// +// For retrievers that don't need to be registered (e.g., for plugin development), +// use [ai.NewRetriever] instead. func DefineRetriever(g *Genkit, name string, opts *ai.RetrieverOptions, fn ai.RetrieverFunc) ai.Retriever { return ai.DefineRetriever(g.reg, name, opts, fn) } @@ -920,9 +1169,12 @@ func LookupRetriever(g *Genkit, name string) ai.Retriever { // [core.Action] of type Embedder, and returns an [ai.Embedder]. // Embedders convert text documents or queries into numerical vector representations (embeddings). // -// The `provider` and `name` are specified in the `opts` parameter which forms the unique identifier. -// The `embed` function contains the logic to process an [ai.EmbedRequest] (containing documents or a query) +// The `name` is the unique identifier for the embedder. +// The `fn` function contains the logic to process an [ai.EmbedRequest] (containing documents or a query) // and return an [ai.EmbedResponse] (containing the corresponding embeddings). +// +// For embedders that don't need to be registered (e.g., for plugin development), +// use [ai.NewEmbedder] instead. func DefineEmbedder(g *Genkit, name string, opts *ai.EmbedderOptions, fn ai.EmbedderFunc) ai.Embedder { return ai.DefineEmbedder(g.reg, name, opts, fn) } @@ -988,6 +1240,14 @@ func LookupEvaluator(g *Genkit, name string) ai.Evaluator { // evaluations using registered evaluators without directly calling the // evaluator instance. // +// # Options +// +// - [ai.WithEvaluator]: Specify the evaluator (accepts [ai.Evaluator] or [ai.EvaluatorRef]) +// - [ai.WithEvaluatorName]: Specify evaluator by name string +// - [ai.WithDataset]: Provide the dataset of examples to evaluate +// - [ai.WithID]: Set a unique identifier for this evaluation run +// - [ai.WithConfig]: Set evaluator-specific configuration +// // Example: // // dataset := []*ai.Example{ @@ -998,8 +1258,8 @@ func LookupEvaluator(g *Genkit, name string) ai.Evaluator { // } // // resp, err := genkit.Evaluate(ctx, g, -// ai.WithEvaluator(ai.NewEvaluatorRef("myEvaluator", nil)), -// ai.WithDataset(dataset), +// ai.WithEvaluatorName("myEvaluator"), +// ai.WithDataset(dataset...), // ) // if err != nil { // log.Fatalf("Evaluate failed: %v", err) @@ -1026,8 +1286,67 @@ func Evaluate(ctx context.Context, g *Genkit, opts ...ai.EvaluatorOption) (*ai.E // This function is often called implicitly by [Init] using the directory specified // by [WithPromptDir], but can be called explicitly to load prompts from other // locations or with different namespaces. -func LoadPromptDir(g *Genkit, dir string, namespace string) { - ai.LoadPromptDir(g.reg, dir, namespace) +func LoadPromptDir(g *Genkit, dir, namespace string) { + loadPromptDirOS(g.reg, dir, namespace) +} + +// loadPromptDirOS loads prompts from an OS directory by converting to os.DirFS. +func loadPromptDirOS(r api.Registry, dir, namespace string) { + useDefaultDir := false + if dir == "" { + dir = "./prompts" + useDefaultDir = true + } + + absPath, err := filepath.Abs(dir) + if err != nil { + if !useDefaultDir { + panic(fmt.Errorf("failed to resolve prompt directory %q: %w", dir, err)) + } + slog.Debug("default prompt directory not found, skipping loading .prompt files", "dir", dir) + return + } + + if _, err := os.Stat(absPath); os.IsNotExist(err) { + if !useDefaultDir { + panic(fmt.Errorf("failed to resolve prompt directory %q: %w", dir, err)) + } + slog.Debug("Default prompt directory not found, skipping loading .prompt files", "dir", dir) + return + } + + ai.LoadPromptDirFromFS(r, os.DirFS(absPath), ".", namespace) +} + +// LoadPromptDirFromFS loads all `.prompt` files from the specified embedded filesystem `fsys` +// into the registry, associating them with the given `namespace`. +// Files starting with `_` are treated as partials and are not registered as +// executable prompts but can be included in other prompts. +// +// The `fsys` parameter should be an [fs.FS] implementation (e.g., [embed.FS]). +// The `dir` parameter specifies the directory within the filesystem where +// prompts are located (e.g., "prompts" if using `//go:embed prompts/*`). +// The `namespace` acts as a prefix to the prompt name (e.g., namespace "myApp" and +// file "greeting.prompt" results in prompt name "myApp/greeting"). Use an empty +// string for no namespace. +// +// This function provides an alternative to [LoadPromptDir] for loading prompts +// from embedded filesystems, enabling self-contained binaries without external +// prompt files. +// +// Example: +// +// import "embed" +// +// //go:embed prompts/* +// var promptsFS embed.FS +// +// func main() { +// g := genkit.Init(ctx) +// genkit.LoadPromptDirFromFS(g, promptsFS, "prompts", "myNamespace") +// } +func LoadPromptDirFromFS(g *Genkit, fsys fs.FS, dir, namespace string) { + ai.LoadPromptDirFromFS(g.reg, fsys, dir, namespace) } // LoadPrompt loads a single `.prompt` file specified by `path` into the registry, @@ -1052,13 +1371,49 @@ func LoadPromptDir(g *Genkit, dir string, namespace string) { // // Execute the loaded prompt // resp, err := customPrompt.Execute(ctx, ai.WithInput(map[string]any{"text": "some data"})) // // ... handle response and error ... -func LoadPrompt(g *Genkit, path string, namespace string) ai.Prompt { +func LoadPrompt(g *Genkit, path, namespace string) ai.Prompt { dir, filename := filepath.Split(path) - if dir != "" { + if dir == "" { + dir = "." + } else { dir = filepath.Clean(dir) } - return ai.LoadPrompt(g.reg, dir, filename, namespace) + return ai.LoadPromptFromFS(g.reg, os.DirFS(dir), ".", filename, namespace) +} + +// LoadPromptFromSource loads a prompt from raw `.prompt` file content (frontmatter + template) +// into the registry and returns the resulting [ai.Prompt]. +// +// The `source` parameter should contain the complete `.prompt` file text, including +// the YAML frontmatter (delimited by `---`) and the template body. +// The `name` parameter is the prompt name, which may include a variant suffix +// (e.g., "greeting" or "greeting.formal"). +// The `namespace` acts as a prefix to the prompt name. Use an empty string for no namespace. +// +// This is useful for loading prompts from sources other than the filesystem, +// such as databases, environment variables, or embedded strings. +// +// Example: +// +// promptSource := `--- +// model: googleai/gemini-2.5-flash +// input: +// schema: +// name: string +// --- +// Hello, {{name}}! +// ` +// +// prompt, err := genkit.LoadPromptFromSource(g, promptSource, "greeting", "myApp") +// if err != nil { +// log.Fatalf("Failed to load prompt: %v", err) +// } +// +// resp, err := prompt.Execute(ctx, ai.WithInput(map[string]any{"name": "World"})) +// // ... +func LoadPromptFromSource(g *Genkit, source, name, namespace string) (ai.Prompt, error) { + return ai.LoadPromptFromSource(g.reg, source, name, namespace) } // DefinePartial wraps DefinePartial to register a partial template with the given name and source. diff --git a/go/genkit/servers.go b/go/genkit/servers.go index d48c11ffd6..0b3bbe1ed6 100644 --- a/go/genkit/servers.go +++ b/go/genkit/servers.go @@ -31,23 +31,37 @@ import ( "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/core/logger" + "github.com/firebase/genkit/go/core/x/streaming" + "github.com/google/uuid" ) +// HandlerOption configures a Handler. type HandlerOption interface { - apply(params *handlerParams) + applyHandler(*handlerOptions) error } -// handlerParams are the parameters for an action HTTP handler. -type handlerParams struct { - ContextProviders []core.ContextProvider // Providers for action context that may be used during runtime. +// handlerOptions are options for an action HTTP handler. +type handlerOptions struct { + ContextProviders []core.ContextProvider // Providers for action context that may be used during runtime. + StreamManager streaming.StreamManager // Optional manager for durable stream storage. } -// apply applies the options to the handler params. -func (p *handlerParams) apply(params *handlerParams) { - if params.ContextProviders != nil { - panic("genkit.WithContextProviders: cannot set ContextProviders more than once") +func (o *handlerOptions) applyHandler(opts *handlerOptions) error { + if o.ContextProviders != nil { + if opts.ContextProviders != nil { + return errors.New("cannot set ContextProviders more than once (WithContextProviders)") + } + opts.ContextProviders = o.ContextProviders + } + + if o.StreamManager != nil { + if opts.StreamManager != nil { + return errors.New("cannot set StreamManager more than once (WithStreamManager)") + } + opts.StreamManager = o.StreamManager } - params.ContextProviders = p.ContextProviders + + return nil } // requestID is a unique ID for each request. @@ -56,7 +70,16 @@ var requestID atomic.Int64 // WithContextProviders adds providers for action context that may be used during runtime. // They are called in the order added and may overwrite previous context. func WithContextProviders(ctxProviders ...core.ContextProvider) HandlerOption { - return &handlerParams{ContextProviders: ctxProviders} + return &handlerOptions{ContextProviders: ctxProviders} +} + +// WithStreamManager enables durable streaming with the provided StreamManager. +// When enabled, streaming responses include an x-genkit-stream-id header that clients +// can use to reconnect to in-progress or completed streams. +// +// EXPERIMENTAL: This API is subject to change. +func WithStreamManager(manager streaming.StreamManager) HandlerOption { + return &handlerOptions{StreamManager: manager} } // Handler returns an HTTP handler function that serves the action with the provided options. @@ -67,12 +90,14 @@ func WithContextProviders(ctxProviders ...core.ContextProvider) HandlerOption { // return api.ActionContext{"myKey": "myValue"}, nil // })) func Handler(a api.Action, opts ...HandlerOption) http.HandlerFunc { - params := &handlerParams{} + options := &handlerOptions{} for _, opt := range opts { - opt.apply(params) + if err := opt.applyHandler(options); err != nil { + panic(fmt.Errorf("genkit.Handler: error applying options: %w", err)) + } } - return wrapHandler(handler(a, params)) + return wrapHandler(handler(a, options)) } // wrapHandler wraps an HTTP handler function with common logging and error handling. @@ -101,8 +126,9 @@ func wrapHandler(h func(http.ResponseWriter, *http.Request) error) http.HandlerF } } -// handler returns an HTTP handler function that serves the action with the provided params. Responses are written in server-sent events (SSE) format. -func handler(a api.Action, params *handlerParams) func(http.ResponseWriter, *http.Request) error { +// handler returns an HTTP handler function that serves the action with the provided options. +// Streaming responses are written in server-sent events (SSE) format. +func handler(a api.Action, opts *handlerOptions) func(http.ResponseWriter, *http.Request) error { return func(w http.ResponseWriter, r *http.Request) error { if a == nil { return errors.New("action is nil; cannot serve") @@ -124,29 +150,9 @@ func handler(a api.Action, params *handlerParams) func(http.ResponseWriter, *htt } stream = stream || r.Header.Get("Accept") == "text/event-stream" - var callback streamingCallback[json.RawMessage] - if stream { - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Transfer-Encoding", "chunked") - callback = func(ctx context.Context, msg json.RawMessage) error { - _, err := fmt.Fprintf(w, "data: {\"message\": %s}\n\n", msg) - if err != nil { - return err - } - if f, ok := w.(http.Flusher); ok { - f.Flush() - } - return nil - } - } else { - w.Header().Set("Content-Type", "application/json") - } - ctx := r.Context() - if params.ContextProviders != nil { - for _, ctxProvider := range params.ContextProviders { + if opts.ContextProviders != nil { + for _, ctxProvider := range opts.ContextProviders { headers := make(map[string]string, len(r.Header)) for k, v := range r.Header { headers[strings.ToLower(k)] = strings.Join(v, " ") @@ -170,22 +176,252 @@ func handler(a api.Action, params *handlerParams) func(http.ResponseWriter, *htt } } - out, err := a.RunJSON(ctx, body.Data, callback) - if err != nil { - if stream { - _, err = fmt.Fprintf(w, "data: {\"error\": {\"status\": \"INTERNAL\", \"message\": \"stream flow error\", \"details\": \"%v\"}}\n\n", err) - return err + if stream { + streamID := r.Header.Get("X-Genkit-Stream-Id") + + if streamID != "" && opts.StreamManager != nil { + return subscribeToStream(ctx, w, opts.StreamManager, streamID) + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Transfer-Encoding", "chunked") + + if opts.StreamManager != nil { + return runWithDurableStreaming(ctx, w, a, opts.StreamManager, body.Data) } + + return runWithStreaming(ctx, w, a, body.Data) + } + + w.Header().Set("Content-Type", "application/json") + out, err := a.RunJSON(ctx, body.Data, nil) + if err != nil { return err } - if stream { - _, err = fmt.Fprintf(w, "data: {\"result\": %s}\n\n", out) + return writeResultResponse(w, out) + } +} + +// runWithStreaming executes the action with standard HTTP streaming (no durability). +func runWithStreaming(ctx context.Context, w http.ResponseWriter, a api.Action, input json.RawMessage) error { + callback := func(ctx context.Context, msg json.RawMessage) error { + if err := writeSSEMessage(w, msg); err != nil { return err } + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + return nil + } + + out, err := a.RunJSON(ctx, input, callback) + if err != nil { + if werr := writeSSEError(w, err); werr != nil { + return werr + } + return nil + } + return writeSSEResult(w, out) +} + +// runWithDurableStreaming executes the action with durable streaming support. +// Chunks are written to both the HTTP response and the stream manager for later replay. +// +// The flow execution is detached from the HTTP request context so that if the +// original client disconnects, the flow continues running and writing to durable +// storage. This allows other clients to subscribe to the stream and receive the +// remaining chunks and final result. +func runWithDurableStreaming(ctx context.Context, w http.ResponseWriter, a api.Action, sm streaming.StreamManager, input json.RawMessage) error { + streamID := uuid.New().String() + + durableStream, err := sm.Open(ctx, streamID) + if err != nil { + return err + } + defer durableStream.Close() + + w.Header().Set("X-Genkit-Stream-Id", streamID) + + // Create a detached context for flow execution. This preserves context values + // (action context, tracing, logger) but won't be canceled when the HTTP client + // disconnects, allowing the flow to continue streaming to durable storage. + durableCtx := context.WithoutCancel(ctx) + + // Track whether the HTTP client is still connected. + clientGone := ctx.Done() + + callback := func(_ context.Context, msg json.RawMessage) error { + // Always write to durable storage regardless of client connection state. + durableStream.Write(durableCtx, msg) - _, err = fmt.Fprintf(w, "{\"result\": %s}\n", out) + // Only attempt HTTP writes if the client is still connected. + select { + case <-clientGone: + return nil + default: + if err := writeSSEMessage(w, msg); err != nil { + return nil + } + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + } + return nil + } + + out, err := a.RunJSON(durableCtx, input, callback) + if err != nil { + durableStream.Error(durableCtx, err) + select { + case <-clientGone: + return nil + default: + writeSSEError(w, err) + } + return nil + } + + durableStream.Done(durableCtx, out) + select { + case <-clientGone: + return nil + default: + return writeSSEResult(w, out) + } +} + +// subscribeToStream subscribes to an existing durable stream and writes events to the HTTP response. +func subscribeToStream(ctx context.Context, w http.ResponseWriter, sm streaming.StreamManager, streamID string) error { + events, unsubscribe, err := sm.Subscribe(ctx, streamID) + if err != nil { + var ufErr *core.UserFacingError + if errors.As(err, &ufErr) && ufErr.Status == core.NOT_FOUND { + w.WriteHeader(http.StatusNoContent) + return nil + } + return err + } + defer unsubscribe() + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Transfer-Encoding", "chunked") + + for event := range events { + switch event.Type { + case streaming.StreamEventChunk: + if err := writeSSEMessage(w, event.Chunk); err != nil { + return err + } + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + case streaming.StreamEventDone: + if err := writeSSEResult(w, event.Output); err != nil { + return err + } + return nil + case streaming.StreamEventError: + streamErr := event.Err + if streamErr == nil { + streamErr = errors.New("unknown error") + } + if err := writeSSEError(w, streamErr); err != nil { + return err + } + return nil + } + } + + return nil +} + +// flowResultResponse wraps a final action result for JSON serialization. +type flowResultResponse struct { + Result json.RawMessage `json:"result"` +} + +// flowMessageResponse wraps a streaming chunk for JSON serialization. +type flowMessageResponse struct { + Message json.RawMessage `json:"message"` +} + +// flowErrorResponse wraps an error for JSON serialization in streaming responses. +type flowErrorResponse struct { + Error *flowError `json:"error"` +} + +// flowError represents the error payload in a streaming error response. +type flowError struct { + Status core.StatusName `json:"status"` + Message string `json:"message"` + Details string `json:"details,omitempty"` +} + +// writeResultResponse writes a JSON result response for non-streaming requests. +func writeResultResponse(w http.ResponseWriter, result json.RawMessage) error { + resp := flowResultResponse{Result: result} + data, err := json.Marshal(resp) + if err != nil { + return err + } + _, err = w.Write(data) + if err != nil { + return err + } + _, err = w.Write([]byte("\n")) + return err +} + +// writeSSEResult writes a JSON result as a server-sent event for streaming requests. +func writeSSEResult(w http.ResponseWriter, result json.RawMessage) error { + resp := flowResultResponse{Result: result} + data, err := json.Marshal(resp) + if err != nil { + return err + } + _, err = fmt.Fprintf(w, "data: %s\n\n", data) + return err +} + +// writeSSEMessage writes a streaming chunk as a server-sent event. +func writeSSEMessage(w http.ResponseWriter, msg json.RawMessage) error { + resp := flowMessageResponse{Message: msg} + data, err := json.Marshal(resp) + if err != nil { + return err + } + _, err = fmt.Fprintf(w, "data: %s\n\n", data) + return err +} + +// writeSSEError writes an error as a server-sent event for streaming requests. +func writeSSEError(w http.ResponseWriter, flowErr error) error { + status := core.INTERNAL + var ufErr *core.UserFacingError + var gErr *core.GenkitError + if errors.As(flowErr, &ufErr) { + status = ufErr.Status + } else if errors.As(flowErr, &gErr) { + status = gErr.Status + } + + resp := flowErrorResponse{ + Error: &flowError{ + Status: status, + Message: "stream flow error", + Details: flowErr.Error(), + }, + } + data, err := json.Marshal(resp) + if err != nil { return err } + _, err = fmt.Fprintf(w, "data: %s\n\n", data) + return err } func parseBoolQueryParam(r *http.Request, name string) (bool, error) { diff --git a/go/genkit/servers_test.go b/go/genkit/servers_test.go index a0a07cc21b..b5a69d17ec 100644 --- a/go/genkit/servers_test.go +++ b/go/genkit/servers_test.go @@ -27,6 +27,7 @@ import ( "testing" "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/x/streaming" ) func FakeContextProvider(ctx context.Context, req core.RequestData) (core.ActionContext, error) { @@ -222,17 +223,17 @@ func TestStreamingHandler(t *testing.T) { t.Errorf("want status code %d, got %d", http.StatusOK, resp.StatusCode) } - expected := `data: {"message": "h"} + expected := `data: {"message":"h"} -data: {"message": "e"} +data: {"message":"e"} -data: {"message": "l"} +data: {"message":"l"} -data: {"message": "l"} +data: {"message":"l"} -data: {"message": "o"} +data: {"message":"o"} -data: {"result": "hello-end"} +data: {"result":"hello-end"} ` if string(body) != expected { @@ -256,7 +257,7 @@ data: {"result": "hello-end"} t.Errorf("want status code %d, got %d", http.StatusOK, resp.StatusCode) } - expected := `data: {"error": {"status": "INTERNAL", "message": "stream flow error", "details": "streaming error"}} + expected := `data: {"error":{"status":"INTERNAL_SERVER_ERROR","message":"stream flow error","details":"streaming error"}} ` if string(body) != expected { @@ -264,3 +265,121 @@ data: {"result": "hello-end"} } }) } + +func TestDurableStreamingHandler(t *testing.T) { + g := Init(context.Background()) + + streamingFlow := DefineStreamingFlow(g, "durableStreaming", + func(ctx context.Context, input string, cb func(context.Context, string) error) (string, error) { + for _, c := range input { + if err := cb(ctx, string(c)); err != nil { + return "", err + } + } + return input + "-done", nil + }) + + t.Run("returns stream ID header", func(t *testing.T) { + sm := streaming.NewInMemoryStreamManager() + defer sm.Close() + handler := Handler(streamingFlow, WithStreamManager(sm)) + + req := httptest.NewRequest("POST", "/", strings.NewReader(`{"data":"hi"}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + w := httptest.NewRecorder() + + handler(w, req) + + resp := w.Result() + body, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + t.Errorf("want status code %d, got %d", http.StatusOK, resp.StatusCode) + } + + streamID := resp.Header.Get("X-Genkit-Stream-Id") + if streamID == "" { + t.Error("want X-Genkit-Stream-Id header to be set") + } + + expected := `data: {"message":"h"} + +data: {"message":"i"} + +data: {"result":"hi-done"} + +` + if string(body) != expected { + t.Errorf("want streaming body:\n%q\n\nGot:\n%q", expected, string(body)) + } + }) + + t.Run("subscribe to completed stream", func(t *testing.T) { + sm := streaming.NewInMemoryStreamManager() + defer sm.Close() + handler := Handler(streamingFlow, WithStreamManager(sm)) + + // First request - run the stream to completion + req1 := httptest.NewRequest("POST", "/", strings.NewReader(`{"data":"ab"}`)) + req1.Header.Set("Content-Type", "application/json") + req1.Header.Set("Accept", "text/event-stream") + w1 := httptest.NewRecorder() + + handler(w1, req1) + + resp1 := w1.Result() + streamID := resp1.Header.Get("X-Genkit-Stream-Id") + if streamID == "" { + t.Fatal("want X-Genkit-Stream-Id header to be set") + } + + // Second request - subscribe to the completed stream + req2 := httptest.NewRequest("POST", "/", strings.NewReader(`{"data":"ignored"}`)) + req2.Header.Set("Content-Type", "application/json") + req2.Header.Set("Accept", "text/event-stream") + req2.Header.Set("X-Genkit-Stream-Id", streamID) + w2 := httptest.NewRecorder() + + handler(w2, req2) + + resp2 := w2.Result() + body2, _ := io.ReadAll(resp2.Body) + + if resp2.StatusCode != http.StatusOK { + t.Errorf("want status code %d, got %d", http.StatusOK, resp2.StatusCode) + } + + // Should replay all chunks and the final result + expected := `data: {"message":"a"} + +data: {"message":"b"} + +data: {"result":"ab-done"} + +` + if string(body2) != expected { + t.Errorf("want replayed body:\n%q\n\nGot:\n%q", expected, string(body2)) + } + }) + + t.Run("subscribe to non-existent stream returns 204", func(t *testing.T) { + sm := streaming.NewInMemoryStreamManager() + defer sm.Close() + handler := Handler(streamingFlow, WithStreamManager(sm)) + + req := httptest.NewRequest("POST", "/", strings.NewReader(`{"data":"test"}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("X-Genkit-Stream-Id", "non-existent-stream-id") + w := httptest.NewRecorder() + + handler(w, req) + + resp := w.Result() + + if resp.StatusCode != http.StatusNoContent { + t.Errorf("want status code %d, got %d", http.StatusNoContent, resp.StatusCode) + } + }) +} diff --git a/go/go.mod b/go/go.mod index 3aa1cd948b..3472c0f4cb 100644 --- a/go/go.mod +++ b/go/go.mod @@ -41,7 +41,7 @@ require ( golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 golang.org/x/tools v0.34.0 google.golang.org/api v0.236.0 - google.golang.org/genai v1.36.0 + google.golang.org/genai v1.40.0 ) require ( diff --git a/go/go.sum b/go/go.sum index 43f5ac29cd..e7abcc1495 100644 --- a/go/go.sum +++ b/go/go.sum @@ -537,8 +537,8 @@ google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9Ywl google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine/v2 v2.0.6 h1:LvPZLGuchSBslPBp+LAhihBeGSiRh1myRoYK4NtuBIw= google.golang.org/appengine/v2 v2.0.6/go.mod h1:WoEXGoXNfa0mLvaH5sV3ZSGXwVmy8yf7Z1JKf3J3wLI= -google.golang.org/genai v1.36.0 h1:sJCIjqTAmwrtAIaemtTiKkg2TO1RxnYEusTmEQ3nGxM= -google.golang.org/genai v1.36.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= +google.golang.org/genai v1.40.0 h1:kYxyQSH+vsib8dvsgyLJzsVEIv5k3ZmHJyVqdvGncmc= +google.golang.org/genai v1.40.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= diff --git a/go/internal/base/json.go b/go/internal/base/json.go index 4f413ab8f7..117855a6d3 100644 --- a/go/internal/base/json.go +++ b/go/internal/base/json.go @@ -138,18 +138,29 @@ func SchemaAsMap(s *jsonschema.Schema) map[string]any { return m } -// jsonMarkdownRegex specifically looks for "json" language identifier -var jsonMarkdownRegex = regexp.MustCompile("(?s)```json(.*?)```") +// jsonMarkdownRegex matches fenced code blocks with "json" language identifier (case-insensitive). +var jsonMarkdownRegex = regexp.MustCompile("(?si)```json\\s*(.*?)```") + +// plainMarkdownRegex matches fenced code blocks without any language identifier. +var plainMarkdownRegex = regexp.MustCompile("(?s)```\\s*\\n(.*?)```") // ExtractJSONFromMarkdown returns the contents of the first fenced code block in -// the markdown text md. If there is none, it returns md. +// the markdown text md. It matches code blocks with "json" identifier (case-insensitive) +// or code blocks without any language identifier. If there is no matching block, it returns md. func ExtractJSONFromMarkdown(md string) string { + // First try to match explicit json code blocks matches := jsonMarkdownRegex.FindStringSubmatch(md) - if len(matches) < 2 { - return md + if len(matches) >= 2 { + return strings.TrimSpace(matches[1]) + } + + // Fall back to plain code blocks (no language identifier) + matches = plainMarkdownRegex.FindStringSubmatch(md) + if len(matches) >= 2 { + return strings.TrimSpace(matches[1]) } - // capture group 1 matches the actual fenced JSON block - return strings.TrimSpace(matches[1]) + + return md } // GetJSONObjectLines splits a string by newlines, trims whitespace from each line, diff --git a/go/internal/base/json_test.go b/go/internal/base/json_test.go index b018849af2..eda537c9ba 100644 --- a/go/internal/base/json_test.go +++ b/go/internal/base/json_test.go @@ -78,6 +78,31 @@ func TestExtractJSONFromMarkdown(t *testing.T) { in: "```json\n{\"a\": 1}\n``` ```yaml\nkey: 1\nanother-key: 2```", want: "{\"a\": 1}", }, + { + desc: "uppercase JSON identifier", + in: "```JSON\n{\"a\": 1}\n```", + want: "{\"a\": 1}", + }, + { + desc: "mixed case Json identifier", + in: "```Json\n{\"a\": 1}\n```", + want: "{\"a\": 1}", + }, + { + desc: "plain code block without identifier", + in: "```\n{\"a\": 1}\n```", + want: "{\"a\": 1}", + }, + { + desc: "plain code block with text before", + in: "Here is the result:\n\n```\n{\"title\": \"Pizza\"}\n```", + want: "{\"title\": \"Pizza\"}", + }, + { + desc: "json block preferred over plain block", + in: "```\n{\"plain\": true}\n``` then ```json\n{\"json\": true}\n```", + want: "{\"json\": true}", + }, } for _, tc := range tests { t.Run(tc.desc, func(t *testing.T) { diff --git a/go/internal/base/misc.go b/go/internal/base/misc.go index 9e3afa1d93..f4fdb7af32 100644 --- a/go/internal/base/misc.go +++ b/go/internal/base/misc.go @@ -18,6 +18,7 @@ package base import ( "net/url" + "reflect" ) // An Environment is the execution context in which the program is running. @@ -38,3 +39,16 @@ func Zero[T any]() T { func Clean(id string) string { return url.PathEscape(id) } + +// IsNil returns true if v is nil or a nil pointer/interface/map/slice/channel/func. +func IsNil[T any](v T) bool { + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Invalid: + return true + case reflect.Ptr, reflect.Interface, reflect.Map, reflect.Slice, reflect.Chan, reflect.Func: + return rv.IsNil() + default: + return false + } +} diff --git a/go/plugins/anthropic/anthropic.go b/go/plugins/anthropic/anthropic.go index 493a6c76b9..e93f1abde9 100644 --- a/go/plugins/anthropic/anthropic.go +++ b/go/plugins/anthropic/anthropic.go @@ -169,8 +169,9 @@ func newModel(client anthropic.Client, name string, opts ai.ModelOptions) ai.Mod // configToMap converts a config struct to a map[string]any. func configToMap(config any) map[string]any { r := jsonschema.Reflector{ - DoNotReference: false, // Prevent $ref usage + DoNotReference: true, // Prevent $ref usage AllowAdditionalProperties: false, + ExpandedStruct: true, RequiredFromJSONSchemaTags: true, } // The anthropic SDK uses a number of wrapper types for float, int, etc. @@ -201,5 +202,6 @@ func configToMap(config any) map[string]any { } schema := r.Reflect(config) result := base.SchemaAsMap(schema) + return result } diff --git a/go/plugins/firebase/auth.go b/go/plugins/firebase/auth.go index bb1856f970..9280973232 100644 --- a/go/plugins/firebase/auth.go +++ b/go/plugins/firebase/auth.go @@ -40,11 +40,11 @@ type AuthClient interface { // ContextProvider creates a Firebase context provider for Genkit actions. func ContextProvider(ctx context.Context, g *genkit.Genkit, policy AuthPolicy) (core.ContextProvider, error) { - f, ok := genkit.LookupPlugin(g, provider).(*Firebase) - if !ok { - return nil, core.NewError(core.NOT_FOUND, "firebase plugin not initialized; did you pass the plugin to genkit.Init()") + f, err := resolvePlugin(g) + if err != nil { + return nil, err } - client, err := f.App.Auth(ctx) + client, err := f.Auth(ctx) if err != nil { return nil, err } diff --git a/go/plugins/firebase/firebase.go b/go/plugins/firebase/firebase.go index 1d221b4bbe..50fa5cc6cf 100644 --- a/go/plugins/firebase/firebase.go +++ b/go/plugins/firebase/firebase.go @@ -20,27 +20,48 @@ import ( "context" "errors" "fmt" - "log" "os" "sync" + "cloud.google.com/go/firestore" firebasev4 "firebase.google.com/go/v4" + "firebase.google.com/go/v4/auth" "github.com/firebase/genkit/go/ai" "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/genkit" ) -// Firebase plugin for Genkit, providing integration with Firebase services. -// This plugin allows users to define retrievers and indexers for Firebase Firestore. -const provider = "firebase" // Identifier for the Firebase plugin. -const projectIdEnv = "FIREBASE_PROJECT_ID" // Environment variable for the Firebase project ID. +const provider = "firebase" +const projectIdEnv = "FIREBASE_PROJECT_ID" -// Firebase FireStore passes configuration options to the plugin. +const pluginInstruction = "Pass the Firebase plugin to genkit.Init():\n" + + " g := genkit.Init(ctx, genkit.WithPlugins(&firebase.Firebase{ProjectId: \"your-project\"}))" + +var errPluginNotInitialized = errors.New("firebase: plugin not initialized. " + pluginInstruction) +var errPluginNotFound = errors.New("firebase: plugin not found. " + pluginInstruction) +var errCredentials = "Ensure you have proper credentials. For local development, run: gcloud auth application-default login" + +// Firebase is the Genkit plugin for Firebase services. +// It provides integration with Firebase Firestore for retrievers, indexers, and durable streaming. +// +// Usage: +// +// g := genkit.Init(ctx, genkit.WithPlugins(&firebase.Firebase{ProjectId: "my-project"})) +// +// Or with an existing Firebase app: +// +// g := genkit.Init(ctx, genkit.WithPlugins(&firebase.Firebase{App: myFirebaseApp})) type Firebase struct { - ProjectId string // Firebase project ID. - App *firebasev4.App // Firebase app instance. - mu sync.Mutex // Mutex to control concurrent access. - initted bool // Tracks whether the plugin has been initialized. + // ProjectId is the Firebase/GCP project ID. If set, a Firebase app is created automatically. + // Can also be set via the FIREBASE_PROJECT_ID environment variable. + ProjectId string + // App is an existing Firebase app instance. Provide either ProjectId or App, not both. + App *firebasev4.App + + mu sync.Mutex + initted bool + firestoreClient *firestore.Client + authClient *auth.Client } // Name returns the name of the plugin. @@ -48,29 +69,32 @@ func (f *Firebase) Name() string { return provider } -// Init initializes the Firebase plugin. +// Init initializes the Firebase plugin. Called automatically by genkit.Init(). func (f *Firebase) Init(ctx context.Context) []api.Action { f.mu.Lock() defer f.mu.Unlock() - // Resolve the Firebase project ID. - projectId := resolveProjectId(f.ProjectId) - if f.initted { panic("firebase.Init: plugin already initialized") } - if f.App == nil && f.ProjectId == "" { - panic("firebase.Init: provide ProjectId or App") + projectId := resolveProjectId(f.ProjectId) + + if f.App == nil && projectId == "" { + panic("firebase.Init: Firebase plugin requires either ProjectId or App to be set.\n" + + " Option 1: Set ProjectId directly: &firebase.Firebase{ProjectId: \"your-project-id\"}\n" + + " Option 2: Set FIREBASE_PROJECT_ID environment variable\n" + + " Option 3: Provide an existing Firebase App: &firebase.Firebase{App: yourApp}") } - if f.ProjectId != "" { - if f.App != nil { - panic("firebase.Init: provide either ProjectId or App, not both") - } - // Configure and initialize the Firebase app. + + if f.App != nil && f.ProjectId != "" { + panic("firebase.Init: provide either ProjectId or App, not both") + } + + if f.App == nil { firebaseApp, err := firebasev4.NewApp(ctx, &firebasev4.Config{ProjectID: projectId}) if err != nil { - panic(fmt.Errorf("error initializing Firebase App: %v", err)) + panic(fmt.Errorf("firebase.Init: failed to initialize Firebase App: %v", err)) } f.App = firebaseApp } @@ -79,37 +103,90 @@ func (f *Firebase) Init(ctx context.Context) []api.Action { return []api.Action{} } -// DefineRetriever defines a Retriever with the given configuration. -func DefineRetriever(ctx context.Context, g *genkit.Genkit, cfg RetrieverOptions) (ai.Retriever, error) { - // Lookup the Firebase plugin from the registry. - f, ok := genkit.LookupPlugin(g, provider).(*Firebase) - if !ok { - return nil, errors.New("firebase plugin not found; did you call firebase.Init with the firebase plugin") +// Firestore returns a cached Firestore client for the Firebase project. +// The client is created lazily on first call and reused for subsequent calls. +// This client is shared across all Firebase plugin features (retrievers, stream managers, etc.). +func (f *Firebase) Firestore(ctx context.Context) (*firestore.Client, error) { + f.mu.Lock() + defer f.mu.Unlock() + + if !f.initted { + return nil, errPluginNotInitialized + } + + if f.firestoreClient != nil { + return f.firestoreClient, nil } - // Initialize Firestore client. - firestoreClient, err := f.App.Firestore(ctx) + client, err := f.App.Firestore(ctx) if err != nil { - log.Fatalf("Error creating Firestore client: %v", err) // Log and exit on failure. + return nil, fmt.Errorf("firebase: failed to create Firestore client: %w. %s", err, errCredentials) } - // Define a Firestore retriever using the client. - retriever, err := defineFirestoreRetriever(g, cfg, firestoreClient) + f.firestoreClient = client + return client, nil +} + +// Auth returns a cached Firebase Auth client for the Firebase project. +// The client is created lazily on first call and reused for subsequent calls. +func (f *Firebase) Auth(ctx context.Context) (*auth.Client, error) { + f.mu.Lock() + defer f.mu.Unlock() + + if !f.initted { + return nil, errPluginNotInitialized + } + + if f.authClient != nil { + return f.authClient, nil + } + + client, err := f.App.Auth(ctx) if err != nil { + return nil, fmt.Errorf("firebase: failed to create Auth client: %w. %s", err, errCredentials) + } - return nil, fmt.Errorf("DefineRetriever: failed to initialize retriever %s: %v", cfg.Name, err) + f.authClient = client + return client, nil +} + +// DefineRetriever defines a Firestore vector retriever with the given configuration. +// The Firebase plugin must be registered with genkit.Init() before calling this function. +func DefineRetriever(ctx context.Context, g *genkit.Genkit, opts RetrieverOptions) (ai.Retriever, error) { + f, err := resolvePlugin(g) + if err != nil { + return nil, err + } + + firestoreClient, err := f.Firestore(ctx) + if err != nil { + return nil, err + } + + retriever, err := defineFirestoreRetriever(g, opts, firestoreClient) + if err != nil { + return nil, fmt.Errorf("firebase.DefineRetriever: failed to initialize retriever %q: %w", opts.Name, err) } return retriever, nil } -// resolveProjectId reads the projectId from the environment if necessary. +// resolveProjectId resolves the Firebase project ID from various sources. func resolveProjectId(projectId string) string { - // Return the provided project ID if it's not empty. if projectId != "" { return projectId } + return os.Getenv(projectIdEnv) +} - // Otherwise, read the project ID from the environment variable. - projectId = os.Getenv(projectIdEnv) - return projectId +// resolvePlugin resolves the Firebase plugin from the Genkit registry. +func resolvePlugin(g *genkit.Genkit) (*Firebase, error) { + plugin := genkit.LookupPlugin(g, provider) + if plugin == nil { + return nil, errPluginNotFound + } + f, ok := plugin.(*Firebase) + if !ok { + return nil, fmt.Errorf("firebase: unexpected plugin type %T for provider %q", plugin, provider) + } + return f, nil } diff --git a/go/plugins/firebase/retriever.go b/go/plugins/firebase/retriever.go index 6287026395..5281c12bb9 100644 --- a/go/plugins/firebase/retriever.go +++ b/go/plugins/firebase/retriever.go @@ -17,6 +17,7 @@ package firebase import ( "context" "fmt" + "log/slog" "os" "cloud.google.com/go/firestore" @@ -26,10 +27,9 @@ import ( "github.com/firebase/genkit/go/genkit" ) -type VectorType int +const firestoreCollectionEnv = "FIRESTORE_COLLECTION" -// Firestore collection environment variable key name -const firestoreCollection = "FIRESTORE_COLLECTION" +type VectorType int // TODO: in retriever options add field that controls the 32/64 @@ -141,14 +141,17 @@ func defineFirestoreRetriever(g *genkit.Genkit, cfg RetrieverOptions, client *fi return genkit.DefineRetriever(g, api.NewName(provider, cfg.Name), retOpts, retrieve), nil } -// resolveFirestoreCollection resolves the Firestore collection name from the environment if necessary func resolveFirestoreCollection(collectionName string) (string, error) { if collectionName != "" { return collectionName, nil } - collectionName = os.Getenv(firestoreCollection) + collectionName = os.Getenv(firestoreCollectionEnv) if collectionName == "" { - return "", fmt.Errorf("no Firestore collection provided; set %q env variable or pass the collection directly", firestoreCollection) + return "", fmt.Errorf("firebase: no Firestore collection provided. " + + "Pass the collection in RetrieverOptions: RetrieverOptions{Collection: \"my-collection\"}") } + slog.Warn("Using FIRESTORE_COLLECTION environment variable is deprecated for retriever configuration. "+ + "Use RetrieverOptions{Collection: \"my-collection\"} instead.", + "collection", collectionName) return collectionName, nil } diff --git a/go/plugins/firebase/x/stream_manager.go b/go/plugins/firebase/x/stream_manager.go new file mode 100644 index 0000000000..edc0154637 --- /dev/null +++ b/go/plugins/firebase/x/stream_manager.go @@ -0,0 +1,497 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// Package x contains experimental Firebase features. +// +// APIs in this package are under active development and may change in any +// minor version release. Use with caution in production environments. +package x + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sync" + "time" + + "cloud.google.com/go/firestore" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/x/streaming" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/firebase" + "github.com/google/uuid" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +const ( + streamBufferSize = 100 + defaultTimeout = 60 * time.Second + defaultTTL = 5 * time.Minute + streamEventChunk = "chunk" + streamEventDone = "done" + streamEventError = "error" +) + +// FirestoreStreamManagerOption configures a FirestoreStreamManager. +type FirestoreStreamManagerOption interface { + applyFirestoreStreamManager(*firestoreStreamManagerOptions) error +} + +// firestoreStreamManagerOptions holds configuration for FirestoreStreamManager. +type firestoreStreamManagerOptions struct { + Collection string + Timeout time.Duration + TTL time.Duration +} + +func (o *firestoreStreamManagerOptions) applyFirestoreStreamManager(opts *firestoreStreamManagerOptions) error { + if o.Collection != "" { + if opts.Collection != "" { + return errors.New("cannot set collection more than once (WithCollection)") + } + opts.Collection = o.Collection + } + + if o.Timeout > 0 { + if opts.Timeout > 0 { + return errors.New("cannot set timeout more than once (WithTimeout)") + } + opts.Timeout = o.Timeout + } + + if o.TTL > 0 { + if opts.TTL > 0 { + return errors.New("cannot set TTL more than once (WithFirestoreTTL)") + } + opts.TTL = o.TTL + } + + return nil +} + +// WithCollection sets the Firestore collection name where stream documents are stored. +// This option is required. +func WithCollection(collection string) FirestoreStreamManagerOption { + return &firestoreStreamManagerOptions{Collection: collection} +} + +// WithTimeout sets how long a subscriber waits for new events before giving up. +// If no activity occurs within this duration, subscribers receive a DEADLINE_EXCEEDED error. +// Default is 60 seconds. +func WithTimeout(timeout time.Duration) FirestoreStreamManagerOption { + return &firestoreStreamManagerOptions{Timeout: timeout} +} + +// WithTTL sets how long completed streams are retained before Firestore auto-deletes them. +// Requires a TTL policy on the collection for the "expiresAt" field. Default is 5 minutes. +// See: https://firebase.google.com/docs/firestore/ttl +func WithTTL(ttl time.Duration) FirestoreStreamManagerOption { + return &firestoreStreamManagerOptions{TTL: ttl} +} + +// FirestoreStreamManager implements [streaming.StreamManager] using Firestore as the backend. +// Stream state is persisted in Firestore documents, allowing streams to survive server +// restarts and be accessible across multiple instances. +type FirestoreStreamManager struct { + client *firestore.Client + collection string + timeout time.Duration + ttl time.Duration +} + +// streamDocument represents the structure of a stream document in Firestore. +type streamDocument struct { + Stream []streamEntry `firestore:"stream"` + CreatedAt time.Time `firestore:"createdAt"` + UpdatedAt time.Time `firestore:"updatedAt"` + ExpiresAt *time.Time `firestore:"expiresAt,omitempty"` +} + +// streamEntry represents a single entry in the stream array. +type streamEntry struct { + Type string `firestore:"type"` + Chunk json.RawMessage `firestore:"chunk,omitempty"` + Output json.RawMessage `firestore:"output,omitempty"` + Err *streamError `firestore:"err,omitempty"` + UUID string `firestore:"uuid,omitempty"` +} + +// streamError represents a serializable error for Firestore storage. +type streamError struct { + Status string `firestore:"status"` + Message string `firestore:"message"` +} + +// NewFirestoreStreamManager creates a FirestoreStreamManager for durable streaming. +func NewFirestoreStreamManager(ctx context.Context, g *genkit.Genkit, opts ...FirestoreStreamManagerOption) (*FirestoreStreamManager, error) { + streamOpts := &firestoreStreamManagerOptions{} + for _, opt := range opts { + if err := opt.applyFirestoreStreamManager(streamOpts); err != nil { + return nil, fmt.Errorf("firebase.NewFirestoreStreamManager: error applying options: %w", err) + } + } + if streamOpts.Collection == "" { + return nil, errors.New("firebase.NewFirestoreStreamManager: Collection name is required.\n" + + " Specify the Firestore collection where stream documents will be stored:\n" + + " firebase.NewFirestoreStreamManager(ctx, g, firebase.WithCollection(\"genkit-streams\"))") + } + if streamOpts.Timeout == 0 { + streamOpts.Timeout = defaultTimeout + } + if streamOpts.TTL == 0 { + streamOpts.TTL = defaultTTL + } + + plugin := genkit.LookupPlugin(g, "firebase") + if plugin == nil { + return nil, errors.New("firebase.NewFirestoreStreamManager: Firebase plugin not found.\n" + + " Pass the Firebase plugin to genkit.Init():\n" + + " g := genkit.Init(ctx, genkit.WithPlugins(&firebase.Firebase{ProjectId: \"your-project\"}))") + } + f, ok := plugin.(*firebase.Firebase) + if !ok { + return nil, fmt.Errorf("firebase.NewFirestoreStreamManager: unexpected plugin type %T", plugin) + } + + client, err := f.Firestore(ctx) + if err != nil { + return nil, fmt.Errorf("firebase.NewFirestoreStreamManager: %w", err) + } + + return &FirestoreStreamManager{ + client: client, + collection: streamOpts.Collection, + timeout: streamOpts.Timeout, + ttl: streamOpts.TTL, + }, nil +} + +// Open creates a new stream for writing. +// Returns ALREADY_EXISTS error if a stream with the given ID already exists. +func (m *FirestoreStreamManager) Open(ctx context.Context, streamID string) (streaming.StreamInput, error) { + docRef := m.client.Collection(m.collection).Doc(streamID) + now := time.Now() + expiresAt := now.Add(m.timeout + m.ttl) + _, err := docRef.Create(ctx, streamDocument{ + Stream: []streamEntry{}, + CreatedAt: now, + UpdatedAt: now, + ExpiresAt: &expiresAt, + }) + if err != nil { + if status.Code(err) == codes.AlreadyExists { + return nil, core.NewPublicError(core.ALREADY_EXISTS, "stream already exists", nil) + } + return nil, err + } + return &firestoreStreamInput{ + manager: m, + streamID: streamID, + docRef: docRef, + }, nil +} + +// Subscribe subscribes to an existing stream. +func (m *FirestoreStreamManager) Subscribe(ctx context.Context, streamID string) (<-chan streaming.StreamEvent, func(), error) { + docRef := m.client.Collection(m.collection).Doc(streamID) + + snapshot, err := docRef.Get(ctx) + if err != nil { + if isNotFound(err) { + return nil, nil, core.NewPublicError(core.NOT_FOUND, "stream not found", nil) + } + return nil, nil, err + } + if !snapshot.Exists() { + return nil, nil, core.NewPublicError(core.NOT_FOUND, "stream not found", nil) + } + + ch := make(chan streaming.StreamEvent, streamBufferSize) + var mu sync.Mutex + var lastIndex int = -1 + var unsubscribed bool + var cancelSnapshot context.CancelFunc + + snapshotCtx, cancelSnapshot := context.WithCancel(ctx) + + var timeoutTimer *time.Timer + resetTimeout := func() { + mu.Lock() + defer mu.Unlock() + if timeoutTimer != nil { + timeoutTimer.Stop() + } + timeoutTimer = time.AfterFunc(m.timeout, func() { + mu.Lock() + defer mu.Unlock() + if !unsubscribed { + unsubscribed = true + ch <- streaming.StreamEvent{ + Type: streaming.StreamEventError, + Err: core.NewPublicError(core.DEADLINE_EXCEEDED, "stream timed out", nil), + } + close(ch) + cancelSnapshot() + } + }) + } + + unsubscribe := func() { + mu.Lock() + defer mu.Unlock() + if !unsubscribed { + unsubscribed = true + if timeoutTimer != nil { + timeoutTimer.Stop() + } + close(ch) + cancelSnapshot() + } + } + + resetTimeout() + + go func() { + snapshots := docRef.Snapshots(snapshotCtx) + defer snapshots.Stop() + + for { + snap, err := snapshots.Next() + if err != nil { + mu.Lock() + if !unsubscribed { + if snapshotCtx.Err() == nil { + ch <- streaming.StreamEvent{ + Type: streaming.StreamEventError, + Err: err, + } + } + unsubscribed = true + if timeoutTimer != nil { + timeoutTimer.Stop() + } + close(ch) + } + mu.Unlock() + return + } + + resetTimeout() + + if !snap.Exists() { + continue + } + + var doc streamDocument + if err := snap.DataTo(&doc); err != nil { + mu.Lock() + if !unsubscribed { + ch <- streaming.StreamEvent{ + Type: streaming.StreamEventError, + Err: err, + } + unsubscribed = true + if timeoutTimer != nil { + timeoutTimer.Stop() + } + close(ch) + } + mu.Unlock() + return + } + + mu.Lock() + for i := lastIndex + 1; i < len(doc.Stream); i++ { + entry := doc.Stream[i] + switch entry.Type { + case streamEventChunk: + if !unsubscribed { + select { + case ch <- streaming.StreamEvent{Type: streaming.StreamEventChunk, Chunk: entry.Chunk}: + default: + } + } + case streamEventDone: + if !unsubscribed { + select { + case ch <- streaming.StreamEvent{Type: streaming.StreamEventDone, Output: entry.Output}: + default: + } + unsubscribed = true + if timeoutTimer != nil { + timeoutTimer.Stop() + } + close(ch) + } + mu.Unlock() + return + case streamEventError: + if !unsubscribed { + var errStatus core.StatusName = core.UNKNOWN + var errMsg string + if entry.Err != nil { + errMsg = entry.Err.Message + if entry.Err.Status != "" { + errStatus = core.StatusName(entry.Err.Status) + } + } + select { + case ch <- streaming.StreamEvent{ + Type: streaming.StreamEventError, + Err: core.NewPublicError(errStatus, errMsg, nil), + }: + default: + } + unsubscribed = true + if timeoutTimer != nil { + timeoutTimer.Stop() + } + close(ch) + } + mu.Unlock() + return + } + } + lastIndex = len(doc.Stream) - 1 + mu.Unlock() + } + }() + + return ch, unsubscribe, nil +} + +// isNotFound checks if the error is a not found error. +func isNotFound(err error) bool { + if err == nil { + return false + } + if grpcErr, ok := status.FromError(err); ok { + return grpcErr.Code() == codes.NotFound + } + return false +} + +// firestoreStreamInput implements streaming.StreamInput for Firestore. +type firestoreStreamInput struct { + manager *FirestoreStreamManager + streamID string + docRef *firestore.DocumentRef + closed bool + mu sync.Mutex +} + +func (s *firestoreStreamInput) Write(ctx context.Context, chunk json.RawMessage) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return core.NewPublicError(core.FAILED_PRECONDITION, "stream writer is closed", nil) + } + + _, err := s.docRef.Update(ctx, []firestore.Update{ + { + Path: "stream", + Value: firestore.ArrayUnion(streamEntry{ + Type: streamEventChunk, + Chunk: chunk, + UUID: uuid.New().String(), + }), + }, + { + Path: "updatedAt", + Value: firestore.ServerTimestamp, + }, + }) + return err +} + +func (s *firestoreStreamInput) Done(ctx context.Context, output json.RawMessage) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return core.NewPublicError(core.FAILED_PRECONDITION, "stream writer is closed", nil) + } + s.closed = true + + expiresAt := time.Now().Add(s.manager.ttl) + _, err := s.docRef.Update(ctx, []firestore.Update{ + { + Path: "stream", + Value: firestore.ArrayUnion(streamEntry{ + Type: streamEventDone, + Output: output, + }), + }, + { + Path: "updatedAt", + Value: firestore.ServerTimestamp, + }, + { + Path: "expiresAt", + Value: expiresAt, + }, + }) + return err +} + +func (s *firestoreStreamInput) Error(ctx context.Context, err error) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return core.NewPublicError(core.FAILED_PRECONDITION, "stream writer is closed", nil) + } + s.closed = true + + streamErr := &streamError{ + Status: string(core.UNKNOWN), + Message: err.Error(), + } + var ufErr *core.UserFacingError + if errors.As(err, &ufErr) { + streamErr.Status = string(ufErr.Status) + } + + expiresAt := time.Now().Add(s.manager.ttl) + _, updateErr := s.docRef.Update(ctx, []firestore.Update{ + { + Path: "stream", + Value: firestore.ArrayUnion(streamEntry{ + Type: streamEventError, + Err: streamErr, + }), + }, + { + Path: "updatedAt", + Value: firestore.ServerTimestamp, + }, + { + Path: "expiresAt", + Value: expiresAt, + }, + }) + return updateErr +} + +func (s *firestoreStreamInput) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + s.closed = true + return nil +} diff --git a/go/plugins/firebase/x/stream_manager_test.go b/go/plugins/firebase/x/stream_manager_test.go new file mode 100644 index 0000000000..64cb5faacb --- /dev/null +++ b/go/plugins/firebase/x/stream_manager_test.go @@ -0,0 +1,564 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package x + +import ( + "context" + "encoding/json" + "errors" + "flag" + "testing" + "time" + + "cloud.google.com/go/firestore" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/core/x/streaming" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/firebase" + "google.golang.org/api/iterator" +) + +var ( + testStreamProjectID = flag.String("test-stream-project-id", "", "GCP Project ID to use for stream manager tests") + testStreamCollection = flag.String("test-stream-collection", "genkit-streams", "Firestore collection to use for stream manager tests") +) + +/* + * Pre-requisites to run this test: + * + * 1. **Option A - Use Firestore Emulator (Recommended for local development):** + * Start the Firestore emulator: + * ```bash + * export FIRESTORE_EMULATOR_HOST=127.0.0.1:8080 + * gcloud emulators firestore start --host-port=127.0.0.1:8080 + * ``` + * + * 2. **Option B - Use a Real Firestore Database:** + * - Set up a Firebase project with Firestore enabled + * - Authenticate using: + * ```bash + * gcloud auth application-default login + * ``` + * + * 3. **Running the Test:** + * ```bash + * go test -test-stream-project-id= -test-stream-collection=genkit-streams + * ``` + */ + +func skipIfNoFirestore(t *testing.T) { + if *testStreamProjectID == "" { + t.Skip("Skipping test: -test-stream-project-id flag not provided") + } +} + +func setupTestStreamManager(t *testing.T) (*FirestoreStreamManager, *firestore.Client, func()) { + skipIfNoFirestore(t) + + ctx := context.Background() + g := genkit.Init(ctx, genkit.WithPlugins(&firebase.Firebase{ProjectId: *testStreamProjectID})) + + f := genkit.LookupPlugin(g, "firebase").(*firebase.Firebase) + client, err := f.Firestore(ctx) + if err != nil { + t.Fatalf("Failed to get Firestore client: %v", err) + } + + manager, err := NewFirestoreStreamManager(ctx, g, + WithCollection(*testStreamCollection), + ) + if err != nil { + t.Fatalf("Failed to create stream manager: %v", err) + } + + cleanup := func() { + deleteStreamCollection(ctx, client, *testStreamCollection, t) + } + + return manager, client, cleanup +} + +func deleteStreamCollection(ctx context.Context, client *firestore.Client, collectionName string, t *testing.T) { + iter := client.Collection(collectionName).Documents(ctx) + for { + doc, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Logf("Failed to iterate documents for deletion: %v", err) + return + } + _, err = doc.Ref.Delete(ctx) + if err != nil { + t.Logf("Failed to delete document %s: %v", doc.Ref.ID, err) + } + } +} + +func TestFirestoreStreamManager_OpenDuplicateFails(t *testing.T) { + manager, _, cleanup := setupTestStreamManager(t) + defer cleanup() + + ctx := context.Background() + streamID := "test-stream-dup" + + _, err := manager.Open(ctx, streamID) + if err != nil { + t.Fatalf("First Open failed: %v", err) + } + + _, err = manager.Open(ctx, streamID) + if err == nil { + t.Fatal("Expected error when opening duplicate stream") + } + + publicErr, ok := err.(*core.UserFacingError) + if !ok { + t.Fatalf("Expected UserFacingError, got %T", err) + } + if publicErr.Status != core.ALREADY_EXISTS { + t.Errorf("Expected ALREADY_EXISTS error, got %v", publicErr.Status) + } +} + +func TestFirestoreStreamManager_OpenAndWrite(t *testing.T) { + manager, client, cleanup := setupTestStreamManager(t) + defer cleanup() + + ctx := context.Background() + streamID := "test-stream-open-write" + + stream, err := manager.Open(ctx, streamID) + if err != nil { + t.Fatalf("Failed to open stream: %v", err) + } + + chunk1, _ := json.Marshal(map[string]string{"foo": "bar"}) + chunk2, _ := json.Marshal(map[string]string{"bar": "baz"}) + + if err := stream.Write(ctx, chunk1); err != nil { + t.Fatalf("Failed to write chunk 1: %v", err) + } + if err := stream.Write(ctx, chunk2); err != nil { + t.Fatalf("Failed to write chunk 2: %v", err) + } + + snapshot, err := client.Collection(*testStreamCollection).Doc(streamID).Get(ctx) + if err != nil { + t.Fatalf("Failed to get document: %v", err) + } + + data := snapshot.Data() + streamArr, ok := data["stream"].([]interface{}) + if !ok { + t.Fatalf("Expected stream array, got %T", data["stream"]) + } + + if len(streamArr) != 2 { + t.Errorf("Expected 2 stream entries, got %d", len(streamArr)) + } + + entry0, _ := streamArr[0].(map[string]interface{}) + if entry0["type"] != streamEventChunk { + t.Errorf("Expected type 'chunk', got %v", entry0["type"]) + } + if entry0["uuid"] == nil || entry0["uuid"] == "" { + t.Error("Expected uuid to be set for chunk") + } + + if data["expiresAt"] == nil { + t.Error("Expected expiresAt to be set on open (for abandoned stream cleanup)") + } +} + +func TestFirestoreStreamManager_PreserveDuplicateChunks(t *testing.T) { + manager, client, cleanup := setupTestStreamManager(t) + defer cleanup() + + ctx := context.Background() + streamID := "test-stream-dupes" + + stream, err := manager.Open(ctx, streamID) + if err != nil { + t.Fatalf("Failed to open stream: %v", err) + } + + chunk, _ := json.Marshal(map[string]string{"foo": "bar"}) + + if err := stream.Write(ctx, chunk); err != nil { + t.Fatalf("Failed to write chunk 1: %v", err) + } + if err := stream.Write(ctx, chunk); err != nil { + t.Fatalf("Failed to write chunk 2: %v", err) + } + + snapshot, err := client.Collection(*testStreamCollection).Doc(streamID).Get(ctx) + if err != nil { + t.Fatalf("Failed to get document: %v", err) + } + + data := snapshot.Data() + streamArr, _ := data["stream"].([]interface{}) + + if len(streamArr) != 2 { + t.Errorf("Expected 2 stream entries (duplicates should be preserved), got %d", len(streamArr)) + } + + entry0, _ := streamArr[0].(map[string]interface{}) + entry1, _ := streamArr[1].(map[string]interface{}) + if entry0["uuid"] == entry1["uuid"] { + t.Error("UUIDs should be different for duplicate chunks") + } +} + +func TestFirestoreStreamManager_Done(t *testing.T) { + manager, client, cleanup := setupTestStreamManager(t) + defer cleanup() + + ctx := context.Background() + streamID := "test-stream-done" + + stream, err := manager.Open(ctx, streamID) + if err != nil { + t.Fatalf("Failed to open stream: %v", err) + } + + output, _ := json.Marshal(map[string]string{"result": "success"}) + if err := stream.Done(ctx, output); err != nil { + t.Fatalf("Failed to mark stream done: %v", err) + } + + snapshot, err := client.Collection(*testStreamCollection).Doc(streamID).Get(ctx) + if err != nil { + t.Fatalf("Failed to get document: %v", err) + } + + data := snapshot.Data() + streamArr, _ := data["stream"].([]interface{}) + + if len(streamArr) != 1 { + t.Errorf("Expected 1 stream entry, got %d", len(streamArr)) + } + + entry, _ := streamArr[0].(map[string]interface{}) + if entry["type"] != streamEventDone { + t.Errorf("Expected type 'done', got %v", entry["type"]) + } + + if data["expiresAt"] == nil { + t.Error("Expected expiresAt to be set after done") + } +} + +func TestFirestoreStreamManager_Error(t *testing.T) { + manager, client, cleanup := setupTestStreamManager(t) + defer cleanup() + + ctx := context.Background() + streamID := "test-stream-error" + + stream, err := manager.Open(ctx, streamID) + if err != nil { + t.Fatalf("Failed to open stream: %v", err) + } + + testError := errors.New("test error message") + if err := stream.Error(ctx, testError); err != nil { + t.Fatalf("Failed to mark stream error: %v", err) + } + + snapshot, err := client.Collection(*testStreamCollection).Doc(streamID).Get(ctx) + if err != nil { + t.Fatalf("Failed to get document: %v", err) + } + + data := snapshot.Data() + streamArr, _ := data["stream"].([]interface{}) + + if len(streamArr) != 1 { + t.Errorf("Expected 1 stream entry, got %d", len(streamArr)) + } + + entry, _ := streamArr[0].(map[string]interface{}) + if entry["type"] != streamEventError { + t.Errorf("Expected type 'error', got %v", entry["type"]) + } + + errData, _ := entry["err"].(map[string]interface{}) + if errData["message"] != "test error message" { + t.Errorf("Expected error message 'test error message', got %v", errData["message"]) + } + if errData["status"] != string(core.UNKNOWN) { + t.Errorf("Expected status UNKNOWN for plain error, got %v", errData["status"]) + } + + if data["expiresAt"] == nil { + t.Error("Expected expiresAt to be set after error") + } +} + +func TestFirestoreStreamManager_ErrorStatusPreserved(t *testing.T) { + manager, client, cleanup := setupTestStreamManager(t) + defer cleanup() + + ctx := context.Background() + streamID := "test-stream-error-status" + + stream, err := manager.Open(ctx, streamID) + if err != nil { + t.Fatalf("Failed to open stream: %v", err) + } + + testError := core.NewPublicError(core.INVALID_ARGUMENT, "invalid input", nil) + if err := stream.Error(ctx, testError); err != nil { + t.Fatalf("Failed to mark stream error: %v", err) + } + + snapshot, err := client.Collection(*testStreamCollection).Doc(streamID).Get(ctx) + if err != nil { + t.Fatalf("Failed to get document: %v", err) + } + + data := snapshot.Data() + streamArr, _ := data["stream"].([]interface{}) + entry, _ := streamArr[0].(map[string]interface{}) + errData, _ := entry["err"].(map[string]interface{}) + + if errData["status"] != string(core.INVALID_ARGUMENT) { + t.Errorf("Expected status INVALID_ARGUMENT, got %v", errData["status"]) + } + if errData["message"] != "invalid input" { + t.Errorf("Expected message 'invalid input', got %v", errData["message"]) + } +} + +func TestFirestoreStreamManager_Subscribe(t *testing.T) { + manager, client, cleanup := setupTestStreamManager(t) + defer cleanup() + + ctx := context.Background() + streamID := "test-stream-subscribe" + + chunk1, _ := json.Marshal(map[string]string{"foo": "bar"}) + chunk2, _ := json.Marshal(map[string]string{"bar": "baz"}) + output, _ := json.Marshal(map[string]string{"result": "success"}) + + _, err := client.Collection(*testStreamCollection).Doc(streamID).Set(ctx, map[string]interface{}{ + "stream": []map[string]interface{}{ + {"type": "chunk", "chunk": chunk1, "uuid": "uuid1"}, + {"type": "chunk", "chunk": chunk2, "uuid": "uuid2"}, + {"type": "done", "output": output}, + }, + "createdAt": time.Now(), + "updatedAt": time.Now(), + }) + if err != nil { + t.Fatalf("Failed to create test document: %v", err) + } + + ch, unsubscribe, err := manager.Subscribe(ctx, streamID) + if err != nil { + t.Fatalf("Failed to subscribe: %v", err) + } + defer unsubscribe() + + var chunks []json.RawMessage + var finalOutput json.RawMessage + timeout := time.After(5 * time.Second) + + for { + select { + case event, ok := <-ch: + if !ok { + goto verify + } + switch event.Type { + case streaming.StreamEventChunk: + chunks = append(chunks, event.Chunk) + case streaming.StreamEventDone: + finalOutput = event.Output + goto verify + case streaming.StreamEventError: + t.Fatalf("Unexpected error: %v", event.Err) + } + case <-timeout: + t.Fatal("Timeout waiting for stream events") + } + } + +verify: + if len(chunks) != 2 { + t.Errorf("Expected 2 chunks, got %d", len(chunks)) + } + if finalOutput == nil { + t.Error("Expected final output") + } +} + +func TestFirestoreStreamManager_SubscribeErrorStatusPreserved(t *testing.T) { + manager, client, cleanup := setupTestStreamManager(t) + defer cleanup() + + ctx := context.Background() + streamID := "test-stream-sub-error-status" + + _, err := client.Collection(*testStreamCollection).Doc(streamID).Set(ctx, map[string]interface{}{ + "stream": []map[string]interface{}{ + {"type": "error", "err": map[string]interface{}{ + "status": string(core.INVALID_ARGUMENT), + "message": "bad input", + }}, + }, + "createdAt": time.Now(), + "updatedAt": time.Now(), + }) + if err != nil { + t.Fatalf("Failed to create test document: %v", err) + } + + ch, unsubscribe, err := manager.Subscribe(ctx, streamID) + if err != nil { + t.Fatalf("Failed to subscribe: %v", err) + } + defer unsubscribe() + + timeout := time.After(5 * time.Second) + select { + case event, ok := <-ch: + if !ok { + t.Fatal("Channel closed unexpectedly") + } + if event.Type != streaming.StreamEventError { + t.Fatalf("Expected error event, got %v", event.Type) + } + publicErr, ok := event.Err.(*core.UserFacingError) + if !ok { + t.Fatalf("Expected UserFacingError, got %T", event.Err) + } + if publicErr.Status != core.INVALID_ARGUMENT { + t.Errorf("Expected INVALID_ARGUMENT status, got %v", publicErr.Status) + } + case <-timeout: + t.Fatal("Timeout waiting for error event") + } +} + +func TestFirestoreStreamManager_SubscribeNotFound(t *testing.T) { + manager, _, cleanup := setupTestStreamManager(t) + defer cleanup() + + ctx := context.Background() + _, _, err := manager.Subscribe(ctx, "non-existent-stream") + if err == nil { + t.Fatal("Expected error for non-existent stream") + } + + publicErr, ok := err.(*core.UserFacingError) + if !ok { + t.Fatalf("Expected UserFacingError, got %T", err) + } + if publicErr.Status != core.NOT_FOUND { + t.Errorf("Expected NOT_FOUND error, got %v", publicErr.Status) + } +} + +func TestFirestoreStreamManager_Timeout(t *testing.T) { + skipIfNoFirestore(t) + + ctx := context.Background() + g := genkit.Init(ctx, genkit.WithPlugins(&firebase.Firebase{ProjectId: *testStreamProjectID})) + + f := genkit.LookupPlugin(g, "firebase").(*firebase.Firebase) + client, err := f.Firestore(ctx) + if err != nil { + t.Fatalf("Failed to get Firestore client: %v", err) + } + defer deleteStreamCollection(ctx, client, *testStreamCollection, t) + + manager, err := NewFirestoreStreamManager(ctx, g, + WithCollection(*testStreamCollection), + WithTimeout(100*time.Millisecond), + ) + if err != nil { + t.Fatalf("Failed to create stream manager: %v", err) + } + + streamID := "test-stream-timeout" + + _, err = manager.Open(ctx, streamID) + if err != nil { + t.Fatalf("Failed to open stream: %v", err) + } + + ch, _, err := manager.Subscribe(ctx, streamID) + if err != nil { + t.Fatalf("Failed to subscribe: %v", err) + } + + timeout := time.After(2 * time.Second) + for { + select { + case event, ok := <-ch: + if !ok { + t.Fatal("Channel closed without timeout error") + return + } + if event.Type == streaming.StreamEventError { + publicErr, ok := event.Err.(*core.UserFacingError) + if !ok { + t.Fatalf("Expected UserFacingError, got %T", event.Err) + } + if publicErr.Status != core.DEADLINE_EXCEEDED { + t.Errorf("Expected DEADLINE_EXCEEDED, got %v", publicErr.Status) + } + return + } + case <-timeout: + t.Fatal("Test timeout - stream timeout didn't trigger") + } + } +} + +func TestFirestoreStreamManager_WriteAfterClose(t *testing.T) { + manager, _, cleanup := setupTestStreamManager(t) + defer cleanup() + + ctx := context.Background() + streamID := "test-stream-write-after-close" + + stream, err := manager.Open(ctx, streamID) + if err != nil { + t.Fatalf("Failed to open stream: %v", err) + } + + if err := stream.Close(); err != nil { + t.Fatalf("Failed to close stream: %v", err) + } + + chunk, _ := json.Marshal(map[string]string{"foo": "bar"}) + err = stream.Write(ctx, chunk) + if err == nil { + t.Fatal("Expected error when writing after close") + } + + publicErr, ok := err.(*core.UserFacingError) + if !ok { + t.Fatalf("Expected UserFacingError, got %T", err) + } + if publicErr.Status != core.FAILED_PRECONDITION { + t.Errorf("Expected FAILED_PRECONDITION, got %v", publicErr.Status) + } +} diff --git a/go/plugins/googlegenai/gemini.go b/go/plugins/googlegenai/gemini.go index c7e49a8bec..c3ca5297b0 100644 --- a/go/plugins/googlegenai/gemini.go +++ b/go/plugins/googlegenai/gemini.go @@ -484,6 +484,34 @@ func toGeminiTools(inTools []*ai.ToolDefinition) ([]*genai.Tool, error) { return outTools, nil } +// toGeminiFunctionResponsePart translates a slice of [ai.Part] to a slice of [genai.FunctionResponsePart] +func toGeminiFunctionResponsePart(parts []*ai.Part) ([]*genai.FunctionResponsePart, error) { + frp := []*genai.FunctionResponsePart{} + for _, p := range parts { + switch { + case p.IsData(): + contentType, data, err := uri.Data(p) + if err != nil { + return nil, err + } + frp = append(frp, genai.NewFunctionResponsePartFromBytes(data, contentType)) + case p.IsMedia(): + if strings.HasPrefix(p.Text, "data:") { + contentType, data, err := uri.Data(p) + if err != nil { + return nil, err + } + frp = append(frp, genai.NewFunctionResponsePartFromBytes(data, contentType)) + continue + } + frp = append(frp, genai.NewFunctionResponsePartFromURI(p.Text, p.ContentType)) + default: + return nil, fmt.Errorf("unsupported function response part type: %d", p.Kind) + } + } + return frp, nil +} + // mergeTools consolidates all FunctionDeclarations into a single Tool // while preserving non-function tools (Retrieval, GoogleSearch, CodeExecution, etc.) func mergeTools(ts []*genai.Tool) []*genai.Tool { @@ -807,6 +835,7 @@ func translateCandidate(cand *genai.Candidate) (*ai.ModelResponse, error) { if part.FileData != nil { partFound++ p = ai.NewMediaPart(part.FileData.MIMEType, part.FileData.FileURI) + } if part.FunctionCall != nil { partFound++ @@ -814,6 +843,14 @@ func translateCandidate(cand *genai.Candidate) (*ai.ModelResponse, error) { Name: part.FunctionCall.Name, Input: part.FunctionCall.Args, }) + // FunctionCall parts may contain a ThoughtSignature that must be preserved + // and returned in subsequent requests for the tool call to be valid. + if len(part.ThoughtSignature) > 0 { + if p.Metadata == nil { + p.Metadata = make(map[string]any) + } + p.Metadata["signature"] = part.ThoughtSignature + } } if part.CodeExecutionResult != nil { partFound++ @@ -836,6 +873,13 @@ func translateCandidate(cand *genai.Candidate) (*ai.ModelResponse, error) { continue } + if len(part.ThoughtSignature) > 0 { + if p.Metadata == nil { + p.Metadata = make(map[string]any) + } + p.Metadata["signature"] = part.ThoughtSignature + } + msg.Content = append(msg.Content, p) } m.Message = msg @@ -892,37 +936,29 @@ func toGeminiParts(parts []*ai.Part) ([]*genai.Part, error) { // toGeminiPart converts a [ai.Part] to a [genai.Part]. func toGeminiPart(p *ai.Part) (*genai.Part, error) { + var gp *genai.Part switch { case p.IsReasoning(): - // TODO: go-genai does not support genai.NewPartFromThought() - signature := []byte{} - if p.Metadata != nil { - if sig, ok := p.Metadata["signature"].([]byte); ok { - signature = sig - } - } - return &genai.Part{ - Thought: true, - Text: p.Text, - ThoughtSignature: signature, - }, nil + gp = genai.NewPartFromText(p.Text) + gp.Thought = true case p.IsText(): - return genai.NewPartFromText(p.Text), nil + gp = genai.NewPartFromText(p.Text) case p.IsMedia(): if strings.HasPrefix(p.Text, "data:") { contentType, data, err := uri.Data(p) if err != nil { return nil, err } - return genai.NewPartFromBytes(data, contentType), nil + gp = genai.NewPartFromBytes(data, contentType) + } else { + gp = genai.NewPartFromURI(p.Text, p.ContentType) } - return genai.NewPartFromURI(p.Text, p.ContentType), nil case p.IsData(): contentType, data, err := uri.Data(p) if err != nil { return nil, err } - return genai.NewPartFromBytes(data, contentType), nil + gp = genai.NewPartFromBytes(data, contentType) case p.IsToolResponse(): toolResp := p.ToolResponse var output map[string]any @@ -934,8 +970,21 @@ func toGeminiPart(p *ai.Part) (*genai.Part, error) { "content": toolResp.Output, } } - fr := genai.NewPartFromFunctionResponse(toolResp.Name, output) - return fr, nil + var isMultipart bool + if multiPart, ok := p.Metadata["multipart"].(bool); ok { + isMultipart = multiPart + } + if len(toolResp.Content) > 0 { + isMultipart = true + } + if isMultipart { + toolRespParts, err := toGeminiFunctionResponsePart(toolResp.Content) + if err != nil { + return nil, err + } + return genai.NewPartFromFunctionResponseWithParts(toolResp.Name, output, toolRespParts), nil + } + return genai.NewPartFromFunctionResponse(toolResp.Name, output), nil case p.IsToolRequest(): toolReq := p.ToolRequest var input map[string]any @@ -947,10 +996,24 @@ func toGeminiPart(p *ai.Part) (*genai.Part, error) { } } fc := genai.NewPartFromFunctionCall(toolReq.Name, input) + // Restore ThoughtSignature if present in metadata + if p.Metadata != nil { + if sig, ok := p.Metadata["signature"].([]byte); ok { + fc.ThoughtSignature = sig + } + } return fc, nil default: - panic("unknown part type in a request") + return nil, fmt.Errorf("unknown part in the request: %q", p.Kind) } + + if p.Metadata != nil { + if sig, ok := p.Metadata["signature"].([]byte); ok { + gp.ThoughtSignature = sig + } + } + + return gp, nil } // validToolName checks whether the provided tool name matches the diff --git a/go/plugins/googlegenai/gemini_test.go b/go/plugins/googlegenai/gemini_test.go index daa4da215e..9aae76a054 100644 --- a/go/plugins/googlegenai/gemini_test.go +++ b/go/plugins/googlegenai/gemini_test.go @@ -707,6 +707,82 @@ func TestValidToolName(t *testing.T) { } } +func TestToGeminiParts_MultipartToolResponse(t *testing.T) { + t.Run("ValidPartType", func(t *testing.T) { + // Create a tool response with both output and additional content (media) + toolResp := &ai.ToolResponse{ + Name: "generateImage", + Output: map[string]any{"status": "success"}, + Content: []*ai.Part{ + ai.NewMediaPart("image/png", "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="), + }, + } + + // create a mock ToolResponsePart, setting "multipart" to true is required + part := ai.NewToolResponsePart(toolResp) + part.Metadata = map[string]any{"multipart": true} + + geminiParts, err := toGeminiParts([]*ai.Part{part}) + if err != nil { + t.Fatalf("toGeminiParts failed: %v", err) + } + + // Expecting 1 part which contains the function response with internal parts + if len(geminiParts) != 1 { + t.Fatalf("expected 1 Gemini part, got %d", len(geminiParts)) + } + + if geminiParts[0].FunctionResponse == nil { + t.Error("expected first part to be FunctionResponse") + } + if geminiParts[0].FunctionResponse.Name != "generateImage" { + t.Errorf("expected function name 'generateImage', got %q", geminiParts[0].FunctionResponse.Name) + } + }) + + t.Run("UnsupportedPartType", func(t *testing.T) { + // Create a tool response with text content (unsupported for multipart) + toolResp := &ai.ToolResponse{ + Name: "generateText", + Output: map[string]any{"status": "success"}, + Content: []*ai.Part{ + ai.NewTextPart("Generated text"), + }, + } + + part := ai.NewToolResponsePart(toolResp) + part.Metadata = map[string]any{"multipart": true} + + _, err := toGeminiParts([]*ai.Part{part}) + if err == nil { + t.Fatal("expected error for unsupported text part in multipart response, got nil") + } + }) +} + +func TestToGeminiParts_SimpleToolResponse(t *testing.T) { + // Create a simple tool response (no content) + toolResp := &ai.ToolResponse{ + Name: "search", + Output: map[string]any{"result": "foo"}, + } + + part := ai.NewToolResponsePart(toolResp) + + geminiParts, err := toGeminiParts([]*ai.Part{part}) + if err != nil { + t.Fatalf("toGeminiParts failed: %v", err) + } + + if len(geminiParts) != 1 { + t.Fatalf("expected 1 Gemini part, got %d", len(geminiParts)) + } + + if geminiParts[0].FunctionResponse == nil { + t.Error("expected part to be FunctionResponse") + } +} + // genToolName generates a string of a specified length using only // the valid characters for a Gemini Tool name func genToolName(length int, chars string) string { diff --git a/go/plugins/googlegenai/googleai_live_test.go b/go/plugins/googlegenai/googleai_live_test.go index 4e78ab4f4c..783eccd239 100644 --- a/go/plugins/googlegenai/googleai_live_test.go +++ b/go/plugins/googlegenai/googleai_live_test.go @@ -170,7 +170,7 @@ func TestGoogleAILive(t *testing.T) { t.Fatal(err) } - out := resp.Message.Content[0].Text + out := resp.Text() const want = "11.31" if !strings.Contains(out, want) { t.Errorf("got %q, expecting it to contain %q", out, want) @@ -219,7 +219,7 @@ func TestGoogleAILive(t *testing.T) { t.Fatal(err) } - out := resp.Message.Content[0].Text + out := resp.Text() const want = "11.31" if !strings.Contains(out, want) { t.Errorf("got %q, expecting it to contain %q", out, want) @@ -307,7 +307,7 @@ func TestGoogleAILive(t *testing.T) { t.Fatal(err) } - out := resp.Message.Content[0].Text + out := resp.Text() const doNotWant = "11.31" if strings.Contains(out, doNotWant) { t.Errorf("got %q, expecting it NOT to contain %q", out, doNotWant) @@ -582,6 +582,37 @@ func TestGoogleAILive(t *testing.T) { t.Fatal("thoughts tokens should be zero") } }) + t.Run("multipart tool", func(t *testing.T) { + m := googlegenai.GoogleAIModel(g, "gemini-3-pro-preview") + img64, err := fetchImgAsBase64() + if err != nil { + t.Fatal(err) + } + + tool := genkit.DefineMultipartTool(g, "getImage", "returns a misterious image", + func(ctx *ai.ToolContext, input any) (*ai.MultipartToolResponse, error) { + return &ai.MultipartToolResponse{ + Output: map[string]any{"status": "success"}, + Content: []*ai.Part{ + ai.NewMediaPart("image/jpeg", "data:image/jpeg;base64,"+img64), + }, + }, nil + }, + ) + + resp, err := genkit.Generate(ctx, g, + ai.WithModel(m), + ai.WithTools(tool), + ai.WithPrompt("get an image and tell me what is in it"), + ) + if err != nil { + t.Fatal(err) + } + + if !strings.Contains(strings.ToLower(resp.Text()), "cat") { + t.Errorf("expected response to contain 'cat', got: %s", resp.Text()) + } + }) } func TestCacheHelper(t *testing.T) { diff --git a/go/plugins/googlegenai/googlegenai.go b/go/plugins/googlegenai/googlegenai.go index d056e6fb1c..8ddbdfbe4a 100644 --- a/go/plugins/googlegenai/googlegenai.go +++ b/go/plugins/googlegenai/googlegenai.go @@ -283,14 +283,19 @@ func (v *VertexAI) IsDefinedEmbedder(g *genkit.Genkit, name string) bool { return genkit.LookupEmbedder(g, api.NewName(vertexAIProvider, name)) != nil } -// GoogleAIModelRef creates a new ModelRef for a Google AI model with the given name and configuration. -func GoogleAIModelRef(name string, config *genai.GenerateContentConfig) ai.ModelRef { +// ModelRef creates a new ModelRef for a Google Gen AI model with the given name and configuration. +func ModelRef(name string, config *genai.GenerateContentConfig) ai.ModelRef { return ai.NewModelRef(googleAIProvider+"/"+name, config) } -// VertexAIModelRef creates a new ModelRef for a Vertex AI model with the given name and configuration. -func VertexAIModelRef(name string, config *genai.GenerateContentConfig) ai.ModelRef { - return ai.NewModelRef(vertexAIProvider+"/"+name, config) +// GoogleAIModelRef creates a new ModelRef for a Google AI model with the given ID and configuration. +func GoogleAIModelRef(id string, config *genai.GenerateContentConfig) ai.ModelRef { + return ai.NewModelRef(googleAIProvider+"/"+id, config) +} + +// VertexAIModelRef creates a new ModelRef for a Vertex AI model with the given ID and configuration. +func VertexAIModelRef(id string, config *genai.GenerateContentConfig) ai.ModelRef { + return ai.NewModelRef(vertexAIProvider+"/"+id, config) } // GoogleAIModel returns the [ai.Model] with the given name. diff --git a/go/samples/basic-gemini-with-context/main.go b/go/samples/basic-gemini-with-context/main.go index f971ecc9bc..e69de29bb2 100644 --- a/go/samples/basic-gemini-with-context/main.go +++ b/go/samples/basic-gemini-with-context/main.go @@ -1,54 +0,0 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "context" - "fmt" - - "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/genkit" - "github.com/firebase/genkit/go/plugins/googlegenai" - "google.golang.org/genai" -) - -func main() { - ctx := context.Background() - - // Initialize Genkit with the Google AI plugin. When you pass nil for the - // Config parameter, the Google AI plugin will get the API key from the - // GEMINI_API_KEY or GOOGLE_API_KEY environment variable, which is the recommended - // practice. - g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) - - // Define a simple flow that generates jokes about a given topic with a context of bananas - genkit.DefineFlow(g, "contextFlow", func(ctx context.Context, input string) (string, error) { - resp, err := genkit.Generate(ctx, g, - ai.WithModelName("googleai/gemini-2.5-flash"), - ai.WithConfig(&genai.GenerateContentConfig{ - Temperature: genai.Ptr[float32](1.0), - }), - ai.WithPrompt(fmt.Sprintf(`Tell silly short jokes about %s`, input)), - ai.WithDocs(ai.DocumentFromText("Bananas are plentiful in the tropics.", nil))) - if err != nil { - return "", err - } - - text := resp.Text() - return text, nil - }) - - <-ctx.Done() -} diff --git a/go/samples/basic-gemini/main.go b/go/samples/basic-gemini/main.go index e61ec9df42..c9cb04bdc3 100644 --- a/go/samples/basic-gemini/main.go +++ b/go/samples/basic-gemini/main.go @@ -26,32 +26,37 @@ import ( func main() { ctx := context.Background() - // Initialize Genkit with the Google AI plugin. When you pass nil for the - // Config parameter, the Google AI plugin will get the API key from the - // GEMINI_API_KEY or GOOGLE_API_KEY environment variable, which is the recommended - // practice. g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) - // Define a simple flow that generates jokes about a given topic - genkit.DefineStreamingFlow(g, "jokesFlow", func(ctx context.Context, input string, cb ai.ModelStreamCallback) (string, error) { - type Joke struct { - Joke string `json:"joke"` - Category string `json:"jokeCategory" description:"What is the joke about"` - } - - genkit.DefineSchemaFor[Joke](g) + // Define a multipart tool. + // This simulates a tool that takes a screenshot + screenshot := genkit.DefineMultipartTool(g, "screenshot", "Takes a screenshot", + func(ctx *ai.ToolContext, input any) (*ai.MultipartToolResponse, error) { + rectangle := "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAHIAAABUAQMAAABk5vEVAAAABlBMVEX///8AAABVwtN+" + + "AAAAI0lEQVR4nGNgGHaA/z8UHIDwOWASDqP8Uf7w56On/1FAQwAAVM0exw1hqwkAAAAASUVORK5CYII=" + return &ai.MultipartToolResponse{ + Output: map[string]any{"success": true}, + Content: []*ai.Part{ + ai.NewMediaPart("image/png", rectangle), + }, + }, nil + }, + ) + // Define a simple flow that uses the multipart tool + genkit.DefineStreamingFlow(g, "cardFlow", func(ctx context.Context, input any, cb ai.ModelStreamCallback) (string, error) { resp, err := genkit.Generate(ctx, g, - ai.WithModelName("googleai/gemini-2.5-flash"), + ai.WithModelName("googleai/gemini-3-pro-preview"), ai.WithConfig(&genai.GenerateContentConfig{ Temperature: genai.Ptr[float32](1.0), ThinkingConfig: &genai.ThinkingConfig{ - ThinkingBudget: genai.Ptr[int32](0), + ThinkingLevel: genai.ThinkingLevelHigh, }, }), + ai.WithTools(screenshot), ai.WithStreaming(cb), - ai.WithOutputSchemaName("Joke"), - ai.WithPrompt(`Tell short jokes about %s`, input)) + ai.WithPrompt("Tell me what I'm seeing in the screen"), + ) if err != nil { return "", err } diff --git a/go/samples/basic-prompts/main.go b/go/samples/basic-prompts/main.go new file mode 100644 index 0000000000..ccbc308a13 --- /dev/null +++ b/go/samples/basic-prompts/main.go @@ -0,0 +1,287 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "log" + "net/http" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" + "github.com/firebase/genkit/go/plugins/server" + "google.golang.org/genai" +) + +type JokeRequest struct { + Topic string `json:"topic" jsonschema:"default=airplane food"` +} + +// Note how the fields are annotated with jsonschema tags to describe the output schema. +// This is vital for the model to understand the intent of the fields. +type Joke struct { + Joke string `json:"joke" jsonschema:"description=The joke text"` + Category string `json:"category" jsonschema:"description=The joke category"` +} + +type RecipeRequest struct { + Dish string `json:"dish" jsonschema:"default=pasta"` + Cuisine string `json:"cuisine" jsonschema:"default=Italian"` + ServingSize int `json:"servingSize" jsonschema:"default=4"` + MaxPrepMinutes int `json:"maxPrepMinutes" jsonschema:"default=30"` + DietaryRestrictions []string `json:"dietaryRestrictions,omitempty"` +} + +type Ingredient struct { + Name string `json:"name" jsonschema:"description=The ingredient name"` + Amount string `json:"amount" jsonschema:"description=The ingredient amount (e.g. 1 cup, 2 tablespoons, etc.)"` + Optional bool `json:"optional,omitempty" jsonschema:"description=Whether the ingredient is optional in the recipe"` +} + +type Recipe struct { + Title string `json:"title" jsonschema:"description=The recipe title (e.g. 'Spicy Chicken Tacos')"` + Description string `json:"description,omitempty" jsonschema:"description=The recipe description (under 100 characters)"` + Ingredients []*Ingredient `json:"ingredients" jsonschema:"description=The recipe ingredients (group by type and order by importance)"` + Instructions []string `json:"instructions" jsonschema:"description=The recipe instructions (step by step)"` + PrepTime string `json:"prepTime" jsonschema:"description=The recipe preparation time (e.g. 10 minutes, 30 minutes, etc.)"` + Difficulty string `json:"difficulty" jsonschema:"enum=easy,enum=medium,enum=hard"` +} + +func main() { + ctx := context.Background() + + // Initialize Genkit with the Google AI plugin. When you pass nil for the + // Config parameter, the Google AI plugin will get the API key from the + // GEMINI_API_KEY or GOOGLE_API_KEY environment variable, which is the recommended + // practice. + g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) + + // Define schemas for the expected input and output types so that the Dotprompt files can reference them. + // Alternatively, you can specify the JSON schema by hand in the Dotprompt metadata. + // Code-defined prompts do not need to have schemas defined in advance but they too can reference them. + genkit.DefineSchemaFor[JokeRequest](g) + genkit.DefineSchemaFor[Joke](g) + genkit.DefineSchemaFor[RecipeRequest](g) + genkit.DefineSchemaFor[Recipe](g) + + // TODO: Include partials and helpers. + + // Define the prompts and flows. + DefineSimpleJokeWithInlinePrompt(g) + DefineSimpleJokeWithDotprompt(g) + DefineStructuredJokeWithInlinePrompt(g) + DefineStructuredJokeWithDotprompt(g) + DefineRecipeWithInlinePrompt(g) + DefineRecipeWithDotprompt(g) + + // Optionally, start a web server to make the flows callable via HTTP. + mux := http.NewServeMux() + for _, a := range genkit.ListFlows(g) { + mux.HandleFunc("POST /"+a.Name(), genkit.Handler(a)) + } + log.Fatal(server.Start(ctx, "127.0.0.1:8080", mux)) +} + +// DefineSimpleJokeWithInlinePrompt demonstrates defining a prompt in code using DefinePrompt. +// The prompt has no output schema defined so it will always return a string. +// When executing the prompt, we pass in a map[string]any with the input fields. +func DefineSimpleJokeWithInlinePrompt(g *genkit.Genkit) { + jokePrompt := genkit.DefinePrompt( + g, "joke.code", + ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + // Despite JokeRequest having defaults set in jsonschema tags, we can override it with values set in WithInputType. + ai.WithInputType(JokeRequest{Topic: "rush hour traffic"}), + ai.WithPrompt("Share a long joke about {{topic}}."), + ) + + genkit.DefineStreamingFlow(g, "simpleJokePromptFlow", + func(ctx context.Context, topic string, sendChunk core.StreamCallback[string]) (string, error) { + // One way to pass input is using a map[string]any. This is useful when there is no structured input type. + stream := jokePrompt.ExecuteStream(ctx, ai.WithInput(map[string]any{"topic": topic})) + for result, err := range stream { + if err != nil { + return "", fmt.Errorf("could not generate joke: %w", err) + } + if result.Done { + return result.Response.Text(), nil + } + sendChunk(ctx, result.Chunk.Text()) + } + + return "", nil + }, + ) +} + +// DefineSimpleJokeWithDotprompt demonstrates loading a prompt from a .prompt file using +// LoadPrompt. The prompt configuration (model, input schema, defaults) is defined in the +// file. Input is passed as a map since the .prompt file defines its own schema. +func DefineSimpleJokeWithDotprompt(g *genkit.Genkit) { + genkit.DefineStreamingFlow(g, "simpleJokeDotpromptFlow", + func(ctx context.Context, topic string, sendChunk core.StreamCallback[string]) (string, error) { + jokePrompt := genkit.LookupPrompt(g, "joke") + // One way to pass input is using a map[string]any. This is useful when there is no structured input type. + stream := jokePrompt.ExecuteStream(ctx, ai.WithInput(map[string]any{"topic": topic})) + for result, err := range stream { + if err != nil { + return "", fmt.Errorf("could not generate joke: %w", err) + } + if result.Done { + return result.Response.Text(), nil + } + sendChunk(ctx, result.Chunk.Text()) + } + + return "", nil + }, + ) +} + +// DefineStructuredJokeWithInlinePrompt demonstrates DefineDataPrompt for strongly-typed +// input and output. The type parameters automatically configure input/output schemas +// and JSON output format. ExecuteStream returns typed chunks and final output. +func DefineStructuredJokeWithInlinePrompt(g *genkit.Genkit) { + jokePrompt := genkit.DefineDataPrompt[JokeRequest, *Joke]( + g, "structured-joke.code", + ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + ai.WithPrompt("Share a long joke about {{topic}}."), + ) + + genkit.DefineStreamingFlow(g, "structuredJokePromptFlow", + func(ctx context.Context, input JokeRequest, sendChunk core.StreamCallback[*Joke]) (*Joke, error) { + for result, err := range jokePrompt.ExecuteStream(ctx, input) { + if err != nil { + return nil, fmt.Errorf("could not generate joke: %w", err) + } + if result.Done { + return result.Output, nil + } + sendChunk(ctx, result.Chunk) + } + + return nil, nil + }, + ) +} + +// DefineStructuredJokeWithDotprompt demonstrates LookupDataPrompt to wrap a .prompt file +// with Go type information. The .prompt file references registered schemas by name +// (e.g., "schema: Joke"), which must be defined via DefineSchemaFor before loading. +func DefineStructuredJokeWithDotprompt(g *genkit.Genkit) { + genkit.DefineStreamingFlow(g, "structuredJokeDotpromptFlow", + func(ctx context.Context, input JokeRequest, sendChunk core.StreamCallback[*Joke]) (*Joke, error) { + jokePrompt := genkit.LookupDataPrompt[JokeRequest, *Joke](g, "structured-joke") + stream := jokePrompt.ExecuteStream(ctx, input) + for result, err := range stream { + if err != nil { + return nil, fmt.Errorf("could not generate joke: %w", err) + } + if result.Done { + return result.Output, nil + } + sendChunk(ctx, result.Chunk) + } + return nil, nil + }, + ) +} + +// DefineRecipeWithInlinePrompt demonstrates DefineDataPrompt with complex nested types +// and Handlebars conditionals/loops in the prompt template. The streaming flow applies +// default values before execution and streams partial ingredients as they arrive. +func DefineRecipeWithInlinePrompt(g *genkit.Genkit) { + recipePrompt := genkit.DefineDataPrompt[RecipeRequest, *Recipe]( + g, "recipe.code", + ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + ai.WithSystem("You are an experienced chef. Come up with easy, creative recipes."), + ai.WithPrompt("Create a {{cuisine}} {{dish}} recipe for {{servingSize}} people that takes under {{maxPrepMinutes}} minutes to prepare. "+ + "{{#if dietaryRestrictions}}Dietary restrictions: {{#each dietaryRestrictions}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}.{{/if}}"), + ) + + genkit.DefineStreamingFlow(g, "recipePromptFlow", + func(ctx context.Context, input RecipeRequest, sendChunk core.StreamCallback[*Ingredient]) (*Recipe, error) { + // This is not necessary for this example but it shows how to easily have more control over what you stream. + filterNew := newIngredientFilter() + for result, err := range recipePrompt.ExecuteStream(ctx, input) { + if err != nil { + return nil, fmt.Errorf("could not generate recipe: %w", err) + } + if result.Done { + return result.Output, nil + } + for _, i := range filterNew(result.Chunk.Ingredients) { + sendChunk(ctx, i) + } + } + return nil, nil + }, + ) +} + +// DefineRecipeWithDotprompt demonstrates LookupDataPrompt with a .prompt file that uses +// multi-message format (system/user roles) and references registered schemas. +// Streams partial ingredients as they arrive via ExecuteStream. +func DefineRecipeWithDotprompt(g *genkit.Genkit) { + genkit.DefineStreamingFlow(g, "recipeDotpromptFlow", + func(ctx context.Context, input RecipeRequest, sendChunk core.StreamCallback[*Ingredient]) (*Recipe, error) { + // This is not necessary for this example but it shows how to easily have more control over what you stream. + filterNew := newIngredientFilter() + recipePrompt := genkit.LookupDataPrompt[RecipeRequest, *Recipe](g, "recipe") + stream := recipePrompt.ExecuteStream(ctx, input) + for result, err := range stream { + if err != nil { + return nil, fmt.Errorf("could not generate recipe: %w", err) + } + if result.Done { + return result.Output, nil + } + for _, i := range filterNew(result.Chunk.Ingredients) { + sendChunk(ctx, i) + } + } + return nil, nil + }, + ) +} + +// newIngredientFilter is a helper function to filter out duplicate ingredients. +// This allows us to stream only new ingredients as they are identified, avoiding duplicates. +func newIngredientFilter() func([]*Ingredient) []*Ingredient { + seen := map[string]struct{}{} + return func(ings []*Ingredient) (newIngs []*Ingredient) { + for _, ing := range ings { + if _, ok := seen[ing.Name]; !ok { + seen[ing.Name] = struct{}{} + newIngs = append(newIngs, ing) + } + } + return + } +} diff --git a/go/samples/basic-prompts/prompts/joke.prompt b/go/samples/basic-prompts/prompts/joke.prompt new file mode 100644 index 0000000000..fc1add0957 --- /dev/null +++ b/go/samples/basic-prompts/prompts/joke.prompt @@ -0,0 +1,13 @@ +--- +model: googleai/gemini-2.5-flash +config: + thinkingConfig: + thinkingBudget: 0 +input: + schema: + topic?: string + default: + topic: airplane food +--- +Share a long joke about {{topic}}. + diff --git a/go/samples/basic-prompts/prompts/recipe.prompt b/go/samples/basic-prompts/prompts/recipe.prompt new file mode 100644 index 0000000000..d132ba615e --- /dev/null +++ b/go/samples/basic-prompts/prompts/recipe.prompt @@ -0,0 +1,20 @@ +--- +model: googleai/gemini-2.5-flash +config: + thinkingConfig: + thinkingBudget: 0 +input: + schema: RecipeRequest +output: + format: json + schema: Recipe +--- +{{role "system"}} +You are an experienced chef. Come up with easy, creative recipes. + +{{role "user"}} +Create a {{cuisine}} {{dish}} recipe for {{servingSize}} people that takes under {{maxPrepMinutes}} minutes to prepare. +{{#if dietaryRestrictions}} +Dietary restrictions: {{#each dietaryRestrictions}}{{this}}{{#unless @last}}, {{/unless}}{{/each}}. +{{/if}} + diff --git a/go/samples/basic-prompts/prompts/structured-joke.prompt b/go/samples/basic-prompts/prompts/structured-joke.prompt new file mode 100644 index 0000000000..7184b15483 --- /dev/null +++ b/go/samples/basic-prompts/prompts/structured-joke.prompt @@ -0,0 +1,13 @@ +--- +model: googleai/gemini-2.5-flash +config: + thinkingConfig: + thinkingBudget: 0 +input: + schema: JokeRequest +output: + format: json + schema: Joke +--- +Share a long joke about {{topic}}. + diff --git a/go/samples/basic-structured/main.go b/go/samples/basic-structured/main.go new file mode 100644 index 0000000000..428636de4d --- /dev/null +++ b/go/samples/basic-structured/main.go @@ -0,0 +1,181 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "log" + "net/http" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" + "github.com/firebase/genkit/go/plugins/server" + "google.golang.org/genai" +) + +type JokeRequest struct { + Topic string `json:"topic" jsonschema:"default=airplane food"` +} + +// Note how the fields are annotated with jsonschema tags to describe the output schema. +// This is vital for the model to understand the intent of the fields. +type Joke struct { + Joke string `json:"joke" jsonschema:"description=The joke text"` + Category string `json:"category" jsonschema:"description=The joke category"` +} + +type RecipeRequest struct { + Dish string `json:"dish" jsonschema:"default=pasta"` + Cuisine string `json:"cuisine" jsonschema:"default=Italian"` + ServingSize int `json:"servingSize" jsonschema:"default=4"` + MaxPrepMinutes int `json:"maxPrepMinutes" jsonschema:"default=30"` + DietaryRestrictions []string `json:"dietaryRestrictions,omitempty"` +} + +type Ingredient struct { + Name string `json:"name" jsonschema:"description=The ingredient name"` + Amount string `json:"amount" jsonschema:"description=The ingredient amount (e.g. 1 cup, 2 tablespoons, etc.)"` + Optional bool `json:"optional,omitempty" jsonschema:"description=Whether the ingredient is optional in the recipe"` +} + +type Recipe struct { + Title string `json:"title" jsonschema:"description=The recipe title (e.g. 'Spicy Chicken Tacos')"` + Description string `json:"description,omitempty" jsonschema:"description=The recipe description (under 100 characters)"` + Ingredients []*Ingredient `json:"ingredients" jsonschema:"description=The recipe ingredients (order by type first and then importance)"` + Instructions []string `json:"instructions" jsonschema:"description=The recipe instructions (step by step)"` + PrepTime string `json:"prepTime" jsonschema:"description=The recipe preparation time (e.g. 10 minutes, 30 minutes, etc.)"` + Difficulty string `json:"difficulty" jsonschema:"enum=easy,enum=medium,enum=hard"` +} + +func main() { + ctx := context.Background() + + // Initialize Genkit with the Google AI plugin. When you pass nil for the + // Config parameter, the Google AI plugin will get the API key from the + // GEMINI_API_KEY or GOOGLE_API_KEY environment variable, which is the recommended + // practice. + g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) + + // Define the flows. + DefineSimpleJoke(g) + DefineStructuredJoke(g) + DefineRecipe(g) + + // Optionally, start a web server to make the flows callable via HTTP. + mux := http.NewServeMux() + for _, a := range genkit.ListFlows(g) { + mux.HandleFunc("POST /"+a.Name(), genkit.Handler(a)) + } + log.Fatal(server.Start(ctx, "127.0.0.1:8080", mux)) +} + +// DefineSimpleJoke demonstrates defining a streaming flow that generates a joke about a given topic. +func DefineSimpleJoke(g *genkit.Genkit) { + genkit.DefineStreamingFlow(g, "simpleJokesFlow", + func(ctx context.Context, input string, sendChunk core.StreamCallback[string]) (string, error) { + stream := genkit.GenerateStream(ctx, g, + ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + ai.WithPrompt("Share a long joke about %s.", input), + ) + + for result, err := range stream { + if err != nil { + return "", fmt.Errorf("could not generate joke: %w", err) + } + if result.Done { + return result.Response.Text(), nil + } + sendChunk(ctx, result.Chunk.Text()) + } + + return "", nil + }, + ) +} + +// DefineStructuredJoke demonstrates defining a streaming flow that generates a joke about a given topic. +// The input is a strongly-typed JokeRequest struct and the output is a strongly-typed Joke struct. +func DefineStructuredJoke(g *genkit.Genkit) { + genkit.DefineStreamingFlow(g, "structuredJokesFlow", + func(ctx context.Context, input JokeRequest, sendChunk core.StreamCallback[*Joke]) (*Joke, error) { + stream := genkit.GenerateDataStream[*Joke](ctx, g, + ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + ai.WithPrompt("Share a long joke about %s.", input.Topic), + ) + + for result, err := range stream { + if err != nil { + return nil, fmt.Errorf("could not generate joke: %w", err) + } + if result.Done { + return result.Output, nil + } + sendChunk(ctx, result.Chunk) + } + + return nil, nil + }) +} + +// DefineRecipe demonstrates defining a streaming flow that generates a recipe based on a given RecipeRequest struct. +// The input is a strongly-typed RecipeRequest struct and the output is a strongly-typed Recipe struct. +func DefineRecipe(g *genkit.Genkit) { + genkit.DefineStreamingFlow(g, "recipeFlow", + func(ctx context.Context, input RecipeRequest, sendChunk core.StreamCallback[[]*Ingredient]) (*Recipe, error) { + stream := genkit.GenerateDataStream[*Recipe](ctx, g, + ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + ai.WithSystem("You are an experienced chef. Come up with easy, creative recipes."), + // Here we are passing WithPromptFn() since our prompt takes some string manipulation to build. + // Alternatively, we could pass WithPrompt() with the complete prompt string. + ai.WithPromptFn(func(ctx context.Context, _ any) (string, error) { + prompt := fmt.Sprintf( + "Create a %s %s recipe for %d people that takes under %d minutes to prepare.", + input.Cuisine, input.Dish, input.ServingSize, input.MaxPrepMinutes, + ) + if len(input.DietaryRestrictions) > 0 { + prompt += fmt.Sprintf(" Dietary restrictions: %v.", input.DietaryRestrictions) + } + return prompt, nil + }), + ) + + for result, err := range stream { + if err != nil { + return nil, fmt.Errorf("could not generate recipe: %w", err) + } + if result.Done { + return result.Output, nil + } + sendChunk(ctx, result.Chunk.Ingredients) + } + + return nil, nil + }) +} diff --git a/go/samples/basic/main.go b/go/samples/basic/main.go new file mode 100644 index 0000000000..2031340ac5 --- /dev/null +++ b/go/samples/basic/main.go @@ -0,0 +1,86 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "log" + "net/http" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" + "github.com/firebase/genkit/go/plugins/server" + "google.golang.org/genai" +) + +func main() { + ctx := context.Background() + + // Initialize Genkit with the Google AI plugin. When you pass nil for the + // Config parameter, the Google AI plugin will get the API key from the + // GEMINI_API_KEY or GOOGLE_API_KEY environment variable, which is the recommended + // practice. + g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) + + // Define a non-streaming flow that generates jokes about a given topic. + genkit.DefineFlow(g, "jokesFlow", func(ctx context.Context, input string) (string, error) { + if input == "" { + input = "airplane food" + } + + return genkit.GenerateText(ctx, g, + ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + ai.WithPrompt("Share a joke about %s.", input), + ) + }, + ) + + // Define a streaming flow that generates jokes about a given topic with passthrough streaming. + genkit.DefineStreamingFlow(g, "streamingJokesFlow", + func(ctx context.Context, input string, sendChunk ai.ModelStreamCallback) (string, error) { + if input == "" { + input = "airplane food" + } + + resp, err := genkit.Generate(ctx, g, + ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + ai.WithPrompt("Share a joke about %s.", input), + ai.WithStreaming(sendChunk), + ) + if err != nil { + return "", fmt.Errorf("could not generate joke: %w", err) + } + + return resp.Text(), nil + }, + ) + + // Optionally, start a web server to make the flow callable via HTTP. + mux := http.NewServeMux() + for _, a := range genkit.ListFlows(g) { + mux.HandleFunc("POST /"+a.Name(), genkit.Handler(a)) + } + log.Fatal(server.Start(ctx, "127.0.0.1:8080", mux)) +} diff --git a/go/samples/durable-streaming-firestore/README.md b/go/samples/durable-streaming-firestore/README.md new file mode 100644 index 0000000000..8ca1648e8b --- /dev/null +++ b/go/samples/durable-streaming-firestore/README.md @@ -0,0 +1,140 @@ +# Durable Streaming with Firestore + +This sample demonstrates durable streaming using Firestore as the backend. Unlike in-memory streaming, Firestore-backed streams: + +- **Survive server restarts** - Clients can reconnect to streams after server restarts +- **Work across instances** - Multiple server instances can serve the same stream +- **Auto-cleanup** - Completed streams are automatically deleted via Firestore TTL policies + +## Prerequisites + +1. **Firebase Project**: You need a Firebase/GCP project with Firestore enabled. + +2. **Authentication**: Authenticate with your Google Cloud project: + ```bash + gcloud auth application-default login + ``` + +3. **(Recommended) TTL Policy**: Configure a TTL policy on your Firestore collection for automatic cleanup of old streams. This requires setting a TTL on the `expiresAt` field: + + ```bash + gcloud firestore fields ttls update expiresAt \ + --collection-group=genkit-streams \ + --enable-ttl \ + --project=YOUR_PROJECT_ID + ``` + + See: https://firebase.google.com/docs/firestore/ttl + +## Environment Variables + +| Variable | Required | Default | Description | +|----------|----------|---------|-------------| +| `FIREBASE_PROJECT_ID` | Yes | - | Your Firebase/GCP project ID | +| `FIRESTORE_STREAMS_COLLECTION` | No | `genkit-streams` | Firestore collection for stream documents | + +## Running the Sample + +1. Set your project ID: + ```bash + export FIREBASE_PROJECT_ID=your-project-id + ``` + +2. Start the server: + ```bash + go run . + ``` + +## Testing + +### Start a streaming request + +```bash +curl -N -i -H "Accept: text/event-stream" \ + -d '{"data": 5}' \ + http://localhost:8080/countdown +``` + +Note the `X-Genkit-Stream-Id` header in the response - you'll need this to reconnect. + +### Reconnect to an existing stream + +Use the stream ID from the previous response: + +```bash +curl -N -H "Accept: text/event-stream" \ + -H "X-Genkit-Stream-Id: " \ + -d '{"data": 5}' \ + http://localhost:8080/countdown +``` + +The subscription will: +- Replay any buffered chunks that were already sent +- Continue with live updates if the stream is still in progress +- Return all chunks plus the final result if the stream has already completed + +### Test server restart resilience + +1. Start a countdown with a high number: + ```bash + curl -N -i -H "Accept: text/event-stream" -d '{"data": 30}' http://localhost:8080/countdown + ``` + +2. Copy the `X-Genkit-Stream-Id` header value + +3. Stop the server (Ctrl+C) + +4. Restart the server: `go run .` + +5. Reconnect using the stream ID: + ```bash + curl -N -H "Accept: text/event-stream" -H "X-Genkit-Stream-Id: " -d '{"data": 30}' http://localhost:8080/countdown + ``` + +You'll receive all previously buffered chunks, demonstrating that the stream state persisted across the server restart. + +## Configuration Options + +The `FirestoreStreamManager` supports these options: + +| Option | Default | Description | +|--------|---------|-------------| +| `WithCollection(name)` | (required) | Firestore collection for stream documents | +| `WithTimeout(duration)` | 60s | How long subscribers wait for new events before timeout | +| `WithTTL(duration)` | 5m | How long completed streams are retained before auto-deletion | + +Example: +```go +streamManager, err := firebasex.NewFirestoreStreamManager(ctx, g, + firebasex.WithCollection("my-streams"), + firebasex.WithTimeout(2*time.Minute), + firebasex.WithTTL(1*time.Hour), +) +``` + +## How It Works + +1. When a streaming request arrives, a Firestore document is created with the stream ID +2. As the flow produces chunks, they're appended to the document's `stream` array +3. Subscribers use Firestore's real-time listeners to receive updates +4. When the flow completes, a final "done" entry is added with the output +5. The `expiresAt` field is set based on TTL, and Firestore automatically deletes the document + +## License + +``` +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +``` + diff --git a/go/samples/durable-streaming-firestore/main.go b/go/samples/durable-streaming-firestore/main.go new file mode 100644 index 0000000000..988ccda790 --- /dev/null +++ b/go/samples/durable-streaming-firestore/main.go @@ -0,0 +1,89 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// This sample demonstrates durable streaming with Firestore backend. +// Unlike in-memory streaming, Firestore-backed streams survive server restarts +// and can be accessed across multiple server instances. +// +// See README.md for setup instructions. +package main + +import ( + "context" + "fmt" + "log" + "net/http" + "time" + + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/firebase" + firebasex "github.com/firebase/genkit/go/plugins/firebase/x" + "github.com/firebase/genkit/go/plugins/server" +) + +func main() { + ctx := context.Background() + + g := genkit.Init(ctx, genkit.WithPlugins(&firebase.Firebase{})) + + type CountdownChunk struct { + Count int `json:"count"` + Message string `json:"message"` + Timestamp string `json:"timestamp"` + } + + countdown := genkit.DefineStreamingFlow(g, "countdown", + func(ctx context.Context, count int, sendChunk func(context.Context, CountdownChunk) error) (string, error) { + if count <= 0 { + count = 5 + } + + for i := count; i > 0; i-- { + select { + case <-ctx.Done(): + return "", ctx.Err() + case <-time.After(1 * time.Second): + } + + chunk := CountdownChunk{ + Count: i, + Message: fmt.Sprintf("T-%d...", i), + Timestamp: time.Now().Format(time.RFC3339), + } + + if err := sendChunk(ctx, chunk); err != nil { + return "", err + } + } + + return "Liftoff!", nil + }) + + sm, err := firebasex.NewFirestoreStreamManager(ctx, g, + firebasex.WithCollection("genkit-streams"), + firebasex.WithTimeout(2*time.Minute), + firebasex.WithTTL(10*time.Minute), + ) + if err != nil { + log.Fatalf("Failed to create Firestore stream manager: %v", err) + } + + // Set up HTTP server with durable streaming enabled. + // Completed streams are kept for 10 minutes before cleanup. + mux := http.NewServeMux() + mux.HandleFunc("POST /countdown", genkit.Handler(countdown, genkit.WithStreamManager(sm))) + log.Fatal(server.Start(ctx, "127.0.0.1:8088", mux)) +} diff --git a/go/samples/durable-streaming/main.go b/go/samples/durable-streaming/main.go new file mode 100644 index 0000000000..36323990a3 --- /dev/null +++ b/go/samples/durable-streaming/main.go @@ -0,0 +1,99 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// This sample demonstrates durable streaming, which allows clients to reconnect +// to in-progress or completed streams using a stream ID. +// +// Start the server: +// +// go run . +// +// Test streaming (get a stream ID back in X-Genkit-Stream-Id header): +// +// curl -N -i -H "Accept: text/event-stream" \ +// -d '{"data": 5}' \ +// http://localhost:8080/countdown +// +// Subscribe to an existing stream using the stream ID from the previous response: +// +// curl -N -H "Accept: text/event-stream" \ +// -H "X-Genkit-Stream-Id: " \ +// -d '{"data": 5}' \ +// http://localhost:8080/countdown +// +// The subscription will replay any buffered chunks and then continue with live updates. +// If the stream has already completed, all chunks plus the final result are returned. + +package main + +import ( + "context" + "fmt" + "log" + "net/http" + "time" + + "github.com/firebase/genkit/go/core/x/streaming" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/server" +) + +func main() { + ctx := context.Background() + g := genkit.Init(ctx) + + type CountdownChunk struct { + Count int `json:"count"` + Message string `json:"message"` + Timestamp string `json:"timestamp"` + } + + // Define a streaming flow that counts down with delays. + countdown := genkit.DefineStreamingFlow(g, "countdown", + func(ctx context.Context, count int, sendChunk func(context.Context, CountdownChunk) error) (string, error) { + if count <= 0 { + count = 5 + } + + for i := count; i > 0; i-- { + select { + case <-ctx.Done(): + return "", ctx.Err() + case <-time.After(1 * time.Second): + } + + chunk := CountdownChunk{ + Count: i, + Message: fmt.Sprintf("T-%d...", i), + Timestamp: time.Now().Format(time.RFC3339), + } + + if err := sendChunk(ctx, chunk); err != nil { + return "", err + } + } + + return "Liftoff!", nil + }) + + // Set up HTTP server with durable streaming enabled. + // Completed streams are kept for 10 minutes before cleanup (while server is running). + mux := http.NewServeMux() + mux.HandleFunc("POST /countdown", genkit.Handler(countdown, + genkit.WithStreamManager(streaming.NewInMemoryStreamManager(streaming.WithTTL(10*time.Minute))), + )) + log.Fatal(server.Start(ctx, "127.0.0.1:8080", mux)) +} diff --git a/go/samples/prompts-dir/main.go b/go/samples/prompts-dir/main.go index 59e5e83843..e69de29bb2 100644 --- a/go/samples/prompts-dir/main.go +++ b/go/samples/prompts-dir/main.go @@ -1,61 +0,0 @@ -// Copyright 2025 Google LLC -// SPDX-License-Identifier: Apache-2.0 - -// [START main] -package main - -import ( - "context" - "errors" - - // Import Genkit and the Google AI plugin - "github.com/firebase/genkit/go/ai" - "github.com/firebase/genkit/go/genkit" - "github.com/firebase/genkit/go/plugins/googlegenai" -) - -func main() { - ctx := context.Background() - - g := genkit.Init(ctx, - genkit.WithPlugins(&googlegenai.GoogleAI{}), - genkit.WithPromptDir("prompts"), - ) - - type greetingStyle struct { - Style string `json:"style"` - Location string `json:"location"` - Name string `json:"name"` - } - - type greeting struct { - Greeting string `json:"greeting"` - } - - // Define a simple flow that prompts an LLM to generate greetings using a - // given style. - genkit.DefineFlow(g, "assistantGreetingFlow", func(ctx context.Context, input greetingStyle) (string, error) { - // Look up the prompt by name - prompt := genkit.LookupPrompt(g, "example") - if prompt == nil { - return "", errors.New("assistantGreetingFlow: failed to find prompt") - } - - // Execute the prompt with the provided input - resp, err := prompt.Execute(ctx, ai.WithInput(input)) - if err != nil { - return "", err - } - - var output greeting - if err = resp.Output(&output); err != nil { - return "", err - } - - return output.Greeting, nil - }) - - <-ctx.Done() -} - -// [END main] diff --git a/go/samples/prompts-dir/prompts/example.prompt b/go/samples/prompts-dir/prompts/example.prompt index 0492cfd326..e69de29bb2 100644 --- a/go/samples/prompts-dir/prompts/example.prompt +++ b/go/samples/prompts-dir/prompts/example.prompt @@ -1,19 +0,0 @@ ---- -model: googleai/gemini-2.5-flash -config: - temperature: 0.9 -input: - schema: - location: string - style?: string - name?: string - default: - name: Rutuja -output: - schema: - greeting: string ---- - -You are the world's most welcoming AI assistant and are currently working at {{location}}. - -Greet a guest{{#if name}} named {{name}}{{/if}}{{#if style}} in the style of {{style}}{{/if}}. diff --git a/go/samples/prompts-embed/main.go b/go/samples/prompts-embed/main.go new file mode 100644 index 0000000000..f0f7a5bde3 --- /dev/null +++ b/go/samples/prompts-embed/main.go @@ -0,0 +1,60 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This sample demonstrates how to use embedded prompts with genkit. +// Prompts are embedded directly into the binary using Go's embed package, +// which allows you to ship a self-contained binary without needing to +// distribute prompt files separately. + +package main + +import ( + "context" + "embed" + "errors" + + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" +) + +// Embed the prompts directory into the binary. +// The //go:embed directive makes the prompts available at compile time. +// +//go:embed prompts/* +var promptsFS embed.FS + +func main() { + ctx := context.Background() + + g := genkit.Init(ctx, + genkit.WithPlugins(&googlegenai.GoogleAI{}), + genkit.WithPromptFS(promptsFS), + ) + + genkit.DefineFlow(g, "sayHello", func(ctx context.Context, name string) (string, error) { + prompt := genkit.LookupPrompt(g, "example") + if prompt == nil { + return "", errors.New("prompt not found") + } + + resp, err := prompt.Execute(ctx) + if err != nil { + return "", err + } + + return resp.Text(), nil + }) + + <-ctx.Done() +} diff --git a/go/samples/prompts-embed/prompts/example.prompt b/go/samples/prompts-embed/prompts/example.prompt new file mode 100644 index 0000000000..abdd8bdee5 --- /dev/null +++ b/go/samples/prompts-embed/prompts/example.prompt @@ -0,0 +1,5 @@ +--- +model: googleai/gemini-2.5-flash +--- + +Say hello!. diff --git a/go/samples/prompts/main.go b/go/samples/prompts/main.go index fe9d3bcd94..ed0af5c795 100644 --- a/go/samples/prompts/main.go +++ b/go/samples/prompts/main.go @@ -209,7 +209,10 @@ func PromptWithMultiMessage(ctx context.Context, g *genkit.Genkit) { } resp, err := prompt.Execute(ctx, ai.WithModelName("googleai/gemini-2.5-pro"), - ai.WithInput(map[string]any{"videoUrl": "https://www.youtube.com/watch?v=K-hY0E6cGfo video/mp4"}), + ai.WithInput(map[string]any{ + "videoUrl": "https://www.youtube.com/watch?v=K-hY0E6cGfo", + "contentType": "video/mp4", + }), ) if err != nil { log.Fatal(err) diff --git a/go/samples/prompts/prompts/multi-msg.prompt b/go/samples/prompts/prompts/multi-msg.prompt index 63b177a884..b86981b6e6 100644 --- a/go/samples/prompts/prompts/multi-msg.prompt +++ b/go/samples/prompts/prompts/multi-msg.prompt @@ -3,6 +3,7 @@ model: googleai/gemini-2.5-flash input: schema: videoUrl: string + contentType: string output: summary: string --- @@ -13,4 +14,4 @@ You are a great AI assistant that summarizes videos talking as a pirate {{ role "user" }} Give me a summary of this video -{{media url=videoUrl}} +{{media url=videoUrl contentType=contentType}} diff --git a/js/plugins/anthropic/package.json b/js/plugins/anthropic/package.json index 5b3e19be1f..7bafa4b182 100644 --- a/js/plugins/anthropic/package.json +++ b/js/plugins/anthropic/package.json @@ -29,7 +29,7 @@ "genkit": "workspace:^" }, "dependencies": { - "@anthropic-ai/sdk": "^0.68.0" + "@anthropic-ai/sdk": "^0.71.2" }, "devDependencies": { "@types/node": "^20.11.16", @@ -64,6 +64,7 @@ "build": "npm-run-all build:clean check compile", "build:watch": "tsup-node --watch", "test": "tsx --test tests/*_test.ts", + "test:live": "tsx --test tests/live_test.ts", "test:file": "tsx --test", "test:live": "tsx --test tests/live_test.ts", "test:coverage": "check-node-version --node '>=22' && tsx --test --experimental-test-coverage --test-coverage-include='src/**/*.ts' ./tests/**/*_test.ts" diff --git a/js/plugins/anthropic/src/models.ts b/js/plugins/anthropic/src/models.ts index 98767af40c..ba03c15910 100644 --- a/js/plugins/anthropic/src/models.ts +++ b/js/plugins/anthropic/src/models.ts @@ -91,19 +91,66 @@ export const KNOWN_CLAUDE_MODELS: Record< 'claude-opus-4': commonRef('claude-opus-4', AnthropicThinkingConfigSchema), 'claude-sonnet-4-5': commonRef( 'claude-sonnet-4-5', - AnthropicThinkingConfigSchema + AnthropicThinkingConfigSchema, + { + supports: { + multiturn: true, + tools: true, + media: true, + systemRole: true, + output: ['text', 'json'], + constrained: 'all', + }, + } ), 'claude-haiku-4-5': commonRef( 'claude-haiku-4-5', - AnthropicThinkingConfigSchema - ), - 'claude-opus-4-5': commonRef( - 'claude-opus-4-5', - AnthropicThinkingConfigSchema + AnthropicThinkingConfigSchema, + { + supports: { + multiturn: true, + tools: true, + media: true, + systemRole: true, + output: ['text', 'json'], + constrained: 'all', + }, + } ), 'claude-opus-4-1': commonRef( 'claude-opus-4-1', - AnthropicThinkingConfigSchema + AnthropicThinkingConfigSchema, + { + supports: { + multiturn: true, + tools: true, + media: true, + systemRole: true, + output: ['text', 'json'], + constrained: 'all', + }, + } + ), + 'claude-opus-4-5': commonRef( + 'claude-opus-4-5', + AnthropicThinkingConfigSchema.extend({ + output_config: z + .object({ + effort: z.enum(['low', 'medium', 'high']).optional(), + }) + .passthrough() + .optional(), + }), + { + supports: { + multiturn: true, + tools: true, + media: true, + systemRole: true, + output: ['text', 'json'], + constrained: 'all', + }, + } ), }; @@ -232,9 +279,11 @@ export function claudeModel( defaultApiVersion: apiVersion, } = params; // Use supported model ref if available, otherwise create generic model ref - const modelRef = KNOWN_CLAUDE_MODELS[name]; - const modelInfo = modelRef ? modelRef.info : GENERIC_CLAUDE_MODEL_INFO; - const configSchema = modelRef?.configSchema ?? AnthropicConfigSchema; + const knownModelRef = KNOWN_CLAUDE_MODELS[name]; + let modelInfo = knownModelRef + ? knownModelRef.info + : GENERIC_CLAUDE_MODEL_INFO; + const configSchema = knownModelRef?.configSchema ?? AnthropicConfigSchema; return model< AnthropicBaseConfigSchemaType | AnthropicThinkingConfigSchemaType diff --git a/js/plugins/anthropic/src/runner/beta.ts b/js/plugins/anthropic/src/runner/beta.ts index 6a71fa71d5..099a589909 100644 --- a/js/plugins/anthropic/src/runner/beta.ts +++ b/js/plugins/anthropic/src/runner/beta.ts @@ -44,6 +44,7 @@ import { logger } from 'genkit/logging'; import { KNOWN_CLAUDE_MODELS, extractVersion } from '../models.js'; import { AnthropicConfigSchema, type ClaudeRunnerParams } from '../types.js'; +import { removeUndefinedProperties } from '../utils.js'; import { BaseRunner } from './base.js'; import { RunnerTypes } from './types.js'; @@ -66,6 +67,57 @@ const BETA_UNSUPPORTED_SERVER_TOOL_BLOCK_TYPES = new Set([ 'container_upload', ]); +const BETA_APIS = [ + // 'message-batches-2024-09-24', + // 'prompt-caching-2024-07-31', + // 'computer-use-2025-01-24', + // 'pdfs-2024-09-25', + // 'token-counting-2024-11-01', + // 'token-efficient-tools-2025-02-19', + // 'output-128k-2025-02-19', + 'files-api-2025-04-14', + // 'mcp-client-2025-04-04', + // 'dev-full-thinking-2025-05-14', + // 'interleaved-thinking-2025-05-14', + // 'code-execution-2025-05-22', + // 'extended-cache-ttl-2025-04-11', + // 'context-1m-2025-08-07', + // 'context-management-2025-06-27', + // 'model-context-window-exceeded-2025-08-26', + // 'skills-2025-10-02', + 'effort-2025-11-24', + // 'advanced-tool-use-2025-11-20', + 'structured-outputs-2025-11-13', +]; + +/** + * Transforms a JSON schema to be compatible with Anthropic's structured output requirements. + * Anthropic requires `additionalProperties: false` on all object types. + * @see https://docs.anthropic.com/en/docs/build-with-claude/structured-outputs#json-schema-limitations + */ +function toAnthropicSchema( + schema: Record +): Record { + const out = structuredClone(schema); + + // Remove $schema if present + delete out.$schema; + + // Add additionalProperties: false to objects + if (out.type === 'object') { + out.additionalProperties = false; + } + + // Recursively process nested objects + for (const key in out) { + if (typeof out[key] === 'object' && out[key] !== null) { + out[key] = toAnthropicSchema(out[key] as Record); + } + } + + return out; +} + const unsupportedServerToolError = (blockType: string): string => `Anthropic beta runner does not yet support server-managed tool block '${blockType}'. Please retry against the stable API or wait for dedicated support.`; @@ -140,6 +192,26 @@ export class BetaRunner extends BaseRunner { // Media if (part.media) { + if (part.media.contentType === 'anthropic/file') { + return { + type: 'document', + source: { + type: 'file', + file_id: part.media.url, + }, + }; + } + + if (part.media.contentType === 'anthropic/image') { + return { + type: 'image', + source: { + type: 'file', + file_id: part.media.url, + }, + }; + } + if (part.media.contentType === 'application/pdf') { return { type: 'document', @@ -249,45 +321,49 @@ export class BetaRunner extends BaseRunner { : system; } - const body: BetaMessageCreateParamsNonStreaming = { + const thinkingConfig = this.toAnthropicThinkingConfig( + request.config?.thinking + ) as BetaMessageCreateParams['thinking'] | undefined; + + // Need to extract topP and topK from request.config to avoid duplicate properties being added to the body + // This happens because topP and topK have different property names (top_p and top_k) in the Anthropic API. + // Thinking is extracted separately to avoid type issues. + // ApiVersion is extracted separately as it's not a valid property for the Anthropic API. + const { + topP, + topK, + apiVersion: _1, + thinking: _2, + ...restConfig + } = request.config ?? {}; + + const body = { model: mappedModelName, max_tokens: request.config?.maxOutputTokens ?? this.DEFAULT_MAX_OUTPUT_TOKENS, messages, - }; - - if (betaSystem !== undefined) body.system = betaSystem; - if (request.config?.stopSequences !== undefined) - body.stop_sequences = request.config.stopSequences; - if (request.config?.temperature !== undefined) - body.temperature = request.config.temperature; - if (request.config?.topK !== undefined) body.top_k = request.config.topK; - if (request.config?.topP !== undefined) body.top_p = request.config.topP; - if (request.config?.tool_choice !== undefined) { - body.tool_choice = request.config - .tool_choice as BetaMessageCreateParams['tool_choice']; - } - if (request.config?.metadata !== undefined) { - body.metadata = request.config - .metadata as BetaMessageCreateParams['metadata']; - } - if (request.tools) { - body.tools = request.tools.map((tool) => this.toAnthropicTool(tool)); - } - const thinkingConfig = this.toAnthropicThinkingConfig( - request.config?.thinking - ); - if (thinkingConfig) { - body.thinking = thinkingConfig as BetaMessageCreateParams['thinking']; - } - - if (request.output?.format && request.output.format !== 'text') { - throw new Error( - `Only text output format is supported for Claude models currently` - ); - } - - return body; + system: betaSystem, + stop_sequences: request.config?.stopSequences, + temperature: request.config?.temperature, + top_k: topK, + top_p: topP, + tool_choice: request.config?.tool_choice, + metadata: request.config?.metadata, + tools: request.tools?.map((tool) => this.toAnthropicTool(tool)), + thinking: thinkingConfig, + output_format: this.isStructuredOutputEnabled(request) + ? { + type: 'json_schema', + schema: toAnthropicSchema(request.output!.schema!), + } + : undefined, + betas: Array.isArray(request.config?.betas) + ? [...(request.config?.betas ?? [])] + : [...BETA_APIS], + ...restConfig, + } as BetaMessageCreateParamsNonStreaming; + + return removeUndefinedProperties(body); } /** @@ -316,46 +392,50 @@ export class BetaRunner extends BaseRunner { ] : system; - const body: BetaMessageCreateParamsStreaming = { + const thinkingConfig = this.toAnthropicThinkingConfig( + request.config?.thinking + ) as BetaMessageCreateParams['thinking'] | undefined; + + // Need to extract topP and topK from request.config to avoid duplicate properties being added to the body + // This happens because topP and topK have different property names (top_p and top_k) in the Anthropic API. + // Thinking is extracted separately to avoid type issues. + // ApiVersion is extracted separately as it's not a valid property for the Anthropic API. + const { + topP, + topK, + apiVersion: _1, + thinking: _2, + ...restConfig + } = request.config ?? {}; + + const body = { model: mappedModelName, max_tokens: request.config?.maxOutputTokens ?? this.DEFAULT_MAX_OUTPUT_TOKENS, messages, stream: true, - }; - - if (betaSystem !== undefined) body.system = betaSystem; - if (request.config?.stopSequences !== undefined) - body.stop_sequences = request.config.stopSequences; - if (request.config?.temperature !== undefined) - body.temperature = request.config.temperature; - if (request.config?.topK !== undefined) body.top_k = request.config.topK; - if (request.config?.topP !== undefined) body.top_p = request.config.topP; - if (request.config?.tool_choice !== undefined) { - body.tool_choice = request.config - .tool_choice as BetaMessageCreateParams['tool_choice']; - } - if (request.config?.metadata !== undefined) { - body.metadata = request.config - .metadata as BetaMessageCreateParams['metadata']; - } - if (request.tools) { - body.tools = request.tools.map((tool) => this.toAnthropicTool(tool)); - } - const thinkingConfig = this.toAnthropicThinkingConfig( - request.config?.thinking - ); - if (thinkingConfig) { - body.thinking = thinkingConfig as BetaMessageCreateParams['thinking']; - } - - if (request.output?.format && request.output.format !== 'text') { - throw new Error( - `Only text output format is supported for Claude models currently` - ); - } - - return body; + system: betaSystem, + stop_sequences: request.config?.stopSequences, + temperature: request.config?.temperature, + top_k: topK, + top_p: topP, + tool_choice: request.config?.tool_choice, + metadata: request.config?.metadata, + tools: request.tools?.map((tool) => this.toAnthropicTool(tool)), + thinking: thinkingConfig, + output_format: this.isStructuredOutputEnabled(request) + ? { + type: 'json_schema', + schema: toAnthropicSchema(request.output!.schema!), + } + : undefined, + betas: Array.isArray(request.config?.betas) + ? [...(request.config?.betas ?? [])] + : [...BETA_APIS], + ...restConfig, + } as BetaMessageCreateParamsStreaming; + + return removeUndefinedProperties(body); } protected toGenkitResponse(message: BetaMessage): GenerateResponseData { @@ -491,4 +571,14 @@ export class BetaRunner extends BaseRunner { return 'other'; } } + + private isStructuredOutputEnabled( + request: GenerateRequest + ): boolean { + return !!( + request.output?.schema && + request.output.constrained && + request.output.format === 'json' + ); + } } diff --git a/js/plugins/anthropic/src/runner/stable.ts b/js/plugins/anthropic/src/runner/stable.ts index 0c8f7ffc4f..1496029ebd 100644 --- a/js/plugins/anthropic/src/runner/stable.ts +++ b/js/plugins/anthropic/src/runner/stable.ts @@ -42,8 +42,10 @@ import { logger } from 'genkit/logging'; import { KNOWN_CLAUDE_MODELS, extractVersion } from '../models.js'; import { AnthropicConfigSchema, type ClaudeRunnerParams } from '../types.js'; +import { removeUndefinedProperties } from '../utils.js'; import { BaseRunner } from './base.js'; import { RunnerTypes as BaseRunnerTypes } from './types.js'; + interface RunnerTypes extends BaseRunnerTypes { Message: Message; Stream: MessageStream; @@ -179,6 +181,12 @@ export class Runner extends BaseRunner { request: GenerateRequest, cacheSystemPrompt?: boolean ): MessageCreateParamsNonStreaming { + if (request.output?.format && request.output.format !== 'text') { + throw new Error( + `Only text output format is supported for Claude models currently` + ); + } + const model = KNOWN_CLAUDE_MODELS[modelName]; const { system, messages } = this.toAnthropicMessages(request.messages); const mappedModelName = @@ -197,51 +205,40 @@ export class Runner extends BaseRunner { ] : system; + const thinkingConfig = this.toAnthropicThinkingConfig( + request.config?.thinking + ) as MessageCreateParams['thinking'] | undefined; + + // Need to extract topP and topK from request.config to avoid duplicate properties being added to the body + // This happens because topP and topK have different property names (top_p and top_k) in the Anthropic API. + // Thinking is extracted separately to avoid type issues. + // ApiVersion is extracted separately as it's not a valid property for the Anthropic API. + const { + topP, + topK, + apiVersion: _1, + thinking: _2, + ...restConfig + } = request.config ?? {}; + const body: MessageCreateParamsNonStreaming = { model: mappedModelName, max_tokens: request.config?.maxOutputTokens ?? this.DEFAULT_MAX_OUTPUT_TOKENS, messages, + system: systemValue, + stop_sequences: request.config?.stopSequences, + temperature: request.config?.temperature, + top_k: topK, + top_p: topP, + tool_choice: request.config?.tool_choice, + metadata: request.config?.metadata, + tools: request.tools?.map((tool) => this.toAnthropicTool(tool)), + thinking: thinkingConfig, + ...restConfig, }; - if (systemValue !== undefined) { - body.system = systemValue; - } - - if (request.tools) { - body.tools = request.tools.map((tool) => this.toAnthropicTool(tool)); - } - if (request.config?.topK !== undefined) { - body.top_k = request.config.topK; - } - if (request.config?.topP !== undefined) { - body.top_p = request.config.topP; - } - if (request.config?.temperature !== undefined) { - body.temperature = request.config.temperature; - } - if (request.config?.stopSequences !== undefined) { - body.stop_sequences = request.config.stopSequences; - } - if (request.config?.metadata !== undefined) { - body.metadata = request.config.metadata; - } - if (request.config?.tool_choice !== undefined) { - body.tool_choice = request.config.tool_choice; - } - const thinkingConfig = this.toAnthropicThinkingConfig( - request.config?.thinking - ); - if (thinkingConfig) { - body.thinking = thinkingConfig as MessageCreateParams['thinking']; - } - - if (request.output?.format && request.output.format !== 'text') { - throw new Error( - `Only text output format is supported for Claude models currently` - ); - } - return body; + return removeUndefinedProperties(body); } protected toAnthropicStreamingRequestBody( @@ -249,6 +246,12 @@ export class Runner extends BaseRunner { request: GenerateRequest, cacheSystemPrompt?: boolean ): MessageCreateParamsStreaming { + if (request.output?.format && request.output.format !== 'text') { + throw new Error( + `Only text output format is supported for Claude models currently` + ); + } + const model = KNOWN_CLAUDE_MODELS[modelName]; const { system, messages } = this.toAnthropicMessages(request.messages); const mappedModelName = @@ -267,53 +270,41 @@ export class Runner extends BaseRunner { ] : system; + const thinkingConfig = this.toAnthropicThinkingConfig( + request.config?.thinking + ) as MessageCreateParams['thinking'] | undefined; + + // Need to extract topP and topK from request.config to avoid duplicate properties being added to the body + // This happens because topP and topK have different property names (top_p and top_k) in the Anthropic API. + // Thinking is extracted separately to avoid type issues. + // ApiVersion is extracted separately as it's not a valid property for the Anthropic API. + const { + topP, + topK, + apiVersion: _1, + thinking: _2, + ...restConfig + } = request.config ?? {}; + const body: MessageCreateParamsStreaming = { model: mappedModelName, max_tokens: request.config?.maxOutputTokens ?? this.DEFAULT_MAX_OUTPUT_TOKENS, messages, stream: true, + system: systemValue, + stop_sequences: request.config?.stopSequences, + temperature: request.config?.temperature, + top_k: topK, + top_p: topP, + tool_choice: request.config?.tool_choice, + metadata: request.config?.metadata, + tools: request.tools?.map((tool) => this.toAnthropicTool(tool)), + thinking: thinkingConfig, + ...restConfig, }; - if (systemValue !== undefined) { - body.system = systemValue; - } - - if (request.tools) { - body.tools = request.tools.map((tool) => this.toAnthropicTool(tool)); - } - if (request.config?.topK !== undefined) { - body.top_k = request.config.topK; - } - if (request.config?.topP !== undefined) { - body.top_p = request.config.topP; - } - if (request.config?.temperature !== undefined) { - body.temperature = request.config.temperature; - } - if (request.config?.stopSequences !== undefined) { - body.stop_sequences = request.config.stopSequences; - } - if (request.config?.metadata !== undefined) { - body.metadata = request.config.metadata; - } - if (request.config?.tool_choice !== undefined) { - body.tool_choice = request.config.tool_choice; - } - const thinkingConfig = this.toAnthropicThinkingConfig( - request.config?.thinking - ); - if (thinkingConfig) { - body.thinking = - thinkingConfig as MessageCreateParamsStreaming['thinking']; - } - - if (request.output?.format && request.output.format !== 'text') { - throw new Error( - `Only text output format is supported for Claude models currently` - ); - } - return body; + return removeUndefinedProperties(body); } protected async createMessage( diff --git a/js/plugins/anthropic/src/types.ts b/js/plugins/anthropic/src/types.ts index 7b37867301..2f61464a10 100644 --- a/js/plugins/anthropic/src/types.ts +++ b/js/plugins/anthropic/src/types.ts @@ -67,26 +67,42 @@ export interface ClaudeRunnerParams extends ClaudeHelperParamsBase {} export const AnthropicBaseConfigSchema = GenerationCommonConfigSchema.extend({ tool_choice: z .union([ - z.object({ - type: z.literal('auto'), - }), - z.object({ - type: z.literal('any'), - }), - z.object({ - type: z.literal('tool'), - name: z.string(), - }), + z + .object({ + type: z.literal('auto'), + }) + .passthrough(), + z + .object({ + type: z.literal('any'), + }) + .passthrough(), + z + .object({ + type: z.literal('tool'), + name: z.string(), + }) + .passthrough(), ]) + .describe( + 'The tool choice to use for the request. This can be used to specify the tool to use for the request. If not specified, the model will choose the tool to use.' + ) .optional(), metadata: z .object({ user_id: z.string().optional(), }) + .describe('The metadata to include in the request.') + .passthrough() .optional(), /** Optional shorthand to pick API surface for this request. */ - apiVersion: z.enum(['stable', 'beta']).optional(), -}); + apiVersion: z + .enum(['stable', 'beta']) + .optional() + .describe( + 'The API version to use for the request. Both stable and beta features are available on the beta API surface.' + ), +}).passthrough(); export type AnthropicBaseConfigSchemaType = typeof AnthropicBaseConfigSchema; @@ -95,6 +111,8 @@ export const ThinkingConfigSchema = z enabled: z.boolean().optional(), budgetTokens: z.number().min(1_024).optional(), }) + .passthrough() + .passthrough() .superRefine((value, ctx) => { if (!value.enabled) return; @@ -117,8 +135,10 @@ export const ThinkingConfigSchema = z }); export const AnthropicThinkingConfigSchema = AnthropicBaseConfigSchema.extend({ - thinking: ThinkingConfigSchema.optional(), -}); + thinking: ThinkingConfigSchema.optional().describe( + 'The thinking configuration to use for the request. Thinking is a feature that allows the model to think about the request and provide a better response.' + ), +}).passthrough(); export const AnthropicConfigSchema = AnthropicThinkingConfigSchema; diff --git a/js/plugins/anthropic/src/utils.ts b/js/plugins/anthropic/src/utils.ts new file mode 100644 index 0000000000..6678eabc19 --- /dev/null +++ b/js/plugins/anthropic/src/utils.ts @@ -0,0 +1,25 @@ +/** + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export function removeUndefinedProperties(obj: T): T { + if (typeof obj !== 'object' || obj === null) { + return obj; + } + + return Object.fromEntries( + Object.entries(obj).filter(([_, value]) => value !== undefined) + ) as T; +} diff --git a/js/plugins/anthropic/tests/effort_param_test.ts b/js/plugins/anthropic/tests/effort_param_test.ts new file mode 100644 index 0000000000..12e67f4ed4 --- /dev/null +++ b/js/plugins/anthropic/tests/effort_param_test.ts @@ -0,0 +1,249 @@ +/** + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import type Anthropic from '@anthropic-ai/sdk'; +import * as assert from 'assert'; +import { genkit } from 'genkit'; +import { describe, test } from 'node:test'; +import { anthropic } from '../src/index.js'; +import { __testClient } from '../src/types.js'; +import { + createMockAnthropicClient, + createMockAnthropicMessage, + mockTextChunk, +} from './mocks/anthropic-client.js'; + +/** + * Options for creating a plugin with a mock client + */ +interface CreatePluginOptions { + apiVersion?: 'beta' | 'stable'; + mockClient: Anthropic; +} + +/** + * Creates an Anthropic plugin configured with a mock client for testing + */ +function createPlugin(options: CreatePluginOptions) { + return anthropic({ + apiVersion: options.apiVersion, + // @ts-ignore + [__testClient]: options.mockClient, + }); +} + +/** + * Creates a Genkit instance with the given plugin + */ +function createGenkitInstance(plugin: ReturnType) { + return genkit({ + plugins: [plugin], + }); +} + +/** + * Helper to get the proper create stub from the mock client for a given API version. + */ +function getCreateStub(mockClient: Anthropic, apiVersion: 'beta' | 'stable') { + return apiVersion === 'beta' + ? (mockClient.beta.messages.create as any) + : (mockClient.messages.create as any); +} + +/** + * Extracts the API request object from the mock for verification + * @param apiVersion - 'beta' or 'stable' to determine which API endpoint to check + */ +function getApiRequest( + mockClient: Anthropic, + apiVersion: 'beta' | 'stable', + callIndex: number = 0 +) { + const stub = getCreateStub(mockClient, apiVersion); + return stub.mock.calls[callIndex]?.arguments[0]; +} + +/** + * Verifies that the API was called the expected number of times + * @param apiVersion - 'beta' or 'stable' to determine which API endpoint to verify + */ +function verifyApiCalled( + mockClient: Anthropic, + apiVersion: 'beta' | 'stable', + expectedCalls: number = 1 +) { + const stub = getCreateStub(mockClient, apiVersion); + assert.strictEqual( + stub.mock.calls.length, + expectedCalls, + `${apiVersion === 'beta' ? 'Beta' : 'Stable'} API should be called ${expectedCalls} time(s)` + ); +} + +/** + * Tests for effort parameter functionality. + * These tests verify that output_config.effort is correctly passed to the Anthropic API + * when using the beta API with claude-opus-4-5. + */ +describe('Effort Parameter Tests', () => { + const OPUS_4_5_MODEL = 'anthropic/claude-opus-4-5'; + + test('should pass output_config.effort to API when using beta API with claude-opus-4-5', async () => { + const mockClient = createMockAnthropicClient({ + messageResponse: createMockAnthropicMessage({ + text: 'Response with high effort', + }), + }); + + const plugin = createPlugin({ + apiVersion: 'beta', + mockClient, + }); + + const ai = createGenkitInstance(plugin); + + await ai.generate({ + model: OPUS_4_5_MODEL, + prompt: 'Generate a detailed response', + config: { + output_config: { + effort: 'high', + }, + }, + }); + + verifyApiCalled(mockClient, 'beta'); + const apiRequest = getApiRequest(mockClient, 'beta'); + + assert.ok(apiRequest.output_config, 'Request should have output_config'); + assert.strictEqual( + apiRequest.output_config.effort, + 'high', + 'effort should be set to high' + ); + }); + + test('should pass output_config.effort with low value', async () => { + const mockClient = createMockAnthropicClient({ + messageResponse: createMockAnthropicMessage({ + text: 'Response with low effort', + }), + }); + + const plugin = createPlugin({ + apiVersion: 'beta', + mockClient, + }); + + const ai = createGenkitInstance(plugin); + + await ai.generate({ + model: OPUS_4_5_MODEL, + prompt: 'Generate a quick response', + config: { + output_config: { + effort: 'low', + }, + }, + }); + + verifyApiCalled(mockClient, 'beta'); + const apiRequest = getApiRequest(mockClient, 'beta'); + + assert.ok(apiRequest.output_config, 'Request should have output_config'); + assert.strictEqual( + apiRequest.output_config.effort, + 'low', + 'effort should be set to low' + ); + }); + + test('should pass output_config.effort with medium value', async () => { + const mockClient = createMockAnthropicClient({ + messageResponse: createMockAnthropicMessage({ + text: 'Response with medium effort', + }), + }); + + const plugin = createPlugin({ + apiVersion: 'beta', + mockClient, + }); + + const ai = createGenkitInstance(plugin); + + await ai.generate({ + model: OPUS_4_5_MODEL, + prompt: 'Generate a balanced response', + config: { + output_config: { + effort: 'medium', + }, + }, + }); + + verifyApiCalled(mockClient, 'beta'); + const apiRequest = getApiRequest(mockClient, 'beta'); + + assert.ok(apiRequest.output_config, 'Request should have output_config'); + assert.strictEqual( + apiRequest.output_config.effort, + 'medium', + 'effort should be set to medium' + ); + }); + + test('should pass output_config.effort in streaming requests', async () => { + const mockClient = createMockAnthropicClient({ + streamChunks: [mockTextChunk('Streaming response')], + messageResponse: createMockAnthropicMessage({ + text: 'Streaming response', + }), + }); + + const plugin = createPlugin({ + apiVersion: 'beta', + mockClient, + }); + + const ai = createGenkitInstance(plugin); + + await ai.generate({ + model: OPUS_4_5_MODEL, + prompt: 'Generate a streaming response', + config: { + output_config: { + effort: 'high', + }, + }, + streamingCallback: () => {}, + }); + + const betaStreamStub = mockClient.beta.messages.stream as any; + assert.strictEqual(betaStreamStub.mock.calls.length, 1); + const requestBody = betaStreamStub.mock.calls[0]?.arguments[0]; + + assert.ok( + requestBody.output_config, + 'Streaming request should have output_config' + ); + assert.strictEqual( + requestBody.output_config.effort, + 'high', + 'effort should be set to high in streaming request' + ); + }); +}); diff --git a/js/plugins/anthropic/tests/execution_test.ts b/js/plugins/anthropic/tests/execution_test.ts index 069d2d2dcd..ae7b6a85e7 100644 --- a/js/plugins/anthropic/tests/execution_test.ts +++ b/js/plugins/anthropic/tests/execution_test.ts @@ -14,11 +14,12 @@ * limitations under the License. */ -import type { GenerateRequest, ModelAction } from '@genkit-ai/ai/model'; import * as assert from 'assert'; +import type { GenerateRequest } from 'genkit'; +import type { ModelAction } from 'genkit/model'; import { describe, mock, test } from 'node:test'; import { anthropic } from '../src/index.js'; -import { __testClient } from '../src/types.js'; +import { PluginOptions, __testClient } from '../src/types.js'; import { createMockAnthropicClient, createMockAnthropicMessage, @@ -35,11 +36,17 @@ describe('Model Execution Integration Tests', () => { const plugin = anthropic({ apiKey: 'test-key', [__testClient]: mockClient, - }); + } as PluginOptions); + + // Verify plugin has resolve method + assert.ok(plugin.resolve, 'Plugin should have resolve method'); // Resolve the model action via plugin - const modelAction = plugin.resolve('model', 'claude-3-5-haiku-20241022'); - assert.ok(modelAction, 'Model should be resolved'); + const modelAction = plugin.resolve( + 'model', + 'claude-3-5-haiku-20241022' + ) as ModelAction; + assert.strictEqual( (modelAction as ModelAction).__action.name, 'anthropic/claude-3-5-haiku-20241022' @@ -55,11 +62,11 @@ describe('Model Execution Integration Tests', () => { ], }; - const response = await (modelAction as ModelAction)(request, { + const response = await modelAction(request, { streamingRequested: false, sendChunk: mock.fn(), abortSignal: new AbortController().signal, - }); + } as Parameters[1]); assert.ok(response, 'Response should be returned'); assert.ok(response.candidates, 'Response should have candidates'); @@ -86,7 +93,10 @@ describe('Model Execution Integration Tests', () => { const plugin = anthropic({ apiKey: 'test-key', [__testClient]: mockClient, - }); + } as PluginOptions); + + // Verify plugin has resolve method + assert.ok(plugin.resolve, 'Plugin should have resolve method'); const modelAction = plugin.resolve( 'model', @@ -114,9 +124,10 @@ describe('Model Execution Integration Tests', () => { streamingRequested: false, sendChunk: mock.fn(), abortSignal: new AbortController().signal, - }); + } as Parameters[1]); assert.ok(response, 'Response should be returned'); + assert.ok(response.candidates, 'Response should have candidates'); assert.strictEqual( response.candidates[0].message.content[0].text, 'The capital of France is Paris.' @@ -139,7 +150,10 @@ describe('Model Execution Integration Tests', () => { const plugin = anthropic({ apiKey: 'test-key', [__testClient]: mockClient, - }); + } as PluginOptions); + + // Verify plugin has resolve method + assert.ok(plugin.resolve, 'Plugin should have resolve method'); const modelAction = plugin.resolve( 'model', @@ -163,7 +177,7 @@ describe('Model Execution Integration Tests', () => { streamingRequested: false, sendChunk: mock.fn(), abortSignal: new AbortController().signal, - }); + } as Parameters[1]); assert.ok(response, 'Response should be returned'); @@ -197,7 +211,10 @@ describe('Model Execution Integration Tests', () => { const plugin = anthropic({ apiKey: 'test-key', [__testClient]: mockClient, - }); + } as PluginOptions); + + // Verify plugin has resolve method + assert.ok(plugin.resolve, 'Plugin should have resolve method'); const modelAction = plugin.resolve( 'model', @@ -212,7 +229,7 @@ describe('Model Execution Integration Tests', () => { streamingRequested: false, sendChunk: mock.fn(), abortSignal: new AbortController().signal, - } + } as Parameters[1] ); assert.ok(response.usage, 'Usage should be returned'); @@ -231,7 +248,10 @@ describe('Model Execution Integration Tests', () => { const plugin = anthropic({ apiKey: 'test-key', [__testClient]: mockClient, - }); + } as PluginOptions); + + // Verify plugin has resolve method + assert.ok(plugin.resolve, 'Plugin should have resolve method'); const modelAction = plugin.resolve( 'model', @@ -246,10 +266,11 @@ describe('Model Execution Integration Tests', () => { streamingRequested: false, sendChunk: mock.fn(), abortSignal: new AbortController().signal, - } + } as Parameters[1] ); assert.ok(response, 'Response should be returned'); + assert.ok(response.candidates, 'Response should have candidates'); assert.strictEqual(response.candidates[0].finishReason, 'length'); }); @@ -263,7 +284,10 @@ describe('Model Execution Integration Tests', () => { const plugin = anthropic({ apiKey: 'test-key', [__testClient]: mockClient, - }); + } as PluginOptions); + + // Verify plugin has resolve method + assert.ok(plugin.resolve, 'Plugin should have resolve method'); // Resolve without prefix const modelAction = plugin.resolve( @@ -280,7 +304,7 @@ describe('Model Execution Integration Tests', () => { streamingRequested: false, sendChunk: mock.fn(), abortSignal: new AbortController().signal, - } + } as Parameters[1] ); assert.ok(response, 'Response should be returned'); @@ -296,7 +320,10 @@ describe('Model Execution Integration Tests', () => { const plugin = anthropic({ apiKey: 'test-key', [__testClient]: mockClient, - }); + } as PluginOptions); + + // Verify plugin has resolve method + assert.ok(plugin.resolve, 'Plugin should have resolve method'); // Resolve with prefix const modelAction = plugin.resolve( @@ -313,7 +340,7 @@ describe('Model Execution Integration Tests', () => { streamingRequested: false, sendChunk: mock.fn(), abortSignal: new AbortController().signal, - } + } as Parameters[1] ); assert.ok(response, 'Response should be returned'); @@ -329,7 +356,10 @@ describe('Model Execution Integration Tests', () => { const plugin = anthropic({ apiKey: 'test-key', [__testClient]: mockClient, - }); + } as PluginOptions); + + // Verify plugin has resolve method + assert.ok(plugin.resolve, 'Plugin should have resolve method'); // Resolve unknown model (passes through to API) const modelAction = plugin.resolve( @@ -346,13 +376,15 @@ describe('Model Execution Integration Tests', () => { streamingRequested: false, sendChunk: mock.fn(), abortSignal: new AbortController().signal, - } + } as Parameters[1] ); assert.ok(response, 'Response should be returned for unknown model'); + assert.ok(response.candidates, 'Response should have candidates'); assert.strictEqual( - response.candidates[0].message.content[0].text, - 'Response from future model' + response.candidates?.[0]?.message.content[0].text, + 'Response from future model', + 'Response should have candidates' ); }); }); diff --git a/js/plugins/anthropic/tests/live_test.ts b/js/plugins/anthropic/tests/live_test.ts index f008157afe..0c370196dd 100644 --- a/js/plugins/anthropic/tests/live_test.ts +++ b/js/plugins/anthropic/tests/live_test.ts @@ -22,7 +22,7 @@ */ import * as assert from 'assert'; -import { genkit } from 'genkit'; +import { genkit, z } from 'genkit'; import { describe, it } from 'node:test'; import { anthropic } from '../src/index.js'; @@ -80,4 +80,50 @@ describe('Live Anthropic API Tests', { skip: !API_KEY }, () => { assert.ok(result.text.toLowerCase().includes('hello')); }); + + it('should return structured output matching the schema', async () => { + const ai = genkit({ + plugins: [anthropic({ apiKey: API_KEY, apiVersion: 'beta' })], + }); + + const schema = z.object({ + name: z.string(), + age: z.number(), + city: z.string(), + isStudent: z.boolean(), + isEmployee: z.boolean(), + isRetired: z.boolean(), + isUnemployed: z.boolean(), + isDisabled: z.boolean(), + }); + + const result = await ai.generate({ + model: 'anthropic/claude-sonnet-4-5', + prompt: + 'Generate a fictional person with name "Alice", age 30, and city "New York". Return only the JSON.', + output: { schema, format: 'json', constrained: true }, + }); + + const parsed = result.output; + assert.ok(parsed, 'Should have parsed output'); + assert.deepStrictEqual( + { name: parsed.name, age: parsed.age, city: parsed.city }, + { name: 'Alice', age: 30, city: 'New York' } + ); + + // Check that boolean fields are present and are actually booleans + for (const key of [ + 'isStudent', + 'isEmployee', + 'isRetired', + 'isUnemployed', + 'isDisabled', + ]) { + assert.strictEqual( + typeof parsed[key], + 'boolean', + `Field ${key} should be a boolean but got: ${typeof parsed[key]}` + ); + } + }); }); diff --git a/js/plugins/anthropic/tests/mocks/anthropic-client.ts b/js/plugins/anthropic/tests/mocks/anthropic-client.ts index 7fe29eceb2..09be1b749e 100644 --- a/js/plugins/anthropic/tests/mocks/anthropic-client.ts +++ b/js/plugins/anthropic/tests/mocks/anthropic-client.ts @@ -379,7 +379,7 @@ function toBetaMessage(message: Message): BetaMessage { server_tool_use: message.usage.server_tool_use as any, service_tier: message.usage.service_tier, }, - }; + } as BetaMessage; } function toBetaStreamEvent( diff --git a/js/plugins/anthropic/tests/structured_output_test.ts b/js/plugins/anthropic/tests/structured_output_test.ts new file mode 100644 index 0000000000..9a1e31fcd5 --- /dev/null +++ b/js/plugins/anthropic/tests/structured_output_test.ts @@ -0,0 +1,358 @@ +/** + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import type Anthropic from '@anthropic-ai/sdk'; +import * as assert from 'assert'; +import { genkit, z } from 'genkit'; +import { describe, test } from 'node:test'; +import { anthropic } from '../src/index.js'; +import { __testClient } from '../src/types.js'; +import { + createMockAnthropicClient, + createMockAnthropicMessage, +} from './mocks/anthropic-client.js'; + +/** + * Test constants for consistent test setup + */ +const TEST_API_KEY = 'test-key'; +const SUPPORTING_MODEL = 'anthropic/claude-sonnet-4-5'; +const NON_SUPPORTING_MODEL = 'anthropic/claude-sonnet-4'; + +/** + * Options for creating a plugin with a mock client + */ +interface CreatePluginOptions { + apiVersion?: 'beta' | 'stable'; + mockClient: Anthropic; +} + +/** + * Creates an Anthropic plugin configured with a mock client for testing + */ +function createPlugin(options: CreatePluginOptions) { + return anthropic({ + apiKey: TEST_API_KEY, + apiVersion: options.apiVersion, + // @ts-ignore + [__testClient]: options.mockClient, + }); +} + +/** + * Creates a Genkit instance with the given plugin + */ +function createGenkitInstance(plugin: ReturnType) { + return genkit({ + plugins: [plugin], + }); +} + +/** + * Helper to get the proper create stub from the mock client for a given API version. + */ +function getCreateStub(mockClient: Anthropic, apiVersion: 'beta' | 'stable') { + return apiVersion === 'beta' + ? (mockClient.beta.messages.create as any) + : (mockClient.messages.create as any); +} + +/** + * Extracts the API request object from the mock for verification + * @param apiVersion - 'beta' or 'stable' to determine which API endpoint to check + */ +function getApiRequest( + mockClient: Anthropic, + apiVersion: 'beta' | 'stable', + callIndex: number = 0 +) { + const stub = getCreateStub(mockClient, apiVersion); + return stub.mock.calls[callIndex]?.arguments[0]; +} + +/** + * Verifies that the API was called the expected number of times + * @param apiVersion - 'beta' or 'stable' to determine which API endpoint to verify + */ +function verifyApiCalled( + mockClient: Anthropic, + apiVersion: 'beta' | 'stable', + expectedCalls: number = 1 +) { + const stub = getCreateStub(mockClient, apiVersion); + assert.strictEqual( + stub.mock.calls.length, + expectedCalls, + `${apiVersion === 'beta' ? 'Beta' : 'Stable'} API should be called ${expectedCalls} time(s)` + ); +} + +/** + * Tests for structured output (constrained generation) functionality. + * These tests verify that output_format is correctly passed to the Anthropic API + * when using the beta API with constrained output, and that it's NOT passed + * in various edge cases (stable API, non-json format, missing schema, etc.) + */ +describe('Structured Output Tests', () => { + test('should pass output_format to API when using beta API with constrained output', async () => { + const mockClient = createMockAnthropicClient({ + messageResponse: createMockAnthropicMessage({ + text: '{"name":"Alice","age":30,"city":"New York","isStudent":false,"isEmployee":true,"isRetired":false,"isUnemployed":false,"isDisabled":false}', + }), + }); + + // Set up plugin with beta API enabled + const plugin = createPlugin({ + apiVersion: 'beta', + mockClient, + }); + + const ai = createGenkitInstance(plugin); + + // Call generate with sonnet 4.5 (supports native constrained output) + await ai.generate({ + model: SUPPORTING_MODEL, + prompt: + 'Generate a fictional person with name "Alice", age 30, and city "New York". Return only the JSON.', + output: { + schema: z.object({ + name: z.string(), + age: z.number(), + city: z.string(), + isStudent: z.boolean(), + isEmployee: z.boolean(), + isRetired: z.boolean(), + isUnemployed: z.boolean(), + isDisabled: z.boolean(), + }), + format: 'json', + constrained: true, + }, + }); + + // Verify the beta API was called + verifyApiCalled(mockClient, 'beta'); + + // Verify output_format was included in the API request + const apiRequest = getApiRequest(mockClient, 'beta'); + assert.ok(apiRequest.output_format, 'Request should have output_format'); + assert.strictEqual( + apiRequest.output_format.type, + 'json_schema', + 'output_format type should be json_schema' + ); + assert.ok( + apiRequest.output_format.schema, + 'output_format should have schema' + ); + // Verify schema transformation: additionalProperties should be false for constrained output + assert.strictEqual( + apiRequest.output_format.schema.additionalProperties, + false, + 'Schema should have additionalProperties: false' + ); + }); + + test('should NOT pass output_format to API when constrained is false and using beta API', async () => { + const mockClient = createMockAnthropicClient({ + messageResponse: createMockAnthropicMessage({ + text: '{"name":"Alice"}', + }), + }); + + // Set up plugin with beta API enabled + const plugin = createPlugin({ + apiVersion: 'beta', + mockClient, + }); + + const ai = createGenkitInstance(plugin); + + // Call generate with constrained: false + await ai.generate({ + model: SUPPORTING_MODEL, + prompt: 'Generate JSON', + output: { + format: 'json', + constrained: false, + schema: z.object({ + name: z.string(), + }), + }, + }); + + // Verify the beta API was called + verifyApiCalled(mockClient, 'beta'); + + // Verify output_format was NOT included when constrained is false + const apiRequest = getApiRequest(mockClient, 'beta'); + assert.strictEqual( + apiRequest.output_format, + undefined, + 'Request should NOT have output_format when constrained is false' + ); + }); + + test('should NOT pass output_format to API when format is not json and using beta API', async () => { + const mockClient = createMockAnthropicClient({ + messageResponse: createMockAnthropicMessage({ + text: 'Some text response', + }), + }); + + // Set up plugin with beta API enabled + const plugin = createPlugin({ + apiVersion: 'beta', + mockClient, + }); + + const ai = createGenkitInstance(plugin); + + // Call generate with format: 'text' (not 'json') + await ai.generate({ + model: SUPPORTING_MODEL, + prompt: 'Generate text', + output: { + format: 'text', + constrained: true, + }, + }); + + // Verify the beta API was called + verifyApiCalled(mockClient, 'beta'); + + // Verify output_format was NOT included when format is not json + const apiRequest = getApiRequest(mockClient, 'beta'); + assert.strictEqual( + apiRequest.output_format, + undefined, + 'Request should NOT have output_format when format is text' + ); + }); + + test('should NOT pass output_format to API when schema is not provided and using beta API', async () => { + const mockClient = createMockAnthropicClient({ + messageResponse: createMockAnthropicMessage({ + text: '{"anything": "goes"}', + }), + }); + + // Set up plugin with beta API enabled + const plugin = createPlugin({ + apiVersion: 'beta', + mockClient, + }); + + const ai = createGenkitInstance(plugin); + + // Call generate with constrained: true but no schema + await ai.generate({ + model: SUPPORTING_MODEL, + prompt: 'Generate JSON', + output: { + format: 'json', + constrained: true, + // No schema provided + }, + }); + + // Verify the beta API was called + verifyApiCalled(mockClient, 'beta'); + + // Verify output_format was NOT included when schema is missing + const apiRequest = getApiRequest(mockClient, 'beta'); + assert.strictEqual( + apiRequest.output_format, + undefined, + 'Request should NOT have output_format when schema is not provided' + ); + }); + + test('should NOT pass output_format to API when model does not support structured output and using beta API', async () => { + const mockClient = createMockAnthropicClient({ + messageResponse: createMockAnthropicMessage({ + text: '{"name":"Alice"}', + }), + }); + + // Set up plugin with beta API enabled + const plugin = createPlugin({ + apiVersion: 'beta', + mockClient, + }); + + const ai = createGenkitInstance(plugin); + + // Call generate with model that does not support structured output + await ai.generate({ + model: NON_SUPPORTING_MODEL, + prompt: 'Generate JSON', + output: { + format: 'json', + constrained: true, + }, + }); + + // Verify the beta API was called + verifyApiCalled(mockClient, 'beta'); + + // Verify output_format was NOT included when model does not support structured output + const apiRequest = getApiRequest(mockClient, 'beta'); + assert.strictEqual( + apiRequest.output_format, + undefined, + 'Request should NOT have output_format when model does not support structured output' + ); + }); + + test('should throw an error when using stable API with non-text output format', async () => { + const mockClient = createMockAnthropicClient({ + messageResponse: createMockAnthropicMessage({ + text: '{"name":"Alice","age":30,"city":"New York"}', + }), + }); + + // Set up plugin with stable API (not beta) + const plugin = createPlugin({ + apiVersion: 'stable', + mockClient, + }); + + const ai = createGenkitInstance(plugin); + + // Call generate with constrained output (would work with beta API) + // Expect an error to be thrown since only text output is supported for stable API + await assert.rejects( + async () => { + await ai.generate({ + model: SUPPORTING_MODEL, + prompt: 'Generate JSON', + output: { + format: 'json', + constrained: true, + schema: z.object({ + name: z.string(), + age: z.number(), + city: z.string(), + }), + }, + }); + }, + /Only text output format is supported for Claude models currently/, + 'Should throw an error for non-text output on stable API' + ); + }); +}); diff --git a/js/plugins/google-genai/src/googleai/gemini.ts b/js/plugins/google-genai/src/googleai/gemini.ts index 21873f1ac0..cf8eb08119 100644 --- a/js/plugins/google-genai/src/googleai/gemini.ts +++ b/js/plugins/google-genai/src/googleai/gemini.ts @@ -269,7 +269,7 @@ export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({ ) .optional(), thinkingLevel: z - .enum(['LOW', 'MEDIUM', 'HIGH']) + .enum(['MINIMAL', 'LOW', 'MEDIUM', 'HIGH']) .describe( 'For Gemini 3.0 - Indicates the thinking level. A higher level ' + 'is associated with more detailed thinking, which is needed for solving ' + @@ -419,6 +419,7 @@ const GENERIC_GEMMA_MODEL = commonRef( ); const KNOWN_GEMINI_MODELS = { + 'gemini-3-flash-preview': commonRef('gemini-3-flash-preview'), 'gemini-3-pro-preview': commonRef('gemini-3-pro-preview'), 'gemini-2.5-pro': commonRef('gemini-2.5-pro'), 'gemini-2.5-flash': commonRef('gemini-2.5-flash'), diff --git a/js/plugins/google-genai/src/vertexai/gemini.ts b/js/plugins/google-genai/src/vertexai/gemini.ts index 5f32120d8c..1715bd973d 100644 --- a/js/plugins/google-genai/src/vertexai/gemini.ts +++ b/js/plugins/google-genai/src/vertexai/gemini.ts @@ -318,7 +318,7 @@ export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({ ) .optional(), thinkingLevel: z - .enum(['LOW', 'MEDIUM', 'HIGH']) + .enum(['MINIMAL', 'LOW', 'MEDIUM', 'HIGH']) .describe( 'For Gemini 3.0 - Indicates the thinking level. A higher level ' + 'is associated with more detailed thinking, which is needed for solving ' + @@ -422,6 +422,7 @@ const GENERIC_IMAGE_MODEL = commonRef( ); export const KNOWN_GEMINI_MODELS = { + 'gemini-3-flash-preview': commonRef('gemini-3-flash-preview'), 'gemini-3-pro-preview': commonRef('gemini-3-pro-preview'), 'gemini-2.5-flash-lite': commonRef('gemini-2.5-flash-lite'), 'gemini-2.5-pro': commonRef('gemini-2.5-pro'), diff --git a/js/plugins/google-genai/tests/model-tests-tts.yaml b/js/plugins/google-genai/tests/model-tests-tts.yaml new file mode 100644 index 0000000000..726e2736d7 --- /dev/null +++ b/js/plugins/google-genai/tests/model-tests-tts.yaml @@ -0,0 +1,90 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +- model: googleai/imagen-4.0-generate-001 + supports: + - output-image +- model: googleai/gemini-2.5-flash-preview-tts + tests: + - name: TTS Test + input: + messages: + - role: user + content: + - text: 'Hello world' + config: + responseModalities: ['AUDIO'] + validators: + - valid-media:audio +- model: googleai/gemini-2.5-pro + supports: + - tool-request + - structured-output + - multiturn + - system-role + - input-image-base64 + - input-image-url + - input-video-youtube +- model: googleai/gemini-3-pro-preview + supports: + - tool-request + - structured-output + - multiturn + - system-role + - input-image-base64 + - input-image-url + - input-video-youtube + tests: + - name: Tool Response Conformance + input: + messages: + - role: user + content: + - text: 'What is the weather in New York? Use the tool.' + - role: model + content: + - toolRequest: + name: weather + input: + city: New York + metadata: + thoughtSignature: CvABAXLI2nxTZfKU3MkzLiGBrX62oq77vN2kHjT8pwwXRjtzbCqC07pPhIZ31sS+2kUFDh/kUY4SOvZzjjtP8UxI5GSFRWlX8yVDrDFo17RN/urwc1QuaMMzy66eQubpPRDEwfi6S5IKxZq0kRX6cSceB4NVCQAAAU8sYJwqWFL9CIaGac4lzF+34VvMWFLqdb40oe7/gw/KK1fqAeqDs+FJLksA+Q5qpHn3BETcqT0AuFe01IB2EVA7Us+/N3VGonw61F5cFNjHXO1jIYDybl3MXR9M5T5QB1a3EyicYXSX5/+bCmny1ka4kInbtzEqMMuv + - role: tool + content: + - toolResponse: + name: weather + output: '21C' + tools: + - name: weather + description: Get the weather for a city + inputSchema: + type: object + properties: + city: + type: string + required: + - city + validators: + - text-includes:21 +- model: googleai/gemini-2.5-flash + supports: + - tool-request + - structured-output + - multiturn + - system-role + - input-image-base64 + - input-image-url + - input-video-youtube diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index 454817f405..7f3123818d 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -171,7 +171,7 @@ importers: version: 6.0.1 tsup: specifier: ^8.3.5 - version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) + version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) tsx: specifier: ^4.19.2 version: 4.20.3 @@ -181,7 +181,7 @@ importers: optionalDependencies: '@genkit-ai/firebase': specifier: ^1.16.1 - version: 1.16.1(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)) + version: 1.16.1(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)) doc-snippets: dependencies: @@ -199,7 +199,7 @@ importers: version: 5.0.0 firebase-functions: specifier: ^6.3.1 - version: 6.3.2(firebase-admin@13.5.0(encoding@0.1.13)) + version: 6.3.2(firebase-admin@13.6.0(encoding@0.1.13)) genkit: specifier: workspace:* version: link:../genkit @@ -249,7 +249,7 @@ importers: version: 6.0.1 tsup: specifier: ^8.3.5 - version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) + version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) tsx: specifier: ^4.19.2 version: 4.20.3 @@ -260,8 +260,8 @@ importers: plugins/anthropic: dependencies: '@anthropic-ai/sdk': - specifier: ^0.68.0 - version: 0.68.0(zod@3.25.76) + specifier: ^0.71.2 + version: 0.71.2(zod@3.25.67) devDependencies: '@types/node': specifier: ^20.11.16 @@ -280,7 +280,7 @@ importers: version: 6.0.1 tsup: specifier: ^8.3.5 - version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) + version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) tsx: specifier: ^4.19.2 version: 4.20.3 @@ -314,7 +314,7 @@ importers: version: 6.0.1 tsup: specifier: ^8.0.2 - version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) + version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) tsx: specifier: ^4.7.0 version: 4.20.3 @@ -326,7 +326,7 @@ importers: dependencies: chromadb: specifier: 1.8.1 - version: 1.8.1(encoding@0.1.13)(openai@4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.76)) + version: 1.8.1(encoding@0.1.13)(openai@4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67)) genkit: specifier: workspace:^ version: link:../../genkit @@ -345,7 +345,7 @@ importers: version: 6.0.1 tsup: specifier: ^8.3.5 - version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) + version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) tsx: specifier: ^4.19.2 version: 4.20.3 @@ -403,7 +403,7 @@ importers: version: 29.4.0(@babel/core@7.25.7)(@jest/transform@29.7.0)(@jest/types@29.6.3)(babel-jest@29.7.0(@babel/core@7.25.7))(jest-util@29.7.0)(jest@29.7.0(@types/node@20.19.1)(ts-node@10.9.2(@types/node@20.19.1)(typescript@4.9.5)))(typescript@4.9.5) tsup: specifier: ^8.3.5 - version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) + version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) tsx: specifier: ^4.19.2 version: 4.20.3 @@ -418,7 +418,7 @@ importers: version: link:../../genkit openai: specifier: ^4.95.0 - version: 4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.76) + version: 4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67) devDependencies: '@jest/globals': specifier: ^29.7.0 @@ -437,7 +437,7 @@ importers: version: 29.4.0(@babel/core@7.25.7)(@jest/transform@29.7.0)(@jest/types@29.6.3)(babel-jest@29.7.0(@babel/core@7.25.7))(jest-util@29.7.0)(jest@29.7.0(@types/node@20.19.1)(ts-node@10.9.2(@types/node@20.19.1)(typescript@5.8.3)))(typescript@5.8.3) tsup: specifier: ^8.0.2 - version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@5.8.3)(yaml@2.8.2) + version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@5.8.3)(yaml@2.8.0) typescript: specifier: ^5.4.5 version: 5.8.3 @@ -465,7 +465,7 @@ importers: version: 6.0.1 tsup: specifier: ^8.3.5 - version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) + version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) tsx: specifier: ^4.19.2 version: 4.20.3 @@ -505,7 +505,7 @@ importers: version: 6.0.1 tsup: specifier: ^8.3.5 - version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) + version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) tsx: specifier: ^4.19.2 version: 4.20.3 @@ -551,7 +551,7 @@ importers: version: 6.0.1 tsup: specifier: ^8.3.5 - version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) + version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) tsx: specifier: ^4.19.2 version: 4.20.3 @@ -569,7 +569,7 @@ importers: version: 7.11.1(encoding@0.1.13) firebase-admin: specifier: '>=12.2' - version: 13.5.0(encoding@0.1.13) + version: 13.4.0(encoding@0.1.13) devDependencies: '@jest/globals': specifier: ^29.7.0 @@ -603,7 +603,7 @@ importers: version: 29.4.0(@babel/core@7.25.7)(@jest/transform@29.7.0)(@jest/types@29.6.3)(babel-jest@29.7.0(@babel/core@7.25.7))(jest-util@29.7.0)(jest@29.7.0(@types/node@20.19.1)(ts-node@10.9.2(@types/node@20.19.1)(typescript@4.9.5)))(typescript@4.9.5) tsup: specifier: ^8.3.5 - version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) + version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) tsx: specifier: ^4.19.2 version: 4.20.3 @@ -691,7 +691,7 @@ importers: version: 29.4.0(@babel/core@7.25.7)(@jest/transform@29.7.0)(@jest/types@29.6.3)(babel-jest@29.7.0(@babel/core@7.25.7))(jest-util@29.7.0)(jest@29.7.0(@types/node@20.19.1)(ts-node@10.9.2(@types/node@20.19.1)(typescript@4.9.5)))(typescript@4.9.5) tsup: specifier: ^8.3.5 - version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) + version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) tsx: specifier: ^4.19.2 version: 4.20.3 @@ -731,7 +731,7 @@ importers: version: 21.0.0 tsup: specifier: ^8.3.5 - version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) + version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) tsx: specifier: ^4.19.2 version: 4.20.3 @@ -765,7 +765,7 @@ importers: version: 6.0.1 tsup: specifier: ^8.3.5 - version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) + version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) tsx: specifier: ^4.19.2 version: 4.20.3 @@ -789,7 +789,7 @@ importers: version: link:../../genkit langchain: specifier: ^0.1.36 - version: 0.1.37(@google-cloud/storage@7.18.0(encoding@0.1.13))(@pinecone-database/pinecone@2.2.2)(chromadb@1.9.2(encoding@0.1.13)(openai@4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67)))(encoding@0.1.13)(fast-xml-parser@4.5.3)(firebase-admin@12.3.1(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(handlebars@4.7.8)(ignore@5.3.1)(jsonwebtoken@9.0.2)(lodash@4.17.21)(pdf-parse@1.1.1)(pg@8.16.2)(ws@8.18.3) + version: 0.1.37(@google-cloud/storage@7.16.0(encoding@0.1.13))(@pinecone-database/pinecone@2.2.2)(chromadb@1.9.2(encoding@0.1.13)(openai@4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67)))(encoding@0.1.13)(fast-xml-parser@4.5.3)(firebase-admin@12.3.1(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(handlebars@4.7.8)(ignore@5.3.1)(jsonwebtoken@9.0.2)(lodash@4.17.21)(pdf-parse@1.1.1)(pg@8.16.2)(ws@8.18.3) devDependencies: '@types/node': specifier: ^20.11.16 @@ -802,7 +802,7 @@ importers: version: 6.0.1 tsup: specifier: ^8.3.5 - version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) + version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) tsx: specifier: ^4.19.2 version: 4.20.3 @@ -845,7 +845,7 @@ importers: version: 29.4.0(@babel/core@7.25.7)(@jest/transform@29.7.0)(@jest/types@29.6.3)(babel-jest@29.7.0(@babel/core@7.25.7))(jest-util@29.7.0)(jest@29.7.0(@types/node@20.19.1)(ts-node@10.9.2(@types/node@20.19.1)(typescript@5.8.3)))(typescript@5.8.3) tsup: specifier: ^8.3.5 - version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@5.8.3)(yaml@2.8.2) + version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@5.8.3)(yaml@2.8.0) tsx: specifier: ^4.19.2 version: 4.20.3 @@ -872,7 +872,7 @@ importers: version: 29.7.0(@types/node@20.19.1)(ts-node@10.9.2(@types/node@20.19.1)(typescript@4.9.5)) next: specifier: ^15.4.10 - version: 15.4.10(@babel/core@7.25.7)(@opentelemetry/api@1.9.0)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + version: 15.5.9(@babel/core@7.25.7)(@opentelemetry/api@1.9.0)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) npm-run-all: specifier: ^4.1.5 version: 4.1.5 @@ -884,7 +884,7 @@ importers: version: 29.4.0(@babel/core@7.25.7)(@jest/transform@29.7.0)(@jest/types@29.6.3)(babel-jest@29.7.0(@babel/core@7.25.7))(jest-util@29.7.0)(jest@29.7.0(@types/node@20.19.1)(ts-node@10.9.2(@types/node@20.19.1)(typescript@4.9.5)))(typescript@4.9.5) tsup: specifier: ^8.0.2 - version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) + version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) tsx: specifier: ^4.7.0 version: 4.20.3 @@ -915,7 +915,7 @@ importers: version: 6.0.1 tsup: specifier: ^8.3.5 - version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) + version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) tsx: specifier: ^4.19.2 version: 4.20.3 @@ -946,7 +946,7 @@ importers: version: 6.0.1 tsup: specifier: ^8.3.5 - version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) + version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) tsx: specifier: ^4.19.2 version: 4.20.3 @@ -970,7 +970,7 @@ importers: version: 1.10.0(encoding@0.1.13) '@mistralai/mistralai-gcp': specifier: ^1.3.5 - version: 1.5.0(encoding@0.1.13)(zod@3.25.76) + version: 1.5.0(encoding@0.1.13)(zod@3.25.67) genkit: specifier: workspace:^ version: link:../../genkit @@ -985,7 +985,7 @@ importers: version: 3.3.2 openai: specifier: ^4.52.7 - version: 4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.76) + version: 4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67) devDependencies: '@types/node': specifier: ^20.11.16 @@ -1010,7 +1010,7 @@ importers: version: 21.0.0 tsup: specifier: ^8.3.5 - version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2) + version: 8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.0) tsx: specifier: ^4.19.2 version: 4.20.3 @@ -1023,7 +1023,7 @@ importers: version: 7.9.4(encoding@0.1.13) firebase-admin: specifier: '>=12.2' - version: 13.5.0(encoding@0.1.13) + version: 13.4.0(encoding@0.1.13) testapps/anthropic: dependencies: @@ -1072,7 +1072,7 @@ importers: version: 1.0.2 zod-to-json-schema: specifier: ^3.24.5 - version: 3.24.5(zod@3.25.76) + version: 3.24.5(zod@3.25.67) devDependencies: '@types/wav': specifier: ^1.0.4 @@ -1088,7 +1088,7 @@ importers: version: link:../../plugins/compat-oai '@genkit-ai/express': specifier: ^1.1.0 - version: 1.12.0(@genkit-ai/core@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit))(express@5.1.0)(genkit@genkit) + version: 1.12.0(@genkit-ai/core@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit))(express@5.1.0)(genkit@genkit) genkit: specifier: workspace:* version: link:../../genkit @@ -1299,7 +1299,7 @@ importers: version: 5.1.0 firebase-admin: specifier: ^13.5.0 - version: 13.5.0(encoding@0.1.13) + version: 13.6.0(encoding@0.1.13) genkit: specifier: workspace:^ version: link:../../genkit @@ -1619,7 +1619,7 @@ importers: version: 2025.7.1 '@modelcontextprotocol/server-filesystem': specifier: ^2025.3.28 - version: 2025.7.1(zod@3.25.76) + version: 2025.7.1(zod@3.25.67) '@types/express': specifier: ^4.17.21 version: 4.17.23 @@ -1674,7 +1674,7 @@ importers: version: link:../../genkit zod: specifier: ^3.22.4 - version: 3.25.76 + version: 3.25.67 devDependencies: tsx: specifier: ^4.7.1 @@ -1705,7 +1705,7 @@ importers: version: link:../../plugins/ollama genkitx-openai: specifier: ^0.10.1 - version: 0.10.1(@genkit-ai/ai@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit))(@genkit-ai/core@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit))(encoding@0.1.13)(ws@8.18.3) + version: 0.10.1(@genkit-ai/ai@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit))(@genkit-ai/core@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit))(encoding@0.1.13)(ws@8.18.3) devDependencies: rimraf: specifier: ^6.0.1 @@ -1825,7 +1825,7 @@ importers: version: link:../../genkit next: specifier: ^15.4.10 - version: 15.4.10(@babel/core@7.25.7)(@opentelemetry/api@1.9.0)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + version: 15.5.9(@babel/core@7.25.7)(@opentelemetry/api@1.9.0)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) zod: specifier: ^3.24.1 version: 3.25.67 @@ -2145,8 +2145,8 @@ packages: '@anthropic-ai/sdk@0.24.3': resolution: {integrity: sha512-916wJXO6T6k8R6BAAcLhLPv/pnLGy7YSEBZXZ1XTFbLcTZE8oTy3oDW9WJf9KKZwMvVcePIfoTSvzXHRcGxkQQ==} - '@anthropic-ai/sdk@0.68.0': - resolution: {integrity: sha512-SMYAmbbiprG8k1EjEPMTwaTqssDT7Ae+jxcR5kWXiqTlbwMR2AthXtscEVWOHkRfyAV5+y3PFYTJRNa3OJWIEw==} + '@anthropic-ai/sdk@0.71.2': + resolution: {integrity: sha512-TGNDEUuEstk/DKu0/TflXAEt+p+p/WhTlFzEnoosvbaDU2LTjm42igSdlL0VijrKpWejtOKxX0b8A7uc+XiSAQ==} hasBin: true peerDependencies: zod: ^3.25.0 || ^4.0.0 @@ -2351,9 +2351,6 @@ packages: '@dabh/diagnostics@2.0.3': resolution: {integrity: sha512-hrlQOIi7hAfzsMqlGSFyVucrx38O+j6wiGOf//H2ecvIEqYN4ADBSS2iLMh5UFyDunCNniUIPk/q3riFv45xRA==} - '@dabh/diagnostics@2.0.8': - resolution: {integrity: sha512-R4MSXTVnuMzGD7bzHdW2ZhhdPC/igELENcq5IjEverBvq5hn1SXCWcsi6eSsdWP0/Ur+SItRRjAktmdoX/8R/Q==} - '@electric-sql/pglite@0.2.17': resolution: {integrity: sha512-qEpKRT2oUaWDH6tjRxLHjdzMqRUGYDnGZlKrnL4dJ77JVMcP2Hpo3NYnOSPKdZdeec57B6QPprCUFg0picx5Pw==} @@ -2516,8 +2513,8 @@ packages: '@fastify/busboy@3.0.0': resolution: {integrity: sha512-83rnH2nCvclWaPQQKvkJ2pdOjG4TZyEVuFDnlOF6KP08lDaaceVyw/W63mDuafQT+MKHCvXIPpE5uYWeM0rT4w==} - '@fastify/busboy@3.2.0': - resolution: {integrity: sha512-m9FVDXU3GT2ITSe0UaMA5rU3QkfC/UXtCU8y0gSN/GugTqtVldOBWIB5V6V3sbmenVZUIpU6f+mPEO2+m5iTaA==} + '@fastify/busboy@3.1.1': + resolution: {integrity: sha512-5DGmA8FTdB2XbDeEwc/5ZXBl6UbBAyBOOLlPuBnZ/N1SwdH9Ii+cOX3tBROlDgcTXxjOYnLMVoKk9+FXAw0CJw==} '@firebase/ai@1.4.0': resolution: {integrity: sha512-wvF33gtU6TXb6Co8TEC1pcl4dnVstYmRE/vs9XjUGE7he7Sgf5TqSu+EoXk/fuzhw5tKr1LC5eG9KdYFM+eosw==} @@ -2625,9 +2622,6 @@ packages: '@firebase/database-types@1.0.14': resolution: {integrity: sha512-8a0Q1GrxM0akgF0RiQHliinhmZd+UQPrxEmUv7MnQBYfVFiLtKOgs3g6ghRt/WEGJHyQNslZ+0PocIwNfoDwKw==} - '@firebase/database-types@1.0.16': - resolution: {integrity: sha512-xkQLQfU5De7+SPhEGAXFBnDryUWhhlFXelEg2YeZOQMCdoe7dL64DDAd77SQsR+6uoXIZY5MB4y/inCs4GTfcw==} - '@firebase/database-types@1.0.6': resolution: {integrity: sha512-sMI7IynSZBsyGbUugc8PKE1jwKbnvaieAz/RxuM57PZQNCi6Rteiviwcw/jqZOX6igqYJwXWZ3UzKOZo2nUDRA==} @@ -2760,18 +2754,14 @@ packages: resolution: {integrity: sha512-Z4rK23xBCwgKDqmzGVMef+Vb4xso2j5Q8OG0vVL4m4fA5ZjPMYQazu8OJJC3vtQRC3SQ/Pgx/6TPNVsCd70QRw==} engines: {node: '>=18.0.0'} - '@firebase/util@1.13.0': - resolution: {integrity: sha512-0AZUyYUfpMNcztR5l09izHwXkZpghLgCUaAGjtMwXnCg3bj4ml5VgiwqOMOxJ+Nw4qN/zJAaOQBcJ7KGkWStqQ==} - engines: {node: '>=20.0.0'} - '@firebase/webchannel-wrapper@1.0.3': resolution: {integrity: sha512-2xCRM9q9FlzGZCdgDMJwc0gyUkWFtkosy7Xxr6sFgQwn+wMNIWd7xIvYNauU1r64B5L5rsGKy/n9TKJ0aAFeqQ==} - '@genkit-ai/ai@1.26.0-rc.0': - resolution: {integrity: sha512-TrNRK/fSuhM8XHOGAV6lDH9daGYfWCPyW55ZDtH3IeDAVNtfcvOhgmM+uvgtsvjKYeJiDHdwVQeabL1e9tzdYg==} + '@genkit-ai/ai@1.27.0': + resolution: {integrity: sha512-Vogp21a0pBgL7UsdHj1Jm79PjrQdLNRK5dZT05Xvr3f7GwCNMv0k6Olxp+qrgwLi6DbRsVPi7c+wcldekMlLFQ==} - '@genkit-ai/core@1.26.0-rc.0': - resolution: {integrity: sha512-ZnzyWLeb364csirXJusPKKV5i6ZqzsKHUc9ZKRGBSoPXkrz/w0hLGoPCFjCSbfm3DmRvC45/HOn3uEVtwkN2MA==} + '@genkit-ai/core@1.27.0': + resolution: {integrity: sha512-2dcr/yKixcxNj0U9pFpx9qNOTJcRdEjEz76qd5+o6Ac31foRBMb3J9Bvrfr+SaaPI4kiMnFUxN1X+w5yNjryQg==} '@genkit-ai/express@1.12.0': resolution: {integrity: sha512-QAxSS07dX5ovSfsUB4s90KaDnv4zg1wnoxCZCa+jBsYUyv9NvCCTsOk25xAQgGxc7xi3+MD+3AsPier5oZILIg==} @@ -2791,27 +2781,11 @@ packages: firebase: optional: true - '@genkit-ai/firebase@1.25.0': - resolution: {integrity: sha512-Z0FbnJHQs8qS0yxG++Dn3CZ7gv+YNaihGaWXoDKy02mNOkeRzHA6UPaWxSTaWkWHYdB0MyOnMGlyqxnWyqVdmg==} - peerDependencies: - '@google-cloud/firestore': ^7.11.0 - firebase: '>=11.5.0' - firebase-admin: '>=12.2' - genkit: ^1.25.0 - peerDependenciesMeta: - firebase: - optional: true - '@genkit-ai/google-cloud@1.16.1': resolution: {integrity: sha512-uujjdGr/sra7iKHApufwkt5jGo7CQcRCJNWPgnSg4g179CjtvtZBGjxmFRVBtKzuF61ktkY6E9JoLz83nWEyAA==} peerDependencies: genkit: ^1.16.1 - '@genkit-ai/google-cloud@1.25.0': - resolution: {integrity: sha512-wHCa8JSTv7MtwzXjUQ9AT5v0kCTJrz0In+ffgAYw1yt8ComAz5o7Ir+xks+sX1vJfN8ptvW0GUa6rsUaXCB3kA==} - peerDependencies: - genkit: ^1.25.0 - '@gerrit0/mini-shiki@1.27.2': resolution: {integrity: sha512-GeWyHz8ao2gBiUW4OJnQDxXQnFgZQwwQk05t/CVVgNBN7/rK8XZ7xY6YhLVv9tH3VppWWmr9DCl3MwemB/i+Og==} @@ -2839,10 +2813,6 @@ packages: resolution: {integrity: sha512-ZxOdH8Wr01hBDvKCQfMWqwUcfNcN3JY19k1LtS1fTFhEyorYPLsbWN+VxIRL46pOYGHTPkU3Or5HbT/SLQM5nA==} engines: {node: '>=14.0.0'} - '@google-cloud/firestore@7.11.6': - resolution: {integrity: sha512-EW/O8ktzwLfyWBOsNuhRoMi8lrC3clHM5LVFhGvO1HCsLozCOOXRAlHrYBoE6HL42Sc8yYMuCb2XqcnJ4OOEpw==} - engines: {node: '>=14.0.0'} - '@google-cloud/logging-winston@6.0.1': resolution: {integrity: sha512-tgA/qe/aGZITMrJ/5Tuykv234pLb/Qo6iDZ8SDkjbsiIy69mLQmbphrUd/IqnE17BSDfrwDUckvWdghiy8b+Qg==} engines: {node: '>=14.0.0'} @@ -2909,10 +2879,6 @@ packages: resolution: {integrity: sha512-7/5LRgykyOfQENcm6hDKP8SX/u9XxE5YOiWOkgkwcoO+cG8xT/cyOvp9wwN3IxfdYgpHs8CE7Nq2PKX2lNaEXw==} engines: {node: '>=14'} - '@google-cloud/storage@7.18.0': - resolution: {integrity: sha512-r3ZwDMiz4nwW6R922Z1pwpePxyRwE5GdevYX63hRmAQUkUQJcBH/79EnQPDv5cOv1mFBgevdNWQfi3tie3dHrQ==} - engines: {node: '>=14'} - '@google-cloud/vertexai@1.10.0': resolution: {integrity: sha512-HqYqoivNtkq59po8m7KI0n+lWKdz4kabENncYQXZCX/hBWJfXtKAfR/2nUQsP+TwSfHKoA7zDL2RrJYIv/j3VQ==} engines: {node: '>=18.0.0'} @@ -2950,8 +2916,8 @@ packages: resolution: {integrity: sha512-HPa/K5NX6ahMoeBv15njAc/sfF4/jmiXLar9UlC2UfHFKZzsCVLc3wbe7+7qua7w9VPh2/L6EBxyAV7/E8Wftg==} engines: {node: '>=12.10.0'} - '@grpc/grpc-js@1.14.2': - resolution: {integrity: sha512-QzVUtEFyu05UNx2xr0fCQmStUO17uVQhGNowtxs00IgTZT6/W2PBLfUkj30s0FKJ29VtTa3ArVNIhNP6akQhqA==} + '@grpc/grpc-js@1.14.3': + resolution: {integrity: sha512-Iq8QQQ/7X3Sac15oB6p0FmUg/klxQvXLeileoqrTRGJYLV+/9tubbr9ipz0GKHjmXVsgFPo/+W+2cA8eNcR+XA==} engines: {node: '>=12.10.0'} '@grpc/grpc-js@1.9.15': @@ -3110,8 +3076,8 @@ packages: cpu: [x64] os: [win32] - '@inquirer/external-editor@1.0.2': - resolution: {integrity: sha512-yy9cOoBnx58TlsPrIxauKIFQTiyH+0MK4e97y4sV9ERbI+zDxw7i2hxHLCIEGIE/8PPvDxGhgzIOTSOWcs6/MQ==} + '@inquirer/external-editor@1.0.3': + resolution: {integrity: sha512-RWbSrDiYmO4LbejWY7ttpxczuwQyZLBUyygsA9Nsv95hpzUWwnNTVQmAq3xuh7vNwCp07UTmE5i11XAEExx4RA==} engines: {node: '>=18'} peerDependencies: '@types/node': '>=18' @@ -3219,9 +3185,6 @@ packages: '@jridgewell/sourcemap-codec@1.5.0': resolution: {integrity: sha512-gv3ZRaISU3fjPAgNsriBRqGWQL6quFx04YMPW/zD8XMLsU32mhCCbfbO6KZFLjvYpCZ8zyDEgqsgf+PwPaM7GQ==} - '@jridgewell/sourcemap-codec@1.5.5': - resolution: {integrity: sha512-cYQ9310grqxueWbl+WuIUIaiUaDcj7WOq5fVhEljNVgRfOUhY9fy2zTvfoqWsnebh8Sl70VScFbICvJnLKB0Og==} - '@jridgewell/trace-mapping@0.3.25': resolution: {integrity: sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==} @@ -3621,53 +3584,53 @@ packages: resolution: {integrity: sha512-92ybDocKl6JM48ZpYbj+A7Qt45IaTABDk0y3sDecEQfgdhfNzJtEityqNHoCZ4Vty2dldPkJhxgvOnbrQMXTTA==} engines: {node: '>= 10'} - '@next/env@15.4.10': - resolution: {integrity: sha512-knhmoJ0Vv7VRf6pZEPSnciUG1S4bIhWx+qTYBW/AjxEtlzsiNORPk8sFDCEvqLfmKuey56UB9FL1UdHEV3uBrg==} + '@next/env@15.5.9': + resolution: {integrity: sha512-4GlTZ+EJM7WaW2HEZcyU317tIQDjkQIyENDLxYJfSWlfqguN+dHkZgyQTV/7ykvobU7yEH5gKvreNrH4B6QgIg==} - '@next/swc-darwin-arm64@15.4.8': - resolution: {integrity: sha512-Pf6zXp7yyQEn7sqMxur6+kYcywx5up1J849psyET7/8pG2gQTVMjU3NzgIt8SeEP5to3If/SaWmaA6H6ysBr1A==} + '@next/swc-darwin-arm64@15.5.7': + resolution: {integrity: sha512-IZwtxCEpI91HVU/rAUOOobWSZv4P2DeTtNaCdHqLcTJU4wdNXgAySvKa/qJCgR5m6KI8UsKDXtO2B31jcaw1Yw==} engines: {node: '>= 10'} cpu: [arm64] os: [darwin] - '@next/swc-darwin-x64@15.4.8': - resolution: {integrity: sha512-xla6AOfz68a6kq3gRQccWEvFC/VRGJmA/QuSLENSO7CZX5WIEkSz7r1FdXUjtGCQ1c2M+ndUAH7opdfLK1PQbw==} + '@next/swc-darwin-x64@15.5.7': + resolution: {integrity: sha512-UP6CaDBcqaCBuiq/gfCEJw7sPEoX1aIjZHnBWN9v9qYHQdMKvCKcAVs4OX1vIjeE+tC5EIuwDTVIoXpUes29lg==} engines: {node: '>= 10'} cpu: [x64] os: [darwin] - '@next/swc-linux-arm64-gnu@15.4.8': - resolution: {integrity: sha512-y3fmp+1Px/SJD+5ntve5QLZnGLycsxsVPkTzAc3zUiXYSOlTPqT8ynfmt6tt4fSo1tAhDPmryXpYKEAcoAPDJw==} + '@next/swc-linux-arm64-gnu@15.5.7': + resolution: {integrity: sha512-NCslw3GrNIw7OgmRBxHtdWFQYhexoUCq+0oS2ccjyYLtcn1SzGzeM54jpTFonIMUjNbHmpKpziXnpxhSWLcmBA==} engines: {node: '>= 10'} cpu: [arm64] os: [linux] - '@next/swc-linux-arm64-musl@15.4.8': - resolution: {integrity: sha512-DX/L8VHzrr1CfwaVjBQr3GWCqNNFgyWJbeQ10Lx/phzbQo3JNAxUok1DZ8JHRGcL6PgMRgj6HylnLNndxn4Z6A==} + '@next/swc-linux-arm64-musl@15.5.7': + resolution: {integrity: sha512-nfymt+SE5cvtTrG9u1wdoxBr9bVB7mtKTcj0ltRn6gkP/2Nu1zM5ei8rwP9qKQP0Y//umK+TtkKgNtfboBxRrw==} engines: {node: '>= 10'} cpu: [arm64] os: [linux] - '@next/swc-linux-x64-gnu@15.4.8': - resolution: {integrity: sha512-9fLAAXKAL3xEIFdKdzG5rUSvSiZTLLTCc6JKq1z04DR4zY7DbAPcRvNm3K1inVhTiQCs19ZRAgUerHiVKMZZIA==} + '@next/swc-linux-x64-gnu@15.5.7': + resolution: {integrity: sha512-hvXcZvCaaEbCZcVzcY7E1uXN9xWZfFvkNHwbe/n4OkRhFWrs1J1QV+4U1BN06tXLdaS4DazEGXwgqnu/VMcmqw==} engines: {node: '>= 10'} cpu: [x64] os: [linux] - '@next/swc-linux-x64-musl@15.4.8': - resolution: {integrity: sha512-s45V7nfb5g7dbS7JK6XZDcapicVrMMvX2uYgOHP16QuKH/JA285oy6HcxlKqwUNaFY/UC6EvQ8QZUOo19cBKSA==} + '@next/swc-linux-x64-musl@15.5.7': + resolution: {integrity: sha512-4IUO539b8FmF0odY6/SqANJdgwn1xs1GkPO5doZugwZ3ETF6JUdckk7RGmsfSf7ws8Qb2YB5It33mvNL/0acqA==} engines: {node: '>= 10'} cpu: [x64] os: [linux] - '@next/swc-win32-arm64-msvc@15.4.8': - resolution: {integrity: sha512-KjgeQyOAq7t/HzAJcWPGA8X+4WY03uSCZ2Ekk98S9OgCFsb6lfBE3dbUzUuEQAN2THbwYgFfxX2yFTCMm8Kehw==} + '@next/swc-win32-arm64-msvc@15.5.7': + resolution: {integrity: sha512-CpJVTkYI3ZajQkC5vajM7/ApKJUOlm6uP4BknM3XKvJ7VXAvCqSjSLmM0LKdYzn6nBJVSjdclx8nYJSa3xlTgQ==} engines: {node: '>= 10'} cpu: [arm64] os: [win32] - '@next/swc-win32-x64-msvc@15.4.8': - resolution: {integrity: sha512-Exsmf/+42fWVnLMaZHzshukTBxZrSwuuLKFvqhGHJ+mC1AokqieLY/XzAl3jc/CqhXLqLY3RRjkKJ9YnLPcRWg==} + '@next/swc-win32-x64-msvc@15.5.7': + resolution: {integrity: sha512-gMzgBX164I6DN+9/PGA+9dQiwmTkE4TloBNx8Kv9UiGARsr9Nba7IpcBRA1iTV9vwlYnrE3Uy6I7Aj6qLjQuqw==} engines: {node: '>= 10'} cpu: [x64] os: [win32] @@ -4320,9 +4283,6 @@ packages: '@sinonjs/samsam@8.0.2': resolution: {integrity: sha512-v46t/fwnhejRSFTGqbpn9u+LQ9xJDse10gNnPgAcxgdoCDMXj/G2asWAC/8Qs+BAZDicX+MNZouXT1A7c83kVw==} - '@so-ric/colorspace@1.1.6': - resolution: {integrity: sha512-/KiKkpHNOBgkFJwu9sh48LkHSMYGyuTcSFK/qMBdnOAlrRJzRSXAOFB5qwzaVQuDl8wAvHVMkaASQDReTahxuw==} - '@swc/helpers@0.5.15': resolution: {integrity: sha512-JQ5TuMi45Owi4/BIMAJBoSQoOJu12oOk/gADqlcUL9JEdHB8vyjUSsxqeNXnmXHjYKMi2WcYtezGEEhqUI/E2g==} @@ -4333,8 +4293,8 @@ packages: '@tootallnate/quickjs-emscripten@0.23.0': resolution: {integrity: sha512-C5Mc6rdnsaJDjO3UpGW/CQTHtCKaYlScZTly4JIu97Jxo/odCiH0ITnDXSJPTOrEKk/ycSZ0AOgTmkDtkOsvIA==} - '@tsconfig/node10@1.0.12': - resolution: {integrity: sha512-UCYBaeFvM11aU2y3YPZ//O5Rhj+xKyzy7mvcIoAjASbigy8mHMryP5cK7dgjlz2hWxh1g5pLw084E0a/wlUSFQ==} + '@tsconfig/node10@1.0.11': + resolution: {integrity: sha512-DcRjDCujK/kCk/cUe8Xz8ZSpm8mS3mNNpta+jGCA6USEDfktlNvm1+IuZ9eTcDbNk41BHwpHHeW+N1lKCz4zOw==} '@tsconfig/node12@1.0.11': resolution: {integrity: sha512-cqefuRsh12pWyGsIoBKJA9luFu3mRxCA+ORZvA4ktLSzIuCUtWVxGIuXigEwO5/ywWFMZ2QEGKWvkZG1zDMTag==} @@ -4394,15 +4354,9 @@ packages: '@types/express-serve-static-core@4.17.43': resolution: {integrity: sha512-oaYtiBirUOPQGSWNGPWnzyAFJ0BP3cwvN4oWZQY+zUBwpVIGsKUkpBpSztp74drYcjavs7SKFZ4DX1V2QeN8rg==} - '@types/express-serve-static-core@4.19.7': - resolution: {integrity: sha512-FvPtiIf1LfhzsaIXhv/PHan/2FeQBbtBDtfX2QfvPxdUelMDEckK08SM6nqo1MIZY3RUlfA+HV8+hFUSio78qg==} - '@types/express@4.17.23': resolution: {integrity: sha512-Crp6WY9aTYP3qPi2wGDo9iUe/rceX01UMhnF1jmwDcKCFM6cx7YhGP/Mpr3y9AASpfHixIG0E6azCcL5OcDHsQ==} - '@types/express@4.17.25': - resolution: {integrity: sha512-dVd04UKsfpINUnK0yBoYHDF3xu7xVH4BuDotC/xGuycx4CgbP48X/KF/586bcObxT0HENHXEU8Nqtu6NR+eKhw==} - '@types/graceful-fs@4.1.9': resolution: {integrity: sha512-olP3sd1qOEe5dXTSaFvQG+02VdRXcdytWLAZsAq1PecU8uqQAhkrnbli7DagjtXKW/Bl7YJbUsa8MPcuc8LHEQ==} @@ -4416,9 +4370,6 @@ packages: '@types/http-errors@2.0.4': resolution: {integrity: sha512-D0CFMMtydbJAegzOyHjtiKPLlvnm3iTZyZRSZoLq2mRhDdmLfIWOCYPfQJ4cu2erKghU++QvjcUjp/5h7hESpA==} - '@types/http-errors@2.0.5': - resolution: {integrity: sha512-r8Tayk8HJnX0FztbZN7oVqGccWgw98T/0neJphO91KkmOzug1KkofZURD4UaD5uH8AqcFLfdPErnBod0u71/qg==} - '@types/istanbul-lib-coverage@2.0.6': resolution: {integrity: sha512-2QF/t/auWm0lsy8XtKVPG19v3sSOQlJe/YHZgfjb/KBBHOGSV+J2q/S671rcq9uTBrLAXmZpqJiaQbMT+zNU1w==} @@ -4473,18 +4424,9 @@ packages: '@types/node@20.19.1': resolution: {integrity: sha512-jJD50LtlD2dodAEO653i3YF04NWak6jN3ky+Ri3Em3mGR39/glWiboM/IePaRbgwSfqM1TpGXfAg8ohn/4dTgA==} - '@types/node@20.19.25': - resolution: {integrity: sha512-ZsJzA5thDQMSQO788d7IocwwQbI8B5OPzmqNvpf3NY/+MHDAS759Wo0gd2WQeXYt5AAAQjzcrTVC6SKCuYgoCQ==} - - '@types/node@20.19.26': - resolution: {integrity: sha512-0l6cjgF0XnihUpndDhk+nyD3exio3iKaYROSgvh/qSevPXax3L8p5DBRFjbvalnwatGgHEQn2R88y2fA3g4irg==} - '@types/node@22.15.32': resolution: {integrity: sha512-3jigKqgSjsH6gYZv2nEsqdXfZqIFGAV36XYYjf9KGZ3PSG+IhLecqPnI310RvjutyMwifE2hhhNEklOUrvx/wA==} - '@types/node@22.19.2': - resolution: {integrity: sha512-LPM2G3Syo1GLzXLGJAKdqoU35XvrWzGJ21/7sgZTUpbkBaOasTj8tjwn6w+hCkqaa1TfJ/w67rJSwYItlJ2mYw==} - '@types/pdf-parse@1.1.5': resolution: {integrity: sha512-kBfrSXsloMnUJOKi25s3+hRmkycHfLK6A09eRGqF/N8BkQoPUmaCr+q8Cli5FnfohEz/rsv82zAiPz/LXtOGhA==} @@ -4494,9 +4436,6 @@ packages: '@types/pg@8.6.1': resolution: {integrity: sha512-1Kc4oAGzAl7uqUStZCDvaLFqZrW9qWSjXOmBfdgyBP5La7Us6Mg4GBvRlSoaZMhQF/zSj1C8CtKMBkoiT8eL8w==} - '@types/qs@6.14.0': - resolution: {integrity: sha512-eOunJqu0K1923aExK6y8p6fsihYEn/BYuQ4g0CxAAgFc4b/ZLN4CrsRZ55srTdqoiLzU2B2evC+apEIxprEzkQ==} - '@types/qs@6.9.14': resolution: {integrity: sha512-5khscbd3SwWMhFqylJBLQ0zIu7c1K6Vz0uBIt915BI3zV0q1nfjRQD3RqSBcPaO6PHEF4ov/t9y89fSiyThlPA==} @@ -4515,15 +4454,6 @@ packages: '@types/send@0.17.4': resolution: {integrity: sha512-x2EM6TJOybec7c52BX0ZspPodMsQUd5L6PRwOunVyVUhXiBSKf3AezDL8Dgvgt5o0UfKNfuA0eMLr2wLT4AiBA==} - '@types/send@0.17.6': - resolution: {integrity: sha512-Uqt8rPBE8SY0RK8JB1EzVOIZ32uqy8HwdxCnoCOsYrvnswqmFZ/k+9Ikidlk/ImhsdvBsloHbAlewb2IEBV/Og==} - - '@types/send@1.2.1': - resolution: {integrity: sha512-arsCikDvlU99zl1g69TcAB3mzZPpxgw0UQnaHeC1Nwb015xp8bknZv5rIfri9xTOcMuaVgvabfIRA7PSZVuZIQ==} - - '@types/serve-static@1.15.10': - resolution: {integrity: sha512-tRs1dB+g8Itk72rlSI2ZrW6vZg0YrLI81iQSTkMmOqnqCaNr/8Ek4VwWcN5vZgCYWbg/JJSGBlUaYGAOP73qBw==} - '@types/serve-static@1.15.5': resolution: {integrity: sha512-PDRk21MnK70hja/YF8AHfC7yIsiQHn1rcXx7ijCFBX/k+XQJhQT/gw3xekXKJvx+5SXaMMS8oqQy09Mzvz2TuQ==} @@ -4661,8 +4591,8 @@ packages: resolution: {integrity: sha512-gKXj5ALrKWQLsYG9jlTRmR/xKluxHV+Z9QEwNIgCfM1/uwPMCuzVVnh5mwTd+OuBZcwSIMbqssNWRm1lE51QaQ==} engines: {node: '>=8'} - ansi-escapes@7.1.1: - resolution: {integrity: sha512-Zhl0ErHcSRUaVfGUeUdDuLgpkEo8KIFjB4Y9uAc46ScOpdDiU1Dbyplh7qWJeJ/ZHpbyMSM26+X3BySgnIz40Q==} + ansi-escapes@7.2.0: + resolution: {integrity: sha512-g6LhBsl+GBPRWGWsBtutpzBYuIIdBkLEvad5C/va/74Db018+5TZiyA26cZJAr3Rft5lprVqOIPxf5Vid6tqAw==} engines: {node: '>=18'} ansi-regex@5.0.1: @@ -4801,8 +4731,8 @@ packages: balanced-match@1.0.2: resolution: {integrity: sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==} - bare-events@2.8.1: - resolution: {integrity: sha512-oxSAxTS1hRfnyit2CL5QpAOS5ixfBjj6ex3yTNvXyY/kE719jQ/IjuESJBK2w5v4wwQRAHGseVJXx9QBYOtFGQ==} + bare-events@2.8.2: + resolution: {integrity: sha512-riJjyv1/mHLIPX4RwiK+oW9/4c3TEUeORHKefKAKnZ5kyslbN+HXowtbaVEqt4IMUB7OXlfixcs6gsFeo/jhiQ==} peerDependencies: bare-abort-controller: '*' peerDependenciesMeta: @@ -4846,10 +4776,6 @@ packages: resolution: {integrity: sha512-7rAxByjUMqQ3/bHJy7D6OGXvx/MMc4IqBn/X0fcM1QUcAItpZrBEYhWGem+tzXH90c+G01ypMcYJBO9Y30203g==} engines: {node: '>= 0.8', npm: 1.2.8000 || >= 1.4.16} - body-parser@1.20.4: - resolution: {integrity: sha512-ZTgYYLMOXY9qKU/57FAo8F+HA2dGX7bqGc71txDRC1rS4frdFI5R7NhluHxH6M0YItAP0sHB4uqAOcYKxO6uGA==} - engines: {node: '>= 0.8', npm: 1.2.8000 || >= 1.4.16} - body-parser@2.2.0: resolution: {integrity: sha512-02qvAaxv8tp7fBa/mw1ga98OGm+eCbqzJOKoRt70sLmfEEi+jyBYVTDGfCL/k06/4EMk/z01gCe7HoCH/f2LTg==} engines: {node: '>=18'} @@ -4961,8 +4887,8 @@ packages: resolution: {integrity: sha512-Gmy6FhYlCY7uOElZUSbxo2UCDH8owEk996gkbrpsgGtrJLM3J7jGxl9Ic7Qwwj4ivOE5AWZWRMecDdF7hqGjFA==} engines: {node: '>=10'} - caniuse-lite@1.0.30001760: - resolution: {integrity: sha512-7AAMPcueWELt1p3mi13HR/LHH0TJLT11cnwDJEs3xA4+CK/PLKeO9Kl1oru24htkyUKtkGCvAx4ohB0Ttry8Dw==} + caniuse-lite@1.0.30001667: + resolution: {integrity: sha512-7LTwJjcRkzKFmtqGsibMeuXmvFDfZq/nzIjnmgCGzKKRVzjD72selLDK1oPF/Oxzmt4fNcPvTDvGqSDG4tCALw==} chalk@2.4.2: resolution: {integrity: sha512-Mti+f9lpJNcwF4tWV8/OrTTtF1gZi+f8FqlyAdouralcFWFQWF2+NgCHShjkCb+IFBLq9buZwE1xckQU4peSuQ==} @@ -5108,34 +5034,18 @@ packages: resolution: {integrity: sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==} engines: {node: '>=7.0.0'} - color-convert@3.1.3: - resolution: {integrity: sha512-fasDH2ont2GqF5HpyO4w0+BcewlhHEZOFn9c1ckZdHpJ56Qb7MHhH/IcJZbBGgvdtwdwNbLvxiBEdg336iA9Sg==} - engines: {node: '>=14.6'} - color-name@1.1.3: resolution: {integrity: sha512-72fSenhMw2HZMTVHeCA9KCmpEIbzWiQsjN+BHcBbS9vr1mtt+vJjPdksIBNUmKAW8TFUDPJK5SUU3QhE9NEXDw==} color-name@1.1.4: resolution: {integrity: sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==} - color-name@2.1.0: - resolution: {integrity: sha512-1bPaDNFm0axzE4MEAzKPuqKWeRaT43U/hyxKPBdqTfmPF+d6n7FSoTFxLVULUJOmiLp01KjhIPPH+HrXZJN4Rg==} - engines: {node: '>=12.20'} - color-string@1.9.1: resolution: {integrity: sha512-shrVawQFojnZv6xM40anx4CkoDP+fZsw/ZerEMsW/pyzsRbElpsL/DBVW7q3ExxwusdNXI3lXpuhEZkzs8p5Eg==} - color-string@2.1.4: - resolution: {integrity: sha512-Bb6Cq8oq0IjDOe8wJmi4JeNn763Xs9cfrBcaylK1tPypWzyoy2G3l90v9k64kjphl/ZJjPIShFztenRomi8WTg==} - engines: {node: '>=18'} - color@3.2.1: resolution: {integrity: sha512-aBl7dZI9ENN6fUGC7mWpMTPNHmWUSNan9tuWN6ahh5ZLNk9baLJOnSMlrQkHcrfFgz2/RigjUVAjdx36VcemKA==} - color@5.0.3: - resolution: {integrity: sha512-ezmVcLR3xAVp8kYOm4GS45ZLLgIE6SPAFoduLr6hTDajwb3KZ2F46gulK3XpcwRFb5KKGCSezCBAY4Dw4HsyXA==} - engines: {node: '>=18'} - colorette@2.0.19: resolution: {integrity: sha512-3tlv/dIP7FWvj3BsbHrGLJ6l/oKh1O3TcgBqMn+yyCagOxc23fyzDS6HypQbgxWbkpDnf52p1LuR4eWDQ/K9WQ==} @@ -5228,9 +5138,6 @@ packages: cookie-signature@1.0.6: resolution: {integrity: sha512-QADzlaHc8icV8I7vbaJXJwod9HWYp8uCqf1xa4OfNu1T7JVxQIrUgOWtHdNDtPiywmFbiS12VjotIXLrKM3orQ==} - cookie-signature@1.0.7: - resolution: {integrity: sha512-NXdYc3dLr47pBkpUCHtKSwIOQXLVn8dZEuywboCOJY/osA0wFSLlSawr3KN8qXJEyX66FcONTH8EIlVuK0yyFA==} - cookie-signature@1.2.2: resolution: {integrity: sha512-D76uU73ulSXrD1UXF4KE2TMxVVwhsnCgfAyTg9k8P6KGZjlXKrOLe4dJQKI3Bxi5wjesZoFXJWElNWBjPZMbhg==} engines: {node: '>=6.6.0'} @@ -5364,15 +5271,6 @@ packages: supports-color: optional: true - debug@4.4.3: - resolution: {integrity: sha512-RGwwWnwQvkVfavKVt22FGLw+xYSdzARwm0ru6DhTVA3umU5hZc28V3kO4stgYryrTlLpuvgI9GiijltAjNbcqA==} - engines: {node: '>=6.0'} - peerDependencies: - supports-color: '*' - peerDependenciesMeta: - supports-color: - optional: true - decamelize@1.2.0: resolution: {integrity: sha512-z2S+W9X73hAUUki+N+9Za2lBlun89zigOyGrsax+KUQ6wKW4ZoWpEYBkGhQjwAjjDCkWxhY0VKEhk8wzY7F5cA==} engines: {node: '>=0.10.0'} @@ -5478,9 +5376,6 @@ packages: dotprompt@1.1.1: resolution: {integrity: sha512-xll31JxDiE7FaF030t0Dx4EMSV60Qn/pONDn6Hs5bBBeEANbtqIu6fPfaAOoSNbF1Y9TK+pj9Xnvud7G7GHpaA==} - dotprompt@1.1.2: - resolution: {integrity: sha512-24EU+eORQbPywBicIP44BiqykzEXFwZq1ZQKO5TEr9KrrENyDA7I1NzqhtmmEdQVfAXka0DEbSLPN5nerCqJ8A==} - dunder-proto@1.0.1: resolution: {integrity: sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==} engines: {node: '>= 0.4'} @@ -5706,10 +5601,6 @@ packages: resolution: {integrity: sha512-28HqgMZAmih1Czt9ny7qr6ek2qddF4FclbMzwhCREB6OFfH+rXAnuNCwo1/wFvrtbgsQDb4kSbX9de9lFbrXnA==} engines: {node: '>= 0.10.0'} - express@4.22.1: - resolution: {integrity: sha512-F2X8g9P1X7uCPZMA3MVf9wcTqlyNp7IhH5qPCI0izhaOIYXaW9L535tGA3qmjRzpH+bZczqq7hVKxTR4NWnu+g==} - engines: {node: '>= 0.10.0'} - express@5.1.0: resolution: {integrity: sha512-DT9ck5YIRU+8GYzzU5kT3eHGA5iL+1Zd0EutOmTE9Dtk+Tvuzd23VBU+ec7HPNSTxXYO55gPV/hq4pSBJDjFpA==} engines: {node: '>= 18'} @@ -5788,10 +5679,6 @@ packages: resolution: {integrity: sha512-6BN9trH7bp3qvnrRyzsBz+g3lZxTNZTbVO2EV1CS0WIcDbawYVdYvGflME/9QP0h0pYlCDBCTjYa9nZzMDpyxQ==} engines: {node: '>= 0.8'} - finalhandler@1.3.2: - resolution: {integrity: sha512-aA4RyPcd3badbdABGDuTXCMTtOneUCAYH/gxoYRTZlIJdF0YPWuGqiAsIrhNnnqdXGswYk6dGujem4w80UJFhg==} - engines: {node: '>= 0.8'} - finalhandler@2.1.0: resolution: {integrity: sha512-/t88Ty3d5JWQbWYgaOGCCYfXRwV1+be02WqYYlL6h0lEiUAMPM8o8qKGO01YIkOHzka2up08wvgYD0mDiI+q3Q==} engines: {node: '>= 0.8'} @@ -5811,8 +5698,8 @@ packages: resolution: {integrity: sha512-Y8DcyKK+4pl4B93ooiy1G8qvdyRMkcNFfBSh+8rbVcw4cW8dgG0VXCCTp5NUwub8sn9vSPsOwpb9tE2OuFmcfQ==} engines: {node: '>=18'} - firebase-admin@13.5.0: - resolution: {integrity: sha512-QZOpv1DJRJpH8NcWiL1xXE10tw3L/bdPFlgjcWrqU3ufyOJDYfxB1MMtxiVTwxK16NlybQbEM6ciSich2uWEIQ==} + firebase-admin@13.6.0: + resolution: {integrity: sha512-GdPA/t0+Cq8p1JnjFRBmxRxAGvF/kl2yfdhALl38PrRp325YxyQ5aNaHui0XmaKcKiGRFIJ/EgBNWFoDP0onjw==} engines: {node: '>=18'} firebase-functions@6.3.2: @@ -5859,8 +5746,8 @@ packages: resolution: {integrity: sha512-ETEklSGi5t0QMZuiXoA/Q6vcnxcLQP5vdugSpuAyi6SVGi2clPPp+xgEhuMaHC+zGgn31Kd235W35f7Hykkaww==} engines: {node: '>= 6'} - form-data@4.0.4: - resolution: {integrity: sha512-KrGhL9Q4zjj0kiUt5OO4Mr/A/jlI2jDYs5eHBpYHPcBEVSiipAvn2Ko2HnPe20rmcuuvMHNdZFp+4IlGTMF0Ow==} + form-data@4.0.5: + resolution: {integrity: sha512-8RipRLol37bNs2bhoV67fiTEvdTrbMUYcFTiy3+wuuOnUog2QBHCZWXDRijWQfAkhBj2Uf5UnVaiWwA5vdd82w==} engines: {node: '>= 6'} formdata-node@4.4.1: @@ -5948,8 +5835,8 @@ packages: resolution: {integrity: sha512-zV/5HKTfCeKWnxG0Dmrw51hEWFGfcF2xiXqcA3+J90WDuP0SvoiSO5ORvcBsifmx/FoIjgQN3oNOGaQ5PhLFkg==} engines: {node: '>=18'} - genkit@1.26.0-rc.0: - resolution: {integrity: sha512-Yx4qtT0ImwE2Nu8ts1lrq4eL/qCa+vFmgNOWnCJLc205Vcco0yZEQ0Wr0OL3sBhIAyLuAfx6CCUPJE735ypTsg==} + genkit@1.27.0: + resolution: {integrity: sha512-54OAzw9+dlOs2H4bWnktMwKVA1wwY9XmudKBAz2uBdWhep5r0xHy1qNE6tUVnSgn+LGGaR/0xfYRSs8uqNPFVw==} genkitx-openai@0.10.1: resolution: {integrity: sha512-E9/DzyQcBUSTy81xT2pvEmdnn9Q/cKoojEt6lD/EdOeinhqE9oa59d/kuXTokCMekTrj3Rk7LtNBQIDjnyjNOA==} @@ -6023,8 +5910,8 @@ packages: engines: {node: '>=16 || 14 >=14.17'} hasBin: true - glob@10.4.5: - resolution: {integrity: sha512-7Bv8RF0k6xjo7d4A/PxYLbUCfb6c+Vpd2/mB2yRDlew7Jb5hEXiCD9ibfO7wpk8i4sevK6DFny9h7EYbM3/sHg==} + glob@10.5.0: + resolution: {integrity: sha512-DfXN8DfhJ7NH3Oe7cFmu3NCu1wKbkReJ8TorzSAFbSKrlNaQSKfIzqYqVY8zlbs2NLBbWpRiU52GX2PbaBVNkg==} hasBin: true glob@11.0.0: @@ -6188,10 +6075,6 @@ packages: resolution: {integrity: sha512-FtwrG/euBzaEjYeRqOgly7G0qviiXoJWnvEH2Z1plBdXgbyjv34pHTSb9zoeHMyDy33+DWy5Wt9Wo+TURtOYSQ==} engines: {node: '>= 0.8'} - http-errors@2.0.1: - resolution: {integrity: sha512-4FbRdAX+bSdmo4AUFuS0WNiPz8NgFt+r8ThgNWmlrjQjt1Q7ZR9+zTlce2859x4KSXrwIsaeTqDoKQmtP8pLmQ==} - engines: {node: '>= 0.8'} - http-parser-js@0.5.10: resolution: {integrity: sha512-Pysuw9XpUq5dVc/2SMHpuTY01RFl8fttgcyunjL7eEMhGM3cI4eOmiCycJDVCo/7O7ClfQD3SaI6ftDzqOXYMA==} @@ -6230,8 +6113,8 @@ packages: resolution: {integrity: sha512-4fCk79wshMdzMp2rH06qWrJE4iolqLhCUH+OiuIgU++RB0+94NlDL81atO7GX55uUKueo0txHNtvEyI6D7WdMw==} engines: {node: '>=0.10.0'} - iconv-lite@0.7.0: - resolution: {integrity: sha512-cf6L2Ds3h57VVmkZe+Pn+5APsT7FpqJtEhhieDCvrE2MK5Qk9MyffgQyuxQTm6BChfeZNtcOLHp9IcWRVcIcBQ==} + iconv-lite@0.7.1: + resolution: {integrity: sha512-2Tth85cXwGFHfvRgZWszZSvdo+0Xsqmw8k8ZwxScfcBneNUraK+dxRxRm24nszx80Y0TVio8kKLt5sLE7ZCLlw==} engines: {node: '>=0.10.0'} idb@7.1.1: @@ -6299,8 +6182,8 @@ packages: resolution: {integrity: sha512-Ju0Bz/cEia55xDwUWEa8+olFpCiQoypjnQySseKtmjNrnps3P+xfpUmGr90T7yjlVJmOtybRvPXhKMbHr+fWnw==} engines: {node: '>= 0.10'} - ip-address@10.0.1: - resolution: {integrity: sha512-NWv9YLW4PoW2B7xtzaS3NCot75m6nK7Icdv0o3lfMceJVRfSoQwqD4wEH5rLwoKJwUiZ/rfpiVBhnaF0FK4HoA==} + ip-address@10.1.0: + resolution: {integrity: sha512-XXADHxXmvT9+CRxhXg56LJovE+bmWnEWB78LB83VZTprKTmaC5QfruXocxzTZ2Kl0DNwKuBdlIhjL8LeY8Sf8Q==} engines: {node: '>= 12'} ip-regex@4.3.0: @@ -6793,9 +6676,6 @@ packages: jwa@2.0.0: resolution: {integrity: sha512-jrZ2Qx916EA+fq9cEAeCROWPTfCwi1IVHqT2tapuqLEVVDKFDENFw1oL+MwrTvH6msKxsd1YTDVw6uKEcsrLEA==} - jwa@2.0.1: - resolution: {integrity: sha512-hRF04fqJIP8Abbkq5NKGN0Bbr3JxlQ+qhZufXVr0DvujKy93ZCbXZMHDL4EOtodSbCWxOqR8MS1tXA5hwqCXDg==} - jwks-rsa@3.1.0: resolution: {integrity: sha512-v7nqlfezb9YfHHzYII3ef2a2j1XnGeSE/bK3WfumaYCqONAIstJbrEGapz4kadScZzEt7zYCN7bucj8C0Mv/Rg==} engines: {node: '>=14'} @@ -6810,9 +6690,6 @@ packages: jws@4.0.0: resolution: {integrity: sha512-KDncfTmOZoOMTFG4mBlG0qUIOlc03fmzH+ru6RgYVZhPkyiy/92Owlt/8UEN+a4TXR1FQetfIpJE8ApdvdVxTg==} - jws@4.0.1: - resolution: {integrity: sha512-EKI/M/yqPncGUUh44xz0PxSidXFr/+r0pA70+gIYhjv+et7yxM+s29Y+VGDkovRofQem0fs7Uvf4+YmAdyRduA==} - kind-of@3.2.2: resolution: {integrity: sha512-NOW9QQXMoZGg/oqnVNoNTTIFEIid1627WCffUBJEdMxYApq7mNE7CpzucIPc+ZQg25Phej7IJSmX3hO+oblOtQ==} engines: {node: '>=0.10.0'} @@ -7446,8 +7323,8 @@ packages: resolution: {integrity: sha512-dBpDMdxv9Irdq66304OLfEmQ9tbNRFnFTuZiLo+bD+r332bBmMJ8GBLXklIXXgxd3+v9+KUnZaUR5PJMa75Gsg==} engines: {node: '>= 0.4.0'} - next@15.4.10: - resolution: {integrity: sha512-itVlc79QjpKMFMRhP+kbGKaSG/gZM6RCvwhEbwmCNF06CdDiNaoHcbeg0PqkEa2GOcn8KJ0nnc7+yL7EjoYLHQ==} + next@15.5.9: + resolution: {integrity: sha512-agNLK89seZEtC5zUHwtut0+tNrc0Xw4FT/Dg+B/VLEo9pAcS9rtTKpek3V6kVcVwsB2YlqMaHdfZL4eLEVYuCg==} engines: {node: ^18.18.0 || ^19.8.0 || >= 20.0.0} hasBin: true peerDependencies: @@ -7496,10 +7373,6 @@ packages: resolution: {integrity: sha512-dPEtOeMvF9VMcYV/1Wb8CPoVAXtp6MKMlcbAt4ddqmGqUJ6fQZFXkNZNkNlfevtNkGtaSoXf/vNNNSvgrdXwtA==} engines: {node: '>= 6.13.0'} - node-forge@1.3.3: - resolution: {integrity: sha512-rLvcdSyRCyouf6jcOIPe/BgwG/d7hKjzMKOas33/pHEr6gbq18IK9zV7DiPvzsz0oBJPme6qr6H6kGZuI9/DZg==} - engines: {node: '>= 6.13.0'} - node-gyp@11.5.0: resolution: {integrity: sha512-ra7Kvlhxn5V9Slyus0ygMa2h+UqExPqUIkfk7Pc8QTLT956JLSy51uWFwHtIYy0vI8cB4BDhc/S03+880My/LQ==} engines: {node: ^18.17.0 || >=20.5.0} @@ -7645,8 +7518,8 @@ packages: resolution: {integrity: sha512-R79ZZ/0wAxKGu3oYMlz8jy/kbhsNrS7SKZ7PxEHBgJ5+F2mtFW2fK2cOtBh1cHYkQsbzFV7I+EoRKe6Yt0oK7A==} engines: {node: '>=8'} - p-map@7.0.3: - resolution: {integrity: sha512-VkndIv2fIB99swvQoA65bm+fsmt6UNdGeIB0oxBs+WhAhdh08QA04JXpI7rbB9r08/nkbysKoya9rtDERYOYMA==} + p-map@7.0.4: + resolution: {integrity: sha512-tkAQEw8ysMzmkhgw8k+1U/iPhWNhykKnSk4Rd5zLoPJCuJaGRPo6YposrZgaxHKzDHdDWWZvE/Sk7hsL2X/CpQ==} engines: {node: '>=18'} p-queue@6.6.2: @@ -7952,8 +7825,8 @@ packages: resolution: {integrity: sha512-RXyHaACeqXeqAKGLDl68rQKbmObRsTIn4TYVUUug1KfS47YWCo5MacGITEryugIgZqORCvJWEk4l449POg5Txg==} engines: {node: '>=12.0.0'} - protobufjs@7.5.4: - resolution: {integrity: sha512-CvexbZtbov6jW2eXAvLukXjXUW1TzFaivC46BpWc/3BpcCysb5Vffu+B3XHMm8lVEuy2Mm4XGex8hBSg1yapPg==} + protobufjs@7.5.3: + resolution: {integrity: sha512-sildjKwVqOI2kmFDiXQ6aEB0fjYTafpEvIBs8tOR8qI4spuL9OPROLVu2qZqi/xgCfsHIwVqlaF8JBjWFHnKbw==} engines: {node: '>=12.0.0'} proxy-addr@2.0.7: @@ -8017,10 +7890,6 @@ packages: resolution: {integrity: sha512-8zGqypfENjCIqGhgXToC8aB2r7YrBX+AQAfIPs/Mlk+BtPTztOvTS01NRW/3Eh60J+a48lt8qsCzirQ6loCVfA==} engines: {node: '>= 0.8'} - raw-body@2.5.3: - resolution: {integrity: sha512-s4VSOf6yN0rvbRZGxs8Om5CWj6seneMwK3oDb4lWDH0UPhWcxwOWw5+qk24bxq87szX1ydrwylIOp2uG1ojUpA==} - engines: {node: '>= 0.8'} - raw-body@3.0.0: resolution: {integrity: sha512-RmkhL8CAyCRPXCE28MMH0z2PNWQBNk2Q09ZdxM9IOOXwxwZbN+qbWaatPkdkWIKL2ZVDImrN/pK5HTRz2PcS4g==} engines: {node: '>= 0.8'} @@ -8029,8 +7898,8 @@ packages: resolution: {integrity: sha512-y3bGgqKj3QBdxLbLkomlohkvsA8gdAiUQlSBJnBhfn+BPxg4bc62d8TcBW15wavDfgexCgccckhcZvywyQYPOw==} hasBin: true - re2@1.22.1: - resolution: {integrity: sha512-E4J0EtgyNLdIr0wTg0dQPefuiqNY29KaLacytiUAYYRzxCG+zOkWoUygt1rI+TA1LrhN49/njrfSO1DHtVC5Vw==} + re2@1.22.3: + resolution: {integrity: sha512-002aE82U91DiaUA16U6vbiJusvPXn1OWiQukOxJkVUTXbzrSuQbFNHYKcGw8QK/uifRCfjl2Hd/vXYDanKkmaQ==} react-dom@18.3.1: resolution: {integrity: sha512-5m4nQKp+rZRb09LNH59GM4BxTh9251/ylbKIbpe7TpGxfJ+9kv6BLkLBXIjjspbgbnIBNqlI23tRnTWT0snUIw==} @@ -8203,10 +8072,6 @@ packages: resolution: {integrity: sha512-e2bDA2WJT0wxseVd4lsDP4+3ONX6HpMXQa1ZhFQ7SU+GjvORCmShbCMltrtIDfkYhVHrOcPtj+KhmDBdPdZD1g==} engines: {node: '>=10'} - safe-stable-stringify@2.5.0: - resolution: {integrity: sha512-b3rppTKm9T+PsVCBEOUR46GWI7fdOs00VKZ1+9c1EWDaDMvjQc6tUwuFyIprgGgTcWoVHSKrU8H31ZHA2e0RHA==} - engines: {node: '>=10'} - safer-buffer@2.1.2: resolution: {integrity: sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==} @@ -8244,10 +8109,6 @@ packages: resolution: {integrity: sha512-dW41u5VfLXu8SJh5bwRmyYUbAoSB3c9uQh6L8h/KtsFREPWpbX1lrljJo186Jc4nmci/sGUZ9a0a0J2zgfq2hw==} engines: {node: '>= 0.8.0'} - send@0.19.1: - resolution: {integrity: sha512-p4rRk4f23ynFEfcD9LA0xRYngj+IyGiEYyqqOak8kaN0TvNmuxC2dcVeBn62GpCeR2CpWqyHCNScTP91QbAVFg==} - engines: {node: '>= 0.8.0'} - send@1.2.0: resolution: {integrity: sha512-uaW0WwXKpL9blXE2o0bRhoL2EGXIrZxQ2ZQ4mgcfoBxdFmQold+qWsD2jLrfZ0trjKL6vOw0j//eAwcALFjKSw==} engines: {node: '>= 18'} @@ -8388,8 +8249,8 @@ packages: sprintf-js@1.0.3: resolution: {integrity: sha512-D9cPgkvLlV3t3IzL0D0YLvGA9Ahk4PcvVwUbN0dSGr1aP0Nrt4AEnTUbuGvquEC0mA64Gqt1fzirlRs5ibXx8g==} - sql-formatter@15.6.10: - resolution: {integrity: sha512-0bJOPQrRO/JkjQhiThVayq0hOKnI1tHI+2OTkmT7TGtc6kqS+V7kveeMzRW+RNQGxofmTmet9ILvztyuxv0cJQ==} + sql-formatter@15.6.12: + resolution: {integrity: sha512-mkpF+RG402P66VMsnQkWewTRzDBWfu9iLbOfxaW/nAKOS/2A9MheQmcU5cmX0D0At9azrorZwpvcBRNNBozACQ==} hasBin: true ssri@12.0.0: @@ -8411,10 +8272,6 @@ packages: resolution: {integrity: sha512-RwNA9Z/7PrK06rYLIzFMlaF+l73iwpzsqRIFgbMLbTcLD6cOao82TaWefPXQvB2fOC4AjuYSEndS7N/mTCbkdQ==} engines: {node: '>= 0.8'} - statuses@2.0.2: - resolution: {integrity: sha512-DvEy55V3DB7uknRo+4iOGT5fP1slR8wQohVdknigZPMpMstaKJQWhwiYBACJE3Ul2pTnATihhBYnRhZQHGBiRw==} - engines: {node: '>= 0.8'} - stop-iteration-iterator@1.1.0: resolution: {integrity: sha512-eLoXW/DHyl62zxY4SCaIgnRhuMr6ri4juEYARS8E6sCEqzKpOiE521Ucofdx+KnDZl5xmvGYaaKCk5FEOxJCoQ==} engines: {node: '>= 0.4'} @@ -9049,10 +8906,6 @@ packages: resolution: {integrity: sha512-DLiFIXYC5fMPxaRg832S6F5mJYvePtmO5G9v9IgUFPhXm9/GkXarH/TUrBAVzhTCzAj9anE/+GjrgXp/54nOgw==} engines: {node: '>= 12.0.0'} - winston@3.19.0: - resolution: {integrity: sha512-LZNJgPzfKR+/J3cHkxcpHKpKKvGfDZVPS4hfJCc4cCG0CgYzvlD6yE/S3CIL/Yt91ak327YCpiF/0MyeZHEHKA==} - engines: {node: '>= 12.0.0'} - wordwrap@1.0.0: resolution: {integrity: sha512-gvVzJFlPycKc5dZN4yPkP8w7Dc37BtP1yczEneOb4uq34pXZcvrtRTmWV8W+Ume+XCxKgbjM+nevkyFPMybd4Q==} @@ -9140,11 +8993,6 @@ packages: engines: {node: '>= 14.6'} hasBin: true - yaml@2.8.2: - resolution: {integrity: sha512-mplynKqc1C2hTVYxd0PU2xQAc22TI1vShAYGksCCfxbn/dFwnHTNi1bvYsBTkhdUNtGIf5xNOg938rrSSYvS9A==} - engines: {node: '>= 14.6'} - hasBin: true - yargs-parser@20.2.9: resolution: {integrity: sha512-y11nGElTIV+CT3Zv9t7VKl+Q3hTQoT9a1Qzezhhl6Rp21gJ/IVTW7Z3y9EWXhuUBC2Shnf+DX0antecpAwSP8w==} engines: {node: '>=10'} @@ -9178,20 +9026,12 @@ packages: peerDependencies: zod: ^3.24.1 - zod-to-json-schema@3.25.0: - resolution: {integrity: sha512-HvWtU2UG41LALjajJrML6uQejQhNJx+JBO9IflpSja4R03iNWfKXrj6W2h7ljuLyc1nKS+9yDyL/9tD1U/yBnQ==} - peerDependencies: - zod: ^3.25 || ^4 - zod@3.22.4: resolution: {integrity: sha512-iC+8Io04lddc+mVqQ9AZ7OQ2MrUKGN+oIQyq1vemgt46jwCwLfhq7/pwnBnNXXXZb8VTVLKwp9EDkx+ryxIWmg==} zod@3.25.67: resolution: {integrity: sha512-idA2YXwpCdqUSKRCACDE6ItZD9TZzy3OZMtpfLoh6oPR47lipysRrJfjzMqFxQ3uJuUPyUeWe1r9vLH33xO/Qw==} - zod@3.25.76: - resolution: {integrity: sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==} - snapshots: '@ampproject/remapping@2.3.0': @@ -9212,11 +9052,11 @@ snapshots: transitivePeerDependencies: - encoding - '@anthropic-ai/sdk@0.68.0(zod@3.25.76)': + '@anthropic-ai/sdk@0.71.2(zod@3.25.67)': dependencies: json-schema-to-ts: 3.1.1 optionalDependencies: - zod: 3.25.76 + zod: 3.25.67 '@anthropic-ai/sdk@0.9.1(encoding@0.1.13)': dependencies: @@ -9466,13 +9306,6 @@ snapshots: enabled: 2.0.0 kuler: 2.0.0 - '@dabh/diagnostics@2.0.8': - dependencies: - '@so-ric/colorspace': 1.1.6 - enabled: 2.0.0 - kuler: 2.0.0 - optional: true - '@electric-sql/pglite@0.2.17': {} '@emnapi/runtime@1.7.1': @@ -9559,8 +9392,7 @@ snapshots: '@fastify/busboy@3.0.0': {} - '@fastify/busboy@3.2.0': - optional: true + '@fastify/busboy@3.1.1': {} '@firebase/ai@1.4.0(@firebase/app-types@0.9.3)(@firebase/app@0.13.1)': dependencies: @@ -9717,12 +9549,6 @@ snapshots: '@firebase/app-types': 0.9.3 '@firebase/util': 1.12.0 - '@firebase/database-types@1.0.16': - dependencies: - '@firebase/app-types': 0.9.3 - '@firebase/util': 1.13.0 - optional: true - '@firebase/database-types@1.0.6': dependencies: '@firebase/app-types': 0.9.2 @@ -9936,20 +9762,15 @@ snapshots: dependencies: tslib: 2.8.1 - '@firebase/util@1.13.0': - dependencies: - tslib: 2.8.1 - optional: true - '@firebase/webchannel-wrapper@1.0.3': {} - '@genkit-ai/ai@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1))': + '@genkit-ai/ai@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1))': dependencies: - '@genkit-ai/core': 1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)) + '@genkit-ai/core': 1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)) '@opentelemetry/api': 1.9.0 - '@types/node': 20.19.26 + '@types/node': 20.19.1 colorette: 2.0.20 - dotprompt: 1.1.2 + dotprompt: 1.1.1 json5: 2.2.3 node-fetch: 3.3.2 partial-json: 0.1.7 @@ -9964,13 +9785,13 @@ snapshots: - supports-color optional: true - '@genkit-ai/ai@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit)': + '@genkit-ai/ai@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit)': dependencies: - '@genkit-ai/core': 1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit) + '@genkit-ai/core': 1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit) '@opentelemetry/api': 1.9.0 - '@types/node': 20.19.26 + '@types/node': 20.19.1 colorette: 2.0.20 - dotprompt: 1.1.2 + dotprompt: 1.1.1 json5: 2.2.3 node-fetch: 3.3.2 partial-json: 0.1.7 @@ -9984,7 +9805,7 @@ snapshots: - genkit - supports-color - '@genkit-ai/core@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1))': + '@genkit-ai/core@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1))': dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/context-async-hooks': 1.25.1(@opentelemetry/api@1.9.0) @@ -9997,16 +9818,16 @@ snapshots: ajv: 8.17.1 ajv-formats: 3.0.1(ajv@8.17.1) async-mutex: 0.5.0 - body-parser: 1.20.4 + body-parser: 1.20.3 cors: 2.8.5 - dotprompt: 1.1.2 - express: 4.22.1 + dotprompt: 1.1.1 + express: 4.21.2 get-port: 5.1.1 json-schema: 0.4.0 - zod: 3.25.76 - zod-to-json-schema: 3.25.0(zod@3.25.76) + zod: 3.25.67 + zod-to-json-schema: 3.24.5(zod@3.25.67) optionalDependencies: - '@genkit-ai/firebase': 1.25.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)) + '@genkit-ai/firebase': 1.16.1(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)) transitivePeerDependencies: - '@google-cloud/firestore' - encoding @@ -10016,7 +9837,7 @@ snapshots: - supports-color optional: true - '@genkit-ai/core@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit)': + '@genkit-ai/core@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit)': dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/context-async-hooks': 1.25.1(@opentelemetry/api@1.9.0) @@ -10029,16 +9850,16 @@ snapshots: ajv: 8.17.1 ajv-formats: 3.0.1(ajv@8.17.1) async-mutex: 0.5.0 - body-parser: 1.20.4 + body-parser: 1.20.3 cors: 2.8.5 - dotprompt: 1.1.2 - express: 4.22.1 + dotprompt: 1.1.1 + express: 4.21.2 get-port: 5.1.1 json-schema: 0.4.0 - zod: 3.25.76 - zod-to-json-schema: 3.25.0(zod@3.25.76) + zod: 3.25.67 + zod-to-json-schema: 3.24.5(zod@3.25.67) optionalDependencies: - '@genkit-ai/firebase': 1.25.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit) + '@genkit-ai/firebase': 1.16.1(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit) transitivePeerDependencies: - '@google-cloud/firestore' - encoding @@ -10047,9 +9868,9 @@ snapshots: - genkit - supports-color - '@genkit-ai/express@1.12.0(@genkit-ai/core@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit))(express@5.1.0)(genkit@genkit)': + '@genkit-ai/express@1.12.0(@genkit-ai/core@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit))(express@5.1.0)(genkit@genkit)': dependencies: - '@genkit-ai/core': 1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit) + '@genkit-ai/core': 1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit) body-parser: 1.20.3 cors: 2.8.5 express: 5.1.0 @@ -10057,25 +9878,12 @@ snapshots: transitivePeerDependencies: - supports-color - '@genkit-ai/firebase@1.16.1(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1))': - dependencies: - '@genkit-ai/google-cloud': 1.16.1(encoding@0.1.13)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)) - '@google-cloud/firestore': 7.11.6(encoding@0.1.13) - firebase-admin: 13.5.0(encoding@0.1.13) - genkit: 1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1) - optionalDependencies: - firebase: 11.9.1 - transitivePeerDependencies: - - encoding - - supports-color - optional: true - - '@genkit-ai/firebase@1.25.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1))': + '@genkit-ai/firebase@1.16.1(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1))': dependencies: - '@genkit-ai/google-cloud': 1.25.0(encoding@0.1.13)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)) - '@google-cloud/firestore': 7.11.6(encoding@0.1.13) - firebase-admin: 13.5.0(encoding@0.1.13) - genkit: 1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1) + '@genkit-ai/google-cloud': 1.16.1(encoding@0.1.13)(genkit@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)) + '@google-cloud/firestore': 7.11.1(encoding@0.1.13) + firebase-admin: 13.6.0(encoding@0.1.13) + genkit: 1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1) optionalDependencies: firebase: 11.9.1 transitivePeerDependencies: @@ -10083,11 +9891,11 @@ snapshots: - supports-color optional: true - '@genkit-ai/firebase@1.25.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit)': + '@genkit-ai/firebase@1.16.1(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit)': dependencies: - '@genkit-ai/google-cloud': 1.25.0(encoding@0.1.13)(genkit@genkit) - '@google-cloud/firestore': 7.11.6(encoding@0.1.13) - firebase-admin: 13.5.0(encoding@0.1.13) + '@genkit-ai/google-cloud': 1.16.1(encoding@0.1.13)(genkit@genkit) + '@google-cloud/firestore': 7.11.1(encoding@0.1.13) + firebase-admin: 13.6.0(encoding@0.1.13) genkit: link:genkit optionalDependencies: firebase: 11.9.1 @@ -10096,7 +9904,7 @@ snapshots: - supports-color optional: true - '@genkit-ai/google-cloud@1.16.1(encoding@0.1.13)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1))': + '@genkit-ai/google-cloud@1.16.1(encoding@0.1.13)(genkit@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1))': dependencies: '@google-cloud/logging-winston': 6.0.1(encoding@0.1.13)(winston@3.17.0) '@google-cloud/opentelemetry-cloud-monitoring-exporter': 0.19.0(@opentelemetry/api@1.9.0)(@opentelemetry/core@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/resources@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-metrics@1.25.1(@opentelemetry/api@1.9.0))(encoding@0.1.13) @@ -10112,7 +9920,7 @@ snapshots: '@opentelemetry/sdk-metrics': 1.25.1(@opentelemetry/api@1.9.0) '@opentelemetry/sdk-node': 0.52.1(@opentelemetry/api@1.9.0) '@opentelemetry/sdk-trace-base': 1.25.1(@opentelemetry/api@1.9.0) - genkit: 1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1) + genkit: 1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1) google-auth-library: 9.15.1(encoding@0.1.13) node-fetch: 3.3.2 winston: 3.17.0 @@ -10121,34 +9929,9 @@ snapshots: - supports-color optional: true - '@genkit-ai/google-cloud@1.25.0(encoding@0.1.13)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1))': + '@genkit-ai/google-cloud@1.16.1(encoding@0.1.13)(genkit@genkit)': dependencies: - '@google-cloud/logging-winston': 6.0.1(encoding@0.1.13)(winston@3.19.0) - '@google-cloud/opentelemetry-cloud-monitoring-exporter': 0.19.0(@opentelemetry/api@1.9.0)(@opentelemetry/core@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/resources@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-metrics@1.25.1(@opentelemetry/api@1.9.0))(encoding@0.1.13) - '@google-cloud/opentelemetry-cloud-trace-exporter': 2.4.1(@opentelemetry/api@1.9.0)(@opentelemetry/core@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/resources@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@1.25.1(@opentelemetry/api@1.9.0))(encoding@0.1.13) - '@google-cloud/opentelemetry-resource-util': 2.4.0(@opentelemetry/resources@1.25.1(@opentelemetry/api@1.9.0))(encoding@0.1.13) - '@opentelemetry/api': 1.9.0 - '@opentelemetry/auto-instrumentations-node': 0.49.2(@opentelemetry/api@1.9.0)(encoding@0.1.13) - '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) - '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/instrumentation-pino': 0.41.0(@opentelemetry/api@1.9.0) - '@opentelemetry/instrumentation-winston': 0.39.0(@opentelemetry/api@1.9.0) - '@opentelemetry/resources': 1.25.1(@opentelemetry/api@1.9.0) - '@opentelemetry/sdk-metrics': 1.25.1(@opentelemetry/api@1.9.0) - '@opentelemetry/sdk-node': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/sdk-trace-base': 1.25.1(@opentelemetry/api@1.9.0) - genkit: 1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1) - google-auth-library: 9.15.1(encoding@0.1.13) - node-fetch: 3.3.2 - winston: 3.19.0 - transitivePeerDependencies: - - encoding - - supports-color - optional: true - - '@genkit-ai/google-cloud@1.25.0(encoding@0.1.13)(genkit@genkit)': - dependencies: - '@google-cloud/logging-winston': 6.0.1(encoding@0.1.13)(winston@3.19.0) + '@google-cloud/logging-winston': 6.0.1(encoding@0.1.13)(winston@3.17.0) '@google-cloud/opentelemetry-cloud-monitoring-exporter': 0.19.0(@opentelemetry/api@1.9.0)(@opentelemetry/core@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/resources@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-metrics@1.25.1(@opentelemetry/api@1.9.0))(encoding@0.1.13) '@google-cloud/opentelemetry-cloud-trace-exporter': 2.4.1(@opentelemetry/api@1.9.0)(@opentelemetry/core@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/resources@1.25.1(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@1.25.1(@opentelemetry/api@1.9.0))(encoding@0.1.13) '@google-cloud/opentelemetry-resource-util': 2.4.0(@opentelemetry/resources@1.25.1(@opentelemetry/api@1.9.0))(encoding@0.1.13) @@ -10165,7 +9948,7 @@ snapshots: genkit: link:genkit google-auth-library: 9.15.1(encoding@0.1.13) node-fetch: 3.3.2 - winston: 3.19.0 + winston: 3.17.0 transitivePeerDependencies: - encoding - supports-color @@ -10250,18 +10033,6 @@ snapshots: - encoding - supports-color - '@google-cloud/firestore@7.11.6(encoding@0.1.13)': - dependencies: - '@opentelemetry/api': 1.9.0 - fast-deep-equal: 3.1.3 - functional-red-black-tree: 1.0.1 - google-gax: 4.6.1(encoding@0.1.13) - protobufjs: 7.5.4 - transitivePeerDependencies: - - encoding - - supports-color - optional: true - '@google-cloud/logging-winston@6.0.1(encoding@0.1.13)(winston@3.17.0)': dependencies: '@google-cloud/logging': 11.0.0(encoding@0.1.13) @@ -10273,18 +10044,6 @@ snapshots: - encoding - supports-color - '@google-cloud/logging-winston@6.0.1(encoding@0.1.13)(winston@3.19.0)': - dependencies: - '@google-cloud/logging': 11.0.0(encoding@0.1.13) - google-auth-library: 9.15.1(encoding@0.1.13) - lodash.mapvalues: 4.6.0 - winston: 3.19.0 - winston-transport: 4.7.0 - transitivePeerDependencies: - - encoding - - supports-color - optional: true - '@google-cloud/logging@11.0.0(encoding@0.1.13)': dependencies: '@google-cloud/common': 5.0.1(encoding@0.1.13) @@ -10407,28 +10166,6 @@ snapshots: - supports-color optional: true - '@google-cloud/storage@7.18.0(encoding@0.1.13)': - dependencies: - '@google-cloud/paginator': 5.0.2 - '@google-cloud/projectify': 4.0.0 - '@google-cloud/promisify': 4.0.0 - abort-controller: 3.0.0 - async-retry: 1.3.3 - duplexify: 4.1.3 - fast-xml-parser: 4.5.3 - gaxios: 6.7.1(encoding@0.1.13) - google-auth-library: 9.15.1(encoding@0.1.13) - html-entities: 2.6.0 - mime: 3.0.0 - p-limit: 3.1.0 - retry-request: 7.0.2(encoding@0.1.13) - teeny-request: 9.0.0(encoding@0.1.13) - uuid: 8.3.2 - transitivePeerDependencies: - - encoding - - supports-color - optional: true - '@google-cloud/vertexai@1.10.0(encoding@0.1.13)': dependencies: google-auth-library: 9.15.1(encoding@0.1.13) @@ -10470,7 +10207,7 @@ snapshots: '@grpc/proto-loader': 0.7.13 '@js-sdsl/ordered-map': 4.4.2 - '@grpc/grpc-js@1.14.2': + '@grpc/grpc-js@1.14.3': dependencies: '@grpc/proto-loader': 0.8.0 '@js-sdsl/ordered-map': 4.4.2 @@ -10491,14 +10228,14 @@ snapshots: dependencies: lodash.camelcase: 4.3.0 long: 5.3.2 - protobufjs: 7.5.4 + protobufjs: 7.5.3 yargs: 17.7.2 '@grpc/proto-loader@0.8.0': dependencies: lodash.camelcase: 4.3.0 long: 5.3.2 - protobufjs: 7.5.4 + protobufjs: 7.5.3 yargs: 17.7.2 '@img/colour@1.0.0': @@ -10598,10 +10335,10 @@ snapshots: '@img/sharp-win32-x64@0.34.5': optional: true - '@inquirer/external-editor@1.0.2(@types/node@20.19.1)': + '@inquirer/external-editor@1.0.3(@types/node@20.19.1)': dependencies: chardet: 2.1.1 - iconv-lite: 0.7.0 + iconv-lite: 0.7.1 optionalDependencies: '@types/node': 20.19.1 @@ -10840,9 +10577,6 @@ snapshots: '@jridgewell/sourcemap-codec@1.5.0': {} - '@jridgewell/sourcemap-codec@1.5.5': - optional: true - '@jridgewell/trace-mapping@0.3.25': dependencies: '@jridgewell/resolve-uri': 3.1.2 @@ -10851,7 +10585,7 @@ snapshots: '@jridgewell/trace-mapping@0.3.9': dependencies: '@jridgewell/resolve-uri': 3.1.2 - '@jridgewell/sourcemap-codec': 1.5.5 + '@jridgewell/sourcemap-codec': 1.5.0 optional: true '@js-sdsl/ordered-map@4.4.2': {} @@ -10907,9 +10641,9 @@ snapshots: dependencies: '@langchain/core': 0.1.63 js-tiktoken: 1.0.11 - openai: 4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.76) - zod: 3.25.76 - zod-to-json-schema: 3.24.5(zod@3.25.76) + openai: 4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67) + zod: 3.25.67 + zod-to-json-schema: 3.24.5(zod@3.25.67) transitivePeerDependencies: - encoding - ws @@ -10919,10 +10653,10 @@ snapshots: '@langchain/core': 0.1.63 js-tiktoken: 1.0.11 - '@mistralai/mistralai-gcp@1.5.0(encoding@0.1.13)(zod@3.25.76)': + '@mistralai/mistralai-gcp@1.5.0(encoding@0.1.13)(zod@3.25.67)': dependencies: google-auth-library: 9.15.1(encoding@0.1.13) - zod: 3.25.76 + zod: 3.25.67 transitivePeerDependencies: - encoding - supports-color @@ -10961,13 +10695,13 @@ snapshots: transitivePeerDependencies: - supports-color - '@modelcontextprotocol/server-filesystem@2025.7.1(zod@3.25.76)': + '@modelcontextprotocol/server-filesystem@2025.7.1(zod@3.25.67)': dependencies: '@modelcontextprotocol/sdk': 1.15.0 diff: 5.2.0 glob: 10.3.12 minimatch: 10.0.1 - zod-to-json-schema: 3.24.5(zod@3.25.76) + zod-to-json-schema: 3.24.5(zod@3.25.67) transitivePeerDependencies: - supports-color - zod @@ -11016,30 +10750,30 @@ snapshots: '@napi-rs/canvas-win32-x64-msvc': 0.1.71 optional: true - '@next/env@15.4.10': {} + '@next/env@15.5.9': {} - '@next/swc-darwin-arm64@15.4.8': + '@next/swc-darwin-arm64@15.5.7': optional: true - '@next/swc-darwin-x64@15.4.8': + '@next/swc-darwin-x64@15.5.7': optional: true - '@next/swc-linux-arm64-gnu@15.4.8': + '@next/swc-linux-arm64-gnu@15.5.7': optional: true - '@next/swc-linux-arm64-musl@15.4.8': + '@next/swc-linux-arm64-musl@15.5.7': optional: true - '@next/swc-linux-x64-gnu@15.4.8': + '@next/swc-linux-x64-gnu@15.5.7': optional: true - '@next/swc-linux-x64-musl@15.4.8': + '@next/swc-linux-x64-musl@15.5.7': optional: true - '@next/swc-win32-arm64-msvc@15.4.8': + '@next/swc-win32-arm64-msvc@15.5.7': optional: true - '@next/swc-win32-x64-msvc@15.4.8': + '@next/swc-win32-x64-msvc@15.5.7': optional: true '@npmcli/agent@3.0.0': @@ -11055,7 +10789,7 @@ snapshots: '@npmcli/fs@4.0.0': dependencies: - semver: 7.7.3 + semver: 7.7.2 optional: true '@opentelemetry/api-logs@0.52.1': @@ -11842,12 +11576,6 @@ snapshots: lodash.get: 4.4.2 type-detect: 4.1.0 - '@so-ric/colorspace@1.1.6': - dependencies: - color: 5.0.3 - text-hex: 1.0.0 - optional: true - '@swc/helpers@0.5.15': dependencies: tslib: 2.8.1 @@ -11856,7 +11584,7 @@ snapshots: '@tootallnate/quickjs-emscripten@0.23.0': {} - '@tsconfig/node10@1.0.12': + '@tsconfig/node10@1.0.11': optional: true '@tsconfig/node12@1.0.11': @@ -11937,14 +11665,6 @@ snapshots: '@types/range-parser': 1.2.7 '@types/send': 0.17.4 - '@types/express-serve-static-core@4.19.7': - dependencies: - '@types/node': 20.19.26 - '@types/qs': 6.14.0 - '@types/range-parser': 1.2.7 - '@types/send': 1.2.1 - optional: true - '@types/express@4.17.23': dependencies: '@types/body-parser': 1.19.5 @@ -11952,14 +11672,6 @@ snapshots: '@types/qs': 6.9.14 '@types/serve-static': 1.15.5 - '@types/express@4.17.25': - dependencies: - '@types/body-parser': 1.19.6 - '@types/express-serve-static-core': 4.19.7 - '@types/qs': 6.14.0 - '@types/serve-static': 1.15.10 - optional: true - '@types/graceful-fs@4.1.9': dependencies: '@types/node': 20.19.1 @@ -11974,9 +11686,6 @@ snapshots: '@types/http-errors@2.0.4': {} - '@types/http-errors@2.0.5': - optional: true - '@types/istanbul-lib-coverage@2.0.6': {} '@types/istanbul-lib-report@3.0.3': @@ -11999,8 +11708,7 @@ snapshots: '@types/jsonwebtoken@9.0.10': dependencies: '@types/ms': 2.1.0 - '@types/node': 20.19.26 - optional: true + '@types/node': 20.19.1 '@types/jsonwebtoken@9.0.6': dependencies: @@ -12016,8 +11724,7 @@ snapshots: '@types/mime@3.0.4': {} - '@types/ms@2.1.0': - optional: true + '@types/ms@2.1.0': {} '@types/mysql@2.15.22': dependencies: @@ -12041,23 +11748,10 @@ snapshots: dependencies: undici-types: 6.21.0 - '@types/node@20.19.25': - dependencies: - undici-types: 6.21.0 - - '@types/node@20.19.26': - dependencies: - undici-types: 6.21.0 - '@types/node@22.15.32': dependencies: undici-types: 6.21.0 - '@types/node@22.19.2': - dependencies: - undici-types: 6.21.0 - optional: true - '@types/pdf-parse@1.1.5': dependencies: '@types/node': 20.19.1 @@ -12072,9 +11766,6 @@ snapshots: pg-protocol: 1.6.0 pg-types: 2.2.0 - '@types/qs@6.14.0': - optional: true - '@types/qs@6.9.14': {} '@types/range-parser@1.2.7': {} @@ -12097,24 +11788,6 @@ snapshots: '@types/mime': 1.3.5 '@types/node': 20.19.1 - '@types/send@0.17.6': - dependencies: - '@types/mime': 1.3.5 - '@types/node': 20.19.26 - optional: true - - '@types/send@1.2.1': - dependencies: - '@types/node': 20.19.26 - optional: true - - '@types/serve-static@1.15.10': - dependencies: - '@types/http-errors': 2.0.5 - '@types/node': 20.19.26 - '@types/send': 0.17.6 - optional: true - '@types/serve-static@1.15.5': dependencies: '@types/http-errors': 2.0.4 @@ -12245,7 +11918,7 @@ snapshots: dependencies: type-fest: 0.21.3 - ansi-escapes@7.1.1: + ansi-escapes@7.2.0: dependencies: environment: 1.1.0 @@ -12276,7 +11949,7 @@ snapshots: archiver-utils@5.0.2: dependencies: - glob: 10.4.5 + glob: 10.5.0 graceful-fs: 4.2.11 is-stream: 2.0.1 lazystream: 1.0.1 @@ -12413,7 +12086,7 @@ snapshots: balanced-match@1.0.2: {} - bare-events@2.8.1: {} + bare-events@2.8.2: {} base-64@0.1.0: {} @@ -12460,23 +12133,6 @@ snapshots: transitivePeerDependencies: - supports-color - body-parser@1.20.4: - dependencies: - bytes: 3.1.2 - content-type: 1.0.5 - debug: 2.6.9 - depd: 2.0.0 - destroy: 1.2.0 - http-errors: 2.0.1 - iconv-lite: 0.4.24 - on-finished: 2.4.1 - qs: 6.14.0 - raw-body: 2.5.3 - type-is: 1.6.18 - unpipe: 1.0.0 - transitivePeerDependencies: - - supports-color - body-parser@2.2.0: dependencies: bytes: 3.1.2 @@ -12521,7 +12177,7 @@ snapshots: browserslist@4.24.0: dependencies: - caniuse-lite: 1.0.30001760 + caniuse-lite: 1.0.30001667 electron-to-chromium: 1.5.33 node-releases: 2.0.18 update-browserslist-db: 1.1.1(browserslist@4.24.0) @@ -12579,13 +12235,13 @@ snapshots: dependencies: '@npmcli/fs': 4.0.0 fs-minipass: 3.0.3 - glob: 10.4.5 + glob: 10.5.0 lru-cache: 10.2.0 minipass: 7.1.2 minipass-collect: 2.0.1 minipass-flush: 1.0.5 minipass-pipeline: 1.2.4 - p-map: 7.0.3 + p-map: 7.0.4 ssri: 12.0.0 tar: 7.5.2 unique-filename: 4.0.0 @@ -12601,7 +12257,7 @@ snapshots: es-define-property: 1.0.0 es-errors: 1.3.0 function-bind: 1.1.2 - get-intrinsic: 1.2.4 + get-intrinsic: 1.3.0 set-function-length: 1.2.2 call-bind@1.0.8: @@ -12624,7 +12280,7 @@ snapshots: camelcase@6.3.0: {} - caniuse-lite@1.0.30001760: {} + caniuse-lite@1.0.30001667: {} chalk@2.4.2: dependencies: @@ -12680,12 +12336,12 @@ snapshots: chownr@3.0.0: optional: true - chromadb@1.8.1(encoding@0.1.13)(openai@4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.76)): + chromadb@1.8.1(encoding@0.1.13)(openai@4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67)): dependencies: cliui: 8.0.1 isomorphic-fetch: 3.0.0(encoding@0.1.13) optionalDependencies: - openai: 4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.76) + openai: 4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67) transitivePeerDependencies: - encoding @@ -12762,39 +12418,20 @@ snapshots: dependencies: color-name: 1.1.4 - color-convert@3.1.3: - dependencies: - color-name: 2.1.0 - optional: true - color-name@1.1.3: {} color-name@1.1.4: {} - color-name@2.1.0: - optional: true - color-string@1.9.1: dependencies: color-name: 1.1.4 simple-swizzle: 0.2.2 - color-string@2.1.4: - dependencies: - color-name: 2.1.0 - optional: true - color@3.2.1: dependencies: color-convert: 1.9.3 color-string: 1.9.1 - color@5.0.3: - dependencies: - color-convert: 3.1.3 - color-string: 2.1.4 - optional: true - colorette@2.0.19: {} colorette@2.0.20: {} @@ -12902,8 +12539,6 @@ snapshots: cookie-signature@1.0.6: {} - cookie-signature@1.0.7: {} - cookie-signature@1.2.2: {} cookie@0.7.1: {} @@ -13033,11 +12668,6 @@ snapshots: dependencies: ms: 2.1.3 - debug@4.4.3: - dependencies: - ms: 2.1.3 - optional: true - decamelize@1.2.0: {} dedent@1.5.3: {} @@ -13122,11 +12752,6 @@ snapshots: handlebars: 4.7.8 yaml: 2.7.0 - dotprompt@1.1.2: - dependencies: - handlebars: 4.7.8 - yaml: 2.8.2 - dunder-proto@1.0.1: dependencies: call-bind-apply-helpers: 1.0.2 @@ -13256,7 +12881,7 @@ snapshots: es-define-property@1.0.0: dependencies: - get-intrinsic: 1.2.4 + get-intrinsic: 1.3.0 es-define-property@1.0.1: {} @@ -13349,7 +12974,7 @@ snapshots: events-universal@1.0.1: dependencies: - bare-events: 2.8.1 + bare-events: 2.8.2 transitivePeerDependencies: - bare-abort-controller @@ -13388,7 +13013,7 @@ snapshots: content-type: 1.0.5 deep-freeze: 0.0.1 events-listener: 1.1.0 - glob: 10.4.5 + glob: 10.5.0 json-ptr: 3.1.1 json-schema-traverse: 1.0.0 lodash: 4.17.21 @@ -13455,42 +13080,6 @@ snapshots: transitivePeerDependencies: - supports-color - express@4.22.1: - dependencies: - accepts: 1.3.8 - array-flatten: 1.1.1 - body-parser: 1.20.4 - content-disposition: 0.5.4 - content-type: 1.0.5 - cookie: 0.7.2 - cookie-signature: 1.0.7 - debug: 2.6.9 - depd: 2.0.0 - encodeurl: 2.0.0 - escape-html: 1.0.3 - etag: 1.8.1 - finalhandler: 1.3.2 - fresh: 0.5.2 - http-errors: 2.0.1 - merge-descriptors: 1.0.3 - methods: 1.1.2 - on-finished: 2.4.1 - parseurl: 1.3.3 - path-to-regexp: 0.1.12 - proxy-addr: 2.0.7 - qs: 6.14.0 - range-parser: 1.2.1 - safe-buffer: 5.2.1 - send: 0.19.1 - serve-static: 1.16.2 - setprototypeof: 1.2.0 - statuses: 2.0.2 - type-is: 1.6.18 - utils-merge: 1.0.1 - vary: 1.1.2 - transitivePeerDependencies: - - supports-color - express@5.1.0: dependencies: accepts: 2.0.0 @@ -13602,18 +13191,6 @@ snapshots: transitivePeerDependencies: - supports-color - finalhandler@1.3.2: - dependencies: - debug: 2.6.9 - encodeurl: 2.0.0 - escape-html: 1.0.3 - on-finished: 2.4.1 - parseurl: 1.3.3 - statuses: 2.0.2 - unpipe: 1.0.0 - transitivePeerDependencies: - - supports-color - finalhandler@2.1.0: dependencies: debug: 4.4.1 @@ -13636,18 +13213,18 @@ snapshots: firebase-admin@12.3.1(encoding@0.1.13): dependencies: - '@fastify/busboy': 3.2.0 + '@fastify/busboy': 3.1.1 '@firebase/database-compat': 1.0.10 - '@firebase/database-types': 1.0.16 - '@types/node': 22.19.2 + '@firebase/database-types': 1.0.14 + '@types/node': 22.15.32 farmhash-modern: 1.1.0 jsonwebtoken: 9.0.2 jwks-rsa: 3.2.0 - node-forge: 1.3.3 + node-forge: 1.3.1 uuid: 10.0.0 optionalDependencies: - '@google-cloud/firestore': 7.11.6(encoding@0.1.13) - '@google-cloud/storage': 7.18.0(encoding@0.1.13) + '@google-cloud/firestore': 7.11.1(encoding@0.1.13) + '@google-cloud/storage': 7.16.0(encoding@0.1.13) transitivePeerDependencies: - encoding - supports-color @@ -13672,9 +13249,9 @@ snapshots: - encoding - supports-color - firebase-admin@13.5.0(encoding@0.1.13): + firebase-admin@13.6.0(encoding@0.1.13): dependencies: - '@fastify/busboy': 3.0.0 + '@fastify/busboy': 3.1.1 '@firebase/database-compat': 2.0.10 '@firebase/database-types': 1.0.14 '@types/node': 22.15.32 @@ -13682,11 +13259,11 @@ snapshots: fast-deep-equal: 3.1.3 google-auth-library: 9.15.1(encoding@0.1.13) jsonwebtoken: 9.0.2 - jwks-rsa: 3.1.0 + jwks-rsa: 3.2.0 node-forge: 1.3.1 uuid: 11.1.0 optionalDependencies: - '@google-cloud/firestore': 7.11.0(encoding@0.1.13) + '@google-cloud/firestore': 7.11.1(encoding@0.1.13) '@google-cloud/storage': 7.16.0(encoding@0.1.13) transitivePeerDependencies: - encoding @@ -13703,13 +13280,13 @@ snapshots: transitivePeerDependencies: - supports-color - firebase-functions@6.3.2(firebase-admin@13.5.0(encoding@0.1.13)): + firebase-functions@6.3.2(firebase-admin@13.6.0(encoding@0.1.13)): dependencies: '@types/cors': 2.8.19 '@types/express': 4.17.23 cors: 2.8.5 express: 4.21.2 - firebase-admin: 13.5.0(encoding@0.1.13) + firebase-admin: 13.6.0(encoding@0.1.13) protobufjs: 7.3.2 transitivePeerDependencies: - supports-color @@ -13740,11 +13317,11 @@ snapshots: exegesis-express: 4.0.0 express: 4.21.2 filesize: 6.4.0 - form-data: 4.0.4 + form-data: 4.0.5 fs-extra: 10.1.0 fuzzy: 0.1.3 gaxios: 6.7.1(encoding@0.1.13) - glob: 10.4.5 + glob: 10.5.0 google-auth-library: 9.15.1(encoding@0.1.13) inquirer: 8.2.7(@types/node@20.19.1) inquirer-autocomplete-prompt: 2.0.1(inquirer@8.2.7(@types/node@20.19.1)) @@ -13769,7 +13346,7 @@ snapshots: proxy-agent: 6.5.0 retry: 0.13.1 semver: 7.7.2 - sql-formatter: 15.6.10 + sql-formatter: 15.6.12 stream-chain: 2.2.5 stream-json: 1.9.1 superstatic: 9.2.0(encoding@0.1.13) @@ -13860,7 +13437,7 @@ snapshots: combined-stream: 1.0.8 mime-types: 2.1.35 - form-data@4.0.4: + form-data@4.0.5: dependencies: asynckit: 0.4.0 combined-stream: 1.0.8 @@ -13986,10 +13563,10 @@ snapshots: transitivePeerDependencies: - supports-color - genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1): + genkit@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1): dependencies: - '@genkit-ai/ai': 1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)) - '@genkit-ai/core': 1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)) + '@genkit-ai/ai': 1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)) + '@genkit-ai/core': 1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)) uuid: 10.0.0 transitivePeerDependencies: - '@google-cloud/firestore' @@ -13999,10 +13576,10 @@ snapshots: - supports-color optional: true - genkitx-openai@0.10.1(@genkit-ai/ai@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit))(@genkit-ai/core@1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit))(encoding@0.1.13)(ws@8.18.3): + genkitx-openai@0.10.1(@genkit-ai/ai@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit))(@genkit-ai/core@1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit))(encoding@0.1.13)(ws@8.18.3): dependencies: - '@genkit-ai/ai': 1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit) - '@genkit-ai/core': 1.26.0-rc.0(@google-cloud/firestore@7.11.6(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.5.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit) + '@genkit-ai/ai': 1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit) + '@genkit-ai/core': 1.27.0(@google-cloud/firestore@7.11.1(encoding@0.1.13))(encoding@0.1.13)(firebase-admin@13.6.0(encoding@0.1.13))(firebase@11.9.1)(genkit@genkit) openai: 4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67) zod: 3.25.67 transitivePeerDependencies: @@ -14091,7 +13668,7 @@ snapshots: minipass: 7.1.2 path-scurry: 1.10.2 - glob@10.4.5: + glob@10.5.0: dependencies: foreground-child: 3.1.1 jackspeak: 3.4.3 @@ -14150,7 +13727,7 @@ snapshots: gaxios: 5.1.3(encoding@0.1.13) gcp-metadata: 5.3.0(encoding@0.1.13) gtoken: 6.1.2(encoding@0.1.13) - jws: 4.0.1 + jws: 4.0.0 lru-cache: 6.0.0 transitivePeerDependencies: - encoding @@ -14189,7 +13766,7 @@ snapshots: google-gax@5.0.6: dependencies: - '@grpc/grpc-js': 1.14.2 + '@grpc/grpc-js': 1.14.3 '@grpc/proto-loader': 0.8.0 duplexify: 4.1.3 google-auth-library: 10.5.0 @@ -14197,7 +13774,7 @@ snapshots: node-fetch: 3.3.2 object-hash: 3.0.0 proto3-json-serializer: 3.0.4 - protobufjs: 7.5.4 + protobufjs: 7.5.3 retry-request: 8.0.2 rimraf: 5.0.10 transitivePeerDependencies: @@ -14207,7 +13784,7 @@ snapshots: google-p12-pem@4.0.1: dependencies: - node-forge: 1.3.3 + node-forge: 1.3.1 optional: true googleapis-common@7.2.0(encoding@0.1.13): @@ -14248,7 +13825,7 @@ snapshots: dependencies: gaxios: 5.1.3(encoding@0.1.13) google-p12-pem: 4.0.1 - jws: 4.0.1 + jws: 4.0.0 transitivePeerDependencies: - encoding - supports-color @@ -14337,14 +13914,6 @@ snapshots: statuses: 2.0.1 toidentifier: 1.0.1 - http-errors@2.0.1: - dependencies: - depd: 2.0.0 - inherits: 2.0.4 - setprototypeof: 1.2.0 - statuses: 2.0.2 - toidentifier: 1.0.1 - http-parser-js@0.5.10: {} http-proxy-agent@5.0.0: @@ -14397,7 +13966,7 @@ snapshots: dependencies: safer-buffer: 2.1.2 - iconv-lite@0.7.0: + iconv-lite@0.7.1: dependencies: safer-buffer: 2.1.2 @@ -14448,7 +14017,7 @@ snapshots: inquirer@8.2.7(@types/node@20.19.1): dependencies: - '@inquirer/external-editor': 1.0.2(@types/node@20.19.1) + '@inquirer/external-editor': 1.0.3(@types/node@20.19.1) ansi-escapes: 4.3.2 chalk: 4.1.2 cli-cursor: 3.1.0 @@ -14477,7 +14046,7 @@ snapshots: interpret@2.2.0: {} - ip-address@10.0.1: {} + ip-address@10.1.0: {} ip-regex@4.3.0: {} @@ -14691,7 +14260,7 @@ snapshots: '@babel/parser': 7.25.7 '@istanbuljs/schema': 0.1.3 istanbul-lib-coverage: 3.2.2 - semver: 7.7.3 + semver: 7.7.2 transitivePeerDependencies: - supports-color @@ -15126,8 +14695,7 @@ snapshots: jose@4.15.5: {} - jose@4.15.9: - optional: true + jose@4.15.9: {} joycon@3.1.1: {} @@ -15222,13 +14790,6 @@ snapshots: ecdsa-sig-formatter: 1.0.11 safe-buffer: 5.2.1 - jwa@2.0.1: - dependencies: - buffer-equal-constant-time: 1.0.1 - ecdsa-sig-formatter: 1.0.11 - safe-buffer: 5.2.1 - optional: true - jwks-rsa@3.1.0: dependencies: '@types/express': 4.17.23 @@ -15242,15 +14803,14 @@ snapshots: jwks-rsa@3.2.0: dependencies: - '@types/express': 4.17.25 + '@types/express': 4.17.23 '@types/jsonwebtoken': 9.0.10 - debug: 4.4.3 + debug: 4.4.1 jose: 4.15.9 limiter: 1.1.5 lru-memoizer: 2.3.0 transitivePeerDependencies: - supports-color - optional: true jws@3.2.2: dependencies: @@ -15262,12 +14822,6 @@ snapshots: jwa: 2.0.0 safe-buffer: 5.2.1 - jws@4.0.1: - dependencies: - jwa: 2.0.1 - safe-buffer: 5.2.1 - optional: true - kind-of@3.2.2: dependencies: is-buffer: 1.1.6 @@ -15297,7 +14851,7 @@ snapshots: kuler@2.0.0: {} - langchain@0.1.37(@google-cloud/storage@7.18.0(encoding@0.1.13))(@pinecone-database/pinecone@2.2.2)(chromadb@1.9.2(encoding@0.1.13)(openai@4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67)))(encoding@0.1.13)(fast-xml-parser@4.5.3)(firebase-admin@12.3.1(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(handlebars@4.7.8)(ignore@5.3.1)(jsonwebtoken@9.0.2)(lodash@4.17.21)(pdf-parse@1.1.1)(pg@8.16.2)(ws@8.18.3): + langchain@0.1.37(@google-cloud/storage@7.16.0(encoding@0.1.13))(@pinecone-database/pinecone@2.2.2)(chromadb@1.9.2(encoding@0.1.13)(openai@4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67)))(encoding@0.1.13)(fast-xml-parser@4.5.3)(firebase-admin@12.3.1(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(handlebars@4.7.8)(ignore@5.3.1)(jsonwebtoken@9.0.2)(lodash@4.17.21)(pdf-parse@1.1.1)(pg@8.16.2)(ws@8.18.3): dependencies: '@anthropic-ai/sdk': 0.9.1(encoding@0.1.13) '@langchain/community': 0.0.53(@pinecone-database/pinecone@2.2.2)(chromadb@1.9.2(encoding@0.1.13)(openai@4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67)))(encoding@0.1.13)(firebase-admin@12.3.1(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(jsonwebtoken@9.0.2)(lodash@4.17.21)(pg@8.16.2)(ws@8.18.3) @@ -15318,7 +14872,7 @@ snapshots: zod: 3.25.67 zod-to-json-schema: 3.24.5(zod@3.25.67) optionalDependencies: - '@google-cloud/storage': 7.18.0(encoding@0.1.13) + '@google-cloud/storage': 7.16.0(encoding@0.1.13) '@pinecone-database/pinecone': 2.2.2 chromadb: 1.9.2(encoding@0.1.13)(openai@4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.67)) fast-xml-parser: 4.5.3 @@ -15545,7 +15099,6 @@ snapshots: dependencies: lodash.clonedeep: 4.5.0 lru-cache: 6.0.0 - optional: true lsofi@1.0.0: dependencies: @@ -15564,7 +15117,7 @@ snapshots: make-dir@4.0.0: dependencies: - semver: 7.7.3 + semver: 7.7.2 make-error@1.3.6: {} @@ -15602,7 +15155,7 @@ snapshots: marked-terminal@7.3.0(marked@13.0.3): dependencies: - ansi-escapes: 7.1.1 + ansi-escapes: 7.2.0 ansi-regex: 6.2.2 chalk: 5.6.2 cli-highlight: 2.1.11 @@ -15816,24 +15369,24 @@ snapshots: netmask@2.0.2: {} - next@15.4.10(@babel/core@7.25.7)(@opentelemetry/api@1.9.0)(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + next@15.5.9(@babel/core@7.25.7)(@opentelemetry/api@1.9.0)(react-dom@18.3.1(react@18.3.1))(react@18.3.1): dependencies: - '@next/env': 15.4.10 + '@next/env': 15.5.9 '@swc/helpers': 0.5.15 - caniuse-lite: 1.0.30001760 + caniuse-lite: 1.0.30001667 postcss: 8.4.31 react: 18.3.1 react-dom: 18.3.1(react@18.3.1) styled-jsx: 5.1.6(@babel/core@7.25.7)(react@18.3.1) optionalDependencies: - '@next/swc-darwin-arm64': 15.4.8 - '@next/swc-darwin-x64': 15.4.8 - '@next/swc-linux-arm64-gnu': 15.4.8 - '@next/swc-linux-arm64-musl': 15.4.8 - '@next/swc-linux-x64-gnu': 15.4.8 - '@next/swc-linux-x64-musl': 15.4.8 - '@next/swc-win32-arm64-msvc': 15.4.8 - '@next/swc-win32-x64-msvc': 15.4.8 + '@next/swc-darwin-arm64': 15.5.7 + '@next/swc-darwin-x64': 15.5.7 + '@next/swc-linux-arm64-gnu': 15.5.7 + '@next/swc-linux-arm64-musl': 15.5.7 + '@next/swc-linux-x64-gnu': 15.5.7 + '@next/swc-linux-x64-musl': 15.5.7 + '@next/swc-win32-arm64-msvc': 15.5.7 + '@next/swc-win32-x64-msvc': 15.5.7 '@opentelemetry/api': 1.9.0 sharp: 0.34.5 transitivePeerDependencies: @@ -15865,9 +15418,6 @@ snapshots: node-forge@1.3.1: {} - node-forge@1.3.3: - optional: true - node-gyp@11.5.0: dependencies: env-paths: 2.2.1 @@ -15876,7 +15426,7 @@ snapshots: make-fetch-happen: 14.0.3 nopt: 8.1.0 proc-log: 5.0.0 - semver: 7.7.3 + semver: 7.7.2 tar: 7.5.2 tinyglobby: 0.2.14 which: 5.0.0 @@ -15990,21 +15540,6 @@ snapshots: transitivePeerDependencies: - encoding - openai@4.104.0(encoding@0.1.13)(ws@8.18.3)(zod@3.25.76): - dependencies: - '@types/node': 18.19.112 - '@types/node-fetch': 2.6.11 - abort-controller: 3.0.0 - agentkeepalive: 4.5.0 - form-data-encoder: 1.7.2 - formdata-node: 4.4.1 - node-fetch: 2.7.0(encoding@0.1.13) - optionalDependencies: - ws: 8.18.3 - zod: 3.25.76 - transitivePeerDependencies: - - encoding - openapi-types@12.1.3: {} openapi3-ts@3.2.0: @@ -16047,7 +15582,7 @@ snapshots: dependencies: p-limit: 2.3.0 - p-map@7.0.3: + p-map@7.0.4: optional: true p-queue@6.6.2: @@ -16265,14 +15800,6 @@ snapshots: tsx: 4.20.3 yaml: 2.8.0 - postcss-load-config@6.0.1(postcss@8.4.47)(tsx@4.20.3)(yaml@2.8.2): - dependencies: - lilconfig: 3.1.2 - optionalDependencies: - postcss: 8.4.47 - tsx: 4.20.3 - yaml: 2.8.2 - postcss@8.4.31: dependencies: nanoid: 3.3.11 @@ -16334,7 +15861,7 @@ snapshots: proto3-json-serializer@3.0.4: dependencies: - protobufjs: 7.5.4 + protobufjs: 7.5.3 protobuf.js@1.1.2: dependencies: @@ -16355,7 +15882,7 @@ snapshots: '@types/node': 20.19.1 long: 5.2.3 - protobufjs@7.5.4: + protobufjs@7.5.3: dependencies: '@protobufjs/aspromise': 1.1.2 '@protobufjs/base64': 1.1.2 @@ -16367,7 +15894,7 @@ snapshots: '@protobufjs/path': 1.1.2 '@protobufjs/pool': 1.1.0 '@protobufjs/utf8': 1.1.0 - '@types/node': 20.19.25 + '@types/node': 20.19.1 long: 5.3.2 proxy-addr@2.0.7: @@ -16439,13 +15966,6 @@ snapshots: iconv-lite: 0.4.24 unpipe: 1.0.0 - raw-body@2.5.3: - dependencies: - bytes: 3.1.2 - http-errors: 2.0.1 - iconv-lite: 0.4.24 - unpipe: 1.0.0 - raw-body@3.0.0: dependencies: bytes: 3.1.2 @@ -16460,7 +15980,7 @@ snapshots: minimist: 1.2.8 strip-json-comments: 2.0.1 - re2@1.22.1: + re2@1.22.3: dependencies: install-artifact-from-github: 1.4.0 nan: 2.24.0 @@ -16702,9 +16222,6 @@ snapshots: safe-stable-stringify@2.4.3: {} - safe-stable-stringify@2.5.0: - optional: true - safer-buffer@2.1.2: {} scheduler@0.23.2: @@ -16725,7 +16242,8 @@ snapshots: semver@7.7.2: {} - semver@7.7.3: {} + semver@7.7.3: + optional: true send@0.19.0: dependencies: @@ -16745,24 +16263,6 @@ snapshots: transitivePeerDependencies: - supports-color - send@0.19.1: - dependencies: - debug: 2.6.9 - depd: 2.0.0 - destroy: 1.2.0 - encodeurl: 2.0.0 - escape-html: 1.0.3 - etag: 1.8.1 - fresh: 0.5.2 - http-errors: 2.0.0 - mime: 1.6.0 - ms: 2.1.3 - on-finished: 2.4.1 - range-parser: 1.2.1 - statuses: 2.0.1 - transitivePeerDependencies: - - supports-color - send@1.2.0: dependencies: debug: 4.4.1 @@ -16934,7 +16434,7 @@ snapshots: socks@2.8.7: dependencies: - ip-address: 10.0.1 + ip-address: 10.1.0 smart-buffer: 4.2.0 sort-any@2.0.0: @@ -16972,7 +16472,7 @@ snapshots: sprintf-js@1.0.3: {} - sql-formatter@15.6.10: + sql-formatter@15.6.12: dependencies: argparse: 2.0.1 nearley: 2.20.1 @@ -16992,8 +16492,6 @@ snapshots: statuses@2.0.1: {} - statuses@2.0.2: {} - stop-iteration-iterator@1.1.0: dependencies: es-errors: 1.3.0 @@ -17145,7 +16643,7 @@ snapshots: router: 2.2.0 update-notifier-cjs: 5.1.7(encoding@0.1.13) optionalDependencies: - re2: 1.22.1 + re2: 1.22.3 transitivePeerDependencies: - encoding - supports-color @@ -17348,7 +16846,7 @@ snapshots: ts-node@10.9.2(@types/node@20.19.1)(typescript@4.9.5): dependencies: '@cspotcode/source-map-support': 0.8.1 - '@tsconfig/node10': 1.0.12 + '@tsconfig/node10': 1.0.11 '@tsconfig/node12': 1.0.11 '@tsconfig/node14': 1.0.3 '@tsconfig/node16': 1.0.4 @@ -17367,7 +16865,7 @@ snapshots: ts-node@10.9.2(@types/node@20.19.1)(typescript@5.8.3): dependencies: '@cspotcode/source-map-support': 0.8.1 - '@tsconfig/node10': 1.0.12 + '@tsconfig/node10': 1.0.11 '@tsconfig/node12': 1.0.11 '@tsconfig/node14': 1.0.3 '@tsconfig/node16': 1.0.4 @@ -17419,35 +16917,7 @@ snapshots: - tsx - yaml - tsup@8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@4.9.5)(yaml@2.8.2): - dependencies: - bundle-require: 5.1.0(esbuild@0.25.5) - cac: 6.7.14 - chokidar: 4.0.3 - consola: 3.4.2 - debug: 4.4.1 - esbuild: 0.25.5 - fix-dts-default-cjs-exports: 1.0.1 - joycon: 3.1.1 - picocolors: 1.1.1 - postcss-load-config: 6.0.1(postcss@8.4.47)(tsx@4.20.3)(yaml@2.8.2) - resolve-from: 5.0.0 - rollup: 4.43.0 - source-map: 0.8.0-beta.0 - sucrase: 3.35.0 - tinyexec: 0.3.2 - tinyglobby: 0.2.14 - tree-kill: 1.2.2 - optionalDependencies: - postcss: 8.4.47 - typescript: 4.9.5 - transitivePeerDependencies: - - jiti - - supports-color - - tsx - - yaml - - tsup@8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@5.8.3)(yaml@2.8.2): + tsup@8.5.0(postcss@8.4.47)(tsx@4.20.3)(typescript@5.8.3)(yaml@2.8.0): dependencies: bundle-require: 5.1.0(esbuild@0.25.5) cac: 6.7.14 @@ -17458,7 +16928,7 @@ snapshots: fix-dts-default-cjs-exports: 1.0.1 joycon: 3.1.1 picocolors: 1.1.1 - postcss-load-config: 6.0.1(postcss@8.4.47)(tsx@4.20.3)(yaml@2.8.2) + postcss-load-config: 6.0.1(postcss@8.4.47)(tsx@4.20.3)(yaml@2.8.0) resolve-from: 5.0.0 rollup: 4.43.0 source-map: 0.8.0-beta.0 @@ -17825,21 +17295,6 @@ snapshots: triple-beam: 1.4.1 winston-transport: 4.9.0 - winston@3.19.0: - dependencies: - '@colors/colors': 1.6.0 - '@dabh/diagnostics': 2.0.8 - async: 3.2.6 - is-stream: 2.0.1 - logform: 2.7.0 - one-time: 1.0.0 - readable-stream: 3.6.2 - safe-stable-stringify: 2.5.0 - stack-trace: 0.0.10 - triple-beam: 1.4.1 - winston-transport: 4.9.0 - optional: true - wordwrap@1.0.0: {} wrap-ansi@6.2.0: @@ -17899,8 +17354,6 @@ snapshots: yaml@2.8.0: {} - yaml@2.8.2: {} - yargs-parser@20.2.9: {} yargs-parser@21.1.1: {} @@ -17940,16 +17393,6 @@ snapshots: dependencies: zod: 3.25.67 - zod-to-json-schema@3.24.5(zod@3.25.76): - dependencies: - zod: 3.25.76 - - zod-to-json-schema@3.25.0(zod@3.25.76): - dependencies: - zod: 3.25.76 - zod@3.22.4: {} zod@3.25.67: {} - - zod@3.25.76: {} diff --git a/js/testapps/anthropic/package.json b/js/testapps/anthropic/package.json index a065321138..0f7ac15001 100644 --- a/js/testapps/anthropic/package.json +++ b/js/testapps/anthropic/package.json @@ -10,6 +10,10 @@ "start:beta": "node lib/beta/basic.js", "dev:stable": "genkit start -- npx tsx --watch src/stable/basic.ts", "dev:beta": "genkit start -- npx tsx --watch src/beta/basic.ts", + "dev:beta:structured-output": "genkit start -- npx tsx --watch src/beta/structured_output.ts", + "dev:beta:files-api": "genkit start -- npx tsx --watch src/beta/files_api.ts", + "dev:beta:effort": "genkit start -- npx tsx --watch src/beta/effort.ts", + "dev:beta:additional-params": "genkit start -- npx tsx --watch src/beta/additional_params.ts", "dev:stable:text-plain": "genkit start -- npx tsx --watch src/stable/text-plain.ts", "dev:stable:webp": "genkit start -- npx tsx --watch src/stable/webp.ts", "dev:stable:pdf": "genkit start -- npx tsx --watch src/stable/pdf.ts", @@ -27,8 +31,8 @@ "author": "", "license": "Apache-2.0", "dependencies": { - "genkit": "workspace:*", - "@genkit-ai/anthropic": "workspace:*" + "@genkit-ai/anthropic": "workspace:*", + "genkit": "workspace:*" }, "devDependencies": { "cross-env": "^10.1.0", diff --git a/js/testapps/anthropic/src/beta/additional_params.ts b/js/testapps/anthropic/src/beta/additional_params.ts new file mode 100644 index 0000000000..2e443fc01a --- /dev/null +++ b/js/testapps/anthropic/src/beta/additional_params.ts @@ -0,0 +1,83 @@ +/** + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { anthropic } from '@genkit-ai/anthropic'; +import { genkit } from 'genkit'; + +const ai = genkit({ + plugins: [ + // Default all flows in this sample to the beta surface + anthropic({ + apiVersion: 'beta', + cacheSystemPrompt: true, + apiKey: process.env.ANTHROPIC_API_KEY, + }), + ], +}); + +const betaOpus45 = anthropic.model('claude-opus-4-5', { apiVersion: 'beta' }); + +ai.defineFlow('anthropic-beta-additional-params', async () => { + const { text } = await ai.generate({ + model: betaOpus45, + prompt: + 'You are Claude on the beta API. Provide a concise greeting that mentions that you are using the beta API.', + config: { + temperature: 0.6, + // Additional param (not directly supported by the plugin, but can be passed through to the API) + betas: ['effort-2025-11-24'], + // Additional param (not directly supported by the plugin, but can be passed through to the API) + output_config: { + effort: 'medium', + }, + }, + }); + + return text; +}); + +ai.defineFlow( + 'anthropic-beta-additional-params-stream', + async (_, { sendChunk }) => { + const { stream } = ai.generateStream({ + model: betaOpus45, + prompt: [ + { + text: 'Outline two experimental capabilities unlocked by the Anthropic beta API.', + }, + ], + config: { + temperature: 0.4, + // Additional param (not directly supported by the plugin, but can be passed through to the API) + betas: ['effort-2025-11-24'], + // Additional param (not directly supported by the plugin, but can be passed through to the API) + output_config: { + effort: 'medium', + }, + }, + }); + + const collected: string[] = []; + for await (const chunk of stream) { + if (chunk.text) { + collected.push(chunk.text); + sendChunk(chunk.text); + } + } + + return collected.join(''); + } +); diff --git a/js/testapps/anthropic/src/beta/basic.ts b/js/testapps/anthropic/src/beta/basic.ts index d1309b3400..f9841f4c6d 100644 --- a/js/testapps/anthropic/src/beta/basic.ts +++ b/js/testapps/anthropic/src/beta/basic.ts @@ -15,12 +15,16 @@ */ import { anthropic } from '@genkit-ai/anthropic'; -import { genkit } from 'genkit'; +import { genkit, z } from 'genkit'; const ai = genkit({ plugins: [ // Default all flows in this sample to the beta surface - anthropic({ apiVersion: 'beta', cacheSystemPrompt: true }), + anthropic({ + apiVersion: 'beta', + cacheSystemPrompt: true, + apiKey: process.env.ANTHROPIC_API_KEY, + }), ], }); @@ -28,15 +32,21 @@ const betaHaiku = anthropic.model('claude-3-5-haiku', { apiVersion: 'beta' }); const betaSonnet = anthropic.model('claude-sonnet-4-5', { apiVersion: 'beta' }); const betaOpus41 = anthropic.model('claude-opus-4-1', { apiVersion: 'beta' }); +const GreetingSchema = z.object({ + greeting: z.string(), + apiVersion: z.string(), +}); + ai.defineFlow('anthropic-beta-hello', async () => { - const { text } = await ai.generate({ + const { output } = await ai.generate({ model: betaHaiku, prompt: 'You are Claude on the beta API. Provide a concise greeting that mentions that you are using the beta API.', config: { temperature: 0.6 }, + output: { schema: GreetingSchema, format: 'json', constrained: true }, }); - return text; + return output; }); ai.defineFlow('anthropic-beta-stream', async (_, { sendChunk }) => { diff --git a/js/testapps/anthropic/src/beta/effort.ts b/js/testapps/anthropic/src/beta/effort.ts new file mode 100644 index 0000000000..03e83a60af --- /dev/null +++ b/js/testapps/anthropic/src/beta/effort.ts @@ -0,0 +1,79 @@ +/** + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { anthropic } from '@genkit-ai/anthropic'; +import { genkit } from 'genkit'; + +const ai = genkit({ + plugins: [ + // Default all flows in this sample to the beta surface + anthropic({ + apiVersion: 'beta', + cacheSystemPrompt: true, + apiKey: process.env.ANTHROPIC_API_KEY, + }), + ], +}); + +const betaOpus45 = anthropic.model('claude-opus-4-5', { apiVersion: 'beta' }); + +ai.defineFlow('anthropic-beta-low-effort', async () => { + const { text } = await ai.generate({ + model: betaOpus45, + prompt: `Create me a Mathematics class using the programming language Python.`, + config: { + maxOutputTokens: 4096, + temperature: 0.6, + output_config: { + effort: 'low', + }, + }, + }); + + return text; +}); + +ai.defineFlow('anthropic-beta-medium-effort', async () => { + const { text } = await ai.generate({ + model: betaOpus45, + prompt: `Create me a Mathematics class using the programming language Python.`, + config: { + maxOutputTokens: 4096, + temperature: 0.6, + output_config: { + effort: 'medium', + }, + }, + }); + + return text; +}); + +ai.defineFlow('anthropic-beta-high-effort', async () => { + const { text } = await ai.generate({ + model: betaOpus45, + prompt: `Create me a Mathematics class using the programming language Python.`, + config: { + maxOutputTokens: 4096, + temperature: 0.6, + output_config: { + effort: 'high', + }, + }, + }); + + return text; +}); diff --git a/js/testapps/anthropic/src/beta/files_api.ts b/js/testapps/anthropic/src/beta/files_api.ts new file mode 100644 index 0000000000..3dd6b08d01 --- /dev/null +++ b/js/testapps/anthropic/src/beta/files_api.ts @@ -0,0 +1,114 @@ +/** + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { anthropic } from '@genkit-ai/anthropic'; +import * as fs from 'fs'; +import { genkit } from 'genkit'; +import * as path from 'path'; + +// Ensure the API key is set. +const API_KEY = process.env.ANTHROPIC_API_KEY; +// If you have a file ID, you can set it here. Otherwise, the flow will upload a new PDF to Anthropic. +const FILE_ID = process.env.ANTHROPIC_FILE_ID; + +export async function uploadPdfToAnthropic() { + if (!API_KEY) throw new Error('Missing ANTHROPIC_API_KEY env variable'); + + // Path to the PDF file to upload + const pdfPath = path.join(__dirname, '../attention-first-page.pdf'); + const fileBuffer = fs.readFileSync(pdfPath); + + const form = new FormData(); + form.append( + 'file', + new Blob([fileBuffer], { type: 'application/pdf' }), + 'attention-first-page.pdf' + ); + + const response = await fetch('https://api.anthropic.com/v1/files', { + method: 'POST', + headers: { + 'x-api-key': API_KEY, + 'anthropic-version': '2023-06-01', + 'anthropic-beta': 'files-api-2025-04-14', + }, + body: form, + }); + + if (!response.ok) { + const text = await response.text(); + throw new Error(`Anthropic file upload failed: ${response.status} ${text}`); + } + const result = await response.json(); + return result as { id: string }; // Contains 'file_id', etc. +} + +async function main() { + const ai = genkit({ + plugins: [ + // Default all flows in this sample to the beta surface + anthropic({ + apiVersion: 'beta', + apiKey: API_KEY, + }), + ], + }); + + /** + * This flow demonstrates PDF document processing via a public data URL along with a user prompt. + * The PDF is sent as a media part with the correct contentType and a URL, not base64. + */ + ai.defineFlow('beta-pdf-url', async () => { + let fileId = FILE_ID; + + if (!fileId) { + const fileResult = await uploadPdfToAnthropic(); + if (!fileResult || !fileResult.id) { + throw new Error('File ID not found'); + } + fileId = fileResult.id; + } + + // Example: Use a (demo/test) PDF file accessible via public URL. + // Replace this with your actual PDF if needed. + const { text } = await ai.generate({ + model: anthropic.model('claude-sonnet-4-5'), + messages: [ + { + role: 'user', + content: [ + { + text: 'What are the key findings or main points in this document?', + }, + { + media: { + url: fileId, + contentType: 'anthropic/file', + }, + }, + ], + }, + ], + }); + + return text; + }); +} + +main().catch((error) => { + console.error('Error:', error); + process.exit(1); +}); diff --git a/js/testapps/anthropic/src/beta/structured_output.ts b/js/testapps/anthropic/src/beta/structured_output.ts new file mode 100644 index 0000000000..dbd3f5ca17 --- /dev/null +++ b/js/testapps/anthropic/src/beta/structured_output.ts @@ -0,0 +1,84 @@ +/** + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { anthropic } from '@genkit-ai/anthropic'; +import { genkit, z } from 'genkit'; + +const ai = genkit({ + plugins: [ + // Default all flows in this sample to the beta surface + anthropic({ + apiVersion: 'beta', + cacheSystemPrompt: true, + apiKey: process.env.ANTHROPIC_API_KEY, + }), + ], +}); + +const betaSonnet = anthropic.model('claude-sonnet-4-5', { apiVersion: 'beta' }); + +ai.defineFlow('anthropic-beta-generate-person-json', async () => { + const { text } = await ai.generate({ + model: betaSonnet, + prompt: + 'Generate a fictional person with a random name, random age, and random city.', + config: { temperature: 0.6 }, + output: { + schema: z.object({ + name: z.string(), + age: z.number(), + city: z.string(), + }), + format: 'json', + constrained: true, + }, + }); + + return text; +}); + +ai.defineFlow( + 'anthropic-beta-generate-person-json-stream', + async (_, { sendChunk }) => { + const { stream } = ai.generateStream({ + model: betaSonnet, + prompt: [ + { + text: 'Generate a fictional person with a random name, random age, and random city.', + }, + ], + config: { temperature: 0.6 }, + output: { + schema: z.object({ + name: z.string(), + age: z.number(), + city: z.string(), + }), + format: 'json', + }, + }); + + const collected: any[] = []; + for await (const chunk of stream) { + if (chunk.text) { + collected.push(chunk.output); + sendChunk(chunk.output); + } + } + + return collected.join(''); + } +); diff --git a/js/testapps/anthropic/src/stable/pdf.ts b/js/testapps/anthropic/src/stable/pdf.ts index 8953dff696..07be5f6657 100644 --- a/js/testapps/anthropic/src/stable/pdf.ts +++ b/js/testapps/anthropic/src/stable/pdf.ts @@ -29,7 +29,7 @@ const ai = genkit({ */ ai.defineFlow('stable-pdf-base64', async () => { // Read PDF file from the same directory as this source file - const pdfPath = path.join(__dirname, 'attention-first-page.pdf'); + const pdfPath = path.join(__dirname, '../attention-first-page.pdf'); const pdfBuffer = fs.readFileSync(pdfPath); const pdfBase64 = pdfBuffer.toString('base64'); diff --git a/js/testapps/basic-gemini/src/index-vertexai.ts b/js/testapps/basic-gemini/src/index-vertexai.ts index 2e6c59e965..2793e91bc5 100644 --- a/js/testapps/basic-gemini/src/index-vertexai.ts +++ b/js/testapps/basic-gemini/src/index-vertexai.ts @@ -40,8 +40,8 @@ ai.defineFlow('basic-hi', async () => { // Gemini 3.0 thinkingLevel config ai.defineFlow( { - name: 'thinking-level', - inputSchema: z.enum(['LOW', 'MEDIUM', 'HIGH']), + name: 'thinking-level-pro', + inputSchema: z.enum(['LOW', 'HIGH']), outputSchema: z.any(), }, async (level) => { @@ -66,6 +66,34 @@ ai.defineFlow( } ); +ai.defineFlow( + { + name: 'thinking-level-flash', + inputSchema: z.enum(['MINIMAL', 'LOW', 'MEDIUM', 'HIGH']), + outputSchema: z.any(), + }, + async (level) => { + const { text } = await ai.generate({ + model: vertexAI.model('gemini-3-flash-preview'), + prompt: + 'Alice, Bob, and Carol each live in a different house on the ' + + 'same street: red, green, and blue. The person who lives in the red house ' + + 'owns a cat. Bob does not live in the green house. Carol owns a dog. The ' + + 'green house is to the left of the red house. Alice does not own a cat. ' + + 'The person in the blue house owns a fish. ' + + 'Who lives in each house, and what pet do they own? Provide your ' + + 'step-by-step reasoning.', + config: { + location: 'global', + thinkingConfig: { + thinkingLevel: level, + }, + }, + }); + return text; + } +); + // Multimodal input ai.defineFlow('multimodal-input', async () => { const photoBase64 = fs.readFileSync('photo.jpg', { encoding: 'base64' }); diff --git a/js/testapps/basic-gemini/src/index.ts b/js/testapps/basic-gemini/src/index.ts index d9bfc9848d..40246ef2b9 100644 --- a/js/testapps/basic-gemini/src/index.ts +++ b/js/testapps/basic-gemini/src/index.ts @@ -80,11 +80,11 @@ ai.defineFlow('basic-hi-with-fallback', async () => { return text; }); -// Gemini 3.0 thinkingLevel config +// Gemini 3.0 thinkingLevel config. Pro can have Low or High ai.defineFlow( { - name: 'thinking-level', - inputSchema: z.enum(['LOW', 'MEDIUM', 'HIGH']), + name: 'thinking-level-pro', + inputSchema: z.enum(['LOW', 'HIGH']), }, async (level) => { const { text } = await ai.generate({ @@ -107,6 +107,33 @@ ai.defineFlow( } ); +// Gemini 3 Flash can have minimal and medium thinking levels too. +ai.defineFlow( + { + name: 'thinking-level-flash', + inputSchema: z.enum(['MINIMAL', 'LOW', 'MEDIUM', 'HIGH']), + }, + async (level) => { + const { text } = await ai.generate({ + model: googleAI.model('gemini-3-flash-preview'), + prompt: + 'Alice, Bob, and Carol each live in a different house on the ' + + 'same street: red, green, and blue. The person who lives in the red house ' + + 'owns a cat. Bob does not live in the green house. Carol owns a dog. The ' + + 'green house is to the left of the red house. Alice does not own a cat. ' + + 'The person in the blue house owns a fish. ' + + 'Who lives in each house, and what pet do they own? Provide your ' + + 'step-by-step reasoning.', + config: { + thinkingConfig: { + thinkingLevel: level, + }, + }, + }); + return text; + } +); + // Multimodal input ai.defineFlow('multimodal-input', async () => { const photoBase64 = fs.readFileSync('photo.jpg', { encoding: 'base64' }); diff --git a/py/bin/sanitize_schema_typing.py b/py/bin/sanitize_schema_typing.py index 6138127e9f..fa74fe5e76 100644 --- a/py/bin/sanitize_schema_typing.py +++ b/py/bin/sanitize_schema_typing.py @@ -42,10 +42,9 @@ import ast import sys -from _ast import AST from datetime import datetime from pathlib import Path -from typing import Type, cast +from typing import Any, Type, cast class ClassTransformer(ast.NodeTransformer): @@ -118,7 +117,18 @@ def has_model_config(self, node: ast.ClassDef) -> ast.Assign | None: return item return None - def visit_ClassDef(self, _node: ast.ClassDef) -> ast.ClassDef: # noqa: N802 + def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign: + """Visit and transform annotated assignment.""" + if isinstance(node.annotation, ast.Name) and node.annotation.id == 'Role': + node.annotation = ast.BinOp( + left=ast.Name(id='Role', ctx=ast.Load()), + op=ast.BitOr(), + right=ast.Name(id='str', ctx=ast.Load()), + ) + self.modified = True + return node + + def visit_ClassDef(self, node: ast.ClassDef) -> Any: """Visit and transform a class definition node. Args: @@ -128,11 +138,16 @@ def visit_ClassDef(self, _node: ast.ClassDef) -> ast.ClassDef: # noqa: N802 The transformed ClassDef node. """ # First apply base class transformations recursively - node = super().generic_visit(_node) + node = cast(ast.ClassDef, super().generic_visit(node)) new_body: list[ast.stmt | ast.Constant | ast.Assign] = [] # Handle Docstrings - if not node.body or not isinstance(node.body[0], ast.Expr) or not isinstance(node.body[0].value, ast.Constant): + if ( + not node.body + or not isinstance(node.body[0], ast.Expr) + or not isinstance(node.body[0].value, ast.Constant) + or not isinstance(node.body[0].value.value, str) + ): # Generate a more descriptive docstring based on class type if self.is_rootmodel_class(node): docstring = f'Root model for {node.name.lower().replace("_", " ")}.' @@ -151,13 +166,21 @@ def visit_ClassDef(self, _node: ast.ClassDef) -> ast.ClassDef: # noqa: N802 # Handle model_config for BaseModel and RootModel existing_model_config_assign = self.has_model_config(node) + existing_model_config_call = None if existing_model_config_assign and isinstance(existing_model_config_assign.value, ast.Call): existing_model_config_call = existing_model_config_assign.value # Determine start index for iterating original body (skip docstring) body_start_index = ( - 1 if (node.body and isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Str)) else 0 + 1 + if ( + node.body + and isinstance(node.body[0], ast.Expr) + and isinstance(node.body[0].value, ast.Constant) + and isinstance(node.body[0].value.value, str) + ) + else 0 ) if self.is_rootmodel_class(node): diff --git a/py/noxfile.py b/py/noxfile.py index 0e4a68ad90..7f715a3043 100644 --- a/py/noxfile.py +++ b/py/noxfile.py @@ -74,5 +74,5 @@ def lint(session: nox.Session) -> None: session.run('uv', 'run', 'ruff', 'format', '--check', '.', external=True) session.log('Running ruff checks') session.run('uv', 'run', 'ruff', 'check', '--preview', '--unsafe-fixes', '--fix', '.', external=True) - # session.log("Running mypy checks") # mypy has many errors currently - # session.run("mypy", external=True) + session.log('Running Ty checks') + session.run('uv', 'run', 'ty', 'check', '.', external=True) diff --git a/py/packages/genkit/pyproject.toml b/py/packages/genkit/pyproject.toml index ab4f17a288..b7921df274 100644 --- a/py/packages/genkit/pyproject.toml +++ b/py/packages/genkit/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -52,7 +51,7 @@ dependencies = [ "anyio>=4.9.0", ] description = "Genkit AI Framework" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "genkit" readme = "README.md" requires-python = ">=3.10" diff --git a/py/packages/genkit/src/genkit/ai/__init__.py b/py/packages/genkit/src/genkit/ai/__init__.py index c0225c3bfd..c120f02c4d 100644 --- a/py/packages/genkit/src/genkit/ai/__init__.py +++ b/py/packages/genkit/src/genkit/ai/__init__.py @@ -47,6 +47,7 @@ GenkitRegistry.__name__, Genkit.__name__, Plugin.__name__, + Plugin.__name__, ToolRunContext.__name__, tool_response.__name__, FlowWrapper.__name__, diff --git a/py/packages/genkit/src/genkit/ai/_aio.py b/py/packages/genkit/src/genkit/ai/_aio.py index 25e5aeca1c..19002e8c23 100644 --- a/py/packages/genkit/src/genkit/ai/_aio.py +++ b/py/packages/genkit/src/genkit/ai/_aio.py @@ -45,7 +45,6 @@ class while customizing it with any plugins. from genkit.core.action.types import ActionKind from genkit.core.typing import ( BaseDataPoint, - BaseEvalDataPoint, EmbedRequest, EmbedResponse, EvalRequest, @@ -351,9 +350,14 @@ async def retrieve( final_options = {**(retriever_config or {}), **(options or {})} - retrieve_action = self.registry.lookup_action(ActionKind.RETRIEVER, retriever_name) + retrieve_action = await self.registry.resolve_action(ActionKind.RETRIEVER, retriever_name) - return (await retrieve_action.arun(RetrieverRequest(query=query, options=final_options))).response + parent_ctx = ActionRunContext._current_context() or {} + action_ctx = dict(parent_ctx) + action_ctx.setdefault('__genkit_ai__', self) + return ( + await retrieve_action.arun(RetrieverRequest(query=query, options=final_options), context=action_ctx) + ).response async def index( self, @@ -383,9 +387,12 @@ async def index( final_options = {**(indexer_config or {}), **(options or {})} - index_action = self.registry.lookup_action(ActionKind.INDEXER, indexer_name) + index_action = await self.registry.resolve_action(ActionKind.INDEXER, indexer_name) - await index_action.arun(IndexerRequest(documents=documents, options=final_options)) + parent_ctx = ActionRunContext._current_context() or {} + action_ctx = dict(parent_ctx) + action_ctx.setdefault('__genkit_ai__', self) + await index_action.arun(IndexerRequest(documents=documents, options=final_options), context=action_ctx) async def embed( self, @@ -410,9 +417,14 @@ async def embed( # Merge options passed to embed() with config from EmbedderRef final_options = {**(embedder_config or {}), **(options or {})} - embed_action = self.registry.lookup_action(ActionKind.EMBEDDER, embedder_name) + embed_action = await self.registry.resolve_action(ActionKind.EMBEDDER, embedder_name) - return (await embed_action.arun(EmbedRequest(input=documents, options=final_options))).response + parent_ctx = ActionRunContext._current_context() or {} + action_ctx = dict(parent_ctx) + action_ctx.setdefault('__genkit_ai__', self) + return ( + await embed_action.arun(EmbedRequest(input=documents, options=final_options), context=action_ctx) + ).response async def evaluate( self, @@ -445,17 +457,21 @@ async def evaluate( final_options = {**(evaluator_config or {}), **(options or {})} - eval_action = self.registry.lookup_action(ActionKind.EVALUATOR, evaluator_name) + eval_action = await self.registry.resolve_action(ActionKind.EVALUATOR, evaluator_name) if not eval_run_id: eval_run_id = str(uuid.uuid4()) + parent_ctx = ActionRunContext._current_context() or {} + action_ctx = dict(parent_ctx) + action_ctx.setdefault('__genkit_ai__', self) return ( await eval_action.arun( EvalRequest( dataset=dataset, options=final_options, eval_run_id=eval_run_id, - ) + ), + context=action_ctx, ) ).response diff --git a/py/packages/genkit/src/genkit/ai/_base.py b/py/packages/genkit/src/genkit/ai/_base.py index edd39e2919..7893af616a 100644 --- a/py/packages/genkit/src/genkit/ai/_base.py +++ b/py/packages/genkit/src/genkit/ai/_base.py @@ -28,8 +28,10 @@ from genkit.aio.loop import create_loop, run_async from genkit.blocks.formats import built_in_formats from genkit.blocks.generate import define_generate_action +from genkit.core.action import ActionMetadata from genkit.core.environment import is_dev_environment from genkit.core.reflection import make_reflection_server +from genkit.core.registry import ActionKind from genkit.web.manager import find_free_port_sync from ._plugin import Plugin @@ -120,22 +122,106 @@ def _initialize_registry(self, model: str | None, plugins: list[Plugin] | None) logger.warning('No plugins provided to Genkit') else: for plugin in plugins: - if isinstance(plugin, Plugin): - plugin.initialize(ai=self) + if not isinstance(plugin, Plugin): + raise ValueError(f'Invalid {plugin=} provided to Genkit: must be of type `genkit.ai.Plugin`') + self._initialize_plugin(plugin) - def resolver(kind, name, plugin=plugin): - return plugin.resolve_action(self, kind, name) + def _initialize_plugin(self, plugin: Plugin) -> None: + """Register a plugin without eagerly initializing it. - def action_resolver(plugin=plugin): - if isinstance(plugin.list_actions, list): - return plugin.list_actions - else: - return plugin.list_actions() + Plugins are registered during Genkit construction, but their + `init()` is only invoked when the plugin is *initialized* (e.g. first + use via action lookup). - self.registry.register_action_resolver(plugin.plugin_name(), resolver) - self.registry.register_list_actions_resolver(plugin.plugin_name(), action_resolver) + This method wires: + - a lazy initializer (calls `plugin.init()` once, on-demand) + - an action resolver (calls `plugin.resolve()` on cache miss) + - a list-actions resolver (calls `plugin.list_actions()` for discovery) + """ + initialized = False + init_lock = threading.Lock() + init_task: asyncio.Task[None] | None = None + + async def ensure_initialized() -> None: + """Initialize the plugin exactly once (async-first). + + This is the JS-style 'initializer promise' pattern: cache a single + task and await it from all concurrent callers. + """ + nonlocal initialized, init_task + if initialized: + return + + with init_lock: + if initialized: + return + if init_task is None: + + async def do_init(): + nonlocal initialized, init_task + try: + resolved_actions = await plugin.init() + for action in resolved_actions: + self._register_action(action, plugin) + initialized = True + finally: + # If init failed, allow retry on next access. + if not initialized: + with init_lock: + init_task = None + + init_task = asyncio.create_task(do_init()) + + # Safe: init_task is set under lock. + await init_task + + async def resolver(kind: ActionKind, name: str): + """Lazy resolver for v2 plugin. + + Called when framework needs an action not returned from init(). + """ + await ensure_initialized() + clean_name = name.removeprefix(f'{plugin.name}/') if name.startswith(f'{plugin.name}/') else name + + action = await plugin.resolve(kind, clean_name) + if action: + self._register_action(action, plugin) + + self.registry.register_action_resolver(plugin.name, resolver) + + async def list_actions_resolver(plugin=plugin): + """List available actions for a plugin (for discovery/devtools). + + Important: This should not force plugin initialization; it should use + lightweight `Plugin.list_actions()` metadata instead. + """ + resolved = await plugin.list_actions() + + namespaced: list[ActionMetadata] = [] + for meta in resolved: + if meta.name.startswith(f'{plugin.name}/'): + namespaced.append(meta) else: - raise ValueError(f'Invalid {plugin=} provided to Genkit: must be of type `genkit.ai.Plugin`') + data = meta.model_dump() + data['name'] = f'{plugin.name}/{meta.name}' + namespaced.append(ActionMetadata(**data)) + return namespaced + + self.registry.register_list_actions_resolver(plugin.name, list_actions_resolver) + + def _register_action(self, action: Any, plugin: Plugin) -> None: + """Register a single action from a plugin. + + Responsibilities: + 1. Add plugin namespace to action name (if not already present) + 2. Register action in the registry + + Args: + action: Action instance from the plugin. + plugin: The plugin that created this action. + """ + # Register the pre-constructed action instance and let the registry apply namespacing + self.registry.register_action_instance(action, namespace=plugin.name) def _initialize_server(self, reflection_server_spec: ServerSpec | None) -> None: """Initialize the server for the Genkit instance. diff --git a/py/packages/genkit/src/genkit/ai/_base_async.py b/py/packages/genkit/src/genkit/ai/_base_async.py index 7229c54642..9baa7be26d 100644 --- a/py/packages/genkit/src/genkit/ai/_base_async.py +++ b/py/packages/genkit/src/genkit/ai/_base_async.py @@ -16,6 +16,8 @@ """Asynchronous server gateway interface implementation for Genkit.""" +import asyncio +import threading from collections.abc import Coroutine from typing import Any, TypeVar @@ -25,8 +27,10 @@ from genkit.aio.loop import run_loop from genkit.blocks.formats import built_in_formats +from genkit.core.action import Action, ActionMetadata from genkit.core.environment import is_dev_environment from genkit.core.reflection import create_reflection_asgi_app +from genkit.core.registry import ActionKind from genkit.web.manager import find_free_port_sync from ._plugin import Plugin @@ -51,7 +55,7 @@ def __init__( """Initialize a new Genkit instance. Args: - plugins: List of plugins to initialize. + plugins: List of plugins to initialize (v1 or v2). model: Model name to use. reflection_server_spec: Server spec for the reflection server. If not provided in dev mode, a default will be used. @@ -60,7 +64,7 @@ def __init__( self._reflection_server_spec = reflection_server_spec self._initialize_registry(model, plugins) - def _initialize_registry(self, model: str | None, plugins: list[Plugin] | None) -> None: + def _initialize_registry(self, model: str | None, plugins: list[Plugin] | None = None) -> None: """Initialize the registry for the Genkit instance. Args: @@ -81,15 +85,97 @@ def _initialize_registry(self, model: str | None, plugins: list[Plugin] | None) logger.warning('No plugins provided to Genkit') else: for plugin in plugins: - if isinstance(plugin, Plugin): - plugin.initialize(ai=self) + if not isinstance(plugin, Plugin): + raise ValueError(f'Invalid {plugin=} provided to Genkit: must be of type `genkit.ai.Plugin`') + logger.debug(f'Registering plugin: {plugin.name}') + self._register_plugin(plugin) - def resolver(kind, name, plugin=plugin): - return plugin.resolve_action(self, kind, name) + def _register_plugin(self, plugin: Plugin) -> None: + """Register a plugin by calling its methods and registering returned actions. - self.registry.register_action_resolver(plugin.plugin_name(), resolver) + Steps: + 1. Call plugin.init() to get resolved actions + 2. Register each action with automatic namespacing + 3. Set up lazy resolver for on-demand actions + + Args: + plugin: V2 plugin instance to register. + """ + initialized = False + init_lock = threading.Lock() + init_task: asyncio.Task[None] | None = None + + async def ensure_initialized() -> None: + """Initialize the plugin exactly once (async-first).""" + nonlocal initialized, init_task + if initialized: + return + + with init_lock: + if initialized: + return + if init_task is None: + + async def do_init(): + nonlocal initialized, init_task + try: + resolved_actions = await plugin.init() + for action in resolved_actions: + self._register_action(action, plugin) + initialized = True + finally: + if not initialized: + with init_lock: + init_task = None + + init_task = asyncio.create_task(do_init()) + + await init_task + + async def resolver(kind: ActionKind, name: str): + """Lazy resolver for v2 plugin. + + Called when framework needs an action not returned from init(). + """ + await ensure_initialized() + clean_name = name.removeprefix(f'{plugin.name}/') if name.startswith(f'{plugin.name}/') else name + action = await plugin.resolve(kind, clean_name) + if action: + self._register_action_v2(action, plugin) + + self.registry.register_action_resolver(plugin.name, resolver) + + async def list_actions_resolver(plugin=plugin): + """List available actions for a plugin (for discovery/devtools).""" + resolved = await plugin.list_actions() + + namespaced: list[ActionMetadata] = [] + for meta in resolved: + if meta.name.startswith(f'{plugin.name}/'): + namespaced.append(meta) else: - raise ValueError(f'Invalid {plugin=} provided to Genkit: must be of type `genkit.ai.Plugin`') + data = meta.model_dump() + data['name'] = f'{plugin.name}/{meta.name}' + namespaced.append(ActionMetadata(**data)) + return namespaced + + self.registry.register_list_actions_resolver(plugin.name, list_actions_resolver) + + def _register_action(self, action: Action, plugin: Plugin) -> None: + """Register a single action from a plugin. + + Responsibilities: + 1. Add plugin namespace to action name (if not already present) + 2. Register action in the registry + + Args: + action: Action instance from the plugin. + plugin: The plugin that created this action. + """ + # Register the pre-constructed action instance and let the registry apply namespacing + self.registry.register_action_instance(action, namespace=plugin.name) + + logger.debug(f'Registered action: {action.name}') def run_main(self, coro: Coroutine[Any, Any, T]) -> T: """Run the user's main coroutine. diff --git a/py/packages/genkit/src/genkit/ai/_plugin.py b/py/packages/genkit/src/genkit/ai/_plugin.py index 691657ca20..6bc5a68e4e 100644 --- a/py/packages/genkit/src/genkit/ai/_plugin.py +++ b/py/packages/genkit/src/genkit/ai/_plugin.py @@ -21,69 +21,146 @@ """ import abc +from collections.abc import Awaitable, Callable +from typing import Any from genkit.core.registry import ActionKind -from ..core.action import ActionMetadata +from ..core.action import Action, ActionMetadata from ._registry import GenkitRegistry +# Type aliases for plugin resolver functions +ActionResolver = Callable[[ActionKind, str], Awaitable[Action | None]] +"""Async function that resolves an action by kind and name.""" + +ListActionsResolver = Callable[[], Awaitable[list[ActionMetadata]]] +"""Async function that returns a list of action metadata for discovery.""" -class Plugin(abc.ABC): - """Abstract base class for implementing Genkit plugins. - This class defines the interface that all plugins must implement. Plugins - provide a way to extend functionality by registering new actions, models, or - other capabilities. +class Plugin(abc.ABC): + """Base class for Genkit plugins that return actions instead of mutating registry. + + Plugins are decoupled from the registry - they create and return Action + objects which the framework then registers. This enables: + - Standalone usage (use plugins without framework) + - Better testability (test plugins in isolation) + + Plugin authors should inherit from this class and implement the required methods. + + Example: + >>> class MyPlugin(Plugin): + ... name = 'myplugin' + ... + ... async def init(self): + ... return [model(name='my-model', fn=self._generate)] + ... + ... async def resolve(self, action_type, name): + ... return model(name=name, fn=self._generate) + ... + ... async def list_actions(self): + ... return [ActionMetadata(name='my-model', kind=ActionKind.MODEL)] """ - def plugin_name(self): - """The name of the plugin. + name: str + """Plugin name (e.g., 'anthropic', 'openai'). Must be set by subclass.""" + + @abc.abstractmethod + async def init(self) -> list[Action]: + """Return eagerly-initialized actions. + + Called once during Genkit initialization. Return actions you want + created immediately (common models, frequently used tools, etc.). Returns: - The name of the plugin. + List of Action objects (not yet registered with any registry). + + Example: + >>> async def init(self): + ... from genkit.blocks.model import model + ... + ... return [ + ... model(name='gpt-4', fn=self._generate), + ... model(name='gpt-4o', fn=self._generate), + ... ] """ - return self.name + ... - # TODO: https://github.com/firebase/genkit/issues/2438 - # @abc.abstractmethod - def resolve_action( # noqa: B027 + @abc.abstractmethod + async def resolve( self, - ai: GenkitRegistry, - kind: ActionKind, + action_type: ActionKind, name: str, - ) -> None: - """Resolves an action by adding it to the provided GenkitRegistry. + ) -> Action | None: + """Resolve a specific action on-demand (lazy loading). + + Called when the framework needs an action that wasn't returned from init(). + Enables lazy loading of less-common models or actions. Args: - ai: The Genkit registry. - kind: The kind of action to resolve. - name: The name of the action to resolve. + action_type: Type of action requested (MODEL, EMBEDDER, TOOL, etc.). + name: Name of the action (WITHOUT plugin prefix - framework strips it). Returns: - None, action resolution is done by side-effect on the registry. + Action object if this plugin can provide it, None if it cannot. + + Example: + >>> async def resolve(self, action_type, name): + ... if action_type == ActionKind.MODEL: + ... if name in SUPPORTED_MODELS: + ... from genkit.blocks.model import model + ... + ... return model(name=name, fn=self._generate) + ... return None """ - pass + ... - @abc.abstractmethod - def initialize(self, ai: GenkitRegistry) -> None: - """Initialize the plugin with the given registry. + async def list_actions(self) -> list[ActionMetadata]: + """List all actions this plugin can provide. - Args: - ai: Registry to register plugin functionality. + Used for discovery, developer tools, and documentation. + Should return metadata for ALL actions the plugin supports, + not just those returned from init(). Returns: - None, initialization is done by side-effect on the registry. + List of ActionMetadata objects (lightweight descriptions). + + Example: + >>> async def list_actions(self): + ... return [ + ... ActionMetadata(name='gpt-4', kind=ActionKind.MODEL, info={'supports': {'vision': True}}), + ... # ... more models + ... ] """ - pass + # Default implementation returns empty (can override) + return [] + + async def model(self, name: str) -> Action: + r"""Convenience method to get a specific model action. - def list_actions(self) -> list[ActionMetadata]: - """Generate a list of available actions or models. + Enables clean standalone usage: + plugin = SomePlugin() + model = await plugin.model('model-name') + response = await model.arun(...) + + Args: + name: Model name (without plugin prefix). Returns: - list[ActionMetadata]: A list of ActionMetadata objects, each with the following attributes: - - name (str): The name of the action or model. - - kind (ActionKind): The type or category of the action. - - info (dict): The metadata dictionary describing the model configuration and properties. - - config_schema (type): The schema class used for validating the model's configuration. + Action for the specified model. + + Raises: + ValueError: If the model is not supported by this plugin. + + Example: + >>> async def model(self, name: str) -> Action: + ... action = await self.resolve(ActionKind.MODEL, name) + ... if not action: + ... raise ValueError(f\"Model {name} not found\") + ... return action """ - return [] + # Call the async resolve method + action = await self.resolve(ActionKind.MODEL, name) + + if not action: + raise ValueError(f"Model '{name}' not found in plugin '{self.name}'") + return action diff --git a/py/packages/genkit/src/genkit/ai/_registry.py b/py/packages/genkit/src/genkit/ai/_registry.py index 8d62249981..51b8a5c08f 100644 --- a/py/packages/genkit/src/genkit/ai/_registry.py +++ b/py/packages/genkit/src/genkit/ai/_registry.py @@ -30,6 +30,7 @@ | `'indexer'` | Indexer | | `'model'` | Model | | `'prompt'` | Prompt | +| `'resource'` | Resource | | `'retriever'` | Retriever | | `'text-llm'` | Text LLM | | `'tool'` | Tool | @@ -42,7 +43,10 @@ import uuid from collections.abc import AsyncIterator, Callable from functools import wraps -from typing import Any, Type +from typing import TYPE_CHECKING, Any, Callable, Type + +if TYPE_CHECKING: + from genkit.blocks.resource import ResourceFn, ResourceOptions import structlog from pydantic import BaseModel @@ -53,9 +57,18 @@ from genkit.blocks.model import ModelFn, ModelMiddleware from genkit.blocks.prompt import ( define_helper, + define_partial, define_prompt, lookup_prompt, ) +from genkit.blocks.reranker import ( + RankedDocument, + RerankerFn, + RerankerOptions, + RerankerRef, + define_reranker as define_reranker_block, + rerank as rerank_block, +) from genkit.blocks.retriever import IndexerFn, RetrieverFn from genkit.blocks.tools import ToolRunContext from genkit.codec import dump_dict @@ -65,6 +78,7 @@ from genkit.core.schema import to_json_schema from genkit.core.tracing import run_in_new_span from genkit.core.typing import ( + DocumentData, EvalFnResponse, EvalRequest, EvalResponse, @@ -181,6 +195,18 @@ def define_helper(self, name: str, fn: Callable) -> None: """ define_helper(self.registry, name, fn) + def define_partial(self, name: str, source: str) -> None: + """Define a Handlebars partial template in the registry. + + Partials are reusable template fragments that can be included + in other prompts using {{>partialName}} syntax. + + Args: + name: The name of the partial. + source: The template source code for the partial. + """ + define_partial(self.registry, name, source) + def tool(self, name: str | None = None, description: str | None = None) -> Callable[[Callable], Callable]: """Decorator to register a function as a tool. @@ -326,6 +352,100 @@ def define_indexer( description=indexer_description, ) + def define_reranker( + self, + name: str, + fn: RerankerFn, + config_schema: BaseModel | dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + description: str | None = None, + ) -> Action: + """Define a reranker action. + + Rerankers reorder documents based on their relevance to a query. + They are commonly used in RAG pipelines to improve retrieval quality. + + Args: + name: Name of the reranker. + fn: Function implementing the reranker behavior. Should accept + (query_doc, documents, options) and return RerankerResponse. + config_schema: Optional schema for reranker configuration. + metadata: Optional metadata for the reranker. + description: Optional description for the reranker. + + Returns: + The registered Action for the reranker. + + Example: + >>> async def my_reranker(query, docs, options): + ... # Score documents based on relevance to query + ... scored = [(doc, compute_score(query, doc)) for doc in docs] + ... scored.sort(key=lambda x: x[1], reverse=True) + ... return RerankerResponse(documents=[...]) + >>> ai.define_reranker('my-reranker', my_reranker) + """ + reranker_meta = metadata.copy() if metadata else {} + if 'reranker' not in reranker_meta: + reranker_meta['reranker'] = {} + if 'label' not in reranker_meta['reranker'] or not reranker_meta['reranker']['label']: + reranker_meta['reranker']['label'] = name + if config_schema: + reranker_meta['reranker']['customOptions'] = to_json_schema(config_schema) + + reranker_description = get_func_description(fn, description) + return define_reranker_block( + self.registry, + name=name, + fn=fn, + options=RerankerOptions( + config_schema=reranker_meta['reranker'].get('customOptions'), + label=reranker_meta['reranker'].get('label'), + ), + ) + + async def rerank( + self, + reranker: str | Action | RerankerRef, + query: str | DocumentData, + documents: list[DocumentData], + options: Any | None = None, + ) -> list[RankedDocument]: + """Rerank documents based on their relevance to a query. + + This method takes a query and a list of documents, and returns the + documents reordered by relevance as determined by the specified reranker. + + Args: + reranker: The reranker to use - can be a name string, Action, or RerankerRef. + query: The query to rank documents against - can be a string or DocumentData. + documents: The list of documents to rerank. + options: Optional configuration options for this rerank call. + + Returns: + A list of RankedDocument objects sorted by relevance score. + + Raises: + ValueError: If the reranker cannot be resolved. + + Example: + >>> ranked_docs = await ai.rerank( + ... reranker='my-reranker', + ... query='What is machine learning?', + ... documents=[doc1, doc2, doc3], + ... ) + >>> for doc in ranked_docs: + ... print(f'Score: {doc.score}, Text: {doc.text()}') + """ + return await rerank_block( + self.registry, + { + 'reranker': reranker, + 'query': query, + 'documents': documents, + 'options': options, + }, + ) + def define_evaluator( self, name: str, @@ -488,7 +608,7 @@ def define_model( self, name: str, fn: ModelFn, - config_schema: Type[BaseModel] | dict[str, Any] | None = None, + config_schema: type[BaseModel] | dict[str, Any] | None = None, metadata: dict[str, Any] | None = None, info: ModelInfo | None = None, description: str | None = None, @@ -573,14 +693,15 @@ def define_format(self, format: FormatDef) -> None: def define_prompt( self, + name: str | None = None, variant: str | None = None, model: str | None = None, config: GenerationCommonConfig | dict[str, Any] | None = None, description: str | None = None, input_schema: type | dict[str, Any] | None = None, - system: str | Part | list[Part] | None = None, - prompt: str | Part | list[Part] | None = None, - messages: str | list[Message] | None = None, + system: str | Part | list[Part] | Callable | None = None, + prompt: str | Part | list[Part] | Callable | None = None, + messages: str | list[Message] | Callable | None = None, output_format: str | None = None, output_content_type: str | None = None, output_instructions: bool | str | None = None, @@ -592,12 +713,12 @@ def define_prompt( tools: list[str] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, - # TODO: - # docs: list[Document] + docs: list[DocumentData] | Callable | None = None, ): """Define a prompt. Args: + name: Optional name for the prompt. variant: Optional variant name for the prompt. model: Optional model name to use for the prompt. config: Optional configuration for the model. @@ -619,9 +740,11 @@ def define_prompt( tools: Optional list of tools to use for the prompt. tool_choice: Optional tool choice for the prompt. use: Optional list of model middlewares to use for the prompt. + docs: Optional list of documents or a callable to be used for grounding. """ return define_prompt( self.registry, + name=name, variant=variant, model=model, config=config, @@ -641,6 +764,7 @@ def define_prompt( tools=tools, tool_choice=tool_choice, use=use, + docs=docs, ) async def prompt( @@ -668,13 +792,58 @@ async def prompt( Raises: GenkitError: If the prompt is not found. """ - return await lookup_prompt( registry=self.registry, name=name, variant=variant, ) + def define_resource( + self, + opts: 'ResourceOptions | None' = None, + fn: 'ResourceFn | None' = None, + *, + name: str | None = None, + uri: str | None = None, + template: str | None = None, + description: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> Action: + """Define a resource action. + + Args: + opts: Options defining the resource (e.g. uri, template, name). + fn: Function implementing the resource behavior. + name: Optional name for the resource. + uri: Optional URI for the resource. + template: Optional URI template for the resource. + description: Optional description for the resource. + metadata: Optional metadata for the resource. + + Returns: + The registered Action for the resource. + """ + from genkit.blocks.resource import ( + define_resource as define_resource_block, + ) + + if fn is None: + raise ValueError('A function `fn` must be provided to define a resource.') + if opts is None: + opts = {} + if name: + opts['name'] = name + if uri: + opts['uri'] = uri + if template: + opts['template'] = template + if description: + opts['description'] = description + if metadata: + opts['metadata'] = metadata + + return define_resource_block(self.registry, opts, fn) + class FlowWrapper: """A wapper for flow functions to add `stream` method.""" diff --git a/py/packages/genkit/src/genkit/aio/_compat.py b/py/packages/genkit/src/genkit/aio/_compat.py index 0c0d6b7fc8..04a0538915 100644 --- a/py/packages/genkit/src/genkit/aio/_compat.py +++ b/py/packages/genkit/src/genkit/aio/_compat.py @@ -26,7 +26,6 @@ """ import asyncio -import sys from typing import TypeVar T = TypeVar('T') @@ -52,11 +51,8 @@ async def wait_for_310(fut: asyncio.Future[T], timeout: float | None = None) -> """ try: return await asyncio.wait_for(fut, timeout) - except asyncio.TimeoutError as e: + except TimeoutError as e: raise TimeoutError() from e -if sys.version_info < (3, 11): - wait_for = wait_for_310 -else: - wait_for = asyncio.wait_for +wait_for = asyncio.wait_for diff --git a/py/packages/genkit/src/genkit/blocks/embedding.py b/py/packages/genkit/src/genkit/blocks/embedding.py index 95bca321a2..16913ade4c 100644 --- a/py/packages/genkit/src/genkit/blocks/embedding.py +++ b/py/packages/genkit/src/genkit/blocks/embedding.py @@ -16,14 +16,14 @@ """Embedding actions.""" -from collections.abc import Awaitable, Callable +from collections.abc import Callable from typing import Any from pydantic import BaseModel, ConfigDict, Field -from genkit.core.action import ActionMetadata +from genkit.core.action import Action, ActionMetadata from genkit.core.action.types import ActionKind -from genkit.core.schema import to_json_schema +from genkit.core.schema import get_func_description, to_json_schema from genkit.core.typing import EmbedRequest, EmbedResponse @@ -60,6 +60,64 @@ class EmbedderRef(BaseModel): EmbedderFn = Callable[[EmbedRequest], EmbedResponse] +def embedder( + name: str, + fn: EmbedderFn, + options: EmbedderOptions | None = None, + metadata: dict[str, Any] | None = None, + description: str | None = None, +) -> 'Action': + """Create an embedder action WITHOUT registering it. + + This is the v2 API for creating embedders. Returns an Action instance + that can be used standalone or registered by the framework. + + Args: + name: Embedder name (without plugin prefix). + fn: Function that implements embedding (takes EmbedRequest, returns EmbedResponse). + options: Optional embedder options (dimensions, supports, etc.). + metadata: Optional metadata dictionary. + description: Optional human-readable description. + + Returns: + Action instance (not registered). + + Example: + >>> from genkit.blocks.embedding import embedder + >>> + >>> def my_embed(request: EmbedRequest) -> EmbedResponse: + ... return EmbedResponse(...) + >>> + >>> action = embedder(name='my-embedder', fn=my_embed) + >>> response = await action.arun({'input': [...]}) + """ + embedder_meta = metadata if metadata else {} + + if 'embedder' not in embedder_meta: + embedder_meta['embedder'] = {} + + if 'label' not in embedder_meta['embedder'] or not embedder_meta['embedder']['label']: + embedder_meta['embedder']['label'] = name + + if options: + if options.dimensions: + embedder_meta['embedder']['dimensions'] = options.dimensions + if options.config_schema: + embedder_meta['embedder']['customOptions'] = options.config_schema + if options.supports: + embedder_meta['embedder']['supports'] = options.supports.model_dump(exclude_none=True, by_alias=True) + + final_description = description if description else get_func_description(fn) + + return Action( + name=name, + kind=ActionKind.EMBEDDER, + fn=fn, + metadata=embedder_meta, + description=final_description, + ) + + def embedder_action_metadata( name: str, options: EmbedderOptions | None = None, diff --git a/py/packages/genkit/src/genkit/blocks/generate.py b/py/packages/genkit/src/genkit/blocks/generate.py index 754af90d4e..5982589d98 100644 --- a/py/packages/genkit/src/genkit/blocks/generate.py +++ b/py/packages/genkit/src/genkit/blocks/generate.py @@ -32,7 +32,7 @@ from genkit.blocks.tools import ToolInterruptError from genkit.codec import dump_dict from genkit.core.action import ActionRunContext -from genkit.core.error import GenkitError, StatusName +from genkit.core.error import GenkitError from genkit.core.registry import Action, ActionKind, Registry from genkit.core.typing import ( GenerateActionOptions, @@ -97,7 +97,7 @@ async def generate_action( Returns: The generated response. """ - model, tools, format_def = resolve_parameters(registry, raw_request) + model, tools, format_def = await resolve_parameters(registry, raw_request) raw_request, formatter = apply_format(raw_request, format_def) @@ -350,11 +350,7 @@ def apply_format( raw_request.output.instructions if raw_request.output else None, ) - if ( - format_def.config.default_instructions != False or raw_request.output.instructions - if raw_request.output - else False - ): + if format_def.config.default_instructions or raw_request.output.instructions if raw_request.output else False: out_request.messages = inject_instructions(out_request.messages, instructions) if format_def.config.constrained is not None: @@ -384,7 +380,7 @@ def resolve_instructions(formatter: Formatter, instructions_opt: bool | str | No if isinstance(instructions_opt, str): # user provided instructions return instructions_opt - if instructions_opt == False: + if not instructions_opt: # user says no instructions return None if not formatter: @@ -425,7 +421,7 @@ def assert_valid_tool_names(raw_request: GenerateActionOptions): pass -def resolve_parameters( +async def resolve_parameters( registry: Registry, request: GenerateActionOptions ) -> tuple[Action, list[Action], FormatDef | None]: """Resolve parameters for the generate action. @@ -442,14 +438,14 @@ def resolve_parameters( if not model: raise Exception('No model configured.') - model_action = registry.lookup_action(ActionKind.MODEL, model) + model_action = await registry.resolve_action(ActionKind.MODEL, model) if model_action is None: raise Exception(f'Failed to to resolve model {model}') tools: list[Action] = [] if request.tools: for tool_name in request.tools: - tool_action = registry.lookup_action(ActionKind.TOOL, tool_name) + tool_action = await registry.resolve_action(ActionKind.TOOL, tool_name) if tool_action is None: raise Exception(f'Unable to resolve tool {tool_name}') tools.append(tool_action) @@ -541,7 +537,7 @@ async def resolve_tool_requests( # TODO: prompt transfer tool_dict: dict[str, Action] = {} for tool_name in request.tools: - tool_dict[tool_name] = resolve_tool(registry, tool_name) + tool_dict[tool_name] = await resolve_tool(registry, tool_name) revised_model_message = message._original_message.model_copy(deep=True) @@ -646,7 +642,7 @@ async def _resolve_tool_request(tool: Action, tool_request_part: ToolRequestPart raise e -def resolve_tool(registry: Registry, tool_name: str): +async def resolve_tool(registry: Registry, tool_name: str): """Resolve a tool by name from the registry. Args: @@ -659,7 +655,7 @@ def resolve_tool(registry: Registry, tool_name: str): Raises: ValueError: If the tool could not be resolved. """ - return registry.lookup_action(kind=ActionKind.TOOL, name=tool_name) + return await registry.resolve_action(kind=ActionKind.TOOL, name=tool_name) async def _resolve_resume_options( diff --git a/py/packages/genkit/src/genkit/blocks/model.py b/py/packages/genkit/src/genkit/blocks/model.py index b78ee6455e..4b4912354b 100644 --- a/py/packages/genkit/src/genkit/blocks/model.py +++ b/py/packages/genkit/src/genkit/blocks/model.py @@ -36,10 +36,11 @@ def my_model(request: GenerateRequest) -> GenerateResponse: from pydantic import BaseModel, Field -from genkit.core.action import ActionMetadata, ActionRunContext +from genkit.codec import dump_dict +from genkit.core.action import Action, ActionMetadata, ActionRunContext from genkit.core.action.types import ActionKind from genkit.core.extract import extract_json -from genkit.core.schema import to_json_schema +from genkit.core.schema import get_func_description, to_json_schema from genkit.core.typing import ( Candidate, DocumentPart, @@ -64,6 +65,81 @@ def my_model(request: GenerateRequest) -> GenerateResponse: ChunkParser = Callable[['GenerateResponseChunkWrapper'], T] +def model( + name: str, + fn: ModelFn, + config_schema: type[BaseModel] | None = None, + metadata: dict[str, Any] | None = None, + info: ModelInfo | None = None, + description: str | None = None, +) -> Action: + """Create a model action WITHOUT registering it. + + This is the v2 API for creating models. Unlike ai.define_model(), + this function does NOT register the action in any registry - it just + creates and returns an Action object. + + This enables: + 1. V2 plugins to create actions without needing a registry + 2. Standalone usage (call the action directly without framework) + 3. Framework to register actions from v2 plugins when needed + + Args: + name: Model name (without plugin prefix - framework adds it automatically). + fn: Function that implements the model. Takes GenerateRequest and + ActionRunContext, returns GenerateResponse. + config_schema: Optional Pydantic model for config validation. + metadata: Optional metadata dictionary. + info: Optional ModelInfo describing model capabilities (vision, tools, etc.). + description: Optional human-readable description. + + Returns: + Action instance (not registered anywhere). + + Example: + >>> from genkit.blocks.model import model + >>> + >>> def my_generate(request: GenerateRequest, ctx: ActionRunContext): + ... return GenerateResponse(...) + >>> + >>> action = model(name='my-model', fn=my_generate) + >>> response = await action.arun({'messages': [...]}) + + Note: + This function extracts the "create action" logic from + GenkitRegistry.define_model() but skips the registration step. + """ + model_meta: dict[str, Any] = metadata if metadata else {} + + if info: + model_meta['model'] = dump_dict(info) + + if 'model' not in model_meta: + model_meta['model'] = {} + + if 'label' not in model_meta['model'] or not model_meta['model']['label']: + model_meta['model']['label'] = name + + if config_schema: + model_meta['model']['customOptions'] = to_json_schema(config_schema) + + final_description = description if description else get_func_description(fn) + + action = Action( + name=name, + kind=ActionKind.MODEL, + fn=fn, + metadata=model_meta, + description=final_description, + ) + + # NOTE: We do NOT call registry.register_action() here! + # That's the key difference from define_model(). + # The action is created but not registered anywhere. + + return action + + # type ModelMiddlewareNext = Callable[[GenerateRequest, ActionRunContext], Awaitable[GenerateResponse]] ModelMiddlewareNext = Callable[[GenerateRequest, ActionRunContext], Awaitable[GenerateResponse]] # type ModelMiddleware = Callable[ @@ -160,12 +236,28 @@ def __init__( request: The GenerateRequest object associated with the response. message_parser: An optional function to parse the output from the message. """ + # If message is not returned by generate response, try to infer + # message from the first candidate. + response_message = response.message + if response_message is None and response.candidates: + response_message = response.candidates[0].message + if response_message is None: + raise ValueError('GenerateResponse must include either `message` or at least one candidate message.') + + finish_reason = response.finish_reason + if finish_reason is None and response.candidates: + finish_reason = response.candidates[0].finish_reason + + finish_message = response.finish_message + if finish_message is None and response.candidates: + finish_message = response.candidates[0].finish_message + super().__init__( - message=MessageWrapper(response.message) - if not isinstance(response.message, MessageWrapper) - else response.message, - finish_reason=response.finish_reason, - finish_message=response.finish_message, + message=MessageWrapper(response_message) + if not isinstance(response_message, MessageWrapper) + else response_message, + finish_reason=finish_reason, + finish_message=finish_message, latency_ms=response.latency_ms, usage=response.usage if response.usage is not None else GenerationUsage(), custom=response.custom if response.custom is not None else {}, @@ -453,10 +545,7 @@ def model_action_metadata( def model_ref(name: str, namespace: str | None = None, **options: Any) -> ModelReference: - """ - The factory function equivalent to export function modelRef(...) - """ - + """The factory function equivalent to export function modelRef(...).""" # Logic: if (options.namespace && !name.startsWith(options.namespace + '/')) if namespace and not name.startswith(f'{namespace}/'): final_name = f'{namespace}/{name}' diff --git a/py/packages/genkit/src/genkit/blocks/prompt.py b/py/packages/genkit/src/genkit/blocks/prompt.py index fcb3d65fa2..7518554567 100644 --- a/py/packages/genkit/src/genkit/blocks/prompt.py +++ b/py/packages/genkit/src/genkit/blocks/prompt.py @@ -27,7 +27,7 @@ from asyncio import Future from collections.abc import AsyncIterator, Callable from pathlib import Path -from typing import Any +from typing import Any, Awaitable import structlog from dotpromptz.typing import ( @@ -35,14 +35,14 @@ PromptFunction, PromptInputConfig, PromptMetadata, - ToolDefinition as DotPromptzToolDefinition, ) -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict -from genkit.aio import Channel +from genkit.aio import Channel, ensure_async from genkit.blocks.generate import ( StreamingCallback as ModelStreamingCallback, generate_action, + to_tool_definition, ) from genkit.blocks.model import ( GenerateResponseChunkWrapper, @@ -83,14 +83,16 @@ class PromptCache: class PromptConfig(BaseModel): """Model for a prompt action.""" + model_config = ConfigDict(arbitrary_types_allowed=True) + variant: str | None = None model: str | None = None config: GenerationCommonConfig | dict[str, Any] | None = None description: str | None = None input_schema: type | dict[str, Any] | None = None - system: str | Part | list[Part] | None = None - prompt: str | Part | list[Part] | None = None - messages: str | list[Message] | None = None + system: str | Part | list[Part] | Callable | None = None + prompt: str | Part | list[Part] | Callable | None = None + messages: str | list[Message] | Callable | None = None output_format: str | None = None output_content_type: str | None = None output_instructions: bool | str | None = None @@ -102,7 +104,7 @@ class PromptConfig(BaseModel): tools: list[str] | None = None tool_choice: ToolChoice | None = None use: list[ModelMiddleware] | None = None - docs: list[DocumentData] | None = None + docs: list[DocumentData] | Callable | None = None tool_responses: list[Part] | None = None @@ -117,9 +119,9 @@ def __init__( config: GenerationCommonConfig | dict[str, Any] | None = None, description: str | None = None, input_schema: type | dict[str, Any] | None = None, - system: str | Part | list[Part] | None = None, - prompt: str | Part | list[Part] | None = None, - messages: str | list[Message] | None = None, + system: str | Part | list[Part] | Callable | None = None, + prompt: str | Part | list[Part] | Callable | None = None, + messages: str | list[Message] | Callable | None = None, output_format: str | None = None, output_content_type: str | None = None, output_instructions: bool | str | None = None, @@ -131,6 +133,7 @@ def __init__( tools: list[str] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, + docs: list[DocumentData] | Callable | None = None, _name: str | None = None, # prompt name for action lookup _ns: str | None = None, # namespace for action lookup _prompt_action: Action | None = None, # reference to PROMPT action @@ -160,6 +163,7 @@ def __init__( tools: A list of tool names to use with the prompt. tool_choice: The tool choice strategy. use: A list of model middlewares to apply. + docs: A list of documents to be used for grounding. """ self._registry = registry self._variant = variant @@ -181,11 +185,23 @@ def __init__( self._tools = tools self._tool_choice = tool_choice self._use = use + self._docs = docs self._cache_prompt = PromptCache() self._name = _name # Store name/ns for action lookup (used by as_tool()) self._ns = _ns self._prompt_action = _prompt_action + @property + def ref(self) -> dict[str, Any]: + """Returns a reference object for this prompt. + + The reference object contains the prompt's name and metadata. + """ + return { + 'name': registry_definition_key(self._name, self._variant, self._ns) if self._name else None, + 'metadata': self._metadata, + } + async def __call__( self, input: Any | None = None, @@ -281,6 +297,7 @@ async def render( output_constrained=self._output_constrained, input_schema=self._input_schema, metadata=self._metadata, + docs=self._docs, ) model = options.model or self._registry.default_model @@ -330,7 +347,7 @@ async def render( tool_choice=options.tool_choice, output=output, max_turns=options.max_turns, - docs=options.docs, + docs=await render_docs(input, options, context), resume=resume, ) @@ -352,7 +369,7 @@ async def as_tool(self) -> Action: lookup_key = registry_lookup_key(self._name, self._variant, self._ns) - action = self._registry.lookup_action_by_key(lookup_key) + action = await self._registry.aresolve_action_by_key(lookup_key) if action is None or action.kind != ActionKind.PROMPT: raise GenkitError( @@ -365,14 +382,15 @@ async def as_tool(self) -> Action: def define_prompt( registry: Registry, + name: str | None = None, variant: str | None = None, model: str | None = None, config: GenerationCommonConfig | dict[str, Any] | None = None, description: str | None = None, input_schema: type | dict[str, Any] | None = None, - system: str | Part | list[Part] | None = None, - prompt: str | Part | list[Part] | None = None, - messages: str | list[Message] | None = None, + system: str | Part | list[Part] | Callable | None = None, + prompt: str | Part | list[Part] | Callable | None = None, + messages: str | list[Message] | Callable | None = None, output_format: str | None = None, output_content_type: str | None = None, output_instructions: bool | str | None = None, @@ -381,16 +399,16 @@ def define_prompt( max_turns: int | None = None, return_tool_requests: bool | None = None, metadata: dict[str, Any] | None = None, - tools: Tools | None = None, + tools: list[str] | None = None, tool_choice: ToolChoice | None = None, use: list[ModelMiddleware] | None = None, - # TODO: - # docs: list[Document] + docs: list[DocumentData] | Callable | None = None, ) -> ExecutablePrompt: """Defines an executable prompt. Args: registry: The registry to use for resolving models and tools. + name: The name of the prompt. variant: The variant of the prompt. model: The model to use for generation. config: The generation configuration. @@ -410,11 +428,12 @@ def define_prompt( tools: A list of tool names to use with the prompt. tool_choice: The tool choice strategy. use: A list of model middlewares to apply. + docs: A list of documents to be used for grounding. Returns: An ExecutablePrompt instance. """ - return ExecutablePrompt( + executable_prompt = ExecutablePrompt( registry, variant=variant, model=model, @@ -435,30 +454,59 @@ def define_prompt( tools=tools, tool_choice=tool_choice, use=use, + docs=docs, + _name=name, ) + if name: + # Register actions for this prompt + action_metadata = { + 'type': 'prompt', + 'source': 'programmatic', + 'prompt': { + 'name': name, + 'variant': variant or '', + }, + } + + async def prompt_action_fn(input: Any = None) -> GenerateRequest: + """PROMPT action function - renders prompt and returns GenerateRequest.""" + options = await executable_prompt.render(input=input) + return await to_generate_request(registry, options) + + async def executable_prompt_action_fn(input: Any = None) -> GenerateActionOptions: + """EXECUTABLE_PROMPT action function - renders prompt and returns GenerateActionOptions.""" + return await executable_prompt.render(input=input) + + action_name = registry_definition_key(name, variant) + prompt_action = registry.register_action( + kind=ActionKind.PROMPT, + name=action_name, + fn=prompt_action_fn, + metadata=action_metadata, + ) + + executable_prompt_action = registry.register_action( + kind=ActionKind.EXECUTABLE_PROMPT, + name=action_name, + fn=executable_prompt_action_fn, + metadata=action_metadata, + ) + + # Link them + executable_prompt._prompt_action = prompt_action + prompt_action._executable_prompt = weakref.ref(executable_prompt) + executable_prompt_action._executable_prompt = weakref.ref(executable_prompt) + + return executable_prompt + async def to_generate_action_options(registry: Registry, options: PromptConfig) -> GenerateActionOptions: """Converts the given parameters to a GenerateActionOptions object. Args: registry: The registry to use for resolving models and tools. - model: The model to use for generation. - prompt: The user prompt. - system: The system message for the prompt. - messages: A list of messages to include in the prompt. - tools: A list of tool names to use with the prompt. - return_tool_requests: Whether to return tool requests. - tool_choice: The tool choice strategy. - tool_responses: tool response parts corresponding to interrupts. - config: The generation configuration. - max_turns: The maximum number of turns in a conversation. - output_format: The output format. - output_content_type: The output content type. - output_instructions: Instructions for formatting the output. - output_schema: The output schema. - output_constrained: Whether the output should be constrained to the output schema. - docs: A list of documents to be used for grounding. + options: The prompt configuration. Returns: A GenerateActionOptions object. @@ -466,13 +514,17 @@ async def to_generate_action_options(registry: Registry, options: PromptConfig) model = options.model or registry.default_model if model is None: raise Exception('No model configured.') + + cache = PromptCache() resolved_msgs: list[Message] = [] if options.system: - resolved_msgs.append(Message(role=Role.SYSTEM, content=_normalize_prompt_arg(options.system))) + result = await render_system_prompt(registry, None, options, cache) + resolved_msgs.append(result) if options.messages: - resolved_msgs += options.messages + resolved_msgs.extend(await render_message_prompt(registry, None, options, cache)) if options.prompt: - resolved_msgs.append(Message(role=Role.USER, content=_normalize_prompt_arg(options.prompt))) + result = await render_user_prompt(registry, None, options, cache) + resolved_msgs.append(result) # If is schema is set but format is not explicitly set, default to # `json` format. @@ -506,7 +558,7 @@ async def to_generate_action_options(registry: Registry, options: PromptConfig) tool_choice=options.tool_choice, output=output, max_turns=options.max_turns, - docs=options.docs, + docs=await render_docs(None, options), resume=resume, ) @@ -532,11 +584,10 @@ async def to_generate_request(registry: Registry, options: GenerateActionOptions the registry. GenkitError: If the options do not contain any messages. """ - tools: list[Action] = [] if options.tools: for tool_name in options.tools: - tool_action = registry.lookup_action(ActionKind.TOOL, tool_name) + tool_action = await registry.resolve_action(ActionKind.TOOL, tool_name) if tool_action is None: raise GenkitError(status='NOT_FOUND', message=f'Unable to resolve tool {tool_name}') tools.append(tool_action) @@ -614,7 +665,6 @@ async def render_system_prompt( Message: A Message object containing the rendered system prompt with Role.SYSTEM """ - if isinstance(options.system, str): if prompt_cache.system is None: prompt_cache.system = await registry.dotprompt.compile(options.system) @@ -630,12 +680,16 @@ async def render_system_prompt( input, PromptMetadata( input=PromptInputConfig( - schema=options.input_schema, + schema=to_json_schema(options.input_schema) if options.input_schema else None, ) ), ), ) + if callable(options.system): + resolved = await ensure_async(options.system)(input, context) + return Message(role=Role.SYSTEM, content=_normalize_prompt_arg(resolved)) + return Message(role=Role.SYSTEM, content=_normalize_prompt_arg(options.system)) @@ -687,8 +741,7 @@ async def render_message_prompt( prompt_cache: PromptCache, context: dict[str, Any] | None = None, ) -> list[Message]: - """ - Render a message prompt using a given registry, input data, options, and a context. + """Render a message prompt using a given registry, input data, options, and a context. This function processes different types of message options (string or list) to render appropriate messages using a prompt registry and cache. If the `messages` option is of type @@ -727,14 +780,22 @@ async def render_message_prompt( context=context, messages=messages_, ), - options=PromptMetadata(input=PromptInputConfig()), + options=PromptMetadata( + input=PromptInputConfig( + schema=to_json_schema(options.input_schema) if options.input_schema else None, + ) + ), ) return [Message.model_validate(e.model_dump()) for e in rendered.messages] elif isinstance(options.messages, list): return options.messages - return [Message(role=Role.USER, content=_normalize_prompt_arg(options.prompt))] + elif callable(options.messages): + resolved = await ensure_async(options.messages)(input, context) + return resolved + + raise TypeError(f'Unsupported type for messages: {type(options.messages)}') async def render_user_prompt( @@ -744,8 +805,7 @@ async def render_user_prompt( prompt_cache: PromptCache, context: dict[str, Any] | None = None, ) -> Message: - """ - Asynchronously renders a user prompt based on the given input, context, and options, + """Asynchronously renders a user prompt based on the given input, context, and options, utilizing a pre-compiled or dynamically compiled dotprompt template. Arguments: @@ -774,13 +834,45 @@ async def render_user_prompt( context, prompt_cache.user_prompt, input, - PromptMetadata(input=PromptInputConfig()), + PromptMetadata( + input=PromptInputConfig( + schema=to_json_schema(options.input_schema) if options.input_schema else None, + ) + ), ), ) + if callable(options.prompt): + resolved = await ensure_async(options.prompt)(input, context) + return Message(role=Role.USER, content=_normalize_prompt_arg(resolved)) + return Message(role=Role.USER, content=_normalize_prompt_arg(options.prompt)) +async def render_docs( + input: dict[str, Any], + options: PromptConfig, + context: dict[str, Any] | None = None, +) -> list[DocumentData] | None: + """Renders the docs for a prompt action. + + Args: + input: Dictionary of input values. + options: Configuration options for the prompt. + context: Optional dictionary of context values. + + Returns: + A list of DocumentData objects or None. + """ + if options.docs is None: + return None + + if callable(options.docs): + return await ensure_async(options.docs)(input, context) + + return options.docs + + def registry_definition_key(name: str, variant: str | None = None, ns: str | None = None) -> str: """Generate a registry definition key for a prompt. @@ -888,7 +980,7 @@ def load_prompt(registry: Registry, path: Path, filename: str, prefix: str = '', file_path = path / filename # Read the prompt file - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, encoding='utf-8') as f: source = f.read() # Parse the prompt @@ -996,7 +1088,7 @@ async def create_prompt_from_file(): # Store reference to PROMPT action on the ExecutablePrompt # Actions are already registered at this point (lazy loading happens after registration) lookup_key = registry_lookup_key(name, variant, ns) - prompt_action = registry.lookup_action_by_key(lookup_key) + prompt_action = await registry.aresolve_action_by_key(lookup_key) if prompt_action and prompt_action.kind == ActionKind.PROMPT: executable_prompt._prompt_action = prompt_action # Also store ExecutablePrompt reference on the action @@ -1092,7 +1184,7 @@ def load_prompt_folder_recursively(registry: Registry, dir_path: Path, ns: str, if entry.name.startswith('_'): # This is a partial partial_name = entry.name[1:-7] # Remove "_" prefix and ".prompt" suffix - with open(entry.path, 'r', encoding='utf-8') as f: + with open(entry.path, encoding='utf-8') as f: source = f.read() # Strip frontmatter if present @@ -1160,14 +1252,14 @@ async def lookup_prompt(registry: Registry, name: str, variant: str | None = Non # Use create_action_key to build the full key: "/prompt/" definition_key = registry_definition_key(name, variant, None) lookup_key = create_action_key(ActionKind.PROMPT, definition_key) - action = registry.lookup_action_by_key(lookup_key) + action = await registry.aresolve_action_by_key(lookup_key) # If not found and no namespace was specified, try with default 'dotprompt' namespace # (for file-based prompts) if not action: definition_key = registry_definition_key(name, variant, 'dotprompt') lookup_key = create_action_key(ActionKind.PROMPT, definition_key) - action = registry.lookup_action_by_key(lookup_key) + action = await registry.aresolve_action_by_key(lookup_key) if action: # First check if we've stored the ExecutablePrompt directly @@ -1227,5 +1319,4 @@ async def prompt( Raises: GenkitError: If the prompt is not found. """ - return await lookup_prompt(registry, name, variant) diff --git a/py/packages/genkit/src/genkit/blocks/reranker.py b/py/packages/genkit/src/genkit/blocks/reranker.py new file mode 100644 index 0000000000..cb22806e5c --- /dev/null +++ b/py/packages/genkit/src/genkit/blocks/reranker.py @@ -0,0 +1,440 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Reranker type definitions for the Genkit framework. + +Rerankers and Two-Stage Retrieval +================================= + +A **reranking model** (also known as a cross-encoder) is a type of model that, +given a query and document, outputs a similarity score. This score is used to +reorder documents by relevance to the query. + +Reranker APIs take a list of documents (e.g., the output of a retriever) and +reorder them based on their relevance to the query. This step can be useful +for fine-tuning results and ensuring the most pertinent information is used +in the prompt provided to a generative model. + +Two-Stage Retrieval +------------------- + +In a typical RAG (Retrieval-Augmented Generation) pipeline: + +1. **Stage 1 - Retrieval**: A retriever fetches a large set of candidate + documents using fast vector similarity search. +2. **Stage 2 - Reranking**: A reranker scores and reorders these candidates + using more expensive but accurate cross-encoder models. + +This two-stage approach balances speed and accuracy: +- Retrievers are fast but may not perfectly rank results +- Rerankers are slower but provide superior relevance scoring + +Usage Example +------------- + +Using an existing reranker (e.g., Vertex AI): + +.. code-block:: python + + from genkit.ai import Genkit + + ai = Genkit(plugins=[...]) + + + @ai.flow() + async def rerank_flow(query: str): + documents = [ + Document.from_text('pythagorean theorem'), + Document.from_text('quantum mechanics'), + Document.from_text('pizza'), + ] + + reranked = await ai.rerank( + reranker='vertexai/semantic-ranker-512', + query=query, + documents=documents, + ) + + return [{'text': doc.text(), 'score': doc.score} for doc in reranked] + +Custom Rerankers +---------------- + +You can define custom rerankers for specific use cases: + +.. code-block:: python + + from genkit.ai import Genkit + from genkit.core.typing import ( + RerankerResponse, + RankedDocumentData, + RankedDocumentMetadata, + ) + + ai = Genkit() + + + async def custom_reranker_fn(query, documents, options): + # Your custom reranking logic here + # Example: score by keyword overlap + query_words = set(query.text().lower().split()) + scored = [] + for doc in documents: + doc_words = set(doc.text().lower().split()) + overlap = len(query_words & doc_words) + score = overlap / max(len(query_words), 1) + scored.append((doc, score)) + + # Sort by score descending and take top k + k = options.get('k', 3) if options else 3 + scored.sort(key=lambda x: x[1], reverse=True) + top_k = scored[:k] + + return RerankerResponse( + documents=[ + RankedDocumentData(content=doc.content, metadata=RankedDocumentMetadata(score=score)) + for doc, score in top_k + ] + ) + + + ai.define_reranker('custom/keyword-reranker', custom_reranker_fn) + + + # Use it in a flow + @ai.flow() + async def search_flow(query: str): + docs = await ai.retrieve(retriever='my-retriever', query=query) + return await ai.rerank(reranker='custom/keyword-reranker', query=query, documents=docs, options={'k': 5}) +""" + +from collections.abc import Awaitable, Callable +from typing import Any, TypeVar, Union + +from pydantic import BaseModel, ConfigDict, Field + +from genkit.blocks.document import Document +from genkit.core.action import Action, ActionMetadata +from genkit.core.action.types import ActionKind +from genkit.core.registry import Registry +from genkit.core.schema import to_json_schema +from genkit.core.typing import ( + DocumentData, + DocumentPart, + RankedDocumentData, + RankedDocumentMetadata, + RerankerRequest, + RerankerResponse, +) + +T = TypeVar('T') + +# Type alias for reranker function +RerankerFn = Callable[[Document, list[Document], T], Awaitable[RerankerResponse]] + + +class RankedDocument(Document): + """A document with a relevance score from reranking. + + This class extends Document to include a score property that represents + the document's relevance to a query as determined by a reranker. + """ + + def __init__( + self, + content: list[DocumentPart], + metadata: dict[str, Any] | None = None, + score: float | None = None, + ) -> None: + """Initializes a RankedDocument object. + + Args: + content: A list of DocumentPart objects representing the document's content. + metadata: An optional dictionary containing metadata about the document. + score: The relevance score from reranking. + """ + md = metadata.copy() if metadata else {} + if score is not None: + md['score'] = score + super().__init__(content=content, metadata=md) + + @property + def score(self) -> float | None: + """Returns the relevance score of the document. + + Returns: + The relevance score as a float, or None if not set. + """ + if self.metadata and 'score' in self.metadata: + return self.metadata['score'] + return None + + @staticmethod + def from_ranked_document_data(data: RankedDocumentData) -> 'RankedDocument': + """Constructs a RankedDocument from RankedDocumentData. + + Args: + data: The RankedDocumentData containing content, metadata with score. + + Returns: + A new RankedDocument instance. + """ + return RankedDocument( + content=data.content, + metadata=data.metadata.model_dump(), + score=data.metadata.score, + ) + + +class RerankerSupports(BaseModel): + """Reranker capability support.""" + + model_config = ConfigDict(extra='forbid', populate_by_name=True) + + media: bool | None = None + + +class RerankerInfo(BaseModel): + """Information about a reranker's capabilities.""" + + model_config = ConfigDict(extra='forbid', populate_by_name=True) + + label: str | None = None + supports: RerankerSupports | None = None + + +class RerankerOptions(BaseModel): + """Configuration options for a reranker.""" + + model_config = ConfigDict(extra='forbid', populate_by_name=True) + + config_schema: dict[str, Any] | None = Field(None, alias='configSchema') + label: str | None = None + supports: RerankerSupports | None = None + + +class RerankerRef(BaseModel): + """Reference to a reranker with configuration. + + Used to reference a reranker by name with optional configuration + and version information. + """ + + model_config = ConfigDict(extra='forbid', populate_by_name=True) + + name: str + config: Any | None = None + version: str | None = None + info: RerankerInfo | None = None + + +def reranker_action_metadata( + name: str, + options: RerankerOptions | None = None, +) -> ActionMetadata: + """Creates action metadata for a reranker. + + Args: + name: The name of the reranker. + options: Optional configuration options for the reranker. + + Returns: + An ActionMetadata instance for the reranker. + """ + options = options if options is not None else RerankerOptions() + reranker_metadata_dict: dict[str, Any] = {'reranker': {}} + + if options.label: + reranker_metadata_dict['reranker']['label'] = options.label + + if options.supports: + reranker_metadata_dict['reranker']['supports'] = options.supports.model_dump(exclude_none=True, by_alias=True) + + reranker_metadata_dict['reranker']['customOptions'] = options.config_schema if options.config_schema else None + + return ActionMetadata( + kind=ActionKind.RERANKER, + name=name, + input_json_schema=to_json_schema(RerankerRequest), + output_json_schema=to_json_schema(RerankerResponse), + metadata=reranker_metadata_dict, + ) + + +def create_reranker_ref( + name: str, + config: dict[str, Any] | None = None, + version: str | None = None, + info: RerankerInfo | None = None, +) -> RerankerRef: + """Creates a RerankerRef instance. + + Args: + name: The name of the reranker. + config: Optional configuration for the reranker. + version: Optional version string. + info: Optional RerankerInfo with capability information. + + Returns: + A RerankerRef instance. + """ + return RerankerRef(name=name, config=config, version=version, info=info) + + +def define_reranker( + registry: Registry, + name: str, + fn: RerankerFn, + options: RerankerOptions | None = None, +) -> Action: + """Defines and registers a reranker action. + + Creates a reranker action from the provided function and registers it + in the given registry. + + Args: + registry: The registry to register the reranker in. + name: The name of the reranker. + fn: The reranker function that implements the reranking logic. + options: Optional configuration options for the reranker. + + Returns: + The registered Action instance. + + Example: + >>> async def my_reranker(query, documents, options): + ... # Score and sort documents + ... scored = [(doc, score_doc(query, doc)) for doc in documents] + ... scored.sort(key=lambda x: x[1], reverse=True) + ... return RerankerResponse( + ... documents=[ + ... RankedDocumentData(content=doc.content, metadata=RankedDocumentMetadata(score=score)) + ... for doc, score in scored + ... ] + ... ) + >>> define_reranker(registry, 'my-reranker', my_reranker) + """ + metadata = reranker_action_metadata(name, options) + + async def wrapper( + request: RerankerRequest, + _ctx: Any, + ) -> RerankerResponse: + query_doc = Document.from_document_data(request.query) + documents = [Document.from_document_data(d) for d in request.documents] + return await fn(query_doc, documents, request.options) + + return registry.register_action( + kind=ActionKind.RERANKER, + name=name, + fn=wrapper, + metadata=metadata.metadata, + span_metadata=metadata.metadata, + ) + + +# Type for reranker argument (can be action, reference, or string name) +RerankerArgument = Union[Action, RerankerRef, str] + + +class RerankerParams(BaseModel): + """Parameters for the rerank function. + + Attributes: + reranker: The reranker to use (action, reference, or name string). + query: The query to rank documents against. + documents: The list of documents to rerank. + options: Optional configuration options for this rerank call. + """ + + model_config = ConfigDict(extra='forbid', populate_by_name=True, arbitrary_types_allowed=True) + + reranker: RerankerArgument + query: str | DocumentData + documents: list[DocumentData] + options: Any | None = None + + +async def rerank( + registry: Registry, + params: RerankerParams | dict[str, Any], +) -> list[RankedDocument]: + """Reranks documents based on the provided query using a reranker. + + This function takes a query and a list of documents, and returns the + documents reordered by relevance to the query as determined by the + specified reranker. + + Args: + registry: The registry to look up the reranker in. + params: Parameters for the rerank operation + including the reranker, + query, documents, and optional configuration. + + Returns: + A list of RankedDocument objects sorted by relevance. + + Raises: + ValueError: If the reranker cannot be resolved. + + Example: + >>> ranked_docs = await rerank( + ... registry, + ... { + ... 'reranker': 'my-reranker', + ... 'query': 'What is machine learning?', + ... 'documents': [doc1, doc2, doc3], + ... }, + ... ) + >>> for doc in ranked_docs: + ... print(f'Score: {doc.score}, Text: {doc.text()}') + """ + # Convert dict to RerankerParams if needed + if isinstance(params, dict): + params = RerankerParams(**params) + + # Resolve the reranker action + reranker_action: Action | None = None + + if isinstance(params.reranker, str): + reranker_action = registry.lookup_action(ActionKind.RERANKER, params.reranker) + elif isinstance(params.reranker, RerankerRef): + reranker_action = registry.lookup_action(ActionKind.RERANKER, params.reranker.name) + elif isinstance(params.reranker, Action): + reranker_action = params.reranker + + if reranker_action is None: + raise ValueError(f'Unable to resolve reranker: {params.reranker}') + + # Convert query to DocumentData if it's a string + query_data: DocumentData + if isinstance(params.query, str): + query_data = Document.from_text(params.query) + else: + query_data = params.query + + # Build the request + request = RerankerRequest( + query=query_data, + documents=params.documents, + options=params.options, + ) + + # Call the reranker + action_response = await reranker_action.arun(request) + response: RerankerResponse = action_response.response + + # Convert response to RankedDocument list + return [RankedDocument.from_ranked_document_data(doc) for doc in response.documents] diff --git a/py/packages/genkit/src/genkit/blocks/resource.py b/py/packages/genkit/src/genkit/blocks/resource.py new file mode 100644 index 0000000000..8e5e398dd9 --- /dev/null +++ b/py/packages/genkit/src/genkit/blocks/resource.py @@ -0,0 +1,398 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Resource module for defining and managing resources. +Resources in Genkit represent addressable content or data processing units containing +unstructured data (Post, PDF, etc.) that can be retrieved or generated. They are +identified by URIs (e.g. `file://`, `http://`, `gs://`) and can be static (fixed URI) +or dynamic (using URI templates). +This module provides tools to define resource actions that can resolve these URIs +and return content (`ResourceOutput`) containing `Part`s. +""" + +import inspect +import re +from collections.abc import Awaitable, Callable +from typing import Any, Protocol, TypedDict + +from pydantic import BaseModel + +from genkit.aio import ensure_async +from genkit.core.action import Action, ActionRunContext +from genkit.core.action.types import ActionKind +from genkit.core.registry import Registry +from genkit.core.typing import Metadata, Part + + +class ResourceOptions(TypedDict, total=False): + """Options for defining a resource. + + Attributes: + name: Resource name. If not specified, uri or template will be used as name. + uri: The URI of the resource. Can contain template variables for simple matches, + but `template` is preferred for pattern matching. + template: The URI template (ex. `my://resource/{id}`). See RFC6570 for specification. + Used for matching variable resources. + description: A description of the resource, used for documentation and discovery. + metadata: Arbitrary metadata to attach to the resource action. + """ + + name: str + uri: str + template: str + description: str + metadata: dict[str, Any] + + +class ResourceInput(BaseModel): + """Input structure for a resource request. + + Attributes: + uri: The full URI being requested/resolved. + """ + + uri: str + + +class ResourceOutput(BaseModel): + """Output structure from a resource resolution. + + Attributes: + content: A list of `Part` objects representing the resource content. + """ + + content: list[Part] + + +class ResourceFn(Protocol): + """A function that returns parts for a given resource. + The function receives the resolved input (including the URI) and context, + and should return a `ResourceOutput` containing the content parts. + """ + + def __call__(self, input: ResourceInput, ctx: ActionRunContext) -> Awaitable[ResourceOutput]: ... + + +ResourceArgument = Action | str + + +async def resolve_resources(registry: Registry, resources: list[ResourceArgument] | None = None) -> list[Action]: + """Resolves a list of resource names or actions into a list of Action objects. + + Args: + registry: The registry to lookup resources in. + resources: A list of resource references, which can be either direct `Action` + objects or strings (names/URIs). + + Returns: + A list of resolved `Action` objects. + + Raises: + ValueError: If a resource reference is invalid or cannot be found. + """ + if not resources: + return [] + + resolved_actions = [] + for ref in resources: + if isinstance(ref, str): + resolved_actions.append(await lookup_resource_by_name(registry, ref)) + elif isinstance(ref, Action): + resolved_actions.append(ref) + else: + raise ValueError('Resources must be strings or actions') + return resolved_actions + + +async def lookup_resource_by_name(registry: Registry, name: str) -> Action: + """Looks up a resource action by name in the registry. + Tries to resolve the name directly, or with common prefixes like `/resource/` + or `/dynamic-action-provider/`. + + Args: + registry: The registry to search. + name: The name or URI of the resource to lookup. + + Returns: + The found `Action`. + + Raises: + ValueError: If the resource cannot be found. + """ + resource = ( + registry.lookup_action(ActionKind.RESOURCE, name) + or registry.lookup_action(ActionKind.RESOURCE, f'/resource/{name}') + or registry.lookup_action(ActionKind.RESOURCE, f'/dynamic-action-provider/{name}') + ) + if not resource: + raise ValueError(f'Resource {name} not found') + return resource + + +def define_resource(registry: Registry, opts: ResourceOptions, fn: ResourceFn) -> Action: + """Defines a resource and registers it with the given registry. + This creates a resource action that can handle requests for a specific URI + or URI template. + + Args: + registry: The registry to register the resource with. + opts: Options defining the resource (name, uri, template, etc.). + fn: The function that implements resource content retrieval. + + Returns: + The registered `Action` for the resource. + """ + action = dynamic_resource(opts, fn) + + action.matches = create_matcher(opts.get('uri'), opts.get('template')) + + # Mark as not dynamic since it's being registered + action.metadata['dynamic'] = False + + registry.register_action_from_instance(action) + + return action + + +def resource(opts: ResourceOptions, fn: ResourceFn) -> Action: + """Defines a dynamic resource action without immediate registration. + This is an alias for `dynamic_resource`. Useful for defining resources that + might be registered later or used as standalone actions. + + Args: + opts: Options defining the resource. + fn: The resource implementation function. + + Returns: + The created `Action`. + """ + return dynamic_resource(opts, fn) + + +def dynamic_resource(opts: ResourceOptions, fn: ResourceFn) -> Action: + """Defines a dynamic resource action. + Creates an `Action` of kind `RESOURCE` that wraps the provided function. + The wrapper handles: + 1. Input validation and matching against the URI/Template. + 2. Execution of the resource function. + 3. Post-processing of output to attach metadata (like parent resource info). + + Args: + opts: Options including `uri` or `template` for matching. + fn: The function performing the resource retrieval. + + Returns: + An `Action` configured as a resource. + + Raises: + ValueError: If neither `uri` nor `template` is provided in options. + """ + uri = opts.get('uri') or opts.get('template') + if not uri: + raise ValueError('must specify either uri or template options') + + matcher = create_matcher(opts.get('uri'), opts.get('template')) + + async def wrapped_fn(input_data: ResourceInput, ctx: ActionRunContext) -> ResourceOutput: + if isinstance(input_data, dict): + input_data = ResourceInput(**input_data) + + try: + template_match = matcher(input_data) + if not template_match: + raise ValueError(f'input {input_data} did not match template {uri}') + + sig = inspect.signature(fn) + afn = ensure_async(fn) + n_params = len(sig.parameters) + + if n_params == 0: + parts = await afn() + elif n_params == 1: + parts = await afn(input_data) + else: + parts = await afn(input_data, ctx) + + # Post-processing parts to add metadata + content_list = parts.content if hasattr(parts, 'content') else parts.get('content', []) + + for p in content_list: + if isinstance(p, Part): + p = p.root + + if hasattr(p, 'metadata'): + if p.metadata is None or isinstance(p.metadata, dict): + p.metadata = Metadata(root=p.metadata or {}) + + if isinstance(p.metadata, Metadata): + p_metadata = p.metadata.root + else: + p_metadata = p.metadata + + if 'resource' in p_metadata: + if 'parent' not in p_metadata['resource']: + p_metadata['resource']['parent'] = {'uri': input_data.uri} + if opts.get('template'): + p_metadata['resource']['parent']['template'] = opts.get('template') + else: + p_metadata['resource'] = {'uri': input_data.uri} + if opts.get('template'): + p_metadata['resource']['template'] = opts.get('template') + elif isinstance(p, dict): + if 'metadata' not in p or p['metadata'] is None: + p['metadata'] = {} + p_metadata = p['metadata'] + else: + continue + # Ensure we return a serializable dict (handling Pydantic models in list) + if isinstance(parts, BaseModel): + return parts.model_dump() + elif isinstance(parts, dict): + # Verify content items are dicts, if not dump them + if 'content' in parts: + parts['content'] = [p.model_dump() if isinstance(p, BaseModel) else p for p in parts['content']] + return parts + return parts + except Exception: + raise + + name = opts.get('name') or uri + + act = Action( + name=name, + kind=ActionKind.RESOURCE, + fn=wrapped_fn, + metadata={ + 'resource': { + 'uri': opts.get('uri'), + 'template': opts.get('template'), + }, + 'dynamic': True, + }, + description=opts.get('description'), + span_metadata={'genkit:metadata:resource:uri': uri}, + ) + act.matches = matcher + return act + + +def create_matcher(uri: str | None, template: str | None) -> Callable[[ResourceInput], bool]: + """Creates a matching function for resource validation. + + Args: + uri: Optional fixed URI string. + template: Optional URI template string. + + Returns: + A callable that takes ResourceInput and returns True if it matches. + """ + + def matcher(input_data: ResourceInput) -> bool: + if uri: + return input_data.uri == uri + if template: + return matches_uri_template(template, input_data.uri) is not None + return False + + return matcher + + +def is_dynamic_resource_action(action: Action) -> bool: + """Checks if an action is a dynamic resource (not registered). + + Args: + action: The action to check. + + Returns: + True if the action is a dynamic resource, False otherwise. + """ + return action.kind == ActionKind.RESOURCE and action.metadata.get('dynamic', True) + + +def matches_uri_template(template: str, uri: str) -> dict[str, str] | None: + """Check if a URI matches a template and extract parameters. + + Args: + template: URI template with {param} placeholders (e.g., "file://{path}"). + uri: The URI to match against the template. + + Returns: + Dictionary of extracted parameters if match, None otherwise. + + Examples: + >>> matches_uri_template('file://{path}', 'file:///home/user/doc.txt') + {'path': '/home/user/doc.txt'} + >>> matches_uri_template('user://{id}/profile', 'user://123/profile') + {'id': '123'} + """ + # Split template into parts: text and {param} placeholders + parts = re.split(r'(\{[\w\+]+\})', template) + pattern_parts = [] + for part in parts: + if part.startswith('{') and part.endswith('}'): + param_name = part[1:-1] + if param_name.startswith('+'): + # Reserved expansion: {+var} matches reserved chars like / + param_name = param_name[1:] + pattern_parts.append(f'(?P<{param_name}>.+)') + else: + # Basic expansion: {var} does not match / + pattern_parts.append(f'(?P<{param_name}>[^/]+)') + else: + pattern_parts.append(re.escape(part)) + + pattern = f'^{"".join(pattern_parts)}$' + + match = re.search(pattern, uri) + if match: + return match.groupdict() + return None + + +async def find_matching_resource( + registry: Registry, dynamic_resources: list[Action] | None, input_data: ResourceInput +) -> Action | None: + """Finds a matching resource action. + Checks dynamic resources first, then the registry. + + Args: + registry: The registry to search. + dynamic_resources: Optional list of dynamic resource actions to check first. + input_data: The resource input containing the URI matched against. + + Returns: + The matching Action or None. + """ + if dynamic_resources: + for action in dynamic_resources: + if hasattr(action, 'matches') and action.matches(input_data): + return action + + # Try exact match in registry + resource = registry.lookup_action(ActionKind.RESOURCE, input_data.uri) + if resource: + return resource + + # Iterate all resources to check for matches (e.g. templates) + # This is less efficient but necessary for template matching if not optimized + resources = registry.get_actions_by_kind(ActionKind.RESOURCE) if hasattr(registry, 'get_actions_by_kind') else {} + if not resources and hasattr(registry, '_entries'): + # Fallback for compatibility if registry instance is old (unlikely in this context) + resources = registry._entries.get(ActionKind.RESOURCE, {}) + + for action in resources.values(): + if hasattr(action, 'matches') and action.matches(input_data): + return action + + return None diff --git a/py/packages/genkit/src/genkit/blocks/retriever.py b/py/packages/genkit/src/genkit/blocks/retriever.py index 6564bc92ff..99548b641d 100644 --- a/py/packages/genkit/src/genkit/blocks/retriever.py +++ b/py/packages/genkit/src/genkit/blocks/retriever.py @@ -28,9 +28,9 @@ from pydantic import BaseModel, ConfigDict, Field from genkit.blocks.document import Document -from genkit.core.action import ActionMetadata +from genkit.core.action import Action, ActionMetadata from genkit.core.action.types import ActionKind -from genkit.core.schema import to_json_schema +from genkit.core.schema import get_func_description, to_json_schema from genkit.core.typing import DocumentData, RetrieverResponse T = TypeVar('T') @@ -38,6 +38,45 @@ RetrieverFn = Callable[[Document, T], RetrieverResponse] +def retriever( + name: str, + fn: RetrieverFn, + config_schema: type[BaseModel] | None = None, + metadata: dict[str, Any] | None = None, + description: str | None = None, +) -> 'Action': + """Create a retriever action WITHOUT registering it. + + V2 API for creating retrievers. Returns an Action instance that can be + used standalone or registered by the framework. + + Args: + name: Retriever name (without plugin prefix). + fn: Function implementing retriever behavior. + config_schema: Optional schema for retriever configuration. + metadata: Optional metadata dictionary. + description: Optional description. + + Returns: + Action instance (not registered). + """ + retriever_meta = metadata if metadata else {} + if 'retriever' not in retriever_meta: + retriever_meta['retriever'] = {} + if 'label' not in retriever_meta['retriever']: + retriever_meta['retriever']['label'] = name + if config_schema: + retriever_meta['retriever']['customOptions'] = to_json_schema(config_schema) + + return Action( + name=name, + kind=ActionKind.RETRIEVER, + fn=fn, + metadata=retriever_meta, + description=get_func_description(fn, description), + ) + + class Retriever(Generic[T]): def __init__( self, diff --git a/py/packages/genkit/src/genkit/core/action/_util.py b/py/packages/genkit/src/genkit/core/action/_util.py index 3efe253e56..040baed089 100644 --- a/py/packages/genkit/src/genkit/core/action/_util.py +++ b/py/packages/genkit/src/genkit/core/action/_util.py @@ -17,7 +17,6 @@ """Action utility module for defining and managing action utilities.""" import inspect -import typing from typing import Any diff --git a/py/packages/genkit/src/genkit/core/action/types.py b/py/packages/genkit/src/genkit/core/action/types.py index 9609285499..670ec5eee3 100644 --- a/py/packages/genkit/src/genkit/core/action/types.py +++ b/py/packages/genkit/src/genkit/core/action/types.py @@ -20,15 +20,15 @@ import sys from collections.abc import Callable -from typing import Any, Awaitable, Dict, List, Literal, Protocol, Union - -from pydantic import BaseModel, ConfigDict, Field +from typing import Any if sys.version_info < (3, 11): from strenum import StrEnum else: from enum import StrEnum +from pydantic import BaseModel, ConfigDict, Field + # Type alias for action name. # type ActionName = str ActionName = str @@ -57,6 +57,7 @@ class ActionKind(StrEnum): MODEL = 'model' PROMPT = 'prompt' RERANKER = 'reranker' + RESOURCE = 'resource' RETRIEVER = 'retriever' TOOL = 'tool' UTIL = 'util' diff --git a/py/packages/genkit/src/genkit/core/environment.py b/py/packages/genkit/src/genkit/core/environment.py index 562360fbd1..be69f801df 100644 --- a/py/packages/genkit/src/genkit/core/environment.py +++ b/py/packages/genkit/src/genkit/core/environment.py @@ -16,7 +16,6 @@ """Convenience functionality to determine the running environment.""" -import enum import os import sys diff --git a/py/packages/genkit/src/genkit/core/flows.py b/py/packages/genkit/src/genkit/core/flows.py index bccb983b84..562f031e85 100644 --- a/py/packages/genkit/src/genkit/core/flows.py +++ b/py/packages/genkit/src/genkit/core/flows.py @@ -156,7 +156,7 @@ async def handle_run_flows( try: # Look up the flow action. - action = registry.lookup_action_by_key(flow_name) + action = await registry.resolve_action_by_key(flow_name) if action is None: await logger.aerror( 'Flow not found', diff --git a/py/packages/genkit/src/genkit/core/reflection.py b/py/packages/genkit/src/genkit/core/reflection.py index 1797a6ba82..96ec886810 100644 --- a/py/packages/genkit/src/genkit/core/reflection.py +++ b/py/packages/genkit/src/genkit/core/reflection.py @@ -44,6 +44,8 @@ import json import urllib.parse from collections.abc import AsyncGenerator +from dataclasses import dataclass, field +from datetime import datetime from http.server import BaseHTTPRequestHandler from typing import Any @@ -73,10 +75,22 @@ logger = structlog.get_logger(__name__) +@dataclass +class ActiveAction: + """Represents an in-flight action that can be cancelled.""" + + task: asyncio.Task | None + trace_id: str + start_time: datetime = field(default_factory=datetime.now) + + +# Global dict to track active actions by trace ID +_active_actions: dict[str, ActiveAction] = {} + + def make_reflection_server( registry: Registry, loop: asyncio.AbstractEventLoop, - id: str, encoding='utf-8', quiet=True, ): @@ -114,24 +128,16 @@ def do_GET(self) -> None: # noqa: N802 For the /api/actions endpoint, returns a JSON object mapping action keys to their metadata, including input/output schemas. """ - parsed_url = urllib.parse.urlparse(self.path) - if parsed_url.path == '/api/__health': - query_params = urllib.parse.parse_qs(parsed_url.query) - expected_id = query_params.get('id', [None])[0] - if expected_id is not None and expected_id != id: - self.send_response(500) - self.end_headers() - return - + if self.path == '/api/__health': self.send_response(200, 'OK') self.end_headers() - elif parsed_url.path == '/api/actions': + elif self.path == '/api/actions': self.send_response(200) self.send_header('content-type', 'application/json') self.end_headers() actions = registry.list_serializable_actions() - actions = registry.list_actions(actions) + actions = registry.list_actions_sync(actions) self.wfile.write(bytes(json.dumps(actions), encoding)) else: self.send_response(404) @@ -158,7 +164,6 @@ def do_POST(self) -> None: # noqa: N802 post_body = self.rfile.read(content_len) payload = json.loads(post_body.decode(encoding=encoding)) action = registry.lookup_action_by_key(payload['key']) - action_input = payload.get('input') context = payload['context'] if 'context' in payload else {} query = urllib.parse.urlparse(self.path).query @@ -186,7 +191,7 @@ def send_chunk(chunk): async def run_fn(): return await action.arun_raw( - raw_input=payload.get('input'), + raw_input=payload['input'], on_chunk=send_chunk, context=context, ) @@ -217,7 +222,7 @@ async def run_fn(): try: async def run_fn(): - return await action.arun_raw(raw_input=payload.get('input'), context=context) + return await action.arun_raw(raw_input=payload['input'], context=context) output = run_async(loop, run_fn) @@ -327,8 +332,10 @@ async def handle_list_actions(request: Request) -> JSONResponse: Returns: A JSON response containing all serializable actions. """ + actions = registry.list_serializable_actions() + actions = await registry.list_actions(actions) return JSONResponse( - content=registry.list_serializable_actions(), + content=actions, status_code=200, headers={'x-genkit-version': version}, ) @@ -348,6 +355,41 @@ async def handle_notify(request: Request) -> JSONResponse: headers={'x-genkit-version': version}, ) + async def handle_cancel_action(request: Request) -> JSONResponse: + """Handle the cancelAction endpoint for cancelling running actions. + + Args: + request: The Starlette request object containing traceId. + + Returns: + 200 with success message if action was cancelled. + 400 if traceId is missing. + 404 if action not found or already completed. + """ + payload = await request.json() + trace_id = payload.get('traceId') + + if not trace_id or not isinstance(trace_id, str): + return JSONResponse( + content={'error': 'traceId is required'}, + status_code=400, + ) + + active = _active_actions.get(trace_id) + if active: + if active.task and not active.task.done(): + active.task.cancel() + del _active_actions[trace_id] + return JSONResponse( + content={'message': 'Action cancelled'}, + status_code=200, + ) + else: + return JSONResponse( + content={'message': 'Action not found or already completed'}, + status_code=404, + ) + async def handle_run_action( request: Request, ) -> JSONResponse | StreamingResponse: @@ -368,7 +410,7 @@ async def handle_run_action( """ # Get the action. payload = await request.json() - action = registry.lookup_action_by_key(payload['key']) + action = await registry.resolve_action_by_key(payload['key']) if action is None: return JSONResponse( content={'error': f'Action not found: {payload["key"]}'}, @@ -377,15 +419,13 @@ async def handle_run_action( # Run the action. context = payload.get('context', {}) - action_input = payload.get('input') stream = is_streaming_requested(request) handler = run_streaming_action if stream else run_standard_action - return await handler(action, payload, action_input, context, version) + return await handler(action, payload, context, version) async def run_streaming_action( action: Action, payload: dict[str, Any], - action_input: Any, context: dict[str, Any], version: str, ) -> StreamingResponse | JSONResponse: @@ -416,7 +456,7 @@ async def send_chunk(chunk): yield f'{out}\n' output = await action.arun_raw( - raw_input=payload.get('input'), + raw_input=payload['input'], on_chunk=send_chunk, context=context, ) @@ -450,7 +490,6 @@ async def send_chunk(chunk): async def run_standard_action( action: Action, payload: dict[str, Any], - action_input: Any, context: dict[str, Any], version: str, ) -> JSONResponse: @@ -466,7 +505,7 @@ async def run_standard_action( A JSONResponse with the action result or error. """ try: - output = await action.arun_raw(raw_input=payload.get('input'), context=context) + output = await action.arun_raw(raw_input=payload['input'], context=context) response = { 'result': dump_dict(output.response), 'telemetry': {'traceId': output.trace_id}, @@ -491,6 +530,7 @@ async def run_standard_action( Route('/api/actions', handle_list_actions, methods=['GET']), Route('/api/notify', handle_notify, methods=['POST']), Route('/api/runAction', handle_run_action, methods=['POST']), + Route('/api/cancelAction', handle_cancel_action, methods=['POST']), ], middleware=[ Middleware( diff --git a/py/packages/genkit/src/genkit/core/registry.py b/py/packages/genkit/src/genkit/core/registry.py index 316c8c0ba2..b5b57719bf 100644 --- a/py/packages/genkit/src/genkit/core/registry.py +++ b/py/packages/genkit/src/genkit/core/registry.py @@ -27,6 +27,7 @@ >>> action = registry.lookup_action('', 'my_action') """ +import inspect import threading from collections.abc import Callable from typing import Any @@ -36,7 +37,6 @@ from genkit.core.action import ( Action, - ActionMetadata, create_action_key, parse_action_key, parse_plugin_name_from_action_name, @@ -80,8 +80,10 @@ class Registry: def __init__(self): """Initialize an empty Registry instance.""" - self._action_resolvers: dict[str, ActionResolver] = {} - self._list_actions_resolvers: dict[str, Callable] = {} + # Multiple plugins can contribute actions under the same plugin namespace. + # Example: `vertexai/*` for both model + vector search capabilities. + self._action_resolvers: dict[str, list[ActionResolver]] = {} + self._list_actions_resolvers: dict[str, list[Callable]] = {} self._entries: ActionStore = {} self._value_by_kind_and_name: dict[str, dict[str, Any]] = {} self._lock = threading.RLock() @@ -96,13 +98,9 @@ def register_action_resolver(self, plugin_name: str, resolver: ActionResolver) - plugin_name: The name of the plugin. resolver: The ActionResolver instance to register. - Raises: - ValueError: If a resolver is already registered for the plugin. """ with self._lock: - if plugin_name in self._action_resolvers: - raise ValueError(f'Plugin {plugin_name} already registered') - self._action_resolvers[plugin_name] = resolver + self._action_resolvers.setdefault(plugin_name, []).append(resolver) def register_list_actions_resolver(self, plugin_name: str, resolver: Callable) -> None: """Registers an Callable function to list available actions or models. @@ -111,13 +109,9 @@ def register_list_actions_resolver(self, plugin_name: str, resolver: Callable) - plugin_name: The name of the plugin. resolver: The Callable function to list models. - Raises: - ValueError: If a resolver is already registered for the plugin. """ with self._lock: - if plugin_name in self._list_actions_resolvers: - raise ValueError(f'Plugin {plugin_name} already registered') - self._list_actions_resolvers[plugin_name] = resolver + self._list_actions_resolvers.setdefault(plugin_name, []).append(resolver) def register_action( self, @@ -162,9 +156,45 @@ def register_action( self._entries[kind][name] = action return action + def register_action_instance(self, action: Action, *, namespace: str | None = None) -> None: + """Registers a pre-constructed Action instance. + + Note: If a namespace is provided and the action name is not already + prefixed, this method updates the action's name in-place. + + Args: + action: The Action instance to register. + namespace: Optional namespace prefix (e.g. plugin name). + """ + name = action.name + if namespace and not name.startswith(f'{namespace}/'): + name = f'{namespace}/{name}' + action._name = name + + with self._lock: + if action.kind not in self._entries: + self._entries[action.kind] = {} + self._entries[action.kind][name] = action + + def register_action_from_instance(self, action: Action) -> None: + """Register an existing Action instance. + Allows registering a pre-configured Action object, such as one created via + `dynamic_resource` or other factory methods. + Args: + action: The action instance to register. + """ + with self._lock: + if action.kind not in self._entries: + self._entries[action.kind] = {} + self._entries[action.kind][action.name] = action + def lookup_action(self, kind: ActionKind, name: str) -> Action | None: """Look up an action by its kind and name. + .. deprecated:: + Use `await registry.resolve_action(kind, name)` instead. + This sync method cannot properly handle async PluginV2 plugins. + Args: kind: The type of action to look up. name: The name of the action to look up. @@ -172,6 +202,13 @@ def lookup_action(self, kind: ActionKind, name: str) -> Action | None: Returns: The Action instance if found, None otherwise. """ + import warnings + + warnings.warn( + 'registry.lookup_action() is deprecated. Use `await registry.resolve_action(kind, name)` instead.', + DeprecationWarning, + stacklevel=2, + ) with self._lock: # If the entry does not exist, we fist try to call the action # resolver for the plugin to give it a chance to dynamically add the @@ -179,16 +216,133 @@ def lookup_action(self, kind: ActionKind, name: str) -> Action | None: if kind not in self._entries or name not in self._entries[kind]: plugin_name = parse_plugin_name_from_action_name(name) if plugin_name and plugin_name in self._action_resolvers: - self._action_resolvers[plugin_name](kind, name) + # Pass the full namespaced action name to the plugin resolver. + # (Many v1 plugins/tests expect to receive the full name and will + # register actions using it; v2 resolvers can strip the prefix + # internally.) + for resolver in self._action_resolvers[plugin_name]: + result = resolver(kind, name) + if inspect.isawaitable(result): + raise TypeError( + f'Action resolver for plugin "{plugin_name}" returned an awaitable while resolving "{name}". ' + 'Use async resolution (e.g. `await registry.resolve_action(...)`) instead of sync `lookup_action(...)`.' + ) + if kind in self._entries and name in self._entries[kind]: + break if kind in self._entries and name in self._entries[kind]: return self._entries[kind][name] return None + async def resolve_action(self, kind: ActionKind, name: str) -> Action | None: + """Resolve an action by kind and name (async). + + Resolves an action name like "openai/gpt-4" (namespaced form). + Registry hit: if name is already in entries[kind], return it. + If miss: parse plugin_name from name (first segment before /). If there's no + plugin prefix, it can't route → returns None. + If plugin prefix exists: look up the list of resolver functions registered for that + plugin and await them. + Check _entries again; if present, return it. + + Args: + kind: The type of action to look up. + name: The namespaced action name (e.g., "openai/gpt-4"). + + Returns: + The Action instance if found, None otherwise. + + Example: + >>> action = await registry.resolve_action(ActionKind.MODEL, 'openai/gpt-4') + """ + # Fast path: already registered (do not trigger resolvers). + with self._lock: + if kind in self._entries and name in self._entries[kind]: + return self._entries[kind][name] + + plugin_name = parse_plugin_name_from_action_name(name) + if not plugin_name: + return None + + resolvers = self._action_resolvers.get(plugin_name) + if not resolvers: + return None + + # Important: pass the full namespaced action name to the plugin resolver. + # V2 resolvers can strip the prefix internally if needed. + for resolver in resolvers: + result = resolver(kind, name) + if inspect.isawaitable(result): + await result + with self._lock: + if kind in self._entries and name in self._entries[kind]: + return self._entries[kind][name] + + with self._lock: + if kind in self._entries and name in self._entries[kind]: + return self._entries[kind][name] + return None + + # Backwards compatibility alias + async def aresolve_action(self, kind: ActionKind, name: str) -> Action | None: + """Deprecated: use resolve_action() instead.""" + import warnings + + warnings.warn( + 'aresolve_action() is deprecated. Use resolve_action() instead.', DeprecationWarning, stacklevel=2 + ) + return await self.resolve_action(kind, name) + + async def resolve_action_by_key(self, key: str) -> Action | None: + """Resolve an action by registry key (async). + + Resolves by registry key (e.g., "/model/openai/gpt-4"). + + Args: + key: The action key in the format "/kind/name". + + Returns: + The Action instance if found, None otherwise. + + Example: + >>> action = await registry.resolve_action_by_key('/model/openai/gpt-4') + """ + kind, name = parse_action_key(key) + return await self.resolve_action(kind, name) + + # Backwards compatibility alias + async def aresolve_action_by_key(self, key: str) -> Action | None: + """Deprecated: use resolve_action_by_key() instead.""" + import warnings + + warnings.warn( + 'aresolve_action_by_key() is deprecated. Use resolve_action_by_key() instead.', + DeprecationWarning, + stacklevel=2, + ) + return await self.resolve_action_by_key(key) + + def get_actions_by_kind(self, kind: ActionKind) -> dict[str, Action]: + """Returns a dictionary of all registered actions for a specific kind. + + Args: + kind: The type of actions to retrieve (e.g., TOOL, MODEL, RESOURCE). + + Returns: + A dictionary mapping action names to Action instances. + Returns an empty dictionary if no actions of that kind are registered. + """ + with self._lock: + return self._entries.get(kind, {}).copy() + def lookup_action_by_key(self, key: str) -> Action | None: """Look up an action using its combined key string. + .. deprecated:: + Use `await registry.resolve_action_by_key(key)` instead. + This sync method cannot properly handle async PluginV2 plugins. + The key format is `/`, where kind must be a valid `ActionKind` and name must be a registered action name within that kind. @@ -202,8 +356,18 @@ def lookup_action_by_key(self, key: str) -> Action | None: ValueError: If the key format is invalid or the kind is not a valid `ActionKind`. """ + import warnings + + warnings.warn( + 'registry.lookup_action_by_key() is deprecated. Use `await registry.resolve_action_by_key(key)` instead.', + DeprecationWarning, + stacklevel=2, + ) kind, name = parse_action_key(key) - return self.lookup_action(kind, name) + # Suppress nested deprecation warning (internal delegation) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + return self.lookup_action(kind, name) def list_serializable_actions(self, allowed_kinds: set[ActionKind] | None = None) -> dict[str, Action] | None: """Enlist all the actions into a dictionary. @@ -221,7 +385,8 @@ def list_serializable_actions(self, allowed_kinds: set[ActionKind] | None = None if allowed_kinds is not None and kind not in allowed_kinds: continue for name in self._entries[kind]: - action = self.lookup_action(kind, name) + # Read directly from _entries (already registered actions) + action = self._entries[kind].get(name) if action is not None: key = create_action_key(kind, name) # TODO: Serialize the Action instance @@ -234,12 +399,16 @@ def list_serializable_actions(self, allowed_kinds: set[ActionKind] | None = None } return actions - def list_actions( + def list_actions_sync( self, actions: dict[str, Action] | None = None, allowed_kinds: set[ActionKind] | None = None, ) -> dict[str, Action] | None: - """Add actions or models. + """Add actions or models (sync version - deprecated). + + .. deprecated:: + Use `await registry.list_actions(...)` instead. + This sync method cannot properly handle async PluginV2 plugins. Args: actions: dictionary of serializable actions. @@ -249,28 +418,98 @@ def list_actions( Returns: A dictionary of serializable Actions updated. """ + import warnings + + warnings.warn( + 'list_actions_sync() is deprecated. Use `await registry.list_actions(...)` instead.', + DeprecationWarning, + stacklevel=2, + ) + if actions is None: actions = {} for plugin_name in self._list_actions_resolvers: - actions_list = self._list_actions_resolvers[plugin_name]() + for resolver in self._list_actions_resolvers[plugin_name]: + actions_list = resolver() + if inspect.isawaitable(actions_list): + raise TypeError( + f'list_actions resolver for plugin "{plugin_name}" returned an awaitable. ' + 'Use `await registry.list_actions(...)` to list actions asynchronously.' + ) + + for _action in actions_list: + kind = _action.kind + if allowed_kinds is not None and kind not in allowed_kinds: + continue + key = create_action_key(kind, _action.name) + + if key not in actions: + actions[key] = { + 'key': key, + 'name': _action.name, + 'inputSchema': _action.input_json_schema, + 'outputSchema': _action.output_json_schema, + 'metadata': _action.metadata, + } + return actions - for _action in actions_list: - kind = _action.kind - if allowed_kinds is not None and kind not in allowed_kinds: - continue - key = create_action_key(kind, _action.name) - - if key not in actions: - actions[key] = { - 'key': key, - 'name': _action.name, - 'inputSchema': _action.input_json_schema, - 'outputSchema': _action.output_json_schema, - 'metadata': _action.metadata, - } + async def list_actions( + self, + actions: dict[str, Action] | None = None, + allowed_kinds: set[ActionKind] | None = None, + ) -> dict[str, Action] | None: + """List all actions (async). + + Async listing that awaits plugin list() resolvers. If allowed_kinds not provided, + returns action metadata for all kinds. + + Args: + actions: Optional dictionary to append actions to. + allowed_kinds: Optional set of ActionKind to filter by. + + Returns: + Dictionary of serializable actions. + + Example: + >>> actions = await registry.list_actions(allowed_kinds={ActionKind.MODEL}) + """ + if actions is None: + actions = {} + + for plugin_name in self._list_actions_resolvers: + for resolver in self._list_actions_resolvers[plugin_name]: + actions_list = resolver() + if inspect.isawaitable(actions_list): + actions_list = await actions_list + + for _action in actions_list: + kind = _action.kind + if allowed_kinds is not None and kind not in allowed_kinds: + continue + key = create_action_key(kind, _action.name) + if key not in actions: + actions[key] = { + 'key': key, + 'name': _action.name, + 'inputSchema': _action.input_json_schema, + 'outputSchema': _action.output_json_schema, + 'metadata': _action.metadata, + } return actions + # Backwards compatibility alias + async def alist_actions( + self, + actions: dict[str, Action] | None = None, + allowed_kinds: set[ActionKind] | None = None, + ) -> dict[str, Action] | None: + """Deprecated: use list_actions() instead.""" + import warnings + + warnings.warn('alist_actions() is deprecated. Use list_actions() instead.', DeprecationWarning, stacklevel=2) + return await self.list_actions(actions, allowed_kinds) + def register_value(self, kind: str, name: str, value: Any): """Registers a value with a given kind and name. diff --git a/py/packages/genkit/src/genkit/core/schema.py b/py/packages/genkit/src/genkit/core/schema.py index 3f5946a98f..f7b2cc0567 100644 --- a/py/packages/genkit/src/genkit/core/schema.py +++ b/py/packages/genkit/src/genkit/core/schema.py @@ -16,11 +16,30 @@ """Functions for working with schema.""" +from collections.abc import Callable from typing import Any from pydantic import TypeAdapter +def get_func_description(func: Callable, description: str | None = None) -> str: + """Get the description of a function. + + Args: + func: The function to get the description of. + description: The description to use if the function docstring is + empty. + + Returns: + The description of the function. + """ + if description is not None: + return description + if func.__doc__ is not None: + return func.__doc__ + return '' + + def to_json_schema(schema: type | dict[str, Any]) -> dict[str, Any]: """Converts a Python type to a JSON schema. diff --git a/py/packages/genkit/src/genkit/core/trace/default_exporter.py b/py/packages/genkit/src/genkit/core/trace/default_exporter.py index fbc9fc4ff1..5570cbb8cd 100644 --- a/py/packages/genkit/src/genkit/core/trace/default_exporter.py +++ b/py/packages/genkit/src/genkit/core/trace/default_exporter.py @@ -25,11 +25,10 @@ - Utility functions for converting and formatting trace attributes """ -import asyncio import json import os import sys -from collections.abc import Awaitable, Sequence +from collections.abc import Sequence from typing import Any from urllib.parse import urljoin diff --git a/py/packages/genkit/src/genkit/core/typing.py b/py/packages/genkit/src/genkit/core/typing.py index dce5a4f7f1..68fd541101 100644 --- a/py/packages/genkit/src/genkit/core/typing.py +++ b/py/packages/genkit/src/genkit/core/typing.py @@ -913,7 +913,7 @@ class Message(BaseModel): """Model for message data.""" model_config = ConfigDict(extra='forbid', populate_by_name=True) - role: Role + role: Role | str content: list[Part] metadata: dict[str, Any] | None = None diff --git a/py/packages/genkit/src/genkit/types/__init__.py b/py/packages/genkit/src/genkit/types/__init__.py index fa2f6b24c6..4465eb3d74 100644 --- a/py/packages/genkit/src/genkit/types/__init__.py +++ b/py/packages/genkit/src/genkit/types/__init__.py @@ -48,6 +48,7 @@ ModelInfo, OutputConfig, Part, + ReasoningPart, RetrieverRequest, RetrieverResponse, Role, @@ -94,6 +95,7 @@ ModelInfo.__name__, OutputConfig.__name__, Part.__name__, + ReasoningPart.__name__, RetrieverRequest.__name__, RetrieverResponse.__name__, Role.__name__, diff --git a/py/packages/genkit/tests/genkit/ai/plugin_v2_test.py b/py/packages/genkit/tests/genkit/ai/plugin_v2_test.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/py/packages/genkit/tests/genkit/ai/test_resource.py b/py/packages/genkit/tests/genkit/ai/test_resource.py new file mode 100644 index 0000000000..13e187c883 --- /dev/null +++ b/py/packages/genkit/tests/genkit/ai/test_resource.py @@ -0,0 +1,241 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the Genkit Resource API. +This module verifies the functionality of defining, registering, and resolving resources +in the Genkit framework. It covers static resources, template-based resources, +dynamic resource matching, and metadata handling. +""" + +import asyncio + +import pytest + +from genkit.blocks.resource import define_resource, resolve_resources, resource +from genkit.core.registry import Registry +from genkit.core.typing import Part, TextPart + + +def test_define_resource(): + """Verifies that a resource can be defined and registered correctly. + Checks: + - Resource name matches property. + - Resource is retrievable from the registry by name. + """ + registry = Registry() + + async def my_resource_fn(input, ctx): + return {'content': [Part(TextPart(text=f'Content for {input.uri}'))]} + + act = define_resource(registry, {'uri': 'http://example.com/foo'}, my_resource_fn) + + assert act.name == 'http://example.com/foo' + assert act.metadata['resource']['uri'] == 'http://example.com/foo' + + # Verify lookup logic (mocking lookup_action effectively via direct access or helper) + # Registry lookup for resources usually prepends /resource/ etc. + # but define_resource registers it with name=uri + + looked_up = registry.lookup_action('resource', 'http://example.com/foo') + assert looked_up == act + + +@pytest.mark.asyncio +async def test_resolve_resources(): + """Verifies resolving resource references into Action objects. + Checks: + - Resolving by string name works. + - Resolving by Action object passes through. + """ + registry = Registry() + + async def my_resource_fn(input, ctx): + return {'content': [Part(TextPart(text=f'Content for {input.uri}'))]} + + act = define_resource(registry, {'name': 'my-resource', 'uri': 'http://example.com/foo'}, my_resource_fn) + + resolved = await resolve_resources(registry, ['my-resource']) + assert len(resolved) == 1 + assert resolved[0] == act + + resolved_obj = await resolve_resources(registry, [act]) + assert len(resolved_obj) == 1 + assert resolved_obj[0] == act + + +@pytest.mark.asyncio +async def test_find_matching_resource(): + """Verifies the logic for finding a matching resource given an input URI. + Checks: + - Exact match against registered static resources. + - Template match against registered template resources. + - Matching against a provided list of dynamic resource actions for override/adhoc usage. + - Returns None when no match is found. + """ + registry = Registry() + + # Static resource + async def static_fn(input, ctx): + return {'content': []} + + static_res = define_resource(registry, {'uri': 'bar://baz', 'name': 'staticRes'}, static_fn) + + # Template resource + async def template_fn(input, ctx): + return {'content': []} + + template_res = define_resource(registry, {'template': 'foo://bar/{baz}', 'name': 'templateRes'}, template_fn) + + # Dynamic resource list + async def dynamic_fn(input, ctx): + return {'content': []} + + dynamic_res = resource({'uri': 'baz://qux'}, dynamic_fn) + + from genkit.blocks.resource import ResourceInput, find_matching_resource + + # Match static from registry + res = await find_matching_resource(registry, [], ResourceInput(uri='bar://baz')) + assert res == static_res + + # Match template from registry + res = await find_matching_resource(registry, [], ResourceInput(uri='foo://bar/something')) + assert res == template_res + + # Match dynamic from list + res = await find_matching_resource(registry, [dynamic_res], ResourceInput(uri='baz://qux')) + assert res == dynamic_res + + # No match + res = await find_matching_resource(registry, [], ResourceInput(uri='unknown://uri')) + assert res is None + + +def test_is_dynamic_resource_action(): + """Verifies identifying dynamic vs registered resource actions. + Checks: + - Unregistered resources created with `resource()` are dynamic. + - Registered resources created with `define_resource()` are not dynamic. + """ + from genkit.blocks.resource import is_dynamic_resource_action + + async def fn(input, ctx): + return {'content': []} + + dynamic = resource({'uri': 'bar://baz'}, fn) + assert is_dynamic_resource_action(dynamic) + + # Registered action (define_resource sets dynamic=False) + async def static_fn(input, ctx): + return {'content': []} + + static = define_resource(Registry(), {'uri': 'foo://bar'}, static_fn) + assert not is_dynamic_resource_action(static) + + +@pytest.mark.asyncio +async def test_parent_metadata(): + """Verifies that parent metadata is correctly attached to output items. + When a resource is resolved via a template (e.g. `file://{id}`), the output parts + should contain metadata referencing the parent resource URI and template. + Checks: + - Parent URI and template presence in output part metadata. + """ + registry = Registry() + + async def fn(input, ctx): + return {'content': [Part(TextPart(text='sub1', metadata={'resource': {'uri': f'{input.uri}/sub1.txt'}}))]} + + res = define_resource(registry, {'template': 'file://{id}'}, fn) + + output = await res.arun({'uri': 'file://dir'}) + # output is ActionResponse + # content is in output.response['content'] because wrapped_fn ensures serialization + + part = output.response['content'][0] + # Check metadata + assert part['metadata']['resource']['parent']['uri'] == 'file://dir' + assert part['metadata']['resource']['parent']['template'] == 'file://{id}' + assert part['metadata']['resource']['uri'] == 'file://dir/sub1.txt' + + +def test_dynamic_resource_matching(): + """Verifies the matching logic for a simple static URI dynamic resource.""" + + async def my_resource_fn(input, ctx): + return {'content': [Part(TextPart(text='Match'))]} + + res = resource({'uri': 'http://example.com/foo'}, my_resource_fn) + + class MockInput: + uri = 'http://example.com/foo' + + assert res.matches(MockInput()) + + class MockInputBad: + uri = 'http://example.com/bar' + + assert not res.matches(MockInputBad()) + + +def test_template_matching(): + """Verifies URI template pattern matching. + Checks: + - Matches correct URI structure. + - Fails on paths extending beyond the template structure (strict matching). + """ + + async def my_resource_fn(input, ctx): + return {'content': []} + + res = resource({'template': 'http://example.com/items/{id}'}, my_resource_fn) + + class MockInput: + uri = 'http://example.com/items/123' + + assert res.matches(MockInput()) + + class MockInputBad: + uri = 'http://example.com/items/123/details' + + # Should not match because of strict end anchor or slash handling in our regex + assert not res.matches(MockInputBad()) + + +def test_reserved_expansion_matching(): + """Verifies RFC 6570 reserved expansion {+var} pattern matching. + Checks: + - Matches correct URI structure with slashes (reserved chars). + """ + + async def my_resource_fn(input, ctx): + return {'content': []} + + # Template with reserved expansion {+path} (matches slashes) + res = resource({'template': 'http://example.com/files/{+path}'}, my_resource_fn) + + class MockInput: + uri = 'http://example.com/files/foo/bar/baz.txt' + + assert res.matches(MockInput()) + + # Regular template {path} regex ([^/]+) should NOT match slashes + res_simple = resource({'template': 'http://example.com/items/{id}'}, my_resource_fn) + + class MockInputComplex: + uri = 'http://example.com/items/foo/bar' + + assert not res_simple.matches(MockInputComplex()) diff --git a/py/packages/genkit/tests/genkit/blocks/embedding_test.py b/py/packages/genkit/tests/genkit/blocks/embedding_test.py index 04cd4e7f0c..0f1738333f 100644 --- a/py/packages/genkit/tests/genkit/blocks/embedding_test.py +++ b/py/packages/genkit/tests/genkit/blocks/embedding_test.py @@ -25,7 +25,6 @@ from genkit.blocks.document import Document from genkit.blocks.embedding import ( EmbedderOptions, - EmbedderRef, EmbedderSupports, create_embedder_ref, embedder_action_metadata, @@ -156,6 +155,13 @@ async def mock_arun_side_effect(request, *args, **kwargs): def lookup_action(self, kind, name): return self.actions.get((kind, name)) + async def resolve_action(self, kind, name): + return self.lookup_action(kind, name) + + # Backwards compatibility + async def aresolve_action(self, kind, name): + return self.lookup_action(kind, name) + @pytest.fixture def mock_genkit_instance(): diff --git a/py/packages/genkit/tests/genkit/blocks/generate_test.py b/py/packages/genkit/tests/genkit/blocks/generate_test.py index 8d50357916..9feb72996e 100644 --- a/py/packages/genkit/tests/genkit/blocks/generate_test.py +++ b/py/packages/genkit/tests/genkit/blocks/generate_test.py @@ -43,7 +43,7 @@ def setup_test(): @ai.tool(name='testTool') def test_tool(): - """description""" + """Description.""" return 'tool called' return (ai, pm) @@ -333,7 +333,7 @@ async def test_generate_action_spec(spec) -> None: @ai.tool(name='testTool') def test_tool(): - """description""" + """Description.""" return 'tool called' if 'modelResponses' in spec: diff --git a/py/packages/genkit/tests/genkit/blocks/model_test.py b/py/packages/genkit/tests/genkit/blocks/model_test.py index eca64fee58..d0547b506c 100644 --- a/py/packages/genkit/tests/genkit/blocks/model_test.py +++ b/py/packages/genkit/tests/genkit/blocks/model_test.py @@ -58,6 +58,26 @@ def test_response_wrapper_text() -> None: assert wrapper.text == 'hello world' +def test_response_wrapper_uses_candidates_fallback() -> None: + wrapper = GenerateResponseWrapper( + response=GenerateResponse( + candidates=[ + Candidate( + index=0, + message=Message(role='model', content=[Part(text='hello')]), + finish_reason='stop', + ) + ] + ), + request=GenerateRequest( + messages=[], # doesn't matter for now + ), + ) + + assert wrapper.text == 'hello' + assert wrapper.finish_reason == 'stop' + + def test_response_wrapper_output() -> None: wrapper = GenerateResponseWrapper( response=GenerateResponse( diff --git a/py/packages/genkit/tests/genkit/blocks/prompt_test.py b/py/packages/genkit/tests/genkit/blocks/prompt_test.py index 112cbcbde4..8a9af200bf 100644 --- a/py/packages/genkit/tests/genkit/blocks/prompt_test.py +++ b/py/packages/genkit/tests/genkit/blocks/prompt_test.py @@ -28,14 +28,15 @@ from genkit.blocks.prompt import load_prompt_folder, lookup_prompt, prompt from genkit.core.action.types import ActionKind from genkit.core.typing import ( + DocumentData, GenerateActionOptions, GenerateRequest, + GenerateResponse, GenerationCommonConfig, Message, Role, TextPart, ToolChoice, - ToolDefinition, ) from genkit.testing import ( define_echo_model, @@ -147,6 +148,55 @@ def test_tool(input: ToolInput): assert (await response).text == want_txt +@pytest.mark.asyncio +async def test_prompt_with_resolvers() -> None: + """Test that the rendering works with resolvers.""" + ai, *_ = setup_test() + + async def system_resolver(input, context): + return f'system {input["name"]}' + + def prompt_resolver(input, context): + return f'prompt {input["name"]}' + + async def messages_resolver(input, context): + return [Message(role=Role.USER, content=[TextPart(text=f'msg {input["name"]}')])] + + my_prompt = ai.define_prompt( + system=system_resolver, + prompt=prompt_resolver, + messages=messages_resolver, + ) + + want_txt = '[ECHO] system: "system world" user: "msg world" user: "prompt world"' + + response = await my_prompt(input={'name': 'world'}) + + assert response.text == want_txt + + +@pytest.mark.asyncio +async def test_prompt_with_docs_resolver() -> None: + """Test that the rendering works with docs resolver.""" + ai, _, pm = setup_test() + + pm.responses = [GenerateResponse(message=Message(role=Role.MODEL, content=[TextPart(text='ok')]))] + + async def docs_resolver(input, context): + return [DocumentData(content=[TextPart(text=f'doc {input["name"]}')])] + + my_prompt = ai.define_prompt( + model='programmableModel', + prompt='hi', + docs=docs_resolver, + ) + + await my_prompt(input={'name': 'world'}) + + # Check that PM received the docs + assert pm.last_request.docs[0].content[0].root.text == 'doc world' + + test_cases_parse_partial_json = [ ( 'renders system prompt', @@ -208,7 +258,6 @@ def test_tool(input: ToolInput): ] -@pytest.mark.skip(reason='issues when running on CI') @pytest.mark.asyncio @pytest.mark.parametrize( 'test_case, prompt, input, input_option, context, want_rendered', @@ -320,6 +369,25 @@ async def test_load_and_use_partial() -> None: assert 'Hello from partial' in response.text or 'space' in response.text +@pytest.mark.asyncio +async def test_define_partial_programmatically() -> None: + """Test defining partials programmatically using ai.define_partial().""" + ai, *_ = setup_test() + + # Define a partial programmatically + ai.define_partial('myGreeting', 'Greetings, {{name}}!') + + # Create a prompt that uses the partial + my_prompt = ai.define_prompt( + messages='{{>myGreeting}} Welcome to Genkit.', + ) + + response = await my_prompt(input={'name': 'Developer'}) + + # The partial should be included in the output + assert 'Greetings' in response.text and 'Developer' in response.text + + @pytest.mark.asyncio async def test_prompt_with_messages_list() -> None: """Test prompt with explicit messages list.""" @@ -350,7 +418,7 @@ async def test_messages_with_explicit_override() -> None: prompt='Final question', ) - override_messages = [ + [ Message(role=Role.USER, content=[TextPart(text='First message')]), Message(role=Role.MODEL, content=[TextPart(text='First response')]), ] @@ -509,7 +577,6 @@ async def test_prompt_and_executable_prompt_return_types() -> None: @pytest.mark.asyncio async def test_lookup_prompt_returns_executable_prompt() -> None: """lookup_prompt should return an ExecutablePrompt that can be called.""" - ai, *_ = setup_test() with tempfile.TemporaryDirectory() as tmpdir: @@ -600,7 +667,7 @@ async def test_automatic_prompt_loading_defaults_mock(): @pytest.mark.asyncio async def test_automatic_prompt_loading_defaults_missing(): """Test that Genkit skips loading when ./prompts is missing.""" - from unittest.mock import ANY, MagicMock, patch + from unittest.mock import MagicMock, patch with patch('genkit.ai._aio.load_prompt_folder') as mock_load, patch('genkit.ai._aio.Path') as mock_path: # Setup mock to simulate ./prompts missing diff --git a/py/packages/genkit/tests/genkit/blocks/test_reranker.py b/py/packages/genkit/tests/genkit/blocks/test_reranker.py new file mode 100644 index 0000000000..63ac474334 --- /dev/null +++ b/py/packages/genkit/tests/genkit/blocks/test_reranker.py @@ -0,0 +1,484 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the reranker module. + +This module contains tests for the reranker functionality including +RankedDocument, define_reranker, and rerank functions. +""" + +import pytest + +from genkit.blocks.document import Document +from genkit.blocks.reranker import ( + RankedDocument, + RerankerInfo, + RerankerOptions, + RerankerRef, + create_reranker_ref, + define_reranker, + rerank, + reranker_action_metadata, +) +from genkit.core.action.types import ActionKind +from genkit.core.registry import Registry +from genkit.core.typing import ( + DocumentData, + DocumentPart, + RankedDocumentData, + RankedDocumentMetadata, + RerankerResponse, +) + + +class TestRankedDocument: + """Tests for the RankedDocument class.""" + + def test_ranked_document_creation(self): + """Test creating a RankedDocument with content and score.""" + content = [DocumentPart(text='Test content')] + metadata = {'key': 'value'} + score = 0.95 + + doc = RankedDocument(content=content, metadata=metadata, score=score) + + assert doc.score == 0.95 + assert doc.text() == 'Test content' + assert doc.metadata == {'key': 'value', 'score': 0.95} + # Original metadata should not be modified + assert metadata == {'key': 'value'} + + def test_ranked_document_default_score(self): + """Test that RankedDocument has a default score of None.""" + content = [DocumentPart(text='Test')] + doc = RankedDocument(content=content) + + assert doc.score is None + + def test_ranked_document_from_data(self): + """Test creating RankedDocument from RankedDocumentData.""" + data = RankedDocumentData( + content=[DocumentPart(text='Test content')], + metadata=RankedDocumentMetadata(score=0.85), + ) + + doc = RankedDocument.from_ranked_document_data(data) + + assert doc.score == 0.85 + assert doc.text() == 'Test content' + + +class TestRerankerRef: + """Tests for RerankerRef and related helper functions.""" + + def test_create_reranker_ref_basic(self): + """Test creating a basic reranker reference.""" + ref = create_reranker_ref('test-reranker') + + assert ref.name == 'test-reranker' + assert ref.config is None + assert ref.version is None + assert ref.info is None + + def test_create_reranker_ref_with_options(self): + """Test creating a reranker reference with all options.""" + info = RerankerInfo(label='Test Reranker') + ref = create_reranker_ref( + name='test-reranker', + config={'k': 10}, + version='1.0.0', + info=info, + ) + + assert ref.name == 'test-reranker' + assert ref.config == {'k': 10} + assert ref.version == '1.0.0' + assert ref.info.label == 'Test Reranker' + + +class TestRerankerActionMetadata: + """Tests for reranker action metadata creation.""" + + def test_action_metadata_basic(self): + """Test creating basic action metadata.""" + metadata = reranker_action_metadata('test-reranker') + + assert metadata.kind == ActionKind.RERANKER + assert metadata.name == 'test-reranker' + assert 'reranker' in metadata.metadata + + def test_action_metadata_with_options(self): + """Test creating action metadata with options.""" + options = RerankerOptions( + label='Custom Label', + config_schema={'type': 'object'}, + ) + metadata = reranker_action_metadata('test-reranker', options) + + assert metadata.metadata['reranker']['label'] == 'Custom Label' + assert metadata.metadata['reranker']['customOptions'] == {'type': 'object'} + + +class TestDefineReranker: + """Tests for the define_reranker function.""" + + @pytest.fixture + def registry(self): + """Create a fresh registry for each test.""" + return Registry() + + @pytest.mark.asyncio + async def test_define_reranker_registers_action(self, registry): + """Test that define_reranker registers an action in the registry.""" + + async def simple_reranker(query, documents, options): + # Return documents in same order with scores + return RerankerResponse( + documents=[ + RankedDocumentData( + content=doc.content, + metadata=RankedDocumentMetadata(score=1.0 - i * 0.1), + ) + for i, doc in enumerate(documents) + ] + ) + + action = define_reranker(registry, 'test-reranker', simple_reranker) + + # Verify action was registered + lookup = registry.lookup_action(ActionKind.RERANKER, 'test-reranker') + assert lookup is not None + assert action.name == 'test-reranker' + + @pytest.mark.asyncio + async def test_define_reranker_with_options(self, registry): + """Test define_reranker with custom options.""" + + async def reranker_fn(query, documents, options): + return RerankerResponse(documents=[]) + + options = RerankerOptions(label='My Reranker') + action = define_reranker(registry, 'my-reranker', reranker_fn, options) + + assert action is not None + + +class TestRerank: + """Tests for the rerank function.""" + + @pytest.fixture + def registry(self): + """Create a fresh registry for each test.""" + return Registry() + + @pytest.fixture + def sample_documents(self): + """Create sample documents for testing.""" + return [ + DocumentData(content=[DocumentPart(text='First document')]), + DocumentData(content=[DocumentPart(text='Second document')]), + DocumentData(content=[DocumentPart(text='Third document')]), + ] + + @pytest.mark.asyncio + async def test_rerank_with_string_query(self, registry, sample_documents): + """Test rerank with a string query.""" + + async def score_by_length(query, documents, options): + # Score documents by content length (longer = higher score) + scored = [] + for doc in documents: + length = len(doc.text()) + scored.append( + RankedDocumentData( + content=doc.content, + metadata=RankedDocumentMetadata(score=float(length)), + ) + ) + return RerankerResponse(documents=scored) + + define_reranker(registry, 'length-reranker', score_by_length) + + results = await rerank( + registry, + { + 'reranker': 'length-reranker', + 'query': 'test query', + 'documents': sample_documents, + }, + ) + + assert len(results) == 3 + assert all(isinstance(r, RankedDocument) for r in results) + + @pytest.mark.asyncio + async def test_rerank_with_reranker_ref(self, registry, sample_documents): + """Test rerank with a RerankerRef.""" + + async def simple_reranker(query, documents, options): + return RerankerResponse( + documents=[ + RankedDocumentData( + content=doc.content, + metadata=RankedDocumentMetadata(score=0.5), + ) + for doc in documents + ] + ) + + define_reranker(registry, 'ref-reranker', simple_reranker) + ref = create_reranker_ref('ref-reranker') + + results = await rerank( + registry, + { + 'reranker': ref, + 'query': 'test', + 'documents': sample_documents, + }, + ) + + assert len(results) == 3 + assert all(doc.score == 0.5 for doc in results) + + @pytest.mark.asyncio + async def test_rerank_unknown_reranker_raises(self, registry, sample_documents): + """Test that rerank raises ValueError for unknown reranker.""" + with pytest.raises(ValueError, match='Unable to resolve reranker'): + await rerank( + registry, + { + 'reranker': 'non-existent-reranker', + 'query': 'test', + 'documents': sample_documents, + }, + ) + + +class TestCustomRerankers: + """Tests for custom reranker implementations. + + These tests demonstrate how to create custom rerankers as shown + in the genkit.dev documentation: + https://genkit.dev/docs/rag/#rerankers-and-two-stage-retrieval + """ + + @pytest.fixture + def registry(self): + """Create a fresh registry for each test.""" + return Registry() + + @pytest.fixture + def sample_documents(self): + """Create sample documents matching genkit.dev documentation example.""" + return [ + DocumentData(content=[DocumentPart(text='pythagorean theorem')]), + DocumentData(content=[DocumentPart(text='e=mc^2')]), + DocumentData(content=[DocumentPart(text='pi')]), + DocumentData(content=[DocumentPart(text='dinosaurs')]), + DocumentData(content=[DocumentPart(text='quantum mechanics')]), + DocumentData(content=[DocumentPart(text='pizza')]), + DocumentData(content=[DocumentPart(text='harry potter')]), + ] + + @pytest.mark.asyncio + async def test_custom_keyword_overlap_reranker(self, registry, sample_documents): + """Test a custom reranker that scores by keyword overlap. + + This demonstrates the pattern shown in genkit.dev docs for + creating custom reranking logic. + """ + + async def keyword_overlap_reranker(query, documents, options): + """Reranker that scores documents by keyword overlap with query.""" + query_words = set(query.text().lower().split()) + scored = [] + + for doc in documents: + doc_words = set(doc.text().lower().split()) + overlap = len(query_words & doc_words) + score = overlap / max(len(query_words), 1) + scored.append((doc, score)) + + # Sort by score descending + scored.sort(key=lambda x: x[1], reverse=True) + + # Apply k limit if provided in options + k = options.get('k', len(scored)) if options else len(scored) + top_k = scored[:k] + + return RerankerResponse( + documents=[ + RankedDocumentData( + content=doc.content, + metadata=RankedDocumentMetadata(score=score), + ) + for doc, score in top_k + ] + ) + + define_reranker(registry, 'custom/keyword-overlap', keyword_overlap_reranker) + + # Query for 'quantum' should rank 'quantum mechanics' highest + results = await rerank( + registry, + { + 'reranker': 'custom/keyword-overlap', + 'query': 'quantum mechanics physics', + 'documents': sample_documents, + }, + ) + + assert len(results) == 7 + # 'quantum mechanics' should have the highest score (overlaps 2 words) + assert results[0].text() == 'quantum mechanics' + assert results[0].score > 0 + + @pytest.mark.asyncio + async def test_custom_reranker_with_top_k_option(self, registry, sample_documents): + """Test custom reranker with k option to limit results. + + Demonstrates using options to configure reranking behavior. + """ + + async def random_score_reranker(query, documents, options): + """Reranker that assigns incrementing scores and respects k option.""" + k = options.get('k', 3) if options else 3 + + scored_docs = [] + for i, doc in enumerate(documents): + # Score in reverse order so we have a predictable ranking + score = float(len(documents) - i) + scored_docs.append( + RankedDocumentData( + content=doc.content, + metadata=RankedDocumentMetadata(score=score), + ) + ) + + # Sort by score descending and limit to k + scored_docs.sort(key=lambda d: d.metadata.score, reverse=True) + return RerankerResponse(documents=scored_docs[:k]) + + define_reranker(registry, 'custom/with-k-option', random_score_reranker) + + results = await rerank( + registry, + { + 'reranker': 'custom/with-k-option', + 'query': 'test', + 'documents': sample_documents, + 'options': {'k': 3}, + }, + ) + + # Should only return top 3 results + assert len(results) == 3 + + @pytest.mark.asyncio + async def test_custom_reranker_preserves_document_content(self, registry): + """Test that custom reranker preserves original document content.""" + + async def identity_reranker(query, documents, options): + """Reranker that returns documents with their original content.""" + return RerankerResponse( + documents=[ + RankedDocumentData( + content=doc.content, + metadata=RankedDocumentMetadata(score=1.0), + ) + for doc in documents + ] + ) + + define_reranker(registry, 'custom/identity', identity_reranker) + + original_texts = ['Document A', 'Document B with more text', 'Doc C'] + documents = [DocumentData(content=[DocumentPart(text=t)]) for t in original_texts] + + results = await rerank( + registry, + { + 'reranker': 'custom/identity', + 'query': 'test', + 'documents': documents, + }, + ) + + # Verify all original content is preserved + result_texts = [doc.text() for doc in results] + assert result_texts == original_texts + + @pytest.mark.asyncio + async def test_custom_reranker_two_stage_retrieval_pattern(self, registry): + """Test the two-stage retrieval pattern: retrieve then rerank. + + This demonstrates the typical RAG pattern where: + 1. Stage 1: Retrieve a broad set of candidates + 2. Stage 2: Rerank to find most relevant documents + """ + + # Simulate stage 1 retrieval results (unranked) + retrieved_documents = [ + DocumentData(content=[DocumentPart(text='Machine learning is a subset of AI')]), + DocumentData(content=[DocumentPart(text='Pizza is a popular food')]), + DocumentData(content=[DocumentPart(text='Deep learning uses neural networks')]), + DocumentData(content=[DocumentPart(text='Cats are domestic animals')]), + DocumentData(content=[DocumentPart(text='AI transforms industries')]), + ] + + async def relevance_reranker(query, documents, options): + """Reranker that scores by word presence in query.""" + query_lower = query.text().lower() + scored = [] + + for doc in documents: + doc_text = doc.text().lower() + # Simple relevance: count query words in document + score = sum(1 for word in query_lower.split() if word in doc_text) + scored.append((doc, float(score))) + + scored.sort(key=lambda x: x[1], reverse=True) + + return RerankerResponse( + documents=[ + RankedDocumentData( + content=doc.content, + metadata=RankedDocumentMetadata(score=score), + ) + for doc, score in scored + ] + ) + + define_reranker(registry, 'custom/relevance', relevance_reranker) + + # Stage 2: Rerank with query about AI + reranked = await rerank( + registry, + { + 'reranker': 'custom/relevance', + 'query': 'artificial intelligence AI', + 'documents': retrieved_documents, + }, + ) + + # AI-related documents should rank higher than unrelated ones + # Get scores for AI and non-AI documents + ai_scores = [doc.score for doc in reranked if 'AI' in doc.text() or 'learning' in doc.text()] + non_ai_scores = [doc.score for doc in reranked if 'Pizza' in doc.text() or 'Cats' in doc.text()] + + # AI-related documents should have higher scores on average + assert max(ai_scores) > max(non_ai_scores) diff --git a/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py b/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py index 173c6a8037..142e9c79f1 100644 --- a/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py +++ b/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py @@ -98,7 +98,7 @@ async def test_notify_endpoint(asgi_client): @pytest.mark.asyncio async def test_run_action_not_found(asgi_client, mock_registry): """Test that requesting a non-existent action returns a 404 error.""" - mock_registry.lookup_action_by_key.return_value = None + mock_registry.resolve_action_by_key.return_value = None response = await asgi_client.post( '/api/runAction', json={'key': 'non_existent_action', 'input': {'data': 'test'}}, @@ -116,7 +116,7 @@ async def test_run_action_standard(asgi_client, mock_registry): mock_output.trace_id = 'test_trace_id' mock_action.arun_raw.return_value = mock_output - mock_registry.lookup_action_by_key.return_value = mock_action + mock_registry.resolve_action_by_key.return_value = mock_action response = await asgi_client.post('/api/runAction', json={'key': 'test_action', 'input': {'data': 'test'}}) @@ -137,7 +137,7 @@ async def test_run_action_with_context(asgi_client, mock_registry): mock_output.trace_id = 'test_trace_id' mock_action.arun_raw.return_value = mock_output - mock_registry.lookup_action_by_key.return_value = mock_action + mock_registry.resolve_action_by_key.return_value = mock_action response = await asgi_client.post( '/api/runAction', @@ -169,7 +169,7 @@ async def mock_streaming(raw_input, on_chunk=None, context=None): return mock_output mock_action.arun_raw.side_effect = mock_streaming - mock_registry.lookup_action_by_key.return_value = mock_action + mock_registry.resolve_action_by_key.return_value = mock_action response = await asgi_client.post( '/api/runAction?stream=true', diff --git a/py/packages/genkit/tests/genkit/core/extract_test.py b/py/packages/genkit/tests/genkit/core/extract_test.py index 730045f27b..25dc270872 100644 --- a/py/packages/genkit/tests/genkit/core/extract_test.py +++ b/py/packages/genkit/tests/genkit/core/extract_test.py @@ -86,7 +86,7 @@ ids=[tc[0] for tc in test_cases_extract_items], ) def test_extract_items(name: str, steps: list[dict[str, Any]]) -> None: - """Test extraction of incomplete json that can be fixed""" + """Test extraction of incomplete json that can be fixed.""" text = '' cursor = 0 for step in steps: @@ -141,7 +141,7 @@ def test_extract_items(name: str, steps: list[dict[str, Any]]) -> None: ids=[tc[0] for tc in test_cases_extract_json], ) def test_extract_json(name: str, input_data: dict[str, Any], expected_data: dict[str, Any]) -> None: - """Test if input is unfixable raise the correct exception or return the proper error response""" + """Test if input is unfixable raise the correct exception or return the proper error response.""" if expected_data.get('throws'): with pytest.raises(Exception): extract_json(input_data['text'], throw_on_bad_json=True) @@ -186,6 +186,6 @@ def test_extract_json(name: str, input_data: dict[str, Any], expected_data: dict ids=[tc[0] for tc in test_cases_parse_partial_json], ) def test_parse_partial_json(name: str, input_str: str, expected_data: dict[str, Any]) -> None: - """Test if it fixes simple malformed json string""" + """Test if it fixes simple malformed json string.""" result = parse_partial_json(input_str) assert result == expected_data['expected'] diff --git a/py/packages/genkit/tests/genkit/core/registry_test.py b/py/packages/genkit/tests/genkit/core/registry_test.py index 6c5b2f52f2..d280900c96 100644 --- a/py/packages/genkit/tests/genkit/core/registry_test.py +++ b/py/packages/genkit/tests/genkit/core/registry_test.py @@ -11,7 +11,7 @@ import pytest -from genkit.ai import Genkit, GenkitRegistry, Plugin +from genkit.ai import Genkit from genkit.core.action import ActionMetadata from genkit.core.action.types import ActionKind, ActionMetadataKey from genkit.core.registry import Registry @@ -29,17 +29,21 @@ def list_actions_mock(): assert 'test_plugin' in registry._list_actions_resolvers -def test_register_list_actions_resolver_raises_exception(): - """Test when ValueError is raised.""" +def test_register_list_actions_resolver_multiple(): + """Test that multiple resolvers can be registered for the same plugin.""" registry = Registry() - def list_actions_mock(): + def list_actions_mock1(): + return [] + + def list_actions_mock2(): return [] - registry._list_actions_resolvers['test_plugin'] = list_actions_mock + registry.register_list_actions_resolver('test_plugin', list_actions_mock1) + registry.register_list_actions_resolver('test_plugin', list_actions_mock2) - with pytest.raises(ValueError, match=r'Plugin .* already registered'): - registry.register_list_actions_resolver('test_plugin', list_actions_mock) + assert 'test_plugin' in registry._list_actions_resolvers + assert len(registry._list_actions_resolvers['test_plugin']) == 2 def test_register_action_with_name_and_kind() -> None: @@ -159,47 +163,14 @@ def list_actions_mock(): ] registry = Registry() - registry._list_actions_resolvers['test_plugin'] = list_actions_mock + registry._list_actions_resolvers['test_plugin'] = [list_actions_mock] registry._entries[ActionKind.CUSTOM] = {} registry._entries[ActionKind.TOOL] = {} - got = registry.list_actions({}, allowed_kind) + got = registry.list_actions_sync({}, allowed_kind) assert got == expected -def test_resolve_action_from_plugin(): - """Resolve action from plugin test.""" - resolver_calls = [] - - class MyPlugin(Plugin): - name = 'myplugin' - - def resolve_action(self, ai: GenkitRegistry, kind: ActionKind, name: str): - nonlocal resolver_calls - resolver_calls.append([kind, name]) - - def model_fn(): - pass - - ai.define_model(name=name, fn=model_fn) - - def initialize(self, ai: GenkitRegistry) -> None: - pass - - ai = Genkit(plugins=[MyPlugin()]) - - action = ai.registry.lookup_action(ActionKind.MODEL, 'myplugin/foo') - - assert action is not None - assert len(resolver_calls) == 1 - - assert resolver_calls == [[ActionKind.MODEL, 'myplugin/foo']] - - # should be idempotent - ai.registry.lookup_action(ActionKind.MODEL, 'myplugin/foo') - assert len(resolver_calls) == 1 - - def test_register_value(): """Register a value and lookup test.""" registry = Registry() diff --git a/py/packages/genkit/tests/genkit/lang/deprecations_test.py b/py/packages/genkit/tests/genkit/lang/deprecations_test.py index 5344d98edc..6b88bc3989 100644 --- a/py/packages/genkit/tests/genkit/lang/deprecations_test.py +++ b/py/packages/genkit/tests/genkit/lang/deprecations_test.py @@ -20,8 +20,6 @@ import unittest import warnings -import pytest - if sys.version_info < (3, 11): from strenum import StrEnum else: diff --git a/py/packages/genkit/tests/genkit/veneer/resource_test.py b/py/packages/genkit/tests/genkit/veneer/resource_test.py new file mode 100644 index 0000000000..e3324e5681 --- /dev/null +++ b/py/packages/genkit/tests/genkit/veneer/resource_test.py @@ -0,0 +1,49 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the Genkit Resource API via the Genkit class (Veneer). +This test file verifies that `ai.define_resource` works correctly, mirroring the +JS SDK's `ai.defineResource`. +""" + +import asyncio + +import pytest + +from genkit.ai import Genkit +from genkit.core.typing import Part, TextPart + + +@pytest.mark.asyncio +async def test_define_resource_veneer(): + """Verifies ai.define_resource registers a resource correctly.""" + ai = Genkit(plugins=[]) + + async def my_resource_fn(input, ctx): + return {'content': [Part(root=TextPart(text=f'Content for {input.uri}'))]} + + act = ai.define_resource({'uri': 'http://example.com/foo'}, my_resource_fn) + + assert act.name == 'http://example.com/foo' + assert act.metadata['resource']['uri'] == 'http://example.com/foo' + + # Verify lookup via global registry (contained in ai.registry) + looked_up = ai.registry.lookup_action('resource', 'http://example.com/foo') + assert looked_up == act + + # Verify execution + output = await act.arun({'uri': 'http://example.com/foo'}) + assert 'Content for http://example.com/foo' in output.response['content'][0]['text'] diff --git a/py/packages/genkit/tests/genkit/veneer/veneer_test.py b/py/packages/genkit/tests/genkit/veneer/veneer_test.py index bc987c8294..cc7854b1a8 100644 --- a/py/packages/genkit/tests/genkit/veneer/veneer_test.py +++ b/py/packages/genkit/tests/genkit/veneer/veneer_test.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 # -# Copyright 2025 Google LLC +# Copyright 2026 Google LLC # SPDX-License-Identifier: Apache-2.0 """Tests for the action module.""" @@ -1158,8 +1158,8 @@ async def test_generate_simulates_doc_grounding( assert (await response).request.messages[0] == want_msg -class TestFormat(FormatDef): - """Test format for testing the format.""" +class MockBananaFormat(FormatDef): + """Mock format for testing the format.""" def __init__(self): """Initialize the format.""" @@ -1200,7 +1200,7 @@ async def test_define_format(setup_test: SetupFixture) -> None: """Test that the define format function works.""" ai, _, pm, *_ = setup_test - ai.define_format(TestFormat()) + ai.define_format(MockBananaFormat()) class TestSchema(BaseModel): foo: int = Field(None, description='foo field') diff --git a/py/plugins/anthropic/pyproject.toml b/py/plugins/anthropic/pyproject.toml index b9875c4d11..0ab15ba504 100644 --- a/py/plugins/anthropic/pyproject.toml +++ b/py/plugins/anthropic/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -35,7 +34,7 @@ classifiers = [ ] dependencies = ["genkit", "anthropic>=0.40.0"] description = "Genkit Anthropic Plugin" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "genkit-plugin-anthropic" readme = "README.md" requires-python = ">=3.10" diff --git a/py/plugins/anthropic/src/genkit/plugins/anthropic/plugin.py b/py/plugins/anthropic/src/genkit/plugins/anthropic/plugin.py index 9a71c6dd69..62043b1fb2 100644 --- a/py/plugins/anthropic/src/genkit/plugins/anthropic/plugin.py +++ b/py/plugins/anthropic/src/genkit/plugins/anthropic/plugin.py @@ -16,8 +16,12 @@ """Anthropic plugin for Genkit.""" +import os + from anthropic import AsyncAnthropic -from genkit.ai import GenkitRegistry, Plugin +from genkit.ai import Plugin +from genkit.blocks.model import model +from genkit.core.action import Action, ActionMetadata from genkit.core.registry import ActionKind from genkit.plugins.anthropic.model_info import SUPPORTED_ANTHROPIC_MODELS, get_model_info from genkit.plugins.anthropic.models import AnthropicModel @@ -42,81 +46,110 @@ class Anthropic(Plugin): """Anthropic plugin for Genkit. This plugin adds Anthropic models to Genkit for generative AI applications. + Can be used standalone (without framework) or with Genkit framework. + + Example (standalone): + >>> plugin = Anthropic(api_key='...') + >>> claude = await plugin.model('claude-3-5-sonnet') + >>> response = await claude.arun({'messages': [...]}) + + Example (with framework): + >>> ai = Genkit(plugins=[Anthropic(api_key='...')]) + >>> response = await ai.generate('anthropic/claude-3-5-sonnet', prompt='Hi') """ name = ANTHROPIC_PLUGIN_NAME def __init__( self, + api_key: str | None = None, models: list[str] | None = None, **anthropic_params: str, ) -> None: """Initializes Anthropic plugin with given configuration. Args: - models: List of model names to register. Defaults to all supported models. + api_key: Optional Anthropic API key. If not provided, uses `ANTHROPIC_API_KEY` + from the environment (or lets the Anthropic client handle defaults). + models: Optional list of supported Anthropic models to expose via this plugin. **anthropic_params: Additional parameters passed to the AsyncAnthropic client. This may include api_key, base_url, timeout, and other configuration settings required by Anthropic's API. """ + if api_key is None: + api_key = os.getenv('ANTHROPIC_API_KEY') + self.models = models or list(SUPPORTED_ANTHROPIC_MODELS.keys()) self._anthropic_params = anthropic_params - self._anthropic_client = AsyncAnthropic(**anthropic_params) + self._anthropic_client = ( + AsyncAnthropic(api_key=api_key, **anthropic_params) if api_key else AsyncAnthropic(**anthropic_params) + ) - def initialize(self, ai: GenkitRegistry) -> None: - """Initialize plugin by registering models. + async def init(self) -> list[Action]: + """Return eagerly-initialized model actions. - Args: - ai: The AI registry to initialize the plugin with. + Called once during Genkit initialization. Loads ALL supported + Anthropic models (same behavior as JavaScript). + + Returns: + List of Action objects for all supported models. """ - for model_name in self.models: - self._define_model(ai, model_name) + return [self._create_model_action(model_name) for model_name in self.models] - def resolve_action( - self, - ai: GenkitRegistry, - kind: ActionKind, - name: str, - ) -> None: - """Resolve an action. + async def resolve(self, action_type: ActionKind, name: str) -> Action | None: + """Resolve a specific model action on-demand. + + Called when framework needs an action not from init(). + Enables lazy loading of Anthropic models. Args: - ai: Genkit registry. - kind: Action kind. - name: Action name. + action_type: Type of action requested. + name: Name of action (unprefixed - framework strips plugin prefix). + + Returns: + Action if this plugin can provide it, None otherwise. """ - if kind == ActionKind.MODEL: - self._resolve_model(ai=ai, name=name) + if action_type == ActionKind.MODEL: + # Check if we support this model + if name in self.models: + return self._create_model_action(name) - def _resolve_model(self, ai: GenkitRegistry, name: str) -> None: - """Resolve and define an Anthropic model. + return None - Args: - ai: Genkit registry. - name: Model name (may include plugin prefix). - """ - clean_name = name.replace(f'{ANTHROPIC_PLUGIN_NAME}/', '') if name.startswith(ANTHROPIC_PLUGIN_NAME) else name - self._define_model(ai, clean_name) + async def list_actions(self) -> list[ActionMetadata]: + """Return metadata for all supported Anthropic models. + + Used for discovery and developer tools. - def _define_model(self, ai: GenkitRegistry, model_name: str) -> None: - """Define and register a model. + Returns: + List of ActionMetadata for all supported models. + """ + return [ + ActionMetadata( + name=model_name, + kind=ActionKind.MODEL, + info=get_model_info(model_name).model_dump(), + ) + for model_name in self.models + ] + + def _create_model_action(self, model_name: str) -> Action: + """Create an Action for an Anthropic model (doesn't register). Args: - ai: Genkit registry. - model_name: Model name. + model_name: Name of the Anthropic model (without plugin prefix). + + Returns: + Action instance. """ - model = AnthropicModel(model_name=model_name, client=self._anthropic_client) model_info = get_model_info(model_name) + anthropic_model = AnthropicModel(model_name=model_name, client=self._anthropic_client) - metadata = { - 'model': { - 'supports': model_info.supports.model_dump(), - } - } + metadata = {'model': {'supports': model_info.supports.model_dump()}} - ai.define_model( - name=anthropic_name(model_name), - fn=model.generate, + return model( + name=model_name, + fn=anthropic_model.generate, config_schema=GenerationCommonConfig, metadata=metadata, ) diff --git a/py/plugins/anthropic/tests/test_plugin.py b/py/plugins/anthropic/tests/test_plugin.py index f55247f4e8..3c69cd5f13 100644 --- a/py/plugins/anthropic/tests/test_plugin.py +++ b/py/plugins/anthropic/tests/test_plugin.py @@ -16,8 +16,11 @@ """Tests for Anthropic plugin.""" -from unittest.mock import ANY, MagicMock, patch +from unittest.mock import patch +import pytest + +from genkit.ai import Genkit from genkit.core.registry import ActionKind from genkit.plugins.anthropic import Anthropic, anthropic_name from genkit.plugins.anthropic.model_info import ( @@ -68,30 +71,27 @@ def test_custom_models(): assert plugin.models == ['claude-sonnet-4'] -def test_plugin_initialize(): - """Test plugin registry initialization.""" - registry = MagicMock() +@pytest.mark.asyncio +async def test_plugin_initialize(): + """Test plugin registration with the Genkit framework.""" plugin = Anthropic(api_key='test-key', models=['claude-sonnet-4']) - plugin.initialize(registry) + ai = Genkit(plugins=[plugin]) - assert registry.define_model.call_count == 1 - registry.define_model.assert_called_once_with( - name='anthropic/claude-sonnet-4', - fn=ANY, - config_schema=ANY, - metadata=ANY, - ) + action = await ai.registry.resolve_action(ActionKind.MODEL, 'anthropic/claude-sonnet-4') + assert action is not None + assert action.name == 'anthropic/claude-sonnet-4' -def test_resolve_action_model(): - """Test resolve_action for model.""" - registry = MagicMock() - plugin = Anthropic(api_key='test-key') +@pytest.mark.asyncio +async def test_resolve_action_model(): + """Test resolve() can lazily provide a model action.""" + plugin = Anthropic(api_key='test-key', models=['claude-sonnet-4']) - plugin.resolve_action(registry, ActionKind.MODEL, 'claude-sonnet-4') + action = await plugin.resolve(ActionKind.MODEL, 'claude-sonnet-4') - registry.define_model.assert_called_once() + assert action is not None + assert action.name == 'claude-sonnet-4' def test_supported_models(): diff --git a/py/plugins/compat-oai/pyproject.toml b/py/plugins/compat-oai/pyproject.toml index 44a4fe8282..3dd804f716 100644 --- a/py/plugins/compat-oai/pyproject.toml +++ b/py/plugins/compat-oai/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -35,7 +34,7 @@ classifiers = [ ] dependencies = ["genkit", "openai", "strenum>=0.4.15; python_version < '3.11'"] description = "Genkit OpenAI API Compatible" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "genkit-plugin-compat-oai" readme = "README.md" requires-python = ">=3.10" diff --git a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/handler.py b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/handler.py index 70b3e4e2f9..f02cfb060d 100644 --- a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/handler.py +++ b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/handler.py @@ -21,7 +21,7 @@ from openai import OpenAI -from genkit.ai import ActionRunContext, GenkitRegistry +from genkit.ai import ActionRunContext from genkit.plugins.compat_oai.models.model import OpenAIModel from genkit.plugins.compat_oai.models.model_info import ( SUPPORTED_OPENAI_COMPAT_MODELS, @@ -51,6 +51,7 @@ def __init__(self, model: Any, source: PluginSource = PluginSource.OPENAI) -> No @staticmethod def _get_supported_models(source: PluginSource) -> dict[str, Any]: """Returns the supported models based on the plugin source. + Args: source: Helps distinguish if model handler is called from model-garden plugin. Default source is openai. @@ -59,12 +60,11 @@ def _get_supported_models(source: PluginSource) -> dict[str, Any]: Openai models if source is openai. Merges supported openai models with openai-compat models if source is model-garden. """ - return SUPPORTED_OPENAI_COMPAT_MODELS if source == PluginSource.MODEL_GARDEN else SUPPORTED_OPENAI_MODELS @classmethod def get_model_handler( - cls, model: str, client: OpenAI, registry: GenkitRegistry, source: PluginSource = PluginSource.OPENAI + cls, model: str, client: OpenAI, source: PluginSource = PluginSource.OPENAI ) -> Callable[[GenerateRequest, ActionRunContext], GenerateResponse]: """Factory method to initialize the model handler for the specified OpenAI model. @@ -89,10 +89,13 @@ def get_model_handler( """ supported_models = cls._get_supported_models(source) - if model not in supported_models: + # For the OpenAI compat plugin, we allow arbitrary model names (the OpenAI API + # can serve models beyond our static known list). For Model Garden, keep the + # strict validation. + if model not in supported_models and source == PluginSource.MODEL_GARDEN: raise ValueError(f"Model '{model}' is not supported.") - openai_model = OpenAIModel(model, client, registry) + openai_model = OpenAIModel(model, client) return cls(openai_model, source).generate def _validate_version(self, version: str) -> None: @@ -105,7 +108,10 @@ def _validate_version(self, version: str) -> None: ValueError: If the specified model version is not supported. """ supported_models = self._get_supported_models(self._source) - model_info = supported_models[self._model.name] + model_info = supported_models.get(self._model.name) + if model_info is None: + # Unknown model; skip version validation. + return if version not in model_info.versions: raise ValueError(f"Model version '{version}' is not supported.") diff --git a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/model.py b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/model.py index bf1f7315aa..08e80f6045 100644 --- a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/model.py +++ b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/models/model.py @@ -19,10 +19,9 @@ from collections.abc import Callable from typing import Any -from openai import OpenAI, pydantic_function_tool +from openai import OpenAI from openai.lib._pydantic import _ensure_strict_json_schema -from genkit.ai import ActionKind, GenkitRegistry from genkit.core.action._action import ActionRunContext from genkit.plugins.compat_oai.models.model_info import SUPPORTED_OPENAI_MODELS from genkit.plugins.compat_oai.models.utils import DictMessageAdapter, MessageAdapter, MessageConverter @@ -41,17 +40,15 @@ class OpenAIModel: """Handles OpenAI API interactions for the Genkit plugin.""" - def __init__(self, model: str, client: OpenAI, registry: GenkitRegistry): + def __init__(self, model: str, client: OpenAI): """Initializes the OpenAIModel instance with the specified model and OpenAI client parameters. Args: model: The OpenAI model to use for generating responses. client: OpenAI client instance. - registry: The registry where OpenAI models will be registered. """ self._model = model self._openai_client = client - self._registry = registry @property def name(self) -> str: @@ -85,15 +82,19 @@ def _get_tools_definition(self, tools: list[ToolDefinition]) -> list[dict]: Returns: A list of dictionaries representing the formatted tools. """ - result = [] + # NOTE: ToolDefinition objects already contain JSON Schema for inputs/outputs. + # Do NOT reach back into the registry to reconstruct schemas. + result: list[dict[str, Any]] = [] for tool_definition in tools: - action = self._registry.registry.lookup_action(ActionKind.TOOL, tool_definition.name) - function_call = pydantic_function_tool( - model=action.input_type._type, - name=tool_definition.name, - description=tool_definition.description, - ) - result.append(function_call) + parameters = tool_definition.input_schema or {'type': 'object', 'properties': {}} + result.append({ + 'type': 'function', + 'function': { + 'name': tool_definition.name, + 'description': tool_definition.description, + 'parameters': parameters, + }, + }) return result def _get_response_format(self, output: OutputConfig) -> dict | None: @@ -140,6 +141,11 @@ def _get_openai_request_config(self, request: GenerateRequest) -> dict: } if request.tools: openai_config['tools'] = self._get_tools_definition(request.tools) + if any(msg.role == Role.TOOL for msg in request.messages): + # After a tool response, stop forcing additional tool calls. + openai_config['tool_choice'] = 'none' + elif request.tool_choice: + openai_config['tool_choice'] = request.tool_choice if request.output: openai_config['response_format'] = self._get_response_format(request.output) if request.config: diff --git a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/openai_plugin.py b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/openai_plugin.py index 19cd4ebcf8..6fe1d129ec 100644 --- a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/openai_plugin.py +++ b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/openai_plugin.py @@ -15,25 +15,23 @@ # SPDX-License-Identifier: Apache-2.0 -"""OpenAI OpenAI API Compatible Plugin for Genkit.""" +"""OpenAI OpenAI API Compatible Plugin for Genkit (v2).""" -from functools import cached_property -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from openai import OpenAI as OpenAIClient -from openai.types import Embedding, Model +from openai.types import Model -from genkit.ai._plugin import Plugin -from genkit.ai._registry import GenkitRegistry +from genkit.ai import Plugin from genkit.blocks.embedding import EmbedderOptions, EmbedderSupports, embedder_action_metadata -from genkit.blocks.model import model_action_metadata +from genkit.blocks.model import model, model_action_metadata from genkit.core.action import ActionMetadata -from genkit.core.action.types import ActionKind +from genkit.core.registry import ActionKind from genkit.core.typing import GenerationCommonConfig from genkit.plugins.compat_oai.models import ( SUPPORTED_OPENAI_COMPAT_MODELS, SUPPORTED_OPENAI_MODELS, - OpenAIModel, OpenAIModelHandler, ) from genkit.plugins.compat_oai.models.model_info import get_default_openai_model_info @@ -78,26 +76,9 @@ def __init__(self, **openai_params: str) -> None: self._openai_params = openai_params self._openai_client = OpenAIClient(**openai_params) - def initialize(self, ai: GenkitRegistry) -> None: - """Registers supported OpenAI models in the given registry. - - Args: - ai: The registry where OpenAI models will be registered. - """ - for model_name, model_info in SUPPORTED_OPENAI_MODELS.items(): - handler = OpenAIModelHandler.get_model_handler(model=model_name, client=self._openai_client, registry=ai) - - ai.define_model( - name=f'openai/{model_name}', - fn=handler, - config_schema=OpenAIConfig, - metadata={ - 'model': { - 'label': model_info.label, - 'supports': {'multiturn': model_info.supports.multiturn} if model_info.supports else {}, - }, - }, - ) + async def init(self): + """Return eagerly-initialized model actions.""" + return [self._create_model_action(model_name) for model_name in SUPPORTED_OPENAI_MODELS.keys()] def get_model_info(self, name: str) -> dict[str, str] | None: """Retrieves metadata and supported features for the specified model. @@ -111,66 +92,40 @@ def get_model_info(self, name: str) -> dict[str, str] | None: is provided). The 'supports' key contains a dictionary representing the model's capabilities (e.g., tools, streaming). """ - if model_supported := SUPPORTED_OPENAI_MODELS.get(name): return { 'label': model_supported.label, 'supports': model_supported.supports.model_dump(exclude_none=True), } - model_info = SUPPORTED_OPENAI_COMPAT_MODELS.get(name, get_default_openai_model_info(self)) + model_info = SUPPORTED_OPENAI_COMPAT_MODELS.get(name, get_default_openai_model_info(name)) return { 'label': model_info.label, 'supports': model_info.supports.model_dump(exclude_none=True), } - def resolve_action( # noqa: B027 - self, - ai: GenkitRegistry, - kind: ActionKind, - name: str, - ) -> None: - if kind is not ActionKind.MODEL: + async def resolve(self, action_type: ActionKind, name: str): + if action_type != ActionKind.MODEL: return None - self._define_openai_model(ai, name) - return None - - def to_openai_compatible_model(self, name: str, ai: GenkitRegistry) -> Callable: - """Converts a OpenAi model into an OpenAI-compatible Genkit model function. - - Returns: - A callable function (specifically, the `generate` method of an - `OpenAIModel` instance) that can be used by Genkit. - """ - - openai_model = OpenAIModelHandler(OpenAIModel(name, self._openai_client, ai)) - return openai_model.generate + clean_name = name.replace('openai/', '') if name.startswith('openai/') else name + return self._create_model_action(clean_name) - def _define_openai_model(self, ai: GenkitRegistry, name: str) -> None: - """Defines and registers an OpenAI model with Genkit. + def to_openai_compatible_model(self, name: str) -> Callable: + """Return a Genkit model handler for a specific OpenAI model name.""" + return OpenAIModelHandler.get_model_handler(model=name, client=self._openai_client) - Cleans the model name, instantiates an OpenAI, and registers it - with the provided Genkit AI registry, including metadata about its capabilities. - - Args: - ai: The Genkit AI registry instance. - name: The name of the model to be registered. - """ - - handler = self.to_openai_compatible_model(name, ai) + def _create_model_action(self, name: str): + handler = self.to_openai_compatible_model(name) model_info = self.get_model_info(name) - ai.define_model( - name=open_ai_name(name), + return model( + name=name, fn=handler, config_schema=OpenAIConfig, - metadata={ - 'model': model_info, - }, + metadata={'model': model_info}, ) - @cached_property - def list_actions(self) -> list[ActionMetadata]: + async def list_actions(self) -> list[ActionMetadata]: """Generate a list of available actions or models. Returns: @@ -180,7 +135,6 @@ def list_actions(self) -> list[ActionMetadata]: - info (dict): The metadata dictionary describing the model configuration and properties. - config_schema (type): The schema class used for validating the model's configuration. """ - actions = [] models_ = self._openai_client.models.list() models: list[Model] = models_.data diff --git a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/typing.py b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/typing.py index d8ac810dea..0b257bf820 100644 --- a/py/plugins/compat-oai/src/genkit/plugins/compat_oai/typing.py +++ b/py/plugins/compat-oai/src/genkit/plugins/compat_oai/typing.py @@ -24,7 +24,7 @@ else: # noqa from enum import StrEnum # noqa -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field class OpenAIConfig(BaseModel): @@ -38,10 +38,14 @@ class OpenAIConfig(BaseModel): stop: str | list[str] | None = None max_tokens: int | None = None stream: bool | None = None + frequency_penalty: float | None = Field(default=None, ge=-2, le=2) + presence_penalty: float | None = Field(default=None, ge=-2, le=2) + logprobs: bool | None = None + top_logprobs: int | None = Field(default=None, ge=0, le=20) class SupportedOutputFormat(StrEnum): - """Model Output Formats""" + """Model Output Formats.""" JSON_MODE = 'json_mode' STRUCTURED_OUTPUTS = 'structured_outputs' diff --git a/py/plugins/compat-oai/tests/test_handler.py b/py/plugins/compat-oai/tests/test_handler.py index f0321ef3d2..a6abcd8a45 100644 --- a/py/plugins/compat-oai/tests/test_handler.py +++ b/py/plugins/compat-oai/tests/test_handler.py @@ -19,28 +19,30 @@ import pytest -from genkit.ai import ActionRunContext from genkit.plugins.compat_oai.models import OpenAIModelHandler -from genkit.plugins.compat_oai.models.model import OpenAIModel from genkit.plugins.compat_oai.models.model_info import ( GPT_3_5_TURBO, GPT_4, SUPPORTED_OPENAI_MODELS, + PluginSource, ) -from genkit.types import GenerateRequest, GenerateResponse, Message, Role, TextPart def test_get_model_handler() -> None: """Test get_model_handler method returns a callable.""" model_name = GPT_4 - handler = OpenAIModelHandler.get_model_handler(model=model_name, client=MagicMock(), registry=MagicMock()) + handler = OpenAIModelHandler.get_model_handler(model=model_name, client=MagicMock()) assert callable(handler) def test_get_model_handler_invalid() -> None: """Test get_model_handler raises ValueError for unsupported models.""" with pytest.raises(ValueError, match="Model 'unsupported-model' is not supported."): - OpenAIModelHandler.get_model_handler(model='unsupported-model', client=MagicMock(), registry=MagicMock()) + OpenAIModelHandler.get_model_handler( + model='unsupported-model', + client=MagicMock(), + source=PluginSource.MODEL_GARDEN, + ) def test_validate_version() -> None: diff --git a/py/plugins/compat-oai/tests/test_model.py b/py/plugins/compat-oai/tests/test_model.py index 0d16d97b12..ec95bb266f 100644 --- a/py/plugins/compat-oai/tests/test_model.py +++ b/py/plugins/compat-oai/tests/test_model.py @@ -36,7 +36,7 @@ def test_get_messages(sample_request): Ensures the method correctly converts GenerateRequest messages into OpenAI-compatible ChatMessage format. """ - model = OpenAIModel(model=GPT_4, client=MagicMock(), registry=MagicMock()) + model = OpenAIModel(model=GPT_4, client=MagicMock()) messages = model._get_messages(sample_request.messages) assert len(messages) == 2 @@ -51,7 +51,7 @@ def test_get_openai_config(sample_request): Ensures the method correctly constructs the OpenAI API configuration dictionary. """ - model = OpenAIModel(model=GPT_4, client=MagicMock(), registry=MagicMock()) + model = OpenAIModel(model=GPT_4, client=MagicMock()) openai_config = model._get_openai_request_config(sample_request) assert isinstance(openai_config, dict) @@ -72,7 +72,7 @@ def test__generate(sample_request): mock_client = MagicMock() mock_client.chat.completions.create.return_value = mock_response - model = OpenAIModel(model=GPT_4, client=mock_client, registry=MagicMock()) + model = OpenAIModel(model=GPT_4, client=mock_client) response = model._generate(sample_request) mock_client.chat.completions.create.assert_called_once() @@ -112,7 +112,7 @@ def __next__(self): mock_client.chat.completions.create.return_value = MockStream(['Hello', ', world!']) - model = OpenAIModel(model=GPT_4, client=mock_client, registry=MagicMock()) + model = OpenAIModel(model=GPT_4, client=mock_client) collected_chunks = [] def callback(chunk: GenerateResponseChunk): @@ -137,7 +137,7 @@ def test_generate(stream, sample_request): mock_response = GenerateResponse(message=Message(role=Role.MODEL, content=[TextPart(text='mocked')])) - model = OpenAIModel(model=GPT_4, client=MagicMock(), registry=MagicMock()) + model = OpenAIModel(model=GPT_4, client=MagicMock()) model._generate_stream = MagicMock(return_value=mock_response) model._generate = MagicMock(return_value=mock_response) model.normalize_config = MagicMock(return_value={}) diff --git a/py/plugins/compat-oai/tests/test_plugin.py b/py/plugins/compat-oai/tests/test_plugin.py index c6ba1cf790..d88ad575be 100644 --- a/py/plugins/compat-oai/tests/test_plugin.py +++ b/py/plugins/compat-oai/tests/test_plugin.py @@ -14,69 +14,49 @@ # # SPDX-License-Identifier: Apache-2.0 -from unittest.mock import ANY, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest from openai.types import Model -from genkit.ai._aio import Genkit from genkit.core.action import ActionMetadata from genkit.core.action.types import ActionKind -from genkit.plugins.compat_oai import OpenAIConfig from genkit.plugins.compat_oai.models.model_info import SUPPORTED_OPENAI_MODELS from genkit.plugins.compat_oai.openai_plugin import OpenAI, openai_model -def test_openai_plugin_initialize() -> None: - """Test OpenAI plugin registry initialization.""" - registry = MagicMock(spec=Genkit) +@pytest.mark.asyncio +async def test_openai_plugin_initialize() -> None: + """Test OpenAI plugin init() returns model actions.""" plugin = OpenAI(api_key='test-key') with patch('genkit.plugins.compat_oai.models.OpenAIModelHandler.get_model_handler') as mock_get_handler: mock_handler = MagicMock() mock_get_handler.return_value = mock_handler - plugin.initialize(registry) + actions = await plugin.init() assert mock_get_handler.call_count == len(SUPPORTED_OPENAI_MODELS) - assert registry.define_model.call_count == len(SUPPORTED_OPENAI_MODELS) + assert len(actions) == len(SUPPORTED_OPENAI_MODELS) @pytest.mark.parametrize( 'kind, name', [(ActionKind.MODEL, 'gpt-3.5-turbo')], ) -def test_openai_plugin_resolve_action(kind, name): - """Unit Tests for resolve action method.""" +@pytest.mark.asyncio +async def test_openai_plugin_resolve_action(kind, name): + """Unit Tests for resolve method.""" plugin = OpenAI(api_key='test-key') - registry = MagicMock(spec=Genkit) - plugin.resolve_action(registry, kind, name) - - model_info = SUPPORTED_OPENAI_MODELS[name] - - registry.define_model.assert_called_once_with( - name=f'openai/{name}', - fn=ANY, - config_schema=OpenAIConfig, - metadata={ - 'model': { - 'label': model_info.label, - 'supports': { - 'media': False, - 'multiturn': True, - 'output': [ - 'json_mode', - 'text', - ], - 'system_role': True, - 'tools': True, - }, - }, - }, - ) - - -def test_openai_plugin_list_actions() -> None: + action = await plugin.resolve(kind, name) + assert action is not None + assert action.kind == ActionKind.MODEL + assert action.name == name + assert action.metadata is not None + + +@pytest.mark.asyncio +async def test_openai_plugin_list_actions() -> None: entries = [ Model(id='gpt-4-0613', created=1686588896, object='model', owned_by='openai'), Model(id='gpt-4', created=1687882411, object='model', owned_by='openai'), @@ -94,9 +74,7 @@ def test_openai_plugin_list_actions() -> None: plugin._openai_client = mock_client - actions: list[ActionMetadata] = plugin.list_actions - mock_client.models.list.assert_called_once() - _ = plugin.list_actions + actions: list[ActionMetadata] = await plugin.list_actions() mock_client.models.list.assert_called_once() assert len(actions) == len(entries) @@ -108,14 +86,14 @@ def test_openai_plugin_list_actions() -> None: 'kind, name', [(ActionKind.MODEL, 'model_doesnt_exist')], ) -def test_openai_plugin_resolve_action_not_found(kind, name): - """Unit Tests for resolve action method.""" - +@pytest.mark.asyncio +async def test_openai_plugin_resolve_action_not_found(kind, name): + """Unknown models are still resolvable (compat plugin).""" plugin = OpenAI(api_key='test-key') - registry = MagicMock(spec=Genkit) - plugin.resolve_action(registry, kind, name) - - registry.define_model.assert_called_once() + action = await plugin.resolve(kind, name) + assert action is not None + assert action.kind == ActionKind.MODEL + assert action.name == name def test_openai_model_function() -> None: diff --git a/py/plugins/compat-oai/tests/test_tool_calling.py b/py/plugins/compat-oai/tests/test_tool_calling.py index 6aaf581df9..288f4b3c44 100644 --- a/py/plugins/compat-oai/tests/test_tool_calling.py +++ b/py/plugins/compat-oai/tests/test_tool_calling.py @@ -56,7 +56,7 @@ def test_generate_with_tool_calls_executes_tools(sample_request: GenerateRequest second_response, ] - model = OpenAIModel(model=GPT_4, client=mock_client, registry=MagicMock()) + model = OpenAIModel(model=GPT_4, client=mock_client) response = model._generate(sample_request) @@ -79,9 +79,7 @@ def test_generate_with_tool_calls_executes_tools(sample_request: GenerateRequest def test_generate_stream_with_tool_calls(sample_request): - """ - Test generate_stream processes tool calls streamed in chunks correctly. - """ + """Test generate_stream processes tool calls streamed in chunks correctly.""" mock_client = MagicMock() class MockToolCall: @@ -127,7 +125,7 @@ def __next__(self): mock_client.chat.completions.create.return_value = MockStream() - model = OpenAIModel(model=GPT_4, client=mock_client, registry=MagicMock()) + model = OpenAIModel(model=GPT_4, client=mock_client) collected_chunks = [] def callback(chunk: GenerateResponseChunk): diff --git a/py/plugins/deepseek/pyproject.toml b/py/plugins/deepseek/pyproject.toml new file mode 100644 index 0000000000..d4cbe0ef82 --- /dev/null +++ b/py/plugins/deepseek/pyproject.toml @@ -0,0 +1,47 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +[project] +authors = [{ name = "Google" }] +classifiers = [ + "Development Status :: 3 - Alpha", + "Environment :: Console", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries", +] +dependencies = ["genkit", "genkit-plugin-compat-oai", "openai>=1.0.0"] +description = "Genkit DeepSeek Plugin" +license = "Apache-2.0" +name = "genkit-plugin-deepseek" +requires-python = ">=3.10" +version = "0.1.0" + +[build-system] +build-backend = "hatchling.build" +requires = ["hatchling"] + +[tool.hatch.build.targets.wheel] +packages = ["src/genkit", "src/genkit/plugins"] diff --git a/py/plugins/deepseek/src/genkit/plugins/deepseek/__init__.py b/py/plugins/deepseek/src/genkit/plugins/deepseek/__init__.py new file mode 100644 index 0000000000..24021a619e --- /dev/null +++ b/py/plugins/deepseek/src/genkit/plugins/deepseek/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""DeepSeek plugin for Genkit.""" + +from .models import deepseek_name +from .plugin import DeepSeek + +__all__ = ['DeepSeek', 'deepseek_name'] diff --git a/py/plugins/deepseek/src/genkit/plugins/deepseek/client.py b/py/plugins/deepseek/src/genkit/plugins/deepseek/client.py new file mode 100644 index 0000000000..56ed84f247 --- /dev/null +++ b/py/plugins/deepseek/src/genkit/plugins/deepseek/client.py @@ -0,0 +1,40 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""DeepSeek API client.""" + +from openai import OpenAI as _OpenAI + + +class DeepSeekClient: + """DeepSeek API client initialization.""" + + def __new__(cls, **deepseek_params) -> _OpenAI: + """Initialize the DeepSeek client. + + Args: + **deepseek_params: Client configuration parameters including: + - api_key: DeepSeek API key. + - base_url: API base URL (defaults to https://api.deepseek.com). + - Additional OpenAI client parameters. + + Returns: + Configured OpenAI client instance. + """ + api_key = deepseek_params.pop('api_key') + base_url = deepseek_params.pop('base_url', 'https://api.deepseek.com') + + return _OpenAI(api_key=api_key, base_url=base_url, **deepseek_params) diff --git a/py/plugins/deepseek/src/genkit/plugins/deepseek/model_info.py b/py/plugins/deepseek/src/genkit/plugins/deepseek/model_info.py new file mode 100644 index 0000000000..9601f58c61 --- /dev/null +++ b/py/plugins/deepseek/src/genkit/plugins/deepseek/model_info.py @@ -0,0 +1,58 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""DeepSeek model information and metadata.""" + +from genkit.types import ModelInfo, Supports + +__all__ = ['SUPPORTED_DEEPSEEK_MODELS', 'get_default_model_info'] + +# Model capabilities matching JS implementation +_DEEPSEEK_SUPPORTS = Supports( + multiturn=True, + tools=True, + media=False, + system_role=True, + output=['text', 'json'], +) + +SUPPORTED_DEEPSEEK_MODELS: dict[str, ModelInfo] = { + 'deepseek-reasoner': ModelInfo( + label='DeepSeek - Reasoner', + versions=['deepseek-reasoner'], + supports=_DEEPSEEK_SUPPORTS, + ), + 'deepseek-chat': ModelInfo( + label='DeepSeek - Chat', + versions=['deepseek-chat'], + supports=_DEEPSEEK_SUPPORTS, + ), +} + + +def get_default_model_info(name: str) -> ModelInfo: + """Get default model information for unknown DeepSeek models. + + Args: + name: Model name. + + Returns: + Default ModelInfo with standard DeepSeek capabilities. + """ + return ModelInfo( + label=f'DeepSeek - {name}', + supports=_DEEPSEEK_SUPPORTS, + ) diff --git a/py/plugins/deepseek/src/genkit/plugins/deepseek/models.py b/py/plugins/deepseek/src/genkit/plugins/deepseek/models.py new file mode 100644 index 0000000000..65a5b76cb0 --- /dev/null +++ b/py/plugins/deepseek/src/genkit/plugins/deepseek/models.py @@ -0,0 +1,124 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""DeepSeek model integration for Genkit.""" + +from collections.abc import Callable +from typing import Any + +from genkit.ai import GenkitRegistry +from genkit.plugins.compat_oai.models.model import OpenAIModel +from genkit.plugins.compat_oai.typing import OpenAIConfig +from genkit.plugins.deepseek.client import DeepSeekClient +from genkit.plugins.deepseek.model_info import ( + SUPPORTED_DEEPSEEK_MODELS, + get_default_model_info, +) + +DEEPSEEK_PLUGIN_NAME = 'deepseek' + + +def deepseek_name(name: str) -> str: + """Create a DeepSeek action name. + + Args: + name: Base name for the action. + + Returns: + The fully qualified DeepSeek action name. + """ + return f'{DEEPSEEK_PLUGIN_NAME}/{name}' + + +class DeepSeekModel: + """Manages DeepSeek model integration for Genkit. + + This class provides integration with DeepSeek's OpenAI-compatible API, + allowing DeepSeek models to be exposed as Genkit models. It handles + client initialization, model information retrieval, and dynamic model + definition within the Genkit registry. + + Follows the Model Garden pattern for implementation consistency. + """ + + def __init__( + self, + model: str, + api_key: str, + registry: GenkitRegistry, + **deepseek_params, + ) -> None: + """Initialize the DeepSeek instance. + + Args: + model: The name of the specific DeepSeek model (e.g., 'deepseek-chat'). + api_key: DeepSeek API key for authentication. + registry: An instance of GenkitRegistry to register the model. + **deepseek_params: Additional parameters for the DeepSeek client. + """ + self.name = model + self.ai = registry + client_params = {'api_key': api_key, **deepseek_params} + self.client = DeepSeekClient(**client_params) + + def get_model_info(self) -> dict[str, Any] | None: + """Retrieve metadata and supported features for the specified model. + + This method looks up the model's information from a predefined list + of supported DeepSeek models or provides default information. + + Returns: + A dictionary containing the model's 'name' and 'supports' features. + The 'supports' key contains a dictionary representing the model's + capabilities (e.g., tools, streaming). + """ + model_info = SUPPORTED_DEEPSEEK_MODELS.get(self.name, get_default_model_info(self.name)) + return { + 'name': model_info.label, + 'supports': model_info.supports.model_dump(), + } + + def to_deepseek_model(self) -> Callable: + """Convert the DeepSeek model into a Genkit-compatible model function. + + This method wraps the underlying DeepSeek client and its generation + logic into a callable that adheres to the OpenAI model interface + expected by Genkit. + + Returns: + A callable function (the generate method of an OpenAIModel instance) + that can be used by Genkit. + """ + deepseek_model = OpenAIModel(self.name, self.client, self.ai) + return deepseek_model.generate + + def define_model(self) -> None: + """Define and register the DeepSeek model with the Genkit registry. + + This method orchestrates the retrieval of model metadata and the + creation of the generation function, then registers this model + within the Genkit framework using self.ai.define_model. + """ + model_info = self.get_model_info() + generate_fn = self.to_deepseek_model() + self.ai.define_model( + name=deepseek_name(self.name), + fn=generate_fn, + config_schema=OpenAIConfig, + metadata={ + 'model': model_info, + }, + ) diff --git a/py/plugins/deepseek/src/genkit/plugins/deepseek/plugin.py b/py/plugins/deepseek/src/genkit/plugins/deepseek/plugin.py new file mode 100644 index 0000000000..2943838c87 --- /dev/null +++ b/py/plugins/deepseek/src/genkit/plugins/deepseek/plugin.py @@ -0,0 +1,140 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""DeepSeek Plugin for Genkit.""" + +import os +from functools import cached_property + +from genkit.ai import GenkitRegistry, Plugin +from genkit.blocks.model import model_action_metadata +from genkit.core.action import ActionMetadata +from genkit.core.action.types import ActionKind +from genkit.core.error import GenkitError +from genkit.plugins.compat_oai.typing import OpenAIConfig +from genkit.plugins.deepseek.model_info import SUPPORTED_DEEPSEEK_MODELS +from genkit.plugins.deepseek.models import DEEPSEEK_PLUGIN_NAME, DeepSeekModel, deepseek_name + + +class DeepSeek(Plugin): + """DeepSeek plugin for Genkit. + + This plugin provides integration with DeepSeek's OpenAI-compatible API, + enabling the use of DeepSeek models within the Genkit framework. + """ + + name = DEEPSEEK_PLUGIN_NAME + + def __init__( + self, + api_key: str | None = None, + models: list[str] | None = None, + **deepseek_params, + ) -> None: + """Initialize the plugin and set up its configuration. + + Args: + api_key: The DeepSeek API key. If not provided, it attempts to load + from the DEEPSEEK_API_KEY environment variable. + models: An optional list of model names to register with the plugin. + If None, all supported models will be registered. + **deepseek_params: Additional parameters for the DeepSeek client. + + Raises: + GenkitError: If no API key is provided via parameter or environment. + """ + self.api_key = api_key if api_key is not None else os.getenv('DEEPSEEK_API_KEY') + + if not self.api_key: + raise GenkitError(message='Please provide api_key or set DEEPSEEK_API_KEY environment variable.') + + self.models = models + self.deepseek_params = deepseek_params + + def initialize(self, ai: GenkitRegistry) -> None: + """Initialize the plugin by registering specified models. + + Args: + ai: The Genkit registry where models will be registered. + """ + models = self.models + if models is None: + models = list(SUPPORTED_DEEPSEEK_MODELS.keys()) + + for model in models: + deepseek_model = DeepSeekModel( + model=model, + api_key=self.api_key, + registry=ai, + **self.deepseek_params, + ) + deepseek_model.define_model() + + def resolve_action( + self, + ai: GenkitRegistry, + kind: ActionKind, + name: str, + ) -> None: + """Resolve and register an action dynamically. + + Args: + ai: The Genkit registry. + kind: The kind of action to resolve. + name: The name of the action to resolve. + """ + if kind == ActionKind.MODEL: + self._resolve_model(ai=ai, name=name) + + def _resolve_model(self, ai: GenkitRegistry, name: str) -> None: + """Resolve and define a DeepSeek model within the Genkit registry. + + This internal method handles the logic for registering DeepSeek models + dynamically based on the provided name. It extracts a clean name, + instantiates the DeepSeek class, and registers it with the registry. + + Args: + ai: The Genkit AI registry instance to define the model in. + name: The name of the model to resolve. This name might include a + prefix indicating it's from the DeepSeek plugin. + """ + clean_name = name.replace(DEEPSEEK_PLUGIN_NAME + '/', '') if name.startswith(DEEPSEEK_PLUGIN_NAME) else name + + deepseek_model = DeepSeekModel( + model=clean_name, + api_key=self.api_key, + registry=ai, + **self.deepseek_params, + ) + deepseek_model.define_model() + + @cached_property + def list_actions(self) -> list[ActionMetadata]: + """Generate a list of available DeepSeek models. + + Returns: + list[ActionMetadata]: A list of ActionMetadata objects for each + supported DeepSeek model, including name, metadata, and config schema. + """ + actions_list = [] + for model, model_info in SUPPORTED_DEEPSEEK_MODELS.items(): + actions_list.append( + model_action_metadata( + name=deepseek_name(model), info=model_info.model_dump(), config_schema=OpenAIConfig + ) + ) + + return actions_list diff --git a/py/plugins/deepseek/tests/test_deepseek_plugin.py b/py/plugins/deepseek/tests/test_deepseek_plugin.py new file mode 100644 index 0000000000..150d1d23e6 --- /dev/null +++ b/py/plugins/deepseek/tests/test_deepseek_plugin.py @@ -0,0 +1,185 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for DeepSeek plugin.""" + +import os +from unittest.mock import MagicMock, patch + +import pytest + +from genkit.core.error import GenkitError +from genkit.core.registry import ActionKind +from genkit.plugins.deepseek import DeepSeek, deepseek_name + + +def test_deepseek_name(): + """Test name helper function.""" + assert deepseek_name('deepseek-chat') == 'deepseek/deepseek-chat' + assert deepseek_name('deepseek-reasoner') == 'deepseek/deepseek-reasoner' + + +def test_plugin_initialization_with_api_key(): + """Test plugin initializes with API key.""" + plugin = DeepSeek(api_key='test-key') + assert plugin.name == 'deepseek' + assert plugin.api_key == 'test-key' + + +def test_plugin_initialization_from_env(): + """Test plugin reads API key from environment.""" + with patch.dict(os.environ, {'DEEPSEEK_API_KEY': 'env-key'}): + plugin = DeepSeek() + assert plugin.api_key == 'env-key' + + +def test_plugin_initialization_without_api_key(): + """Test plugin raises error without API key.""" + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(GenkitError) as exc_info: + DeepSeek() + assert 'DEEPSEEK_API_KEY' in str(exc_info.value) + + +@patch('genkit.plugins.deepseek.models.DeepSeekClient') +def test_plugin_initialize(mock_client): + """Test plugin registers models during initialization.""" + plugin = DeepSeek(api_key='test-key', models=['deepseek-chat']) + mock_registry = MagicMock() + + plugin.initialize(mock_registry) + + # Should call define_model for the specified model + mock_registry.define_model.assert_called_once() + + +@patch('genkit.plugins.deepseek.models.DeepSeekClient') +def test_plugin_resolve_action(mock_client): + """Test plugin resolves models dynamically.""" + plugin = DeepSeek(api_key='test-key', models=[]) + mock_registry = MagicMock() + + plugin.resolve_action(mock_registry, ActionKind.MODEL, 'deepseek/deepseek-chat') + + # Should register the requested model + mock_registry.define_model.assert_called_once() + + +def test_plugin_list_actions(): + """Test plugin lists available models.""" + plugin = DeepSeek(api_key='test-key') + actions = plugin.list_actions + + assert len(actions) == 2 + action_names = [action.name for action in actions] + assert 'deepseek/deepseek-reasoner' in action_names + assert 'deepseek/deepseek-chat' in action_names + + +@patch('genkit.plugins.deepseek.models.DeepSeekClient') +def test_plugin_with_custom_params(mock_client): + """Test plugin accepts custom parameters.""" + plugin = DeepSeek( + api_key='test-key', + models=['deepseek-chat'], + timeout=60, + max_retries=3, + ) + + assert plugin.deepseek_params['timeout'] == 60 + assert plugin.deepseek_params['max_retries'] == 3 + + +@patch('genkit.plugins.deepseek.models.DeepSeekClient') +def test_plugin_initialize_no_models(mock_client): + """Test plugin registers all supported models when models is None.""" + from genkit.plugins.deepseek.model_info import SUPPORTED_DEEPSEEK_MODELS + + plugin = DeepSeek(api_key='test-key') + mock_registry = MagicMock() + + # When models is None, all supported models should be registered + plugin.initialize(mock_registry) + + assert mock_registry.define_model.call_count == len(SUPPORTED_DEEPSEEK_MODELS) + + +def test_plugin_resolve_action_non_model_kind(): + """Test resolve_action does nothing for non-MODEL kinds.""" + plugin = DeepSeek(api_key='test-key') + mock_registry = MagicMock() + + # Using PROMPT kind to test the case where kind != MODEL + plugin.resolve_action(mock_registry, ActionKind.PROMPT, 'some-prompt') + + # Should not attempt to register anything + mock_registry.define_model.assert_not_called() + + +@patch('genkit.plugins.deepseek.models.DeepSeekClient') +def test_plugin_resolve_action_without_prefix(mock_client): + """Test plugin resolves models without plugin prefix.""" + plugin = DeepSeek(api_key='test-key', models=[]) + mock_registry = MagicMock() + + # Pass name without 'deepseek/' prefix + plugin.resolve_action(mock_registry, ActionKind.MODEL, 'deepseek-chat') + + mock_registry.define_model.assert_called_once() + + +@patch('genkit.plugins.deepseek.client.DeepSeekClient.__new__') +def test_deepseek_client_initialization(mock_new): + """Test DeepSeekClient creates OpenAI client with correct params.""" + from genkit.plugins.deepseek.client import DeepSeekClient + + # Set up mock to return a fake client + mock_client_instance = MagicMock() + mock_new.return_value = mock_client_instance + + # Create a DeepSeekClient + result = DeepSeekClient(api_key='test-key', timeout=30) + + # Verify __new__ was called with correct parameters + mock_new.assert_called_once() + + +def test_deepseek_client_with_custom_base_url(): + """Test DeepSeekClient accepts custom base_url.""" + from openai import OpenAI + + from genkit.plugins.deepseek.client import DeepSeekClient + + with patch.object(OpenAI, '__init__', return_value=None) as mock_init: + DeepSeekClient(api_key='test-key', base_url='https://custom.api.deepseek.com') + mock_init.assert_called_once_with( + api_key='test-key', + base_url='https://custom.api.deepseek.com', + ) + + +def test_deepseek_client_default_base_url(): + """Test DeepSeekClient uses default base_url when not provided.""" + from openai import OpenAI + + from genkit.plugins.deepseek.client import DeepSeekClient + + with patch.object(OpenAI, '__init__', return_value=None) as mock_init: + DeepSeekClient(api_key='test-key') + mock_init.assert_called_once_with( + api_key='test-key', + base_url='https://api.deepseek.com', + ) diff --git a/py/plugins/deepseek/tests/test_model_info.py b/py/plugins/deepseek/tests/test_model_info.py new file mode 100644 index 0000000000..dd61b137ba --- /dev/null +++ b/py/plugins/deepseek/tests/test_model_info.py @@ -0,0 +1,55 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for DeepSeek model information.""" + +import pytest + +from genkit.plugins.deepseek.model_info import SUPPORTED_DEEPSEEK_MODELS, get_default_model_info + + +def test_supported_models_exist(): + """Test that supported models are defined.""" + assert 'deepseek-reasoner' in SUPPORTED_DEEPSEEK_MODELS + assert 'deepseek-chat' in SUPPORTED_DEEPSEEK_MODELS + + +def test_model_order(): + """Test models are in correct order (matching JS).""" + keys = list(SUPPORTED_DEEPSEEK_MODELS.keys()) + assert keys[0] == 'deepseek-reasoner' + assert keys[1] == 'deepseek-chat' + + +def test_model_info_structure(): + """Test model info has required fields.""" + for model_name, model_info in SUPPORTED_DEEPSEEK_MODELS.items(): + assert model_info.label + assert model_info.supports + assert model_info.supports.multiturn is True + assert model_info.supports.tools is True + assert model_info.supports.media is False + assert model_info.supports.system_role is True + assert 'text' in model_info.supports.output + assert 'json' in model_info.supports.output + + +def test_get_default_model_info(): + """Test getting default info for unknown models.""" + info = get_default_model_info('deepseek-future-model') + assert 'deepseek-future-model' in info.label + assert info.supports.multiturn is True + assert info.supports.tools is True diff --git a/py/plugins/dev-local-vectorstore/pyproject.toml b/py/plugins/dev-local-vectorstore/pyproject.toml index 7302053da9..fe965b2399 100644 --- a/py/plugins/dev-local-vectorstore/pyproject.toml +++ b/py/plugins/dev-local-vectorstore/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -40,7 +39,7 @@ dependencies = [ "strenum>=0.4.15; python_version < '3.11'", ] description = "Genkit Local Vector Store Plugin" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "genkit-plugin-dev-local-vectorstore" readme = "README.md" requires-python = ">=3.10" diff --git a/py/plugins/dev-local-vectorstore/src/genkit/plugins/dev_local_vectorstore/indexer.py b/py/plugins/dev-local-vectorstore/src/genkit/plugins/dev_local_vectorstore/indexer.py index 18c36482c1..8f83790957 100644 --- a/py/plugins/dev-local-vectorstore/src/genkit/plugins/dev_local_vectorstore/indexer.py +++ b/py/plugins/dev-local-vectorstore/src/genkit/plugins/dev_local_vectorstore/indexer.py @@ -22,7 +22,7 @@ from genkit.blocks.document import Document from genkit.blocks.retriever import IndexerRequest from genkit.codec import dump_json -from genkit.types import DocumentData, Embedding +from genkit.types import Embedding from .constant import DbValue from .local_vector_store_api import ( diff --git a/py/plugins/dev-local-vectorstore/src/genkit/plugins/dev_local_vectorstore/plugin_api.py b/py/plugins/dev-local-vectorstore/src/genkit/plugins/dev_local_vectorstore/plugin_api.py index 21cea58314..22df23c638 100644 --- a/py/plugins/dev-local-vectorstore/src/genkit/plugins/dev_local_vectorstore/plugin_api.py +++ b/py/plugins/dev-local-vectorstore/src/genkit/plugins/dev_local_vectorstore/plugin_api.py @@ -18,9 +18,16 @@ from typing import Any -from genkit.ai import GenkitRegistry, Plugin -from genkit.core.action import Action -from genkit.types import Docs +from genkit.ai import Plugin +from genkit.blocks.retriever import ( + IndexerOptions, + RetrieverOptions, + indexer_action_metadata, + retriever_action_metadata, +) +from genkit.core.action import Action, ActionMetadata +from genkit.core.action.types import ActionKind +from genkit.core.schema import to_json_schema from .indexer import ( DevLocalVectorStoreIndexer, @@ -44,62 +51,83 @@ def __init__(self, name: str, embedder: str, embedder_options: dict[str, Any] | self.embedder = embedder self.embedder_options = embedder_options - def initialize(self, ai: GenkitRegistry) -> None: - """Initialize the plugin by registering actions with the registry. - - This method registers the Local Vector Store actions with the provided - registry, making them available for use in the Genkit framework. - - Args: - ai: The registry to register actions with. - - Returns: - None - """ - self._configure_dev_local_retriever(ai=ai) - self._configure_dev_local_indexer(ai=ai) - - def _configure_dev_local_retriever(self, ai: GenkitRegistry) -> Action: - """Registers Local Vector Store retriever for provided parameters. - - Args: - ai: The registry to register retriever with. - params: Parameters to register retriever with. - - Returns: - registered Action instance - """ - retriever = DevLocalVectorStoreRetriever( - ai=ai, - index_name=self.index_name, - embedder=self.embedder, - embedder_options=self.embedder_options, - ) - - return ai.define_retriever( - name=self.index_name, - config_schema=RetrieverOptionsSchema, - fn=retriever.retrieve, - ) - - def _configure_dev_local_indexer(self, ai: GenkitRegistry) -> Action: - """Registers Local Vector Store indexer for provided parameters. - - Args: - ai: The registry to register indexer with. - params: Parameters to register indexer with. - - Returns: - registered Action instance - """ - indexer = DevLocalVectorStoreIndexer( - ai=ai, - index_name=self.index_name, - embedder=self.embedder, - embedder_options=self.embedder_options, - ) - - return ai.define_indexer( - name=self.index_name, - fn=indexer.index, - ) + async def init(self) -> list[Action]: + return [ + self._create_retriever_action(), + self._create_indexer_action(), + ] + + async def resolve(self, action_type: ActionKind, name: str) -> Action | None: + if name != self.index_name: + return None + if action_type == ActionKind.RETRIEVER: + return self._create_retriever_action() + if action_type == ActionKind.INDEXER: + return self._create_indexer_action() + return None + + async def list_actions(self) -> list[ActionMetadata]: + return [ + retriever_action_metadata( + name=self.index_name, + options=RetrieverOptions( + label=self.index_name, + config_schema=to_json_schema(RetrieverOptionsSchema), + ), + ), + indexer_action_metadata( + name=self.index_name, + options=IndexerOptions( + label=self.index_name, + ), + ), + ] + + def _create_retriever_action(self) -> Action: + metadata: dict[str, Any] = { + 'retriever': { + 'label': self.index_name, + 'customOptions': to_json_schema(RetrieverOptionsSchema), + } + } + + async def retrieve(request, ctx): + ai = (ctx.context or {}).get('__genkit_ai__') + if ai is None: + raise ValueError( + 'DevLocalVectorStore retriever requires a Genkit instance in action context. ' + 'Use it via `await ai.retrieve(...)`.' + ) + retriever = DevLocalVectorStoreRetriever( + ai=ai, + index_name=self.index_name, + embedder=self.embedder, + embedder_options=self.embedder_options, + ) + return await retriever.retrieve(request, ctx) + + return Action(kind=ActionKind.RETRIEVER, name=self.index_name, fn=retrieve, metadata=metadata) + + def _create_indexer_action(self) -> Action: + metadata: dict[str, Any] = { + 'indexer': { + 'label': self.index_name, + } + } + + async def index(request, ctx): + ai = (ctx.context or {}).get('__genkit_ai__') + if ai is None: + raise ValueError( + 'DevLocalVectorStore indexer requires a Genkit instance in action context. ' + 'Use it via `await ai.index(...)`.' + ) + indexer = DevLocalVectorStoreIndexer( + ai=ai, + index_name=self.index_name, + embedder=self.embedder, + embedder_options=self.embedder_options, + ) + return await indexer.index(request) + + return Action(kind=ActionKind.INDEXER, name=self.index_name, fn=index, metadata=metadata) diff --git a/py/plugins/dev-local-vectorstore/tests/test_dev_local_vectorstore_plugin_v2.py b/py/plugins/dev-local-vectorstore/tests/test_dev_local_vectorstore_plugin_v2.py new file mode 100644 index 0000000000..609161cda8 --- /dev/null +++ b/py/plugins/dev-local-vectorstore/tests/test_dev_local_vectorstore_plugin_v2.py @@ -0,0 +1,48 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +from genkit.core.action.types import ActionKind +from genkit.plugins.dev_local_vectorstore import DevLocalVectorStore + + +@pytest.mark.asyncio +async def test_init_returns_retriever_and_indexer_actions(): + plugin = DevLocalVectorStore( + name='films', + embedder='vertexai/text-embedding-004', + ) + + actions = await plugin.init() + + assert {a.kind for a in actions} == {ActionKind.RETRIEVER, ActionKind.INDEXER} + assert {a.name for a in actions} == {'films'} + + +@pytest.mark.asyncio +async def test_list_returns_action_metadata(): + plugin = DevLocalVectorStore( + name='films', + embedder='vertexai/text-embedding-004', + ) + + metas = await plugin.list_actions() + + assert {m.kind for m in metas} == {ActionKind.RETRIEVER, ActionKind.INDEXER} + assert {m.name for m in metas} == {'films'} diff --git a/py/plugins/evaluators/pyproject.toml b/py/plugins/evaluators/pyproject.toml index 257fad1fb8..c7b9385d19 100644 --- a/py/plugins/evaluators/pyproject.toml +++ b/py/plugins/evaluators/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", @@ -38,7 +37,7 @@ dependencies = [ "strenum>=0.4.15; python_version < '3.11'", ] description = "Genkit Evaluators Plugin for RAGAS" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "genkit-plugin-evaluators" readme = "README.md" requires-python = ">=3.10" diff --git a/py/plugins/evaluators/src/genkit/plugins/evaluators/plugin_api.py b/py/plugins/evaluators/src/genkit/plugins/evaluators/plugin_api.py index e8da849b9a..a9b07c37a5 100644 --- a/py/plugins/evaluators/src/genkit/plugins/evaluators/plugin_api.py +++ b/py/plugins/evaluators/src/genkit/plugins/evaluators/plugin_api.py @@ -22,9 +22,12 @@ from typing import Any import jsonata -from dotpromptz.typing import DataArgument -from genkit.ai import Genkit, Plugin +from genkit.ai import Plugin +from genkit.core.action import Action, ActionMetadata +from genkit.core.action.types import ActionKind +from genkit.core.schema import to_json_schema +from genkit.core.typing import EvalRequest, EvalResponse from genkit.plugins.evaluators.constant import ( AnswerRelevancyResponseSchema, GenkitMetricType, @@ -80,25 +83,108 @@ def __init__(self, params: PluginOptions | list[MetricConfig]): params = PluginOptions(root=params) self.params = params - def initialize(self, ai: Genkit) -> None: - """Initialize the plugin by registering actions with the registry.""" + async def init(self) -> list[Action]: + return [self._create_evaluator_action(param) for param in self.params.root] + + async def resolve(self, action_type: ActionKind, name: str) -> Action | None: + if action_type != ActionKind.EVALUATOR: + return None + for param in self.params.root: + metric_name, _, _ = self._metric_descriptor(param) + if name == metric_name: + return self._create_evaluator_action(param) + return None + + async def list_actions(self) -> list[ActionMetadata]: + metas: list[ActionMetadata] = [] for param in self.params.root: - self._configure_evaluator(ai=ai, param=param) + metric_name, display_name, definition = self._metric_descriptor(param) + metas.append( + ActionMetadata( + kind=ActionKind.EVALUATOR, + name=metric_name, + input_json_schema=to_json_schema(EvalRequest), + output_json_schema=to_json_schema(EvalResponse), + metadata={ + 'evaluator': { + 'label': metric_name, + 'displayName': display_name, + 'definition': definition, + 'isBilled': bool(param.judge), + } + }, + ) + ) + return metas - def _configure_evaluator(self, ai: Genkit, param: MetricConfig): - """Validates and configures supported evaluators.""" + def _metric_descriptor(self, param: MetricConfig) -> tuple[str, str, str]: metric_type = param.metric_type match metric_type: case GenkitMetricType.ANSWER_RELEVANCY: + return ( + str(metric_type).lower(), + 'Answer Relevancy', + 'Assesses how pertinent the generated answer is to the given prompt', + ) + case GenkitMetricType.FAITHFULNESS: + return ( + str(metric_type).lower(), + 'Faithfulness', + 'Measures the factual consistency of the generated answer against the given context', + ) + case GenkitMetricType.MALICIOUSNESS: + return ( + str(metric_type).lower(), + 'Maliciousness', + 'Measures whether the generated output intends to deceive, harm, or exploit', + ) + case GenkitMetricType.REGEX: + return ( + str(metric_type).lower(), + 'RegExp', + 'Tests output against the regexp provided as reference', + ) + case GenkitMetricType.DEEP_EQUAL: + return ( + str(metric_type).lower(), + 'Deep Equals', + 'Tests equality of output against the provided reference', + ) + case GenkitMetricType.JSONATA: + return ( + str(metric_type).lower(), + 'JSONata', + 'Tests JSONata expression (provided in reference) against output', + ) + case _: + raise ValueError(f'Unsupported metric type: {metric_type}') + + def _create_evaluator_action(self, param: MetricConfig) -> Action: + metric_name, display_name, definition = self._metric_descriptor(param) + metadata = { + 'evaluator': { + 'label': metric_name, + 'displayName': display_name, + 'definition': definition, + 'isBilled': bool(param.judge), + } + } + + metric_type = param.metric_type + + # Cache for prompts (loaded on first use) - scoped per-action to avoid cross-test coupling. + _faithfulness_prompts: dict[str, Any] = {} - async def _relevancy_eval(datapoint: BaseEvalDataPoint, options: Any | None): + async def eval_one(datapoint: BaseEvalDataPoint, options: Any | None, ai) -> EvalFnResponse: + match metric_type: + case GenkitMetricType.ANSWER_RELEVANCY: assert datapoint.output is not None, 'output is required' output_string = ( datapoint.output if isinstance(datapoint.output, str) else json.dumps(datapoint.output) ) input_string = datapoint.input if isinstance(datapoint.input, str) else json.dumps(datapoint.input) prompt_function = await load_prompt_file(_get_prompt_path('faithfulness_long_form.prompt')) - context = ' '.join(json.dumps(e) for e in datapoint.context) + context = ' '.join(json.dumps(e) for e in (datapoint.context or [])) prompt = await render_text( prompt_function, {'input': input_string, 'output': output_string, 'context': context} ) @@ -106,24 +192,17 @@ async def _relevancy_eval(datapoint: BaseEvalDataPoint, options: Any | None): response = await ai.generate( model=param.judge.name, prompt=prompt, - config=param.config, + config=param.judge_config, output_schema=AnswerRelevancyResponseSchema, ) - # TODO: embedding comparison between the input and the result of the llm - status = EvalStatusEnum.PASS_ if response.output else EvalStatusEnum.FAIL - return fill_scores(datapoint, Score(score=score, status=status), param.status_override_fn) - ai.define_evaluator( - name=evaluators_name(str(GenkitMetricType.ANSWER_RELEVANCY).lower()), - display_name='Answer Relevancy', - definition='Assesses how pertinent the generated answer is to the given prompt', - fn=_relevancy_eval, - ) - case GenkitMetricType.FAITHFULNESS: - # Cache for prompts (loaded on first use) - _faithfulness_prompts = {} + out = response.output + answered = out.get('answered') if isinstance(out, dict) else (out.answered if out else False) + score = bool(answered) + status = EvalStatusEnum.PASS_ if score else EvalStatusEnum.FAIL + return fill_scores(datapoint, Score(score=score, status=status), param.status_override_fn) - async def _faithfulness_eval(datapoint: BaseEvalDataPoint, options: Any | None): + case GenkitMetricType.FAITHFULNESS: assert datapoint.output is not None, 'output is required' output_string = ( datapoint.output if isinstance(datapoint.output, str) else json.dumps(datapoint.output) @@ -131,7 +210,6 @@ async def _faithfulness_eval(datapoint: BaseEvalDataPoint, options: Any | None): input_string = datapoint.input if isinstance(datapoint.input, str) else json.dumps(datapoint.input) context_list = [(json.dumps(e) if not isinstance(e, str) else e) for e in (datapoint.context or [])] - # Lazy load and cache prompts if 'longform' not in _faithfulness_prompts: _faithfulness_prompts['longform'] = await load_prompt_file( _get_prompt_path('faithfulness_long_form.prompt') @@ -141,7 +219,6 @@ async def _faithfulness_eval(datapoint: BaseEvalDataPoint, options: Any | None): _get_prompt_path('faithfulness_nli.prompt') ) - # Step 1: Extract statements prompt = await render_text( _faithfulness_prompts['longform'], {'question': input_string, 'answer': output_string} ) @@ -159,7 +236,6 @@ async def _faithfulness_eval(datapoint: BaseEvalDataPoint, options: Any | None): if not statements: raise ValueError('No statements returned') - # Step 2: NLI Check all_statements = '\n'.join([f'statement: {s}' for s in statements]) all_context = '\n'.join(context_list) prompt = await render_text( @@ -174,68 +250,51 @@ async def _faithfulness_eval(datapoint: BaseEvalDataPoint, options: Any | None): ) nli_output = nli_response.output - if isinstance(nli_output, dict): - responses = nli_output.get('responses', []) - else: - responses = nli_output.responses if nli_output else [] - + responses = ( + nli_output.get('responses', []) + if isinstance(nli_output, dict) + else (nli_output.responses if nli_output else []) + ) if not responses: raise ValueError('Evaluator response empty') - # Handle both dict and object responses faithful_count = sum( 1 for r in responses if (r.get('verdict') if isinstance(r, dict) else r.verdict) ) score_val = faithful_count / len(responses) reasoning = '; '.join([r.get('reason', '') if isinstance(r, dict) else r.reason for r in responses]) status = EvalStatusEnum.PASS_ if score_val > 0.5 else EvalStatusEnum.FAIL - return fill_scores( datapoint, Score(score=score_val, status=status, details={'reasoning': reasoning}), param.status_override_fn, ) - ai.define_evaluator( - name=evaluators_name(str(GenkitMetricType.FAITHFULNESS).lower()), - display_name='Faithfulness', - definition='Measures the factual consistency of the generated answer against the given context', - fn=_faithfulness_eval, - ) - - case GenkitMetricType.MALICIOUSNESS: - - async def _maliciousness_eval(datapoint: BaseEvalDataPoint, options: Any | None): + case GenkitMetricType.MALICIOUSNESS: assert datapoint.output is not None, 'output is required' output_string = ( datapoint.output if isinstance(datapoint.output, str) else json.dumps(datapoint.output) ) input_string = datapoint.input if isinstance(datapoint.input, str) else json.dumps(datapoint.input) prompt_function = await load_prompt_file(_get_prompt_path('maliciousness.prompt')) - context = ' '.join(json.dumps(e) for e in datapoint.context) + context = ' '.join(json.dumps(e) for e in (datapoint.context or [])) prompt = await render_text( prompt_function, {'input': input_string, 'output': output_string, 'context': context} ) - score = await ai.generate( + response = await ai.generate( model=param.judge.name, prompt=prompt, - config=param.config, + config=param.judge_config, output_schema=MaliciousnessResponseSchema, ) + out = response.output + verdict = out.get('verdict') if isinstance(out, dict) else (out.verdict if out else False) + score = bool(verdict) status = EvalStatusEnum.PASS_ if score else EvalStatusEnum.FAIL return fill_scores(datapoint, Score(score=score, status=status), param.status_override_fn) - ai.define_evaluator( - name=evaluators_name(str(GenkitMetricType.MALICIOUSNESS).lower()), - display_name='Maliciousness', - definition='Measures whether the generated output intends to deceive, harm, or exploit', - fn=_maliciousness_eval, - ) - # - case GenkitMetricType.REGEX: - - async def _regex_eval(datapoint: BaseEvalDataPoint, options: Any | None): + case GenkitMetricType.REGEX: assert datapoint.output is not None, 'output is required' assert datapoint.reference is not None, 'reference is required' assert isinstance(datapoint.reference, str), 'reference must be of string (regex)' @@ -243,39 +302,20 @@ async def _regex_eval(datapoint: BaseEvalDataPoint, options: Any | None): datapoint.output if isinstance(datapoint.output, str) else json.dumps(datapoint.output) ) pattern = re.compile(datapoint.reference) - score = False if pattern.search(output_string) is None else True + score = pattern.search(output_string) is not None status = EvalStatusEnum.PASS_ if score else EvalStatusEnum.FAIL return fill_scores(datapoint, Score(score=score, status=status), param.status_override_fn) - ai.define_evaluator( - name=evaluators_name(str(GenkitMetricType.REGEX).lower()), - display_name='RegExp', - definition='Tests output against the regexp provided as reference', - fn=_regex_eval, - ) - - case GenkitMetricType.DEEP_EQUAL: - - async def _deep_equal_eval(datapoint: BaseEvalDataPoint, options: Any | None): + case GenkitMetricType.DEEP_EQUAL: assert datapoint.reference is not None, 'reference is required' assert datapoint.output is not None, 'output is required' - score = False - if type(datapoint.output) is type(datapoint.reference): - if datapoint.output == datapoint.reference: - score = True + score = ( + type(datapoint.output) is type(datapoint.reference) and datapoint.output == datapoint.reference + ) status = EvalStatusEnum.PASS_ if score else EvalStatusEnum.FAIL return fill_scores(datapoint, Score(score=score, status=status), param.status_override_fn) - ai.define_evaluator( - name=evaluators_name(str(GenkitMetricType.DEEP_EQUAL).lower()), - display_name='Deep Equals', - definition="""Tests equality of output against the provided reference""", - fn=_deep_equal_eval, - ) - - case GenkitMetricType.JSONATA: - - async def _jsonata_eval(datapoint: BaseEvalDataPoint, options: Any | None): + case GenkitMetricType.JSONATA: assert datapoint.output is not None, 'output is required' assert datapoint.reference is not None, 'reference is required' assert isinstance(datapoint.reference, str), 'reference must be of string (jsonata)' @@ -284,9 +324,33 @@ async def _jsonata_eval(datapoint: BaseEvalDataPoint, options: Any | None): status = EvalStatusEnum.PASS_ if bool(score) else EvalStatusEnum.FAIL return fill_scores(datapoint, Score(score=score, status=status), param.status_override_fn) - ai.define_evaluator( - name=evaluators_name(str(GenkitMetricType.JSONATA).lower()), - display_name='JSONata', - definition="""Tests JSONata expression (provided in reference) against output""", - fn=_jsonata_eval, + case _: + raise ValueError(f'Unsupported metric type: {metric_type}') + + async def eval_stepper(req: EvalRequest, ctx): + ai = (ctx.context or {}).get('__genkit_ai__') + if ai is None: + raise ValueError( + 'GenkitEvaluators requires a Genkit instance in action context. Use `await ai.evaluate(...)`.' ) + + responses: list[EvalFnResponse] = [] + for datapoint in req.dataset: + if datapoint.test_case_id is None: + # Keep behavior consistent with core evaluator runner. + datapoint.test_case_id = 'unknown' + try: + responses.append(await eval_one(datapoint, req.options, ai)) + except Exception as e: + responses.append( + EvalFnResponse( + test_case_id=datapoint.test_case_id, + evaluation=Score( + error=f'Evaluation of test case {datapoint.test_case_id} failed: \n{str(e)}', + status=EvalStatusEnum.FAIL, + ), + ) + ) + return EvalResponse(root=responses) + + return Action(kind=ActionKind.EVALUATOR, name=metric_name, fn=eval_stepper, metadata=metadata) diff --git a/py/plugins/evaluators/src/genkit/plugins/metrics/helper.py b/py/plugins/evaluators/src/genkit/plugins/metrics/helper.py index 43c60bf3d6..7c9efb7f22 100644 --- a/py/plugins/evaluators/src/genkit/plugins/metrics/helper.py +++ b/py/plugins/evaluators/src/genkit/plugins/metrics/helper.py @@ -18,14 +18,14 @@ from typing import Any -from dotpromptz import Dotprompt -from dotpromptz.typing import DataArgument, PromptFunction +from dotprompt import Dotprompt +from dotprompt.typing import DataArgument, PromptFunction dp = Dotprompt() async def load_prompt_file(path: str) -> PromptFunction: - with open(path, 'r') as f: + with open(path) as f: result = await dp.compile(f.read()) return result diff --git a/py/plugins/evaluators/tests/test_evaluators_plugin_v2.py b/py/plugins/evaluators/tests/test_evaluators_plugin_v2.py new file mode 100644 index 0000000000..7ab7f0f9e7 --- /dev/null +++ b/py/plugins/evaluators/tests/test_evaluators_plugin_v2.py @@ -0,0 +1,53 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +from genkit.core.action.types import ActionKind +from genkit.plugins.evaluators.constant import GenkitMetricType, MetricConfig +from genkit.plugins.evaluators.plugin_api import GenkitEvaluators + + +@pytest.mark.asyncio +async def test_init_returns_evaluator_actions(): + plugin = GenkitEvaluators( + params=[ + MetricConfig(metric_type=GenkitMetricType.REGEX), + MetricConfig(metric_type=GenkitMetricType.DEEP_EQUAL), + ] + ) + + actions = await plugin.init() + + assert {a.kind for a in actions} == {ActionKind.EVALUATOR} + assert {a.name for a in actions} == {str(GenkitMetricType.REGEX).lower(), str(GenkitMetricType.DEEP_EQUAL).lower()} + + +@pytest.mark.asyncio +async def test_list_returns_action_metadata(): + plugin = GenkitEvaluators( + params=[ + MetricConfig(metric_type=GenkitMetricType.REGEX), + ] + ) + + metas = await plugin.list_actions() + + assert len(metas) == 1 + assert metas[0].kind == ActionKind.EVALUATOR + assert metas[0].name == str(GenkitMetricType.REGEX).lower() diff --git a/py/plugins/firebase/pyproject.toml b/py/plugins/firebase/pyproject.toml index 05bb7c57fa..fd747dbb31 100644 --- a/py/plugins/firebase/pyproject.toml +++ b/py/plugins/firebase/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -39,7 +38,7 @@ dependencies = [ "strenum>=0.4.15; python_version < '3.11'", ] description = "Genkit Firebase Plugin" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "genkit-plugin-firebase" readme = "README.md" requires-python = ">=3.10" diff --git a/py/plugins/firebase/src/genkit/plugins/firebase/firestore.py b/py/plugins/firebase/src/genkit/plugins/firebase/firestore.py index 8552fe4a66..074348d02b 100644 --- a/py/plugins/firebase/src/genkit/plugins/firebase/firestore.py +++ b/py/plugins/firebase/src/genkit/plugins/firebase/firestore.py @@ -20,7 +20,10 @@ from google.cloud.firestore_v1 import DocumentSnapshot from google.cloud.firestore_v1.base_vector_query import DistanceMeasure -from genkit.ai import GenkitRegistry, Plugin +from genkit.ai import Plugin +from genkit.blocks.retriever import RetrieverOptions, retriever_action_metadata +from genkit.core.action import Action, ActionMetadata +from genkit.core.action.types import ActionKind from genkit.plugins.firebase.retriever import FirestoreRetriever from .constant import MetadataTransformFn @@ -40,19 +43,9 @@ def firestore_action_name(name: str) -> str: class FirestoreVectorStore(Plugin): - """Firestore retriever plugin. + """Firestore retriever plugin (PluginV2).""" - Args: - name: name if the retriever. - collection: The name of the Firestore collection to query. - vector_field: The name of the field containing the vector embeddings. - content_field: The name of the field containing the document content, you wish to return. - embedder: The embedder to use with this retriever. - embedder_options: Optional configuration to pass to the embedder. - distance_measure: The distance measure to use when comparing vectors. Defaults to 'COSINE'. - firestore_client: The Firestore database instance from which to query. - metadata_fields: Optional list of metadata fields to include. - """ + name: str = 'firestore' def __init__( self, @@ -79,7 +72,7 @@ def __init__( firestore_client: The Firestore database instance from which to query. metadata_fields: Optional list of metadata fields to include. """ - self.name = name + self.store_name = name self.firestore_client = firestore_client self.collection = collection self.vector_field = vector_field @@ -89,31 +82,57 @@ def __init__( self.distance_measure = distance_measure self.metadata_fields = metadata_fields - def initialize(self, ai: GenkitRegistry) -> None: - """Initialize firestore plugin. - - Register actions with the registry making them available for use in the Genkit framework. - - Args: - ai: The registry to register actions with. - - Returns: - None - """ - retriever = FirestoreRetriever( - ai=ai, - name=self.name, - firestore_client=self.firestore_client, - collection=self.collection, - vector_field=self.vector_field, - content_field=self.content_field, - embedder=self.embedder, - embedder_options=self.embedder_options, - distance_measure=self.distance_measure, - metadata_fields=self.metadata_fields, - ) - - return ai.define_retriever( - name=firestore_action_name(self.name), - fn=retriever.retrieve, + async def init(self) -> list[Action]: + return [self._create_retriever_action()] + + async def resolve(self, action_type: ActionKind, name: str) -> Action | None: + if action_type != ActionKind.RETRIEVER: + return None + if name != self.store_name: + return None + return self._create_retriever_action() + + async def list_actions(self) -> list[ActionMetadata]: + return [ + retriever_action_metadata( + name=self.store_name, + options=RetrieverOptions( + label=self.store_name, + ), + ) + ] + + def _create_retriever_action(self) -> Action: + metadata: dict[str, Any] = { + 'retriever': { + 'label': self.store_name, + } + } + + async def retrieve(request, ctx): + ai = (ctx.context or {}).get('__genkit_ai__') + if ai is None: + raise ValueError( + 'FirestoreVectorStore retriever requires a Genkit instance in action context. ' + 'Use it via `await ai.retrieve(...)`.' + ) + retriever = FirestoreRetriever( + ai=ai, + name=self.store_name, + firestore_client=self.firestore_client, + collection=self.collection, + vector_field=self.vector_field, + content_field=self.content_field, + embedder=self.embedder, + embedder_options=self.embedder_options, + distance_measure=self.distance_measure, + metadata_fields=self.metadata_fields, + ) + return await retriever.retrieve(request, ctx) + + return Action( + kind=ActionKind.RETRIEVER, + name=self.store_name, + fn=retrieve, + metadata=metadata, ) diff --git a/py/plugins/firebase/src/genkit/plugins/firebase/tests/test_firestore_vectorstore_plugin.py b/py/plugins/firebase/src/genkit/plugins/firebase/tests/test_firestore_vectorstore_plugin.py new file mode 100644 index 0000000000..df4345fce3 --- /dev/null +++ b/py/plugins/firebase/src/genkit/plugins/firebase/tests/test_firestore_vectorstore_plugin.py @@ -0,0 +1,62 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +from google.cloud.firestore_v1.base_vector_query import DistanceMeasure + +from genkit.core.action.types import ActionKind +from genkit.plugins.firebase.firestore import FirestoreVectorStore + + +@pytest.mark.asyncio +async def test_init_returns_retriever_action(): + plugin = FirestoreVectorStore( + name='kb', + firestore_client=MagicMock(), + collection='docs', + vector_field='embedding', + content_field='text', + embedder='vertexai/text-embedding-004', + distance_measure=DistanceMeasure.COSINE, + ) + + actions = await plugin.init() + + assert len(actions) == 1 + assert actions[0].kind == ActionKind.RETRIEVER + assert actions[0].name == 'kb' + + +@pytest.mark.asyncio +async def test_list_returns_metadata(): + plugin = FirestoreVectorStore( + name='kb', + firestore_client=MagicMock(), + collection='docs', + vector_field='embedding', + content_field='text', + embedder='vertexai/text-embedding-004', + ) + + metas = await plugin.list_actions() + + assert len(metas) == 1 + assert metas[0].kind == ActionKind.RETRIEVER + assert metas[0].name == 'kb' diff --git a/py/plugins/flask/pyproject.toml b/py/plugins/flask/pyproject.toml index 43bb01a88d..0b43ad401e 100644 --- a/py/plugins/flask/pyproject.toml +++ b/py/plugins/flask/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -41,7 +40,7 @@ dependencies = [ "flask", ] description = "Genkit Firebase Plugin" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "genkit-plugin-flask" readme = "README.md" requires-python = ">=3.10" diff --git a/py/plugins/flask/tests/flask_test.py b/py/plugins/flask/tests/flask_test.py index 1308e9deef..e314644df6 100644 --- a/py/plugins/flask/tests/flask_test.py +++ b/py/plugins/flask/tests/flask_test.py @@ -65,7 +65,7 @@ def test_streaming(): headers={'Authorization': 'Pavel', 'content-Type': 'application/json', 'accept': 'text/event-stream'}, ) - assert response.is_streamed == True + assert response.is_streamed chunks = [] for chunk in response.response: diff --git a/py/plugins/google-cloud/pyproject.toml b/py/plugins/google-cloud/pyproject.toml index 2a58f3e6f8..43985cab8a 100644 --- a/py/plugins/google-cloud/pyproject.toml +++ b/py/plugins/google-cloud/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -40,7 +39,7 @@ dependencies = [ "strenum>=0.4.15; python_version < '3.11'", ] description = "Genkit Google Cloud Plugin" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "genkit-plugin-google-cloud" readme = "README.md" requires-python = ">=3.10" diff --git a/py/plugins/google-genai/pyproject.toml b/py/plugins/google-genai/pyproject.toml index bc4df8a955..e98788f12a 100644 --- a/py/plugins/google-genai/pyproject.toml +++ b/py/plugins/google-genai/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -41,7 +40,7 @@ dependencies = [ "strenum>=0.4.15; python_version < '3.11'", ] description = "Genkit Google GenAI Plugin" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "genkit-plugin-google-genai" readme = "README.md" requires-python = ">=3.10" diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py index 437cf9506b..c22d6a546e 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/google.py @@ -20,15 +20,14 @@ from google import genai from google.auth.credentials import Credentials from google.genai.client import DebugConfig -from google.genai.types import EmbedContentConfig, GenerateImagesConfigOrDict, HttpOptions, HttpOptionsDict +from google.genai.types import HttpOptions, HttpOptionsDict import genkit.plugins.google_genai.constants as const -from genkit.ai import GENKIT_CLIENT_HEADER, GenkitRegistry, Plugin -from genkit.blocks.embedding import EmbedderOptions, EmbedderSupports, embedder_action_metadata -from genkit.blocks.model import model_action_metadata -from genkit.core.action import ActionMetadata +from genkit.ai import GENKIT_CLIENT_HEADER, Plugin +from genkit.blocks.embedding import EmbedderOptions, EmbedderSupports, embedder, embedder_action_metadata +from genkit.blocks.model import model, model_action_metadata +from genkit.core.action import Action, ActionMetadata from genkit.core.registry import ActionKind -from genkit.core.schema import to_json_schema from genkit.plugins.google_genai.models.embedder import ( Embedder, GeminiEmbeddingModels, @@ -37,7 +36,6 @@ ) from genkit.plugins.google_genai.models.gemini import ( SUPPORTED_MODELS, - GeminiConfigSchema, GeminiModel, GoogleAIGeminiVersion, VertexAIGeminiVersion, @@ -128,107 +126,65 @@ def __init__( http_options=_inject_attribution_headers(http_options), ) - def initialize(self, ai: GenkitRegistry) -> None: - """Initialize the plugin by registering actions in the registry. + async def init(self) -> list[Action]: + actions: list[Action] = [] - Args: - ai: the action registry. - """ for version in GoogleAIGeminiVersion: - gemini_model = GeminiModel(version, self._client, ai) - ai.define_model( - name=googleai_name(version), - fn=gemini_model.generate, - metadata=gemini_model.metadata, - # config_schema=GeminiConfigSchema, + gemini_model = GeminiModel(version, self._client) + actions.append( + model( + name=str(version), + fn=gemini_model.generate, + metadata=gemini_model.metadata, + # config_schema=GeminiConfigSchema, + ) ) for version in GeminiEmbeddingModels: - embedder = Embedder(version=version, client=self._client) + embedder_impl = Embedder(version=version, client=self._client) embedder_info = default_embedder_info(version) - ai.define_embedder( - name=googleai_name(version), - fn=embedder.generate, + actions.append( + embedder( + name=str(version), + fn=embedder_impl.generate, + options=EmbedderOptions( + label=embedder_info.get('label'), + dimensions=embedder_info.get('dimensions'), + supports=EmbedderSupports(**embedder_info['supports']) + if embedder_info.get('supports') + else None, + ), + ) + ) + + return actions + + async def resolve(self, action_type: ActionKind, name: str) -> Action | None: + if action_type == ActionKind.MODEL: + model_ref = google_model_info(name) + SUPPORTED_MODELS[name] = model_ref + gemini_model = GeminiModel(name, self._client) + return model( + name=name, + fn=gemini_model.generate, + metadata=gemini_model.metadata, + ) + if action_type == ActionKind.EMBEDDER: + embedder_impl = Embedder(version=name, client=self._client) + embedder_info = default_embedder_info(name) + return embedder( + name=name, + fn=embedder_impl.generate, options=EmbedderOptions( label=embedder_info.get('label'), dimensions=embedder_info.get('dimensions'), supports=EmbedderSupports(**embedder_info['supports']) if embedder_info.get('supports') else None, ), ) - - def resolve_action( - self, - ai: GenkitRegistry, - kind: ActionKind, - name: str, - ) -> None: - """Resolves and action. - - Args: - ai: The Genkit registry. - kind: The kind of action to resolve. - name: The name of the action to resolve. - """ - if kind == ActionKind.MODEL: - self._resolve_model(ai, name) - elif kind == ActionKind.EMBEDDER: - self._resolve_embedder(ai, name) - - def _resolve_model(self, ai: GenkitRegistry, name: str) -> None: - """Resolves and defines a Google AI model within the Genkit registry. - - This internal method handles the logic for registering different types of - Google AI models (e.g., Gemini text models) based on the provided name. - It extracts a clean name, determines the model type, instantiates the - appropriate model class, and registers it with the Genkit AI registry. - - Args: - ai: The Genkit AI registry instance to define the model in. - name: The name of the model to resolve. This name might include a - prefix indicating it's from a specific plugin (e.g., 'googleai/gemini-pro'). - """ - _clean_name = name.replace(GOOGLEAI_PLUGIN_NAME + '/', '') if name.startswith(GOOGLEAI_PLUGIN_NAME) else name - model_ref = google_model_info(_clean_name) - - SUPPORTED_MODELS[_clean_name] = model_ref - - gemini_model = GeminiModel(_clean_name, self._client, ai) - - ai.define_model( - name=googleai_name(_clean_name), - fn=gemini_model.generate, - metadata=gemini_model.metadata, - # config_schema=GeminiConfigSchema, - ) - - def _resolve_embedder(self, ai: GenkitRegistry, name: str) -> None: - """Resolves and defines a Google AI embedder within the Genkit registry. - - This internal method handles the logic for registering Google AI embedder - models. It extracts a clean name, instantiates the embedder class, and - registers it with the Genkit AI registry. - - Args: - ai: The Genkit AI registry instance to define the embedder in. - name: The name of the embedder to resolve. This name might include a - prefix indicating it's from a specific plugin (e.g., 'googleai/embedding-001'). - """ - _clean_name = name.replace(GOOGLEAI_PLUGIN_NAME + '/', '') if name.startswith(GOOGLEAI_PLUGIN_NAME) else name - embedder = Embedder(version=_clean_name, client=self._client) - - embedder_info = default_embedder_info(_clean_name) - ai.define_embedder( - name=googleai_name(_clean_name), - fn=embedder.generate, - options=EmbedderOptions( - label=embedder_info.get('label'), - dimensions=embedder_info.get('dimensions'), - supports=EmbedderSupports(**embedder_info['supports']) if embedder_info.get('supports') else None, - ), - ) + return None @cached_property - def list_actions(self) -> list[ActionMetadata]: + def _list_actions_cache(self) -> list[ActionMetadata]: """Generate a list of available actions or models. Returns: @@ -264,6 +220,9 @@ def list_actions(self) -> list[ActionMetadata]: return actions_list + async def list_actions(self) -> list[ActionMetadata]: + return list(self._list_actions_cache) + class VertexAI(Plugin): """Vertex AI plugin for Genkit. @@ -315,125 +274,81 @@ def __init__( http_options=_inject_attribution_headers(http_options), ) - def initialize(self, ai: GenkitRegistry) -> None: - """Initialize the plugin by registering actions with the registry. - - This method registers the Vertex AI model actions with the provided - registry, making them available for use in the Genkit framework. + async def init(self) -> list[Action]: + actions: list[Action] = [] - Args: - ai: the action registry. - """ for version in VertexAIGeminiVersion: - gemini_model = GeminiModel(version, self._client, ai) - ai.define_model( - name=vertexai_name(version), - fn=gemini_model.generate, - metadata=gemini_model.metadata, - # config_schema=GeminiConfigSchema, + gemini_model = GeminiModel(version, self._client) + actions.append( + model( + name=str(version), + fn=gemini_model.generate, + metadata=gemini_model.metadata, + ) ) for version in VertexEmbeddingModels: - embedder = Embedder(version=version, client=self._client) + embedder_impl = Embedder(version=version, client=self._client) embedder_info = default_embedder_info(version) - ai.define_embedder( - name=vertexai_name(version), - fn=embedder.generate, - options=EmbedderOptions( - label=embedder_info.get('label'), - dimensions=embedder_info.get('dimensions'), - supports=EmbedderSupports(**embedder_info['supports']) if embedder_info.get('supports') else None, - ), + actions.append( + embedder( + name=str(version), + fn=embedder_impl.generate, + options=EmbedderOptions( + label=embedder_info.get('label'), + dimensions=embedder_info.get('dimensions'), + supports=EmbedderSupports(**embedder_info['supports']) + if embedder_info.get('supports') + else None, + ), + ) ) for version in ImagenVersion: imagen_model = ImagenModel(version, self._client) - ai.define_model( - name=vertexai_name(version), - fn=imagen_model.generate, - metadata=imagen_model.metadata, + actions.append( + model( + name=str(version), + fn=imagen_model.generate, + metadata=imagen_model.metadata, + ) ) - def resolve_action( - self, - ai: GenkitRegistry, - kind: ActionKind, - name: str, - ) -> None: - """Resolves and action. - - Args: - ai: The Genkit registry. - kind: The kind of action to resolve. - name: The name of the action to resolve. - """ - if kind == ActionKind.MODEL: - self._resolve_model(ai, name) - elif kind == ActionKind.EMBEDDER: - self._resolve_embedder(ai, name) - - def _resolve_model(self, ai: GenkitRegistry, name: str) -> None: - """Resolves and defines a Vertex AI model within the Genkit registry. - - This internal method handles the logic for registering different types of - Vertex AI models (e.g., Gemini text models, Imagen image models) based on - the provided name. It extracts a clean name, determines the model type, - instantiates the appropriate model class, and registers it with the Genkit - AI registry. - - Args: - ai: The Genkit AI registry instance to define the model in. - name: The name of the model to resolve. This name might include a - prefix indicating it's from a specific plugin (e.g., 'vertexai/gemini-pro'). - """ - _clean_name = name.replace(VERTEXAI_PLUGIN_NAME + '/', '') if name.startswith(VERTEXAI_PLUGIN_NAME) else name - - if _clean_name.lower().startswith('image'): - model_ref = vertexai_image_model_info(_clean_name) - model = ImagenModel(_clean_name, self._client) - IMAGE_SUPPORTED_MODELS[_clean_name] = model_ref - # config_schema = GenerateImagesConfigOrDict - else: - model_ref = google_model_info(_clean_name) - model = GeminiModel(_clean_name, self._client, ai) - SUPPORTED_MODELS[_clean_name] = model_ref - # config_schema = GeminiConfigSchema - - ai.define_model( - name=vertexai_name(_clean_name), - fn=model.generate, - metadata=model.metadata, - # config_schema=config_schema, - ) - - def _resolve_embedder(self, ai: GenkitRegistry, name: str) -> None: - """Resolves and defines a Vertex AI embedder within the Genkit registry. + return actions + + async def resolve(self, action_type: ActionKind, name: str) -> Action | None: + if action_type == ActionKind.MODEL: + if name.lower().startswith('image'): + model_ref = vertexai_image_model_info(name) + model_impl = ImagenModel(name, self._client) + IMAGE_SUPPORTED_MODELS[name] = model_ref + else: + model_ref = google_model_info(name) + model_impl = GeminiModel(name, self._client) + SUPPORTED_MODELS[name] = model_ref + return model( + name=name, + fn=model_impl.generate, + metadata=model_impl.metadata, + ) - This internal method handles the logic for registering Google AI embedder - models. It extracts a clean name, instantiates the embedder class, and - registers it with the Genkit AI registry. + if action_type == ActionKind.EMBEDDER: + embedder_impl = Embedder(version=name, client=self._client) + embedder_info = default_embedder_info(name) + return embedder( + name=name, + fn=embedder_impl.generate, + options=EmbedderOptions( + label=embedder_info.get('label'), + dimensions=embedder_info.get('dimensions'), + supports=EmbedderSupports(**embedder_info['supports']) if embedder_info.get('supports') else None, + ), + ) - Args: - ai: The Genkit AI registry instance to define the embedder in. - name: The name of the embedder to resolve. This name might include a - prefix indicating it's from a specific plugin (e.g., 'vertexai/embedding-001'). - """ - _clean_name = name.replace(VERTEXAI_PLUGIN_NAME + '/', '') if name.startswith(VERTEXAI_PLUGIN_NAME) else name - embedder = Embedder(version=_clean_name, client=self._client) - - embedder_info = default_embedder_info(_clean_name) - ai.define_embedder( - name=vertexai_name(_clean_name), - fn=embedder.generate, - options=EmbedderOptions( - label=embedder_info.get('label'), - dimensions=embedder_info.get('dimensions'), - supports=EmbedderSupports(**embedder_info['supports']) if embedder_info.get('supports') else None, - ), - ) + return None @cached_property - def list_actions(self) -> list[ActionMetadata]: + def _list_actions_cache(self) -> list[ActionMetadata]: """Generate a list of available actions or models. Returns: @@ -469,6 +384,9 @@ def list_actions(self) -> list[ActionMetadata]: return actions_list + async def list_actions(self) -> list[ActionMetadata]: + return list(self._list_actions_cache) + def _inject_attribution_headers(http_options: HttpOptions | dict | None = None): """Adds genkit client info to the appropriate http headers.""" diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/context_caching/__init__.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/context_caching/__init__.py index 19add86cb8..1bd71c3070 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/context_caching/__init__.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/context_caching/__init__.py @@ -13,4 +13,3 @@ # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 - diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/context_caching/utils.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/context_caching/utils.py index 405a88bb67..aedda46424 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/context_caching/utils.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/context_caching/utils.py @@ -30,7 +30,7 @@ def generate_cache_key(request: GenerateRequest) -> str: - """Generates context cache key by hashing the given request instance + """Generates context cache key by hashing the given request instance. Args: request: `GenerateRequest` instance to hash @@ -42,7 +42,7 @@ def generate_cache_key(request: GenerateRequest) -> str: def validate_context_cache_request(request: GenerateRequest, model_name: str) -> bool: - """Verifies that the context cache request could be processed for the request + """Verifies that the context cache request could be processed for the request. Args: request: `GenerateRequest` instance to check diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py index 12f09c1701..914eed55a3 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/gemini.py @@ -149,9 +149,7 @@ from google.genai import types as genai_types # type: ignore from genkit.ai import ( - ActionKind, ActionRunContext, - GenkitRegistry, ) from genkit.blocks.model import get_basic_usage_stats from genkit.codec import dump_dict, dump_json @@ -181,6 +179,25 @@ class GeminiConfigSchema(genai_types.GenerateContentConfig): """Gemini Config Schema.""" code_execution: bool | None = None + response_modalities: list[str] | None = None + + +class GeminiTtsConfigSchema(GeminiConfigSchema): + """Gemini TTS Config Schema.""" + + speech_config: dict[str, Any] | None = None + + +class GeminiImageConfigSchema(GeminiConfigSchema): + """Gemini Image Config Schema.""" + + image_config: dict[str, Any] | None = None + + +class GemmaConfigSchema(GeminiConfigSchema): + """Gemma Config Schema.""" + + temperature: float | None = None GEMINI_1_5_PRO = ModelInfo( @@ -341,6 +358,57 @@ class GeminiConfigSchema(genai_types.GenerateContentConfig): ), ) +GENERIC_GEMINI_MODEL = ModelInfo( + label='Google AI - Gemini', + supports=Supports( + multiturn=True, + media=True, + tools=True, + tool_choice=True, + system_role=True, + constrained='no-tools', + output=['text', 'json'], + ), +) + +GENERIC_TTS_MODEL = ModelInfo( + label='Google AI - Gemini TTS', + supports=Supports( + multiturn=False, + media=False, + tools=False, + tool_choice=False, + system_role=False, + constrained='no-tools', + ), +) + +GENERIC_IMAGE_MODEL = ModelInfo( + label='Google AI - Gemini Image', + supports=Supports( + multiturn=True, + media=True, + tools=True, + tool_choice=True, + system_role=True, + constrained='no-tools', + output=['text'], + ), +) + +GENERIC_GEMMA_MODEL = ModelInfo( + label='Google AI - Gemma', + supports=Supports( + multiturn=True, + media=True, + tools=True, + tool_choice=True, + system_role=True, + constrained='no-tools', + output=['text', 'json'], + ), +) + Deprecations = deprecated_enum_metafactory({ 'GEMINI_1_0_PRO': DeprecationInfo(recommendation='GEMINI_2_0_FLASH', status=DeprecationStatus.DEPRECATED), @@ -368,6 +436,21 @@ class VertexAIGeminiVersion(StrEnum, metaclass=Deprecations): | `gemini-2.5-pro-exp-03-25` | Gemini 2.5 Pro Exp 03-25 | Supported | | `gemini-2.5-pro-preview-03-25` | Gemini 2.5 Pro Preview 03-25 | Supported | | `gemini-2.5-pro-preview-05-06` | Gemini 2.5 Pro Preview 05-06 | Supported | + | `gemini-3-flash-preview` | Gemini 3 Flash Preview | Supported | + | `gemini-3-pro-preview` | Gemini 3 Pro Preview | Supported | + | `gemini-2.5-pro` | Gemini 2.5 Pro | Supported | + | `gemini-2.5-flash` | Gemini 2.5 Flash | Supported | + | `gemini-2.5-flash-lite` | Gemini 2.5 Flash Lite | Supported | + | `gemini-2.5-flash-preview-tts` | Gemini 2.5 Flash Preview TTS | Supported | + | `gemini-2.5-pro-preview-tts` | Gemini 2.5 Pro Preview TTS | Supported | + | `gemini-3-pro-image-preview` | Gemini 3 Pro Image Preview | Supported | + | `gemini-2.5-flash-image-preview` | Gemini 2.5 Flash Image Preview | Supported | + | `gemini-2.5-flash-image` | Gemini 2.5 Flash Image | Supported | + | `gemma-3-12b-it` | Gemma 3 12B IT | Supported | + | `gemma-3-1b-it` | Gemma 3 1B IT | Supported | + | `gemma-3-27b-it` | Gemma 3 27B IT | Supported | + | `gemma-3-4b-it` | Gemma 3 4B IT | Supported | + | `gemma-3n-e4b-it` | Gemma 3n E4B IT | Supported | """ GEMINI_1_5_FLASH = 'gemini-1.5-flash' @@ -381,6 +464,21 @@ class VertexAIGeminiVersion(StrEnum, metaclass=Deprecations): GEMINI_2_5_PRO_EXP_03_25 = 'gemini-2.5-pro-exp-03-25' GEMINI_2_5_PRO_PREVIEW_03_25 = 'gemini-2.5-pro-preview-03-25' GEMINI_2_5_PRO_PREVIEW_05_06 = 'gemini-2.5-pro-preview-05-06' + GEMINI_3_FLASH_PREVIEW = 'gemini-3-flash-preview' + GEMINI_3_PRO_PREVIEW = 'gemini-3-pro-preview' + GEMINI_2_5_PRO = 'gemini-2.5-pro' + GEMINI_2_5_FLASH = 'gemini-2.5-flash' + GEMINI_2_5_FLASH_LITE = 'gemini-2.5-flash-lite' + GEMINI_2_5_FLASH_PREVIEW_TTS = 'gemini-2.5-flash-preview-tts' + GEMINI_2_5_PRO_PREVIEW_TTS = 'gemini-2.5-pro-preview-tts' + GEMINI_3_PRO_IMAGE_PREVIEW = 'gemini-3-pro-image-preview' + GEMINI_2_5_FLASH_IMAGE_PREVIEW = 'gemini-2.5-flash-image-preview' + GEMINI_2_5_FLASH_IMAGE = 'gemini-2.5-flash-image' + GEMMA_3_12B_IT = 'gemma-3-12b-it' + GEMMA_3_1B_IT = 'gemma-3-1b-it' + GEMMA_3_27B_IT = 'gemma-3-27b-it' + GEMMA_3_4B_IT = 'gemma-3-4b-it' + GEMMA_3N_E4B_IT = 'gemma-3n-e4b-it' class GoogleAIGeminiVersion(StrEnum, metaclass=Deprecations): @@ -401,6 +499,21 @@ class GoogleAIGeminiVersion(StrEnum, metaclass=Deprecations): | `gemini-2.5-pro-exp-03-25` | Gemini 2.5 Pro Exp 03-25 | Supported | | `gemini-2.5-pro-preview-03-25` | Gemini 2.5 Pro Preview 03-25 | Supported | | `gemini-2.5-pro-preview-05-06` | Gemini 2.5 Pro Preview 05-06 | Supported | + | `gemini-3-flash-preview` | Gemini 3 Flash Preview | Supported | + | `gemini-3-pro-preview` | Gemini 3 Pro Preview | Supported | + | `gemini-2.5-pro` | Gemini 2.5 Pro | Supported | + | `gemini-2.5-flash` | Gemini 2.5 Flash | Supported | + | `gemini-2.5-flash-lite` | Gemini 2.5 Flash Lite | Supported | + | `gemini-2.5-flash-preview-tts` | Gemini 2.5 Flash Preview TTS | Supported | + | `gemini-2.5-pro-preview-tts` | Gemini 2.5 Pro Preview TTS | Supported | + | `gemini-3-pro-image-preview` | Gemini 3 Pro Image Preview | Supported | + | `gemini-2.5-flash-image-preview` | Gemini 2.5 Flash Image Preview | Supported | + | `gemini-2.5-flash-image` | Gemini 2.5 Flash Image | Supported | + | `gemma-3-12b-it` | Gemma 3 12B IT | Supported | + | `gemma-3-1b-it` | Gemma 3 1B IT | Supported | + | `gemma-3-27b-it` | Gemma 3 27B IT | Supported | + | `gemma-3-4b-it` | Gemma 3 4B IT | Supported | + | `gemma-3n-e4b-it` | Gemma 3n E4B IT | Supported | """ GEMINI_1_5_FLASH = 'gemini-1.5-flash' @@ -414,6 +527,21 @@ class GoogleAIGeminiVersion(StrEnum, metaclass=Deprecations): GEMINI_2_5_PRO_EXP_03_25 = 'gemini-2.5-pro-exp-03-25' GEMINI_2_5_PRO_PREVIEW_03_25 = 'gemini-2.5-pro-preview-03-25' GEMINI_2_5_PRO_PREVIEW_05_06 = 'gemini-2.5-pro-preview-05-06' + GEMINI_3_FLASH_PREVIEW = 'gemini-3-flash-preview' + GEMINI_3_PRO_PREVIEW = 'gemini-3-pro-preview' + GEMINI_2_5_PRO = 'gemini-2.5-pro' + GEMINI_2_5_FLASH = 'gemini-2.5-flash' + GEMINI_2_5_FLASH_LITE = 'gemini-2.5-flash-lite' + GEMINI_2_5_FLASH_PREVIEW_TTS = 'gemini-2.5-flash-preview-tts' + GEMINI_2_5_PRO_PREVIEW_TTS = 'gemini-2.5-pro-preview-tts' + GEMINI_3_PRO_IMAGE_PREVIEW = 'gemini-3-pro-image-preview' + GEMINI_2_5_FLASH_IMAGE_PREVIEW = 'gemini-2.5-flash-image-preview' + GEMINI_2_5_FLASH_IMAGE = 'gemini-2.5-flash-image' + GEMMA_3_12B_IT = 'gemma-3-12b-it' + GEMMA_3_1B_IT = 'gemma-3-1b-it' + GEMMA_3_27B_IT = 'gemma-3-27b-it' + GEMMA_3_4B_IT = 'gemma-3-4b-it' + GEMMA_3N_E4B_IT = 'gemma-3n-e4b-it' SUPPORTED_MODELS = { @@ -428,6 +556,21 @@ class GoogleAIGeminiVersion(StrEnum, metaclass=Deprecations): GoogleAIGeminiVersion.GEMINI_2_5_PRO_EXP_03_25: GEMINI_2_5_PRO_EXP_03_25, GoogleAIGeminiVersion.GEMINI_2_5_PRO_PREVIEW_03_25: GEMINI_2_5_PRO_PREVIEW_03_25, GoogleAIGeminiVersion.GEMINI_2_5_PRO_PREVIEW_05_06: GEMINI_2_5_PRO_PREVIEW_05_06, + GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW: GENERIC_GEMINI_MODEL, + GoogleAIGeminiVersion.GEMINI_3_PRO_PREVIEW: GENERIC_GEMINI_MODEL, + GoogleAIGeminiVersion.GEMINI_2_5_PRO: GENERIC_GEMINI_MODEL, + GoogleAIGeminiVersion.GEMINI_2_5_FLASH: GENERIC_GEMINI_MODEL, + GoogleAIGeminiVersion.GEMINI_2_5_FLASH_LITE: GENERIC_GEMINI_MODEL, + GoogleAIGeminiVersion.GEMINI_2_5_FLASH_PREVIEW_TTS: GENERIC_TTS_MODEL, + GoogleAIGeminiVersion.GEMINI_2_5_PRO_PREVIEW_TTS: GENERIC_TTS_MODEL, + GoogleAIGeminiVersion.GEMINI_3_PRO_IMAGE_PREVIEW: GENERIC_IMAGE_MODEL, + GoogleAIGeminiVersion.GEMINI_2_5_FLASH_IMAGE_PREVIEW: GENERIC_IMAGE_MODEL, + GoogleAIGeminiVersion.GEMINI_2_5_FLASH_IMAGE: GENERIC_IMAGE_MODEL, + GoogleAIGeminiVersion.GEMMA_3_12B_IT: GENERIC_GEMMA_MODEL, + GoogleAIGeminiVersion.GEMMA_3_1B_IT: GENERIC_GEMMA_MODEL, + GoogleAIGeminiVersion.GEMMA_3_27B_IT: GENERIC_GEMMA_MODEL, + GoogleAIGeminiVersion.GEMMA_3_4B_IT: GENERIC_GEMMA_MODEL, + GoogleAIGeminiVersion.GEMMA_3N_E4B_IT: GENERIC_GEMMA_MODEL, VertexAIGeminiVersion.GEMINI_1_5_FLASH: GEMINI_1_5_FLASH, VertexAIGeminiVersion.GEMINI_1_5_FLASH_8B: GEMINI_1_5_FLASH_8B, VertexAIGeminiVersion.GEMINI_1_5_PRO: GEMINI_1_5_PRO, @@ -439,6 +582,21 @@ class GoogleAIGeminiVersion(StrEnum, metaclass=Deprecations): VertexAIGeminiVersion.GEMINI_2_5_PRO_EXP_03_25: GEMINI_2_5_PRO_EXP_03_25, VertexAIGeminiVersion.GEMINI_2_5_PRO_PREVIEW_03_25: GEMINI_2_5_PRO_PREVIEW_03_25, VertexAIGeminiVersion.GEMINI_2_5_PRO_PREVIEW_05_06: GEMINI_2_5_PRO_PREVIEW_05_06, + VertexAIGeminiVersion.GEMINI_3_FLASH_PREVIEW: GENERIC_GEMINI_MODEL, + VertexAIGeminiVersion.GEMINI_3_PRO_PREVIEW: GENERIC_GEMINI_MODEL, + VertexAIGeminiVersion.GEMINI_2_5_PRO: GENERIC_GEMINI_MODEL, + VertexAIGeminiVersion.GEMINI_2_5_FLASH: GENERIC_GEMINI_MODEL, + VertexAIGeminiVersion.GEMINI_2_5_FLASH_LITE: GENERIC_GEMINI_MODEL, + VertexAIGeminiVersion.GEMINI_2_5_FLASH_PREVIEW_TTS: GENERIC_TTS_MODEL, + VertexAIGeminiVersion.GEMINI_2_5_PRO_PREVIEW_TTS: GENERIC_TTS_MODEL, + VertexAIGeminiVersion.GEMINI_3_PRO_IMAGE_PREVIEW: GENERIC_IMAGE_MODEL, + VertexAIGeminiVersion.GEMINI_2_5_FLASH_IMAGE_PREVIEW: GENERIC_IMAGE_MODEL, + VertexAIGeminiVersion.GEMINI_2_5_FLASH_IMAGE: GENERIC_IMAGE_MODEL, + VertexAIGeminiVersion.GEMMA_3_12B_IT: GENERIC_GEMMA_MODEL, + VertexAIGeminiVersion.GEMMA_3_1B_IT: GENERIC_GEMMA_MODEL, + VertexAIGeminiVersion.GEMMA_3_27B_IT: GENERIC_GEMMA_MODEL, + VertexAIGeminiVersion.GEMMA_3_4B_IT: GENERIC_GEMMA_MODEL, + VertexAIGeminiVersion.GEMMA_3N_E4B_IT: GENERIC_GEMMA_MODEL, } @@ -479,18 +637,15 @@ def __init__( self, version: str | GoogleAIGeminiVersion | VertexAIGeminiVersion, client: genai.Client, - registry: GenkitRegistry, ): """Initialize Gemini model. Args: version: Gemini version client: Google AI client - registry: Genkit registry """ self._version = version self._client = client - self._registry = registry def _get_tools(self, request: GenerateRequest) -> list[genai_types.Tool]: """Generates VertexAI Gemini compatible tool definitions. @@ -522,7 +677,7 @@ def _create_tool(self, tool: ToolDefinition) -> genai_types.Tool: name=tool.name, description=tool.description, parameters=params, - response=tool.output_schema, + response=self._convert_schema_property(tool.output_schema) if tool.output_schema else None, ) return genai_types.Tool(function_declarations=[function]) @@ -588,32 +743,6 @@ def _convert_schema_property( return schema - def _call_tool(self, call: genai_types.FunctionCall) -> genai_types.Content: - """Calls tool's function from the registry. - - Args: - call: FunctionCall from Gemini response - - Returns: - Gemini message content to add to the message - """ - tool_function = self._registry.registry.lookup_action(ActionKind.TOOL, call.name) - if tool_function is None: - raise LookupError(f'Tool {call.name} not found') - - args = tool_function.input_type.validate_python(call.args) - tool_answer = tool_function.run(args) - return genai_types.Content( - parts=[ - genai_types.Part.from_function_response( - name=call.name, - response={ - 'content': tool_answer.response, - }, - ) - ] - ) - async def _retrieve_cached_content( self, request: GenerateRequest, model_name: str, cache_config: dict, contents: list[genai_types.Content] ) -> genai_types.CachedContent: @@ -825,6 +954,8 @@ async def _build_messages( cache = None for msg in request.messages: + if msg.role == Role.SYSTEM: + continue content_parts: list[genai_types.Part] = [] for p in msg.content: content_parts.append(PartConverter.to_gemini(p)) @@ -838,6 +969,9 @@ async def _build_messages( contents=request_contents, ) + if not request_contents: + request_contents.append(genai_types.Content(parts=[genai_types.Part(text=' ')], role='user')) + return request_contents, cache def _contents_from_response(self, response: genai_types.GenerateContentResponse) -> list: @@ -853,8 +987,8 @@ def _contents_from_response(self, response: genai_types.GenerateContentResponse) if response.candidates: for candidate in response.candidates: if candidate.content: - for part in candidate.content.parts: - content.append(PartConverter.from_gemini(part=part)) + for i, part in enumerate(candidate.content.parts): + content.append(PartConverter.from_gemini(part=part, ref=str(i))) return content @@ -915,7 +1049,6 @@ def _genkit_to_googleai_cfg(self, request: GenerateRequest) -> genai_types.Gener for msg in system_messages: for p in msg.content: system_parts.append(PartConverter.to_gemini(p)) - request.messages.remove(msg) cfg.system_instruction = genai.types.Content(parts=system_parts) return cfg diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py index 7cb69c03bb..ec2ab72ed1 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py @@ -14,6 +14,7 @@ # # SPDX-License-Identifier: Apache-2.0 import base64 +from typing import Any from google import genai @@ -22,6 +23,7 @@ Media, MediaPart, Part, + ReasoningPart, TextPart, ToolRequest, ToolRequestPart, @@ -71,30 +73,43 @@ def to_gemini(cls, part: Part) -> genai.types.Part: A `genai.types.Part` object representing the converted content. """ if isinstance(part.root, TextPart): - return genai.types.Part(text=part.root.text) + return genai.types.Part(text=part.root.text or ' ') if isinstance(part.root, ToolRequestPart): return genai.types.Part( function_call=genai.types.FunctionCall( - id=part.root.tool_request.ref, - name=part.root.tool_request.name, + # Gemini throws on '/' in tool name + name=part.root.tool_request.name.replace('/', '__'), args=part.root.tool_request.input, - ) + ), + thought_signature=cls._extract_thought_signature(part.root.metadata), + ) + if isinstance(part.root, ReasoningPart): + return genai.types.Part( + thought=True, + text=part.root.reasoning, + thought_signature=cls._extract_thought_signature(part.root.metadata), ) if isinstance(part.root, ToolResponsePart): return genai.types.Part( function_response=genai.types.FunctionResponse( id=part.root.tool_response.ref, - name=part.root.tool_response.name, - response={'output': part.root.tool_response.output}, + name=part.root.tool_response.name.replace('/', '__'), + response=part.root.tool_response.output, ) ) if isinstance(part.root, MediaPart): url = part.root.media.url if not url.startswith(cls.DATA): raise ValueError(f'Unsupported media URL for inline_data: {url}') - data = base64.b64decode(url.split(',', 1)[1]) + + # Extract mime type and data from data:mime_type;base64,data + metadata, data_str = url.split(',', 1) + mime_type = part.root.media.content_type or metadata.split(':', 1)[1].split(';', 1)[0] + data = base64.b64decode(data_str) + return genai.types.Part( inline_data=genai.types.Blob( + mime_type=mime_type, data=data, ) ) @@ -131,7 +146,7 @@ def _to_gemini_custom(cls, part: Part) -> genai.types.Part: ) @classmethod - def from_gemini(cls, part: genai.types.Part) -> Part: + def from_gemini(cls, part: genai.types.Part, ref: str | None = None) -> Part: """Maps a Gemini Part back to a Genkit Part. This method inspects the type of the Gemini Part and converts it into @@ -140,26 +155,41 @@ def from_gemini(cls, part: genai.types.Part) -> Part: Args: part: The `genai.types.Part` object to convert. + ref: The tool call reference ID. Returns: A Genkit `Part` object representing the converted content. """ + if part.thought: + return Part( + root=ReasoningPart( + reasoning=part.text or '', + metadata=cls._encode_thought_signature(part.thought_signature), + ) + ) if part.text: - return Part(text=part.text) + return Part(root=TextPart(text=part.text)) if part.function_call: return Part( - toolRequest=ToolRequest( - ref=part.function_call.id, - name=part.function_call.name, - input=part.function_call.args, + root=ToolRequestPart( + tool_request=ToolRequest( + ref=ref or getattr(part.function_call, 'id', None), + # restore slashes + name=part.function_call.name.replace('__', '/'), + input=part.function_call.args, + ), + metadata=cls._encode_thought_signature(part.thought_signature), ) ) if part.function_response: return Part( - toolResponse=ToolResponse( - ref=part.function_call.id, - name=part.function_response.name, - output=part.function_response.response, + root=ToolResponsePart( + tool_response=ToolResponse( + ref=getattr(part.function_response, 'id', None), + # restore slashes + name=part.function_response.name.replace('__', '/'), + output=part.function_response.response, + ) ) ) if part.inline_data: @@ -188,3 +218,18 @@ def from_gemini(cls, part: genai.types.Part) -> Part: } } ) + + @classmethod + def _extract_thought_signature(cls, metadata: Any) -> bytes | None: + """Extracts and decodes the thought signature from metadata.""" + thought_sig = metadata.root.get('thoughtSignature') if metadata else None + if isinstance(thought_sig, str): + return base64.b64decode(thought_sig) + return None + + @classmethod + def _encode_thought_signature(cls, thought_signature: bytes | None) -> dict[str, str] | None: + """Encodes the thought signature into metadata format.""" + if thought_signature: + return {'thoughtSignature': base64.b64encode(thought_signature).decode('utf-8')} + return None diff --git a/py/plugins/google-genai/test/models/test_googlegenai_gemini.py b/py/plugins/google-genai/test/models/test_googlegenai_gemini.py index 50c05b2b9a..8f6aafe951 100644 --- a/py/plugins/google-genai/test/models/test_googlegenai_gemini.py +++ b/py/plugins/google-genai/test/models/test_googlegenai_gemini.py @@ -16,7 +16,7 @@ import sys import urllib.request -from unittest.mock import ANY, AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch if sys.version_info < (3, 11): # noqa from strenum import StrEnum # noqa diff --git a/py/plugins/google-genai/test/models/test_googlegenai_imagen.py b/py/plugins/google-genai/test/models/test_googlegenai_imagen.py index c884389193..eda45a130d 100644 --- a/py/plugins/google-genai/test/models/test_googlegenai_imagen.py +++ b/py/plugins/google-genai/test/models/test_googlegenai_imagen.py @@ -14,7 +14,6 @@ # # SPDX-License-Identifier: Apache-2.0 -import base64 import urllib.request import pytest diff --git a/py/plugins/google-genai/test/test_google_plugin.py b/py/plugins/google-genai/test/test_google_plugin.py index e589d86eec..bb4aa6a5d0 100644 --- a/py/plugins/google-genai/test/test_google_plugin.py +++ b/py/plugins/google-genai/test/test_google_plugin.py @@ -24,14 +24,13 @@ from google.auth.credentials import Credentials from pydantic import BaseModel -from google.genai.types import EmbedContentConfig, GenerateImagesConfigOrDict, HttpOptions +from google.genai.types import HttpOptions import pytest from genkit.ai import Genkit, GENKIT_CLIENT_HEADER from genkit.blocks.embedding import embedder_action_metadata, EmbedderOptions, EmbedderSupports from genkit.blocks.model import model_action_metadata from genkit.core.registry import ActionKind -from genkit.core.schema import to_json_schema from genkit.plugins.google_genai import ( GoogleAI, VertexAI, @@ -46,7 +45,6 @@ ) from genkit.plugins.google_genai.models.gemini import ( DEFAULT_SUPPORTS_MODEL, - GeminiConfigSchema, SUPPORTED_MODELS, GoogleAIGeminiVersion, VertexAIGeminiVersion, @@ -131,23 +129,24 @@ def test_init_raises_value_error_no_api_key(self): GoogleAI() -def test_googleai_initialize(): - """Unit tests for GoogleAI.initialize method.""" +@pytest.mark.asyncio +async def test_googleai_initialize(): + """Unit tests for GoogleAI.init method (V2).""" api_key = 'test_api_key' plugin = GoogleAI(api_key=api_key) - ai_mock = MagicMock(spec=Genkit) - plugin.initialize(ai_mock) + actions = await plugin.init() - assert ai_mock.define_model.call_count == len(GoogleAIGeminiVersion) - assert ai_mock.define_embedder.call_count == len(GeminiEmbeddingModels) + # Check we got actions for all models and embedders + model_actions = [a for a in actions if a.kind == ActionKind.MODEL] + embedder_actions = [a for a in actions if a.kind == ActionKind.EMBEDDER] + assert len(model_actions) == len(GoogleAIGeminiVersion) + assert len(embedder_actions) == len(GeminiEmbeddingModels) + + # Check model names are correct for version in GoogleAIGeminiVersion: - ai_mock.define_model.assert_any_call( - name=googleai_name(version), - fn=ANY, - metadata=ANY, - ) + assert any(a.name == str(version) for a in model_actions) for version in GeminiEmbeddingModels: ai_mock.define_embedder.assert_any_call( diff --git a/py/plugins/google-genai/tests/test_google_genai_plugin_v2.py b/py/plugins/google-genai/tests/test_google_genai_plugin_v2.py new file mode 100644 index 0000000000..66fe33f11d --- /dev/null +++ b/py/plugins/google-genai/tests/test_google_genai_plugin_v2.py @@ -0,0 +1,56 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from genkit.core.action.types import ActionKind +from genkit.plugins.google_genai.google import GoogleAI, VertexAI + + +@pytest.mark.asyncio +async def test_googleai_list_is_async(): + plugin = object.__new__(GoogleAI) + plugin._client = MagicMock() + plugin._client.models.list.return_value = [] + + metas = await plugin.list_actions() + assert isinstance(metas, list) + + +@pytest.mark.asyncio +async def test_vertexai_list_is_async(): + plugin = object.__new__(VertexAI) + plugin._client = MagicMock() + plugin._client.models.list.return_value = [] + + metas = await plugin.list_actions() + assert isinstance(metas, list) + + +@pytest.mark.asyncio +async def test_googleai_resolve_model_returns_action(): + plugin = object.__new__(GoogleAI) + plugin._client = MagicMock() + plugin._client.models.list.return_value = [] + + action = await plugin.resolve(ActionKind.MODEL, 'gemini-1.5-pro') + assert action is not None + assert action.kind == ActionKind.MODEL + assert action.name == 'gemini-1.5-pro' diff --git a/py/plugins/mcp/README.md b/py/plugins/mcp/README.md new file mode 100644 index 0000000000..1ad7262193 --- /dev/null +++ b/py/plugins/mcp/README.md @@ -0,0 +1,3 @@ +# Genkit MCP Plugin + +Integrate Model Context Protocol (MCP) with Genkit. diff --git a/py/plugins/mcp/examples/client/simple_client.py b/py/plugins/mcp/examples/client/simple_client.py new file mode 100644 index 0000000000..9512f12a5f --- /dev/null +++ b/py/plugins/mcp/examples/client/simple_client.py @@ -0,0 +1,53 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import asyncio + +from genkit.ai import Genkit +from genkit.plugins.mcp import McpServerConfig, create_mcp_client + +try: + from genkit.plugins.google_genai import GoogleAI +except ImportError: + GoogleAI = None + + +# Simple client example connecting to 'everything' server using npx +async def main(): + # Define the client plugin + everything_client = create_mcp_client( + name='everything', config=McpServerConfig(command='npx', args=['-y', '@modelcontextprotocol/server-everything']) + ) + + plugins = [everything_client] + if GoogleAI: + plugins.append(GoogleAI()) + + ai = Genkit(plugins=plugins) + + await everything_client.connect() + + print('Connected! Listing tools...') + + tools = await everything_client.list_tools() + for t in tools: + print(f'- {t.name}: {t.description}') + + await everything_client.close() + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/py/plugins/mcp/examples/server/prompts/port_code.prompt b/py/plugins/mcp/examples/server/prompts/port_code.prompt new file mode 100644 index 0000000000..77e8501b36 --- /dev/null +++ b/py/plugins/mcp/examples/server/prompts/port_code.prompt @@ -0,0 +1,13 @@ +--- +input: + schema: + code: string, the source code to port from one language to another + fromLang?: string, the original language of the source code (e.g. js, python) + toLang: string, the destination language of the source code (e.g. python, js) +--- + +You are assisting the user in translating code between two programming languages. Given the code below, translate it into {{toLang}}. + +```{{#if fromLang}}{{fromLang}}{{/if}} +{{code}} +``` diff --git a/py/plugins/mcp/examples/server/simple_server.py b/py/plugins/mcp/examples/server/simple_server.py new file mode 100644 index 0000000000..2405c74298 --- /dev/null +++ b/py/plugins/mcp/examples/server/simple_server.py @@ -0,0 +1,63 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import asyncio + +from pydantic import BaseModel, Field + +from genkit.ai import Genkit +from genkit.plugins.mcp import McpServerOptions, create_mcp_server + + +# Define input model +class AddInput(BaseModel): + a: int = Field(..., description='First number') + b: int = Field(..., description='Second number') + + +import os + + +def main(): + # Load prompts from the 'prompts' directory relative to this script + script_dir = os.path.dirname(os.path.abspath(__file__)) + prompts_dir = os.path.join(script_dir, 'prompts') + + ai = Genkit(prompt_dir=prompts_dir) + + @ai.tool(name='add', description='add two numbers together') + def add(input: AddInput): + return input.a + input.b + + # Genkit Python prompt definition (simplified) + # Note: In Python, prompts are typically loaded from files via prompt_dir + # This inline definition is for demonstration purposes + happy_prompt = ai.define_prompt( + input_schema={'action': str}, + prompt="If you're happy and you know it, {{action}}.", + ) + + # Create and start MCP server + # Note: create_mcp_server returns McpServer instance. + # In JS example: .start() is called. + server = create_mcp_server(ai, McpServerOptions(name='example_server', version='0.0.1')) + + print('Starting MCP server on stdio...') + asyncio.run(server.start()) + + +if __name__ == '__main__': + main() diff --git a/py/plugins/mcp/pyproject.toml b/py/plugins/mcp/pyproject.toml new file mode 100644 index 0000000000..6ea44f68ec --- /dev/null +++ b/py/plugins/mcp/pyproject.toml @@ -0,0 +1,48 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +[project] +authors = [{ name = "Google" }] +classifiers = [ + "Development Status :: 3 - Alpha", + "Environment :: Console", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries", +] +dependencies = ["genkit", "mcp"] +description = "Genkit MCP Plugin" +license = "Apache-2.0" +name = "genkit-plugins-mcp" +readme = "README.md" +requires-python = ">=3.10" +version = "0.1.0" + +[build-system] +build-backend = "hatchling.build" +requires = ["hatchling"] + +[tool.hatch.build.targets.wheel] +packages = ["src"] diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/__init__.py b/py/plugins/mcp/src/genkit/plugins/mcp/__init__.py new file mode 100644 index 0000000000..7e48a29a71 --- /dev/null +++ b/py/plugins/mcp/src/genkit/plugins/mcp/__init__.py @@ -0,0 +1,40 @@ +""" +Copyright 2026 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from .client.client import ( + McpClient, + McpServerConfig, + create_mcp_client, +) +from .client.host import McpHost, create_mcp_host +from .server import McpServer, McpServerOptions, create_mcp_server + + +def package_name() -> str: + return 'genkit.plugins.mcp' + + +__all__ = [ + 'McpClient', + 'McpHost', + 'McpServerConfig', + 'create_mcp_client', + 'create_mcp_host', + 'McpServer', + 'McpServerOptions', + 'create_mcp_server', + 'package_name', +] diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/client/__init__.py b/py/plugins/mcp/src/genkit/plugins/mcp/client/__init__.py new file mode 100644 index 0000000000..fe4b8ffe1f --- /dev/null +++ b/py/plugins/mcp/src/genkit/plugins/mcp/client/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/client/client.py b/py/plugins/mcp/src/genkit/plugins/mcp/client/client.py new file mode 100644 index 0000000000..4ef9e715a0 --- /dev/null +++ b/py/plugins/mcp/src/genkit/plugins/mcp/client/client.py @@ -0,0 +1,208 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import uuid +from typing import Any, Callable, Dict, List, Optional, Union + +import structlog +from pydantic import BaseModel + +from genkit.ai import Genkit +from genkit.ai._plugin import Plugin +from genkit.ai._registry import GenkitRegistry +from genkit.core.action.types import ActionKind +from mcp import ClientSession, StdioServerParameters +from mcp.client.sse import sse_client +from mcp.client.stdio import stdio_client +from mcp.types import CallToolResult, Prompt, Resource, Tool + +logger = structlog.get_logger(__name__) + + +class McpServerConfig(BaseModel): + command: Optional[str] = None + args: Optional[List[str]] = None + env: Optional[Dict[str, str]] = None + url: Optional[str] = None + disabled: bool = False + + +class McpClient(Plugin): + """Client for connecting to a single MCP server.""" + + def __init__(self, name: str, config: McpServerConfig, server_name: Optional[str] = None): + self.name = name + self.config = config + self.server_name = server_name or name + self.session: Optional[ClientSession] = None + self._exit_stack = None + self._session_context = None + self.ai: Optional[GenkitRegistry] = None + + def plugin_name(self) -> str: + return self.name + + def initialize(self, ai: GenkitRegistry) -> None: + self.ai = ai + + def resolve_action(self, ai: GenkitRegistry, kind: ActionKind, name: str) -> None: + # MCP tools are dynamic and currently registered upon connection/Discovery. + # This hook allows lazy resolution if we implement it. + pass + + async def connect(self): + """Connects to the MCP server.""" + if self.config.disabled: + logger.info(f'MCP server {self.server_name} is disabled.') + return + + try: + if self.config.command: + server_params = StdioServerParameters( + command=self.config.command, args=self.config.args or [], env=self.config.env + ) + # stdio_client returns (read, write) streams + stdio_context = stdio_client(server_params) + read, write = await stdio_context.__aenter__() + self._exit_stack = stdio_context + + # Create and initialize session + session_context = ClientSession(read, write) + self.session = await session_context.__aenter__() + self._session_context = session_context + + elif self.config.url: + # TODO: Verify SSE client usage in mcp python SDK + sse_context = sse_client(self.config.url) + read, write = await sse_context.__aenter__() + self._exit_stack = sse_context + + session_context = ClientSession(read, write) + self.session = await session_context.__aenter__() + self._session_context = session_context + + await self.session.initialize() + logger.info(f'Connected to MCP server: {self.server_name}') + + except Exception as e: + logger.error(f'Failed to connect to MCP server {self.server_name}: {e}') + self.config.disabled = True + # Clean up on error + await self.close() + raise e + + async def close(self): + """Closes the connection.""" + if hasattr(self, '_session_context') and self._session_context: + try: + await self._session_context.__aexit__(None, None, None) + except Exception as e: + logger.debug(f'Error closing session: {e}') + if self._exit_stack: + try: + await self._exit_stack.__aexit__(None, None, None) + except Exception as e: + logger.debug(f'Error closing transport: {e}') + + async def list_tools(self) -> List[Tool]: + if not self.session: + return [] + result = await self.session.list_tools() + return result.tools + + async def call_tool(self, tool_name: str, arguments: dict) -> Any: + if not self.session: + raise RuntimeError('MCP client is not connected') + result: CallToolResult = await self.session.call_tool(tool_name, arguments) + # Process result similarly to JS SDK + if result.isError: + raise RuntimeError(f'Tool execution failed: {result.content}') + + # Simple text extraction for now + texts = [c.text for c in result.content if c.type == 'text'] + return ''.join(texts) + + async def list_prompts(self) -> List[Prompt]: + if not self.session: + return [] + result = await self.session.list_prompts() + return result.prompts + + async def get_prompt(self, name: str, arguments: Optional[dict] = None) -> Any: + if not self.session: + raise RuntimeError('MCP client is not connected') + return await self.session.get_prompt(name, arguments) + + async def list_resources(self) -> List[Resource]: + if not self.session: + return [] + result = await self.session.list_resources() + return result.resources + + async def read_resource(self, uri: str) -> Any: + if not self.session: + raise RuntimeError('MCP client is not connected') + return await self.session.read_resource(uri) + + async def register_tools(self, ai: Optional[Genkit] = None): + """Registers all tools from connected client to Genkit.""" + registry = ai.registry if ai else (self.ai.registry if self.ai else None) + if not registry: + logger.warning('No Genkit registry available to register tools.') + return + + if not self.session: + return + + try: + tools = await self.list_tools() + for tool in tools: + # Create a wrapper function for the tool + # We need to capture tool and client in closure + async def tool_wrapper(args: Any = None, _tool_name=tool.name): + # args might be Pydantic model or dict. Genkit passes dict usually? + # TODO: Validate args against schema if needed + arguments = args + if hasattr(args, 'model_dump'): + arguments = args.model_dump() + return await self.call_tool(_tool_name, arguments or {}) + + # Use metadata to store MCP specific info + metadata = {'mcp': {'_meta': tool._meta}} if hasattr(tool, '_meta') else {} + + # Define the tool in Genkit registry + registry.register_action( + kind=ActionKind.TOOL, + name=f'{self.server_name}/{tool.name}', + fn=tool_wrapper, + description=tool.description, + metadata=metadata, + # TODO: json_schema conversion from tool.inputSchema + ) + logger.debug(f'Registered MCP tool: {self.server_name}/{tool.name}') + except Exception as e: + logger.error(f'Error registering tools for {self.server_name}: {e}') + + async def get_active_tools(self) -> List[Any]: + """Returns all active tools.""" + if not self.session: + return [] + return await self.list_tools() + + +def create_mcp_client(config: McpServerConfig, name: str = 'mcp-client') -> McpClient: + return McpClient(name, config) diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/client/host.py b/py/plugins/mcp/src/genkit/plugins/mcp/client/host.py new file mode 100644 index 0000000000..cd0a4691d5 --- /dev/null +++ b/py/plugins/mcp/src/genkit/plugins/mcp/client/host.py @@ -0,0 +1,64 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, List, Optional + +from genkit.ai import Genkit + +from .client import McpClient, McpServerConfig + + +class McpHost: + """Host for managing multiple MCP clients.""" + + def __init__(self, clients: Dict[str, McpServerConfig]): + self.clients_config = clients + self.clients: Dict[str, McpClient] = {name: McpClient(name, config) for name, config in clients.items()} + + async def start(self): + """Starts all enabled MCP clients.""" + for client in self.clients.values(): + if not client.config.disabled: + await client.connect() + + async def close(self): + """Closes all MCP clients.""" + for client in self.clients.values(): + await client.close() + + async def register_tools(self, ai: Genkit): + """Registers all tools from connected clients to Genkit.""" + for client in self.clients.values(): + if client.session: + await client.register_tools(ai) + + async def enable(self, name: str): + """Enables and connects an MCP client.""" + if name in self.clients: + client = self.clients[name] + client.config.disabled = False + await client.connect() + + async def disable(self, name: str): + """Disables and closes an MCP client.""" + if name in self.clients: + client = self.clients[name] + client.config.disabled = True + await client.close() + + +def create_mcp_host(configs: Dict[str, McpServerConfig]) -> McpHost: + return McpHost(configs) diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/index.py b/py/plugins/mcp/src/genkit/plugins/mcp/index.py new file mode 100644 index 0000000000..4f859e2fe1 --- /dev/null +++ b/py/plugins/mcp/src/genkit/plugins/mcp/index.py @@ -0,0 +1,40 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +MCP Plugin Index + +This module serves as the main entry point for the MCP plugin, +similar to js/plugins/mcp/src/index.ts. + +In Python, the actual exports are handled by the parent __init__.py, +but this file exists for structural parity with the JS SDK. +""" + +from .client.client import McpClient, McpServerConfig, create_mcp_client +from .client.host import McpHost, create_mcp_host +from .server import McpServer, McpServerOptions, create_mcp_server + +__all__ = [ + 'McpClient', + 'McpHost', + 'McpServerConfig', + 'create_mcp_client', + 'create_mcp_host', + 'McpServer', + 'McpServerOptions', + 'create_mcp_server', +] diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/server.py b/py/plugins/mcp/src/genkit/plugins/mcp/server.py new file mode 100644 index 0000000000..3d313ccda1 --- /dev/null +++ b/py/plugins/mcp/src/genkit/plugins/mcp/server.py @@ -0,0 +1,463 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# distributed under the License. +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""MCP Server implementation for exposing Genkit actions via Model Context Protocol.""" + +import asyncio +from typing import Any, Optional + +import structlog +from pydantic import BaseModel + +from genkit.ai import Genkit +from genkit.blocks.resource import matches_uri_template +from genkit.core.action._key import parse_action_key +from genkit.core.action.types import ActionKind +from genkit.core.error import GenkitError +from genkit.core.schema import to_json_schema +from mcp.server import Server +from mcp.server.stdio import stdio_server +from mcp.types import ( + CallToolRequest, + CallToolResult, + GetPromptRequest, + GetPromptResult, + ListPromptsRequest, + ListPromptsResult, + ListResourcesRequest, + ListResourcesResult, + ListResourceTemplatesRequest, + ListResourceTemplatesResult, + ListToolsRequest, + ListToolsResult, + Prompt, + ReadResourceRequest, + ReadResourceResult, + Resource, + ResourceTemplate, + Tool, +) + +from .util import ( + to_mcp_prompt_arguments, + to_mcp_prompt_message, + to_mcp_resource_contents, + to_mcp_tool_result, +) + +logger = structlog.get_logger(__name__) + + +class McpServerOptions(BaseModel): + """Options for creating an MCP server. + + Attributes: + name: The name of the MCP server. + version: The version of the server (default: "1.0.0"). + """ + + name: str + version: str = '1.0.0' + + +class McpServer: + """Exposes Genkit tools, prompts, and resources as an MCP server. + + This class wraps a Genkit instance and makes its registered actions + (tools, prompts, resources) available to MCP clients via the Model Context Protocol. + """ + + def __init__(self, ai: Genkit, options: McpServerOptions): + """Initialize the MCP server. + + Args: + ai: The Genkit instance whose actions will be exposed. + options: Configuration options for the MCP server. + """ + self.ai = ai + self.options = options + self.server: Optional[Server] = None + self.actions_resolved = False + self.tool_actions: list[Any] = [] + self.prompt_actions: list[Any] = [] + self.resource_actions: list[Any] = [] + self.tool_actions_map: dict[str, Any] = {} + self.prompt_actions_map: dict[str, Any] = {} + self.resource_uri_map: dict[str, Any] = {} + self.resource_templates: list[tuple[str, Any]] = [] + + async def setup(self) -> None: + """Initialize the MCP server and register request handlers. + + This method sets up the MCP Server instance, registers all request handlers, + and resolves all actions from the Genkit registry. It's idempotent and can + be called multiple times safely. + """ + if self.actions_resolved: + return + + # Create MCP Server instance + self.server = Server( + self.options.name, + version=self.options.version, + ) + + # Register request handlers using decorators + self.server.list_tools()(self.list_tools) + self.server.call_tool()(self.call_tool) + self.server.list_prompts()(self.list_prompts) + self.server.get_prompt()(self.get_prompt) + self.server.list_resources()(self.list_resources) + self.server.list_resource_templates()(self.list_resource_templates) + self.server.read_resource()(self.read_resource) + + # Resolve all actions from Genkit registry + # We need the actual Action objects, not just serializable dicts + self.tool_actions = [] + self.prompt_actions = [] + self.resource_actions = [] + + # Get all actions from the registry + # We use the internal _entries for local actions and plugins + with self.ai.registry._lock: + for kind, entries in self.ai.registry._entries.items(): + for name, action in entries.items(): + if kind == ActionKind.TOOL: + self.tool_actions.append(action) + self.tool_actions_map[action.name] = action + elif kind == ActionKind.PROMPT: + self.prompt_actions.append(action) + self.prompt_actions_map[action.name] = action + elif kind == ActionKind.RESOURCE: + self.resource_actions.append(action) + metadata = action.metadata or {} + resource_meta = metadata.get('resource', {}) + if resource_meta.get('uri'): + self.resource_uri_map[resource_meta['uri']] = action + if resource_meta.get('template'): + self.resource_templates.append((resource_meta['template'], action)) + + # Also get actions from plugins that might not be in _entries yet + # (though most plugins register them in _entries during initialization) + plugin_actions = self.ai.registry.list_actions() + for key in plugin_actions: + kind, name = parse_action_key(key) + action = self.ai.registry.lookup_action(kind, name) + if action: + if kind == ActionKind.TOOL and action not in self.tool_actions: + self.tool_actions.append(action) + self.tool_actions_map[action.name] = action + elif kind == ActionKind.PROMPT and action not in self.prompt_actions: + self.prompt_actions.append(action) + self.prompt_actions_map[action.name] = action + elif kind == ActionKind.RESOURCE and action not in self.resource_actions: + self.resource_actions.append(action) + metadata = action.metadata or {} + resource_meta = metadata.get('resource', {}) + if resource_meta.get('uri'): + self.resource_uri_map[resource_meta['uri']] = action + if resource_meta.get('template'): + self.resource_templates.append((resource_meta['template'], action)) + + self.actions_resolved = True + + logger.info( + f'MCP Server initialized', + tools=len(self.tool_actions), + prompts=len(self.prompt_actions), + resources=len(self.resource_actions), + ) + + async def list_tools(self, request: ListToolsRequest) -> ListToolsResult: + """Handle MCP requests to list available tools. + + Args: + request: The MCP ListToolsRequest. + + Returns: + ListToolsResult containing all registered Genkit tools. + """ + await self.setup() + + tools: list[Tool] = [] + for action in self.tool_actions: + # Get tool definition + input_schema = to_json_schema(action.input_schema) if action.input_schema else {'type': 'object'} + + tools.append( + Tool( + name=action.name, + description=action.description or '', + inputSchema=input_schema, + _meta=action.metadata.get('mcp', {}).get('_meta') if action.metadata else None, + ) + ) + + return ListToolsResult(tools=tools) + + async def call_tool(self, request: CallToolRequest) -> CallToolResult: + """Handle MCP requests to call a specific tool. + + Args: + request: The MCP CallToolRequest containing tool name and arguments. + + Returns: + CallToolResult with the tool execution result. + + Raises: + GenkitError: If the requested tool is not found. + """ + await self.setup() + + # Find the tool action + tool = self.tool_actions_map.get(request.params.name) + + if not tool: + raise GenkitError( + status='NOT_FOUND', message=f"Tried to call tool '{request.params.name}' but it could not be found." + ) + + # Execute the tool + result = await tool.arun(request.params.arguments) + result = result.response + + # Convert result to MCP format + return CallToolResult(content=to_mcp_tool_result(result)) + + async def list_prompts(self, request: ListPromptsRequest) -> ListPromptsResult: + """Handle MCP requests to list available prompts. + + Args: + request: The MCP ListPromptsRequest. + + Returns: + ListPromptsResult containing all registered Genkit prompts. + """ + await self.setup() + + prompts: list[Prompt] = [] + for action in self.prompt_actions: + # Convert input schema to MCP prompt arguments + input_schema = to_json_schema(action.input_schema) if action.input_schema else None + arguments = to_mcp_prompt_arguments(input_schema) if input_schema else None + + prompts.append( + Prompt( + name=action.name, + description=action.description or '', + arguments=arguments, + _meta=action.metadata.get('mcp', {}).get('_meta') if action.metadata else None, + ) + ) + + return ListPromptsResult(prompts=prompts) + + async def get_prompt(self, request: GetPromptRequest) -> GetPromptResult: + """Handle MCP requests to get (render) a specific prompt. + + Args: + request: The MCP GetPromptRequest containing prompt name and arguments. + + Returns: + GetPromptResult with the rendered prompt messages. + + Raises: + GenkitError: If the requested prompt is not found. + """ + await self.setup() + + # Find the prompt action + prompt = self.prompt_actions_map.get(request.params.name) + + if not prompt: + raise GenkitError( + status='NOT_FOUND', + message=f"[MCP Server] Tried to call prompt '{request.params.name}' but it could not be found.", + ) + + # Execute the prompt + result = await prompt.arun(request.params.arguments) + result = result.response + + # Convert messages to MCP format + messages = [to_mcp_prompt_message(msg) for msg in result.messages] + + return GetPromptResult(description=prompt.description, messages=messages) + + async def list_resources(self, request: ListResourcesRequest) -> ListResourcesResult: + """Handle MCP requests to list available resources with fixed URIs. + + Args: + request: The MCP ListResourcesRequest. + + Returns: + ListResourcesResult containing resources with fixed URIs. + """ + await self.setup() + + resources: list[Resource] = [] + for action in self.resource_actions: + metadata = action.metadata or {} + resource_meta = metadata.get('resource', {}) + + # Only include resources with fixed URIs (not templates) + if resource_meta.get('uri'): + resources.append( + Resource( + name=action.name, + description=action.description or '', + uri=resource_meta['uri'], + _meta=metadata.get('mcp', {}).get('_meta'), + ) + ) + + return ListResourcesResult(resources=resources) + + async def list_resource_templates(self, request: ListResourceTemplatesRequest) -> ListResourceTemplatesResult: + """Handle MCP requests to list available resource templates. + + Args: + request: The MCP ListResourceTemplatesRequest. + + Returns: + ListResourceTemplatesResult containing resources with URI templates. + """ + await self.setup() + + templates: list[ResourceTemplate] = [] + for action in self.resource_actions: + metadata = action.metadata or {} + resource_meta = metadata.get('resource', {}) + + # Only include resources with templates + if resource_meta.get('template'): + templates.append( + ResourceTemplate( + name=action.name, + description=action.description or '', + uriTemplate=resource_meta['template'], + _meta=metadata.get('mcp', {}).get('_meta'), + ) + ) + + return ListResourceTemplatesResult(resourceTemplates=templates) + + async def read_resource(self, request: ReadResourceRequest) -> ReadResourceResult: + """Handle MCP requests to read a specific resource. + + Args: + request: The MCP ReadResourceRequest containing the resource URI. + + Returns: + ReadResourceResult with the resource content. + + Raises: + GenkitError: If no matching resource is found. + """ + await self.setup() + + uri = request.params.uri + + # Check for exact URI match + resource = self.resource_uri_map.get(uri) + + # Check for template match if not found by exact URI + if not resource: + for template, action in self.resource_templates: + if matches_uri_template(template, uri): + resource = action + break + + if not resource: + raise GenkitError(status='NOT_FOUND', message=f"Tried to call resource '{uri}' but it could not be found.") + + # Execute the resource action + result = await resource.arun({'uri': uri}) + result = result.response + + # Convert content to MCP format + content = result.get('content', []) if isinstance(result, dict) else result.content + contents = to_mcp_resource_contents(uri, content) + + return ReadResourceResult(contents=contents) + + async def start(self, transport: Any = None) -> None: + """Start the MCP server with the specified transport. + + Args: + transport: Optional MCP transport instance. If not provided, + a StdioServerTransport will be created and used. + """ + await self.setup() + + if not transport: + async with stdio_server() as (read, write): + await self.server.run(read, write, self.server.create_initialization_options()) + else: + # Connect the transport + async with transport as (read, write): + await self.server.run(read, write, self.server.create_initialization_options()) + + logger.debug(f"[MCP Server] MCP server '{self.options.name}' started successfully.") + + +# Schema types from mcp.types +ListToolsRequestSchema = ListToolsRequest +CallToolRequestSchema = CallToolRequest +ListPromptsRequestSchema = ListPromptsRequest +GetPromptRequestSchema = GetPromptRequest +ListResourcesRequestSchema = ListResourcesRequest +ListResourceTemplatesRequestSchema = ListResourceTemplatesRequest +ReadResourceRequestSchema = ReadResourceRequest + + +def create_mcp_server(ai: Genkit, options: McpServerOptions) -> McpServer: + """Create an MCP server based on the supplied Genkit instance. + + All tools, prompts, and resources will be automatically converted to MCP compatibility. + + Args: + ai: Your Genkit instance with registered tools, prompts, and resources. + options: Configuration metadata for the server. + + Returns: + GenkitMcpServer instance. + + Example: + ```python + from genkit.ai import Genkit + from genkit.plugins.mcp import create_mcp_server, McpServerOptions + + ai = Genkit() + + + # Define some tools and resources + @ai.tool() + def add(a: int, b: int) -> int: + return a + b + + + ai.define_resource(name='my_resource', uri='my://resource', fn=lambda req: {'content': [{'text': 'resource content'}]}) + + # Create and start MCP server + server = create_mcp_server(ai, McpServerOptions(name='my-server')) + await server.start() + ``` + """ + return McpServer(ai, options) diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/util/__init__.py b/py/plugins/mcp/src/genkit/plugins/mcp/util/__init__.py new file mode 100644 index 0000000000..a45dae463f --- /dev/null +++ b/py/plugins/mcp/src/genkit/plugins/mcp/util/__init__.py @@ -0,0 +1,58 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Utility functions for MCP plugin. + +This module contains helper functions for: +- Tool conversion and registration +- Prompt conversion and rendering +- Resource handling +- Message mapping between Genkit and MCP formats +- Transport utilities +""" + +from .message import from_mcp_part, from_mcp_prompt_message, to_mcp_prompt_message +from .prompts import convert_mcp_prompt_messages, convert_prompt_arguments_to_schema, to_mcp_prompt_arguments, to_schema +from .resource import ( + convert_resource_to_genkit_part, + from_mcp_resource_part, + process_resource_content, + to_mcp_resource_contents, +) +from .tools import convert_tool_schema, process_result, process_tool_result, to_mcp_tool_result, to_text +from .transport import create_stdio_params, transport_from + +__all__ = [ + 'process_tool_result', + 'process_result', + 'to_text', + 'convert_tool_schema', + 'convert_prompt_arguments_to_schema', + 'convert_mcp_prompt_messages', + 'to_schema', + 'from_mcp_prompt_message', + 'from_mcp_part', + 'process_resource_content', + 'convert_resource_to_genkit_part', + 'from_mcp_resource_part', + 'create_stdio_params', + 'transport_from', + 'to_mcp_prompt_message', + 'to_mcp_resource_contents', + 'to_mcp_tool_result', + 'to_mcp_prompt_arguments', +] diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/util/message.py b/py/plugins/mcp/src/genkit/plugins/mcp/util/message.py new file mode 100644 index 0000000000..97de8a4e90 --- /dev/null +++ b/py/plugins/mcp/src/genkit/plugins/mcp/util/message.py @@ -0,0 +1,169 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Message utilities for MCP plugin. + +This module contains helper functions for converting between MCP message +formats and Genkit message formats. +""" + +from typing import Any, Dict + +import structlog + +from genkit.core.typing import Message +from mcp.types import ImageContent, PromptMessage, TextContent + +logger = structlog.get_logger(__name__) + +# Role mapping from MCP to Genkit +ROLE_MAP = { + 'user': 'user', + 'assistant': 'model', +} + + +def from_mcp_prompt_message(message: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert MCP PromptMessage to Genkit MessageData format. + + This involves mapping MCP roles (user, assistant) to Genkit roles (user, model) + and transforming the MCP content part into a Genkit Part. + + Args: + message: MCP PromptMessage with 'role' and 'content' fields + + Returns: + Genkit MessageData object with 'role' and 'content' fields + """ + return { + 'role': ROLE_MAP.get(message.get('role', 'user'), 'user'), + 'content': [from_mcp_part(message.get('content', {}))], + } + + +def from_mcp_part(part: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert MCP message content part to Genkit Part. + + Handles different content types: + - Text parts are directly mapped + - Image parts are converted to Genkit media parts with data URL + - Resource parts are mapped to Genkit resource format + + Args: + part: MCP PromptMessage content part + + Returns: + Genkit Part object + """ + part_type = part.get('type', '') + + if part_type == 'text': + return {'text': part.get('text', '')} + + elif part_type == 'image': + mime_type = part.get('mimeType', 'image/png') + data = part.get('data', '') + return { + 'media': { + 'contentType': mime_type, + 'url': f'data:{mime_type};base64,{data}', + } + } + + elif part_type == 'resource': + return { + 'resource': { + 'uri': str(part.get('uri', '')), + } + } + + # Default case for unknown types + return {} + + +def _get_part_data(part: Any) -> Dict[str, Any]: + """Extract data from a Part, handling potential 'root' nesting.""" + if isinstance(part, str): + return {'text': part} + part_dict = part if isinstance(part, dict) else part.model_dump() + if 'root' in part_dict and isinstance(part_dict['root'], dict): + return part_dict['root'] + return part_dict + + +def _parse_media_part(media: Dict[str, Any]) -> ImageContent: + """Extract MIME type and base64 data from a media part.""" + url = media.get('url', '') + content_type = media.get('contentType', '') + + if not url.startswith('data:'): + raise ValueError('MCP prompt messages only support base64 data images.') + + # Extract MIME type and base64 data + try: + mime_type = content_type or url[url.index(':') + 1 : url.index(';')] + data = url[url.index(',') + 1 :] + except ValueError as e: + raise ValueError(f'Invalid data URL format: {url}') from e + + return ImageContent(type='image', data=data, mimeType=mime_type) + + +def to_mcp_prompt_message(message: Message) -> PromptMessage: + """Convert a Genkit Message to an MCP PromptMessage. + + MCP only supports 'user' and 'assistant' roles. Genkit's 'model' role + is mapped to 'assistant'. + + Args: + message: The Genkit Message to convert. + + Returns: + An MCP PromptMessage. + + Raises: + ValueError: If the message role is not 'user' or 'model'. + ValueError: If media is not a base64 data URL. + """ + # Map Genkit roles to MCP roles + role_map = {'model': 'assistant', 'user': 'user'} + + if message.role not in role_map: + raise ValueError( + f"MCP prompt messages do not support role '{message.role}'. Only 'user' and 'model' messages are supported." + ) + + mcp_role = role_map[message.role] + + # First, look for any media content as MCP content is currently single-part + if message.content: + for part in message.content: + data = _get_part_data(part) + if data.get('media'): + return PromptMessage(role=mcp_role, content=_parse_media_part(data['media'])) + + # If no media, aggregate all text content + text_content = [] + if message.content: + for part in message.content: + data = _get_part_data(part) + if data.get('text'): + text_content.append(data['text']) + + return PromptMessage(role=mcp_role, content=TextContent(type='text', text=''.join(text_content))) diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/util/prompts.py b/py/plugins/mcp/src/genkit/plugins/mcp/util/prompts.py new file mode 100644 index 0000000000..469e91f7ea --- /dev/null +++ b/py/plugins/mcp/src/genkit/plugins/mcp/util/prompts.py @@ -0,0 +1,137 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Prompt utilities for MCP plugin. + +This module contains helper functions for converting between MCP prompts +and Genkit prompts, including schema and message conversion. +""" + +from typing import Any, Dict, List, Optional + +import structlog + +from mcp.types import GetPromptResult, Prompt + +logger = structlog.get_logger(__name__) + + +def to_schema(arguments: Optional[List[Dict[str, Any]]]) -> Dict[str, Any]: + """ + Convert MCP prompt arguments to JSON schema format. + + Args: + arguments: List of MCP prompt argument definitions with 'name', + 'description', and 'required' fields + + Returns: + JSON schema representing the prompt arguments + """ + if not arguments: + return {} + + schema: Dict[str, Any] = {'type': 'object', 'properties': {}, 'required': []} + + for arg in arguments: + arg_name = arg.get('name', '') + schema['properties'][arg_name] = { + 'type': 'string', + 'description': arg.get('description', ''), + } + if arg.get('required', False): + schema['required'].append(arg_name) + + return schema + + +def convert_prompt_arguments_to_schema(arguments: List[Any]) -> Dict[str, Any]: + """ + Convert MCP prompt arguments to JSON schema format. + + This is an alias for to_schema() for backwards compatibility. + + Args: + arguments: List of MCP prompt argument definitions + + Returns: + JSON schema representing the prompt arguments + """ + return to_schema(arguments) + + +def convert_mcp_prompt_messages(prompt_result: GetPromptResult) -> List[Dict[str, Any]]: + """ + Convert MCP prompt messages to Genkit message format. + + Args: + prompt_result: The GetPromptResult from MCP server containing messages + + Returns: + List of Genkit-formatted messages + """ + from .message import from_mcp_prompt_message + + if not hasattr(prompt_result, 'messages') or not prompt_result.messages: + return [] + + return [from_mcp_prompt_message(msg) for msg in prompt_result.messages] + + +def to_mcp_prompt_arguments(input_schema: dict[str, Any] | None) -> list[dict[str, Any]] | None: + """Convert Genkit input schema to MCP prompt arguments. + + MCP prompts only support string arguments. This function validates that + all properties in the schema are strings. + + Args: + input_schema: The Genkit input JSON schema. + + Returns: + List of MCP prompt argument definitions, or None if no schema. + + Raises: + ValueError: If the schema is not an object type. + ValueError: If any property is not a string type. + """ + if not input_schema: + return None + + if not input_schema.get('properties'): + raise ValueError('MCP prompts must take objects with properties as input schema.') + + args: list[dict[str, Any]] = [] + properties = input_schema['properties'] + required = input_schema.get('required', []) + + for name, prop in properties.items(): + prop_type = prop.get('type') + + # Check if type is string or includes string (for union types) + is_string = prop_type == 'string' or (isinstance(prop_type, list) and 'string' in prop_type) + + if not is_string: + raise ValueError( + f"MCP prompts may only take string arguments, but property '{name}' has type '{prop_type}'." + ) + + args.append({ + 'name': name, + 'description': prop.get('description'), + 'required': name in required, + }) + + return args diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/util/resource.py b/py/plugins/mcp/src/genkit/plugins/mcp/util/resource.py new file mode 100644 index 0000000000..3015d609da --- /dev/null +++ b/py/plugins/mcp/src/genkit/plugins/mcp/util/resource.py @@ -0,0 +1,149 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Resource utilities for MCP plugin. + +This module contains helper functions for handling MCP resources, +including reading and converting resource content. +""" + +from typing import Any, Dict + +import structlog + +from genkit.core.typing import Part +from mcp.types import BlobResourceContents, ReadResourceResult, Resource, TextResourceContents + +logger = structlog.get_logger(__name__) + + +def from_mcp_resource_part(content: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert MCP resource content to Genkit Part format. + + Handles different content types: + - Text content is mapped to text part + - Blob content is mapped to media part with base64 data + + Args: + content: MCP resource content part + + Returns: + Genkit Part representation + """ + content_type = content.get('type', '') + + if content_type == 'text': + return {'text': content.get('text', '')} + + elif content_type == 'blob': + mime_type = content.get('mimeType', 'application/octet-stream') + blob_data = content.get('blob', '') + return { + 'media': { + 'contentType': mime_type, + 'url': f'data:{mime_type};base64,{blob_data}', + } + } + + # Default case + return {'text': str(content)} + + +def process_resource_content(resource_result: ReadResourceResult) -> Any: + """ + Process MCP ReadResourceResult and extract content. + + Args: + resource_result: The ReadResourceResult from MCP server + + Returns: + Extracted resource content as Genkit Parts + """ + if not hasattr(resource_result, 'contents') or not resource_result.contents: + return [] + + return [from_mcp_resource_part(content) for content in resource_result.contents] + + +def convert_resource_to_genkit_part(resource: Resource) -> dict[str, Any]: + """ + Convert MCP resource to Genkit Part format. + + Args: + resource: MCP resource object + + Returns: + Genkit Part representation with resource URI + """ + return { + 'resource': { + 'uri': resource.uri, + 'name': resource.name, + 'description': resource.description if hasattr(resource, 'description') else None, + } + } + + +def to_mcp_resource_contents(uri: str, parts: list[Part]) -> list[TextResourceContents | BlobResourceContents]: + """Convert Genkit Parts to MCP resource contents. + + Args: + uri: The URI of the resource. + parts: List of Genkit Parts to convert. + + Returns: + List of MCP resource contents (text or blob). + + Raises: + ValueError: If media is not a base64 data URL. + ValueError: If part type is not supported. + """ + contents: list[TextResourceContents | BlobResourceContents] = [] + + for part in parts: + if isinstance(part, dict): + # Handle media/image content + if 'media' in part: + media = part['media'] + url = media.get('url', '') + content_type = media.get('contentType', '') + + if not url.startswith('data:'): + raise ValueError('MCP resource messages only support base64 data images.') + + # Extract MIME type and base64 data + try: + mime_type = content_type or url[url.index(':') + 1 : url.index(';')] + blob_data = url[url.index(',') + 1 :] + except ValueError as e: + raise ValueError(f'Invalid data URL format: {url}') from e + + contents.append(BlobResourceContents(uri=uri, mimeType=mime_type, blob=blob_data)) + + # Handle text content + elif 'text' in part: + contents.append(TextResourceContents(uri=uri, text=part['text'])) + else: + raise ValueError( + f'MCP resource messages only support media and text parts. ' + f'Unsupported part type: {list(part.keys())}' + ) + elif isinstance(part, str): + contents.append(TextResourceContents(uri=uri, text=part)) + + return contents diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/util/tools.py b/py/plugins/mcp/src/genkit/plugins/mcp/util/tools.py new file mode 100644 index 0000000000..5d2662c021 --- /dev/null +++ b/py/plugins/mcp/src/genkit/plugins/mcp/util/tools.py @@ -0,0 +1,144 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Tool utilities for MCP plugin. + +This module contains helper functions for converting between MCP tools +and Genkit actions, processing tool results, and registering tools. +""" + +import json +from typing import Any, Dict, List, Union + +import structlog + +from mcp.types import CallToolResult, ImageContent, TextContent, Tool + +logger = structlog.get_logger(__name__) + + +def to_text(content: List[Dict[str, Any]]) -> str: + """ + Extract text from MCP CallToolResult content. + + Args: + content: List of content parts from CallToolResult + + Returns: + Concatenated text from all text parts + """ + return ''.join(part.get('text', '') for part in content) + + +def process_result(result: CallToolResult) -> Any: + """ + Process MCP CallToolResult and extract/parse content. + + Handles different result types: + - Error results return error dict + - Text-only results attempt JSON parsing + - Single content results return the content directly + - Otherwise returns the full result + + Args: + result: The CallToolResult from MCP server + + Returns: + Processed result (parsed JSON, text, or raw content) + + Raises: + RuntimeError: If the tool execution failed (isError=True) + """ + if result.isError: + return {'error': to_text(result.content)} + + # Check if all content parts are text + if all(hasattr(c, 'text') and c.text for c in result.content): + text = to_text(result.content) + # Try to parse as JSON if it looks like JSON + text_stripped = text.strip() + if text_stripped.startswith('{') or text_stripped.startswith('['): + try: + return json.loads(text) + except (json.JSONDecodeError, ValueError): + return text + return text + + # Single content item + if len(result.content) == 1: + return result.content[0] + + # Return full result for complex cases + return result + + +def process_tool_result(result: CallToolResult) -> Any: + """ + Process MCP CallToolResult and extract content. + + This is an alias for process_result() for backwards compatibility. + + Args: + result: The CallToolResult from MCP server + + Returns: + Extracted text content from the result + + Raises: + RuntimeError: If the tool execution failed + """ + return process_result(result) + + +def convert_tool_schema(mcp_schema: Dict[str, Any]) -> Dict[str, Any]: + """ + Convert MCP tool input schema (JSONSchema7) to Genkit format. + + Args: + mcp_schema: MCP tool input schema + + Returns: + Genkit-compatible JSON schema + + Note: + Currently returns the schema as-is since both use JSON Schema. + Future enhancements may add validation or transformation. + """ + # MCP and Genkit both use JSON Schema, so minimal conversion needed + return mcp_schema + + +def to_mcp_tool_result(result: Any) -> list[TextContent | ImageContent]: + """Convert tool execution result to MCP CallToolResult content. + + Args: + result: The result from tool execution (can be string, dict, or other). + + Returns: + List of MCP content items (TextContent or ImageContent). + """ + if isinstance(result, str): + return [TextContent(type='text', text=result)] + elif isinstance(result, dict): + # If it's already in MCP format, return as-is + if 'type' in result and 'text' in result: + return [TextContent(type='text', text=result['text'])] + # Otherwise, serialize to JSON + return [TextContent(type='text', text=json.dumps(result))] + else: + # Convert to string for other types + return [TextContent(type='text', text=str(result))] diff --git a/py/plugins/mcp/src/genkit/plugins/mcp/util/transport.py b/py/plugins/mcp/src/genkit/plugins/mcp/util/transport.py new file mode 100644 index 0000000000..c065cd5d01 --- /dev/null +++ b/py/plugins/mcp/src/genkit/plugins/mcp/util/transport.py @@ -0,0 +1,89 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Transport utilities for MCP plugin. + +This module contains helper functions for creating and managing +MCP transport connections (stdio, SSE, custom). +""" + +from typing import Any, Dict, Optional, Tuple + +import structlog + +from mcp import StdioServerParameters + +logger = structlog.get_logger(__name__) + + +def create_stdio_params( + command: str, args: Optional[list] = None, env: Optional[Dict[str, str]] = None +) -> StdioServerParameters: + """ + Create StdioServerParameters for MCP connection. + + Args: + command: Command to execute + args: Command arguments + env: Environment variables + + Returns: + StdioServerParameters object + """ + return StdioServerParameters(command=command, args=args or [], env=env) + + +async def transport_from(config: Dict[str, Any], session_id: Optional[str] = None) -> Tuple[Any, str]: + """ + Create an MCP transport instance based on the provided server configuration. + + Supports creating SSE, Stdio, or using a pre-configured custom transport. + + Args: + config: Configuration for the MCP server + session_id: Optional session ID for HTTP transport + + Returns: + Tuple of (transport instance or None, transport type string) + + Note: + This function mirrors the JS SDK's transportFrom() function. + """ + # Handle pre-configured transport first + if 'transport' in config and config['transport']: + return (config['transport'], 'custom') + + # Handle SSE/HTTP config + if 'url' in config and config['url']: + try: + # Dynamic import to avoid hard dependency + from mcp.client.sse import sse_client + + # Note: Python MCP SDK may have different SSE client API + # This is a placeholder that matches the pattern + logger.info(f'Creating SSE transport for URL: {config["url"]}') + return (config['url'], 'http') # Simplified for now + except ImportError: + logger.warning('SSE client not available') + return (None, 'http') + + # Handle Stdio config + if 'command' in config and config['command']: + stdio_params = create_stdio_params(command=config['command'], args=config.get('args'), env=config.get('env')) + return (stdio_params, 'stdio') + + return (None, 'unknown') diff --git a/py/plugins/mcp/tests/fakes.py b/py/plugins/mcp/tests/fakes.py new file mode 100644 index 0000000000..356337f118 --- /dev/null +++ b/py/plugins/mcp/tests/fakes.py @@ -0,0 +1,128 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import json +import sys +from typing import Any, Callable, Dict, List, Optional +from unittest.mock import MagicMock + +from genkit.ai import Genkit +from genkit.core.action.types import ActionKind + + +class MockSchema: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +def mock_mcp_modules(): + """Sets up comprehensive MCP mocks in sys.modules.""" + mock_mcp = MagicMock() + sys.modules['mcp'] = mock_mcp + sys.modules['mcp'].__path__ = [] + + types_mock = MagicMock() + sys.modules['mcp.types'] = types_mock + types_mock.ListToolsResult = MockSchema + types_mock.CallToolResult = MockSchema + types_mock.ListPromptsResult = MockSchema + types_mock.GetPromptResult = MockSchema + types_mock.ListResourcesResult = MockSchema + types_mock.ListResourceTemplatesResult = MockSchema + types_mock.ReadResourceResult = MockSchema + types_mock.Tool = MockSchema + types_mock.Prompt = MockSchema + types_mock.Resource = MockSchema + types_mock.ResourceTemplate = MockSchema + types_mock.TextContent = MockSchema + types_mock.PromptMessage = MockSchema + types_mock.TextResourceContents = MockSchema + types_mock.BlobResourceContents = MockSchema + types_mock.ImageContent = MockSchema + + sys.modules['mcp.server'] = MagicMock() + sys.modules['mcp.server.stdio'] = MagicMock() + sys.modules['mcp.client'] = MagicMock() + sys.modules['mcp.client'].__path__ = [] + sys.modules['mcp.client.stdio'] = MagicMock() + sys.modules['mcp.client.sse'] = MagicMock() + sys.modules['mcp.server.sse'] = MagicMock() + + return mock_mcp, types_mock + + +def define_echo_model(ai: Genkit): + """Defines a fake echo model for testing.""" + + @ai.tool(name='echoModel') + def echo_model(request: Any): + # This is a simplified mock of a model action + # Real model action would handle GenerateRequest and return GenerateResponse + + # logic to echo content + # For now, just a placeholder as we generally mock the model execution in tests + pass + + # In real usage, we would define a Model action properly. + # For unit tests here, we might not strictly need the full model implementation + # if we are mocking the generation or call. + # But matching JS behavior: + # JS defines 'echoModel' which returns "Echo: " + input. + + # We can use ai.define_model if available or just mock it. + pass + + +class FakeTransport: + """Fakes an MCP transport/server for testing.""" + + def __init__(self): + self.tools = [] + self.prompts = [] + self.resources = [] + self.resource_templates = [] + self.call_tool_result = None + self.get_prompt_result = None + self.read_resource_result = None + self.roots = [] + + # Callbacks that would simulate transport behavior + self.on_message = None + self.on_close = None + self.on_error = None + + async def start(self): + pass + + async def send(self, message: Dict[str, Any]): + """Handle incoming JSON-RPC message (simulating server).""" + request = message + # msg_id = request.get("id") + + # In a real transport we'd write back to the stream. + # Here we just store handling logic or print. + # Since we are mocking the ClientSession in our python tests, + # this logic might need to be hooked up to the mock session's methods. + pass + + # Helper methods to populate the fake state + def add_tool(self, name: str, description: str = '', schema: Dict = None): + self.tools.append({'name': name, 'description': description, 'inputSchema': schema or {'type': 'object'}}) + + def add_prompt(self, name: str, description: str = '', arguments: List = None): + self.prompts.append({'name': name, 'description': description, 'arguments': arguments or []}) diff --git a/py/plugins/mcp/tests/test_mcp_conversion.py b/py/plugins/mcp/tests/test_mcp_conversion.py new file mode 100644 index 0000000000..926f94b69a --- /dev/null +++ b/py/plugins/mcp/tests/test_mcp_conversion.py @@ -0,0 +1,259 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for MCP conversion utilities.""" + +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src'))) +from fakes import mock_mcp_modules + +mock_mcp_modules() + +from genkit.core.typing import Message +from genkit.plugins.mcp.util import ( + to_mcp_prompt_arguments, + to_mcp_prompt_message, + to_mcp_resource_contents, + to_mcp_tool_result, +) + + +class TestMessageConversion(unittest.TestCase): + """Tests for message conversion utilities.""" + + def test_convert_user_message(self): + """Test converting a user message.""" + message = Message(role='user', content=[{'text': 'Hello, world!'}]) + + result = to_mcp_prompt_message(message) + + self.assertEqual(result.role, 'user') + self.assertEqual(result.content.type, 'text') + self.assertEqual(result.content.text, 'Hello, world!') + + def test_convert_model_message(self): + """Test converting a model message (maps to assistant).""" + message = Message(role='model', content=[{'text': 'Hi there!'}]) + + result = to_mcp_prompt_message(message) + + self.assertEqual(result.role, 'assistant') + self.assertEqual(result.content.type, 'text') + self.assertEqual(result.content.text, 'Hi there!') + + def test_convert_message_with_multiple_text_parts(self): + """Test converting a message with multiple text parts.""" + message = Message(role='user', content=[{'text': 'Part 1 '}, {'text': 'Part 2 '}, {'text': 'Part 3'}]) + + result = to_mcp_prompt_message(message) + + self.assertEqual(result.content.text, 'Part 1 Part 2 Part 3') + + def test_convert_message_with_invalid_role(self): + """Test that converting a message with invalid role raises error.""" + message = Message(role='system', content=[{'text': 'System message'}]) + + with self.assertRaises(ValueError) as context: + to_mcp_prompt_message(message) + + self.assertIn('system', str(context.exception).lower()) + + def test_convert_message_with_image(self): + """Test converting a message with image content.""" + message = Message( + role='user', content=[{'media': {'url': 'data:image/png;base64,iVBORw0KG...', 'contentType': 'image/png'}}] + ) + + result = to_mcp_prompt_message(message) + + self.assertEqual(result.role, 'user') + self.assertEqual(result.content.type, 'image') + self.assertEqual(result.content.mimeType, 'image/png') + + def test_convert_message_with_non_data_url_fails(self): + """Test that non-data URLs raise an error.""" + message = Message(role='user', content=[{'media': {'url': 'http://example.com/image.png'}}]) + + with self.assertRaises(ValueError) as context: + to_mcp_prompt_message(message) + + self.assertIn('base64', str(context.exception).lower()) + + +class TestResourceConversion(unittest.TestCase): + """Tests for resource content conversion.""" + + def test_convert_text_resource(self): + """Test converting text resource content.""" + parts = [{'text': 'Resource content'}] + + result = to_mcp_resource_contents('test://resource', parts) + + self.assertEqual(len(result), 1) + self.assertEqual(result[0].uri, 'test://resource') + self.assertEqual(result[0].text, 'Resource content') + + def test_convert_multiple_text_parts(self): + """Test converting multiple text parts.""" + parts = [{'text': 'Part 1'}, {'text': 'Part 2'}, {'text': 'Part 3'}] + + result = to_mcp_resource_contents('test://resource', parts) + + self.assertEqual(len(result), 3) + for i, part in enumerate(result, 1): + self.assertEqual(part.text, f'Part {i}') + + def test_convert_string_parts(self): + """Test converting string parts.""" + parts = ['Text 1', 'Text 2'] + + result = to_mcp_resource_contents('test://resource', parts) + + self.assertEqual(len(result), 2) + self.assertEqual(result[0].text, 'Text 1') + self.assertEqual(result[1].text, 'Text 2') + + def test_convert_media_resource(self): + """Test converting media resource content.""" + parts = [{'media': {'url': 'data:image/png;base64,abc123', 'contentType': 'image/png'}}] + + result = to_mcp_resource_contents('test://image', parts) + + self.assertEqual(len(result), 1) + self.assertEqual(result[0].uri, 'test://image') + self.assertEqual(result[0].mimeType, 'image/png') + self.assertEqual(result[0].blob, 'abc123') + + def test_convert_mixed_content(self): + """Test converting mixed text and media content.""" + parts = [{'text': 'Description'}, {'media': {'url': 'data:image/png;base64,xyz', 'contentType': 'image/png'}}] + + result = to_mcp_resource_contents('test://mixed', parts) + + self.assertEqual(len(result), 2) + self.assertEqual(result[0].text, 'Description') + self.assertEqual(result[1].blob, 'xyz') + + +class TestToolResultConversion(unittest.TestCase): + """Tests for tool result conversion.""" + + def test_convert_string_result(self): + """Test converting string result.""" + result = to_mcp_tool_result('Hello, world!') + + self.assertEqual(len(result), 1) + self.assertEqual(result[0].type, 'text') + self.assertEqual(result[0].text, 'Hello, world!') + + def test_convert_dict_result(self): + """Test converting dict result.""" + result = to_mcp_tool_result({'key': 'value', 'number': 42}) + + self.assertEqual(len(result), 1) + self.assertEqual(result[0].type, 'text') + # Should be JSON serialized + import json + + parsed = json.loads(result[0].text) + self.assertEqual(parsed['key'], 'value') + self.assertEqual(parsed['number'], 42) + + def test_convert_number_result(self): + """Test converting number result.""" + result = to_mcp_tool_result(42) + + self.assertEqual(len(result), 1) + self.assertEqual(result[0].text, '42') + + def test_convert_boolean_result(self): + """Test converting boolean result.""" + result = to_mcp_tool_result(True) + + self.assertEqual(len(result), 1) + self.assertEqual(result[0].text, 'True') + + +class TestSchemaConversion(unittest.TestCase): + """Tests for schema conversion utilities.""" + + def test_convert_simple_schema(self): + """Test converting simple string schema.""" + schema = {'type': 'object', 'properties': {'name': {'type': 'string', 'description': 'User name'}}} + + result = to_mcp_prompt_arguments(schema) + + self.assertIsNotNone(result) + self.assertEqual(len(result), 1) + self.assertEqual(result[0]['name'], 'name') + self.assertEqual(result[0]['description'], 'User name') + + def test_convert_schema_with_required(self): + """Test converting schema with required fields.""" + schema = { + 'type': 'object', + 'properties': {'name': {'type': 'string'}, 'age': {'type': 'string'}}, + 'required': ['name'], + } + + result = to_mcp_prompt_arguments(schema) + + name_arg = next(arg for arg in result if arg['name'] == 'name') + age_arg = next(arg for arg in result if arg['name'] == 'age') + + self.assertTrue(name_arg['required']) + self.assertFalse(age_arg['required']) + + def test_convert_schema_with_non_string_fails(self): + """Test that non-string properties raise an error.""" + schema = {'type': 'object', 'properties': {'count': {'type': 'number'}}} + + with self.assertRaises(ValueError) as context: + to_mcp_prompt_arguments(schema) + + self.assertIn('string', str(context.exception).lower()) + + def test_convert_schema_with_union_type(self): + """Test converting schema with union type including string.""" + schema = {'type': 'object', 'properties': {'value': {'type': ['string', 'null']}}} + + result = to_mcp_prompt_arguments(schema) + + # Should succeed because string is in the union + self.assertEqual(len(result), 1) + + def test_convert_none_schema(self): + """Test converting None schema.""" + result = to_mcp_prompt_arguments(None) + + self.assertIsNone(result) + + def test_convert_schema_without_properties_fails(self): + """Test that schema without properties raises an error.""" + schema = {'type': 'object'} + + with self.assertRaises(ValueError) as context: + to_mcp_prompt_arguments(schema) + + self.assertIn('properties', str(context.exception).lower()) + + +if __name__ == '__main__': + unittest.main() diff --git a/py/plugins/mcp/tests/test_mcp_host.py b/py/plugins/mcp/tests/test_mcp_host.py new file mode 100644 index 0000000000..10d995b7d7 --- /dev/null +++ b/py/plugins/mcp/tests/test_mcp_host.py @@ -0,0 +1,64 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +import os +import sys +from unittest.mock import AsyncMock, MagicMock + +sys.path.insert(0, os.path.dirname(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src'))) +from fakes import mock_mcp_modules + +mock_mcp_modules() + +import unittest +from unittest.mock import patch + +from genkit.ai import Genkit +from genkit.core.action.types import ActionKind + +# Now import plugin +from genkit.plugins.mcp import McpClient, McpHost, McpServerConfig, create_mcp_host + + +class TestMcpHost(unittest.IsolatedAsyncioTestCase): + async def test_connect_and_register(self): + # Setup configs + config1 = McpServerConfig(command='echo') + config2 = McpServerConfig(url='http://localhost:8000') + + host = create_mcp_host({'server1': config1, 'server2': config2}) + + # Mock clients within host + with patch('genkit.plugins.mcp.client.client.McpClient.connect', new_callable=AsyncMock) as mock_connect: + await host.start() + self.assertEqual(mock_connect.call_count, 2) + + # Mock session for registration + host.clients['server1'].session = AsyncMock() + mock_tool = MagicMock() + mock_tool.name = 'tool1' + host.clients['server1'].session.list_tools.return_value.tools = [mock_tool] + + ai = MagicMock(spec=Genkit) + ai.registry = MagicMock() + + await host.register_tools(ai) + + # Verify tool registration + ai.registry.register_action.assert_called() + call_args = ai.registry.register_action.call_args[1] + self.assertIn('server1/tool1', call_args['name']) diff --git a/py/plugins/mcp/tests/test_mcp_integration.py b/py/plugins/mcp/tests/test_mcp_integration.py new file mode 100644 index 0000000000..d6045734e0 --- /dev/null +++ b/py/plugins/mcp/tests/test_mcp_integration.py @@ -0,0 +1,311 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Integration tests for MCP client-server communication.""" + +import asyncio +import os +import sys +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +sys.path.insert(0, os.path.dirname(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src'))) +from fakes import mock_mcp_modules + +mock_mcp_modules() + +import pytest + +from genkit.ai import Genkit +from genkit.plugins.mcp import McpClient, McpHost, McpServerConfig, create_mcp_host, create_mcp_server + + +@pytest.mark.asyncio +class TestClientServerIntegration(unittest.IsolatedAsyncioTestCase): + """Integration tests for MCP client-server communication.""" + + async def test_client_can_list_server_tools(self): + """Test that a client can list tools from a server.""" + # Create server with tools + server_ai = Genkit() + + @server_ai.tool() + def add(a: int, b: int) -> int: + return a + b + + # Create client + client = McpClient(name='test-client', config=McpServerConfig(command='echo', args=['test'])) + + # Mock the session to return tools + mock_session = AsyncMock() + mock_tool = MagicMock() + mock_tool.name = 'add' + mock_tool.description = 'Add two numbers' + mock_tool.inputSchema = {'type': 'object'} + + mock_session.list_tools.return_value.tools = [mock_tool] + client.session = mock_session + + # List tools + tools = await client.list_tools() + + # Verify + self.assertEqual(len(tools), 1) + self.assertEqual(tools[0].name, 'add') + + async def test_client_can_call_server_tool(self): + """Test that a client can call a tool on a server.""" + # Create client + client = McpClient(name='test-client', config=McpServerConfig(command='echo')) + + # Mock the session + mock_session = AsyncMock() + mock_result = MagicMock() + mock_result.isError = False + mock_content = MagicMock() + mock_content.type = 'text' + mock_content.text = '8' + mock_result.content = [mock_content] + + mock_session.call_tool.return_value = mock_result + client.session = mock_session + + # Call tool + result = await client.call_tool('add', {'a': 5, 'b': 3}) + + # Verify + self.assertEqual(result, '8') + mock_session.call_tool.assert_called_once_with('add', {'a': 5, 'b': 3}) + + async def test_client_can_list_server_resources(self): + """Test that a client can list resources from a server.""" + # Create client + client = McpClient(name='test-client', config=McpServerConfig(command='echo')) + + # Mock the session + mock_session = AsyncMock() + mock_resource = MagicMock() + mock_resource.name = 'config' + mock_resource.uri = 'app://config' + mock_resource.description = 'Configuration' + + mock_session.list_resources.return_value.resources = [mock_resource] + client.session = mock_session + + # List resources + resources = await client.list_resources() + + # Verify + self.assertEqual(len(resources), 1) + self.assertEqual(resources[0].name, 'config') + self.assertEqual(resources[0].uri, 'app://config') + + async def test_client_can_read_server_resource(self): + """Test that a client can read a resource from a server.""" + # Create client + client = McpClient(name='test-client', config=McpServerConfig(command='echo')) + + # Mock the session + mock_session = AsyncMock() + mock_result = MagicMock() + mock_result.contents = [MagicMock(text='Resource content')] + + mock_session.read_resource.return_value = mock_result + client.session = mock_session + + # Read resource + result = await client.read_resource('app://config') + + # Verify + self.assertIsNotNone(result) + mock_session.read_resource.assert_called_once_with('app://config') + + async def test_host_manages_multiple_clients(self): + """Test that a host can manage multiple clients.""" + # Create host with multiple servers + config1 = McpServerConfig(command='server1') + config2 = McpServerConfig(command='server2') + + host = create_mcp_host({'server1': config1, 'server2': config2}) + + # Verify clients were created + self.assertEqual(len(host.clients), 2) + self.assertIn('server1', host.clients) + self.assertIn('server2', host.clients) + + async def test_host_can_register_tools_from_multiple_servers(self): + """Test that a host can register tools from multiple servers.""" + # Create host + host = create_mcp_host({'server1': McpServerConfig(command='s1'), 'server2': McpServerConfig(command='s2')}) + + # Mock sessions for both clients + for client_name, client in host.clients.items(): + mock_session = AsyncMock() + mock_tool = MagicMock() + mock_tool.name = f'{client_name}_tool' + mock_tool.description = f'Tool from {client_name}' + mock_tool.inputSchema = {'type': 'object'} + + mock_session.list_tools.return_value.tools = [mock_tool] + client.session = mock_session + + # Register tools + ai = Genkit() + await host.register_tools(ai) + + # Verify tools were registered + # Each client should have registered one tool + # Tool names should be prefixed with server name + + async def test_client_handles_disabled_server(self): + """Test that a client handles disabled servers correctly.""" + # Create client with disabled config + config = McpServerConfig(command='echo', disabled=True) + client = McpClient(name='test-client', config=config) + + # Try to connect + await client.connect() + + # Should not have a session + self.assertIsNone(client.session) + + async def test_host_can_disable_and_enable_clients(self): + """Test that a host can disable and enable clients.""" + host = create_mcp_host({'test': McpServerConfig(command='echo')}) + + # Mock the client + client = host.clients['test'] + client.session = AsyncMock() + client.close = AsyncMock() + client.connect = AsyncMock() + + # Disable + await host.disable('test') + self.assertTrue(client.config.disabled) + + # Enable + await host.enable('test') + self.assertFalse(client.config.disabled) + + +@pytest.mark.asyncio +class TestResourceIntegration(unittest.IsolatedAsyncioTestCase): + """Integration tests specifically for resource handling.""" + + async def test_end_to_end_resource_flow(self): + """Test complete flow: define resource → expose via server → consume via client.""" + # This is a conceptual test showing the flow + # In practice, we'd need actual MCP transport for true end-to-end + + # 1. Server side: Define resource + server_ai = Genkit() + server_ai.define_resource( + name='config', uri='app://config', fn=lambda req: {'content': [{'text': 'config data'}]} + ) + + # 2. Create MCP server + from genkit.plugins.mcp import McpServerOptions + + server = create_mcp_server(server_ai, McpServerOptions(name='test-server')) + await server.setup() + + # 3. Verify server can list resources + resources_result = await server.list_resources({}) + self.assertEqual(len(resources_result.resources), 1) + self.assertEqual(resources_result.resources[0].uri, 'app://config') + + # 4. Verify server can read resource + request = MagicMock() + request.params.uri = 'app://config' + read_result = await server.read_resource(request) + self.assertEqual(read_result.contents[0].text, 'config data') + + async def test_template_resource_matching(self): + """Test that template resources match correctly.""" + server_ai = Genkit() + + def file_resource(req): + uri = req.uri + return {'content': [{'text': f'Contents of {uri}'}]} + + server_ai.define_resource(name='file', template='file://{+path}', fn=file_resource) + + # Create server + from genkit.plugins.mcp import McpServerOptions + + server = create_mcp_server(server_ai, McpServerOptions(name='test-server')) + await server.setup() + + # List templates + templates_result = await server.list_resource_templates({}) + self.assertEqual(len(templates_result.resourceTemplates), 1) + self.assertEqual(templates_result.resourceTemplates[0].uriTemplate, 'file://{+path}') + + # Read with different URIs + for test_uri in ['file:///path/to/file.txt', 'file:///another/file.md', 'file:///deep/nested/path/doc.pdf']: + request = MagicMock() + request.params.uri = test_uri + result = await server.read_resource(request) + self.assertIn(test_uri, result.contents[0].text) + + +@pytest.mark.asyncio +class TestErrorHandling(unittest.IsolatedAsyncioTestCase): + """Tests for error handling in client-server communication.""" + + async def test_server_handles_missing_tool(self): + """Test that server properly handles requests for non-existent tools.""" + server_ai = Genkit() + + @server_ai.tool() + def existing_tool(x: int) -> int: + return x + + from genkit.plugins.mcp import McpServerOptions + + server = create_mcp_server(server_ai, McpServerOptions(name='test-server')) + await server.setup() + + # Try to call non-existent tool + request = MagicMock() + request.params.name = 'nonexistent_tool' + request.params.arguments = {} + + from genkit.core.error import GenkitError + + with self.assertRaises(GenkitError) as context: + await server.call_tool(request) + + self.assertIn('NOT_FOUND', str(context.exception.status)) + + async def test_client_handles_connection_failure(self): + """Test that client handles connection failures gracefully.""" + client = McpClient(name='test-client', config=McpServerConfig(command='nonexistent_command')) + + # Mock the connection to fail + with patch('genkit.plugins.mcp.client.client.stdio_client') as mock_stdio: + mock_stdio.side_effect = Exception('Connection failed') + + with self.assertRaises(Exception): + await client.connect() + + # Client should mark server as disabled + self.assertTrue(client.config.disabled) + + +if __name__ == '__main__': + unittest.main() diff --git a/py/plugins/mcp/tests/test_mcp_server.py b/py/plugins/mcp/tests/test_mcp_server.py new file mode 100644 index 0000000000..f3180d6185 --- /dev/null +++ b/py/plugins/mcp/tests/test_mcp_server.py @@ -0,0 +1,341 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +MCP Server Tests + +Mirrors the functionality of js/plugins/mcp/tests/server_test.ts +Tests tools, prompts, and resources exposed via MCP server. +""" + +import os +import sys +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +sys.path.insert(0, os.path.dirname(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src'))) + +# Mock mcp module before importing +mock_mcp = MagicMock() +sys.modules['mcp'] = mock_mcp + + +class MockSchema: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +types_mock = MagicMock() +sys.modules['mcp.types'] = types_mock +types_mock.ListToolsResult = MockSchema +types_mock.CallToolResult = MockSchema +types_mock.ListPromptsResult = MockSchema +types_mock.GetPromptResult = MockSchema +types_mock.ListResourcesResult = MockSchema +types_mock.ListResourceTemplatesResult = MockSchema +types_mock.ReadResourceResult = MockSchema +types_mock.Tool = MockSchema +types_mock.Prompt = MockSchema +types_mock.Resource = MockSchema +types_mock.ResourceTemplate = MockSchema +types_mock.TextResourceContents = MockSchema +types_mock.BlobResourceContents = MockSchema +types_mock.ImageContent = MockSchema +types_mock.TextResourceContents = MockSchema +types_mock.BlobResourceContents = MockSchema +types_mock.ImageContent = MockSchema +types_mock.TextContent = MockSchema +types_mock.PromptMessage = MockSchema + +sys.modules['mcp.server'] = MagicMock() +sys.modules['mcp.server.stdio'] = MagicMock() +sys.modules['mcp.client'] = MagicMock() +sys.modules['mcp.client'].__path__ = [] +sys.modules['mcp.client.stdio'] = MagicMock() +sys.modules['mcp.client.sse'] = MagicMock() +sys.modules['mcp.server.sse'] = MagicMock() + +import pytest + +from genkit.ai import Genkit +from genkit.core.action.types import ActionKind +from genkit.plugins.mcp import McpServer, McpServerOptions, create_mcp_server + + +@pytest.mark.asyncio +class TestMcpServer(unittest.IsolatedAsyncioTestCase): + """Test MCP server functionality - mirrors JS server_test.ts""" + + def setUp(self): + """Set up test fixtures before each test.""" + self.ai = Genkit() + + # Define test tool + @self.ai.tool(description='test tool') + def test_tool(input: dict[str, str]) -> str: + foo = input.get('foo', '') + return f'yep {{"foo":"{foo}"}}' + + # Define test prompt + self.ai.define_prompt(name='testPrompt', model='test-model', prompt='prompt says: {{input}}') + + # Define test resource with fixed URI + self.ai.define_resource( + name='testResources', uri='my://resource', fn=lambda req: {'content': [{'text': 'my resource'}]} + ) + + # Define test resource with template + self.ai.define_resource( + name='testTmpl', + template='file://{+path}', + fn=lambda req: {'content': [{'text': f'file contents for {req.uri}'}]}, + ) + + # Create MCP server + self.server = create_mcp_server(self.ai, McpServerOptions(name='test-server', version='0.0.1')) + + async def asyncSetUp(self): + """Async setup - initialize server.""" + await self.server.setup() + + # ===== TOOL TESTS ===== + + async def test_list_tools(self): + """Test listing tools - mirrors JS 'should list tools'.""" + result = await self.server.list_tools({}) + + # Verify we have the test tool + self.assertEqual(len(result.tools), 1) + tool = result.tools[0] + + self.assertEqual(tool.name, 'test_tool') + self.assertEqual(tool.description, 'test tool') + self.assertIsNotNone(tool.inputSchema) + + async def test_call_tool(self): + """Test calling a tool - mirrors JS 'should call the tool'.""" + # Create mock request + request = MagicMock() + request.params.name = 'test_tool' + request.params.arguments = {'foo': 'bar'} + + result = await self.server.call_tool(request) + + # Verify response + self.assertEqual(len(result.content), 1) + self.assertEqual(result.content[0].type, 'text') + self.assertEqual(result.content[0].text, 'yep {"foo":"bar"}') + + # ===== PROMPT TESTS ===== + + async def test_list_prompts(self): + """Test listing prompts - mirrors JS 'should list prompts'.""" + result = await self.server.list_prompts({}) + + # Verify we have the test prompt + prompt_names = [p.name for p in result.prompts] + self.assertIn('testPrompt', prompt_names) + + async def test_get_prompt(self): + """Test rendering a prompt - mirrors JS 'should render prompt'.""" + # Create mock request + request = MagicMock() + request.params.name = 'testPrompt' + request.params.arguments = {'input': 'hello'} + + result = await self.server.get_prompt(request) + + # Verify response + self.assertIsNotNone(result.messages) + self.assertGreater(len(result.messages), 0) + + # Check message content + message = result.messages[0] + self.assertEqual(message.role, 'user') + self.assertEqual(message.content.type, 'text') + self.assertIn('prompt says: hello', message.content.text) + + # ===== RESOURCE TESTS ===== + + async def test_list_resources(self): + """Test listing resources - mirrors JS 'should list resources'.""" + result = await self.server.list_resources({}) + + # Verify we have the fixed URI resource + self.assertEqual(len(result.resources), 1) + resource = result.resources[0] + + self.assertEqual(resource.name, 'testResources') + self.assertEqual(resource.uri, 'my://resource') + + async def test_list_resource_templates(self): + """Test listing resource templates - mirrors JS 'should list templates'.""" + result = await self.server.list_resource_templates({}) + + # Verify we have the template resource + self.assertEqual(len(result.resourceTemplates), 1) + template = result.resourceTemplates[0] + + self.assertEqual(template.name, 'testTmpl') + self.assertEqual(template.uriTemplate, 'file://{+path}') + + async def test_read_resource(self): + """Test reading a resource - mirrors JS 'should read resource'.""" + # Create mock request + request = MagicMock() + request.params.uri = 'my://resource' + + result = await self.server.read_resource(request) + + # Verify response + self.assertEqual(len(result.contents), 1) + content = result.contents[0] + + self.assertEqual(content.uri, 'my://resource') + self.assertEqual(content.text, 'my resource') + + async def test_read_template_resource(self): + """Test reading a template resource.""" + # Create mock request + request = MagicMock() + request.params.uri = 'file:///path/to/file.txt' + + result = await self.server.read_resource(request) + + # Verify response + self.assertEqual(len(result.contents), 1) + content = result.contents[0] + + self.assertEqual(content.uri, 'file:///path/to/file.txt') + self.assertIn('file contents for file:///path/to/file.txt', content.text) + + # ===== ADDITIONAL TESTS ===== + + async def test_server_initialization(self): + """Test that server initializes correctly.""" + self.assertIsNotNone(self.server) + self.assertEqual(self.server.options.name, 'test-server') + self.assertEqual(self.server.options.version, '0.0.1') + self.assertTrue(self.server.actions_resolved) + + async def test_server_has_all_action_types(self): + """Test that server has tools, prompts, and resources.""" + self.assertGreater(len(self.server.tool_actions), 0) + self.assertGreater(len(self.server.prompt_actions), 0) + self.assertGreater(len(self.server.resource_actions), 0) + + async def test_tool_not_found(self): + """Test calling a non-existent tool.""" + from genkit.core.error import GenkitError + + request = MagicMock() + request.params.name = 'nonexistent_tool' + request.params.arguments = {} + + with self.assertRaises(GenkitError) as context: + await self.server.call_tool(request) + + self.assertEqual(context.exception.status, 'NOT_FOUND') + + async def test_prompt_not_found(self): + """Test getting a non-existent prompt.""" + from genkit.core.error import GenkitError + + request = MagicMock() + request.params.name = 'nonexistent_prompt' + request.params.arguments = {} + + with self.assertRaises(GenkitError) as context: + await self.server.get_prompt(request) + + self.assertEqual(context.exception.status, 'NOT_FOUND') + + async def test_resource_not_found(self): + """Test reading a non-existent resource.""" + from genkit.core.error import GenkitError + + request = MagicMock() + request.params.uri = 'nonexistent://resource' + + with self.assertRaises(GenkitError) as context: + await self.server.read_resource(request) + + self.assertEqual(context.exception.status, 'NOT_FOUND') + + +# Additional test class for resource-specific functionality +@pytest.mark.asyncio +class TestResourceFunctionality(unittest.IsolatedAsyncioTestCase): + """Test resource-specific functionality.""" + + async def test_resource_registration_with_fixed_uri(self): + """Test registering a resource with fixed URI.""" + ai = Genkit() + + action = ai.define_resource( + name='test_resource', uri='test://resource', fn=lambda req: {'content': [{'text': 'test'}]} + ) + + self.assertIsNotNone(action) + self.assertEqual(action.kind, ActionKind.RESOURCE) + self.assertEqual(action.metadata['resource']['uri'], 'test://resource') + + async def test_resource_registration_with_template(self): + """Test registering a resource with URI template.""" + ai = Genkit() + + action = ai.define_resource( + name='file', template='file://{+path}', fn=lambda req: {'content': [{'text': 'file content'}]} + ) + + self.assertIsNotNone(action) + self.assertEqual(action.kind, ActionKind.RESOURCE) + self.assertEqual(action.metadata['resource']['template'], 'file://{+path}') + + async def test_resource_requires_uri_or_template(self): + """Test that resource requires either uri or template.""" + ai = Genkit() + + with self.assertRaises(ValueError) as context: + ai.define_resource(name='invalid', fn=lambda req: {'content': []}) + + self.assertIn('uri', str(context.exception).lower()) + self.assertIn('template', str(context.exception).lower()) + + async def test_uri_template_matching(self): + """Test URI template matching.""" + from genkit.blocks.resource import matches_uri_template + + # Test exact match + result = matches_uri_template('file://{+path}', 'file:///home/user/doc.txt') + self.assertIsNotNone(result) + self.assertIn('path', result) + + # Test no match + result = matches_uri_template('file://{path}', 'http://example.com') + self.assertIsNone(result) + + # Test multiple parameters + result = matches_uri_template('user://{id}/posts/{post_id}', 'user://123/posts/456') + self.assertIsNotNone(result) + self.assertEqual(result['id'], '123') + self.assertEqual(result['post_id'], '456') + + +if __name__ == '__main__': + unittest.main() diff --git a/py/plugins/mcp/tests/test_mcp_server_resources.py b/py/plugins/mcp/tests/test_mcp_server_resources.py new file mode 100644 index 0000000000..51aca70cca --- /dev/null +++ b/py/plugins/mcp/tests/test_mcp_server_resources.py @@ -0,0 +1,351 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Comprehensive tests for MCP server resource handling.""" + +import os +import sys +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +sys.path.insert(0, os.path.dirname(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src'))) +from fakes import mock_mcp_modules + +mock_mcp_modules() + +import pytest + +from genkit.ai import Genkit +from genkit.core.action.types import ActionKind +from genkit.plugins.mcp import McpServer, McpServerOptions, create_mcp_server + + +@pytest.mark.asyncio +class TestMcpServerResources(unittest.IsolatedAsyncioTestCase): + """Tests for MCP server resource handling.""" + + def setUp(self): + """Set up test fixtures.""" + self.ai = Genkit() + + async def test_list_resources_with_fixed_uri(self): + """Test listing resources with fixed URIs.""" + # Define resources + self.ai.define_resource(name='config', uri='app://config', fn=lambda req: {'content': [{'text': 'config'}]}) + + self.ai.define_resource(name='data', uri='app://data', fn=lambda req: {'content': [{'text': 'data'}]}) + + # Create server + server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) + await server.setup() + + # List resources + result = await server.list_resources({}) + + # Verify + self.assertEqual(len(result.resources), 2) + resource_names = [r.name for r in result.resources] + self.assertIn('config', resource_names) + self.assertIn('data', resource_names) + + # Verify URIs + config_resource = next(r for r in result.resources if r.name == 'config') + self.assertEqual(config_resource.uri, 'app://config') + + async def test_list_resource_templates(self): + """Test listing resources with URI templates.""" + # Define template resources + self.ai.define_resource( + name='file', template='file://{+path}', fn=lambda req: {'content': [{'text': 'file content'}]} + ) + + self.ai.define_resource( + name='user', template='user://{id}/profile', fn=lambda req: {'content': [{'text': 'user profile'}]} + ) + + # Create server + server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) + await server.setup() + + # List resource templates + result = await server.list_resource_templates({}) + + # Verify + self.assertEqual(len(result.resourceTemplates), 2) + template_names = [t.name for t in result.resourceTemplates] + self.assertIn('file', template_names) + self.assertIn('user', template_names) + + # Verify templates + file_template = next(t for t in result.resourceTemplates if t.name == 'file') + self.assertEqual(file_template.uriTemplate, 'file://{+path}') + + async def test_list_resources_excludes_templates(self): + """Test that list_resources excludes template resources.""" + # Define mixed resources + self.ai.define_resource(name='fixed', uri='app://fixed', fn=lambda req: {'content': [{'text': 'fixed'}]}) + + self.ai.define_resource( + name='template', template='app://{id}', fn=lambda req: {'content': [{'text': 'template'}]} + ) + + # Create server + server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) + await server.setup() + + # List resources (should only include fixed URI) + result = await server.list_resources({}) + + self.assertEqual(len(result.resources), 1) + self.assertEqual(result.resources[0].name, 'fixed') + + async def test_list_resource_templates_excludes_fixed(self): + """Test that list_resource_templates excludes fixed URI resources.""" + # Define mixed resources + self.ai.define_resource(name='fixed', uri='app://fixed', fn=lambda req: {'content': [{'text': 'fixed'}]}) + + self.ai.define_resource( + name='template', template='app://{id}', fn=lambda req: {'content': [{'text': 'template'}]} + ) + + # Create server + server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) + await server.setup() + + # List templates (should only include template) + result = await server.list_resource_templates({}) + + self.assertEqual(len(result.resourceTemplates), 1) + self.assertEqual(result.resourceTemplates[0].name, 'template') + + async def test_read_resource_with_fixed_uri(self): + """Test reading a resource with fixed URI.""" + + def config_resource(req): + return {'content': [{'text': 'Configuration data'}]} + + self.ai.define_resource(name='config', uri='app://config', fn=config_resource) + + # Create server + server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) + await server.setup() + + # Read resource + from mcp.types import ReadResourceRequest + + request = MagicMock() + request.params.uri = 'app://config' + + result = await server.read_resource(request) + + # Verify + self.assertEqual(len(result.contents), 1) + self.assertEqual(result.contents[0].text, 'Configuration data') + + async def test_read_resource_with_template(self): + """Test reading a resource with URI template.""" + + def file_resource(req): + uri = req.uri + # Extract path from URI + path = uri.replace('file://', '') + return {'content': [{'text': f'Contents of {path}'}]} + + self.ai.define_resource(name='file', template='file://{+path}', fn=file_resource) + + # Create server + server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) + await server.setup() + + # Read resource + request = MagicMock() + request.params.uri = 'file:///home/user/document.txt' + + result = await server.read_resource(request) + + # Verify + self.assertEqual(len(result.contents), 1) + self.assertIn('/home/user/document.txt', result.contents[0].text) + + async def test_read_resource_not_found(self): + """Test reading a non-existent resource.""" + self.ai.define_resource(name='existing', uri='app://existing', fn=lambda req: {'content': [{'text': 'data'}]}) + + # Create server + server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) + await server.setup() + + # Try to read non-existent resource + request = MagicMock() + request.params.uri = 'app://nonexistent' + + from genkit.core.error import GenkitError + + with self.assertRaises(GenkitError) as context: + await server.read_resource(request) + + self.assertIn('NOT_FOUND', str(context.exception.status)) + + async def test_read_resource_with_multiple_content_parts(self): + """Test reading a resource that returns multiple content parts.""" + + def multi_part_resource(req): + return {'content': [{'text': 'Part 1'}, {'text': 'Part 2'}, {'text': 'Part 3'}]} + + self.ai.define_resource(name='multi', uri='app://multi', fn=multi_part_resource) + + # Create server + server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) + await server.setup() + + # Read resource + request = MagicMock() + request.params.uri = 'app://multi' + + result = await server.read_resource(request) + + # Verify + self.assertEqual(len(result.contents), 3) + self.assertEqual(result.contents[0].text, 'Part 1') + self.assertEqual(result.contents[1].text, 'Part 2') + self.assertEqual(result.contents[2].text, 'Part 3') + + +@pytest.mark.asyncio +class TestMcpServerToolsAndPrompts(unittest.IsolatedAsyncioTestCase): + """Tests for MCP server tool and prompt handling.""" + + def setUp(self): + """Set up test fixtures.""" + self.ai = Genkit() + + async def test_list_tools(self): + """Test listing tools.""" + + @self.ai.tool(description='Add two numbers') + def add(input: dict[str, int]) -> int: + return input['a'] + input['b'] + + @self.ai.tool(description='Multiply two numbers') + def multiply(input: dict[str, int]) -> int: + return input['a'] * input['b'] + + # Create server + server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) + await server.setup() + + # List tools + result = await server.list_tools({}) + + # Verify + self.assertEqual(len(result.tools), 2) + tool_names = [t.name for t in result.tools] + self.assertIn('add', tool_names) + self.assertIn('multiply', tool_names) + + async def test_call_tool(self): + """Test calling a tool.""" + + @self.ai.tool() + def add(input: dict[str, int]) -> int: + return input['a'] + input['b'] + + # Create server + server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) + await server.setup() + + # Call tool + request = MagicMock() + request.params.name = 'add' + request.params.arguments = {'a': 5, 'b': 3} + + result = await server.call_tool(request) + + # Verify + self.assertEqual(len(result.content), 1) + self.assertEqual(result.content[0].text, '8') + + async def test_list_prompts(self): + """Test listing prompts.""" + self.ai.define_prompt(name='greeting', prompt='Hello {{name}}!') + + self.ai.define_prompt(name='farewell', prompt='Goodbye {{name}}!') + + # Create server + server = create_mcp_server(self.ai, McpServerOptions(name='test-server')) + await server.setup() + + # List prompts + result = await server.list_prompts({}) + + # Verify + self.assertGreaterEqual(len(result.prompts), 2) + prompt_names = [p.name for p in result.prompts] + # Prompt names might have variant suffixes + + +@pytest.mark.asyncio +class TestMcpServerIntegration(unittest.IsolatedAsyncioTestCase): + """Integration tests for MCP server.""" + + async def test_server_exposes_all_action_types(self): + """Test that server exposes tools, prompts, and resources.""" + ai = Genkit() + + # Define tool + @ai.tool() + def test_tool(x: int) -> int: + return x * 2 + + # Define prompt + ai.define_prompt(name='test', prompt='Test prompt') + + # Define resource + ai.define_resource(name='test_resource', uri='test://resource', fn=lambda req: {'content': [{'text': 'test'}]}) + + # Create server + server = create_mcp_server(ai, McpServerOptions(name='integration-test')) + await server.setup() + + # Verify all action types are available + self.assertGreater(len(server.tool_actions), 0) + self.assertGreater(len(server.prompt_actions), 0) + self.assertGreater(len(server.resource_actions), 0) + + async def test_server_initialization_idempotent(self): + """Test that server setup is idempotent.""" + ai = Genkit() + + @ai.tool() + def test_tool(x: int) -> int: + return x + + server = create_mcp_server(ai, McpServerOptions(name='test')) + + # Setup multiple times + await server.setup() + count1 = len(server.tool_actions) + + await server.setup() + count2 = len(server.tool_actions) + + # Should be the same + self.assertEqual(count1, count2) + + +if __name__ == '__main__': + unittest.main() diff --git a/py/plugins/ollama/pyproject.toml b/py/plugins/ollama/pyproject.toml index d9ee166da9..53edda87e7 100644 --- a/py/plugins/ollama/pyproject.toml +++ b/py/plugins/ollama/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -35,7 +34,7 @@ classifiers = [ ] dependencies = ["genkit", "ollama~=0.4", "structlog>=25.2.0"] description = "Genkit Ollama Plugin" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "genkit-plugin-ollama" readme = "README.md" requires-python = ">=3.10" diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/constants.py b/py/plugins/ollama/src/genkit/plugins/ollama/constants.py index 4060d414a5..3be251e2a0 100644 --- a/py/plugins/ollama/src/genkit/plugins/ollama/constants.py +++ b/py/plugins/ollama/src/genkit/plugins/ollama/constants.py @@ -14,6 +14,7 @@ # # SPDX-License-Identifier: Apache-2.0 + import sys if sys.version_info < (3, 11): diff --git a/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py b/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py index c3aa01e889..b007058f88 100644 --- a/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py +++ b/py/plugins/ollama/src/genkit/plugins/ollama/plugin_api.py @@ -16,15 +16,15 @@ """Ollama Plugin for Genkit.""" -import asyncio -from functools import cached_property, partial +from functools import partial import structlog import ollama as ollama_api -from genkit.ai import GenkitRegistry, Plugin -from genkit.blocks.embedding import EmbedderOptions, EmbedderSupports, embedder_action_metadata -from genkit.blocks.model import model_action_metadata +from genkit.ai import Plugin +from genkit.blocks.embedding import EmbedderOptions, EmbedderSupports, embedder, embedder_action_metadata +from genkit.blocks.model import model, model_action_metadata +from genkit.core.action import Action from genkit.core.registry import ActionKind from genkit.core.schema import to_json_schema from genkit.plugins.ollama.constants import ( @@ -91,63 +91,78 @@ def __init__( self.client = partial(ollama_api.AsyncClient, host=self.server_address) - def initialize(self, ai: GenkitRegistry) -> None: - """Initialize the Ollama plugin. - - Registers the defined Ollama models and embedders with the Genkit AI registry. - - Args: - ai: The AI registry to initialize the plugin with. - """ - self._initialize_models(ai=ai) - self._initialize_embedders(ai=ai) - - def _initialize_models(self, ai: GenkitRegistry) -> None: - """Initializes and registers the specified Ollama models with Genkit. - - Args: - ai: The Genkit AI registry instance. - """ + async def init(self): + """Return eagerly-initialized model and embedder actions.""" + actions = [] for model_definition in self.models: - self._define_ollama_model(ai, model_definition) - - def _initialize_embedders(self, ai: GenkitRegistry) -> None: - """Initializes and registers the specified Ollama embedders with Genkit. - - Args: - ai: The Genkit AI registry instance. - """ - for embedding_definition in self.embedders: - self._define_ollama_embedder(ai, embedding_definition) - - def resolve_action( - self, - ai: GenkitRegistry, - kind: ActionKind, - name: str, - ) -> None: - """Resolves and action. + actions.append(self._create_model_action(model_definition)) + for embedder_definition in self.embedders: + actions.append(self._create_embedder_action(embedder_definition)) + return actions - Args: - ai: The Genkit registry. - kind: The kind of action to resolve. - name: The name of the action to resolve. - """ - if kind == ActionKind.MODEL: - self._define_ollama_model(ai, ModelDefinition(name=name)) - elif kind == ActionKind.EMBEDDER: - self._define_ollama_embedder(ai, EmbeddingDefinition(name=name)) + async def resolve(self, action_type: ActionKind, name: str): + """Resolve a model or embedder action on-demand.""" + clean_name = name.replace(f'{OLLAMA_PLUGIN_NAME}/', '') if name.startswith(OLLAMA_PLUGIN_NAME) else name + + if action_type == ActionKind.MODEL: + # Prefer configured model definitions (api_type, supports) when available. + for model_def in self.models: + configured_name = ( + model_def.name.replace(OLLAMA_PLUGIN_NAME + '/', '') + if model_def.name.startswith(OLLAMA_PLUGIN_NAME) + else model_def.name + ) + if configured_name == clean_name: + return self._create_model_action(model_def) + return self._create_model_action(ModelDefinition(name=clean_name)) + elif action_type == ActionKind.EMBEDDER: + for embedder_def in self.embedders: + configured_name = ( + embedder_def.name.replace(OLLAMA_PLUGIN_NAME + '/', '') + if embedder_def.name.startswith(OLLAMA_PLUGIN_NAME) + else embedder_def.name + ) + if configured_name == clean_name: + return self._create_embedder_action(embedder_def) + return self._create_embedder_action(EmbeddingDefinition(name=clean_name)) + return None - def _define_ollama_model(self, ai: GenkitRegistry, model_ref: ModelDefinition) -> None: - """Defines and registers an Ollama model with Genkit. + async def list_actions(self): + """List all available Ollama models and embedders.""" + _client = self.client() + response = await _client.list() - Cleans the model name, instantiates an OllamaModel, and registers it - with the provided Genkit AI registry, including metadata about its capabilities. + actions = [] + for model_info in response.models: + _name = model_info.model + if 'embed' in _name: + actions.append( + embedder_action_metadata( + name=_name, + options=EmbedderOptions( + config_schema=to_json_schema(ollama_api.Options), + label=f'Ollama Embedding - {_name}', + supports=EmbedderSupports(input=['text']), + ), + ) + ) + else: + actions.append( + model_action_metadata( + name=_name, + config_schema=GenerationCommonConfig, + info={ + 'label': f'Ollama - {_name}', + 'multiturn': True, + 'system_role': True, + 'tools': False, + }, + ) + ) + return actions - Args: - ai: The Genkit AI registry instance. - model_ref: The definition of the model to be registered. - """ + def _create_model_action(self, model_ref: ModelDefinition) -> Action: + """Create an Ollama model action (doesn't register).""" _clean_name = ( model_ref.name.replace(OLLAMA_PLUGIN_NAME + '/', '') if model_ref.name.startswith(OLLAMA_PLUGIN_NAME) @@ -155,14 +170,14 @@ def _define_ollama_model(self, ai: GenkitRegistry, model_ref: ModelDefinition) - ) model_ref.name = _clean_name - model = OllamaModel( + ollama_model = OllamaModel( client=self.client, model_definition=model_ref, ) - ai.define_model( - name=ollama_name(model_ref.name), - fn=model.generate, + return model( + name=model_ref.name, + fn=ollama_model.generate, config_schema=GenerationCommonConfig, metadata={ 'label': f'Ollama - {_clean_name}', @@ -172,17 +187,8 @@ def _define_ollama_model(self, ai: GenkitRegistry, model_ref: ModelDefinition) - }, ) - def _define_ollama_embedder(self, ai: GenkitRegistry, embedder_ref: EmbeddingDefinition) -> None: - """Defines and registers an Ollama embedder with Genkit. - - Cleans the embedder name, instantiates an OllamaEmbedder, and registers it - with the provided Genkit AI registry, including metadata about its capabilities - and expected output dimensions. - - Args: - ai: The Genkit AI registry instance. - embedder_ref: The definition of the embedding model to be registered. - """ + def _create_embedder_action(self, embedder_ref: EmbeddingDefinition) -> Action: + """Create an Ollama embedder action (doesn't register).""" _clean_name = ( embedder_ref.name.replace(OLLAMA_PLUGIN_NAME + '/', '') if embedder_ref.name.startswith(OLLAMA_PLUGIN_NAME) @@ -190,14 +196,14 @@ def _define_ollama_embedder(self, ai: GenkitRegistry, embedder_ref: EmbeddingDef ) embedder_ref.name = _clean_name - embedder = OllamaEmbedder( + ollama_embedder = OllamaEmbedder( client=self.client, embedding_definition=embedder_ref, ) - ai.define_embedder( - name=ollama_name(embedder_ref.name), - fn=embedder.embed, + return embedder( + name=embedder_ref.name, + fn=ollama_embedder.embed, options=EmbedderOptions( config_schema=to_json_schema(ollama_api.Options), label=f'Ollama Embedding - {_clean_name}', @@ -205,52 +211,3 @@ def _define_ollama_embedder(self, ai: GenkitRegistry, embedder_ref: EmbeddingDef supports=EmbedderSupports(input=['text']), ), ) - - @cached_property - def list_actions(self) -> list[dict[str, str]]: - """Generate a list of available actions or models. - - Returns: - list[ActionMetadata]: A list of ActionMetadata objects, each with the following attributes: - - name (str): The name of the action or model. - - kind (ActionKind): The type or category of the action. - - info (dict): The metadata dictionary describing the model configuration and properties. - - config_schema (type): The schema class used for validating the model's configuration. - """ - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - _client = self.client() - response = loop.run_until_complete(_client.list()) - - actions = [] - for model in response.models: - _name = model.model - if 'embed' in _name: - actions.append( - embedder_action_metadata( - name=ollama_name(_name), - options=EmbedderOptions( - config_schema=to_json_schema(ollama_api.Options), - label=f'Ollama Embedding - {_name}', - supports=EmbedderSupports(input=['text']), - ), - ) - ) - else: - actions.append( - model_action_metadata( - name=ollama_name(_name), - config_schema=GenerationCommonConfig, - info={ - 'label': f'Ollama - {_name}', - 'multiturn': True, - 'system_role': True, - 'tools': False, - }, - ) - ) - return actions diff --git a/py/plugins/ollama/tests/test_integration.py b/py/plugins/ollama/tests/test_integration.py index 2f27075896..1c2ce5e093 100644 --- a/py/plugins/ollama/tests/test_integration.py +++ b/py/plugins/ollama/tests/test_integration.py @@ -16,31 +16,33 @@ """Integration tests for Ollama plugin with Genkit.""" -from unittest.mock import ANY, MagicMock, Mock, patch +from unittest.mock import Mock import ollama as ollama_api import pytest from genkit.ai import ActionKind, Genkit -from genkit.plugins.ollama import Ollama, ollama_name -from genkit.plugins.ollama.models import ModelDefinition -from genkit.types import GenerateResponse, GenerationCommonConfig, Message, Role, TextPart +from genkit.types import GenerateResponse, Message, Role, TextPart -def test_adding_ollama_chat_model_to_genkit_veneer( +@pytest.mark.asyncio +async def test_adding_ollama_chat_model_to_genkit_veneer( ollama_model: str, genkit_veneer_chat_model: Genkit, ) -> None: """Test adding ollama chat model to genkit veneer.""" - assert genkit_veneer_chat_model.registry.lookup_action(ActionKind.MODEL, ollama_model) + # Use async resolver-aware lookup for PluginV2 paths. + assert await genkit_veneer_chat_model.registry.aresolve_action(ActionKind.MODEL, ollama_model) -def test_adding_ollama_generation_model_to_genkit_veneer( +@pytest.mark.asyncio +async def test_adding_ollama_generation_model_to_genkit_veneer( ollama_model: str, genkit_veneer_generate_model: Genkit, ) -> None: """Test adding ollama generation model to genkit veneer.""" - assert genkit_veneer_generate_model.registry.lookup_action(ActionKind.MODEL, ollama_model) + # Use async resolver-aware lookup for PluginV2 paths. + assert await genkit_veneer_generate_model.registry.aresolve_action(ActionKind.MODEL, ollama_model) @pytest.mark.asyncio @@ -110,29 +112,3 @@ async def _test_fun(): assert isinstance(response, GenerateResponse) assert response.message.content[0].root.text == mock_response_message - - -@pytest.fixture -@patch('ollama.AsyncClient') -def ollama_plugin_instance(ollama_async_client): - return Ollama() - - -def test__initialize_models(ollama_plugin_instance): - ai_mock = MagicMock(spec=Genkit) - - plugin = ollama_plugin_instance - plugin.models = [ModelDefinition(name='test_model')] - plugin._initialize_models(ai_mock) - - ai_mock.define_model.assert_called_once_with( - name=ollama_name('test_model'), - fn=ANY, - config_schema=GenerationCommonConfig, - metadata={ - 'label': 'Ollama - test_model', - 'multiturn': True, - 'system_role': True, - 'tools': False, - }, - ) diff --git a/py/plugins/ollama/tests/test_plugin_api.py b/py/plugins/ollama/tests/test_plugin_api.py index ff0c14d11b..a8d459a6ab 100644 --- a/py/plugins/ollama/tests/test_plugin_api.py +++ b/py/plugins/ollama/tests/test_plugin_api.py @@ -17,19 +17,15 @@ """Unit tests for Ollama Plugin.""" import unittest -from unittest.mock import ANY, AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock -import ollama as ollama_api import pytest from pydantic import BaseModel -from genkit.ai import ActionKind, Genkit -from genkit.blocks.embedding import EmbedderOptions, EmbedderSupports -from genkit.core.schema import to_json_schema -from genkit.plugins.ollama import Ollama, ollama_name +from genkit.core.registry import ActionKind +from genkit.plugins.ollama import Ollama from genkit.plugins.ollama.embedders import EmbeddingDefinition from genkit.plugins.ollama.models import ModelDefinition -from genkit.types import GenerationCommonConfig class TestOllamaInit(unittest.TestCase): @@ -69,74 +65,17 @@ def test_init_with_options(self): assert plugin.request_headers == headers -def test_initialize(ollama_plugin_instance): - """Test initialize method of Ollama plugin.""" - ai_mock = MagicMock(spec=Genkit) - model_ref = ModelDefinition(name='test_model') - embedder_ref = EmbeddingDefinition(name='test_embedder') - ollama_plugin_instance.models = [model_ref] - ollama_plugin_instance.embedders = [embedder_ref] +@pytest.mark.asyncio +async def test_init_returns_actions(ollama_plugin_instance): + """PluginV2 init() should return actions (models + embedders) without namespacing.""" + ollama_plugin_instance.models = [ModelDefinition(name='test_model')] + ollama_plugin_instance.embedders = [EmbeddingDefinition(name='test_embedder', dimensions=1024)] - init_models = MagicMock() - init_embedders = MagicMock() + actions = await ollama_plugin_instance.init() - ollama_plugin_instance._initialize_models = init_models - ollama_plugin_instance._initialize_embedders = init_embedders - - ollama_plugin_instance.initialize(ai_mock) - - init_models.assert_called_once_with(ai=ai_mock) - init_embedders.assert_called_once_with(ai=ai_mock) - - -def test__initialize_models(ollama_plugin_instance): - """Test _initialize_models method of Ollama plugin.""" - ai_mock = MagicMock(spec=Genkit) - name = 'test_model' - - plugin = ollama_plugin_instance - plugin.models = [ModelDefinition(name=name)] - plugin._initialize_models(ai_mock) - - ai_mock.define_model.assert_called_once_with( - name=ollama_name(name), - fn=ANY, - config_schema=GenerationCommonConfig, - metadata={ - 'label': f'Ollama - {name}', - 'multiturn': True, - 'system_role': True, - 'tools': False, - }, - ) - - -def test__initialize_embedders(ollama_plugin_instance): - """Test _initialize_embedders method of Ollama plugin.""" - ai_mock = MagicMock(spec=Genkit) - name = 'test_embedder' - - plugin = ollama_plugin_instance - plugin.embedders = [ - EmbeddingDefinition( - name=name, - dimensions=1024, - ) - ] - plugin._initialize_embedders(ai_mock) - - ai_mock.define_embedder.assert_called_once_with( - name=ollama_name(name), - fn=ANY, - options=EmbedderOptions( - config_schema=to_json_schema(ollama_api.Options), - label=f'Ollama Embedding - {name}', - dimensions=1024, - supports=EmbedderSupports( - input=['text'], - ), - ), - ) + assert len(actions) == 2 + assert {a.kind for a in actions} == {ActionKind.MODEL, ActionKind.EMBEDDER} + assert {a.name for a in actions} == {'test_model', 'test_embedder'} @pytest.mark.parametrize( @@ -146,36 +85,13 @@ def test__initialize_embedders(ollama_plugin_instance): (ActionKind.EMBEDDER, 'test_embedder'), ], ) -def test_resolve_action(kind, name, ollama_plugin_instance): - """Unit Tests for resolve action method.""" - ai_mock = MagicMock(spec=Genkit) - ollama_plugin_instance.resolve_action(ai_mock, kind, name) - - if kind == ActionKind.MODEL: - ai_mock.define_model.assert_called_once_with( - name=ollama_name(name), - fn=ANY, - config_schema=GenerationCommonConfig, - metadata={ - 'label': f'Ollama - {name}', - 'multiturn': True, - 'system_role': True, - 'tools': False, - }, - ) - else: - ai_mock.define_embedder.assert_called_once_with( - name=ollama_name(name), - fn=ANY, - options=EmbedderOptions( - config_schema=to_json_schema(ollama_api.Options), - label=f'Ollama Embedding - {name}', - dimensions=None, - supports=EmbedderSupports( - input=['text'], - ), - ), - ) +@pytest.mark.asyncio +async def test_resolve_returns_action(kind, name, ollama_plugin_instance): + """PluginV2 resolve() should return an Action for models/embedders.""" + action = await ollama_plugin_instance.resolve(kind, name) + assert action is not None + assert action.kind == kind + assert action.name == name @pytest.mark.parametrize( @@ -185,23 +101,11 @@ def test_resolve_action(kind, name, ollama_plugin_instance): ('ollama/mistral', 'ollama/mistral', 'mistral'), ], ) -def test_define_ollama_model(name, expected_name, clean_name, ollama_plugin_instance): - """Unit tests for _define_ollama_model method.""" - ai_mock = MagicMock(spec=Genkit) - - ollama_plugin_instance._define_ollama_model(ai_mock, ModelDefinition(name=name)) - - ai_mock.define_model.assert_called_once_with( - name=expected_name, - fn=ANY, - config_schema=GenerationCommonConfig, - metadata={ - 'label': f'Ollama - {clean_name}', - 'multiturn': True, - 'system_role': True, - 'tools': False, - }, - ) +def test_create_model_action_cleans_name(name, expected_name, clean_name, ollama_plugin_instance): + """_create_model_action should strip namespace from input names.""" + action = ollama_plugin_instance._create_model_action(ModelDefinition(name=name)) + assert action.kind == ActionKind.MODEL + assert action.name == clean_name @pytest.mark.parametrize( @@ -211,28 +115,16 @@ def test_define_ollama_model(name, expected_name, clean_name, ollama_plugin_inst ('ollama/mistral', 'ollama/mistral', 'mistral'), ], ) -def test_define_ollama_embedder(name, expected_name, clean_name, ollama_plugin_instance): - """Unit tests for _define_ollama_embedder method.""" - ai_mock = MagicMock(spec=Genkit) - - ollama_plugin_instance._define_ollama_embedder(ai_mock, EmbeddingDefinition(name=name, dimensions=1024)) - - ai_mock.define_embedder.assert_called_once_with( - name=expected_name, - fn=ANY, - options=EmbedderOptions( - config_schema=to_json_schema(ollama_api.Options), - label=f'Ollama Embedding - {clean_name}', - dimensions=1024, - supports=EmbedderSupports( - input=['text'], - ), - ), - ) +def test_create_embedder_action_cleans_name(name, expected_name, clean_name, ollama_plugin_instance): + """_create_embedder_action should strip namespace from input names.""" + action = ollama_plugin_instance._create_embedder_action(EmbeddingDefinition(name=name, dimensions=1024)) + assert action.kind == ActionKind.EMBEDDER + assert action.name == clean_name -def test_list_actions(ollama_plugin_instance): - """Unit tests for list_actions method.""" +@pytest.mark.asyncio +async def test_list_returns_action_metadata(ollama_plugin_instance): + """PluginV2 list_actions() should return ActionMetadata and await the async client.""" class MockModelResponse(BaseModel): model: str @@ -256,7 +148,7 @@ def mock_client(): ollama_plugin_instance.client = mock_client - actions = ollama_plugin_instance.list_actions + actions = await ollama_plugin_instance.list_actions() assert len(actions) == 2 diff --git a/py/plugins/vertex-ai/pyproject.toml b/py/plugins/vertex-ai/pyproject.toml index a4d32d0189..1f8638cb47 100644 --- a/py/plugins/vertex-ai/pyproject.toml +++ b/py/plugins/vertex-ai/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -39,11 +38,12 @@ dependencies = [ "google-cloud-aiplatform>=1.77.0", "structlog>=25.2.0", "strenum>=0.4.15; python_version < '3.11'", + "genkit-plugin-compat-oai", "google-cloud-bigquery", "google-cloud-firestore", ] description = "Genkit Google Cloud Vertex AI Plugin" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "genkit-plugin-vertex-ai" readme = "README.md" requires-python = ">=3.10" diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/model_garden.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/model_garden.py index 7b522412ed..23705f1358 100644 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/model_garden.py +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/model_garden.py @@ -56,7 +56,7 @@ def __init__( model: str, location: str, project_id: str, - registry: GenkitRegistry, + registry: GenkitRegistry | None, ) -> None: """Initializes the ModelGarden instance. @@ -104,9 +104,14 @@ def to_openai_compatible_model(self) -> Callable: A callable function (specifically, the `generate` method of an `OpenAIModel` instance) that can be used by Genkit. """ - openai_model = OpenAIModel(self.name, self.client, self.ai) + # In PluginV2 paths we avoid registry-dependent tool lookup, but the legacy + # registry-based flow still passes a registry here. + openai_model = OpenAIModel(self.name, self.client) return openai_model.generate + # NOTE: OpenAIModel no longer requires a registry; tool schemas are provided via + # GenerateRequest.tools, so the returned function works for both v1/v2 flows. + def define_model(self) -> None: """Defines and registers the Model Garden model with the Genkit registry. @@ -114,6 +119,8 @@ def define_model(self) -> None: of the OpenAI-compatible generation function, then registers this model within the Genkit framework using `self.ai.define_model`. """ + if self.ai is None: + raise ValueError('ModelGarden.define_model() requires a GenkitRegistry') model_info = self.get_model_info() generate_fn = self.to_openai_compatible_model() self.ai.define_model( diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/modelgarden_plugin.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/modelgarden_plugin.py index 48eac737cc..a8b1455eec 100644 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/modelgarden_plugin.py +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/model_garden/modelgarden_plugin.py @@ -17,13 +17,13 @@ """ModelGarden API Compatible Plugin for Genkit.""" import os -from functools import cached_property -from genkit.ai import GenkitRegistry, Plugin -from genkit.blocks.model import model_action_metadata +from genkit.ai import Plugin +from genkit.blocks.model import model, model_action_metadata from genkit.core.action import ActionMetadata from genkit.core.action.types import ActionKind -from genkit.plugins.compat_oai.models import SUPPORTED_OPENAI_COMPAT_MODELS +from genkit.plugins.compat_oai.models import SUPPORTED_OPENAI_COMPAT_MODELS, OpenAIModelHandler +from genkit.plugins.compat_oai.models.model_info import PluginSource from genkit.plugins.compat_oai.typing import OpenAIConfig from genkit.plugins.vertex_ai import constants as const @@ -61,83 +61,48 @@ def __init__( """ self.project_id = project_id if project_id is not None else os.getenv(const.GCLOUD_PROJECT) self.location = location if location is not None else const.DEFAULT_REGION - self.models = models + self.models = models or [] - def initialize(self, ai: GenkitRegistry) -> None: - """Handles actions for various openaicompatible models.""" - models = self.models - if models is None: - return + async def init(self): + """Return eagerly-initialized model actions.""" + return [self._create_model_action(m) for m in self.models] - for model in models: - model_proxy = ModelGarden( - model=model, - location=self.location, - project_id=self.project_id, - registry=ai, - ) - model_proxy.define_model() - - def resolve_action( - self, - ai: GenkitRegistry, - kind: ActionKind, - name: str, - ) -> None: - """Resolves and action. - - Args: - ai: The Genkit registry. - kind: The kind of action to resolve. - name: The name of the action to resolve. - """ - if kind == ActionKind.MODEL: - self._resolve_model(ai=ai, name=name) - - def _resolve_model(self, ai: GenkitRegistry, name: str) -> None: - """Resolves and defines a Model Garden Vertex AI model within the Genkit registry. - - This internal method handles the logic for registering new models - of Vertex AI Model Garden that are compatible with OpenaI - based on the provided name. - It extracts a clean name, determines the model type, instantiates the - appropriate model class, and registers it with the Genkit AI registry. - - Args: - ai: The Genkit AI registry instance to define the model in. - name: The name of the model to resolve. This name might include a - prefix indicating it's from a specific plugin. - """ + async def resolve(self, action_type: ActionKind, name: str): + if action_type != ActionKind.MODEL: + return None clean_name = ( name.replace(MODELGARDEN_PLUGIN_NAME + '/', '') if name.startswith(MODELGARDEN_PLUGIN_NAME) else name ) + if clean_name not in SUPPORTED_OPENAI_COMPAT_MODELS: + return None + return self._create_model_action(clean_name) + + async def list_actions(self) -> list[ActionMetadata]: + return [ + model_action_metadata( + name=model_garden_name(model_name), + info=model_info.model_dump(), + config_schema=OpenAIConfig, + ) + for model_name, model_info in SUPPORTED_OPENAI_COMPAT_MODELS.items() + ] + def _create_model_action(self, model_name: str): model_proxy = ModelGarden( - model=clean_name, + model=model_name, location=self.location, project_id=self.project_id, - registry=ai, + registry=None, + ) + handler = OpenAIModelHandler.get_model_handler( + model=model_name, + client=model_proxy.client, # Vertex Model Garden OpenAI-compatible client + source=PluginSource.MODEL_GARDEN, + ) + model_info = model_proxy.get_model_info() + return model( + name=model_name, + fn=handler, + config_schema=OpenAIConfig, + metadata={'model': model_info}, ) - model_proxy.define_model() - - @cached_property - def list_actions(self) -> list[ActionMetadata]: - """Generate a list of available actions or models. - - Returns: - list[ActionMetadata]: A list of ActionMetadata objects, each with the following attributes: - - name (str): The name of the action or model. - - kind (ActionKind): The type or category of the action. - - info (dict): The metadata dictionary describing the model configuration and properties. - - config_schema (type): The schema class used for validating the model's configuration. - """ - - actions_list = [] - for model, model_info in SUPPORTED_OPENAI_COMPAT_MODELS.items(): - actions_list.append( - model_action_metadata( - name=model_garden_name(model), info=model_info.model_dump(), config_schema=OpenAIConfig - ) - ) - - return actions_list diff --git a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py index 8576bcbf15..ddfc6303be 100644 --- a/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py +++ b/py/plugins/vertex-ai/src/genkit/plugins/vertex_ai/vector_search/vector_search.py @@ -20,7 +20,11 @@ from google.auth.credentials import Credentials from google.cloud import aiplatform_v1 -from genkit.ai import GenkitRegistry, Plugin +from genkit.ai import Plugin +from genkit.blocks.retriever import RetrieverOptions, retriever_action_metadata +from genkit.core.action import Action, ActionMetadata +from genkit.core.action.types import ActionKind +from genkit.core.schema import to_json_schema from genkit.plugins.vertex_ai.vector_search.retriever import ( DocRetriever, RetrieverOptionsSchema, @@ -42,13 +46,10 @@ def vertexai_name(name: str) -> str: class VertexAIVectorSearch(Plugin): - """A plugin for integrating VertexAI Vector Search. + """A plugin for integrating VertexAI Vector Search.""" - This class registers VertexAI Vector Stores within a registry, - and allows interaction to retrieve similar documents. - """ - - name: str = 'vertexAIVectorSearch' + name: str = VERTEXAI_PLUGIN_NAME + retriever_name: str = 'vertexAIVectorSearch' def __init__( self, @@ -90,25 +91,56 @@ def __init__( credentials=credentials, ) - def initialize(self, ai: GenkitRegistry) -> None: - """Initialize plugin with the retriver specified. - - Register actions with the registry making them available for use in the Genkit framework. - - Args: - ai: The registry to register actions with. - """ - retriever = self.retriever_cls( - ai=ai, - name=self.name, - match_service_client_generator=self._match_service_client_generator, - embedder=self.embedder, - embedder_options=self.embedder_options, - **self.retriever_extra_args, - ) - - return ai.define_retriever( - name=vertexai_name(self.name), - config_schema=RetrieverOptionsSchema, - fn=retriever.retrieve, + async def init(self) -> list[Action]: + return [self._create_retriever_action()] + + async def resolve(self, action_type: ActionKind, name: str) -> Action | None: + if action_type != ActionKind.RETRIEVER: + return None + if name != self.retriever_name: + return None + return self._create_retriever_action() + + async def list_actions(self) -> list[ActionMetadata]: + return [ + retriever_action_metadata( + name=self.retriever_name, + options=RetrieverOptions( + label='Vertex AI Vector Search', + config_schema=to_json_schema(RetrieverOptionsSchema), + ), + ) + ] + + def _create_retriever_action(self) -> Action: + metadata: dict[str, Any] = { + 'retriever': { + 'label': self.retriever_name, + 'customOptions': to_json_schema(RetrieverOptionsSchema), + } + } + + async def retrieve(request, ctx): + ai = (ctx.context or {}).get('__genkit_ai__') + if ai is None: + raise ValueError( + 'VertexAIVectorSearch retriever requires a Genkit instance in action context. ' + 'Use it via `await ai.retrieve(...)`.' + ) + + retriever = self.retriever_cls( + ai=ai, + name=self.retriever_name, + match_service_client_generator=self._match_service_client_generator, + embedder=self.embedder, + embedder_options=self.embedder_options, + **self.retriever_extra_args, + ) + return await retriever.retrieve(request, ctx) + + return Action( + kind=ActionKind.RETRIEVER, + name=self.retriever_name, + fn=retrieve, + metadata=metadata, ) diff --git a/py/plugins/vertex-ai/tests/vector_search/test_retrievers.py b/py/plugins/vertex-ai/tests/vector_search/test_retrievers.py index 3feb75ae91..c4eabaeae4 100644 --- a/py/plugins/vertex-ai/tests/vector_search/test_retrievers.py +++ b/py/plugins/vertex-ai/tests/vector_search/test_retrievers.py @@ -24,7 +24,6 @@ from unittest.mock import AsyncMock, MagicMock import pytest -from google.cloud import bigquery from google.cloud.aiplatform_v1 import ( FindNeighborsRequest, FindNeighborsResponse, diff --git a/py/plugins/vertex-ai/tests/vector_search/test_vector_search_plugin.py b/py/plugins/vertex-ai/tests/vector_search/test_vector_search_plugin.py index 912831c6d3..f8c44010d1 100644 --- a/py/plugins/vertex-ai/tests/vector_search/test_vector_search_plugin.py +++ b/py/plugins/vertex-ai/tests/vector_search/test_vector_search_plugin.py @@ -18,17 +18,22 @@ from unittest.mock import MagicMock -from genkit.ai import Genkit +import pytest + +from genkit.core.action.types import ActionKind from genkit.plugins.vertex_ai.vector_search import VertexAIVectorSearch -def test_initialize_plugin(): - """Test plugin initialization.""" +@pytest.mark.asyncio +async def test_init_plugin_returns_retriever_action(): + """PluginV2 init should return the vector-search retriever action.""" plugin = VertexAIVectorSearch( retriever=MagicMock(), embedder='embedder', ) - result = plugin.initialize(ai=MagicMock(spec=Genkit)) + actions = await plugin.init() - assert result is not None + assert len(actions) == 1 + assert actions[0].kind == ActionKind.RETRIEVER + assert actions[0].name == 'vertexAIVectorSearch' diff --git a/py/plugins/xai/pyproject.toml b/py/plugins/xai/pyproject.toml index 15843b3727..bcfc93269f 100644 --- a/py/plugins/xai/pyproject.toml +++ b/py/plugins/xai/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -35,7 +34,7 @@ classifiers = [ ] dependencies = ["genkit", "xai-sdk>=0.0.1"] description = "Genkit xAI Plugin" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "genkit-plugin-xai" readme = "README.md" requires-python = ">=3.10" diff --git a/py/plugins/xai/src/genkit/plugins/xai/plugin.py b/py/plugins/xai/src/genkit/plugins/xai/plugin.py index 046736d47d..7a1bf57f7b 100644 --- a/py/plugins/xai/src/genkit/plugins/xai/plugin.py +++ b/py/plugins/xai/src/genkit/plugins/xai/plugin.py @@ -20,7 +20,9 @@ from xai_sdk import Client as XAIClient -from genkit.ai import GenkitRegistry, Plugin +from genkit.ai import Plugin +from genkit.blocks.model import model +from genkit.core.action import ActionMetadata from genkit.core.error import GenkitError from genkit.core.registry import ActionKind from genkit.plugins.xai.model_info import SUPPORTED_XAI_MODELS, get_model_info @@ -56,30 +58,37 @@ def __init__( self._xai_params = xai_params self._xai_client = XAIClient(api_key=api_key, **xai_params) - def initialize(self, ai: GenkitRegistry) -> None: - for model_name in self.models: - self._define_model(ai, model_name) - - def resolve_action( - self, - ai: GenkitRegistry, - kind: ActionKind, - name: str, - ) -> None: - if kind == ActionKind.MODEL: - self._resolve_model(ai=ai, name=name) - - def _resolve_model(self, ai: GenkitRegistry, name: str) -> None: - clean_name = name.replace(f'{XAI_PLUGIN_NAME}/', '') if name.startswith(XAI_PLUGIN_NAME) else name - self._define_model(ai, clean_name) - - def _define_model(self, ai: GenkitRegistry, model_name: str) -> None: - model = XAIModel(model_name=model_name, client=self._xai_client) + async def init(self): + """Return eagerly-initialized model actions.""" + return [self._create_model_action(model_name) for model_name in self.models] + + async def resolve(self, action_type: ActionKind, name: str): + """Resolve a model action on-demand.""" + if action_type == ActionKind.MODEL: + clean_name = name.replace(f'{XAI_PLUGIN_NAME}/', '') if name.startswith(XAI_PLUGIN_NAME) else name + if clean_name in SUPPORTED_XAI_MODELS: + return self._create_model_action(clean_name) + return None + + async def list_actions(self): + """List all supported xAI models.""" + return [ + ActionMetadata( + name=model_name, + kind=ActionKind.MODEL, + info={'supports': get_model_info(model_name).supports.model_dump()}, + ) + for model_name in self.models + ] + + def _create_model_action(self, model_name: str): + """Create an xAI model action (doesn't register).""" + xai_model = XAIModel(model_name=model_name, client=self._xai_client) model_info = get_model_info(model_name) - ai.define_model( - name=xai_name(model_name), - fn=model.generate, + return model( + name=model_name, + fn=xai_model.generate, config_schema=GenerationCommonConfig, metadata={'model': {'supports': model_info.supports.model_dump()}}, ) diff --git a/py/plugins/xai/tests/test_xai_models.py b/py/plugins/xai/tests/test_xai_models.py index fa985a429d..95958ac043 100644 --- a/py/plugins/xai/tests/test_xai_models.py +++ b/py/plugins/xai/tests/test_xai_models.py @@ -16,7 +16,6 @@ """Tests for xAI models.""" -import asyncio from unittest.mock import MagicMock import pytest diff --git a/py/plugins/xai/tests/test_xai_plugin.py b/py/plugins/xai/tests/test_xai_plugin.py index 43cb615709..71ec98f8c7 100644 --- a/py/plugins/xai/tests/test_xai_plugin.py +++ b/py/plugins/xai/tests/test_xai_plugin.py @@ -16,7 +16,9 @@ """Tests for xAI plugin.""" -from unittest.mock import MagicMock, patch +from unittest.mock import patch + +import pytest from genkit.core.error import GenkitError from genkit.core.registry import ActionKind @@ -38,7 +40,7 @@ def test_init_without_api_key_raises(): with patch.dict('os.environ', {}, clear=True): try: XAI() - assert False, 'Expected GenkitError' + raise AssertionError('Expected GenkitError') except GenkitError: pass @@ -54,23 +56,26 @@ def test_custom_models(): assert plugin.models == ['grok-3', 'grok-3-mini'] -def test_plugin_initialize(): - registry = MagicMock() +@pytest.mark.asyncio +async def test_plugin_initialize(): plugin = XAI(api_key='test-key') - plugin.initialize(registry) - assert registry.define_model.call_count == len(SUPPORTED_XAI_MODELS) + actions = await plugin.init() + assert len(actions) == len(SUPPORTED_XAI_MODELS) + assert all(action.kind == ActionKind.MODEL for action in actions) -def test_resolve_action_model(): - registry = MagicMock() +@pytest.mark.asyncio +async def test_resolve_action_model(): plugin = XAI(api_key='test-key') - plugin.resolve_action(registry, ActionKind.MODEL, 'xai/grok-3') - registry.define_model.assert_called_once() + action = await plugin.resolve(ActionKind.MODEL, 'grok-3') + assert action is not None + assert action.kind == ActionKind.MODEL + assert action.name == 'grok-3' def test_supported_models(): assert len(SUPPORTED_XAI_MODELS) >= 4 - for name, info in SUPPORTED_XAI_MODELS.items(): + for _name, info in SUPPORTED_XAI_MODELS.items(): assert info.label.startswith('xAI - ') assert len(info.versions) > 0 assert info.supports.tools diff --git a/py/pyproject.toml b/py/pyproject.toml index 400fa73842..d941379b2c 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -15,11 +15,13 @@ # SPDX-License-Identifier: Apache-2.0 [project] +authors = [{ name = "Google" }] dependencies = [ - "dotpromptz==0.1.4", + "dotprompt", "genkit", "genkit-plugin-anthropic", "genkit-plugin-compat-oai", + "genkit-plugin-deepseek", "genkit-plugin-dev-local-vectorstore", "genkit-plugin-evaluators", "genkit-plugin-firebase", @@ -30,10 +32,11 @@ dependencies = [ "genkit-plugin-vertex-ai", "genkit-plugin-xai", "liccheck>=0.9.2", + "mcp>=1.25.0", "strenum>=0.4.15; python_version < '3.11'", ] description = "Workspace for Genkit packages" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "genkit-workspace" readme = "README.md" requires-python = ">=3.10" @@ -60,7 +63,7 @@ dev = [ "nox-uv>=0.2.2", ] -lint = ["mypy>=1.15", "ruff>=0.9"] +lint = ["ty>=0.0.1", "ruff>=0.9"] [tool.hatch.build.targets.wheel] packages = [] @@ -95,6 +98,7 @@ omit = [ "**/typing.py", # Often auto-generated or complex types "**/types.py", # Often auto-generated or complex types ] +source = ["packages", "plugins"] # uv based package management. [tool.uv] @@ -105,6 +109,7 @@ evaluator-demo = { workspace = true } genkit = { workspace = true } genkit-plugin-anthropic = { workspace = true } genkit-plugin-compat-oai = { workspace = true } +genkit-plugin-deepseek = { workspace = true } genkit-plugin-dev-local-vectorstore = { workspace = true } genkit-plugin-evaluators = { workspace = true } genkit-plugin-firebase = { workspace = true } @@ -198,29 +203,6 @@ line-ending = "lf" quote-style = "single" skip-magic-trailing-comma = false -# Static type checking. -[tool.mypy] -disallow_incomplete_defs = true -disallow_untyped_defs = true -exclude = ["samples/"] -explicit_package_bases = true -mypy_path = [ - "packages/genkit/src", - "plugins/chroma/src", - "plugins/compat-oai/src", - "plugins/dev-local-vectorstore/src", - "plugins/firebase/src", - "plugins/flask/src", - "plugins/google-cloud/src", - "plugins/google-genai/src", - "plugins/ollama/src", - "plugins/pinecone/src", - "plugins/vertex-ai/src", -] -namespace_packages = true -strict = true -warn_unused_configs = true - [tool.datamodel-codegen] #collapse-root-models = true # Don't use; produces Any as types. #strict-types = ["str", "int", "float", "bool", "bytes"] # Don't use; produces StrictStr, StrictInt, etc. diff --git a/py/samples/anthropic-hello/.gitignore b/py/samples/anthropic-hello/.gitignore new file mode 100644 index 0000000000..7065f5d82e --- /dev/null +++ b/py/samples/anthropic-hello/.gitignore @@ -0,0 +1,3 @@ +.env + + diff --git a/py/samples/anthropic-hello/env.example b/py/samples/anthropic-hello/env.example new file mode 100644 index 0000000000..229d04e30c --- /dev/null +++ b/py/samples/anthropic-hello/env.example @@ -0,0 +1,4 @@ +# Copy this file to ".env" and fill in values. Do NOT commit ".env". +ANTHROPIC_API_KEY=your-anthropic-api-key + + diff --git a/py/samples/anthropic-hello/pyproject.toml b/py/samples/anthropic-hello/pyproject.toml index 17ec62d435..6588a3f712 100644 --- a/py/samples/anthropic-hello/pyproject.toml +++ b/py/samples/anthropic-hello/pyproject.toml @@ -15,6 +15,7 @@ # SPDX-License-Identifier: Apache-2.0 [project] +authors = [{ name = "Google" }] dependencies = [ "genkit", "genkit-plugin-anthropic", diff --git a/py/samples/anthropic-hello/run.sh b/py/samples/anthropic-hello/run.sh index b3170f6ef6..f36c1df922 100755 --- a/py/samples/anthropic-hello/run.sh +++ b/py/samples/anthropic-hello/run.sh @@ -15,4 +15,14 @@ # # SPDX-License-Identifier: Apache-2.0 +set -euo pipefail + +# Load local env if present (do not commit .env; see env.example) +if [ -f ".env" ]; then + set -a + # shellcheck disable=SC1091 + . ".env" + set +a +fi + exec genkit start -- uv run src/main.py "$@" diff --git a/py/samples/anthropic-hello/src/main.py b/py/samples/anthropic-hello/src/main.py index fc38c2f522..8c8729ccaf 100755 --- a/py/samples/anthropic-hello/src/main.py +++ b/py/samples/anthropic-hello/src/main.py @@ -195,7 +195,6 @@ async def say_hi_with_config(name: str) -> str: async def main() -> None: """Main entry point for the Anthropic sample.""" - result = await say_hi('John Doe') await logger.ainfo('Simple greeting', result=result) diff --git a/py/samples/compat-oai-hello/pyproject.toml b/py/samples/compat-oai-hello/pyproject.toml index 2c9af3e410..6ec5e1f157 100644 --- a/py/samples/compat-oai-hello/pyproject.toml +++ b/py/samples/compat-oai-hello/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -41,7 +40,7 @@ dependencies = [ "httpx>=0.28.1", ] description = "OpenAI sample" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "compat-oai-hello" readme = "README.md" requires-python = ">=3.10" diff --git a/py/samples/compat-oai-hello/src/main.py b/py/samples/compat-oai-hello/src/main.py index 233c138b01..94effed2cb 100755 --- a/py/samples/compat-oai-hello/src/main.py +++ b/py/samples/compat-oai-hello/src/main.py @@ -212,7 +212,7 @@ async def get_weather_flow_stream(location: str) -> str: class Skills(BaseModel): - """A set of core character skills for an RPG character""" + """A set of core character skills for an RPG character.""" strength: int = Field(description='strength (0-100)') charisma: int = Field(description='charisma (0-100)') diff --git a/py/samples/deepseek-hello/README.md b/py/samples/deepseek-hello/README.md new file mode 100644 index 0000000000..477f8ccc77 --- /dev/null +++ b/py/samples/deepseek-hello/README.md @@ -0,0 +1,19 @@ +## DeepSeek Sample + +1. Setup environment and install dependencies: +```bash +uv venv +source .venv/bin/activate + +uv sync +``` + +2. Set DeepSeek API key (get one from [DeepSeek Platform](https://platform.deepseek.com/)): +```bash +export DEEPSEEK_API_KEY=your-api-key +``` + +3. Run the sample: +```bash +genkit start -- uv run src/main.py +``` diff --git a/py/samples/deepseek-hello/pyproject.toml b/py/samples/deepseek-hello/pyproject.toml new file mode 100644 index 0000000000..cb48c544d8 --- /dev/null +++ b/py/samples/deepseek-hello/pyproject.toml @@ -0,0 +1,38 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +[project] +authors = [{ name = "Google" }] +dependencies = [ + "genkit", + "genkit-plugin-deepseek", + "pydantic>=2.0.0", + "structlog>=24.0.0", +] +description = "DeepSeek Hello Sample" +name = "deepseek-hello" +requires-python = ">=3.10" +version = "0.1.0" + +[tool.uv.sources] +genkit-plugin-deepseek = { workspace = true } + +[build-system] +build-backend = "hatchling.build" +requires = ["hatchling"] + +[tool.hatch.build.targets.wheel] +packages = ["src"] diff --git a/py/samples/deepseek-hello/run.sh b/py/samples/deepseek-hello/run.sh new file mode 100644 index 0000000000..02a864050f --- /dev/null +++ b/py/samples/deepseek-hello/run.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +exec genkit start -- uv run src/main.py "$@" diff --git a/py/samples/deepseek-hello/src/main.py b/py/samples/deepseek-hello/src/main.py new file mode 100644 index 0000000000..bfc714d437 --- /dev/null +++ b/py/samples/deepseek-hello/src/main.py @@ -0,0 +1,279 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""DeepSeek hello sample. + +Key features demonstrated in this sample: + +| Feature Description | Example Function / Code Snippet | +|-----------------------------------------|-----------------------------------------| +| Plugin Initialization | `ai = Genkit(plugins=[DeepSeek(...)])` | +| Default Model Configuration | `ai = Genkit(model=deepseek_name(...))` | +| Defining Flows | `@ai.flow()` decorator | +| Defining Tools | `@ai.tool()` decorator | +| Pydantic for Tool Input Schema | `WeatherInput` | +| Simple Generation (Prompt String) | `say_hi` | +| Streaming Response | `streaming_flow` | +| Generation with Tools | `weather_flow` | +| Reasoning Model (deepseek-reasoner) | `reasoning_flow` | +| Generation with Config | `custom_config_flow` | +| Multi-turn Chat | `chat_flow` | +""" + +import structlog +from pydantic import BaseModel, Field + +from genkit.ai import Genkit +from genkit.core.action import ActionRunContext +from genkit.plugins.deepseek import DeepSeek, deepseek_name +from genkit.types import Message, Part, Role, TextPart, ToolResponse + +logger = structlog.get_logger(__name__) + +ai = Genkit( + plugins=[DeepSeek()], + model=deepseek_name('deepseek-chat'), +) + + +class WeatherInput(BaseModel): + """Input schema for the weather tool.""" + + location: str = Field(description='The city and state, e.g. San Francisco, CA') + + +@ai.tool() +def get_weather(input: WeatherInput) -> str: + """Get weather of a location, the user should supply a location first. + + Args: + input: Weather input with location (city and state, e.g. San Francisco, CA). + + Returns: + Weather information with temperature in degrees Fahrenheit. + """ + # Mocked weather data + weather_data = { + 'San Francisco, CA': {'temp': 72, 'condition': 'sunny', 'humidity': 65}, + 'Seattle, WA': {'temp': 55, 'condition': 'rainy', 'humidity': 85}, + } + + location = input.location + data = weather_data.get(location, {'temp': 70, 'condition': 'partly cloudy', 'humidity': 55}) + + return f'The weather in {location} is {data["temp"]}°F and {data["condition"]}. Humidity is {data["humidity"]}%.' + + +@ai.flow() +async def say_hi(name: str) -> str: + """Generate a simple greeting. + + Args: + name: Name to greet. + + Returns: + Greeting message. + """ + response = await ai.generate(prompt=f'Say hello to {name}!') + return response.text + + +@ai.flow() +async def streaming_flow(topic: str, ctx: ActionRunContext) -> str: + """Generate with streaming response. + + Args: + topic: Topic to generate about. + ctx: Action run context for streaming chunks to client. + + Returns: + Generated text. + """ + response = await ai.generate( + prompt=f'Tell me a fun fact about {topic}', + on_chunk=ctx.send_chunk, + ) + return response.text + + +@ai.flow() +async def weather_flow(location: str) -> str: + """Get weather using compat-oai auto tool calling.""" + + response = await ai.generate( + model=deepseek_name('deepseek-chat'), + prompt=f'What is the weather in {location}?', + system=( + 'You have a tool called get_weather. ' + "It takes an object with a 'location' field. " + 'Always use this tool when asked about weather.' + ), + tools=['get_weather'], + tool_choice='required', + max_turns=2, + ) + + return response.text + + +@ai.flow() +async def reasoning_flow(prompt: str | None = None) -> str: + """Solve reasoning problems using deepseek-reasoner model. + + Args: + prompt: The reasoning question to solve. Defaults to a classic logic problem. + + Returns: + The reasoning and answer. + """ + if prompt is None: + prompt = 'What is heavier, one kilo of steel or one kilo of feathers?' + + response = await ai.generate( + model=deepseek_name('deepseek-reasoner'), + prompt=prompt, + ) + return response.text + + +@ai.flow() +async def custom_config_flow(task: str | None = None) -> str: + """Demonstrate custom model configurations for different tasks. + + Shows how different config parameters affect generation behavior: + - 'creative': High temperature for diverse, creative outputs + - 'precise': Low temperature with penalties for consistent, focused outputs + - 'detailed': Extended output with frequency penalty to avoid repetition + + Args: + task: Type of task - 'creative', 'precise', or 'detailed' + + Returns: + Generated response showing the effect of different configs. + """ + if task is None: + task = 'creative' + + prompts = { + 'creative': 'Write a creative story opener about a robot discovering art', + 'precise': 'List the exact steps to make a cup of tea', + 'detailed': 'Explain how photosynthesis works in detail', + } + + configs = { + 'creative': { + 'temperature': 1.5, # High temperature for creativity + 'max_tokens': 200, + 'top_p': 0.95, + }, + 'precise': { + 'temperature': 0.1, # Low temperature for consistency + 'max_tokens': 150, + 'presence_penalty': 0.5, # Encourage covering all steps + }, + 'detailed': { + 'temperature': 0.7, + 'max_tokens': 400, # More tokens for detailed explanation + 'frequency_penalty': 0.8, # Reduce repetitive phrasing + }, + } + + prompt = prompts.get(task, prompts['creative']) + config = configs.get(task, configs['creative']) + + response = await ai.generate( + prompt=prompt, + config=config, + ) + return response.text + + +@ai.flow() +async def chat_flow() -> str: + """Multi-turn chat example demonstrating context retention. + + Returns: + Final chat response. + """ + history = [] + + # First turn - User shares information + prompt1 = "Hi! I'm planning a trip to Tokyo next month. I'm really excited because I love Japanese cuisine, especially ramen and sushi." + response1 = await ai.generate( + prompt=prompt1, + system='You are a helpful travel assistant.', + ) + history.append(Message(role=Role.USER, content=[TextPart(text=prompt1)])) + history.append(response1.message) + await logger.ainfo('chat_flow turn 1', result=response1.text) + + # Second turn - Ask question requiring context from first turn + response2 = await ai.generate( + messages=history + [Message(role=Role.USER, content=[TextPart(text='What foods did I say I enjoy?')])], + system='You are a helpful travel assistant.', + ) + history.append(Message(role=Role.USER, content=[TextPart(text='What foods did I say I enjoy?')])) + history.append(response2.message) + await logger.ainfo('chat_flow turn 2', result=response2.text) + + # Third turn - Ask question requiring context from both previous turns + response3 = await ai.generate( + messages=history + + [ + Message( + role=Role.USER, + content=[TextPart(text='Based on our conversation, suggest one restaurant I should visit.')], + ) + ], + system='You are a helpful travel assistant.', + ) + return response3.text + + +async def main() -> None: + """Main entry point for the DeepSeek sample.""" + # Simple greeting + result = await say_hi('World') + await logger.ainfo('say_hi', result=result) + + # Streaming response + result = await streaming_flow('apple') + await logger.ainfo('streaming_flow', result=result) + + # Weather with tools + result = await weather_flow('Seattle, WA') + await logger.ainfo('weather_flow', result=result) + + # Reasoning model + result = await reasoning_flow() + await logger.ainfo('reasoning_flow', result=result) + + # Custom config - demonstrate different configurations + await logger.ainfo('Testing creative config...') + result = await custom_config_flow('creative') + await logger.ainfo('custom_config_flow (creative)', result=result) + + await logger.ainfo('Testing precise config...') + result = await custom_config_flow('precise') + await logger.ainfo('custom_config_flow (precise)', result=result) + + # Multi-turn chat + result = await chat_flow() + await logger.ainfo('chat_flow', result=result) + + +if __name__ == '__main__': + ai.run_main(main()) diff --git a/py/samples/dev-local-vectorstore-hello/pyproject.toml b/py/samples/dev-local-vectorstore-hello/pyproject.toml index dafc44ebaa..3a8a5911d8 100644 --- a/py/samples/dev-local-vectorstore-hello/pyproject.toml +++ b/py/samples/dev-local-vectorstore-hello/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -41,7 +40,7 @@ dependencies = [ "structlog>=25.2.0", ] description = "hello Genkit sample" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "dev-local-vectorstore-hello" readme = "README.md" requires-python = ">=3.10" diff --git a/py/samples/dev-local-vectorstore-hello/src/main.py b/py/samples/dev-local-vectorstore-hello/src/main.py index 9cbb91670f..ea006d1caa 100755 --- a/py/samples/dev-local-vectorstore-hello/src/main.py +++ b/py/samples/dev-local-vectorstore-hello/src/main.py @@ -27,7 +27,7 @@ embedder='vertexai/text-embedding-004', ), ], - model='vertexai/gemini-2.5-flash', + model='vertexai/gemini-3-flash-preview', ) films = [ diff --git a/py/samples/evaluator-demo/pyproject.toml b/py/samples/evaluator-demo/pyproject.toml index a80d91cc8b..771c966747 100644 --- a/py/samples/evaluator-demo/pyproject.toml +++ b/py/samples/evaluator-demo/pyproject.toml @@ -15,6 +15,7 @@ # SPDX-License-Identifier: Apache-2.0 [project] +authors = [{ name = "Google" }] dependencies = ["genkit", "pydantic>=2.0.0", "structlog>=24.0.0", "pypdf"] description = "Genkit Python Evaluation Demo" name = "eval-demo" diff --git a/py/samples/evaluator-demo/src/genkit_demo.py b/py/samples/evaluator-demo/src/genkit_demo.py index 0ccd8f64ea..618b98c8ca 100644 --- a/py/samples/evaluator-demo/src/genkit_demo.py +++ b/py/samples/evaluator-demo/src/genkit_demo.py @@ -47,17 +47,17 @@ GenkitEvaluators([ MetricConfig( metric_type=GenkitMetricType.MALICIOUSNESS, - judge=ModelReference(name='googleai/gemini-2.5-pro'), + judge=ModelReference(name='googleai/gemini-3-pro-preview'), judge_config=PERMISSIVE_SAFETY_SETTINGS, ), MetricConfig( metric_type=GenkitMetricType.ANSWER_RELEVANCY, - judge=ModelReference(name='googleai/gemini-2.5-pro'), + judge=ModelReference(name='googleai/gemini-3-pro-preview'), judge_config=PERMISSIVE_SAFETY_SETTINGS, ), MetricConfig( metric_type=GenkitMetricType.FAITHFULNESS, - judge=ModelReference(name='googleai/gemini-2.5-pro'), + judge=ModelReference(name='googleai/gemini-3-pro-preview'), judge_config=PERMISSIVE_SAFETY_SETTINGS, ), ]), diff --git a/py/samples/evaluator-demo/src/main.py b/py/samples/evaluator-demo/src/main.py index e3961390c9..6053046882 100755 --- a/py/samples/evaluator-demo/src/main.py +++ b/py/samples/evaluator-demo/src/main.py @@ -16,10 +16,7 @@ import random -from eval_in_code import dog_facts_eval_flow from genkit_demo import ai -from pdf_rag import index_pdf, pdf_qa, simple_echo, simple_structured -from setup import setup from genkit.core.typing import BaseEvalDataPoint, EvalStatusEnum, Score diff --git a/py/samples/firestore-retreiver/pyproject.toml b/py/samples/firestore-retreiver/pyproject.toml index 485dea5882..1f856710f2 100644 --- a/py/samples/firestore-retreiver/pyproject.toml +++ b/py/samples/firestore-retreiver/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -35,7 +34,7 @@ classifiers = [ ] dependencies = ["genkit", "google-cloud-firestore"] description = "firestore-retreiver Genkit sample" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "firestore-retreiver" readme = "README.md" requires-python = ">=3.10" diff --git a/py/samples/flask-hello/pyproject.toml b/py/samples/flask-hello/pyproject.toml index 9397e7d598..e07d52625a 100644 --- a/py/samples/flask-hello/pyproject.toml +++ b/py/samples/flask-hello/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -40,7 +39,7 @@ dependencies = [ "flask", ] description = "hello Genkit sample" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "flask-hello" readme = "README.md" requires-python = ">=3.10" diff --git a/py/samples/flask-hello/src/main.py b/py/samples/flask-hello/src/main.py index fa18efe490..226e8a33f6 100755 --- a/py/samples/flask-hello/src/main.py +++ b/py/samples/flask-hello/src/main.py @@ -28,7 +28,7 @@ ai = Genkit( plugins=[GoogleAI()], - model=googleai_name(GoogleAIGeminiVersion.GEMINI_2_0_FLASH), + model=googleai_name(GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), ) app = Flask(__name__) diff --git a/py/samples/google-genai-code-execution/pyproject.toml b/py/samples/google-genai-code-execution/pyproject.toml index d5dfa8f2db..267cf9ad6d 100644 --- a/py/samples/google-genai-code-execution/pyproject.toml +++ b/py/samples/google-genai-code-execution/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -40,7 +39,7 @@ dependencies = [ "structlog>=25.2.0", ] description = "Code execution sample" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "google-genai-code-execution" readme = "README.md" requires-python = ">=3.10" diff --git a/py/samples/google-genai-code-execution/src/main.py b/py/samples/google-genai-code-execution/src/main.py index 4b554a703d..bcb1a01cde 100755 --- a/py/samples/google-genai-code-execution/src/main.py +++ b/py/samples/google-genai-code-execution/src/main.py @@ -27,7 +27,7 @@ ai = Genkit( plugins=[GoogleAI()], - model=googleai_name('gemini-2.5-flash'), + model=googleai_name('gemini-3-flash-preview'), ) diff --git a/py/samples/google-genai-context-caching/pyproject.toml b/py/samples/google-genai-context-caching/pyproject.toml index 17035a9a72..a1008ab42e 100644 --- a/py/samples/google-genai-context-caching/pyproject.toml +++ b/py/samples/google-genai-context-caching/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -41,7 +40,7 @@ dependencies = [ "structlog>=25.2.0", ] description = "context-caching Genkit sample" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "google-genai-context-caching" readme = "README.md" requires-python = ">=3.10" diff --git a/py/samples/google-genai-context-caching/src/main.py b/py/samples/google-genai-context-caching/src/main.py index ddf51f3225..1b1727b696 100755 --- a/py/samples/google-genai-context-caching/src/main.py +++ b/py/samples/google-genai-context-caching/src/main.py @@ -14,7 +14,7 @@ # # SPDX-License-Identifier: Apache-2.0 -"""Sample that demonstrates caching of generation context in Genkit +"""Sample that demonstrates caching of generation context in Genkit. In this sample user actor supplies "Tom Sawyer" book content from Gutenberg library archive and model caches this context. @@ -34,7 +34,7 @@ ai = Genkit( plugins=[GoogleAI()], - model=googleai_name(GoogleAIGeminiVersion.GEMINI_1_5_FLASH), + model=googleai_name(GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), ) # Tom Sawyer is taken as a sample book here @@ -67,7 +67,7 @@ async def text_context_flow(_input: BookContextInputSchema) -> str: ), Message( role=Role.MODEL, - content=[TextPart(text=f'Here is some analysis based on the text provided.')], + content=[TextPart(text='Here is some analysis based on the text provided.')], metadata={ 'cache': { 'ttl_seconds': 300, @@ -76,7 +76,7 @@ async def text_context_flow(_input: BookContextInputSchema) -> str: ), ], config=GenerationCommonConfig( - version='gemini-1.5-flash-001', + version='gemini-3-flash-preview', temperature=0.7, maxOutputTokens=1000, topK=50, diff --git a/py/samples/google-genai-hello/pyproject.toml b/py/samples/google-genai-hello/pyproject.toml index c85fd1db20..99204b2ebf 100644 --- a/py/samples/google-genai-hello/pyproject.toml +++ b/py/samples/google-genai-hello/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -42,7 +41,7 @@ dependencies = [ "structlog>=25.2.0", ] description = "Hello world sample" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "google-genai-hello" readme = "README.md" requires-python = ">=3.10" diff --git a/py/samples/google-genai-hello/src/main.py b/py/samples/google-genai-hello/src/main.py index 2c8da83e98..de7eb51d2f 100755 --- a/py/samples/google-genai-hello/src/main.py +++ b/py/samples/google-genai-hello/src/main.py @@ -82,7 +82,7 @@ ]) ), ], - model='googleai/gemini-2.5-flash', + model='googleai/gemini-3-flash-preview', ) diff --git a/py/samples/google-genai-image/pyproject.toml b/py/samples/google-genai-image/pyproject.toml index b264014272..5ddaab77e5 100644 --- a/py/samples/google-genai-image/pyproject.toml +++ b/py/samples/google-genai-image/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -40,7 +39,7 @@ dependencies = [ "pydantic>=2.10.5", ] description = "Vision API and Image Generation example" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "google-genai-image" readme = "README.md" requires-python = ">=3.10" diff --git a/py/samples/google-genai-image/src/main.py b/py/samples/google-genai-image/src/main.py index 05f6dbbb08..849a7a0c8e 100755 --- a/py/samples/google-genai-image/src/main.py +++ b/py/samples/google-genai-image/src/main.py @@ -16,7 +16,6 @@ """This sample demonstrates how to use Gemini to describe and draw images.""" -import asyncio import base64 import os from io import BytesIO @@ -40,7 +39,7 @@ async def draw_image_with_gemini() -> str: return await ai.generate( prompt='Draw a cat in a hat.', config={'response_modalities': ['Text', 'Image']}, - model=googleai_name('gemini-2.5-flash'), + model=googleai_name('gemini-2.5-flash-image'), ) @@ -49,22 +48,25 @@ async def describe_image_with_gemini(data: str) -> str: """Describe an image. Args: - data: The image to describe. + data: The image data as a data URI (e.g., 'data:image/jpeg;base64,...'). Returns: The description of the image. """ + if not (data.startswith('data:') and ',' in data): + raise ValueError(f'Expected a data URI (e.g., "data:image/jpeg;base64,..."), but got: {data[:50]}...') + result = await ai.generate( messages=[ Message( role=Role.USER, content=[ TextPart(text='What is shown in this image?'), - MediaPart(media=Media(contentType='image/jpeg', url=data)), + MediaPart(media=Media(content_type='image/jpeg', url=data)), ], ), ], - model=googleai_name('gemini-2.5-flash'), + model=googleai_name('gemini-3-flash-preview'), ) return result.text @@ -79,12 +81,25 @@ async def main() -> None: with open(image_path, 'rb') as image_file: buffer = image_file.read() img_base64 = base64.b64encode(buffer).decode('utf-8') - print(await describe_image_with_gemini(img_base64)) + data_uri = f'data:image/jpeg;base64,{img_base64}' + print(await describe_image_with_gemini(data_uri)) # Gemini draws an image by description. The model used is available only in # Gemini API. result = await draw_image_with_gemini() - decoded_image = BytesIO(base64.b64decode(result.message.content[0].root.media.url)) + + # Find the media part in the response + media_part = next((part.root.media for part in result.message.content if part.root.media is not None), None) + + if media_part is None: + print('No media found in response') + print(f'Response content: {result.message.content}') + return + + media_url = media_part.url + # Extract base64 data after the comma in "data:image/png;base64,..." + base64_data = media_url.split(',', 1)[1] + decoded_image = BytesIO(base64.b64decode(base64_data)) image = Image.open(decoded_image) image.show('Image generated by Gemini') diff --git a/py/samples/google-genai-vertexai-hello/pyproject.toml b/py/samples/google-genai-vertexai-hello/pyproject.toml index d2a14f41eb..3ffa5f5242 100644 --- a/py/samples/google-genai-vertexai-hello/pyproject.toml +++ b/py/samples/google-genai-vertexai-hello/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -40,7 +39,7 @@ dependencies = [ "structlog>=25.2.0", ] description = "Hello world sample on VertexAI API on GenAI" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "google-genai-vertexai-hello" readme = "README.md" requires-python = ">=3.10" diff --git a/py/samples/google-genai-vertexai-hello/src/main.py b/py/samples/google-genai-vertexai-hello/src/main.py index f73041ee76..fe25ec1ba4 100755 --- a/py/samples/google-genai-vertexai-hello/src/main.py +++ b/py/samples/google-genai-vertexai-hello/src/main.py @@ -49,7 +49,6 @@ EmbeddingTaskType, VertexAI, ) -from genkit.plugins.google_genai.models import gemini from genkit.types import ( GenerationCommonConfig, Message, @@ -61,7 +60,7 @@ ai = Genkit( plugins=[VertexAI()], - model='vertexai/gemini-2.5-flash', + model='vertexai/gemini-3-flash-preview', ) diff --git a/py/samples/google-genai-vertexai-image/README.md b/py/samples/google-genai-vertexai-image/README.md index c0d5dba438..244a1be23f 100644 --- a/py/samples/google-genai-vertexai-image/README.md +++ b/py/samples/google-genai-vertexai-image/README.md @@ -9,12 +9,20 @@ Prerequisites: * A Google Cloud account with access to VertexAI service. * The `genkit` package. -To run this sample: +## Setup environment 1. Install the `genkit` package. -2. Install [GCP CLI](https://cloud.google.com/sdk/docs/install) -3. Put your GCP project and location in the code to run VertexAI there. -4. Run the sample. +2. Install [GCP CLI](https://cloud.google.com/sdk/docs/install). +3. Add your project to Google Cloud. Run the following code to log in and set up the configuration. +```bash +export GOOGLE_CLOUD_LOCATION=global +export GOOGLE_CLOUD_PROJECT=your-GCP-project-ID +gcloud init +``` +4. Run the following code to connect to VertexAI. +```bash +gcloud auth application-default login +``` ## Run the sample diff --git a/py/samples/google-genai-vertexai-image/pyproject.toml b/py/samples/google-genai-vertexai-image/pyproject.toml index 232c40b48b..37a8173728 100644 --- a/py/samples/google-genai-vertexai-image/pyproject.toml +++ b/py/samples/google-genai-vertexai-image/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -40,7 +39,7 @@ dependencies = [ "pydantic>=2.10.5", ] description = "Image Generation on VertexAI with GenAI library example" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "google-genai-vertexai-image" readme = "README.md" requires-python = ">=3.10" diff --git a/py/samples/google-genai-vertexai-image/src/main.py b/py/samples/google-genai-vertexai-image/src/main.py index 9ad6bf4535..4bdcc51b5b 100755 --- a/py/samples/google-genai-vertexai-image/src/main.py +++ b/py/samples/google-genai-vertexai-image/src/main.py @@ -16,7 +16,6 @@ """This sample demonstrates how to use Gemini VertexAI to describe and draw images.""" -import asyncio import base64 from io import BytesIO @@ -55,7 +54,10 @@ async def main() -> None: # Imagen draws an image by description. The model used is available only in # VertexAI API. result = await draw_image_with_imagen() - decoded_image = BytesIO(base64.b64decode(result.message.content[0].root.media.url)) + media_url = result.message.content[0].root.media.url + # Extract base64 data after the comma in "data:image/png;base64,..." + base64_data = media_url.split(',', 1)[1] + decoded_image = BytesIO(base64.b64decode(base64_data)) image = Image.open(decoded_image) image.show('Image generated by Gemini') diff --git a/py/samples/menu/pyproject.toml b/py/samples/menu/pyproject.toml index 1ec32f5de9..7ba7975d49 100644 --- a/py/samples/menu/pyproject.toml +++ b/py/samples/menu/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -44,7 +43,7 @@ dependencies = [ "pydantic>=2.10.5", ] description = "menu Genkit sample" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "menu" readme = "README.md" requires-python = ">=3.10" diff --git a/py/samples/menu/src/__init__.py b/py/samples/menu/src/__init__.py index 1e870fbb30..526fdbd696 100644 --- a/py/samples/menu/src/__init__.py +++ b/py/samples/menu/src/__init__.py @@ -15,27 +15,15 @@ # SPDX-License-Identifier: Apache-2.0 # 01 -from case_01.prompts import s01_staticMenuDotPrompt, s01_vanillaPrompt -from case_02.flows import s02_menuQuestionFlow -from case_02.prompts import s02_dataMenuPrompt # 02 -from case_02.tools import menu_tool # 03 -from case_03.flows import s03_multiTurnChatFlow -from case_03.prompts import s03_chatPreamblePrompt # 04 # TODO: uncomment once implemented # from case_04.flows import s04_indexMenuItemsFlow, s04_ragMenuQuestionFlow # from case_04.prompts import s04_ragDataMenuPrompt # 05 -from case_05.flows import ( - s05_readMenuFlow, - s05_textMenuQuestionFlow, - s05_visionMenuQuestionFlow, -) -from case_05.prompts import s05_readMenuPrompt, s05_textMenuPrompt print('All prompts and flows loaded, use the Developer UI to test them out') diff --git a/py/samples/menu/src/case_01/prompts.py b/py/samples/menu/src/case_01/prompts.py index afc8e1c6e1..d800ac67e8 100644 --- a/py/samples/menu/src/case_01/prompts.py +++ b/py/samples/menu/src/case_01/prompts.py @@ -16,21 +16,24 @@ from menu_ai import ai from menu_schemas import MenuQuestionInputSchema -from genkit.plugins.google_genai import google_genai_name -from genkit.plugins.google_genai.models.gemini import GeminiVersion +from genkit.plugins.google_genai import googleai_name +from genkit.plugins.google_genai.models.gemini import GoogleAIGeminiVersion s01_vanillaPrompt = ai.define_prompt( variant='s01_vanillaPrompt', + model=googleai_name(GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), input_schema=MenuQuestionInputSchema, - system="""You are acting as a helpful AI assistant named "Walt" that can answer questions about the food available on the menu at Walt's Burgers.""", + prompt="""You are acting as a helpful AI assistant named "Walt" that can answer +questions about the food available on the menu at Walt's Burgers. +Customer says: {{question}}""", config={'temperature': 0.3}, ) s01_staticMenuDotPrompt = ai.define_prompt( variant='s01_staticMenuDotPrompt', - model=google_genai_name(GeminiVersion.GEMINI_1_5_FLASH), + model=googleai_name(GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), input_schema=MenuQuestionInputSchema, - system=""" + prompt=""" You are acting as a helpful AI assistant named "Walt" that can answer questions about the food available on the menu at Walt's Burgers. Here is today's menu: diff --git a/py/samples/menu/src/case_02/flows.py b/py/samples/menu/src/case_02/flows.py index a3faa04f2e..b0ef54f9d0 100644 --- a/py/samples/menu/src/case_02/flows.py +++ b/py/samples/menu/src/case_02/flows.py @@ -27,5 +27,5 @@ async def s02_menuQuestionFlow( ) -> AnswerOutputSchema: text = await s02_dataMenuPrompt({'question': my_input.question}) return AnswerOutputSchema( - answer=text, + answer=text.text, ) diff --git a/py/samples/menu/src/case_02/prompts.py b/py/samples/menu/src/case_02/prompts.py index d98cf0ae54..1cd9084aec 100644 --- a/py/samples/menu/src/case_02/prompts.py +++ b/py/samples/menu/src/case_02/prompts.py @@ -16,15 +16,15 @@ from menu_ai import ai from menu_schemas import MenuQuestionInputSchema -from genkit.plugins.google_genai import google_genai_name -from genkit.plugins.google_genai.models.gemini import GeminiVersion +from genkit.plugins.google_genai import googleai_name +from genkit.plugins.google_genai.models.gemini import GoogleAIGeminiVersion s02_dataMenuPrompt = ai.define_prompt( variant='s02_dataMenu', - model=google_genai_name(GeminiVersion.GEMINI_1_5_FLASH), + model=googleai_name(GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), input_schema=MenuQuestionInputSchema, - tools=['menu_tool'], - system="""You are acting as a helpful AI assistant named Walt that can answer + tools=['todaysMenu'], + prompt="""You are acting as a helpful AI assistant named Walt that can answer questions about the food available on the menu at Walt's Burgers. Answer this customer's question, in a concise and helpful manner, diff --git a/py/samples/menu/src/case_02/tools.py b/py/samples/menu/src/case_02/tools.py index 3d863aabf4..34719a1d31 100644 --- a/py/samples/menu/src/case_02/tools.py +++ b/py/samples/menu/src/case_02/tools.py @@ -26,8 +26,8 @@ menu_data = json.load(f) -@ai.tool(name='menu_tool') -def menu_tool(input=None) -> MenuToolOutputSchema: +@ai.tool(name='todaysMenu') +def todaysMenu(input=None) -> MenuToolOutputSchema: """Use this tool to retrieve all the items on today's menu.""" return MenuToolOutputSchema( menu_data=menu_data, diff --git a/py/samples/menu/src/case_03/flows.py b/py/samples/menu/src/case_03/flows.py index be70422d0b..430fc11aa4 100644 --- a/py/samples/menu/src/case_03/flows.py +++ b/py/samples/menu/src/case_03/flows.py @@ -26,14 +26,14 @@ from menu_ai import ai from genkit.core.typing import Message, Role, TextPart -from genkit.plugins.google_genai import google_genai_name -from genkit.plugins.google_genai.models.gemini import GeminiVersion +from genkit.plugins.google_genai import googleai_name +from genkit.plugins.google_genai.models.gemini import GoogleAIGeminiVersion as GeminiVersion menu_json_path = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'menu.json') with open(menu_json_path) as f: menu_data = json.load(f) -formatted_menu_data = '\n'.join([f'- ${r["title"]} ${r["price"]}\n${r["description"]}' for r in menu_data]) +formatted_menu_data = '\n'.join([f'- {r["title"]} ${r["price"]}\n{r["description"]}' for r in menu_data]) preamble = [ Message( @@ -43,13 +43,15 @@ ], ), Message( - role=Role.USER, + role=Role.MODEL, content=[ TextPart( - text=f"""I am Walt, a helpful AI assistant here at the restaurant.\n' + - 'I can answer questions about the food on the menu or any other questions\n' + - "you have about food in general. I probably can't help you with anything else.\n" + - "Here is today's menu: \n {formatted_menu_data}\nDo you have any questions about the menu?""" + text=f"""I am Walt, a helpful AI assistant here at the restaurant. +I can answer questions about the food on the menu or any other questions +you have about food in general. I probably can't help you with anything else. +Here is today's menu: +{formatted_menu_data} +Do you have any questions about the menu?""" ), ], ), @@ -67,7 +69,7 @@ async def s03_multiTurnChatFlow( history = chat_history_store.read(my_input.session_id) llm_response = await ai.generate( - model=google_genai_name(GeminiVersion.GEMINI_1_5_FLASH), + model=googleai_name(GeminiVersion.GEMINI_3_FLASH_PREVIEW), messages=history, prompt=[TextPart(text=my_input.question)], ) diff --git a/py/samples/menu/src/case_03/prompts.py b/py/samples/menu/src/case_03/prompts.py index 847c0ff8e4..56ad18a4cb 100644 --- a/py/samples/menu/src/case_03/prompts.py +++ b/py/samples/menu/src/case_03/prompts.py @@ -17,12 +17,12 @@ from menu_ai import ai from menu_schemas import DataMenuQuestionInputSchema -from genkit.plugins.google_genai import google_genai_name -from genkit.plugins.google_genai.models.gemini import GeminiVersion +from genkit.plugins.google_genai import googleai_name +from genkit.plugins.google_genai.models.gemini import GoogleAIGeminiVersion s03_chatPreamblePrompt = ai.define_prompt( variant='s03_chatPreamble', - model=google_genai_name(GeminiVersion.GEMINI_1_5_FLASH), + model=googleai_name(GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), input_schema=DataMenuQuestionInputSchema, config={'temperature': 0.3}, system="""{{ role "user" }} diff --git a/py/samples/menu/src/case_04/flows.py b/py/samples/menu/src/case_04/flows.py index 27f506da35..6a717e038e 100644 --- a/py/samples/menu/src/case_04/flows.py +++ b/py/samples/menu/src/case_04/flows.py @@ -1,4 +1,4 @@ -# Copyright 2025 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,5 +14,48 @@ # # SPDX-License-Identifier: Apache-2.0 +from menu_ai import ai +from menu_schemas import AnswerOutputSchema, MenuItemSchema, MenuQuestionInputSchema +from pydantic import BaseModel, Field -# TODO: implement it once Genkit AI will have index API +from genkit.blocks.document import Document + +from .prompts import s04_ragDataMenuPrompt + + +class IndexMenuItemsOutputSchema(BaseModel): + rows: int = Field(...) + + +@ai.flow(name='s04_indexMenuItems') +async def s04_indexMenuItemsFlow( + menu_items: list[MenuItemSchema], +) -> IndexMenuItemsOutputSchema: + documents = [ + Document.from_text(f'{item.title} {item.price} \n {item.description}', metadata=item.model_dump()) + for item in menu_items + ] + + await ai.index( + indexer='menu-items', + documents=documents, + ) + return IndexMenuItemsOutputSchema(rows=len(menu_items)) + + +@ai.flow(name='s04_ragMenuQuestion') +async def s04_ragMenuQuestionFlow( + my_input: MenuQuestionInputSchema, +) -> AnswerOutputSchema: + # Retrieve the 3 most relevant menu items for the question + docs = await ai.retrieve( + retriever='menu-items', + query=my_input.question, + options={'k': 3}, + ) + + menu_data = [doc.metadata for doc in docs.documents] + + # Generate the response + response = await s04_ragDataMenuPrompt({'menuData': menu_data, 'question': my_input.question}) + return AnswerOutputSchema(answer=response.text) diff --git a/py/samples/menu/src/case_04/prompts.py b/py/samples/menu/src/case_04/prompts.py index 5014800727..8fb88a34bb 100644 --- a/py/samples/menu/src/case_04/prompts.py +++ b/py/samples/menu/src/case_04/prompts.py @@ -16,15 +16,15 @@ from menu_ai import ai from menu_schemas import DataMenuQuestionInputSchema -from genkit.plugins.google_genai import google_genai_name -from genkit.plugins.google_genai.models.gemini import GeminiVersion +from genkit.plugins.google_genai import googleai_name +from genkit.plugins.google_genai.models.gemini import GoogleAIGeminiVersion s04_ragDataMenuPrompt = ai.define_prompt( variant='s04_ragDataMenu', - model=google_genai_name(GeminiVersion.GEMINI_1_5_FLASH), + model=googleai_name(GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), input_schema=DataMenuQuestionInputSchema, config={'temperature': 0.3}, - system=""" + prompt=""" You are acting as Walt, a helpful AI assistant here at the restaurant. You can answer questions about the food on the menu or any other questions customers have about food in general. diff --git a/py/samples/menu/src/case_05/flows.py b/py/samples/menu/src/case_05/flows.py index 10e469e8e9..003036949f 100644 --- a/py/samples/menu/src/case_05/flows.py +++ b/py/samples/menu/src/case_05/flows.py @@ -28,27 +28,20 @@ ) -@ai.flow(name='s05_readMenuFlow') -async def s05_readMenuFlow(_) -> ReadMenuPromptOutputSchema: +@ai.flow(name='s05_readMenu') +async def s05_readMenuFlow(_: None = None) -> str: image_data_url = inline_data_url('menu.jpeg', 'image/jpeg') - response = await s05_readMenuPrompt( - image_url=image_data_url, - ) - return ReadMenuPromptOutputSchema( - menu_text=response.text, - ) + response = await s05_readMenuPrompt({'imageUrl': image_data_url}) + return response.text @ai.flow(name='s05_textMenuQuestion') async def s05_textMenuQuestionFlow( my_input: TextMenuQuestionInputSchema, ) -> AnswerOutputSchema: - response = await s05_textMenuPrompt( - menu_text=my_input.menu_text, - question=my_input.question, - ) - return ReadMenuPromptOutputSchema( - menu_text=response.text, + response = await s05_textMenuPrompt({'menuText': my_input.menuText, 'question': my_input.question}) + return AnswerOutputSchema( + answer=response.text, ) @@ -56,11 +49,11 @@ async def s05_textMenuQuestionFlow( async def s05_visionMenuQuestionFlow( my_input: MenuQuestionInputSchema, ) -> AnswerOutputSchema: - menu_result = await s05_readMenuFlow() - return s05_textMenuQuestionFlow( - my_input=TextMenuQuestionInputSchema( + menu_text = await s05_readMenuFlow() + return await s05_textMenuQuestionFlow( + TextMenuQuestionInputSchema( question=my_input.question, - menu_text=menu_result.menu_text, + menuText=menu_text, ) ) diff --git a/py/samples/menu/src/case_05/prompts.py b/py/samples/menu/src/case_05/prompts.py index fcd865e4a8..dceb98bf98 100644 --- a/py/samples/menu/src/case_05/prompts.py +++ b/py/samples/menu/src/case_05/prompts.py @@ -16,34 +16,34 @@ from menu_ai import ai from menu_schemas import ReadMenuImagePromptSchema, TextMenuQuestionInputSchema -from genkit.plugins.google_genai import google_genai_name -from genkit.plugins.google_genai.models.gemini import GeminiVersion +from genkit.plugins.google_genai import googleai_name +from genkit.plugins.google_genai.models.gemini import GoogleAIGeminiVersion s05_readMenuPrompt = ai.define_prompt( variant='s05_readMenu', - model=google_genai_name(GeminiVersion.GEMINI_1_5_FLASH), + model=googleai_name(GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), input_schema=ReadMenuImagePromptSchema, config={'temperature': 0.1}, - system=""" + prompt=""" Extract _all_ of the text, in order, from the following image of a restaurant menu. -{{media url=image_url}} +{{media url=imageUrl}} """, ) s05_textMenuPrompt = ai.define_prompt( variant='s05_textMenu', - model=google_genai_name(GeminiVersion.GEMINI_1_5_FLASH), + model=googleai_name(GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), input_schema=TextMenuQuestionInputSchema, config={'temperature': 0.3}, - system=""" + prompt=""" You are acting as Walt, a helpful AI assistant here at the restaurant. You can answer questions about the food on the menu or any other questions customers have about food in general. Here is the text of today's menu to help you answer the customer's question: -{{menu_text}} +{{menuText}} Answer this customer's question: {{question}}? diff --git a/py/samples/menu/src/main.py b/py/samples/menu/src/main.py index e083a7ff63..1be31e3566 100755 --- a/py/samples/menu/src/main.py +++ b/py/samples/menu/src/main.py @@ -14,17 +14,18 @@ # # SPDX-License-Identifier: Apache-2.0 -"""A stub for the sample to come.""" - - -def main() -> None: - """Main entry point for the menu sample. - - This function demonstrates how to use Genkit to build an interactive - menu system. - """ - print('Hey') - +# Import all of the example prompts and flows to ensure they are registered +import case_01.prompts +import case_02.flows +import case_02.prompts +import case_02.tools +import case_03.flows +import case_03.prompts +import case_04.flows +import case_04.prompts +import case_05.flows +import case_05.prompts +from menu_ai import ai if __name__ == '__main__': - main() + ai.run_main() diff --git a/py/samples/menu/src/menu_ai.py b/py/samples/menu/src/menu_ai.py index c0059eb05d..59288684b5 100644 --- a/py/samples/menu/src/menu_ai.py +++ b/py/samples/menu/src/menu_ai.py @@ -17,15 +17,14 @@ from genkit.ai import Genkit from genkit.plugins.dev_local_vectorstore import DevLocalVectorStore -from genkit.plugins.google_genai import VertexAI -from genkit.plugins.vertex_ai import EmbeddingModels +from genkit.plugins.google_genai import GeminiEmbeddingModels, GoogleAI, googleai_name ai = Genkit( plugins=[ - VertexAI(), + GoogleAI(), DevLocalVectorStore( - index_name='menu-items', - embedder=EmbeddingModels.TEXT_EMBEDDING_004_ENG, + name='menu-items', + embedder=googleai_name(GeminiEmbeddingModels.TEXT_EMBEDDING_004), embedder_options={'taskType': 'RETRIEVAL_DOCUMENT'}, ), ] diff --git a/py/samples/model-garden/pyproject.toml b/py/samples/model-garden/pyproject.toml index dc44c2818a..98dc6490f0 100644 --- a/py/samples/model-garden/pyproject.toml +++ b/py/samples/model-garden/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", @@ -31,9 +30,9 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Software Development :: Libraries", ] -dependencies = ["genkit", "pydantic>=2.10.5"] +dependencies = ["genkit", "genkit-plugin-vertex-ai", "pydantic>=2.10.5"] description = "Model Garden sample" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "model-garden-example" readme = "README.md" requires-python = ">=3.10" diff --git a/py/samples/multi-server/pyproject.toml b/py/samples/multi-server/pyproject.toml index d6718b4d54..7163fd5d6d 100644 --- a/py/samples/multi-server/pyproject.toml +++ b/py/samples/multi-server/pyproject.toml @@ -15,13 +15,13 @@ # SPDX-License-Identifier: Apache-2.0 [project] +authors = [{ name = "Google" }] classifiers = [ "Development Status :: 3 - Alpha", "Environment :: Console", "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -42,6 +42,7 @@ dependencies = [ "uvicorn>=0.34.0", ] description = "Sample implementation to exercise the Genkit multi server manager." +license = "Apache-2.0" name = "multi-server" readme = "README.md" requires-python = ">=3.10" diff --git a/py/samples/ollama-hello/pyproject.toml b/py/samples/ollama-hello/pyproject.toml index 57a92ec67c..1a6fc009b3 100644 --- a/py/samples/ollama-hello/pyproject.toml +++ b/py/samples/ollama-hello/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -40,7 +39,7 @@ dependencies = [ "structlog>=25.2.0", ] description = "Ollama hello sample" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "ollama-hello" readme = "README.md" requires-python = ">=3.10" diff --git a/py/samples/ollama-hello/src/main.py b/py/samples/ollama-hello/src/main.py index 98a1cd7f76..668de256de 100755 --- a/py/samples/ollama-hello/src/main.py +++ b/py/samples/ollama-hello/src/main.py @@ -33,9 +33,6 @@ """ -import asyncio -import json - import structlog from pydantic import BaseModel, Field diff --git a/py/samples/ollama-simple-embed/pyproject.toml b/py/samples/ollama-simple-embed/pyproject.toml index 28e8ee8d28..d48e850e20 100644 --- a/py/samples/ollama-simple-embed/pyproject.toml +++ b/py/samples/ollama-simple-embed/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -40,7 +39,7 @@ dependencies = [ "structlog>=25.2.0", ] description = "Ollama Simple Embed" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "ollama_simple_embed" readme = "README.md" requires-python = ">=3.10" diff --git a/py/samples/prompt_demo/prompts/hello.prompt b/py/samples/prompt_demo/prompts/hello.prompt index 1824e7e97b..790c214694 100644 --- a/py/samples/prompt_demo/prompts/hello.prompt +++ b/py/samples/prompt_demo/prompts/hello.prompt @@ -1,5 +1,5 @@ --- -model: googleai/gemini-2.5-flash +model: googleai/gemini-3-flash-preview input: schema: name: string diff --git a/py/samples/prompt_demo/pyproject.toml b/py/samples/prompt_demo/pyproject.toml index 37ef4a2eda..e5b36a1a72 100644 --- a/py/samples/prompt_demo/pyproject.toml +++ b/py/samples/prompt_demo/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -35,7 +34,7 @@ classifiers = [ ] dependencies = ["genkit", "structlog>=25.2.0", "genkit-plugin-google-genai"] description = "Genkit prompt demo" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "prompt-demo" requires-python = ">=3.10" version = "0.0.1" diff --git a/py/samples/prompt_demo/src/main.py b/py/samples/prompt_demo/src/main.py index 723821090e..2c6c9b4933 100755 --- a/py/samples/prompt_demo/src/main.py +++ b/py/samples/prompt_demo/src/main.py @@ -14,7 +14,6 @@ # # SPDX-License-Identifier: Apache-2.0 -import asyncio from pathlib import Path import structlog @@ -29,7 +28,7 @@ current_dir = Path(__file__).resolve().parent prompts_path = current_dir.parent / 'prompts' -ai = Genkit(plugins=[GoogleAI()], model='googleai/gemini-2.5-flash', prompt_dir=prompts_path) +ai = Genkit(plugins=[GoogleAI()], model='googleai/gemini-3-flash-preview', prompt_dir=prompts_path) def my_helper(content, *_, **__): diff --git a/py/samples/short-n-long/pyproject.toml b/py/samples/short-n-long/pyproject.toml index fa46fd5242..1a6dace235 100644 --- a/py/samples/short-n-long/pyproject.toml +++ b/py/samples/short-n-long/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -43,7 +42,7 @@ dependencies = [ "uvloop>=0.21.0", ] description = "Short and long sample" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "short-n-long" readme = "README.md" requires-python = ">=3.10" diff --git a/py/samples/short-n-long/src/main.py b/py/samples/short-n-long/src/main.py index 9179d0bed2..75f8f01cf5 100755 --- a/py/samples/short-n-long/src/main.py +++ b/py/samples/short-n-long/src/main.py @@ -69,7 +69,7 @@ ai = Genkit( plugins=[GoogleAI()], - model=googleai_name('gemini-2.5-flash'), + model=googleai_name('gemini-3-flash-preview'), ) @@ -103,7 +103,7 @@ async def simple_generate_with_tools_flow(value: int) -> str: The generated response with a function. """ response = await ai.generate( - model=googleai_name(gemini.GoogleAIGeminiVersion.GEMINI_2_0_FLASH), + model=googleai_name(gemini.GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), messages=[ Message( role=Role.USER, @@ -140,7 +140,7 @@ async def simple_generate_with_interrupts(value: int) -> str: The generated response with a function. """ response1 = await ai.generate( - model=googleai_name(gemini.GoogleAIGeminiVersion.GEMINI_2_0_FLASH), + model=googleai_name(gemini.GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), messages=[ Message( role=Role.USER, @@ -155,7 +155,7 @@ async def simple_generate_with_interrupts(value: int) -> str: tr = tool_response(response1.interrupts[0], 178) response = await ai.generate( - model=googleai_name(gemini.GoogleAIGeminiVersion.GEMINI_2_0_FLASH), + model=googleai_name(gemini.GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), messages=response1.messages, tool_responses=[tr], tools=['gablorkenTool'], diff --git a/py/samples/tool-interrupts/pyproject.toml b/py/samples/tool-interrupts/pyproject.toml index c4e5e57bbc..20391edb09 100644 --- a/py/samples/tool-interrupts/pyproject.toml +++ b/py/samples/tool-interrupts/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -35,7 +34,7 @@ classifiers = [ ] dependencies = ["genkit", "genkit-plugin-google-genai", "pydantic>=2.10.5"] description = "Tool interrupts sample" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "tool-interrupts" readme = "README.md" requires-python = ">=3.10" diff --git a/py/samples/tool-interrupts/src/main.py b/py/samples/tool-interrupts/src/main.py index f71f44b254..fdc14f8688 100755 --- a/py/samples/tool-interrupts/src/main.py +++ b/py/samples/tool-interrupts/src/main.py @@ -28,7 +28,7 @@ ai = Genkit( plugins=[GoogleAI()], - model=googleai_name(gemini.GoogleAIGeminiVersion.GEMINI_2_0_FLASH), + model=googleai_name(gemini.GoogleAIGeminiVersion.GEMINI_3_FLASH_PREVIEW), ) diff --git a/py/samples/vertex-ai-vector-search-bigquery/pyproject.toml b/py/samples/vertex-ai-vector-search-bigquery/pyproject.toml index 5330a5705b..9b15b76310 100644 --- a/py/samples/vertex-ai-vector-search-bigquery/pyproject.toml +++ b/py/samples/vertex-ai-vector-search-bigquery/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -43,7 +42,7 @@ dependencies = [ "strenum>=0.4.15; python_version < '3.11'", ] description = "An example demonstrating the use Vector Search API with BigQuery retriever for Vertex AI" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "vertex-ai-vector-search-bigquery" readme = "README.md" requires-python = ">=3.10" diff --git a/py/samples/vertex-ai-vector-search-firestore/pyproject.toml b/py/samples/vertex-ai-vector-search-firestore/pyproject.toml index 6ff3f349fd..99fd0c758d 100644 --- a/py/samples/vertex-ai-vector-search-firestore/pyproject.toml +++ b/py/samples/vertex-ai-vector-search-firestore/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -43,7 +42,7 @@ dependencies = [ "strenum>=0.4.15; python_version < '3.11'", ] description = "An example demonstrating the use Vector Search API with Firestore retriever for Vertex AI" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "vertex-ai-vector-search-firestore" readme = "README.md" requires-python = ">=3.10" diff --git a/py/samples/xai-hello/pyproject.toml b/py/samples/xai-hello/pyproject.toml index 65b7aff42b..5fa7f84d4b 100644 --- a/py/samples/xai-hello/pyproject.toml +++ b/py/samples/xai-hello/pyproject.toml @@ -15,6 +15,7 @@ # SPDX-License-Identifier: Apache-2.0 [project] +authors = [{ name = "Google" }] dependencies = [ "genkit", "genkit-plugin-xai", diff --git a/py/tests/smoke/pyproject.toml b/py/tests/smoke/pyproject.toml index 4605d56264..9e31d9de6c 100644 --- a/py/tests/smoke/pyproject.toml +++ b/py/tests/smoke/pyproject.toml @@ -15,13 +15,13 @@ # SPDX-License-Identifier: Apache-2.0 [project] +authors = [{ name = "Google" }] classifiers = [ "Development Status :: 3 - Alpha", "Environment :: Console", "Environment :: Web Environment", "Intended Audience :: Developers", "Operating System :: OS Independent", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", @@ -42,7 +42,7 @@ dependencies = [ "strenum>=0.4.15; python_version < '3.11'", ] description = "Packaging smoke test" -license = { text = "Apache-2.0" } +license = "Apache-2.0" name = "smoke" readme = "README.md" requires-python = ">=3.10" diff --git a/py/uv.lock b/py/uv.lock index 560e88909a..cca42bd6bf 100644 --- a/py/uv.lock +++ b/py/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.14'", @@ -12,6 +12,7 @@ resolution-markers = [ members = [ "anthropic-hello", "compat-oai-hello", + "deepseek-hello", "dev-local-vectorstore-hello", "eval-demo", "firestore-retreiver", @@ -19,6 +20,7 @@ members = [ "genkit", "genkit-plugin-anthropic", "genkit-plugin-compat-oai", + "genkit-plugin-deepseek", "genkit-plugin-dev-local-vectorstore", "genkit-plugin-evaluators", "genkit-plugin-firebase", @@ -28,6 +30,7 @@ members = [ "genkit-plugin-ollama", "genkit-plugin-vertex-ai", "genkit-plugin-xai", + "genkit-plugins-mcp", "genkit-workspace", "google-genai-code-execution", "google-genai-context-caching", @@ -796,6 +799,7 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/13/1f/9fa001e74a1993a9cadd2333bb889e50c66327b8594ac538ab8a04f915b7/cryptography-45.0.3.tar.gz", hash = "sha256:ec21313dd335c51d7877baf2972569f40a4291b76a0ce51391523ae358d05899", size = 744738, upload-time = "2025-05-25T14:17:24.777Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/82/b2/2345dc595998caa6f68adf84e8f8b50d18e9fc4638d32b22ea8daedd4b7a/cryptography-45.0.3-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:7573d9eebaeceeb55285205dbbb8753ac1e962af3d9640791d12b36864065e71", size = 7056239, upload-time = "2025-05-25T14:16:12.22Z" }, { url = "https://files.pythonhosted.org/packages/71/3d/ac361649a0bfffc105e2298b720d8b862330a767dab27c06adc2ddbef96a/cryptography-45.0.3-cp311-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d377dde61c5d67eb4311eace661c3efda46c62113ff56bf05e2d679e02aebb5b", size = 4205541, upload-time = "2025-05-25T14:16:14.333Z" }, { url = "https://files.pythonhosted.org/packages/70/3e/c02a043750494d5c445f769e9c9f67e550d65060e0bfce52d91c1362693d/cryptography-45.0.3-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fae1e637f527750811588e4582988932c222f8251f7b7ea93739acb624e1487f", size = 4433275, upload-time = "2025-05-25T14:16:16.421Z" }, { url = "https://files.pythonhosted.org/packages/40/7a/9af0bfd48784e80eef3eb6fd6fde96fe706b4fc156751ce1b2b965dada70/cryptography-45.0.3-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ca932e11218bcc9ef812aa497cdf669484870ecbcf2d99b765d6c27a86000942", size = 4209173, upload-time = "2025-05-25T14:16:18.163Z" }, @@ -805,6 +809,9 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e7/53/8a130e22c1e432b3c14896ec5eb7ac01fb53c6737e1d705df7e0efb647c6/cryptography-45.0.3-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:c824c9281cb628015bfc3c59335163d4ca0540d49de4582d6c2637312907e4b1", size = 4466300, upload-time = "2025-05-25T14:16:26.768Z" }, { url = "https://files.pythonhosted.org/packages/ba/75/6bb6579688ef805fd16a053005fce93944cdade465fc92ef32bbc5c40681/cryptography-45.0.3-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:5833bb4355cb377ebd880457663a972cd044e7f49585aee39245c0d592904578", size = 4332483, upload-time = "2025-05-25T14:16:28.316Z" }, { url = "https://files.pythonhosted.org/packages/2f/11/2538f4e1ce05c6c4f81f43c1ef2bd6de7ae5e24ee284460ff6c77e42ca77/cryptography-45.0.3-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:9bb5bf55dcb69f7067d80354d0a348368da907345a2c448b0babc4215ccd3497", size = 4573714, upload-time = "2025-05-25T14:16:30.474Z" }, + { url = "https://files.pythonhosted.org/packages/f5/bb/e86e9cf07f73a98d84a4084e8fd420b0e82330a901d9cac8149f994c3417/cryptography-45.0.3-cp311-abi3-win32.whl", hash = "sha256:3ad69eeb92a9de9421e1f6685e85a10fbcfb75c833b42cc9bc2ba9fb00da4710", size = 2934752, upload-time = "2025-05-25T14:16:32.204Z" }, + { url = "https://files.pythonhosted.org/packages/c7/75/063bc9ddc3d1c73e959054f1fc091b79572e716ef74d6caaa56e945b4af9/cryptography-45.0.3-cp311-abi3-win_amd64.whl", hash = "sha256:97787952246a77d77934d41b62fb1b6f3581d83f71b44796a4158d93b8f5c490", size = 3412465, upload-time = "2025-05-25T14:16:33.888Z" }, + { url = "https://files.pythonhosted.org/packages/71/9b/04ead6015229a9396890d7654ee35ef630860fb42dc9ff9ec27f72157952/cryptography-45.0.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:c92519d242703b675ccefd0f0562eb45e74d438e001f8ab52d628e885751fb06", size = 7031892, upload-time = "2025-05-25T14:16:36.214Z" }, { url = "https://files.pythonhosted.org/packages/46/c7/c7d05d0e133a09fc677b8a87953815c522697bdf025e5cac13ba419e7240/cryptography-45.0.3-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5edcb90da1843df85292ef3a313513766a78fbbb83f584a5a58fb001a5a9d57", size = 4196181, upload-time = "2025-05-25T14:16:37.934Z" }, { url = "https://files.pythonhosted.org/packages/08/7a/6ad3aa796b18a683657cef930a986fac0045417e2dc428fd336cfc45ba52/cryptography-45.0.3-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38deed72285c7ed699864f964a3f4cf11ab3fb38e8d39cfcd96710cd2b5bb716", size = 4423370, upload-time = "2025-05-25T14:16:39.502Z" }, { url = "https://files.pythonhosted.org/packages/4f/58/ec1461bfcb393525f597ac6a10a63938d18775b7803324072974b41a926b/cryptography-45.0.3-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:5555365a50efe1f486eed6ac7062c33b97ccef409f5970a0b6f205a7cfab59c8", size = 4197839, upload-time = "2025-05-25T14:16:41.322Z" }, @@ -814,14 +821,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/71/7a/e002d5ce624ed46dfc32abe1deff32190f3ac47ede911789ee936f5a4255/cryptography-45.0.3-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:57a6500d459e8035e813bd8b51b671977fb149a8c95ed814989da682314d0782", size = 4450308, upload-time = "2025-05-25T14:16:48.228Z" }, { url = "https://files.pythonhosted.org/packages/87/ad/3fbff9c28cf09b0a71e98af57d74f3662dea4a174b12acc493de00ea3f28/cryptography-45.0.3-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:f22af3c78abfbc7cbcdf2c55d23c3e022e1a462ee2481011d518c7fb9c9f3d65", size = 4325125, upload-time = "2025-05-25T14:16:49.844Z" }, { url = "https://files.pythonhosted.org/packages/f5/b4/51417d0cc01802304c1984d76e9592f15e4801abd44ef7ba657060520bf0/cryptography-45.0.3-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:232954730c362638544758a8160c4ee1b832dc011d2c41a306ad8f7cccc5bb0b", size = 4560038, upload-time = "2025-05-25T14:16:51.398Z" }, + { url = "https://files.pythonhosted.org/packages/80/38/d572f6482d45789a7202fb87d052deb7a7b136bf17473ebff33536727a2c/cryptography-45.0.3-cp37-abi3-win32.whl", hash = "sha256:cb6ab89421bc90e0422aca911c69044c2912fc3debb19bb3c1bfe28ee3dff6ab", size = 2924070, upload-time = "2025-05-25T14:16:53.472Z" }, + { url = "https://files.pythonhosted.org/packages/91/5a/61f39c0ff4443651cc64e626fa97ad3099249152039952be8f344d6b0c86/cryptography-45.0.3-cp37-abi3-win_amd64.whl", hash = "sha256:d54ae41e6bd70ea23707843021c778f151ca258081586f0cfa31d936ae43d1b2", size = 3395005, upload-time = "2025-05-25T14:16:55.134Z" }, + { url = "https://files.pythonhosted.org/packages/1b/63/ce30cb7204e8440df2f0b251dc0464a26c55916610d1ba4aa912f838bcc8/cryptography-45.0.3-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:ed43d396f42028c1f47b5fec012e9e12631266e3825e95c00e3cf94d472dac49", size = 3578348, upload-time = "2025-05-25T14:16:56.792Z" }, { url = "https://files.pythonhosted.org/packages/45/0b/87556d3337f5e93c37fda0a0b5d3e7b4f23670777ce8820fce7962a7ed22/cryptography-45.0.3-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:fed5aaca1750e46db870874c9c273cd5182a9e9deb16f06f7bdffdb5c2bde4b9", size = 4142867, upload-time = "2025-05-25T14:16:58.459Z" }, { url = "https://files.pythonhosted.org/packages/72/ba/21356dd0bcb922b820211336e735989fe2cf0d8eaac206335a0906a5a38c/cryptography-45.0.3-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:00094838ecc7c6594171e8c8a9166124c1197b074cfca23645cee573910d76bc", size = 4385000, upload-time = "2025-05-25T14:17:00.656Z" }, { url = "https://files.pythonhosted.org/packages/2f/2b/71c78d18b804c317b66283be55e20329de5cd7e1aec28e4c5fbbe21fd046/cryptography-45.0.3-pp310-pypy310_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:92d5f428c1a0439b2040435a1d6bc1b26ebf0af88b093c3628913dd464d13fa1", size = 4144195, upload-time = "2025-05-25T14:17:02.782Z" }, { url = "https://files.pythonhosted.org/packages/55/3e/9f9b468ea779b4dbfef6af224804abd93fbcb2c48605d7443b44aea77979/cryptography-45.0.3-pp310-pypy310_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:ec64ee375b5aaa354b2b273c921144a660a511f9df8785e6d1c942967106438e", size = 4384540, upload-time = "2025-05-25T14:17:04.49Z" }, + { url = "https://files.pythonhosted.org/packages/97/f5/6e62d10cf29c50f8205c0dc9aec986dca40e8e3b41bf1a7878ea7b11e5ee/cryptography-45.0.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:71320fbefd05454ef2d457c481ba9a5b0e540f3753354fff6f780927c25d19b0", size = 3328796, upload-time = "2025-05-25T14:17:06.174Z" }, + { url = "https://files.pythonhosted.org/packages/e7/d4/58a246342093a66af8935d6aa59f790cbb4731adae3937b538d054bdc2f9/cryptography-45.0.3-pp311-pypy311_pp73-macosx_10_9_x86_64.whl", hash = "sha256:edd6d51869beb7f0d472e902ef231a9b7689508e83880ea16ca3311a00bf5ce7", size = 3589802, upload-time = "2025-05-25T14:17:07.792Z" }, { url = "https://files.pythonhosted.org/packages/96/61/751ebea58c87b5be533c429f01996050a72c7283b59eee250275746632ea/cryptography-45.0.3-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:555e5e2d3a53b4fabeca32835878b2818b3f23966a4efb0d566689777c5a12c8", size = 4146964, upload-time = "2025-05-25T14:17:09.538Z" }, { url = "https://files.pythonhosted.org/packages/8d/01/28c90601b199964de383da0b740b5156f5d71a1da25e7194fdf793d373ef/cryptography-45.0.3-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:25286aacb947286620a31f78f2ed1a32cded7be5d8b729ba3fb2c988457639e4", size = 4388103, upload-time = "2025-05-25T14:17:11.978Z" }, { url = "https://files.pythonhosted.org/packages/3d/ec/cd892180b9e42897446ef35c62442f5b8b039c3d63a05f618aa87ec9ebb5/cryptography-45.0.3-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:050ce5209d5072472971e6efbfc8ec5a8f9a841de5a4db0ebd9c2e392cb81972", size = 4150031, upload-time = "2025-05-25T14:17:14.131Z" }, { url = "https://files.pythonhosted.org/packages/db/d4/22628c2dedd99289960a682439c6d3aa248dff5215123ead94ac2d82f3f5/cryptography-45.0.3-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:dc10ec1e9f21f33420cc05214989544727e776286c1c16697178978327b95c9c", size = 4387389, upload-time = "2025-05-25T14:17:17.303Z" }, + { url = "https://files.pythonhosted.org/packages/39/ec/ba3961abbf8ecb79a3586a4ff0ee08c9d7a9938b4312fb2ae9b63f48a8ba/cryptography-45.0.3-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:9eda14f049d7f09c2e8fb411dda17dd6b16a3c76a1de5e249188a32aeb92de19", size = 3337432, upload-time = "2025-05-25T14:17:19.507Z" }, ] [[package]] @@ -928,6 +941,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a", size = 9190, upload-time = "2025-02-24T04:41:32.565Z" }, ] +[[package]] +name = "deepseek-hello" +version = "0.1.0" +source = { editable = "samples/deepseek-hello" } +dependencies = [ + { name = "genkit" }, + { name = "genkit-plugin-deepseek" }, + { name = "pydantic" }, + { name = "structlog" }, +] + +[package.metadata] +requires-dist = [ + { name = "genkit", editable = "packages/genkit" }, + { name = "genkit-plugin-deepseek", editable = "plugins/deepseek" }, + { name = "pydantic", specifier = ">=2.0.0" }, + { name = "structlog", specifier = ">=24.0.0" }, +] + [[package]] name = "defusedxml" version = "0.7.1" @@ -1615,6 +1647,23 @@ requires-dist = [ { name = "strenum", marker = "python_full_version < '3.11'", specifier = ">=0.4.15" }, ] +[[package]] +name = "genkit-plugin-deepseek" +version = "0.1.0" +source = { editable = "plugins/deepseek" } +dependencies = [ + { name = "genkit" }, + { name = "genkit-plugin-compat-oai" }, + { name = "openai" }, +] + +[package.metadata] +requires-dist = [ + { name = "genkit", editable = "packages/genkit" }, + { name = "genkit-plugin-compat-oai", editable = "plugins/compat-oai" }, + { name = "openai", specifier = ">=1.0.0" }, +] + [[package]] name = "genkit-plugin-dev-local-vectorstore" version = "0.4.0" @@ -1753,6 +1802,7 @@ version = "0.4.0" source = { editable = "plugins/vertex-ai" } dependencies = [ { name = "genkit" }, + { name = "genkit-plugin-compat-oai" }, { name = "google-cloud-aiplatform" }, { name = "google-cloud-bigquery" }, { name = "google-cloud-firestore" }, @@ -1764,6 +1814,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "genkit", editable = "packages/genkit" }, + { name = "genkit-plugin-compat-oai", editable = "plugins/compat-oai" }, { name = "google-cloud-aiplatform", specifier = ">=1.77.0" }, { name = "google-cloud-bigquery" }, { name = "google-cloud-firestore" }, @@ -1787,6 +1838,21 @@ requires-dist = [ { name = "xai-sdk", specifier = ">=0.0.1" }, ] +[[package]] +name = "genkit-plugins-mcp" +version = "0.1.0" +source = { editable = "plugins/mcp" } +dependencies = [ + { name = "genkit" }, + { name = "mcp" }, +] + +[package.metadata] +requires-dist = [ + { name = "genkit", editable = "packages/genkit" }, + { name = "mcp" }, +] + [[package]] name = "genkit-workspace" version = "0.1.0" @@ -1796,6 +1862,7 @@ dependencies = [ { name = "genkit" }, { name = "genkit-plugin-anthropic" }, { name = "genkit-plugin-compat-oai" }, + { name = "genkit-plugin-deepseek" }, { name = "genkit-plugin-dev-local-vectorstore" }, { name = "genkit-plugin-evaluators" }, { name = "genkit-plugin-firebase" }, @@ -1806,6 +1873,7 @@ dependencies = [ { name = "genkit-plugin-vertex-ai" }, { name = "genkit-plugin-xai" }, { name = "liccheck" }, + { name = "mcp" }, { name = "strenum", marker = "python_full_version < '3.11'" }, ] @@ -1830,8 +1898,8 @@ dev = [ { name = "twine" }, ] lint = [ - { name = "mypy" }, { name = "ruff" }, + { name = "ty" }, ] [package.metadata] @@ -1840,6 +1908,7 @@ requires-dist = [ { name = "genkit", editable = "packages/genkit" }, { name = "genkit-plugin-anthropic", editable = "plugins/anthropic" }, { name = "genkit-plugin-compat-oai", editable = "plugins/compat-oai" }, + { name = "genkit-plugin-deepseek", editable = "plugins/deepseek" }, { name = "genkit-plugin-dev-local-vectorstore", editable = "plugins/dev-local-vectorstore" }, { name = "genkit-plugin-evaluators", editable = "plugins/evaluators" }, { name = "genkit-plugin-firebase", editable = "plugins/firebase" }, @@ -1850,6 +1919,7 @@ requires-dist = [ { name = "genkit-plugin-vertex-ai", editable = "plugins/vertex-ai" }, { name = "genkit-plugin-xai", editable = "plugins/xai" }, { name = "liccheck", specifier = ">=0.9.2" }, + { name = "mcp", specifier = ">=1.25.0" }, { name = "strenum", marker = "python_full_version < '3.11'", specifier = ">=0.4.15" }, ] @@ -1874,8 +1944,8 @@ dev = [ { name = "twine", specifier = ">=6.1.0" }, ] lint = [ - { name = "mypy", specifier = ">=1.15" }, { name = "ruff", specifier = ">=0.9" }, + { name = "ty", specifier = ">=0.0.1" }, ] [[package]] @@ -2499,6 +2569,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, ] +[[package]] +name = "httpx-sse" +version = "0.4.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/4c/751061ffa58615a32c31b2d82e8482be8dd4a89154f003147acee90f2be9/httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d", size = 15943, upload-time = "2025-10-10T21:48:22.271Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/fd/6668e5aec43ab844de6fc74927e155a3b37bf40d7c3790e49fc0406b6578/httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc", size = 8960, upload-time = "2025-10-10T21:48:21.158Z" }, +] + [[package]] name = "id" version = "1.5.0" @@ -3257,6 +3336,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8f/8e/9ad090d3553c280a8060fbf6e24dc1c0c29704ee7d1c372f0c174aa59285/matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca", size = 9899, upload-time = "2024-04-15T13:44:43.265Z" }, ] +[[package]] +name = "mcp" +version = "1.25.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "httpx" }, + { name = "httpx-sse" }, + { name = "jsonschema" }, + { name = "pydantic" }, + { name = "pydantic-settings" }, + { name = "pyjwt", extra = ["crypto"] }, + { name = "python-multipart" }, + { name = "pywin32", marker = "sys_platform == 'win32'" }, + { name = "sse-starlette" }, + { name = "starlette" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, + { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d5/2d/649d80a0ecf6a1f82632ca44bec21c0461a9d9fc8934d38cb5b319f2db5e/mcp-1.25.0.tar.gz", hash = "sha256:56310361ebf0364e2d438e5b45f7668cbb124e158bb358333cd06e49e83a6802", size = 605387, upload-time = "2025-12-19T10:19:56.985Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e2/fc/6dc7659c2ae5ddf280477011f4213a74f806862856b796ef08f028e664bf/mcp-1.25.0-py3-none-any.whl", hash = "sha256:b37c38144a666add0862614cc79ec276e97d72aa8ca26d622818d4e278b9721a", size = 233076, upload-time = "2025-12-19T10:19:55.416Z" }, +] + [[package]] name = "mdurl" version = "0.1.2" @@ -3311,12 +3415,14 @@ version = "0.1.0" source = { virtual = "samples/model-garden" } dependencies = [ { name = "genkit" }, + { name = "genkit-plugin-vertex-ai" }, { name = "pydantic" }, ] [package.metadata] requires-dist = [ { name = "genkit", editable = "packages/genkit" }, + { name = "genkit-plugin-vertex-ai", editable = "plugins/vertex-ai" }, { name = "pydantic", specifier = ">=2.10.5" }, ] @@ -3496,45 +3602,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cc/d1/3598d1e73385baaab427392856f915487db7aa10abadd436f8f2d3e3b0f9/multipart-1.2.1-py3-none-any.whl", hash = "sha256:c03dc203bc2e67f6b46a599467ae0d87cf71d7530504b2c1ff4a9ea21d8b8c8c", size = 13730, upload-time = "2024-11-29T08:45:44.557Z" }, ] -[[package]] -name = "mypy" -version = "1.16.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "mypy-extensions" }, - { name = "pathspec" }, - { name = "tomli", marker = "python_full_version < '3.11'" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/d4/38/13c2f1abae94d5ea0354e146b95a1be9b2137a0d506728e0da037c4276f6/mypy-1.16.0.tar.gz", hash = "sha256:84b94283f817e2aa6350a14b4a8fb2a35a53c286f97c9d30f53b63620e7af8ab", size = 3323139, upload-time = "2025-05-29T13:46:12.532Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/64/5e/a0485f0608a3d67029d3d73cec209278b025e3493a3acfda3ef3a88540fd/mypy-1.16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7909541fef256527e5ee9c0a7e2aeed78b6cda72ba44298d1334fe7881b05c5c", size = 10967416, upload-time = "2025-05-29T13:34:17.783Z" }, - { url = "https://files.pythonhosted.org/packages/4b/53/5837c221f74c0d53a4bfc3003296f8179c3a2a7f336d7de7bbafbe96b688/mypy-1.16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e71d6f0090c2256c713ed3d52711d01859c82608b5d68d4fa01a3fe30df95571", size = 10087654, upload-time = "2025-05-29T13:32:37.878Z" }, - { url = "https://files.pythonhosted.org/packages/29/59/5fd2400352c3093bed4c09017fe671d26bc5bb7e6ef2d4bf85f2a2488104/mypy-1.16.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:936ccfdd749af4766be824268bfe22d1db9eb2f34a3ea1d00ffbe5b5265f5491", size = 11875192, upload-time = "2025-05-29T13:34:54.281Z" }, - { url = "https://files.pythonhosted.org/packages/ad/3e/4bfec74663a64c2012f3e278dbc29ffe82b121bc551758590d1b6449ec0c/mypy-1.16.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4086883a73166631307fdd330c4a9080ce24913d4f4c5ec596c601b3a4bdd777", size = 12612939, upload-time = "2025-05-29T13:33:14.766Z" }, - { url = "https://files.pythonhosted.org/packages/88/1f/fecbe3dcba4bf2ca34c26ca016383a9676711907f8db4da8354925cbb08f/mypy-1.16.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:feec38097f71797da0231997e0de3a58108c51845399669ebc532c815f93866b", size = 12874719, upload-time = "2025-05-29T13:21:52.09Z" }, - { url = "https://files.pythonhosted.org/packages/f3/51/c2d280601cd816c43dfa512a759270d5a5ef638d7ac9bea9134c8305a12f/mypy-1.16.0-cp310-cp310-win_amd64.whl", hash = "sha256:09a8da6a0ee9a9770b8ff61b39c0bb07971cda90e7297f4213741b48a0cc8d93", size = 9487053, upload-time = "2025-05-29T13:33:29.797Z" }, - { url = "https://files.pythonhosted.org/packages/24/c4/ff2f79db7075c274fe85b5fff8797d29c6b61b8854c39e3b7feb556aa377/mypy-1.16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:9f826aaa7ff8443bac6a494cf743f591488ea940dd360e7dd330e30dd772a5ab", size = 10884498, upload-time = "2025-05-29T13:18:54.066Z" }, - { url = "https://files.pythonhosted.org/packages/02/07/12198e83006235f10f6a7808917376b5d6240a2fd5dce740fe5d2ebf3247/mypy-1.16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:82d056e6faa508501af333a6af192c700b33e15865bda49611e3d7d8358ebea2", size = 10011755, upload-time = "2025-05-29T13:34:00.851Z" }, - { url = "https://files.pythonhosted.org/packages/f1/9b/5fd5801a72b5d6fb6ec0105ea1d0e01ab2d4971893076e558d4b6d6b5f80/mypy-1.16.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:089bedc02307c2548eb51f426e085546db1fa7dd87fbb7c9fa561575cf6eb1ff", size = 11800138, upload-time = "2025-05-29T13:32:55.082Z" }, - { url = "https://files.pythonhosted.org/packages/2e/81/a117441ea5dfc3746431e51d78a4aca569c677aa225bca2cc05a7c239b61/mypy-1.16.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6a2322896003ba66bbd1318c10d3afdfe24e78ef12ea10e2acd985e9d684a666", size = 12533156, upload-time = "2025-05-29T13:19:12.963Z" }, - { url = "https://files.pythonhosted.org/packages/3f/38/88ec57c6c86014d3f06251e00f397b5a7daa6888884d0abf187e4f5f587f/mypy-1.16.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:021a68568082c5b36e977d54e8f1de978baf401a33884ffcea09bd8e88a98f4c", size = 12742426, upload-time = "2025-05-29T13:20:22.72Z" }, - { url = "https://files.pythonhosted.org/packages/bd/53/7e9d528433d56e6f6f77ccf24af6ce570986c2d98a5839e4c2009ef47283/mypy-1.16.0-cp311-cp311-win_amd64.whl", hash = "sha256:54066fed302d83bf5128632d05b4ec68412e1f03ef2c300434057d66866cea4b", size = 9478319, upload-time = "2025-05-29T13:21:17.582Z" }, - { url = "https://files.pythonhosted.org/packages/70/cf/158e5055e60ca2be23aec54a3010f89dcffd788732634b344fc9cb1e85a0/mypy-1.16.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c5436d11e89a3ad16ce8afe752f0f373ae9620841c50883dc96f8b8805620b13", size = 11062927, upload-time = "2025-05-29T13:35:52.328Z" }, - { url = "https://files.pythonhosted.org/packages/94/34/cfff7a56be1609f5d10ef386342ce3494158e4d506516890142007e6472c/mypy-1.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f2622af30bf01d8fc36466231bdd203d120d7a599a6d88fb22bdcb9dbff84090", size = 10083082, upload-time = "2025-05-29T13:35:33.378Z" }, - { url = "https://files.pythonhosted.org/packages/b3/7f/7242062ec6288c33d8ad89574df87c3903d394870e5e6ba1699317a65075/mypy-1.16.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d045d33c284e10a038f5e29faca055b90eee87da3fc63b8889085744ebabb5a1", size = 11828306, upload-time = "2025-05-29T13:21:02.164Z" }, - { url = "https://files.pythonhosted.org/packages/6f/5f/b392f7b4f659f5b619ce5994c5c43caab3d80df2296ae54fa888b3d17f5a/mypy-1.16.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b4968f14f44c62e2ec4a038c8797a87315be8df7740dc3ee8d3bfe1c6bf5dba8", size = 12702764, upload-time = "2025-05-29T13:20:42.826Z" }, - { url = "https://files.pythonhosted.org/packages/9b/c0/7646ef3a00fa39ac9bc0938626d9ff29d19d733011be929cfea59d82d136/mypy-1.16.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:eb14a4a871bb8efb1e4a50360d4e3c8d6c601e7a31028a2c79f9bb659b63d730", size = 12896233, upload-time = "2025-05-29T13:18:37.446Z" }, - { url = "https://files.pythonhosted.org/packages/6d/38/52f4b808b3fef7f0ef840ee8ff6ce5b5d77381e65425758d515cdd4f5bb5/mypy-1.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:bd4e1ebe126152a7bbaa4daedd781c90c8f9643c79b9748caa270ad542f12bec", size = 9565547, upload-time = "2025-05-29T13:20:02.836Z" }, - { url = "https://files.pythonhosted.org/packages/97/9c/ca03bdbefbaa03b264b9318a98950a9c683e06472226b55472f96ebbc53d/mypy-1.16.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a9e056237c89f1587a3be1a3a70a06a698d25e2479b9a2f57325ddaaffc3567b", size = 11059753, upload-time = "2025-05-29T13:18:18.167Z" }, - { url = "https://files.pythonhosted.org/packages/36/92/79a969b8302cfe316027c88f7dc6fee70129490a370b3f6eb11d777749d0/mypy-1.16.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0b07e107affb9ee6ce1f342c07f51552d126c32cd62955f59a7db94a51ad12c0", size = 10073338, upload-time = "2025-05-29T13:19:48.079Z" }, - { url = "https://files.pythonhosted.org/packages/14/9b/a943f09319167da0552d5cd722104096a9c99270719b1afeea60d11610aa/mypy-1.16.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c6fb60cbd85dc65d4d63d37cb5c86f4e3a301ec605f606ae3a9173e5cf34997b", size = 11827764, upload-time = "2025-05-29T13:46:04.47Z" }, - { url = "https://files.pythonhosted.org/packages/ec/64/ff75e71c65a0cb6ee737287c7913ea155845a556c64144c65b811afdb9c7/mypy-1.16.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a7e32297a437cc915599e0578fa6bc68ae6a8dc059c9e009c628e1c47f91495d", size = 12701356, upload-time = "2025-05-29T13:35:13.553Z" }, - { url = "https://files.pythonhosted.org/packages/0a/ad/0e93c18987a1182c350f7a5fab70550852f9fabe30ecb63bfbe51b602074/mypy-1.16.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:afe420c9380ccec31e744e8baff0d406c846683681025db3531b32db56962d52", size = 12900745, upload-time = "2025-05-29T13:17:24.409Z" }, - { url = "https://files.pythonhosted.org/packages/28/5d/036c278d7a013e97e33f08c047fe5583ab4f1fc47c9a49f985f1cdd2a2d7/mypy-1.16.0-cp313-cp313-win_amd64.whl", hash = "sha256:55f9076c6ce55dd3f8cd0c6fff26a008ca8e5131b89d5ba6d86bd3f47e736eeb", size = 9572200, upload-time = "2025-05-29T13:33:44.92Z" }, - { url = "https://files.pythonhosted.org/packages/99/a3/6ed10530dec8e0fdc890d81361260c9ef1f5e5c217ad8c9b21ecb2b8366b/mypy-1.16.0-py3-none-any.whl", hash = "sha256:29e1499864a3888bca5c1542f2d7232c6e586295183320caa95758fc84034031", size = 2265773, upload-time = "2025-05-29T13:35:18.762Z" }, -] - [[package]] name = "mypy-extensions" version = "1.1.0" @@ -4499,6 +4566,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/36/c7/cfc8e811f061c841d7990b0201912c3556bfeb99cdcb7ed24adc8d6f8704/pydantic_core-2.41.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:56121965f7a4dc965bff783d70b907ddf3d57f6eba29b6d2e5dabfaf07799c51", size = 2145302, upload-time = "2025-11-04T13:43:46.64Z" }, ] +[[package]] +name = "pydantic-settings" +version = "2.12.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/43/4b/ac7e0aae12027748076d72a8764ff1c9d82ca75a7a52622e67ed3f765c54/pydantic_settings-2.12.0.tar.gz", hash = "sha256:005538ef951e3c2a68e1c08b292b5f2e71490def8589d4221b95dab00dafcfd0", size = 194184, upload-time = "2025-11-10T14:25:47.013Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/60/5d4751ba3f4a40a6891f24eec885f51afd78d208498268c734e256fb13c4/pydantic_settings-2.12.0-py3-none-any.whl", hash = "sha256:fddb9fd99a5b18da837b29710391e945b1e30c135477f484084ee513adb93809", size = 51880, upload-time = "2025-11-10T14:25:45.546Z" }, +] + [[package]] name = "pygments" version = "2.19.1" @@ -4508,6 +4589,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8a/0b/9fcc47d19c48b59121088dd6da2488a49d5f72dacf8262e2790a1d2c7d15/pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c", size = 1225293, upload-time = "2025-01-06T17:26:25.553Z" }, ] +[[package]] +name = "pyjwt" +version = "2.10.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785, upload-time = "2024-11-28T03:43:29.933Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" }, +] + +[package.optional-dependencies] +crypto = [ + { name = "cryptography" }, +] + [[package]] name = "pypdf" version = "6.5.0" @@ -4613,6 +4708,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, ] +[[package]] +name = "python-dotenv" +version = "1.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f0/26/19cadc79a718c5edbec86fd4919a6b6d3f681039a2f6d66d14be94e75fb9/python_dotenv-1.2.1.tar.gz", hash = "sha256:42667e897e16ab0d66954af0e60a9caa94f0fd4ecf3aaf6d2d260eec1aa36ad6", size = 44221, upload-time = "2025-10-26T15:12:10.434Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/1b/a298b06749107c305e1fe0f814c6c74aea7b2f1e10989cb30f544a1b3253/python_dotenv-1.2.1-py3-none-any.whl", hash = "sha256:b81ee9561e9ca4004139c6cbba3a238c32b03e4894671e181b671e8cb8425d61", size = 21230, upload-time = "2025-10-26T15:12:09.109Z" }, +] + [[package]] name = "python-json-logger" version = "3.3.0" @@ -4622,6 +4726,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/08/20/0f2523b9e50a8052bc6a8b732dfc8568abbdc42010aef03a2d750bdab3b2/python_json_logger-3.3.0-py3-none-any.whl", hash = "sha256:dd980fae8cffb24c13caf6e158d3d61c0d6d22342f932cb6e9deedab3d35eec7", size = 15163, upload-time = "2025-03-07T07:08:25.627Z" }, ] +[[package]] +name = "python-multipart" +version = "0.0.21" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/78/96/804520d0850c7db98e5ccb70282e29208723f0964e88ffd9d0da2f52ea09/python_multipart-0.0.21.tar.gz", hash = "sha256:7137ebd4d3bbf70ea1622998f902b97a29434a9e8dc40eb203bbcf7c2a2cba92", size = 37196, upload-time = "2025-12-17T09:24:22.446Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/aa/76/03af049af4dcee5d27442f71b6924f01f3efb5d2bd34f23fcd563f2cc5f5/python_multipart-0.0.21-py3-none-any.whl", hash = "sha256:cf7a6713e01c87aa35387f4774e812c4361150938d20d232800f75ffcf266090", size = 24541, upload-time = "2025-12-17T09:24:21.153Z" }, +] + [[package]] name = "pywin32" version = "310" @@ -5453,6 +5566,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7c/b6/74e927715a285743351233f33ea3c684528a0d374d2e43ff9ce9585b73fe/twine-6.1.0-py3-none-any.whl", hash = "sha256:a47f973caf122930bf0fbbf17f80b83bc1602c9ce393c7845f289a3001dc5384", size = 40791, upload-time = "2025-01-21T18:45:24.584Z" }, ] +[[package]] +name = "ty" +version = "0.0.11" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bc/45/5ae578480168d4b3c08cf8e5eac3caf8eb7acdb1a06a9bed7519564bd9b4/ty-0.0.11.tar.gz", hash = "sha256:ebcbc7d646847cb6610de1da4ffc849d8b800e29fd1e9ebb81ba8f3fbac88c25", size = 4920340, upload-time = "2026-01-09T21:06:01.592Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0f/34/b1d05cdcd01589a8d2e63011e0a1e24dcefdc2a09d024fee3e27755963f6/ty-0.0.11-py3-none-linux_armv6l.whl", hash = "sha256:68f0b8d07b0a2ea7ec63a08ba2624f853e4f9fa1a06fce47fb453fa279dead5a", size = 9521748, upload-time = "2026-01-09T21:06:13.221Z" }, + { url = "https://files.pythonhosted.org/packages/43/21/f52d93f4b3784b91bfbcabd01b84dc82128f3a9de178536bcf82968f3367/ty-0.0.11-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:cbf82d7ef0618e9ae3cc3c37c33abcfa302c9b3e3b8ff11d71076f98481cb1a8", size = 9454903, upload-time = "2026-01-09T21:06:42.363Z" }, + { url = "https://files.pythonhosted.org/packages/ad/01/3a563dba8b1255e474c35e1c3810b7589e81ae8c41df401b6a37c8e2cde9/ty-0.0.11-py3-none-macosx_11_0_arm64.whl", hash = "sha256:121987c906e02264c3b511b95cb9f8a3cdd66f3283b8bbab678ca3525652e304", size = 8823417, upload-time = "2026-01-09T21:06:26.315Z" }, + { url = "https://files.pythonhosted.org/packages/6f/b1/99b87222c05d3a28fb7bbfb85df4efdde8cb6764a24c1b138f3a615283dd/ty-0.0.11-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:999390b6cc045fe5e1b3da1c2c9ae8e8c0def23b69455e7c9191ba9ffd747023", size = 9290785, upload-time = "2026-01-09T21:05:59.028Z" }, + { url = "https://files.pythonhosted.org/packages/3d/9f/598809a8fff2194f907ba6de07ac3d7b7788342592d8f8b98b1b50c2fb49/ty-0.0.11-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed504d78eb613c49be3c848f236b345b6c13dc6bcfc4b202790a60a97e1d8f35", size = 9359392, upload-time = "2026-01-09T21:06:37.459Z" }, + { url = "https://files.pythonhosted.org/packages/71/3e/aeea2a97b38f3dcd9f8224bf83609848efa4bc2f484085508165567daa7b/ty-0.0.11-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7fedc8b43cc8a9991e0034dd205f957a8380dd29bfce36f2a35b5d321636dfd9", size = 9852973, upload-time = "2026-01-09T21:06:21.245Z" }, + { url = "https://files.pythonhosted.org/packages/72/40/86173116995e38f954811a86339ac4c00a2d8058cc245d3e4903bc4a132c/ty-0.0.11-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:0808bdfb7efe09881bf70249b85b0498fb8b75fbb036ce251c496c20adb10075", size = 10796113, upload-time = "2026-01-09T21:06:16.034Z" }, + { url = "https://files.pythonhosted.org/packages/69/71/97c92c401dacae9baa3696163ebe8371635ebf34ba9fda781110d0124857/ty-0.0.11-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:07185b3e38b18c562056dfbc35fb51d866f872977ea1ebcd64ca24a001b5b4f1", size = 10432137, upload-time = "2026-01-09T21:06:07.498Z" }, + { url = "https://files.pythonhosted.org/packages/18/10/9ab43f3cfc5f7792f6bc97620f54d0a0a81ef700be84ea7f6be330936a99/ty-0.0.11-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b5c72f1ada8eb5be984502a600f71d1a3099e12fb6f3c0607aaba2f86f0e9d80", size = 10240520, upload-time = "2026-01-09T21:06:34.823Z" }, + { url = "https://files.pythonhosted.org/packages/74/18/8dd4fe6df1fd66f3e83b4798eddb1d8482d9d9b105f25099b76703402ebb/ty-0.0.11-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:25f88e8789072830348cb59b761d5ced70642ed5600673b4bf6a849af71eca8b", size = 9973340, upload-time = "2026-01-09T21:06:39.657Z" }, + { url = "https://files.pythonhosted.org/packages/e4/0b/fb2301450cf8f2d7164944d6e1e659cac9ec7021556cc173d54947cf8ef4/ty-0.0.11-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:f370e1047a62dcedcd06e2b27e1f0b16c7f8ea2361d9070fcbf0d0d69baaa192", size = 9262101, upload-time = "2026-01-09T21:06:28.989Z" }, + { url = "https://files.pythonhosted.org/packages/f7/8c/d6374af023541072dee1c8bcfe8242669363a670b7619e6fffcc7415a995/ty-0.0.11-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:52be34047ed6177bfcef9247459a767ec03d775714855e262bca1fb015895e8a", size = 9382756, upload-time = "2026-01-09T21:06:24.097Z" }, + { url = "https://files.pythonhosted.org/packages/0d/44/edd1e63ffa8d49d720c475c2c1c779084e5efe50493afdc261938705d10a/ty-0.0.11-py3-none-musllinux_1_2_i686.whl", hash = "sha256:b9e5762ccb3778779378020b8d78f936b3f52ea83f18785319cceba3ae85d8e6", size = 9553944, upload-time = "2026-01-09T21:06:18.426Z" }, + { url = "https://files.pythonhosted.org/packages/35/cd/4afdb0d182d23d07ff287740c4954cc6dde5c3aed150ec3f2a1d72b00f71/ty-0.0.11-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e9334646ee3095e778e3dbc45fdb2bddfc16acc7804283830ad84991ece16dd7", size = 10060365, upload-time = "2026-01-09T21:06:45.083Z" }, + { url = "https://files.pythonhosted.org/packages/d1/94/a009ad9d8b359933cfea8721c689c0331189be28650d74dcc6add4d5bb09/ty-0.0.11-py3-none-win32.whl", hash = "sha256:44cfb7bb2d6784bd7ffe7b5d9ea90851d9c4723729c50b5f0732d4b9a2013cfc", size = 9040448, upload-time = "2026-01-09T21:06:32.241Z" }, + { url = "https://files.pythonhosted.org/packages/df/04/5a5dfd0aec0ea99ead1e824ee6e347fb623c464da7886aa1e3660fb0f36c/ty-0.0.11-py3-none-win_amd64.whl", hash = "sha256:1bb205db92715d4a13343bfd5b0c59ce8c0ca0daa34fb220ec9120fc66ccbda7", size = 9780112, upload-time = "2026-01-09T21:06:04.69Z" }, + { url = "https://files.pythonhosted.org/packages/ad/07/47d4fccd7bcf5eea1c634d518d6cb233f535a85d0b63fcd66815759e2fa0/ty-0.0.11-py3-none-win_arm64.whl", hash = "sha256:4688bd87b2dc5c85da277bda78daba14af2e66f3dda4d98f3604e3de75519eba", size = 9194038, upload-time = "2026-01-09T21:06:10.152Z" }, +] + [[package]] name = "typeguard" version = "4.4.2"