Skip to content

Commit

Permalink
Refactor and move the run/execution scripts to the train.py and test.…
Browse files Browse the repository at this point in the history
…py modules
  • Loading branch information
vbelis committed Aug 30, 2022
1 parent 7605037 commit 675c3d2
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 199 deletions.
50 changes: 0 additions & 50 deletions kernel_machines/run_testing

This file was deleted.

147 changes: 0 additions & 147 deletions kernel_machines/run_training

This file was deleted.

55 changes: 54 additions & 1 deletion kernel_machines/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import numpy as np

import util
import argparse
import data_processing
from terminal_enhancer import tcols


def main(args):
def main(args: dict):
_, test_loader = data_processing.get_data(args)
test_features, test_labels = test_loader[0], test_loader[1]
sig_fold, bkg_fold = data_processing.get_kfold_data(
Expand Down Expand Up @@ -37,3 +38,55 @@ def main(args):
)
np.save(output_path + "sig_scores.npy", score_sig)
np.save(output_path + "bkg_scores.npy", score_bkg)

def get_arguments() -> dict:
"""
Parses command line arguments and gives back a dictionary.
Returns: Dictionary with the arguments
"""

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--sig_path",
type=str,
required=True,
help="Path to the signal/anomaly dataset (.h5 format).",
)
parser.add_argument(
"--bkg_path",
type=str,
required=True,
help="Path to the QCD background dataset (.h5 format).",
)
parser.add_argument(
"--test_bkg_path",
type=str,
required=True,
help="Path to the background testing dataset (.h5 format).",
)
parser.add_argument(
"--model", type=str, required=True, help="The folder path of the QSVM model."
)
parser.add_argument(
"--ntest", type=int, default=720, help="Number of test events for the QSVM."
)
parser.add_argument(
"--kfolds", type=int, default=5, help="Number of k-validation/test folds used."
)
args = parser.parse_args()

args = {
"sig_path": args.sig_path,
"bkg_path": args.bkg_path,
"test_bkg_path": args.test_bkg_path,
"model": args.model,
"ntest": args.ntest,
"kfolds": args.kfolds,
}
return args


if __name__ == "__main__":
args = get_arguments()
main(args)

Loading

0 comments on commit 675c3d2

Please sign in to comment.