Skip to content

Commit

Permalink
Issue #332 implement job listing pagination in ElasticJobRegistry
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Dec 5, 2024
1 parent 9c5ddcf commit 4875725
Show file tree
Hide file tree
Showing 7 changed files with 352 additions and 32 deletions.
2 changes: 1 addition & 1 deletion openeo_driver/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.120.0a1"
__version__ = "0.120.1a1"
22 changes: 17 additions & 5 deletions openeo_driver/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from typing import List, Union, NamedTuple, Dict, Optional, Callable, Iterable

import flask
import werkzeug.datastructures

import openeo_driver.util.view_helpers
from openeo.capabilities import ComparableVersion
Expand Down Expand Up @@ -393,20 +392,33 @@ class JobListing:
__slots__ = ["_jobs", "_next_parameters"]

def __init__(self, jobs: List[BatchJobMetadata], next_parameters: Optional[dict] = None):
"""
:param jobs: list of job metadata constructs
:param next_parameters: dictionary of URL parameters to add to `GET /jobs` URL to link to next page
"""
self._jobs = jobs
self._next_parameters = next_parameters

def to_response_dict(self, url_for: Callable[[dict], str], api_version: ComparableVersion) -> dict:
"""Produce `GET /jobs` response data, to be JSONified."""
def to_response_dict(self, build_url: Callable[[dict], str], api_version: ComparableVersion = None) -> dict:
"""
Produce `GET /jobs` response data, to be JSONified.
:param build_url: function to generate a paginated" URL from given pagination related parameters,
e.g. `lambda params: flask.url_for(".list_jobs", **params, _external=True)`
"""
links = []
if self._next_parameters:
links.append({"rel": "next", "href": url_for(self._next_parameters)})
links.append({"rel": "next", "href": build_url(self._next_parameters)})

return {
"jobs": [m.to_api_dict(full=False, api_version=api_version) for m in self._jobs],
"links": links,
}

def __len__(self) -> int:
return len(self._jobs)


class BatchJobs(MicroService):
"""
Expand Down Expand Up @@ -445,7 +457,7 @@ def get_user_jobs(
self,
user_id: str,
limit: Optional[int] = None,
request_parameters: Optional[werkzeug.datastructures.MultiDict] = None,
request_parameters: Optional[dict] = None,
# TODO #332 settle on returning just `JobListing` and eliminate other options/code paths.
) -> Union[List[BatchJobMetadata], dict, JobListing]:
"""
Expand Down
6 changes: 2 additions & 4 deletions openeo_driver/dummy/dummy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import flask
import numpy
import openeo.udf
import werkzeug.datastructures
import xarray
from openeo.api.logs import normalize_log_level
from openeo.internal.process_graph_visitor import ProcessGraphVisitor
Expand Down Expand Up @@ -44,7 +43,7 @@
UserDefinedProcessMetadata,
)
from openeo_driver.config import OpenEoBackendConfig
from openeo_driver.constants import STAC_EXTENSION
from openeo_driver.constants import JOB_STATUS, STAC_EXTENSION
from openeo_driver.datacube import DriverDataCube, DriverMlModel, DriverVectorCube
from openeo_driver.datastructs import StacAsset
from openeo_driver.delayed_vector import DelayedVector
Expand All @@ -56,7 +55,6 @@
PermissionsInsufficientException,
ProcessGraphNotFoundException,
)
from openeo_driver.constants import JOB_STATUS
from openeo_driver.ProcessGraphDeserializer import ConcreteProcessing
from openeo_driver.save_result import (
AggregatePolygonResult,
Expand Down Expand Up @@ -688,7 +686,7 @@ def get_user_jobs(
self,
user_id: str,
limit: Optional[int] = None,
request_parameters: Optional[werkzeug.datastructures.MultiDict] = None,
request_parameters: Optional[dict] = None,
) -> JobListing:
jobs: List[BatchJobMetadata] = [v for (k, v) in self._job_registry.items() if k[0] == user_id]
next_parameters = None
Expand Down
179 changes: 167 additions & 12 deletions openeo_driver/jobregistry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import argparse
import dataclasses
import datetime
import json
import logging
import os
Expand All @@ -9,6 +11,7 @@
import textwrap
import time
import typing
import urllib.parse
from decimal import Decimal
from typing import Any, Dict, List, Optional, Sequence, Union

Expand All @@ -19,10 +22,12 @@
from openeo.util import TimingLogger, repr_truncate, rfc3339

import openeo_driver._version
from openeo_driver.backend import BatchJobMetadata, JobListing
from openeo_driver.config import get_backend_config
from openeo_driver.constants import JOB_STATUS
from openeo_driver.errors import InternalException, JobNotFoundException
from openeo_driver.util.auth import ClientCredentials, ClientCredentialsAccessTokenHelper
from openeo_driver.util.http import UrlSafeStructCodec
from openeo_driver.util.logging import ExtraLoggingFilter
from openeo_driver.utils import generate_unique_id

Expand Down Expand Up @@ -139,8 +144,14 @@ def set_results_metadata(self, job_id: str, costs: Optional[float], usage: dict,
raise NotImplementedError

def list_user_jobs(
self, user_id: str, fields: Optional[List[str]] = None
) -> List[JobDict]:
self,
user_id: str,
*,
fields: Optional[List[str]] = None,
limit: Optional[int] = None,
request_parameters: Optional[dict] = None,
# TODO #332 settle on returning just `JobListing` and eliminate other options/code paths.
) -> Union[JobListing, List[JobDict]]:
"""
List all jobs of a user
Expand Down Expand Up @@ -172,6 +183,50 @@ def list_active_jobs(
raise NotImplementedError


def ejr_job_info_to_metadata(job_info: JobDict, full: bool = True) -> BatchJobMetadata:
"""Convert job info dict (from JobRegistryInterface) to BatchJobMetadata"""
# TODO: make this a classmethod in a more appropriate place, e.g. JobRegistryInterface?

def map_safe(prop: str, f):
value = job_info.get(prop)
return f(value) if value else None

def get_results_metadata(result_metadata_prop: str):
return job_info.get("results_metadata", {}).get(result_metadata_prop)

def map_results_metadata_safe(result_metadata_prop: str, f):
value = get_results_metadata(result_metadata_prop)
return f(value) if value is not None else None

return BatchJobMetadata(
id=job_info["job_id"],
status=job_info["status"],
created=map_safe("created", rfc3339.parse_datetime),
process=job_info.get("process") if full else None,
job_options=job_info.get("job_options") if full else None,
title=job_info.get("title"),
description=job_info.get("description"),
updated=map_safe("updated", rfc3339.parse_datetime),
started=map_safe("started", rfc3339.parse_datetime),
finished=map_safe("finished", rfc3339.parse_datetime),
memory_time_megabyte=map_safe(
"memory_time_megabyte_seconds", lambda seconds: datetime.timedelta(seconds=seconds)
),
cpu_time=map_safe("cpu_time_seconds", lambda seconds: datetime.timedelta(seconds=seconds)),
geometry=get_results_metadata("geometry"),
bbox=get_results_metadata("bbox"),
start_datetime=map_results_metadata_safe("start_datetime", rfc3339.parse_datetime),
end_datetime=map_results_metadata_safe("end_datetime", rfc3339.parse_datetime),
instruments=get_results_metadata("instruments"),
epsg=get_results_metadata("epsg"),
links=get_results_metadata("links"),
usage=job_info.get("usage"),
costs=job_info.get("costs"),
proj_shape=get_results_metadata("proj:shape"),
proj_bbox=get_results_metadata("proj:bbox"),
)


class EjrError(Exception):
"""Elastic Job Registry error (base class)."""

Expand Down Expand Up @@ -239,6 +294,8 @@ class ElasticJobRegistry(JobRegistryInterface):

_REQUEST_TIMEOUT = 20

PAGINATION_URL_PARAM = "page"

logger = logging.getLogger(f"{__name__}.elastic")

def __init__(
Expand Down Expand Up @@ -284,7 +341,9 @@ def _do_request(
self,
method: str,
path: str,
*,
json: Union[dict, list, None] = None,
params: Optional[dict] = None,
use_auth: bool = True,
expected_status: int = 200,
log_response_errors: bool = True,
Expand All @@ -298,14 +357,16 @@ def _do_request(
headers["Authorization"] = f"Bearer {access_token}"

url = url_join(self._api_url, path)
self.logger.debug(f"Doing EJR request `{method} {url}` {headers.keys()=}")
self.logger.debug(f"Doing EJR request `{method} {url}` {params=} {headers.keys()=}")
if self._debug_show_curl:
# TODO: add params to curl command
curl_command = self._as_curl(method=method, url=url, data=json, headers=headers)
self.logger.debug(f"Equivalent curl command: {curl_command}")
try:
do_request = lambda: self._session.request(
method=method,
url=url,
params=params,
json=json,
headers=headers,
timeout=self._REQUEST_TIMEOUT,
Expand Down Expand Up @@ -529,16 +590,60 @@ def _search(self, query: dict, fields: Optional[List[str]] = None) -> List[JobDi
fields = set(fields or [])
# Make sure to include some basic fields by default
fields.update(["job_id", "user_id", "created", "status", "updated"])
query = {
body = {
"query": query,
"_source": list(fields),
}
self.logger.debug(f"Doing search with query {json.dumps(body)}")
return self._do_request("POST", "/jobs/search", json=body, retry=True)

@dataclasses.dataclass(frozen=True)
class PaginatedSearchResult:
jobs: List[JobDict]
pagination: dict

def _search_paginated(
self,
query: dict,
*,
fields: Optional[List[str]] = None,
page_size: Optional[int] = None,
page_number: Optional[int] = None,
) -> PaginatedSearchResult:
fields = set(fields or [])
# Make sure to include some basic fields by default
# TODO: avoid duplication of this default field set
fields.update(["job_id", "user_id", "created", "status", "updated"])
params = {}
if page_size:
params["size"] = page_size
if page_number:
params["page"] = page_number
body = {
"query": query,
"_source": list(fields),
}
self.logger.debug(f"Doing search with query {json.dumps(query)}")
return self._do_request("POST", "/jobs/search", query, retry=True)
self.logger.debug(f"Doing search with query {json.dumps(body)} and {params=}")
response = self._do_request("POST", "/jobs/search/paginated", params=params, json=body, retry=True)
# Response structure:
# {
# "jobs": [list of job docs],
# "pagination': {"previous": "size=5&page=1", "next": "size=5&page=3"}
# }
return self.PaginatedSearchResult(
jobs=response.get("jobs", []),
pagination=response.get("pagination", {}),
)

def list_user_jobs(
self, user_id: Optional[str], fields: Optional[List[str]] = None
) -> List[JobDict]:
self,
user_id: Optional[str],
*,
fields: Optional[List[str]] = None,
limit: Optional[int] = None,
request_parameters: Optional[dict] = None,
# TODO #332 settle on returning just `JobListing` and eliminate other options/code paths.
) -> Union[JobListing, List[JobDict]]:
query = {
"bool": {
"filter": [
Expand All @@ -547,7 +652,38 @@ def list_user_jobs(
]
}
}
return self._search(query=query, fields=fields)

if limit:
# Do paginated search
# TODO: make this the one and only code path
url_safe_codec = UrlSafeStructCodec(signature_field="_usc")
page_number = None
page_params = (request_parameters or {}).get(self.PAGINATION_URL_PARAM)
if page_params:
# Extract page number
page_params = url_safe_codec.decode(page_params)
if isinstance(page_params, str):
# TODO: this is old code path where page params where encoded as URL query string
# parse as URL params
page_params = urllib.parse.parse_qs(page_params)
assert limit == int(page_params["size"][0])
page_number = int(page_params["page"][0])
elif isinstance(page_params, dict):
# TODO: this dict handling should be (is?) the only code path?
assert limit == page_params["size"]
page_number = page_params["page"]
else:
raise ValueError(page_params)

data = self._search_paginated(query=query, fields=fields, page_size=limit, page_number=page_number)
return JobListing(
jobs=[ejr_job_info_to_metadata(j, full=False) for j in data.jobs],
next_parameters={self.PAGINATION_URL_PARAM: url_safe_codec.encode(data.pagination.get("next"))},
)
else:
# Deprecated non-paginated search
# TODO: eliminate this code path
return self._search(query=query, fields=fields)

def list_active_jobs(
self,
Expand Down Expand Up @@ -644,6 +780,12 @@ def _parse_cli(self) -> argparse.Namespace:
cli_list_user = subparsers.add_parser("list-user", help="List jobs for given user.")
cli_list_user.add_argument("--backend-id", help="Backend id to filter on.")
cli_list_user.add_argument("user_id", help="User id to filter on.")
cli_list_user.add_argument("--limit", type=int, default=None, help="Page size of job listing")
cli_list_user.add_argument(
"--pagination",
default=None,
help=f"Pagination URL fragment from 'next' link. E.g. '/jobs?{ElasticJobRegistry.PAGINATION_URL_PARAM}=InNpemU9N...'",
)
cli_list_user.set_defaults(func=self.list_user_jobs)

cli_create = subparsers.add_parser("create", help="Create a new job.")
Expand Down Expand Up @@ -738,10 +880,23 @@ def health_check(self, args: argparse.Namespace):
def list_user_jobs(self, args: argparse.Namespace):
user_id = args.user_id
ejr = self._get_job_registry(cli_args=args)
# TODO: option to return more fields?
jobs = ejr.list_user_jobs(user_id=user_id, fields=["started", "finished", "title"])
request_parameters = {}
if args.pagination:
request_parameters[ElasticJobRegistry.PAGINATION_URL_PARAM] = args.pagination.split("?")[-1]
jobs = ejr.list_user_jobs(
user_id=user_id,
# TODO: option to return more fields?
fields=["started", "finished", "title"],
limit=args.limit,
request_parameters=request_parameters,
)
print(f"Found {len(jobs)} jobs for user {user_id!r}:")
pprint.pp(jobs)
if isinstance(jobs, JobListing):
pprint.pp(jobs.to_response_dict(build_url=lambda d: "/jobs?" + urllib.parse.urlencode(d)))
elif isinstance(jobs, list):
pprint.pp(jobs)
else:
raise ValueError(jobs)

def list_active_jobs(self, args: argparse.Namespace):
ejr = self._get_job_registry(cli_args=args)
Expand Down
2 changes: 1 addition & 1 deletion openeo_driver/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ def list_jobs(user: User):
extra = {k: listing[k] for k in ["federation:missing"] if k in listing}
elif isinstance(listing, JobListing):
data = listing.to_response_dict(
url_for=lambda params: flask.url_for(".list_jobs", **params, _external=True),
build_url=lambda params: flask.url_for(".list_jobs", **params, _external=True),
api_version=requested_api_version(),
)
return flask.jsonify(data)
Expand Down
Loading

0 comments on commit 4875725

Please sign in to comment.