diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..185d634a6b --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,77 @@ +Release type: minor + +Add first-class support for Pydantic v2+ models in Strawberry GraphQL. + +This release introduces a new `strawberry.pydantic` module that allows you to directly +decorate Pydantic `BaseModel` classes to create GraphQL types, inputs, and interfaces +without requiring separate wrapper classes. + +## Basic Usage + +```python +import strawberry +from pydantic import BaseModel + + +@strawberry.pydantic.type +class User(BaseModel): + name: str + age: int + + +@strawberry.pydantic.input +class CreateUserInput(BaseModel): + name: str + age: int +``` + +## Features + +- `@strawberry.pydantic.type` - Convert Pydantic models to GraphQL types +- `@strawberry.pydantic.input` - Convert Pydantic models to GraphQL input types +- `@strawberry.pydantic.interface` - Convert Pydantic models to GraphQL interfaces +- Automatic field extraction from Pydantic models +- Pydantic field descriptions preserved in GraphQL schema +- Pydantic field aliases used as GraphQL field names +- Support for `strawberry.Private` to exclude fields from schema +- Support for `strawberry.field()` with `Annotated` for directives, permissions, deprecation +- Generic Pydantic model support +- `strawberry.pydantic.Error` type for validation error handling in union return types + +## Migration from Experimental + +The experimental `strawberry.experimental.pydantic` integration is now deprecated. +See the documentation for migration guide. + +--- + +## TODO: Improvements Needed Before/After Release + +### Documentation Fixes +- [x] Remove or implement `from_pydantic()`/`to_pydantic()` methods mentioned in docs +- [x] Update test snapshots for Pydantic version (2.11 → 2.12) - made assertions version-agnostic + +### Code Quality +- [x] Replace `typing._GenericAlias` private API usage - now uses `is_generic_alias` utility +- [x] Narrow exception handling in `schema_converter.py` - refactored to `_maybe_convert_validation_error` +- [x] Add computed field tests to verify `include_computed` works with `@computed_field` + +### Future Pydantic v2 Features to Implement + +#### High Priority +- [ ] Validation context - Pass `strawberry.Info` to Pydantic validators +- [ ] Honor `model_config` settings (`strict`, `extra='forbid'`, etc.) +- [ ] Expose `@field_validator` errors in a structured way + +#### Medium Priority +- [ ] Discriminated unions - Use `__typename` as discriminator +- [ ] `@model_validator` support for cross-field validation +- [ ] Separate input/output aliases (`validation_alias` vs `serialization_alias`) +- [ ] Strict mode configuration per field/type + +#### Low Priority +- [ ] `TypeAdapter` support for non-model types +- [ ] `RootModel` support for custom scalar-like types +- [ ] Custom serializers via `@model_serializer` +- [ ] JSON schema generation integration +- [ ] Functional validators (`BeforeValidator`, `AfterValidator`, etc.) diff --git a/docs/integrations/pydantic.md b/docs/integrations/pydantic.md index e1c1967772..46ccdc7d13 100644 --- a/docs/integrations/pydantic.md +++ b/docs/integrations/pydantic.md @@ -1,85 +1,655 @@ --- title: Pydantic support -experimental: true --- # Pydantic support -Strawberry comes with support for -[Pydantic](https://pydantic-docs.helpmanual.io/). This allows for the creation -of Strawberry types from pydantic models without having to write code twice. +Strawberry provides first-class support for [Pydantic](https://pydantic.dev/) +models, allowing you to directly decorate your Pydantic `BaseModel` classes to +create GraphQL types without writing code twice. -Here's a basic example of how this works, let's say we have a pydantic Model for -a user, like this: +## Installation + +```bash +pip install strawberry-graphql[pydantic] +``` + +## Basic Usage + +The simplest way to use Pydantic with Strawberry is to decorate your Pydantic +models directly: ```python -from datetime import datetime -from typing import List, Optional +import strawberry from pydantic import BaseModel +@strawberry.pydantic.type class User(BaseModel): id: int name: str - signup_ts: Optional[datetime] = None - friends: List[int] = [] + email: str + + +@strawberry.type +class Query: + @strawberry.field + def get_user(self) -> User: + return User(id=1, name="John", email="john@example.com") + + +schema = strawberry.Schema(query=Query) ``` -We can create a Strawberry type by using the -`strawberry.experimental.pydantic.type` decorator: +This automatically creates a GraphQL type that includes all fields from your +Pydantic model. + +## Type Decorators + +### `@strawberry.pydantic.type` + +Creates a GraphQL object type from a Pydantic model: + +```python +@strawberry.pydantic.type +class User(BaseModel): + name: str + age: int + is_active: bool = True +``` + +### `@strawberry.pydantic.input` + +Creates a GraphQL input type from a Pydantic model: + +```python +@strawberry.pydantic.input +class CreateUserInput(BaseModel): + name: str + age: int + email: str + + +@strawberry.type +class Mutation: + @strawberry.field + def create_user(self, input: CreateUserInput) -> User: + return User(name=input.name, age=input.age, email=input.email) +``` + +### `@strawberry.pydantic.interface` + +Creates a GraphQL interface from a Pydantic model: + +```python +@strawberry.pydantic.interface +class Node(BaseModel): + id: str + + +@strawberry.pydantic.type +class User(BaseModel): + id: str + name: str + # User implements Node interface +``` + +## Configuration Options + +All decorators accept optional configuration parameters: + +```python +@strawberry.pydantic.type( + name="CustomUser", # Override the GraphQL type name + description="A user in the system", # Add type description +) +class User(BaseModel): + name: str = Field(alias="fullName") + age: int +``` + +## Field Features + +### Field Descriptions + +Pydantic field descriptions are automatically preserved in the GraphQL schema: + +```python +from pydantic import Field + + +@strawberry.pydantic.type +class User(BaseModel): + name: str = Field(description="The user's full name") + age: int = Field(description="The user's age in years") +``` + +### Field Aliases + +Pydantic field aliases are automatically used as GraphQL field names: + +```python +@strawberry.pydantic.type +class User(BaseModel): + name: str = Field(alias="fullName") + age: int = Field(alias="yearsOld") +``` + +### Optional Fields + +Pydantic optional fields are properly handled: + +```python +from typing import Optional + + +@strawberry.pydantic.type +class User(BaseModel): + name: str + email: Optional[str] = None + age: Optional[int] = None +``` + +### Private Fields + +You can use `strawberry.Private` to mark fields that should not be exposed in +the GraphQL schema but are still accessible in your Python code: ```python import strawberry -from .models import User +@strawberry.pydantic.type +class User(BaseModel): + id: int + name: str + password: strawberry.Private[str] # Not exposed in GraphQL + email: str +``` -@strawberry.experimental.pydantic.type(model=User) -class UserType: - id: strawberry.auto - name: strawberry.auto - friends: strawberry.auto +This generates a GraphQL schema with only the public fields: + +```graphql +type User { + id: Int! + name: String! + email: String! +} +``` + +The private fields are still accessible in Python code for use in resolvers or +business logic: + +```python +@strawberry.type +class Query: + @strawberry.field + def get_user(self) -> User: + user = User(id=1, name="John", password="secret", email="john@example.com") + # Can access private field in Python + if user.password: + return user + return None +``` + +## Advanced Usage + +### Nested Types + +Pydantic models can contain other Pydantic models: + +```python +@strawberry.pydantic.type +class Address(BaseModel): + street: str + city: str + zipcode: str + + +@strawberry.pydantic.type +class User(BaseModel): + name: str + address: Address +``` + +### Lists and Collections + +Lists of Pydantic models work seamlessly: + +```python +from typing import List + + +@strawberry.pydantic.type +class User(BaseModel): + name: str + age: int + + +@strawberry.type +class Query: + @strawberry.field + def get_users(self) -> List[User]: + return [User(name="John", age=30), User(name="Jane", age=25)] +``` + +### Validation + +Pydantic validation is automatically applied to input types. Strawberry supports +all Pydantic v2 validation features including field validators, model +validators, and functional validators. + +#### Field Validators + +```python +from pydantic import field_validator + + +@strawberry.pydantic.input +class CreateUserInput(BaseModel): + name: str + age: int + + @field_validator("age") + @classmethod + def validate_age(cls, v: int) -> int: + if v < 0: + raise ValueError("Age must be non-negative") + return v +``` + +#### Model Validators + +Cross-field validation using `@model_validator`: + +```python +from pydantic import model_validator + + +@strawberry.pydantic.input +class DateRangeInput(BaseModel): + start_date: date + end_date: date + + @model_validator(mode="after") + def check_dates(self) -> "DateRangeInput": + if self.start_date > self.end_date: + raise ValueError("start_date must be before end_date") + return self +``` + +#### Functional Validators + +Reusable validation with `Annotated` types: + +```python +from typing import Annotated +from pydantic import AfterValidator + + +def validate_email(v: str) -> str: + if "@" not in v: + raise ValueError("Invalid email") + return v.lower() + + +Email = Annotated[str, AfterValidator(validate_email)] + + +@strawberry.pydantic.input +class UserInput(BaseModel): + email: Email # Validator runs during GraphQL input processing ``` -The `strawberry.experimental.pydantic.type` decorator accepts a Pydantic model -and wraps a class that contains dataclass style fields with `strawberry.auto` as -the type annotation. The fields marked with `strawberry.auto` will inherit their -types from the Pydantic model. +#### Validation Context + +Strawberry automatically passes GraphQL context to Pydantic validators, allowing +access to request information, user authentication, database sessions, etc: + +```python +from pydantic import field_validator, ValidationInfo + + +@strawberry.pydantic.input +class CreatePostInput(BaseModel): + title: str + + @field_validator("title") + @classmethod + def check_permissions(cls, v: str, info: ValidationInfo) -> str: + # Access GraphQL context passed during validation + strawberry_info = info.context.get("info") if info.context else None + if strawberry_info: + user = strawberry_info.context.get("user") + if user and not user.can_create_posts: + raise ValueError("User cannot create posts") + return v +``` -If you want to include all of the fields from your Pydantic model, you can -instead pass `all_fields=True` to the decorator. +### Model Config --> **Note** Care should be taken to avoid accidentally exposing fields that -> -weren't meant to be exposed on an API using this feature. +Pydantic's `model_config` settings are respected during validation: ```python +from pydantic import ConfigDict + + +@strawberry.pydantic.input +class StrictUserInput(BaseModel): + model_config = ConfigDict(strict=True, extra="forbid") + + age: int # Will NOT accept "25" as string + name: str +``` + +#### Per-Field Strict Mode + +You can also enable strict mode on individual fields: + +```python +from pydantic import Field + + +@strawberry.pydantic.input +class UserInput(BaseModel): + age: int = Field(strict=True) # Must be int, not "25" + name: str # Normal coercion allowed +``` + +### Aliases + +Pydantic field aliases are supported for both input and output types: + +```python +from pydantic import Field + + +@strawberry.pydantic.type +class User(BaseModel): + user_id: int = Field(alias="userId") + full_name: str = Field(validation_alias="fullName") +``` + +### Field Directives and Customization + +You can use `strawberry.field()` with `Annotated` types to add GraphQL-specific +features like directives, permissions, and deprecation to individual Pydantic +model fields: + +```python +from typing import Annotated +from pydantic import BaseModel, Field import strawberry -from .models import User +@strawberry.schema_directive( + locations=[strawberry.schema_directive.Location.FIELD_DEFINITION] +) +class Sensitive: + reason: str -@strawberry.experimental.pydantic.type(model=User, all_fields=True) -class UserType: + +@strawberry.schema_directive( + locations=[strawberry.schema_directive.Location.FIELD_DEFINITION] +) +class Range: + min: int + max: int + + +@strawberry.pydantic.type +class User(BaseModel): + # Regular field - uses Pydantic description + name: Annotated[str, Field(description="The user's full name")] + + # Field with directive + email: Annotated[str, strawberry.field(directives=[Sensitive(reason="PII")])] + + # Field with multiple directives and Pydantic features + age: Annotated[ + int, + Field(alias="userAge", description="User's age"), + strawberry.field(directives=[Range(min=0, max=150)]), + ] + + # Field with permissions + phone: Annotated[ + str, + strawberry.field( + permission_classes=[IsAuthenticated], + directives=[Sensitive(reason="Contact Info")], + ), + ] + + # Deprecated field + old_id: Annotated[int, strawberry.field(deprecation_reason="Use 'id' instead")] +``` + +#### Field Customization Options + +When using `strawberry.field()` with Pydantic models, you can specify: + +- **`directives`**: List of GraphQL directives to apply to the field +- **`permission_classes`**: List of permission classes for field-level + authorization +- **`deprecation_reason`**: Mark a field as deprecated with a reason +- **`description`**: Override the Pydantic field description for GraphQL +- **`name`**: Override the GraphQL field name (takes precedence over Pydantic + aliases) + +#### Input Types with Directives + +Field directives work with input types too: + +```python +@strawberry.schema_directive( + locations=[strawberry.schema_directive.Location.INPUT_FIELD_DEFINITION] +) +class Validate: + pattern: str + + +@strawberry.pydantic.input +class CreateUserInput(BaseModel): + name: str + email: Annotated[ + str, strawberry.field(directives=[Validate(pattern=r"^[^@]+@[^@]+\.[^@]+")]) + ] +``` + +### Discriminated Unions + +Pydantic discriminated unions using `Literal` types work seamlessly with +Strawberry's union type resolution: + +```python +from typing import Literal, Union +from pydantic import Field + + +@strawberry.pydantic.type +class Cat(BaseModel): + pet_type: Literal["cat"] + meow_volume: int + + +@strawberry.pydantic.type +class Dog(BaseModel): + pet_type: Literal["dog"] + bark_volume: int + + +@strawberry.type +class Query: + @strawberry.field + def pet(self) -> Cat | Dog: + return Cat(pet_type="cat", meow_volume=10) +``` + +The `Literal` fields are converted to the appropriate GraphQL scalar type +(String, Int, Boolean) and work as discriminators for union type resolution. + +### TypeAdapter and RootModel + +Pydantic's `TypeAdapter` and `RootModel` can be used in resolvers for additional +validation: + +```python +from pydantic import TypeAdapter, RootModel, Field +from typing import Annotated + +# Using TypeAdapter for scalar validation +PositiveInt = Annotated[int, Field(gt=0)] +positive_adapter = TypeAdapter(PositiveInt) + + +@strawberry.type +class Query: + @strawberry.field + def validate_positive(self, value: int) -> int: + return positive_adapter.validate_python(value) + + +# Using RootModel for list validation +class BoundedList(RootModel[Annotated[list[int], Field(min_length=1, max_length=5)]]): + pass + + +@strawberry.type +class Mutation: + @strawberry.mutation + def process_items(self, items: list[int]) -> int: + validated = BoundedList.model_validate(items) + return sum(validated.root) +``` + +## Migration from Experimental + +If you're using the experimental Pydantic integration, here's how to migrate: + +### Before (Experimental) + +```python +from strawberry.experimental.pydantic import type as pydantic_type + + +class UserModel(BaseModel): + name: str + age: int + + +@pydantic_type(UserModel, all_fields=True) +class User: pass ``` -By default, computed fields are excluded. To also include all computed fields -pass `include_computed=True` to the decorator. +### After (First-class) ```python +@strawberry.pydantic.type +class User(BaseModel): + name: str + age: int +``` + +## Complete Example + +```python +from pydantic import BaseModel, Field, field_validator +from typing import List, Optional import strawberry -from .models import User + +@strawberry.pydantic.type +class User(BaseModel): + id: int + name: str = Field(description="The user's full name") + email: str + age: int = Field(ge=0, description="The user's age in years") + is_active: bool = True + tags: List[str] = Field(default_factory=list) -@strawberry.experimental.pydantic.type( - model=User, all_fields=True, include_computed=True -) +@strawberry.pydantic.input +class CreateUserInput(BaseModel): + name: str + email: str + age: int + tags: Optional[List[str]] = None + + @field_validator("age") + @classmethod + def validate_age(cls, v: int) -> int: + if v < 0: + raise ValueError("Age must be non-negative") + return v + + +@strawberry.type +class Query: + @strawberry.field + def get_user(self, id: int) -> Optional[User]: + return User( + id=id, + name="John Doe", + email="john@example.com", + age=30, + tags=["developer", "python"], + ) + + +@strawberry.type +class Mutation: + @strawberry.mutation + def create_user(self, input: CreateUserInput) -> User: + return User( + id=1, + name=input.name, + email=input.email, + age=input.age, + tags=input.tags or [], + ) + + +schema = strawberry.Schema(query=Query, mutation=Mutation) +``` + +--- + +# Experimental Pydantic Support (Deprecated) + +The experimental Pydantic integration is deprecated in favor of the first-class +support above. The experimental integration will be removed in a future version. + +## Experimental Usage + +The experimental integration required creating separate wrapper classes: + +```python +from strawberry.experimental.pydantic import type as pydantic_type + + +class UserModel(BaseModel): + id: int + name: str + signup_ts: Optional[datetime] = None + friends: List[int] = [] + + +@pydantic_type(model=UserModel) +class UserType: + id: strawberry.auto + name: strawberry.auto + friends: strawberry.auto + + +# Or include all fields +@pydantic_type(model=UserModel, all_fields=True) class UserType: pass ``` -## Input types +### Input types Input types are similar to types; we can create one by using the `strawberry.experimental.pydantic.input` decorator: diff --git a/strawberry/__init__.py b/strawberry/__init__.py index 3cedd7c9b8..db51564877 100644 --- a/strawberry/__init__.py +++ b/strawberry/__init__.py @@ -4,7 +4,7 @@ specification and allow for a more natural way of defining GraphQL schemas. """ -from . import experimental, federation, relay +from . import experimental, federation, pydantic, relay from .directive import directive, directive_field from .parent import Parent from .permission import BasePermission @@ -54,6 +54,7 @@ "interface", "lazy", "mutation", + "pydantic", "relay", "scalar", "schema_directive", diff --git a/strawberry/pydantic/__init__.py b/strawberry/pydantic/__init__.py new file mode 100644 index 0000000000..ea5bd7f81a --- /dev/null +++ b/strawberry/pydantic/__init__.py @@ -0,0 +1,22 @@ +"""Strawberry Pydantic integration. + +This module provides first-class support for Pydantic models in Strawberry GraphQL. +You can directly decorate Pydantic BaseModel classes to create GraphQL types. + +Example: + @strawberry.pydantic.type + class User(BaseModel): + name: str + age: int +""" + +from .error import Error +from .object_type import input as input_decorator +from .object_type import interface +from .object_type import type as type_decorator + +# Re-export with proper names +input = input_decorator +type = type_decorator + +__all__ = ["Error", "input", "interface", "type"] diff --git a/strawberry/pydantic/error.py b/strawberry/pydantic/error.py new file mode 100644 index 0000000000..77223e0335 --- /dev/null +++ b/strawberry/pydantic/error.py @@ -0,0 +1,51 @@ +"""Generic error type for Pydantic validation errors in Strawberry GraphQL. + +This module provides a generic Error type that can be used to represent +Pydantic validation errors in GraphQL responses. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from strawberry.types.object_type import type as strawberry_type + +if TYPE_CHECKING: + from pydantic import ValidationError + + +@strawberry_type +class ErrorDetail: + """Represents a single validation error detail.""" + + type: str + loc: list[str] + msg: str + + +@strawberry_type +class Error: + """Generic error type for Pydantic validation errors.""" + + errors: list[ErrorDetail] + + @staticmethod + def from_validation_error(exc: ValidationError) -> Error: + """Create an Error instance from a Pydantic ValidationError. + + Args: + exc: The Pydantic ValidationError to convert + + Returns: + An Error instance containing all validation errors + """ + return Error( + errors=[ + ErrorDetail( + type=error["type"], + loc=[str(loc) for loc in error["loc"]], + msg=error["msg"], + ) + for error in exc.errors() + ] + ) diff --git a/strawberry/pydantic/exceptions.py b/strawberry/pydantic/exceptions.py new file mode 100644 index 0000000000..0fa71ea882 --- /dev/null +++ b/strawberry/pydantic/exceptions.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pydantic import BaseModel + + +class UnregisteredTypeException(Exception): + def __init__(self, type: type[BaseModel]) -> None: + message = ( + f"Cannot find a Strawberry Type for {type} did you forget to register it?" + ) + + super().__init__(message) diff --git a/strawberry/pydantic/fields.py b/strawberry/pydantic/fields.py new file mode 100644 index 0000000000..ff0603ffe7 --- /dev/null +++ b/strawberry/pydantic/fields.py @@ -0,0 +1,222 @@ +"""Field processing utilities for Pydantic models in Strawberry GraphQL. + +This module provides functions to extract and process fields from Pydantic BaseModel +classes, converting them to StrawberryField instances that can be used in GraphQL schemas. +""" + +from __future__ import annotations + +import functools +import operator +import sys +from typing import TYPE_CHECKING, Any, get_args, get_origin + +from strawberry.annotation import StrawberryAnnotation +from strawberry.experimental.pydantic._compat import PydanticCompat +from strawberry.experimental.pydantic.utils import get_default_factory_for_field +from strawberry.types.field import StrawberryField +from strawberry.types.private import is_private +from strawberry.utils.typing import is_generic_alias, is_union + +from .exceptions import UnregisteredTypeException + +if TYPE_CHECKING: + from pydantic import BaseModel + + from strawberry.experimental.pydantic._compat import CompatModelField + +from strawberry.experimental.pydantic._compat import lenient_issubclass + + +def _extract_strawberry_field_from_annotation( + annotation: Any, +) -> StrawberryField | None: + """Extract StrawberryField from an Annotated type annotation. + + Args: + annotation: The type annotation, possibly Annotated[Type, strawberry.field(...)] + + Returns: + StrawberryField instance if found in annotation metadata, None otherwise + """ + # Check if this is an Annotated type + if hasattr(annotation, "__metadata__"): + # Look for StrawberryField in the metadata + for metadata_item in annotation.__metadata__: + if isinstance(metadata_item, StrawberryField): + return metadata_item + + return None + + +def replace_pydantic_types(type_: Any, is_input: bool) -> Any: + """Replace Pydantic types with their Strawberry equivalents for first-class integration.""" + from pydantic import BaseModel + + if lenient_issubclass(type_, BaseModel): + if hasattr(type_, "__strawberry_definition__"): + return type_ + + raise UnregisteredTypeException(type_) + + return type_ + + +def replace_types_recursively( + type_: Any, + is_input: bool, + compat: PydanticCompat, +) -> Any: + """Recursively replace Pydantic types with their Strawberry equivalents.""" + # For now, use a simpler approach similar to the experimental module + basic_type = compat.get_basic_type(type_) + replaced_type = replace_pydantic_types(basic_type, is_input) + + origin = get_origin(type_) + + if not origin or not hasattr(type_, "__args__"): + return replaced_type + + converted = tuple( + replace_types_recursively(t, is_input=is_input, compat=compat) + for t in get_args(replaced_type) + ) + + # Handle special cases for typing generics + if is_generic_alias(replaced_type): + # Use origin[converted] to reconstruct the generic type + return origin[converted] + if is_union(replaced_type): + # Use functools.reduce with operator.or_ to create X | Y | Z union type + return functools.reduce(operator.or_, converted) + + # Fallback to origin[converted] for standard generic types + return origin[converted] + + +def get_type_for_field( + field: CompatModelField, is_input: bool, compat: PydanticCompat +) -> Any: + """Get the GraphQL type for a Pydantic field.""" + return replace_types_recursively(field.outer_type_, is_input, compat=compat) + + +def _get_pydantic_fields( + cls: type[BaseModel], + original_type_annotations: dict[str, type[Any]], + is_input: bool = False, + include_computed: bool = False, +) -> list[StrawberryField]: + """Extract StrawberryFields from a Pydantic BaseModel class. + + This function processes a Pydantic BaseModel and extracts its fields, + converting them to StrawberryField instances that can be used in GraphQL schemas. + All fields from the Pydantic model are included by default, except those marked + with strawberry.Private. + + Fields can be customized using strawberry.field() overrides: + + @strawberry.pydantic.type + class User(BaseModel): + name: str + age: int = strawberry.field(directives=[SomeDirective()]) + + Args: + cls: The Pydantic BaseModel class to extract fields from + original_type_annotations: Type annotations that may override field types + is_input: Whether this is for an input type + include_computed: Whether to include computed fields + + Returns: + List of StrawberryField instances + """ + fields: list[StrawberryField] = [] + + # Get compatibility layer for this model + compat = PydanticCompat.from_model(cls) + + # Extract Pydantic model fields + model_fields = compat.get_model_fields(cls, include_computed=include_computed) + + # Get annotations from the class to check for strawberry.Private and strawberry.field() overrides + existing_annotations = getattr(cls, "__annotations__", {}) + + # Process each field from the Pydantic model + for field_name, pydantic_field in model_fields.items(): + # Check if this field is marked as private or has strawberry.field() metadata + strawberry_override = None + if field_name in existing_annotations: + field_annotation = existing_annotations[field_name] + + # Skip private fields - they shouldn't be included in GraphQL schema + if is_private(field_annotation): + continue + + # Check for strawberry.field() in Annotated metadata + strawberry_override = _extract_strawberry_field_from_annotation( + field_annotation + ) + + # Get the field type from the Pydantic model + field_type = get_type_for_field(pydantic_field, is_input, compat=compat) + + # Start with values from Pydantic field + graphql_name = pydantic_field.alias if pydantic_field.has_alias else None + description = pydantic_field.description + directives = [] + permission_classes = [] + extensions = [] + deprecation_reason = None + + # If there's a strawberry.field() override, merge its values + if strawberry_override: + # strawberry.field() overrides take precedence for GraphQL-specific settings + if strawberry_override.graphql_name is not None: + graphql_name = strawberry_override.graphql_name + if strawberry_override.description is not None: + description = strawberry_override.description + if strawberry_override.directives: + directives = list(strawberry_override.directives) + if strawberry_override.permission_classes: + permission_classes = list(strawberry_override.permission_classes) + if strawberry_override.extensions: + extensions = list(strawberry_override.extensions) + if strawberry_override.deprecation_reason is not None: + deprecation_reason = strawberry_override.deprecation_reason + + strawberry_field = StrawberryField( + python_name=field_name, + graphql_name=graphql_name, + type_annotation=StrawberryAnnotation.from_annotation(field_type), + description=description, + default_factory=get_default_factory_for_field( + pydantic_field, compat=compat + ), + directives=directives, + permission_classes=permission_classes, + extensions=extensions, + deprecation_reason=deprecation_reason, + ) + + # Set the origin module for proper type resolution + origin = cls + module = sys.modules[origin.__module__] + + if ( + isinstance(strawberry_field.type_annotation, StrawberryAnnotation) + and strawberry_field.type_annotation.namespace is None + ): + strawberry_field.type_annotation.namespace = module.__dict__ + + strawberry_field.origin = origin + + fields.append(strawberry_field) + + return fields + + +__all__ = [ + "_get_pydantic_fields", + "replace_pydantic_types", + "replace_types_recursively", +] diff --git a/strawberry/pydantic/object_type.py b/strawberry/pydantic/object_type.py new file mode 100644 index 0000000000..a4ce0a5ed8 --- /dev/null +++ b/strawberry/pydantic/object_type.py @@ -0,0 +1,322 @@ +"""Object type decorators for Pydantic models in Strawberry GraphQL. + +This module provides decorators to convert Pydantic BaseModel classes directly +into GraphQL types, inputs, and interfaces without requiring a separate wrapper class. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, overload + +if TYPE_CHECKING: + import builtins + from collections.abc import Callable, Sequence + + from graphql import GraphQLResolveInfo + from pydantic import BaseModel + +from strawberry.types.base import StrawberryObjectDefinition +from strawberry.types.cast import get_strawberry_type_cast +from strawberry.utils.str_converters import to_camel_case + +from .fields import _get_pydantic_fields + + +def _get_interfaces(cls: builtins.type[Any]) -> list[StrawberryObjectDefinition]: + """Extract interfaces from a class's inheritance hierarchy.""" + interfaces: list[StrawberryObjectDefinition] = [] + + for base in cls.__mro__[1:]: # Exclude current class + if hasattr(base, "__strawberry_definition__"): + type_definition = base.__strawberry_definition__ + if type_definition.is_interface: + interfaces.append(type_definition) + + return interfaces + + +def _process_pydantic_type( + cls: builtins.type[BaseModel], + *, + name: str | None = None, + is_input: bool = False, + is_interface: bool = False, + description: str | None = None, + directives: Sequence[object] | None = (), + include_computed: bool = False, +) -> builtins.type[BaseModel]: + """Process a Pydantic BaseModel class and add GraphQL metadata. + + Args: + cls: The Pydantic BaseModel class to process + name: The GraphQL type name (defaults to class name) + is_input: Whether this is an input type + is_interface: Whether this is an interface type + description: The GraphQL type description + directives: GraphQL directives to apply + include_computed: Whether to include computed fields + + Returns: + The processed BaseModel class with GraphQL metadata + """ + # Get the GraphQL type name + name = name or to_camel_case(cls.__name__) + + # Extract fields using our custom function + # All fields from the Pydantic model are included by default, except strawberry.Private fields + fields = _get_pydantic_fields( + cls=cls, + original_type_annotations={}, + is_input=is_input, + include_computed=include_computed, + ) + + # Get interfaces from inheritance hierarchy + interfaces = _get_interfaces(cls) + + # Create the is_type_of method for proper type resolution + def is_type_of(obj: Any, _info: GraphQLResolveInfo) -> bool: + if (type_cast := get_strawberry_type_cast(obj)) is not None: + return type_cast is cls + return isinstance(obj, cls) + + # Create the GraphQL type definition + cls.__strawberry_definition__ = StrawberryObjectDefinition( # type: ignore + name=name, + is_input=is_input, + is_interface=is_interface, + interfaces=interfaces, + description=description, + directives=directives, + origin=cls, + extend=False, + fields=fields, + is_type_of=is_type_of, + resolve_type=getattr(cls, "resolve_type", None), + ) + + # Add the is_type_of method to the class for testing purposes + cls.is_type_of = is_type_of # type: ignore + + return cls + + +@overload +def type( + cls: builtins.type[BaseModel], + *, + name: str | None = None, + description: str | None = None, + directives: Sequence[object] | None = (), + include_computed: bool = False, +) -> builtins.type[BaseModel]: ... + + +@overload +def type( + *, + name: str | None = None, + description: str | None = None, + directives: Sequence[object] | None = (), + include_computed: bool = False, +) -> Callable[[builtins.type[BaseModel]], builtins.type[BaseModel]]: ... + + +def type( + cls: builtins.type[BaseModel] | None = None, + *, + name: str | None = None, + description: str | None = None, + directives: Sequence[object] | None = (), + include_computed: bool = False, +) -> ( + builtins.type[BaseModel] + | Callable[[builtins.type[BaseModel]], builtins.type[BaseModel]] +): + """Decorator to convert a Pydantic BaseModel directly into a GraphQL type. + + This decorator allows you to use Pydantic models directly as GraphQL types + without needing to create a separate wrapper class. + + Args: + cls: The Pydantic BaseModel class to convert + name: The GraphQL type name (defaults to class name) + description: The GraphQL type description + directives: GraphQL directives to apply to the type + include_computed: Whether to include computed fields + + Returns: + The decorated BaseModel class with GraphQL metadata + + Example: + @strawberry.pydantic.type + class User(BaseModel): + name: str + age: int + + # All fields from the Pydantic model will be included in the GraphQL type + + # You can also use strawberry.field() for field-level customization: + @strawberry.pydantic.type + class User(BaseModel): + name: str + age: int = strawberry.field(directives=[SomeDirective()]) + """ + + def wrap(cls: builtins.type[BaseModel]) -> builtins.type[BaseModel]: + return _process_pydantic_type( + cls, + name=name, + is_input=False, + is_interface=False, + description=description, + directives=directives, + include_computed=include_computed, + ) + + if cls is None: + return wrap + + return wrap(cls) + + +@overload +def input( + cls: builtins.type[BaseModel], + *, + name: str | None = None, + description: str | None = None, + directives: Sequence[object] | None = (), +) -> builtins.type[BaseModel]: ... + + +@overload +def input( + *, + name: str | None = None, + description: str | None = None, + directives: Sequence[object] | None = (), +) -> Callable[[builtins.type[BaseModel]], builtins.type[BaseModel]]: ... + + +def input( + cls: builtins.type[BaseModel] | None = None, + *, + name: str | None = None, + description: str | None = None, + directives: Sequence[object] | None = (), +) -> ( + builtins.type[BaseModel] + | Callable[[builtins.type[BaseModel]], builtins.type[BaseModel]] +): + """Decorator to convert a Pydantic BaseModel directly into a GraphQL input type. + + This decorator allows you to use Pydantic models directly as GraphQL input types + without needing to create a separate wrapper class. + + Args: + cls: The Pydantic BaseModel class to convert + name: The GraphQL input type name (defaults to class name) + description: The GraphQL input type description + directives: GraphQL directives to apply to the input type + + Returns: + The decorated BaseModel class with GraphQL input metadata + + Example: + @strawberry.pydantic.input + class CreateUserInput(BaseModel): + name: str + age: int + + # All fields from the Pydantic model will be included in the GraphQL input type + """ + + def wrap(cls: builtins.type[BaseModel]) -> builtins.type[BaseModel]: + return _process_pydantic_type( + cls, + name=name, + is_input=True, + is_interface=False, + description=description, + directives=directives, + include_computed=False, # Input types don't need computed fields + ) + + if cls is None: + return wrap + + return wrap(cls) + + +@overload +def interface( + cls: builtins.type[BaseModel], + *, + name: str | None = None, + description: str | None = None, + directives: Sequence[object] | None = (), + include_computed: bool = False, +) -> builtins.type[BaseModel]: ... + + +@overload +def interface( + *, + name: str | None = None, + description: str | None = None, + directives: Sequence[object] | None = (), + include_computed: bool = False, +) -> Callable[[builtins.type[BaseModel]], builtins.type[BaseModel]]: ... + + +def interface( + cls: builtins.type[BaseModel] | None = None, + *, + name: str | None = None, + description: str | None = None, + directives: Sequence[object] | None = (), + include_computed: bool = False, +) -> ( + builtins.type[BaseModel] + | Callable[[builtins.type[BaseModel]], builtins.type[BaseModel]] +): + """Decorator to convert a Pydantic BaseModel directly into a GraphQL interface. + + This decorator allows you to use Pydantic models directly as GraphQL interfaces + without needing to create a separate wrapper class. + + Args: + cls: The Pydantic BaseModel class to convert + name: The GraphQL interface name (defaults to class name) + description: The GraphQL interface description + directives: GraphQL directives to apply to the interface + include_computed: Whether to include computed fields + + Returns: + The decorated BaseModel class with GraphQL interface metadata + + Example: + @strawberry.pydantic.interface + class Node(BaseModel): + id: str + """ + + def wrap(cls: builtins.type[BaseModel]) -> builtins.type[BaseModel]: + return _process_pydantic_type( + cls, + name=name, + is_input=False, + is_interface=True, + description=description, + directives=directives, + include_computed=include_computed, + ) + + if cls is None: + return wrap + + return wrap(cls) + + +__all__ = ["input", "interface", "type"] diff --git a/strawberry/schema/schema_converter.py b/strawberry/schema/schema_converter.py index 584fda8c33..b9e548dd07 100644 --- a/strawberry/schema/schema_converter.py +++ b/strawberry/schema/schema_converter.py @@ -17,20 +17,24 @@ from graphql import ( GraphQLAbstractType, GraphQLArgument, + GraphQLBoolean, GraphQLDirective, GraphQLEnumType, GraphQLEnumValue, GraphQLError, GraphQLField, + GraphQLFloat, GraphQLID, GraphQLInputField, GraphQLInputObjectType, + GraphQLInt, GraphQLInterfaceType, GraphQLList, GraphQLNamedType, GraphQLNonNull, GraphQLObjectType, GraphQLScalarType, + GraphQLString, GraphQLType, GraphQLUnionType, Undefined, @@ -97,6 +101,11 @@ from strawberry.types.scalar import ScalarDefinition +def _is_literal_type(type_: Any) -> bool: + """Check if a type is a Literal type (e.g., Literal["cat"]).""" + return typing.get_origin(type_) is typing.Literal + + FieldType = TypeVar( "FieldType", bound=GraphQLField | GraphQLInputField, @@ -213,6 +222,7 @@ def get_arguments( field_arguments, scalar_registry=scalar_registry, config=config, + info=info, ) # the following code allows to omit info and root arguments @@ -344,6 +354,45 @@ def from_enum_value(self, enum_value: EnumValue) -> GraphQLEnumValue: }, ) + def from_literal(self, literal_type: type) -> GraphQLScalarType: + """Convert a Literal type to the appropriate GraphQL scalar type. + + Literal types are commonly used in Pydantic discriminated unions to identify + which union member a value belongs to. For example: + Literal["cat"] -> String (value "cat") + Literal[1] -> Int + Literal[True] -> Boolean + + We use scalar types rather than enums because: + 1. Different union members may have different Literal values for the same field + 2. GraphQL requires fields with the same name to have the same type in unions + 3. Scalars allow the discriminator pattern to work naturally + + Raises: + TypeError: If the Literal contains values that cannot be converted. + """ + args = typing.get_args(literal_type) + if not args: + raise TypeError("Literal type must have at least one value") + + # Get the type of the first argument to determine the scalar type + first_arg = args[0] + + if isinstance(first_arg, str): + return GraphQLString + if isinstance(first_arg, bool): + # bool must come before int since bool is a subclass of int in Python + return GraphQLBoolean + if isinstance(first_arg, int): + return GraphQLInt + if isinstance(first_arg, float): + return GraphQLFloat + + raise TypeError( + f"Unsupported Literal type: {args}. " + "Only string, int, float, and bool Literals are supported." + ) + def from_directive(self, directive: StrawberryDirective) -> GraphQLDirective: graphql_arguments = {} @@ -730,14 +779,22 @@ def extension_resolver( ) -> Any: # parse field arguments into Strawberry input types and convert # field names to Python equivalents - field_args, field_kwargs = get_arguments( - field=field, - source=_source, - info=info, - kwargs=kwargs, - config=self.config, - scalar_registry=self.scalar_registry, - ) + try: + field_args, field_kwargs = get_arguments( + field=field, + source=_source, + info=info, + kwargs=kwargs, + config=self.config, + scalar_registry=self.scalar_registry, + ) + except Exception as exc: + # Check if this is a Pydantic ValidationError that should be + # converted to a strawberry.pydantic.Error type + converted = self._maybe_convert_validation_error(exc, field) + if converted is not None: + return converted + raise resolver_requested_info = False if "info" in field_kwargs: @@ -798,6 +855,40 @@ async def _async_resolver( _resolver._is_default = not field.base_resolver # type: ignore return _resolver + def _maybe_convert_validation_error( + self, exc: Exception, field: StrawberryField + ) -> Any | None: + """Convert Pydantic ValidationError to strawberry.pydantic.Error if applicable. + + Returns the Error instance if conversion was successful, None otherwise. + """ + # Try to import Pydantic - it's an optional dependency + try: + from pydantic import ValidationError + except ImportError: + return None + + # Check if the exception is a Pydantic ValidationError + if not isinstance(exc, ValidationError): + return None + + # Check if the field's return type includes strawberry.pydantic.Error + from strawberry.types.union import StrawberryUnion + + field_type = field.type + if not isinstance(field_type, StrawberryUnion): + return None + + try: + from strawberry.pydantic import Error + + if any(union_type is Error for union_type in field_type.types): + return Error.from_validation_error(exc) + except ImportError: + pass + + return None + def from_scalar(self, scalar: type) -> GraphQLScalarType: from strawberry.relay.types import GlobalID @@ -901,6 +992,11 @@ def from_type(self, type_: StrawberryType | type) -> GraphQLNullableType: ): # TODO: Replace with StrawberryScalar return self.from_scalar(type_) + # Handle Literal types (e.g., Literal["cat"], Literal[1]) + # These are commonly used in Pydantic discriminated unions + if _is_literal_type(type_): + return self.from_literal(type_) + raise TypeError(f"Unexpected type '{type_}'") def from_union(self, union: StrawberryUnion) -> GraphQLUnionType: diff --git a/strawberry/types/arguments.py b/strawberry/types/arguments.py index dc63674ec2..7c8c03bfb1 100644 --- a/strawberry/types/arguments.py +++ b/strawberry/types/arguments.py @@ -33,6 +33,7 @@ from strawberry.schema.config import StrawberryConfig from strawberry.types.base import StrawberryType + from strawberry.types.info import Info from strawberry.types.scalar import ScalarDefinition, ScalarWrapper @@ -182,6 +183,7 @@ def convert_argument( type_: StrawberryType | type, scalar_registry: Mapping[object, ScalarWrapper | ScalarDefinition], config: StrawberryConfig, + info: Info | None = None, ) -> object: from strawberry.relay.types import GlobalID @@ -191,19 +193,19 @@ def convert_argument( # Check if this is Maybe[T | None] (has StrawberryOptional as of_type) if isinstance(type_.of_type, StrawberryOptional): # This is Maybe[T | None] - allows null values - res = convert_argument(value, type_.of_type, scalar_registry, config) + res = convert_argument(value, type_.of_type, scalar_registry, config, info) return Some(res) # This is Maybe[T] - validation for null values is handled by MaybeNullValidationRule # Convert the value and wrap in Some() - res = convert_argument(value, type_.of_type, scalar_registry, config) + res = convert_argument(value, type_.of_type, scalar_registry, config, info) return Some(res) # Handle regular StrawberryOptional (not Maybe) if isinstance(type_, StrawberryOptional): - return convert_argument(value, type_.of_type, scalar_registry, config) + return convert_argument(value, type_.of_type, scalar_registry, config, info) if value is None: return None @@ -224,7 +226,7 @@ def convert_argument( value_list = cast("Iterable", value) return [ - convert_argument(x, type_.of_type, scalar_registry, config) + convert_argument(x, type_.of_type, scalar_registry, config, info) for x in value_list ] @@ -235,11 +237,13 @@ def convert_argument( return value if isinstance(type_, LazyType): - return convert_argument(value, type_.resolve_type(), scalar_registry, config) + return convert_argument( + value, type_.resolve_type(), scalar_registry, config, info + ) if has_enum_definition(type_): enum_definition: StrawberryEnumDefinition = type_.__strawberry_definition__ - return convert_argument(value, enum_definition, scalar_registry, config) + return convert_argument(value, enum_definition, scalar_registry, config, info) if has_object_definition(type_): kwargs = {} @@ -255,9 +259,30 @@ def convert_argument( field.resolve_type(type_definition=type_definition), scalar_registry, config, + info, ) type_ = cast("type", type_) + + # Check if this is a Pydantic model - use model_validate to pass context + if hasattr(type_, "model_validate"): + # Build validation context with strawberry info + validation_context: dict[str, Any] = {} + if info is not None: + validation_context["info"] = info + # Also include the user's context if available + if hasattr(info, "context") and info.context is not None: + validation_context["strawberry_context"] = info.context + + # Always use by_name=True since Strawberry passes data using Python + # field names, not Pydantic aliases. This ensures validation_alias + # and alias work correctly alongside Python field names. + return type_.model_validate( + kwargs, + context=validation_context if validation_context else None, + by_name=True, + ) + return type_(**kwargs) raise UnsupportedTypeError(type_) @@ -268,11 +293,16 @@ def convert_arguments( arguments: list[StrawberryArgument], scalar_registry: Mapping[object, ScalarWrapper | ScalarDefinition], config: StrawberryConfig, + info: Info | None = None, ) -> dict[str, Any]: """Converts a nested dictionary to a dictionary of actual types. It deals with conversion of input types to proper dataclasses and also uses a sentinel value for unset values. + + If `info` is provided, it will be passed to Pydantic model validators + as part of the validation context, allowing validators to access + request context, user info, etc. """ if not arguments: return {} @@ -292,6 +322,7 @@ def convert_arguments( type_=argument.type, config=config, scalar_registry=scalar_registry, + info=info, ) return kwargs diff --git a/tests/experimental/pydantic/schema/test_mutation.py b/tests/experimental/pydantic/schema/test_mutation.py index 43c032884f..db7f990cdc 100644 --- a/tests/experimental/pydantic/schema/test_mutation.py +++ b/tests/experimental/pydantic/schema/test_mutation.py @@ -2,6 +2,7 @@ import strawberry from strawberry.experimental.pydantic._compat import IS_PYDANTIC_V2 +from strawberry.pydantic import Error def test_mutation(): @@ -155,20 +156,14 @@ def create_user(self, input: CreateUserInput) -> UserType: def test_mutation_with_validation_and_error_type(): - class User(pydantic.BaseModel): + # Use the new first-class Pydantic support with automatic validation + @strawberry.pydantic.input + class CreateUserInput(pydantic.BaseModel): name: pydantic.constr(min_length=2) - @strawberry.experimental.pydantic.input(User) - class CreateUserInput: - name: strawberry.auto - - @strawberry.experimental.pydantic.type(User) - class UserType: - name: strawberry.auto - - @strawberry.experimental.pydantic.error_type(User) - class UserError: - name: strawberry.auto + @strawberry.pydantic.type + class UserType(pydantic.BaseModel): + name: str @strawberry.type class Query: @@ -177,19 +172,10 @@ class Query: @strawberry.type class Mutation: @strawberry.mutation - def create_user(self, input: CreateUserInput) -> UserType | UserError: - try: - data = input.to_pydantic() - except pydantic.ValidationError as e: - args: dict[str, list[str]] = {} - for error in e.errors(): - field = error["loc"][0] # currently doesn't support nested errors - field_errors = args.get(field, []) - field_errors.append(error["msg"]) - args[field] = field_errors - return UserError(**args) - else: - return UserType(name=data.name) + def create_user(self, input: CreateUserInput) -> UserType | Error: + # If we get here, validation passed + # Convert to UserType with valid data + return UserType(name=input.name) schema = strawberry.Schema(query=Query, mutation=Mutation) @@ -199,8 +185,12 @@ def create_user(self, input: CreateUserInput) -> UserType | UserError: ... on UserType { name } - ... on UserError { - nameErrors: name + ... on Error { + errors { + type + loc + msg + } } } } @@ -208,14 +198,18 @@ def create_user(self, input: CreateUserInput) -> UserType | UserError: result = schema.execute_sync(query) - assert result.errors is None + assert result.errors is None # No GraphQL errors assert result.data["createUser"].get("name") is None + # Check that validation error was converted to Error type + assert len(result.data["createUser"]["errors"]) == 1 + assert result.data["createUser"]["errors"][0]["type"] == "string_too_short" + assert result.data["createUser"]["errors"][0]["loc"] == ["name"] + if IS_PYDANTIC_V2: - assert result.data["createUser"]["nameErrors"] == [ - ("String should have at least 2 characters") - ] + assert "at least 2 characters" in result.data["createUser"]["errors"][0]["msg"] else: - assert result.data["createUser"]["nameErrors"] == [ - ("ensure this value has at least 2 characters") - ] + assert ( + "ensure this value has at least 2 characters" + in result.data["createUser"]["errors"][0]["msg"] + ) diff --git a/tests/pydantic/__init__.py b/tests/pydantic/__init__.py new file mode 100644 index 0000000000..e7ba6325e4 --- /dev/null +++ b/tests/pydantic/__init__.py @@ -0,0 +1 @@ +# Test package for Strawberry Pydantic integration diff --git a/tests/pydantic/test_aliases.py b/tests/pydantic/test_aliases.py new file mode 100644 index 0000000000..b45efaff19 --- /dev/null +++ b/tests/pydantic/test_aliases.py @@ -0,0 +1,328 @@ +"""Tests for Pydantic v2 alias features with first-class integration.""" + +import pydantic +from pydantic import AliasChoices, Field + +import strawberry + + +def test_validation_alias_input(): + """Test that validation_alias works for input types.""" + + @strawberry.pydantic.input + class UserInput(pydantic.BaseModel): + # GraphQL will use the Python field name, but Pydantic can accept different names + user_name: str = Field(validation_alias="userName") + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + user_name: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_user(self, input: UserInput) -> User: + return User(user_name=input.user_name) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # GraphQL uses the camelCase name from the name converter + result = schema.execute_sync( + """ + mutation { + createUser(input: { userName: "Alice" }) { + userName + } + } + """ + ) + + assert not result.errors + assert result.data["createUser"]["userName"] == "Alice" + + +def test_serialization_alias_output(): + """Test that serialization_alias works for output types.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + # Internal field name differs from what clients might expect + user_name: str = Field(serialization_alias="displayName") + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + return User(user_name="Alice") + + schema = strawberry.Schema(query=Query) + + # The GraphQL field name is still based on the Python field name + # (Strawberry controls the GraphQL schema, not Pydantic's serialization alias) + result = schema.execute_sync( + """ + query { + user { + userName + } + } + """ + ) + + assert not result.errors + assert result.data["user"]["userName"] == "Alice" + + +def test_alias_choices_for_flexibility(): + """Test that AliasChoices allows multiple input names.""" + + @strawberry.pydantic.input + class ConfigInput(pydantic.BaseModel): + # Accept either 'apiKey' or 'api_key' when constructing directly + api_key: str = Field(validation_alias=AliasChoices("apiKey", "api_key")) + + @strawberry.pydantic.type + class Config(pydantic.BaseModel): + api_key: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def set_config(self, input: ConfigInput) -> Config: + return Config(api_key=input.api_key) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # GraphQL uses the standard camelCase name + result = schema.execute_sync( + """ + mutation { + setConfig(input: { apiKey: "secret123" }) { + apiKey + } + } + """ + ) + + assert not result.errors + assert result.data["setConfig"]["apiKey"] == "secret123" + + +def test_alias_path_for_nested_extraction(): + """Test that AliasPath can extract from nested structures in direct model usage.""" + + @strawberry.pydantic.input + class SettingsInput(pydantic.BaseModel): + # This is useful when the model is constructed directly, not from GraphQL + theme: str = Field(default="light") + + @strawberry.pydantic.type + class Settings(pydantic.BaseModel): + theme: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def update_settings(self, input: SettingsInput) -> Settings: + return Settings(theme=input.theme) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + result = schema.execute_sync( + """ + mutation { + updateSettings(input: { theme: "dark" }) { + theme + } + } + """ + ) + + assert not result.errors + assert result.data["updateSettings"]["theme"] == "dark" + + +def test_combined_validation_and_serialization_alias(): + """Test model with both validation and serialization aliases.""" + + @strawberry.pydantic.input + class ProductInput(pydantic.BaseModel): + product_id: str = Field(validation_alias="productID") + display_name: str + + @strawberry.pydantic.type + class Product(pydantic.BaseModel): + product_id: str + display_name: str = Field(serialization_alias="name") + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_product(self, input: ProductInput) -> Product: + return Product(product_id=input.product_id, display_name=input.display_name) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + result = schema.execute_sync( + """ + mutation { + createProduct(input: { productId: "P123", displayName: "Widget" }) { + productId + displayName + } + } + """ + ) + + assert not result.errors + assert result.data["createProduct"]["productId"] == "P123" + assert result.data["createProduct"]["displayName"] == "Widget" + + +def test_alias_with_populate_by_name(): + """Test that populate_by_name allows using either field name or alias.""" + + @strawberry.pydantic.input + class UserInput(pydantic.BaseModel): + model_config = pydantic.ConfigDict(populate_by_name=True) + + email_address: str = Field(alias="email") + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + email_address: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_user(self, input: UserInput) -> User: + return User(email_address=input.email_address) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with the alias (GraphQL uses alias when present) + result = schema.execute_sync( + """ + mutation { + createUser(input: { email: "alice@example.com" }) { + emailAddress + } + } + """ + ) + + assert not result.errors + assert result.data["createUser"]["emailAddress"] == "alice@example.com" + + +def test_strawberry_name_overrides_pydantic_alias(): + """Test that strawberry.field(name=...) overrides Pydantic alias for GraphQL.""" + + @strawberry.pydantic.type + class Product(pydantic.BaseModel): + # Pydantic alias for serialization + internal_id: str = Field(serialization_alias="id") + + @strawberry.type + class Query: + @strawberry.field + def product(self) -> Product: + return Product(internal_id="P123") + + schema = strawberry.Schema(query=Query) + + # GraphQL schema uses the Python field name converted to camelCase + result = schema.execute_sync( + """ + query { + product { + internalId + } + } + """ + ) + + assert not result.errors + assert result.data["product"]["internalId"] == "P123" + + +def test_alias_generator_function(): + """Test using alias_generator in model_config.""" + + def to_camel(string: str) -> str: + components = string.split("_") + return components[0] + "".join(x.title() for x in components[1:]) + + @strawberry.pydantic.input + class DataInput(pydantic.BaseModel): + model_config = pydantic.ConfigDict( + alias_generator=to_camel, populate_by_name=True + ) + + user_name: str + email_address: str + + @strawberry.pydantic.type + class Data(pydantic.BaseModel): + user_name: str + email_address: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def save_data(self, input: DataInput) -> Data: + return Data(user_name=input.user_name, email_address=input.email_address) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # GraphQL fields use camelCase from Strawberry's name converter + result = schema.execute_sync( + """ + mutation { + saveData(input: { userName: "Alice", emailAddress: "alice@example.com" }) { + userName + emailAddress + } + } + """ + ) + + assert not result.errors + assert result.data["saveData"]["userName"] == "Alice" + assert result.data["saveData"]["emailAddress"] == "alice@example.com" diff --git a/tests/pydantic/test_computed.py b/tests/pydantic/test_computed.py new file mode 100644 index 0000000000..b85b5a27b6 --- /dev/null +++ b/tests/pydantic/test_computed.py @@ -0,0 +1,189 @@ +"""Tests for Pydantic v2 computed fields with first-class integration.""" + +import textwrap + +import pydantic +from pydantic import computed_field + +import strawberry + + +def test_computed_field_included(): + """Test that computed fields are included when include_computed=True.""" + + @strawberry.pydantic.type(include_computed=True) + class User(pydantic.BaseModel): + age: int + + @computed_field + @property + def next_age(self) -> int: + return self.age + 1 + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + return User(age=1) + + schema = strawberry.Schema(query=Query) + + expected_schema = """ + type Query { + user: User! + } + + type User { + age: Int! + nextAge: Int! + } + """ + + assert str(schema) == textwrap.dedent(expected_schema).strip() + + query = "{ user { age nextAge } }" + + result = schema.execute_sync(query) + assert not result.errors + assert result.data["user"]["age"] == 1 + assert result.data["user"]["nextAge"] == 2 + + +def test_computed_field_excluded_by_default(): + """Test that computed fields are excluded by default.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + age: int + + @computed_field + @property + def next_age(self) -> int: + return self.age + 1 + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + return User(age=1) + + schema = strawberry.Schema(query=Query) + + expected_schema = """ + type Query { + user: User! + } + + type User { + age: Int! + } + """ + + assert str(schema) == textwrap.dedent(expected_schema).strip() + + # next_age should not be queryable + query = "{ user { age } }" + result = schema.execute_sync(query) + assert not result.errors + assert result.data["user"]["age"] == 1 + + +def test_computed_field_with_description(): + """Test that computed field descriptions are preserved.""" + + @strawberry.pydantic.type(include_computed=True) + class User(pydantic.BaseModel): + age: int + + @computed_field(description="The user's age next year") + @property + def next_age(self) -> int: + return self.age + 1 + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + return User(age=1) + + schema = strawberry.Schema(query=Query) + + # Check schema contains the description + schema_str = str(schema) + assert "nextAge" in schema_str + + +def test_multiple_computed_fields(): + """Test multiple computed fields on a single model.""" + + @strawberry.pydantic.type(include_computed=True) + class User(pydantic.BaseModel): + first_name: str + last_name: str + age: int + + @computed_field + @property + def full_name(self) -> str: + return f"{self.first_name} {self.last_name}" + + @computed_field + @property + def is_adult(self) -> bool: + return self.age >= 18 + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + return User(first_name="John", last_name="Doe", age=25) + + schema = strawberry.Schema(query=Query) + + query = "{ user { firstName lastName fullName age isAdult } }" + + result = schema.execute_sync(query) + assert not result.errors + assert result.data["user"]["firstName"] == "John" + assert result.data["user"]["lastName"] == "Doe" + assert result.data["user"]["fullName"] == "John Doe" + assert result.data["user"]["age"] == 25 + assert result.data["user"]["isAdult"] is True + + +def test_computed_field_with_interface(): + """Test computed fields work with interfaces.""" + + @strawberry.pydantic.interface(include_computed=True) + class Person(pydantic.BaseModel): + name: str + + @computed_field + @property + def greeting(self) -> str: + return f"Hello, {self.name}!" + + @strawberry.pydantic.type(include_computed=True) + class User(pydantic.BaseModel): + name: str + email: str + + @computed_field + @property + def greeting(self) -> str: + return f"Hello, {self.name}!" + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + return User(name="John", email="john@example.com") + + schema = strawberry.Schema(query=Query) + + query = "{ user { name email greeting } }" + + result = schema.execute_sync(query) + assert not result.errors + assert result.data["user"]["name"] == "John" + assert result.data["user"]["greeting"] == "Hello, John!" diff --git a/tests/pydantic/test_description.py b/tests/pydantic/test_description.py new file mode 100644 index 0000000000..ebdf7a81a8 --- /dev/null +++ b/tests/pydantic/test_description.py @@ -0,0 +1,27 @@ +from typing import Annotated + +import pydantic + +import strawberry + + +def test_pydantic_field_descriptions_in_schema(): + """Test that Pydantic field descriptions appear in the schema.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: Annotated[str, pydantic.Field(description="The user's full name")] + age: Annotated[int, pydantic.Field(description="The user's age in years")] + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User(name="John", age=30) + + schema = strawberry.Schema(query=Query) + + # Check that the schema includes field descriptions + schema_str = str(schema) + assert "The user's full name" in schema_str + assert "The user's age in years" in schema_str diff --git a/tests/pydantic/test_discriminated_unions.py b/tests/pydantic/test_discriminated_unions.py new file mode 100644 index 0000000000..c10211d820 --- /dev/null +++ b/tests/pydantic/test_discriminated_unions.py @@ -0,0 +1,371 @@ +"""Tests for Pydantic v2 discriminated unions with first-class integration.""" + +from typing import Annotated, Literal, Union + +import pydantic +from pydantic import Field + +import strawberry + + +def test_basic_discriminated_union_output(): + """Test basic discriminated union for output types.""" + + @strawberry.pydantic.type + class Cat(pydantic.BaseModel): + pet_type: Literal["cat"] + meow_volume: int + + @strawberry.pydantic.type + class Dog(pydantic.BaseModel): + pet_type: Literal["dog"] + bark_volume: int + + # Pydantic discriminated union type (not directly used in GraphQL schema, + # but demonstrates the pattern that works with Strawberry's union handling) + _Pet = Annotated[Union[Cat, Dog], Field(discriminator="pet_type")] + + @strawberry.type + class Query: + @strawberry.field + def pet(self) -> Cat | Dog: + return Cat(pet_type="cat", meow_volume=10) + + schema = strawberry.Schema(query=Query) + + result = schema.execute_sync( + """ + query { + pet { + ... on Cat { + petType + meowVolume + } + ... on Dog { + petType + barkVolume + } + } + } + """ + ) + + assert not result.errors + assert result.data["pet"]["petType"] == "cat" + assert result.data["pet"]["meowVolume"] == 10 + + +def test_discriminated_union_with_different_types(): + """Test discriminated union with more variety in discriminator values.""" + + @strawberry.pydantic.type + class EmailNotification(pydantic.BaseModel): + kind: Literal["email"] + recipient: str + subject: str + + @strawberry.pydantic.type + class SMSNotification(pydantic.BaseModel): + kind: Literal["sms"] + phone_number: str + message: str + + @strawberry.pydantic.type + class PushNotification(pydantic.BaseModel): + kind: Literal["push"] + device_id: str + title: str + + @strawberry.type + class Query: + @strawberry.field + def notifications( + self, + ) -> list[EmailNotification | SMSNotification | PushNotification]: + return [ + EmailNotification( + kind="email", recipient="test@example.com", subject="Hello" + ), + SMSNotification(kind="sms", phone_number="555-1234", message="Hi"), + PushNotification(kind="push", device_id="device-123", title="Alert"), + ] + + schema = strawberry.Schema(query=Query) + + result = schema.execute_sync( + """ + query { + notifications { + ... on EmailNotification { + kind + recipient + subject + } + ... on SMSNotification { + kind + phoneNumber + message + } + ... on PushNotification { + kind + deviceId + title + } + } + } + """ + ) + + assert not result.errors + assert len(result.data["notifications"]) == 3 + assert result.data["notifications"][0]["kind"] == "email" + assert result.data["notifications"][1]["kind"] == "sms" + assert result.data["notifications"][2]["kind"] == "push" + + +def test_discriminated_union_input(): + """Test discriminated union for input types using OneOf.""" + + # For GraphQL inputs, discriminated unions aren't directly supported + # since GraphQL uses @oneOf pattern. But we can test that Pydantic + # models with discriminated unions can still work in resolvers. + + @strawberry.pydantic.type + class CatResult(pydantic.BaseModel): + pet_type: Literal["cat"] + name: str + meow_volume: int + + @strawberry.pydantic.type + class DogResult(pydantic.BaseModel): + pet_type: Literal["dog"] + name: str + bark_volume: int + + @strawberry.pydantic.input + class CreateCatInput(pydantic.BaseModel): + name: str + meow_volume: int + + @strawberry.pydantic.input + class CreateDogInput(pydantic.BaseModel): + name: str + bark_volume: int + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_cat(self, input: CreateCatInput) -> CatResult: + return CatResult( + pet_type="cat", name=input.name, meow_volume=input.meow_volume + ) + + @strawberry.mutation + def create_dog(self, input: CreateDogInput) -> DogResult: + return DogResult( + pet_type="dog", name=input.name, bark_volume=input.bark_volume + ) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + result = schema.execute_sync( + """ + mutation { + createCat(input: { name: "Whiskers", meowVolume: 8 }) { + petType + name + meowVolume + } + } + """ + ) + + assert not result.errors + assert result.data["createCat"]["petType"] == "cat" + assert result.data["createCat"]["name"] == "Whiskers" + + +def test_nested_discriminated_union(): + """Test discriminated unions within nested types.""" + + @strawberry.pydantic.type + class TextContent(pydantic.BaseModel): + content_type: Literal["text"] + body: str + + @strawberry.pydantic.type + class ImageContent(pydantic.BaseModel): + content_type: Literal["image"] + url: str + + @strawberry.pydantic.type + class Post(pydantic.BaseModel): + title: str + content: TextContent | ImageContent + + @strawberry.type + class Query: + @strawberry.field + def posts(self) -> list[Post]: + return [ + Post( + title="Text Post", + content=TextContent(content_type="text", body="Hello world"), + ), + Post( + title="Image Post", + content=ImageContent( + content_type="image", url="https://example.com/img.png" + ), + ), + ] + + schema = strawberry.Schema(query=Query) + + result = schema.execute_sync( + """ + query { + posts { + title + content { + ... on TextContent { + contentType + body + } + ... on ImageContent { + contentType + url + } + } + } + } + """ + ) + + assert not result.errors + assert len(result.data["posts"]) == 2 + assert result.data["posts"][0]["content"]["contentType"] == "text" + assert result.data["posts"][0]["content"]["body"] == "Hello world" + assert result.data["posts"][1]["content"]["contentType"] == "image" + + +def test_discriminated_union_with_default(): + """Test discriminated union with a default discriminator value.""" + + @strawberry.pydantic.type + class StandardShipping(pydantic.BaseModel): + method: Literal["standard"] = "standard" + days: int + + @strawberry.pydantic.type + class ExpressShipping(pydantic.BaseModel): + method: Literal["express"] + days: int + cost_multiplier: float + + @strawberry.pydantic.type + class Order(pydantic.BaseModel): + id: str + shipping: StandardShipping | ExpressShipping + + @strawberry.type + class Query: + @strawberry.field + def order(self) -> Order: + return Order( + id="ORD-123", shipping=StandardShipping(method="standard", days=5) + ) + + schema = strawberry.Schema(query=Query) + + result = schema.execute_sync( + """ + query { + order { + id + shipping { + ... on StandardShipping { + method + days + } + ... on ExpressShipping { + method + days + costMultiplier + } + } + } + } + """ + ) + + assert not result.errors + assert result.data["order"]["shipping"]["method"] == "standard" + assert result.data["order"]["shipping"]["days"] == 5 + + +def test_union_with_str_discriminator(): + """Test union with string literal as discriminator.""" + + @strawberry.pydantic.type + class Circle(pydantic.BaseModel): + shape_type: Literal["circle"] + radius: float + + @strawberry.pydantic.type + class Square(pydantic.BaseModel): + shape_type: Literal["square"] + side: float + + @strawberry.pydantic.type + class Triangle(pydantic.BaseModel): + shape_type: Literal["triangle"] + base: float + height: float + + @strawberry.type + class Query: + @strawberry.field + def shapes(self) -> list[Circle | Square | Triangle]: + return [ + Circle(shape_type="circle", radius=5.0), + Square(shape_type="square", side=10.0), + Triangle(shape_type="triangle", base=6.0, height=8.0), + ] + + schema = strawberry.Schema(query=Query) + + result = schema.execute_sync( + """ + query { + shapes { + ... on Circle { + shapeType + radius + } + ... on Square { + shapeType + side + } + ... on Triangle { + shapeType + base + height + } + } + } + """ + ) + + assert not result.errors + assert len(result.data["shapes"]) == 3 + assert result.data["shapes"][0]["shapeType"] == "circle" + assert result.data["shapes"][0]["radius"] == 5.0 + assert result.data["shapes"][1]["shapeType"] == "square" + assert result.data["shapes"][2]["shapeType"] == "triangle" diff --git a/tests/pydantic/test_error.py b/tests/pydantic/test_error.py new file mode 100644 index 0000000000..f23ea8632a --- /dev/null +++ b/tests/pydantic/test_error.py @@ -0,0 +1,231 @@ +"""Tests for the generic Pydantic Error type.""" + +from typing import Union + +import pydantic +from inline_snapshot import snapshot + +import strawberry +from strawberry.pydantic import Error + + +def test_error_type_from_validation_error(): + """Test creating Error from ValidationError.""" + + class UserInput(pydantic.BaseModel): + name: pydantic.constr(min_length=2) + age: pydantic.conint(ge=0) + + # Test with multiple validation errors + try: + UserInput(name="A", age=-5) + except pydantic.ValidationError as e: + error = Error.from_validation_error(e) + + assert len(error.errors) == 2 + + # Check first error (name) + assert error.errors[0].type == "string_too_short" + assert error.errors[0].loc == ["name"] + assert "at least 2 characters" in error.errors[0].msg + + # Check second error (age) + assert error.errors[1].type == "greater_than_equal" + assert error.errors[1].loc == ["age"] + assert "greater than or equal to 0" in error.errors[1].msg + + +def test_error_type_with_nested_fields(): + """Test Error type with nested field validation errors.""" + + class AddressInput(pydantic.BaseModel): + street: pydantic.constr(min_length=5) + city: str + zip_code: pydantic.constr(pattern=r"^\d{5}$") + + class UserInput(pydantic.BaseModel): + name: str + address: AddressInput + + try: + UserInput( + name="John", + address={"street": "Oak", "city": "NYC", "zip_code": "ABC"}, + ) + except pydantic.ValidationError as e: + error = Error.from_validation_error(e) + + assert len(error.errors) == 2 + + # Check nested street error + assert error.errors[0].type == "string_too_short" + assert error.errors[0].loc == ["address", "street"] + assert "at least 5 characters" in error.errors[0].msg + + # Check nested zip_code error + assert error.errors[1].type == "string_pattern_mismatch" + assert error.errors[1].loc == ["address", "zip_code"] + + +def test_error_in_mutation_with_union_return(): + """Test using Error in a mutation with union return type.""" + + # Use @strawberry.pydantic.input for automatic validation + @strawberry.pydantic.input + class CreateUserInput(pydantic.BaseModel): + name: pydantic.constr(min_length=2) + age: pydantic.conint(ge=0, le=120) + + @strawberry.type + class CreateUserSuccess: + user_id: int + message: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_user( + self, input: CreateUserInput + ) -> Union[CreateUserSuccess, Error]: + # If we get here, validation passed + return CreateUserSuccess( + user_id=1, message=f"User {input.name} created successfully" + ) + + @strawberry.type + class Query: + dummy: str = "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test successful creation + result = schema.execute_sync( + """ + mutation { + createUser(input: { name: "John", age: 30 }) { + ... on CreateUserSuccess { + userId + message + } + ... on Error { + errors { + type + loc + msg + } + } + } + } + """ + ) + + assert not result.errors + assert result.data["createUser"]["userId"] == 1 + assert result.data["createUser"]["message"] == "User John created successfully" + + # Test validation error + result = schema.execute_sync( + """ + mutation { + createUser(input: { name: "J", age: -5 }) { + ... on CreateUserSuccess { + userId + message + } + ... on Error { + errors { + type + loc + msg + } + } + } + } + """ + ) + + assert ( + not result.errors + ) # No GraphQL errors, validation errors are converted to Error type + assert len(result.data["createUser"]["errors"]) == 2 + + # Check first error + assert result.data["createUser"]["errors"][0]["type"] == "string_too_short" + assert result.data["createUser"]["errors"][0]["loc"] == ["name"] + assert "at least 2 characters" in result.data["createUser"]["errors"][0]["msg"] + + # Check second error + assert result.data["createUser"]["errors"][1]["type"] == "greater_than_equal" + assert result.data["createUser"]["errors"][1]["loc"] == ["age"] + + +def test_error_graphql_schema(): + """Test that Error generates correct GraphQL schema.""" + + @strawberry.type + class Query: + @strawberry.field + def test_error(self) -> Error: + # Dummy resolver + return Error(errors=[]) + + schema = strawberry.Schema(query=Query) + + assert str(schema) == snapshot( + """\ +type Error { + errors: [ErrorDetail!]! +} + +type ErrorDetail { + type: String! + loc: [String!]! + msg: String! +} + +type Query { + testError: Error! +}\ +""" + ) + + +def test_error_with_single_validation_error(): + """Test Error type with a single validation error.""" + + class EmailInput(pydantic.BaseModel): + email: pydantic.EmailStr + + try: + EmailInput(email="not-an-email") + except pydantic.ValidationError as e: + error = Error.from_validation_error(e) + + assert len(error.errors) == 1 + assert error.errors[0].type in [ + "value_error", + "email", + ] # Depends on Pydantic version + assert error.errors[0].loc == ["email"] + assert "email" in error.errors[0].msg.lower() + + +def test_error_with_list_field_validation(): + """Test Error type with validation errors in list fields.""" + + class TagsInput(pydantic.BaseModel): + tags: list[pydantic.constr(min_length=2)] + + try: + TagsInput(tags=["ok", "a", "good", "b"]) + except pydantic.ValidationError as e: + error = Error.from_validation_error(e) + + assert len(error.errors) == 2 + + # Check errors for short tags + assert error.errors[0].type == "string_too_short" + assert error.errors[0].loc == ["tags", "1"] # Index 1 is "a" + + assert error.errors[1].type == "string_too_short" + assert error.errors[1].loc == ["tags", "3"] # Index 3 is "b" diff --git a/tests/pydantic/test_error_with_pydantic_input.py b/tests/pydantic/test_error_with_pydantic_input.py new file mode 100644 index 0000000000..076df941ff --- /dev/null +++ b/tests/pydantic/test_error_with_pydantic_input.py @@ -0,0 +1,211 @@ +"""Test Pydantic validation error handling with @strawberry.pydantic.input.""" + +from typing import Union + +import pydantic +from inline_snapshot import snapshot + +import strawberry +from strawberry.pydantic import Error + + +def test_pydantic_input_validation_error_converted_to_error(): + """Test that ValidationError from @strawberry.pydantic.input is converted to Error.""" + + @strawberry.pydantic.input + class CreateUserInput(pydantic.BaseModel): + name: pydantic.constr(min_length=2) + age: pydantic.conint(ge=0, le=120) + + @strawberry.type + class CreateUserSuccess: + id: int + message: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_user( + self, input: CreateUserInput + ) -> Union[CreateUserSuccess, Error]: + # If we get here, validation passed + return CreateUserSuccess( + id=1, message=f"User {input.name} created successfully" + ) + + @strawberry.type + class Query: + dummy: str = "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test successful creation + result = schema.execute_sync( + """ + mutation { + createUser(input: { name: "John", age: 30 }) { + ... on CreateUserSuccess { + id + message + } + ... on Error { + errors { + type + loc + msg + } + } + } + } + """ + ) + + assert not result.errors + assert result.data["createUser"]["id"] == 1 + assert result.data["createUser"]["message"] == "User John created successfully" + + # Test validation error - should be converted to Error type + result = schema.execute_sync( + """ + mutation { + createUser(input: { name: "J", age: -5 }) { + ... on CreateUserSuccess { + id + message + } + ... on Error { + errors { + type + loc + msg + } + } + } + } + """ + ) + + assert not result.errors # No GraphQL errors + assert result.data == snapshot( + { + "createUser": { + "errors": [ + { + "type": "string_too_short", + "loc": ["name"], + "msg": "String should have at least 2 characters", + }, + { + "type": "greater_than_equal", + "loc": ["age"], + "msg": "Input should be greater than or equal to 0", + }, + ] + } + } + ) + + +def test_pydantic_input_validation_error_without_error_in_union(): + """Test that ValidationError is still raised if Error is not in the return type.""" + + @strawberry.pydantic.input + class CreateUserInput(pydantic.BaseModel): + name: pydantic.constr(min_length=2) + age: pydantic.conint(ge=0) + + @strawberry.type + class CreateUserSuccess: + id: int + message: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_user(self, input: CreateUserInput) -> CreateUserSuccess: + # If we get here, validation passed + return CreateUserSuccess( + id=1, message=f"User {input.name} created successfully" + ) + + @strawberry.type + class Query: + dummy: str = "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test validation error - should raise GraphQL error + result = schema.execute_sync( + """ + mutation { + createUser(input: { name: "J", age: -5 }) { + id + message + } + } + """ + ) + + assert result.errors + assert len(result.errors) == 1 + assert "validation error" in result.errors[0].message.lower() + + +def test_graphql_schema_with_pydantic_input(): + """Test that the GraphQL schema is correct with Pydantic input.""" + + @strawberry.pydantic.input + class UserInput(pydantic.BaseModel): + name: str + age: int + + @strawberry.type + class UserResult: + success: bool + message: str + + @strawberry.type + class Query: + dummy: str = "dummy" + + @strawberry.type + class Mutation: + @strawberry.field + def create_user(self, input: UserInput) -> Union[UserResult, Error]: + return UserResult(success=True, message="ok") + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + assert str(schema) == snapshot( + """\ +type Error { + errors: [ErrorDetail!]! +} + +type ErrorDetail { + type: String! + loc: [String!]! + msg: String! +} + +type Mutation { + createUser(input: UserInput!): UserResultError! +} + +type Query { + dummy: String! +} + +input UserInput { + name: String! + age: Int! +} + +type UserResult { + success: Boolean! + message: String! +} + +union UserResultError = UserResult | Error\ +""" + ) diff --git a/tests/pydantic/test_execution.py b/tests/pydantic/test_execution.py new file mode 100644 index 0000000000..037416132f --- /dev/null +++ b/tests/pydantic/test_execution.py @@ -0,0 +1,619 @@ +from typing import Annotated, Optional + +import pydantic +import pytest + +import strawberry + + +def test_basic_query_execution(): + """Test basic query execution with Pydantic types.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User(name="John", age=30) + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUser { + name + age + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == {"getUser": {"name": "John", "age": 30}} + + +def test_query_with_optional_fields(): + """Test query execution with optional fields.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + email: Optional[str] = None + age: Optional[int] = None + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User(name="John", email="john@example.com") + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUser { + name + email + age + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == { + "getUser": {"name": "John", "email": "john@example.com", "age": None} + } + + +def test_mutation_with_input_types(): + """Test mutation execution with Pydantic input types.""" + + @strawberry.pydantic.input + class CreateUserInput(pydantic.BaseModel): + name: str + age: int + email: Optional[str] = None + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + id: int + name: str + age: int + email: Optional[str] = None + + @strawberry.type + class Mutation: + @strawberry.field + def create_user(self, input: CreateUserInput) -> User: + return User(id=1, name=input.name, age=input.age, email=input.email) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + mutation = """ + mutation { + createUser(input: { + name: "Alice" + age: 25 + email: "alice@example.com" + }) { + id + name + age + email + } + } + """ + + result = schema.execute_sync(mutation) + + assert not result.errors + assert result.data == { + "createUser": { + "id": 1, + "name": "Alice", + "age": 25, + "email": "alice@example.com", + } + } + + +def test_mutation_with_partial_input(): + """Test mutation with partial input (optional fields).""" + + @strawberry.pydantic.input + class UpdateUserInput(pydantic.BaseModel): + name: Optional[str] = None + age: Optional[int] = None + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + id: int + name: str + age: int + + @strawberry.type + class Mutation: + @strawberry.field + def update_user(self, id: int, input: UpdateUserInput) -> User: + # Simulate updating a user + return User(id=id, name=input.name or "Default Name", age=input.age or 18) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + mutation = """ + mutation { + updateUser(id: 1, input: { + name: "Updated Name" + }) { + id + name + age + } + } + """ + + result = schema.execute_sync(mutation) + + assert not result.errors + assert result.data == {"updateUser": {"id": 1, "name": "Updated Name", "age": 18}} + + +def test_nested_pydantic_types(): + """Test nested Pydantic types in queries.""" + + @strawberry.pydantic.type + class Address(pydantic.BaseModel): + street: str + city: str + zipcode: str + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + address: Address + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User( + name="John", + age=30, + address=Address(street="123 Main St", city="Anytown", zipcode="12345"), + ) + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUser { + name + age + address { + street + city + zipcode + } + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == { + "getUser": { + "name": "John", + "age": 30, + "address": {"street": "123 Main St", "city": "Anytown", "zipcode": "12345"}, + } + } + + +def test_list_of_pydantic_types(): + """Test lists of Pydantic types.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + + @strawberry.type + class Query: + @strawberry.field + def get_users(self) -> list[User]: + return [ + User(name="John", age=30), + User(name="Jane", age=25), + User(name="Bob", age=35), + ] + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUsers { + name + age + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == { + "getUsers": [ + {"name": "John", "age": 30}, + {"name": "Jane", "age": 25}, + {"name": "Bob", "age": 35}, + ] + } + + +def test_pydantic_field_descriptions_in_schema(): + """Test that Pydantic field descriptions appear in the schema.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: Annotated[str, pydantic.Field(description="The user's full name")] + age: Annotated[int, pydantic.Field(description="The user's age in years")] + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User(name="John", age=30) + + schema = strawberry.Schema(query=Query) + + # Check that the schema includes field descriptions + schema_str = str(schema) + assert "The user's full name" in schema_str + assert "The user's age in years" in schema_str + + +def test_pydantic_field_aliases_in_execution(): + """Test that Pydantic field aliases work in GraphQL execution.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: Annotated[str, pydantic.Field(alias="fullName")] + age: Annotated[int, pydantic.Field(alias="yearsOld")] + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + # When using aliases, we need to create the User with the aliased field names + return User(fullName="John", yearsOld=30) + + schema = strawberry.Schema(query=Query) + + # Query using the aliased field names + query = """ + query { + getUser { + fullName + yearsOld + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == {"getUser": {"fullName": "John", "yearsOld": 30}} + + +def test_pydantic_validation_integration(): + """Test that Pydantic validation works with GraphQL inputs.""" + + @strawberry.pydantic.input + class CreateUserInput(pydantic.BaseModel): + name: str + age: int + email: Annotated[str, pydantic.Field(pattern=r"^[^@]+@[^@]+\.[^@]+$")] + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + email: str + + @strawberry.type + class Mutation: + @strawberry.field + def create_user(self, input: CreateUserInput) -> User: + return User(name=input.name, age=input.age, email=input.email) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with valid input + mutation = """ + mutation { + createUser(input: { + name: "Alice" + age: 25 + email: "alice@example.com" + }) { + name + age + email + } + } + """ + + result = schema.execute_sync(mutation) + + assert not result.errors + assert result.data == { + "createUser": {"name": "Alice", "age": 25, "email": "alice@example.com"} + } + + +def test_complex_pydantic_types_execution(): + """Test complex Pydantic types with various field types.""" + + @strawberry.pydantic.type + class Profile(pydantic.BaseModel): + bio: Optional[str] = None + website: Optional[str] = None + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + id: int + name: str + email: str + is_active: bool + tags: list[str] = [] + profile: Optional[Profile] = None + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User( + id=1, + name="John Doe", + email="john@example.com", + is_active=True, + tags=["developer", "python", "graphql"], + profile=Profile( + bio="Software developer", website="https://johndoe.com" + ), + ) + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUser { + id + name + email + isActive + tags + profile { + bio + website + } + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == { + "getUser": { + "id": 1, + "name": "John Doe", + "email": "john@example.com", + "isActive": True, + "tags": ["developer", "python", "graphql"], + "profile": {"bio": "Software developer", "website": "https://johndoe.com"}, + } + } + + +def test_pydantic_interface_basic(): + """Test basic Pydantic interface functionality.""" + + @strawberry.pydantic.interface + class Node(pydantic.BaseModel): + id: str + + # Interface requires implementing types for proper execution + @strawberry.pydantic.type + class User(pydantic.BaseModel): + id: str + name: str + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User(id="user_1", name="John") + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUser { + id + name + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == {"getUser": {"id": "user_1", "name": "John"}} + + +def test_error_handling_with_pydantic_validation(): + """Test error handling when Pydantic validation fails.""" + + @strawberry.pydantic.input + class CreateUserInput(pydantic.BaseModel): + name: str + age: int + + @pydantic.validator("age") + def validate_age(cls, v): + if v < 0: + raise ValueError("Age must be non-negative") + return v + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + + @strawberry.type + class Mutation: + @strawberry.field + def create_user(self, input: CreateUserInput) -> User: + return User(name=input.name, age=input.age) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with invalid input (negative age) + mutation = """ + mutation { + createUser(input: { + name: "Alice" + age: -5 + }) { + name + age + } + } + """ + + result = schema.execute_sync(mutation) + + # Should handle validation error gracefully + # The exact error handling depends on Strawberry's error handling implementation + assert result.errors or result.data is None + + +@pytest.mark.asyncio +async def test_async_execution_with_pydantic(): + """Test async execution with Pydantic types.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + + @strawberry.type + class Query: + @strawberry.field + async def get_user(self) -> User: + # Simulate async operation + return User(name="John", age=30) + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUser { + name + age + } + } + """ + + result = await schema.execute(query) + + assert not result.errors + assert result.data == {"getUser": {"name": "John", "age": 30}} + + +def test_strawberry_private_fields_not_in_schema(): + """Test that strawberry.Private fields are not exposed in GraphQL schema.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + id: int + name: str + password: strawberry.Private[str] + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User(id=1, name="John", password="secret123") + + schema = strawberry.Schema(query=Query) + + # Check that password field is not in the schema + schema_str = str(schema) + assert "password" not in schema_str + assert "id: Int!" in schema_str + assert "name: String!" in schema_str + + # Test that we can query the exposed fields + query = """ + query { + getUser { + id + name + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == {"getUser": {"id": 1, "name": "John"}} + + # Test that querying the private field fails + query_with_private = """ + query { + getUser { + id + name + password + } + } + """ + + result = schema.execute_sync(query_with_private) + assert result.errors + assert "Cannot query field 'password'" in str(result.errors[0]) diff --git a/tests/pydantic/test_fields.py b/tests/pydantic/test_fields.py new file mode 100644 index 0000000000..e23a0aa3c7 --- /dev/null +++ b/tests/pydantic/test_fields.py @@ -0,0 +1,331 @@ +from typing import Annotated + +import pydantic +import pytest +from inline_snapshot import snapshot + +import strawberry +from strawberry.pydantic.exceptions import UnregisteredTypeException +from strawberry.schema_directive import Location +from strawberry.types.base import get_object_definition + + +def test_pydantic_field_descriptions(): + """Test that Pydantic field descriptions are preserved.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + age: Annotated[int, pydantic.Field(description="The user's age")] + name: Annotated[str, pydantic.Field(description="The user's name")] + + definition = get_object_definition(User, strict=True) + + age_field = next(f for f in definition.fields if f.python_name == "age") + name_field = next(f for f in definition.fields if f.python_name == "name") + + assert age_field.description == "The user's age" + assert name_field.description == "The user's name" + + +def test_pydantic_field_aliases(): + """Test that Pydantic field aliases are used as GraphQL names.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + age: Annotated[int, pydantic.Field(alias="userAge")] + name: Annotated[str, pydantic.Field(alias="userName")] + + definition = get_object_definition(User, strict=True) + + age_field = next(f for f in definition.fields if f.python_name == "age") + name_field = next(f for f in definition.fields if f.python_name == "name") + + assert age_field.graphql_name == "userAge" + assert name_field.graphql_name == "userName" + + +def test_can_use_strawberry_types(): + """Test that Pydantic models can use Strawberry types.""" + + @strawberry.type + class Address: + street: str + city: str + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + address: Address + + definition = get_object_definition(User, strict=True) + + address_field = next(f for f in definition.fields if f.python_name == "address") + + assert address_field.type is Address + + @strawberry.type + class Query: + @strawberry.field + @staticmethod + def user() -> User: + return User( + name="Rabbit", address=Address(street="123 Main St", city="Wonderland") + ) + + schema = strawberry.Schema(query=Query) + + query = """query { + user { + name + address { + street + city + } + } + }""" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == snapshot( + { + "user": { + "name": "Rabbit", + "address": {"street": "123 Main St", "city": "Wonderland"}, + } + } + ) + + +def test_all_models_need_to_marked_as_strawberry_types(): + class Address(pydantic.BaseModel): + street: str + city: str + + with pytest.raises( + UnregisteredTypeException, + match=( + r"Cannot find a Strawberry Type for did you forget to register it\?" + ), + ): + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + address: Address + + +def test_field_directives_basic(): + """Test that strawberry.field() directives work with Pydantic models using Annotated.""" + + @strawberry.schema_directive(locations=[Location.FIELD_DEFINITION]) + class Sensitive: + reason: str + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: Annotated[int, strawberry.field(directives=[Sensitive(reason="PII")])] + + definition = get_object_definition(User, strict=True) + + name_field = next(f for f in definition.fields if f.python_name == "name") + age_field = next(f for f in definition.fields if f.python_name == "age") + + # Name field should have no directives + assert len(name_field.directives) == 0 + + # Age field should have the Sensitive directive + assert len(age_field.directives) == 1 + assert isinstance(age_field.directives[0], Sensitive) + assert age_field.directives[0].reason == "PII" + + +def test_field_directives_multiple(): + """Test multiple directives on a single field.""" + + @strawberry.schema_directive(locations=[Location.FIELD_DEFINITION]) + class Sensitive: + reason: str + + @strawberry.schema_directive(locations=[Location.FIELD_DEFINITION]) + class Tag: + name: str + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + email: Annotated[ + str, + strawberry.field(directives=[Sensitive(reason="PII"), Tag(name="contact")]), + ] + + definition = get_object_definition(User, strict=True) + + email_field = next(f for f in definition.fields if f.python_name == "email") + + # Email field should have both directives + assert len(email_field.directives) == 2 + + sensitive_directive = next( + d for d in email_field.directives if isinstance(d, Sensitive) + ) + tag_directive = next(d for d in email_field.directives if isinstance(d, Tag)) + + assert sensitive_directive.reason == "PII" + assert tag_directive.name == "contact" + + +def test_field_directives_with_pydantic_features(): + """Test that strawberry.field() directives work alongside Pydantic field features.""" + + @strawberry.schema_directive(locations=[Location.FIELD_DEFINITION]) + class Range: + min: int + max: int + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: Annotated[str, pydantic.Field(description="The user's name")] + age: Annotated[ + int, + pydantic.Field(alias="userAge", description="The user's age"), + strawberry.field(directives=[Range(min=0, max=150)]), + ] + + definition = get_object_definition(User, strict=True) + + name_field = next(f for f in definition.fields if f.python_name == "name") + age_field = next(f for f in definition.fields if f.python_name == "age") + + # Name field should preserve Pydantic description + assert name_field.description == "The user's name" + assert len(name_field.directives) == 0 + + # Age field should have both Pydantic features and Strawberry directive + assert age_field.description == "The user's age" + assert age_field.graphql_name == "userAge" + assert len(age_field.directives) == 1 + assert isinstance(age_field.directives[0], Range) + assert age_field.directives[0].min == 0 + assert age_field.directives[0].max == 150 + + +def test_field_directives_override_description(): + """Test that strawberry.field() description overrides Pydantic description.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: Annotated[str, pydantic.Field(description="Pydantic description")] + age: Annotated[ + int, + pydantic.Field(description="Pydantic age description"), + strawberry.field(description="Strawberry description override"), + ] + + definition = get_object_definition(User, strict=True) + + name_field = next(f for f in definition.fields if f.python_name == "name") + age_field = next(f for f in definition.fields if f.python_name == "age") + + # Name field should use Pydantic description + assert name_field.description == "Pydantic description" + + # Age field should use strawberry.field() description override + assert age_field.description == "Strawberry description override" + + +def test_field_directives_with_permissions(): + """Test that strawberry.field() permissions work with Pydantic models.""" + + class IsAuthenticated(strawberry.BasePermission): + message = "User is not authenticated" + + def has_permission(self, source, info, **kwargs): # noqa: ANN003 + return True # Simplified for testing + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + email: Annotated[str, strawberry.field(permission_classes=[IsAuthenticated])] + + definition = get_object_definition(User, strict=True) + + name_field = next(f for f in definition.fields if f.python_name == "name") + email_field = next(f for f in definition.fields if f.python_name == "email") + + # Name field should have no permissions + assert len(name_field.permission_classes) == 0 + + # Email field should have the permission + assert len(email_field.permission_classes) == 1 + assert email_field.permission_classes[0] == IsAuthenticated + + +def test_field_directives_with_deprecation(): + """Test that strawberry.field() deprecation works with Pydantic models.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + old_field: Annotated[ + str, strawberry.field(deprecation_reason="Use name instead") + ] + + definition = get_object_definition(User, strict=True) + + name_field = next(f for f in definition.fields if f.python_name == "name") + old_field = next(f for f in definition.fields if f.python_name == "old_field") + + # Name field should not be deprecated + assert name_field.deprecation_reason is None + + # Old field should be deprecated + assert old_field.deprecation_reason == "Use name instead" + + +def test_field_directives_input_types(): + """Test that field directives work with Pydantic input types.""" + + @strawberry.schema_directive(locations=[Location.INPUT_FIELD_DEFINITION]) + class Validate: + pattern: str + + @strawberry.pydantic.input + class CreateUserInput(pydantic.BaseModel): + name: str + email: Annotated[ + str, strawberry.field(directives=[Validate(pattern=r"^[^@]+@[^@]+\.[^@]+")]) + ] + + definition = get_object_definition(CreateUserInput, strict=True) + + name_field = next(f for f in definition.fields if f.python_name == "name") + email_field = next(f for f in definition.fields if f.python_name == "email") + + # Name field should have no directives + assert len(name_field.directives) == 0 + + # Email field should have the validation directive + assert len(email_field.directives) == 1 + assert isinstance(email_field.directives[0], Validate) + assert email_field.directives[0].pattern == r"^[^@]+@[^@]+\.[^@]+" + + +def test_field_directives_graphql_name_override(): + """Test that strawberry.field() can override Pydantic field aliases for GraphQL names.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: Annotated[ + str, + pydantic.Field(alias="pydantic_name"), + strawberry.field(name="strawberry_name"), + ] + + definition = get_object_definition(User, strict=True) + + name_field = next(f for f in definition.fields if f.python_name == "name") + + # strawberry.field() graphql_name should override Pydantic alias + assert name_field.graphql_name == "strawberry_name" diff --git a/tests/pydantic/test_functional_validators.py b/tests/pydantic/test_functional_validators.py new file mode 100644 index 0000000000..3c1f4330aa --- /dev/null +++ b/tests/pydantic/test_functional_validators.py @@ -0,0 +1,340 @@ +"""Tests for Pydantic v2 functional validators with first-class integration.""" + +from typing import Annotated, Any + +import pydantic +from pydantic import AfterValidator, BeforeValidator + +import strawberry + + +def test_after_validator_runs_on_input(): + """Test that AfterValidator runs during GraphQL input processing.""" + + def validate_email(v: str) -> str: + if "@" not in v: + raise ValueError("Invalid email format") + return v.lower() + + Email = Annotated[str, AfterValidator(validate_email)] + + @strawberry.pydantic.input + class UserInput(pydantic.BaseModel): + email: Email + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + email: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_user(self, input: UserInput) -> User: + return User(email=input.email) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test valid email - should be lowercased + result = schema.execute_sync( + """ + mutation { + createUser(input: { email: "TEST@EXAMPLE.COM" }) { + email + } + } + """ + ) + + assert not result.errors + assert result.data["createUser"]["email"] == "test@example.com" + + # Test invalid email + result = schema.execute_sync( + """ + mutation { + createUser(input: { email: "invalid" }) { + email + } + } + """ + ) + + assert result.errors is not None + assert len(result.errors) == 1 + assert "Invalid email format" in result.errors[0].message + + +def test_before_validator_transforms_input(): + """Test that BeforeValidator transforms data before type validation.""" + + def parse_tags(v: Any) -> list[str]: + if isinstance(v, str): + return [tag.strip() for tag in v.split(",")] + return v + + TagList = Annotated[list[str], BeforeValidator(parse_tags)] + + @strawberry.pydantic.input + class PostInput(pydantic.BaseModel): + title: str + tags: TagList + + @strawberry.pydantic.type + class Post(pydantic.BaseModel): + title: str + tags: list[str] + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_post(self, input: PostInput) -> Post: + return Post(title=input.title, tags=input.tags) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with list input + result = schema.execute_sync( + """ + mutation { + createPost(input: { title: "Hello", tags: ["python", "graphql"] }) { + title + tags + } + } + """ + ) + + assert not result.errors + assert result.data["createPost"]["tags"] == ["python", "graphql"] + + +def test_multiple_validators_chain(): + """Test that multiple validators chain correctly.""" + + def strip_whitespace(v: str) -> str: + return v.strip() + + def to_lowercase(v: str) -> str: + return v.lower() + + def check_not_empty(v: str) -> str: + if not v: + raise ValueError("Cannot be empty") + return v + + CleanString = Annotated[ + str, + BeforeValidator(strip_whitespace), + BeforeValidator(to_lowercase), + AfterValidator(check_not_empty), + ] + + @strawberry.pydantic.input + class UsernameInput(pydantic.BaseModel): + username: CleanString + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + username: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_user(self, input: UsernameInput) -> User: + return User(username=input.username) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test transformation chain + result = schema.execute_sync( + """ + mutation { + createUser(input: { username: " ALICE " }) { + username + } + } + """ + ) + + assert not result.errors + assert result.data["createUser"]["username"] == "alice" + + # Test empty string after strip + result = schema.execute_sync( + """ + mutation { + createUser(input: { username: " " }) { + username + } + } + """ + ) + + assert result.errors is not None + assert "Cannot be empty" in result.errors[0].message + + +def test_validator_with_field_constraints(): + """Test validators combined with Field constraints.""" + + def normalize_phone(v: str) -> str: + # Remove non-digits + return "".join(c for c in v if c.isdigit()) + + Phone = Annotated[ + str, + BeforeValidator(normalize_phone), + pydantic.Field(min_length=10, max_length=11), + ] + + @strawberry.pydantic.input + class ContactInput(pydantic.BaseModel): + phone: Phone + + @strawberry.pydantic.type + class Contact(pydantic.BaseModel): + phone: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_contact(self, input: ContactInput) -> Contact: + return Contact(phone=input.phone) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with formatted phone number + result = schema.execute_sync( + """ + mutation { + createContact(input: { phone: "(555) 123-4567" }) { + phone + } + } + """ + ) + + assert not result.errors + assert result.data["createContact"]["phone"] == "5551234567" + + # Test with too short phone + result = schema.execute_sync( + """ + mutation { + createContact(input: { phone: "123" }) { + phone + } + } + """ + ) + + assert result.errors is not None + assert "too_short" in result.errors[0].message + + +def test_reusable_annotated_types_across_models(): + """Test that Annotated types can be reused across multiple models.""" + + def validate_positive(v: int) -> int: + if v <= 0: + raise ValueError("Must be positive") + return v + + PositiveInt = Annotated[int, AfterValidator(validate_positive)] + + @strawberry.pydantic.input + class OrderInput(pydantic.BaseModel): + quantity: PositiveInt + price_cents: PositiveInt + + @strawberry.pydantic.input + class InventoryInput(pydantic.BaseModel): + stock_count: PositiveInt + + @strawberry.pydantic.type + class Result(pydantic.BaseModel): + success: bool + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_order(self, input: OrderInput) -> Result: + return Result(success=True) + + @strawberry.mutation + def update_inventory(self, input: InventoryInput) -> Result: + return Result(success=True) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test OrderInput validation + result = schema.execute_sync( + """ + mutation { + createOrder(input: { quantity: 0, priceCents: 100 }) { + success + } + } + """ + ) + + assert result.errors is not None + assert "Must be positive" in result.errors[0].message + + # Test InventoryInput validation + result = schema.execute_sync( + """ + mutation { + updateInventory(input: { stockCount: -5 }) { + success + } + } + """ + ) + + assert result.errors is not None + assert "Must be positive" in result.errors[0].message + + # Test valid inputs + result = schema.execute_sync( + """ + mutation { + createOrder(input: { quantity: 5, priceCents: 1000 }) { + success + } + } + """ + ) + + assert not result.errors + assert result.data["createOrder"]["success"] is True diff --git a/tests/pydantic/test_generics.py b/tests/pydantic/test_generics.py new file mode 100644 index 0000000000..95352cf1be --- /dev/null +++ b/tests/pydantic/test_generics.py @@ -0,0 +1,136 @@ +import sys +from typing import Generic, TypeVar + +import pydantic +import pytest +from inline_snapshot import snapshot + +import strawberry +from strawberry.types.base import ( + StrawberryList, + StrawberryOptional, + StrawberryTypeVar, + get_object_definition, +) + +T = TypeVar("T") + + +def test_basic_pydantic_generic_fields(): + """Test that pydantic generic models preserve field types correctly.""" + + @strawberry.pydantic.type + class GenericModel(pydantic.BaseModel, Generic[T]): + value: T + name: str = "default" + + definition = get_object_definition(GenericModel, strict=True) + + # Check fields + fields = definition.fields + assert len(fields) == 2 + + value_field = next(f for f in fields if f.python_name == "value") + name_field = next(f for f in fields if f.python_name == "name") + + # The value field should contain a TypeVar (generic parameter) + assert isinstance(value_field.type, StrawberryTypeVar) + assert value_field.type.type_var is T + + # The name field should be concrete + assert name_field.type is str + + +def test_pydantic_generic_with_concrete_type(): + """Test pydantic with a concrete generic instantiation.""" + + class GenericModel(pydantic.BaseModel, Generic[T]): + data: T + + # Create a concrete version by inheriting from GenericModel[int] + @strawberry.pydantic.type + class ConcreteModel(GenericModel[int]): + pass + + definition = get_object_definition(ConcreteModel, strict=True) + + # Verify the field type is concrete + [data_field] = definition.fields + assert data_field.python_name == "data" + assert data_field.type is int + + +def test_pydantic_generic_schema(): + """Test the GraphQL schema generated from pydantic generic types.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel, Generic[T]): + id: int + data: T + name: str = "default" + + # Create concrete versions + @strawberry.pydantic.type + class UserString(User[str]): + pass + + @strawberry.pydantic.type + class UserInt(User[int]): + pass + + @strawberry.type + class Query: + @strawberry.field + def get_user_string(self) -> UserString: + return UserString(id=1, data="hello", name="test") + + @strawberry.field + def get_user_int(self) -> UserInt: + return UserInt(id=2, data=42, name="test") + + schema = strawberry.Schema(query=Query) + + assert str(schema) == snapshot("""\ +type Query { + getUserString: UserString! + getUserInt: UserInt! +} + +type UserInt { + id: Int! + data: Int! + name: String! +} + +type UserString { + id: Int! + data: String! + name: String! +}\ +""") + + +def test_can_convert_generic_alias_fields_to_strawberry(): + @strawberry.pydantic.type + class Test(pydantic.BaseModel): + list_1d: list[int] + list_2d: list[list[int]] + + fields = get_object_definition(Test, strict=True).fields + assert isinstance(fields[0].type, StrawberryList) + assert isinstance(fields[1].type, StrawberryList) + + +@pytest.mark.skipif( + sys.version_info < (3, 10), + reason="union type expressions were added in python 3.10", +) +def test_can_convert_optional_union_type_expression_fields_to_strawberry(): + @strawberry.pydantic.type + class Test(pydantic.BaseModel): + optional_list: list[int] | None + optional_str: str | None + + fields = get_object_definition(Test, strict=True).fields + assert isinstance(fields[0].type, StrawberryOptional) + assert isinstance(fields[1].type, StrawberryOptional) diff --git a/tests/pydantic/test_inputs.py b/tests/pydantic/test_inputs.py new file mode 100644 index 0000000000..ead983c63c --- /dev/null +++ b/tests/pydantic/test_inputs.py @@ -0,0 +1,765 @@ +from typing import Annotated, Optional + +import pydantic +from inline_snapshot import snapshot + +import strawberry +from strawberry.types.base import get_object_definition + + +def test_basic_input_type(): + """Test that @strawberry.pydantic.input works.""" + + @strawberry.pydantic.input + class CreateUserInput(pydantic.BaseModel): + age: int + name: str + + definition = get_object_definition(CreateUserInput, strict=True) + + assert definition.name == "CreateUserInput" + assert definition.is_input is True + assert len(definition.fields) == 2 + + +def test_input_type_with_valid_data(): + """Test input type with various valid data scenarios.""" + + @strawberry.pydantic.input + class UserInput(pydantic.BaseModel): + name: str + age: int + email: str + is_active: bool = True + tags: list[str] = [] + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + id: int + name: str + age: int + email: str + is_active: bool + tags: list[str] + + @strawberry.type + class Mutation: + @strawberry.field + def create_user(self, input: UserInput) -> User: + return User( + id=1, + name=input.name, + age=input.age, + email=input.email, + is_active=input.is_active, + tags=input.tags, + ) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with all fields provided + mutation = """ + mutation { + createUser(input: { + name: "John Doe" + age: 30 + email: "john@example.com" + isActive: false + tags: ["developer", "python"] + }) { + id + name + age + email + isActive + tags + } + } + """ + + result = schema.execute_sync(mutation) + + assert not result.errors + assert result.data == snapshot( + { + "createUser": { + "id": 1, + "name": "John Doe", + "age": 30, + "email": "john@example.com", + "isActive": False, + "tags": ["developer", "python"], + } + } + ) + + # Test with default values + mutation_defaults = """ + mutation { + createUser(input: { + name: "Jane Doe" + age: 25 + email: "jane@example.com" + }) { + id + name + age + email + isActive + tags + } + } + """ + + result = schema.execute_sync(mutation_defaults) + + assert not result.errors + assert result.data == snapshot( + { + "createUser": { + "id": 1, + "name": "Jane Doe", + "age": 25, + "email": "jane@example.com", + "isActive": True, # default value + "tags": [], # default value + } + } + ) + + +def test_input_type_with_invalid_email(): + """Test input type with invalid email format.""" + + @strawberry.pydantic.input + class UserInput(pydantic.BaseModel): + name: Annotated[str, pydantic.Field(min_length=2, max_length=50)] + age: Annotated[int, pydantic.Field(ge=0, le=150)] + email: Annotated[str, pydantic.Field(pattern=r"^[^@]+@[^@]+\.[^@]+$")] + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + email: str + + @strawberry.type + class Mutation: + @strawberry.field + def create_user(self, input: UserInput) -> User: + return User(name=input.name, age=input.age, email=input.email) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with invalid email + mutation_invalid_email = """ + mutation { + createUser(input: { + name: "John" + age: 30 + email: "invalid-email" + }) { + name + age + email + } + } + """ + + result = schema.execute_sync(mutation_invalid_email) + assert result.errors is not None + assert len(result.errors) == 1 + error_message = result.errors[0].message + assert "1 validation error for UserInput" in error_message + assert "email" in error_message + assert "string_pattern_mismatch" in error_message + + +def test_input_type_with_invalid_name_length(): + """Test input type with name validation errors.""" + + @strawberry.pydantic.input + class UserInput(pydantic.BaseModel): + name: Annotated[str, pydantic.Field(min_length=2, max_length=50)] + age: Annotated[int, pydantic.Field(ge=0, le=150)] + email: Annotated[str, pydantic.Field(pattern=r"^[^@]+@[^@]+\.[^@]+$")] + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + email: str + + @strawberry.type + class Mutation: + @strawberry.field + def create_user(self, input: UserInput) -> User: + return User(name=input.name, age=input.age, email=input.email) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with name too short + mutation_short_name = """ + mutation { + createUser(input: { + name: "J" + age: 30 + email: "john@example.com" + }) { + name + age + email + } + } + """ + + result = schema.execute_sync(mutation_short_name) + assert result.errors is not None + assert len(result.errors) == 1 + error_message = result.errors[0].message + assert "1 validation error for UserInput" in error_message + assert "name" in error_message + assert "string_too_short" in error_message + + +def test_input_type_with_invalid_age_range(): + """Test input type with age validation errors.""" + + @strawberry.pydantic.input + class UserInput(pydantic.BaseModel): + name: Annotated[str, pydantic.Field(min_length=2, max_length=50)] + age: Annotated[int, pydantic.Field(ge=0, le=150)] + email: Annotated[str, pydantic.Field(pattern=r"^[^@]+@[^@]+\.[^@]+$")] + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + email: str + + @strawberry.type + class Mutation: + @strawberry.field + def create_user(self, input: UserInput) -> User: + return User(name=input.name, age=input.age, email=input.email) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with age out of range (negative) + mutation_negative_age = """ + mutation { + createUser(input: { + name: "John" + age: -5 + email: "john@example.com" + }) { + name + age + email + } + } + """ + + result = schema.execute_sync(mutation_negative_age) + assert result.errors is not None + assert len(result.errors) == 1 + error_message = result.errors[0].message + assert "1 validation error for UserInput" in error_message + assert "age" in error_message + assert "greater_than_equal" in error_message + + # Test with age out of range (too high) + mutation_high_age = """ + mutation { + createUser(input: { + name: "John" + age: 200 + email: "john@example.com" + }) { + name + age + email + } + } + """ + + result = schema.execute_sync(mutation_high_age) + assert result.errors is not None + assert len(result.errors) == 1 + error_message = result.errors[0].message + assert "1 validation error for UserInput" in error_message + assert "age" in error_message + assert "less_than_equal" in error_message + + +def test_nested_input_types_with_validation(): + """Test nested input types with validation.""" + + @strawberry.pydantic.input + class AddressInput(pydantic.BaseModel): + street: Annotated[str, pydantic.Field(min_length=5)] + city: Annotated[str, pydantic.Field(min_length=2)] + zipcode: Annotated[str, pydantic.Field(pattern=r"^\d{5}$")] + + @strawberry.pydantic.input + class UserInput(pydantic.BaseModel): + name: str + age: Annotated[int, pydantic.Field(ge=18)] # Must be 18 or older + address: AddressInput + + @strawberry.pydantic.type + class Address(pydantic.BaseModel): + street: str + city: str + zipcode: str + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + address: Address + + @strawberry.type + class Mutation: + @strawberry.field + def create_user(self, input: UserInput) -> User: + return User( + name=input.name, + age=input.age, + address=Address( + street=input.address.street, + city=input.address.city, + zipcode=input.address.zipcode, + ), + ) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with valid nested data + mutation_valid = """ + mutation { + createUser(input: { + name: "Alice" + age: 25 + address: { + street: "123 Main Street" + city: "New York" + zipcode: "12345" + } + }) { + name + age + address { + street + city + zipcode + } + } + } + """ + + result = schema.execute_sync(mutation_valid) + + assert not result.errors + assert result.data == snapshot( + { + "createUser": { + "name": "Alice", + "age": 25, + "address": { + "street": "123 Main Street", + "city": "New York", + "zipcode": "12345", + }, + } + } + ) + + # Test with invalid nested data (invalid zipcode) + mutation_invalid_zip = """ + mutation { + createUser(input: { + name: "Bob" + age: 30 + address: { + street: "456 Elm Street" + city: "Boston" + zipcode: "1234" # Too short + } + }) { + name + age + address { + street + city + zipcode + } + } + } + """ + + result = schema.execute_sync(mutation_invalid_zip) + assert result.errors is not None + assert len(result.errors) == 1 + error_message = result.errors[0].message + assert "1 validation error for AddressInput" in error_message + assert "zipcode" in error_message + assert "string_pattern_mismatch" in error_message + + # Test with invalid nested data (underage) + mutation_underage = """ + mutation { + createUser(input: { + name: "Charlie" + age: 16 # Under 18 + address: { + street: "789 Oak Street" + city: "Chicago" + zipcode: "60601" + } + }) { + name + age + address { + street + city + zipcode + } + } + } + """ + + result = schema.execute_sync(mutation_underage) + assert result.errors is not None + assert len(result.errors) == 1 + error_message = result.errors[0].message + assert "1 validation error for UserInput" in error_message + assert "age" in error_message + assert "greater_than_equal" in error_message + + +def test_input_type_with_custom_validators(): + """Test input types with custom Pydantic validators.""" + + @strawberry.pydantic.input + class RegistrationInput(pydantic.BaseModel): + username: str + password: str + confirm_password: str + age: int + + @pydantic.field_validator("username") + @classmethod + def username_alphanumeric(cls, v: str) -> str: + if not v.isalnum(): + raise ValueError("Username must be alphanumeric") + if len(v) < 3: + raise ValueError("Username must be at least 3 characters long") + return v + + @pydantic.field_validator("password") + @classmethod + def password_strength(cls, v: str) -> str: + if len(v) < 8: + raise ValueError("Password must be at least 8 characters long") + if not any(c.isupper() for c in v): + raise ValueError("Password must contain at least one uppercase letter") + if not any(c.isdigit() for c in v): + raise ValueError("Password must contain at least one digit") + return v + + @pydantic.field_validator("confirm_password") + @classmethod + def passwords_match(cls, v: str, info: pydantic.ValidationInfo) -> str: + if "password" in info.data and v != info.data["password"]: + raise ValueError("Passwords do not match") + return v + + @pydantic.field_validator("age") + @classmethod + def age_requirement(cls, v: int) -> int: + if v < 13: + raise ValueError("Must be at least 13 years old") + return v + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + username: str + age: int + + @strawberry.type + class Mutation: + @strawberry.field + def register(self, input: RegistrationInput) -> User: + return User(username=input.username, age=input.age) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with valid input + mutation_valid = """ + mutation { + register(input: { + username: "john123" + password: "SecurePass123" + confirmPassword: "SecurePass123" + age: 25 + }) { + username + age + } + } + """ + + result = schema.execute_sync(mutation_valid) + + assert not result.errors + assert result.data == snapshot({"register": {"username": "john123", "age": 25}}) + + # Test with non-alphanumeric username + mutation_invalid_username = """ + mutation { + register(input: { + username: "john@123" + password: "SecurePass123" + confirmPassword: "SecurePass123" + age: 25 + }) { + username + age + } + } + """ + + result = schema.execute_sync(mutation_invalid_username) + assert result.errors is not None + assert len(result.errors) == 1 + error_message = result.errors[0].message + assert "1 validation error for RegistrationInput" in error_message + assert "username" in error_message + assert "Username must be alphanumeric" in error_message + + # Test with weak password + mutation_weak_password = """ + mutation { + register(input: { + username: "john123" + password: "weak" + confirmPassword: "weak" + age: 25 + }) { + username + age + } + } + """ + + result = schema.execute_sync(mutation_weak_password) + assert result.errors is not None + assert len(result.errors) == 1 + error_message = result.errors[0].message + assert "1 validation error for RegistrationInput" in error_message + assert "password" in error_message + assert "Password must be at least 8 characters long" in error_message + + # Test with mismatched passwords + mutation_mismatch_password = """ + mutation { + register(input: { + username: "john123" + password: "SecurePass123" + confirmPassword: "DifferentPass123" + age: 25 + }) { + username + age + } + } + """ + + result = schema.execute_sync(mutation_mismatch_password) + assert result.errors is not None + assert len(result.errors) == 1 + error_message = result.errors[0].message + assert "1 validation error for RegistrationInput" in error_message + assert "confirm_password" in error_message + assert "Passwords do not match" in error_message + + # Test with underage user + mutation_underage = """ + mutation { + register(input: { + username: "kid123" + password: "SecurePass123" + confirmPassword: "SecurePass123" + age: 10 + }) { + username + age + } + } + """ + + result = schema.execute_sync(mutation_underage) + assert result.errors is not None + assert len(result.errors) == 1 + error_message = result.errors[0].message + assert "1 validation error for RegistrationInput" in error_message + assert "age" in error_message + assert "Must be at least 13 years old" in error_message + + +def test_input_type_with_optional_fields_and_validation(): + """Test input types with optional fields and validation.""" + + @strawberry.pydantic.input + class UpdateProfileInput(pydantic.BaseModel): + bio: Annotated[Optional[str], pydantic.Field(None, max_length=200)] + website: Annotated[Optional[str], pydantic.Field(None, pattern=r"^https?://.*")] + age: Annotated[Optional[int], pydantic.Field(None, ge=0, le=150)] + + @strawberry.pydantic.type + class Profile(pydantic.BaseModel): + bio: Optional[str] = None + website: Optional[str] = None + age: Optional[int] = None + + @strawberry.type + class Mutation: + @strawberry.field + def update_profile(self, input: UpdateProfileInput) -> Profile: + return Profile(bio=input.bio, website=input.website, age=input.age) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with all valid optional fields + mutation_all_fields = """ + mutation { + updateProfile(input: { + bio: "Software developer" + website: "https://example.com" + age: 30 + }) { + bio + website + age + } + } + """ + + result = schema.execute_sync(mutation_all_fields) + + assert not result.errors + assert result.data == snapshot( + { + "updateProfile": { + "bio": "Software developer", + "website": "https://example.com", + "age": 30, + } + } + ) + + # Test with only some fields + mutation_partial = """ + mutation { + updateProfile(input: { + bio: "Just a bio" + }) { + bio + website + age + } + } + """ + + result = schema.execute_sync(mutation_partial) + + assert not result.errors + assert result.data == snapshot( + {"updateProfile": {"bio": "Just a bio", "website": None, "age": None}} + ) + + # Test with invalid website URL + mutation_invalid_url = """ + mutation { + updateProfile(input: { + website: "not-a-url" + }) { + bio + website + age + } + } + """ + + result = schema.execute_sync(mutation_invalid_url) + assert result.errors is not None + assert len(result.errors) == 1 + error_message = result.errors[0].message + assert "1 validation error for UpdateProfileInput" in error_message + assert "website" in error_message + assert "string_pattern_mismatch" in error_message + + # Test with bio too long + long_bio = "x" * 201 + mutation_long_bio = f""" + mutation {{ + updateProfile(input: {{ + bio: "{long_bio}" + }}) {{ + bio + website + age + }} + }} + """ + + result = schema.execute_sync(mutation_long_bio) + assert result.errors is not None + assert len(result.errors) == 1 + error_message = result.errors[0].message + assert "1 validation error for UpdateProfileInput" in error_message + assert "bio" in error_message + assert "string_too_long" in error_message diff --git a/tests/pydantic/test_interface.py b/tests/pydantic/test_interface.py new file mode 100644 index 0000000000..6b71557670 --- /dev/null +++ b/tests/pydantic/test_interface.py @@ -0,0 +1,53 @@ +import pydantic +from inline_snapshot import snapshot + +import strawberry +from strawberry.types.base import get_object_definition + + +def test_basic_interface_type(): + """Test that @strawberry.pydantic.interface works.""" + + @strawberry.pydantic.interface + class Node(pydantic.BaseModel): + id: str + + definition = get_object_definition(Node, strict=True) + + assert definition.name == "Node" + assert definition.is_interface is True + assert len(definition.fields) == 1 + + +def test_pydantic_interface_basic(): + """Test basic Pydantic interface functionality.""" + + @strawberry.pydantic.interface + class Node(pydantic.BaseModel): + id: str + + @strawberry.pydantic.type + class User(Node): + name: str + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User(id="user_1", name="John") + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUser { + id + name + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == snapshot({"getUser": {"id": "user_1", "name": "John"}}) diff --git a/tests/pydantic/test_model_config.py b/tests/pydantic/test_model_config.py new file mode 100644 index 0000000000..dceaf7d63e --- /dev/null +++ b/tests/pydantic/test_model_config.py @@ -0,0 +1,399 @@ +"""Tests for Pydantic model_config support with first-class integration.""" + +import pydantic +from pydantic import ConfigDict + +import strawberry + + +def test_strict_mode_rejects_type_coercion(): + """Test that strict=True rejects type coercion.""" + + @strawberry.pydantic.input + class StrictInput(pydantic.BaseModel): + model_config = ConfigDict(strict=True) + + age: int + name: str + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + age: int + name: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_user(self, input: StrictInput) -> User: + return User(age=input.age, name=input.name) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with correct types - should work + result = schema.execute_sync( + """ + mutation { + createUser(input: { age: 25, name: "Alice" }) { + age + name + } + } + """ + ) + + assert not result.errors + assert result.data["createUser"]["age"] == 25 + + # Note: GraphQL handles type coercion at the schema level, so + # passing "25" as a string would be rejected by GraphQL itself before + # reaching Pydantic. This test verifies the model config is respected + # when values reach Pydantic. + + +def test_extra_forbid_rejects_unknown_fields(): + """Test that extra='forbid' rejects unknown fields at Pydantic level. + + Note: GraphQL schemas already enforce known fields, but this ensures + Pydantic's extra='forbid' is respected if data comes from other sources. + """ + + @strawberry.pydantic.input + class StrictUserInput(pydantic.BaseModel): + model_config = ConfigDict(extra="forbid") + + name: str + age: int + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_user(self, input: StrictUserInput) -> User: + return User(name=input.name, age=input.age) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Valid input should work + result = schema.execute_sync( + """ + mutation { + createUser(input: { name: "Alice", age: 25 }) { + name + age + } + } + """ + ) + + assert not result.errors + assert result.data["createUser"]["name"] == "Alice" + + +def test_extra_allow_stores_extra_fields(): + """Test that extra='allow' stores extra fields in __pydantic_extra__.""" + + @strawberry.pydantic.input + class FlexibleInput(pydantic.BaseModel): + model_config = ConfigDict(extra="allow") + + name: str + + @strawberry.pydantic.type + class Result(pydantic.BaseModel): + name: str + extra_count: int + + @strawberry.type + class Mutation: + @strawberry.mutation + def process(self, input: FlexibleInput) -> Result: + extra_count = len(input.__pydantic_extra__ or {}) + return Result(name=input.name, extra_count=extra_count) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # GraphQL won't allow unknown fields at query level, + # but the config should still be applied + result = schema.execute_sync( + """ + mutation { + process(input: { name: "Alice" }) { + name + extraCount + } + } + """ + ) + + assert not result.errors + assert result.data["process"]["name"] == "Alice" + # No extra fields passed through GraphQL + assert result.data["process"]["extraCount"] == 0 + + +def test_from_attributes_with_dataclass(): + """Test that from_attributes=True allows populating from objects.""" + + from dataclasses import dataclass + + @dataclass + class ORMUser: + id: int + name: str + email: str + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: int + name: str + email: str + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + # Simulate getting data from ORM + orm_user = ORMUser(id=1, name="Alice", email="alice@example.com") + # Use model_validate to populate from attributes + return User.model_validate(orm_user) + + schema = strawberry.Schema(query=Query) + + result = schema.execute_sync( + """ + query { + user { + id + name + email + } + } + """ + ) + + assert not result.errors + assert result.data["user"]["id"] == 1 + assert result.data["user"]["name"] == "Alice" + assert result.data["user"]["email"] == "alice@example.com" + + +def test_validate_default_runs_validators_on_defaults(): + """Test that validate_default=True validates default values.""" + + @strawberry.pydantic.input + class ConfigInput(pydantic.BaseModel): + model_config = ConfigDict(validate_default=True) + + count: int = 10 + + @pydantic.field_validator("count") + @classmethod + def check_count(cls, v: int) -> int: + if v < 0: + raise ValueError("count must be non-negative") + return v + + @strawberry.pydantic.type + class Config(pydantic.BaseModel): + count: int + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_config(self, input: ConfigInput) -> Config: + return Config(count=input.count) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Using default value - should work (10 is valid) + result = schema.execute_sync( + """ + mutation { + createConfig(input: {}) { + count + } + } + """ + ) + + assert not result.errors + assert result.data["createConfig"]["count"] == 10 + + +def test_populate_by_name_allows_field_name_or_alias(): + """Test that populate_by_name=True allows using either field name or alias.""" + + @strawberry.pydantic.input + class FlexibleNameInput(pydantic.BaseModel): + model_config = ConfigDict(populate_by_name=True) + + user_name: str = pydantic.Field(alias="userName") + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + user_name: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_user(self, input: FlexibleNameInput) -> User: + return User(user_name=input.user_name) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # GraphQL uses the alias (userName) for the field name + result = schema.execute_sync( + """ + mutation { + createUser(input: { userName: "Alice" }) { + userName + } + } + """ + ) + + assert not result.errors + assert result.data["createUser"]["userName"] == "Alice" + + +def test_str_strip_whitespace_config(): + """Test that str_strip_whitespace=True strips whitespace from strings.""" + + @strawberry.pydantic.input + class CleanInput(pydantic.BaseModel): + model_config = ConfigDict(str_strip_whitespace=True) + + name: str + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_user(self, input: CleanInput) -> User: + return User(name=input.name) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + result = schema.execute_sync( + """ + mutation { + createUser(input: { name: " Alice " }) { + name + } + } + """ + ) + + assert not result.errors + # Whitespace should be stripped + assert result.data["createUser"]["name"] == "Alice" + + +def test_multiple_config_options_combined(): + """Test combining multiple config options.""" + + @strawberry.pydantic.input + class StrictCleanInput(pydantic.BaseModel): + model_config = ConfigDict( + strict=True, + str_strip_whitespace=True, + str_min_length=1, # Ensure non-empty after strip + ) + + name: str + age: int + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_user(self, input: StrictCleanInput) -> User: + return User(name=input.name, age=input.age) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Valid input with whitespace that gets stripped + result = schema.execute_sync( + """ + mutation { + createUser(input: { name: " Alice ", age: 25 }) { + name + age + } + } + """ + ) + + assert not result.errors + assert result.data["createUser"]["name"] == "Alice" + assert result.data["createUser"]["age"] == 25 + + # Empty string after strip should fail (str_min_length=1) + result = schema.execute_sync( + """ + mutation { + createUser(input: { name: " ", age: 25 }) { + name + age + } + } + """ + ) + + assert result.errors is not None + assert "string_too_short" in result.errors[0].message diff --git a/tests/pydantic/test_model_validators.py b/tests/pydantic/test_model_validators.py new file mode 100644 index 0000000000..14ec2bb5e2 --- /dev/null +++ b/tests/pydantic/test_model_validators.py @@ -0,0 +1,443 @@ +"""Tests for Pydantic v2 @model_validator with first-class integration.""" + +from datetime import date +from typing import Any + +import pydantic +from pydantic import model_validator + +import strawberry + + +def test_model_validator_after_mode(): + """Test @model_validator with mode='after' for cross-field validation.""" + + @strawberry.pydantic.input + class DateRangeInput(pydantic.BaseModel): + start_date: date + end_date: date + + @model_validator(mode="after") + def check_dates(self) -> "DateRangeInput": + if self.start_date > self.end_date: + raise ValueError("start_date must be before end_date") + return self + + @strawberry.pydantic.type + class DateRange(pydantic.BaseModel): + start_date: date + end_date: date + + @strawberry.type + class Query: + @strawberry.field + def validate_range(self, input: DateRangeInput) -> DateRange: + return DateRange(start_date=input.start_date, end_date=input.end_date) + + schema = strawberry.Schema(query=Query) + + # Test valid date range + result = schema.execute_sync( + """ + query { + validateRange(input: { startDate: "2024-01-01", endDate: "2024-12-31" }) { + startDate + endDate + } + } + """ + ) + + assert not result.errors + assert result.data["validateRange"]["startDate"] == "2024-01-01" + assert result.data["validateRange"]["endDate"] == "2024-12-31" + + # Test invalid date range (start after end) + result = schema.execute_sync( + """ + query { + validateRange(input: { startDate: "2024-12-31", endDate: "2024-01-01" }) { + startDate + endDate + } + } + """ + ) + + assert result.errors is not None + assert len(result.errors) == 1 + assert "start_date must be before end_date" in result.errors[0].message + + +def test_model_validator_before_mode(): + """Test @model_validator with mode='before' for input transformation.""" + + @strawberry.pydantic.input + class UserInput(pydantic.BaseModel): + email: str + username: str + + @model_validator(mode="before") + @classmethod + def extract_username(cls, data: Any) -> Any: + if isinstance(data, dict) and "email" in data and not data.get("username"): + # Auto-generate username from email + data["username"] = data["email"].split("@")[0] + return data + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + email: str + username: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_user(self, input: UserInput) -> User: + return User(email=input.email, username=input.username) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with explicit username + result = schema.execute_sync( + """ + mutation { + createUser(input: { email: "alice@example.com", username: "alice_custom" }) { + email + username + } + } + """ + ) + + assert not result.errors + assert result.data["createUser"]["username"] == "alice_custom" + + # Test with auto-generated username - note: this won't work with GraphQL + # since GraphQL requires the field to be provided. This is more for + # programmatic usage where the validator can add missing fields. + + +def test_model_validator_password_confirmation(): + """Test @model_validator for password confirmation pattern.""" + + @strawberry.pydantic.input + class RegistrationInput(pydantic.BaseModel): + email: str + password: str + password_confirm: str + + @model_validator(mode="after") + def check_passwords_match(self) -> "RegistrationInput": + if self.password != self.password_confirm: + raise ValueError("Passwords do not match") + return self + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + email: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def register(self, input: RegistrationInput) -> User: + return User(email=input.email) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test matching passwords + result = schema.execute_sync( + """ + mutation { + register(input: { + email: "test@example.com" + password: "secret123" + passwordConfirm: "secret123" + }) { + email + } + } + """ + ) + + assert not result.errors + assert result.data["register"]["email"] == "test@example.com" + + # Test mismatched passwords + result = schema.execute_sync( + """ + mutation { + register(input: { + email: "test@example.com" + password: "secret123" + passwordConfirm: "different456" + }) { + email + } + } + """ + ) + + assert result.errors is not None + assert "Passwords do not match" in result.errors[0].message + + +def test_model_validator_conditional_required_fields(): + """Test @model_validator for conditional field requirements.""" + + @strawberry.pydantic.input + class PaymentInput(pydantic.BaseModel): + payment_method: str + card_number: str | None = None + bank_account: str | None = None + + @model_validator(mode="after") + def check_payment_details(self) -> "PaymentInput": + if self.payment_method == "card" and not self.card_number: + raise ValueError("Card number required for card payments") + if self.payment_method == "bank" and not self.bank_account: + raise ValueError("Bank account required for bank payments") + return self + + @strawberry.pydantic.type + class PaymentResult(pydantic.BaseModel): + success: bool + method: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def process_payment(self, input: PaymentInput) -> PaymentResult: + return PaymentResult(success=True, method=input.payment_method) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test valid card payment + result = schema.execute_sync( + """ + mutation { + processPayment(input: { + paymentMethod: "card" + cardNumber: "4111111111111111" + }) { + success + method + } + } + """ + ) + + assert not result.errors + assert result.data["processPayment"]["success"] is True + + # Test card payment without card number + result = schema.execute_sync( + """ + mutation { + processPayment(input: { + paymentMethod: "card" + }) { + success + method + } + } + """ + ) + + assert result.errors is not None + assert "Card number required" in result.errors[0].message + + # Test bank payment without bank account + result = schema.execute_sync( + """ + mutation { + processPayment(input: { + paymentMethod: "bank" + }) { + success + method + } + } + """ + ) + + assert result.errors is not None + assert "Bank account required" in result.errors[0].message + + +def test_model_validator_nested_inputs(): + """Test @model_validator with nested input types.""" + + @strawberry.pydantic.input + class AddressInput(pydantic.BaseModel): + street: str + city: str + country: str + + @strawberry.pydantic.input + class OrderInput(pydantic.BaseModel): + billing_address: AddressInput + shipping_address: AddressInput | None = None + same_as_billing: bool = False + + @model_validator(mode="after") + def check_shipping(self) -> "OrderInput": + if not self.same_as_billing and not self.shipping_address: + raise ValueError( + "Shipping address required unless same_as_billing is true" + ) + return self + + @strawberry.pydantic.type + class Order(pydantic.BaseModel): + id: int + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_order(self, input: OrderInput) -> Order: + return Order(id=1) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with same_as_billing=true (no shipping needed) + result = schema.execute_sync( + """ + mutation { + createOrder(input: { + billingAddress: { street: "123 Main", city: "NYC", country: "USA" } + sameAsBilling: true + }) { + id + } + } + """ + ) + + assert not result.errors + assert result.data["createOrder"]["id"] == 1 + + # Test with explicit shipping address + result = schema.execute_sync( + """ + mutation { + createOrder(input: { + billingAddress: { street: "123 Main", city: "NYC", country: "USA" } + shippingAddress: { street: "456 Oak", city: "LA", country: "USA" } + sameAsBilling: false + }) { + id + } + } + """ + ) + + assert not result.errors + + # Test missing shipping when required + result = schema.execute_sync( + """ + mutation { + createOrder(input: { + billingAddress: { street: "123 Main", city: "NYC", country: "USA" } + sameAsBilling: false + }) { + id + } + } + """ + ) + + assert result.errors is not None + assert "Shipping address required" in result.errors[0].message + + +def test_model_validator_multiple_errors(): + """Test @model_validator that raises multiple validation errors.""" + + @strawberry.pydantic.input + class ProfileInput(pydantic.BaseModel): + username: str + age: int + website: str | None = None + + @model_validator(mode="after") + def validate_profile(self) -> "ProfileInput": + errors = [] + + if len(self.username) < 3: + errors.append("Username must be at least 3 characters") + + if self.age < 13: + errors.append("Must be at least 13 years old") + + if self.website and not self.website.startswith(("http://", "https://")): + errors.append("Website must start with http:// or https://") + + if errors: + raise ValueError("; ".join(errors)) + + return self + + @strawberry.pydantic.type + class Profile(pydantic.BaseModel): + username: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_profile(self, input: ProfileInput) -> Profile: + return Profile(username=input.username) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with multiple validation errors + result = schema.execute_sync( + """ + mutation { + createProfile(input: { + username: "ab" + age: 10 + website: "invalid" + }) { + username + } + } + """ + ) + + assert result.errors is not None + error_message = result.errors[0].message + assert "Username must be at least 3 characters" in error_message + assert "Must be at least 13 years old" in error_message + assert "Website must start with" in error_message diff --git a/tests/pydantic/test_nested_types.py b/tests/pydantic/test_nested_types.py new file mode 100644 index 0000000000..3ade42251d --- /dev/null +++ b/tests/pydantic/test_nested_types.py @@ -0,0 +1,184 @@ +""" +Nested type tests for Pydantic integration. + +These tests verify that nested Pydantic types work correctly in GraphQL. +""" + +from typing import Optional + +import pydantic +from inline_snapshot import snapshot + +import strawberry + + +def test_nested_pydantic_types(): + """Test nested Pydantic types in queries.""" + + @strawberry.pydantic.type + class Address(pydantic.BaseModel): + street: str + city: str + zipcode: str + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + address: Address + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User( + name="John", + age=30, + address=Address(street="123 Main St", city="Anytown", zipcode="12345"), + ) + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUser { + name + age + address { + street + city + zipcode + } + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == snapshot( + { + "getUser": { + "name": "John", + "age": 30, + "address": { + "street": "123 Main St", + "city": "Anytown", + "zipcode": "12345", + }, + } + } + ) + + +def test_list_of_pydantic_types(): + """Test lists of Pydantic types.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + + @strawberry.type + class Query: + @strawberry.field + def get_users(self) -> list[User]: + return [ + User(name="John", age=30), + User(name="Jane", age=25), + User(name="Bob", age=35), + ] + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUsers { + name + age + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == snapshot( + { + "getUsers": [ + {"name": "John", "age": 30}, + {"name": "Jane", "age": 25}, + {"name": "Bob", "age": 35}, + ] + } + ) + + +def test_complex_pydantic_types_execution(): + """Test complex Pydantic types with various field types.""" + + @strawberry.pydantic.type + class Profile(pydantic.BaseModel): + bio: Optional[str] = None + website: Optional[str] = None + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + id: int + name: str + email: str + is_active: bool + tags: list[str] = [] + profile: Optional[Profile] = None + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User( + id=1, + name="John Doe", + email="john@example.com", + is_active=True, + tags=["developer", "python", "graphql"], + profile=Profile( + bio="Software developer", website="https://johndoe.com" + ), + ) + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUser { + id + name + email + isActive + tags + profile { + bio + website + } + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == snapshot( + { + "getUser": { + "id": 1, + "name": "John Doe", + "email": "john@example.com", + "isActive": True, + "tags": ["developer", "python", "graphql"], + "profile": { + "bio": "Software developer", + "website": "https://johndoe.com", + }, + } + } + ) diff --git a/tests/pydantic/test_private.py b/tests/pydantic/test_private.py new file mode 100644 index 0000000000..e7487066df --- /dev/null +++ b/tests/pydantic/test_private.py @@ -0,0 +1,126 @@ +import pydantic +from inline_snapshot import snapshot + +import strawberry +from strawberry.types.base import get_object_definition + + +def test_strawberry_private_fields(): + """Test that strawberry.Private fields are excluded from the GraphQL schema.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + id: int + name: str + age: int + password: strawberry.Private[str] + + definition = get_object_definition(User, strict=True) + assert definition.name == "User" + + # Should have three fields (id, name, age) - password should be excluded + assert len(definition.fields) == 3 + + field_names = {f.python_name for f in definition.fields} + assert field_names == {"id", "name", "age"} + + # password field should not be in the GraphQL schema + assert "password" not in field_names + + # But the python object should still have the password field + user = User(id=1, name="John", age=30, password="secret") + assert user.id == 1 + assert user.name == "John" + assert user.age == 30 + assert user.password == "secret" + + +def test_strawberry_private_fields_access(): + """Test that strawberry.Private fields can be accessed in Python code.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + id: int + name: str + password: strawberry.Private[str] + + definition = get_object_definition(User, strict=True) + assert definition.name == "User" + + # Should have two fields (id, name) - password should be excluded + assert len(definition.fields) == 2 + + field_names = {f.python_name for f in definition.fields} + assert field_names == {"id", "name"} + + # Test that the private field is still accessible on the instance + user = User(id=1, name="John", password="secret") + assert user.id == 1 + assert user.name == "John" + assert user.password == "secret" + + # Test that we can use the private field in Python logic + def has_password(user: User) -> bool: + return bool(user.password) + + assert has_password(user) is True + + user_no_password = User(id=2, name="Jane", password="") + assert has_password(user_no_password) is False + + +def test_strawberry_private_fields_not_in_schema(): + """Test that strawberry.Private fields are not exposed in GraphQL schema.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + id: int + name: str + password: strawberry.Private[str] + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User(id=1, name="John", password="secret123") + + schema = strawberry.Schema(query=Query) + + # Check that password field is not in the schema + schema_str = str(schema) + assert "password" not in schema_str + assert "id: Int!" in schema_str + assert "name: String!" in schema_str + + # Test that we can query the exposed fields + query = """ + query { + getUser { + id + name + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == snapshot({"getUser": {"id": 1, "name": "John"}}) + + # Test that querying the private field fails + query_with_private = """ + query { + getUser { + id + name + password + } + } + """ + + result = schema.execute_sync(query_with_private) + assert result.errors + assert len(result.errors) == 1 + assert result.errors[0].message == snapshot( + "Cannot query field 'password' on type 'User'." + ) diff --git a/tests/pydantic/test_queries_mutations.py b/tests/pydantic/test_queries_mutations.py new file mode 100644 index 0000000000..166d593068 --- /dev/null +++ b/tests/pydantic/test_queries_mutations.py @@ -0,0 +1,187 @@ +""" +Query and mutation execution tests for Pydantic integration. + +These tests verify that Pydantic models work correctly in GraphQL queries and mutations. +""" + +from typing import Optional + +import pydantic +from inline_snapshot import snapshot + +import strawberry + + +def test_basic_query_execution(): + """Test basic query execution with Pydantic types.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User(name="John", age=30) + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUser { + name + age + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == snapshot({"getUser": {"name": "John", "age": 30}}) + + +def test_query_with_optional_fields(): + """Test query execution with optional fields.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + email: Optional[str] = None + age: Optional[int] = None + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User(name="John", email="john@example.com") + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUser { + name + email + age + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == snapshot( + {"getUser": {"name": "John", "email": "john@example.com", "age": None}} + ) + + +def test_mutation_with_input_types(): + """Test mutation execution with Pydantic input types.""" + + @strawberry.pydantic.input + class CreateUserInput(pydantic.BaseModel): + name: str + age: int + email: Optional[str] = None + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + id: int + name: str + age: int + email: Optional[str] = None + + @strawberry.type + class Mutation: + @strawberry.field + def create_user(self, input: CreateUserInput) -> User: + return User(id=1, name=input.name, age=input.age, email=input.email) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + mutation = """ + mutation { + createUser(input: { + name: "Alice" + age: 25 + email: "alice@example.com" + }) { + id + name + age + email + } + } + """ + + result = schema.execute_sync(mutation) + + assert not result.errors + assert result.data == snapshot( + { + "createUser": { + "id": 1, + "name": "Alice", + "age": 25, + "email": "alice@example.com", + } + } + ) + + +def test_mutation_with_partial_input(): + """Test mutation with partial input (optional fields).""" + + @strawberry.pydantic.input + class UpdateUserInput(pydantic.BaseModel): + name: Optional[str] = None + age: Optional[int] = None + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + id: int + name: str + age: int + + @strawberry.type + class Mutation: + @strawberry.field + def update_user(self, id: int, input: UpdateUserInput) -> User: + # Simulate updating a user + return User(id=id, name=input.name or "Default Name", age=input.age or 18) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + mutation = """ + mutation { + updateUser(id: 1, input: { + name: "Updated Name" + }) { + id + name + age + } + } + """ + + result = schema.execute_sync(mutation) + + assert not result.errors + assert result.data == snapshot( + {"updateUser": {"id": 1, "name": "Updated Name", "age": 18}} + ) diff --git a/tests/pydantic/test_root_model.py b/tests/pydantic/test_root_model.py new file mode 100644 index 0000000000..0c5e0e4b64 --- /dev/null +++ b/tests/pydantic/test_root_model.py @@ -0,0 +1,313 @@ +"""Tests for Pydantic v2 RootModel with Strawberry. + +RootModel allows wrapping a single value in a model with validation. +This is useful for validating scalars, lists, or dicts with custom validation. +""" + +from typing import Annotated + +import pydantic +from pydantic import Field, RootModel + +import strawberry + + +def test_root_model_with_list_in_resolver(): + """Test using RootModel to wrap a list with validation in a resolver.""" + + class TagList(RootModel[list[str]]): + """A validated list of tags.""" + + def __iter__(self): + return iter(self.root) + + def __len__(self): + return len(self.root) + + @strawberry.type + class Query: + @strawberry.field + def process_tags(self, tags: list[str]) -> list[str]: + # Use RootModel for validation + validated = TagList.model_validate(tags) + return list(validated) + + schema = strawberry.Schema(query=Query) + + result = schema.execute_sync( + """ + query { + processTags(tags: ["python", "graphql", "strawberry"]) + } + """ + ) + + assert not result.errors + assert result.data["processTags"] == ["python", "graphql", "strawberry"] + + +def test_root_model_with_constrained_list(): + """Test RootModel with a constrained list (min/max items).""" + + class BoundedList( + RootModel[Annotated[list[int], Field(min_length=1, max_length=5)]] + ): + """A list with 1-5 items.""" + + @strawberry.type + class Query: + @strawberry.field + def bounded_list(self, items: list[int]) -> list[int]: + validated = BoundedList.model_validate(items) + return validated.root + + schema = strawberry.Schema(query=Query) + + # Valid input + result = schema.execute_sync( + """ + query { + boundedList(items: [1, 2, 3]) + } + """ + ) + + assert not result.errors + assert result.data["boundedList"] == [1, 2, 3] + + # Too many items + result = schema.execute_sync( + """ + query { + boundedList(items: [1, 2, 3, 4, 5, 6]) + } + """ + ) + + assert result.errors is not None + assert "too_long" in result.errors[0].message + + +def test_root_model_with_dict(): + """Test RootModel wrapping a dictionary.""" + + class StringDict(RootModel[dict[str, str]]): + """A string-to-string dictionary.""" + + @strawberry.type + class Query: + @strawberry.field + def dict_values(self) -> list[str]: + data = StringDict.model_validate({"key1": "value1", "key2": "value2"}) + return list(data.root.values()) + + schema = strawberry.Schema(query=Query) + + result = schema.execute_sync( + """ + query { + dictValues + } + """ + ) + + assert not result.errors + assert set(result.data["dictValues"]) == {"value1", "value2"} + + +def test_root_model_with_scalar(): + """Test RootModel wrapping a scalar with constraints.""" + + class PositiveInt(RootModel[Annotated[int, Field(gt=0)]]): + """A positive integer wrapper.""" + + @strawberry.type + class Query: + @strawberry.field + def positive_value(self, value: int) -> int: + validated = PositiveInt.model_validate(value) + return validated.root + + schema = strawberry.Schema(query=Query) + + # Valid input + result = schema.execute_sync( + """ + query { + positiveValue(value: 42) + } + """ + ) + + assert not result.errors + assert result.data["positiveValue"] == 42 + + # Invalid input + result = schema.execute_sync( + """ + query { + positiveValue(value: -1) + } + """ + ) + + assert result.errors is not None + assert "greater_than" in result.errors[0].message + + +def test_root_model_with_validators(): + """Test RootModel with custom validators.""" + + class UniqueStrings(RootModel[list[str]]): + """A list that must contain unique strings.""" + + @pydantic.field_validator("root") + @classmethod + def check_unique(cls, v: list[str]) -> list[str]: + if len(v) != len(set(v)): + raise ValueError("All items must be unique") + return v + + @strawberry.type + class Query: + @strawberry.field + def unique_items(self, items: list[str]) -> list[str]: + validated = UniqueStrings.model_validate(items) + return validated.root + + schema = strawberry.Schema(query=Query) + + # Valid input (unique items) + result = schema.execute_sync( + """ + query { + uniqueItems(items: ["a", "b", "c"]) + } + """ + ) + + assert not result.errors + assert result.data["uniqueItems"] == ["a", "b", "c"] + + # Invalid input (duplicate items) + result = schema.execute_sync( + """ + query { + uniqueItems(items: ["a", "b", "a"]) + } + """ + ) + + assert result.errors is not None + assert "unique" in result.errors[0].message.lower() + + +def test_root_model_in_output_type(): + """Test using RootModel in output type context.""" + + class Scores(RootModel[list[int]]): + """A list of scores.""" + + @strawberry.pydantic.type + class GameResult(pydantic.BaseModel): + player_name: str + scores: list[int] + + @strawberry.type + class Query: + @strawberry.field + def game_result(self) -> GameResult: + # Use RootModel to validate scores before creating result + validated_scores = Scores.model_validate([85, 90, 78]) + return GameResult(player_name="Alice", scores=validated_scores.root) + + schema = strawberry.Schema(query=Query) + + result = schema.execute_sync( + """ + query { + gameResult { + playerName + scores + } + } + """ + ) + + assert not result.errors + assert result.data["gameResult"]["playerName"] == "Alice" + assert result.data["gameResult"]["scores"] == [85, 90, 78] + + +def test_root_model_nested_validation(): + """Test RootModel with nested models in a resolver.""" + + class InnerItem(pydantic.BaseModel): + name: str + quantity: int = Field(ge=1) + + class ItemList(RootModel[list[InnerItem]]): + """A validated list of items.""" + + @pydantic.field_validator("root") + @classmethod + def check_not_empty(cls, v: list[InnerItem]) -> list[InnerItem]: + if not v: + raise ValueError("Item list cannot be empty") + return v + + @strawberry.pydantic.type + class OrderSummary(pydantic.BaseModel): + total_items: int + item_names: list[str] + + @strawberry.type + class Query: + @strawberry.field + def process_items( + self, names: list[str], quantities: list[int] + ) -> OrderSummary: + # Build list of items and validate with RootModel + raw_items = [ + {"name": n, "quantity": q} + for n, q in zip(names, quantities, strict=True) + ] + validated = ItemList.model_validate(raw_items) + return OrderSummary( + total_items=sum(item.quantity for item in validated.root), + item_names=[item.name for item in validated.root], + ) + + schema = strawberry.Schema(query=Query) + + result = schema.execute_sync( + """ + query { + processItems( + names: ["Widget", "Gadget"], + quantities: [2, 3] + ) { + totalItems + itemNames + } + } + """ + ) + + assert not result.errors + assert result.data["processItems"]["totalItems"] == 5 + assert result.data["processItems"]["itemNames"] == ["Widget", "Gadget"] + + # Test validation failure (empty list) + result = schema.execute_sync( + """ + query { + processItems(names: [], quantities: []) { + totalItems + } + } + """ + ) + + assert result.errors is not None + assert "empty" in result.errors[0].message.lower() diff --git a/tests/pydantic/test_strict_mode.py b/tests/pydantic/test_strict_mode.py new file mode 100644 index 0000000000..6fdada1517 --- /dev/null +++ b/tests/pydantic/test_strict_mode.py @@ -0,0 +1,322 @@ +"""Tests for Pydantic per-field strict mode with first-class integration.""" + +import pydantic +from pydantic import Field + +import strawberry + + +def test_strict_field_rejects_wrong_type(): + """Test that Field(strict=True) enforces exact types.""" + + @strawberry.pydantic.input + class MixedInput(pydantic.BaseModel): + strict_age: int = Field(strict=True) + flexible_count: int # Not strict - allows coercion + + @strawberry.pydantic.type + class Result(pydantic.BaseModel): + strict_age: int + flexible_count: int + + @strawberry.type + class Mutation: + @strawberry.mutation + def process(self, input: MixedInput) -> Result: + return Result( + strict_age=input.strict_age, flexible_count=input.flexible_count + ) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Both with correct types - should work + result = schema.execute_sync( + """ + mutation { + process(input: { strictAge: 25, flexibleCount: 10 }) { + strictAge + flexibleCount + } + } + """ + ) + + assert not result.errors + assert result.data["process"]["strictAge"] == 25 + assert result.data["process"]["flexibleCount"] == 10 + + +def test_strict_string_field(): + """Test strict mode for string fields.""" + + @strawberry.pydantic.input + class StrictStringInput(pydantic.BaseModel): + name: str = Field(strict=True) + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_user(self, input: StrictStringInput) -> User: + return User(name=input.name) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # String input - should work + result = schema.execute_sync( + """ + mutation { + createUser(input: { name: "Alice" }) { + name + } + } + """ + ) + + assert not result.errors + assert result.data["createUser"]["name"] == "Alice" + + +def test_strict_bool_field(): + """Test strict mode for boolean fields.""" + + @strawberry.pydantic.input + class StrictBoolInput(pydantic.BaseModel): + active: bool = Field(strict=True) + + @strawberry.pydantic.type + class Status(pydantic.BaseModel): + active: bool + + @strawberry.type + class Mutation: + @strawberry.mutation + def set_status(self, input: StrictBoolInput) -> Status: + return Status(active=input.active) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Boolean input - should work + result = schema.execute_sync( + """ + mutation { + setStatus(input: { active: true }) { + active + } + } + """ + ) + + assert not result.errors + assert result.data["setStatus"]["active"] is True + + +def test_strict_float_field(): + """Test strict mode for float fields.""" + + @strawberry.pydantic.input + class StrictFloatInput(pydantic.BaseModel): + price: float = Field(strict=True) + + @strawberry.pydantic.type + class Product(pydantic.BaseModel): + price: float + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_product(self, input: StrictFloatInput) -> Product: + return Product(price=input.price) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Float input - should work + result = schema.execute_sync( + """ + mutation { + createProduct(input: { price: 9.99 }) { + price + } + } + """ + ) + + assert not result.errors + assert result.data["createProduct"]["price"] == 9.99 + + # Int input to strict float - GraphQL allows this coercion at schema level + # The value arrives at Pydantic as a Python int/float + result = schema.execute_sync( + """ + mutation { + createProduct(input: { price: 10 }) { + price + } + } + """ + ) + + # GraphQL converts 10 to 10.0 before reaching Pydantic + assert not result.errors + + +def test_mixed_strict_and_non_strict_fields(): + """Test a model with both strict and non-strict fields.""" + + @strawberry.pydantic.input + class MixedStrictnessInput(pydantic.BaseModel): + strict_int: int = Field(strict=True) + strict_str: str = Field(strict=True) + flexible_int: int + flexible_str: str + + @strawberry.pydantic.type + class Result(pydantic.BaseModel): + strict_int: int + strict_str: str + flexible_int: int + flexible_str: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def process(self, input: MixedStrictnessInput) -> Result: + return Result( + strict_int=input.strict_int, + strict_str=input.strict_str, + flexible_int=input.flexible_int, + flexible_str=input.flexible_str, + ) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + result = schema.execute_sync( + """ + mutation { + process(input: { + strictInt: 42 + strictStr: "hello" + flexibleInt: 100 + flexibleStr: "world" + }) { + strictInt + strictStr + flexibleInt + flexibleStr + } + } + """ + ) + + assert not result.errors + assert result.data["process"]["strictInt"] == 42 + assert result.data["process"]["strictStr"] == "hello" + assert result.data["process"]["flexibleInt"] == 100 + assert result.data["process"]["flexibleStr"] == "world" + + +def test_strict_with_constraints(): + """Test strict mode combined with field constraints.""" + + @strawberry.pydantic.input + class ConstrainedInput(pydantic.BaseModel): + age: int = Field(strict=True, ge=0, le=150) + name: str = Field(strict=True, min_length=1, max_length=50) + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + age: int + name: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_user(self, input: ConstrainedInput) -> User: + return User(age=input.age, name=input.name) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Valid input + result = schema.execute_sync( + """ + mutation { + createUser(input: { age: 25, name: "Alice" }) { + age + name + } + } + """ + ) + + assert not result.errors + assert result.data["createUser"]["age"] == 25 + + # Age out of range + result = schema.execute_sync( + """ + mutation { + createUser(input: { age: 200, name: "Alice" }) { + age + name + } + } + """ + ) + + assert result.errors is not None + assert "less_than_equal" in result.errors[0].message + + # Name too long + long_name = "A" * 51 + result = schema.execute_sync( + f""" + mutation {{ + createUser(input: {{ age: 25, name: "{long_name}" }}) {{ + age + name + }} + }} + """ + ) + + assert result.errors is not None + assert "string_too_long" in result.errors[0].message diff --git a/tests/pydantic/test_type.py b/tests/pydantic/test_type.py new file mode 100644 index 0000000000..3df07181b3 --- /dev/null +++ b/tests/pydantic/test_type.py @@ -0,0 +1,136 @@ +from typing import Optional + +import pydantic +from inline_snapshot import snapshot + +import strawberry +from strawberry.types.base import ( + StrawberryOptional, + get_object_definition, +) + + +def test_basic_type_includes_all_fields(): + """Test that @strawberry.pydantic.type includes all fields from the model.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + age: int + password: Optional[str] + + definition = get_object_definition(User, strict=True) + assert definition.name == "User" + + # Should have two fields + assert len(definition.fields) == 2 + + # Find fields by name + age_field = next(f for f in definition.fields if f.python_name == "age") + password_field = next(f for f in definition.fields if f.python_name == "password") + + assert age_field.python_name == "age" + assert age_field.graphql_name is None + assert age_field.type is int + + assert password_field.python_name == "password" + assert password_field.graphql_name is None + assert isinstance(password_field.type, StrawberryOptional) + assert password_field.type.of_type is str + + +def test_basic_type_with_name_override(): + """Test that @strawberry.pydantic.type with name parameter works.""" + + @strawberry.pydantic.type(name="CustomUser") + class User(pydantic.BaseModel): + age: int + + definition = get_object_definition(User, strict=True) + assert definition.name == "CustomUser" + + +def test_basic_type_with_description(): + """Test that @strawberry.pydantic.type with description parameter works.""" + + @strawberry.pydantic.type(description="A user model") + class User(pydantic.BaseModel): + age: int + + definition = get_object_definition(User, strict=True) + assert definition.description == "A user model" + + +def test_is_type_of_method(): + """Test that is_type_of method is added for proper type resolution.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + age: int + name: str + + # Check that is_type_of method exists + assert hasattr(User, "is_type_of") + assert callable(User.is_type_of) + + # Test type checking + user_instance = User(age=25, name="John") + assert User.is_type_of(user_instance, None) is True + + # Test with different type + class Other: + pass + + other_instance = Other() + assert User.is_type_of(other_instance, None) is False + + +def test_schema_generation(): + """Test that the decorated models work in schema generation.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + age: int + name: str + + @strawberry.pydantic.input + class CreateUserInput(pydantic.BaseModel): + age: int + name: str + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User(age=25, name="John") + + @strawberry.type + class Mutation: + @strawberry.field + def create_user(self, input: CreateUserInput) -> User: + return User(age=input.age, name=input.name) + + # Test that schema can be created successfully + schema = strawberry.Schema(query=Query, mutation=Mutation) + assert schema is not None + + assert str(schema) == snapshot( + """\ +input CreateUserInput { + age: Int! + name: String! +} + +type Mutation { + createUser(input: CreateUserInput!): User! +} + +type Query { + getUser: User! +} + +type User { + age: Int! + name: String! +}\ +""" + ) diff --git a/tests/pydantic/test_type_adapter.py b/tests/pydantic/test_type_adapter.py new file mode 100644 index 0000000000..478bef4022 --- /dev/null +++ b/tests/pydantic/test_type_adapter.py @@ -0,0 +1,258 @@ +"""Tests for Pydantic v2 TypeAdapter with Strawberry. + +TypeAdapter allows validation of arbitrary types without creating a full BaseModel. +This is useful for validating scalars, lists, and other types in resolvers. +""" + +from typing import Annotated + +import pydantic +from pydantic import Field, TypeAdapter, ValidationError + +import strawberry + + +def test_type_adapter_for_scalar_validation(): + """Test using TypeAdapter to validate scalar values in a resolver.""" + + # Create a type adapter for validating positive integers + PositiveInt = Annotated[int, Field(gt=0)] + positive_int_adapter = TypeAdapter(PositiveInt) + + @strawberry.type + class Query: + @strawberry.field + def validate_positive(self, value: int) -> int: + # Use TypeAdapter to validate the input + return positive_int_adapter.validate_python(value) + + schema = strawberry.Schema(query=Query) + + # Valid input + result = schema.execute_sync( + """ + query { + validatePositive(value: 42) + } + """ + ) + + assert not result.errors + assert result.data["validatePositive"] == 42 + + # Invalid input (negative) + result = schema.execute_sync( + """ + query { + validatePositive(value: -1) + } + """ + ) + + assert result.errors is not None + assert "greater_than" in result.errors[0].message + + +def test_type_adapter_for_string_validation(): + """Test using TypeAdapter for string validation.""" + + # Create a type adapter for validating email-like strings + EmailStr = Annotated[str, Field(min_length=3, pattern=r".*@.*\..*")] + email_adapter = TypeAdapter(EmailStr) + + @strawberry.type + class Query: + @strawberry.field + def validate_email(self, email: str) -> str: + return email_adapter.validate_python(email) + + schema = strawberry.Schema(query=Query) + + # Valid email + result = schema.execute_sync( + """ + query { + validateEmail(email: "test@example.com") + } + """ + ) + + assert not result.errors + assert result.data["validateEmail"] == "test@example.com" + + # Invalid email + result = schema.execute_sync( + """ + query { + validateEmail(email: "notanemail") + } + """ + ) + + assert result.errors is not None + assert "string_pattern_mismatch" in result.errors[0].message + + +def test_type_adapter_for_list_validation(): + """Test using TypeAdapter to validate list contents.""" + + # Create a type adapter for validating a list of positive integers + PositiveIntList = list[Annotated[int, Field(ge=0)]] + list_adapter = TypeAdapter(PositiveIntList) + + @strawberry.type + class Query: + @strawberry.field + def validate_numbers(self, numbers: list[int]) -> list[int]: + return list_adapter.validate_python(numbers) + + schema = strawberry.Schema(query=Query) + + # Valid input + result = schema.execute_sync( + """ + query { + validateNumbers(numbers: [1, 2, 3]) + } + """ + ) + + assert not result.errors + assert result.data["validateNumbers"] == [1, 2, 3] + + # Invalid input (contains negative) + result = schema.execute_sync( + """ + query { + validateNumbers(numbers: [1, -2, 3]) + } + """ + ) + + assert result.errors is not None + assert "greater_than_equal" in result.errors[0].message + + +def test_type_adapter_with_complex_type(): + """Test TypeAdapter with a more complex nested type.""" + + from typing import TypedDict + + class UserData(TypedDict): + name: str + age: int + + user_adapter = TypeAdapter(UserData) + + @strawberry.pydantic.input + class UserInput(pydantic.BaseModel): + name: str + age: int + + @strawberry.type + class Query: + @strawberry.field + def validate_user(self, input: UserInput) -> str: + # Convert to dict and validate with TypeAdapter + user_data = {"name": input.name, "age": input.age} + validated = user_adapter.validate_python(user_data) + return f"{validated['name']} is {validated['age']} years old" + + schema = strawberry.Schema(query=Query) + + result = schema.execute_sync( + """ + query { + validateUser(input: { name: "Alice", age: 30 }) + } + """ + ) + + assert not result.errors + assert result.data["validateUser"] == "Alice is 30 years old" + + +def test_type_adapter_validation_error_handling(): + """Test that TypeAdapter validation errors are properly handled.""" + + BoundedInt = Annotated[int, Field(ge=0, le=100)] + adapter = TypeAdapter(BoundedInt) + + @strawberry.type + class Query: + @strawberry.field + def bounded_value(self, value: int) -> int: + try: + return adapter.validate_python(value) + except ValidationError as e: + # Re-raise as a more user-friendly error + raise ValueError(f"Invalid value: {e.errors()[0]['msg']}") from None + + schema = strawberry.Schema(query=Query) + + # Value too high + result = schema.execute_sync( + """ + query { + boundedValue(value: 150) + } + """ + ) + + assert result.errors is not None + assert "Invalid value" in result.errors[0].message + + +def test_type_adapter_with_coercion(): + """Test that TypeAdapter coercion works as expected.""" + + # TypeAdapter can coerce string to int if strict=False (default) + adapter = TypeAdapter(int) + + # In resolvers, we can use this for flexible input handling + @strawberry.type + class Query: + @strawberry.field + def coerce_to_int(self, value: str) -> int: + # This would coerce "42" to 42 + return adapter.validate_python(value) + + schema = strawberry.Schema(query=Query) + + result = schema.execute_sync( + """ + query { + coerceToInt(value: "42") + } + """ + ) + + assert not result.errors + assert result.data["coerceToInt"] == 42 + + +def test_type_adapter_strict_mode(): + """Test TypeAdapter with strict mode.""" + + adapter = TypeAdapter(int) + + @strawberry.type + class Query: + @strawberry.field + def strict_int(self, value: str) -> int: + # Strict mode will reject string input + return adapter.validate_python(value, strict=True) + + schema = strawberry.Schema(query=Query) + + result = schema.execute_sync( + """ + query { + strictInt(value: "42") + } + """ + ) + + assert result.errors is not None + # Strict mode rejects string input for int + assert "int_type" in result.errors[0].message or "type" in result.errors[0].message diff --git a/tests/pydantic/test_validation_context.py b/tests/pydantic/test_validation_context.py new file mode 100644 index 0000000000..026a9a1253 --- /dev/null +++ b/tests/pydantic/test_validation_context.py @@ -0,0 +1,348 @@ +"""Tests for Pydantic validation context with Strawberry Info.""" + +from typing import Any + +import pydantic +from pydantic import ValidationInfo, field_validator, model_validator + +import strawberry +from strawberry.types.info import Info + + +def test_validation_context_passed_to_field_validator(): + """Test that Strawberry Info is passed to Pydantic field validators.""" + + received_context: dict[str, Any] = {} + + @strawberry.pydantic.input + class UserInput(pydantic.BaseModel): + name: str + + @field_validator("name") + @classmethod + def capture_context(cls, v: str, info: ValidationInfo) -> str: + # Capture the context for testing + if info.context: + received_context.update(info.context) + return v + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_user(self, input: UserInput) -> User: + return User(name=input.name) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + result = schema.execute_sync( + """ + mutation { + createUser(input: { name: "Alice" }) { + name + } + } + """ + ) + + assert not result.errors + assert result.data["createUser"]["name"] == "Alice" + + # Verify that context was passed (should contain 'info') + assert "info" in received_context + assert isinstance(received_context["info"], Info) + + +def test_validation_context_passed_to_model_validator(): + """Test that Strawberry Info is passed to Pydantic model validators.""" + + received_context: dict[str, Any] = {} + + @strawberry.pydantic.input + class OrderInput(pydantic.BaseModel): + quantity: int + price: int + + @model_validator(mode="after") + def capture_context(self, info: ValidationInfo) -> "OrderInput": + if info.context: + received_context.update(info.context) + return self + + @strawberry.pydantic.type + class Order(pydantic.BaseModel): + quantity: int + price: int + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_order(self, input: OrderInput) -> Order: + return Order(quantity=input.quantity, price=input.price) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + result = schema.execute_sync( + """ + mutation { + createOrder(input: { quantity: 5, price: 100 }) { + quantity + price + } + } + """ + ) + + assert not result.errors + assert "info" in received_context + assert isinstance(received_context["info"], Info) + + +def test_validation_context_with_custom_context(): + """Test that user's custom context is also passed to validators.""" + + received_context: dict[str, Any] = {} + + @strawberry.pydantic.input + class PostInput(pydantic.BaseModel): + title: str + + @field_validator("title") + @classmethod + def check_context(cls, v: str, info: ValidationInfo) -> str: + if info.context: + received_context.update(info.context) + return v + + @strawberry.pydantic.type + class Post(pydantic.BaseModel): + title: str + + class CustomContext: + def __init__(self, user_id: int): + self.user_id = user_id + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_post(self, input: PostInput) -> Post: + return Post(title=input.title) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Execute with custom context + result = schema.execute_sync( + """ + mutation { + createPost(input: { title: "Hello World" }) { + title + } + } + """, + context_value=CustomContext(user_id=42), + ) + + assert not result.errors + + # Verify both info and strawberry_context are available + assert "info" in received_context + assert "strawberry_context" in received_context + assert received_context["strawberry_context"].user_id == 42 + + +def test_validation_context_for_permission_based_validation(): + """Test using validation context for permission-based validation.""" + + class UserContext: + def __init__(self, role: str): + self.role = role + + @strawberry.pydantic.input + class AdminActionInput(pydantic.BaseModel): + action: str + + @field_validator("action") + @classmethod + def check_admin_permission(cls, v: str, info: ValidationInfo) -> str: + if info.context: + strawberry_ctx = info.context.get("strawberry_context") + if ( + strawberry_ctx + and hasattr(strawberry_ctx, "role") + and strawberry_ctx.role != "admin" + ): + raise ValueError("Only admins can perform this action") + return v + + @strawberry.pydantic.type + class ActionResult(pydantic.BaseModel): + success: bool + action: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def perform_admin_action(self, input: AdminActionInput) -> ActionResult: + return ActionResult(success=True, action=input.action) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with admin user - should succeed + result = schema.execute_sync( + """ + mutation { + performAdminAction(input: { action: "delete_all" }) { + success + action + } + } + """, + context_value=UserContext(role="admin"), + ) + + assert not result.errors + assert result.data["performAdminAction"]["success"] is True + + # Test with non-admin user - should fail + result = schema.execute_sync( + """ + mutation { + performAdminAction(input: { action: "delete_all" }) { + success + action + } + } + """, + context_value=UserContext(role="user"), + ) + + assert result.errors is not None + assert "Only admins can perform this action" in result.errors[0].message + + +def test_validation_context_with_nested_inputs(): + """Test that validation context is passed to nested input validators.""" + + nested_context_received = {"outer": False, "inner": False} + + @strawberry.pydantic.input + class AddressInput(pydantic.BaseModel): + city: str + + @field_validator("city") + @classmethod + def check_inner(cls, v: str, info: ValidationInfo) -> str: + if info.context and "info" in info.context: + nested_context_received["inner"] = True + return v + + @strawberry.pydantic.input + class PersonInput(pydantic.BaseModel): + name: str + address: AddressInput + + @field_validator("name") + @classmethod + def check_outer(cls, v: str, info: ValidationInfo) -> str: + if info.context and "info" in info.context: + nested_context_received["outer"] = True + return v + + @strawberry.pydantic.type + class Person(pydantic.BaseModel): + name: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_person(self, input: PersonInput) -> Person: + return Person(name=input.name) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + result = schema.execute_sync( + """ + mutation { + createPerson(input: { + name: "Alice" + address: { city: "NYC" } + }) { + name + } + } + """ + ) + + assert not result.errors + # Both outer and inner validators should receive context + assert nested_context_received["outer"] is True + assert nested_context_received["inner"] is True + + +def test_validation_context_in_query_arguments(): + """Test that validation context works for query arguments too.""" + + context_received = {"received": False} + + @strawberry.pydantic.input + class FilterInput(pydantic.BaseModel): + search: str + + @field_validator("search") + @classmethod + def check_context(cls, v: str, info: ValidationInfo) -> str: + if info.context and "info" in info.context: + context_received["received"] = True + return v + + @strawberry.type + class Query: + @strawberry.field + def search(self, filter: FilterInput) -> str: + return f"Searching for: {filter.search}" + + schema = strawberry.Schema(query=Query) + + result = schema.execute_sync( + """ + query { + search(filter: { search: "test" }) + } + """ + ) + + assert not result.errors + assert context_received["received"] is True