-
Notifications
You must be signed in to change notification settings - Fork 3
/
main_TCQE_TFLEX.py
1473 lines (1274 loc) · 68.3 KB
/
main_TCQE_TFLEX.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
@date: 2021/10/26
@description: null
"""
import gc
from collections import defaultdict
import logging
import os
import sys
from typing import List, Dict, Tuple, Optional, Union, Set
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import expression
from ComplexTemporalQueryData import ICEWS05_15, ICEWS14, GDELT, ComplexTemporalQueryDatasetCachePath, TemporalComplexQueryData, TYPE_train_queries_answers, groups
from ComplexTemporalQueryDataloader import TestDataset, TrainDataset
from expression.ParamSchema import is_entity, is_relation, is_timestamp
from expression.TFLEX_DSL import is_to_predict_entity_set, query_contains_union_and_we_should_use_DNF
from toolbox.data.dataloader import SingledirectionalOneShotIterator
from toolbox.evaluate.GatherMetric import AverageMeter
from toolbox.exp.Experiment import Experiment
from toolbox.exp.OutputSchema import OutputSchema
from toolbox.utils.KGArgsParser import KGEArgParser
from toolbox.utils.Log import Log
from toolbox.utils.Progbar import Progbar
from toolbox.utils.RandomSeeds import set_seeds
from dataclasses import dataclass, field
import os
import sys
from toolbox.utils.KGArgs import OutputArguments
from toolbox.utils.KGArgsParser import KGEArgParser
@dataclass
class ModelArguments:
entity_dim: int = field(default=800, metadata={
"help": "The dimension of the entity embeddings."
})
hidden_dim: int = field(default=800, metadata={"help": "embedding dimension"})
input_dropout: float = field(default=0.1, metadata={"help": "Input layer dropout."})
gamma: float = field(default=15.0, metadata={"help": "margin in the loss"})
center_reg: float = field(default=0.02, metadata={
"help": 'center_reg for ConE, center_reg balances the in_cone dist and out_cone dist'
})
@dataclass
class DataArguments:
data_home: str = field(default="data", metadata={"help": "The folder path to dataset."})
dataset: str = field(default="ICEWS14", metadata={"help": "Which dataset to use: ICEWS14, ICEWS05_15, GDELT."})
@dataclass
class ExperimentArguments:
name: str = field(default="TFLEX", metadata={"help": "Name of the experiment."})
@dataclass
class TrainingArguments:
do_train: bool = field(default=True, metadata={
"help": "Whether to run training."
})
do_valid: bool = field(default=True, metadata={
"help": "Whether to run on the dev set."
})
do_test: bool = field(default=True, metadata={
"help": "Whether to run on the test set."
})
seed: int = field(default=42, metadata={"help": "random seed for initialization"})
# 1. args for training, available only if do_train is True
resume: bool = field(default=False, metadata={"help": "Resume from output directory."})
resume_by_score: float = field(default=0.0, metadata={
"help": "Resume by score from output directory. Resume best if it is 0. Default: 0"
})
start_step: int = field(default=0, metadata={"help": "start step."})
max_steps: int = field(default=200001, metadata={"help": "Number of steps."})
every_valid_step: int = field(default=10000, metadata={"help": "Number of steps."})
every_test_step: int = field(default=10000, metadata={"help": "Number of steps."})
negative_sample_size: int = field(default=128, metadata={"help": "negative entities sampled per query"})
lr: float = field(default=0.0001, metadata={"help": "Learning rate."})
train_tasks: str = field(default="Pe,Pe2,Pe3,e2i,e3i,"
+ "Pt,aPt,bPt,Pe_Pt,Pt_sPe_Pt,Pt_oPe_Pt,t2i,t3i,"
+ "e2i_N,e3i_N,Pe_e2i_Pe_NPe,e2i_PeN,e2i_NPe,"
+ "t2i_N,t3i_N,Pe_t2i_PtPe_NPt,t2i_PtN,t2i_NPt",
metadata={"help": 'the tasks for training'})
train_all: bool = field(default=False, metadata={
"help": 'if training all, it will use all tasks in data.train_queries_answers'
})
train_batch_size: int = field(default=512, metadata={"help": "for training: batch size"})
train_shuffle: bool = field(default=True, metadata={"help": "for training: shuffle data"})
train_drop_last: bool = field(default=True, metadata={"help": "for training: drop last batch"})
train_num_workers: int = field(default=1, metadata={"help": "for training: number of workers"})
train_pin_memory: bool = field(default=False, metadata={"help": "for training: pin memory"})
train_device: str = field(default="cuda:0", metadata={"help": "choice: cuda:0, cuda:1, cpu."})
# 2. args for evaluation and testing, available only if do_eval or do_test is True
test_tasks: str = field(default="Pe,Pt,Pe2,Pe3", metadata={"help": 'for testing: the tasks'})
test_all: bool = field(default=False, metadata={
"help": 'if testing all, it will use all tasks in data.test_queries_answers'
})
test_batch_size: int = field(default=8, metadata={"help": "for testing: batch size"})
test_shuffle: bool = field(default=False, metadata={"help": "for testing: shuffle data"})
test_drop_last: bool = field(default=False, metadata={"help": "for testing: drop last batch"})
test_num_workers: int = field(default=1, metadata={"help": "for testing: number of workers"})
test_pin_memory: bool = field(default=False, metadata={"help": "for testing: pin memory"})
test_device: str = field(default="cuda:0", metadata={"help": "choice: cuda:0, cuda:1, cpu."})
QueryStructure = str
TYPE_token = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
L = 1
def convert_to_logic(x):
# [0, 1]
y = torch.sigmoid(2 * x)
return y
def convert_to_feature(x):
# [-1, 1]
y = torch.tanh(x) * L
return y
def convert_to_time_feature(x):
# [-1, 1]
y = torch.tanh(x) * L
return y
def convert_to_time_logic(x):
# [0, 1]
y = torch.sigmoid(2 * x)
return y
class EntityProjection(nn.Module):
def __init__(self, dim, hidden_dim=800, num_layers=2, drop=0.1):
super(EntityProjection, self).__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.dropout = nn.Dropout(drop)
token_dim = dim * 4
self.layer1 = nn.Linear(token_dim, self.hidden_dim)
self.layer0 = nn.Linear(self.hidden_dim, token_dim)
for i in range(2, num_layers + 1):
setattr(self, f"layer{i}", nn.Linear(self.hidden_dim, self.hidden_dim))
for i in range(num_layers + 1):
nn.init.xavier_uniform_(getattr(self, f"layer{i}").weight)
def forward(self,
q_feature, q_logic, q_time_feature, q_time_logic,
r_feature, r_logic, r_time_feature, r_time_logic,
t_feature, t_logic, t_time_feature, t_time_logic):
x = torch.cat([
q_feature + r_feature + t_feature,
q_logic + r_logic + t_logic,
q_time_feature + r_time_feature + t_time_feature,
q_time_logic + r_time_logic + t_time_logic,
], dim=-1)
for i in range(1, self.num_layers + 1):
x = F.relu(getattr(self, f"layer{i}")(x))
x = self.layer0(x)
feature, logic, time_feature, time_logic = torch.chunk(x, 4, dim=-1)
feature = convert_to_feature(feature)
logic = convert_to_logic(logic)
time_feature = convert_to_time_feature(time_feature)
time_logic = convert_to_time_logic(time_logic)
return feature, logic, time_feature, time_logic
class TimeProjection(nn.Module):
def __init__(self, dim, hidden_dim=800, num_layers=2, drop=0.1):
super(TimeProjection, self).__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.dropout = nn.Dropout(drop)
token_dim = dim * 4
self.layer1 = nn.Linear(token_dim, self.hidden_dim)
self.layer0 = nn.Linear(self.hidden_dim, token_dim)
for nl in range(2, num_layers + 1):
setattr(self, "layer{}".format(nl), nn.Linear(self.hidden_dim, self.hidden_dim))
for nl in range(num_layers + 1):
nn.init.xavier_uniform_(getattr(self, "layer{}".format(nl)).weight)
def forward(self,
q1_feature, q1_logic, q1_time_feature, q1_time_logic,
r_feature, r_logic, r_time_feature, r_time_logic,
q2_feature, q2_logic, q2_time_feature, q2_time_logic):
x = torch.cat([
q1_feature + r_feature + q2_feature,
q1_logic + r_logic + q2_logic,
q1_time_feature + r_time_feature + q2_time_feature,
q1_time_logic + r_time_logic + q2_time_logic,
], dim=-1)
for nl in range(1, self.num_layers + 1):
x = F.relu(getattr(self, "layer{}".format(nl))(x))
x = self.layer0(x)
feature, logic, time_feature, time_logic = torch.chunk(x, 4, dim=-1)
feature = convert_to_feature(feature)
logic = convert_to_logic(logic)
time_feature = convert_to_time_feature(time_feature)
time_logic = convert_to_time_logic(time_logic)
return feature, logic, time_feature, time_logic
class EntityIntersection(nn.Module):
def __init__(self, dim):
super(EntityIntersection, self).__init__()
self.dim = dim
self.feature_layer_1 = nn.Linear(self.dim * 2, self.dim)
self.feature_layer_2 = nn.Linear(self.dim, self.dim)
self.time_feature_layer_1 = nn.Linear(self.dim * 2, self.dim)
self.time_feature_layer_2 = nn.Linear(self.dim, self.dim)
nn.init.xavier_uniform_(self.feature_layer_1.weight)
nn.init.xavier_uniform_(self.feature_layer_2.weight)
nn.init.xavier_uniform_(self.time_feature_layer_1.weight)
nn.init.xavier_uniform_(self.time_feature_layer_2.weight)
def forward(self, feature, logic, time_feature, time_logic):
# N x B x d
logits = torch.cat([feature, logic], dim=-1) # N x B x 2d
feature_attention = F.softmax(self.feature_layer_2(F.relu(self.feature_layer_1(logits))), dim=0)
feature = torch.sum(feature_attention * feature, dim=0)
logits = torch.cat([time_feature, time_logic], dim=-1) # N x B x 2d
feature_attention = F.softmax(self.time_feature_layer_2(F.relu(self.time_feature_layer_1(logits))), dim=0)
time_feature = torch.sum(feature_attention * time_feature, dim=0)
logic, _ = torch.min(logic, dim=0)
time_logic, _ = torch.min(time_logic, dim=0)
return feature, logic, time_feature, time_logic
class TemporalIntersection(nn.Module):
def __init__(self, dim):
super(TemporalIntersection, self).__init__()
self.dim = dim
self.feature_layer_1 = nn.Linear(self.dim * 2, self.dim)
self.feature_layer_2 = nn.Linear(self.dim, self.dim)
self.time_feature_layer_1 = nn.Linear(self.dim * 2, self.dim)
self.time_feature_layer_2 = nn.Linear(self.dim, self.dim)
nn.init.xavier_uniform_(self.feature_layer_1.weight)
nn.init.xavier_uniform_(self.feature_layer_2.weight)
nn.init.xavier_uniform_(self.time_feature_layer_1.weight)
nn.init.xavier_uniform_(self.time_feature_layer_2.weight)
def forward(self, feature, logic, time_feature, time_logic):
# N x B x d
logits = torch.cat([feature, logic], dim=-1) # N x B x 2d
feature_attention = F.softmax(self.feature_layer_2(F.relu(self.feature_layer_1(logits))), dim=0)
feature = torch.sum(feature_attention * feature, dim=0)
logits = torch.cat([time_feature, time_logic], dim=-1) # N x B x 2d
feature_attention = F.softmax(self.time_feature_layer_2(F.relu(self.time_feature_layer_1(logits))), dim=0)
time_feature = torch.sum(feature_attention * time_feature, dim=0)
logic, _ = torch.min(logic, dim=0)
time_logic, _ = torch.min(time_logic, dim=0)
return feature, logic, time_feature, time_logic
class EntityNegation(nn.Module):
def __init__(self, dim):
super(EntityNegation, self).__init__()
self.dim = dim
self.feature_layer_1 = nn.Linear(self.dim * 2, self.dim)
self.feature_layer_2 = nn.Linear(self.dim, self.dim)
nn.init.xavier_uniform_(self.feature_layer_1.weight)
nn.init.xavier_uniform_(self.feature_layer_2.weight)
def forward(self, feature, logic, time_feature, time_logic):
logits = torch.cat([feature, logic], dim=-1) # N x B x 2d
feature = self.feature_layer_2(F.relu(self.feature_layer_1(logits)))
logic = 1 - logic
return feature, logic, time_feature, time_logic
class TemporalNegation(nn.Module):
def __init__(self, dim):
super(TemporalNegation, self).__init__()
self.dim = dim
self.feature_layer_1 = nn.Linear(self.dim * 2, self.dim)
self.feature_layer_2 = nn.Linear(self.dim, self.dim)
nn.init.xavier_uniform_(self.feature_layer_1.weight)
nn.init.xavier_uniform_(self.feature_layer_2.weight)
def forward(self, feature, logic, time_feature, time_logic):
logits = torch.cat([time_feature, time_logic], dim=-1) # N x B x 2d
time_feature = self.feature_layer_2(F.relu(self.feature_layer_1(logits)))
time_logic = 1 - time_logic
return feature, logic, time_feature, time_logic
def scale_feature(feature):
# f,f' in [-L, L]
# f' = (f + 2L) % (2L) - L, where L=1
indicator_positive = feature >= 0
indicator_negative = feature < 0
feature[indicator_positive] = feature[indicator_positive] - L
feature[indicator_negative] = feature[indicator_negative] + L
return feature
class TemporalBefore(nn.Module):
def __init__(self, dim):
super(TemporalBefore, self).__init__()
self.dim = dim
def forward(self, feature, logic, time_feature, time_logic):
time_feature = scale_feature(time_feature - L / 2 - time_logic / 2)
time_logic = (L - time_logic) / 2
return feature, logic, time_feature, time_logic
class TemporalAfter(nn.Module):
def __init__(self, dim):
super(TemporalAfter, self).__init__()
self.dim = dim
def forward(self, feature, logic, time_feature, time_logic):
time_feature = scale_feature(time_feature + L / 2 + time_logic / 2)
time_logic = (L - time_logic) / 2
return feature, logic, time_feature, time_logic
class TemporalNext(nn.Module):
def __init__(self):
super(TemporalNext, self).__init__()
def forward(self, feature, logic, time_feature, time_logic):
time_feature = scale_feature(time_feature)
time_logic = 1 - time_logic
return feature, logic, time_feature, time_logic
class EntityUnion(nn.Module):
def __init__(self, dim):
super(EntityUnion, self).__init__()
self.dim = dim
self.feature_layer_1 = nn.Linear(self.dim * 2, self.dim)
self.feature_layer_2 = nn.Linear(self.dim, self.dim)
self.time_feature_layer_1 = nn.Linear(self.dim * 2, self.dim)
self.time_feature_layer_2 = nn.Linear(self.dim, self.dim)
nn.init.xavier_uniform_(self.feature_layer_1.weight)
nn.init.xavier_uniform_(self.feature_layer_2.weight)
nn.init.xavier_uniform_(self.time_feature_layer_1.weight)
nn.init.xavier_uniform_(self.time_feature_layer_2.weight)
def forward(self, feature, logic, time_feature, time_logic):
# N x B x d
logits = torch.cat([feature, logic], dim=-1) # N x B x 2d
feature_attention = F.softmax(self.feature_layer_2(F.relu(self.feature_layer_1(logits))), dim=0)
feature = torch.sum(feature_attention * feature, dim=0)
logits = torch.cat([time_feature, time_logic], dim=-1) # N x B x 2d
feature_attention = F.softmax(self.time_feature_layer_2(F.relu(self.time_feature_layer_1(logits))), dim=0)
time_feature = torch.sum(feature_attention * time_feature, dim=0)
logic, _ = torch.max(logic, dim=0)
# for time, it is intersection
time_logic, _ = torch.min(time_logic, dim=0)
# logic = torch.prod(logic, dim=0)
return feature, logic, time_feature, time_logic
class TemporalUnion(nn.Module):
def __init__(self, dim):
super(TemporalUnion, self).__init__()
self.dim = dim
self.feature_layer_1 = nn.Linear(self.dim * 2, self.dim)
self.feature_layer_2 = nn.Linear(self.dim, self.dim)
self.time_feature_layer_1 = nn.Linear(self.dim * 2, self.dim)
self.time_feature_layer_2 = nn.Linear(self.dim, self.dim)
nn.init.xavier_uniform_(self.feature_layer_1.weight)
nn.init.xavier_uniform_(self.feature_layer_2.weight)
nn.init.xavier_uniform_(self.time_feature_layer_1.weight)
nn.init.xavier_uniform_(self.time_feature_layer_2.weight)
def forward(self, feature, logic, time_feature, time_logic):
# N x B x d
logits = torch.cat([feature, logic], dim=-1) # N x B x 2d
feature_attention = F.softmax(self.feature_layer_2(F.relu(self.feature_layer_1(logits))), dim=0)
feature = torch.sum(feature_attention * feature, dim=0)
logits = torch.cat([time_feature, time_logic], dim=-1) # N x B x 2d
feature_attention = F.softmax(self.time_feature_layer_2(F.relu(self.time_feature_layer_1(logits))), dim=0)
time_feature = torch.sum(feature_attention * time_feature, dim=0)
# for entity, it is intersection
logic, _ = torch.min(logic, dim=0)
# for time, it is union
time_logic, _ = torch.max(time_logic, dim=0)
# logic = torch.prod(logic, dim=0)
return feature, logic, time_feature, time_logic
class TFLEX(nn.Module):
def __init__(self,
nentity, nrelation, ntimestamp, hidden_dim, gamma,
center_reg=None, drop: float = 0.
):
super(TFLEX, self).__init__()
self.nentity = nentity
self.nrelation = nrelation
self.hidden_dim = hidden_dim
self.entity_dim = hidden_dim
self.relation_dim = hidden_dim
self.timestamp_dim = hidden_dim
# entity only have feature part but no logic part
self.entity_feature_embedding = nn.Embedding(nentity, self.entity_dim)
self.timestamp_feature_embedding = nn.Embedding(ntimestamp, self.timestamp_dim)
self.relation_feature_embedding = nn.Embedding(nrelation, self.relation_dim)
self.relation_logic_embedding = nn.Embedding(nrelation, self.relation_dim)
self.relation_time_feature_embedding = nn.Embedding(nrelation, self.relation_dim)
self.relation_time_logic_embedding = nn.Embedding(nrelation, self.relation_dim)
self.entity_projection = EntityProjection(hidden_dim, drop=drop)
self.entity_intersection = EntityIntersection(hidden_dim)
self.entity_union = EntityUnion(hidden_dim)
self.entity_negation = EntityNegation(hidden_dim)
self.time_projection = TimeProjection(hidden_dim, drop=drop)
self.time_intersection = TemporalIntersection(hidden_dim)
self.time_union = TemporalUnion(hidden_dim)
self.time_negation = TemporalNegation(hidden_dim)
self.time_before = TemporalBefore(hidden_dim)
self.time_after = TemporalAfter(hidden_dim)
self.time_next = TemporalNext()
self.epsilon = 2.0
self.gamma = nn.Parameter(torch.Tensor([gamma]), requires_grad=False)
self.embedding_range = nn.Parameter(
torch.Tensor([(self.gamma.item() + self.epsilon) / hidden_dim]),
requires_grad=False)
embedding_range = self.embedding_range.item()
self.modulus = nn.Parameter(torch.Tensor([0.5 * embedding_range]), requires_grad=True)
self.cen = center_reg
self.parser = self.build_parser()
def build_neural_ops(self):
def And(q1, q2):
q1_feature, q1_logic, q1_time_feature, q1_time_logic = q1
q2_feature, q2_logic, q2_time_feature, q2_time_logic = q2
feature = torch.stack([q1_feature, q2_feature])
logic = torch.stack([q1_logic, q2_logic])
time_feature = torch.stack([q1_time_feature, q2_time_feature])
time_logic = torch.stack([q1_time_logic, q2_time_logic])
return self.entity_intersection(feature, logic, time_feature, time_logic)
def And3(q1, q2, q3):
q1_feature, q1_logic, q1_time_feature, q1_time_logic = q1
q2_feature, q2_logic, q2_time_feature, q2_time_logic = q2
q3_feature, q3_logic, q3_time_feature, q3_time_logic = q3
feature = torch.stack([q1_feature, q2_feature, q3_feature])
logic = torch.stack([q1_logic, q2_logic, q3_logic])
time_feature = torch.stack([q1_time_feature, q2_time_feature, q3_time_feature])
time_logic = torch.stack([q1_time_logic, q2_time_logic, q3_time_logic])
return self.entity_intersection(feature, logic, time_feature, time_logic)
def Or(q1, q2):
q1_feature, q1_logic, q1_time_feature, q1_time_logic = q1
q2_feature, q2_logic, q2_time_feature, q2_time_logic = q2
feature = torch.stack([q1_feature, q2_feature])
logic = torch.stack([q1_logic, q2_logic])
time_feature = torch.stack([q1_time_feature, q2_time_feature])
time_logic = torch.stack([q1_time_logic, q2_time_logic])
return self.entity_union(feature, logic, time_feature, time_logic)
def Not(q):
feature, logic, time_feature, time_logic = q
return self.entity_negation(feature, logic, time_feature, time_logic)
def TimeNot(q):
feature, logic, time_feature, time_logic = q
return self.time_negation(feature, logic, time_feature, time_logic)
def EntityProjection2(e1, r1, t1):
s_feature, s_logic, s_time_feature, s_time_logic = e1
r_feature, r_logic, r_time_feature, r_time_logic = r1
t_feature, t_logic, t_time_feature, t_time_logic = t1
return self.entity_projection(
s_feature, s_logic, s_time_feature, s_time_logic,
r_feature, r_logic, r_time_feature, r_time_logic,
t_feature, t_logic, t_time_feature, t_time_logic
)
def TimeProjection2(e1, r1, e2):
s_feature, s_logic, s_time_feature, s_time_logic = e1
r_feature, r_logic, r_time_feature, r_time_logic = r1
o_feature, o_logic, o_time_feature, o_time_logic = e2
return self.time_projection(
s_feature, s_logic, s_time_feature, s_time_logic,
r_feature, r_logic, r_time_feature, r_time_logic,
o_feature, o_logic, o_time_feature, o_time_logic
)
def TimeAnd(q1, q2):
q1_feature, q1_logic, q1_time_feature, q1_time_logic = q1
q2_feature, q2_logic, q2_time_feature, q2_time_logic = q2
feature = torch.stack([q1_feature, q2_feature])
logic = torch.stack([q1_logic, q2_logic])
time_feature = torch.stack([q1_time_feature, q2_time_feature])
time_logic = torch.stack([q1_time_logic, q2_time_logic])
return self.time_intersection(feature, logic, time_feature, time_logic)
def TimeAnd3(q1, q2, q3):
q1_feature, q1_logic, q1_time_feature, q1_time_logic = q1
q2_feature, q2_logic, q2_time_feature, q2_time_logic = q2
q3_feature, q3_logic, q3_time_feature, q3_time_logic = q3
feature = torch.stack([q1_feature, q2_feature, q3_feature])
logic = torch.stack([q1_logic, q2_logic, q3_logic])
time_feature = torch.stack([q1_time_feature, q2_time_feature, q3_time_feature])
time_logic = torch.stack([q1_time_logic, q2_time_logic, q3_time_logic])
return self.time_intersection(feature, logic, time_feature, time_logic)
def TimeOr(q1, q2):
q1_feature, q1_logic, q1_time_feature, q1_time_logic = q1
q2_feature, q2_logic, q2_time_feature, q2_time_logic = q2
feature = torch.stack([q1_feature, q2_feature])
logic = torch.stack([q1_logic, q2_logic])
time_feature = torch.stack([q1_time_feature, q2_time_feature])
time_logic = torch.stack([q1_time_logic, q2_time_logic])
return self.time_union(feature, logic, time_feature, time_logic)
def TimeBefore(q):
feature, logic, time_feature, time_logic = q
return self.time_before(feature, logic, time_feature, time_logic)
def TimeAfter(q):
feature, logic, time_feature, time_logic = q
return self.time_after(feature, logic, time_feature, time_logic)
def TimeNext(q):
feature, logic, time_feature, time_logic = q
return self.time_next(feature, logic, time_feature, time_logic)
def beforePt(e1, r1, e2):
return TimeBefore(TimeProjection2(e1, r1, e2))
def afterPt(e1, r1, e2):
return TimeAfter(TimeProjection2(e1, r1, e2))
neural_ops = {
"And": And,
"And3": And3,
"Or": Or,
"Not": Not,
"EntityProjection": EntityProjection2,
"TimeProjection": TimeProjection2,
"TimeAnd": TimeAnd,
"TimeAnd3": TimeAnd3,
"TimeOr": TimeOr,
"TimeNot": TimeNot,
"TimeBefore": TimeBefore,
"TimeAfter": TimeAfter,
"TimeNext": TimeNext,
"afterPt": afterPt,
"beforePt": beforePt,
}
return neural_ops
def build_parser(self):
neural_ops = self.build_neural_ops()
return expression.NeuralParser(neural_ops)
def init(self):
embedding_range = self.embedding_range.item()
nn.init.uniform_(tensor=self.entity_feature_embedding.weight.data, a=-embedding_range, b=embedding_range)
nn.init.uniform_(tensor=self.timestamp_feature_embedding.weight.data, a=-embedding_range, b=embedding_range)
nn.init.uniform_(tensor=self.relation_feature_embedding.weight.data, a=-embedding_range, b=embedding_range)
nn.init.uniform_(tensor=self.relation_logic_embedding.weight.data, a=-embedding_range, b=embedding_range)
nn.init.uniform_(tensor=self.relation_time_feature_embedding.weight.data, a=-embedding_range, b=embedding_range)
nn.init.uniform_(tensor=self.relation_time_logic_embedding.weight.data, a=-embedding_range, b=embedding_range)
def scale(self, embedding):
return embedding / self.embedding_range
def entity_feature(self, idx):
return convert_to_feature(self.scale(self.entity_feature_embedding(idx)))
def timestamp_feature(self, idx):
return convert_to_time_feature(self.scale(self.timestamp_feature_embedding(idx)))
def entity_token(self, idx) -> TYPE_token:
feature = self.entity_feature(idx)
logic = torch.zeros_like(feature).to(feature.device)
time_feature = torch.zeros_like(feature).to(feature.device)
time_logic = torch.zeros_like(feature).to(feature.device)
return feature, logic, time_feature, time_logic
def relation_token(self, idx) -> TYPE_token:
feature = convert_to_feature(self.scale(self.relation_feature_embedding(idx)))
logic = convert_to_logic(self.scale(self.relation_logic_embedding(idx)))
time_feature = convert_to_time_feature(self.scale(self.relation_time_feature_embedding(idx)))
time_logic = convert_to_time_logic(self.scale(self.relation_time_logic_embedding(idx)))
return feature, logic, time_feature, time_logic
def timestamp_token(self, idx) -> TYPE_token:
time_feature = self.timestamp_feature(idx)
feature = torch.zeros_like(time_feature).to(time_feature.device)
logic = torch.zeros_like(feature).to(feature.device)
time_logic = torch.zeros_like(feature).to(feature.device)
return feature, logic, time_feature, time_logic
def embed_args(self, query_args: List[str], query_tensor: torch.Tensor) -> TYPE_token:
embedding_of_args = []
for i in range(len(query_args)):
arg_name = query_args[i]
tensor = query_tensor[:, i]
if is_entity(arg_name):
token_embedding = self.entity_token(tensor)
elif is_relation(arg_name):
token_embedding = self.relation_token(tensor)
elif is_timestamp(arg_name):
token_embedding = self.timestamp_token(tensor)
else:
raise Exception("Unknown Args %s" % arg_name)
embedding_of_args.append(token_embedding)
return tuple(embedding_of_args)
def forward(self, positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict):
return self.forward_FLEX(
positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict)
def forward_FLEX(self,
positive_answer: Optional[torch.Tensor],
negative_answer: Optional[torch.Tensor],
subsampling_weight: Optional[torch.Tensor],
grouped_query: Dict[QueryStructure, torch.Tensor],
grouped_idxs: Dict[QueryStructure, List[List[int]]]):
"""
positive_answer: None or (B, )
negative_answer: None or (B, N)
subsampling_weight: None or (B, )
"""
# 1. 将 查询 嵌入到低维空间
(all_idxs_e, all_predict_e), \
(all_idxs_t, all_predict_t), \
(all_union_idxs_e, all_union_predict_e), \
(all_union_idxs_t, all_union_predict_t) = self.batch_predict(grouped_query, grouped_idxs)
all_idxs = all_idxs_e + all_idxs_t + all_union_idxs_e + all_union_idxs_t
if subsampling_weight is not None:
subsampling_weight = subsampling_weight[all_idxs]
positive_scores = None
negative_scores = None
# 2. 计算正例损失
if positive_answer is not None:
scores_e = self.scoring_to_answers_by_idxs(
all_idxs_e, positive_answer, all_predict_e, predict_entity=True, DNF_predict=False)
scores_t = self.scoring_to_answers_by_idxs(
all_idxs_t, positive_answer, all_predict_t, predict_entity=False, DNF_predict=False)
scores_union_e = self.scoring_to_answers_by_idxs(
all_union_idxs_e, positive_answer, all_union_predict_e, predict_entity=True, DNF_predict=True)
scores_union_t = self.scoring_to_answers_by_idxs(
all_union_idxs_t, positive_answer, all_union_predict_t, predict_entity=False, DNF_predict=True)
positive_scores = torch.cat([scores_e, scores_t, scores_union_e, scores_union_t], dim=0)
# 3. 计算负例损失
if negative_answer is not None:
scores_e = self.scoring_to_answers_by_idxs(
all_idxs_e, negative_answer, all_predict_e, predict_entity=True, DNF_predict=False)
scores_t = self.scoring_to_answers_by_idxs(
all_idxs_t, negative_answer, all_predict_t, predict_entity=False, DNF_predict=False)
scores_union_e = self.scoring_to_answers_by_idxs(
all_union_idxs_e, negative_answer, all_union_predict_e, predict_entity=True, DNF_predict=True)
scores_union_t = self.scoring_to_answers_by_idxs(
all_union_idxs_t, negative_answer, all_union_predict_t, predict_entity=False, DNF_predict=True)
negative_scores = torch.cat([scores_e, scores_t, scores_union_e, scores_union_t], dim=0)
return positive_scores, negative_scores, subsampling_weight, all_idxs
def single_predict(self, query_structure: QueryStructure, query_tensor: torch.Tensor) -> Union[TYPE_token,
Tuple[TYPE_token, TYPE_token]]:
query_name, query_args = query_structure
if query_contains_union_and_we_should_use_DNF(query_name):
# transform to DNF
func = self.parser.fast_function(query_name + "_DNF")
embedding_of_args = self.embed_args(query_args, query_tensor)
predict_1, predict_2 = func(*embedding_of_args) # tuple[B x dt, B x dt]
return predict_1, predict_2
else:
# other query and DM are normal
func = self.parser.fast_function(query_name)
embedding_of_args = self.embed_args(query_args, query_tensor) # [B x dt]*L
predict = func(*embedding_of_args) # B x dt
return predict
def batch_predict(self, grouped_query: Dict[QueryStructure, torch.Tensor],
grouped_idxs: Dict[QueryStructure, List[List[int]]]):
all_idxs_e, all_predict_e = [], []
all_idxs_t, all_predict_t = [], []
all_union_idxs_e, all_union_predict_1_e, all_union_predict_2_e = [], [], []
all_union_idxs_t, all_union_predict_1_t, all_union_predict_2_t = [], [], []
all_union_predict_e: Optional[TYPE_token] = None
all_union_predict_t: Optional[TYPE_token] = None
for query_structure in grouped_query:
query_name = query_structure
query_args = self.parser.fast_args(query_name)
query_tensor = grouped_query[query_structure] # (B, L), B for batch size, L for query args length
query_idxs = grouped_idxs[query_structure]
# query_idxs is of shape Bx1.
# each element indicates global index of each row in query_tensor.
# global index means the index in sample from dataloader.
# the sample is grouped by query name and leads to query_tensor here.
if query_contains_union_and_we_should_use_DNF(query_name):
# transform to DNF
func = self.parser.fast_function(query_name + "_DNF")
embedding_of_args = self.embed_args(query_args, query_tensor)
predict_1, predict_2 = func(*embedding_of_args) # tuple[(B, d), (B, d)]
if is_to_predict_entity_set(query_name):
all_union_predict_1_e.append(predict_1)
all_union_predict_2_e.append(predict_2)
all_union_idxs_e.extend(query_idxs)
else:
all_union_predict_1_t.append(predict_1)
all_union_predict_2_t.append(predict_2)
all_union_idxs_t.extend(query_idxs)
else:
# other query and DM are normal
func = self.parser.fast_function(query_name)
embedding_of_args = self.embed_args(query_args, query_tensor) # (B, d)*L
predict = func(*embedding_of_args) # (B, d)
if is_to_predict_entity_set(query_name):
all_predict_e.append(predict)
all_idxs_e.extend(query_idxs)
else:
all_predict_t.append(predict)
all_idxs_t.extend(query_idxs)
def cat_to_tensor(token_list: List[TYPE_token]) -> TYPE_token:
feature = []
logic = []
time_feature = []
time_logic = []
for x in token_list:
feature.append(x[0])
logic.append(x[1])
time_feature.append(x[2])
time_logic.append(x[3])
feature = torch.cat(feature, dim=0).unsqueeze(1)
logic = torch.cat(logic, dim=0).unsqueeze(1)
time_feature = torch.cat(time_feature, dim=0).unsqueeze(1)
time_logic = torch.cat(time_logic, dim=0).unsqueeze(1)
return feature, logic, time_feature, time_logic
if len(all_idxs_e) > 0:
all_predict_e = cat_to_tensor(all_predict_e) # (B, 1, d) * 5
if len(all_idxs_t) > 0:
all_predict_t = cat_to_tensor(all_predict_t) # (B, 1, d) * 5
if len(all_union_idxs_e) > 0:
all_union_predict_1_e = cat_to_tensor(all_union_predict_1_e) # (B, 1, d) * 5
all_union_predict_2_e = cat_to_tensor(all_union_predict_2_e) # (B, 1, d) * 5
all_union_predict_e: TYPE_token = tuple([torch.cat([x, y], dim=1) for x, y in zip(
all_union_predict_1_e, all_union_predict_2_e)]) # (B, 2, d) * 5
if len(all_union_idxs_t) > 0:
all_union_predict_1_t = cat_to_tensor(all_union_predict_1_t) # (B, 1, d) * 5
all_union_predict_2_t = cat_to_tensor(all_union_predict_2_t) # (B, 1, d) * 5
all_union_predict_t: TYPE_token = tuple([torch.cat([x, y], dim=1) for x, y in zip(
all_union_predict_1_t, all_union_predict_2_t)]) # (B, 2, d) * 5
return (all_idxs_e, all_predict_e), \
(all_idxs_t, all_predict_t), \
(all_union_idxs_e, all_union_predict_e), \
(all_union_idxs_t, all_union_predict_t)
def grouped_predict(self, grouped_query: Dict[QueryStructure, torch.Tensor],
grouped_answer: Dict[QueryStructure, torch.Tensor]) -> Dict[QueryStructure, torch.Tensor]:
"""
return {"Pe": (B, L) }
L 是答案个数,预测实体和预测时间戳 的答案个数不一样,所以不能对齐合并
不同结构的 L 不同
一般用于valid/test,不用于train
"""
grouped_score = {}
for query_structure in grouped_query:
query = grouped_query[query_structure] # (B, L), B for batch size, L for query args length
answer = grouped_answer[query_structure] # (B, N)
grouped_score[query_structure] = self.forward_predict(query_structure, query, answer)
return grouped_score
def forward_predict(
self, query_structure: QueryStructure, query_tensor: torch.Tensor, answer: torch.Tensor) -> torch.Tensor:
# query_tensor # (B, L), B for batch size, L for query args length
# answer # (B, N)
query_name = query_structure
query_args = self.parser.fast_args(query_name)
# the sample is grouped by query name and leads to query_tensor here.
if query_contains_union_and_we_should_use_DNF(query_name):
# transform to DNF
func = self.parser.fast_function(query_name + "_DNF")
embedding_of_args = self.embed_args(query_args, query_tensor)
predict_1, predict_2 = func(*embedding_of_args) # tuple[(B, d), (B, d)]
all_union_predict: TYPE_token = tuple([torch.stack([x, y], dim=1)
for x, y in zip(predict_1, predict_2)]) # (B, 2, d) * 5
if is_to_predict_entity_set(query_name):
return self.scoring_to_answers(answer, all_union_predict, predict_entity=True, DNF_predict=True)
else:
return self.scoring_to_answers(answer, all_union_predict, predict_entity=False, DNF_predict=True)
else:
# other query and DM are normal
func = self.parser.fast_function(query_name)
embedding_of_args = self.embed_args(query_args, query_tensor) # (B, d)*L
predict = func(*embedding_of_args) # (B, d)
all_predict: TYPE_token = tuple([i.unsqueeze(dim=1) for i in predict]) # (B, 1, d)
if is_to_predict_entity_set(query_name):
return self.scoring_to_answers(answer, all_predict, predict_entity=True, DNF_predict=False)
else:
return self.scoring_to_answers(answer, all_predict, predict_entity=False, DNF_predict=False)
def scoring_to_answers_by_idxs(
self, all_idxs, answer: torch.Tensor, q: TYPE_token, predict_entity=True, DNF_predict=False):
"""
B for batch size
N for negative sampling size (maybe N=1 when positive samples only)
all_answer_idxs: (B, ) or (B, N) int
all_predict: (B, 1, dt) or (B, 2, dt) float
return score: (B, N) float
"""
if len(all_idxs) <= 0:
return torch.Tensor([]).to(self.embedding_range.device)
answer_ids = answer[all_idxs]
answer_ids = answer_ids.view(answer_ids.shape[0], -1)
return self.scoring_to_answers(answer_ids, q, predict_entity, DNF_predict)
def scoring_to_answers(self, answer_ids: torch.Tensor, q: TYPE_token, predict_entity=True, DNF_predict=False):
"""
B for batch size
N for negative sampling size (maybe N=1 when positive samples only)
answer_ids: (B, N) int
all_predict: (B, 1, dt) or (B, 2, dt) float
return score: (B, N) float
"""
q: TYPE_token = tuple([i.unsqueeze(dim=2) for i in q]) # (B, 1, 1, dt) or (B, 2, 1, dt)
if predict_entity:
feature = self.entity_feature(answer_ids).unsqueeze(dim=1) # (B, 1, N, d)
scores = self.scoring_entity(feature, q) # (B, 1, N) or (B, 2, N)
else:
feature = self.timestamp_feature(answer_ids).unsqueeze(dim=1) # (B, 1, N, d)
scores = self.scoring_timestamp(feature, q) # (B, 1, N) or (B, 2, N)
if DNF_predict:
scores = torch.max(scores, dim=1)[0] # (B, N)
else:
scores = scores.squeeze(dim=1) # (B, N)
return scores # (B, N)
def distance_between_entity_and_query(self, entity_feature, query_feature, query_logic):
"""
entity_feature (B, 1, N, d)
query_feature (B, 1, 1, dt) or (B, 2, 1, dt)
query_logic (B, 1, 1, dt) or (B, 2, 1, dt)
query = [(feature - logic) | feature | (feature + logic)]
entity = entity_feature | | |
| | | |
1) from entity to center of the interval | |
d_center = entity_feature - feature |
|<------------------------------->| |
2) from entity to left of the interval |
d_left = entity_feature - (feature - logic) |
|<----------------->| |
3) from entity to right of the interval |
d_right = entity_feature - (feature + logic)
|<----------------------------------------------->|
"""
d_center = entity_feature - query_feature
d_left = entity_feature - (query_feature - query_logic)
d_right = entity_feature - (query_feature + query_logic)
# inner distance
feature_distance = torch.abs(d_center)
inner_distance = torch.min(feature_distance, query_logic)
# outer distance
outer_distance = torch.min(torch.abs(d_left), torch.abs(d_right))
outer_distance[feature_distance < query_logic] = 0. # if entity is inside, we don't care about outer.
distance = torch.norm(outer_distance, p=1, dim=-1) + self.cen * torch.norm(inner_distance, p=1, dim=-1)
return distance
def distance_between_timestamp_and_query(self, timestamp_feature, time_feature, time_logic):
"""
entity_feature (B, 1, N, d)
query_feature (B, 1, 1, dt) or (B, 2, 1, dt)
query_logic (B, 1, 1, dt) or (B, 2, 1, dt)
query = [(feature - logic) | feature | (feature + logic)]
entity = entity_feature | | |
| | | |
1) from entity to center of the interval | |
d_center = entity_feature - feature |
|<------------------------------->| |
2) from entity to left of the interval |
d_left = entity_feature - (feature - logic) |
|<----------------->| |
3) from entity to right of the interval |
d_right = entity_feature - (feature + logic)
|<----------------------------------------------->|
"""
d_center = timestamp_feature - time_feature
d_left = timestamp_feature - (time_feature - time_logic)
d_right = timestamp_feature - (time_feature + time_logic)
# inner distance
feature_distance = torch.abs(d_center)
inner_distance = torch.min(feature_distance, time_logic)
# outer distance
outer_distance = torch.min(torch.abs(d_left), torch.abs(d_right))
outer_distance[feature_distance < time_logic] = 0. # if entity is inside, we don't care about outer.
distance = torch.norm(outer_distance, p=1, dim=-1) + self.cen * torch.norm(inner_distance, p=1, dim=-1)
return distance
def scoring_entity(self, entity_feature, q: TYPE_token):
feature, logic, time_feature, time_logic = q
distance = self.distance_between_entity_and_query(entity_feature, feature, logic)
score = self.gamma - distance * self.modulus
return score
def scoring_timestamp(self, timestamp_feature, q: TYPE_token):
feature, logic, time_feature, time_logic = q
distance = self.distance_between_timestamp_and_query(timestamp_feature, time_feature, time_logic)
score = self.gamma - distance * self.modulus
return score
class MyExperiment(Experiment):
def __init__(self, output: OutputSchema, data: TemporalComplexQueryData, model, args: TrainingArguments):
super(MyExperiment, self).__init__(output)
self.debug(f"{locals()}")
self.metric_log_store.add_hyper(args, "args")
self.model_param_store.save_scripts([__file__])
entity_count = data.entity_count
timestamp_count = data.timestamp_count
self.groups = groups
# 1. build train dataset
if args.do_train:
if args.train_all:
data.load_cache(["train_queries_answers"])
else:
data.train_queries_answers = data.load_cache_by_tasks(args.train_tasks.split(","), "train")
train_queries_answers = data.train_queries_answers
train_path_queries: TYPE_train_queries_answers = {}
train_other_queries: TYPE_train_queries_answers = {}
path_list = ["Pe", "Pt", "Pe2", 'Pe3']
for query_structure_name in train_queries_answers:
if query_structure_name in path_list:
train_path_queries[query_structure_name] = train_queries_answers[query_structure_name]
else:
train_other_queries[query_structure_name] = train_queries_answers[query_structure_name]
self.log("Training info:")
self.log(str({
query_structure_name: len(train_queries_answers[query_structure_name]['queries_answers'])
for query_structure_name in train_queries_answers
}))
del train_queries_answers
del data.train_queries_answers
gc.collect()
train_path_iterator = SingledirectionalOneShotIterator(DataLoader(
TrainDataset(train_path_queries, entity_count, timestamp_count, args.negative_sample_size),
batch_size=args.train_batch_size,
num_workers=args.train_num_workers,
pin_memory=args.train_pin_memory,
shuffle=args.train_shuffle,
drop_last=args.train_drop_last,
collate_fn=TrainDataset.collate_fn
))
if len(train_other_queries) > 0:
train_other_iterator = SingledirectionalOneShotIterator(DataLoader(