Skip to content

Commit be2a463

Browse files
authored
Fix the interplay between TUDataset and pre_transform that modify node features (#4669)
* fix num node attrs * changelog * typo
1 parent 6156650 commit be2a463

File tree

3 files changed

+34
-26
lines changed

3 files changed

+34
-26
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2323
- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
2424
- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))
2525
### Changed
26+
- Fixed the interplay between `TUDataset` and `pre_transform` that modify node features ([#4669](https://github.com/pyg-team/pytorch_geometric/pull/4669))
2627
- Make use of the `pyg_sphinx_theme` documentation template ([#4664](https://github.com/pyg-team/pyg-lib/pull/4664), [#4667](https://github.com/pyg-team/pyg-lib/pull/4667))
2728
- Refactored reading molecular positions from sdf file for qm9 datasets ([4654](https://github.com/pyg-team/pytorch_geometric/pull/4654))
2829
- Fixed `MLP.jittable()` bug in case `return_emb=True` ([#4645](https://github.com/pyg-team/pytorch_geometric/pull/4645), [#4648](https://github.com/pyg-team/pytorch_geometric/pull/4648))

torch_geometric/datasets/tu_dataset.py

+16-22
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,16 @@ def __init__(self, root: str, name: str,
121121
self.name = name
122122
self.cleaned = cleaned
123123
super().__init__(root, transform, pre_transform, pre_filter)
124-
self.data, self.slices = torch.load(self.processed_paths[0])
124+
125+
out = torch.load(self.processed_paths[0])
126+
if not isinstance(out, tuple) and len(out) != 3:
127+
raise RuntimeError(
128+
"The 'data' object was created by an older version of PyG. "
129+
"If this error occurred while loading an already existing "
130+
"dataset, remove the 'processed/' directory in the dataset's "
131+
"root folder and try again.")
132+
self.data, self.slices, self.sizes = out
133+
125134
if self.data.x is not None and not use_node_attr:
126135
num_node_attributes = self.num_node_attributes
127136
self.data.x = self.data.x[:, num_node_attributes:]
@@ -141,34 +150,19 @@ def processed_dir(self) -> str:
141150

142151
@property
143152
def num_node_labels(self) -> int:
144-
if self.data.x is None:
145-
return 0
146-
for i in range(self.data.x.size(1)):
147-
x = self.data.x[:, i:]
148-
if ((x == 0) | (x == 1)).all() and (x.sum(dim=1) == 1).all():
149-
return self.data.x.size(1) - i
150-
return 0
153+
return self.sizes['num_node_labels']
151154

152155
@property
153156
def num_node_attributes(self) -> int:
154-
if self.data.x is None:
155-
return 0
156-
return self.data.x.size(1) - self.num_node_labels
157+
return self.sizes['num_node_attributes']
157158

158159
@property
159160
def num_edge_labels(self) -> int:
160-
if self.data.edge_attr is None:
161-
return 0
162-
for i in range(self.data.edge_attr.size(1)):
163-
if self.data.edge_attr[:, i:].sum() == self.data.edge_attr.size(0):
164-
return self.data.edge_attr.size(1) - i
165-
return 0
161+
return self.sizes['num_edge_labels']
166162

167163
@property
168164
def num_edge_attributes(self) -> int:
169-
if self.data.edge_attr is None:
170-
return 0
171-
return self.data.edge_attr.size(1) - self.num_edge_labels
165+
return self.sizes['num_edge_attributes']
172166

173167
@property
174168
def raw_file_names(self) -> List[str]:
@@ -189,7 +183,7 @@ def download(self):
189183
os.rename(osp.join(folder, self.name), self.raw_dir)
190184

191185
def process(self):
192-
self.data, self.slices = read_tu_data(self.raw_dir, self.name)
186+
self.data, self.slices, sizes = read_tu_data(self.raw_dir, self.name)
193187

194188
if self.pre_filter is not None:
195189
data_list = [self.get(idx) for idx in range(len(self))]
@@ -201,7 +195,7 @@ def process(self):
201195
data_list = [self.pre_transform(data) for data in data_list]
202196
self.data, self.slices = self.collate(data_list)
203197

204-
torch.save((self.data, self.slices), self.processed_paths[0])
198+
torch.save((self.data, self.slices, sizes), self.processed_paths[0])
205199

206200
def __repr__(self) -> str:
207201
return f'{self.name}({len(self)})'

torch_geometric/io/tu.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@ def read_tu_data(folder, prefix):
2424
edge_index = read_file(folder, prefix, 'A', torch.long).t() - 1
2525
batch = read_file(folder, prefix, 'graph_indicator', torch.long) - 1
2626

27-
node_attributes = node_labels = None
27+
node_attributes = torch.empty((batch.size(0), 0))
2828
if 'node_attributes' in names:
2929
node_attributes = read_file(folder, prefix, 'node_attributes')
30+
31+
node_labels = torch.empty((batch.size(0), 0))
3032
if 'node_labels' in names:
3133
node_labels = read_file(folder, prefix, 'node_labels', torch.long)
3234
if node_labels.dim() == 1:
@@ -35,11 +37,12 @@ def read_tu_data(folder, prefix):
3537
node_labels = node_labels.unbind(dim=-1)
3638
node_labels = [F.one_hot(x, num_classes=-1) for x in node_labels]
3739
node_labels = torch.cat(node_labels, dim=-1).to(torch.float)
38-
x = cat([node_attributes, node_labels])
3940

40-
edge_attributes, edge_labels = None, None
41+
edge_attributes = torch.empty((edge_index.size(1), 0))
4142
if 'edge_attributes' in names:
4243
edge_attributes = read_file(folder, prefix, 'edge_attributes')
44+
45+
edge_labels = torch.empty((edge_index.size(1), 0))
4346
if 'edge_labels' in names:
4447
edge_labels = read_file(folder, prefix, 'edge_labels', torch.long)
4548
if edge_labels.dim() == 1:
@@ -48,6 +51,8 @@ def read_tu_data(folder, prefix):
4851
edge_labels = edge_labels.unbind(dim=-1)
4952
edge_labels = [F.one_hot(e, num_classes=-1) for e in edge_labels]
5053
edge_labels = torch.cat(edge_labels, dim=-1).to(torch.float)
54+
55+
x = cat([node_attributes, node_labels])
5156
edge_attr = cat([edge_attributes, edge_labels])
5257

5358
y = None
@@ -65,7 +70,14 @@ def read_tu_data(folder, prefix):
6570
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
6671
data, slices = split(data, batch)
6772

68-
return data, slices
73+
sizes = {
74+
'num_node_attributes': node_attributes.size(-1),
75+
'num_node_labels': node_labels.size(-1),
76+
'num_edge_attributes': edge_attributes.size(-1),
77+
'num_edge_labels': edge_labels.size(-1),
78+
}
79+
80+
return data, slices, sizes
6981

7082

7183
def read_file(folder, prefix, name, dtype=None):
@@ -75,6 +87,7 @@ def read_file(folder, prefix, name, dtype=None):
7587

7688
def cat(seq):
7789
seq = [item for item in seq if item is not None]
90+
seq = [item for item in seq if item.numel() > 0]
7891
seq = [item.unsqueeze(-1) if item.dim() == 1 else item for item in seq]
7992
return torch.cat(seq, dim=-1) if len(seq) > 0 else None
8093

0 commit comments

Comments
 (0)