Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimise dumping to reduce unnecessary overhead #1649

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,4 @@ Contributors (chronological)
- Ryan Morehart `@traherom <https://github.com/traherom>`_
- Ben Windsor `@bwindsor <https://github.com/bwindsor>`_
- Kevin Kirsche `@kkirsche <https://github.com/kkirsche>`_
- Dusko Simidzija `@dsimidzija <https://github.com/dsimidzija>`_
46 changes: 46 additions & 0 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
missing as missing_,
resolve_field_instance,
is_aware,
_get_value_for_key,
_get_value_for_keys,
)
from marshmallow.exceptions import (
ValidationError,
Expand Down Expand Up @@ -310,6 +312,50 @@ def _validate_missing(self, value):
if value is None and not self.allow_none:
raise self.make_error("null")

def get_serializer(
self,
attr: str,
accessor: typing.Optional[
typing.Callable[[typing.Any, str, typing.Any], typing.Any]
] = None,
**kwargs,
) -> typing.Callable[[typing.Any], typing.Any]:
"""Return an optimized serializer for this Field object.

:param str attr: The attribute or key on the object to be serialized.
:param dict kwargs: Field-specific keyword arguments.
:return: Serializer function.
"""
if not self._CHECK_ATTRIBUTE:
return lambda obj: self._serialize(None, attr, obj, **kwargs)

attribute = getattr(self, "attribute", None)
check_key = attr if attribute is None else attribute
dump_default = None
callable_default = False
has_default = hasattr(self, "dump_default")
if has_default:
dump_default = self.dump_default
callable_default = callable(dump_default)
if accessor:
accessor_func = accessor
else:
if not isinstance(check_key, int) and "." in check_key:
accessor_func = _get_value_for_keys
check_key = check_key.split(".")
else:
accessor_func = _get_value_for_key

def _serializer(obj):
value = accessor_func(obj, check_key, missing_)
if value is missing_ and has_default:
value = dump_default() if callable_default else dump_default
if value is missing_:
return value
return self._serialize(value, attr, obj, **kwargs)

return _serializer

def serialize(
self,
attr: str,
Expand Down
56 changes: 42 additions & 14 deletions src/marshmallow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,9 @@ def __init__(
self.fields = {} # type: typing.Dict[str, ma_fields.Field]
self.load_fields = {} # type: typing.Dict[str, ma_fields.Field]
self.dump_fields = {} # type: typing.Dict[str, ma_fields.Field]
self.dump_serializers = (
self.dict_class()
) # type: typing.Dict[str, typing.Callable]
self._init_fields()
messages = {}
messages.update(self._default_error_messages)
Expand Down Expand Up @@ -466,7 +469,7 @@ def handle_error(
"""
pass

def get_attribute(self, obj: typing.Any, attr: str, default: typing.Any):
def default_get_attribute(self, obj: typing.Any, attr: str, default: typing.Any):
"""Defines how to pull values from an object to serialize.

.. versionadded:: 2.0.0
Expand All @@ -476,6 +479,8 @@ def get_attribute(self, obj: typing.Any, attr: str, default: typing.Any):
"""
return get_value(obj, attr, default)

get_attribute = default_get_attribute

##### Serialization/Deserialization API #####

@staticmethod
Expand Down Expand Up @@ -510,19 +515,41 @@ def _serialize(
.. versionchanged:: 1.0.0
Renamed from ``marshal``.
"""
if many and obj is not None:
return [
self._serialize(d, many=False)
for d in typing.cast(typing.Iterable[_T], obj)
]
ret = self.dict_class()
for attr_name, field_obj in self.dump_fields.items():
value = field_obj.serialize(attr_name, obj, accessor=self.get_attribute)
if value is missing:
continue
key = field_obj.data_key if field_obj.data_key is not None else attr_name
ret[key] = value
return ret
if not self.dump_serializers:
accessor = (
None
) # type: typing.Optional[typing.Callable[[typing.Any, str, typing.Any], typing.Any]]
if self.get_attribute != self.default_get_attribute:
accessor = self.get_attribute

for field_name, field_obj in self.dump_fields.items():
key = (
field_obj.data_key if field_obj.data_key is not None else field_name
)
self.dump_serializers[key] = field_obj.get_serializer(
dsimidzija marked this conversation as resolved.
Show resolved Hide resolved
field_name, accessor
)

source_obj = [None] # typing: typing.List[typing.Any]

if not many:
source_obj = [typing.cast(typing.Any, obj)]
elif many and obj is not None:
source_obj = typing.cast(typing.List[typing.Any], obj)

output = []
for current_obj in source_obj:
ret = self.dict_class()
for key, serializer in self.dump_serializers.items():
value = serializer(current_obj)
if value is missing:
continue
ret[key] = value
output.append(ret)

if not many:
return output[0]
return output

def dump(self, obj: typing.Any, *, many: typing.Optional[bool] = None):
"""Serialize an object to native Python data types according to this
Expand Down Expand Up @@ -1015,6 +1042,7 @@ def _init_fields(self) -> None:
self.fields = fields_dict
self.dump_fields = dump_fields
self.load_fields = load_fields
self.dump_serializers = self.dict_class()

def on_bind_field(self, field_name: str, field_obj: ma_fields.Field) -> None:
"""Hook to modify a field when it is bound to the `Schema`.
Expand Down