Skip to content

Commit

Permalink
update Supervised
Browse files Browse the repository at this point in the history
  • Loading branch information
chao chen committed Oct 25, 2021
1 parent f4065b1 commit 09f4ba8
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 6 deletions.
4 changes: 2 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
LOG_DIR = 'log/'
MODEL_FILENAME = "model.ckpt"

DATASET_FOLDER = '../../benchmark_datasets/'
DATASET_FOLDER = '/home/cc/dm_data/'

# TRAIN
BATCH_NUM_QUERIES = 2
Expand All @@ -18,7 +18,7 @@
BASE_LEARNING_RATE = 0.000005
MOMENTUM = 0.9
OPTIMIZER = 'ADAM'
MAX_EPOCH = 20
MAX_EPOCH = 100

MARGIN_1 = 0.5
MARGIN_2 = 0.2
Expand Down
6 changes: 4 additions & 2 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import socket
import importlib
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import sys
import torch
import torch.nn as nn
Expand Down Expand Up @@ -45,10 +45,12 @@ def evaluate():
print("ave_one_percent_recall:"+str(ave_one_percent_recall))


def evaluate_model(model,epoch,save=False):
def evaluate_model(model,optimizer,epoch,save=False):
if save:
torch.save({
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
}, cfg.LOG_DIR + "checkpoint.pth.tar")

#checkpoint = torch.load(cfg.LOG_DIR + "checkpoint.pth.tar")
Expand Down
2 changes: 1 addition & 1 deletion generating_queries/generate_training_tuples_cc_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
folders = []

# All runs are used for training (both full and partial)
index_list = range(10)
index_list = range(18)
print("Number of runs: "+str(len(index_list)))
for index in index_list:
folders.append(all_folders[index])
Expand Down
2 changes: 1 addition & 1 deletion train_pointnetvlad.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def train():
log_string('EVALUATING...')
cfg.OUTPUT_FILE = cfg.RESULTS_FOLDER + 'results_' + str(epoch) + '.txt'

eval_recall = evaluate.evaluate_model(model, epoch,True)
eval_recall = evaluate.evaluate_model(model,optimizer,epoch,True)
log_string('EVAL RECALL: %s' % str(eval_recall))

train_writer.add_scalar("Val Recall", eval_recall, epoch)
Expand Down

0 comments on commit 09f4ba8

Please sign in to comment.