-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
Copy pathhgt_dblp.py
88 lines (66 loc) · 2.61 KB
/
hgt_dblp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os.path as osp
import torch
import torch.nn.functional as F
import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.datasets import DBLP
from torch_geometric.nn import HGTConv, Linear
path = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/DBLP')
# We initialize conference node features with a single one-vector as feature:
dataset = DBLP(path, transform=T.Constant(node_types='conference'))
data = dataset[0]
class HGT(torch.nn.Module):
def __init__(self, hidden_channels, out_channels, num_heads, num_layers):
super().__init__()
self.lin_dict = torch.nn.ModuleDict()
for node_type in data.node_types:
self.lin_dict[node_type] = Linear(-1, hidden_channels)
self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
conv = HGTConv(hidden_channels, hidden_channels, data.metadata(),
num_heads)
self.convs.append(conv)
self.lin = Linear(hidden_channels, out_channels)
def forward(self, x_dict, edge_index_dict):
x_dict = {
node_type: self.lin_dict[node_type](x).relu_()
for node_type, x in x_dict.items()
}
for conv in self.convs:
x_dict = conv(x_dict, edge_index_dict)
return self.lin(x_dict['author'])
model = HGT(hidden_channels=64, out_channels=4, num_heads=2, num_layers=1)
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch_geometric.is_xpu_available():
device = torch.device('xpu')
else:
device = torch.device('cpu')
data, model = data.to(device), model.to(device)
with torch.no_grad(): # Initialize lazy modules.
out = model(data.x_dict, data.edge_index_dict)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001)
def train():
model.train()
optimizer.zero_grad()
out = model(data.x_dict, data.edge_index_dict)
mask = data['author'].train_mask
loss = F.cross_entropy(out[mask], data['author'].y[mask])
loss.backward()
optimizer.step()
return float(loss)
@torch.no_grad()
def test():
model.eval()
pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1)
accs = []
for split in ['train_mask', 'val_mask', 'test_mask']:
mask = data['author'][split]
acc = (pred[mask] == data['author'].y[mask]).sum() / mask.sum()
accs.append(float(acc))
return accs
for epoch in range(1, 101):
loss = train()
train_acc, val_acc, test_acc = test()
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '
f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')