Skip to content

Commit f52fb3f

Browse files
committed
Schema is splitted to request and response parts
1 parent db86981 commit f52fb3f

File tree

3 files changed

+57
-33
lines changed

3 files changed

+57
-33
lines changed

src/drf_yasg/inspectors/base.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ def __init__(self, view, path, method, components, request, field_inspectors):
136136
super(FieldInspector, self).__init__(view, path, method, components, request)
137137
self.field_inspectors = field_inspectors
138138

139-
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
139+
def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False,
140+
**kwargs):
140141
"""Convert a drf Serializer or Field instance into a Swagger object.
141142
142143
Should return :data:`.NotHandled` if this inspector does not know how to handle the given `field`.
@@ -152,7 +153,8 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, **
152153
"""
153154
return NotHandled
154155

155-
def probe_field_inspectors(self, field, swagger_object_type, use_references, **kwargs):
156+
def probe_field_inspectors(self, field, swagger_object_type, use_references, is_request=False, is_response=False,
157+
**kwargs):
156158
"""Helper method for recursively probing `field_inspectors` to handle a given field.
157159
158160
All arguments are the same as :meth:`.field_to_swagger_object`.
@@ -161,7 +163,8 @@ def probe_field_inspectors(self, field, swagger_object_type, use_references, **k
161163
"""
162164
return self.probe_inspectors(
163165
self.field_inspectors, 'field_to_swagger_object', field, {'field_inspectors': self.field_inspectors},
164-
swagger_object_type=swagger_object_type, use_references=use_references, **kwargs
166+
swagger_object_type=swagger_object_type, use_references=use_references, is_request=is_request,
167+
is_response=is_response, **kwargs
165168
)
166169

167170
def _get_partial_types(self, field, swagger_object_type, use_references, **kwargs):
@@ -253,7 +256,8 @@ def SwaggerType(existing_object=None, **instance_kwargs):
253256

254257

255258
class SerializerInspector(FieldInspector):
256-
def get_schema(self, serializer):
259+
260+
def get_schema(self, serializer, is_request=False, is_response=False):
257261
"""Convert a DRF Serializer instance to an :class:`.openapi.Schema`.
258262
259263
Should return :data:`.NotHandled` if this inspector does not know how to handle the given `serializer`.
@@ -366,15 +370,17 @@ def get_pagination_parameters(self):
366370

367371
return self.probe_inspectors(self.paginator_inspectors, 'get_paginator_parameters', self.view.paginator) or []
368372

369-
def serializer_to_schema(self, serializer):
373+
def serializer_to_schema(self, serializer, is_request=False, is_response=False):
370374
"""Convert a serializer to an OpenAPI :class:`.Schema`.
371375
372376
:param serializers.BaseSerializer serializer: the ``Serializer`` instance
373377
:returns: the converted :class:`.Schema`, or ``None`` in case of an unknown serializer
374378
:rtype: openapi.Schema,openapi.SchemaRef,None
375379
"""
380+
376381
return self.probe_inspectors(
377-
self.field_inspectors, 'get_schema', serializer, {'field_inspectors': self.field_inspectors}
382+
self.field_inspectors, 'get_schema', serializer, {'field_inspectors': self.field_inspectors},
383+
is_request=is_request, is_response=is_response
378384
)
379385

380386
def serializer_to_parameters(self, serializer, in_):

src/drf_yasg/inspectors/field.py

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@ def add_manual_fields(self, serializer, schema):
3535
for attr, val in swagger_schema_fields.items():
3636
setattr(schema, attr, val)
3737

38-
def get_schema(self, serializer):
39-
return self.probe_field_inspectors(serializer, openapi.Schema, self.use_definitions)
38+
def get_schema(self, serializer, is_request=False, is_response=False):
39+
return self.probe_field_inspectors(serializer, openapi.Schema, self.use_definitions, is_request=is_request,
40+
is_response=is_response)
4041

41-
def add_manual_parameters(self, serializer, parameters):
42+
def add_manual_parameters(self, serializer, is_request, is_response):
4243
"""Add/replace parameters from the given list of automatically generated request parameters. This method
4344
is called only when the serializer is converted into a list of parameters for use in a form data request.
4445
@@ -54,7 +55,7 @@ def get_request_parameters(self, serializer, in_):
5455
parameters = [
5556
self.probe_field_inspectors(
5657
value, openapi.Parameter, self.use_definitions,
57-
name=self.get_parameter_name(key), in_=in_
58+
name=self.get_parameter_name(key), in_=in_, is_request=True
5859
)
5960
for key, value
6061
in fields.items()
@@ -71,33 +72,39 @@ def get_parameter_name(self, field_name):
7172
def get_serializer_ref_name(self, serializer):
7273
return get_serializer_ref_name(serializer)
7374

74-
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
75+
def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False,
76+
**kwargs):
7577
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
7678

7779
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
78-
child_schema = self.probe_field_inspectors(field.child, ChildSwaggerType, use_references)
80+
child_schema = self.probe_field_inspectors(field.child, ChildSwaggerType, use_references,
81+
is_request=is_request, is_response=is_response)
7982
return SwaggerType(
8083
type=openapi.TYPE_ARRAY,
81-
items=child_schema,
84+
items=child_schema
8285
)
8386
elif isinstance(field, serializers.Serializer):
8487
if swagger_object_type != openapi.Schema:
8588
raise SwaggerGenerationError("cannot instantiate nested serializer as " + swagger_object_type.__name__)
8689

8790
ref_name = self.get_serializer_ref_name(field)
91+
if is_request:
92+
ref_name += 'Request'
93+
if is_response:
94+
ref_name += 'Response'
8895

8996
def make_schema_definition():
9097
properties = OrderedDict()
9198
required = []
9299
for property_name, child in field.fields.items():
100+
if is_request and child.write_only:
101+
continue
102+
elif is_response and child.read_only:
103+
continue
93104
property_name = self.get_property_name(property_name)
94-
prop_kwargs = {
95-
'read_only': bool(child.read_only) or None
96-
}
97-
prop_kwargs = filter_none(prop_kwargs)
98105

99106
child_schema = self.probe_field_inspectors(
100-
child, ChildSwaggerType, use_references, **prop_kwargs
107+
child, ChildSwaggerType, use_references, is_request=is_request, is_response=is_response
101108
)
102109
properties[property_name] = child_schema
103110

@@ -197,11 +204,13 @@ def get_related_model(model, source):
197204
class RelatedFieldInspector(FieldInspector):
198205
"""Provides conversions for ``RelatedField``\ s."""
199206

200-
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
207+
def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False,
208+
**kwargs):
201209
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
202210

203211
if isinstance(field, serializers.ManyRelatedField):
204-
child_schema = self.probe_field_inspectors(field.child_relation, ChildSwaggerType, use_references)
212+
child_schema = self.probe_field_inspectors(field.child_relation, ChildSwaggerType, use_references,
213+
is_request=is_request, is_response=is_response)
205214
return SwaggerType(
206215
type=openapi.TYPE_ARRAY,
207216
items=child_schema,
@@ -217,7 +226,8 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, **
217226
if getattr(field, 'pk_field', ''):
218227
# a PrimaryKeyRelatedField can have a `pk_field` attribute which is a
219228
# serializer field that will convert the PK value
220-
result = self.probe_field_inspectors(field.pk_field, swagger_object_type, use_references, **kwargs)
229+
result = self.probe_field_inspectors(field.pk_field, swagger_object_type, use_references,
230+
is_request=is_request, is_response=is_response, **kwargs)
221231
# take the type, format, etc from `pk_field`, and the field-level information
222232
# like title, description, default from the PrimaryKeyRelatedField
223233
return SwaggerType(existing_object=result)
@@ -414,7 +424,8 @@ class SimpleFieldInspector(FieldInspector):
414424
and min/max validators.
415425
"""
416426

417-
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
427+
def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False,
428+
**kwargs):
418429
type_info = get_basic_type_info(field)
419430
if type_info is None:
420431
return NotHandled
@@ -426,7 +437,8 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, **
426437
class ChoiceFieldInspector(FieldInspector):
427438
"""Provides conversions for ``ChoiceField`` and ``MultipleChoiceField``."""
428439

429-
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
440+
def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False,
441+
**kwargs):
430442
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
431443

432444
if isinstance(field, serializers.ChoiceField):
@@ -459,7 +471,8 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, **
459471
class FileFieldInspector(FieldInspector):
460472
"""Provides conversions for ``FileField``\ s."""
461473

462-
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
474+
def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False,
475+
**kwargs):
463476
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
464477

465478
if isinstance(field, serializers.FileField):
@@ -486,11 +499,13 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, **
486499
class DictFieldInspector(FieldInspector):
487500
"""Provides conversion for ``DictField``."""
488501

489-
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
502+
def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False,
503+
**kwargs):
490504
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
491505

492506
if isinstance(field, serializers.DictField) and swagger_object_type == openapi.Schema:
493-
child_schema = self.probe_field_inspectors(field.child, ChildSwaggerType, use_references)
507+
child_schema = self.probe_field_inspectors(field.child, ChildSwaggerType, use_references,
508+
is_request=is_request, is_response=is_response)
494509
return SwaggerType(
495510
type=openapi.TYPE_OBJECT,
496511
additional_properties=child_schema
@@ -502,7 +517,8 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, **
502517
class HiddenFieldInspector(FieldInspector):
503518
"""Hide ``HiddenField``."""
504519

505-
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
520+
def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False,
521+
**kwargs):
506522
if isinstance(field, serializers.HiddenField):
507523
return None
508524

@@ -512,7 +528,8 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, **
512528
class StringDefaultFieldInspector(FieldInspector):
513529
"""For otherwise unhandled fields, return them as plain :data:`.TYPE_STRING` objects."""
514530

515-
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): # pragma: no cover
531+
def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False,
532+
**kwargs): # pragma: no cover
516533
# TODO unhandled fields: TimeField JSONField
517534
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
518535
return SwaggerType(type=openapi.TYPE_STRING)
@@ -569,7 +586,8 @@ class RecursiveFieldInspector(FieldInspector):
569586
else:
570587
class RecursiveFieldInspector(FieldInspector):
571588
"""Provides conversion for RecursiveField (https://github.com/heywbj/django-rest-framework-recursive)"""
572-
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
589+
def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False,
590+
is_response=False, **kwargs):
573591
if isinstance(field, RecursiveField) and swagger_object_type == openapi.Schema:
574592
assert use_references is True, "Can not create schema for RecursiveField when use_references is False"
575593

src/drf_yasg/inspectors/view.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def get_request_body_schema(self, serializer):
138138
:param serializer: the view's request serializer as returned by :meth:`.get_request_serializer`
139139
:rtype: openapi.Schema
140140
"""
141-
return self.serializer_to_schema(serializer)
141+
return self.serializer_to_schema(serializer, is_request=True)
142142

143143
def make_body_parameter(self, schema):
144144
"""Given a :class:`.Schema` object, create an ``in: body`` :class:`.Parameter`.
@@ -206,7 +206,7 @@ def get_default_responses(self):
206206
if any(is_form_media_type(encoding) for encoding in self.get_consumes()):
207207
default_schema = ''
208208
if default_schema and not isinstance(default_schema, openapi.Schema):
209-
default_schema = self.serializer_to_schema(default_schema) or ''
209+
default_schema = self.serializer_to_schema(default_schema, is_response=True) or ''
210210

211211
if default_schema:
212212
if is_list_view(self.path, self.method, self.view) and self.method.lower() == 'get':
@@ -254,7 +254,7 @@ def get_response_schemas(self, response_serializers):
254254
response = serializer
255255
if hasattr(response, 'schema') and not isinstance(response.schema, openapi.Schema.OR_REF):
256256
serializer = force_serializer_instance(response.schema)
257-
response.schema = self.serializer_to_schema(serializer)
257+
response.schema = self.serializer_to_schema(serializer, is_response=True)
258258
elif isinstance(serializer, openapi.Schema.OR_REF):
259259
response = openapi.Response(
260260
description='',
@@ -264,7 +264,7 @@ def get_response_schemas(self, response_serializers):
264264
serializer = force_serializer_instance(serializer)
265265
response = openapi.Response(
266266
description='',
267-
schema=self.serializer_to_schema(serializer),
267+
schema=self.serializer_to_schema(serializer, is_response=True),
268268
)
269269

270270
responses[str(sc)] = response

0 commit comments

Comments
 (0)