Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 27 additions & 100 deletions src/vellum/workflows/utils/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,14 @@
Callable,
ForwardRef,
List,
Literal,
Optional,
Type,
Union,
get_args,
get_origin,
)

from pydantic import BaseModel
from pydantic_core import PydanticUndefined
from pydantic import BaseModel, TypeAdapter
from pydash import snake_case

from vellum import Vellum
Expand All @@ -38,26 +36,19 @@
if TYPE_CHECKING:
from vellum.workflows.workflows.base import BaseWorkflow

type_map: dict[Any, str] = {
str: "string",
int: "integer",
float: "number",
bool: "boolean",
list: "array",
dict: "object",
None: "null",
type(None): "null",
inspect._empty: "null",
"None": "null",
}

for k, v in list(type_map.items()):
if isinstance(k, type):
type_map[k.__name__] = v


def _get_def_name(annotation: Type) -> str:
return f"{annotation.__module__}.{annotation.__qualname__}"
def _strip_title_fields(schema: Any) -> Any:
"""Recursively remove 'title' fields from JSON schema."""
if isinstance(schema, dict):
result = {}
for key, value in schema.items():
if key != "title":
result[key] = _strip_title_fields(value)
return result
elif isinstance(schema, list):
return [_strip_title_fields(item) for item in schema]
else:
return schema


def compile_annotation(annotation: Optional[Any], defs: dict[str, Any]) -> dict:
Expand All @@ -70,89 +61,25 @@ def compile_annotation(annotation: Optional[Any], defs: dict[str, Any]) -> dict:
if annotation is datetime:
return {"type": "string", "format": "date-time"}

if get_origin(annotation) is Union:
if is_json_type(get_args(annotation)):
return {"$ref": "#/$defs/vellum.workflows.types.core.Json"}

return {"anyOf": [compile_annotation(a, defs) for a in get_args(annotation)]}

if get_origin(annotation) is Literal:
values = list(get_args(annotation))
types = {type(value) for value in values}
if len(types) == 1:
value_type = types.pop()
if value_type in type_map:
return {"type": type_map[value_type], "enum": values}
else:
return {"enum": values}
else:
return {"enum": values}

if get_origin(annotation) is dict:
_, value_type = get_args(annotation)
return {"type": "object", "additionalProperties": compile_annotation(value_type, defs)}

if get_origin(annotation) is list:
item_type = get_args(annotation)[0]
return {"type": "array", "items": compile_annotation(item_type, defs)}

if get_origin(annotation) is tuple:
args = get_args(annotation)
if len(args) == 2 and args[1] is Ellipsis:
# Tuple[int, ...] with homogeneous items
return {"type": "array", "items": compile_annotation(args[0], defs)}
else:
# Tuple[int, str] with fixed length items
result = {
"type": "array",
"prefixItems": [compile_annotation(arg, defs) for arg in args],
"minItems": len(args),
"maxItems": len(args),
}
return result

if dataclasses.is_dataclass(annotation) and isinstance(annotation, type):
def_name = _get_def_name(annotation)
if def_name not in defs:
properties = {}
required = []
for field in dataclasses.fields(annotation):
properties[field.name] = compile_annotation(field.type, defs)
if field.default is dataclasses.MISSING:
required.append(field.name)
else:
properties[field.name]["default"] = _compile_default_value(field.default)
defs[def_name] = {"type": "object", "properties": properties, "required": required}
return {"$ref": f"#/$defs/{def_name}"}

if inspect.isclass(annotation) and issubclass(annotation, BaseModel):
def_name = _get_def_name(annotation)
if def_name not in defs:
properties = {}
required = []
for field_name, field_info in annotation.model_fields.items():
# field_info is a FieldInfo object which has an annotation attribute
properties[field_name] = compile_annotation(field_info.annotation, defs)

if field_info.description is not None:
properties[field_name]["description"] = field_info.description

if field_info.default is PydanticUndefined:
required.append(field_name)
else:
properties[field_name]["default"] = _compile_default_value(field_info.default)
defs[def_name] = {"type": "object", "properties": properties, "required": required}

return {"$ref": f"#/$defs/{def_name}"}

# Handle ForwardRef early (before trying to use Pydantic)
if type(annotation) is ForwardRef:
# Ignore forward references for now
return {}

if annotation not in type_map:
raise ValueError(f"Failed to compile type: {annotation}")
# Handle Union types - check for Json type special case
if get_origin(annotation) is Union:
if is_json_type(get_args(annotation)):
return {"$ref": "#/$defs/vellum.workflows.types.core.Json"}

return {"type": type_map[annotation]}
try:
adapter = TypeAdapter(annotation)
schema = adapter.json_schema(mode="serialization", ref_template="#/$defs/{model}")

This comment was marked as outdated.

schema_defs = schema.pop("$defs", {})
schema_defs = {k: _strip_title_fields(v) for k, v in schema_defs.items()}
defs.update(schema_defs)
return _strip_title_fields(schema)
except Exception:
raise ValueError(f"Failed to compile type: {annotation}")


def _compile_default_value(default: Any) -> Any:
Expand Down
77 changes: 33 additions & 44 deletions src/vellum/workflows/utils/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def my_function(a: str, b: int, c: float, d: bool, e: list, f: dict):
"b": {"type": "integer"},
"c": {"type": "number"},
"d": {"type": "boolean"},
"e": {"type": "array"},
"f": {"type": "object"},
"e": {"type": "array", "items": {}},
"f": {"type": "object", "additionalProperties": True},
},
"required": ["a", "b", "c", "d", "e", "f"],
},
Expand Down Expand Up @@ -187,20 +187,21 @@ def my_function(c: MyDataClass):
compiled_function = compile_function_definition(my_function)

# THEN it should return the compiled function definition
ref_name = f"{__name__}.test_compile_function_definition__dataclasses.<locals>.MyDataClass"
assert compiled_function == FunctionDefinition(
name="my_function",
parameters={
"type": "object",
"properties": {"c": {"$ref": f"#/$defs/{ref_name}"}},
"required": ["c"],
"$defs": {
ref_name: {
"properties": {
"c": {
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "string"}},
"properties": {
"a": {"type": "integer"},
"b": {"type": "string"},
},
"required": ["a", "b"],
}
},
"required": ["c"],
},
)

Expand All @@ -218,15 +219,12 @@ def my_function(c: MyPydanticModel):
compiled_function = compile_function_definition(my_function)

# THEN it should return the compiled function definition
ref_name = f"{__name__}.test_compile_function_definition__pydantic.<locals>.MyPydanticModel"
assert compiled_function == FunctionDefinition(
name="my_function",
parameters={
"type": "object",
"properties": {"c": {"$ref": f"#/$defs/{ref_name}"}},
"required": ["c"],
"$defs": {
ref_name: {
"properties": {
"c": {
"type": "object",
"properties": {
"a": {"type": "integer", "description": "The first number"},
Expand All @@ -235,6 +233,7 @@ def my_function(c: MyPydanticModel):
"required": ["a", "b"],
}
},
"required": ["c"],
},
)

Expand All @@ -253,20 +252,22 @@ def my_function(c: MyDataClass = MyDataClass(a=1, b="hello")):
compiled_function = compile_function_definition(my_function)

# THEN it should return the compiled function definition
ref_name = f"{__name__}.test_compile_function_definition__default_dataclass.<locals>.MyDataClass"
assert compiled_function == FunctionDefinition(
name="my_function",
parameters={
"type": "object",
"properties": {"c": {"$ref": f"#/$defs/{ref_name}", "default": {"a": 1, "b": "hello"}}},
"required": [],
"$defs": {
ref_name: {
"properties": {
"c": {
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "string"}},
"properties": {
"a": {"type": "integer"},
"b": {"type": "string"},
},
"required": ["a", "b"],
"default": {"a": 1, "b": "hello"},
}
},
"required": [],
},
)

Expand All @@ -284,38 +285,26 @@ def my_function(c: MyPydanticModel = MyPydanticModel(a=1, b="hello")):
compiled_function = compile_function_definition(my_function)

# THEN it should return the compiled function definition
ref_name = f"{__name__}.test_compile_function_definition__default_pydantic.<locals>.MyPydanticModel"
assert compiled_function == FunctionDefinition(
name="my_function",
parameters={
"type": "object",
"properties": {"c": {"$ref": f"#/$defs/{ref_name}", "default": {"a": 1, "b": "hello"}}},
"required": [],
"$defs": {
ref_name: {
"properties": {
"c": {
"type": "object",
"properties": {"a": {"type": "integer"}, "b": {"type": "string"}},
"properties": {
"a": {"type": "integer"},
"b": {"type": "string"},
},
"required": ["a", "b"],
"default": {"a": 1, "b": "hello"},
}
},
"required": [],
},
)


def test_compile_function_definition__lambda():
# GIVEN a lambda
lambda_function = lambda x: x + 1 # noqa: E731

# WHEN compiling the function
compiled_function = compile_function_definition(lambda_function)

# THEN it should return the compiled function definition
assert compiled_function == FunctionDefinition(
name="<lambda>",
parameters={"type": "object", "properties": {"x": {"type": "null"}}, "required": ["x"]},
)


def test_compile_inline_workflow_function_definition():
class MyNode(BaseNode):
pass
Expand Down Expand Up @@ -383,8 +372,8 @@ class MyWorkflow(BaseWorkflow[MyInputs, BaseState]):
"b": {"type": "integer"},
"c": {"type": "number"},
"d": {"type": "boolean"},
"e": {"type": "array"},
"f": {"type": "object"},
"e": {"type": "array", "items": {}},
"f": {"type": "object", "additionalProperties": True},
},
"required": ["a", "b", "c", "d", "e", "f"],
},
Expand Down Expand Up @@ -622,7 +611,7 @@ def my_function(a: Literal[MyEnum.FOO, MyEnum.BAR]):

compiled_function = compile_function_definition(my_function)
assert isinstance(compiled_function.parameters, dict)
assert compiled_function.parameters["properties"]["a"] == {"enum": [MyEnum.FOO, MyEnum.BAR]}
assert compiled_function.parameters["properties"]["a"] == {"enum": ["foo", "bar"], "type": "string"}


def test_compile_function_definition__annotated_descriptions():
Expand Down Expand Up @@ -770,8 +759,8 @@ def my_function_with_string_annotations(
"b": {"type": "integer"},
"c": {"type": "number"},
"d": {"type": "boolean"},
"e": {"type": "array"},
"f": {"type": "object"},
"e": {"type": "array", "items": {}},
"f": {"type": "object", "additionalProperties": True},
"g": {"type": "null"},
},
"required": ["a", "b", "c", "d", "e", "f", "g"],
Expand Down
Loading