diff --git a/docs/source/examples/plot_dcca_multi.py b/docs/source/examples/plot_dcca_multi.py index c462cced..490a52e6 100644 --- a/docs/source/examples/plot_dcca_multi.py +++ b/docs/source/examples/plot_dcca_multi.py @@ -40,7 +40,7 @@ dcca_mcca = DCCA( latent_dimensions=LATENT_DIMS, encoders=[encoder_1, encoder_2], - objective=objectives.MCCA, + objective=objectives.MCCALoss, ) trainer_mcca = pl.Trainer(max_epochs=EPOCHS, enable_checkpointing=False, enable_model_summary=False,enable_progress_bar=False) trainer_mcca.fit(dcca_mcca, train_loader, val_loader) @@ -54,7 +54,7 @@ dcca_gcca = DCCA( latent_dimensions=LATENT_DIMS, encoders=[encoder_1, encoder_2], - objective=objectives.GCCA, + objective=objectives.GCCALoss, ) trainer_gcca = pl.Trainer(max_epochs=EPOCHS, enable_checkpointing=False, enable_model_summary=False,enable_progress_bar=False) trainer_gcca.fit(dcca_gcca, train_loader, val_loader)