diff --git a/config.py b/config.py index 55289a2..deeaac1 100644 --- a/config.py +++ b/config.py @@ -7,7 +7,7 @@ LOG_DIR = 'log/' MODEL_FILENAME = "model.ckpt" -DATASET_FOLDER = '../../benchmark_datasets/' +DATASET_FOLDER = '/home/cc/dm_data/' # TRAIN BATCH_NUM_QUERIES = 2 @@ -18,7 +18,7 @@ BASE_LEARNING_RATE = 0.000005 MOMENTUM = 0.9 OPTIMIZER = 'ADAM' -MAX_EPOCH = 20 +MAX_EPOCH = 100 MARGIN_1 = 0.5 MARGIN_2 = 0.2 diff --git a/evaluate.py b/evaluate.py index 0b29a86..2a822b1 100644 --- a/evaluate.py +++ b/evaluate.py @@ -4,7 +4,7 @@ import socket import importlib import os -os.environ["CUDA_VISIBLE_DEVICES"] = "3" +os.environ["CUDA_VISIBLE_DEVICES"] = "2" import sys import torch import torch.nn as nn @@ -45,10 +45,12 @@ def evaluate(): print("ave_one_percent_recall:"+str(ave_one_percent_recall)) -def evaluate_model(model,epoch,save=False): +def evaluate_model(model,optimizer,epoch,save=False): if save: torch.save({ 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'epoch': epoch, }, cfg.LOG_DIR + "checkpoint.pth.tar") #checkpoint = torch.load(cfg.LOG_DIR + "checkpoint.pth.tar") diff --git a/generating_queries/generate_training_tuples_cc_baseline.py b/generating_queries/generate_training_tuples_cc_baseline.py index 2ce0ba6..4bbd7ee 100644 --- a/generating_queries/generate_training_tuples_cc_baseline.py +++ b/generating_queries/generate_training_tuples_cc_baseline.py @@ -27,7 +27,7 @@ folders = [] # All runs are used for training (both full and partial) -index_list = range(10) +index_list = range(18) print("Number of runs: "+str(len(index_list))) for index in index_list: folders.append(all_folders[index]) diff --git a/train_pointnetvlad.py b/train_pointnetvlad.py index 5c9202a..eff2090 100644 --- a/train_pointnetvlad.py +++ b/train_pointnetvlad.py @@ -202,7 +202,7 @@ def train(): log_string('EVALUATING...') cfg.OUTPUT_FILE = cfg.RESULTS_FOLDER + 'results_' + str(epoch) + '.txt' - eval_recall = evaluate.evaluate_model(model, epoch,True) + eval_recall = evaluate.evaluate_model(model,optimizer,epoch,True) log_string('EVAL RECALL: %s' % str(eval_recall)) train_writer.add_scalar("Val Recall", eval_recall, epoch)