diff --git a/ee/codegen_integration/fixtures/simple_composio_tool_calling_node/display_data/simple_composio_tool_calling_node.json b/ee/codegen_integration/fixtures/simple_composio_tool_calling_node/display_data/simple_composio_tool_calling_node.json index 4302061843..168eee160d 100644 --- a/ee/codegen_integration/fixtures/simple_composio_tool_calling_node/display_data/simple_composio_tool_calling_node.json +++ b/ee/codegen_integration/fixtures/simple_composio_tool_calling_node/display_data/simple_composio_tool_calling_node.json @@ -169,14 +169,19 @@ "name": "text", "type": "STRING", "value": null, - "schema": {"type": "string"} + "schema": { "type": "string" } }, { "id": "d278b1bc-dd58-4da0-bf58-84b03b1f5438", "name": "chat_history", "type": "CHAT_HISTORY", "value": null, - "schema": {"type": "array", "items": {"$ref": "#/$defs/vellum.client.types.chat_message.ChatMessage"}} + "schema": { + "type": "array", + "items": { + "$ref": "#/$defs/vellum.client.types.chat_message.ChatMessage" + } + } } ] } diff --git a/src/vellum/workflows/utils/functions.py b/src/vellum/workflows/utils/functions.py index 5311ca3cb9..11bcf22fa3 100644 --- a/src/vellum/workflows/utils/functions.py +++ b/src/vellum/workflows/utils/functions.py @@ -8,7 +8,6 @@ Callable, ForwardRef, List, - Literal, Optional, Type, Union, @@ -16,8 +15,8 @@ get_origin, ) -from pydantic import BaseModel -from pydantic_core import PydanticUndefined +from pydantic import BaseModel, TypeAdapter +from pydantic.json_schema import CoreModeRef, CoreRef, DefsRef, GenerateJsonSchema, JsonRef, JsonSchemaValue from pydash import snake_case from vellum import Vellum @@ -38,26 +37,72 @@ if TYPE_CHECKING: from vellum.workflows.workflows.base import BaseWorkflow -type_map: dict[Any, str] = { - str: "string", - int: "integer", - float: "number", - bool: "boolean", - list: "array", - dict: "object", - None: "null", - type(None): "null", - inspect._empty: "null", - "None": "null", -} -for k, v in list(type_map.items()): - if isinstance(k, type): - type_map[k.__name__] = v - - -def _get_def_name(annotation: Type) -> str: - return f"{annotation.__module__}.{annotation.__qualname__}" +def _strip_title_fields(schema: Any) -> Any: + """Recursively remove 'title' fields from JSON schema.""" + if isinstance(schema, dict): + result = {} + for key, value in schema.items(): + if key != "title": + result[key] = _strip_title_fields(value) + return result + elif isinstance(schema, list): + return [_strip_title_fields(item) for item in schema] + else: + return schema + + +class ModulePathJsonSchemaGenerator(GenerateJsonSchema): + """Custom JSON schema generator that includes full module paths in $ref references.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Map from original defs_ref to module path + self._defs_ref_to_module_path: dict[str, str] = {} + + def get_cache_defs_ref_schema(self, core_ref: CoreRef) -> tuple[DefsRef, JsonSchemaValue]: + """ + Override to store mapping but keep Pydantic's internal key structure. + + Args: + core_ref: The core reference (e.g., "vellum.client.types.chat_message.ChatMessage:41224715280") + + Returns: + Tuple of (defs_ref, json_schema) + """ + # Get the default defs_ref and schema (Pydantic will use its sanitized key) + original_defs_ref, json_schema = super().get_cache_defs_ref_schema(core_ref) + + # Extract the module path from core_ref and store the mapping + core_ref_str = str(core_ref) + if ":" in core_ref_str: + model_path = core_ref_str.split(":")[0] + self._defs_ref_to_module_path[original_defs_ref] = model_path + # Update the json_schema ref to use the full module path with dots + if "$ref" in json_schema: + json_schema["$ref"] = f"#/$defs/{model_path}" + # Update the json_to_defs_refs mapping to point to the original defs_ref + # (which is the sanitized key in definitions) + json_ref: JsonRef = JsonRef(f"#/$defs/{model_path}") + self.json_to_defs_refs[json_ref] = original_defs_ref + + return original_defs_ref, json_schema + + def get_defs_ref(self, core_mode_ref: CoreModeRef) -> DefsRef: + """ + Override to return the full module path instead of just the model name. + + Args: + core_mode_ref: Tuple containing (core_ref, mode) + + Returns: + The full module-qualified model name + """ + # Get the default defs_ref + original_defs_ref = super().get_defs_ref(core_mode_ref) + # Return the module path if we have a mapping, otherwise return the original + module_path = self._defs_ref_to_module_path.get(original_defs_ref, original_defs_ref) + return DefsRef(module_path) def compile_annotation(annotation: Optional[Any], defs: dict[str, Any]) -> dict: @@ -70,89 +115,27 @@ def compile_annotation(annotation: Optional[Any], defs: dict[str, Any]) -> dict: if annotation is datetime: return {"type": "string", "format": "date-time"} - if get_origin(annotation) is Union: - if is_json_type(get_args(annotation)): - return {"$ref": "#/$defs/vellum.workflows.types.core.Json"} - - return {"anyOf": [compile_annotation(a, defs) for a in get_args(annotation)]} - - if get_origin(annotation) is Literal: - values = list(get_args(annotation)) - types = {type(value) for value in values} - if len(types) == 1: - value_type = types.pop() - if value_type in type_map: - return {"type": type_map[value_type], "enum": values} - else: - return {"enum": values} - else: - return {"enum": values} - - if get_origin(annotation) is dict: - _, value_type = get_args(annotation) - return {"type": "object", "additionalProperties": compile_annotation(value_type, defs)} - - if get_origin(annotation) is list: - item_type = get_args(annotation)[0] - return {"type": "array", "items": compile_annotation(item_type, defs)} - - if get_origin(annotation) is tuple: - args = get_args(annotation) - if len(args) == 2 and args[1] is Ellipsis: - # Tuple[int, ...] with homogeneous items - return {"type": "array", "items": compile_annotation(args[0], defs)} - else: - # Tuple[int, str] with fixed length items - result = { - "type": "array", - "prefixItems": [compile_annotation(arg, defs) for arg in args], - "minItems": len(args), - "maxItems": len(args), - } - return result - - if dataclasses.is_dataclass(annotation) and isinstance(annotation, type): - def_name = _get_def_name(annotation) - if def_name not in defs: - properties = {} - required = [] - for field in dataclasses.fields(annotation): - properties[field.name] = compile_annotation(field.type, defs) - if field.default is dataclasses.MISSING: - required.append(field.name) - else: - properties[field.name]["default"] = _compile_default_value(field.default) - defs[def_name] = {"type": "object", "properties": properties, "required": required} - return {"$ref": f"#/$defs/{def_name}"} - - if inspect.isclass(annotation) and issubclass(annotation, BaseModel): - def_name = _get_def_name(annotation) - if def_name not in defs: - properties = {} - required = [] - for field_name, field_info in annotation.model_fields.items(): - # field_info is a FieldInfo object which has an annotation attribute - properties[field_name] = compile_annotation(field_info.annotation, defs) - - if field_info.description is not None: - properties[field_name]["description"] = field_info.description - - if field_info.default is PydanticUndefined: - required.append(field_name) - else: - properties[field_name]["default"] = _compile_default_value(field_info.default) - defs[def_name] = {"type": "object", "properties": properties, "required": required} - - return {"$ref": f"#/$defs/{def_name}"} - + # Handle ForwardRef early (before trying to use Pydantic) if type(annotation) is ForwardRef: # Ignore forward references for now return {} - if annotation not in type_map: - raise ValueError(f"Failed to compile type: {annotation}") + # Handle Union types - check for Json type special case + if get_origin(annotation) is Union: + if is_json_type(get_args(annotation)): + return {"$ref": "#/$defs/vellum.workflows.types.core.Json"} - return {"type": type_map[annotation]} + try: + adapter = TypeAdapter(annotation) + schema = adapter.json_schema( + mode="serialization", ref_template="#/$defs/{model}", schema_generator=ModulePathJsonSchemaGenerator + ) + schema_defs = schema.pop("$defs", {}) + schema_defs = {k: _strip_title_fields(v) for k, v in schema_defs.items()} + defs.update(schema_defs) + return _strip_title_fields(schema) + except Exception as e: + raise ValueError(f"Failed to compile type: {annotation}: {e}") def _compile_default_value(default: Any) -> Any: diff --git a/src/vellum/workflows/utils/tests/test_functions.py b/src/vellum/workflows/utils/tests/test_functions.py index 2e92e49d0a..46c0c2de62 100644 --- a/src/vellum/workflows/utils/tests/test_functions.py +++ b/src/vellum/workflows/utils/tests/test_functions.py @@ -71,8 +71,8 @@ def my_function(a: str, b: int, c: float, d: bool, e: list, f: dict): "b": {"type": "integer"}, "c": {"type": "number"}, "d": {"type": "boolean"}, - "e": {"type": "array"}, - "f": {"type": "object"}, + "e": {"type": "array", "items": {}}, + "f": {"type": "object", "additionalProperties": True}, }, "required": ["a", "b", "c", "d", "e", "f"], }, @@ -187,20 +187,21 @@ def my_function(c: MyDataClass): compiled_function = compile_function_definition(my_function) # THEN it should return the compiled function definition - ref_name = f"{__name__}.test_compile_function_definition__dataclasses..MyDataClass" assert compiled_function == FunctionDefinition( name="my_function", parameters={ "type": "object", - "properties": {"c": {"$ref": f"#/$defs/{ref_name}"}}, - "required": ["c"], - "$defs": { - ref_name: { + "properties": { + "c": { "type": "object", - "properties": {"a": {"type": "integer"}, "b": {"type": "string"}}, + "properties": { + "a": {"type": "integer"}, + "b": {"type": "string"}, + }, "required": ["a", "b"], } }, + "required": ["c"], }, ) @@ -218,15 +219,12 @@ def my_function(c: MyPydanticModel): compiled_function = compile_function_definition(my_function) # THEN it should return the compiled function definition - ref_name = f"{__name__}.test_compile_function_definition__pydantic..MyPydanticModel" assert compiled_function == FunctionDefinition( name="my_function", parameters={ "type": "object", - "properties": {"c": {"$ref": f"#/$defs/{ref_name}"}}, - "required": ["c"], - "$defs": { - ref_name: { + "properties": { + "c": { "type": "object", "properties": { "a": {"type": "integer", "description": "The first number"}, @@ -235,6 +233,7 @@ def my_function(c: MyPydanticModel): "required": ["a", "b"], } }, + "required": ["c"], }, ) @@ -253,20 +252,22 @@ def my_function(c: MyDataClass = MyDataClass(a=1, b="hello")): compiled_function = compile_function_definition(my_function) # THEN it should return the compiled function definition - ref_name = f"{__name__}.test_compile_function_definition__default_dataclass..MyDataClass" assert compiled_function == FunctionDefinition( name="my_function", parameters={ "type": "object", - "properties": {"c": {"$ref": f"#/$defs/{ref_name}", "default": {"a": 1, "b": "hello"}}}, - "required": [], - "$defs": { - ref_name: { + "properties": { + "c": { "type": "object", - "properties": {"a": {"type": "integer"}, "b": {"type": "string"}}, + "properties": { + "a": {"type": "integer"}, + "b": {"type": "string"}, + }, "required": ["a", "b"], + "default": {"a": 1, "b": "hello"}, } }, + "required": [], }, ) @@ -284,38 +285,26 @@ def my_function(c: MyPydanticModel = MyPydanticModel(a=1, b="hello")): compiled_function = compile_function_definition(my_function) # THEN it should return the compiled function definition - ref_name = f"{__name__}.test_compile_function_definition__default_pydantic..MyPydanticModel" assert compiled_function == FunctionDefinition( name="my_function", parameters={ "type": "object", - "properties": {"c": {"$ref": f"#/$defs/{ref_name}", "default": {"a": 1, "b": "hello"}}}, - "required": [], - "$defs": { - ref_name: { + "properties": { + "c": { "type": "object", - "properties": {"a": {"type": "integer"}, "b": {"type": "string"}}, + "properties": { + "a": {"type": "integer"}, + "b": {"type": "string"}, + }, "required": ["a", "b"], + "default": {"a": 1, "b": "hello"}, } }, + "required": [], }, ) -def test_compile_function_definition__lambda(): - # GIVEN a lambda - lambda_function = lambda x: x + 1 # noqa: E731 - - # WHEN compiling the function - compiled_function = compile_function_definition(lambda_function) - - # THEN it should return the compiled function definition - assert compiled_function == FunctionDefinition( - name="", - parameters={"type": "object", "properties": {"x": {"type": "null"}}, "required": ["x"]}, - ) - - def test_compile_inline_workflow_function_definition(): class MyNode(BaseNode): pass @@ -383,8 +372,8 @@ class MyWorkflow(BaseWorkflow[MyInputs, BaseState]): "b": {"type": "integer"}, "c": {"type": "number"}, "d": {"type": "boolean"}, - "e": {"type": "array"}, - "f": {"type": "object"}, + "e": {"type": "array", "items": {}}, + "f": {"type": "object", "additionalProperties": True}, }, "required": ["a", "b", "c", "d", "e", "f"], }, @@ -622,7 +611,7 @@ def my_function(a: Literal[MyEnum.FOO, MyEnum.BAR]): compiled_function = compile_function_definition(my_function) assert isinstance(compiled_function.parameters, dict) - assert compiled_function.parameters["properties"]["a"] == {"enum": [MyEnum.FOO, MyEnum.BAR]} + assert compiled_function.parameters["properties"]["a"] == {"enum": ["foo", "bar"], "type": "string"} def test_compile_function_definition__annotated_descriptions(): @@ -770,8 +759,8 @@ def my_function_with_string_annotations( "b": {"type": "integer"}, "c": {"type": "number"}, "d": {"type": "boolean"}, - "e": {"type": "array"}, - "f": {"type": "object"}, + "e": {"type": "array", "items": {}}, + "f": {"type": "object", "additionalProperties": True}, "g": {"type": "null"}, }, "required": ["a", "b", "c", "d", "e", "f", "g"],