Skip to content

Commit

Permalink
Make "propagate_unknown" behavior explicit
Browse files Browse the repository at this point in the history
This adds a new option, `propagate_unknown` which defaults to False.

It controls the behavior of `unknown` with respect to nested
structures. Anywhere that `unknown` can be set, `propagate_unknown`
can be set. That means it can be applied to a schema instance, a load
call, schema.Meta, or to fields.Nested .

When set, nested deserialize calls will get the same value for
`unknown` which their parent call got and they will receive
`propagate_unknown=True`.

The new flag is completely opt-in and therefore backwards
compatible with any current usages of marshmallow.
Once you opt in to this behavior on a schema, you don't need to
worry about making sure it's set by nested schemas that you use.

In the name of simplicity, this sacrifices a bit of flexibility. A
schema with `propagate_unknown=True, unknown=...` will override the
`unknown` settings on any of its child schemas.

Tests cover usage as a schema instantiation arg and as a load arg
for some simple data structures.
  • Loading branch information
sirosen committed Jul 17, 2020
1 parent 9f8b7ed commit 61d3057
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 55 deletions.
19 changes: 17 additions & 2 deletions src/marshmallow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,25 @@ def dump(self, obj, *, many: bool = None):
def dumps(self, obj, *, many: bool = None):
raise NotImplementedError

def load(self, data, *, many: bool = None, partial=None, unknown=None):
def load(
self,
data,
*,
many: bool = None,
partial=None,
unknown=None,
propagate_unknown=None
):
raise NotImplementedError

def loads(
self, json_data, *, many: bool = None, partial=None, unknown=None, **kwargs
self,
json_data,
*,
many: bool = None,
partial=None,
unknown=None,
propagate_unknown=None,
**kwargs
):
raise NotImplementedError
30 changes: 26 additions & 4 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ def __init__(
exclude: types.StrSequenceOrSet = (),
many: bool = False,
unknown: str = None,
propagate_unknown: bool = None,
**kwargs
):
# Raise error if only or exclude is passed as string, not list of strings
Expand All @@ -494,6 +495,7 @@ def __init__(
self.exclude = exclude
self.many = many
self.unknown = unknown
self.propagate_unknown = propagate_unknown
self._schema = None # Cached Schema instance
super().__init__(default=default, **kwargs)

Expand Down Expand Up @@ -571,18 +573,32 @@ def _test_collection(self, value):
if many and not utils.is_collection(value):
raise self.make_error("type", input=value, type=value.__class__.__name__)

def _load(self, value, data, partial=None, unknown=None):
def _load(self, value, data, partial=None, unknown=None, propagate_unknown=None):
try:
valid_data = self.schema.load(
value, unknown=unknown or self.unknown, partial=partial,
value,
unknown=unknown or self.unknown,
propagate_unknown=propagate_unknown
if propagate_unknown is not None
else self.propagate_unknown,
partial=partial,
)
except ValidationError as error:
raise ValidationError(
error.messages, valid_data=error.valid_data
) from error
return valid_data

def _deserialize(self, value, attr, data, partial=None, unknown=None, **kwargs):
def _deserialize(
self,
value,
attr,
data,
partial=None,
unknown=None,
propagate_unknown=None,
**kwargs
):
"""Same as :meth:`Field._deserialize` with additional ``partial`` argument.
:param bool|tuple partial: For nested schemas, the ``partial``
Expand All @@ -592,7 +608,13 @@ def _deserialize(self, value, attr, data, partial=None, unknown=None, **kwargs):
Add ``partial`` parameter.
"""
self._test_collection(value)
return self._load(value, data, partial=partial, unknown=unknown)
return self._load(
value,
data,
partial=partial,
unknown=unknown,
propagate_unknown=propagate_unknown,
)


class Pluck(Nested):
Expand Down
40 changes: 29 additions & 11 deletions src/marshmallow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def __init__(self, meta, ordered: bool = False):
self.load_only = getattr(meta, "load_only", ())
self.dump_only = getattr(meta, "dump_only", ())
self.unknown = getattr(meta, "unknown", RAISE)
self.propagate_unknown = getattr(meta, "propagate_unknown", False)
self.register = getattr(meta, "register", True)


Expand Down Expand Up @@ -372,7 +373,8 @@ def __init__(
load_only: types.StrSequenceOrSet = (),
dump_only: types.StrSequenceOrSet = (),
partial: typing.Union[bool, types.StrSequenceOrSet] = False,
unknown: str = None
unknown: str = None,
propagate_unknown: bool = None
):
# Raise error if only or exclude is passed as string, not list of strings
if only is not None and not is_collection(only):
Expand All @@ -389,6 +391,7 @@ def __init__(
self.dump_only = set(dump_only) or set(self.opts.dump_only)
self.partial = partial
self.unknown = unknown or self.opts.unknown
self.propagate_unknown = propagate_unknown or self.opts.propagate_unknown
self.context = context or {}
self._normalize_nested_options()
#: Dictionary mapping field_names -> :class:`Field` objects
Expand Down Expand Up @@ -592,6 +595,7 @@ def _deserialize(
many: bool = False,
partial=False,
unknown=RAISE,
propagate_unknown=False,
index=None
) -> typing.Union[_T, typing.List[_T]]:
"""Deserialize ``data``.
Expand Down Expand Up @@ -625,6 +629,7 @@ def _deserialize(
many=False,
partial=partial,
unknown=unknown,
propagate_unknown=propagate_unknown,
index=idx,
),
)
Expand All @@ -648,7 +653,7 @@ def _deserialize(
partial_is_collection and attr_name in partial
):
continue
d_kwargs = {}
d_kwargs = {} # type: typing.Dict[str, typing.Any]
# Allow partial loading of nested schemas.
if partial_is_collection:
prefix = field_name + "."
Expand All @@ -660,11 +665,9 @@ def _deserialize(
else:
d_kwargs["partial"] = partial

try:
if self.context["propagate_unknown_to_nested"]:
d_kwargs["unknown"] = unknown
except KeyError:
pass
if propagate_unknown:
d_kwargs["unknown"] = unknown
d_kwargs["propagate_unknown"] = True

getter = lambda val: field_obj.deserialize(
val, field_name, data, **d_kwargs
Expand Down Expand Up @@ -705,7 +708,8 @@ def load(
*,
many: bool = None,
partial: typing.Union[bool, types.StrSequenceOrSet] = None,
unknown: str = None
unknown: str = None,
propagate_unknown: bool = None
):
"""Deserialize a data structure to an object defined by this Schema's fields.
Expand All @@ -728,7 +732,12 @@ def load(
if invalid data are passed.
"""
return self._do_load(
data, many=many, partial=partial, unknown=unknown, postprocess=True
data,
many=many,
partial=partial,
unknown=unknown,
propagate_unknown=propagate_unknown,
postprocess=True,
)

def loads(
Expand All @@ -738,6 +747,7 @@ def loads(
many: bool = None,
partial: typing.Union[bool, types.StrSequenceOrSet] = None,
unknown: str = None,
propagate_unknown: bool = None,
**kwargs
):
"""Same as :meth:`load`, except it takes a JSON string as input.
Expand All @@ -761,7 +771,13 @@ def loads(
if invalid data are passed.
"""
data = self.opts.render_module.loads(json_data, **kwargs)
return self.load(data, many=many, partial=partial, unknown=unknown)
return self.load(
data,
many=many,
partial=partial,
unknown=unknown,
propagate_unknown=propagate_unknown,
)

def _run_validator(
self,
Expand Down Expand Up @@ -822,6 +838,7 @@ def _do_load(
many: bool = None,
partial: typing.Union[bool, types.StrSequenceOrSet] = None,
unknown: str = None,
propagate_unknown: bool = None,
postprocess: bool = True
):
"""Deserialize `data`, returning the deserialized result.
Expand All @@ -843,8 +860,8 @@ def _do_load(
error_store = ErrorStore()
errors = {} # type: typing.Dict[str, typing.List[str]]
many = self.many if many is None else bool(many)
self.context["propagate_unknown_to_nested"] = unknown is not None
unknown = unknown or self.unknown
propagate_unknown = propagate_unknown or self.propagate_unknown
if partial is None:
partial = self.partial
# Run preprocessors
Expand All @@ -868,6 +885,7 @@ def _do_load(
many=many,
partial=partial,
unknown=unknown,
propagate_unknown=propagate_unknown,
)
# Run field-level validation
self._invoke_field_validators(
Expand Down
90 changes: 52 additions & 38 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,56 +308,70 @@ class ShelfSchema(Schema):

@pytest.fixture
def data_nested_unknown(self):
return {
"spam": {"meat": "pork", "add-on": "eggs"},
}
return {"spam": {"meat": "pork", "add-on": "eggs"}}

@pytest.fixture
def multi_nested_data_with_unknown(self, data_nested_unknown):
return {
"can": data_nested_unknown,
return {"can": data_nested_unknown, "box": {"foo": "bar"}}

@pytest.mark.parametrize(
"schema_kwargs,load_kwargs",
[
({}, {"propagate_unknown": True, "unknown": INCLUDE}),
({"propagate_unknown": True}, {"unknown": INCLUDE}),
({"propagate_unknown": True, "unknown": INCLUDE}, {}),
({"unknown": INCLUDE}, {"propagate_unknown": True}),
],
)
def test_propagate_unknown_include(
self,
schema_kwargs,
load_kwargs,
data_nested_unknown,
multi_nested_data_with_unknown,
):
data = self.ShelfSchema(**schema_kwargs).load(
multi_nested_data_with_unknown, **load_kwargs
)
assert data == {
"can": {"spam": {"meat": "pork", "add-on": "eggs"}},
"box": {"foo": "bar"},
}

@pytest.mark.parametrize("schema_kw", ({}, {"unknown": INCLUDE}))
def test_raises_when_unknown_passed_to_first_level_nested(
self, schema_kw, data_nested_unknown,
):
with pytest.raises(ValidationError) as exc_info:
self.CanSchema(**schema_kw).load(data_nested_unknown)
assert exc_info.value.messages == {"spam": {"add-on": ["Unknown field."]}}
data = self.CanSchema(**schema_kwargs).load(data_nested_unknown, **load_kwargs)
assert data == {"spam": {"meat": "pork", "add-on": "eggs"}}

@pytest.mark.parametrize(
"load_kw,expected_data",
(
({"unknown": INCLUDE}, {"spam": {"meat": "pork", "add-on": "eggs"}}),
({"unknown": EXCLUDE}, {"spam": {"meat": "pork"}}),
),
"schema_kwargs,load_kwargs",
[
({}, {"propagate_unknown": True, "unknown": EXCLUDE}),
({"propagate_unknown": True}, {"unknown": EXCLUDE}),
({"propagate_unknown": True, "unknown": EXCLUDE}, {}),
({"unknown": EXCLUDE}, {"propagate_unknown": True}),
],
)
def test_processes_when_unknown_stated_directly(
self, load_kw, data_nested_unknown, expected_data,
def test_propagate_unknown_exclude(
self,
schema_kwargs,
load_kwargs,
data_nested_unknown,
multi_nested_data_with_unknown,
):
data = self.CanSchema().load(data_nested_unknown, **load_kw)
assert data == expected_data
data = self.ShelfSchema(**schema_kwargs).load(
multi_nested_data_with_unknown, **load_kwargs
)
assert data == {"can": {"spam": {"meat": "pork"}}}

@pytest.mark.parametrize(
"load_kw,expected_data",
(
(
{"unknown": INCLUDE},
{
"can": {"spam": {"meat": "pork", "add-on": "eggs"}},
"box": {"foo": "bar"},
},
),
({"unknown": EXCLUDE}, {"can": {"spam": {"meat": "pork"}}}),
),
)
def test_propagates_unknown_to_multi_nested_fields(
self, load_kw, expected_data, multi_nested_data_with_unknown,
data = self.CanSchema(**schema_kwargs).load(data_nested_unknown, **load_kwargs)
assert data == {"spam": {"meat": "pork"}}

@pytest.mark.parametrize("schema_kw", ({}, {"unknown": INCLUDE}))
def test_raises_when_unknown_passed_to_first_level_nested(
self, schema_kw, data_nested_unknown
):
data = self.ShelfSchema().load(multi_nested_data_with_unknown, **load_kw)
assert data == expected_data
with pytest.raises(ValidationError) as exc_info:
self.CanSchema(**schema_kw).load(data_nested_unknown)
assert exc_info.value.messages == {"spam": {"add-on": ["Unknown field."]}}


class TestListNested:
Expand Down

0 comments on commit 61d3057

Please sign in to comment.