-
Notifications
You must be signed in to change notification settings - Fork 0
/
input.py
25 lines (20 loc) · 804 Bytes
/
input.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
"""
the value of alpha is determined empirically.
In the experiments described in the paper,
the hyperparameter alpha is searched in the
scopes {0, 0.01, 0.1, 1, 10, 100} on five fine-grained datasets.
The best value of alpha is chosen based on
the performance of the method on these datasets.
"""
def aggregate_features(global_feature, local_feature, alpha):
"""
Aggregates global and local features.
Args:
global_feature (torch.Tensor): Global feature tensor.
local_feature (torch.Tensor): Local feature tensor.
alpha (float): Weight balancing the effect of different features.
Returns:
torch.Tensor: Aggregated feature tensor.
"""
aggregated_feature = global_feature + alpha * local_feature
return aggregated_feature