-
Notifications
You must be signed in to change notification settings - Fork 8
/
model_args.py
51 lines (40 loc) · 1.29 KB
/
model_args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# -------------------
args = {}
args['use_cuda'] = True
args['encoder_size'] = 64 #64
args['decoder_size'] = 128 #128
args['in_length'] = 16
args['out_length'] = 25
args['grid_size'] = (13, 3)
args['dyn_embedding_size'] = 32
args['dyn_matrix_and_centralit_input']=39
args['input_embedding_size'] = 32
args['num_lat_classes'] = 3
args['num_lon_classes'] = 3
args['train_flag'] = True
args['dyn_matrix_and_centralit_output'] = 32
# Dimensionality of the input:
# 2D (X and Y or R and Theta)
# 3D (adding velocity as a 3d dimension)
args['input_dim'] = 3
# Using Intention module?
args['intention_module'] = True
# Choose the pooling mechanism
# -----------------------------
args['pooling'] = 'polar'
if args['pooling'] == 'slstm':
args['kernel_size'] = (4, 3)
elif args['pooling'] == 'cslstm':
args['soc_conv_depth'] = 64
args['conv_3x1_depth'] = 16
elif args['pooling'] == 'sgan' or args['pooling'] == 'polar':
args['bottleneck_dim'] = 256
args['sgan_batch_norm'] = False
# ngsimDataset Class in utils.py and HighdDataset Class in utils_HighD.py
args['t_hist'] = 30
args['t_fut'] = 50
args['skip_factor'] = 2 # d_s
args['pretrainEpochs'] = 6
args['trainEpochs'] = 5
# Prediction horizon used in evaluation
args['pred_horiz'] = 5