diff --git a/src/ocrd_network/cli/client.py b/src/ocrd_network/cli/client.py index fe87d01fbf..c2b6c500d1 100644 --- a/src/ocrd_network/cli/client.py +++ b/src/ocrd_network/cli/client.py @@ -1,9 +1,9 @@ import click +from json import dumps, loads from typing import Optional - from ocrd.decorators import parameter_option -from ocrd_network import Client from ocrd_utils import DEFAULT_METS_BASENAME +from ..client import Client @click.group('client') @@ -43,7 +43,7 @@ def processing_cli(): @click.option('--callback-url') @click.option('--agent-type', default='worker') @click.option('-b', '--block-till-job-end', default=False) -def send_processing_request( +def send_processing_job_request( address: Optional[str], processor_name: str, mets: str, @@ -73,13 +73,13 @@ def send_processing_request( req_params["result_queue_name"] = result_queue_name if callback_url: req_params["callback_url"] = callback_url - client = Client(server_addr_processing=address) - response = client.send_processing_request(processor_name=processor_name, req_params=req_params) - processing_job_id = response.get('job_id', None) + processing_job_id = client.send_processing_job_request( + processor_name=processor_name, req_params=loads(dumps(req_params))) + assert processing_job_id print(f"Processing job id: {processing_job_id}") if block_till_job_end: - pass + client.poll_job_status_till_timeout_fail_or_success(processing_job_id) @client_cli.group('workflow') diff --git a/src/ocrd_network/client.py b/src/ocrd_network/client.py index 688651f7de..333ec7fa44 100644 --- a/src/ocrd_network/client.py +++ b/src/ocrd_network/client.py @@ -1,8 +1,8 @@ -from json import dumps, loads -from requests import post as requests_post + from ocrd_utils import config, getLogger, LOG_FORMAT -from .constants import NETWORK_PROTOCOLS +from .client_utils import ( + poll_job_status_till_timeout_fail_or_success, post_ps_processing_request, verify_server_protocol) class Client: @@ -10,19 +10,13 @@ def __init__(self, server_addr_processing: str = config.OCRD_NETWORK_SERVER_ADDR self.log = getLogger(f"ocrd_network.client") self.server_addr_processing = server_addr_processing verify_server_protocol(self.server_addr_processing) + self.polling_tries = 900 + self.polling_wait = 30 - def send_processing_request(self, processor_name: str, req_params: dict): - req_url = f"{self.server_addr_processing}/processor/{processor_name}" - req_headers = {"Content-Type": "application/json; charset=utf-8"} - req_json = loads(dumps(req_params)) - self.log.info(f"Sending processing request to: {req_url}") - self.log.debug(req_json) - response = requests_post(url=req_url, headers=req_headers, json=req_json) - return response.json() - + def poll_job_status_till_timeout_fail_or_success(self, job_id: str) -> str: + return poll_job_status_till_timeout_fail_or_success( + ps_server_host=self.server_addr_processing, job_id=job_id, tries=self.polling_tries, wait=self.polling_wait) -def verify_server_protocol(address: str): - for protocol in NETWORK_PROTOCOLS: - if address.startswith(protocol): - return - raise ValueError(f"Wrong/Missing protocol in the server address: {address}, must be one of: {NETWORK_PROTOCOLS}") + def send_processing_job_request(self, processor_name: str, req_params: dict) -> str: + return post_ps_processing_request( + ps_server_host=self.server_addr_processing, processor=processor_name, job_input=req_params) diff --git a/src/ocrd_network/client_utils.py b/src/ocrd_network/client_utils.py index 96fab03372..651dc5cf6b 100644 --- a/src/ocrd_network/client_utils.py +++ b/src/ocrd_network/client_utils.py @@ -1,6 +1,6 @@ from requests import get as request_get, post as request_post from time import sleep -from .constants import JobState +from .constants import JobState, NETWORK_PROTOCOLS def _poll_endpoint_status(ps_server_host: str, job_id: str, job_type: str, tries: int, wait: int): @@ -29,7 +29,7 @@ def poll_wf_status_till_timeout_fail_or_success(ps_server_host: str, job_id: str def get_ps_processing_job_status(ps_server_host: str, processing_job_id: str) -> str: request_url = f"{ps_server_host}/processor/job/{processing_job_id}" - response = request_get(url=request_url, headers={"accept": "application/json"}) + response = request_get(url=request_url, headers={"accept": "application/json; charset=utf-8"}) assert response.status_code == 200, f"Processing server: {request_url}, {response.status_code}" job_state = response.json()["state"] assert job_state @@ -38,7 +38,7 @@ def get_ps_processing_job_status(ps_server_host: str, processing_job_id: str) -> def get_ps_workflow_job_status(ps_server_host: str, workflow_job_id: str) -> str: request_url = f"{ps_server_host}/workflow/job-simple/{workflow_job_id}" - response = request_get(url=request_url, headers={"accept": "application/json"}) + response = request_get(url=request_url, headers={"accept": "application/json; charset=utf-8"}) assert response.status_code == 200, f"Processing server: {request_url}, {response.status_code}" job_state = response.json()["state"] assert job_state @@ -47,7 +47,11 @@ def get_ps_workflow_job_status(ps_server_host: str, workflow_job_id: str) -> str def post_ps_processing_request(ps_server_host: str, processor: str, job_input: dict) -> str: request_url = f"{ps_server_host}/processor/run/{processor}" - response = request_post(url=request_url, headers={"accept": "application/json"}, json=job_input) + response = request_post( + url=request_url, + headers={"accept": "application/json; charset=utf-8"}, + json=job_input + ) assert response.status_code == 200, f"Processing server: {request_url}, {response.status_code}" processing_job_id = response.json()["job_id"] assert processing_job_id @@ -58,10 +62,20 @@ def post_ps_processing_request(ps_server_host: str, processor: str, job_input: d def post_ps_workflow_request(ps_server_host: str, path_to_wf: str, path_to_mets: str) -> str: request_url = f"{ps_server_host}/workflow/run?mets_path={path_to_mets}&page_wise=True" response = request_post( - url=request_url, headers={"accept": "application/json"}, files={"workflow": open(path_to_wf, "rb")}) + url=request_url, + headers={"accept": "application/json; charset=utf-8"}, + files={"workflow": open(path_to_wf, "rb")} + ) # print(response.json()) # print(response.__dict__) assert response.status_code == 200, f"Processing server: {request_url}, {response.status_code}" wf_job_id = response.json()["job_id"] assert wf_job_id return wf_job_id + + +def verify_server_protocol(address: str): + for protocol in NETWORK_PROTOCOLS: + if address.startswith(protocol): + return + raise ValueError(f"Wrong/Missing protocol in the server address: {address}, must be one of: {NETWORK_PROTOCOLS}") diff --git a/tests/network/test_integration_6_client.py b/tests/network/test_integration_6_client.py index 55168d8322..c1fe5ab260 100644 --- a/tests/network/test_integration_6_client.py +++ b/tests/network/test_integration_6_client.py @@ -1,9 +1,7 @@ -from click.testing import CliRunner - from src.ocrd_network.constants import AgentType, JobState from tests.base import assets from tests.network.config import test_config -from ocrd_network.cli.client import client_cli +from ocrd_network.client import Client PROCESSING_SERVER_URL = test_config.PROCESSING_SERVER_URL @@ -11,20 +9,19 @@ def test_client_processing_processor(): workspace_root = "kant_aufklaerung_1784/data" path_to_mets = assets.path_to(f"{workspace_root}/mets.xml") - runner = CliRunner() - result = runner.invoke( - client_cli, - args=[ - "processing", "processor", "ocrd-dummy", - "--address", PROCESSING_SERVER_URL, - "--mets", path_to_mets, - "--input-file-grp", "OCR-D-IMG", - "--output-file-grp", "OCR-D-DUMMY-TEST-CLIENT", - "--agent-type", AgentType.PROCESSING_WORKER - ] - ) - # TODO: Do a better result check - assert result.output.count(f"{JobState.success}") == 1 + client = Client(server_addr_processing=PROCESSING_SERVER_URL) + req_params = { + "path_to_mets": path_to_mets, + "description": "OCR-D Network client request", + "input_file_grps": ["OCR-D-IMG"], + "output_file_grps": ["OCR-D-DUMMY-TEST-CLIENT"], + "parameters": {}, + "agent_type": AgentType.PROCESSING_WORKER + } + processing_job_id = client.send_processing_job_request(processor_name="ocrd-dummy", req_params=req_params) + assert processing_job_id + print(f"Processing job id: {processing_job_id}") + assert JobState.success == client.poll_job_status_till_timeout_fail_or_success(processing_job_id) def test_client_processing_workflow():