Skip to content

Commit 6f120ff

Browse files
Fixed an error in generated node features in StochasticBlockModelDataset (#4617)
* Fix an error in sbm node features (x) - If n_clusters_per_class in make_classification is bigger than 1, its output is not sorted with respect to the labels (y). * Updated changelog Co-authored-by: Matthias Fey <[email protected]>
1 parent 405b2ba commit 6f120ff

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1212
- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
1313
- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))
1414
### Changed
15+
- The generated node features of `StochasticBlockModelDataset` are now ordered with respect to their labels ([#4617](https://github.com/pyg-team/pytorch_geometric/pull/4617))
1516
- Removed unnecessary colons and fixed typos in the documentation ([#4616](https://github.com/pyg-team/pytorch_geometric/pull/4616))
1617
- The `bias` argument in `TAGConv` is now actually applied ([#4597](https://github.com/pyg-team/pytorch_geometric/pull/4597))
1718
- Fixed subclass behaviour of `process` and `download` in `Datsaet` ([#4586](https://github.com/pyg-team/pytorch_geometric/pull/4586))

torch_geometric/datasets/sbm_dataset.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
from typing import Callable, List, Optional, Union
33

4+
import numpy as np
45
import torch
56
from torch import Tensor
67

@@ -94,13 +95,14 @@ def process(self):
9495

9596
x = None
9697
if self.num_channels is not None:
97-
x, _ = make_classification(
98+
x, y_not_sorted = make_classification(
9899
n_samples=num_samples,
99100
n_features=self.num_channels,
100101
n_classes=num_classes,
101102
weights=self.block_sizes / num_samples,
102103
**self.kwargs,
103104
)
105+
x = x[np.argsort(y_not_sorted)]
104106
x = torch.from_numpy(x).to(torch.float)
105107

106108
y = torch.arange(num_classes).repeat_interleave(self.block_sizes)

0 commit comments

Comments
 (0)