Skip to content

Commit

Permalink
Added initial MongoFileTrials methods to allow for restarts
Browse files Browse the repository at this point in the history
  • Loading branch information
Cmurilochem committed Feb 21, 2024
1 parent 2cecb0a commit 08e16c4
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 7 deletions.
16 changes: 11 additions & 5 deletions n3fit/src/n3fit/hyper_optimization/hyper_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,16 @@ def hyper_scan_wrapper(replica_path_set, model_trainer, hyperscanner, max_evals=
# Initialize seed for hyperopt
trials.rstate = np.random.default_rng(HYPEROPT_SEED)

# For sequential hyperopt restarts, reset the state of `FileTrials` saved in the pickle file
if not hyperscanner.parallel_hyperopt and hyperscanner.restart_hyperopt:
pickle_file_to_load = f"{replica_path_set}/tries.pkl"
log.info("Restarting hyperopt run using the pickle file %s", pickle_file_to_load)
trials = FileTrials.from_pkl(pickle_file_to_load)
if hyperscanner.restart_hyperopt:
# For parallel hyperopt restarts, extract the database tar file
if hyperscanner.parallel_hyperopt:
log.info("Restarting hyperopt run using the MongoDB database %s", trials.db_name)
trials.extract_mongodb_database()
else:
# For sequential hyperopt restarts, reset the state of `FileTrials` saved in the pickle file
pickle_file_to_load = f"{replica_path_set}/tries.pkl"
log.info("Restarting hyperopt run using the pickle file %s", pickle_file_to_load)
trials = FileTrials.from_pkl(pickle_file_to_load)

# Call to hyperopt.fmin
fmin_args = dict(
Expand All @@ -154,6 +159,7 @@ def hyper_scan_wrapper(replica_path_set, model_trainer, hyperscanner, max_evals=
trials.start_mongo_workers()
best = hyperopt.fmin(**fmin_args, show_progressbar=True, max_queue_len=trials.num_workers)
trials.stop_mongo_workers()
trials.compress_mongodb_database()
else:
best = hyperopt.fmin(**fmin_args, show_progressbar=False, trials_save_file=trials.pkl_file)
return hyperscanner.space_eval(best)
Expand Down
38 changes: 36 additions & 2 deletions n3fit/src/n3fit/hyper_optimization/mongofiletrials.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Hyperopt trial object for parallel hyperoptimization with MongoDB.
Data are fetched from MongoDB databases and stored in the form of json files within the nnfit folder
"""
import glob
import json
import logging
import os
Expand Down Expand Up @@ -100,6 +101,7 @@ def __init__(

self._store_trial = False
self._json_file = replica_path / "tries.json"
self.database_tar_file = replica_path / f"{self.db_name}.tar.gz"
self._parameters = parameters
self._rstate = None
self._dynamic_trials = []
Expand Down Expand Up @@ -203,7 +205,39 @@ def stop_mongo_workers(self):
worker.terminate()
worker.wait()
log.info(f"Stopped mongo worker {self.workers.index(worker)+1}/{self.num_workers}")
except Exception as e:
except Exception as err:
log.error(
f"Failed to stop mongo worker {self.workers.index(worker)+1}/{self.num_workers}: {e}"
f"Failed to stop mongo worker {self.workers.index(worker)+1}/{self.num_workers}: {err}"
)

def compress_mongodb_database(self):
"""Saves MongoDB database as tar file"""
# check if the database exist
if not os.path.exists(f"{self.db_name}" and not glob.glob('65*')):
raise FileNotFoundError(
f"The MongoDB database directory '{self.db_name}' does not exist. "
"Ensure it has been initiated correctly and it is in your path."
)
# create the tar.gz file
try:
log.info(f"Compressing MongoDB database into {self.database_tar_file}")
subprocess.run(
['tar', '-cvf', f'{self.database_tar_file}', f'{self.db_name}'] + glob.glob('65*'),
check=True,
)
except subprocess.CalledProcessError as err:
raise RuntimeError(f"Error compressing the database: {err}")

def extract_mongodb_database(self):
"""Untar MongoDB database for use in restarts."""
# check if the database tar file exist
if not os.path.exists(f"{self.database_tar_file}"):
raise FileNotFoundError(
f"The MongoDB database tar file '{self.database_tar_file}' does not exist."
)
# extract tar file
try:
log.info(f"Extracting MongoDB database from {self.database_tar_file}")
subprocess.run(['tar', '-xvf', f'{self.database_tar_file}'], check=True)
except subprocess.CalledProcessError as err:
raise RuntimeError(f"Error extracting the database: {err}")

0 comments on commit 08e16c4

Please sign in to comment.