From 1d5de342ed51977f9ca5c9e0438882c73cc36484 Mon Sep 17 00:00:00 2001 From: matllubos Date: Tue, 10 Jul 2018 13:41:43 +0200 Subject: [PATCH] Schema is splitted to request and response parts --- src/drf_yasg/inspectors/base.py | 18 ++++++---- src/drf_yasg/inspectors/field.py | 62 ++++++++++++++++++++------------ src/drf_yasg/inspectors/view.py | 8 ++--- 3 files changed, 56 insertions(+), 32 deletions(-) diff --git a/src/drf_yasg/inspectors/base.py b/src/drf_yasg/inspectors/base.py index c54fe7df..7089b65d 100644 --- a/src/drf_yasg/inspectors/base.py +++ b/src/drf_yasg/inspectors/base.py @@ -136,7 +136,8 @@ def __init__(self, view, path, method, components, request, field_inspectors): super(FieldInspector, self).__init__(view, path, method, components, request) self.field_inspectors = field_inspectors - def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): + def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False, + **kwargs): """Convert a drf Serializer or Field instance into a Swagger object. Should return :data:`.NotHandled` if this inspector does not know how to handle the given `field`. @@ -152,7 +153,8 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, ** """ return NotHandled - def probe_field_inspectors(self, field, swagger_object_type, use_references, **kwargs): + def probe_field_inspectors(self, field, swagger_object_type, use_references, is_request=False, is_response=False, + **kwargs): """Helper method for recursively probing `field_inspectors` to handle a given field. All arguments are the same as :meth:`.field_to_swagger_object`. @@ -161,7 +163,8 @@ def probe_field_inspectors(self, field, swagger_object_type, use_references, **k """ return self.probe_inspectors( self.field_inspectors, 'field_to_swagger_object', field, {'field_inspectors': self.field_inspectors}, - swagger_object_type=swagger_object_type, use_references=use_references, **kwargs + swagger_object_type=swagger_object_type, use_references=use_references, is_request=is_request, + is_response=is_response, **kwargs ) def _get_partial_types(self, field, swagger_object_type, use_references, **kwargs): @@ -253,7 +256,8 @@ def SwaggerType(existing_object=None, **instance_kwargs): class SerializerInspector(FieldInspector): - def get_schema(self, serializer): + + def get_schema(self, serializer, is_request=False, is_response=False): """Convert a DRF Serializer instance to an :class:`.openapi.Schema`. Should return :data:`.NotHandled` if this inspector does not know how to handle the given `serializer`. @@ -366,15 +370,17 @@ def get_pagination_parameters(self): return self.probe_inspectors(self.paginator_inspectors, 'get_paginator_parameters', self.view.paginator) or [] - def serializer_to_schema(self, serializer): + def serializer_to_schema(self, serializer, is_request=False, is_response=False): """Convert a serializer to an OpenAPI :class:`.Schema`. :param serializers.BaseSerializer serializer: the ``Serializer`` instance :returns: the converted :class:`.Schema`, or ``None`` in case of an unknown serializer :rtype: openapi.Schema,openapi.SchemaRef,None """ + return self.probe_inspectors( - self.field_inspectors, 'get_schema', serializer, {'field_inspectors': self.field_inspectors} + self.field_inspectors, 'get_schema', serializer, {'field_inspectors': self.field_inspectors}, + is_request=is_request, is_response=is_response ) def serializer_to_parameters(self, serializer, in_): diff --git a/src/drf_yasg/inspectors/field.py b/src/drf_yasg/inspectors/field.py index d1ea24fa..1b06a9ee 100644 --- a/src/drf_yasg/inspectors/field.py +++ b/src/drf_yasg/inspectors/field.py @@ -35,8 +35,9 @@ def add_manual_fields(self, serializer, schema): for attr, val in swagger_schema_fields.items(): setattr(schema, attr, val) - def get_schema(self, serializer): - return self.probe_field_inspectors(serializer, openapi.Schema, self.use_definitions) + def get_schema(self, serializer, is_request=False, is_response=False): + return self.probe_field_inspectors(serializer, openapi.Schema, self.use_definitions, is_request=is_request, + is_response=is_response) def add_manual_parameters(self, serializer, parameters): """Add/replace parameters from the given list of automatically generated request parameters. This method @@ -54,7 +55,7 @@ def get_request_parameters(self, serializer, in_): parameters = [ self.probe_field_inspectors( value, openapi.Parameter, self.use_definitions, - name=self.get_parameter_name(key), in_=in_ + name=self.get_parameter_name(key), in_=in_, is_request=True ) for key, value in fields.items() @@ -71,33 +72,39 @@ def get_parameter_name(self, field_name): def get_serializer_ref_name(self, serializer): return get_serializer_ref_name(serializer) - def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): + def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False, + **kwargs): SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs) if isinstance(field, (serializers.ListSerializer, serializers.ListField)): - child_schema = self.probe_field_inspectors(field.child, ChildSwaggerType, use_references) + child_schema = self.probe_field_inspectors(field.child, ChildSwaggerType, use_references, + is_request=is_request, is_response=is_response) return SwaggerType( type=openapi.TYPE_ARRAY, - items=child_schema, + items=child_schema ) elif isinstance(field, serializers.Serializer): if swagger_object_type != openapi.Schema: raise SwaggerGenerationError("cannot instantiate nested serializer as " + swagger_object_type.__name__) ref_name = self.get_serializer_ref_name(field) + if ref_name and is_request: + ref_name += 'Request' + if ref_name and is_response: + ref_name += 'Response' def make_schema_definition(): properties = OrderedDict() required = [] for property_name, child in field.fields.items(): + if is_request and child.write_only: + continue + elif is_response and child.read_only: + continue property_name = self.get_property_name(property_name) - prop_kwargs = { - 'read_only': bool(child.read_only) or None - } - prop_kwargs = filter_none(prop_kwargs) child_schema = self.probe_field_inspectors( - child, ChildSwaggerType, use_references, **prop_kwargs + child, ChildSwaggerType, use_references, is_request=is_request, is_response=is_response ) properties[property_name] = child_schema @@ -197,11 +204,13 @@ def get_related_model(model, source): class RelatedFieldInspector(FieldInspector): """Provides conversions for ``RelatedField``\ s.""" - def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): + def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False, + **kwargs): SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs) if isinstance(field, serializers.ManyRelatedField): - child_schema = self.probe_field_inspectors(field.child_relation, ChildSwaggerType, use_references) + child_schema = self.probe_field_inspectors(field.child_relation, ChildSwaggerType, use_references, + is_request=is_request, is_response=is_response) return SwaggerType( type=openapi.TYPE_ARRAY, items=child_schema, @@ -217,7 +226,8 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, ** if getattr(field, 'pk_field', ''): # a PrimaryKeyRelatedField can have a `pk_field` attribute which is a # serializer field that will convert the PK value - result = self.probe_field_inspectors(field.pk_field, swagger_object_type, use_references, **kwargs) + result = self.probe_field_inspectors(field.pk_field, swagger_object_type, use_references, + is_request=is_request, is_response=is_response, **kwargs) # take the type, format, etc from `pk_field`, and the field-level information # like title, description, default from the PrimaryKeyRelatedField return SwaggerType(existing_object=result) @@ -414,7 +424,8 @@ class SimpleFieldInspector(FieldInspector): and min/max validators. """ - def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): + def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False, + **kwargs): type_info = get_basic_type_info(field) if type_info is None: return NotHandled @@ -426,7 +437,8 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, ** class ChoiceFieldInspector(FieldInspector): """Provides conversions for ``ChoiceField`` and ``MultipleChoiceField``.""" - def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): + def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False, + **kwargs): SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs) if isinstance(field, serializers.ChoiceField): @@ -459,7 +471,8 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, ** class FileFieldInspector(FieldInspector): """Provides conversions for ``FileField``\ s.""" - def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): + def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False, + **kwargs): SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs) if isinstance(field, serializers.FileField): @@ -486,11 +499,13 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, ** class DictFieldInspector(FieldInspector): """Provides conversion for ``DictField``.""" - def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): + def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False, + **kwargs): SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs) if isinstance(field, serializers.DictField) and swagger_object_type == openapi.Schema: - child_schema = self.probe_field_inspectors(field.child, ChildSwaggerType, use_references) + child_schema = self.probe_field_inspectors(field.child, ChildSwaggerType, use_references, + is_request=is_request, is_response=is_response) return SwaggerType( type=openapi.TYPE_OBJECT, additional_properties=child_schema @@ -502,7 +517,8 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, ** class HiddenFieldInspector(FieldInspector): """Hide ``HiddenField``.""" - def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): + def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False, + **kwargs): if isinstance(field, serializers.HiddenField): return None @@ -512,7 +528,8 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, ** class StringDefaultFieldInspector(FieldInspector): """For otherwise unhandled fields, return them as plain :data:`.TYPE_STRING` objects.""" - def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): # pragma: no cover + def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False, + **kwargs): # pragma: no cover # TODO unhandled fields: TimeField JSONField SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs) return SwaggerType(type=openapi.TYPE_STRING) @@ -569,7 +586,8 @@ class RecursiveFieldInspector(FieldInspector): else: class RecursiveFieldInspector(FieldInspector): """Provides conversion for RecursiveField (https://github.com/heywbj/django-rest-framework-recursive)""" - def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): + def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, + is_response=False, **kwargs): if isinstance(field, RecursiveField) and swagger_object_type == openapi.Schema: assert use_references is True, "Can not create schema for RecursiveField when use_references is False" diff --git a/src/drf_yasg/inspectors/view.py b/src/drf_yasg/inspectors/view.py index 957c8f75..79dc7573 100644 --- a/src/drf_yasg/inspectors/view.py +++ b/src/drf_yasg/inspectors/view.py @@ -138,7 +138,7 @@ def get_request_body_schema(self, serializer): :param serializer: the view's request serializer as returned by :meth:`.get_request_serializer` :rtype: openapi.Schema """ - return self.serializer_to_schema(serializer) + return self.serializer_to_schema(serializer, is_request=True) def make_body_parameter(self, schema): """Given a :class:`.Schema` object, create an ``in: body`` :class:`.Parameter`. @@ -206,7 +206,7 @@ def get_default_responses(self): if any(is_form_media_type(encoding) for encoding in self.get_consumes()): default_schema = '' if default_schema and not isinstance(default_schema, openapi.Schema): - default_schema = self.serializer_to_schema(default_schema) or '' + default_schema = self.serializer_to_schema(default_schema, is_response=True) or '' if default_schema: if is_list_view(self.path, self.method, self.view) and self.method.lower() == 'get': @@ -254,7 +254,7 @@ def get_response_schemas(self, response_serializers): response = serializer if hasattr(response, 'schema') and not isinstance(response.schema, openapi.Schema.OR_REF): serializer = force_serializer_instance(response.schema) - response.schema = self.serializer_to_schema(serializer) + response.schema = self.serializer_to_schema(serializer, is_response=True) elif isinstance(serializer, openapi.Schema.OR_REF): response = openapi.Response( description='', @@ -264,7 +264,7 @@ def get_response_schemas(self, response_serializers): serializer = force_serializer_instance(serializer) response = openapi.Response( description='', - schema=self.serializer_to_schema(serializer), + schema=self.serializer_to_schema(serializer, is_response=True), ) responses[str(sc)] = response