-
Notifications
You must be signed in to change notification settings - Fork 1
/
testing.py
24 lines (21 loc) · 870 Bytes
/
testing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from utils import str_to_class
from tqdm import tqdm
from tensorboardX import SummaryWriter
import torchvision.utils as vutils
def get_tester(test_loader,model,save_dir,device):
class_ = str_to_class('Tester','testing')
instance = class_(test_loader,model,save_dir,device)
return instance or None
class Tester():
def __init__(self,test_loader,model,save_dir,device):
self.device = device
self.save_dir = save_dir
self.test_loader = test_loader
self.model = model
self.writer = SummaryWriter()
def run(self):
for idx, input_data in tqdm(enumerate(self.test_loader)):
# print(input_data)
out = self.model(input_data)
self.model.backbone_visualization(input_data = input_data, predicted = out,save_dir = self.save_dir,idx = idx)
#self.renderer.visualize(silhouete,image_ref,save_dir = 'rendered',idx = idx)