diff --git a/n3fit/src/n3fit/hyper_optimization/hyper_scan.py b/n3fit/src/n3fit/hyper_optimization/hyper_scan.py index 86a242fd8c..eea0425470 100644 --- a/n3fit/src/n3fit/hyper_optimization/hyper_scan.py +++ b/n3fit/src/n3fit/hyper_optimization/hyper_scan.py @@ -135,11 +135,16 @@ def hyper_scan_wrapper(replica_path_set, model_trainer, hyperscanner, max_evals= # Initialize seed for hyperopt trials.rstate = np.random.default_rng(HYPEROPT_SEED) - # For sequential hyperopt restarts, reset the state of `FileTrials` saved in the pickle file - if not hyperscanner.parallel_hyperopt and hyperscanner.restart_hyperopt: - pickle_file_to_load = f"{replica_path_set}/tries.pkl" - log.info("Restarting hyperopt run using the pickle file %s", pickle_file_to_load) - trials = FileTrials.from_pkl(pickle_file_to_load) + if hyperscanner.restart_hyperopt: + # For parallel hyperopt restarts, extract the database tar file + if hyperscanner.parallel_hyperopt: + log.info("Restarting hyperopt run using the MongoDB database %s", trials.db_name) + trials.extract_mongodb_database() + else: + # For sequential hyperopt restarts, reset the state of `FileTrials` saved in the pickle file + pickle_file_to_load = f"{replica_path_set}/tries.pkl" + log.info("Restarting hyperopt run using the pickle file %s", pickle_file_to_load) + trials = FileTrials.from_pkl(pickle_file_to_load) # Call to hyperopt.fmin fmin_args = dict( @@ -154,6 +159,7 @@ def hyper_scan_wrapper(replica_path_set, model_trainer, hyperscanner, max_evals= trials.start_mongo_workers() best = hyperopt.fmin(**fmin_args, show_progressbar=True, max_queue_len=trials.num_workers) trials.stop_mongo_workers() + trials.compress_mongodb_database() else: best = hyperopt.fmin(**fmin_args, show_progressbar=False, trials_save_file=trials.pkl_file) return hyperscanner.space_eval(best) diff --git a/n3fit/src/n3fit/hyper_optimization/mongofiletrials.py b/n3fit/src/n3fit/hyper_optimization/mongofiletrials.py index 4ed0872f89..64f2a877f5 100644 --- a/n3fit/src/n3fit/hyper_optimization/mongofiletrials.py +++ b/n3fit/src/n3fit/hyper_optimization/mongofiletrials.py @@ -2,6 +2,7 @@ Hyperopt trial object for parallel hyperoptimization with MongoDB. Data are fetched from MongoDB databases and stored in the form of json files within the nnfit folder """ +import glob import json import logging import os @@ -100,6 +101,7 @@ def __init__( self._store_trial = False self._json_file = replica_path / "tries.json" + self.database_tar_file = replica_path / f"{self.db_name}.tar.gz" self._parameters = parameters self._rstate = None self._dynamic_trials = [] @@ -203,7 +205,39 @@ def stop_mongo_workers(self): worker.terminate() worker.wait() log.info(f"Stopped mongo worker {self.workers.index(worker)+1}/{self.num_workers}") - except Exception as e: + except Exception as err: log.error( - f"Failed to stop mongo worker {self.workers.index(worker)+1}/{self.num_workers}: {e}" + f"Failed to stop mongo worker {self.workers.index(worker)+1}/{self.num_workers}: {err}" ) + + def compress_mongodb_database(self): + """Saves MongoDB database as tar file""" + # check if the database exist + if not os.path.exists(f"{self.db_name}" and not glob.glob('65*')): + raise FileNotFoundError( + f"The MongoDB database directory '{self.db_name}' does not exist. " + "Ensure it has been initiated correctly and it is in your path." + ) + # create the tar.gz file + try: + log.info(f"Compressing MongoDB database into {self.database_tar_file}") + subprocess.run( + ['tar', '-cvf', f'{self.database_tar_file}', f'{self.db_name}'] + glob.glob('65*'), + check=True, + ) + except subprocess.CalledProcessError as err: + raise RuntimeError(f"Error compressing the database: {err}") + + def extract_mongodb_database(self): + """Untar MongoDB database for use in restarts.""" + # check if the database tar file exist + if not os.path.exists(f"{self.database_tar_file}"): + raise FileNotFoundError( + f"The MongoDB database tar file '{self.database_tar_file}' does not exist." + ) + # extract tar file + try: + log.info(f"Extracting MongoDB database from {self.database_tar_file}") + subprocess.run(['tar', '-xvf', f'{self.database_tar_file}'], check=True) + except subprocess.CalledProcessError as err: + raise RuntimeError(f"Error extracting the database: {err}")