diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f80882306c8..c5b5b6d9ee6f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582)) - Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581)) ### Changed +- Fixed `data.subgraph` generation for 0-dim tensors ([#4932](https://github.com/pyg-team/pytorch_geometric/pull/4932)) - Removed unnecssary inclusion of self-loops when sampling negative edges ([#4880](https://github.com/pyg-team/pytorch_geometric/pull/4880)) - Fixed `InMemoryDataset` inferring wrong `len` for lists of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837)) - Fixed `Batch.separate` when using it for lists of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837)) diff --git a/torch_geometric/data/storage.py b/torch_geometric/data/storage.py index 1b51571f277b..569c1de5585b 100644 --- a/torch_geometric/data/storage.py +++ b/torch_geometric/data/storage.py @@ -301,7 +301,7 @@ def is_node_attr(self, key: str) -> bool: cat_dim = self._parent().__cat_dim__(key, value, self) if not isinstance(value, Tensor): return False - if value.size(cat_dim) != self.num_nodes: + if value.dim() == 0 or value.size(cat_dim) != self.num_nodes: return False return True @@ -385,7 +385,7 @@ def is_edge_attr(self, key: str) -> bool: cat_dim = self._parent().__cat_dim__(key, value, self) if not isinstance(value, Tensor): return False - if value.size(cat_dim) != self.num_edges: + if value.dim() == 0 or value.size(cat_dim) != self.num_edges: return False return True @@ -466,7 +466,7 @@ def is_node_attr(self, key: str) -> bool: num_nodes, num_edges = self.num_nodes, self.num_edges if not isinstance(value, Tensor): return False - if value.size(cat_dim) != num_nodes: + if value.dim() == 0 or value.size(cat_dim) != num_nodes: return False if num_nodes != num_edges: return True @@ -479,7 +479,7 @@ def is_edge_attr(self, key: str) -> bool: num_nodes, num_edges = self.num_nodes, self.num_edges if not isinstance(value, Tensor): return False - if value.size(cat_dim) != num_edges: + if value.dim() == 0 or value.size(cat_dim) != num_edges: return False if num_nodes != num_edges: return True diff --git a/torch_geometric/loader/cluster.py b/torch_geometric/loader/cluster.py index 7e6c9392c8f1..5b86e9d3b508 100644 --- a/torch_geometric/loader/cluster.py +++ b/torch_geometric/loader/cluster.py @@ -60,17 +60,15 @@ def __init__(self, data, num_parts: int, recursive: bool = False, self.perm = perm def __permute_data__(self, data, node_idx, adj): - data = copy.copy(data) - N = data.num_nodes + out = copy.copy(data) + for key, value in data.items(): + if data.is_node_attr(key): + out[key] = value[node_idx] - for key, item in data: - if isinstance(item, torch.Tensor) and item.size(0) == N: - data[key] = item[node_idx] + out.edge_index = None + out.adj = adj - data.edge_index = None - data.adj = adj - - return data + return out def __len__(self): return self.partptr.numel() - 1