-
Notifications
You must be signed in to change notification settings - Fork 0
/
nn_testonly.py
76 lines (58 loc) · 2.47 KB
/
nn_testonly.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import yaml
import os
from PIL import Image
from tqdm import tqdm
import datetime
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from nn_models import *
from nn_dataset import *
from nn_training import *
from nn_acc import *
from performance import *
import shutil
if __name__ == "__main__":
'''
Different splitting yields different results (test score more similar to train)
'''
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)
tag = "normal_model_with_dropout_on_distortion"
dataset_name = "DAGDatasetDistorted100_100_5"
# dataset_name = "DAGDataset100_100_5"
if not os.path.exists(f"./datasets/{dataset_name}"):
raise FileNotFoundError(f"Dataset {dataset_name} not found")
# decoder = "Building Mass Decoder"
model_name = "model_DAGDataset300_100_5_20240402232909"
# model_name = "model_DAGDatasetDistorted100_100_5_20240403191926"
weights_path = f"./models/{model_name}.pth"
# Load metadata
# ranges, parameter_output_mapping, decoders, switches, batch_cam_angles = load_metadata_for_inference(f"./models/{model_name}_meta.yml", need_full=True, decoder=decoder)
ranges, parameter_output_mapping, decoders, switches, batch_cam_angles = load_metadata_for_inference(f"./models/{model_name}_meta.yml", need_full=True)
# Load the dataset
# dataset = DAGDatasetSingleDecoder(decoder, dataset_name)
dataset = DAGDataset(dataset_name)
train_dataset, val_dataset, test_dataset = split_dataset(dataset, 0.1, 0.1, 0.8)
train_loader, val_loader, test_loader = create_dataloaders_of(train_dataset, val_dataset, test_dataset, batch_size=32)
# Load the model
encoder = Encoder()
model = EncoderDecoderModel(encoder, decoders)
# model = ManualEncoderDecoderModelBM()
model.load_state_dict(torch.load(f"./models/{model_name}.pth", map_location=device
))
model.eval()
model.to(device)
# Loss function
criterion = EncDecsLoss(decoders, switches)
# criterion = custom_loss
results_name = f"results_{tag}.yml"
# Inference
test(model, weights_path, test_loader, criterion, ranges, results_save_path=results_name)
acc_discrete(results_name)
shutil.copyfile("./performance.yml", f"./performance_{tag}.yml")
calculate_performance()
shutil.copyfile("./performance.pdf", f"./performance_{tag}.pdf")