Skip to content

Commit

Permalink
adding version works better for logs now, it is saved as an int rathe…
Browse files Browse the repository at this point in the history
…r than a string
  • Loading branch information
Samuel Musson committed Oct 26, 2023
1 parent c4fb49f commit 20081d5
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions src/molearn/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,10 @@ def prepare_logs(self, log_filename, log_folder=None):
if log_folder is not None:
if not os.path.exists(log_folder):
os.mkdir(log_folder)
self.log_filename = log_folder+'/'+self.log_filename
if hasattr(self, "_repeat") and self._repeat >0:
self.log_filename = f'{log_folder}/{self._repeat}_{self.log_filename}'
else:
self.log_filename = f'{log_folder}/{self.log_filename}'


def run(self, max_epochs=100, log_filename=None, log_folder=None, checkpoint_frequency=1, checkpoint_folder='checkpoint_folder', allow_n_failures=10, verbose=None, allow_grad_in_valid=False):
Expand All @@ -166,6 +169,7 @@ def run(self, max_epochs=100, log_filename=None, log_folder=None, checkpoint_fre
:param bool verbose: (default: None) set trainer.verbose. If True, the epoch logs will be printed as well as written to log_filename
'''
self.get_repeat(checkpoint_folder)
self.prepare_logs(log_filename if log_filename is not None else self.log_filename, log_folder)
#if log_filename is not None:
# self.log_filename = log_filename
Expand Down Expand Up @@ -382,7 +386,6 @@ def checkpoint(self, epoch, valid_logs, checkpoint_folder, loss_key='valid_loss'
:param str checkpoint_folder: The folder in which to save the checkpoint.
:param str loss_key: (default: 'valid_loss') The key with which to get loss from valid_logs.
'''
self.get_repeat(checkpoint_folder)
valid_loss = valid_logs[loss_key]
if not os.path.exists(checkpoint_folder):
os.mkdir(checkpoint_folder)
Expand All @@ -394,11 +397,11 @@ def checkpoint(self, epoch, valid_logs, checkpoint_folder, loss_key='valid_loss'
'atoms': self._data.atoms,
'std': self.std,
'mean': self.mean},
f'{checkpoint_folder}/last{self._repeat}.ckpt')
f'{checkpoint_folder}/last{f"_{self._repeat}" if self._repeat > 0 else ""}.ckpt')

if self.best is None or self.best > valid_loss:
filename = f'{checkpoint_folder}/checkpoint{self._repeat}_epoch{epoch}_loss{valid_loss}.ckpt'
shutil.copyfile(f'{checkpoint_folder}/last{self._repeat}.ckpt', filename)
filename = f'{checkpoint_folder}/checkpoint{f"_{self._repeat}" if self._repeat>0 else ""}_epoch{epoch}_loss{valid_loss}.ckpt'
shutil.copyfile(f'{checkpoint_folder}/last{f"_{self._repeat}" if self._repeat>0 else ""}.ckpt', filename)
if self.best is not None:
os.remove(self.best_name)
self.best_name = filename
Expand Down Expand Up @@ -441,14 +444,12 @@ def get_repeat(self, checkpoint_folder):
if not os.path.exists(checkpoint_folder):
os.mkdir(checkpoint_folder)
if not hasattr(self, '_repeat'):
_repeat = 0
self._repeat = f'_{_repeat}' if _repeat>0 else ''
self._repeat = 0
for i in range(1000):
if not os.path.exists(checkpoint_folder+f'/last{self._repeat}.ckpt'):
if not os.path.exists(checkpoint_folder+f'/last{f"_{self._repeat}" if self._repeat>0 else ""}.ckpt'):
break#os.mkdir(checkpoint_folder)
else:
_repeat += 1
self._repeat = f'_{_repeat}' if _repeat>0 else ''
self._repeat += 1
else:
raise Exception('Something went wrong, you surely havnt done 1000 repeats?')

Expand Down

0 comments on commit 20081d5

Please sign in to comment.