diff --git a/.gitignore b/.gitignore index a87d076799..9e61a53036 100644 --- a/.gitignore +++ b/.gitignore @@ -52,6 +52,8 @@ js/plugins/firebase/database-debug.log js/plugins/firebase/firestore-debug.log .firebaserc js/plugins/middleware/workspace/ +js/testapps/agents/.snapshots/ +js/testapps/agents/.snapshots-pruning/ # auto-generated /js/core/src/__codegen diff --git a/docs/agent-js.md b/docs/agent-js.md new file mode 100644 index 0000000000..1e5fbb48ca --- /dev/null +++ b/docs/agent-js.md @@ -0,0 +1,185 @@ +# Design Document: Genkit Agents in JS/TS + +**Status**: Implemented +**Authors**: Antigravity +**Target Package**: `@genkit-ai/ai` and `@genkit-ai/core` + +--- + +## 1. Objective + +This document proposes a unified **Agent** primitive for Genkit JS/TS, drawing architectural parity with the reference Go implementation while addressing modern agentic needs (State, History, and Artifact Persistence) for Tooling and the Genkit Dev UI. + +Agents act as a standard agent abstraction, allowing the framework to support multi-turn conversations and long-running agentic tasks with state persisted between independent runs. + +--- + +## 2. Background & Relationship to Existing APIs + +Genkit currently contains a beta `Session` class (`js/ai/src/session.ts`) and `SessionStore` interface. +**Decision**: We have decided to replace the existing beta `Session` and `SessionStore` APIs with the snapshot-based model described here to achieve full parity with the Go implementation and enable advanced tooling features. +- **Flow Integration**: Built on top of `defineBidiAction`, Agents natively support bidirectional streaming. +- **Durable Streaming**: Connects with Genkit's durable streaming protocol for robust reconnects over WebSockets. + +--- + +## 3. Core Schemas & Wire Protocol + +Agents enforce structured input/output payloads for predictable agent lifecycle management. These schemas match `genkit-tools/common/src/types/agent.ts` from the reference Go PR. + +### Session State +```ts +import { z } from 'zod'; +import { MessageSchema, PartSchema, ModelResponseChunkSchema } from './model-types.js'; + +export const ArtifactSchema = z.object({ + name: z.string().optional(), + parts: z.array(PartSchema), + metadata: z.record(z.any()).optional(), +}); + +export const SessionStateSchema = z.object({ + messages: z.array(MessageSchema).optional(), + custom: z.any().optional(), + artifacts: z.array(ArtifactSchema).optional(), +}); +``` + +### Wire Payloads +```ts +export const AgentInitSchema = z.object({ + snapshotId: z.string().optional(), + state: SessionStateSchema.optional(), +}); + +export const AgentInputSchema = z.object({ + messages: z.array(MessageSchema).optional(), + resume: z.object({ + respond: z.array(ToolResponsePartSchema).optional(), + restart: z.array(ToolRequestPartSchema).optional(), + }).optional(), +}); + +export const TurnEndSchema = z.object({ + snapshotId: z.string().optional(), +}); + +export const AgentStreamChunkSchema = z.object({ + modelChunk: ModelResponseChunkSchema.optional(), + status: z.any().optional(), + artifact: ArtifactSchema.optional(), + turnEnd: TurnEndSchema.optional(), +}); + +export const AgentOutputSchema = z.object({ + snapshotId: z.string().optional(), + state: SessionStateSchema.optional(), + message: MessageSchema.optional(), + artifacts: z.array(ArtifactSchema).optional(), +}); +``` + +--- + +## 4. Persistence & The Snapshot System + +To maintain state across environments, Genkit provides `SessionStore` abstractions for saving and loading point-in-time captures (`SessionSnapshot`). + +### Interfaces (Strongly Typed) +```ts +export interface SnapshotContext { + state: SessionState; + prevState?: SessionState; + turnIndex: number; + event: 'turnEnd' | 'invocationEnd'; +} + +export type SnapshotCallback = (ctx: SnapshotContext) => boolean; + +export interface SessionSnapshot { + snapshotId: string; + parentId?: string; + createdAt: string; + event: 'turnEnd' | 'invocationEnd'; + state: SessionState; +} + +export interface SessionStore { + getSnapshot(snapshotId: string): Promise | undefined>; + saveSnapshot(snapshot: SessionSnapshot): Promise; +} +``` + +--- + +## 5. SDK APIs + +### 5.1 `defineCustomAgent` +Allows programmatic declaration of agent logic with an injected `SessionRunner`. + +```ts +export function defineCustomAgent( + registry: Registry, + config: { + name: string; + description?: string; + store?: SessionStore; + snapshotCallback?: SnapshotCallback; + toClient?: { + messages?: (msgs: Message[]) => Message[]; + state?: (state: State) => Partial; + }; + }, + fn: AgentFn +): Agent; +``` + +### 5.2 `definePromptAgent` +Ergonomic shortcut for standard prompt-backed loop orchestration. Automatically manages history, tool restarts, and renders prompts. References a prompt defined separately via `definePrompt`. + +```ts +export function definePromptAgent( + registry: Registry, + config: { + promptName: string; + store?: SessionStore; + } +): Agent; +``` + +### 5.3 `defineAgent` +The most ergonomic API — combines `definePrompt` and `definePromptAgent` into a single flat config. The config accepts all `PromptConfig` fields (name, model, system, tools, etc.) plus agent-specific fields (`store`, `snapshotCallback`). + +```ts +export function defineAgent( + registry: Registry, + config: AgentConfig +): Agent; + +export interface AgentConfig extends PromptConfig { + store?: SessionStore; + snapshotCallback?: SnapshotCallback; +} +``` + +--- + +## 6. Tooling & Dev UI Integration + +- **Live State Playground**: Renders a continuous view of the accumulated `Artifacts` and `statePatch` streams. +- **Time-Travel Debugging**: Snapshots are tied directly to trace spans, enabling the developer to resume sessions from a past `snapshotId`. + +--- + +## 7. Execution & Verification Plan + +1. **Phase 1: Core Types & Schemas** + - Declare Zod schemas in `js/ai/src/agent.ts`. +2. **Phase 2: Context Runner & Wrappers** + - Construct the `SessionRunner` orchestrator. + - Integrate snapshot event callbacks into execution loops. +3. **Phase 3: Prompt Engine Hooks** + - Integrate prompt rendering into agent execution loops. +4. **Phase 4: Verification** + - Test state retention between single-turn `.run()` bounds. + - Simulate client disconnections over `.streamBidi()`. diff --git a/js/ai/src/agent.ts b/js/ai/src/agent.ts new file mode 100644 index 0000000000..6cea031e1d --- /dev/null +++ b/js/ai/src/agent.ts @@ -0,0 +1,1139 @@ +/** + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { + GenkitError, + deepEqual, + defineAction, + defineBidiAction, + getContext, + run, + z, + type Action, + type ActionContext, + type ActionFnArg, + type BidiAction, +} from '@genkit-ai/core'; +import { Channel } from '@genkit-ai/core/async'; +import type { Registry } from '@genkit-ai/core/registry'; +import { parseSchema, toJsonSchema } from '@genkit-ai/core/schema'; +import { generateStream } from './generate.js'; +import { + MessageData, + MessageSchema, + ModelResponseChunkSchema, +} from './model-types.js'; +import { + ToolRequestPartSchema, + ToolResponsePartSchema, + type ToolRequestPart, + type ToolResponsePart, +} from './parts.js'; +import { + definePrompt, + type PromptAction, + type PromptConfig, +} from './prompt.js'; +import { + Artifact, + ArtifactSchema, + InMemorySessionStore, + Session, + SessionSnapshot, + SessionState, + SessionStateSchema, + SessionStore, + SnapshotCallback, + runWithSession, + type SessionSnapshotInput, + type SessionStoreOptions, +} from './session.js'; + +/** + * Schema for initializing an agent turn. + */ +export const AgentInitSchema = z.object({ + snapshotId: z.string().optional(), + state: SessionStateSchema.optional(), +}); + +/** + * Initialization options for an agent turn. + */ +export interface AgentInit { + snapshotId?: string; + newSnapshotId?: string; + state?: SessionState; +} + +/** + * Schema for agent input messages and commands. + */ +export const AgentInputSchema = z.object({ + messages: z.array(MessageSchema).optional(), + /** Options for resuming an interrupted generation. */ + resume: z + .object({ + respond: z.array(ToolResponsePartSchema).optional(), + restart: z.array(ToolRequestPartSchema).optional(), + }) + .optional(), + detach: z.boolean().optional(), +}); + +/** + * Input received by an agent turn. + */ +export type AgentInput = z.infer; + +/** + * Schema identifying a turn termination event. + */ +export const TurnEndSchema = z.object({ + snapshotId: z.string().optional(), +}); + +/** + * Identifies a turn termination event. + */ +export type TurnEnd = z.infer; + +/** + * Schema for stream chunks emitted during agent execution. + */ +export const AgentStreamChunkSchema = z.object({ + modelChunk: ModelResponseChunkSchema.optional(), + status: z.any().optional(), + artifact: ArtifactSchema.optional(), + turnEnd: TurnEndSchema.optional(), +}); + +/** + * Streamed chunk emitted during agent execution. + * The `Stream` parameter types the `status` field for custom status payloads. + */ +export type AgentStreamChunk = Omit< + z.infer, + 'status' +> & { status?: Stream }; + +/** + * Schema for final results of an agent execution. + */ +export const AgentResultSchema = z.object({ + message: MessageSchema.optional(), + artifacts: z.array(ArtifactSchema).optional(), +}); + +/** + * Result returned upon completing an agent execution. + */ +export type AgentResult = z.infer; + +/** + * Schema for output returned at turn completion. + */ +export const AgentOutputSchema = z.object({ + snapshotId: z.string().optional(), + state: SessionStateSchema.optional(), + message: MessageSchema.optional(), + artifacts: z.array(ArtifactSchema).optional(), +}); + +/** + * Output returned at turn completion. + */ +export interface AgentOutput { + artifacts?: Artifact[]; + message?: MessageData; + snapshotId?: string; + state?: SessionState; +} + +/** + * Executor responsible for running turns over input streams and persisting state. + */ +export class SessionRunner { + readonly session: Session; + readonly inputCh: AsyncIterable; + turnIndex: number = 0; + public onEndTurn?: (snapshotId?: string) => void; + public onDetach?: (snapshotId: string) => void; + public newSnapshotId?: string; + private snapshotCallback?: SnapshotCallback; + private lastSnapshot?: SessionSnapshot; + + private lastSnapshotVersion: number = 0; + private store?: SessionStore; + public isDetached: boolean = false; + + constructor( + session: Session, + inputCh: AsyncIterable, + options?: { + snapshotCallback?: SnapshotCallback; + lastSnapshot?: SessionSnapshot; + store?: SessionStore; + onEndTurn?: (snapshotId?: string) => void; + onDetach?: (snapshotId: string) => void; + newSnapshotId?: string; + } + ) { + this.session = session; + this.inputCh = inputCh; + + this.snapshotCallback = options?.snapshotCallback; + this.lastSnapshot = options?.lastSnapshot; + this.store = options?.store; + this.onEndTurn = options?.onEndTurn; + this.onDetach = options?.onDetach; + this.newSnapshotId = options?.newSnapshotId; + } + + // ── Session delegate methods ──────────────────────────────────────── + // These forward to `this.session` so callers can write `sess.addMessages()` + // instead of the verbose `sess.session.addMessages()`. + + /** Returns a deep copy of the current session state. */ + getState(): SessionState { + return this.session.getState(); + } + + /** Retrieves all messages associated with the session. */ + getMessages(): MessageData[] { + return this.session.getMessages(); + } + + /** Appends messages to the session. */ + addMessages(messages: MessageData[]): void { + this.session.addMessages(messages); + } + + /** Overwrites the session messages. */ + setMessages(messages: MessageData[]): void { + this.session.setMessages(messages); + } + + /** Retrieves the custom state of the session. */ + getCustom(): State | undefined { + return this.session.getCustom(); + } + + /** Updates the custom state using a mutator function. */ + updateCustom(fn: (custom?: State) => State): void { + this.session.updateCustom(fn); + } + + /** Retrieves the list of artifacts generated during the session. */ + getArtifacts(): Artifact[] { + return this.session.getArtifacts(); + } + + /** Adds artifacts to the session, deduplicating by name. */ + addArtifacts(artifacts: Artifact[]): void { + this.session.addArtifacts(artifacts); + } + + /** + * Executes the flow handler against incoming input messages sequentially. + */ + async run(fn: (input: AgentInput) => Promise): Promise { + for await (const input of this.inputCh) { + if (input.messages) { + this.session.addMessages(input.messages); + } + + const turnSnapshotId = this.newSnapshotId; + this.newSnapshotId = undefined; + + try { + await run(`runTurn-${this.turnIndex + 1}`, input, async () => { + await fn(input); + + const snapshotId = await this.maybeSnapshot( + 'turnEnd', + 'done', + undefined, + turnSnapshotId + ); + try { + if (this.onEndTurn) { + this.onEndTurn(snapshotId); + } + } catch (e) { + // Stream was closed, absorb exception + } + return { + lastSnapshot: this.lastSnapshot, + }; + }); + this.turnIndex++; + } catch (e: any) { + const errStatus = e.status || 'INTERNAL'; + const errMessage = e.message || 'Internal failure'; + const errDetails = e.detail || e.details || e; + const snapshotId = await this.maybeSnapshot( + 'turnEnd', + 'failed', + { + status: errStatus, + message: errMessage, + details: errDetails, + }, + turnSnapshotId + ); + try { + if (this.onEndTurn) { + this.onEndTurn(snapshotId); + } + } catch (_) { + // Stream was closed, absorb exception + } + throw e; + } + } + } + + /** + * Evaluates whether to save a snapshot to the persistent store. + * + * Uses the mutator-based `saveSnapshot` to atomically check that the + * snapshot has not been concurrently aborted before writing — preventing + * a race where a "done" write could overwrite a concurrent "aborted" + * status. + */ + async maybeSnapshot( + event: 'turnEnd' | 'invocationEnd', + status?: 'pending' | 'done' | 'failed', + error?: { status: string; message: string; details?: any }, + snapshotId?: string + ): Promise { + if ( + !this.store || + (this.isDetached && snapshotId !== this.lastSnapshot?.snapshotId) + ) + return this.lastSnapshot?.snapshotId; + + const currentVersion = this.session.getVersion(); + if (currentVersion === this.lastSnapshotVersion && !status) { + return this.lastSnapshot?.snapshotId; + } + + const currentState = this.session.getState(); + const prevState = this.lastSnapshot ? this.lastSnapshot.state : undefined; + + if (this.snapshotCallback && !this.isDetached) { + if ( + !this.snapshotCallback({ + state: currentState as SessionState, + prevState: prevState as SessionState | undefined, + turnIndex: this.turnIndex, + event: event, + }) + ) { + return undefined; + } + } + + const snapshotInput: SessionSnapshotInput = { + ...(snapshotId || this.newSnapshotId + ? { snapshotId: (snapshotId || this.newSnapshotId)! } + : {}), + createdAt: new Date().toISOString(), + event: event, + state: currentState as SessionState, + parentId: this.lastSnapshot?.snapshotId, + status, + error, + }; + + const effectiveId = snapshotId || this.newSnapshotId; + + // Use the mutator-based saveSnapshot to atomically check the current + // status before writing. If the snapshot was concurrently aborted, + // the mutator returns null and the write is skipped. + const assignedId = await this.store.saveSnapshot( + effectiveId, + (current) => { + if (current?.status === 'aborted') { + return null; // Respect the abort — skip the write. + } + return snapshotInput; + }, + { context: getContext() } + ); + if (assignedId === null) { + // Snapshot was aborted concurrently; preserve the existing ID + // without overwriting. + return effectiveId; + } + + this.lastSnapshot = { ...snapshotInput, snapshotId: assignedId }; + this.lastSnapshotVersion = currentVersion; + + return assignedId; + } +} + +/** + * Optional transform applied to session state before it is exposed to the + * client (e.g. in `AgentOutput.state` or via `getSnapshotData`). This lets + * agents redact sensitive fields or reshape the state for the client. + */ +export type ClientStateTransform = ( + state: SessionState +) => SessionState; + +/** + * Function handler definition for custom agent actions. + */ +export type AgentFn = ( + sess: SessionRunner, + options: { + sendChunk: (chunk: AgentStreamChunk) => void; + abortSignal?: AbortSignal; + context?: ActionContext; + } +) => Promise; + +export type GetSnapshotDataAction = Action< + z.ZodString, + z.ZodType> +>; + +/** + * Represents a configured, registered Agent. + */ +export interface Agent + extends BidiAction< + typeof AgentInputSchema, + typeof AgentOutputSchema, + typeof AgentStreamChunkSchema, + typeof AgentInitSchema + > { + getSnapshotData( + snapshotId: string, + options?: SessionStoreOptions + ): Promise | undefined>; + + abort( + snapshotId: string, + options?: SessionStoreOptions + ): Promise; + + readonly getSnapshotDataAction: GetSnapshotDataAction; + readonly abortAgentAction: Action>; +} + +/** + * Registers a multi-turn custom agent action capable of maintaining persistent state. + * + * When `stateSchema` is provided the custom state is validated at load time + * (from a snapshot store or from the client-supplied `init.state`) and the + * JSON Schema representation is included in the action metadata so that + * tooling (e.g. the Dev UI) can inspect / validate the state shape. + */ +export function defineCustomAgent( + registry: Registry, + config: { + name: string; + description?: string; + stateSchema?: z.ZodType; + store?: SessionStore; + snapshotCallback?: SnapshotCallback; + clientStateTransform?: ClientStateTransform; + }, + fn: AgentFn +): Agent { + // Helper that applies the optional transform before exposing state to the + // client. When no transform is configured it returns the raw state. + const toClientState = ( + state: SessionState + ): SessionState | undefined => { + if (config.clientStateTransform) { + return config.clientStateTransform(state); + } + return state as SessionState; + }; + + // If a state schema was provided, pre-compute the JSON schema once so it + // can be embedded in metadata and reused for validation. + const stateJsonSchema = config.stateSchema + ? toJsonSchema({ schema: config.stateSchema }) + : undefined; + + /** + * Validates the `custom` field of a session state against the configured + * `stateSchema`. No-ops when no schema was provided. + */ + const validateCustomState = (custom: unknown, label: string): void => { + if (config.stateSchema && custom !== undefined) { + parseSchema(custom, { schema: config.stateSchema }); + } + }; + + const primaryAction = defineBidiAction( + registry, + { + name: config.name, + description: config.description, + actionType: 'agent', + inputSchema: AgentInputSchema, + outputSchema: AgentOutputSchema, + streamSchema: AgentStreamChunkSchema, + initSchema: AgentInitSchema, + metadata: { + agent: { + stateManagement: config.store ? 'server' : 'client', + abortable: !!config.store?.onSnapshotStateChange, + ...(stateJsonSchema && { stateSchema: stateJsonSchema }), + }, + }, + }, + async function* ( + arg: ActionFnArg + ) { + const init = arg.init; + const store = config.store || new InMemorySessionStore(); + + // Validate that the init strategy matches the agent's state management + // mode. Server-managed agents (with a store) expect a snapshotId; + // client-managed agents (no store) expect the full state blob. + if (init?.snapshotId && !config.store) { + throw new GenkitError({ + status: 'FAILED_PRECONDITION', + message: + `Cannot use 'snapshotId' with agent '${config.name}': this agent ` + + `has no store configured (client-managed state). Send 'state' instead.`, + }); + } + if (init?.state && config.store) { + throw new GenkitError({ + status: 'FAILED_PRECONDITION', + message: + `Cannot send 'state' to agent '${config.name}': this agent uses ` + + `a server-managed store. Send 'snapshotId' instead.`, + }); + } + + let session: Session; + + let snapshot: SessionSnapshot | undefined; + + if (init?.snapshotId) { + snapshot = await store.getSnapshot(init.snapshotId, { + context: getContext(), + }); + if (!snapshot) { + throw new Error(`Snapshot ${init.snapshotId} not found`); + } + validateCustomState( + snapshot.state?.custom, + `snapshot ${init.snapshotId}` + ); + session = new Session(snapshot.state as SessionState); + } else if (init?.state && !config.store) { + validateCustomState(init.state.custom, 'client-supplied init.state'); + session = new Session(init.state as SessionState); + } else { + session = new Session({ + custom: {} as State, + artifacts: [], + messages: [], + }); + } + + let detachedSnapshotId: string | undefined; + let resolveDetach: + | ((value: void | PromiseLike) => void) + | undefined; + let rejectDetach: ((reason: any) => void) | undefined; + const detachPromise = new Promise((resolve, reject) => { + resolveDetach = resolve; + rejectDetach = reject; + }); + + const abortController = new AbortController(); + let unsubscribe: any = undefined; + + let runner!: SessionRunner; + + // We construct an asynchronous proxy channel over the inputStream. + // This enables immediate interception of `detach: true` directives. Without this proxy, + // a backlog of pre-queued inputs would have to be resolved sequentially by the runner first. + const runnerInputChannel = new Channel(); + + (async () => { + try { + for await (const input of arg.inputStream) { + if (input.detach) { + if (!config.store) { + if (rejectDetach) { + rejectDetach( + new GenkitError({ + status: 'FAILED_PRECONDITION', + message: + 'Detach is only supported when a session store is provided.', + }) + ); + } + } else { + const turnSnapshotId = + runner.newSnapshotId || crypto.randomUUID(); + runner.newSnapshotId = turnSnapshotId; + await runner.maybeSnapshot( + 'turnEnd', + 'pending', + undefined, + turnSnapshotId + ); + runner.isDetached = true; + + if (runner.onDetach) { + runner.onDetach(turnSnapshotId); + } + } + // Only forward to runner if the input carries a payload beyond the + // detach directive; a detach-only message has no turn to process. + const hasPayload = !!( + input.messages?.length || + input.resume?.restart?.length || + input.resume?.respond?.length + ); + if (hasPayload) { + runnerInputChannel.send(input); + } + } else { + runnerInputChannel.send(input); + } + } + runnerInputChannel.close(); + } catch (e) { + runnerInputChannel.error(e); + } + })(); + + runner = new SessionRunner(session, runnerInputChannel, { + store, + snapshotCallback: config.snapshotCallback, + lastSnapshot: snapshot, + newSnapshotId: init?.newSnapshotId, + onDetach: (snapshotId) => { + detachedSnapshotId = snapshotId; + if (resolveDetach) { + resolveDetach(); + } + + if (store.onSnapshotStateChange) { + unsubscribe = store.onSnapshotStateChange( + snapshotId, + (snap) => { + if (snap.status === 'aborted') { + abortController.abort(); + if (unsubscribe) unsubscribe(); + } + }, + { context: getContext() } + ); + } + }, + + onEndTurn: (snapshotId) => { + if (!runner.isDetached) { + arg.sendChunk({ + turnEnd: { ...(config.store && { snapshotId }) }, + }); + } + }, + }); + + const sendArtifactChunk = (a: Artifact) => { + if (!runner.isDetached) { + arg.sendChunk({ artifact: a }); + } + }; + session.on('artifactAdded', sendArtifactChunk); + session.on('artifactUpdated', sendArtifactChunk); + + const sendChunk = (chunk: AgentStreamChunk) => { + if (!runner.isDetached) { + arg.sendChunk(chunk as AgentStreamChunk); + } + }; + + const flowPromise = (async () => { + try { + const result = await runWithSession(registry, session, () => + fn(runner, { + sendChunk, + abortSignal: abortController.signal, + context: getContext(), + }) + ); + const finalSnapshotId = await runner.maybeSnapshot('invocationEnd'); + return { result, finalSnapshotId }; + } finally { + if (unsubscribe) unsubscribe(); + session.off('artifactAdded', sendArtifactChunk); + session.off('artifactUpdated', sendArtifactChunk); + } + })(); + + // We race the background flow execution against the detach signal. + // If detachment is requested, we yield output metadata early, but allow + // the flow handler promise to continue its asynchronous completion. + const outcome = await Promise.race([ + flowPromise, + detachPromise.then(() => 'detached' as const), + ]); + + if (outcome === 'detached') { + return { + snapshotId: detachedSnapshotId!, + ...(!config.store && { state: toClientState(session.getState()) }), + }; + } + + const { result, finalSnapshotId } = outcome; + + return { + ...(result.artifacts?.length && { artifacts: result.artifacts }), + ...(result.message && { message: result.message }), + ...(config.store && { snapshotId: finalSnapshotId }), + ...(!config.store && { state: toClientState(session.getState()) }), + }; + } + ); + + // Helper that applies the clientStateTransform to a snapshot's state, + // returning a new snapshot object with the transformed state. + const toClientSnapshot = ( + snapshot: SessionSnapshot + ): SessionSnapshot => { + if (!config.clientStateTransform) { + return snapshot as SessionSnapshot; + } + return { + ...snapshot, + state: config.clientStateTransform(snapshot.state), + }; + }; + + const getSnapshotDataAction = defineAction( + registry, + { + name: config.name, + description: `Gets snapshot data for ${config.name} by snapshotId`, + actionType: 'agent-snapshot', + inputSchema: z.string(), + outputSchema: z.any(), // SessionSnapshot Schema + }, + async (snapshotId) => { + if (!config.store) { + throw new GenkitError({ + status: 'FAILED_PRECONDITION', + message: `getSnapshotData requires a persistent store. Provide a 'store' when defining '${config.name}'.`, + }); + } + const snapshot = await config.store.getSnapshot(snapshotId, { + context: getContext(), + }); + return snapshot ? toClientSnapshot(snapshot) : undefined; + } + ); + + const abortAgentAction = defineAction( + registry, + { + name: config.name, + description: `Aborts ${config.name} agent by snapshotId. Returns the previous status of the snapshot before it was set to 'aborted', or undefined if the snapshot was not found.`, + actionType: 'agent-abort', + inputSchema: z.string(), + outputSchema: z.string().optional(), + }, + async (snapshotId) => { + if (!config.store) { + throw new GenkitError({ + status: 'FAILED_PRECONDITION', + message: `abort requires a persistent store. Provide a 'store' when defining '${config.name}'.`, + }); + } + let previousStatus: SessionSnapshot['status'] | undefined; + await config.store.saveSnapshot( + snapshotId, + (current) => { + if (!current) return null; + previousStatus = current.status; + if ( + current.status === 'done' || + current.status === 'failed' || + current.status === 'aborted' + ) { + return null; // Already terminal — don't override. + } + return { ...current, status: 'aborted' }; + }, + { context: getContext() } + ); + return previousStatus; + } + ); + + const composite = Object.assign(primaryAction, { + getSnapshotData: async ( + snapshotId: string, + options?: SessionStoreOptions + ) => { + if (!config.store) { + throw new GenkitError({ + status: 'FAILED_PRECONDITION', + message: `getSnapshotData requires a persistent store. Provide a 'store' when defining '${config.name}'.`, + }); + } + const snapshot = await config.store.getSnapshot(snapshotId, options); + return snapshot ? toClientSnapshot(snapshot) : undefined; + }, + abort: async (snapshotId: string, options?: SessionStoreOptions) => { + if (!config.store) { + throw new GenkitError({ + status: 'FAILED_PRECONDITION', + message: `abort requires a persistent store. Provide a 'store' when defining '${config.name}'.`, + }); + } + let previousStatus: SessionSnapshot['status'] | undefined; + await config.store.saveSnapshot( + snapshotId, + (current) => { + if (!current) return null; + previousStatus = current.status; + if ( + current.status === 'done' || + current.status === 'failed' || + current.status === 'aborted' + ) { + return null; // Already terminal — don't override. + } + return { ...current, status: 'aborted' }; + }, + options + ); + return previousStatus; + }, + getSnapshotDataAction: + getSnapshotDataAction as unknown as GetSnapshotDataAction, + abortAgentAction: abortAgentAction as unknown as Action< + z.ZodString, + z.ZodType + >, + }); + + return composite as unknown as Agent; +} + +/** + * Registers an agent from an existing PromptAction. + */ +export function definePromptAgent( + registry: Registry, + config: { + promptName: string; + stateSchema?: z.ZodType; + store?: SessionStore; + snapshotCallback?: SnapshotCallback; + clientStateTransform?: ClientStateTransform; + } +) { + let cachedPromptAction: PromptAction | undefined; + + const fn: AgentFn = async ( + sess, + { sendChunk, abortSignal } + ) => { + await sess.run(async (input) => { + const promptInput = {}; + + if (!cachedPromptAction) { + cachedPromptAction = (await registry.lookupAction( + `/prompt/${config.promptName}` + )) as PromptAction; + if (!cachedPromptAction) { + throw new Error( + `Prompt '${config.promptName}' not found. Ensure it is defined before the agent is invoked.` + ); + } + } + + const historyTag = '_genkit_history'; + const promptTag = 'agentPreamble'; + + // Tag every history message so we can identify them after render. + const history = (sess.getMessages() || []).map((m) => ({ + ...m, + metadata: { ...m.metadata, [historyTag]: true }, + })); + + // Let the prompt control where history is placed (e.g. dotprompt + // {{history}}). When the prompt has no explicit `messages` config + // the render helper simply appends history after system/user. + const genOpts = await cachedPromptAction.__executablePrompt.render( + promptInput as unknown as z.ZodTypeAny, + { messages: history } + ); + + // After render: tag everything that is NOT history as a prompt + // message so we can strip it after generation. Also strip the + // internal history tag — it is an implementation detail that + // should not leak to the model. + if (genOpts.messages) { + genOpts.messages = genOpts.messages.map((m) => { + if (m.metadata?.[historyTag]) { + // Strip the history tag before sending to the model. + const { [historyTag]: _, ...restMeta } = m.metadata!; + return { + ...m, + metadata: Object.keys(restMeta).length ? restMeta : undefined, + }; + } + return { ...m, metadata: { ...m.metadata, [promptTag]: true } }; + }); + } + + if (input.resume) { + // Safety: validate that every restart/respond entry references + // a tool request that actually exists in the session history. + // For restarts, also verify that the input has not been tampered with. + validateResumeAgainstHistory(input.resume, sess.getMessages()); + + genOpts.resume = { + ...(input.resume.restart?.length && { + restart: input.resume.restart as ToolRequestPart[], + }), + ...(input.resume.respond?.length && { + respond: input.resume.respond as ToolResponsePart[], + }), + }; + } + + const result = generateStream(registry, { ...genOpts, abortSignal }); + + for await (const chunk of result.stream) { + sendChunk({ modelChunk: chunk }); + } + + const res = await result.response; + + // Keep everything that is NOT a prompt-template message: + // • history messages (clean — history tag was stripped before generate) + // • new messages from tool loops (untagged) + // • model response + if (res.request?.messages) { + const msgs = res.request.messages.filter( + (m) => !m.metadata?.[promptTag] + ); + if (res.message) { + msgs.push(res.message); + } + sess.setMessages(msgs); + } else if (res.message) { + sess.addMessages([res.message]); + } + + if (res.finishReason === 'interrupted') { + const parts = + res.message?.content?.filter((p) => !!p.toolRequest) || []; + if (parts.length > 0) { + sendChunk({ + modelChunk: { + role: 'tool', + content: parts, + }, + }); + } + } + }); + + const msgs = sess.getMessages(); + return { + artifacts: sess.getArtifacts(), + message: msgs.length > 0 ? msgs[msgs.length - 1] : undefined, + }; + }; + + return defineCustomAgent( + registry, + { + name: config.promptName, + stateSchema: config.stateSchema, + store: config.store, + snapshotCallback: config.snapshotCallback, + clientStateTransform: config.clientStateTransform, + }, + fn + ); +} + +// --------------------------------------------------------------------------- +// Resume validation — ensure restart/respond entries match session history +// --------------------------------------------------------------------------- + +/** + * Validates that every `resume.restart` and `resume.respond` entry references + * a tool request that actually exists in the session history. + * + * For **restart** entries, also validates that the `input` has not been modified + * compared to the original tool request — preventing a malicious client from + * forging tool inputs. + * + * For **respond** entries, validates that a matching tool request (by name + ref) + * exists in history. + * + * Searches the **entire history** (all model messages), not just the last one. + */ +export function validateResumeAgainstHistory( + resume: { + restart?: Array<{ + toolRequest: { name: string; ref?: string; input?: unknown }; + metadata?: Record; + }>; + respond?: Array<{ + toolResponse: { name: string; ref?: string; output?: unknown }; + }>; + }, + history: MessageData[] +): void { + // Collect all tool requests from all model messages in the stored history. + const allToolRequests: Array<{ + name: string; + ref?: string; + input?: unknown; + }> = []; + for (const msg of history) { + if (msg.role === 'model') { + for (const part of msg.content) { + if (part.toolRequest) { + allToolRequests.push(part.toolRequest); + } + } + } + } + + // Validate restart entries: name + ref must exist AND input must match exactly + for (const restart of resume.restart || []) { + const { name, ref, input } = restart.toolRequest; + const match = allToolRequests.find( + (tr) => tr.name === name && tr.ref === ref + ); + if (!match) { + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: + `resume.restart references tool '${name}'` + + (ref ? ` (ref: ${ref})` : '') + + ` which was not found in session history.`, + }); + } + if (!deepEqual(input, match.input)) { + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: + `resume.restart for tool '${name}'` + + (ref ? ` (ref: ${ref})` : '') + + ` has modified inputs that do not match the original tool request ` + + `in session history. Restart inputs must exactly match the ` + + `interrupted tool request.`, + }); + } + } + + // Validate respond entries: name + ref must match a tool request in history + for (const respond of resume.respond || []) { + const { name, ref } = respond.toolResponse; + const match = allToolRequests.find( + (tr) => tr.name === name && tr.ref === ref + ); + if (!match) { + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: + `resume.respond references tool '${name}'` + + (ref ? ` (ref: ${ref})` : '') + + ` which was not found in session history.`, + }); + } + } +} + +// --------------------------------------------------------------------------- +// defineAgent — shortcut that combines definePrompt + definePromptAgent +// --------------------------------------------------------------------------- + +/** + * Configuration for `defineAgent`, which combines prompt definition and agent + * registration into a single call. + */ +export interface AgentConfig extends PromptConfig { + /** + * Optional Zod schema describing the shape of the custom session state. + * + * When provided: + * - The `State` type is inferred from the schema (no explicit generic needed). + * - The JSON Schema is included in action metadata (`metadata.agent.stateSchema`) + * so the Dev UI and other tooling can inspect / validate the state. + * - Custom state is validated at load time (from a snapshot store or from the + * client-supplied `init.state`). + */ + stateSchema?: z.ZodType; + store?: SessionStore; + snapshotCallback?: SnapshotCallback; + clientStateTransform?: ClientStateTransform; +} + +/** + * Defines and registers an agent by creating a prompt and wiring it into a + * multi-turn agent in one step. + * + * This is a convenience shortcut for: + * ```ts + * definePrompt(registry, promptConfig); + * definePromptAgent(registry, { promptName: promptConfig.name, ... }); + * ``` + */ +export function defineAgent( + registry: Registry, + config: AgentConfig +): Agent { + // Extract agent-specific fields from the combined config; the rest is + // forwarded to definePrompt. + const { + stateSchema, + store, + snapshotCallback, + clientStateTransform, + ...promptConfig + } = config; + + // Register the prompt. + definePrompt(registry, promptConfig); + + // Wire it into a prompt agent. + return definePromptAgent(registry, { + promptName: promptConfig.name, + stateSchema, + store, + snapshotCallback, + clientStateTransform, + }); +} diff --git a/js/ai/src/index.ts b/js/ai/src/index.ts index edf234a368..6f12ed9726 100644 --- a/js/ai/src/index.ts +++ b/js/ai/src/index.ts @@ -14,6 +14,7 @@ * limitations under the License. */ +export * from './agent.js'; export { cancelOperation } from './cancel-operation.js'; export { checkOperation } from './check-operation.js'; export { Document, DocumentDataSchema, type DocumentData } from './document.js'; diff --git a/js/ai/src/prompt.ts b/js/ai/src/prompt.ts index 894de3168c..363ba22241 100644 --- a/js/ai/src/prompt.ts +++ b/js/ai/src/prompt.ts @@ -309,7 +309,8 @@ function definePromptAsync< let docs: DocumentData[] | undefined; if (typeof resolvedOptions.docs === 'function') { docs = await resolvedOptions.docs(input, { - state: session?.state, + state: session?.getCustom(), + context: renderOptions?.context || getContext() || {}, }); } else { @@ -542,7 +543,7 @@ async function renderSystemPrompt< role: 'system', content: normalizeParts( await options.system(input, { - state: session?.state, + state: session?.getCustom(), context: renderOptions?.context || getContext() || {}, }) ), @@ -588,7 +589,7 @@ async function renderMessages< if (typeof options.messages === 'function') { messages.push( ...(await options.messages(input, { - state: session?.state, + state: session?.getCustom(), context: renderOptions?.context || getContext() || {}, history: renderOptions?.messages, })) @@ -604,7 +605,7 @@ async function renderMessages< input, context: { ...(renderOptions?.context || getContext()), - state: session?.state, + state: session?.getCustom(), }, messages: renderOptions?.messages?.map((m) => Message.parseData(m) @@ -643,7 +644,7 @@ async function renderUserPrompt< role: 'user', content: normalizeParts( await options.prompt(input, { - state: session?.state, + state: session?.getCustom(), context: renderOptions?.context || getContext() || {}, }) ), @@ -744,7 +745,7 @@ async function renderDotpromptToParts< input, context: { ...(renderOptions?.context || getContext()), - state: session?.state, + state: session?.getCustom(), }, }); if (renderred.messages.length !== 1) { diff --git a/js/ai/src/session.ts b/js/ai/src/session.ts index 6a82f2cbfe..10ac5e0ad5 100644 --- a/js/ai/src/session.ts +++ b/js/ai/src/session.ts @@ -1,5 +1,5 @@ /** - * Copyright 2024 Google LLC + * Copyright 2026 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,173 +14,577 @@ * limitations under the License. */ -import { getAsyncContext, type z } from '@genkit-ai/core'; +import { getAsyncContext, z, type ActionContext } from '@genkit-ai/core'; +import { EventEmitter } from '@genkit-ai/core/async'; import type { Registry } from '@genkit-ai/core/registry'; -import { v4 as uuidv4 } from 'uuid'; -import { type GenerateOptions, type MessageData } from './index.js'; - -export type BaseGenerateOptions< - O extends z.ZodTypeAny = z.ZodTypeAny, - CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, -> = Omit, 'prompt'>; - -export interface SessionOptions { - /** Session store implementation for persisting the session state. */ - store?: SessionStore; - /** Initial state of the session. */ - initialState?: S; - /** Custom session Id. */ - sessionId?: string; +import * as fs from 'fs'; +import * as fsp from 'fs/promises'; +import * as path from 'path'; +import { MessageData, MessageSchema } from './model-types.js'; + +import { PartSchema } from './model-types.js'; + +/** + * Schema for tracking persistent artifacts generated during a session turn. + */ +export const ArtifactSchema = z.object({ + name: z.string().optional(), + parts: z.array(PartSchema), + metadata: z.record(z.any()).optional(), +}); + +/** + * Artifact generated during a session turn. + */ +export type Artifact = z.infer; + +/** + * Events signifying a session snapshot persistence point. + */ +export const SnapshotEventSchema = z.enum(['turnEnd', 'invocationEnd']); + +/** + * Event signifying a session snapshot persistence point. + */ +export type SnapshotEvent = z.infer; + +/** + * Schema for session execution state. + */ +export const SessionStateSchema = z.object({ + messages: z.array(MessageSchema).optional(), + custom: z.any().optional(), + artifacts: z.array(ArtifactSchema).optional(), +}); + +/** + * State persisted for a session across turns. + */ +export interface SessionState { + messages?: MessageData[]; + custom?: S; + artifacts?: Artifact[]; +} + +/** + * The execution context provided to a snapshot callback. + */ +export interface SnapshotContext { + state: SessionState; + prevState?: SessionState; + turnIndex: number; + event: 'turnEnd' | 'invocationEnd'; +} + +/** + * Callback triggered before a snapshot is saved. Return false to reject persistence. + */ +export type SnapshotCallback = ( + ctx: SnapshotContext +) => boolean; + +/** + * Saved snapshot of a session's state at a given event point. + */ +export interface SessionSnapshot { + snapshotId: string; + parentId?: string; + createdAt: string; + event: 'turnEnd' | 'invocationEnd'; + state: SessionState; + status?: 'pending' | 'done' | 'failed' | 'aborted'; + + error?: { + status: string; + message: string; + details?: any; + }; } /** - * Session encapsulates a statful execution environment for chat. - * Chat session executed within a session in this environment will have acesss to - * session session convesation history. + * Input type for {@link SessionStore.saveSnapshot}. * - * ```ts - * const ai = genkit({...}); - * const chat = ai.chat(); // create a Session - * let response = await chat.send('hi'); // session/history aware conversation - * response = await chat.send('tell me a story'); - * ``` - */ -export class Session { - readonly id: string; - private sessionData?: SessionData; - private store: SessionStore; + * Identical to {@link SessionSnapshot} except that `snapshotId` is optional. + * When omitted the store is responsible for assigning a new identifier + * (enabling stores to encode grouping or routing information in the ID). + * When provided the store performs an upsert — updating the existing snapshot. + */ +export type SessionSnapshotInput = Omit< + SessionSnapshot, + 'snapshotId' +> & { + snapshotId?: string; +}; - constructor( - readonly registry: Registry, - options?: { - id?: string; - stateSchema?: S; - sessionData?: SessionData; - store?: SessionStore; - } - ) { - this.id = options?.id ?? uuidv4(); - this.sessionData = options?.sessionData ?? { - id: this.id, - }; - if (!this.sessionData) { - this.sessionData = { id: this.id }; - } - if (!this.sessionData.threads) { - this.sessionData!.threads = {}; - } - this.store = options?.store ?? new InMemorySessionStore(); +/** + * Options provided to the session store methods. + */ +export interface SessionStoreOptions { + context?: ActionContext; +} + +/** + * A function that receives the current snapshot and returns the updated + * snapshot to persist. + * + * - Return the mutated snapshot to save it. + * - Return `null` to silently skip the update (no-op). + * - Throw to abort with an error (e.g. precondition failure). + */ +export type SnapshotMutator = ( + current: SessionSnapshot | undefined +) => SessionSnapshotInput | null; + +/** + * Interface for persistent session snapshot storage. + */ +export interface SessionStore { + getSnapshot( + snapshotId: string, + options?: SessionStoreOptions + ): Promise | undefined>; + + /** + * Atomically reads the current snapshot (if `snapshotId` is provided), + * passes it to `mutator`, and persists the result. + * + * - When `snapshotId` is provided the store reads the existing snapshot + * and passes it to the mutator. The mutator can inspect the current + * state (e.g. to check for concurrent status changes) and return the + * updated snapshot to save, or `null` to skip the write. + * - When `snapshotId` is `undefined` the store passes `undefined` to + * the mutator (signaling a new snapshot). The store assigns a new + * identifier. + * + * Implementations should ensure the read→mutate→write cycle is atomic + * to prevent race conditions (e.g. a "done" write overwriting a + * concurrent "aborted" status). + * + * The mutator can: + * + * - Return a snapshot to save it. + * - Return `null` to silently skip the write. + * - Throw to abort with an error. + * + * @returns The `snapshotId` that was used, or `null` when the mutator + * returned `null`. + */ + saveSnapshot( + snapshotId: string | undefined, + mutator: SnapshotMutator, + options?: SessionStoreOptions + ): Promise; + + onSnapshotStateChange?( + snapshotId: string, + callback: (snapshot: SessionSnapshot) => void, + options?: SessionStoreOptions + ): void | (() => void); +} + +/** + * State manager for a session turn, tracking messages, custom state, and artifacts. + */ +export class Session extends EventEmitter { + private state: SessionState; + private version: number = 0; + + constructor(initialState: SessionState) { + super(); + this.state = initialState; } - get state(): S | undefined { - return this.sessionData!.state; + /** + * Returns a deep copy of the current session state. + */ + getState(): SessionState { + return structuredClone(this.state); } /** - * Update session state data. + * Retrieves all messages associated with the session. */ - async updateState(data: S): Promise { - let sessionData = this.sessionData; - if (!sessionData) { - sessionData = {} as SessionData; - } - sessionData.state = data; - this.sessionData = sessionData; + getMessages(): MessageData[] { + return this.state.messages || []; + } - await this.store.save(this.id, sessionData); + /** + * Appends a list of messages to the session. + */ + addMessages(messages: MessageData[]) { + this.state.messages = [...(this.state.messages || []), ...messages]; + this.version++; } /** - * Update messages for a given thread. + * Overwrites the session messages. */ - async updateMessages(thread: string, messages: MessageData[]): Promise { - let sessionData = this.sessionData; - if (!sessionData) { - sessionData = {} as SessionData; - } - if (!sessionData.threads) { - sessionData.threads = {}; - } - sessionData.threads[thread] = messages.map((m: any) => - m.toJSON ? m.toJSON() : m - ); - this.sessionData = sessionData; + setMessages(messages: MessageData[]) { + this.state.messages = messages; + this.version++; + } - await this.store.save(this.id, sessionData); + /** + * Retrieves the custom state of the session. + */ + getCustom(): S | undefined { + return this.state.custom; } /** - * Create a chat session with the provided options. - * - * ```ts + * Updates the custom state of the session using a mutator function. + */ + updateCustom(fn: (custom?: S) => S) { + this.state.custom = fn(this.state.custom); + this.version++; + } + /** + * Retrieves the list of artifacts generated during the session. + */ + getArtifacts(): Artifact[] { + return this.state.artifacts || []; + } + + /** + * Adds artifacts to the session, deduplicating items by name. + * Emits 'artifactAdded' for new artifacts and 'artifactUpdated' for replacements. + */ + addArtifacts(artifacts: Artifact[]) { + const existing = this.state.artifacts || []; + const added: Artifact[] = []; + const updated: Artifact[] = []; + + for (const a of artifacts) { + if (a.name) { + const idx = existing.findIndex((e) => e.name === a.name); + if (idx >= 0) { + existing[idx] = a; + updated.push(a); + continue; + } + } + existing.push(a); + added.push(a); + } + this.state.artifacts = existing; + if (added.length + updated.length > 0) { + this.version++; + } + for (const a of added) { + this.emit('artifactAdded', a); + } + for (const a of updated) { + this.emit('artifactUpdated', a); + } + } /** - * Executes provided function within this session context allowing calling - * `ai.currentSession().state` + * Runs the provided function inside the session's context. */ run(fn: () => O) { - return runWithSession(this.registry, this, fn); + return getAsyncContext().run('ai.session', this, fn); } - toJSON() { - return this.sessionData; + /** + * Gets the current mutation version of the session state. + */ + getVersion(): number { + return this.version; } } -export interface SessionData { - id: string; - state?: S; - threads?: Record; -} +/** + * In-memory implementation of persistent Session Store. + */ +export class InMemorySessionStore implements SessionStore { + private snapshots = new Map>(); + private listeners = new Map< + string, + Array<(snapshot: SessionSnapshot) => void> + >(); + + async getSnapshot( + snapshotId: string, + options?: SessionStoreOptions + ): Promise | undefined> { + const snap = this.snapshots.get(snapshotId); + if (!snap) return undefined; + return structuredClone(snap); + } + + async saveSnapshot( + snapshotId: string | undefined, + mutator: SnapshotMutator, + options?: SessionStoreOptions + ): Promise { + const current = snapshotId ? this.snapshots.get(snapshotId) : undefined; + const result = mutator(current ? structuredClone(current) : undefined); + if (result === null) return null; + + const id = snapshotId || result.snapshotId || crypto.randomUUID(); + const full: SessionSnapshot = { + ...result, + snapshotId: id, + }; + this.snapshots.set(id, structuredClone(full)); + const snapshotListeners = this.listeners.get(id); + if (snapshotListeners) { + for (const listener of snapshotListeners) { + listener(structuredClone(full)); + } + } + return id; + } -const sessionAlsKey = 'ai.session'; + onSnapshotStateChange( + snapshotId: string, + callback: (snapshot: SessionSnapshot) => void, + options?: SessionStoreOptions + ): void | (() => void) { + if (!this.listeners.has(snapshotId)) { + this.listeners.set(snapshotId, []); + } + this.listeners.get(snapshotId)!.push(callback); + return () => { + const list = this.listeners.get(snapshotId); + if (list) { + const index = list.indexOf(callback); + if (index >= 0) list.splice(index, 1); + } + }; + } +} /** - * Executes provided function within the provided session state. + * Utility to execute a function bound to a Session instance context. */ export function runWithSession( registry: Registry, session: Session, fn: () => O ): O { - return getAsyncContext().run(sessionAlsKey, session, fn); + return getAsyncContext().run('ai.session', session, fn); } -/** Returns the current session. */ +/** + * Returns the Session instance active in the current context. + */ export function getCurrentSession( registry: Registry ): Session | undefined { - return getAsyncContext().getStore(sessionAlsKey); + return getAsyncContext().getStore('ai.session'); } -/** Throw when session state errors occur, ex. missing state, etc. */ +/** + * Error thrown during session execution. + */ export class SessionError extends Error { constructor(msg: string) { super(msg); } } -/** Session store persists session data such as state and chat messages. */ -export interface SessionStore { - get(sessionId: string): Promise | undefined>; +// Only UUID-shaped strings are accepted for the convoId component. +const UUID_PATTERN = + /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i; - save(sessionId: string, data: Omit, 'id'>): Promise; +// The suffix part (after the convoId) must be alphanumeric / hyphens / underscores. +const SAFE_SUFFIX_PATTERN = /^[0-9a-zA-Z_-]+$/; + +/** + * Generates a short, unique suffix for a snapshot ID. + * + * Format: `{epochMs}_{random4}` — e.g. `1747000878123_k9m2` + */ +function generateSnapshotSuffix(): string { + const timestamp = Date.now(); + const random = Math.random().toString(36).slice(2, 6); + return `${timestamp}_${random}`; } -export function inMemorySessionStore() { - return new InMemorySessionStore(); +/** + * Composes a snapshot ID from a conversation ID and a short suffix. + * + * Format: `{convoId}_{epochMs}_{random}` + */ +function composeSnapshotId(convoId: string, suffix: string): string { + return `${convoId}_${suffix}`; } -class InMemorySessionStore implements SessionStore { - private data: Record> = {}; +/** + * Parses a composite snapshot ID into its conversation ID and file-name suffix. + * + * The convoId is a UUID (36 chars with hyphens). Since UUIDs never contain + * underscores, the first `_` after the UUID boundary reliably separates the + * two parts. + * + * @throws If the ID cannot be parsed or the convoId is not a valid UUID. + */ +function parseSnapshotId(snapshotId: string): { + convoId: string; + suffix: string; +} { + // UUID is always 36 chars (8-4-4-4-12). The separator `_` follows at index 36. + if (snapshotId.length < 38 || snapshotId[36] !== '_') { + throw new Error( + `Invalid snapshotId: expected format "{uuid}_{suffix}", got "${snapshotId}"` + ); + } + const convoId = snapshotId.slice(0, 36); + const suffix = snapshotId.slice(37); + if (!UUID_PATTERN.test(convoId)) { + throw new Error( + `Invalid snapshotId: convoId component is not a valid UUID ("${convoId}")` + ); + } + if (!suffix || !SAFE_SUFFIX_PATTERN.test(suffix)) { + throw new Error( + `Invalid snapshotId: suffix component is invalid ("${suffix}")` + ); + } + return { convoId, suffix }; +} + +/** + * A Node.js file-system backed session snapshot store. + * + * Snapshots belonging to the same conversation are grouped in a shared + * sub-directory keyed by a conversation ID that is embedded in the + * `snapshotId` itself. + * + * ID format: `{convoId}_{epochMs}_{random}` + * + * File layout: `dirPath///_.json` + */ +export class FileSessionStore implements SessionStore { + private dirPath: string; + private maxPersistedChainLength?: number; + private snapshotPathPrefix?: ( + snapshotId: string, + options?: SessionStoreOptions + ) => string; - async get(sessionId: string): Promise | undefined> { - return this.data[sessionId]; + /** + * @param dirPath Directory where snapshot JSON files are stored. + * @param options.maxPersistedChainLength When set, snapshots older than this + * many entries in a chain are automatically deleted on each save. + * @param options.snapshotPathPrefix Returns a sub-directory prefix per + * snapshot, useful for multi-tenant isolation. Defaults to `"global"`. + */ + constructor( + dirPath: string, + options?: { + maxPersistedChainLength?: number; + snapshotPathPrefix?: ( + snapshotId: string, + options?: SessionStoreOptions + ) => string; + } + ) { + this.dirPath = path.resolve(dirPath); + fs.mkdirSync(this.dirPath, { recursive: true }); + this.maxPersistedChainLength = options?.maxPersistedChainLength; + this.snapshotPathPrefix = options?.snapshotPathPrefix; } - async save(sessionId: string, sessionData: SessionData): Promise { - this.data[sessionId] = sessionData; + private async ensureDir(dir: string): Promise { + await fsp.mkdir(dir, { recursive: true }); + } + + /** + * Resolves the file path for a given composite snapshotId. + */ + private async getFilePath( + snapshotId: string, + options?: SessionStoreOptions + ): Promise { + const { convoId, suffix } = parseSnapshotId(snapshotId); + const prefix = this.snapshotPathPrefix + ? this.snapshotPathPrefix(snapshotId, options) + : 'global'; + const dir = path.join(this.dirPath, prefix, convoId); + await this.ensureDir(dir); + return path.join(dir, `${suffix}.json`); + } + + async getSnapshot( + snapshotId: string, + options?: SessionStoreOptions + ): Promise | undefined> { + const filePath = await this.getFilePath(snapshotId, options); + try { + const fileContents = await fsp.readFile(filePath, 'utf-8'); + return JSON.parse(fileContents) as SessionSnapshot; + } catch (e: unknown) { + if ((e as NodeJS.ErrnoException).code === 'ENOENT') return undefined; + throw e; + } + } + + async saveSnapshot( + snapshotId: string | undefined, + mutator: SnapshotMutator, + options?: SessionStoreOptions + ): Promise { + // Read the current snapshot when an ID is provided. + const current = snapshotId + ? await this.getSnapshot(snapshotId, options) + : undefined; + + const snapshot = mutator(current); + if (snapshot === null) return null; + + // Determine the final ID. + let id: string; + if (snapshotId) { + // Upsert — the caller supplied an ID. + id = snapshotId; + } else if (snapshot.snapshotId) { + id = snapshot.snapshotId; + } else { + // New snapshot — derive the convoId from parentId or start a new + // conversation. + let convoId: string; + if (snapshot.parentId) { + ({ convoId } = parseSnapshotId(snapshot.parentId)); + } else { + convoId = crypto.randomUUID(); + } + id = composeSnapshotId(convoId, generateSnapshotSuffix()); + } + + const full: SessionSnapshot = { + ...snapshot, + snapshotId: id, + }; + const filePath = await this.getFilePath(id, options); + await fsp.writeFile(filePath, JSON.stringify(full, null, 2), 'utf-8'); + + if (this.maxPersistedChainLength && this.maxPersistedChainLength > 0) { + let cur: SessionSnapshot | undefined = full; + const chain: string[] = []; + + while (cur) { + chain.push(cur.snapshotId); + if (cur.parentId) { + cur = await this.getSnapshot(cur.parentId, options); + } else { + break; + } + } + + if (chain.length > this.maxPersistedChainLength) { + for (let i = this.maxPersistedChainLength; i < chain.length; i++) { + const pathToDelete = await this.getFilePath(chain[i], options); + await fsp.unlink(pathToDelete).catch((e: unknown) => { + if ((e as NodeJS.ErrnoException).code !== 'ENOENT') throw e; + }); + } + } + } + + return id; } } diff --git a/js/ai/tests/agent_test.ts b/js/ai/tests/agent_test.ts new file mode 100644 index 0000000000..23672e2ec2 --- /dev/null +++ b/js/ai/tests/agent_test.ts @@ -0,0 +1,2432 @@ +/** + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { initNodeFeatures } from '@genkit-ai/core/node'; +import { Registry } from '@genkit-ai/core/registry'; +import * as assert from 'assert'; +import { describe, it } from 'node:test'; + +import { z } from '@genkit-ai/core'; +import { + AgentStreamChunk, + SessionRunner, + defineAgent, + defineCustomAgent, + definePromptAgent, +} from '../src/agent.js'; +import { definePrompt } from '../src/prompt.js'; +import { + InMemorySessionStore, + Session, + type SessionSnapshot, +} from '../src/session.js'; +import { ToolInterruptError, defineTool, interrupt } from '../src/tool.js'; +import { + defineEchoModel, + defineProgrammableModel, + type ProgrammableModel, +} from './helpers.js'; + +initNodeFeatures(); + +/** + * Returns a Promise that resolves once the given snapshotId reaches targetStatus + * in the store. Rejects after timeoutMs if the status is never reached. + */ +function waitForSnapshotStatus( + store: InMemorySessionStore, + snapshotId: string, + targetStatus: NonNullable['status']>, + timeoutMs = 5000 +): Promise> { + return new Promise((resolve, reject) => { + const timer = setTimeout( + () => + reject( + new Error( + `Timed out waiting for snapshot ${snapshotId} to reach status "${targetStatus}"` + ) + ), + timeoutMs + ); + + const unsubscribeFn = store.onSnapshotStateChange(snapshotId, (snap) => { + if (snap.status === targetStatus) { + clearTimeout(timer); + if (typeof unsubscribeFn === 'function') unsubscribeFn(); + resolve(snap); + } + }); + + // Check in case already at the target status. + store.getSnapshot(snapshotId).then((snap) => { + if (snap?.status === targetStatus) { + clearTimeout(timer); + if (typeof unsubscribeFn === 'function') unsubscribeFn(); + resolve(snap); + } + }); + }); +} + +describe('Agent', () => { + describe('Session', () => { + it('should maintain custom state', () => { + const session = new Session<{ foo: string }>({ custom: { foo: 'bar' } }); + assert.strictEqual(session.getCustom()?.foo, 'bar'); + + session.updateCustom((c) => ({ ...c!, foo: 'baz' })); + assert.strictEqual(session.getCustom()?.foo, 'baz'); + }); + + it('should add and set messages', () => { + const session = new Session({}); + session.addMessages([{ role: 'user', content: [{ text: 'hi' }] }]); + assert.strictEqual(session.getMessages().length, 1); + assert.strictEqual(session.getMessages()[0].role, 'user'); + + session.setMessages([{ role: 'model', content: [{ text: 'hello' }] }]); + assert.strictEqual(session.getMessages().length, 1); + assert.strictEqual(session.getMessages()[0].role, 'model'); + }); + + it('should add and deduplicate artifacts', () => { + const session = new Session({}); + session.addArtifacts([{ name: 'art1', parts: [{ text: 'content1' }] }]); + assert.strictEqual(session.getArtifacts().length, 1); + + // Add with same name should replace + session.addArtifacts([{ name: 'art1', parts: [{ text: 'content2' }] }]); + assert.strictEqual(session.getArtifacts().length, 1); + assert.deepStrictEqual(session.getArtifacts()[0].parts, [ + { text: 'content2' }, + ]); + + // Add with different name should append + session.addArtifacts([{ name: 'art2', parts: [{ text: 'content3' }] }]); + assert.strictEqual(session.getArtifacts().length, 2); + }); + + it('should process all artifacts in a batch without dropping any', () => { + const session = new Session({}); + session.addArtifacts([{ name: 'art1', parts: [{ text: 'v1' }] }]); + + // Replace art1 and add art2 and art3 in the same batch. + session.addArtifacts([ + { name: 'art1', parts: [{ text: 'v2' }] }, + { name: 'art2', parts: [{ text: 'new' }] }, + { name: 'art3', parts: [{ text: 'another' }] }, + ]); + + const arts = session.getArtifacts(); + assert.strictEqual(arts.length, 3); + assert.strictEqual( + arts.find((a) => a.name === 'art1')?.parts[0].text, + 'v2' + ); + assert.strictEqual( + arts.find((a) => a.name === 'art2')?.parts[0].text, + 'new' + ); + assert.strictEqual( + arts.find((a) => a.name === 'art3')?.parts[0].text, + 'another' + ); + }); + + it('should emit artifactAdded for new and artifactUpdated for replaced', () => { + const session = new Session({}); + const added: string[] = []; + const updated: string[] = []; + session.on('artifactAdded', (a: { name?: string }) => + added.push(a.name ?? '') + ); + session.on('artifactUpdated', (a: { name?: string }) => + updated.push(a.name ?? '') + ); + + session.addArtifacts([{ name: 'art1', parts: [] }]); + session.addArtifacts([ + { name: 'art1', parts: [] }, // replace + { name: 'art2', parts: [] }, // new + ]); + + assert.deepStrictEqual(added, ['art1', 'art2']); + assert.deepStrictEqual(updated, ['art1']); + }); + + it('should increment version on mutation', () => { + const session = new Session({}); + const v0 = session.getVersion(); + + session.addMessages([{ role: 'user', content: [{ text: 'hi' }] }]); + const v1 = session.getVersion(); + assert.ok(v1 > v0); + + session.updateCustom((c) => c); + const v2 = session.getVersion(); + assert.ok(v2 > v1); + + session.addArtifacts([{ name: 'a', parts: [] }]); + const v3 = session.getVersion(); + assert.ok(v3 > v2); + }); + }); + + describe('InMemorySessionStore', () => { + it('should save and get snapshots', async () => { + const store = new InMemorySessionStore<{ foo: string }>(); + const snapshot = { + snapshotId: 'snap-123', + createdAt: new Date().toISOString(), + event: 'turnEnd' as const, + state: { custom: { foo: 'bar' } }, + }; + await store.saveSnapshot('snap-123', () => snapshot); + + const got = await store.getSnapshot('snap-123'); + assert.deepStrictEqual(got, snapshot); + }); + + it('should return undefined for missing snapshot', async () => { + const store = new InMemorySessionStore(); + const got = await store.getSnapshot('missing'); + assert.strictEqual(got, undefined); + }); + + it('should deep copy on save and get', async () => { + const store = new InMemorySessionStore<{ foo: string }>(); + const state = { foo: 'bar' }; + const snapshot = { + snapshotId: 'snap-123', + createdAt: new Date().toISOString(), + event: 'turnEnd' as const, + state: { custom: state }, + }; + await store.saveSnapshot('snap-123', () => snapshot); + + // Mutate local state + state.foo = 'baz'; + + const got = await store.getSnapshot('snap-123'); + assert.strictEqual(got?.state.custom?.foo, 'bar'); + }); + }); + + describe('SessionRunner', () => { + it('should loop over inputs and call handler', async () => { + const session = new Session({}); + const inputs = [ + { messages: [{ role: 'user' as const, content: [{ text: 'hi' }] }] }, + { messages: [{ role: 'user' as const, content: [{ text: 'bye' }] }] }, + ]; + + async function* inputGen() { + for (const input of inputs) { + yield input; + } + } + + const runner = new SessionRunner(session, inputGen()); + let turns = 0; + const seenInputs: any[] = []; + + await runner.run(async (input) => { + turns++; + seenInputs.push(input); + }); + + assert.strictEqual(turns, 2); + assert.deepStrictEqual(seenInputs, inputs); + assert.strictEqual(session.getMessages().length, 2); + }); + + it('should trigger snapshots if store is present', async () => { + const store = new InMemorySessionStore(); + const session = new Session({}); + const inputs = [ + { messages: [{ role: 'user' as const, content: [{ text: 'hi' }] }] }, + ]; + + async function* inputGen() { + for (const input of inputs) { + yield input; + } + } + + let turnEnded = false; + let turnSnapshotId: string | undefined; + + const runner = new SessionRunner(session, inputGen(), { + store, + onEndTurn: (snapshotId) => { + turnEnded = true; + turnSnapshotId = snapshotId; + }, + }); + + await runner.run(async () => {}); + + assert.ok(turnEnded); + assert.ok(turnSnapshotId); + + const saved = await store.getSnapshot(turnSnapshotId!); + assert.ok(saved); + assert.strictEqual(saved?.snapshotId, turnSnapshotId); + }); + + it('should respect snapshot callback', async () => { + const store = new InMemorySessionStore(); + const session = new Session({}); + const inputs = [ + { messages: [{ role: 'user' as const, content: [{ text: 'hi' }] }] }, + ]; + + async function* inputGen() { + for (const input of inputs) { + yield input; + } + } + + const runner = new SessionRunner(session, inputGen(), { + store, + snapshotCallback: () => false, // Never snapshot + }); + + await runner.run(async () => {}); + + // Verify the store is empty (callback suppressed all snapshots). + const onEndTurnSnapshotId = await new Promise( + (resolve) => { + const r = new SessionRunner(session, inputGen(), { + store, + onEndTurn: resolve, + }); + r.run(async () => {}).catch(() => {}); + } + ); + // The callback-suppressed runner should have produced no entries. + const keys = Array.from((store as any).snapshots.keys()) as string[]; + // Only the snapshot from the second (non-callback) runner should exist. + assert.ok(keys.every((k) => k === onEndTurnSnapshotId)); + }); + }); + + describe('defineCustomAgent', () => { + it('should set client stateManagement and abortable=false when no store is provided', () => { + const registry = new Registry(); + const agent = defineCustomAgent( + registry, + { name: 'noStoreMetadataTest' }, + async () => ({ artifacts: [] }) + ); + assert.strictEqual( + agent.__action.metadata?.agent?.stateManagement, + 'client' + ); + assert.strictEqual(agent.__action.metadata?.agent?.abortable, false); + }); + + it('should set server stateManagement and abortable=true when store with onSnapshotStateChange is provided', () => { + const registry = new Registry(); + const store = new InMemorySessionStore(); + const agent = defineCustomAgent( + registry, + { name: 'fullStoreMetadataTest', store }, + async () => ({ artifacts: [] }) + ); + assert.strictEqual( + agent.__action.metadata?.agent?.stateManagement, + 'server' + ); + assert.strictEqual(agent.__action.metadata?.agent?.abortable, true); + }); + + it('should reject init.state for server-managed agents (store is set)', async () => { + const registry = new Registry(); + const store = new InMemorySessionStore<{ foo: string }>(); + + const flow = defineCustomAgent( + registry, + { name: 'rejectInitStateTest', store }, + async (sess) => { + await sess.run(async () => {}); + return { + artifacts: [], + message: { role: 'model', content: [{ text: 'done' }] }, + }; + } + ); + + // Pass init.state — should throw FAILED_PRECONDITION for server-managed agents + const session = flow.streamBidi({ + state: { + custom: { foo: 'should-be-rejected' }, + messages: [{ role: 'user', content: [{ text: 'stale history' }] }], + artifacts: [], + }, + }); + session.send({ + messages: [{ role: 'user', content: [{ text: 'hello' }] }], + }); + session.close(); + + try { + for await (const _ of session.stream) { + } + await session.output; + assert.fail('Expected FAILED_PRECONDITION error'); + } catch (e: any) { + assert.ok( + e.message.includes("Cannot send 'state' to agent"), + `Expected FAILED_PRECONDITION error, got: ${e.message}` + ); + assert.strictEqual(e.status, 'FAILED_PRECONDITION'); + } + }); + + it('should use init.state for client-managed agents (no store)', async () => { + const registry = new Registry(); + + const flow = defineCustomAgent( + registry, + { name: 'useInitStateTest' }, + async (sess) => { + await sess.run(async () => {}); + return { + artifacts: [], + message: { role: 'model', content: [{ text: 'done' }] }, + }; + } + ); + + // Pass init.state — it should be used because no store is set + const session = flow.streamBidi({ + state: { + custom: { foo: 'seeded' }, + messages: [{ role: 'user', content: [{ text: 'prior msg' }] }], + artifacts: [], + }, + }); + session.send({ + messages: [{ role: 'user', content: [{ text: 'hello' }] }], + }); + session.close(); + + for await (const _ of session.stream) { + } + const output = await session.output; + + // State should include the seeded state plus the new message + assert.ok(output.state); + assert.strictEqual((output.state!.custom as any).foo, 'seeded'); + // Messages: 1 from init.state + 1 from input + assert.strictEqual(output.state!.messages!.length, 2); + assert.strictEqual( + output.state!.messages![0].content[0].text, + 'prior msg' + ); + assert.strictEqual(output.state!.messages![1].content[0].text, 'hello'); + }); + + it('should set server stateManagement and abortable=false when store lacks onSnapshotStateChange', () => { + const registry = new Registry(); + const store: any = { + getSnapshot: async () => undefined, + saveSnapshot: async () => {}, + // no onSnapshotStateChange + }; + const agent = defineCustomAgent( + registry, + { name: 'noAbortStoreMetadataTest', store }, + async () => ({ artifacts: [] }) + ); + assert.strictEqual( + agent.__action.metadata?.agent?.stateManagement, + 'server' + ); + assert.strictEqual(agent.__action.metadata?.agent?.abortable, false); + }); + + it('should register and execute agent', async () => { + const registry = new Registry(); + + const flow = defineCustomAgent( + registry, + { name: 'testFlow' }, + async (sess, { sendChunk }) => { + let receivedInput = false; + await sess.run(async (input) => { + receivedInput = true; + assert.strictEqual(input.messages?.[0].role, 'user'); + }); + assert.ok(receivedInput); + return { message: { role: 'model', content: [{ text: 'done' }] } }; + } + ); + + const session = flow.streamBidi({}); + + session.send({ + messages: [{ role: 'user' as const, content: [{ text: 'hi' }] }], + }); + session.close(); + + const chunks: AgentStreamChunk[] = []; + for await (const chunk of session.stream) { + chunks.push(chunk); + } + + const output = await session.output; + assert.strictEqual(output.message?.role, 'model'); + assert.strictEqual(output.message?.content[0].text, 'done'); + }); + + it('should automatically stream artifacts added via Session.addArtifacts()', async () => { + const registry = new Registry(); + + const flow = defineCustomAgent( + registry, + { name: 'testEventFlow' }, + async (sess, { sendChunk }) => { + await sess.run(async (input) => { + sess.session.addArtifacts([ + { name: 'testArt', parts: [{ text: 'testPart' }] }, + ]); + }); + return { message: { role: 'model', content: [{ text: 'done' }] } }; + } + ); + + const session = flow.streamBidi({}); + session.send({ + messages: [{ role: 'user' as const, content: [{ text: 'hi' }] }], + }); + session.close(); + + const chunks: AgentStreamChunk[] = []; + for await (const chunk of session.stream) { + chunks.push(chunk); + } + + const artChunks = chunks.filter((c) => !!c.artifact); + assert.strictEqual(artChunks.length, 1); + assert.strictEqual(artChunks[0].artifact?.name, 'testArt'); + }); + + it('should stream artifactUpdated chunks when an artifact is replaced', async () => { + const registry = new Registry(); + + const flow = defineCustomAgent( + registry, + { name: 'testArtifactUpdateFlow' }, + async (sess) => { + await sess.run(async () => { + sess.session.addArtifacts([{ name: 'a', parts: [{ text: 'v1' }] }]); + sess.session.addArtifacts([{ name: 'a', parts: [{ text: 'v2' }] }]); + }); + return {}; + } + ); + + const session = flow.streamBidi({}); + session.send({ messages: [{ role: 'user', content: [{ text: 'go' }] }] }); + session.close(); + + const chunks: AgentStreamChunk[] = []; + for await (const chunk of session.stream) { + chunks.push(chunk); + } + + const artChunks = chunks.filter((c) => !!c.artifact); + assert.strictEqual(artChunks.length, 2); + assert.strictEqual(artChunks[0].artifact?.parts[0].text, 'v1'); + assert.strictEqual(artChunks[1].artifact?.parts[0].text, 'v2'); + }); + }); + + describe('definePromptAgent', () => { + it('should register and execute agent from prompt', async () => { + const registry = new Registry(); + defineEchoModel(registry); + definePrompt(registry, { + name: 'agent', + model: 'echoModel', + config: { temperature: 1 }, + system: 'hello from template', + }); + + const flow = definePromptAgent(registry, { + promptName: 'agent', + }); + + const session = flow.streamBidi({}); + session.send({ + messages: [{ role: 'user' as const, content: [{ text: 'hi' }] }], + }); + session.close(); + + const chunks: AgentStreamChunk[] = []; + for await (const chunk of session.stream) { + chunks.push(chunk); + } + + const output = await session.output; + assert.strictEqual(output.message?.role, 'model'); + }); + + it('should detach asynchronously and continue execution in the background', async () => { + const store = new InMemorySessionStore<{ foo: string }>(); + let resolvePromise: () => void = () => {}; + const releasePromise = new Promise((resolve) => { + resolvePromise = resolve; + }); + + const flow = defineCustomAgent( + new Registry(), + { + name: 'detachTest', + store, + }, + async (sess, { sendChunk }) => { + await sess.run(async () => { + await releasePromise; + }); + return { + artifacts: [], + message: { role: 'model', content: [{ text: 'hi' }] }, + }; + } + ); + + const session = flow.streamBidi({}); + session.send({ + messages: [{ role: 'user' as const, content: [{ text: 'hi' }] }], + detach: true, + }); + + const output = await session.output; + const snapshotId = output.snapshotId; + assert.ok(snapshotId); + + const snapPending = await store.getSnapshot(snapshotId!); + assert.strictEqual(snapPending?.status, 'pending'); + + resolvePromise(); + session.close(); + + const snapDone = await waitForSnapshotStatus(store, snapshotId!, 'done'); + assert.strictEqual(snapDone.status, 'done'); + }); + + it('should abort a detached agent', async () => { + const store = new InMemorySessionStore<{ foo: string }>(); + let aborted = false; + + const flow = defineCustomAgent( + new Registry(), + { + name: 'abortTest', + store, + }, + async (sess, { abortSignal }) => { + if (abortSignal) { + abortSignal.onabort = () => { + aborted = true; + }; + } + await sess.run(async () => { + await new Promise((resolve) => setTimeout(resolve, 5000)); + }); + return { + artifacts: [], + message: { role: 'model', content: [{ text: 'hi' }] }, + }; + } + ); + + const session = flow.streamBidi({}); + session.send({ + messages: [{ role: 'user' as const, content: [{ text: 'hi' }] }], + detach: true, + }); + + const output = await session.output; + const snapshotId = output.snapshotId; + assert.ok(snapshotId); + + const previousStatus = await flow.abort(snapshotId!); + + assert.strictEqual(previousStatus, 'pending'); + const snapAborted = await store.getSnapshot(snapshotId!); + assert.strictEqual(snapAborted?.status, 'aborted'); + // AbortController.abort() fires onabort synchronously, so no delay needed. + assert.strictEqual(aborted, true); + }); + + it('should not override terminal status when aborting an already-completed flow', async () => { + const store = new InMemorySessionStore<{ foo: string }>(); + + const flow = defineCustomAgent( + new Registry(), + { + name: 'abortDoneTest', + store, + }, + async (sess) => { + await sess.run(async () => {}); + return { + artifacts: [], + message: { role: 'model', content: [{ text: 'hi' }] }, + }; + } + ); + + const session = flow.streamBidi({}); + session.send({ + messages: [{ role: 'user' as const, content: [{ text: 'hi' }] }], + }); + session.close(); + const output = await session.output; + assert.ok(output.snapshotId); + + // Snapshot should be 'done' now + const snapBefore = await store.getSnapshot(output.snapshotId!); + assert.strictEqual(snapBefore?.status, 'done'); + + // Abort returns the previous status but does not override terminal states + const previousStatus = await flow.abort(output.snapshotId!); + assert.strictEqual(previousStatus, 'done'); + + // Snapshot should still be 'done' — the mutator skips terminal states + const snapAfter = await store.getSnapshot(output.snapshotId!); + assert.strictEqual(snapAfter?.status, 'done'); + }); + + it('should return undefined when aborting a non-existent snapshot', async () => { + const store = new InMemorySessionStore<{ foo: string }>(); + + const flow = defineCustomAgent( + new Registry(), + { + name: 'abortMissingTest', + store, + }, + async (sess) => { + await sess.run(async () => {}); + return { + artifacts: [], + message: { role: 'model', content: [{ text: 'hi' }] }, + }; + } + ); + + const previousStatus = await flow.abort('non-existent-id'); + assert.strictEqual(previousStatus, undefined); + }); + + it('should throw error when detach is requested without session store', async () => { + const flow = defineCustomAgent( + new Registry(), + { + name: 'noStoreTest', + }, + async (sess) => { + await sess.run(async () => {}); + return { + artifacts: [], + message: { role: 'model', content: [{ text: 'hi' }] }, + }; + } + ); + + const session = flow.streamBidi({}); + session.send({ + messages: [{ role: 'user' as const, content: [{ text: 'hi' }] }], + detach: true, + }); + + try { + await session.output; + assert.fail('Should have thrown error'); + } catch (e: any) { + assert.strictEqual( + e.message, + 'FAILED_PRECONDITION: Detach is only supported when a session store is provided.' + ); + } + }); + + it('should save failed snapshot if detached flow throws', async () => { + const store = new InMemorySessionStore<{ foo: string }>(); + let resolvePromise: () => void = () => {}; + const releasePromise = new Promise((resolve) => { + resolvePromise = resolve; + }); + + const flow = defineCustomAgent( + new Registry(), + { + name: 'detachErrorTest', + store, + }, + async (sess, { sendChunk }) => { + await sess.run(async () => { + await releasePromise; + throw new Error('intentional background failure'); + }); + return { + artifacts: [], + message: { role: 'model', content: [{ text: 'hi' }] }, + }; + } + ); + + const session = flow.streamBidi({}); + session.send({ + messages: [{ role: 'user' as const, content: [{ text: 'hi' }] }], + detach: true, + }); + + const output = await session.output; + const snapshotId = output.snapshotId; + assert.ok(snapshotId); + + resolvePromise(); + session.close(); + + const snapFailed = await waitForSnapshotStatus( + store, + snapshotId!, + 'failed' + ); + assert.strictEqual(snapFailed.status, 'failed'); + assert.strictEqual( + snapFailed.error?.message, + 'intentional background failure' + ); + }); + + it('should mark snapshot aborted even without subscription support', async () => { + const baseStore = new InMemorySessionStore(); + const store = Object.assign(Object.create(baseStore), { + onSnapshotStateChange: undefined, + getSnapshot: baseStore.getSnapshot.bind(baseStore), + saveSnapshot: baseStore.saveSnapshot.bind(baseStore), + }) as InMemorySessionStore; + + let resolveBlock: () => void = () => {}; + const blockPromise = new Promise((resolve) => { + resolveBlock = resolve; + }); + + const flow = defineCustomAgent( + new Registry(), + { + name: 'legacyStoreTest', + store, + }, + async (sess, { sendChunk }) => { + await sess.run(async () => { + await blockPromise; // Keep flow pending until abort is called + }); + return { + artifacts: [], + message: { role: 'model', content: [{ text: 'hi' }] }, + }; + } + ); + + const session = flow.streamBidi({}); + session.send({ + messages: [{ role: 'user' as const, content: [{ text: 'hi' }] }], + detach: true, + }); + + const output = await session.output; + const snapshotId = output.snapshotId; + + // Snapshot should be 'pending' since the flow is still blocked + const snapBefore = await store.getSnapshot(snapshotId!); + assert.strictEqual(snapBefore?.status, 'pending'); + + await flow.abort(snapshotId!); + + const snapshot = await store.getSnapshot(snapshotId!); + assert.strictEqual(snapshot?.status, 'aborted'); + + // Release the flow so it doesn't hang + resolveBlock(); + session.close(); + }); + + it('should fetch snapshot data via companion action', async () => { + const store = new InMemorySessionStore<{ foo: string }>(); + const flow = defineCustomAgent( + new Registry(), + { + name: 'companionActionFlow', + store, + }, + async (sess) => { + return { + artifacts: [], + message: { role: 'model', content: [{ text: 'hi' }] }, + }; + } + ); + + const session = flow.streamBidi({}); + session.send({ + messages: [{ role: 'user' as const, content: [{ text: 'hi' }] }], + }); + session.close(); + const output = await session.output; + + const snapData = await flow.getSnapshotData(output.snapshotId!); + assert.strictEqual(snapData?.snapshotId, output.snapshotId); + }); + + it('should chain parentId properly across session snapshots', async () => { + const store = new InMemorySessionStore<{ foo: string }>(); + const flow = defineCustomAgent( + new Registry(), + { + name: 'lineageTest', + store, + }, + async (sess) => { + await sess.run(async () => {}); + return { + artifacts: [], + message: { role: 'model', content: [{ text: 'hi' }] }, + }; + } + ); + + const session1 = flow.streamBidi({}); + session1.send({ + messages: [{ role: 'user' as const, content: [{ text: 'first' }] }], + }); + session1.close(); + const output1 = await session1.output; + + const session2 = flow.streamBidi({ + snapshotId: output1.snapshotId, + }); + + session2.send({ + messages: [{ role: 'user' as const, content: [{ text: 'second' }] }], + }); + session2.close(); + const output2 = await session2.output; + + const snapshot2 = await store.getSnapshot(output2.snapshotId!); + assert.strictEqual(snapshot2?.parentId, output1.snapshotId); + }); + + it('should detach immediately when a detach input is queued', async () => { + const store = new InMemorySessionStore<{ foo: string }>(); + let releasePromise: () => void = () => {}; + const blockPromise = new Promise((resolve) => { + releasePromise = resolve; + }); + + const flow = defineCustomAgent( + new Registry(), + { + name: 'immediateDetachTest', + store, + }, + async (sess) => { + await sess.run(async () => { + await blockPromise; + }); + return { + artifacts: [], + message: { role: 'model', content: [{ text: 'hi' }] }, + }; + } + ); + + const session = flow.streamBidi({}); + session.send({ + messages: [ + { role: 'user' as const, content: [{ text: 'heavy task' }] }, + ], + }); + session.send({ + detach: true, + }); + + const output = await session.output; + assert.ok(output.snapshotId); + const snapshot = await store.getSnapshot(output.snapshotId!); + assert.strictEqual(snapshot?.status, 'pending'); + + releasePromise(); + session.close(); + }); + + it('should process messages even when detach is present in the same payload', async () => { + const store = new InMemorySessionStore<{ foo: string }>(); + const flow = defineCustomAgent( + new Registry(), + { + name: 'mixedPayloadTest', + store, + }, + async (sess) => { + await sess.run(async () => {}); + return { + artifacts: [], + message: { role: 'model', content: [{ text: 'hi' }] }, + }; + } + ); + + const session = flow.streamBidi({}); + session.send({ + messages: [ + { role: 'user' as const, content: [{ text: 'appended message' }] }, + ], + detach: true, + }); + + const output = await session.output; + assert.ok(output.snapshotId); + + const snapDone = await waitForSnapshotStatus( + store, + output.snapshotId!, + 'done' + ); + assert.ok(snapDone.state.messages); + assert.strictEqual(snapDone.state.messages.length, 1); + assert.strictEqual( + snapDone.state.messages[0].content[0].text, + 'appended message' + ); + + session.close(); + }); + + it('should accumulate message history across multiple turns in one invocation', async () => { + const registry = new Registry(); + defineEchoModel(registry); + definePrompt(registry, { + name: 'multiTurnAccumPrompt', + model: 'echoModel', + config: { temperature: 1 }, + system: 'sys', + }); + + const flow = definePromptAgent(registry, { + promptName: 'multiTurnAccumPrompt', + }); + + const session = flow.streamBidi({}); + session.send({ + messages: [{ role: 'user' as const, content: [{ text: 'turn1' }] }], + }); + session.send({ + messages: [{ role: 'user' as const, content: [{ text: 'turn2' }] }], + }); + session.close(); + + const chunks: AgentStreamChunk[] = []; + for await (const chunk of session.stream) { + chunks.push(chunk); + } + + // Two turns must have completed. + const turnEndChunks = chunks.filter((c) => c.turnEnd !== undefined); + assert.strictEqual(turnEndChunks.length, 2); + + const output = await session.output; + assert.strictEqual(output.message?.role, 'model'); + + // The second-turn echo should contain the first model reply in its history, + // proving the session history was passed to the second generate call. + const turn2Text = + output.message?.content.map((c) => c.text).join('') ?? ''; + assert.ok( + turn2Text.includes('Echo:'), + `Expected second turn to be an echo response, got: ${turn2Text}` + ); + + // Model chunks must have been emitted for both turns. + const modelChunks = chunks.filter((c) => c.modelChunk !== undefined); + assert.ok( + modelChunks.length >= 2, + 'Expected model chunks from both turns' + ); + }); + + it('should successfully handle native tool interrupts and tool response resumption', async () => { + const registry = new Registry(); + registry.apiStability = 'beta'; + const store = new InMemorySessionStore<{}>(); + + const pm = defineProgrammableModel(registry, undefined, 'interruptModel'); + + const myInterrupt = interrupt({ + name: 'myInterrupt', + description: 'Ask user', + inputSchema: z.object({ query: z.string() }), + outputSchema: z.object({ answer: z.string() }), + }); + registry.registerAction('tool', myInterrupt); + + definePrompt(registry, { + name: 'interruptPrompt', + model: 'interruptModel', + tools: ['myInterrupt'], + config: { temperature: 1 }, + }); + + const flow = definePromptAgent(registry, { + promptName: 'interruptPrompt', + store, + }); + + // Phase 1: User says hello, model responds with a toolRequest (interrupt) + pm.handleResponse = async () => { + return { + message: { + role: 'model', + content: [ + { + toolRequest: { + name: 'myInterrupt', + input: { query: 'yes?' }, + ref: '123', + }, + }, + ], + }, + finishReason: 'stop', + }; + }; + + const session1 = flow.streamBidi({}); + session1.send({ + messages: [{ role: 'user', content: [{ text: 'hello' }] }], + }); + session1.close(); // IMPORTANT: close the stream so it doesn't hang! + + for await (const chunk of session1.stream) { + } + const output1 = await session1.output; + + assert.ok(output1.snapshotId); + assert.ok(output1.message); + assert.ok(output1.message.content[0].toolRequest); + assert.strictEqual( + output1.message.content[0].toolRequest.name, + 'myInterrupt' + ); + + // Phase 2: Resume with the tool response + pm.handleResponse = async (req) => { + // Assert that the resumed request contains the tool response! + const lastMsg = req.messages[req.messages.length - 1]; + assert.strictEqual(lastMsg.role, 'tool'); + assert.strictEqual( + (lastMsg.content[0] as any).toolResponse.output.answer, + 'yes indeed' + ); + + return { + message: { + role: 'model', + content: [{ text: 'Task completed successfully!' }], + }, + finishReason: 'stop', + }; + }; + + const session2 = flow.streamBidi({ snapshotId: output1.snapshotId }); + session2.send({ + resume: { + respond: [ + { + toolResponse: { + name: 'myInterrupt', + ref: '123', + output: { answer: 'yes indeed' }, + }, + }, + ], + }, + }); + session2.close(); // IMPORTANT: close the stream so it doesn't hang! + + for await (const chunk of session2.stream) { + } + const output2 = await session2.output; + + assert.strictEqual(output2.message?.role, 'model'); + assert.strictEqual( + output2.message?.content[0].text, + 'Task completed successfully!' + ); + }); + + it('should handle resume.restart for tool re-execution with metadata', async () => { + const registry = new Registry(); + registry.apiStability = 'beta'; + const store = new InMemorySessionStore<{}>(); + + const pm = defineProgrammableModel(registry, undefined, 'restartModel'); + + // Track whether the tool was called and with what resumed metadata + let toolCallCount = 0; + let lastResumedMetadata: any = undefined; + + defineTool( + registry, + { + name: 'dangerousTool', + description: 'A tool that requires confirmation', + inputSchema: z.object({ action: z.string() }), + outputSchema: z.object({ result: z.string() }), + }, + async (input, { resumed }) => { + toolCallCount++; + lastResumedMetadata = resumed; + + if (!resumed) { + // First call — interrupt to ask for user confirmation + throw new ToolInterruptError({ requiresConfirmation: true }); + } + // Restarted with confirmation metadata + return { result: `confirmed and executed ${input.action}` }; + } + ); + + definePrompt(registry, { + name: 'restartPrompt', + model: 'restartModel', + tools: ['dangerousTool'], + config: { temperature: 1 }, + }); + + const flow = definePromptAgent(registry, { + promptName: 'restartPrompt', + store, + }); + + // Phase 1: Model requests the tool. The tool throws ToolInterruptError, + // causing the generate action to return finishReason: 'interrupted'. + pm.handleResponse = async () => { + return { + message: { + role: 'model', + content: [ + { + toolRequest: { + name: 'dangerousTool', + input: { action: 'delete files' }, + ref: 'tr1', + }, + }, + ], + }, + finishReason: 'stop', + }; + }; + + const session1 = flow.streamBidi({}); + session1.send({ + messages: [ + { role: 'user', content: [{ text: 'please delete files' }] }, + ], + }); + session1.close(); + + for await (const chunk of session1.stream) { + } + const output1 = await session1.output; + + assert.ok(output1.snapshotId); + assert.ok(output1.message); + assert.ok(output1.message.content[0].toolRequest); + assert.strictEqual( + output1.message.content[0].toolRequest.name, + 'dangerousTool' + ); + + // Phase 2: Client resumes with restart — re-execute the tool with metadata + toolCallCount = 0; // Reset counter + + pm.handleResponse = async (req) => { + // After restart, the model should receive the tool response from re-execution + const toolMsgs = req.messages.filter((m: any) => m.role === 'tool'); + assert.ok( + toolMsgs.length > 0, + 'Model should receive a tool response message' + ); + const lastToolMsg = toolMsgs[toolMsgs.length - 1]; + assert.strictEqual( + (lastToolMsg.content[0] as any).toolResponse.output.result, + 'confirmed and executed delete files' + ); + + return { + message: { + role: 'model', + content: [{ text: 'Files deleted successfully!' }], + }, + finishReason: 'stop', + }; + }; + + const session2 = flow.streamBidi({ snapshotId: output1.snapshotId }); + session2.send({ + resume: { + restart: [ + { + toolRequest: { + name: 'dangerousTool', + input: { action: 'delete files' }, + ref: 'tr1', + }, + metadata: { resumed: { approved: true } }, + }, + ], + }, + }); + session2.close(); + + for await (const chunk of session2.stream) { + } + const output2 = await session2.output; + + // Verify the tool was actually re-executed + assert.strictEqual( + toolCallCount, + 1, + 'Tool should be called once on restart' + ); + assert.ok(lastResumedMetadata, 'Tool should receive resumed metadata'); + assert.strictEqual(lastResumedMetadata.approved, true); + + assert.strictEqual(output2.message?.role, 'model'); + assert.strictEqual( + output2.message?.content[0].text, + 'Files deleted successfully!' + ); + }); + + it('should reject resume.restart with forged (modified) inputs', async () => { + const registry = new Registry(); + registry.apiStability = 'beta'; + const store = new InMemorySessionStore<{}>(); + + const pm = defineProgrammableModel( + registry, + undefined, + 'forgedRestartModel' + ); + + defineTool( + registry, + { + name: 'sensitiveTool', + description: 'Tool with sensitive inputs', + inputSchema: z.object({ target: z.string() }), + outputSchema: z.object({ result: z.string() }), + }, + async (input, { resumed }) => { + if (!resumed) { + throw new ToolInterruptError({ needsApproval: true }); + } + return { result: `executed on ${input.target}` }; + } + ); + + definePrompt(registry, { + name: 'forgedRestartPrompt', + model: 'forgedRestartModel', + tools: ['sensitiveTool'], + config: { temperature: 1 }, + }); + + const flow = definePromptAgent(registry, { + promptName: 'forgedRestartPrompt', + store, + }); + + // Phase 1: Model requests tool, tool interrupts + pm.handleResponse = async () => ({ + message: { + role: 'model', + content: [ + { + toolRequest: { + name: 'sensitiveTool', + input: { target: 'safe-file.txt' }, + ref: 'ref1', + }, + }, + ], + }, + finishReason: 'stop', + }); + + const session1 = flow.streamBidi({}); + session1.send({ + messages: [{ role: 'user', content: [{ text: 'do it' }] }], + }); + session1.close(); + for await (const _ of session1.stream) { + } + const output1 = await session1.output; + assert.ok(output1.snapshotId); + + // Phase 2: Client forges restart with DIFFERENT input + const session2 = flow.streamBidi({ snapshotId: output1.snapshotId }); + session2.send({ + resume: { + restart: [ + { + toolRequest: { + name: 'sensitiveTool', + input: { target: '/etc/passwd' }, // FORGED! + ref: 'ref1', + }, + metadata: { resumed: { approved: true } }, + }, + ], + }, + }); + session2.close(); + + try { + for await (const _ of session2.stream) { + } + await session2.output; + assert.fail( + 'Expected INVALID_ARGUMENT error for forged restart inputs' + ); + } catch (e: any) { + assert.ok( + e.message.includes('modified inputs'), + `Expected modified inputs error, got: ${e.message}` + ); + assert.strictEqual(e.status, 'INVALID_ARGUMENT'); + } + }); + + it('should reject resume.respond referencing a non-existent tool', async () => { + const registry = new Registry(); + registry.apiStability = 'beta'; + const store = new InMemorySessionStore<{}>(); + + const pm = defineProgrammableModel( + registry, + undefined, + 'fakeRespondModel' + ); + + const myInterrupt = interrupt({ + name: 'realInterrupt', + description: 'A real interrupt', + inputSchema: z.object({ q: z.string() }), + outputSchema: z.object({ a: z.string() }), + }); + registry.registerAction('tool', myInterrupt); + + definePrompt(registry, { + name: 'fakeRespondPrompt', + model: 'fakeRespondModel', + tools: ['realInterrupt'], + config: { temperature: 1 }, + }); + + const flow = definePromptAgent(registry, { + promptName: 'fakeRespondPrompt', + store, + }); + + // Phase 1: Model requests the real interrupt tool + pm.handleResponse = async () => ({ + message: { + role: 'model', + content: [ + { + toolRequest: { + name: 'realInterrupt', + input: { q: 'confirm?' }, + ref: 'r1', + }, + }, + ], + }, + finishReason: 'stop', + }); + + const session1 = flow.streamBidi({}); + session1.send({ + messages: [{ role: 'user', content: [{ text: 'hi' }] }], + }); + session1.close(); + for await (const _ of session1.stream) { + } + const output1 = await session1.output; + assert.ok(output1.snapshotId); + + // Phase 2: Client responds with a FAKE tool name/ref + const session2 = flow.streamBidi({ snapshotId: output1.snapshotId }); + session2.send({ + resume: { + respond: [ + { + toolResponse: { + name: 'fakeToolThatDoesNotExist', + ref: 'fake-ref', + output: { a: 'hacked' }, + }, + }, + ], + }, + }); + session2.close(); + + try { + for await (const _ of session2.stream) { + } + await session2.output; + assert.fail( + 'Expected INVALID_ARGUMENT error for non-existent tool respond' + ); + } catch (e: any) { + assert.ok( + e.message.includes('not found in session history'), + `Expected not found error, got: ${e.message}` + ); + assert.strictEqual(e.status, 'INVALID_ARGUMENT'); + } + }); + + it('should reject resume.restart referencing a non-existent tool', async () => { + const registry = new Registry(); + registry.apiStability = 'beta'; + const store = new InMemorySessionStore<{}>(); + + const pm = defineProgrammableModel( + registry, + undefined, + 'fakeRestartModel' + ); + + definePrompt(registry, { + name: 'fakeRestartPrompt', + model: 'fakeRestartModel', + config: { temperature: 1 }, + }); + + const flow = definePromptAgent(registry, { + promptName: 'fakeRestartPrompt', + store, + }); + + // Phase 1: Model returns a simple text response (no tools at all) + pm.handleResponse = async () => ({ + message: { + role: 'model', + content: [{ text: 'hello' }], + }, + finishReason: 'stop', + }); + + const session1 = flow.streamBidi({}); + session1.send({ + messages: [{ role: 'user', content: [{ text: 'hi' }] }], + }); + session1.close(); + for await (const _ of session1.stream) { + } + const output1 = await session1.output; + assert.ok(output1.snapshotId); + + // Phase 2: Client fabricates a restart for a tool that was never requested + const session2 = flow.streamBidi({ snapshotId: output1.snapshotId }); + session2.send({ + resume: { + restart: [ + { + toolRequest: { + name: 'inventedTool', + input: { evil: true }, + ref: 'fake-ref', + }, + metadata: { resumed: true }, + }, + ], + }, + }); + session2.close(); + + try { + for await (const _ of session2.stream) { + } + await session2.output; + assert.fail('Expected INVALID_ARGUMENT error for fabricated restart'); + } catch (e: any) { + assert.ok( + e.message.includes('not found in session history'), + `Expected not found error, got: ${e.message}` + ); + assert.strictEqual(e.status, 'INVALID_ARGUMENT'); + } + }); + + it('should process all pre-queued messages in the background after detaching', async () => { + const store = new InMemorySessionStore<{ foo: string }>(); + let processedCount = 0; + + const flow = defineCustomAgent( + new Registry(), + { + name: 'sequentialBackgroundTest', + store, + }, + async (sess) => { + await sess.run(async () => { + processedCount++; + }); + return { + artifacts: [], + message: { role: 'model', content: [{ text: 'hi' }] }, + }; + } + ); + + const session = flow.streamBidi({}); + session.send({ + messages: [{ role: 'user' as const, content: [{ text: 'task 1' }] }], + }); + session.send({ + messages: [{ role: 'user' as const, content: [{ text: 'task 2' }] }], + }); + session.send({ detach: true }); + + const output = await session.output; + assert.ok(output.snapshotId); + + // Detach-only messages are not forwarded to the runner — 2 turns, not 3. + const snapDone = await waitForSnapshotStatus( + store, + output.snapshotId!, + 'done' + ); + assert.strictEqual(snapDone.status, 'done'); + assert.strictEqual(processedCount, 2); + + session.close(); + }); + }); + + describe('clientStateTransform', () => { + it('should transform state in AgentOutput for client-managed agents', async () => { + const registry = new Registry(); + + const flow = defineCustomAgent< + unknown, + { publicField: string; secretField: string } + >( + registry, + { + name: 'clientTransformTest', + clientStateTransform: (state) => ({ + custom: { publicField: (state.custom as any)?.publicField }, + // Strip messages and artifacts + }), + }, + async (sess) => { + sess.session.updateCustom(() => ({ + publicField: 'visible', + secretField: 'top-secret', + })); + await sess.run(async () => {}); + return { + artifacts: [], + message: { role: 'model', content: [{ text: 'done' }] }, + }; + } + ); + + const session = flow.streamBidi({}); + session.send({ + messages: [{ role: 'user', content: [{ text: 'hi' }] }], + }); + session.close(); + + for await (const _ of session.stream) { + } + const output = await session.output; + + assert.ok(output.state); + assert.strictEqual((output.state!.custom as any).publicField, 'visible'); + assert.strictEqual((output.state!.custom as any).secretField, undefined); + // Messages were stripped by the transform + assert.strictEqual(output.state!.messages, undefined); + }); + + it('should return full state when no clientStateTransform is provided', async () => { + const registry = new Registry(); + + const flow = defineCustomAgent< + unknown, + { publicField: string; secretField: string } + >(registry, { name: 'noTransformTest' }, async (sess) => { + sess.session.updateCustom(() => ({ + publicField: 'visible', + secretField: 'top-secret', + })); + await sess.run(async () => {}); + return { + artifacts: [], + message: { role: 'model', content: [{ text: 'done' }] }, + }; + }); + + const session = flow.streamBidi({}); + session.send({ + messages: [{ role: 'user', content: [{ text: 'hi' }] }], + }); + session.close(); + + for await (const _ of session.stream) { + } + const output = await session.output; + + assert.ok(output.state); + assert.strictEqual((output.state!.custom as any).publicField, 'visible'); + assert.strictEqual( + (output.state!.custom as any).secretField, + 'top-secret' + ); + // Messages should be present + assert.ok(output.state!.messages); + assert.strictEqual(output.state!.messages!.length, 1); + }); + + it('should transform snapshot state in getSnapshotData for server-managed agents', async () => { + const store = new InMemorySessionStore<{ + publicField: string; + secretField: string; + }>(); + + const flow = defineCustomAgent< + unknown, + { publicField: string; secretField: string } + >( + new Registry(), + { + name: 'snapshotTransformTest', + store, + clientStateTransform: (state) => ({ + custom: { publicField: (state.custom as any)?.publicField }, + }), + }, + async (sess) => { + sess.session.updateCustom(() => ({ + publicField: 'visible', + secretField: 'top-secret', + })); + await sess.run(async () => {}); + return { + artifacts: [], + message: { role: 'model', content: [{ text: 'done' }] }, + }; + } + ); + + const session = flow.streamBidi({}); + session.send({ + messages: [{ role: 'user', content: [{ text: 'hi' }] }], + }); + session.close(); + + for await (const _ of session.stream) { + } + const output = await session.output; + assert.ok(output.snapshotId); + + // getSnapshotData should return transformed state + const snapshot = await flow.getSnapshotData(output.snapshotId!); + assert.ok(snapshot); + assert.strictEqual( + (snapshot!.state.custom as any).publicField, + 'visible' + ); + assert.strictEqual( + (snapshot!.state.custom as any).secretField, + undefined + ); + // Messages were stripped + assert.strictEqual(snapshot!.state.messages, undefined); + + // But the raw store should still have the full state + const rawSnapshot = await store.getSnapshot(output.snapshotId!); + assert.ok(rawSnapshot); + assert.strictEqual(rawSnapshot!.state.custom?.secretField, 'top-secret'); + assert.ok(rawSnapshot!.state.messages); + }); + + it('should transform snapshot state in getSnapshotDataAction for server-managed agents', async () => { + const registry = new Registry(); + const store = new InMemorySessionStore<{ + publicField: string; + secretField: string; + }>(); + + const flow = defineCustomAgent< + unknown, + { publicField: string; secretField: string } + >( + registry, + { + name: 'snapshotActionTransformTest', + store, + clientStateTransform: (state) => ({ + custom: { publicField: (state.custom as any)?.publicField }, + }), + }, + async (sess) => { + sess.session.updateCustom(() => ({ + publicField: 'visible', + secretField: 'top-secret', + })); + await sess.run(async () => {}); + return { + artifacts: [], + message: { role: 'model', content: [{ text: 'done' }] }, + }; + } + ); + + const session = flow.streamBidi({}); + session.send({ + messages: [{ role: 'user', content: [{ text: 'hi' }] }], + }); + session.close(); + + for await (const _ of session.stream) { + } + const output = await session.output; + assert.ok(output.snapshotId); + + // Invoke the companion action directly + const actionResult = await flow.getSnapshotDataAction(output.snapshotId!); + assert.ok(actionResult); + assert.strictEqual( + (actionResult as any).state.custom.publicField, + 'visible' + ); + assert.strictEqual( + (actionResult as any).state.custom.secretField, + undefined + ); + }); + + it('should transform state in detached output for client-managed agents', async () => { + const store = new InMemorySessionStore<{ + publicField: string; + secretField: string; + }>(); + let resolvePromise: () => void = () => {}; + const releasePromise = new Promise((resolve) => { + resolvePromise = resolve; + }); + + // Client-managed (no store in config), but we need a store for detach; + // use a server-managed config to test detach transform path + const flow = defineCustomAgent< + unknown, + { publicField: string; secretField: string } + >( + new Registry(), + { + name: 'detachTransformTest', + store, + clientStateTransform: (state) => ({ + custom: { publicField: (state.custom as any)?.publicField }, + }), + }, + async (sess) => { + sess.session.updateCustom(() => ({ + publicField: 'visible', + secretField: 'top-secret', + })); + await sess.run(async () => { + await releasePromise; + }); + return { + artifacts: [], + message: { role: 'model', content: [{ text: 'done' }] }, + }; + } + ); + + const session = flow.streamBidi({}); + session.send({ + messages: [{ role: 'user', content: [{ text: 'hi' }] }], + detach: true, + }); + + const output = await session.output; + assert.ok(output.snapshotId); + // Server-managed agents don't return state in output (state is undefined) + // but the snapshot should have the transformed state + const snapshot = await flow.getSnapshotData(output.snapshotId!); + assert.ok(snapshot); + assert.strictEqual( + (snapshot!.state.custom as any).publicField, + 'visible' + ); + assert.strictEqual( + (snapshot!.state.custom as any).secretField, + undefined + ); + + resolvePromise(); + session.close(); + }); + + it('should pass clientStateTransform through definePromptAgent', async () => { + const registry = new Registry(); + defineEchoModel(registry); + definePrompt(registry, { + name: 'transformPromptAgent', + model: 'echoModel', + config: { temperature: 1 }, + }); + + const flow = definePromptAgent<{ secret: string }>(registry, { + promptName: 'transformPromptAgent', + clientStateTransform: (state) => ({ + // strip custom state entirely, keep messages + messages: state.messages, + }), + }); + + const session = flow.streamBidi({}); + session.send({ + messages: [{ role: 'user', content: [{ text: 'hi' }] }], + }); + session.close(); + + for await (const _ of session.stream) { + } + const output = await session.output; + + assert.ok(output.state); + // Custom state should be stripped + assert.strictEqual(output.state!.custom, undefined); + // Messages should be present + assert.ok(output.state!.messages); + assert.ok(output.state!.messages!.length > 0); + }); + + it('should pass clientStateTransform through defineAgent', async () => { + const registry = new Registry(); + defineEchoModel(registry); + + const flow = defineAgent<{ secret: string }>(registry, { + name: 'transformDefineAgent', + model: 'echoModel', + config: { temperature: 1 }, + clientStateTransform: (state) => ({ + // strip custom state entirely, keep messages + messages: state.messages, + }), + }); + + const session = flow.streamBidi({}); + session.send({ + messages: [{ role: 'user', content: [{ text: 'hi' }] }], + }); + session.close(); + + for await (const _ of session.stream) { + } + const output = await session.output; + + assert.ok(output.state); + // Custom state should be stripped + assert.strictEqual(output.state!.custom, undefined); + // Messages should be present + assert.ok(output.state!.messages); + assert.ok(output.state!.messages!.length > 0); + }); + }); + + // ========================================================================= + // Prompt rendering across turns + // ========================================================================= + + describe('prompt rendering across turns', () => { + /** Run a single invocation, collecting all model requests made during it. */ + async function runAgent( + agent: ReturnType, + pm: ProgrammableModel, + opts: { + init?: any; + inputs: any[]; + modelResponses: any[]; + } + ) { + const modelRequests: any[] = []; + let reqCounter = 0; + + pm.handleResponse = async (req) => { + modelRequests.push(JSON.parse(JSON.stringify(req))); + return opts.modelResponses[reqCounter++]!; + }; + + const session = agent.streamBidi(opts.init || {}); + for (const input of opts.inputs) { + session.send(input); + } + session.close(); + + const chunks: AgentStreamChunk[] = []; + for await (const chunk of session.stream) { + chunks.push(chunk); + } + + const output = await session.output; + return { output, chunks, modelRequests }; + } + + it('system-only: system appears in model request each turn, not in stored history', async () => { + const registry = new Registry(); + registry.apiStability = 'beta'; + const pm = defineProgrammableModel(registry); + + const agent = defineAgent(registry, { + name: 'systemOnlyAgent', + model: 'programmableModel', + system: 'You are a helpful assistant.', + }); + + const { output, modelRequests } = await runAgent(agent, pm, { + inputs: [ + { messages: [{ role: 'user', content: [{ text: 'turn1' }] }] }, + { messages: [{ role: 'user', content: [{ text: 'turn2' }] }] }, + ], + modelResponses: [ + { + message: { role: 'model', content: [{ text: 'reply1' }] }, + finishReason: 'stop', + }, + { + message: { role: 'model', content: [{ text: 'reply2' }] }, + finishReason: 'stop', + }, + ], + }); + + // --- Model request assertions --- + + // Turn 1: model sees [system("You are a helpful assistant."), user("turn1")] + const t1 = modelRequests[0].messages; + assert.strictEqual( + t1.length, + 2, + 'Turn 1: model should receive 2 messages' + ); + assert.strictEqual(t1[0].role, 'system'); + assert.strictEqual(t1[0].content[0].text, 'You are a helpful assistant.'); + assert.strictEqual(t1[1].role, 'user'); + assert.strictEqual(t1[1].content[0].text, 'turn1'); + + // Turn 2: model sees [system, user("turn1"), model("reply1"), user("turn2")] + const t2 = modelRequests[1].messages; + assert.strictEqual( + t2.length, + 4, + 'Turn 2: model should receive 4 messages' + ); + assert.strictEqual(t2[0].role, 'system'); + assert.strictEqual(t2[0].content[0].text, 'You are a helpful assistant.'); + assert.strictEqual(t2[1].role, 'user'); + assert.strictEqual(t2[1].content[0].text, 'turn1'); + assert.strictEqual(t2[2].role, 'model'); + assert.strictEqual(t2[2].content[0].text, 'reply1'); + assert.strictEqual(t2[3].role, 'user'); + assert.strictEqual(t2[3].content[0].text, 'turn2'); + + // No duplicate system messages + assert.strictEqual(t2.filter((m: any) => m.role === 'system').length, 1); + + // --- Stored messages assertions --- + const storedMessages = output.state?.messages || []; + assert.strictEqual( + storedMessages.filter((m: any) => m.role === 'system').length, + 0, + 'Stored history should not contain system messages' + ); + assert.strictEqual(storedMessages.length, 4); + }); + + it('system + user prompt: template user prompt appears each turn but does not accumulate', async () => { + const registry = new Registry(); + registry.apiStability = 'beta'; + const pm = defineProgrammableModel(registry); + + const agent = defineAgent(registry, { + name: 'systemAndPromptAgent', + model: 'programmableModel', + system: 'You are a helpful assistant.', + prompt: 'Always respond concisely.', + }); + + const { output, modelRequests } = await runAgent(agent, pm, { + inputs: [ + { messages: [{ role: 'user', content: [{ text: 'turn1' }] }] }, + { messages: [{ role: 'user', content: [{ text: 'turn2' }] }] }, + ], + modelResponses: [ + { + message: { role: 'model', content: [{ text: 'reply1' }] }, + finishReason: 'stop', + }, + { + message: { role: 'model', content: [{ text: 'reply2' }] }, + finishReason: 'stop', + }, + ], + }); + + // Turn 2: template user prompt should appear exactly once + const templateMsgs = modelRequests[1].messages.filter( + (m: any) => + m.role === 'user' && + m.content?.[0]?.text?.includes('Always respond concisely') + ); + assert.strictEqual(templateMsgs.length, 1); + + // Stored history should NOT contain system or template user prompt + const storedMessages = output.state?.messages || []; + assert.strictEqual( + storedMessages.filter((m: any) => m.role === 'system').length, + 0 + ); + assert.strictEqual( + storedMessages.filter( + (m: any) => + m.role === 'user' && + m.content?.[0]?.text?.includes('Always respond concisely') + ).length, + 0 + ); + assert.strictEqual(storedMessages.length, 4); + }); + + it('cross-invocation: system + prompt do not duplicate when state is carried over', async () => { + const registry = new Registry(); + registry.apiStability = 'beta'; + const pm = defineProgrammableModel(registry); + + const agent = defineAgent(registry, { + name: 'crossInvAgent', + model: 'programmableModel', + system: 'You are a helpful assistant.', + prompt: 'Always respond concisely.', + }); + + // Invocation 1 + const result1 = await runAgent(agent, pm, { + inputs: [ + { messages: [{ role: 'user', content: [{ text: 'first' }] }] }, + ], + modelResponses: [ + { + message: { role: 'model', content: [{ text: 'reply1' }] }, + finishReason: 'stop', + }, + ], + }); + + // Invocation 2: seed with state from invocation 1 + const result2 = await runAgent(agent, pm, { + init: { state: result1.output.state }, + inputs: [ + { messages: [{ role: 'user', content: [{ text: 'second' }] }] }, + ], + modelResponses: [ + { + message: { role: 'model', content: [{ text: 'reply2' }] }, + finishReason: 'stop', + }, + ], + }); + + const req2msgs = result2.modelRequests[0].messages; + assert.strictEqual( + req2msgs.filter((m: any) => m.role === 'system').length, + 1 + ); + assert.strictEqual( + req2msgs.filter( + (m: any) => + m.role === 'user' && + m.content?.[0]?.text?.includes('Always respond concisely') + ).length, + 1 + ); + + // Stored messages should be clean + const storedMessages = result2.output.state?.messages || []; + assert.strictEqual(storedMessages.length, 4); + assert.strictEqual( + storedMessages.filter((m: any) => m.role === 'system').length, + 0 + ); + }); + + it('message ordering: [system, ...history, user_prompt_from_template]', async () => { + const registry = new Registry(); + registry.apiStability = 'beta'; + const pm = defineProgrammableModel(registry); + + const agent = defineAgent(registry, { + name: 'orderingAgent', + model: 'programmableModel', + system: 'Be helpful.', + prompt: 'Be concise.', + }); + + const { modelRequests } = await runAgent(agent, pm, { + inputs: [ + { messages: [{ role: 'user', content: [{ text: 'q1' }] }] }, + { messages: [{ role: 'user', content: [{ text: 'q2' }] }] }, + ], + modelResponses: [ + { + message: { role: 'model', content: [{ text: 'a1' }] }, + finishReason: 'stop', + }, + { + message: { role: 'model', content: [{ text: 'a2' }] }, + finishReason: 'stop', + }, + ], + }); + + // Turn 2: render places history between system and user prompt + const req2msgs = modelRequests[1].messages; + const roles = req2msgs.map((m: any) => m.role); + // Expected: [system, user(q1), model(a1), user(q2), user(Be concise.)] + assert.deepStrictEqual(roles, [ + 'system', + 'user', + 'model', + 'user', + 'user', + ]); + // Preamble messages are tagged agentPreamble; history messages are + // clean (the internal _genkit_history tag is stripped before the model + // sees them). + assert.ok( + req2msgs[0].metadata?.agentPreamble, + 'system is preamble-tagged' + ); + assert.strictEqual( + req2msgs[1].metadata?.agentPreamble, + undefined, + 'q1 has no preamble tag' + ); + assert.strictEqual( + req2msgs[1].metadata?._genkit_history, + undefined, + 'q1 has no history tag (stripped)' + ); + assert.strictEqual( + req2msgs[2].metadata?._genkit_history, + undefined, + 'a1 has no history tag (stripped)' + ); + assert.strictEqual( + req2msgs[3].metadata?._genkit_history, + undefined, + 'q2 has no history tag (stripped)' + ); + assert.ok( + req2msgs[4].metadata?.agentPreamble, + 'Be concise is preamble-tagged' + ); + }); + + it('dotprompt {{history}}: history is inserted where the template specifies', async () => { + const registry = new Registry(); + registry.apiStability = 'beta'; + const pm = defineProgrammableModel(registry); + + // Define a prompt with a dotprompt messages template that uses {{history}} + definePrompt(registry, { + name: 'historyTemplatePrompt', + model: 'programmableModel', + system: 'You are a helpful assistant.', + messages: `{{role "user"}}Here is the conversation so far: +{{history}} +Now respond to the latest message.`, + }); + + const agent = definePromptAgent(registry, { + promptName: 'historyTemplatePrompt', + }); + + const { output, modelRequests } = await runAgent(agent, pm, { + inputs: [ + { messages: [{ role: 'user', content: [{ text: 'hello' }] }] }, + { messages: [{ role: 'user', content: [{ text: 'how are you' }] }] }, + ], + modelResponses: [ + { + message: { role: 'model', content: [{ text: 'hi there' }] }, + finishReason: 'stop', + }, + { + message: { role: 'model', content: [{ text: 'doing well' }] }, + finishReason: 'stop', + }, + ], + }); + + // --- Turn 1 model request assertions --- + // Model sees: [system, user(template-before), user(hello), model(template-after)] + const t1 = modelRequests[0].messages; + assert.strictEqual(t1.length, 4, 'Turn 1: 4 messages'); + + assert.strictEqual(t1[0].role, 'system'); + assert.strictEqual(t1[0].content[0].text, 'You are a helpful assistant.'); + assert.ok(t1[0].metadata?.agentPreamble, 'T1: system is preamble'); + + assert.strictEqual(t1[1].role, 'user'); + assert.ok( + t1[1].content[0].text.includes('Here is the conversation so far'), + 'T1: template text before {{history}}' + ); + assert.ok( + t1[1].metadata?.agentPreamble, + 'T1: template-before is preamble' + ); + + assert.strictEqual(t1[2].role, 'user'); + assert.strictEqual(t1[2].content[0].text, 'hello'); + assert.strictEqual( + t1[2].metadata?.agentPreamble, + undefined, + 'T1: hello is not preamble' + ); + assert.strictEqual( + t1[2].metadata?._genkit_history, + undefined, + 'T1: hello has no internal tag' + ); + + assert.strictEqual(t1[3].role, 'model'); + assert.ok( + t1[3].content[0].text.includes('Now respond to the latest message'), + 'T1: template text after {{history}}' + ); + assert.ok( + t1[3].metadata?.agentPreamble, + 'T1: template-after is preamble' + ); + + // --- Turn 2 model request assertions --- + // Model sees: [system, user(template-before), user(hello), model(hi there), + // user(how are you), model(template-after)] + const t2 = modelRequests[1].messages; + assert.strictEqual(t2.length, 6, 'Turn 2: 6 messages'); + + assert.strictEqual(t2[0].role, 'system'); + assert.ok(t2[0].metadata?.agentPreamble, 'T2: system is preamble'); + + assert.strictEqual(t2[1].role, 'user'); + assert.ok( + t2[1].metadata?.agentPreamble, + 'T2: template-before is preamble' + ); + + // History messages are embedded between template parts, clean of internal tags + assert.strictEqual(t2[2].role, 'user'); + assert.strictEqual(t2[2].content[0].text, 'hello'); + assert.strictEqual( + t2[2].metadata?.agentPreamble, + undefined, + 'T2: hello not preamble' + ); + assert.strictEqual( + t2[2].metadata?._genkit_history, + undefined, + 'T2: hello no internal tag' + ); + + assert.strictEqual(t2[3].role, 'model'); + assert.strictEqual(t2[3].content[0].text, 'hi there'); + assert.strictEqual( + t2[3].metadata?.agentPreamble, + undefined, + 'T2: hi there not preamble' + ); + assert.strictEqual( + t2[3].metadata?._genkit_history, + undefined, + 'T2: hi there no internal tag' + ); + + assert.strictEqual(t2[4].role, 'user'); + assert.strictEqual(t2[4].content[0].text, 'how are you'); + assert.strictEqual( + t2[4].metadata?.agentPreamble, + undefined, + 'T2: how are you not preamble' + ); + + assert.strictEqual(t2[5].role, 'model'); + assert.ok( + t2[5].content[0].text.includes('Now respond to the latest message'), + 'T2: template-after text' + ); + assert.ok( + t2[5].metadata?.agentPreamble, + 'T2: template-after is preamble' + ); + + // --- Stored messages should be clean (no system, no template wrapper) --- + const storedMessages = output.state?.messages || []; + assert.strictEqual( + storedMessages.filter((m: any) => m.role === 'system').length, + 0, + 'No system in stored history' + ); + // Should have the 4 conversation messages + assert.strictEqual(storedMessages.length, 4); + assert.strictEqual(storedMessages[0].content[0].text, 'hello'); + assert.strictEqual(storedMessages[1].content[0].text, 'hi there'); + assert.strictEqual(storedMessages[2].content[0].text, 'how are you'); + assert.strictEqual(storedMessages[3].content[0].text, 'doing well'); + }); + }); +}); diff --git a/js/ai/tests/prompt/prompt_test.ts b/js/ai/tests/prompt/prompt_test.ts index e2d856897f..d58ef2de5d 100644 --- a/js/ai/tests/prompt/prompt_test.ts +++ b/js/ai/tests/prompt/prompt_test.ts @@ -787,10 +787,7 @@ describe('prompt', () => { it(test.name, async () => { let session: Session | undefined; if (test.state) { - session = new Session(registry, { - id: '123', - sessionData: { id: '123', state: test.state }, - }); + session = new Session({ custom: test.state } as any); } const p = definePrompt(registry, test.prompt); diff --git a/js/core/src/async.ts b/js/core/src/async.ts index 7f5df2aea6..e64a2bcb5f 100644 --- a/js/core/src/async.ts +++ b/js/core/src/async.ts @@ -47,6 +47,7 @@ export class Channel implements AsyncIterable { private ready: Task = createTask(); private buffer: (T | null)[] = []; private err: unknown = null; + private done: boolean = false; send(value: T): void { this.buffer.push(value); @@ -69,6 +70,9 @@ export class Channel implements AsyncIterable { [Symbol.asyncIterator](): AsyncIterator { return { next: async (): Promise> => { + if (this.done) { + return { value: undefined as unknown as T, done: true }; + } if (this.err) { throw this.err; } @@ -81,6 +85,8 @@ export class Channel implements AsyncIterable { this.ready = createTask(); } + if (value === null) this.done = true; + return { value, done: value === null, @@ -174,3 +180,25 @@ export class AsyncTaskQueue { await this.last; } } + +/** A lightweight, cross-platform EventEmitter. */ +export class EventEmitter { + private listeners: Record void)[]> = {}; + + on(event: string, listener: (...args: any[]) => void) { + if (!this.listeners[event]) { + this.listeners[event] = []; + } + this.listeners[event].push(listener); + } + + off(event: string, listener: (...args: any[]) => void) { + if (!this.listeners[event]) return; + this.listeners[event] = this.listeners[event].filter((l) => l !== listener); + } + + emit(event: string, ...args: any[]) { + if (!this.listeners[event]) return; + this.listeners[event].forEach((l) => l(...args)); + } +} diff --git a/js/core/src/registry.ts b/js/core/src/registry.ts index a27060fb80..085b3937a5 100644 --- a/js/core/src/registry.ts +++ b/js/core/src/registry.ts @@ -56,6 +56,9 @@ const ACTION_TYPES = [ 'tool.v2', 'util', 'resource', + 'agent', + 'agent-snapshot', + 'agent-abort', ] as const; export type ActionType = (typeof ACTION_TYPES)[number]; diff --git a/js/core/src/utils.ts b/js/core/src/utils.ts index a15ab4c02d..16dfaa23da 100644 --- a/js/core/src/utils.ts +++ b/js/core/src/utils.ts @@ -73,3 +73,29 @@ export function isDevEnv(): boolean { export function featureMetadataPrefix(name: string) { return `feature:${name}`; } + +/** + * Deep-equality check for plain JSON-serializable values. + * Handles objects, arrays, and primitives. Does not handle functions, dates, + * or other non-JSON types. + */ +export function deepEqual(a: unknown, b: unknown): boolean { + if (a === b) return true; + if (a == null || b == null) return a === b; + if (typeof a !== typeof b) return false; + if (typeof a !== 'object') return false; + + if (Array.isArray(a)) { + if (!Array.isArray(b) || a.length !== b.length) return false; + return a.every((v, i) => deepEqual(v, b[i])); + } + + const aObj = a as Record; + const bObj = b as Record; + const aKeys = Object.keys(aObj).sort(); + const bKeys = Object.keys(bObj).sort(); + if (aKeys.length !== bKeys.length) return false; + return aKeys.every( + (key, i) => key === bKeys[i] && deepEqual(aObj[key], bObj[key]) + ); +} diff --git a/js/core/tests/utils_test.ts b/js/core/tests/utils_test.ts new file mode 100644 index 0000000000..8e2778c0ed --- /dev/null +++ b/js/core/tests/utils_test.ts @@ -0,0 +1,171 @@ +/** + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import * as assert from 'assert'; +import { describe, it } from 'node:test'; +import { deepEqual } from '../src/utils.js'; + +describe('deepEqual', () => { + // ── Primitives ────────────────────────────────────────────────────── + it('returns true for identical primitives', () => { + assert.strictEqual(deepEqual(1, 1), true); + assert.strictEqual(deepEqual('hello', 'hello'), true); + assert.strictEqual(deepEqual(true, true), true); + assert.strictEqual(deepEqual(null, null), true); + assert.strictEqual(deepEqual(undefined, undefined), true); + }); + + it('returns false for different primitives', () => { + assert.strictEqual(deepEqual(1, 2), false); + assert.strictEqual(deepEqual('a', 'b'), false); + assert.strictEqual(deepEqual(true, false), false); + assert.strictEqual(deepEqual(0, '0'), false); + assert.strictEqual(deepEqual(0, false), false); + }); + + // ── null / undefined ──────────────────────────────────────────────── + it('distinguishes null from undefined', () => { + assert.strictEqual(deepEqual(null, undefined), false); + assert.strictEqual(deepEqual(undefined, null), false); + }); + + it('returns false when comparing null to an object', () => { + assert.strictEqual(deepEqual(null, {}), false); + assert.strictEqual(deepEqual({}, null), false); + }); + + it('returns false when comparing undefined to an object', () => { + assert.strictEqual(deepEqual(undefined, {}), false); + assert.strictEqual(deepEqual({}, undefined), false); + }); + + // ── Arrays ────────────────────────────────────────────────────────── + it('returns true for equal arrays', () => { + assert.strictEqual(deepEqual([], []), true); + assert.strictEqual(deepEqual([1, 2, 3], [1, 2, 3]), true); + assert.strictEqual(deepEqual(['a', 'b'], ['a', 'b']), true); + }); + + it('returns false for arrays with different lengths', () => { + assert.strictEqual(deepEqual([1, 2], [1, 2, 3]), false); + assert.strictEqual(deepEqual([1, 2, 3], [1, 2]), false); + }); + + it('returns false for arrays with different elements', () => { + assert.strictEqual(deepEqual([1, 2, 3], [1, 2, 4]), false); + }); + + it('returns false for array vs non-array', () => { + assert.strictEqual(deepEqual([1, 2], { 0: 1, 1: 2 }), false); + }); + + // ── Objects ───────────────────────────────────────────────────────── + it('returns true for equal objects', () => { + assert.strictEqual(deepEqual({}, {}), true); + assert.strictEqual(deepEqual({ a: 1, b: 2 }, { a: 1, b: 2 }), true); + }); + + it('returns true regardless of key order', () => { + assert.strictEqual(deepEqual({ a: 1, b: 2 }, { b: 2, a: 1 }), true); + }); + + it('returns false for objects with different keys', () => { + assert.strictEqual(deepEqual({ a: 1 }, { b: 1 }), false); + assert.strictEqual(deepEqual({ a: 1 }, { a: 1, b: 2 }), false); + }); + + it('returns false for objects with different values', () => { + assert.strictEqual(deepEqual({ a: 1 }, { a: 2 }), false); + }); + + // ── Nested structures ────────────────────────────────────────────── + it('returns true for deeply nested equal structures', () => { + const a = { x: { y: { z: [1, { w: 'hello' }] } } }; + const b = { x: { y: { z: [1, { w: 'hello' }] } } }; + assert.strictEqual(deepEqual(a, b), true); + }); + + it('returns false for deeply nested structures with differences', () => { + const a = { x: { y: { z: [1, { w: 'hello' }] } } }; + const b = { x: { y: { z: [1, { w: 'world' }] } } }; + assert.strictEqual(deepEqual(a, b), false); + }); + + it('handles mixed arrays and objects', () => { + assert.strictEqual( + deepEqual([{ a: 1 }, { b: [2, 3] }], [{ a: 1 }, { b: [2, 3] }]), + true + ); + assert.strictEqual( + deepEqual([{ a: 1 }, { b: [2, 3] }], [{ a: 1 }, { b: [2, 4] }]), + false + ); + }); + + // ── Type mismatches ──────────────────────────────────────────────── + it('returns false for different primitive types', () => { + assert.strictEqual(deepEqual(1, '1'), false); + assert.strictEqual(deepEqual(true, 1), false); + }); + + // Note: deepEqual treats {} and [] as equal because both are objects + // with zero enumerable keys. This is acceptable for its intended use + // case (comparing JSON-serializable tool inputs, where an empty object + // and empty array are unlikely to appear in practice). + it('treats empty object and empty array as equal (known limitation)', () => { + assert.strictEqual(deepEqual({}, []), true); + }); + + // ── Same reference ───────────────────────────────────────────────── + it('returns true for the same reference', () => { + const obj = { a: 1, b: [2, 3] }; + assert.strictEqual(deepEqual(obj, obj), true); + }); + + // ── Edge case: empty nested structures ───────────────────────────── + it('handles empty nested structures', () => { + assert.strictEqual(deepEqual({ a: {} }, { a: {} }), true); + assert.strictEqual(deepEqual({ a: [] }, { a: [] }), true); + // { a: {} } vs { a: [] } — same known limitation as {} vs [] + assert.strictEqual(deepEqual({ a: {} }, { a: [] }), true); + }); + + it('distinguishes non-empty arrays from objects', () => { + assert.strictEqual(deepEqual({ a: [1] }, { a: { 0: 1 } }), false); + assert.strictEqual(deepEqual([1, 2], { 0: 1, 1: 2 }), false); + }); + + // ── Realistic tool input comparison ──────────────────────────────── + it('correctly compares realistic tool request inputs', () => { + const original = { + action: 'delete files', + target: '/tmp/test', + options: { recursive: true, force: false }, + }; + const legitimate = { + action: 'delete files', + target: '/tmp/test', + options: { recursive: true, force: false }, + }; + const forged = { + action: 'delete files', + target: '/etc/passwd', + options: { recursive: true, force: false }, + }; + assert.strictEqual(deepEqual(original, legitimate), true); + assert.strictEqual(deepEqual(original, forged), false); + }); +}); diff --git a/js/genkit/src/beta.ts b/js/genkit/src/beta.ts index c1cf7de963..23c2042955 100644 --- a/js/genkit/src/beta.ts +++ b/js/genkit/src/beta.ts @@ -14,6 +14,15 @@ * limitations under the License. */ +export { + AgentInitSchema, + AgentInputSchema, + AgentOutputSchema, + AgentStreamChunkSchema, + type AgentInit, + type AgentInput, + type AgentOutput, +} from '@genkit-ai/ai'; export { InMemoryStreamManager, StreamNotFoundError, @@ -23,4 +32,21 @@ export { } from '@genkit-ai/core'; export { AsyncTaskQueue, lazy } from '@genkit-ai/core/async'; export * from './common.js'; -export { GenkitBeta, genkit, type GenkitBetaOptions } from './genkit-beta.js'; +export { + FileSessionStore, + GenkitBeta, + InMemorySessionStore, + SessionRunner, + genkit, + type AgentFn, + type AgentStreamChunk, + type GenkitBetaOptions, + type SessionSnapshot, + type SessionSnapshotInput, + type SessionState, + type SessionStore, + type SessionStoreOptions, + type SnapshotCallback, + type SnapshotContext, + type SnapshotMutator, +} from './genkit-beta.js'; diff --git a/js/genkit/src/common.ts b/js/genkit/src/common.ts index 4c123efd73..219474e712 100644 --- a/js/genkit/src/common.ts +++ b/js/genkit/src/common.ts @@ -121,11 +121,7 @@ export { type ToolResponse, type ToolResponsePart, } from '@genkit-ai/ai'; -export { - Session, - type SessionData, - type SessionStore, -} from '@genkit-ai/ai/session'; +export { Session, type SessionStore } from '@genkit-ai/ai/session'; export { dynamicTool, tool } from '@genkit-ai/ai/tool'; export { GENKIT_CLIENT_HEADER, diff --git a/js/genkit/src/genkit-beta.ts b/js/genkit/src/genkit-beta.ts index 39eddf6498..e668834136 100644 --- a/js/genkit/src/genkit-beta.ts +++ b/js/genkit/src/genkit-beta.ts @@ -15,34 +15,63 @@ */ import { - defineInterrupt, - defineResource, - generateOperation, GenerateOptions, GenerateResponseData, GenerationCommonConfigSchema, - ResourceAction, - ResourceFn, - ResourceOptions, + SessionRunner, + defineAgent, + defineCustomAgent, + defineInterrupt, + definePromptAgent, + defineResource, + generateOperation, + type AgentConfig, + type AgentFn, + type AgentStreamChunk, type InterruptConfig, + type PromptConfig, + type ResourceAction, + type ResourceFn, + type ResourceOptions, type ToolAction, } from '@genkit-ai/ai'; import { defineFormat } from '@genkit-ai/ai/formats'; - import { - getCurrentSession, + FileSessionStore, + InMemorySessionStore, Session, SessionError, - type SessionData, - type SessionOptions, + getCurrentSession, + type SessionSnapshot, + type SessionSnapshotInput, + type SessionState, + type SessionStore, + type SessionStoreOptions, + type SnapshotCallback, + type SnapshotContext, + type SnapshotMutator, } from '@genkit-ai/ai/session'; + import { type Operation, type z } from '@genkit-ai/core'; -import { v4 as uuidv4 } from 'uuid'; import type { Formatter } from './formats'; import { Genkit, type GenkitOptions } from './genkit'; -export type { GenkitOptions as GenkitBetaOptions }; // in case they drift later +export { FileSessionStore, InMemorySessionStore, SessionRunner }; +export type { + AgentFn, + AgentStreamChunk, + GenkitOptions as GenkitBetaOptions, + PromptConfig, + SessionSnapshot, + SessionSnapshotInput, + SessionState, + SessionStore, + SessionStoreOptions, + SnapshotCallback, + SnapshotContext, + SnapshotMutator, +}; /** * WARNING: these APIs are considered unstable and subject to frequent breaking changes that may not honor semver. @@ -70,40 +99,58 @@ export class GenkitBeta extends Genkit { } /** - * Create a session for this environment. + * Defines and registers a custom agent with a custom handler function. + * + * @beta */ - createSession(options?: SessionOptions): Session { - const sessionId = options?.sessionId?.trim() || uuidv4(); - const sessionData: SessionData = { - id: sessionId, - state: options?.initialState, - }; - return new Session(this.registry, { - id: sessionId, - sessionData, - store: options?.store, - }); + defineCustomAgent( + config: { + name: string; + description?: string; + stateSchema?: z.ZodType; + store?: SessionStore; + snapshotCallback?: SnapshotCallback; + }, + fn: AgentFn + ) { + return defineCustomAgent(this.registry, config, fn); } /** - * Loads a session from the store. + * Defines and registers an agent from an existing Prompt template. * * @beta */ - async loadSession( - sessionId: string, - options: SessionOptions - ): Promise { - if (!options.store) { - throw new Error('options.store is required'); - } - const sessionData = await options.store.get(sessionId); + definePromptAgent(config: { + promptName: string; + stateSchema?: z.ZodType; + store?: SessionStore; + snapshotCallback?: SnapshotCallback; + }) { + return definePromptAgent(this.registry, config); + } - return new Session(this.registry, { - id: sessionId, - sessionData, - store: options.store, - }); + /** + * Defines and registers an agent by creating a prompt and wiring it into a + * multi-turn agent in one step. + * + * This is a convenience shortcut that combines `definePrompt` and + * `definePromptAgent` into a single call. + * + * ```ts + * const myAgent = ai.defineAgent({ + * name: 'myAgent', + * model: 'googleai/gemini-2.5-flash', + * system: 'Talk like a pirate.', + * tools: [weatherTool], + * store: new FileSessionStore('./.snapshots'), + * }); + * ``` + * + * @beta + */ + defineAgent(config: AgentConfig) { + return defineAgent(this.registry, config); } /** @@ -116,7 +163,7 @@ export class GenkitBeta extends Genkit { if (!currentSession) { throw new SessionError('not running within a session'); } - return currentSession as Session; + return currentSession as any as Session; } /** @@ -217,6 +264,7 @@ export class GenkitBeta extends Genkit { * await ai.generate({ * prompt: [{ resource: 'my://resource/value' }] * }) + * ``` */ defineResource(opts: ResourceOptions, fn: ResourceFn): ResourceAction { return defineResource(this.registry, opts, fn); diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index 32324108e4..23e8cbfd45 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -1061,6 +1061,22 @@ importers: specifier: '>=12.2' version: 13.7.0(encoding@0.1.13) + testapps/agents: + dependencies: + '@genkit-ai/google-genai': + specifier: workspace:* + version: link:../../plugins/google-genai + express: + specifier: ^4.20.0 + version: 4.22.1 + genkit: + specifier: workspace:* + version: link:../../genkit + devDependencies: + typescript: + specifier: ^5.9.3 + version: 5.9.3 + testapps/anthropic: dependencies: '@anthropic-ai/sdk':