Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/drf_yasg/inspectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ def __init__(self, view, path, method, components, request, field_inspectors):
super(FieldInspector, self).__init__(view, path, method, components, request)
self.field_inspectors = field_inspectors

def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False,
**kwargs):
"""Convert a drf Serializer or Field instance into a Swagger object.

Should return :data:`.NotHandled` if this inspector does not know how to handle the given `field`.
Expand All @@ -151,7 +152,8 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, **
"""
return NotHandled

def probe_field_inspectors(self, field, swagger_object_type, use_references, **kwargs):
def probe_field_inspectors(self, field, swagger_object_type, use_references, is_request=False, is_response=False,
**kwargs):
"""Helper method for recursively probing `field_inspectors` to handle a given field.

All arguments are the same as :meth:`.field_to_swagger_object`.
Expand All @@ -160,7 +162,8 @@ def probe_field_inspectors(self, field, swagger_object_type, use_references, **k
"""
return self.probe_inspectors(
self.field_inspectors, 'field_to_swagger_object', field, {'field_inspectors': self.field_inspectors},
swagger_object_type=swagger_object_type, use_references=use_references, **kwargs
swagger_object_type=swagger_object_type, use_references=use_references, is_request=is_request,
is_response=is_response, **kwargs
)

def _get_partial_types(self, field, swagger_object_type, use_references, **kwargs):
Expand Down Expand Up @@ -252,7 +255,8 @@ def SwaggerType(existing_object=None, **instance_kwargs):


class SerializerInspector(FieldInspector):
def get_schema(self, serializer):

def get_schema(self, serializer, is_request=False, is_response=False):
"""Convert a DRF Serializer instance to an :class:`.openapi.Schema`.

Should return :data:`.NotHandled` if this inspector does not know how to handle the given `serializer`.
Expand Down Expand Up @@ -365,15 +369,17 @@ def get_pagination_parameters(self):

return self.probe_inspectors(self.paginator_inspectors, 'get_paginator_parameters', self.view.paginator) or []

def serializer_to_schema(self, serializer):
def serializer_to_schema(self, serializer, is_request=False, is_response=False):
"""Convert a serializer to an OpenAPI :class:`.Schema`.

:param serializers.BaseSerializer serializer: the ``Serializer`` instance
:returns: the converted :class:`.Schema`, or ``None`` in case of an unknown serializer
:rtype: openapi.Schema,openapi.SchemaRef
"""

return self.probe_inspectors(
self.field_inspectors, 'get_schema', serializer, {'field_inspectors': self.field_inspectors}
self.field_inspectors, 'get_schema', serializer, {'field_inspectors': self.field_inspectors},
is_request=is_request, is_response=is_response
)

def serializer_to_parameters(self, serializer, in_):
Expand Down
62 changes: 40 additions & 22 deletions src/drf_yasg/inspectors/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ def add_manual_fields(self, serializer, schema):
for attr, val in swagger_schema_fields.items():
setattr(schema, attr, val)

def get_schema(self, serializer):
return self.probe_field_inspectors(serializer, openapi.Schema, self.use_definitions)
def get_schema(self, serializer, is_request=False, is_response=False):
return self.probe_field_inspectors(serializer, openapi.Schema, self.use_definitions, is_request=is_request,
is_response=is_response)

def add_manual_parameters(self, serializer, parameters):
"""Add/replace parameters from the given list of automatically generated request parameters. This method
Expand All @@ -54,7 +55,7 @@ def get_request_parameters(self, serializer, in_):
parameters = [
self.probe_field_inspectors(
value, openapi.Parameter, self.use_definitions,
name=self.get_parameter_name(key), in_=in_
name=self.get_parameter_name(key), in_=in_, is_request=True
)
for key, value
in fields.items()
Expand All @@ -71,33 +72,39 @@ def get_parameter_name(self, field_name):
def get_serializer_ref_name(self, serializer):
return get_serializer_ref_name(serializer)

def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False,
**kwargs):
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)

if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
child_schema = self.probe_field_inspectors(field.child, ChildSwaggerType, use_references)
child_schema = self.probe_field_inspectors(field.child, ChildSwaggerType, use_references,
is_request=is_request, is_response=is_response)
return SwaggerType(
type=openapi.TYPE_ARRAY,
items=child_schema,
items=child_schema
)
elif isinstance(field, serializers.Serializer):
if swagger_object_type != openapi.Schema:
raise SwaggerGenerationError("cannot instantiate nested serializer as " + swagger_object_type.__name__)

ref_name = self.get_serializer_ref_name(field)
if ref_name and is_request:
ref_name += 'Request'
if ref_name and is_response:
ref_name += 'Response'

def make_schema_definition():
properties = OrderedDict()
required = []
for property_name, child in field.fields.items():
if is_request and child.write_only:
continue
elif is_response and child.read_only:
continue
property_name = self.get_property_name(property_name)
prop_kwargs = {
'read_only': bool(child.read_only) or None
}
prop_kwargs = filter_none(prop_kwargs)

child_schema = self.probe_field_inspectors(
child, ChildSwaggerType, use_references, **prop_kwargs
child, ChildSwaggerType, use_references, is_request=is_request, is_response=is_response
)
properties[property_name] = child_schema

Expand Down Expand Up @@ -197,11 +204,13 @@ def get_related_model(model, source):
class RelatedFieldInspector(FieldInspector):
"""Provides conversions for ``RelatedField``\ s."""

def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False,
**kwargs):
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)

if isinstance(field, serializers.ManyRelatedField):
child_schema = self.probe_field_inspectors(field.child_relation, ChildSwaggerType, use_references)
child_schema = self.probe_field_inspectors(field.child_relation, ChildSwaggerType, use_references,
is_request=is_request, is_response=is_response)
return SwaggerType(
type=openapi.TYPE_ARRAY,
items=child_schema,
Expand All @@ -217,7 +226,8 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, **
if getattr(field, 'pk_field', ''):
# a PrimaryKeyRelatedField can have a `pk_field` attribute which is a
# serializer field that will convert the PK value
result = self.probe_field_inspectors(field.pk_field, swagger_object_type, use_references, **kwargs)
result = self.probe_field_inspectors(field.pk_field, swagger_object_type, use_references,
is_request=is_request, is_response=is_response, **kwargs)
# take the type, format, etc from `pk_field`, and the field-level information
# like title, description, default from the PrimaryKeyRelatedField
return SwaggerType(existing_object=result)
Expand Down Expand Up @@ -414,7 +424,8 @@ class SimpleFieldInspector(FieldInspector):
and min/max validators.
"""

def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False,
**kwargs):
type_info = get_basic_type_info(field)
if type_info is None:
return NotHandled
Expand All @@ -426,7 +437,8 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, **
class ChoiceFieldInspector(FieldInspector):
"""Provides conversions for ``ChoiceField`` and ``MultipleChoiceField``."""

def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False,
**kwargs):
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)

if isinstance(field, serializers.ChoiceField):
Expand Down Expand Up @@ -459,7 +471,8 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, **
class FileFieldInspector(FieldInspector):
"""Provides conversions for ``FileField``\ s."""

def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False,
**kwargs):
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)

if isinstance(field, serializers.FileField):
Expand All @@ -486,11 +499,13 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, **
class DictFieldInspector(FieldInspector):
"""Provides conversion for ``DictField``."""

def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False,
**kwargs):
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)

if isinstance(field, serializers.DictField) and swagger_object_type == openapi.Schema:
child_schema = self.probe_field_inspectors(field.child, ChildSwaggerType, use_references)
child_schema = self.probe_field_inspectors(field.child, ChildSwaggerType, use_references,
is_request=is_request, is_response=is_response)
return SwaggerType(
type=openapi.TYPE_OBJECT,
additional_properties=child_schema
Expand All @@ -502,7 +517,8 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, **
class HiddenFieldInspector(FieldInspector):
"""Hide ``HiddenField``."""

def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False,
**kwargs):
if isinstance(field, serializers.HiddenField):
return None

Expand All @@ -512,7 +528,8 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, **
class StringDefaultFieldInspector(FieldInspector):
"""For otherwise unhandled fields, return them as plain :data:`.TYPE_STRING` objects."""

def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): # pragma: no cover
def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False, is_response=False,
**kwargs): # pragma: no cover
# TODO unhandled fields: TimeField JSONField
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
return SwaggerType(type=openapi.TYPE_STRING)
Expand Down Expand Up @@ -569,7 +586,8 @@ class RecursiveFieldInspector(FieldInspector):
else:
class RecursiveFieldInspector(FieldInspector):
"""Provides conversion for RecursiveField (https://github.com/heywbj/django-rest-framework-recursive)"""
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
def field_to_swagger_object(self, field, swagger_object_type, use_references, is_request=False,
is_response=False, **kwargs):
if isinstance(field, RecursiveField) and swagger_object_type == openapi.Schema:
assert use_references is True, "Can not create schema for RecursiveField when use_references is False"

Expand Down
8 changes: 4 additions & 4 deletions src/drf_yasg/inspectors/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def get_request_body_schema(self, serializer):
:param serializer: the view's request serializer as returned by :meth:`.get_request_serializer`
:rtype: openapi.Schema
"""
return self.serializer_to_schema(serializer)
return self.serializer_to_schema(serializer, is_request=True)

def make_body_parameter(self, schema):
"""Given a :class:`.Schema` object, create an ``in: body`` :class:`.Parameter`.
Expand Down Expand Up @@ -206,7 +206,7 @@ def get_default_responses(self):
if any(is_form_media_type(encoding) for encoding in self.get_consumes()):
default_schema = ''
if default_schema and not isinstance(default_schema, openapi.Schema):
default_schema = self.serializer_to_schema(default_schema) or ''
default_schema = self.serializer_to_schema(default_schema, is_response=True) or ''

if default_schema:
if is_list_view(self.path, self.method, self.view) and self.method.lower() == 'get':
Expand Down Expand Up @@ -254,7 +254,7 @@ def get_response_schemas(self, response_serializers):
response = serializer
if hasattr(response, 'schema') and not isinstance(response.schema, openapi.Schema.OR_REF):
serializer = force_serializer_instance(response.schema)
response.schema = self.serializer_to_schema(serializer)
response.schema = self.serializer_to_schema(serializer, is_response=True)
elif isinstance(serializer, openapi.Schema.OR_REF):
response = openapi.Response(
description='',
Expand All @@ -264,7 +264,7 @@ def get_response_schemas(self, response_serializers):
serializer = force_serializer_instance(serializer)
response = openapi.Response(
description='',
schema=self.serializer_to_schema(serializer),
schema=self.serializer_to_schema(serializer, is_response=True),
)

responses[str(sc)] = response
Expand Down