From 8b97377e115ba15ace382af48a69f0c40b120c7e Mon Sep 17 00:00:00 2001 From: Stefaan Lippens Date: Tue, 30 May 2023 18:36:13 +0200 Subject: [PATCH] Issue #195 #196 use official spec of `text_concat` also support custom spec in `add_simple_function` --- openeo_driver/ProcessGraphDeserializer.py | 11 +++-------- openeo_driver/processes.py | 13 +++++++++---- tests/test_processes.py | 21 +++++++++++++++++++-- tests/test_views_execute.py | 1 + 4 files changed, 32 insertions(+), 14 deletions(-) diff --git a/openeo_driver/ProcessGraphDeserializer.py b/openeo_driver/ProcessGraphDeserializer.py index 36739285..eab4bbde 100644 --- a/openeo_driver/ProcessGraphDeserializer.py +++ b/openeo_driver/ProcessGraphDeserializer.py @@ -2068,16 +2068,11 @@ def text_merge( return str(separator).join(str(d) for d in data) -# TODO #195 #196 use official spec instead of custom openeo-processes/experimental/text_concat.json -@process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/text_concat.json")) +@process_registry_100.add_simple_function(spec=read_spec("openeo-processes/2.x/text_concat.json")) def text_concat( - args: Dict, - env: EvalEnv - # data: List[Union[str, int, float, bool, None]], - # separator: Union[str, int, float, bool, None] = "" + data: List[Union[str, int, float, bool, None]], + separator: str = "", ) -> str: - data = extract_arg(args, "data") - separator = args.get("separator", "") return str(separator).join(str(d) for d in data) diff --git a/openeo_driver/processes.py b/openeo_driver/processes.py index d6081fc7..4a318e35 100644 --- a/openeo_driver/processes.py +++ b/openeo_driver/processes.py @@ -178,17 +178,20 @@ def add_function( ) return f - def add_simple_function(self, f: Callable = None, name: str = None): + def add_simple_function( + self, f: Optional[Callable] = None, name: Optional[str] = None, spec: Optional[dict] = None + ): """ Register a simple function that uses normal arguments instead of `args: dict, env: EvalEnv`: wrap it in a wrapper that automatically extracts these arguments :param f: :param name: process_id (when guessing from `f.__name__` doesn't work) + :param spec: optional spec dict :return: """ if f is None: # Called as parameterized decorator - return functools.partial(self.add_simple_function, name=name) + return functools.partial(self.add_simple_function, name=name, spec=spec) process_id = name or f.__name__ # Detect arguments without and with defaults @@ -201,18 +204,20 @@ def add_simple_function(self, f: Callable = None, name: str = None): else: defaults[param.name] = param.default - # TODO: avoid this local import, e.g. by encapsulating all extrac_ functions in some kind of ProcessArgs object + # TODO: avoid this local import, e.g. by encapsulating all extract_ functions in some kind of ProcessArgs object from openeo_driver.ProcessGraphDeserializer import extract_arg # TODO: can we generalize this assumption? assert self._argument_names == ["args", "env"] + # TODO: option to also pass `env: EvalEnv` to `f`? def wrapped(args: dict, env: EvalEnv): kwargs = {a: extract_arg(args, a, process_id=process_id) for a in required} kwargs.update({a: args.get(a, d) for a, d in defaults.items()}) return f(**kwargs) - self.add_process(name=process_id, function=wrapped, spec=self.load_predefined_spec(process_id)) + spec = spec or self.load_predefined_spec(process_id) + self.add_process(name=process_id, function=wrapped, spec=spec) return f def add_hidden(self, f: Callable, name: str = None, namespace: str = DEFAULT_NAMESPACE): diff --git a/tests/test_processes.py b/tests/test_processes.py index da8cd259..f55627de 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -392,7 +392,7 @@ def add(x: int, y: int = 100): assert process(args={"x": 2, "y": 3}, env=None) == 5 assert process(args={"x": 2}, env=None) == 102 - with pytest.raises(ProcessParameterRequiredException): + with pytest.raises(ProcessParameterRequiredException, match="Process 'add' parameter 'x' is required."): _ = process(args={}, env=None) @@ -408,5 +408,22 @@ def if_(value, accept, reject=None): assert process(args={"value": True, "accept": 3}, env=None) == 3 assert process(args={"value": False, "accept": 3}, env=None) is None assert process(args={"value": False, "accept": 3, "reject": 5}, env=None) == 5 - with pytest.raises(ProcessParameterRequiredException): + with pytest.raises(ProcessParameterRequiredException, match="Process 'if' parameter 'value' is required."): + _ = process(args={}, env=None) + + +def test_process_registry_add_simple_function_with_spec(): + reg = ProcessRegistry(argument_names=["args", "env"]) + + @reg.add_simple_function(spec={"id": "something_custom"}) + def something_custom(x: int, y: int = 123): + return x + y + + process = reg.get_function("something_custom") + + assert process(args={"x": 5, "y": 3}, env=None) == 8 + assert process(args={"x": 5}, env=None) == 128 + with pytest.raises( + ProcessParameterRequiredException, match="Process 'something_custom' parameter 'x' is required." + ): _ = process(args={}, env=None) diff --git a/tests/test_views_execute.py b/tests/test_views_execute.py index b9790379..3417e6ca 100644 --- a/tests/test_views_execute.py +++ b/tests/test_views_execute.py @@ -2449,6 +2449,7 @@ def test_execute_no_cube_logic(api100, process_graph, expected): ("text_ends", {"data": "FooBar", "pattern": "Foo"}, False), ("text_ends", {"data": "FooBar", "pattern": "bar"}, False), ("text_ends", {"data": "FooBar", "pattern": "bar", "case_sensitive": False}, True), + # TODO: `text_merge` is deprecated (in favor of `text_concat`) ("text_merge", {"data": ["foo", "bar"]}, "foobar"), ("text_merge", {"data": ["foo", "bar"], "separator": "--"}, "foo--bar"), ("text_merge", {"data": [1, 2, 3], "separator": "/"}, "1/2/3"),