@@ -377,7 +377,7 @@ def pixelshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> torch
377
377
See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution".
378
378
379
379
Args:
380
- x: Input tensor
380
+ x: Input tensor with shape BCHW[D]
381
381
spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D
382
382
scale_factor: factor to rescale the spatial dimensions by, must be >=1
383
383
@@ -423,7 +423,7 @@ def pixelunshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> tor
423
423
See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution".
424
424
425
425
Args:
426
- x: Input tensor
426
+ x: Input tensor with shape BCHW[D]
427
427
spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D
428
428
scale_factor: factor to reduce the spatial dimensions by, must be >=1
429
429
@@ -443,7 +443,7 @@ def pixelunshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> tor
443
443
444
444
if any (d % factor != 0 for d in input_size [2 :]):
445
445
raise ValueError (
446
- f"All spatial dimensions must be divisible by factor { factor } . " f"Got spatial dimensions : { input_size [2 :]} "
446
+ f"All spatial dimensions must be divisible by factor { factor } . " f", spatial shape is : { input_size [2 :]} "
447
447
)
448
448
output_size = [batch_size , new_channels ] + [d // factor for d in input_size [2 :]]
449
449
reshaped_size = [batch_size , channels ] + sum ([[d // factor , factor ] for d in input_size [2 :]], [])
0 commit comments