Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
hy395 committed Dec 24, 2024
1 parent e8ac5aa commit ba65d74
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/baskerville/scripts/hound_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def main():
seqnn_model.restore(args.restore, trunk=args.trunk)
else:
seqnn_model.restore(args.restore, pretrain=True)

# head params
print(
"params in new head: %d"
Expand Down Expand Up @@ -281,7 +281,7 @@ def main():

if args.skip_train:
exit(0)

# train model
if args.keras_fit:
seqnn_trainer.fit_keras(seqnn_model)
Expand Down
4 changes: 3 additions & 1 deletion src/baskerville/seqnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,7 +1025,9 @@ def restore(self, model_file, head_i=0, trunk=False, pretrain=False):
if trunk:
self.model_trunk.load_weights(model_file)
elif pretrain:
self.models[head_i].load_weights(model_file, by_name=True, skip_mismatch=True)
self.models[head_i].load_weights(
model_file, by_name=True, skip_mismatch=True
)
self.model = self.models[head_i]
else:
self.models[head_i].load_weights(model_file)
Expand Down

0 comments on commit ba65d74

Please sign in to comment.