From 6806e2c5a79a59459c1005345d26a3491ec5c33f Mon Sep 17 00:00:00 2001 From: sengineer0 Date: Tue, 7 Mar 2023 00:03:22 +0700 Subject: [PATCH] Add dry_run option for mark datasource as dump success --- biothings/hub/dataload/dumper.py | 48 +++++++++++++++++++------------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/biothings/hub/dataload/dumper.py b/biothings/hub/dataload/dumper.py index 6022fec06..b49b5b2b8 100644 --- a/biothings/hub/dataload/dumper.py +++ b/biothings/hub/dataload/dumper.py @@ -12,6 +12,7 @@ import subprocess import time from concurrent.futures import ProcessPoolExecutor +from copy import deepcopy from datetime import datetime, timezone from functools import partial from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Tuple, Union @@ -268,7 +269,8 @@ def prepare_src_dump(self): self.src_dump = get_src_dump() self.src_doc = self.src_dump.find_one({'_id': self.src_name}) or {} - def register_status(self, status, transient=False, **extra): + def register_status(self, status, transient=False, dry_run=False, **extra): + src_doc = deepcopy(self.src_doc) try: # if status is "failed" and depending on where it failed, # we may not be able to get the new_data_folder (if dumper didn't reach @@ -281,20 +283,18 @@ def register_status(self, status, transient=False, **extra): # it has not been set by the dumper before while exploring # remote site. maybe we're just running post step ? # back-compatibility; use "release" at root level if not found under "download" - release = self.src_doc.get("download", {}).get("release") or self.src_doc.get( - "release" - ) + release = src_doc.get("download", {}).get("release") or src_doc.get("release") self.logger.error( "No release set, assuming: data_folder: %s, release: %s" % (data_folder, release) ) # make sure to remove old "release" field to get back on track for field in ["release", "data_folder"]: - if self.src_doc.get(field): + if src_doc.get(field): self.logger.warning( "Found '%s'='%s' at root level, convert to new format" - % (field, self.src_doc[field]) + % (field, src_doc[field]) ) - self.src_doc.pop(field) + src_doc.pop(field) current_download_info = { '_id': self.src_name, @@ -312,7 +312,7 @@ def register_status(self, status, transient=False, **extra): last_success = current_download_info["download"]["started_at"] else: # If failed, we will get the last_success from the last download instead. - last_download_info = self.src_doc.setdefault("download", {}) + last_download_info = src_doc.setdefault("download", {}) last_success = last_download_info.get("last_success", None) if not last_success and last_download_info.get("status") == 'success': # If last_success from the last download doesn't exist or is None, and last @@ -321,18 +321,22 @@ def register_status(self, status, transient=False, **extra): if last_success: current_download_info["download"]["last_success"] = last_success - self.src_doc.update(current_download_info) + src_doc.update(current_download_info) # only register time when it's a final state if transient: - self.src_doc["download"]["pid"] = os.getpid() + src_doc["download"]["pid"] = os.getpid() else: - self.src_doc["download"]["time"] = timesofar(self.t0) + src_doc["download"]["time"] = timesofar(self.t0) if "download" in extra: - self.src_doc["download"].update(extra["download"]) + src_doc["download"].update(extra["download"]) else: - self.src_doc.update(extra) - self.src_dump.save(self.src_doc) + src_doc.update(extra) + + # when dry run, we should not change the src_doc, and src_dump + if not dry_run: + self.src_doc = deepcopy(src_doc) + self.src_dump.save(src_doc) async def dump(self, steps=None, force=False, job_manager=None, check_only=False, **kwargs): ''' @@ -423,13 +427,18 @@ def postdumped(f): if self.client: self.release_client() - def mark_success(self): + def mark_success(self, dry_run=True): ''' Mark the datasource as successful dumped. It's useful in case the datasource is unstable, and need to be manually downloaded. ''' - self.register_status("success") + self.register_status("success", dry_run=dry_run) self.logger.info("Done!") + result = { + "_id": self.src_doc["_id"], + "download": self.src_doc["download"], + } + return result def get_predicates(self): """ @@ -1452,7 +1461,8 @@ def dump_src( logging.error("Error while dumping '%s': %s" % (src, e)) raise - def mark_success(self, src): + def mark_success(self, src, dry_run=True): + result = [] if src in self.register: klasses = self.register[src] else: @@ -1461,8 +1471,8 @@ def mark_success(self, src): ) for _, klass in enumerate(klasses): inst = self.create_instance(klass) - inst.mark_success() - + result.append(inst.mark_success(dry_run=dry_run)) + return result def call(self, src, method_name, *args, **kwargs): """