diff --git a/CHANGELOG.md b/CHANGELOG.md index bd0d1e2f8305..17065c8be284 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582)) - Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581)) ### Changed +- Fixed the interplay between `TUDataset` and `pre_transform` that modify node features ([#4669](https://github.com/pyg-team/pytorch_geometric/pull/4669)) - Make use of the `pyg_sphinx_theme` documentation template ([#4664](https://github.com/pyg-team/pyg-lib/pull/4664), [#4667](https://github.com/pyg-team/pyg-lib/pull/4667)) - Refactored reading molecular positions from sdf file for qm9 datasets ([4654](https://github.com/pyg-team/pytorch_geometric/pull/4654)) - Fixed `MLP.jittable()` bug in case `return_emb=True` ([#4645](https://github.com/pyg-team/pytorch_geometric/pull/4645), [#4648](https://github.com/pyg-team/pytorch_geometric/pull/4648)) diff --git a/torch_geometric/datasets/tu_dataset.py b/torch_geometric/datasets/tu_dataset.py index 4f948f207cdb..88e2af51af63 100644 --- a/torch_geometric/datasets/tu_dataset.py +++ b/torch_geometric/datasets/tu_dataset.py @@ -121,7 +121,16 @@ def __init__(self, root: str, name: str, self.name = name self.cleaned = cleaned super().__init__(root, transform, pre_transform, pre_filter) - self.data, self.slices = torch.load(self.processed_paths[0]) + + out = torch.load(self.processed_paths[0]) + if not isinstance(out, tuple) and len(out) != 3: + raise RuntimeError( + "The 'data' object was created by an older version of PyG. " + "If this error occurred while loading an already existing " + "dataset, remove the 'processed/' directory in the dataset's " + "root folder and try again.") + self.data, self.slices, self.sizes = out + if self.data.x is not None and not use_node_attr: num_node_attributes = self.num_node_attributes self.data.x = self.data.x[:, num_node_attributes:] @@ -141,34 +150,19 @@ def processed_dir(self) -> str: @property def num_node_labels(self) -> int: - if self.data.x is None: - return 0 - for i in range(self.data.x.size(1)): - x = self.data.x[:, i:] - if ((x == 0) | (x == 1)).all() and (x.sum(dim=1) == 1).all(): - return self.data.x.size(1) - i - return 0 + return self.sizes['num_node_labels'] @property def num_node_attributes(self) -> int: - if self.data.x is None: - return 0 - return self.data.x.size(1) - self.num_node_labels + return self.sizes['num_node_attributes'] @property def num_edge_labels(self) -> int: - if self.data.edge_attr is None: - return 0 - for i in range(self.data.edge_attr.size(1)): - if self.data.edge_attr[:, i:].sum() == self.data.edge_attr.size(0): - return self.data.edge_attr.size(1) - i - return 0 + return self.sizes['num_edge_labels'] @property def num_edge_attributes(self) -> int: - if self.data.edge_attr is None: - return 0 - return self.data.edge_attr.size(1) - self.num_edge_labels + return self.sizes['num_edge_attributes'] @property def raw_file_names(self) -> List[str]: @@ -189,7 +183,7 @@ def download(self): os.rename(osp.join(folder, self.name), self.raw_dir) def process(self): - self.data, self.slices = read_tu_data(self.raw_dir, self.name) + self.data, self.slices, sizes = read_tu_data(self.raw_dir, self.name) if self.pre_filter is not None: data_list = [self.get(idx) for idx in range(len(self))] @@ -201,7 +195,7 @@ def process(self): data_list = [self.pre_transform(data) for data in data_list] self.data, self.slices = self.collate(data_list) - torch.save((self.data, self.slices), self.processed_paths[0]) + torch.save((self.data, self.slices, sizes), self.processed_paths[0]) def __repr__(self) -> str: return f'{self.name}({len(self)})' diff --git a/torch_geometric/io/tu.py b/torch_geometric/io/tu.py index 0483c58fd9e9..4d85d8ea05cc 100644 --- a/torch_geometric/io/tu.py +++ b/torch_geometric/io/tu.py @@ -24,9 +24,11 @@ def read_tu_data(folder, prefix): edge_index = read_file(folder, prefix, 'A', torch.long).t() - 1 batch = read_file(folder, prefix, 'graph_indicator', torch.long) - 1 - node_attributes = node_labels = None + node_attributes = torch.empty((batch.size(0), 0)) if 'node_attributes' in names: node_attributes = read_file(folder, prefix, 'node_attributes') + + node_labels = torch.empty((batch.size(0), 0)) if 'node_labels' in names: node_labels = read_file(folder, prefix, 'node_labels', torch.long) if node_labels.dim() == 1: @@ -35,11 +37,12 @@ def read_tu_data(folder, prefix): node_labels = node_labels.unbind(dim=-1) node_labels = [F.one_hot(x, num_classes=-1) for x in node_labels] node_labels = torch.cat(node_labels, dim=-1).to(torch.float) - x = cat([node_attributes, node_labels]) - edge_attributes, edge_labels = None, None + edge_attributes = torch.empty((edge_index.size(1), 0)) if 'edge_attributes' in names: edge_attributes = read_file(folder, prefix, 'edge_attributes') + + edge_labels = torch.empty((edge_index.size(1), 0)) if 'edge_labels' in names: edge_labels = read_file(folder, prefix, 'edge_labels', torch.long) if edge_labels.dim() == 1: @@ -48,6 +51,8 @@ def read_tu_data(folder, prefix): edge_labels = edge_labels.unbind(dim=-1) edge_labels = [F.one_hot(e, num_classes=-1) for e in edge_labels] edge_labels = torch.cat(edge_labels, dim=-1).to(torch.float) + + x = cat([node_attributes, node_labels]) edge_attr = cat([edge_attributes, edge_labels]) y = None @@ -65,7 +70,14 @@ def read_tu_data(folder, prefix): data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y) data, slices = split(data, batch) - return data, slices + sizes = { + 'num_node_attributes': node_attributes.size(-1), + 'num_node_labels': node_labels.size(-1), + 'num_edge_attributes': edge_attributes.size(-1), + 'num_edge_labels': edge_labels.size(-1), + } + + return data, slices, sizes def read_file(folder, prefix, name, dtype=None): @@ -75,6 +87,7 @@ def read_file(folder, prefix, name, dtype=None): def cat(seq): seq = [item for item in seq if item is not None] + seq = [item for item in seq if item.numel() > 0] seq = [item.unsqueeze(-1) if item.dim() == 1 else item for item in seq] return torch.cat(seq, dim=-1) if len(seq) > 0 else None