Skip to content

Commit

Permalink
Store Job and tomato version in NetCDF files. (#109)
Browse files Browse the repository at this point in the history
* Store job metadata in netcdf.

* Add metadata tests.

* Annotate groups with component metadata.

* tomato_Component

* More docs
  • Loading branch information
PeterKraus authored Nov 20, 2024
1 parent fc47ee2 commit b334c93
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 57 deletions.
18 changes: 14 additions & 4 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,23 @@ Each *job* stores its data and logs in its own *job* folder, which is a subfolde
Note that a *pipeline* dashboard functionality is planned for a future version of ``tomato``.


Final job data
**************
Final job data and metadata
***************************
By default, all data in the *job* folder is processed to create a NetCDF file. The NetCDF files can be read using :func:`xaray.open_datatree`, returning a :class:`xarray.DataTree`.

In the root node of the :class:`~xarray.DataTree`, a copy of the full *payload* is included, serialised as a json :class:`str`. Additionally, execution-specific metadata, such as the *pipeline* ``name``, and *job* submission/execution/completion time are stored on the root node, too.
In the root node of the :class:`~xarray.DataTree`, the :obj:`attrs` dictionary contains all **tomato**-relevant metadata. This currently includes:

The child nodes of the :class:`~xarray.DataTree` contain the actual data from each *pipeline* *component*, unit-annotated using the CF Metadata Conventions. The node names correspond to the ``role`` that *component* fullfils in a *pipeline*.
- ``tomato_version`` which is the version of **tomato** used to create the NetCDF file,
- ``tomato_Job`` which is the *job* object serialised as a json :class:`str`, containing the full *payload*, sample information, as well as *job* submission/execution/completion time.

The child nodes of the :class:`~xarray.DataTree` contain:

- the actual data from each *pipeline* *component*, unit-annotated using the CF Metadata Conventions. The node names correspond to the ``role`` that *component* fullfils in a *pipeline*.
- a ``tomato_Component`` entry in the :obj:`attrs` object, which is the *component* object serialised as a json :class:`str`, containing information about the *device* address and channel that define the *component*, the *driver* and *device* names, as well as the *component* capabilities.

.. note::

The ``tomato_Job`` and ``tomato_Component`` entries can be converted back to the source objects using :func:`tomato.models.Job.model_validate_json` and :func:`tomato.models.Component.model_validate_json`, respectively.

.. note::

Expand Down
4 changes: 3 additions & 1 deletion src/tomato/daemon/cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,13 @@ def job(msg: dict, daemon: Daemon) -> Reply:
daemon.jobs[jobid] = Job(id=jobid, **msg.get("params", {}))
logger.info("received job %d", jobid)
daemon.nextjob += 1
ret = daemon.jobs[jobid]
else:
for k, v in msg.get("params", {}).items():
logger.debug("setting job parameter %s.%s to %s", jobid, k, v)
setattr(daemon.jobs[jobid], k, v)
cjob = daemon.jobs[jobid]
ret = cjob
if cjob.status in {"c"}:
daemon.jobs[jobid] = CompletedJob(
id=cjob.id,
Expand All @@ -246,7 +248,7 @@ def job(msg: dict, daemon: Daemon) -> Reply:
jobpath=cjob.jobpath,
respath=cjob.respath,
)
return Reply(success=True, msg="job updated", data=daemon.jobs[jobid])
return Reply(success=True, msg="job updated", data=ret)


def driver(msg: dict, daemon: Daemon) -> Reply:
Expand Down
21 changes: 16 additions & 5 deletions src/tomato/daemon/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import pickle
import logging
import xarray as xr
import importlib.metadata
from pathlib import Path
from tomato.models import Daemon
from tomato.models import Daemon, Job

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -42,22 +43,32 @@ def load(daemon: Daemon):
daemon.status = "running"


def merge_netcdfs(jobpath: Path, outpath: Path):
def merge_netcdfs(job: Job, snapshot=False):
"""
Merges the individual pickled :class:`xr.Datasets` of each Component found in
`jobpath` into a single :class:`xr.DataTree`, which is then stored in the NetCDF file,
Merges the individual pickled :class:`xr.Datasets` of each Component found in :obj:`job.jobpath`
into a single :class:`xr.DataTree`, which is then stored in the NetCDF file,
using the Component `role` as the group label.
"""
logger = logging.getLogger(f"{__name__}.merge_netcdf")
logger.debug("opening datasets")
datasets = []
for fn in jobpath.glob("*.pkl"):
logger.debug(f"{job=}")
logger.debug(f"{job.jobpath=}")
for fn in Path(job.jobpath).glob("*.pkl"):
with pickle.load(fn.open("rb")) as ds:
datasets.append(ds)
logger.debug("creating a DataTree from %d groups", len(datasets))
dt = xr.DataTree.from_dict({ds.attrs["role"]: ds for ds in datasets})
logger.debug(f"{dt=}")
root_attrs = {
"tomato_version": importlib.metadata.version("tomato"),
"tomato_Job": job.model_dump_json(),
}
dt.attrs = root_attrs
outpath = job.snappath if snapshot else job.respath
logger.debug("saving DataTree into '%s'", outpath)
dt.to_netcdf(outpath, engine="h5netcdf")
logger.debug(f"{dt=}")


def data_to_pickle(ds: xr.Dataset, path: Path, role: str):
Expand Down
47 changes: 27 additions & 20 deletions src/tomato/daemon/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
import psutil

from tomato.daemon.io import merge_netcdfs, data_to_pickle
from tomato.models import Pipeline, Daemon, Component, Device, Driver
from tomato.models import Pipeline, Daemon, Component, Device, Driver, Job
from dgbowl_schemas.tomato import to_payload
from dgbowl_schemas.tomato.payload import Payload, Task
from dgbowl_schemas.tomato.payload import Task

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -111,7 +111,7 @@ def manage_running_pips(daemon: Daemon, req):
proc = psutil.Process(pid=job.pid)
kill_tomato_job(proc)
logger.info(f"job {job.id} with pid {job.pid} was terminated successfully")
merge_netcdfs(Path(job.jobpath), Path(job.respath))
merge_netcdfs(job)
reset = True
params = dict(status="cd")
# dead jobs marked as running (status == 'r') should be cleared
Expand Down Expand Up @@ -259,7 +259,6 @@ def manager(port: int, timeout: int = 500):
def lazy_pirate(
pyobj: Any, retries: int, timeout: int, address: str, context: zmq.Context
) -> Any:
logger.debug("Here")
req = context.socket(zmq.REQ)
req.connect(address)
poller = zmq.Poller()
Expand Down Expand Up @@ -384,19 +383,27 @@ def tomato_job() -> None:
respath = outpath / f"{prefix}.nc"
snappath = outpath / f"snapshot.{jobid}.nc"
params = dict(respath=str(respath), snappath=str(snappath), jobpath=str(jobpath))
lazy_pirate(pyobj=dict(cmd="job", id=jobid, params=params), **pkwargs)
ret = lazy_pirate(pyobj=dict(cmd="job", id=jobid, params=params), **pkwargs)
if ret.success is False:
logger.error("could not set job status for unknown reason")
return 1
job: Job = ret.data

logger.info("handing off to 'job_main_loop'")
logger.info("==============================")
job_main_loop(context, args.port, payload, pip, jobpath, snappath, logpath)
job_main_loop(context, args.port, job, pip, logpath)
logger.info("==============================")

merge_netcdfs(jobpath, respath)
logger.info("job finished successfully")
job.completed_at = str(datetime.now(timezone.utc))
job.status = "c"

logger.info("job finished successfully, attempting to set status to 'c'")
params = dict(status="c", completed_at=str(datetime.now(timezone.utc)))
logger.info("writing final data to a NetCDF file")
merge_netcdfs(job)

logger.info("attempting to set job status to 'c'")
params = dict(status=job.status, completed_at=job.completed_at)
ret = lazy_pirate(pyobj=dict(cmd="job", id=jobid, params=params), **pkwargs)
logger.debug(f"{ret=}")
if ret.success is False:
logger.error("could not set job status for unknown reason")
return 1
Expand Down Expand Up @@ -438,7 +445,7 @@ def job_thread(

kwargs = dict(address=component.address, channel=component.channel)

datapath = jobpath / f"{component.role}.pkl"
datapath = Path(jobpath) / f"{component.role}.pkl"
logger.debug("distributing tasks:")
for task in tasks:
logger.debug(f"{task=}")
Expand All @@ -463,7 +470,9 @@ def job_thread(
ret = req.recv_pyobj()
if ret.success:
logger.debug("pickling received data")
data_to_pickle(ret.data, datapath, role=component.role)
ds = ret.data
ds.attrs["tomato_Component"] = component.model_dump_json()
data_to_pickle(ds, datapath, role=component.role)
t0 += device.pollrate

logger.debug("polling component '%s' for task completion", component.role)
Expand All @@ -489,10 +498,8 @@ def job_thread(
def job_main_loop(
context: zmq.Context,
port: int,
payload: Payload,
job: Job,
pipname: str,
jobpath: Path,
snappath: Path,
logpath: Path,
) -> None:
"""
Expand All @@ -507,7 +514,7 @@ def job_main_loop(

while True:
req.send_pyobj(dict(cmd="status", sender=sender))
daemon = req.recv_pyobj().data
daemon: Daemon = req.recv_pyobj().data
if all([drv.port is not None for drv in daemon.drvs.values()]):
break
else:
Expand All @@ -519,7 +526,7 @@ def job_main_loop(

# collate steps by role
plan = {}
for step in payload.method:
for step in job.payload.method:
if step.component_tag not in plan:
plan[step.component_tag] = []
plan[step.component_tag].append(step)
Expand All @@ -540,20 +547,20 @@ def job_main_loop(
logger.debug(" driver=%s", driver)
threads[component.role] = Thread(
target=job_thread,
args=(tasks, component, device, driver, jobpath, logpath),
args=(tasks, component, device, driver, job.jobpath, logpath),
name="job-thread",
)
threads[component.role].start()

# wait until threads join or we're killed
snapshot = payload.settings.snapshot
snapshot = job.payload.settings.snapshot
t0 = time.perf_counter()
while True:
logger.debug("tick")
tN = time.perf_counter()
if snapshot is not None and tN - t0 > snapshot.frequency:
logger.debug("creating snapshot")
merge_netcdfs(jobpath, snappath)
merge_netcdfs(job, snapshot=True)
t0 += snapshot.frequency
joined = [proc.is_alive() is False for proc in threads.values()]
if all(joined):
Expand Down
48 changes: 26 additions & 22 deletions src/tomato/driverinterface_1_0/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class Attr(BaseModel):

class ModelInterface(metaclass=ABCMeta):
"""
An abstract base class specifying the a driver interface.
An abstract base class specifying the driver interface.
Individual driver modules should expose a :class:`DriverInterface` which inherits
from this abstract class. Only the methods of this class should be used to interact
Expand All @@ -58,9 +58,13 @@ class ModelInterface(metaclass=ABCMeta):
class DeviceManager(metaclass=ABCMeta):
"""
An abstract base class specifying a manager for an individual component.
This class should handle determining attributes and capabilities of the component,
the reading/writing of those attributes, processing of tasks, and caching and
returning of task data.
"""

driver: super
driver: "ModelInterface"
"""The parent :class:`DriverInterface` instance."""

data: dict[str, list]
Expand All @@ -70,7 +74,7 @@ class DeviceManager(metaclass=ABCMeta):
"""Lock object for thread-safe data manipulation."""

key: tuple
"""The key in :obj:`driver.devmap` referring to this object."""
"""The key in :obj:`self.driver.devmap` referring to this object."""

thread: Thread
"""The worker :class:`Thread`."""
Expand Down Expand Up @@ -203,8 +207,8 @@ def reset(self, **kwargs) -> None:

def CreateDeviceManager(self, key, **kwargs):
"""
A factory function which is used to pass this :class:`ModelInterface` to the new
:class:`DeviceManager` instance.
A factory function which is used to pass this instance of the :class:`ModelInterface`
to the new :class:`DeviceManager` instance.
"""
return self.DeviceManager(self, key, **kwargs)

Expand Down Expand Up @@ -272,20 +276,6 @@ def dev_reset(self, key: tuple, **kwargs: dict) -> Reply:
msg=f"component {key!r} reset successfully",
)

@in_devmap
def attrs(self, key: tuple, **kwargs: dict) -> Reply:
"""
Query available :class:`Attrs` on the specified device component.
Pass-through to the :func:`DeviceManager.attrs` function.
"""
ret = self.devmap[key].attrs(**kwargs)
return Reply(
success=True,
msg=f"attrs of component {key!r} are: {ret}",
data=ret,
)

@in_devmap
def dev_set_attr(self, attr: str, val: Any, key: tuple, **kwargs: dict) -> Reply:
"""
Expand All @@ -307,7 +297,7 @@ def dev_get_attr(self, attr: str, key: tuple, **kwargs: dict) -> Reply:
Get value of the :class:`Attr` from the specified device component.
Pass-through to the :func:`DeviceManager.get_attr` function. Units are not
returned; those can be queried for all :class:`Attrs` using :func:`attrs`.
returned; those can be queried for all :class:`Attrs` using :func:`self.attrs`.
"""
ret = self.devmap[key].get_attr(attr=attr, **kwargs)
Expand All @@ -322,8 +312,8 @@ def dev_status(self, key: tuple, **kwargs: dict) -> Reply:
"""
Get the status report from the specified device component.
Iterates over all :class:`Attrs` on the component that have `status=True` and
returns their values in a :class:`dict`.
Iterates over all :class:`Attrs` on the component that have ``status=True`` and
returns their values in the :obj:`Reply.data` as a :class:`dict`.
"""
ret = {}
for k, attr in self.devmap[key].attrs(key=key, **kwargs).items():
Expand Down Expand Up @@ -474,3 +464,17 @@ def capabilities(self, key: tuple, **kwargs) -> Reply:
msg=f"capabilities supported by component {key!r} are: {ret}",
data=ret,
)

@in_devmap
def attrs(self, key: tuple, **kwargs: dict) -> Reply:
"""
Query available :class:`Attrs` on the specified device component.
Pass-through to the :func:`DeviceManager.attrs` function.
"""
ret = self.devmap[key].attrs(**kwargs)
return Reply(
success=True,
msg=f"attrs of component {key!r} are: {ret}",
data=ret,
)
7 changes: 4 additions & 3 deletions src/tomato/ketchup/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from dgbowl_schemas.tomato import to_payload

from tomato.daemon.io import merge_netcdfs
from tomato.models import Reply, Daemon
from tomato.models import Reply, Daemon, Job

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -288,15 +288,16 @@ def snapshot(
Success: snapshot for job [3] created successfully
"""
jobs = status.data.jobs
jobs: list[Job] = status.data.jobs
for jobid in jobids:
if jobid not in jobs:
return Reply(success=False, msg=f"job {jobid} does not exist")
if jobs[jobid].status in {"q", "qw"}:
return Reply(success=False, msg=f"job {jobid} is still queued")

for jobid in jobids:
merge_netcdfs(Path(jobs[jobid].jobpath), Path(f"snapshot.{jobid}.nc"))
jobs[jobid].snappath = Path(f"snapshot.{jobid}.nc")
merge_netcdfs(jobs[jobid], snapshot=True)
if len(jobids) > 1:
msg = f"snapshot for jobs {jobids} created successfully"
else:
Expand Down
Loading

0 comments on commit b334c93

Please sign in to comment.