Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,19 @@
"name": "text",
"type": "STRING",
"value": null,
"schema": {"type": "string"}
"schema": { "type": "string" }
},
{
"id": "d278b1bc-dd58-4da0-bf58-84b03b1f5438",
"name": "chat_history",
"type": "CHAT_HISTORY",
"value": null,
"schema": {"type": "array", "items": {"$ref": "#/$defs/vellum.client.types.chat_message.ChatMessage"}}
"schema": {
"type": "array",
"items": {
"$ref": "#/$defs/vellum.client.types.chat_message.ChatMessage"
}
}
}
]
}
Expand Down
183 changes: 83 additions & 100 deletions src/vellum/workflows/utils/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,15 @@
Callable,
ForwardRef,
List,
Literal,
Optional,
Type,
Union,
get_args,
get_origin,
)

from pydantic import BaseModel
from pydantic_core import PydanticUndefined
from pydantic import BaseModel, TypeAdapter
from pydantic.json_schema import CoreModeRef, CoreRef, DefsRef, GenerateJsonSchema, JsonRef, JsonSchemaValue
from pydash import snake_case

from vellum import Vellum
Expand All @@ -38,26 +37,72 @@
if TYPE_CHECKING:
from vellum.workflows.workflows.base import BaseWorkflow

type_map: dict[Any, str] = {
str: "string",
int: "integer",
float: "number",
bool: "boolean",
list: "array",
dict: "object",
None: "null",
type(None): "null",
inspect._empty: "null",
"None": "null",
}

for k, v in list(type_map.items()):
if isinstance(k, type):
type_map[k.__name__] = v


def _get_def_name(annotation: Type) -> str:
return f"{annotation.__module__}.{annotation.__qualname__}"
def _strip_title_fields(schema: Any) -> Any:
"""Recursively remove 'title' fields from JSON schema."""
if isinstance(schema, dict):
result = {}
for key, value in schema.items():
if key != "title":
result[key] = _strip_title_fields(value)
return result
elif isinstance(schema, list):
return [_strip_title_fields(item) for item in schema]
else:
return schema


class ModulePathJsonSchemaGenerator(GenerateJsonSchema):
"""Custom JSON schema generator that includes full module paths in $ref references."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Map from original defs_ref to module path
self._defs_ref_to_module_path: dict[str, str] = {}

def get_cache_defs_ref_schema(self, core_ref: CoreRef) -> tuple[DefsRef, JsonSchemaValue]:
"""
Override to store mapping but keep Pydantic's internal key structure.

Args:
core_ref: The core reference (e.g., "vellum.client.types.chat_message.ChatMessage:41224715280")

Returns:
Tuple of (defs_ref, json_schema)
"""
# Get the default defs_ref and schema (Pydantic will use its sanitized key)
original_defs_ref, json_schema = super().get_cache_defs_ref_schema(core_ref)

# Extract the module path from core_ref and store the mapping
core_ref_str = str(core_ref)
if ":" in core_ref_str:
model_path = core_ref_str.split(":")[0]
self._defs_ref_to_module_path[original_defs_ref] = model_path
# Update the json_schema ref to use the full module path with dots
if "$ref" in json_schema:
json_schema["$ref"] = f"#/$defs/{model_path}"
# Update the json_to_defs_refs mapping to point to the original defs_ref
# (which is the sanitized key in definitions)
json_ref: JsonRef = JsonRef(f"#/$defs/{model_path}")
self.json_to_defs_refs[json_ref] = original_defs_ref

return original_defs_ref, json_schema

def get_defs_ref(self, core_mode_ref: CoreModeRef) -> DefsRef:
"""
Override to return the full module path instead of just the model name.

Args:
core_mode_ref: Tuple containing (core_ref, mode)

Returns:
The full module-qualified model name
"""
# Get the default defs_ref
original_defs_ref = super().get_defs_ref(core_mode_ref)
# Return the module path if we have a mapping, otherwise return the original
module_path = self._defs_ref_to_module_path.get(original_defs_ref, original_defs_ref)
return DefsRef(module_path)


def compile_annotation(annotation: Optional[Any], defs: dict[str, Any]) -> dict:
Expand All @@ -70,89 +115,27 @@ def compile_annotation(annotation: Optional[Any], defs: dict[str, Any]) -> dict:
if annotation is datetime:
return {"type": "string", "format": "date-time"}

if get_origin(annotation) is Union:
if is_json_type(get_args(annotation)):
return {"$ref": "#/$defs/vellum.workflows.types.core.Json"}

return {"anyOf": [compile_annotation(a, defs) for a in get_args(annotation)]}

if get_origin(annotation) is Literal:
values = list(get_args(annotation))
types = {type(value) for value in values}
if len(types) == 1:
value_type = types.pop()
if value_type in type_map:
return {"type": type_map[value_type], "enum": values}
else:
return {"enum": values}
else:
return {"enum": values}

if get_origin(annotation) is dict:
_, value_type = get_args(annotation)
return {"type": "object", "additionalProperties": compile_annotation(value_type, defs)}

if get_origin(annotation) is list:
item_type = get_args(annotation)[0]
return {"type": "array", "items": compile_annotation(item_type, defs)}

if get_origin(annotation) is tuple:
args = get_args(annotation)
if len(args) == 2 and args[1] is Ellipsis:
# Tuple[int, ...] with homogeneous items
return {"type": "array", "items": compile_annotation(args[0], defs)}
else:
# Tuple[int, str] with fixed length items
result = {
"type": "array",
"prefixItems": [compile_annotation(arg, defs) for arg in args],
"minItems": len(args),
"maxItems": len(args),
}
return result

if dataclasses.is_dataclass(annotation) and isinstance(annotation, type):
def_name = _get_def_name(annotation)
if def_name not in defs:
properties = {}
required = []
for field in dataclasses.fields(annotation):
properties[field.name] = compile_annotation(field.type, defs)
if field.default is dataclasses.MISSING:
required.append(field.name)
else:
properties[field.name]["default"] = _compile_default_value(field.default)
defs[def_name] = {"type": "object", "properties": properties, "required": required}
return {"$ref": f"#/$defs/{def_name}"}

if inspect.isclass(annotation) and issubclass(annotation, BaseModel):
def_name = _get_def_name(annotation)
if def_name not in defs:
properties = {}
required = []
for field_name, field_info in annotation.model_fields.items():
# field_info is a FieldInfo object which has an annotation attribute
properties[field_name] = compile_annotation(field_info.annotation, defs)

if field_info.description is not None:
properties[field_name]["description"] = field_info.description

if field_info.default is PydanticUndefined:
required.append(field_name)
else:
properties[field_name]["default"] = _compile_default_value(field_info.default)
defs[def_name] = {"type": "object", "properties": properties, "required": required}

return {"$ref": f"#/$defs/{def_name}"}

# Handle ForwardRef early (before trying to use Pydantic)
if type(annotation) is ForwardRef:
# Ignore forward references for now
return {}

if annotation not in type_map:
raise ValueError(f"Failed to compile type: {annotation}")
# Handle Union types - check for Json type special case
if get_origin(annotation) is Union:
if is_json_type(get_args(annotation)):
return {"$ref": "#/$defs/vellum.workflows.types.core.Json"}

return {"type": type_map[annotation]}
try:
adapter = TypeAdapter(annotation)
schema = adapter.json_schema(
mode="serialization", ref_template="#/$defs/{model}", schema_generator=ModulePathJsonSchemaGenerator
)
schema_defs = schema.pop("$defs", {})
schema_defs = {k: _strip_title_fields(v) for k, v in schema_defs.items()}
defs.update(schema_defs)
return _strip_title_fields(schema)
except Exception as e:
raise ValueError(f"Failed to compile type: {annotation}: {e}")


def _compile_default_value(default: Any) -> Any:
Expand Down
Loading