-
Notifications
You must be signed in to change notification settings - Fork 1
/
configuration.py
145 lines (118 loc) · 8.75 KB
/
configuration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import argparse
arg_parser = argparse.ArgumentParser('MAML + SSL (SMI, PL, VAT and FixMatch(later)) ')
# General
arg_parser.add_argument('--seed', type=int, default=123)
arg_parser.add_argument('--gpu_id', default=7, type=int,
help='GPU available. index of gpu, if <0 then use cpu')
arg_parser.add_argument('--data_folder', type=str, default='/home/cxl173430/data/DATASETS/miniimagenet_test', #/miniimagenet_test
help='Path to the folder the data is downloaded to.')
arg_parser.add_argument('--output_folder', type=str, default="output/",
help='Path to the output folder to save the model.')
arg_parser.add_argument('--dataset', type=str, default='miniimagenet',
choices=['mnist', 'omniglot', 'miniimagenet', 'tieredimagenet', 'cifarfs', 'svhn'],
help='Name of the dataset (default: omniglot).')
arg_parser.add_argument('--ratio', type=float, default=0.01,
help='ratio of labeled for each class in the task.')
# SSL Experimental Setting
arg_parser.add_argument('--scenario', type=str, default="woDistractor",
choices=["woDistractor", "distractor", "random" , "imbalance"],
help="Different SS FSL approaches, including subset selection and baselines")
arg_parser.add_argument('--ssl_algo', type=str, default='PLtopZperClass') # "PL", "VAT", "SMI", "PLtopZ", "PLtopZperClass"
arg_parser.add_argument('--selection_option', type=str, default='cross') # "same", "cross", "union"
arg_parser.add_argument('--type_smi', type=str, default='vanilla') # "vanilla", "rank", "gain"
arg_parser.add_argument('--ssl_algo_meta_test', type=str, default='mamlTestLargeS') # "no", "mamlTestLargeS"
arg_parser.add_argument('--coef_inner', type=float, default=-1,
help='coefficient of ssl loss function in the inner loop.')
arg_parser.add_argument('--coef_outer', type=float, default=-1,
help='coefficient of ssl loss function in the outer loop.')
# SMI
arg_parser.add_argument("--sf", type=str, default="fl2mi")
arg_parser.add_argument("--budget_s", type=int, default=25) # 30 for 1-shot, 50 for 5-shot, for support set
arg_parser.add_argument("--budget_q", type=int, default=75) # for query set
arg_parser.add_argument("--embedding_type", type=str, default="gradients") # for query set
arg_parser.add_argument('--num_ways', type=int, default=5,
help='Number of classes per task (N in "N-way").')
arg_parser.add_argument('--num_shots', type=int, default=1,
help='Number of examples per class for support set (k in "k-shot").')
arg_parser.add_argument('--num_shots_test_meta_train', type=int, default=15,
help='Number of examples per class for query set. If negative, same as `--num_shots`(default:15).')
arg_parser.add_argument('--num_shots_test_meta_test', type=int, default=15,
help='Number of examples per class for query set. If negative, same as `--num_shots`(default:15).')
# unlabeled part, for with and without distractor
arg_parser.add_argument('--num_shots_unlabeled', type=int, default=50,
help='Number of unlabeled example per class.') # 200 for SVHN, 300 for MNIST, 20 for miniImagenet
arg_parser.add_argument('--num_shots_unlabeled_evaluate', type=int, default=50,
help='Number of unlabeled example per class during meta-validation/test.') # 200 for SVHN, 300 for MNIST, 20 for miniImagenet
# unlabeled part, Scenario: distractor.
arg_parser.add_argument('--num_classes_distractor', type=int, default=0,
help='Number of distractor classes.')
arg_parser.add_argument('--num_shots_distractor', type=int, default=0,
help='Number of unlabeled example per distractor class during meta-training.')
arg_parser.add_argument('--num_shots_distractor_eval', type=int, default=0,
help='Number of unlabeled example per distractor class during meta-validation/test.')
# unlabeled part, for random selection
arg_parser.add_argument('--num_unlabel_total', type=int, default=100,
help='Num of unlabeled examples totally for each task. (default: 600 for SVHN, 900 for MNIST).')
arg_parser.add_argument('--num_unlabel_total_evaluate', type=int, default=100,
help='Number of unlabeled examples totally for each task during meta-validation/test.')
# CNN Model
arg_parser.add_argument('--hidden_size', type=int, default=64,
help='Number of channels in each convolution layer of the VGG network (default: 64).')
# Optimization
arg_parser.add_argument('--first_order', action='store_true',
help='Use the first order approximation, do not use higher-order derivatives during meta-optimization.')
arg_parser.add_argument('--batch_size', type=int, default=1,
help='Number of tasks in a batch of tasks for meta-training.')
arg_parser.add_argument('--batch_size_val', type=int, default=1,
help='Number of tasks in a batch of tasks for meta-validation.')
arg_parser.add_argument('--batch_size_test', type=int, default=1, # todo: check this later
help='Number of tasks in a batch of tasks for meta-test.')
arg_parser.add_argument('--num_epochs', type=int, default=400,
help='Number of epochs of meta-training (default: 50).')
arg_parser.add_argument('--num_batches', type=int, default=100,
help='Number of batch of tasks per epoch (default: 100).')
arg_parser.add_argument('--num_batches_eval', type=int, default=100,
help='Number of batch of tasks per epoch (default: 100).')
arg_parser.add_argument('--step_size', type=float, default=0.001,
help='Size of the fast adaptation step, ie. learning rate in the '
'gradient descent update (default: 0.1).')
arg_parser.add_argument('--swn_lr', type=float, default=0.001,
help='learning rate for swn (default 0.001).')
arg_parser.add_argument('--meta_lr', type=float, default=0.0001,
help='Learning rate for the meta-optimizer (optimization of the outer '
'loss). The default optimizer is Adam (default: 1e-3).')
arg_parser.add_argument('--num_steps', type=int, default=5,
help='Number of fast adaptation steps, ie. gradient descent updates (default: 1).')
arg_parser.add_argument('--num_steps_evaluate', type=int, default=10,
help='Number of fast adaptation steps in valid/test, ie. gradient descent updates.')
# PL with threshold or top Z
arg_parser.add_argument('--pl_threshold', type=float, default=0,
help='The threshold used in the PL algorithm in the inner loop. (Default: 0)')
arg_parser.add_argument('--pl_threshold_outer', type=float, default=0,
help='The threshold used in the PL algorithm in the outer loop. (Default: 0)')
arg_parser.add_argument("--pl_num_topz", type=int, default=25,
help='The number of examples which have top Z probability logits in the inner loop')
arg_parser.add_argument("--pl_num_topz_outer", type=int, default=75,
help='The number of examples which have top Z probability logits in the outer loop')
arg_parser.add_argument("--pl_batch_size", type=int, default=80)
# Miscellaneous
arg_parser.add_argument('--num_workers', type=int, default=1,
help='Number of workers to use for data-loading (default: 1).')
arg_parser.add_argument('--verbose', action='store_true')
# debugging purpose
arg_parser.add_argument('--select_true_label', action='store_true') # false: pl, True: true label
arg_parser.add_argument('--no_outer_selection', action='store_true',
help='whether outer loop has selection or not')
arg_parser.add_argument("--interval_val", type=int, default=1)
arg_parser.add_argument("--WARMSTART_EPOCH", type=int, default=100)
arg_parser.add_argument("--resume", action='store_true')
##### tmp for LST
arg_parser.add_argument('--WARM_inner', type=int, default=1)
arg_parser.add_argument('--re_train_step', type=int, default=2)
arg_parser.add_argument('--WARM_inner_test', type=int, default=2)
arg_parser.add_argument('--re_train_step_test', type=int, default=5)
arg_parser.add_argument('--selection_option_LST', type=str, default='withOUTSelection') # withoutSelection
arg_parser.add_argument('--inStepsSet', type=str, default='small') # large
arg_parser.add_argument("--in_select_ty", type=str, default='continue')
arg_parser.add_argument('--num_unlabel_many', type=int, default=50)
arg_parser.add_argument('--num_unlabel_less', type=int, default=25)