From 85089c3c1d4ba7b700ef74d9b9581346c5cd826a Mon Sep 17 00:00:00 2001 From: Stefaan Lippens Date: Tue, 10 Oct 2023 18:18:48 +0200 Subject: [PATCH] Issue #195 use openeo-processes 2.0 based process reg for 1.2 api requests --- openeo_driver/ProcessGraphDeserializer.py | 197 +++++++++++++++------- openeo_driver/processes.py | 51 ++++-- tests/test_processes.py | 37 ++++ 3 files changed, 213 insertions(+), 72 deletions(-) diff --git a/openeo_driver/ProcessGraphDeserializer.py b/openeo_driver/ProcessGraphDeserializer.py index 33dfeeb9..8e1710fb 100644 --- a/openeo_driver/ProcessGraphDeserializer.py +++ b/openeo_driver/ProcessGraphDeserializer.py @@ -11,7 +11,7 @@ import time import warnings from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Tuple, Union, Sequence import geopandas as gpd import numpy as np @@ -55,8 +55,15 @@ CollectionNotFoundException, ) from openeo_driver.processes import ProcessRegistry, ProcessSpec, DEFAULT_NAMESPACE, ProcessArgs -from openeo_driver.save_result import JSONResult, SaveResult, AggregatePolygonResult, NullResult, \ - to_save_result, AggregatePolygonSpatialResult, MlModelResult +from openeo_driver.save_result import ( + JSONResult, + SaveResult, + AggregatePolygonResult, + NullResult, + to_save_result, + AggregatePolygonSpatialResult, + MlModelResult, +) from openeo_driver.specs import SPECS_ROOT, read_spec from openeo_driver.util.date_math import month_shift from openeo_driver.util.geometry import geojson_to_geometry, geojson_to_multipolygon, spatial_extent_union @@ -66,8 +73,12 @@ _log = logging.getLogger(__name__) # Set up process registries (version dependent) -process_registry_100 = ProcessRegistry(spec_root=SPECS_ROOT / 'openeo-processes/1.x', argument_names=["args", "env"]) -# TODO #195 support openeo-processes 2.x + +# Process registry based on 1.x version of openeo-processes, to be used with api_version 1.0 an 1.1 +process_registry_100 = ProcessRegistry(spec_root=SPECS_ROOT / "openeo-processes/1.x", argument_names=["args", "env"]) + +# Process registry based on 2.x version of openeo-processes, to be used starting with api_version 1.2 +process_registry_2xx = ProcessRegistry(spec_root=SPECS_ROOT / "openeo-processes/2.x", argument_names=["args", "env"]) def _add_standard_processes(process_registry: ProcessRegistry, process_ids: List[str]): @@ -119,6 +130,7 @@ def wrapped(args: dict, env: EvalEnv): } _add_standard_processes(process_registry_100, _OPENEO_PROCESSES_PYTHON_WHITELIST) +_add_standard_processes(process_registry_2xx, _OPENEO_PROCESSES_PYTHON_WHITELIST) # Type hint alias for a "process function": @@ -127,8 +139,32 @@ def wrapped(args: dict, env: EvalEnv): def process(f: ProcessFunction) -> ProcessFunction: - """Decorator for registering a process function in the process registries""" + """ + Decorator for registering a process function in the process registries. + To be used as shortcut for all simple cases of + + @process_registry_100.add_function + @process_registry_2xx.add_function + def foo(args, env): + ... + """ process_registry_100.add_function(f) + process_registry_2xx.add_function(f) + return f + + +def simple_function(f: Callable) -> Callable: + """ + Decorator for registering a process function in the process registries. + To be used as shortcut for all simple cases of + + @process_registry_100.add_simple_function + @process_registry_2xx.add_simple_function + def foo(args, env): + ... + """ + process_registry_100.add_simple_function(f) + process_registry_2xx.add_simple_function(f) return f @@ -137,6 +173,7 @@ def non_standard_process(spec: ProcessSpec) -> Callable[[ProcessFunction], Proce def decorator(f: ProcessFunction) -> ProcessFunction: process_registry_100.add_function(f=f, spec=spec.to_dict_100()) + process_registry_2xx.add_function(f=f, spec=spec.to_dict_100()) return f return decorator @@ -145,27 +182,29 @@ def decorator(f: ProcessFunction) -> ProcessFunction: def custom_process(f: ProcessFunction): """Decorator for custom processes (e.g. in custom_processes.py).""" process_registry_100.add_hidden(f) + process_registry_2xx.add_hidden(f) return f def custom_process_from_process_graph( - process_spec: Union[dict, Path], - process_registry: ProcessRegistry = process_registry_100, - namespace: str = DEFAULT_NAMESPACE + process_spec: Union[dict, Path], + process_registries: Sequence[ProcessRegistry] = (process_registry_100, process_registry_2xx), + namespace: str = DEFAULT_NAMESPACE, ): """ Register a custom process from a process spec containing a "process_graph" definition :param process_spec: process spec dict or path to a JSON file, containing keys like "id", "process_graph", "parameter" - :param process_registry: process registry to register to + :param process_registries: process registries to register to """ # TODO: option to hide process graph for (public) listing if isinstance(process_spec, Path): process_spec = load_json(process_spec) process_id = process_spec["id"] process_function = _process_function_from_process_graph(process_spec) - process_registry.add_function(process_function, name=process_id, spec=process_spec, namespace=namespace) + for process_registry in process_registries: + process_registry.add_function(process_function, name=process_id, spec=process_spec, namespace=namespace) def _process_function_from_process_graph(process_spec: dict) -> ProcessFunction: @@ -188,7 +227,7 @@ def process_function(args: dict, env: EvalEnv): return process_function -def _register_fallback_implementations_by_process_graph(process_registry: ProcessRegistry = process_registry_100): +def _register_fallback_implementations_by_process_graph(process_registry: ProcessRegistry): """ Register process functions for (yet undefined) processes that have a process graph based fallback implementation in their spec @@ -197,7 +236,7 @@ def _register_fallback_implementations_by_process_graph(process_registry: Proces spec = process_registry.load_predefined_spec(name) if "process_graph" in spec and not process_registry.contains(name): _log.info(f"Registering fallback implementation of {name!r} by process graph ({process_registry})") - custom_process_from_process_graph(process_spec=spec, process_registry=process_registry) + custom_process_from_process_graph(process_spec=spec, process_registries=[process_registry]) # Some (env) string constants to simplify code navigation @@ -217,8 +256,13 @@ class SimpleProcessing(Processing): def get_process_registry(self, api_version: Union[str, ComparableVersion]) -> ProcessRegistry: # Lazy load registry. - assert ComparableVersion("1.0.0").or_higher(api_version) - spec = 'openeo-processes/1.x' + api_version = ComparableVersion(api_version) + if api_version.at_least("1.2.0"): + spec = "openeo-processes/2.x" + elif api_version.at_least("1.0.0"): + spec = "openeo-processes/1.x" + else: + raise OpenEOApiException(message=f"No process support for openEO version {api_version}") if spec not in self._registry_cache: registry = ProcessRegistry(spec_root=SPECS_ROOT / spec, argument_names=["args", "env"]) _add_standard_processes(registry, _OPENEO_PROCESSES_PYTHON_WHITELIST) @@ -245,9 +289,10 @@ class ConcreteProcessing(Processing): """ def get_process_registry(self, api_version: Union[str, ComparableVersion]) -> ProcessRegistry: - if ComparableVersion("1.0.0").or_higher(api_version): + if ComparableVersion(api_version).at_least("1.2.0"): + return process_registry_2xx + elif ComparableVersion(api_version).at_least("1.0.0"): return process_registry_100 - # TODO #195 support openeo-processes 2.x else: raise OpenEOApiException(message=f"No process support for openEO version {api_version}") @@ -304,7 +349,7 @@ def evaluate( Converts the json representation of a (part of a) process graph into the corresponding Python data cube. """ - if 'version' not in env: + if "version" not in env: _log.warning("No version in `evaluate()` env. Blindly assuming 1.0.0.") env = env.push({"version": "1.0.0"}) @@ -637,7 +682,7 @@ def vector_buffer(args: Dict, env: EvalEnv) -> dict: return mapping(poly_buff_latlon[0]) if len(poly_buff_latlon) == 1 else mapping(poly_buff_latlon) -@process_registry_100.add_function +@process def apply_neighborhood(args: ProcessArgs, env: EvalEnv) -> DriverDataCube: data_cube = args.get_required("data", expected_type=DriverDataCube) process = args.get_deep("process", "process_graph", expected_type=dict) @@ -646,6 +691,7 @@ def apply_neighborhood(args: ProcessArgs, env: EvalEnv) -> DriverDataCube: context = args.get_optional("context", default=None) return data_cube.apply_neighborhood(process=process, size=size, overlap=overlap, env=env, context=context) + @process def apply_dimension(args: ProcessArgs, env: EvalEnv) -> DriverDataCube: data_cube = args.get_required("data", expected_type=(DriverDataCube, DriverVectorCube)) @@ -687,6 +733,7 @@ def save_result(args: Dict, env: EvalEnv) -> SaveResult: @process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/save_ml_model.json")) +@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/save_ml_model.json")) def save_ml_model(args: dict, env: EvalEnv) -> MlModelResult: data: DriverMlModel = extract_arg(args, "data", process_id="save_ml_model") if not isinstance(data, DriverMlModel): @@ -698,6 +745,7 @@ def save_ml_model(args: dict, env: EvalEnv) -> MlModelResult: @process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/load_ml_model.json")) +@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/load_ml_model.json")) def load_ml_model(args: dict, env: EvalEnv) -> DriverMlModel: job_id = extract_arg(args, "id") return env.backend_implementation.load_ml_model(job_id) @@ -714,7 +762,7 @@ def apply(args: ProcessArgs, env: EvalEnv) -> DriverDataCube: return data_cube.apply(process=apply_pg, context=context, env=env) -@process_registry_100.add_function +@process def reduce_dimension(args: ProcessArgs, env: EvalEnv) -> DriverDataCube: data_cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube) reduce_pg = args.get_deep("reducer", "process_graph", expected_type=dict) @@ -725,8 +773,12 @@ def reduce_dimension(args: ProcessArgs, env: EvalEnv) -> DriverDataCube: return data_cube.reduce_dimension(reducer=reduce_pg, dimension=dimension, context=context, env=env) -@process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/chunk_polygon.json")) -def chunk_polygon(args: ProcessArgs, env: EvalEnv) -> DriverDataCube: +@process_registry_100.add_function( + spec=read_spec("openeo-processes/experimental/chunk_polygon.json"), name="chunk_polygon" +) +@process_registry_100.add_function(spec=read_spec("openeo-processes/2.x/proposals/apply_polygon.json")) +@process_registry_2xx.add_function(spec=read_spec("openeo-processes/2.x/proposals/apply_polygon.json")) +def apply_polygon(args: ProcessArgs, env: EvalEnv) -> DriverDataCube: # TODO #229 deprecate this process and promote the "apply_polygon" name. # See https://github.com/Open-EO/openeo-processes/issues/287, https://github.com/Open-EO/openeo-processes/pull/298 data_cube = args.get_required("data", expected_type=DriverDataCube) @@ -765,8 +817,8 @@ def chunk_polygon(args: ProcessArgs, env: EvalEnv) -> DriverDataCube: @process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/fit_class_random_forest.json")) +@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/fit_class_random_forest.json")) def fit_class_random_forest(args: dict, env: EvalEnv) -> DriverMlModel: - # Keep it simple for dry run if env.get(ENV_DRY_RUN_TRACER): return DriverMlModel() @@ -815,17 +867,21 @@ def fit_class_random_forest(args: dict, env: EvalEnv) -> DriverMlModel: @process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/predict_random_forest.json")) -def predict_random_forest(args: dict, env: EvalEnv) -> SaveResult: - pass +@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/predict_random_forest.json")) +def predict_random_forest(args: dict, env: EvalEnv): + raise NotImplementedError @process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/predict_catboost.json")) -def predict_catboost(args: dict, env: EvalEnv) -> SaveResult: - pass +@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/predict_catboost.json")) +def predict_catboost(args: dict, env: EvalEnv): + raise NotImplementedError + @process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/predict_probabilities.json")) -def predict_probabilities(args: dict, env: EvalEnv) -> SaveResult: - pass +@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/predict_probabilities.json")) +def predict_probabilities(args: dict, env: EvalEnv): + raise NotImplementedError @process @@ -838,7 +894,7 @@ def add_dimension(args: ProcessArgs, env: EvalEnv) -> DriverDataCube: ) -@process_registry_100.add_function +@process def drop_dimension(args: dict, env: EvalEnv) -> DriverDataCube: data_cube = extract_arg(args, 'data') if not isinstance(data_cube, DriverDataCube): @@ -849,7 +905,7 @@ def drop_dimension(args: dict, env: EvalEnv) -> DriverDataCube: return data_cube.drop_dimension(name=extract_arg(args, 'name')) -@process_registry_100.add_function +@process def dimension_labels(args: dict, env: EvalEnv) -> DriverDataCube: data_cube = extract_arg(args, 'data') if not isinstance(data_cube, DriverDataCube): @@ -859,7 +915,8 @@ def dimension_labels(args: dict, env: EvalEnv) -> DriverDataCube: ) return data_cube.dimension_labels(dimension=extract_arg(args, 'dimension')) -@process_registry_100.add_function + +@process def rename_dimension(args: dict, env: EvalEnv) -> DriverDataCube: data_cube = extract_arg(args, 'data') if not isinstance(data_cube, DriverDataCube): @@ -869,7 +926,8 @@ def rename_dimension(args: dict, env: EvalEnv) -> DriverDataCube: ) return data_cube.rename_dimension(source=extract_arg(args, 'source'),target=extract_arg(args, 'target')) -@process_registry_100.add_function + +@process def rename_labels(args: dict, env: EvalEnv) -> DriverDataCube: data_cube = extract_arg(args, 'data') if not isinstance(data_cube, DriverDataCube): @@ -902,7 +960,7 @@ def aggregate_temporal(args: ProcessArgs, env: EvalEnv) -> DriverDataCube: ) -@process_registry_100.add_function +@process def aggregate_temporal_period(args: ProcessArgs, env: EvalEnv) -> DriverDataCube: data_cube = args.get_required("data", expected_type=DriverDataCube) period = args.get_required("period") @@ -989,7 +1047,7 @@ def _period_to_intervals(start, end, period) -> List[Tuple[pd.Timestamp, pd.Time return intervals -@process_registry_100.add_function +@process def aggregate_spatial(args: ProcessArgs, env: EvalEnv) -> DriverDataCube: cube = args.get_required("data", expected_type=DriverDataCube) reduce_pg = args.get_deep("reducer", "process_graph", expected_type=dict) @@ -1023,7 +1081,7 @@ def aggregate_spatial(args: ProcessArgs, env: EvalEnv) -> DriverDataCube: return cube.aggregate_spatial(geometries=geoms, reducer=reduce_pg, target_dimension=target_dimension) -@process_registry_100.add_function +@process def mask(args: dict, env: EvalEnv) -> DriverDataCube: cube = extract_arg(args, 'data') if not isinstance(cube, DriverDataCube): @@ -1035,7 +1093,7 @@ def mask(args: dict, env: EvalEnv) -> DriverDataCube: return cube.mask(mask=mask, replacement=replacement) -@process_registry_100.add_function +@process def mask_polygon(args: dict, env: EvalEnv) -> DriverDataCube: mask = extract_arg(args, 'mask') replacement = args.get('replacement', None) @@ -1134,7 +1192,7 @@ def filter_bbox(args: Dict, env: EvalEnv) -> DriverDataCube: return cube.filter_bbox(**spatial_extent) -@process_registry_100.add_function +@process def filter_spatial(args: Dict, env: EvalEnv) -> DriverDataCube: cube = extract_arg(args, 'data') geometries = extract_arg(args, 'geometries') @@ -1332,7 +1390,7 @@ def linear_scale_range(args: dict, env: EvalEnv) -> DriverDataCube: return image_collection.linear_scale_range(inputMin, inputMax, outputMin, outputMax) -@process_registry_100.add_function +@process def constant(args: dict, env: EvalEnv): return args["x"] @@ -1353,6 +1411,7 @@ def recurse(graph): recurse(process_graph) return children_node_types + def flatten_children_node_names(process_graph: Union[dict, list]): children_node_names = set() @@ -1493,6 +1552,7 @@ def read_vector(args: Dict, env: EvalEnv) -> DelayedVector: @process_registry_100.add_function(spec=read_spec("openeo-processes/1.x/proposals/load_uploaded_files.json")) +@process_registry_2xx.add_function(spec=read_spec("openeo-processes/2.x/proposals/load_uploaded_files.json")) def load_uploaded_files(args: ProcessArgs, env: EvalEnv) -> Union[DriverVectorCube, DriverDataCube]: # TODO #114 EP-3981 process name is still under discussion https://github.com/Open-EO/openeo-processes/issues/322 paths = args.get_required("paths", expected_type=list) @@ -1541,6 +1601,7 @@ def to_vector_cube(args: Dict, env: EvalEnv): @process_registry_100.add_function(spec=read_spec("openeo-processes/2.x/proposals/load_geojson.json")) +@process_registry_2xx.add_function(spec=read_spec("openeo-processes/2.x/proposals/load_geojson.json")) def load_geojson(args: ProcessArgs, env: EvalEnv) -> DriverVectorCube: data = args.get_required( "data", @@ -1556,6 +1617,7 @@ def load_geojson(args: ProcessArgs, env: EvalEnv) -> DriverVectorCube: @process_registry_100.add_function(spec=read_spec("openeo-processes/2.x/proposals/load_url.json")) +@process_registry_2xx.add_function(spec=read_spec("openeo-processes/2.x/proposals/load_url.json")) def load_url(args: ProcessArgs, env: EvalEnv) -> DriverVectorCube: # TODO: Follow up possible `load_url` changes https://github.com/Open-EO/openeo-processes/issues/450 ? url = args.get_required("url", expected_type=str, validator=re.compile("^https?://").match) @@ -1800,6 +1862,7 @@ def atmospheric_correction(args: ProcessArgs, env: EvalEnv) -> DriverDataCube: @process_registry_100.add_function(spec=read_spec("openeo-processes/1.x/proposals/sar_backscatter.json")) +@process_registry_2xx.add_function(spec=read_spec("openeo-processes/2.x/proposals/sar_backscatter.json")) def sar_backscatter(args: ProcessArgs, env: EvalEnv): cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube) kwargs = args.get_subset( @@ -1818,6 +1881,7 @@ def sar_backscatter(args: ProcessArgs, env: EvalEnv): @process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/resolution_merge.json")) +@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/resolution_merge.json")) def resolution_merge(args: ProcessArgs, env: EvalEnv): cube: DriverDataCube = args.get_required("data", expected_type=DriverDataCube) kwargs = args.get_subset(names=["method", "high_resolution_bands", "low_resolution_bands", "options"]) @@ -1835,6 +1899,7 @@ def discard_result(args: Dict, env: EvalEnv): @process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/mask_scl_dilation.json")) +@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/mask_scl_dilation.json")) def mask_scl_dilation(args: Dict, env: EvalEnv): cube: DriverDataCube = extract_arg(args, 'data') if not isinstance(cube, DriverDataCube): @@ -1851,6 +1916,7 @@ def mask_scl_dilation(args: Dict, env: EvalEnv): @process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/to_scl_dilation_mask.json")) +@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/to_scl_dilation_mask.json")) def to_scl_dilation_mask(args: Dict, env: EvalEnv): cube: DriverDataCube = extract_arg(args, "data") if not isinstance(cube, DriverDataCube): @@ -1871,6 +1937,7 @@ def to_scl_dilation_mask(args: Dict, env: EvalEnv): @process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/mask_l1c.json")) +@process_registry_2xx.add_function(spec=read_spec("openeo-processes/experimental/mask_l1c.json")) def mask_l1c(args: Dict, env: EvalEnv): cube: DriverDataCube = extract_arg(args, 'data') if not isinstance(cube, DriverDataCube): @@ -1886,15 +1953,21 @@ def mask_l1c(args: Dict, env: EvalEnv): custom_process_from_process_graph(read_spec("openeo-processes/1.x/proposals/ard_normalized_radar_backscatter.json")) + @process_registry_100.add_function(spec=read_spec("openeo-processes/1.x/proposals/array_append.json")) -def array_append(args: Dict, env: EvalEnv) -> str: - pass +@process_registry_2xx.add_function +def array_append(args: ProcessArgs, env: EvalEnv) -> list: + raise NotImplementedError + @process_registry_100.add_function(spec=read_spec("openeo-processes/1.x/proposals/array_interpolate_linear.json")) -def array_interpolate_linear(args: Dict, env: EvalEnv) -> str: - pass +@process_registry_2xx.add_function +def array_interpolate_linear(args: ProcessArgs, env: EvalEnv) -> list: + raise NotImplementedError + @process_registry_100.add_function(spec=read_spec("openeo-processes/1.x/proposals/date_shift.json")) +@process_registry_2xx.add_function(spec=read_spec("openeo-processes/2.x/proposals/date_shift.json")) def date_shift(args: ProcessArgs, env: EvalEnv) -> str: date = rfc3339.parse_date_or_datetime(args.get_required("date", expected_type=str)) value = int(args.get_required("value", expected_type=int)) @@ -1908,21 +1981,25 @@ def date_shift(args: ProcessArgs, env: EvalEnv) -> str: @process_registry_100.add_function(spec=read_spec("openeo-processes/1.x/proposals/array_concat.json")) -def array_concat(args: dict, env: EvalEnv) -> list: - array1 = extract_arg(args, "array1") - array2 = extract_arg(args, "array2") +@process_registry_2xx.add_function +def array_concat(args: ProcessArgs, env: EvalEnv) -> list: + array1 = args.get_required(name="array1", expected_type=list) + array2 = args.get_required(name="array2", expected_type=list) return list(array1) + list(array2) @process_registry_100.add_function(spec=read_spec("openeo-processes/1.x/proposals/array_create.json")) -def array_create(args: dict, env: EvalEnv) -> list: - data = extract_arg(args, "data") - repeat = args.get("repeat", 1) - if not isinstance(repeat, int) or repeat < 1: - raise ProcessParameterInvalidException( - parameter="repeat", process="array_create", - reason="The `repeat` parameter should be an integer of at least value 1." - ) +@process_registry_2xx.add_function +def array_create(args: ProcessArgs, env: EvalEnv) -> list: + data = args.get_required("data", expected_type=list) + repeat = args.get_optional( + name="repeat", + default=1, + expected_type=int, + validator=ProcessArgs.validator_generic( + lambda v: v >= 1, error_message="The `repeat` parameter should be an integer of at least value 1." + ), + ) return list(data) * repeat @@ -1953,7 +2030,9 @@ def load_result(args: dict, env: EvalEnv) -> DriverDataCube: return env.backend_implementation.load_result(job_id=job_id, user_id=user.user_id if user is not None else None, load_params=load_params, env=env) + @process_registry_100.add_function(spec=read_spec("openeo-processes/1.x/proposals/inspect.json")) +@process_registry_2xx.add_function(spec=read_spec("openeo-processes/2.x/proposals/inspect.json")) def inspect(args: dict, env: EvalEnv): data = extract_arg(args,"data") message = args.get("message","") @@ -1966,7 +2045,7 @@ def inspect(args: dict, env: EvalEnv): _log.log(level=logging.getLevelName(level.upper()), msg=data_message) return data -@process_registry_100.add_simple_function +@simple_function def text_begins(data: str, pattern: str, case_sensitive: bool = True) -> Union[bool, None]: if data is None: return None @@ -1976,7 +2055,7 @@ def text_begins(data: str, pattern: str, case_sensitive: bool = True) -> Union[b return data.startswith(pattern) -@process_registry_100.add_simple_function +@simple_function def text_contains(data: str, pattern: str, case_sensitive: bool = True) -> Union[bool, None]: if data is None: return None @@ -1986,7 +2065,7 @@ def text_contains(data: str, pattern: str, case_sensitive: bool = True) -> Union return pattern in data -@process_registry_100.add_simple_function +@simple_function def text_ends(data: str, pattern: str, case_sensitive: bool = True) -> Union[bool, None]: if data is None: return None @@ -2006,6 +2085,7 @@ def text_merge( @process_registry_100.add_simple_function(spec=read_spec("openeo-processes/2.x/text_concat.json")) +@process_registry_2xx.add_simple_function def text_concat( data: List[Union[str, int, float, bool, None]], separator: str = "", @@ -2014,6 +2094,7 @@ def text_concat( @process_registry_100.add_function(spec=read_spec("openeo-processes/experimental/load_stac.json")) +@process_registry_2xx.add_function(spec=read_spec("openeo-processes/2.x/proposals/load_stac.json")) def load_stac(args: Dict, env: EvalEnv) -> DriverDataCube: url = extract_arg(args, "url", process_id="load_stac") @@ -2043,9 +2124,11 @@ def load_stac(args: Dict, env: EvalEnv) -> DriverDataCube: @process_registry_100.add_simple_function(name="if") +@process_registry_2xx.add_simple_function(name="if") def if_(value: Union[bool, None], accept, reject=None): return accept if value else reject # Finally: register some fallback implementation if possible _register_fallback_implementations_by_process_graph(process_registry_100) +_register_fallback_implementations_by_process_graph(process_registry_2xx) diff --git a/openeo_driver/processes.py b/openeo_driver/processes.py index b97309b9..e3aaea60 100644 --- a/openeo_driver/processes.py +++ b/openeo_driver/processes.py @@ -57,6 +57,7 @@ def returns(self, description: str, schema: dict) -> 'ProcessSpec': def to_dict_040(self) -> dict: """Generate process spec as (JSON-able) dictionary (API 0.4.0 style).""" + # TODO #47 drop this if len(self._parameters) == 0: warnings.warn("Process with no parameters") assert self._returns is not None @@ -142,7 +143,8 @@ def add_process(self, name: str, function: Callable = None, spec: dict = None, n if self.contains(name, namespace): raise ProcessRegistryException(f"Process {name!r} already defined in namespace {namespace!r}") if spec: - assert name == spec['id'] + if name != spec["id"]: + raise ProcessRegistryException(f"Process {name!r} has unexpected id {spec['id']!r}") if function and self._argument_names: sig = inspect.signature(function) arg_names = [n for n, p in sig.parameters.items() if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD] @@ -188,7 +190,7 @@ def add_simple_function( self, f: Optional[Callable] = None, name: Optional[str] = None, spec: Optional[dict] = None ): """ - Register a simple function that uses normal arguments instead of `args: dict, env: EvalEnv`: + Register a simple function that uses normal arguments instead of `args: ProcessArgs, env: EvalEnv`: wrap it in a wrapper that automatically extracts these arguments :param f: :param name: process_id (when guessing from `f.__name__` doesn't work) @@ -265,8 +267,9 @@ def get_function(self, name: str, namespace: str = DEFAULT_NAMESPACE) -> Callabl return self._get(name, namespace).function -# Type annotation for an argument value +# Type annotation aliases ArgumentValue = Any +Validator = Callable[[Any], bool] class ProcessArgs(dict): @@ -294,7 +297,7 @@ def get_required( name: str, *, expected_type: Optional[Union[type, Tuple[type, ...]]] = None, - validator: Optional[Callable[[Any], bool]] = None, + validator: Optional[Validator] = None, ) -> ArgumentValue: """ Get a required argument by name. @@ -315,7 +318,7 @@ def _check_value( name: str, value: Any, expected_type: Optional[Union[type, Tuple[type, ...]]] = None, - validator: Optional[Callable[[Any], bool]] = None, + validator: Optional[Validator] = None, ): if expected_type: if not isinstance(value, expected_type): @@ -323,9 +326,9 @@ def _check_value( parameter=name, process=self.process_id, reason=f"Expected {expected_type} but got {type(value)}." ) if validator: + reason = None try: valid = validator(value) - reason = "Failed validation." except OpenEOApiException: # Preserve original OpenEOApiException raise @@ -333,7 +336,9 @@ def _check_value( valid = False reason = str(e) if not valid: - raise ProcessParameterInvalidException(parameter=name, process=self.process_id, reason=reason) + raise ProcessParameterInvalidException( + parameter=name, process=self.process_id, reason=reason or "Failed validation." + ) def get_optional( self, @@ -341,7 +346,7 @@ def get_optional( default: Union[Any, Callable[[], Any]] = None, *, expected_type: Optional[Union[type, Tuple[type, ...]]] = None, - validator: Optional[Callable[[Any], bool]] = None, + validator: Optional[Validator] = None, ) -> ArgumentValue: """ Get an optional argument with default @@ -364,7 +369,7 @@ def get_deep( self, *steps: str, expected_type: Optional[Union[type, Tuple[type, ...]]] = None, - validator: Optional[Callable[[Any], bool]] = None, + validator: Optional[Validator] = None, ) -> ArgumentValue: """ Walk recursively through a dictionary to get to a value. @@ -431,10 +436,26 @@ def get_enum(self, name: str, options: Collection[ArgumentValue]) -> ArgumentVal return value @staticmethod - def validator_one_of(options: list, show_value: bool = True): - """Build a validator function that check that the value is in given list""" + def validator_generic(condition: Callable[[Any], bool], error_message: str) -> Validator: + """ + Build validator function based on a condition (another validator) + and a custom error message when validation returns False. + (supports interpolation of actual value with "{actual}"). + """ def validator(value): + valid = condition(value) + if not valid: + raise ValueError(error_message.format(actual=value)) + return valid + + return validator + + @staticmethod + def validator_one_of(options: list, show_value: bool = True) -> Validator: + """Build a validator function that check that the value is in given list""" + + def validator(value) -> bool: if value not in options: if show_value: message = f"Must be one of {options!r} but got {value!r}." @@ -446,7 +467,7 @@ def validator(value): return validator @staticmethod - def validator_file_format(formats: Union[List[str], Dict[str, dict]]): + def validator_file_format(formats: Union[List[str], Dict[str, dict]]) -> Validator: """ Build validator for input/output format (case-insensitive check) @@ -455,7 +476,7 @@ def validator_file_format(formats: Union[List[str], Dict[str, dict]]): formats = list(formats) options = set(f.lower() for f in formats) - def validator(value: str): + def validator(value: str) -> bool: if value.lower() not in options: raise OpenEOApiException( message=f"Invalid file format {value!r}. Allowed formats: {', '.join(formats)}", @@ -469,10 +490,10 @@ def validator(value: str): @staticmethod def validator_geojson_dict( allowed_types: Optional[Collection[str]] = None, - ): + ) -> Validator: """Build validator to verify that provided structure looks like a GeoJSON-style object""" - def validator(value): + def validator(value) -> bool: issues = validate_geojson_basic(value=value, allowed_types=allowed_types, raise_exception=False) if issues: raise ValueError(f"Invalid GeoJSON: {', '.join(issues)}.") diff --git a/tests/test_processes.py b/tests/test_processes.py index 0a360183..663954fd 100644 --- a/tests/test_processes.py +++ b/tests/test_processes.py @@ -615,6 +615,43 @@ def test_get_enum(self): ): _ = args.get_enum("color", options=["R", "G", "B"]) + def test_validator_generic(self): + args = ProcessArgs({"size": 11}, process_id="wibble") + + validator = ProcessArgs.validator_generic(lambda v: v > 1, error_message="Should be stricly positive.") + value = args.get_required("size", expected_type=int, validator=validator) + assert value == 11 + + validator = ProcessArgs.validator_generic(lambda v: v % 2 == 0, error_message="Should be even.") + with pytest.raises( + ProcessParameterInvalidException, + match=re.escape("The value passed for parameter 'size' in process 'wibble' is invalid: Should be even."), + ): + _ = args.get_required("size", expected_type=int, validator=validator) + + validator = ProcessArgs.validator_generic( + lambda v: v % 2 == 0, error_message="Should be even but got {actual}." + ) + with pytest.raises( + ProcessParameterInvalidException, + match=re.escape( + "The value passed for parameter 'size' in process 'wibble' is invalid: Should be even but got 11." + ), + ): + _ = args.get_required("size", expected_type=int, validator=validator) + + def test_validator_one_of(self): + args = ProcessArgs({"color": "red", "size": 5}, process_id="wibble") + with pytest.raises( + ProcessParameterInvalidException, + match=re.escape( + "The value passed for parameter 'color' in process 'wibble' is invalid: Must be one of ['yellow', 'violet'] but got 'red'." + ), + ): + _ = args.get_required( + "color", expected_type=str, validator=ProcessArgs.validator_one_of(["yellow", "violet"]) + ) + def test_validator_geojson_dict(self): polygon = {"type": "Polygon", "coordinates": [[1, 2]]} args = ProcessArgs({"geometry": polygon, "color": "red"}, process_id="wibble")