Skip to content

Commit 34c7406

Browse files
committed
update asserts
1 parent ca5ed0a commit 34c7406

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

torch_geometric/nn/aggr/lstm.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,17 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
3737
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
3838
dim: int = -2) -> Tensor:
3939

40-
assert index is not None # TODO
41-
assert x.dim() == 2 and dim in [-2, 0]
40+
if index is None: # TODO
41+
raise NotImplementedError(f"'{self.__class__.__name__}' with "
42+
f"'ptr' not yet supported")
43+
44+
if x.dim() != 2:
45+
raise ValueError(f"'{self.__class__.__name__}' requires "
46+
f"two-dimensional inputs (got '{x.dim()}')")
47+
48+
if dim not in [-2, 0]:
49+
raise ValueError(f"'{self.__class__.__name__}' needs to perform "
50+
f"aggregation in first dimension (got '{dim}')")
4251

4352
x, _ = to_dense_batch(x, index, batch_size=dim_size)
4453
return self.lstm(x)[0][:, -1]

0 commit comments

Comments
 (0)