Skip to content

Commit 061d138

Browse files
Support input_examples codegen for tool decorator
Co-Authored-By: [email protected] <[email protected]>
1 parent a6f0a6a commit 061d138

File tree

4 files changed

+263
-26
lines changed

4 files changed

+263
-26
lines changed

ee/codegen/src/__test__/nodes/__snapshots__/tool-calling-node.test.ts.snap

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,62 @@ Summarize the following text:
252252
"
253253
`;
254254

255+
exports[`ToolCallingNode > input_examples > should generate tool decorator with both inputs and input_examples 1`] = `
256+
"from .search import search
257+
258+
from vellum.workflows.nodes.displayable.tool_calling_node import (
259+
ToolCallingNode as BaseToolCallingNode,
260+
)
261+
from vellum.workflows.utils.functions import tool
262+
263+
from ..inputs import Inputs
264+
265+
266+
class ToolCallingNode(BaseToolCallingNode):
267+
functions = [
268+
tool(
269+
inputs={
270+
"parent_context": Inputs.location,
271+
},
272+
input_examples=[
273+
{
274+
"query": "weather in SF",
275+
},
276+
{
277+
"query": "stock prices",
278+
},
279+
],
280+
)(search)
281+
]
282+
"
283+
`;
284+
285+
exports[`ToolCallingNode > input_examples > should generate tool decorator with input_examples 1`] = `
286+
"from .get_weather import get_weather
287+
288+
from vellum.workflows.nodes.displayable.tool_calling_node import (
289+
ToolCallingNode as BaseToolCallingNode,
290+
)
291+
from vellum.workflows.utils.functions import tool
292+
293+
294+
class ToolCallingNode(BaseToolCallingNode):
295+
functions = [
296+
tool(
297+
input_examples=[
298+
{
299+
"location": "San Francisco",
300+
},
301+
{
302+
"location": "New York",
303+
"units": "celsius",
304+
},
305+
]
306+
)(get_weather)
307+
]
308+
"
309+
`;
310+
255311
exports[`ToolCallingNode > mcp server > should generate mcp server 1`] = `
256312
"from vellum.workflows.constants import AuthorizationType
257313
from vellum.workflows.nodes.displayable.tool_calling_node import (

ee/codegen/src/__test__/nodes/tool-calling-node.test.ts

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1482,4 +1482,149 @@ describe("ToolCallingNode", () => {
14821482
expect(output).toMatchSnapshot();
14831483
});
14841484
});
1485+
1486+
describe("input_examples", () => {
1487+
it("should generate tool decorator with input_examples", async () => {
1488+
/**
1489+
* Tests that a CODE_EXECUTION function with input_examples generates
1490+
* a tool decorator with the input_examples parameter.
1491+
*/
1492+
1493+
// GIVEN a code execution function with input_examples
1494+
const codeExecutionFunctionWithExamples: FunctionArgs = {
1495+
type: "CODE_EXECUTION",
1496+
src: 'def get_weather(location: str, units: str = "fahrenheit") -> str:\n """Get the current weather for a location."""\n return "sunny"\n',
1497+
name: "get_weather",
1498+
description: "Get the current weather for a location.",
1499+
definition: {
1500+
name: "get_weather",
1501+
parameters: {
1502+
type: "object",
1503+
required: ["location"],
1504+
properties: {
1505+
location: { type: "string" },
1506+
units: { type: "string", default: "fahrenheit" },
1507+
},
1508+
},
1509+
},
1510+
input_examples: [
1511+
{ location: "San Francisco" },
1512+
{ location: "New York", units: "celsius" },
1513+
],
1514+
};
1515+
1516+
const nodePortData: NodePort[] = [
1517+
nodePortFactory({
1518+
id: "port-id",
1519+
}),
1520+
];
1521+
1522+
const functionsAttribute = nodeAttributeFactory(
1523+
"functions-attr-id",
1524+
"functions",
1525+
{
1526+
type: "CONSTANT_VALUE",
1527+
value: {
1528+
type: "JSON",
1529+
value: [codeExecutionFunctionWithExamples],
1530+
},
1531+
}
1532+
);
1533+
1534+
const nodeData = toolCallingNodeFactory({
1535+
nodePorts: nodePortData,
1536+
nodeAttributes: [functionsAttribute],
1537+
});
1538+
1539+
// WHEN we create the node and generate the node file
1540+
const nodeContext = (await createNodeContext({
1541+
workflowContext,
1542+
nodeData,
1543+
})) as GenericNodeContext;
1544+
1545+
const node = new GenericNode({
1546+
workflowContext,
1547+
nodeContext,
1548+
});
1549+
1550+
node.getNodeFile().write(writer);
1551+
const output = await writer.toStringFormatted();
1552+
1553+
// THEN the generated code should include the tool decorator with input_examples
1554+
expect(output).toMatchSnapshot();
1555+
});
1556+
1557+
it("should generate tool decorator with both inputs and input_examples", async () => {
1558+
/**
1559+
* Tests that a CODE_EXECUTION function with both inputs and input_examples
1560+
* generates a tool decorator with both parameters.
1561+
*/
1562+
1563+
// GIVEN a code execution function with both inputs and input_examples
1564+
const codeExecutionFunctionWithBoth: FunctionArgs = {
1565+
type: "CODE_EXECUTION",
1566+
src: 'def search(query: str, parent_context: str) -> str:\n """Search for information."""\n return "results"\n',
1567+
name: "search",
1568+
description: "Search for information.",
1569+
definition: {
1570+
name: "search",
1571+
parameters: {
1572+
type: "object",
1573+
required: ["query"],
1574+
properties: {
1575+
query: { type: "string" },
1576+
parent_context: { type: "string" },
1577+
},
1578+
},
1579+
inputs: {
1580+
parent_context: {
1581+
type: "WORKFLOW_INPUT",
1582+
input_variable_id: "input-1",
1583+
},
1584+
},
1585+
},
1586+
input_examples: [{ query: "weather in SF" }, { query: "stock prices" }],
1587+
};
1588+
1589+
const nodePortData: NodePort[] = [
1590+
nodePortFactory({
1591+
id: "port-id",
1592+
}),
1593+
];
1594+
1595+
const functionsAttribute = nodeAttributeFactory(
1596+
"functions-attr-id",
1597+
"functions",
1598+
{
1599+
type: "CONSTANT_VALUE",
1600+
value: {
1601+
type: "JSON",
1602+
value: [codeExecutionFunctionWithBoth],
1603+
},
1604+
}
1605+
);
1606+
1607+
const nodeData = toolCallingNodeFactory({
1608+
nodePorts: nodePortData,
1609+
nodeAttributes: [functionsAttribute],
1610+
});
1611+
1612+
// WHEN we create the node and generate the node file
1613+
const nodeContext = (await createNodeContext({
1614+
workflowContext,
1615+
nodeData,
1616+
})) as GenericNodeContext;
1617+
1618+
const node = new GenericNode({
1619+
workflowContext,
1620+
nodeContext,
1621+
});
1622+
1623+
node.getNodeFile().write(writer);
1624+
const output = await writer.toStringFormatted();
1625+
1626+
// THEN the generated code should include the tool decorator with both inputs and input_examples
1627+
expect(output).toMatchSnapshot();
1628+
});
1629+
});
14851630
});

ee/codegen/src/generators/nodes/generic-node.ts

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import { StarImport } from "src/generators/extensions/star-import";
2525
import { StrInstantiation } from "src/generators/extensions/str-instantiation";
2626
import { WrappedCall } from "src/generators/extensions/wrapped-call";
2727
import { InitFile } from "src/generators/init-file";
28+
import { Json } from "src/generators/json";
2829
import { NodeOutputs } from "src/generators/node-outputs";
2930
import { BaseNode } from "src/generators/nodes/bases/base";
3031
import { AttributeType, NODE_ATTRIBUTES } from "src/generators/nodes/constants";
@@ -321,11 +322,15 @@ export class GenericNode extends BaseNode<GenericNodeType, GenericNodeContext> {
321322
modulePath: [`.${snakeName}`], // Import from snake_case module
322323
});
323324

324-
// Check if function has inputs that need to be wrapped with tool()
325+
// Check if function has inputs or input_examples that need to be wrapped with tool()
325326
const parsedInputs = this.parseToolInputs(codeExecutionFunction);
326-
if (parsedInputs && Object.keys(parsedInputs).length > 0) {
327+
const inputExamples = codeExecutionFunction.input_examples ?? null;
328+
const hasInputs = parsedInputs && Object.keys(parsedInputs).length > 0;
329+
const hasInputExamples = inputExamples && inputExamples.length > 0;
330+
331+
if (hasInputs || hasInputExamples) {
327332
// Wrap the function reference with tool(...)(func)
328-
const wrapper = this.getToolInvocation(parsedInputs);
333+
const wrapper = this.getToolInvocation(parsedInputs, inputExamples);
329334
return new WrappedCall({
330335
wrapper,
331336
inner: functionReference,
@@ -888,41 +893,71 @@ export class GenericNode extends BaseNode<GenericNodeType, GenericNodeContext> {
888893
}
889894

890895
/**
891-
* Creates a tool(inputs={...}) method invocation for wrapping function references.
896+
* Creates a tool(inputs={...}, input_examples=[...]) method invocation for wrapping function references.
892897
*/
893898
private getToolInvocation(
894-
inputs: Record<string, WorkflowValueDescriptorType>
899+
inputs: Record<string, WorkflowValueDescriptorType> | null,
900+
inputExamples: Array<Record<string, unknown>> | null
895901
): python.MethodInvocation {
896-
// Build dict entries for the inputs parameter
897-
const dictEntries = Object.entries(inputs).map(([inputName, inputDef]) => {
898-
const workflowValueDescriptor = new WorkflowValueDescriptor({
899-
workflowValueDescriptor: inputDef,
900-
nodeContext: this.nodeContext,
901-
workflowContext: this.workflowContext,
902+
const arguments_: python.MethodArgument[] = [];
903+
904+
// Build dict entries for the inputs parameter if provided
905+
if (inputs && Object.keys(inputs).length > 0) {
906+
const dictEntries = Object.entries(inputs).map(
907+
([inputName, inputDef]) => {
908+
const workflowValueDescriptor = new WorkflowValueDescriptor({
909+
workflowValueDescriptor: inputDef,
910+
nodeContext: this.nodeContext,
911+
workflowContext: this.workflowContext,
912+
});
913+
914+
return {
915+
key: new StrInstantiation(inputName),
916+
value: workflowValueDescriptor,
917+
};
918+
}
919+
);
920+
921+
const inputsDict = python.TypeInstantiation.dict(dictEntries, {
922+
endWithComma: true,
902923
});
903924

904-
return {
905-
key: new StrInstantiation(inputName),
906-
value: workflowValueDescriptor,
907-
};
908-
});
925+
arguments_.push(
926+
new MethodArgument({
927+
name: "inputs",
928+
value: inputsDict,
929+
})
930+
);
931+
}
909932

910-
// Create the dict literal for inputs
911-
const inputsDict = python.TypeInstantiation.dict(dictEntries, {
912-
endWithComma: true,
913-
});
933+
// Build list entries for the input_examples parameter if provided
934+
if (inputExamples && inputExamples.length > 0) {
935+
const exampleDicts = inputExamples.map((example) => {
936+
const dictEntries = Object.entries(example).map(([key, value]) => ({
937+
key: new StrInstantiation(key),
938+
value: new Json(value),
939+
}));
940+
return python.TypeInstantiation.dict(dictEntries, {
941+
endWithComma: true,
942+
});
943+
});
944+
945+
arguments_.push(
946+
new MethodArgument({
947+
name: "input_examples",
948+
value: python.TypeInstantiation.list(exampleDicts, {
949+
endWithComma: true,
950+
}),
951+
})
952+
);
953+
}
914954

915955
return python.invokeMethod({
916956
methodReference: new Reference({
917957
name: "tool",
918958
modulePath: ["vellum", "workflows", "utils", "functions"],
919959
}),
920-
arguments_: [
921-
new MethodArgument({
922-
name: "inputs",
923-
value: inputsDict,
924-
}),
925-
],
960+
arguments_,
926961
});
927962
}
928963

ee/codegen/src/types/vellum.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,6 +1068,7 @@ export type FunctionArgs = {
10681068
type: "CODE_EXECUTION";
10691069
src: string;
10701070
definition?: FunctionDefinition; // `legacy frontend` does not send definition field
1071+
input_examples?: Array<Record<string, unknown>>;
10711072
} & NameDescription;
10721073

10731074
export type InlineWorkflowFunctionArgs = {

0 commit comments

Comments
 (0)