Skip to content

Commit

Permalink
Merge pull request #70 from calico/testbranch
Browse files Browse the repository at this point in the history
Allow specification of validation loss check interval
  • Loading branch information
georgiaschmitt authored Oct 21, 2021
2 parents 0c03379 + fa7db6f commit 824a707
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions solo/solo.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def main():
batch_key = params.get("batch_key", None)
batch_size = params.get("batch_size", 128)
valid_pct = params.get("valid_pct", 0.1)
check_val_every_n_epoch = params.get("check_val_every_n_epoch", 5)
learning_rate = params.get("learning_rate", 1e-3)
stopping_params = {"patience": params.get("patience", 8), "min_delta": 0}

Expand Down Expand Up @@ -231,11 +232,10 @@ def main():
"lr_min": 1e-4,
"lr_scheduler_metric": "reconstruction_loss_validation",
}

vae.train(
max_epochs=2000,
validation_size=valid_pct,
check_val_every_n_epoch=5,
check_val_every_n_epoch=check_val_every_n_epoch,
plan_kwargs=plan_kwargs,
callbacks=scvi_callbacks,
)
Expand Down

0 comments on commit 824a707

Please sign in to comment.