Skip to content

Commit 6a830dc

Browse files
authored
Revert behavior change in error reporting for function after validators (#495)
1 parent 7662bdd commit 6a830dc

File tree

3 files changed

+145
-24
lines changed

3 files changed

+145
-24
lines changed

src/validators/function.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,10 @@ impl FunctionAfterValidator {
157157
extra: &Extra,
158158
) -> ValResult<'data, PyObject> {
159159
let info = ValidationInfo::new(py, extra, &self.config, self.is_field_validator)?;
160-
let input = call(input, extra)?;
160+
let v = call(input, extra)?;
161161
self.func
162-
.call1(py, (input.to_object(py), info))
163-
.map_err(|e| convert_err(py, e, input.into_ref(py)))
162+
.call1(py, (v.to_object(py), info))
163+
.map_err(|e| convert_err(py, e, input))
164164
}
165165
}
166166

tests/validators/test_function.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,39 @@ def f(input_value, validator, info):
238238
v.validate_python(4)
239239

240240

241+
def test_function_after():
242+
def f(input_value, _info):
243+
return input_value + ' Changed'
244+
245+
v = SchemaValidator(
246+
{'type': 'function-after', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'str'}}
247+
)
248+
249+
assert v.validate_python('input value') == 'input value Changed'
250+
251+
252+
def test_function_after_raise():
253+
def f(input_value, info):
254+
raise ValueError('foobar')
255+
256+
v = SchemaValidator(
257+
{'type': 'function-after', 'function': {'type': 'general', 'function': f}, 'schema': {'type': 'str'}}
258+
)
259+
260+
with pytest.raises(ValidationError) as exc_info:
261+
assert v.validate_python('input value') == 'input value Changed'
262+
# debug(str(exc_info.value))
263+
assert exc_info.value.errors() == [
264+
{
265+
'type': 'value_error',
266+
'loc': (),
267+
'msg': 'Value error, foobar',
268+
'input': 'input value',
269+
'ctx': {'error': 'foobar'},
270+
}
271+
]
272+
273+
241274
def test_function_after_config():
242275
f_kwargs = None
243276

tests/validators/test_model.py

Lines changed: 109 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import re
22
from copy import deepcopy
3-
from typing import Any, List
3+
from typing import Any, Callable, Dict, List, Set, Tuple
44

55
import pytest
66

@@ -88,33 +88,121 @@ def __setattr__(self, key, value):
8888
assert setattr_calls == []
8989

9090

91-
def test_model_class_root_validator():
91+
def test_model_class_root_validator_wrap():
9292
class MyModel:
93-
pass
93+
def __init__(self, **kwargs: Any) -> None:
94+
self.__dict__.update(kwargs)
9495

95-
def f(input_value, validator, info):
96+
def f(
97+
input_value: Dict[str, Any],
98+
validator: Callable[[Dict[str, Any]], Dict[str, Any]],
99+
info: core_schema.ValidationInfo,
100+
):
101+
assert input_value['field_a'] == 123
96102
output = validator(input_value)
97-
return str(output)
103+
return output
98104

99-
v = SchemaValidator(
105+
schema = core_schema.model_schema(
106+
MyModel,
107+
core_schema.general_wrap_validator_function(
108+
f,
109+
core_schema.typed_dict_schema(
110+
{'field_a': core_schema.typed_dict_field(core_schema.int_schema())}, return_fields_set=True
111+
),
112+
),
113+
)
114+
115+
v = SchemaValidator(schema)
116+
m = v.validate_python({'field_a': 123})
117+
assert m.field_a == 123
118+
119+
with pytest.raises(ValidationError) as e:
120+
v.validate_python({'field_a': 456})
121+
122+
assert e.value.errors() == [
100123
{
101-
'type': 'function-wrap',
102-
'function': {'type': 'general', 'function': f},
103-
'schema': {
104-
'type': 'model',
105-
'cls': MyModel,
106-
'schema': {
107-
'type': 'typed-dict',
108-
'return_fields_set': True,
109-
'fields': {'field_a': {'type': 'typed-dict-field', 'schema': {'type': 'str'}}},
110-
},
111-
},
124+
'type': 'assertion_error',
125+
'loc': (),
126+
'msg': 'Assertion failed, assert 456 == 123',
127+
'input': {'field_a': 456},
128+
'ctx': {'error': 'assert 456 == 123'},
129+
}
130+
]
131+
132+
133+
def test_model_class_root_validator_before():
134+
class MyModel:
135+
def __init__(self, **kwargs: Any) -> None:
136+
self.__dict__.update(kwargs)
137+
138+
def f(input_value: Dict[str, Any], info: core_schema.ValidationInfo):
139+
assert input_value['field_a'] == 123
140+
return input_value
141+
142+
schema = core_schema.model_schema(
143+
MyModel,
144+
core_schema.general_before_validator_function(
145+
f,
146+
core_schema.typed_dict_schema(
147+
{'field_a': core_schema.typed_dict_field(core_schema.int_schema())}, return_fields_set=True
148+
),
149+
),
150+
)
151+
152+
v = SchemaValidator(schema)
153+
m = v.validate_python({'field_a': 123})
154+
assert m.field_a == 123
155+
156+
with pytest.raises(ValidationError) as e:
157+
v.validate_python({'field_a': 456})
158+
159+
assert e.value.errors() == [
160+
{
161+
'type': 'assertion_error',
162+
'loc': (),
163+
'msg': 'Assertion failed, assert 456 == 123',
164+
'input': {'field_a': 456},
165+
'ctx': {'error': 'assert 456 == 123'},
112166
}
167+
]
168+
169+
170+
def test_model_class_root_validator_after():
171+
class MyModel:
172+
def __init__(self, **kwargs: Any) -> None:
173+
self.__dict__.update(kwargs)
174+
175+
def f(input_value_and_fields_set: Tuple[Dict[str, Any], Set[str]], info: core_schema.ValidationInfo):
176+
input_value, _ = input_value_and_fields_set
177+
assert input_value['field_a'] == 123
178+
return input_value_and_fields_set
179+
180+
schema = core_schema.model_schema(
181+
MyModel,
182+
core_schema.general_after_validator_function(
183+
f,
184+
core_schema.typed_dict_schema(
185+
{'field_a': core_schema.typed_dict_field(core_schema.int_schema())}, return_fields_set=True
186+
),
187+
),
113188
)
114-
assert 'expect_fields_set:true' in plain_repr(v)
115-
m = v.validate_python({'field_a': 'test'})
116-
assert isinstance(m, str)
117-
assert 'test_model_class_root_validator.<locals>.MyModel' in m
189+
190+
v = SchemaValidator(schema)
191+
m = v.validate_python({'field_a': 123})
192+
assert m.field_a == 123
193+
194+
with pytest.raises(ValidationError) as e:
195+
v.validate_python({'field_a': 456})
196+
197+
assert e.value.errors() == [
198+
{
199+
'type': 'assertion_error',
200+
'loc': (),
201+
'msg': 'Assertion failed, assert 456 == 123',
202+
'input': {'field_a': 456},
203+
'ctx': {'error': 'assert 456 == 123'},
204+
}
205+
]
118206

119207

120208
@pytest.mark.parametrize('mode', ['before', 'after', 'wrap'])

0 commit comments

Comments
 (0)