diff --git a/ai/src/main/java/com/google/genkit/ai/Prompt.java b/ai/src/main/java/com/google/genkit/ai/Prompt.java index 7b90e0c7d..f5655b366 100644 --- a/ai/src/main/java/com/google/genkit/ai/Prompt.java +++ b/ai/src/main/java/com/google/genkit/ai/Prompt.java @@ -43,6 +43,7 @@ public class Prompt implements Action { private final String name; + private final String variant; private final String model; private final String template; private final Map inputSchema; @@ -55,6 +56,7 @@ public class Prompt implements Action { * Creates a new Prompt. * * @param name the prompt name + * @param variant the prompt variant * @param model the default model name * @param template the prompt template * @param inputSchema the input JSON schema @@ -64,6 +66,7 @@ public class Prompt implements Action { */ public Prompt( String name, + String variant, String model, String template, Map inputSchema, @@ -71,6 +74,7 @@ public Prompt( Class inputClass, BiFunction renderer) { this.name = name; + this.variant = variant; this.model = model; this.template = template; this.inputSchema = inputSchema; @@ -86,7 +90,17 @@ public Prompt( // Build the prompt sub-object with detailed metadata Map promptMetadata = new HashMap<>(); - promptMetadata.put("name", name); + + // The 'name' in metadata should be the base name (without variant) + String baseName = name; + if (variant != null && name.endsWith("." + variant)) { + baseName = name.substring(0, name.length() - variant.length() - 1); + } + + promptMetadata.put("name", baseName); + if (variant != null) { + promptMetadata.put("variant", variant); + } promptMetadata.put("model", model); promptMetadata.put("template", template); if (inputSchema != null) { @@ -113,6 +127,15 @@ public String getName() { return name; } + /** + * Gets the prompt variant. + * + * @return the variant name, or null if none + */ + public String getVariant() { + return variant; + } + @Override public ActionType getType() { return ActionType.EXECUTABLE_PROMPT; @@ -213,6 +236,7 @@ public GenerationConfig getConfig() { */ public static class Builder { private String name; + private String variant; private String model; private String template; private Map inputSchema; @@ -225,6 +249,11 @@ public Builder name(String name) { return this; } + public Builder variant(String variant) { + this.variant = variant; + return this; + } + public Builder model(String model) { this.model = model; return this; @@ -262,7 +291,8 @@ public Prompt build() { if (renderer == null) { throw new IllegalStateException("Prompt renderer is required"); } - return new Prompt<>(name, model, template, inputSchema, config, inputClass, renderer); + return new Prompt<>( + name, variant, model, template, inputSchema, config, inputClass, renderer); } } } diff --git a/ai/src/test/java/com/google/genkit/ai/PromptTest.java b/ai/src/test/java/com/google/genkit/ai/PromptTest.java new file mode 100644 index 000000000..e0585cdad --- /dev/null +++ b/ai/src/test/java/com/google/genkit/ai/PromptTest.java @@ -0,0 +1,98 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.ai; + +import static org.junit.jupiter.api.Assertions.*; + +import com.google.genkit.core.ActionType; +import java.util.Map; +import org.junit.jupiter.api.Test; + +/** Unit tests for Prompt. */ +class PromptTest { + + @Test + void testPromptVariantAndMetadata() { + Prompt prompt = + Prompt.builder() + .name("recipe.robot") + .variant("robot") + .model("openai/gpt-4o") + .template("Tell me a recipe for {{input}}") + .renderer((ctx, input) -> ModelRequest.builder().addUserMessage(input).build()) + .build(); + + assertEquals("recipe.robot", prompt.getName()); + assertEquals("robot", prompt.getVariant()); + assertEquals(ActionType.EXECUTABLE_PROMPT, prompt.getType()); + + Map metadata = prompt.getMetadata(); + assertEquals(ActionType.EXECUTABLE_PROMPT.getValue(), metadata.get("type")); + + @SuppressWarnings("unchecked") + Map promptMetadata = (Map) metadata.get("prompt"); + assertNotNull(promptMetadata); + assertEquals("recipe", promptMetadata.get("name")); + assertEquals("robot", promptMetadata.get("variant")); + assertEquals("openai/gpt-4o", promptMetadata.get("model")); + assertEquals("Tell me a recipe for {{input}}", promptMetadata.get("template")); + } + + @Test + void testPromptWithoutVariant() { + Prompt prompt = + Prompt.builder() + .name("recipe") + .model("openai/gpt-4o") + .template("Tell me a recipe for {{input}}") + .renderer((ctx, input) -> ModelRequest.builder().addUserMessage(input).build()) + .build(); + + assertEquals("recipe", prompt.getName()); + assertNull(prompt.getVariant()); + + Map metadata = prompt.getMetadata(); + @SuppressWarnings("unchecked") + Map promptMetadata = (Map) metadata.get("prompt"); + assertEquals("recipe", promptMetadata.get("name")); + assertNull(promptMetadata.get("variant")); + } + + @Test + void testPromptVariantMismatch() { + // If name doesn't end with variant, baseName should be the full name + Prompt prompt = + Prompt.builder() + .name("recipe") + .variant("robot") + .model("openai/gpt-4o") + .template("Tell me a recipe for {{input}}") + .renderer((ctx, input) -> ModelRequest.builder().addUserMessage(input).build()) + .build(); + + assertEquals("recipe", prompt.getName()); + assertEquals("robot", prompt.getVariant()); + + Map metadata = prompt.getMetadata(); + @SuppressWarnings("unchecked") + Map promptMetadata = (Map) metadata.get("prompt"); + assertEquals("recipe", promptMetadata.get("name")); + assertEquals("robot", promptMetadata.get("variant")); + } +} diff --git a/genkit/src/main/java/com/google/genkit/prompt/DotPrompt.java b/genkit/src/main/java/com/google/genkit/prompt/DotPrompt.java index 07a1db2be..2c5f092f2 100644 --- a/genkit/src/main/java/com/google/genkit/prompt/DotPrompt.java +++ b/genkit/src/main/java/com/google/genkit/prompt/DotPrompt.java @@ -105,6 +105,7 @@ public Charset getCharset() { private static final Handlebars sharedHandlebars = new Handlebars(partialLoader); private final String name; + private final String variant; private final String model; private final String template; private final Map inputSchema; @@ -115,6 +116,7 @@ public Charset getCharset() { * Creates a new DotPrompt. * * @param name the prompt name + * @param variant the prompt variant * @param model the default model name * @param template the Handlebars template * @param inputSchema the input JSON schema @@ -122,11 +124,13 @@ public Charset getCharset() { */ public DotPrompt( String name, + String variant, String model, String template, Map inputSchema, GenerationConfig config) { this.name = name; + this.variant = variant; this.model = model; this.template = template; this.inputSchema = inputSchema; @@ -298,7 +302,15 @@ public static DotPrompt parse(String name, String content) throws GenkitE name = name.substring(name.lastIndexOf('/') + 1); } - return new DotPrompt<>(name, model, template, inputSchema, config); + // Extract variant from name (e.g., "recipe.robot" -> variant="robot") + // Note: name remains the full name (e.g., "recipe.robot") for registry key uniqueness + String variant = null; + int dotIndex = name.lastIndexOf('.'); + if (dotIndex != -1) { + variant = name.substring(dotIndex + 1); + } + + return new DotPrompt<>(name, variant, model, template, inputSchema, config); } /** @@ -364,6 +376,7 @@ public ModelRequest toModelRequest(I input) throws GenkitException { public Prompt toPrompt(Class inputClass) { return Prompt.builder() .name(name) + .variant(variant) .model(model) .template(template) .inputSchema(inputSchema) @@ -512,6 +525,10 @@ public String getName() { return name; } + public String getVariant() { + return variant; + } + public String getModel() { return model; } diff --git a/genkit/src/test/java/com/google/genkit/prompt/DotPromptTest.java b/genkit/src/test/java/com/google/genkit/prompt/DotPromptTest.java new file mode 100644 index 000000000..0a5256165 --- /dev/null +++ b/genkit/src/test/java/com/google/genkit/prompt/DotPromptTest.java @@ -0,0 +1,78 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.genkit.prompt; + +import static org.junit.jupiter.api.Assertions.*; + +import com.google.genkit.ai.Prompt; +import com.google.genkit.core.GenkitException; +import java.util.Map; +import org.junit.jupiter.api.Test; + +/** Unit tests for DotPrompt. */ +class DotPromptTest { + + @Test + void testParseVariant() throws GenkitException { + String content = "---\nmodel: openai/gpt-4o\n---\nHello {{input}}"; + DotPrompt> dotPrompt = DotPrompt.parse("recipe.robot", content); + + assertEquals("recipe.robot", dotPrompt.getName()); + assertEquals("robot", dotPrompt.getVariant()); + assertEquals("openai/gpt-4o", dotPrompt.getModel()); + assertEquals("Hello {{input}}", dotPrompt.getTemplate()); + } + + @Test + void testParseWithoutVariant() throws GenkitException { + String content = "---\nmodel: openai/gpt-4o\n---\nHello {{input}}"; + DotPrompt> dotPrompt = DotPrompt.parse("recipe", content); + + assertEquals("recipe", dotPrompt.getName()); + assertNull(dotPrompt.getVariant()); + } + + @Test + void testParseWithMultipleDots() throws GenkitException { + String content = "---\nmodel: openai/gpt-4o\n---\nHello {{input}}"; + // Should use last dot for variant + DotPrompt> dotPrompt = DotPrompt.parse("my.awesome.recipe.robot", content); + + assertEquals("my.awesome.recipe.robot", dotPrompt.getName()); + assertEquals("robot", dotPrompt.getVariant()); + } + + @Test + void testToPrompt() throws GenkitException { + String content = "---\nmodel: openai/gpt-4o\n---\nHello {{input}}"; + DotPrompt> dotPrompt = DotPrompt.parse("recipe.robot", content); + + @SuppressWarnings("unchecked") + Prompt> prompt = + (Prompt>) dotPrompt.toPrompt((Class) Map.class); + + assertEquals("recipe.robot", prompt.getName()); + assertEquals("robot", prompt.getVariant()); + + @SuppressWarnings("unchecked") + Map promptMetadata = (Map) prompt.getMetadata().get("prompt"); + assertEquals("recipe", promptMetadata.get("name")); + assertEquals("robot", promptMetadata.get("variant")); + } +} diff --git a/samples/dotprompt/src/main/java/com/google/genkit/samples/DotPromptSample.java b/samples/dotprompt/src/main/java/com/google/genkit/samples/DotPromptSample.java index 765519819..ffe12d932 100644 --- a/samples/dotprompt/src/main/java/com/google/genkit/samples/DotPromptSample.java +++ b/samples/dotprompt/src/main/java/com/google/genkit/samples/DotPromptSample.java @@ -143,18 +143,18 @@ public static void main(String[] args) throws Exception { // Load and auto-register prompts using genkit.prompt() // This automatically loads from /prompts directory and registers as actions - ExecutablePrompt storyPrompt = genkit.prompt("story", StoryInput.class); - ExecutablePrompt travelPrompt = genkit.prompt("travel-planner", TravelInput.class); ExecutablePrompt codeReviewPrompt = genkit.prompt("code-review", CodeReviewInput.class); - - // Load prompt with variant (e.g., recipe.robot.prompt) + ExecutablePrompt recipePrompt = genkit.prompt("recipe", RecipeInput.class); + ExecutablePrompt travelPrompt = genkit.prompt("travel-planner", TravelInput.class); + + // Load prompt with variant (e.g., recipe.prompt, recipe.robot.prompt) ExecutablePrompt robotRecipePrompt = genkit.prompt("recipe", RecipeInput.class, "robot"); - + // ============================================================ // Method 2: Load prompts manually using DotPrompt.loadFromResource() // Useful when you need more control over the loading process // ============================================================ - DotPrompt recipePrompt = DotPrompt.loadFromResource("/prompts/recipe.prompt"); + DotPrompt manualStoryPrompt = DotPrompt.loadFromResource("/prompts/story.prompt"); // ============================================================ // Flow Examples: Different ways to use prompts @@ -220,7 +220,7 @@ public static void main(String[] args) throws Exception { String.class, (ctx, input) -> { // Generate with custom temperature override - ModelResponse response = storyPrompt.generate(input, + ModelResponse response = manualStoryPrompt.generate(genkit.getRegistry(), input, GenerateOptions.builder() .config(GenerationConfig.builder() .temperature(0.9)