-
Notifications
You must be signed in to change notification settings - Fork 0
/
option.py
105 lines (71 loc) · 2.85 KB
/
option.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# coding=utf-8
class Option(object):
"""
Options used by embedding model.
"""
def __init__(self, config_holder, vocab, model="unknown"):
# Model options.
self.model = model
# Embedding dimension.
self.embed_dim = config_holder.embed_dim
# Doc Embedding dimension
self.embed_dim_doc = config_holder.embed_dim
# Window size
self.window_size = config_holder.window
# Number of negative samples per example.
self.num_samples = config_holder.num_neg
# The initial learning rate.
self.learning_rate = config_holder.alpha
# Number of epochs to train. After these many epochs, the learning
# rate decays linearly to zero and the training stops.
self.epochs_to_train = config_holder.epochs
# Number of negative sample
self.negative_sample_size = config_holder.num_neg
# Concurrent training steps.
# self.concurrent_steps = config_holder.concurrent_steps
# Number of examples for one training step.
self.batch_size = config_holder.batch_size
# The minimum number of word occurrences for it to be included in the
# vocabulary.
self.min_count = config_holder.countThreshold
# Training algorithm
self.train_method = config_holder.train_method
# Subsampling threshold for word occurrence.
self.subsample = config_holder.sample
# Vocabulary size
self.vocab_size = len(vocab)
# How often to print statistics.
self.statistics_interval = config_holder.stat_interval
# How often to write checkpoints (rounds up to the nearest statistics
# interval).
self.checkpoint_interval = config_holder.check_interval
# Which of the loss the embedding model use
self.loss = config_holder.loss
self.concat = config_holder.concat
self.sentence_sample = config_holder.ss if self.model != "doc2vecc" else 10
self.dp_ratio = 0.5
class BasicOption(object):
"""
Options used by Basic model.
"""
def __init__(self, name, num_feature, num_class, lr, epochs, interval):
# Name of Basic Model
self.model = "linear"
# Dimension of feature that training data have
self.num_feature = num_feature
# Number of class need to predict
self.num_class = num_class
# Learning rate (alpha) of model
self.lr = lr
# Number of training epochs
self.num_epochs = epochs
# How often to print statistics.
self.statistics_interval = interval
class GloVeOption(Option):
"""
Options for GloVe Model
"""
def __init__(self, config_holder, vocab):
super(GloVeOption, self).__init__(config_holder, vocab, "GloVe")
self.cooccurrence_cap = 0.1
self.scaling_factor = 0.1