diff --git a/go/ai/prompt.go b/go/ai/prompt.go index e4ef99a643..15c2c33fcb 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -32,6 +32,7 @@ import ( "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/api" "github.com/firebase/genkit/go/core/logger" + "github.com/firebase/genkit/go/core/x/session" "github.com/firebase/genkit/go/internal/base" "github.com/google/dotprompt/go/dotprompt" "github.com/invopop/jsonschema" @@ -588,14 +589,19 @@ func renderPrompt(ctx context.Context, opts promptOptions, templateText string, // renderDotpromptToMessages executes a dotprompt prompt function and converts the result to a slice of messages func renderDotpromptToMessages(ctx context.Context, promptFn dotprompt.PromptFunction, input map[string]any, additionalMetadata *dotprompt.PromptMetadata) ([]*Message, error) { // Prepare the context for rendering - context := map[string]any{} + templateContext := map[string]any{} actionCtx := core.FromContext(ctx) - maps.Copy(context, actionCtx) + maps.Copy(templateContext, actionCtx) + + // Inject session state if available (accessible via {{@state.field}} in templates) + if state := session.StateFromContext(ctx); state != nil { + templateContext["state"] = state + } // Call the prompt function with the input and context rendered, err := promptFn(&dotprompt.DataArgument{ Input: input, - Context: context, + Context: templateContext, }, additionalMetadata) if err != nil { return nil, fmt.Errorf("failed to render prompt: %w", err) diff --git a/go/ai/prompt_test.go b/go/ai/prompt_test.go index 7bc0ed3a5b..e570e40313 100644 --- a/go/ai/prompt_test.go +++ b/go/ai/prompt_test.go @@ -26,6 +26,7 @@ import ( "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/api" + "github.com/firebase/genkit/go/core/x/session" "github.com/firebase/genkit/go/internal/base" "github.com/firebase/genkit/go/internal/registry" "github.com/google/go-cmp/cmp" @@ -2168,6 +2169,143 @@ func TestPromptExecuteStream(t *testing.T) { }) } +// TestSessionStateInjection tests that session state is automatically injected +// into prompt templates and accessible via {{@state.field}} syntax. +func TestSessionStateInjection(t *testing.T) { + r := registry.New() + ConfigureFormats(r) + + // Define a test state type + type UserState struct { + Name string `json:"name"` + Preferences map[string]string `json:"preferences"` + } + + t.Run("session state accessible in prompt template", func(t *testing.T) { + var capturedPrompt string + + testModel := DefineModel(r, "test/sessionStateModel", &ModelOptions{ + Supports: &ModelSupports{Multiturn: true}, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + capturedPrompt = req.Messages[0].Text() + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("response"), + }, nil + }) + + // Create a prompt that uses {{@state.name}} syntax + p := DefinePrompt(r, "sessionStatePrompt", + WithModel(testModel), + WithPrompt("Hello {{@state.name}}, your theme is {{@state.preferences.theme}}"), + ) + + // Create a session with state + ctx := context.Background() + sess, err := session.New(ctx, session.WithInitialState(UserState{ + Name: "Alice", + Preferences: map[string]string{"theme": "dark"}, + })) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + + // Attach session to context + ctx = session.NewContext(ctx, sess) + + // Execute prompt with session in context + _, err = p.Execute(ctx) + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + + // Verify the session state was injected into the template + expected := "Hello Alice, your theme is dark" + if capturedPrompt != expected { + t.Errorf("Expected prompt %q, got %q", expected, capturedPrompt) + } + }) + + t.Run("prompt works without session in context", func(t *testing.T) { + var capturedPrompt string + + testModel := DefineModel(r, "test/noSessionModel", &ModelOptions{ + Supports: &ModelSupports{Multiturn: true}, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + capturedPrompt = req.Messages[0].Text() + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("response"), + }, nil + }) + + // Create a prompt that uses regular input variables (not session state) + p := DefinePrompt(r, "noSessionPrompt", + WithModel(testModel), + WithPrompt("Hello {{name}}"), + WithInputType(struct { + Name string `json:"name"` + }{}), + ) + + // Execute without session in context + ctx := context.Background() + _, err := p.Execute(ctx, WithInput(map[string]any{"name": "Bob"})) + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + + expected := "Hello Bob" + if capturedPrompt != expected { + t.Errorf("Expected prompt %q, got %q", expected, capturedPrompt) + } + }) + + t.Run("session state and input variables can be used together", func(t *testing.T) { + var capturedPrompt string + + testModel := DefineModel(r, "test/mixedModel", &ModelOptions{ + Supports: &ModelSupports{Multiturn: true}, + }, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) { + capturedPrompt = req.Messages[0].Text() + return &ModelResponse{ + Request: req, + Message: NewModelTextMessage("response"), + }, nil + }) + + // Create a prompt that uses both input and session state + p := DefinePrompt(r, "mixedPrompt", + WithModel(testModel), + WithPrompt("User {{@state.name}} asks: {{question}}"), + WithInputType(struct { + Question string `json:"question"` + }{}), + ) + + // Create session + ctx := context.Background() + sess, err := session.New(ctx, session.WithInitialState(UserState{ + Name: "Charlie", + })) + if err != nil { + t.Fatalf("Failed to create session: %v", err) + } + ctx = session.NewContext(ctx, sess) + + // Execute with both session and input + _, err = p.Execute(ctx, WithInput(map[string]any{"question": "What is the weather?"})) + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + + expected := "User Charlie asks: What is the weather?" + if capturedPrompt != expected { + t.Errorf("Expected prompt %q, got %q", expected, capturedPrompt) + } + }) +} + // TestDefineExecuteOptionInteractions tests the complex interactions between // options set at DefinePrompt time vs Execute time. func TestDefineExecuteOptionInteractions(t *testing.T) { diff --git a/go/core/x/session/session.go b/go/core/x/session/session.go new file mode 100644 index 0000000000..8a9f0387f9 --- /dev/null +++ b/go/core/x/session/session.go @@ -0,0 +1,358 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// Package session provides experimental session management APIs for Genkit. +// +// A session encapsulates a stateful execution environment with strongly-typed +// state that can be persisted across requests. Sessions are useful for maintaining +// user preferences, conversation context, or any application state that needs +// to survive between interactions. +// +// APIs in this package are under active development and may change in any +// minor version release. Use with caution in production environments. +// +// When these APIs stabilize, they will be moved to the core package +// and these exports will be deprecated. +package session + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sync" + + "github.com/google/uuid" +) + +// Session represents a stateful environment with typed state. +// The type parameter S defines the shape of the session state and must be +// JSON-serializable for persistence. +type Session[S any] struct { + id string + state S + store Store[S] + mu sync.RWMutex +} + +// Data is the serializable session state persisted by Store. +type Data[S any] struct { + ID string `json:"id"` + State S `json:"state,omitempty"` +} + +// Store persists session data to a backend (database, file, memory, etc). +// Implementations must be safe for concurrent use. +type Store[S any] interface { + // Get retrieves session data by ID. Returns nil if not found. + Get(ctx context.Context, sessionID string) (*Data[S], error) + // Save persists session data, creating or updating as needed. + Save(ctx context.Context, sessionID string, data *Data[S]) error +} + +// options holds configuration for creating a Session. +type options[S any] struct { + ID string + InitialState S + Store Store[S] + hasID bool + hasState bool + hasStore bool +} + +// Option configures a Session during creation. +type Option[S any] interface { + apply(*options[S]) error +} + +// apply implements Option for options, enabling composition. +func (o *options[S]) apply(opts *options[S]) error { + if o.hasID { + if opts.hasID { + return errors.New("cannot set ID more than once (WithID)") + } + opts.ID = o.ID + opts.hasID = true + } + + if o.hasState { + if opts.hasState { + return errors.New("cannot set initial state more than once (WithInitialState)") + } + opts.InitialState = o.InitialState + opts.hasState = true + } + + if o.hasStore { + if opts.hasStore { + return errors.New("cannot set store more than once (WithStore)") + } + opts.Store = o.Store + opts.hasStore = true + } + + return nil +} + +// WithID sets a custom session ID. If not provided, a UUID is generated. +func WithID[S any](id string) Option[S] { + return &options[S]{ID: id, hasID: true} +} + +// WithInitialState sets the initial state for a new session. +func WithInitialState[S any](state S) Option[S] { + return &options[S]{InitialState: state, hasState: true} +} + +// WithStore sets the persistence backend for the session. +// If not provided, the session is not persisted and exists only in memory. +func WithStore[S any](store Store[S]) Option[S] { + return &options[S]{Store: store, hasStore: true} +} + +// New creates a new session with the provided options. +// If a store is provided via [WithStore], the session is persisted immediately. +// If no store is provided, the session exists only in memory for the current +// request and can be propagated via context using [NewContext]. +// If no ID is provided, a new UUID is generated. +// If no initial state is provided, the session is created with an empty state. +func New[S any](ctx context.Context, opts ...Option[S]) (*Session[S], error) { + o := &options[S]{} + for _, opt := range opts { + if err := opt.apply(o); err != nil { + return nil, fmt.Errorf("session.New: %w", err) + } + } + + id := o.ID + if !o.hasID { + id = uuid.New().String() + } + + // Only persist if a store was explicitly provided + if o.hasStore { + data := &Data[S]{ + ID: id, + State: o.InitialState, + } + if err := o.Store.Save(ctx, id, data); err != nil { + return nil, fmt.Errorf("session.New: failed to persist initial state: %w", err) + } + } + + return &Session[S]{ + id: id, + state: o.InitialState, + store: o.Store, // nil if no store provided + }, nil +} + +// Load loads an existing session from the store. +// Returns an error if the session is not found or if loading fails. +func Load[S any](ctx context.Context, store Store[S], sessionID string) (*Session[S], error) { + data, err := store.Get(ctx, sessionID) + if err != nil { + return nil, err + } + if data == nil { + return nil, &NotFoundError{SessionID: sessionID} + } + + return &Session[S]{ + id: data.ID, + state: data.State, + store: store, + }, nil +} + +// ID returns the session's unique identifier. +func (s *Session[S]) ID() string { + return s.id +} + +// State returns the current session state. +// The returned value is a copy; modifications do not affect the session. +func (s *Session[S]) State() S { + s.mu.RLock() + defer s.mu.RUnlock() + return deepCopyState(s.state) +} + +// deepCopyState creates a deep copy of the state using JSON marshaling. +// Panics if serialization fails, as this indicates a programming error +// (the state type S must be JSON-serializable per the Session contract). +func deepCopyState[S any](state S) S { + bytes, err := json.Marshal(state) + if err != nil { + panic(fmt.Sprintf("session.State: failed to marshal state: %v", err)) + } + + var copied S + if err := json.Unmarshal(bytes, &copied); err != nil { + panic(fmt.Sprintf("session.State: failed to unmarshal state: %v", err)) + } + + return copied +} + +// UpdateState updates the session state and persists it to the store (if configured). +func (s *Session[S]) UpdateState(ctx context.Context, state S) error { + s.mu.Lock() + defer s.mu.Unlock() + + s.state = state + + if s.store != nil { + data := &Data[S]{ + ID: s.id, + State: state, + } + if err := s.store.Save(ctx, s.id, data); err != nil { + return err + } + } + + return nil +} + +// contextKey is a private type for context keys to avoid collisions. +type contextKey struct{} + +// sessionContextKey is the key used to store sessions in context. +var sessionContextKey = contextKey{} + +// sessionHolder wraps a session with its type erased for context storage. +type sessionHolder struct { + session any +} + +// NewContext returns a new context with the session attached. +func NewContext[S any](ctx context.Context, s *Session[S]) context.Context { + return context.WithValue(ctx, sessionContextKey, &sessionHolder{session: s}) +} + +// FromContext retrieves the current session from context. +// Returns nil if no session is in context or if the type doesn't match. +func FromContext[S any](ctx context.Context) *Session[S] { + holder, ok := ctx.Value(sessionContextKey).(*sessionHolder) + if !ok || holder == nil { + return nil + } + session, ok := holder.session.(*Session[S]) + if !ok { + return nil + } + return session +} + +// stateGetter is an internal interface for retrieving state without type parameters. +type stateGetter interface { + getState() any +} + +// getState implements stateGetter, returning the session state as any. +func (s *Session[S]) getState() any { + return s.State() +} + +// StateFromContext retrieves the current session state from context without +// requiring knowledge of the state type. This is useful for template rendering +// where the state type is not known at compile time. +// Returns nil if no session is in context. +func StateFromContext(ctx context.Context) any { + holder, ok := ctx.Value(sessionContextKey).(*sessionHolder) + if !ok || holder == nil { + return nil + } + if getter, ok := holder.session.(stateGetter); ok { + return getter.getState() + } + return nil +} + +// NotFoundError is returned when a session cannot be found in the store. +type NotFoundError struct { + SessionID string +} + +func (e *NotFoundError) Error() string { + return "session not found: " + e.SessionID +} + +// InMemoryStore is a thread-safe in-memory implementation of Store. +// Useful for testing or single-instance deployments where persistence is not required. +type InMemoryStore[S any] struct { + data map[string]*Data[S] + mu sync.RWMutex +} + +// NewInMemoryStore creates a new in-memory session store. +func NewInMemoryStore[S any]() *InMemoryStore[S] { + return &InMemoryStore[S]{ + data: make(map[string]*Data[S]), + } +} + +// Get retrieves session data by ID. +func (s *InMemoryStore[S]) Get(_ context.Context, sessionID string) (*Data[S], error) { + s.mu.RLock() + defer s.mu.RUnlock() + + data, exists := s.data[sessionID] + if !exists { + return nil, nil + } + + // Return a copy to prevent external modifications + copied, err := copyData(data) + if err != nil { + return nil, err + } + return copied, nil +} + +// Save persists session data. +func (s *InMemoryStore[S]) Save(_ context.Context, sessionID string, data *Data[S]) error { + s.mu.Lock() + defer s.mu.Unlock() + + // Store a copy to prevent external modifications + copied, err := copyData(data) + if err != nil { + return err + } + s.data[sessionID] = copied + return nil +} + +// copyData creates a deep copy of Data using JSON marshaling. +func copyData[S any](data *Data[S]) (*Data[S], error) { + if data == nil { + return nil, nil + } + + bytes, err := json.Marshal(data) + if err != nil { + return nil, err + } + + var copied Data[S] + if err := json.Unmarshal(bytes, &copied); err != nil { + return nil, err + } + + return &copied, nil +} diff --git a/go/core/x/session/session_test.go b/go/core/x/session/session_test.go new file mode 100644 index 0000000000..b34100a44c --- /dev/null +++ b/go/core/x/session/session_test.go @@ -0,0 +1,782 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package session + +import ( + "context" + "errors" + "strings" + "sync" + "testing" +) + +// UserState is a test state type with various field types. +type UserState struct { + Name string `json:"name"` + Count int `json:"count"` + Preferences map[string]string `json:"preferences,omitempty"` +} + +func TestNew_DefaultID(t *testing.T) { + ctx := context.Background() + sess, err := New[UserState](ctx) + if err != nil { + t.Fatalf("New failed: %v", err) + } + + if sess.ID() == "" { + t.Error("Expected session to have a generated ID") + } +} + +func TestNew_WithID(t *testing.T) { + ctx := context.Background() + customID := "my-custom-id" + sess, err := New(ctx, WithID[UserState](customID)) + if err != nil { + t.Fatalf("New failed: %v", err) + } + + if sess.ID() != customID { + t.Errorf("Expected ID %q, got %q", customID, sess.ID()) + } +} + +func TestNew_WithInitialState(t *testing.T) { + ctx := context.Background() + initial := UserState{Name: "Alice", Count: 42} + sess, err := New(ctx, WithInitialState(initial)) + if err != nil { + t.Fatalf("New failed: %v", err) + } + + got := sess.State() + if got.Name != initial.Name { + t.Errorf("Expected Name %q, got %q", initial.Name, got.Name) + } + if got.Count != initial.Count { + t.Errorf("Expected Count %d, got %d", initial.Count, got.Count) + } +} + +func TestNew_WithStore(t *testing.T) { + ctx := context.Background() + store := NewInMemoryStore[UserState]() + sess, err := New(ctx, WithStore(store)) + if err != nil { + t.Fatalf("New failed: %v", err) + } + + if sess.store != store { + t.Error("Expected store to be set") + } +} + +func TestNew_MultipleOptions(t *testing.T) { + ctx := context.Background() + store := NewInMemoryStore[UserState]() + customID := "multi-option-id" + initial := UserState{Name: "Bob", Count: 100} + + sess, err := New(ctx, + WithID[UserState](customID), + WithInitialState(initial), + WithStore(store), + ) + if err != nil { + t.Fatalf("New failed: %v", err) + } + + if sess.ID() != customID { + t.Errorf("Expected ID %q, got %q", customID, sess.ID()) + } + if sess.State().Name != initial.Name { + t.Errorf("Expected Name %q, got %q", initial.Name, sess.State().Name) + } + if sess.store != store { + t.Error("Expected store to be set") + } +} + +func TestNew_DuplicateID(t *testing.T) { + ctx := context.Background() + _, err := New(ctx, + WithID[UserState]("first"), + WithID[UserState]("second"), + ) + if err == nil { + t.Fatal("Expected error for duplicate WithID") + } + if !strings.Contains(err.Error(), "cannot set ID more than once") { + t.Errorf("Expected duplicate ID error, got: %v", err) + } +} + +func TestNew_DuplicateInitialState(t *testing.T) { + ctx := context.Background() + _, err := New(ctx, + WithInitialState(UserState{Name: "First"}), + WithInitialState(UserState{Name: "Second"}), + ) + if err == nil { + t.Fatal("Expected error for duplicate WithInitialState") + } + if !strings.Contains(err.Error(), "cannot set initial state more than once") { + t.Errorf("Expected duplicate state error, got: %v", err) + } +} + +func TestNew_DuplicateStore(t *testing.T) { + ctx := context.Background() + store1 := NewInMemoryStore[UserState]() + store2 := NewInMemoryStore[UserState]() + _, err := New(ctx, + WithStore(store1), + WithStore(store2), + ) + if err == nil { + t.Fatal("Expected error for duplicate WithStore") + } + if !strings.Contains(err.Error(), "cannot set store more than once") { + t.Errorf("Expected duplicate store error, got: %v", err) + } +} + +func TestSession_State(t *testing.T) { + ctx := context.Background() + initial := UserState{ + Name: "Alice", + Count: 10, + Preferences: map[string]string{"theme": "dark"}, + } + sess, err := New(ctx, WithInitialState(initial)) + if err != nil { + t.Fatalf("New failed: %v", err) + } + + t.Run("returns correct values", func(t *testing.T) { + got := sess.State() + if got.Name != initial.Name { + t.Errorf("Expected Name %q, got %q", initial.Name, got.Name) + } + if got.Count != initial.Count { + t.Errorf("Expected Count %d, got %d", initial.Count, got.Count) + } + if got.Preferences["theme"] != "dark" { + t.Errorf("Expected theme %q, got %q", "dark", got.Preferences["theme"]) + } + }) + + t.Run("modifications to returned copy do not affect session", func(t *testing.T) { + // Get a copy of the state + copy1 := sess.State() + + // Modify the map in the returned copy + copy1.Preferences["theme"] = "light" + copy1.Preferences["newKey"] = "newValue" + copy1.Name = "Modified" + copy1.Count = 999 + + // Get another copy and verify the session's internal state is unchanged + copy2 := sess.State() + + if copy2.Name != "Alice" { + t.Errorf("Session state was mutated: expected Name %q, got %q", "Alice", copy2.Name) + } + if copy2.Count != 10 { + t.Errorf("Session state was mutated: expected Count %d, got %d", 10, copy2.Count) + } + if copy2.Preferences["theme"] != "dark" { + t.Errorf("Session state was mutated: expected theme %q, got %q", "dark", copy2.Preferences["theme"]) + } + if _, exists := copy2.Preferences["newKey"]; exists { + t.Errorf("Session state was mutated: unexpected key 'newKey' in Preferences") + } + }) +} + +func TestSession_UpdateState_NoStore(t *testing.T) { + ctx := context.Background() + sess, err := New(ctx, WithInitialState(UserState{Name: "Alice"})) + if err != nil { + t.Fatalf("New failed: %v", err) + } + + // Verify no store is set when not provided + if sess.store != nil { + t.Fatal("Expected no store when not provided") + } + + newState := UserState{Name: "Bob", Count: 5} + if err := sess.UpdateState(ctx, newState); err != nil { + t.Fatalf("UpdateState failed: %v", err) + } + + // State should still be updated in memory + got := sess.State() + if got.Name != newState.Name { + t.Errorf("Expected Name %q, got %q", newState.Name, got.Name) + } + if got.Count != newState.Count { + t.Errorf("Expected Count %d, got %d", newState.Count, got.Count) + } +} + +func TestSession_UpdateState_WithStore(t *testing.T) { + ctx := context.Background() + store := NewInMemoryStore[UserState]() + sess, err := New(ctx, + WithID[UserState]("test-session"), + WithInitialState(UserState{Name: "Alice"}), + WithStore(store), + ) + if err != nil { + t.Fatalf("New failed: %v", err) + } + + newState := UserState{Name: "Bob", Count: 5} + if err := sess.UpdateState(ctx, newState); err != nil { + t.Fatalf("UpdateState failed: %v", err) + } + + // Verify state is updated in session + got := sess.State() + if got.Name != newState.Name { + t.Errorf("Expected Name %q, got %q", newState.Name, got.Name) + } + + // Verify state is persisted in store + data, err := store.Get(ctx, "test-session") + if err != nil { + t.Fatalf("Store.Get failed: %v", err) + } + if data == nil { + t.Fatal("Expected data in store, got nil") + } + if data.State.Name != newState.Name { + t.Errorf("Store: expected Name %q, got %q", newState.Name, data.State.Name) + } +} + +func TestLoad_Success(t *testing.T) { + store := NewInMemoryStore[UserState]() + ctx := context.Background() + + // Save some data + data := &Data[UserState]{ + ID: "existing-session", + State: UserState{Name: "Charlie", Count: 99}, + } + if err := store.Save(ctx, data.ID, data); err != nil { + t.Fatalf("Store.Save failed: %v", err) + } + + // Load the session + loaded, err := Load(ctx, store, "existing-session") + if err != nil { + t.Fatalf("Load failed: %v", err) + } + + if loaded.ID() != "existing-session" { + t.Errorf("Expected ID %q, got %q", "existing-session", loaded.ID()) + } + if loaded.State().Name != "Charlie" { + t.Errorf("Expected Name %q, got %q", "Charlie", loaded.State().Name) + } + if loaded.State().Count != 99 { + t.Errorf("Expected Count %d, got %d", 99, loaded.State().Count) + } +} + +func TestLoad_NotFound(t *testing.T) { + store := NewInMemoryStore[UserState]() + ctx := context.Background() + + _, err := Load(ctx, store, "non-existent") + if err == nil { + t.Fatal("Expected error for non-existent session") + } + + var notFoundErr *NotFoundError + if !errors.As(err, ¬FoundErr) { + t.Errorf("Expected NotFoundError, got %T: %v", err, err) + } + if notFoundErr.SessionID != "non-existent" { + t.Errorf("Expected SessionID %q, got %q", "non-existent", notFoundErr.SessionID) + } +} + +func TestNewContext_FromContext(t *testing.T) { + ctx := context.Background() + sess, err := New(ctx, + WithID[UserState]("ctx-test"), + WithInitialState(UserState{Name: "Diana"}), + ) + if err != nil { + t.Fatalf("New failed: %v", err) + } + + // Attach session to context + ctx = NewContext(ctx, sess) + + // Retrieve from context + retrieved := FromContext[UserState](ctx) + if retrieved == nil { + t.Fatal("Expected session from context, got nil") + } + if retrieved.ID() != "ctx-test" { + t.Errorf("Expected ID %q, got %q", "ctx-test", retrieved.ID()) + } + if retrieved.State().Name != "Diana" { + t.Errorf("Expected Name %q, got %q", "Diana", retrieved.State().Name) + } +} + +func TestStateFromContext(t *testing.T) { + t.Run("returns state when session exists", func(t *testing.T) { + ctx := context.Background() + initial := UserState{ + Name: "Alice", + Count: 42, + Preferences: map[string]string{"theme": "dark"}, + } + sess, err := New(ctx, WithInitialState(initial)) + if err != nil { + t.Fatalf("New failed: %v", err) + } + + ctx = NewContext(ctx, sess) + + state := StateFromContext(ctx) + if state == nil { + t.Fatal("Expected state from context, got nil") + } + + // StateFromContext returns the state as any, so we need to type assert + userState, ok := state.(UserState) + if !ok { + t.Fatalf("Expected UserState, got %T", state) + } + + if userState.Name != "Alice" { + t.Errorf("Expected Name %q, got %q", "Alice", userState.Name) + } + if userState.Count != 42 { + t.Errorf("Expected Count %d, got %d", 42, userState.Count) + } + if userState.Preferences["theme"] != "dark" { + t.Errorf("Expected theme %q, got %q", "dark", userState.Preferences["theme"]) + } + }) + + t.Run("returns nil when no session in context", func(t *testing.T) { + ctx := context.Background() + state := StateFromContext(ctx) + if state != nil { + t.Errorf("Expected nil for empty context, got %v", state) + } + }) + + t.Run("returns deep copy that cannot mutate session", func(t *testing.T) { + ctx := context.Background() + initial := UserState{ + Name: "Bob", + Preferences: map[string]string{"lang": "en"}, + } + sess, err := New(ctx, WithInitialState(initial)) + if err != nil { + t.Fatalf("New failed: %v", err) + } + + ctx = NewContext(ctx, sess) + + // Get state via StateFromContext + state := StateFromContext(ctx) + userState := state.(UserState) + + // Modify the returned state + userState.Name = "Modified" + userState.Preferences["lang"] = "fr" + + // Verify the session's internal state is unchanged + originalState := sess.State() + if originalState.Name != "Bob" { + t.Errorf("Session state was mutated: expected Name %q, got %q", "Bob", originalState.Name) + } + if originalState.Preferences["lang"] != "en" { + t.Errorf("Session state was mutated: expected lang %q, got %q", "en", originalState.Preferences["lang"]) + } + }) +} + +func TestFromContext_NoSession(t *testing.T) { + ctx := context.Background() + + retrieved := FromContext[UserState](ctx) + if retrieved != nil { + t.Errorf("Expected nil for empty context, got %v", retrieved) + } +} + +func TestFromContext_WrongType(t *testing.T) { + ctx := context.Background() + // Create session with one type + type OtherState struct { + Value string + } + sess, err := New(ctx, WithInitialState(OtherState{Value: "test"})) + if err != nil { + t.Fatalf("New failed: %v", err) + } + ctx = NewContext(ctx, sess) + + // Try to retrieve with different type + retrieved := FromContext[UserState](ctx) + if retrieved != nil { + t.Errorf("Expected nil for wrong type, got %v", retrieved) + } +} + +func TestInMemoryStore_GetSave(t *testing.T) { + store := NewInMemoryStore[UserState]() + ctx := context.Background() + + // Initially empty + data, err := store.Get(ctx, "test-id") + if err != nil { + t.Fatalf("Get failed: %v", err) + } + if data != nil { + t.Errorf("Expected nil for non-existent key, got %v", data) + } + + // Save data + original := &Data[UserState]{ + ID: "test-id", + State: UserState{Name: "Eve", Count: 7}, + } + if err := store.Save(ctx, "test-id", original); err != nil { + t.Fatalf("Save failed: %v", err) + } + + // Retrieve data + retrieved, err := store.Get(ctx, "test-id") + if err != nil { + t.Fatalf("Get failed: %v", err) + } + if retrieved == nil { + t.Fatal("Expected data, got nil") + } + if retrieved.ID != original.ID { + t.Errorf("Expected ID %q, got %q", original.ID, retrieved.ID) + } + if retrieved.State.Name != original.State.Name { + t.Errorf("Expected Name %q, got %q", original.State.Name, retrieved.State.Name) + } +} + +func TestInMemoryStore_Isolation(t *testing.T) { + store := NewInMemoryStore[UserState]() + ctx := context.Background() + + // Save data + original := &Data[UserState]{ + ID: "isolation-test", + State: UserState{Name: "Frank", Count: 1}, + } + if err := store.Save(ctx, "isolation-test", original); err != nil { + t.Fatalf("Save failed: %v", err) + } + + // Modify original after save + original.State.Name = "Modified" + + // Retrieved data should not be affected + retrieved, err := store.Get(ctx, "isolation-test") + if err != nil { + t.Fatalf("Get failed: %v", err) + } + if retrieved.State.Name != "Frank" { + t.Errorf("Expected Name %q (isolation), got %q", "Frank", retrieved.State.Name) + } + + // Modify retrieved data + retrieved.State.Name = "Also Modified" + + // Get again - should still be original + retrieved2, err := store.Get(ctx, "isolation-test") + if err != nil { + t.Fatalf("Get failed: %v", err) + } + if retrieved2.State.Name != "Frank" { + t.Errorf("Expected Name %q (isolation), got %q", "Frank", retrieved2.State.Name) + } +} + +func TestInMemoryStore_Overwrite(t *testing.T) { + store := NewInMemoryStore[UserState]() + ctx := context.Background() + + // Save initial data + initial := &Data[UserState]{ + ID: "overwrite-test", + State: UserState{Name: "Grace", Count: 1}, + } + if err := store.Save(ctx, "overwrite-test", initial); err != nil { + t.Fatalf("Save failed: %v", err) + } + + // Overwrite with new data + updated := &Data[UserState]{ + ID: "overwrite-test", + State: UserState{Name: "Grace Updated", Count: 2}, + } + if err := store.Save(ctx, "overwrite-test", updated); err != nil { + t.Fatalf("Save failed: %v", err) + } + + // Retrieve and verify + retrieved, err := store.Get(ctx, "overwrite-test") + if err != nil { + t.Fatalf("Get failed: %v", err) + } + if retrieved.State.Name != "Grace Updated" { + t.Errorf("Expected Name %q, got %q", "Grace Updated", retrieved.State.Name) + } + if retrieved.State.Count != 2 { + t.Errorf("Expected Count %d, got %d", 2, retrieved.State.Count) + } +} + +func TestSession_ConcurrentAccess(t *testing.T) { + ctx := context.Background() + store := NewInMemoryStore[UserState]() + sess, err := New(ctx, + WithID[UserState]("concurrent-test"), + WithInitialState(UserState{Name: "Initial", Count: 0}), + WithStore(store), + ) + if err != nil { + t.Fatalf("New failed: %v", err) + } + + const numGoroutines = 10 + const numUpdates = 100 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < numUpdates; j++ { + // Read state + _ = sess.State() + + // Update state + _ = sess.UpdateState(ctx, UserState{ + Name: "Goroutine", + Count: id*numUpdates + j, + }) + } + }(i) + } + + wg.Wait() + + // Verify no data corruption + state := sess.State() + if state.Name != "Goroutine" { + t.Errorf("Expected Name %q, got %q", "Goroutine", state.Name) + } +} + +func TestInMemoryStore_ConcurrentAccess(t *testing.T) { + store := NewInMemoryStore[UserState]() + ctx := context.Background() + + const numGoroutines = 10 + const numOperations = 100 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + key := "shared-key" + for j := 0; j < numOperations; j++ { + // Save + data := &Data[UserState]{ + ID: key, + State: UserState{Name: "Concurrent", Count: id*numOperations + j}, + } + _ = store.Save(ctx, key, data) + + // Get + _, _ = store.Get(ctx, key) + } + }(i) + } + + wg.Wait() + + // Verify we can still read + data, err := store.Get(ctx, "shared-key") + if err != nil { + t.Fatalf("Get failed: %v", err) + } + if data == nil { + t.Fatal("Expected data, got nil") + } +} + +func TestNotFoundError(t *testing.T) { + err := &NotFoundError{SessionID: "test-123"} + + expected := "session not found: test-123" + if err.Error() != expected { + t.Errorf("Expected error message %q, got %q", expected, err.Error()) + } +} + +func TestSession_ZeroState(t *testing.T) { + ctx := context.Background() + // Create session without initial state + sess, err := New[UserState](ctx) + if err != nil { + t.Fatalf("New failed: %v", err) + } + + state := sess.State() + if state.Name != "" { + t.Errorf("Expected empty Name, got %q", state.Name) + } + if state.Count != 0 { + t.Errorf("Expected zero Count, got %d", state.Count) + } + if state.Preferences != nil { + t.Errorf("Expected nil Preferences, got %v", state.Preferences) + } +} + +func TestSession_ComplexState(t *testing.T) { + ctx := context.Background() + type NestedState struct { + Inner struct { + Value string `json:"value"` + } `json:"inner"` + List []int `json:"list"` + } + + store := NewInMemoryStore[NestedState]() + initial := NestedState{ + List: []int{1, 2, 3}, + } + initial.Inner.Value = "nested" + + sess, err := New(ctx, + WithID[NestedState]("complex-state"), + WithInitialState(initial), + WithStore(store), + ) + if err != nil { + t.Fatalf("New failed: %v", err) + } + + // Update with nested modifications + newState := NestedState{ + List: []int{4, 5, 6, 7}, + } + newState.Inner.Value = "updated nested" + + if err := sess.UpdateState(ctx, newState); err != nil { + t.Fatalf("UpdateState failed: %v", err) + } + + // Verify nested state is correct + got := sess.State() + if got.Inner.Value != "updated nested" { + t.Errorf("Expected Inner.Value %q, got %q", "updated nested", got.Inner.Value) + } + if len(got.List) != 4 { + t.Errorf("Expected List length %d, got %d", 4, len(got.List)) + } + + // Verify persistence + data, err := store.Get(ctx, "complex-state") + if err != nil { + t.Fatalf("Store.Get failed: %v", err) + } + if data.State.Inner.Value != "updated nested" { + t.Errorf("Store: expected Inner.Value %q, got %q", "updated nested", data.State.Inner.Value) + } +} + +// mockFailingStore is a store that fails on Save for testing error handling. +type mockFailingStore[S any] struct { + saveErr error +} + +func (s *mockFailingStore[S]) Get(_ context.Context, _ string) (*Data[S], error) { + return nil, nil +} + +func (s *mockFailingStore[S]) Save(_ context.Context, _ string, _ *Data[S]) error { + return s.saveErr +} +func TestNew_StoreError(t *testing.T) { + ctx := context.Background() + expectedErr := errors.New("store failure") + store := &mockFailingStore[UserState]{saveErr: expectedErr} + _, err := New(ctx, + WithID[UserState]("error-test"), + WithStore(store), + ) + if err == nil { + t.Fatal("Expected error from failing store") + } + if !strings.Contains(err.Error(), "failed to persist initial state") { + t.Errorf("Expected persist error, got: %v", err) + } + if !errors.Is(err, expectedErr) { + t.Errorf("Expected wrapped error %v, got %v", expectedErr, err) + } +} + +func TestSession_UpdateState_StoreError(t *testing.T) { + ctx := context.Background() + store := NewInMemoryStore[UserState]() + sess, err := New(ctx, + WithID[UserState]("error-test"), + WithStore(store), + ) + if err != nil { + t.Fatalf("New failed: %v", err) + } + + expectedErr := errors.New("store failure") + sess.store = &mockFailingStore[UserState]{saveErr: expectedErr} + + err = sess.UpdateState(ctx, UserState{Name: "Test"}) + if err == nil { + t.Fatal("Expected error from failing store") + } + if err != expectedErr { + t.Errorf("Expected error %v, got %v", expectedErr, err) + } +} diff --git a/go/plugins/firebase/x/option.go b/go/plugins/firebase/x/option.go new file mode 100644 index 0000000000..9de3443d60 --- /dev/null +++ b/go/plugins/firebase/x/option.go @@ -0,0 +1,76 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package x + +import ( + "errors" + "time" +) + +const ( + // DefaultTTL is the default time-to-live for Firestore documents. + DefaultTTL = 5 * time.Minute +) + +// firestoreOptions holds common configuration for Firestore-based services. +type firestoreOptions struct { + Collection string + TTL time.Duration +} + +// applyFirestore applies common Firestore options. +func (o *firestoreOptions) applyFirestore(opts *firestoreOptions) error { + if o.Collection != "" { + if opts.Collection != "" { + return errors.New("cannot set collection more than once (WithCollection)") + } + opts.Collection = o.Collection + } + + if o.TTL > 0 { + if opts.TTL > 0 { + return errors.New("cannot set TTL more than once (WithTTL)") + } + opts.TTL = o.TTL + } + + return nil +} + +// applyStreamManager implements StreamManagerOption for firestoreOptions. +func (o *firestoreOptions) applyStreamManager(opts *streamManagerOptions) error { + return o.applyFirestore(&opts.firestoreOptions) +} + +// applySessionStore implements SessionStoreOption for firestoreOptions. +func (o *firestoreOptions) applySessionStore(opts *sessionStoreOptions) error { + return o.applyFirestore(&opts.firestoreOptions) +} + +// WithCollection sets the Firestore collection name where documents are stored. +// This option is required for all Firestore-based services. +func WithCollection(collection string) *firestoreOptions { + return &firestoreOptions{Collection: collection} +} + +// WithTTL sets how long documents are retained before Firestore auto-deletes them. +// Requires a TTL policy on the collection for the "expiresAt" field. +// Default is 5 minutes. +// See: https://firebase.google.com/docs/firestore/ttl +func WithTTL(ttl time.Duration) *firestoreOptions { + return &firestoreOptions{TTL: ttl} +} diff --git a/go/plugins/firebase/x/session_store.go b/go/plugins/firebase/x/session_store.go new file mode 100644 index 0000000000..57ce322f4b --- /dev/null +++ b/go/plugins/firebase/x/session_store.go @@ -0,0 +1,184 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package x + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + + "cloud.google.com/go/firestore" + "github.com/firebase/genkit/go/core/x/session" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/firebase" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// SessionStoreOption configures a FirestoreSessionStore. +// Implemented by firestoreOptions (WithCollection, WithTTL). +type SessionStoreOption interface { + applySessionStore(*sessionStoreOptions) error +} + +// sessionStoreOptions holds configuration for FirestoreSessionStore. +type sessionStoreOptions struct { + firestoreOptions +} + +// applySessionStore implements SessionStoreOption for sessionStoreOptions. +func (o *sessionStoreOptions) applySessionStore(opts *sessionStoreOptions) error { + return o.firestoreOptions.applyFirestore(&opts.firestoreOptions) +} + +// FirestoreSessionStore implements [session.Store[S]] using Firestore as the backend. +// Session state is persisted in Firestore documents, allowing sessions to survive +// server restarts and be accessible across multiple instances. +type FirestoreSessionStore[S any] struct { + client *firestore.Client + collection string + ttl time.Duration +} + +// sessionDocument represents the structure of a session document in Firestore. +type sessionDocument struct { + State json.RawMessage `firestore:"state"` + CreatedAt time.Time `firestore:"createdAt"` + UpdatedAt time.Time `firestore:"updatedAt"` + ExpiresAt *time.Time `firestore:"expiresAt,omitempty"` +} + +// NewFirestoreSessionStore creates a Firestore-backed session store. +// Requires the Firebase plugin to be initialized in the Genkit instance. +func NewFirestoreSessionStore[S any](ctx context.Context, g *genkit.Genkit, opts ...SessionStoreOption) (*FirestoreSessionStore[S], error) { + storeOpts := &sessionStoreOptions{} + for _, opt := range opts { + if err := opt.applySessionStore(storeOpts); err != nil { + return nil, fmt.Errorf("firebase.NewFirestoreSessionStore: error applying options: %w", err) + } + } + if storeOpts.Collection == "" { + return nil, errors.New("firebase.NewFirestoreSessionStore: Collection name is required.\n" + + " Specify the Firestore collection where session documents will be stored:\n" + + " firebase.NewFirestoreSessionStore[MyState](ctx, g, firebase.WithCollection(\"genkit-sessions\"))") + } + if storeOpts.TTL == 0 { + storeOpts.TTL = DefaultTTL + } + + plugin := genkit.LookupPlugin(g, "firebase") + if plugin == nil { + return nil, errors.New("firebase.NewFirestoreSessionStore: Firebase plugin not found.\n" + + " Pass the Firebase plugin to genkit.Init():\n" + + " g := genkit.Init(ctx, genkit.WithPlugins(&firebase.Firebase{ProjectId: \"your-project\"}))") + } + f, ok := plugin.(*firebase.Firebase) + if !ok { + return nil, fmt.Errorf("firebase.NewFirestoreSessionStore: unexpected plugin type %T", plugin) + } + + client, err := f.Firestore(ctx) + if err != nil { + return nil, fmt.Errorf("firebase.NewFirestoreSessionStore: %w", err) + } + + return &FirestoreSessionStore[S]{ + client: client, + collection: storeOpts.Collection, + ttl: storeOpts.TTL, + }, nil +} + +// Get retrieves session data by ID from Firestore. +// Returns nil if the session does not exist. +func (s *FirestoreSessionStore[S]) Get(ctx context.Context, sessionID string) (*session.Data[S], error) { + docRef := s.client.Collection(s.collection).Doc(sessionID) + + snapshot, err := docRef.Get(ctx) + if err != nil { + if status.Code(err) == codes.NotFound { + return nil, nil + } + return nil, fmt.Errorf("firebase.FirestoreSessionStore.Get: %w", err) + } + if !snapshot.Exists() { + return nil, nil + } + + var doc sessionDocument + if err := snapshot.DataTo(&doc); err != nil { + return nil, fmt.Errorf("firebase.FirestoreSessionStore.Get: failed to parse document: %w", err) + } + + var state S + if len(doc.State) > 0 { + if err := json.Unmarshal(doc.State, &state); err != nil { + return nil, fmt.Errorf("firebase.FirestoreSessionStore.Get: failed to unmarshal state: %w", err) + } + } + + return &session.Data[S]{ + ID: sessionID, + State: state, + }, nil +} + +// Save persists session data to Firestore, creating or updating as needed. +// CreatedAt is only set when the document is first created; subsequent saves +// only update UpdatedAt and ExpiresAt. +func (s *FirestoreSessionStore[S]) Save(ctx context.Context, sessionID string, data *session.Data[S]) error { + docRef := s.client.Collection(s.collection).Doc(sessionID) + + stateJSON, err := json.Marshal(data.State) + if err != nil { + return fmt.Errorf("firebase.FirestoreSessionStore.Save: failed to marshal state: %w", err) + } + + now := time.Now() + expiresAt := now.Add(s.ttl) + + err = s.client.RunTransaction(ctx, func(ctx context.Context, tx *firestore.Transaction) error { + snapshot, err := tx.Get(docRef) + if err != nil && status.Code(err) != codes.NotFound { + return err + } + + if !snapshot.Exists() { + // Document doesn't exist - create it with CreatedAt + return tx.Create(docRef, sessionDocument{ + State: stateJSON, + CreatedAt: now, + UpdatedAt: now, + ExpiresAt: &expiresAt, + }) + } + + // Document exists - update without modifying CreatedAt + return tx.Update(docRef, []firestore.Update{ + {Path: "state", Value: stateJSON}, + {Path: "updatedAt", Value: now}, + {Path: "expiresAt", Value: &expiresAt}, + }) + }) + if err != nil { + return fmt.Errorf("firebase.FirestoreSessionStore.Save: %w", err) + } + + return nil +} diff --git a/go/plugins/firebase/x/session_store_test.go b/go/plugins/firebase/x/session_store_test.go new file mode 100644 index 0000000000..7b7c998e45 --- /dev/null +++ b/go/plugins/firebase/x/session_store_test.go @@ -0,0 +1,439 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package x + +import ( + "context" + "flag" + "testing" + "time" + + "cloud.google.com/go/firestore" + "github.com/firebase/genkit/go/core/x/session" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/firebase" + "google.golang.org/api/iterator" +) + +var ( + testSessionProjectID = flag.String("test-session-project-id", "", "GCP Project ID to use for session store tests") + testSessionCollection = flag.String("test-session-collection", "genkit-sessions", "Firestore collection to use for session store tests") +) + +/* + * Pre-requisites to run this test: + * + * 1. **Option A - Use Firestore Emulator (Recommended for local development):** + * Start the Firestore emulator: + * ```bash + * export FIRESTORE_EMULATOR_HOST=127.0.0.1:8080 + * gcloud emulators firestore start --host-port=127.0.0.1:8080 + * ``` + * + * 2. **Option B - Use a Real Firestore Database:** + * - Set up a Firebase project with Firestore enabled + * - Authenticate using: + * ```bash + * gcloud auth application-default login + * ``` + * + * 3. **Running the Test:** + * ```bash + * go test -test-session-project-id= -test-session-collection=genkit-sessions + * ``` + */ + +// TestState is a test state type with various field types. +type TestState struct { + Name string `json:"name"` + Count int `json:"count"` + Preferences map[string]string `json:"preferences,omitempty"` +} + +func skipIfNoFirestoreSession(t *testing.T) { + if *testSessionProjectID == "" { + t.Skip("Skipping test: -test-session-project-id flag not provided") + } +} + +func setupTestSessionStore(t *testing.T) (*FirestoreSessionStore[TestState], *firestore.Client, func()) { + skipIfNoFirestoreSession(t) + + ctx := context.Background() + g := genkit.Init(ctx, genkit.WithPlugins(&firebase.Firebase{ProjectId: *testSessionProjectID})) + + f := genkit.LookupPlugin(g, "firebase").(*firebase.Firebase) + client, err := f.Firestore(ctx) + if err != nil { + t.Fatalf("Failed to get Firestore client: %v", err) + } + + store, err := NewFirestoreSessionStore[TestState](ctx, g, + WithCollection(*testSessionCollection), + ) + if err != nil { + t.Fatalf("Failed to create session store: %v", err) + } + + cleanup := func() { + deleteSessionCollection(ctx, client, *testSessionCollection, t) + } + + return store, client, cleanup +} + +func deleteSessionCollection(ctx context.Context, client *firestore.Client, collectionName string, t *testing.T) { + iter := client.Collection(collectionName).Documents(ctx) + for { + doc, err := iter.Next() + if err == iterator.Done { + break + } + if err != nil { + t.Logf("Failed to iterate documents for deletion: %v", err) + return + } + _, err = doc.Ref.Delete(ctx) + if err != nil { + t.Logf("Failed to delete document %s: %v", doc.Ref.ID, err) + } + } +} + +func TestNewFirestoreSessionStore_MissingCollection(t *testing.T) { + skipIfNoFirestoreSession(t) + + ctx := context.Background() + g := genkit.Init(ctx, genkit.WithPlugins(&firebase.Firebase{ProjectId: *testSessionProjectID})) + + _, err := NewFirestoreSessionStore[TestState](ctx, g) + if err == nil { + t.Fatal("Expected error when collection is missing") + } +} + +func TestFirestoreSessionStore_SaveAndGet(t *testing.T) { + store, _, cleanup := setupTestSessionStore(t) + defer cleanup() + + ctx := context.Background() + sessionID := "test-session-save-get" + + // Initially empty + data, err := store.Get(ctx, sessionID) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + if data != nil { + t.Errorf("Expected nil for non-existent session, got %v", data) + } + + // Save data + original := &session.Data[TestState]{ + ID: sessionID, + State: TestState{ + Name: "Alice", + Count: 42, + Preferences: map[string]string{"theme": "dark"}, + }, + } + if err := store.Save(ctx, sessionID, original); err != nil { + t.Fatalf("Save failed: %v", err) + } + + // Retrieve data + retrieved, err := store.Get(ctx, sessionID) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + if retrieved == nil { + t.Fatal("Expected data, got nil") + } + if retrieved.ID != sessionID { + t.Errorf("Expected ID %q, got %q", sessionID, retrieved.ID) + } + if retrieved.State.Name != original.State.Name { + t.Errorf("Expected Name %q, got %q", original.State.Name, retrieved.State.Name) + } + if retrieved.State.Count != original.State.Count { + t.Errorf("Expected Count %d, got %d", original.State.Count, retrieved.State.Count) + } + if retrieved.State.Preferences["theme"] != "dark" { + t.Errorf("Expected theme %q, got %q", "dark", retrieved.State.Preferences["theme"]) + } +} + +func TestFirestoreSessionStore_Overwrite(t *testing.T) { + store, client, cleanup := setupTestSessionStore(t) + defer cleanup() + + ctx := context.Background() + sessionID := "test-session-overwrite" + + // Save initial data + initial := &session.Data[TestState]{ + ID: sessionID, + State: TestState{Name: "Alice", Count: 1}, + } + if err := store.Save(ctx, sessionID, initial); err != nil { + t.Fatalf("Save failed: %v", err) + } + + // Get the initial document to capture CreatedAt and UpdatedAt + snapshot1, err := client.Collection(*testSessionCollection).Doc(sessionID).Get(ctx) + if err != nil { + t.Fatalf("Failed to get initial document: %v", err) + } + initialData := snapshot1.Data() + initialCreatedAt, ok := initialData["createdAt"].(time.Time) + if !ok { + t.Fatal("Expected createdAt to be a timestamp") + } + initialUpdatedAt, ok := initialData["updatedAt"].(time.Time) + if !ok { + t.Fatal("Expected updatedAt to be a timestamp") + } + + // Wait a moment to ensure timestamp difference is detectable + time.Sleep(10 * time.Millisecond) + + // Overwrite with new data + updated := &session.Data[TestState]{ + ID: sessionID, + State: TestState{Name: "Alice Updated", Count: 2}, + } + if err := store.Save(ctx, sessionID, updated); err != nil { + t.Fatalf("Save failed: %v", err) + } + + // Get the updated document to verify timestamps + snapshot2, err := client.Collection(*testSessionCollection).Doc(sessionID).Get(ctx) + if err != nil { + t.Fatalf("Failed to get updated document: %v", err) + } + updatedData := snapshot2.Data() + updatedCreatedAt, ok := updatedData["createdAt"].(time.Time) + if !ok { + t.Fatal("Expected createdAt to be a timestamp after update") + } + updatedUpdatedAt, ok := updatedData["updatedAt"].(time.Time) + if !ok { + t.Fatal("Expected updatedAt to be a timestamp after update") + } + + // Verify CreatedAt is preserved (not modified during overwrite) + if !updatedCreatedAt.Equal(initialCreatedAt) { + t.Errorf("CreatedAt was modified during overwrite: initial=%v, after=%v", initialCreatedAt, updatedCreatedAt) + } + + // Verify UpdatedAt is modified (should be later than initial) + if !updatedUpdatedAt.After(initialUpdatedAt) { + t.Errorf("UpdatedAt should be later after overwrite: initial=%v, after=%v", initialUpdatedAt, updatedUpdatedAt) + } + + // Retrieve and verify state data + retrieved, err := store.Get(ctx, sessionID) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + if retrieved.State.Name != "Alice Updated" { + t.Errorf("Expected Name %q, got %q", "Alice Updated", retrieved.State.Name) + } + if retrieved.State.Count != 2 { + t.Errorf("Expected Count %d, got %d", 2, retrieved.State.Count) + } +} + +func TestFirestoreSessionStore_ExpiresAt(t *testing.T) { + store, client, cleanup := setupTestSessionStore(t) + defer cleanup() + + ctx := context.Background() + sessionID := "test-session-expires" + + data := &session.Data[TestState]{ + ID: sessionID, + State: TestState{Name: "ExpiresTest"}, + } + if err := store.Save(ctx, sessionID, data); err != nil { + t.Fatalf("Save failed: %v", err) + } + + // Verify expiresAt is set in Firestore + snapshot, err := client.Collection(*testSessionCollection).Doc(sessionID).Get(ctx) + if err != nil { + t.Fatalf("Failed to get document: %v", err) + } + + docData := snapshot.Data() + if docData["expiresAt"] == nil { + t.Error("Expected expiresAt to be set") + } +} + +func TestFirestoreSessionStore_WithTTL(t *testing.T) { + skipIfNoFirestoreSession(t) + + ctx := context.Background() + g := genkit.Init(ctx, genkit.WithPlugins(&firebase.Firebase{ProjectId: *testSessionProjectID})) + + f := genkit.LookupPlugin(g, "firebase").(*firebase.Firebase) + client, err := f.Firestore(ctx) + if err != nil { + t.Fatalf("Failed to get Firestore client: %v", err) + } + defer deleteSessionCollection(ctx, client, *testSessionCollection, t) + + customTTL := 1 * time.Hour + store, err := NewFirestoreSessionStore[TestState](ctx, g, + WithCollection(*testSessionCollection), + WithTTL(customTTL), + ) + if err != nil { + t.Fatalf("Failed to create session store: %v", err) + } + + if store.ttl != customTTL { + t.Errorf("Expected TTL %v, got %v", customTTL, store.ttl) + } +} + +func TestFirestoreSessionStore_EmptyState(t *testing.T) { + store, _, cleanup := setupTestSessionStore(t) + defer cleanup() + + ctx := context.Background() + sessionID := "test-session-empty" + + // Save session with zero-value state + data := &session.Data[TestState]{ + ID: sessionID, + State: TestState{}, + } + if err := store.Save(ctx, sessionID, data); err != nil { + t.Fatalf("Save failed: %v", err) + } + + // Retrieve and verify + retrieved, err := store.Get(ctx, sessionID) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + if retrieved == nil { + t.Fatal("Expected data, got nil") + } + if retrieved.State.Name != "" { + t.Errorf("Expected empty Name, got %q", retrieved.State.Name) + } + if retrieved.State.Count != 0 { + t.Errorf("Expected zero Count, got %d", retrieved.State.Count) + } +} + +func TestFirestoreSessionStore_ComplexState(t *testing.T) { + skipIfNoFirestoreSession(t) + + ctx := context.Background() + g := genkit.Init(ctx, genkit.WithPlugins(&firebase.Firebase{ProjectId: *testSessionProjectID})) + + f := genkit.LookupPlugin(g, "firebase").(*firebase.Firebase) + client, err := f.Firestore(ctx) + if err != nil { + t.Fatalf("Failed to get Firestore client: %v", err) + } + defer deleteSessionCollection(ctx, client, *testSessionCollection, t) + + type NestedState struct { + Inner struct { + Value string `json:"value"` + } `json:"inner"` + List []int `json:"list"` + } + + store, err := NewFirestoreSessionStore[NestedState](ctx, g, + WithCollection(*testSessionCollection), + ) + if err != nil { + t.Fatalf("Failed to create session store: %v", err) + } + + sessionID := "test-session-complex" + + // Save complex state + state := NestedState{ + List: []int{1, 2, 3, 4, 5}, + } + state.Inner.Value = "nested value" + + data := &session.Data[NestedState]{ + ID: sessionID, + State: state, + } + if err := store.Save(ctx, sessionID, data); err != nil { + t.Fatalf("Save failed: %v", err) + } + + // Retrieve and verify + retrieved, err := store.Get(ctx, sessionID) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + if retrieved == nil { + t.Fatal("Expected data, got nil") + } + if retrieved.State.Inner.Value != "nested value" { + t.Errorf("Expected Inner.Value %q, got %q", "nested value", retrieved.State.Inner.Value) + } + if len(retrieved.State.List) != 5 { + t.Errorf("Expected List length %d, got %d", 5, len(retrieved.State.List)) + } +} + +func TestFirestoreSessionStore_IntegrationWithSession(t *testing.T) { + store, _, cleanup := setupTestSessionStore(t) + defer cleanup() + + ctx := context.Background() + + // Create a session with the Firestore store + sess, err := session.New(ctx, + session.WithID[TestState]("integration-test"), + session.WithInitialState(TestState{Name: "Integration", Count: 0}), + session.WithStore(store), + ) + if err != nil { + t.Fatalf("New failed: %v", err) + } + + // Update state (should persist to Firestore) + if err := sess.UpdateState(ctx, TestState{Name: "Updated", Count: 10}); err != nil { + t.Fatalf("UpdateState failed: %v", err) + } + + // Load session from store + loaded, err := session.Load(ctx, store, "integration-test") + if err != nil { + t.Fatalf("Load failed: %v", err) + } + + if loaded.State().Name != "Updated" { + t.Errorf("Expected Name %q, got %q", "Updated", loaded.State().Name) + } + if loaded.State().Count != 10 { + t.Errorf("Expected Count %d, got %d", 10, loaded.State().Count) + } +} diff --git a/go/plugins/firebase/x/stream_manager.go b/go/plugins/firebase/x/stream_manager.go index edc0154637..7adb06ad54 100644 --- a/go/plugins/firebase/x/stream_manager.go +++ b/go/plugins/firebase/x/stream_manager.go @@ -41,30 +41,27 @@ import ( const ( streamBufferSize = 100 defaultTimeout = 60 * time.Second - defaultTTL = 5 * time.Minute streamEventChunk = "chunk" streamEventDone = "done" streamEventError = "error" ) -// FirestoreStreamManagerOption configures a FirestoreStreamManager. -type FirestoreStreamManagerOption interface { - applyFirestoreStreamManager(*firestoreStreamManagerOptions) error +// StreamManagerOption configures a FirestoreStreamManager. +// Implemented by firestoreOptions (WithCollection, WithTTL) and streamManagerOptions (WithTimeout). +type StreamManagerOption interface { + applyStreamManager(*streamManagerOptions) error } -// firestoreStreamManagerOptions holds configuration for FirestoreStreamManager. -type firestoreStreamManagerOptions struct { - Collection string - Timeout time.Duration - TTL time.Duration +// streamManagerOptions holds configuration for FirestoreStreamManager. +type streamManagerOptions struct { + firestoreOptions + Timeout time.Duration } -func (o *firestoreStreamManagerOptions) applyFirestoreStreamManager(opts *firestoreStreamManagerOptions) error { - if o.Collection != "" { - if opts.Collection != "" { - return errors.New("cannot set collection more than once (WithCollection)") - } - opts.Collection = o.Collection +// applyStreamManager implements StreamManagerOption for streamManagerOptions. +func (o *streamManagerOptions) applyStreamManager(opts *streamManagerOptions) error { + if err := o.firestoreOptions.applyFirestore(&opts.firestoreOptions); err != nil { + return err } if o.Timeout > 0 { @@ -74,34 +71,14 @@ func (o *firestoreStreamManagerOptions) applyFirestoreStreamManager(opts *firest opts.Timeout = o.Timeout } - if o.TTL > 0 { - if opts.TTL > 0 { - return errors.New("cannot set TTL more than once (WithFirestoreTTL)") - } - opts.TTL = o.TTL - } - return nil } -// WithCollection sets the Firestore collection name where stream documents are stored. -// This option is required. -func WithCollection(collection string) FirestoreStreamManagerOption { - return &firestoreStreamManagerOptions{Collection: collection} -} - // WithTimeout sets how long a subscriber waits for new events before giving up. // If no activity occurs within this duration, subscribers receive a DEADLINE_EXCEEDED error. // Default is 60 seconds. -func WithTimeout(timeout time.Duration) FirestoreStreamManagerOption { - return &firestoreStreamManagerOptions{Timeout: timeout} -} - -// WithTTL sets how long completed streams are retained before Firestore auto-deletes them. -// Requires a TTL policy on the collection for the "expiresAt" field. Default is 5 minutes. -// See: https://firebase.google.com/docs/firestore/ttl -func WithTTL(ttl time.Duration) FirestoreStreamManagerOption { - return &firestoreStreamManagerOptions{TTL: ttl} +func WithTimeout(timeout time.Duration) StreamManagerOption { + return &streamManagerOptions{Timeout: timeout} } // FirestoreStreamManager implements [streaming.StreamManager] using Firestore as the backend. @@ -137,11 +114,12 @@ type streamError struct { Message string `firestore:"message"` } -// NewFirestoreStreamManager creates a FirestoreStreamManager for durable streaming. -func NewFirestoreStreamManager(ctx context.Context, g *genkit.Genkit, opts ...FirestoreStreamManagerOption) (*FirestoreStreamManager, error) { - streamOpts := &firestoreStreamManagerOptions{} +// NewFirestoreStreamManager creates a [FirestoreStreamManager] for durable streaming. +// Requires the Firebase plugin to be initialized in the Genkit instance. +func NewFirestoreStreamManager(ctx context.Context, g *genkit.Genkit, opts ...StreamManagerOption) (*FirestoreStreamManager, error) { + streamOpts := &streamManagerOptions{} for _, opt := range opts { - if err := opt.applyFirestoreStreamManager(streamOpts); err != nil { + if err := opt.applyStreamManager(streamOpts); err != nil { return nil, fmt.Errorf("firebase.NewFirestoreStreamManager: error applying options: %w", err) } } @@ -154,7 +132,7 @@ func NewFirestoreStreamManager(ctx context.Context, g *genkit.Genkit, opts ...Fi streamOpts.Timeout = defaultTimeout } if streamOpts.TTL == 0 { - streamOpts.TTL = defaultTTL + streamOpts.TTL = DefaultTTL } plugin := genkit.LookupPlugin(g, "firebase") diff --git a/go/samples/basic-prompts/main.go b/go/samples/basic-prompts/main.go index ccbc308a13..c4d2d009de 100644 --- a/go/samples/basic-prompts/main.go +++ b/go/samples/basic-prompts/main.go @@ -12,6 +12,31 @@ // See the License for the specific language governing permissions and // limitations under the License. +// This sample demonstrates prompts using both inline code definitions and +// .prompt files (Dotprompt). It shows simple prompts, structured output with +// typed schemas, and complex prompts with Handlebars conditionals. +// +// To run: +// +// go run . +// +// In another terminal, test a simple joke flow: +// +// curl -N -X POST http://localhost:8080/simpleJokePromptFlow \ +// -H "Content-Type: application/json" \ +// -d '{"data": "bananas"}' +// +// Test a structured joke flow (returns JSON): +// +// curl -N -X POST http://localhost:8080/structuredJokePromptFlow \ +// -H "Content-Type: application/json" \ +// -d '{"data": {"topic": "bananas"}}' +// +// Test a recipe flow: +// +// curl -N -X POST http://localhost:8080/recipePromptFlow \ +// -H "Content-Type: application/json" \ +// -d '{"data": {"dish": "tacos", "cuisine": "Mexican", "servingSize": 4}}' package main import ( diff --git a/go/samples/basic-structured/main.go b/go/samples/basic-structured/main.go index 428636de4d..57e1de4079 100644 --- a/go/samples/basic-structured/main.go +++ b/go/samples/basic-structured/main.go @@ -12,6 +12,31 @@ // See the License for the specific language governing permissions and // limitations under the License. +// This sample demonstrates structured input/output with strongly-typed Go +// structs. It shows GenerateStream for simple output and GenerateDataStream +// for typed JSON output with streaming partial results. +// +// To run: +// +// go run . +// +// In another terminal, test a simple joke flow: +// +// curl -N -X POST http://localhost:8080/simpleJokesFlow \ +// -H "Content-Type: application/json" \ +// -d '{"data": "bananas"}' +// +// Test a structured joke flow (returns JSON): +// +// curl -N -X POST http://localhost:8080/structuredJokesFlow \ +// -H "Content-Type: application/json" \ +// -d '{"data": {"topic": "bananas"}}' +// +// Test a recipe flow: +// +// curl -N -X POST http://localhost:8080/recipeFlow \ +// -H "Content-Type: application/json" \ +// -d '{"data": {"dish": "tacos", "cuisine": "Mexican", "servingSize": 4}}' package main import ( diff --git a/go/samples/basic/main.go b/go/samples/basic/main.go index 2031340ac5..c2fb382399 100644 --- a/go/samples/basic/main.go +++ b/go/samples/basic/main.go @@ -12,6 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. +// This sample demonstrates basic Genkit flows: a non-streaming flow and a +// streaming flow that generate jokes about a given topic. +// +// To run: +// +// go run . +// +// In another terminal, test the non-streaming flow: +// +// curl -X POST http://localhost:8080/jokesFlow \ +// -H "Content-Type: application/json" \ +// -d '{"data": "bananas"}' +// +// Test the streaming flow: +// +// curl -N -X POST http://localhost:8080/streamingJokesFlow \ +// -H "Content-Type: application/json" \ +// -d '{"data": "bananas"}' package main import ( diff --git a/go/samples/durable-streaming-firestore/README.md b/go/samples/durable-streaming-firestore/README.md deleted file mode 100644 index 8ca1648e8b..0000000000 --- a/go/samples/durable-streaming-firestore/README.md +++ /dev/null @@ -1,140 +0,0 @@ -# Durable Streaming with Firestore - -This sample demonstrates durable streaming using Firestore as the backend. Unlike in-memory streaming, Firestore-backed streams: - -- **Survive server restarts** - Clients can reconnect to streams after server restarts -- **Work across instances** - Multiple server instances can serve the same stream -- **Auto-cleanup** - Completed streams are automatically deleted via Firestore TTL policies - -## Prerequisites - -1. **Firebase Project**: You need a Firebase/GCP project with Firestore enabled. - -2. **Authentication**: Authenticate with your Google Cloud project: - ```bash - gcloud auth application-default login - ``` - -3. **(Recommended) TTL Policy**: Configure a TTL policy on your Firestore collection for automatic cleanup of old streams. This requires setting a TTL on the `expiresAt` field: - - ```bash - gcloud firestore fields ttls update expiresAt \ - --collection-group=genkit-streams \ - --enable-ttl \ - --project=YOUR_PROJECT_ID - ``` - - See: https://firebase.google.com/docs/firestore/ttl - -## Environment Variables - -| Variable | Required | Default | Description | -|----------|----------|---------|-------------| -| `FIREBASE_PROJECT_ID` | Yes | - | Your Firebase/GCP project ID | -| `FIRESTORE_STREAMS_COLLECTION` | No | `genkit-streams` | Firestore collection for stream documents | - -## Running the Sample - -1. Set your project ID: - ```bash - export FIREBASE_PROJECT_ID=your-project-id - ``` - -2. Start the server: - ```bash - go run . - ``` - -## Testing - -### Start a streaming request - -```bash -curl -N -i -H "Accept: text/event-stream" \ - -d '{"data": 5}' \ - http://localhost:8080/countdown -``` - -Note the `X-Genkit-Stream-Id` header in the response - you'll need this to reconnect. - -### Reconnect to an existing stream - -Use the stream ID from the previous response: - -```bash -curl -N -H "Accept: text/event-stream" \ - -H "X-Genkit-Stream-Id: " \ - -d '{"data": 5}' \ - http://localhost:8080/countdown -``` - -The subscription will: -- Replay any buffered chunks that were already sent -- Continue with live updates if the stream is still in progress -- Return all chunks plus the final result if the stream has already completed - -### Test server restart resilience - -1. Start a countdown with a high number: - ```bash - curl -N -i -H "Accept: text/event-stream" -d '{"data": 30}' http://localhost:8080/countdown - ``` - -2. Copy the `X-Genkit-Stream-Id` header value - -3. Stop the server (Ctrl+C) - -4. Restart the server: `go run .` - -5. Reconnect using the stream ID: - ```bash - curl -N -H "Accept: text/event-stream" -H "X-Genkit-Stream-Id: " -d '{"data": 30}' http://localhost:8080/countdown - ``` - -You'll receive all previously buffered chunks, demonstrating that the stream state persisted across the server restart. - -## Configuration Options - -The `FirestoreStreamManager` supports these options: - -| Option | Default | Description | -|--------|---------|-------------| -| `WithCollection(name)` | (required) | Firestore collection for stream documents | -| `WithTimeout(duration)` | 60s | How long subscribers wait for new events before timeout | -| `WithTTL(duration)` | 5m | How long completed streams are retained before auto-deletion | - -Example: -```go -streamManager, err := firebasex.NewFirestoreStreamManager(ctx, g, - firebasex.WithCollection("my-streams"), - firebasex.WithTimeout(2*time.Minute), - firebasex.WithTTL(1*time.Hour), -) -``` - -## How It Works - -1. When a streaming request arrives, a Firestore document is created with the stream ID -2. As the flow produces chunks, they're appended to the document's `stream` array -3. Subscribers use Firestore's real-time listeners to receive updates -4. When the flow completes, a final "done" entry is added with the output -5. The `expiresAt` field is set based on TTL, and Firestore automatically deletes the document - -## License - -``` -Copyright 2025 Google LLC - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -``` - diff --git a/go/samples/durable-streaming-firestore/main.go b/go/samples/durable-streaming-firestore/main.go index 988ccda790..80fb0b47a2 100644 --- a/go/samples/durable-streaming-firestore/main.go +++ b/go/samples/durable-streaming-firestore/main.go @@ -18,7 +18,27 @@ // Unlike in-memory streaming, Firestore-backed streams survive server restarts // and can be accessed across multiple server instances. // -// See README.md for setup instructions. +// Prerequisites: +// - Firebase/GCP project with Firestore enabled +// - Run: gcloud auth application-default login +// - Set: export FIREBASE_PROJECT_ID=your-project-id +// +// To run: +// +// go run . +// +// In another terminal, start a streaming request: +// +// curl -N -i -H "Accept: text/event-stream" \ +// -d '{"data": 5}' \ +// http://localhost:8088/countdown +// +// Note the X-Genkit-Stream-Id header. To reconnect to the same stream: +// +// curl -N -H "Accept: text/event-stream" \ +// -H "X-Genkit-Stream-Id: " \ +// -d '{"data": 5}' \ +// http://localhost:8088/countdown package main import ( diff --git a/go/samples/durable-streaming/main.go b/go/samples/durable-streaming/main.go index 36323990a3..2c22f4b3b6 100644 --- a/go/samples/durable-streaming/main.go +++ b/go/samples/durable-streaming/main.go @@ -36,7 +36,6 @@ // // The subscription will replay any buffered chunks and then continue with live updates. // If the stream has already completed, all chunks plus the final result are returned. - package main import ( diff --git a/go/samples/session/main.go b/go/samples/session/main.go new file mode 100644 index 0000000000..c1f6d7b0b9 --- /dev/null +++ b/go/samples/session/main.go @@ -0,0 +1,129 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This sample demonstrates how to use sessions to maintain state across +// multiple requests. It implements a shopping cart where items persist +// between calls using the session API. +// +// To run: +// +// go run . +// +// In another terminal, test (items persist across requests): +// +// curl -X POST http://localhost:8080/manageCart \ +// -H "Content-Type: application/json" \ +// -d '{"data": "Add apples and bananas to my cart"}' +// +// curl -X POST http://localhost:8080/manageCart \ +// -H "Content-Type: application/json" \ +// -d '{"data": "What is in my cart?"}' +package main + +import ( + "context" + "fmt" + "log" + "net/http" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core/x/session" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/googlegenai" + "github.com/firebase/genkit/go/plugins/server" + "google.golang.org/genai" +) + +// CartState holds the shopping cart items. +type CartState struct { + Items []string `json:"items"` +} + +func main() { + ctx := context.Background() + g := genkit.Init(ctx, genkit.WithPlugins(&googlegenai.GoogleAI{})) + + // Create in-memory store (shared across requests). + store := session.NewInMemoryStore[CartState]() + + // Fixed session ID for simplicity. + const sessionID = "shopping-session" + + // Define addToCart tool - adds an item to the cart stored in session state. + addToCartTool := genkit.DefineTool(g, "addToCart", + "Adds items to the shopping cart", + func(ctx *ai.ToolContext, input struct{ Items []string }) ([]string, error) { + sess := session.FromContext[CartState](ctx.Context) + if sess == nil { + return nil, fmt.Errorf("no session in context") + } + state := sess.State() + state.Items = append(state.Items, input.Items...) + if err := sess.UpdateState(ctx.Context, state); err != nil { + return nil, err + } + return state.Items, nil + }, + ) + + // Define getCart tool - returns all items currently in the cart. + getCartTool := genkit.DefineTool(g, "getCart", + "Returns all items currently in the shopping cart", + func(ctx *ai.ToolContext, input struct{}) ([]string, error) { + sess := session.FromContext[CartState](ctx.Context) + if sess == nil { + return nil, fmt.Errorf("no session in context") + } + return sess.State().Items, nil + }, + ) + + // Define flow that uses session to maintain cart state across requests. + genkit.DefineFlow(g, "manageCart", func(ctx context.Context, input string) (string, error) { + // Load existing session or create new one. + sess, err := session.Load(ctx, store, sessionID) + if err != nil { + // Session doesn't exist, create it. + sess, err = session.New(ctx, + session.WithID[CartState](sessionID), + session.WithStore(store), + session.WithInitialState(CartState{Items: []string{}}), + ) + if err != nil { + return "", err + } + } + + // Attach session to context for tools. + ctx = session.NewContext(ctx, sess) + + return genkit.GenerateText(ctx, g, + ai.WithModel(googlegenai.ModelRef("gemini-2.5-flash", &genai.GenerateContentConfig{ + ThinkingConfig: &genai.ThinkingConfig{ + ThinkingBudget: genai.Ptr[int32](0), + }, + })), + ai.WithSystem("You are a helpful shopping assistant. Use the provided tools to manage the user's cart."), + ai.WithTools(addToCartTool, getCartTool), + ai.WithPrompt(input), + ) + }) + + // Start server. + mux := http.NewServeMux() + for _, a := range genkit.ListFlows(g) { + mux.HandleFunc("POST /"+a.Name(), genkit.Handler(a)) + } + log.Fatal(server.Start(ctx, "127.0.0.1:8080", mux)) +}