diff --git a/gradle.properties b/gradle.properties index 2305d65..ad699fb 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,2 +1,2 @@ version=1.1.0-SNAPSHOT -kestraVersion=1.0.4 \ No newline at end of file +kestraVersion=1.1.0-SNAPSHOT \ No newline at end of file diff --git a/src/main/java/io/kestra/plugin/ai/agent/AIAgent.java b/src/main/java/io/kestra/plugin/ai/agent/AIAgent.java index 97dd804..2af71a8 100644 --- a/src/main/java/io/kestra/plugin/ai/agent/AIAgent.java +++ b/src/main/java/io/kestra/plugin/ai/agent/AIAgent.java @@ -9,11 +9,14 @@ import dev.langchain4j.rag.query.router.QueryRouter; import dev.langchain4j.service.AiServices; import dev.langchain4j.service.Result; +import io.kestra.core.exceptions.IllegalVariableEvaluationException; import io.kestra.core.models.annotations.Example; import io.kestra.core.models.annotations.Metric; import io.kestra.core.models.annotations.Plugin; import io.kestra.core.models.annotations.PluginProperty; import io.kestra.core.models.executions.metrics.Counter; +import io.kestra.core.models.executions.TaskRun; +import io.kestra.core.models.hierarchies.*; import io.kestra.core.models.property.Property; import io.kestra.core.models.tasks.OutputFilesInterface; import io.kestra.core.models.tasks.RunnableTask; @@ -21,6 +24,7 @@ import io.kestra.core.models.tasks.runners.ScriptService; import io.kestra.core.runners.FilesService; import io.kestra.core.runners.RunContext; +import io.kestra.core.utils.GraphUtils; import io.kestra.core.utils.ListUtils; import io.kestra.plugin.ai.AIUtils; import io.kestra.plugin.ai.domain.*; @@ -351,7 +355,8 @@ public class AIAgent extends Task implements RunnableTask, OutputFiles title = "Content retrievers", description = "Some content retrievers, like WebSearch, can also be used as tools. However, when configured as content retrievers, they will always be used, whereas tools are only invoked when the LLM decides to use them." ) - private Property> contentRetrievers; + @PluginProperty + private List contentRetrievers; @Schema( title = "Agent memory", @@ -361,6 +366,31 @@ public class AIAgent extends Task implements RunnableTask, OutputFiles private Property> outputFiles; + @Override + public AbstractGraph graph(TaskRun taskRun, List values, RelationType relationType) { + GraphTask root = new GraphTask(this, taskRun, values, relationType); + + List nodes = new ArrayList<>(); + nodes.add(new CustomGraphNode("provider", provider.getType(), provider)); + if (this.tools != null) { + int count = 0; + for (ToolProvider tool : this.tools) { + nodes.add(new CustomGraphNode("tool-" + count++, tool.getType(), tool)); + } + } + if (this.contentRetrievers != null) { + int count = 0; + for (ContentRetrieverProvider contentRetriever : this.contentRetrievers) { + nodes.add(new CustomGraphNode("contentRetriever-" + count++, contentRetriever.getType(), contentRetriever)); + } + } + if (this.memory != null) { + nodes.add(new CustomGraphNode("memory", memory.getType(), memory)); + } + + return new CustomGraphCluster(this.id, root, nodes); + } + @Override public AIOutput run(RunContext runContext) throws Exception { Map additionalVariables = outputFiles != null ? Map.of(ScriptService.VAR_WORKING_DIR, runContext.workingDir().path(true).toString()) : Collections.emptyMap(); @@ -385,7 +415,7 @@ public AIOutput run(RunContext runContext) throws Exception { agent.chatMemory(memory.chatMemory(runContext)); } - List toolContentRetrievers = runContext.render(contentRetrievers).asList(ContentRetrieverProvider.class).stream() + List toolContentRetrievers = ListUtils.emptyOnNull(this.contentRetrievers).stream() .map(throwFunction(provider -> provider.contentRetriever(runContext))) .toList(); if (!toolContentRetrievers.isEmpty()) { @@ -419,6 +449,8 @@ public AIOutput run(RunContext runContext) throws Exception { } } + + // output files should all be inside the working directory private Map gatherOutputFiles(RunContext runContext) throws Exception { Map outputFiles = new HashMap<>();