Skip to content

Commit 091887b

Browse files
committedFeb 7, 2025
Clarify input tensor shape in pixelshuffle and pixelunshuffle functions and simplify ValueError message in pixelunshuffle
1 parent 61efefb commit 091887b

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed
 

‎monai/networks/utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def pixelshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> torch
377377
See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution".
378378
379379
Args:
380-
x: Input tensor
380+
x: Input tensor with shape BCHW[D]
381381
spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D
382382
scale_factor: factor to rescale the spatial dimensions by, must be >=1
383383
@@ -423,7 +423,7 @@ def pixelunshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> tor
423423
See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution".
424424
425425
Args:
426-
x: Input tensor
426+
x: Input tensor with shape BCHW[D]
427427
spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D
428428
scale_factor: factor to reduce the spatial dimensions by, must be >=1
429429
@@ -443,7 +443,7 @@ def pixelunshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> tor
443443

444444
if any(d % factor != 0 for d in input_size[2:]):
445445
raise ValueError(
446-
f"All spatial dimensions must be divisible by factor {factor}. " f"Got spatial dimensions: {input_size[2:]}"
446+
f"All spatial dimensions must be divisible by factor {factor}. " f", spatial shape is: {input_size[2:]}"
447447
)
448448
output_size = [batch_size, new_channels] + [d // factor for d in input_size[2:]]
449449
reshaped_size = [batch_size, channels] + sum([[d // factor, factor] for d in input_size[2:]], [])

0 commit comments

Comments
 (0)
Please sign in to comment.