Skip to content

Commit

Permalink
added flag to save json log file and improved trajectory preparation
Browse files Browse the repository at this point in the history
  • Loading branch information
gwirn committed Jun 26, 2024
1 parent e3b651d commit cd6f347
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 20 deletions.
35 changes: 26 additions & 9 deletions src/molearn/data/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def read_traj(self) -> None:
top0 = None
ucell0 = None
for ct, (trp, top) in enumerate(zip(self.traj_path, self.topo_path)):
print(f"\tLoading {os.path.basename(trp)}")
if self.verbose:
print(f"\tLoading {os.path.basename(trp)}")
loaded = None
try:
# do not enforce topology file on this formats
Expand All @@ -171,24 +172,38 @@ def read_traj(self) -> None:
loaded.unitcell_vectors = ucell0
multi_traj.append(loaded)
self.traj = md.join(multi_traj)
# Recenter and apply periodic boundary
if self.image_mol:
try:
if self.verbose:
print("Imaging faild - retrying with supplying anchor molecules")
self.traj.image_molecules(inplace=True)
except ValueError:
try:
self.traj.image_molecules(
inplace=True,
anchor_molecules=[set(self.traj.topology.residue(0).atoms)],
)
except ValueError as e:
print(
f"Unable to image molecule due to '{e}' - will just recenter it"
)
self.traj.superpose(self.traj[0])
# maybe not needed after image_molecules
self.traj.center_coordinates()
# converts ELEMENT names from eg "Cd" -> "C" to avoid later complications
topo_table, topo_bonds = self.traj.topology.to_dataframe()
topo_table["element"] = topo_table["element"].apply(
lambda x: x if len(x.strip()) <= 1 else x.strip()[0]
)
if self.verbose:
print("Saving new topology")
self.traj.topology = md.Topology.from_dataframe(topo_table, topo_bonds)
# save new topology
self.traj[0].save_pdb(
os.path.join(self.outpath, f"./{self.traj_name}_NEW_TOPO.pdb")
)
# Recenter and apply periodic boundary
if self.image_mol:
try:
self.traj.image_molecules(inplace=True)
except ValueError as e:
print(f"Unable to image molecule due to '{e}' - will just recenter it")
# maybe not needed after image_molecules
self.traj.center_coordinates()

n_frames = self.traj.n_frames
# which index separated indices from training and test dataset
self.test_border = int(n_frames * (1.0 - self.test_size))
Expand All @@ -205,6 +220,8 @@ def read_traj(self) -> None:
atom_indices = [
a.index for a in train_traj.topology.atoms if a.element.symbol != "H"
]
if self.verbose:
print("Calculating disance matrix")
# distance matrix between all frames
self.traj_dists = np.empty((n_train_frames, n_train_frames))
for i in range(n_train_frames):
Expand Down
29 changes: 18 additions & 11 deletions src/molearn/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ class Trainer:
"""

def __init__(self, device=None, log_filename="log_file.dat"):
def __init__(self, device=None, log_filename="log_file.dat", json_log=False):
"""
:param torch.Device device: if not given will be determinined automatically based on torch.cuda.is_available()
:param str log_filename: (default: 'default_log_filename.json') file used to log outputs to
:param bool json_log: True to use json.dump to save the log file
"""
if not device:
self.device = (
Expand All @@ -52,6 +53,7 @@ def __init__(self, device=None, log_filename="log_file.dat"):
self.verbose = True
self.log_filename = "default_log_filename.csv"
self.scheduler_key = None
self.json_log = json_log

def get_network_summary(self):
"""
Expand Down Expand Up @@ -151,17 +153,22 @@ def log(self, log_dict, verbose=None):
print(f"{k: <{max_key_len+1}}: {v:.6f}")
print()

# create header if file doesn't exist => first epoch
if not os.path.isfile(self.log_filename):
with open(self.log_filename, "a") as f:
f.write(f"{','.join([str(k) for k in log_dict.keys()])}\n")
if not self.json_log:
# create header if file doesn't exist => first epoch
if not os.path.isfile(self.log_filename):
with open(self.log_filename, "a") as f:
f.write(f"{','.join([str(k) for k in log_dict.keys()])}\n")

with open(self.log_filename, "a") as f:
# just try to format if it is not a Failure
if "Failure" not in log_dict.values():
f.write(f"{','.join([str(v) for v in log_dict.values()])}\n")
else:
dump = json.dumps(log_dict)
with open(self.log_filename, "a") as f:
# just try to format if it is not a Failure
if "Failure" not in log_dict.values():
f.write(f"{','.join([str(v) for v in log_dict.values()])}\n")
else:
dump = json.dumps(log_dict)
f.write(dump + "\n")
else:
dump = json.dumps(log_dict)
with open(self.log_filename, "a") as f:
f.write(dump + "\n")

def scheduler_step(self, logs):
Expand Down

0 comments on commit cd6f347

Please sign in to comment.