diff --git a/examples/analysis_example.py b/examples/analysis_example.py index aec3673..9e9cb70 100644 --- a/examples/analysis_example.py +++ b/examples/analysis_example.py @@ -7,6 +7,10 @@ def main(): + + # Note: running the code below within a function is necessary to ensure that + # multiprocessing (used to calculate DOPE and Ramachandran) runs correctly + print("> Loading network parameters...") fname = f"xbb_foldingnet_checkpoints{os.sep}checkpoint_no_optimizer_state_dict_epoch167_loss0.003259085263643.ckpt" @@ -15,8 +19,9 @@ def main(): checkpoint = torch.load(fname, map_location=device) net = AutoEncoder(**checkpoint["network_kwargs"]) net.load_state_dict(checkpoint["model_state_dict"]) + + # the network is currently on CPU. If GPU is available, move it there if torch.cuda.is_available(): - # otherwise net is still not on the GPU net.to(device) print("> Loading training data...")