Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions go/ai/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/core/api"
"github.com/firebase/genkit/go/core/logger"
"github.com/firebase/genkit/go/core/x/session"
"github.com/firebase/genkit/go/internal/base"
"github.com/google/dotprompt/go/dotprompt"
"github.com/invopop/jsonschema"
Expand Down Expand Up @@ -588,14 +589,19 @@ func renderPrompt(ctx context.Context, opts promptOptions, templateText string,
// renderDotpromptToMessages executes a dotprompt prompt function and converts the result to a slice of messages
func renderDotpromptToMessages(ctx context.Context, promptFn dotprompt.PromptFunction, input map[string]any, additionalMetadata *dotprompt.PromptMetadata) ([]*Message, error) {
// Prepare the context for rendering
context := map[string]any{}
templateContext := map[string]any{}
actionCtx := core.FromContext(ctx)
maps.Copy(context, actionCtx)
maps.Copy(templateContext, actionCtx)

// Inject session state if available (accessible via {{@state.field}} in templates)
if state := session.StateFromContext(ctx); state != nil {
templateContext["state"] = state
}

// Call the prompt function with the input and context
rendered, err := promptFn(&dotprompt.DataArgument{
Input: input,
Context: context,
Context: templateContext,
}, additionalMetadata)
if err != nil {
return nil, fmt.Errorf("failed to render prompt: %w", err)
Expand Down
138 changes: 138 additions & 0 deletions go/ai/prompt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/core/api"
"github.com/firebase/genkit/go/core/x/session"
"github.com/firebase/genkit/go/internal/base"
"github.com/firebase/genkit/go/internal/registry"
"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -2168,6 +2169,143 @@ func TestPromptExecuteStream(t *testing.T) {
})
}

// TestSessionStateInjection tests that session state is automatically injected
// into prompt templates and accessible via {{@state.field}} syntax.
func TestSessionStateInjection(t *testing.T) {
r := registry.New()
ConfigureFormats(r)

// Define a test state type
type UserState struct {
Name string `json:"name"`
Preferences map[string]string `json:"preferences"`
}

t.Run("session state accessible in prompt template", func(t *testing.T) {
var capturedPrompt string

testModel := DefineModel(r, "test/sessionStateModel", &ModelOptions{
Supports: &ModelSupports{Multiturn: true},
}, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) {
capturedPrompt = req.Messages[0].Text()
return &ModelResponse{
Request: req,
Message: NewModelTextMessage("response"),
}, nil
})

// Create a prompt that uses {{@state.name}} syntax
p := DefinePrompt(r, "sessionStatePrompt",
WithModel(testModel),
WithPrompt("Hello {{@state.name}}, your theme is {{@state.preferences.theme}}"),
)

// Create a session with state
ctx := context.Background()
sess, err := session.New(ctx, session.WithInitialState(UserState{
Name: "Alice",
Preferences: map[string]string{"theme": "dark"},
}))
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}

// Attach session to context
ctx = session.NewContext(ctx, sess)

// Execute prompt with session in context
_, err = p.Execute(ctx)
if err != nil {
t.Fatalf("Execute failed: %v", err)
}

// Verify the session state was injected into the template
expected := "Hello Alice, your theme is dark"
if capturedPrompt != expected {
t.Errorf("Expected prompt %q, got %q", expected, capturedPrompt)
}
})

t.Run("prompt works without session in context", func(t *testing.T) {
var capturedPrompt string

testModel := DefineModel(r, "test/noSessionModel", &ModelOptions{
Supports: &ModelSupports{Multiturn: true},
}, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) {
capturedPrompt = req.Messages[0].Text()
return &ModelResponse{
Request: req,
Message: NewModelTextMessage("response"),
}, nil
})

// Create a prompt that uses regular input variables (not session state)
p := DefinePrompt(r, "noSessionPrompt",
WithModel(testModel),
WithPrompt("Hello {{name}}"),
WithInputType(struct {
Name string `json:"name"`
}{}),
)

// Execute without session in context
ctx := context.Background()
_, err := p.Execute(ctx, WithInput(map[string]any{"name": "Bob"}))
if err != nil {
t.Fatalf("Execute failed: %v", err)
}

expected := "Hello Bob"
if capturedPrompt != expected {
t.Errorf("Expected prompt %q, got %q", expected, capturedPrompt)
}
})

t.Run("session state and input variables can be used together", func(t *testing.T) {
var capturedPrompt string

testModel := DefineModel(r, "test/mixedModel", &ModelOptions{
Supports: &ModelSupports{Multiturn: true},
}, func(ctx context.Context, req *ModelRequest, cb ModelStreamCallback) (*ModelResponse, error) {
capturedPrompt = req.Messages[0].Text()
return &ModelResponse{
Request: req,
Message: NewModelTextMessage("response"),
}, nil
})

// Create a prompt that uses both input and session state
p := DefinePrompt(r, "mixedPrompt",
WithModel(testModel),
WithPrompt("User {{@state.name}} asks: {{question}}"),
WithInputType(struct {
Question string `json:"question"`
}{}),
)

// Create session
ctx := context.Background()
sess, err := session.New(ctx, session.WithInitialState(UserState{
Name: "Charlie",
}))
if err != nil {
t.Fatalf("Failed to create session: %v", err)
}
ctx = session.NewContext(ctx, sess)

// Execute with both session and input
_, err = p.Execute(ctx, WithInput(map[string]any{"question": "What is the weather?"}))
if err != nil {
t.Fatalf("Execute failed: %v", err)
}

expected := "User Charlie asks: What is the weather?"
if capturedPrompt != expected {
t.Errorf("Expected prompt %q, got %q", expected, capturedPrompt)
}
})
}

// TestDefineExecuteOptionInteractions tests the complex interactions between
// options set at DefinePrompt time vs Execute time.
func TestDefineExecuteOptionInteractions(t *testing.T) {
Expand Down
Loading
Loading