-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathpretrain_kg_embedding.py
1912 lines (1645 loc) · 75.2 KB
/
pretrain_kg_embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import torch.multiprocessing as mp
import torch as th
import os
import numpy as np
import torch.nn.init as INIT
import torch.multiprocessing as mp
from torch.multiprocessing import Queue
from _thread import start_new_thread
import traceback
from functools import wraps
import dgl
import dgl.backend as F
from dgl.base import NID, EID
import torch.nn as nn
import torch.nn.functional as functional
import logging
import time
import math
import os
import csv
import argparse
import json
import numpy as np
import math
import numpy as np
import scipy as sp
import os
import sys
import pickle
import time
DEFAULT_INFER_BATCHSIZE = 2048
EMB_INIT_EPS = 2.0
logsigmoid = functional.logsigmoid
none = lambda x: x
norm = lambda x, p: x.norm(p=p) ** p
get_scalar = lambda x: x.detach().item()
reshape = lambda arr, x, y: arr.view(x, y)
cuda = lambda arr, gpu: arr.cuda(gpu)
class CommonArgParser(argparse.ArgumentParser):
def __init__(self):
super(CommonArgParser, self).__init__()
self.add_argument('--model_name', default='TransE',
choices=['TransE', 'TransE_l1', 'TransE_l2', 'TransR',
'RESCAL', 'DistMult', 'ComplEx', 'RotatE'],
help='The models provided by DGL-KE.')
self.add_argument('--data_path', type=str, default='data/DRKG',
help='The path of the directory where DGL-KE loads knowledge graph data.')
self.add_argument('--dataset', type=str, default='DRKG',
help='The name of the builtin knowledge graph. Currently, the builtin knowledge ' \
'graphs include FB15k, FB15k-237, wn18, wn18rr and Freebase. ' \
'DGL-KE automatically downloads the knowledge graph and keep it under data_path.')
self.add_argument('--format', type=str, default='udd_hrt',
help='The format of the dataset. For builtin knowledge graphs,' \
'the foramt should be built_in. For users own knowledge graphs,' \
'it needs to be raw_udd_{htr} or udd_{htr}.')
self.add_argument('--data_files', type=str, default=None, nargs='+',
help='A list of data file names. This is used if users want to train KGE' \
'on their own datasets. If the format is raw_udd_{htr},' \
'users need to provide train_file [valid_file] [test_file].' \
'If the format is udd_{htr}, users need to provide' \
'entity_file relation_file train_file [valid_file] [test_file].' \
'In both cases, valid_file and test_file are optional.')
self.add_argument('--delimiter', type=str, default='\t',
help='Delimiter used in data files. Note all files should use the same delimiter.')
self.add_argument('--save_path', type=str, default='ckpts',
help='the path of the directory where models and logs are saved.')
self.add_argument('--no_save_emb', action='store_true',
help='Disable saving the embeddings under save_path.')
self.add_argument('--max_step', type=int, default=80000,
help='The maximal number of steps to train the model.' \
'A step trains the model with a batch of data.')
self.add_argument('--batch_size', type=int, default=1024,
help='The batch size for training.')
self.add_argument('--batch_size_eval', type=int, default=8,
help='The batch size used for validation and test.')
self.add_argument('--neg_sample_size', type=int, default=256,
help='The number of negative samples we use for each positive sample in the training.')
self.add_argument('--neg_deg_sample', action='store_true',
help='Construct negative samples proportional to vertex degree in the training.' \
'When this option is turned on, the number of negative samples per positive edge' \
'will be doubled. Half of the negative samples are generated uniformly while' \
'the other half are generated proportional to vertex degree.')
self.add_argument('--neg_deg_sample_eval', action='store_true',
help='Construct negative samples proportional to vertex degree in the evaluation.')
self.add_argument('--neg_sample_size_eval', type=int, default=-1,
help='The number of negative samples we use to evaluate a positive sample.')
self.add_argument('--eval_percent', type=float, default=1,
help='Randomly sample some percentage of edges for evaluation.')
self.add_argument('--no_eval_filter', action='store_true',
help='Disable filter positive edges from randomly constructed negative edges for evaluation')
self.add_argument('-log', '--log_interval', type=int, default=1000,
help='Print runtime of different components every x steps.')
self.add_argument('--eval_interval', type=int, default=10000,
help='Print evaluation results on the validation dataset every x steps' \
'if validation is turned on')
self.add_argument('--test', action='store_true',
help='Evaluate the model on the test set after the model is trained.')
self.add_argument('--num_proc', type=int, default=1,
help='The number of processes to train the model in parallel.' \
'In multi-GPU training, the number of processes by default is set to match the number of GPUs.' \
'If set explicitly, the number of processes needs to be divisible by the number of GPUs.')
self.add_argument('--num_thread', type=int, default=1,
help='The number of CPU threads to train the model in each process.' \
'This argument is used for multiprocessing training.')
self.add_argument('--force_sync_interval', type=int, default=-1,
help='We force a synchronization between processes every x steps for' \
'multiprocessing training. This potentially stablizes the training process'
'to get a better performance. For multiprocessing training, it is set to 1000 by default.')
self.add_argument('--hidden_dim', type=int, default=400,
help='The embedding size of relation and entity')
self.add_argument('--lr', type=float, default=0.01,
help='The learning rate. DGL-KE uses Adagrad to optimize the model parameters.')
self.add_argument('-g', '--gamma', type=float, default=12.0,
help='The margin value in the score function. It is used by TransX and RotatE.')
self.add_argument('-de', '--double_ent', action='store_true',
help='Double entitiy dim for complex number It is used by RotatE.')
self.add_argument('-dr', '--double_rel', action='store_true',
help='Double relation dim for complex number.')
self.add_argument('-adv', '--neg_adversarial_sampling', action='store_true',
help='Indicate whether to use negative adversarial sampling.' \
'It will weight negative samples with higher scores more.')
self.add_argument('-a', '--adversarial_temperature', default=1.0, type=float,
help='The temperature used for negative adversarial sampling.')
self.add_argument('-rc', '--regularization_coef', type=float, default=0.000002,
help='The coefficient for regularization.')
self.add_argument('-rn', '--regularization_norm', type=int, default=3,
help='norm used in regularization.')
class ArgParser(CommonArgParser):
def __init__(self):
super(ArgParser, self).__init__()
self.add_argument('--gpu', type=int, default=[-1], nargs='+',
help='A list of gpu ids, e.g. 0 1 2 4')
self.add_argument('--mix_cpu_gpu', action='store_true',
help='Training a knowledge graph embedding model with both CPUs and GPUs.' \
'The embeddings are stored in CPU memory and the training is performed in GPUs.' \
'This is usually used for training a large knowledge graph embeddings.')
self.add_argument('--valid', action='store_true',
help='Evaluate the model on the validation set in the training.')
self.add_argument('--rel_part', action='store_true',
help='Enable relation partitioning for multi-GPU training.')
self.add_argument('--async_update', action='store_true',
help='Allow asynchronous update on node embedding for multi-GPU training.' \
'This overlaps CPU and GPU computation to speed up.')
def get_device(args):
return th.device('cpu') if args.gpu[0] < 0 else th.device('cuda:' + str(args.gpu[0]))
def save_model(args, model, emap_file=None, rmap_file=None):
if not os.path.exists(args.save_path):
os.mkdir(args.save_path)
print('Save model to {}'.format(args.save_path))
model.save_emb(args.save_path, args.dataset)
# We need to save the model configurations as well.
conf_file = os.path.join(args.save_path, 'config.json')
with open(conf_file, 'w') as outfile:
json.dump({'dataset': args.dataset,
'model': args.model_name,
'emb_size': args.hidden_dim,
'max_train_step': args.max_step,
'batch_size': args.batch_size,
'neg_sample_size': args.neg_sample_size,
'lr': args.lr,
'gamma': args.gamma,
'double_ent': args.double_ent,
'double_rel': args.double_rel,
'neg_adversarial_sampling': args.neg_adversarial_sampling,
'adversarial_temperature': args.adversarial_temperature,
'regularization_coef': args.regularization_coef,
'regularization_norm': args.regularization_norm,
'emap_file': emap_file,
'rmap_file': rmap_file},
outfile, indent=4)
def prepare_save_path(args):
if not os.path.exists(args.save_path):
os.mkdir(args.save_path)
folder = '{}_{}_'.format(args.model_name, args.dataset)
n = len([x for x in os.listdir(args.save_path) if x.startswith(folder)])
folder += str(n)
args.save_path = os.path.join(args.save_path, folder)
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
def thread_wrapped_func(func):
"""Wrapped func for torch.multiprocessing.Process.
With this wrapper we can use OMP threads in subprocesses
otherwise, OMP_NUM_THREADS=1 is mandatory.
How to use:
@thread_wrapped_func
def func_to_wrap(args ...):
"""
@wraps(func)
def decorated_function(*args, **kwargs):
queue = Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))
start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function
@thread_wrapped_func
def async_update(args, emb, queue):
"""Asynchronous embedding update for entity embeddings.
How it works:
1. trainer process push entity embedding update requests into the queue.
2. async_update process pull requests from the queue, calculate
the gradient state and gradient and write it into entity embeddings.
Parameters
----------
args :
Global confis.
emb : ExternalEmbedding
The entity embeddings.
queue:
The request queue.
"""
th.set_num_threads(args.num_thread)
while True:
(grad_indices, grad_values, gpu_id) = queue.get()
clr = emb.args.lr
if grad_indices is None:
return
with th.no_grad():
grad_sum = (grad_values * grad_values).mean(1)
device = emb.state_sum.device
if device != grad_indices.device:
grad_indices = grad_indices.to(device)
if device != grad_sum.device:
grad_sum = grad_sum.to(device)
emb.state_sum.index_add_(0, grad_indices, grad_sum)
std = emb.state_sum[grad_indices] # _sparse_mask
if gpu_id >= 0:
std = std.cuda(gpu_id)
std_values = std.sqrt_().add_(1e-10).unsqueeze(1)
tmp = (-clr * grad_values / std_values)
if tmp.device != device:
tmp = tmp.to(device)
emb.emb.index_add_(0, grad_indices, tmp)
class ExternalEmbedding:
"""Sparse Embedding for Knowledge Graph
It is used to store both entity embeddings and relation embeddings.
Parameters
----------
args :
Global configs.
num : int
Number of embeddings.
dim : int
Embedding dimention size.
device : th.device
Device to store the embedding.
"""
def __init__(self, args, num, dim, device):
self.gpu = args.gpu
self.args = args
self.num = num
self.trace = []
self.emb = th.empty(num, dim, dtype=th.float32, device=device)
self.state_sum = self.emb.new().resize_(self.emb.size(0)).zero_()
self.state_step = 0
self.has_cross_rel = False
# queue used by asynchronous update
self.async_q = None
# asynchronous update process
self.async_p = None
def init(self, emb_init):
"""Initializing the embeddings.
Parameters
----------
emb_init : float
The intial embedding range should be [-emb_init, emb_init].
"""
INIT.uniform_(self.emb, -emb_init, emb_init)
INIT.zeros_(self.state_sum)
def setup_cross_rels(self, cross_rels, global_emb):
cpu_bitmap = th.zeros((self.num,), dtype=th.bool)
for i, rel in enumerate(cross_rels):
cpu_bitmap[rel] = 1
self.cpu_bitmap = cpu_bitmap
self.has_cross_rel = True
self.global_emb = global_emb
def get_noncross_idx(self, idx):
cpu_mask = self.cpu_bitmap[idx]
gpu_mask = ~cpu_mask
return idx[gpu_mask]
def share_memory(self):
"""Use torch.tensor.share_memory_() to allow cross process tensor access
"""
self.emb.share_memory_()
self.state_sum.share_memory_()
def __call__(self, idx, gpu_id=-1, trace=True):
""" Return sliced tensor.
Parameters
----------
idx : th.tensor
Slicing index
gpu_id : int
Which gpu to put sliced data in.
trace : bool
If True, trace the computation. This is required in training.
If False, do not trace the computation.
Default: True
"""
if self.has_cross_rel:
cpu_idx = idx.cpu()
cpu_mask = self.cpu_bitmap[cpu_idx]
cpu_idx = cpu_idx[cpu_mask]
cpu_idx = th.unique(cpu_idx)
if cpu_idx.shape[0] != 0:
cpu_emb = self.global_emb.emb[cpu_idx]
self.emb[cpu_idx] = cpu_emb.cuda(gpu_id)
s = self.emb[idx]
if gpu_id >= 0:
s = s.cuda(gpu_id)
# During the training, we need to trace the computation.
# In this case, we need to record the computation path and compute the gradients.
if trace:
data = s.clone().detach().requires_grad_(True)
self.trace.append((idx, data))
else:
data = s
return data
def update(self, gpu_id=-1):
""" Update embeddings in a sparse manner
Sparse embeddings are updated in mini batches. we maintains gradient states for
each embedding so they can be updated separately.
Parameters
----------
gpu_id : int
Which gpu to accelerate the calculation. if -1 is provided, cpu is used.
"""
self.state_step += 1
with th.no_grad():
for idx, data in self.trace:
grad = data.grad.data
clr = self.args.lr
# clr = self.args.lr / (1 + (self.state_step - 1) * group['lr_decay'])
# the update is non-linear so indices must be unique
grad_indices = idx
grad_values = grad
if self.async_q is not None:
grad_indices.share_memory_()
grad_values.share_memory_()
self.async_q.put((grad_indices, grad_values, gpu_id))
else:
grad_sum = (grad_values * grad_values).mean(1)
device = self.state_sum.device
if device != grad_indices.device:
grad_indices = grad_indices.to(device)
if device != grad_sum.device:
grad_sum = grad_sum.to(device)
if self.has_cross_rel:
cpu_mask = self.cpu_bitmap[grad_indices]
cpu_idx = grad_indices[cpu_mask]
if cpu_idx.shape[0] > 0:
cpu_grad = grad_values[cpu_mask]
cpu_sum = grad_sum[cpu_mask].cpu()
cpu_idx = cpu_idx.cpu()
self.global_emb.state_sum.index_add_(0, cpu_idx, cpu_sum)
std = self.global_emb.state_sum[cpu_idx]
if gpu_id >= 0:
std = std.cuda(gpu_id)
std_values = std.sqrt_().add_(1e-10).unsqueeze(1)
tmp = (-clr * cpu_grad / std_values)
tmp = tmp.cpu()
self.global_emb.emb.index_add_(0, cpu_idx, tmp)
self.state_sum.index_add_(0, grad_indices, grad_sum)
std = self.state_sum[grad_indices] # _sparse_mask
if gpu_id >= 0:
std = std.cuda(gpu_id)
std_values = std.sqrt_().add_(1e-10).unsqueeze(1)
tmp = (-clr * grad_values / std_values)
if tmp.device != device:
tmp = tmp.to(device)
# TODO(zhengda) the overhead is here.
self.emb.index_add_(0, grad_indices, tmp)
self.trace = []
def create_async_update(self):
"""Set up the async update subprocess.
"""
self.async_q = Queue(1)
self.async_p = mp.Process(target=async_update, args=(self.args, self, self.async_q))
self.async_p.start()
def finish_async_update(self):
"""Notify the async update subprocess to quit.
"""
self.async_q.put((None, None, None))
self.async_p.join()
def curr_emb(self):
"""Return embeddings in trace.
"""
data = [data for _, data in self.trace]
return th.cat(data, 0)
def save(self, path, name):
"""Save embeddings.
Parameters
----------
path : str
Directory to save the embedding.
name : str
Embedding name.
"""
file_name = os.path.join(path, name + '.npy')
np.save(file_name, self.emb.cpu().detach().numpy())
def load(self, path, name):
"""Load embeddings.
Parameters
----------
path : str
Directory to load the embedding.
name : str
Embedding name.
"""
file_name = os.path.join(path, name + '.npy')
self.emb = th.Tensor(np.load(file_name))
def batched_l2_dist(a, b):
a_squared = a.norm(dim=-1).pow(2)
b_squared = b.norm(dim=-1).pow(2)
squared_res = th.baddbmm(
b_squared.unsqueeze(-2), a, b.transpose(-2, -1), alpha=-2
).add_(a_squared.unsqueeze(-1))
res = squared_res.clamp_min_(1e-30).sqrt_()
return res
def batched_l1_dist(a, b):
res = th.cdist(a, b, p=1)
return res
class TransEScore(nn.Module):
"""TransE score function
Paper link: https://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data
"""
def __init__(self, gamma, dist_func='l2'):
super(TransEScore, self).__init__()
self.gamma = gamma
if dist_func == 'l1':
self.neg_dist_func = batched_l1_dist
self.dist_ord = 1
else: # default use l2
self.neg_dist_func = batched_l2_dist
self.dist_ord = 2
def edge_func(self, edges):
head = edges.src['emb']
tail = edges.dst['emb']
rel = edges.data['emb']
score = head + rel - tail
return {'score': self.gamma - th.norm(score, p=self.dist_ord, dim=-1)}
def infer(self, head_emb, rel_emb, tail_emb):
head_emb = head_emb.unsqueeze(1)
rel_emb = rel_emb.unsqueeze(0)
score = (head_emb + rel_emb).unsqueeze(2) - tail_emb.unsqueeze(0).unsqueeze(0)
return self.gamma - th.norm(score, p=self.dist_ord, dim=-1)
def prepare(self, g, gpu_id, trace=False):
pass
def create_neg_prepare(self, neg_head):
def fn(rel_id, num_chunks, head, tail, gpu_id, trace=False):
return head, tail
return fn
def forward(self, g):
g.apply_edges(lambda edges: self.edge_func(edges))
def update(self, gpu_id=-1):
pass
def reset_parameters(self):
pass
def save(self, path, name):
pass
def load(self, path, name):
pass
def create_neg(self, neg_head):
gamma = self.gamma
if neg_head:
def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size):
hidden_dim = heads.shape[1]
heads = heads.reshape(num_chunks, neg_sample_size, hidden_dim)
tails = tails - relations
tails = tails.reshape(num_chunks, chunk_size, hidden_dim)
return gamma - self.neg_dist_func(tails, heads)
return fn
else:
def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size):
hidden_dim = heads.shape[1]
heads = heads + relations
heads = heads.reshape(num_chunks, chunk_size, hidden_dim)
tails = tails.reshape(num_chunks, neg_sample_size, hidden_dim)
return gamma - self.neg_dist_func(heads, tails)
return fn
class KEModel(object):
""" DGL Knowledge Embedding Model.
Parameters
----------
args:
Global configs.
model_name : str
Which KG model to use, including 'TransE_l1', 'TransE_l2', 'TransR',
'RESCAL', 'DistMult', 'ComplEx', 'RotatE'
n_entities : int
Num of entities.
n_relations : int
Num of relations.
hidden_dim : int
Dimetion size of embedding.
gamma : float
Gamma for score function.
double_entity_emb : bool
If True, entity embedding size will be 2 * hidden_dim.
Default: False
double_relation_emb : bool
If True, relation embedding size will be 2 * hidden_dim.
Default: False
"""
def __init__(self, args, model_name, n_entities, n_relations, hidden_dim, gamma,
double_entity_emb=False, double_relation_emb=False):
super(KEModel, self).__init__()
self.args = args
self.n_entities = n_entities
self.n_relations = n_relations
self.model_name = model_name
self.hidden_dim = hidden_dim
self.eps = EMB_INIT_EPS
self.emb_init = (gamma + self.eps) / hidden_dim
entity_dim = 2 * hidden_dim if double_entity_emb else hidden_dim
relation_dim = 2 * hidden_dim if double_relation_emb else hidden_dim
device = get_device(args)
self.entity_emb = ExternalEmbedding(args, n_entities, entity_dim,
F.cpu() if args.mix_cpu_gpu else device)
rel_dim = relation_dim
self.rel_dim = rel_dim
self.entity_dim = entity_dim
self.strict_rel_part = args.strict_rel_part
self.soft_rel_part = args.soft_rel_part
if not self.strict_rel_part and not self.soft_rel_part:
self.relation_emb = ExternalEmbedding(args, n_relations, rel_dim,
F.cpu() if args.mix_cpu_gpu else device)
else:
self.global_relation_emb = ExternalEmbedding(args, n_relations, rel_dim, F.cpu())
if model_name == 'TransE' or model_name == 'TransE_l2':
self.score_func = TransEScore(gamma, 'l2')
elif model_name == 'TransE_l1':
self.score_func = TransEScore(gamma, 'l1')
self.model_name = model_name
self.head_neg_score = self.score_func.create_neg(True)
self.tail_neg_score = self.score_func.create_neg(False)
self.head_neg_prepare = self.score_func.create_neg_prepare(True)
self.tail_neg_prepare = self.score_func.create_neg_prepare(False)
self.reset_parameters()
def share_memory(self):
"""Use torch.tensor.share_memory_() to allow cross process embeddings access.
"""
self.entity_emb.share_memory()
if self.strict_rel_part or self.soft_rel_part:
self.global_relation_emb.share_memory()
else:
self.relation_emb.share_memory()
if self.model_name == 'TransR':
self.score_func.share_memory()
def save_emb(self, path, dataset):
"""Save the model.
Parameters
----------
path : str
Directory to save the model.
dataset : str
Dataset name as prefix to the saved embeddings.
"""
self.entity_emb.save(path, dataset + '_' + self.model_name + '_entity')
if self.strict_rel_part or self.soft_rel_part:
self.global_relation_emb.save(path, dataset + '_' + self.model_name + '_relation')
else:
self.relation_emb.save(path, dataset + '_' + self.model_name + '_relation')
self.score_func.save(path, dataset + '_' + self.model_name)
def load_emb(self, path, dataset):
"""Load the model.
Parameters
----------
path : str
Directory to load the model.
dataset : str
Dataset name as prefix to the saved embeddings.
"""
self.entity_emb.load(path, dataset + '_' + self.model_name + '_entity')
self.relation_emb.load(path, dataset + '_' + self.model_name + '_relation')
self.score_func.load(path, dataset + '_' + self.model_name)
def reset_parameters(self):
"""Re-initialize the model.
"""
self.entity_emb.init(self.emb_init)
self.score_func.reset_parameters()
if (not self.strict_rel_part) and (not self.soft_rel_part):
self.relation_emb.init(self.emb_init)
else:
self.global_relation_emb.init(self.emb_init)
def predict_score(self, g):
"""Predict the positive score.
Parameters
----------
g : DGLGraph
Graph holding positive edges.
Returns
-------
tensor
The positive score
"""
self.score_func(g)
return g.edata['score']
def predict_neg_score(self, pos_g, neg_g, to_device=None, gpu_id=-1, trace=False,
neg_deg_sample=False):
"""Calculate the negative score.
Parameters
----------
pos_g : DGLGraph
Graph holding positive edges.
neg_g : DGLGraph
Graph holding negative edges.
to_device : func
Function to move data into device.
gpu_id : int
Which gpu to move data to.
trace : bool
If True, trace the computation. This is required in training.
If False, do not trace the computation.
Default: False
neg_deg_sample : bool
If True, we use the head and tail nodes of the positive edges to
construct negative edges.
Default: False
Returns
-------
tensor
The negative score
"""
num_chunks = neg_g.num_chunks
chunk_size = neg_g.chunk_size
neg_sample_size = neg_g.neg_sample_size
mask = F.ones((num_chunks, chunk_size * (neg_sample_size + chunk_size)),
dtype=F.float32, ctx=F.context(pos_g.ndata['emb']))
if neg_g.neg_head:
neg_head_ids = neg_g.ndata['id'][neg_g.head_nid]
neg_head = self.entity_emb(neg_head_ids, gpu_id, trace)
head_ids, tail_ids = pos_g.all_edges(order='eid')
if to_device is not None and gpu_id >= 0:
tail_ids = to_device(tail_ids, gpu_id)
tail = pos_g.ndata['emb'][tail_ids]
rel = pos_g.edata['emb']
# When we train a batch, we could use the head nodes of the positive edges to
# construct negative edges. We construct a negative edge between a positive head
# node and every positive tail node.
# When we construct negative edges like this, we know there is one positive
# edge for a positive head node among the negative edges. We need to mask
# them.
if neg_deg_sample:
head = pos_g.ndata['emb'][head_ids]
head = head.reshape(num_chunks, chunk_size, -1)
neg_head = neg_head.reshape(num_chunks, neg_sample_size, -1)
neg_head = F.cat([head, neg_head], 1)
neg_sample_size = chunk_size + neg_sample_size
mask[:, 0::(neg_sample_size + 1)] = 0
neg_head = neg_head.reshape(num_chunks * neg_sample_size, -1)
neg_head, tail = self.head_neg_prepare(pos_g.edata['id'], num_chunks, neg_head, tail, gpu_id, trace)
neg_score = self.head_neg_score(neg_head, rel, tail,
num_chunks, chunk_size, neg_sample_size)
else:
neg_tail_ids = neg_g.ndata['id'][neg_g.tail_nid]
neg_tail = self.entity_emb(neg_tail_ids, gpu_id, trace)
head_ids, tail_ids = pos_g.all_edges(order='eid')
if to_device is not None and gpu_id >= 0:
head_ids = to_device(head_ids, gpu_id)
head = pos_g.ndata['emb'][head_ids]
rel = pos_g.edata['emb']
# This is negative edge construction similar to the above.
if neg_deg_sample:
tail = pos_g.ndata['emb'][tail_ids]
tail = tail.reshape(num_chunks, chunk_size, -1)
neg_tail = neg_tail.reshape(num_chunks, neg_sample_size, -1)
neg_tail = F.cat([tail, neg_tail], 1)
neg_sample_size = chunk_size + neg_sample_size
mask[:, 0::(neg_sample_size + 1)] = 0
neg_tail = neg_tail.reshape(num_chunks * neg_sample_size, -1)
head, neg_tail = self.tail_neg_prepare(pos_g.edata['id'], num_chunks, head, neg_tail, gpu_id, trace)
neg_score = self.tail_neg_score(head, rel, neg_tail,
num_chunks, chunk_size, neg_sample_size)
if neg_deg_sample:
neg_g.neg_sample_size = neg_sample_size
mask = mask.reshape(num_chunks, chunk_size, neg_sample_size)
return neg_score * mask
else:
return neg_score
def forward_test(self, pos_g, neg_g, logs, gpu_id=-1):
"""Do the forward and generate ranking results.
Parameters
----------
pos_g : DGLGraph
Graph holding positive edges.
neg_g : DGLGraph
Graph holding negative edges.
logs : List
Where to put results in.
gpu_id : int
Which gpu to accelerate the calculation. if -1 is provided, cpu is used.
"""
pos_g.ndata['emb'] = self.entity_emb(pos_g.ndata['id'], gpu_id, False)
pos_g.edata['emb'] = self.relation_emb(pos_g.edata['id'], gpu_id, False)
self.score_func.prepare(pos_g, gpu_id, False)
batch_size = pos_g.number_of_edges()
pos_scores = self.predict_score(pos_g)
pos_scores = reshape(logsigmoid(pos_scores), batch_size, -1)
neg_scores = self.predict_neg_score(pos_g, neg_g, to_device=cuda,
gpu_id=gpu_id, trace=False,
neg_deg_sample=self.args.neg_deg_sample_eval)
neg_scores = reshape(logsigmoid(neg_scores), batch_size, -1)
# We need to filter the positive edges in the negative graph.
if self.args.eval_filter:
filter_bias = reshape(neg_g.edata['bias'], batch_size, -1)
if gpu_id >= 0:
filter_bias = cuda(filter_bias, gpu_id)
neg_scores += filter_bias
# To compute the rank of a positive edge among all negative edges,
# we need to know how many negative edges have higher scores than
# the positive edge.
rankings = F.sum(neg_scores >= pos_scores, dim=1) + 1
rankings = F.asnumpy(rankings)
for i in range(batch_size):
ranking = rankings[i]
logs.append({
'MRR': 1.0 / ranking,
'MR': float(ranking),
'HITS@1': 1.0 if ranking <= 1 else 0.0,
'HITS@3': 1.0 if ranking <= 3 else 0.0,
'HITS@10': 1.0 if ranking <= 10 else 0.0
})
# @profile
def forward(self, pos_g, neg_g, gpu_id=-1):
"""Do the forward.
Parameters
----------
pos_g : DGLGraph
Graph holding positive edges.
neg_g : DGLGraph
Graph holding negative edges.
gpu_id : int
Which gpu to accelerate the calculation. if -1 is provided, cpu is used.
Returns
-------
tensor
loss value
dict
loss info
"""
pos_g.ndata['emb'] = self.entity_emb(pos_g.ndata['id'], gpu_id, True)
pos_g.edata['emb'] = self.relation_emb(pos_g.edata['id'], gpu_id, True)
self.score_func.prepare(pos_g, gpu_id, True)
pos_score = self.predict_score(pos_g)
pos_score = logsigmoid(pos_score)
if gpu_id >= 0:
neg_score = self.predict_neg_score(pos_g, neg_g, to_device=cuda,
gpu_id=gpu_id, trace=True,
neg_deg_sample=self.args.neg_deg_sample)
else:
neg_score = self.predict_neg_score(pos_g, neg_g, trace=True,
neg_deg_sample=self.args.neg_deg_sample)
neg_score = reshape(neg_score, -1, neg_g.neg_sample_size)
# Adversarial sampling
if self.args.neg_adversarial_sampling:
neg_score = F.sum(F.softmax(neg_score * self.args.adversarial_temperature, dim=1).detach()
* logsigmoid(-neg_score), dim=1)
else:
neg_score = F.mean(logsigmoid(-neg_score), dim=1)
# subsampling weight
# TODO: add subsampling to new sampler
# if self.args.non_uni_weight:
# subsampling_weight = pos_g.edata['weight']
# pos_score = (pos_score * subsampling_weight).sum() / subsampling_weight.sum()
# neg_score = (neg_score * subsampling_weight).sum() / subsampling_weight.sum()
# else:
pos_score = pos_score.mean()
neg_score = neg_score.mean()
# compute loss
loss = -(pos_score + neg_score) / 2
log = {'pos_loss': - get_scalar(pos_score),
'neg_loss': - get_scalar(neg_score),
'loss': get_scalar(loss)}
# regularization: TODO(zihao)
# TODO: only reg ent&rel embeddings. other params to be added.
if self.args.regularization_coef > 0.0 and self.args.regularization_norm > 0:
coef, nm = self.args.regularization_coef, self.args.regularization_norm
reg = coef * (norm(self.entity_emb.curr_emb(), nm) + norm(self.relation_emb.curr_emb(), nm))
log['regularization'] = get_scalar(reg)
loss = loss + reg
return loss, log
def update(self, gpu_id=-1):
""" Update the embeddings in the model
gpu_id : int
Which gpu to accelerate the calculation. if -1 is provided, cpu is used.
"""
self.entity_emb.update(gpu_id)
self.relation_emb.update(gpu_id)
self.score_func.update(gpu_id)
def prepare_relation(self, device=None):
""" Prepare relation embeddings in multi-process multi-gpu training model.
device : th.device
Which device (GPU) to put relation embeddings in.
"""
self.relation_emb = ExternalEmbedding(self.args, self.n_relations, self.rel_dim, device)
self.relation_emb.init(self.emb_init)
if self.model_name == 'TransR':
local_projection_emb = ExternalEmbedding(self.args, self.n_relations,
self.entity_dim * self.rel_dim, device)
self.score_func.prepare_local_emb(local_projection_emb)
self.score_func.reset_parameters()
def prepare_cross_rels(self, cross_rels):
self.relation_emb.setup_cross_rels(cross_rels, self.global_relation_emb)
if self.model_name == 'TransR':
self.score_func.prepare_cross_rels(cross_rels)
def writeback_relation(self, rank=0, rel_parts=None):
""" Writeback relation embeddings in a specific process to global relation embedding.
Used in multi-process multi-gpu training model.
rank : int
Process id.
rel_parts : List of tensor
List of tensor stroing edge types of each partition.
"""
idx = rel_parts[rank]
if self.soft_rel_part:
idx = self.relation_emb.get_noncross_idx(idx)
self.global_relation_emb.emb[idx] = F.copy_to(self.relation_emb.emb, F.cpu())[idx]
if self.model_name == 'TransR':
self.score_func.writeback_local_emb(idx)
def load_relation(self, device=None):
""" Sync global relation embeddings into local relation embeddings.
Used in multi-process multi-gpu training model.
device : th.device
Which device (GPU) to put relation embeddings in.
"""
self.relation_emb = ExternalEmbedding(self.args, self.n_relations, self.rel_dim, device)
self.relation_emb.emb = F.copy_to(self.global_relation_emb.emb, device)
if self.model_name == 'TransR':
local_projection_emb = ExternalEmbedding(self.args, self.n_relations,
self.entity_dim * self.rel_dim, device)
self.score_func.load_local_emb(local_projection_emb)
def create_async_update(self):
"""Set up the async update for entity embedding.
"""
self.entity_emb.create_async_update()
def finish_async_update(self):
"""Terminate the async update for entity embedding.
"""
self.entity_emb.finish_async_update()
def pull_model(self, client, pos_g, neg_g):
with th.no_grad():
entity_id = F.cat(seq=[pos_g.ndata['id'], neg_g.ndata['id']], dim=0)
relation_id = pos_g.edata['id']
entity_id = F.tensor(np.unique(F.asnumpy(entity_id)))
relation_id = F.tensor(np.unique(F.asnumpy(relation_id)))
l2g = client.get_local2global()
global_entity_id = l2g[entity_id]