Skip to content

Commit

Permalink
make_init method to create an entire __init__ method with control…
Browse files Browse the repository at this point in the history
… of which fields are injected. Fixes #14
  • Loading branch information
Sylvain MARIE committed Sep 26, 2019
1 parent 5c9b986 commit 6129bcd
Show file tree
Hide file tree
Showing 5 changed files with 631 additions and 239 deletions.
7 changes: 4 additions & 3 deletions pyfields/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .core import field, inject_fields, MandatoryFieldInitError, UnsupportedOnNativeFieldError
from .core import field, MandatoryFieldInitError, UnsupportedOnNativeFieldError
from .init_makers import inject_fields, make_init

try:
# Distribution mode : import from _version.py generated by setuptools_scm during release
Expand All @@ -12,7 +13,7 @@
__all__ = [
'__version__',
# submodules
'core',
'core', 'init_makers',
# symbols
'field', 'inject_fields', 'MandatoryFieldInitError', 'UnsupportedOnNativeFieldError'
'field', 'inject_fields', 'make_init', 'MandatoryFieldInitError', 'UnsupportedOnNativeFieldError'
]
231 changes: 15 additions & 216 deletions pyfields/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import sys
from textwrap import dedent

from inspect import isfunction
from inspect import getmro
try:
from inspect import signature, Parameter
except ImportError:
from funcsigs import signature, Parameter

from makefun import with_signature, wraps
from makefun import with_signature
import sentinel

from valid8 import Validator, failure_raiser, ValidationError
Expand Down Expand Up @@ -628,219 +628,6 @@ def __delete__(self, obj):
delattr(obj, "_" + self.name)


def inject_fields(*fields # type: Field
):
"""
A decorator for `__init__` methods, to make them automatically expose arguments corresponding to all `*fields`.
It can be used with or without arguments. If the list of fields is empty, it means "all fields from the class".
The decorated `__init__` method should have an argument named `'fields'`. This argument will be injected with an
object so that users can manually execute the fields initialization. This is done with `fields.init()`.
>>> import sys, pytest
>>> if sys.version_info < (3, 6):
... pytest.skip('doctest skipped')
>>> class Wall(object):
... height = field(doc="Height of the wall in mm.")
... color = field(default='white', doc="Color of the wall.")
...
... @inject_fields(height, color)
... def __init__(self, fields):
... # initialize all fields received
... fields.init(self)
...
... def __repr__(self):
... return "Wall<height=%r, color=%r>" % (self.height, self.color)
...
>>> Wall()
Traceback (most recent call last):
...
TypeError: __init__() missing 1 required positional argument: 'height'
>>> Wall(1)
Wall<height=1, color='white'>
:param fields:
:return:
"""
if len(fields) == 1:
init_fun_candidate = fields[0]
if isfunction(init_fun_candidate):
# called without arguments: return the modified init function
return _apply_inject_fields(init_fun_candidate)

# called with arguments: return a decorator
return lambda init_fun: _apply_inject_fields(init_fun, fields)


class InjectedInitDescriptor(object):
"""
A class member descriptor for the __init__ method that we create with `@inject_fields`.
The first time people access `cls.__init__`, the actual method will be created and injected in the class.
This descriptor will then disappear and the class will behave normally.
Inspired by https://stackoverflow.com/a/3412743/7262247
"""
__slots__ = 'init_fun', 'fields'

def __init__(self, init_fun, fields=None):
self.init_fun = init_fun
self.fields = fields

# not useful and may slow things down anyway
# def __set_name__(self, owner, name):
# if name != '__init__':
# raise ValueError("this should not happen")

def __get__(self, obj, objtype):
if objtype is not None:
# <objtype>.__init__ has been accessed. Create the modified init
fields = self.fields
if fields is None:
fields = collect_all_fields(objtype, auto_set_names=not PY36)
elif not PY36:
# take this opportunity to apply all field names
collect_all_fields(objtype, include_inherited=False, auto_set_names=True)

new_init = create_init_function(self.init_fun, fields)

# replace it forever in the class
setattr(objtype, '__init__', new_init)

# return the new init
return new_init.__get__(obj, objtype)


_apply_inject_fields = InjectedInitDescriptor
# def _apply_inject_fields(init_fun, fields=None):
# """
# A decorator for the `__init__` function
#
# :param init_fun:
# :return:
# """
# if fields is None or len(fields) == 0:
# return InjectedInitDescriptor(init_fun)
# else:
# # explicit list of fields
# # Note: we can not return it directly because the name might not be available yet ! TODO even in python 3.6?
# # return create_init_function(init_fun, fields)
# return InjectedInitDescriptor(init_fun)


def create_init_function(init_fun,
fields # type: Iterable[Field]
):
"""
Creates the new init function that will replace `init_fun`.
:param init_fun:
:param fields:
:return:
"""
# read the existing signature of __init__
init_sig = signature(init_fun)
params = list(init_sig.parameters.values())

# find the index of the 'fields' parameter
for i, p in enumerate(params):
if p.name == 'fields':
# found
break
else:
# 'fields' not found: raise an error
try:
name = init_fun.__qualname__
except AttributeError:
name = init_fun.__name__
raise ValueError("Error applying `@inject_fields` on `%s%s`: "
"no 'fields' argument is available in the signature." % (name, init_sig))

# remove the fields parameter
del params[i]

# inject in the same position, all fields that should be included
# Note: preserve order as much as possible, but automatically place all mandatory fields first so that the
# signature is valid.
field_names = []
last_mandatory_idx = i
for _field in reversed(fields):
# Is this field optional ?
if _field.is_mandatory:
# mandatory
where_to_insert = i
last_mandatory_idx += 1
default = Parameter.empty
elif _field.is_default_factory:
# optional with a default value factory: place a specific symbol in the signature to indicate it
default = USE_FACTORY
where_to_insert = last_mandatory_idx
else:
# optional with a default value
default = _field.default
where_to_insert = last_mandatory_idx

# Are there annotations on the field ?
annotation = _field.annotation if _field.annotation is not EMPTY else Parameter.empty

# remember the list of field names for later use
field_names.append(_field.name)

# finally inject the new parameter in the signature
new_param = Parameter(_field.name, kind=Parameter.POSITIONAL_OR_KEYWORD, default=default, annotation=annotation)
params.insert(where_to_insert, new_param)

# finally replace the signature with the newly created one
new_sig = init_sig.replace(parameters=params)

# and create the new init method
@wraps(init_fun, new_sig=new_sig)
def init_fun_mod(*args, **kwargs):
"""
The `__init__` method generated for you when you use `@inject_fields` on your `__init__`
"""
# 1. remove all field values received from the outer signature
_fields = dict()
for f_name in field_names:
_fields[f_name] = kwargs.pop(f_name)

# 2. inject our special variable
kwargs['fields'] = FieldsForInit(**_fields)

# 3. call your __init__ method
return init_fun(*args, **kwargs)

return init_fun_mod


class FieldsForInit(object):
"""
The object that is injected in the users' `__init__` method as the `fields` argument,
when it has been decorated with `@inject_fields`.
All field values received from the generated `__init__` are available in `self.field_values`, and
a `init()` method allows users to perform the initialization per se.
"""
__slots__ = 'field_values'

def __init__(self, **init_field_values):
self.field_values = init_field_values

def init(self, obj):
"""
Initializes all fields on the provided object
:param obj:
:return:
"""
for field_name, field_value in self.field_values.items():
if field_value is not USE_FACTORY:
# init the field with the provided value or the injected default value
setattr(obj, field_name, field_value)
else:
# init the field with its factory
getattr(obj, field_name, field_value)


def collect_all_fields(cls,
include_inherited=True,
auto_set_names=False):
Expand All @@ -856,7 +643,7 @@ def collect_all_fields(cls,
"""
result = []
if include_inherited:
where = dir(cls)
where = ordereddir(cls)
else:
where = vars(cls)

Expand All @@ -879,6 +666,18 @@ def collect_all_fields(cls,
return result


def ordereddir(cls):
"""
since `dir` does not preserve order, lets have our own implementation
:param cls:
:return:
"""
for parent in getmro(cls):
for k in vars(parent):
yield k


def fix_field_names(cls):
"""
Fixes all field names at once on the given class
Expand Down
Loading

0 comments on commit 6129bcd

Please sign in to comment.