Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions ai/src/main/java/com/google/genkit/ai/Prompt.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
public class Prompt<I> implements Action<I, ModelRequest, Void> {

private final String name;
private final String variant;
private final String model;
private final String template;
private final Map<String, Object> inputSchema;
Expand All @@ -55,6 +56,7 @@ public class Prompt<I> implements Action<I, ModelRequest, Void> {
* Creates a new Prompt.
*
* @param name the prompt name
* @param variant the prompt variant
* @param model the default model name
* @param template the prompt template
* @param inputSchema the input JSON schema
Expand All @@ -64,13 +66,15 @@ public class Prompt<I> implements Action<I, ModelRequest, Void> {
*/
public Prompt(
String name,
String variant,
String model,
String template,
Map<String, Object> inputSchema,
GenerationConfig config,
Class<I> inputClass,
BiFunction<ActionContext, I, ModelRequest> renderer) {
this.name = name;
this.variant = variant;
this.model = model;
this.template = template;
this.inputSchema = inputSchema;
Expand All @@ -86,7 +90,17 @@ public Prompt(

// Build the prompt sub-object with detailed metadata
Map<String, Object> promptMetadata = new HashMap<>();
promptMetadata.put("name", name);

// The 'name' in metadata should be the base name (without variant)
String baseName = name;
if (variant != null && name.endsWith("." + variant)) {
baseName = name.substring(0, name.length() - variant.length() - 1);
}

promptMetadata.put("name", baseName);
if (variant != null) {
promptMetadata.put("variant", variant);
}
promptMetadata.put("model", model);
promptMetadata.put("template", template);
if (inputSchema != null) {
Expand All @@ -113,6 +127,15 @@ public String getName() {
return name;
}

/**
* Gets the prompt variant.
*
* @return the variant name, or null if none
*/
public String getVariant() {
return variant;
}

@Override
public ActionType getType() {
return ActionType.EXECUTABLE_PROMPT;
Expand Down Expand Up @@ -213,6 +236,7 @@ public GenerationConfig getConfig() {
*/
public static class Builder<I> {
private String name;
private String variant;
private String model;
private String template;
private Map<String, Object> inputSchema;
Expand All @@ -225,6 +249,11 @@ public Builder<I> name(String name) {
return this;
}

public Builder<I> variant(String variant) {
this.variant = variant;
return this;
}

public Builder<I> model(String model) {
this.model = model;
return this;
Expand Down Expand Up @@ -262,7 +291,8 @@ public Prompt<I> build() {
if (renderer == null) {
throw new IllegalStateException("Prompt renderer is required");
}
return new Prompt<>(name, model, template, inputSchema, config, inputClass, renderer);
return new Prompt<>(
name, variant, model, template, inputSchema, config, inputClass, renderer);
}
}
}
98 changes: 98 additions & 0 deletions ai/src/test/java/com/google/genkit/ai/PromptTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/

package com.google.genkit.ai;

import static org.junit.jupiter.api.Assertions.*;

import com.google.genkit.core.ActionType;
import java.util.Map;
import org.junit.jupiter.api.Test;

/** Unit tests for Prompt. */
class PromptTest {

@Test
void testPromptVariantAndMetadata() {
Prompt<String> prompt =
Prompt.<String>builder()
.name("recipe.robot")
.variant("robot")
.model("openai/gpt-4o")
.template("Tell me a recipe for {{input}}")
.renderer((ctx, input) -> ModelRequest.builder().addUserMessage(input).build())
.build();

assertEquals("recipe.robot", prompt.getName());
assertEquals("robot", prompt.getVariant());
assertEquals(ActionType.EXECUTABLE_PROMPT, prompt.getType());

Map<String, Object> metadata = prompt.getMetadata();
assertEquals(ActionType.EXECUTABLE_PROMPT.getValue(), metadata.get("type"));

@SuppressWarnings("unchecked")
Map<String, Object> promptMetadata = (Map<String, Object>) metadata.get("prompt");
assertNotNull(promptMetadata);
assertEquals("recipe", promptMetadata.get("name"));
assertEquals("robot", promptMetadata.get("variant"));
assertEquals("openai/gpt-4o", promptMetadata.get("model"));
assertEquals("Tell me a recipe for {{input}}", promptMetadata.get("template"));
}

@Test
void testPromptWithoutVariant() {
Prompt<String> prompt =
Prompt.<String>builder()
.name("recipe")
.model("openai/gpt-4o")
.template("Tell me a recipe for {{input}}")
.renderer((ctx, input) -> ModelRequest.builder().addUserMessage(input).build())
.build();

assertEquals("recipe", prompt.getName());
assertNull(prompt.getVariant());

Map<String, Object> metadata = prompt.getMetadata();
@SuppressWarnings("unchecked")
Map<String, Object> promptMetadata = (Map<String, Object>) metadata.get("prompt");
assertEquals("recipe", promptMetadata.get("name"));
assertNull(promptMetadata.get("variant"));
}

@Test
void testPromptVariantMismatch() {
// If name doesn't end with variant, baseName should be the full name
Prompt<String> prompt =
Prompt.<String>builder()
.name("recipe")
.variant("robot")
.model("openai/gpt-4o")
.template("Tell me a recipe for {{input}}")
.renderer((ctx, input) -> ModelRequest.builder().addUserMessage(input).build())
.build();

assertEquals("recipe", prompt.getName());
assertEquals("robot", prompt.getVariant());

Map<String, Object> metadata = prompt.getMetadata();
@SuppressWarnings("unchecked")
Map<String, Object> promptMetadata = (Map<String, Object>) metadata.get("prompt");
assertEquals("recipe", promptMetadata.get("name"));
assertEquals("robot", promptMetadata.get("variant"));
}
}
19 changes: 18 additions & 1 deletion genkit/src/main/java/com/google/genkit/prompt/DotPrompt.java
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ public Charset getCharset() {
private static final Handlebars sharedHandlebars = new Handlebars(partialLoader);

private final String name;
private final String variant;
private final String model;
private final String template;
private final Map<String, Object> inputSchema;
Expand All @@ -115,18 +116,21 @@ public Charset getCharset() {
* Creates a new DotPrompt.
*
* @param name the prompt name
* @param variant the prompt variant
* @param model the default model name
* @param template the Handlebars template
* @param inputSchema the input JSON schema
* @param config the default generation config
*/
public DotPrompt(
String name,
String variant,
String model,
String template,
Map<String, Object> inputSchema,
GenerationConfig config) {
this.name = name;
this.variant = variant;
this.model = model;
this.template = template;
this.inputSchema = inputSchema;
Expand Down Expand Up @@ -298,7 +302,15 @@ public static <I> DotPrompt<I> parse(String name, String content) throws GenkitE
name = name.substring(name.lastIndexOf('/') + 1);
}

return new DotPrompt<>(name, model, template, inputSchema, config);
// Extract variant from name (e.g., "recipe.robot" -> variant="robot")
// Note: name remains the full name (e.g., "recipe.robot") for registry key uniqueness
String variant = null;
int dotIndex = name.lastIndexOf('.');
if (dotIndex != -1) {
variant = name.substring(dotIndex + 1);
}

return new DotPrompt<>(name, variant, model, template, inputSchema, config);
}

/**
Expand Down Expand Up @@ -364,6 +376,7 @@ public ModelRequest toModelRequest(I input) throws GenkitException {
public Prompt<I> toPrompt(Class<I> inputClass) {
return Prompt.<I>builder()
.name(name)
.variant(variant)
.model(model)
.template(template)
.inputSchema(inputSchema)
Expand Down Expand Up @@ -512,6 +525,10 @@ public String getName() {
return name;
}

public String getVariant() {
return variant;
}

public String getModel() {
return model;
}
Expand Down
78 changes: 78 additions & 0 deletions genkit/src/test/java/com/google/genkit/prompt/DotPromptTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/

package com.google.genkit.prompt;

import static org.junit.jupiter.api.Assertions.*;

import com.google.genkit.ai.Prompt;
import com.google.genkit.core.GenkitException;
import java.util.Map;
import org.junit.jupiter.api.Test;

/** Unit tests for DotPrompt. */
class DotPromptTest {

@Test
void testParseVariant() throws GenkitException {
String content = "---\nmodel: openai/gpt-4o\n---\nHello {{input}}";
DotPrompt<Map<String, Object>> dotPrompt = DotPrompt.parse("recipe.robot", content);

assertEquals("recipe.robot", dotPrompt.getName());
assertEquals("robot", dotPrompt.getVariant());
assertEquals("openai/gpt-4o", dotPrompt.getModel());
assertEquals("Hello {{input}}", dotPrompt.getTemplate());
}

@Test
void testParseWithoutVariant() throws GenkitException {
String content = "---\nmodel: openai/gpt-4o\n---\nHello {{input}}";
DotPrompt<Map<String, Object>> dotPrompt = DotPrompt.parse("recipe", content);

assertEquals("recipe", dotPrompt.getName());
assertNull(dotPrompt.getVariant());
}

@Test
void testParseWithMultipleDots() throws GenkitException {
String content = "---\nmodel: openai/gpt-4o\n---\nHello {{input}}";
// Should use last dot for variant
DotPrompt<Map<String, Object>> dotPrompt = DotPrompt.parse("my.awesome.recipe.robot", content);

assertEquals("my.awesome.recipe.robot", dotPrompt.getName());
assertEquals("robot", dotPrompt.getVariant());
}

@Test
void testToPrompt() throws GenkitException {
String content = "---\nmodel: openai/gpt-4o\n---\nHello {{input}}";
DotPrompt<Map<String, Object>> dotPrompt = DotPrompt.parse("recipe.robot", content);

@SuppressWarnings("unchecked")
Prompt<Map<String, Object>> prompt =
(Prompt<Map<String, Object>>) dotPrompt.toPrompt((Class) Map.class);

assertEquals("recipe.robot", prompt.getName());
assertEquals("robot", prompt.getVariant());

@SuppressWarnings("unchecked")
Map<String, Object> promptMetadata = (Map<String, Object>) prompt.getMetadata().get("prompt");
assertEquals("recipe", promptMetadata.get("name"));
assertEquals("robot", promptMetadata.get("variant"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -143,18 +143,18 @@ public static void main(String[] args) throws Exception {

// Load and auto-register prompts using genkit.prompt()
// This automatically loads from /prompts directory and registers as actions
ExecutablePrompt<StoryInput> storyPrompt = genkit.prompt("story", StoryInput.class);
ExecutablePrompt<TravelInput> travelPrompt = genkit.prompt("travel-planner", TravelInput.class);
ExecutablePrompt<CodeReviewInput> codeReviewPrompt = genkit.prompt("code-review", CodeReviewInput.class);

// Load prompt with variant (e.g., recipe.robot.prompt)
ExecutablePrompt<RecipeInput> recipePrompt = genkit.prompt("recipe", RecipeInput.class);
ExecutablePrompt<TravelInput> travelPrompt = genkit.prompt("travel-planner", TravelInput.class);

// Load prompt with variant (e.g., recipe.prompt, recipe.robot.prompt)
ExecutablePrompt<RecipeInput> robotRecipePrompt = genkit.prompt("recipe", RecipeInput.class, "robot");

// ============================================================
// Method 2: Load prompts manually using DotPrompt.loadFromResource()
// Useful when you need more control over the loading process
// ============================================================
DotPrompt<RecipeInput> recipePrompt = DotPrompt.loadFromResource("/prompts/recipe.prompt");
DotPrompt<StoryInput> manualStoryPrompt = DotPrompt.loadFromResource("/prompts/story.prompt");

// ============================================================
// Flow Examples: Different ways to use prompts
Expand Down Expand Up @@ -220,7 +220,7 @@ public static void main(String[] args) throws Exception {
String.class,
(ctx, input) -> {
// Generate with custom temperature override
ModelResponse response = storyPrompt.generate(input,
ModelResponse response = manualStoryPrompt.generate(genkit.getRegistry(), input,
GenerateOptions.builder()
.config(GenerationConfig.builder()
.temperature(0.9)
Expand Down
Loading