Skip to content

Commit

Permalink
Issu #404 Add automatic process graph validation, when backend suppor…
Browse files Browse the repository at this point in the history
…ts it
  • Loading branch information
JohanKJSchreurs committed Oct 3, 2023
1 parent d8aae54 commit 748a52e
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 2 deletions.
41 changes: 40 additions & 1 deletion openeo/rest/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@
DEFAULT_TIMEOUT_SYNCHRONOUS_EXECUTE = 30 * 60


# TODO: remove temporary constant that is intended for refactoring
# constant for refactoring to switch default validation of process graph on or off.
VALIDATE_PROCESS_GRAPH_BY_DEFAULT = True


class RestApiConnection:
"""Base connection class implementing generic REST API request functionality"""

Expand Down Expand Up @@ -1052,7 +1057,14 @@ def validate_process_graph(self, process_graph: dict) -> List[dict]:
:param process_graph: (flat) dict representing process graph
:return: list of errors (dictionaries with "code" and "message" fields)
"""
request = {"process_graph": process_graph}
# TODO: sometimes process_graph is already in the graph. Should we really *always* add it?
# Was getting errors in some new unit tests because of the double process_graph but
# perhaps the error is really not here but somewhere else that adds process_graph
# when it should not? Still needs to be confirmed.
if "process_graph" not in process_graph:
request = {"process_graph": process_graph}
else:
request = process_graph
return self.post(path="/validation", json=request, expected_status=200).json()["errors"]

@property
Expand Down Expand Up @@ -1474,12 +1486,27 @@ def _build_request_with_process_graph(self, process_graph: Union[dict, FlatGraph
result["process"] = process_graph
return result

def _warn_if_process_graph_invalid(self, process_graph: Union[dict, FlatGraphableMixin, str, Path]):
if not self.capabilities().supports_endpoint("/validation", "POST"):
return

graph = as_flat_graph(process_graph)
if "process_graph" not in graph:
graph = {"process_graph": graph}

validation_errors = self.validate_process_graph(process_graph=graph)
if validation_errors:
_log.warning(
"Process graph is not valid. Validation errors:\n" + "\n".join(e["message"] for e in validation_errors)
)

# TODO: unify `download` and `execute` better: e.g. `download` always writes to disk, `execute` returns result (raw or as JSON decoded dict)
def download(
self,
graph: Union[dict, FlatGraphableMixin, str, Path],
outputfile: Union[Path, str, None] = None,
timeout: Optional[int] = None,
validate: bool = VALIDATE_PROCESS_GRAPH_BY_DEFAULT,
) -> Union[None, bytes]:
"""
Downloads the result of a process graph synchronously,
Expand All @@ -1491,6 +1518,9 @@ def download(
:param outputfile: output file
:param timeout: timeout to wait for response
"""
if validate:
self._warn_if_process_graph_invalid(process_graph=graph)

request = self._build_request_with_process_graph(process_graph=graph)
response = self.post(
path="/result",
Expand All @@ -1511,6 +1541,7 @@ def execute(
self,
process_graph: Union[dict, str, Path],
timeout: Optional[int] = None,
validate: bool = VALIDATE_PROCESS_GRAPH_BY_DEFAULT,
):
"""
Execute a process graph synchronously and return the result (assumed to be JSON).
Expand All @@ -1519,6 +1550,9 @@ def execute(
or as local file path or URL
:return: parsed JSON response
"""
if validate:
self._warn_if_process_graph_invalid(process_graph=process_graph)

req = self._build_request_with_process_graph(process_graph=process_graph)
return self.post(
path="/result",
Expand All @@ -1536,6 +1570,7 @@ def create_job(
plan: Optional[str] = None,
budget: Optional[float] = None,
additional: Optional[dict] = None,
validate: bool = VALIDATE_PROCESS_GRAPH_BY_DEFAULT,
) -> BatchJob:
"""
Create a new job from given process graph on the back-end.
Expand All @@ -1550,6 +1585,10 @@ def create_job(
:return: Created job
"""
# TODO move all this (BatchJob factory) logic to BatchJob?

if validate:
self._warn_if_process_graph_invalid(process_graph=process_graph)

req = self._build_request_with_process_graph(
process_graph=process_graph,
**dict_no_none(title=title, description=description, plan=plan, budget=budget)
Expand Down
34 changes: 33 additions & 1 deletion tests/rest/conftest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import contextlib
import re
import typing
from typing import List, Optional
from unittest import mock

import pytest
import time_machine

from openeo.rest._testing import DummyBackend
import openeo
from openeo.rest._testing import DummyBackend, build_capabilities
from openeo.rest.connection import Connection

API_URL = "https://oeo.test/"
Expand Down Expand Up @@ -87,3 +89,33 @@ def con120(requests_mock):
@pytest.fixture
def dummy_backend(requests_mock, con100) -> DummyBackend:
yield DummyBackend(requests_mock=requests_mock, connection=con100)


def _setup_connection(api_version, requests_mock, build_capabilities_kwargs: Optional[dict] = None) -> Connection:
# TODO: make this more reusable?
requests_mock.get(API_URL, json=build_capabilities(api_version=api_version, **(build_capabilities_kwargs or {})))
requests_mock.get(
API_URL + "file_formats",
json={
"output": {
"GTiff": {"gis_data_types": ["raster"]},
"netCDF": {"gis_data_types": ["raster"]},
"csv": {"gis_data_types": ["table"]},
}
},
)
requests_mock.get(
API_URL + "udf_runtimes",
json={
"Python": {"type": "language", "default": "3", "versions": {"3": {"libraries": {}}}},
"R": {"type": "language", "default": "4", "versions": {"4": {"libraries": {}}}},
},
)

return openeo.connect(API_URL)


@pytest.fixture
def connection_with_pgvalidation(api_version, requests_mock) -> Connection:
"""Connection fixture to a backend of given version with some image collections."""
return _setup_connection(api_version, requests_mock, build_capabilities_kwargs={"validation": True})
113 changes: 113 additions & 0 deletions tests/rest/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
RestApiConnection,
connect,
paginate,
VALIDATE_PROCESS_GRAPH_BY_DEFAULT,
)
from openeo.rest.vectorcube import VectorCube
from openeo.util import ContextTimer
Expand Down Expand Up @@ -3179,23 +3180,50 @@ class TestExecute:
PG_JSON_1 = '{"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": true}}'
PG_JSON_2 = '{"process_graph": {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": true}}}'

PG_INVALID_DICT_INNER = {
"loadcollection1": {
"process_id": "load_collection",
"arguments": {"id": "S2", "spatial_extent": None, "temporal_extent": None},
"result": True,
}
}
PG_INVALID_DICT_OUTER = {"process_graph": PG_INVALID_DICT_INNER}
PG_INVALID_INNER = json.dumps(PG_INVALID_DICT_INNER)
PG_INVALID_OUTER = json.dumps(PG_INVALID_DICT_OUTER)

# Dummy `POST /result` handlers
def _post_result_handler_tiff(self, response: requests.Request, context):
pg = response.json()["process"]["process_graph"]
assert pg == {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}}
return b"TIFF data"

def _post_result_handler_tiff_invalid_pg(self, response: requests.Request, context):
pg = response.json()["process"]["process_graph"]
assert pg == self.PG_INVALID_DICT_INNER
return b"TIFF data"

def _post_result_handler_json(self, response: requests.Request, context):
pg = response.json()["process"]["process_graph"]
assert pg == {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}}
return {"answer": 8}

def _post_result_handler_json_invalid_pg(self, response: requests.Request, context):
pg = response.json()["process"]["process_graph"]
assert pg == self.PG_INVALID_DICT_INNER
return {"answer": 8}

def _post_jobs_handler_json(self, response: requests.Request, context):
pg = response.json()["process"]["process_graph"]
assert pg == {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}}
context.headers["OpenEO-Identifier"] = "j-123"
return b""

def _post_jobs_handler_json_invalid_pg(self, response: requests.Request, context):
pg = response.json()["process"]["process_graph"]
assert pg == self.PG_INVALID_DICT_INNER
context.headers["OpenEO-Identifier"] = "j-123"
return b""

@pytest.mark.parametrize("pg_json", [PG_JSON_1, PG_JSON_2])
def test_download_pg_json(self, requests_mock, tmp_path, pg_json: str):
requests_mock.get(API_URL, json={"api_version": "1.0.0"})
Expand All @@ -3206,6 +3234,28 @@ def test_download_pg_json(self, requests_mock, tmp_path, pg_json: str):
conn.download(pg_json, outputfile=output)
assert output.read_bytes() == b"TIFF data"

@pytest.mark.parametrize("pg_json", [PG_INVALID_INNER, PG_INVALID_OUTER])
def test_download_pg_json_with_invalid_pg(
self, requests_mock, connection_with_pgvalidation, tmp_path, pg_json: str, caplog
):
caplog.set_level(logging.WARNING)
requests_mock.post(API_URL + "result", content=self._post_result_handler_tiff_invalid_pg)

validation_errors = [{"code": "Invalid", "message": "Invalid process graph"}]

def validation(request, context):
assert request.json() == self.PG_INVALID_DICT_OUTER
return {"errors": validation_errors}

m = requests_mock.post(API_URL + "validation", json=validation)

output = tmp_path / "result.tiff"
connection_with_pgvalidation.download(pg_json, outputfile=output, validate=True)

assert output.read_bytes() == b"TIFF data"
assert caplog.messages == ["Process graph is not valid. Validation errors:\nInvalid process graph"]
assert m.call_count == 1

@pytest.mark.parametrize("pg_json", [PG_JSON_1, PG_JSON_2])
def test_execute_pg_json(self, requests_mock, pg_json: str):
requests_mock.get(API_URL, json={"api_version": "1.0.0"})
Expand All @@ -3215,6 +3265,24 @@ def test_execute_pg_json(self, requests_mock, pg_json: str):
result = conn.execute(pg_json)
assert result == {"answer": 8}

@pytest.mark.parametrize("pg_json", [PG_INVALID_INNER, PG_INVALID_OUTER])
def test_execute_pg_json_with_invalid_pg(self, requests_mock, connection_with_pgvalidation, pg_json: str, caplog):
caplog.set_level(logging.WARNING)
requests_mock.post(API_URL + "result", json=self._post_result_handler_json_invalid_pg)

validation_errors = [{"code": "Invalid", "message": "Invalid process graph"}]

def validation(request, context):
assert request.json() == self.PG_INVALID_DICT_OUTER
return {"errors": validation_errors}

m = requests_mock.post(API_URL + "validation", json=validation)

result = connection_with_pgvalidation.execute(pg_json, validate=True)
assert result == {"answer": 8}
assert caplog.messages == ["Process graph is not valid. Validation errors:\nInvalid process graph"]
assert m.call_count == 1

@pytest.mark.parametrize("pg_json", [PG_JSON_1, PG_JSON_2])
def test_create_job_pg_json(self, requests_mock, pg_json: str):
requests_mock.get(API_URL, json={"api_version": "1.0.0"})
Expand All @@ -3224,6 +3292,26 @@ def test_create_job_pg_json(self, requests_mock, pg_json: str):
job = conn.create_job(pg_json)
assert job.job_id == "j-123"

@pytest.mark.parametrize("pg_json", [PG_INVALID_INNER, PG_INVALID_OUTER])
def test_create_job_pg_json_with_invalid_pg(
self, requests_mock, connection_with_pgvalidation, pg_json: str, caplog
):
caplog.set_level(logging.WARNING)
requests_mock.post(API_URL + "jobs", status_code=201, content=self._post_jobs_handler_json_invalid_pg)

validation_errors = [{"code": "Invalid", "message": "Invalid process graph"}]

def validation(request, context):
assert request.json() == self.PG_INVALID_DICT_OUTER
return {"errors": validation_errors}

m = requests_mock.post(API_URL + "validation", json=validation)

job = connection_with_pgvalidation.create_job(pg_json, validate=True)
assert job.job_id == "j-123"
assert caplog.messages == ["Process graph is not valid. Validation errors:\nInvalid process graph"]
assert m.call_count == 1

@pytest.mark.parametrize("pg_json", [PG_JSON_1, PG_JSON_2])
@pytest.mark.parametrize("path_factory", [str, Path])
def test_download_pg_json_file(self, requests_mock, tmp_path, pg_json: str, path_factory):
Expand All @@ -3238,6 +3326,31 @@ def test_download_pg_json_file(self, requests_mock, tmp_path, pg_json: str, path
conn.download(json_file, outputfile=output)
assert output.read_bytes() == b"TIFF data"

@pytest.mark.parametrize("pg_json", [PG_INVALID_INNER, PG_INVALID_OUTER])
@pytest.mark.parametrize("path_factory", [str, Path])
def test_download_pg_json_file_with_invalid_pg(
self, requests_mock, connection_with_pgvalidation, tmp_path, pg_json: str, path_factory, caplog
):
caplog.set_level(logging.WARNING)
requests_mock.post(API_URL + "result", content=self._post_result_handler_tiff_invalid_pg)

validation_errors = [{"code": "Invalid", "message": "Invalid process graph"}]

def validation(request, context):
assert request.json() == self.PG_INVALID_DICT_OUTER
return {"errors": validation_errors}

m = requests_mock.post(API_URL + "validation", json=validation)

json_file = tmp_path / "input.json"
json_file.write_text(pg_json)
json_file = path_factory(json_file)

output = tmp_path / "result.tiff"
connection_with_pgvalidation.download(json_file, outputfile=output, validate=True)
assert caplog.messages == ["Process graph is not valid. Validation errors:\nInvalid process graph"]
assert m.call_count == 1

@pytest.mark.parametrize("pg_json", [PG_JSON_1, PG_JSON_2])
@pytest.mark.parametrize("path_factory", [str, Path])
def test_execute_pg_json_file(self, requests_mock, pg_json: str, tmp_path, path_factory):
Expand Down

0 comments on commit 748a52e

Please sign in to comment.