Skip to content

Commit

Permalink
Merge pull request #87 from vintasoftware/fix/remove-custom-_get_seri…
Browse files Browse the repository at this point in the history
…alizer_class

Refactor strategy to handle `RecursionError` between `get_serializer_class`, `get_read_serializer`, and `get_write_serializer`
  • Loading branch information
pamella authored Jun 6, 2024
2 parents ff845ca + 181e22c commit fecd903
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 34 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ Change Log
Unreleased
~~~~~~~~~~

[1.4.0] - 2024-06-05
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Fixed
_____
* Fix a regression in the `get_read_serializer_class` and `get_write_serializer_class`
methods to return `get_serializer_class` as default.

[1.3.0] - 2024-06-03
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Fixed
Expand Down
10 changes: 10 additions & 0 deletions docs/cross_library_integrations.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
==========================
Cross-Library Integrations
==========================

drf-spectacular
---------------

If your project is using both `drf-rw-serializers` and `drf-spectacular`, there
are specific configurations to be made. Detailed steps for this integration are
provided in the `drf-spectacular documentation <https://drf-spectacular.readthedocs.io/en/latest/blueprints.html#drf-rw-serializers>`_.
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Contents:
readme
installation
usage
cross_library_integrations
contributing
authors
changelog
2 changes: 1 addition & 1 deletion drf_rw_serializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from __future__ import absolute_import, unicode_literals

__version__ = "1.3.0"
__version__ = "1.4.0"

# pylint: disable=invalid-name
default_app_config = "drf_rw_serializers.apps.DrfRwSerializersConfig"
Expand Down
57 changes: 31 additions & 26 deletions drf_rw_serializers/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,19 @@


class GenericAPIView(generics.GenericAPIView):
def _get_serializer_class(self):
"""
Return the class to use for the serializer.
Defaults to using `self.serializer_class`.
You may want to override this if you need to provide different
serializations depending on the incoming request.
(Eg. admins get full serialization, others get basic serialization)
"""
assert (
self.serializer_class is not None
or getattr(self, "read_serializer_class", None) is not None
), (
"'%s' should either include one of `serializer_class` and `read_serializer_class` "
"attribute, or override one of the `get_serializer_class()`, "
"`get_read_serializer_class()` method." % self.__class__.__name__
)

return self.serializer_class

def get_serializer_class(self):
"""
Return the class to use for the serializer.
Defaults to using `self.serializer_class`.
If the request method is GET, it tries to use `self.read_serializer_class`.
If the request method is not GET, it tries to use `self.write_serializer_class`.
If the specific serializer class for the request method is not set, it falls back to
`self.serializer_class`.
You may want to override this if you need to provide different
serializations depending on the incoming request.
(Eg. admins get full serialization, others get basic serialization)
"""
if hasattr(self, "request"):
Expand All @@ -52,7 +36,8 @@ def get_serializer_class(self):
"attribute, or override the `get_read_serializer_class()` or "
"`get_serializer_class()` method." % self.__class__.__name__
)
return self.get_read_serializer_class()
# `default_to_serializer_class` is used to prevent a `RecursionError`
return self.get_read_serializer_class(default_to_serializer_class=True)

if self.request.method in ["POST", "PUT", "PATCH", "DELETE"]:
assert (
Expand All @@ -63,9 +48,19 @@ def get_serializer_class(self):
"attribute, or override the `get_write_serializer_class()` or "
"`get_serializer_class()` method." % self.__class__.__name__
)
return self.get_write_serializer_class()
# `default_to_serializer_class` is used to prevent a `RecursionError`
return self.get_write_serializer_class(default_to_serializer_class=True)

assert (
self.serializer_class is not None
or getattr(self, "read_serializer_class", None) is not None
), (
"'%s' should either include one of `serializer_class` and `read_serializer_class` "
"attribute, or override one of the `get_serializer_class()`, "
"`get_read_serializer_class()` method." % self.__class__.__name__
)

return self._get_serializer_class()
return self.serializer_class

def get_read_serializer(self, *args, **kwargs):
"""
Expand All @@ -75,16 +70,21 @@ def get_read_serializer(self, *args, **kwargs):
kwargs["context"] = self.get_serializer_context()
return serializer_class(*args, **kwargs)

def get_read_serializer_class(self):
def get_read_serializer_class(self, default_to_serializer_class: bool = False):
"""
Return the class to use for the serializer.
Defaults to using `self.read_serializer_class`.
You may want to override this if you need to provide different
serializations depending on the incoming request.
(Eg. admins get full serialization, others get basic serialization)
"""
if getattr(self, "read_serializer_class", None) is None:
return self._get_serializer_class()
if default_to_serializer_class:
return self.serializer_class

return self.get_serializer_class()

return self.read_serializer_class

Expand All @@ -97,16 +97,21 @@ def get_write_serializer(self, *args, **kwargs):
kwargs["context"] = self.get_serializer_context()
return serializer_class(*args, **kwargs)

def get_write_serializer_class(self):
def get_write_serializer_class(self, default_to_serializer_class: bool = False):
"""
Return the class to use for the serializer.
Defaults to using `self.write_serializer_class`.
You may want to override this if you need to provide different
serializations depending on the incoming request.
(Eg. admins can send extra fields, others cannot)
"""
if getattr(self, "write_serializer_class", None) is None:
return self._get_serializer_class()
if default_to_serializer_class:
return self.serializer_class

return self.get_serializer_class()

return self.write_serializer_class

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "drf-rw-serializers"
version = "1.3.0"
version = "1.4.0"
description = "Generic views, viewsets and mixins that extend the Django REST Framework ones adding separated serializers for read and write operations"
authors = ["Vinta Software <[email protected]>"]
license = "MIT"
Expand Down
2 changes: 2 additions & 0 deletions test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,5 @@ def root(*args):
ROOT_URLCONF = "example_app.urls"

SECRET_KEY = "insecure-secret-key"

DEFAULT_AUTO_FIELD = "django.db.models.AutoField"
96 changes: 90 additions & 6 deletions tests/test_generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,21 @@ def test_serializer_class_not_provided(self):
),
)

def test_get_serializer_class_override_provided(self):
class GetSerializerClassView(generics.GenericAPIView):
def get_serializer_class(self):
return OrderedMealDetailsSerializer

self.assertEqual(
GetSerializerClassView().get_serializer_class(), OrderedMealDetailsSerializer
)
self.assertEqual(
GetSerializerClassView().get_read_serializer_class(), OrderedMealDetailsSerializer
)
self.assertEqual(
GetSerializerClassView().get_write_serializer_class(), OrderedMealDetailsSerializer
)

def test_no_request_provided(self):
# Return serializer_class over read_serializer_class and write_serializer_class
self.assertEqual(
Expand Down Expand Up @@ -108,18 +123,37 @@ def test_non_read_write_request_method_provided(self):
self.FullSerializerView().get_serializer_class(), OrderedMealDetailsSerializer
)

def test_all_get_serializer_class_override_provided(self):
class GetSerializerClassView(generics.GenericAPIView):
def get_serializer_class(self):
return OrderedMealDetailsSerializer

def get_read_serializer_class(self, default_to_serializer_class: bool = False):
return OrderListSerializer

def get_write_serializer_class(self, default_to_serializer_class: bool = False):
return OrderCreateSerializer

self.assertEqual(
GetSerializerClassView().get_serializer_class(), OrderedMealDetailsSerializer
)
self.assertEqual(GetSerializerClassView().get_read_serializer_class(), OrderListSerializer)
self.assertEqual(
GetSerializerClassView().get_write_serializer_class(), OrderCreateSerializer
)


class GenericAPIViewGetReadSerializerClassTests(BaseTestCase):
def test_read_serializer_class_not_provided(self):
class NoReadSerializerView(generics.GenericAPIView):
pass

with mock.patch.object(
NoReadSerializerView, "_get_serializer_class"
) as mock__get_serializer_class:
NoReadSerializerView, "get_serializer_class"
) as mock_get_serializer_class:
NoReadSerializerView().get_read_serializer_class()

mock__get_serializer_class.assert_called_once()
mock_get_serializer_class.assert_called_once()

def test_read_serializer_class_provided(self):
class ReadSerializerClassProvided(generics.GenericAPIView):
Expand All @@ -130,18 +164,43 @@ class ReadSerializerClassProvided(generics.GenericAPIView):
OrderListSerializer,
)

def test_use_serializer_class_fallback(self):
class SerializerClassView(generics.GenericAPIView):
serializer_class = OrderedMealDetailsSerializer

self.assertEqual(
SerializerClassView().get_read_serializer_class(default_to_serializer_class=True),
OrderedMealDetailsSerializer,
)

with mock.patch.object(
SerializerClassView, "get_serializer_class"
) as mock_get_serializer_class:
SerializerClassView().get_read_serializer_class(default_to_serializer_class=False)

mock_get_serializer_class.assert_called_once()

def test_get_read_serializer_class_override_provided(self):
class GetReadSerializerClassView(generics.GenericAPIView):
def get_read_serializer_class(self, default_to_serializer_class: bool = False):
return OrderListSerializer

self.assertEqual(
GetReadSerializerClassView().get_read_serializer_class(), OrderListSerializer
)


class GenericAPIViewGetWriteSerializerClassTests(BaseTestCase):
def test_write_serializer_class_not_provided(self):
class NoWriteSerializerView(generics.GenericAPIView):
pass

with mock.patch.object(
NoWriteSerializerView, "_get_serializer_class"
) as mock__get_serializer_class:
NoWriteSerializerView, "get_serializer_class"
) as mock_get_serializer_class:
NoWriteSerializerView().get_write_serializer_class()

mock__get_serializer_class.assert_called_once()
mock_get_serializer_class.assert_called_once()

def test_write_serializer_class_provided(self):
class WriteSerializerClassProvided(generics.GenericAPIView):
Expand All @@ -152,6 +211,31 @@ class WriteSerializerClassProvided(generics.GenericAPIView):
OrderCreateSerializer,
)

def test_use_serializer_class_fallback(self):
class SerializerClassView(generics.GenericAPIView):
serializer_class = OrderedMealDetailsSerializer

self.assertEqual(
SerializerClassView().get_write_serializer_class(default_to_serializer_class=True),
OrderedMealDetailsSerializer,
)

with mock.patch.object(
SerializerClassView, "get_serializer_class"
) as mock_get_serializer_class:
SerializerClassView().get_write_serializer_class(default_to_serializer_class=False)

mock_get_serializer_class.assert_called_once()

def test_get_write_serializer_class_override_provided(self):
class GetWriteSerializerClassView(generics.GenericAPIView):
def get_write_serializer_class(self, default_to_serializer_class: bool = False):
return OrderCreateSerializer

self.assertEqual(
GetWriteSerializerClassView().get_write_serializer_class(), OrderCreateSerializer
)


class OrderListCreateEndpointTests(BaseTestCase, TestListRequestSuccess, TestCreateRequestSuccess):
def setUp(self):
Expand Down

0 comments on commit fecd903

Please sign in to comment.