From 3db93ce8e151f298609061a3e0589cf485baa8aa Mon Sep 17 00:00:00 2001 From: tisalon Date: Wed, 15 Jan 2025 09:18:08 +0100 Subject: [PATCH 01/67] Add new pixel unshuffle for SubPixelDownsample class --- monai/networks/utils.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 1b4cb220ae..0e198da7a9 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -49,6 +49,7 @@ "normal_init", "icnr_init", "pixelshuffle", + "pixelunshuffle", "eval_mode", "train_mode", "get_state_dict", @@ -411,6 +412,36 @@ def pixelshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> torch return x +def pixelunshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> torch.Tensor: + """ + Apply pixel unshuffle to the tensor `x` with spatial dimensions `spatial_dims` and scaling factor `scale_factor`. + Inverse operation of pixelshuffle. + + Args: + x: Input tensor + spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D + scale_factor: factor to reduce the spatial dimensions by, must be >=1 + + Returns: + Unshuffled version of `x`. + """ + dim, factor = spatial_dims, scale_factor + input_size = list(x.size()) + batch_size, channels = input_size[:2] + + output_channels = channels * (factor**dim) + output_spatial = [d // factor for d in input_size[2:]] + output_size = [batch_size, output_channels] + output_spatial + + x = x.reshape([batch_size, channels] + [factor] * dim + output_spatial) + + indices = list(range(2, 2 + 2 * dim)) + indices = indices[dim:] + indices[:dim] + permute_indices = [0, 1] + indices + + x = x.permute(permute_indices).reshape(output_size) + return x + @contextmanager def eval_mode(*nets: nn.Module): """ From 9693e04eff628416061606a332c2f849f05316ea Mon Sep 17 00:00:00 2001 From: tisalon Date: Wed, 15 Jan 2025 09:18:20 +0100 Subject: [PATCH 02/67] Add unit test for pixelunshuffle --- tests/test_pixelunshuffle.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 tests/test_pixelunshuffle.py diff --git a/tests/test_pixelunshuffle.py b/tests/test_pixelunshuffle.py new file mode 100644 index 0000000000..df40415721 --- /dev/null +++ b/tests/test_pixelunshuffle.py @@ -0,0 +1,35 @@ +import unittest +import torch +from monai.networks.utils import pixelunshuffle, pixelshuffle + +class TestPixelUnshuffle(unittest.TestCase): + + def test_2d_basic(self): + x = torch.randn(2, 4, 16, 16) + out = pixelunshuffle(x, spatial_dims=2, scale_factor=2) + self.assertEqual(out.shape, (2, 16, 8, 8)) + + def test_3d_basic(self): + x = torch.randn(2, 4, 16, 16, 16) + out = pixelunshuffle(x, spatial_dims=3, scale_factor=2) + self.assertEqual(out.shape, (2, 32, 8, 8, 8)) + + def test_inverse_pixelshuffle(self): + x = torch.randn(2, 4, 16, 16) + shuffled = pixelshuffle(x, spatial_dims=2, scale_factor=2) + unshuffled = pixelunshuffle(shuffled, spatial_dims=2, scale_factor=2) + torch.testing.assert_close(x, unshuffled) + + def test_compare_torch_pixel_unshuffle(self): + x = torch.randn(2, 4, 16, 16) + monai_out = pixelunshuffle(x, spatial_dims=2, scale_factor=2) + torch_out = torch.pixel_unshuffle(x, downscale_factor=2) + torch.testing.assert_close(monai_out, torch_out) + + def test_invalid_scale(self): + x = torch.randn(2, 4, 15, 15) + with self.assertRaises(RuntimeError): + pixelunshuffle(x, spatial_dims=2, scale_factor=2) + +if __name__ == "__main__": + unittest.main() From a89f2995f600952167014f7bafceecbd243d18fb Mon Sep 17 00:00:00 2001 From: tisalon Date: Wed, 15 Jan 2025 09:18:37 +0100 Subject: [PATCH 03/67] Add DownSample Modes --- monai/utils/enums.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 1fbf3ffa05..74ae829afd 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -24,6 +24,7 @@ "SplineMode", "InterpolateMode", "UpsampleMode", + "DownsampleMode", "BlendMode", "PytorchPadMode", "NdimageMode", @@ -181,6 +182,18 @@ class UpsampleMode(StrEnum): PIXELSHUFFLE = "pixelshuffle" +class DownsampleMode(StrEnum): + """ + See also: :py:class:`monai.networks.blocks.UpSample` + """ + + CONV = "conv" # e.g. using strided convolution + CONVGROUP = "convgroup" # e.g. using grouped strided convolution + PIXELUNSHUFFLE = "pixelunshuffle" + MAXPOOL = "maxpool" + AVGPOOL = "avgpool" + + class BlendMode(StrEnum): """ See also: :py:class:`monai.data.utils.compute_importance_map` From 450691f225d1860cd331f055daee506f398b0df6 Mon Sep 17 00:00:00 2001 From: tisalon Date: Thu, 16 Jan 2025 13:53:37 +0100 Subject: [PATCH 04/67] expand pixelunshuffle for 3D --- monai/networks/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 0e198da7a9..1bbb7daa0f 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -436,7 +436,7 @@ def pixelunshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> tor x = x.reshape([batch_size, channels] + [factor] * dim + output_spatial) indices = list(range(2, 2 + 2 * dim)) - indices = indices[dim:] + indices[:dim] + indices = indices[:dim] + indices[dim:] permute_indices = [0, 1] + indices x = x.permute(permute_indices).reshape(output_size) From d0920d85a6cb334cc8434f2affe9d2f4bc293bbb Mon Sep 17 00:00:00 2001 From: tisalon Date: Thu, 16 Jan 2025 13:54:07 +0100 Subject: [PATCH 05/67] increase testing for pixelunshuffle --- tests/test_pixelunshuffle.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/tests/test_pixelunshuffle.py b/tests/test_pixelunshuffle.py index df40415721..9450aaf273 100644 --- a/tests/test_pixelunshuffle.py +++ b/tests/test_pixelunshuffle.py @@ -14,17 +14,21 @@ def test_3d_basic(self): out = pixelunshuffle(x, spatial_dims=3, scale_factor=2) self.assertEqual(out.shape, (2, 32, 8, 8, 8)) - def test_inverse_pixelshuffle(self): - x = torch.randn(2, 4, 16, 16) - shuffled = pixelshuffle(x, spatial_dims=2, scale_factor=2) - unshuffled = pixelunshuffle(shuffled, spatial_dims=2, scale_factor=2) - torch.testing.assert_close(x, unshuffled) + def test_non_square_input(self): + x = torch.arange(192).reshape(1, 2, 12, 8) + out = pixelunshuffle(x, spatial_dims=2, scale_factor=2) + torch.testing.assert_close(out, torch.pixel_unshuffle(x, 2)) - def test_compare_torch_pixel_unshuffle(self): - x = torch.randn(2, 4, 16, 16) - monai_out = pixelunshuffle(x, spatial_dims=2, scale_factor=2) - torch_out = torch.pixel_unshuffle(x, downscale_factor=2) - torch.testing.assert_close(monai_out, torch_out) + def test_different_scale_factor(self): + x = torch.arange(360).reshape(1, 2, 12, 15) + out = pixelunshuffle(x, spatial_dims=2, scale_factor=3) + torch.testing.assert_close(out, torch.pixel_unshuffle(x, 3)) + + def test_inverse_operation(self): + x = torch.arange(4096).reshape(1, 8, 8, 8, 8) + shuffled = pixelshuffle(x, spatial_dims=3, scale_factor=2) + unshuffled = pixelunshuffle(shuffled, spatial_dims=3, scale_factor=2) + torch.testing.assert_close(x, unshuffled) def test_invalid_scale(self): x = torch.randn(2, 4, 15, 15) From 1a48d4de61fbe239b8f9f2321c25e988a1ec550b Mon Sep 17 00:00:00 2001 From: tisalon Date: Thu, 16 Jan 2025 13:55:14 +0100 Subject: [PATCH 06/67] expand pixelunshuffle for 3D images --- monai/networks/utils.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 1bbb7daa0f..946bb6b824 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -417,29 +417,41 @@ def pixelunshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> tor Apply pixel unshuffle to the tensor `x` with spatial dimensions `spatial_dims` and scaling factor `scale_factor`. Inverse operation of pixelshuffle. + See: Shi et al., 2016, "Real-Time Single Image and Video Super-Resolution + Using an Efficient Sub-Pixel Convolutional Neural Network." + + See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution". + Args: x: Input tensor spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D scale_factor: factor to reduce the spatial dimensions by, must be >=1 Returns: - Unshuffled version of `x`. + Unshuffled version of `x` with shape (B, C*(r**d), H/r, W/r) for 2D + or (B, C*(r**d), D/r, H/r, W/r) for 3D, where r is the scale_factor + and d is spatial_dims. + + Raises: + ValueError: When spatial dimensions are not divisible by scale_factor """ dim, factor = spatial_dims, scale_factor input_size = list(x.size()) batch_size, channels = input_size[:2] - - output_channels = channels * (factor**dim) - output_spatial = [d // factor for d in input_size[2:]] - output_size = [batch_size, output_channels] + output_spatial + scale_factor_mult = factor**dim + new_channels = channels * scale_factor_mult - x = x.reshape([batch_size, channels] + [factor] * dim + output_spatial) - - indices = list(range(2, 2 + 2 * dim)) - indices = indices[:dim] + indices[dim:] - permute_indices = [0, 1] + indices + if any(d % factor != 0 for d in input_size[2:]): + raise ValueError( + f"All spatial dimensions must be divisible by factor {factor}. " + f"Got spatial dimensions: {input_size[2:]}" + ) + output_size = [batch_size, new_channels] + [d // factor for d in input_size[2:]] + reshaped_size = [batch_size, channels] + sum([[d // factor, factor] for d in input_size[2:]], []) - x = x.permute(permute_indices).reshape(output_size) + permute_indices = [0, 1] + [(2 * i + 3) for i in range(spatial_dims)] + [(2 * i + 2) for i in range(spatial_dims)] + x=x.reshape(reshaped_size).permute(permute_indices) + x=x.reshape(output_size) return x @contextmanager From fe47807da520b1dd0f93dc7fc719eb0e092296ff Mon Sep 17 00:00:00 2001 From: tisalon Date: Thu, 16 Jan 2025 14:28:36 +0100 Subject: [PATCH 07/67] add SubpixelDownsample and tests --- monai/networks/blocks/downsample.py | 88 ++++++++++++++++++++++++++++- tests/test_downsample_block.py | 69 +++++++++++++++++++++- 2 files changed, 154 insertions(+), 3 deletions(-) diff --git a/monai/networks/blocks/downsample.py b/monai/networks/blocks/downsample.py index 2a6a60ff8a..37c2f9f93c 100644 --- a/monai/networks/blocks/downsample.py +++ b/monai/networks/blocks/downsample.py @@ -16,9 +16,11 @@ import torch import torch.nn as nn -from monai.networks.layers.factories import Pool -from monai.utils import ensure_tuple_rep +from monai.networks.layers.factories import Conv, Pool +from monai.networks.utils import pixelunshuffle +from monai.utils import InterpolateMode, DownsampleMode, ensure_tuple_rep, look_up_option +__all__ = ["MaxAvgPool", "DownSample", "SubpixelDownsample"] class MaxAvgPool(nn.Module): """ @@ -61,3 +63,85 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Tensor in shape (batch, 2*channel, spatial_1[, spatial_2, ...]). """ return torch.cat([self.max_pool(x), self.avg_pool(x)], dim=1) + + +class SubpixelDownsample(nn.Module): + """ + Downsample via using a subpixel CNN. This module supports 1D, 2D and 3D input images. + The module consists of two parts. First, a convolutional layer is employed + to adjust the number of channels. Secondly, a pixel unshuffle manipulation + rearranges the spatial information into channel space, effectively reducing + spatial dimensions while increasing channel depth. + + The pixel unshuffle operation is the inverse of pixel shuffle, rearranging dimensions + from (B, C, H*r, W*r) to (B, C*r², H, W). + Example: (1, 1, 4, 4) with r=2 becomes (1, 4, 2, 2). + + See: Shi et al., 2016, "Real-Time Single Image and Video Super-Resolution + Using a nEfficient Sub-Pixel Convolutional Neural Network." + + The pixel unshuffle mechanism is the inverse operation of: + https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/blocks/upsample.py + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int | None, + out_channels: int | None = None, + scale_factor: int = 2, + conv_block: nn.Module | str | None = "default", + bias: bool = True, + ) -> None: + """ + Args: + spatial_dims: number of spatial dimensions of the input image. + in_channels: number of channels of the input image. + out_channels: optional number of channels of the output image. + scale_factor: factor to reduce the spatial dimensions by. Defaults to 2. + conv_block: a conv block to adjust channels before downsampling. Defaults to None. + - When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized. + - When ``conv_block`` is an ``nn.module``, + please ensure the input number of channels matches requirements. + bias: whether to have a bias term in the default conv_block. Defaults to True. + """ + super().__init__() + + if scale_factor <= 0: + raise ValueError(f"The `scale_factor` multiplier must be an integer greater than 0, got {scale_factor}.") + + self.dimensions = spatial_dims + self.scale_factor = scale_factor + + if conv_block == "default": + if not in_channels: + raise ValueError("in_channels need to be specified.") + out_channels = out_channels or in_channels + self.conv_block = Conv[Conv.CONV, self.dimensions]( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=bias + ) + elif conv_block is None: + self.conv_block = nn.Identity() + else: + self.conv_block = conv_block + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: Tensor in shape (batch, channel, spatial_1[, spatial_2, ...). + Returns: + Tensor with reduced spatial dimensions and increased channel depth. + """ + x = self.conv_block(x) + if not all(d % self.scale_factor == 0 for d in x.shape[2:]): + raise ValueError( + f"All spatial dimensions {x.shape[2:]} must be evenly " + f"divisible by scale_factor {self.scale_factor}" + ) + x = pixelunshuffle(x, self.dimensions, self.scale_factor) + return x \ No newline at end of file diff --git a/tests/test_downsample_block.py b/tests/test_downsample_block.py index 34afa248ad..0b81bc9fe7 100644 --- a/tests/test_downsample_block.py +++ b/tests/test_downsample_block.py @@ -13,11 +13,19 @@ import unittest +import os +import sys + +# Add project root to Python path +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.insert(0, project_root) + + import torch from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.blocks import MaxAvgPool +from monai.networks.blocks import MaxAvgPool, SubpixelDownsample, SubpixelUpsample TEST_CASES = [ [{"spatial_dims": 2, "kernel_size": 2}, (7, 4, 64, 48), (7, 8, 32, 24)], # 4-channel 2D, batch 7 @@ -35,6 +43,11 @@ ], ] +TEST_CASES_SUBPIXEL = [ + [{"spatial_dims": 2, "in_channels": 1, "scale_factor": 2}, (1, 1, 8, 8), (1, 4, 4, 4)], + [{"spatial_dims": 3, "in_channels": 2, "scale_factor": 2}, (1, 2, 8, 8, 8), (1, 16, 4, 4, 4)], + [{"spatial_dims": 1, "in_channels": 3, "scale_factor": 2}, (1, 3, 8), (1, 6, 4)], +] class TestMaxAvgPool(unittest.TestCase): @@ -46,5 +59,59 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) +class TestSubpixelDownsample(unittest.TestCase): + + @parameterized.expand(TEST_CASES_SUBPIXEL) + def test_shape(self, input_param, input_shape, expected_shape): + downsampler = SubpixelDownsample(**input_param) + with eval_mode(downsampler): + result = downsampler(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_predefined_tensor(self): + test_tensor = torch.arange(4).view(4, 1, 1).repeat(1, 4, 4) + test_tensor = test_tensor.unsqueeze(0) + + downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=None) + with eval_mode(downsampler): + result = downsampler(test_tensor) + self.assertEqual(result.shape, (1, 16, 2, 2)) + self.assertTrue(torch.all(result[0, 0:3] == 0)) + self.assertTrue(torch.all(result[0, 4:7] == 1)) + self.assertTrue(torch.all(result[0, 8:11] == 2)) + self.assertTrue(torch.all(result[0, 12:15] == 3)) + + def test_reconstruction_2D(self): + input_tensor = torch.randn(1, 1, 4, 4) + down = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=None) + up = SubpixelUpsample(spatial_dims=2, in_channels=4, scale_factor=2, conv_block=None, apply_pad_pool=False) + with eval_mode(down), eval_mode(up): + downsampled = down(input_tensor) + reconstructed = up(downsampled) + self.assertTrue(torch.allclose(input_tensor, reconstructed, rtol=1e-5)) + + def test_reconstruction_3D(self): + input_tensor = torch.randn(1, 1, 4, 4, 4) + down = SubpixelDownsample(spatial_dims=3, in_channels=1, scale_factor=2, conv_block=None) + up = SubpixelUpsample(spatial_dims=3, in_channels=4, scale_factor=2, conv_block=None, apply_pad_pool=False) + with eval_mode(down), eval_mode(up): + downsampled = down(input_tensor) + reconstructed = up(downsampled) + self.assertTrue(torch.allclose(input_tensor, reconstructed, rtol=1e-5)) + + + def test_invalid_spatial_size(self): + downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2) + with self.assertRaises(ValueError): + downsampler(torch.randn(1, 1, 3, 4)) + + def test_custom_conv_block(self): + custom_conv = torch.nn.Conv2d(1, 2, kernel_size=3, padding=1) + downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=custom_conv) + with eval_mode(downsampler): + result = downsampler(torch.randn(1, 1, 4, 4)) + self.assertEqual(result.shape, (1, 8, 2, 2)) + + if __name__ == "__main__": unittest.main() From 86155cd967560bb2883655f944a00be44dd38b9c Mon Sep 17 00:00:00 2001 From: tisalon Date: Thu, 16 Jan 2025 14:50:35 +0100 Subject: [PATCH 08/67] Add DownSample Class --- tests/test_downsample_block.py | 69 +++++++++++++++++++++++++++++----- 1 file changed, 60 insertions(+), 9 deletions(-) diff --git a/tests/test_downsample_block.py b/tests/test_downsample_block.py index 0b81bc9fe7..df5c6076ce 100644 --- a/tests/test_downsample_block.py +++ b/tests/test_downsample_block.py @@ -13,19 +13,11 @@ import unittest -import os -import sys - -# Add project root to Python path -project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) -sys.path.insert(0, project_root) - - import torch from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.blocks import MaxAvgPool, SubpixelDownsample, SubpixelUpsample +from monai.networks.blocks import MaxAvgPool, SubpixelDownsample, SubpixelUpsample, DownSample TEST_CASES = [ [{"spatial_dims": 2, "kernel_size": 2}, (7, 4, 64, 48), (7, 8, 32, 24)], # 4-channel 2D, batch 7 @@ -49,6 +41,35 @@ [{"spatial_dims": 1, "in_channels": 3, "scale_factor": 2}, (1, 3, 8), (1, 6, 4)], ] +TEST_CASES_DOWNSAMPLE = [ + [ + {"spatial_dims": 2, "in_channels": 4, "mode": "conv"}, + (1, 4, 16, 16), + (1, 4, 8, 8), + ], + [ + {"spatial_dims": 2, "in_channels": 4, "out_channels": 8, "mode": "convgroup"}, + (1, 4, 16, 16), + (1, 8, 8, 8), + ], + [ + {"spatial_dims": 3, "in_channels": 2, "mode": "maxpool"}, + (1, 2, 16, 16, 16), + (1, 2, 8, 8, 8), + ], + [ + {"spatial_dims": 2, "in_channels": 4, "mode": "avgpool"}, + (1, 4, 16, 16), + (1, 4, 8, 8), + ], + [ + {"spatial_dims": 2, "in_channels": 1, "mode": "pixelunshuffle"}, + (1, 1, 16, 16), + (1, 4, 8, 8), + ], +] + + class TestMaxAvgPool(unittest.TestCase): @parameterized.expand(TEST_CASES) @@ -113,5 +134,35 @@ def test_custom_conv_block(self): self.assertEqual(result.shape, (1, 8, 2, 2)) +class TestDownSample(unittest.TestCase): + @parameterized.expand(TEST_CASES_DOWNSAMPLE) + def test_shape(self, input_param, input_shape, expected_shape): + net = DownSample(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_pre_post_conv(self): + net = DownSample( + spatial_dims=2, + in_channels=4, + out_channels=8, + mode="maxpool", + pre_conv="default", + post_conv=torch.nn.Conv2d(8, 16, 1), + ) + with eval_mode(net): + result = net(torch.randn(1, 4, 16, 16)) + self.assertEqual(result.shape, (1, 16, 8, 8)) + + def test_invalid_mode(self): + with self.assertRaises(ValueError): + DownSample(spatial_dims=2, in_channels=4, mode="invalid") + + def test_missing_channels(self): + with self.assertRaises(ValueError): + DownSample(spatial_dims=2, mode="conv") + + if __name__ == "__main__": unittest.main() From 137a7f21e6125fb9f76b7799f7104ba62466d89b Mon Sep 17 00:00:00 2001 From: tisalon Date: Thu, 16 Jan 2025 14:50:44 +0100 Subject: [PATCH 09/67] Add tests for Downsample --- monai/networks/blocks/downsample.py | 167 ++++++++++++++++++++++++++++ 1 file changed, 167 insertions(+) diff --git a/monai/networks/blocks/downsample.py b/monai/networks/blocks/downsample.py index 37c2f9f93c..bda3aeb961 100644 --- a/monai/networks/blocks/downsample.py +++ b/monai/networks/blocks/downsample.py @@ -65,6 +65,173 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.cat([self.max_pool(x), self.avg_pool(x)], dim=1) +class DownSample(nn.Sequential): + """ + Downsamples data by `scale_factor`. + + Supported modes are: + + - "conv": uses a strided convolution for learnable downsampling. + - "convgroup": uses a grouped strided convolution for efficient feature reduction. + - "nontrainable": uses :py:class:`torch.nn.Upsample` with inverse scale factor. + - "pixelunshuffle": uses :py:class:`monai.networks.blocks.PixelUnshuffle` for channel-space rearrangement. + + This operation will cause non-deterministic behavior when ``mode`` is ``DownsampleMode.NONTRAINABLE``. + Please check the link below for more details: + https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms + + This module can optionally take a pre-convolution + (often used to map the number of features from `in_channels` to `out_channels`). + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int | None = None, + out_channels: int | None = None, + scale_factor: Sequence[float] | float = 2, + kernel_size: Sequence[float] | float | None = None, + mode: str = "conv", # conv, convgroup, nontrainable, pixelunshuffle + pre_conv: nn.Module | str | None = "default", + post_conv: nn.Module | None = None, + bias: bool = True, + ) -> None: + """ + Downsamples data by `scale_factor`. + Supported modes are: + + - "conv": uses a strided convolution for learnable downsampling. + - "convgroup": uses a grouped strided convolution for efficient feature reduction. + - "maxpool": uses maxpooling for non-learnable downsampling. + - "avgpool": uses average pooling for non-learnable downsampling. + - "pixelunshuffle": uses :py:class:`monai.networks.blocks.SubpixelDownsample`. + + This operation will cause non-deterministic behavior when ``mode`` is ``DownsampleMode.NONTRAINABLE``. + Please check the link below for more details: + https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms + + This module can optionally take a pre-convolution and post-convolution + (often used to map the number of features from `in_channels` to `out_channels`). + + Args: + spatial_dims: number of spatial dimensions of the input image. + in_channels: number of channels of the input image. + out_channels: number of channels of the output image. Defaults to `in_channels`. + scale_factor: multiplier for spatial size reduction. Has to match input size if it is a tuple. Defaults to 2. + kernel_size: kernel size used during convolutions. Defaults to `scale_factor`. + mode: {``"conv"``, ``"convgroup"``, ``"maxpool"``, ``"avgpool"``, ``"pixelunshuffle"``}. Defaults to ``"conv"``. + pre_conv: a conv block applied before downsampling. Defaults to "default". + When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized. + Only used in the "maxpool", "avgpool" or "pixelunshuffle" modes. + post_conv: a conv block applied after downsampling. Defaults to None. Only used in the "maxpool" and "avgpool" modes. + bias: whether to have a bias term in the default preconv and conv layers. Defaults to True. + """ + super().__init__() + + scale_factor_ = ensure_tuple_rep(scale_factor, spatial_dims) + down_mode = look_up_option(mode, DownsampleMode) + + if not kernel_size: + kernel_size_ = scale_factor_ + padding = 0 + else: + kernel_size_ = ensure_tuple_rep(kernel_size, spatial_dims) + padding = tuple((k - 1) // 2 for k in kernel_size_) + + if down_mode == "conv": + if not in_channels: + raise ValueError("in_channels needs to be specified in conv mode") + self.add_module( + "conv", + Conv[Conv.CONV, spatial_dims]( + in_channels=in_channels, + out_channels=out_channels or in_channels, + kernel_size=kernel_size_, + stride=scale_factor_, + padding=padding, + bias=bias, + ), + ) + elif down_mode == "convgroup": + if not in_channels: + raise ValueError("in_channels needs to be specified") + if out_channels is None: + out_channels = in_channels + groups = in_channels if out_channels % in_channels == 0 else 1 + self.add_module( + "convgroup", + Conv[Conv.CONV, spatial_dims]( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size_, + stride=scale_factor_, + padding=padding, + groups=groups, + bias=bias, + ), + ) + elif down_mode == DownsampleMode.MAXPOOL: + if pre_conv == "default" and (out_channels != in_channels): + if not in_channels: + raise ValueError("in_channels needs to be specified") + self.add_module( + "preconv", + Conv[Conv.CONV, spatial_dims]( + in_channels=in_channels, + out_channels=out_channels or in_channels, + kernel_size=1, + bias=bias + ), + ) + self.add_module( + "maxpool", + Pool[Pool.MAX, spatial_dims]( + kernel_size=kernel_size_, + stride=scale_factor_, + padding=padding + ) + ) + if post_conv: + self.add_module("postconv", post_conv) + + elif down_mode == DownsampleMode.AVGPOOL: + if pre_conv == "default" and (out_channels != in_channels): + if not in_channels: + raise ValueError("in_channels needs to be specified") + self.add_module( + "preconv", + Conv[Conv.CONV, spatial_dims]( + in_channels=in_channels, + out_channels=out_channels or in_channels, + kernel_size=1, + bias=bias + ), + ) + self.add_module( + "avgpool", + Pool[Pool.AVG, spatial_dims]( + kernel_size=kernel_size_, + stride=scale_factor_, + padding=padding + ) + ) + if post_conv: + self.add_module("postconv", post_conv) + + elif down_mode == "pixelunshuffle": + self.add_module( + "pixelunshuffle", + SubpixelDownsample( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + scale_factor=scale_factor_[0], + conv_block=pre_conv, + bias=bias, + ), + ) + + class SubpixelDownsample(nn.Module): """ Downsample via using a subpixel CNN. This module supports 1D, 2D and 3D input images. From fb17baf6a4370af0b99a098b2cca9dd35010e628 Mon Sep 17 00:00:00 2001 From: tisalon Date: Thu, 16 Jan 2025 14:50:59 +0100 Subject: [PATCH 10/67] add exports to __init__ --- monai/networks/blocks/__init__.py | 2 +- monai/utils/__init__.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index 499caf2e0f..dd1c7a256a 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -20,7 +20,7 @@ from .crossattention import CrossAttentionBlock from .denseblock import ConvDenseBlock, DenseBlock from .dints_block import ActiConvNormBlock, FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock -from .downsample import MaxAvgPool +from .downsample import MaxAvgPool, DownSample, SubpixelDownsample from .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding from .encoder import BaseEncoder from .fcn import FCN, GCN, MCFCN, Refine diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 8f2f400b5d..eb8bd451f8 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -62,6 +62,7 @@ TraceStatusKeys, TransformBackends, UpsampleMode, + DownsampleMode, Weight, WSIPatchKeys, ) From 5ff0baa783de08354b87ef37a487526452636103 Mon Sep 17 00:00:00 2001 From: tisalon Date: Thu, 16 Jan 2025 15:16:03 +0100 Subject: [PATCH 11/67] Include test to compare with Conv + unshuffle from original restormer --- tests/test_downsample_block.py | 41 ++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/test_downsample_block.py b/tests/test_downsample_block.py index df5c6076ce..be0ac4a9b4 100644 --- a/tests/test_downsample_block.py +++ b/tests/test_downsample_block.py @@ -13,6 +13,15 @@ import unittest +import os +import sys + +# Add project root to Python path +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.insert(0, project_root) + + + import torch from parameterized import parameterized @@ -155,6 +164,38 @@ def test_pre_post_conv(self): result = net(torch.randn(1, 4, 16, 16)) self.assertEqual(result.shape, (1, 16, 8, 8)) + def test_pixelunshuffle_equivalence(self): + class DownSample_local(torch.nn.Module): + def __init__(self, n_feat: int): + super().__init__() + self.conv = torch.nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False) + self.pixelunshuffle = torch.nn.PixelUnshuffle(2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) + return self.pixelunshuffle(x) + n_feat = 2 + x = torch.randn(1, n_feat, 64, 64) + + fix_weight_conv = torch.nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False) + + monai_down = DownSample( + spatial_dims=2, + in_channels=n_feat, + out_channels=n_feat//2, + mode="pixelunshuffle", + pre_conv=fix_weight_conv + ) + + local_down = DownSample_local(n_feat) + local_down.conv.weight.data = fix_weight_conv.weight.data.clone() + + with eval_mode(monai_down), eval_mode(local_down): + out_monai = monai_down(x) + out_local = local_down(x) + + self.assertTrue(torch.allclose(out_monai, out_local, rtol=1e-5)) + def test_invalid_mode(self): with self.assertRaises(ValueError): DownSample(spatial_dims=2, in_channels=4, mode="invalid") From 2566db179645bd447723bd2ed1813b7d6a6c4fdf Mon Sep 17 00:00:00 2001 From: tisalon Date: Thu, 16 Jan 2025 15:16:51 +0100 Subject: [PATCH 12/67] remove relative imports --- tests/test_downsample_block.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/tests/test_downsample_block.py b/tests/test_downsample_block.py index be0ac4a9b4..1775e5609f 100644 --- a/tests/test_downsample_block.py +++ b/tests/test_downsample_block.py @@ -11,19 +11,10 @@ from __future__ import annotations -import unittest - -import os -import sys - -# Add project root to Python path -project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) -sys.path.insert(0, project_root) - - +from collections.abc import Sequence import torch -from parameterized import parameterized +import torch.nn as nn from monai.networks import eval_mode from monai.networks.blocks import MaxAvgPool, SubpixelDownsample, SubpixelUpsample, DownSample From ac4047b26234803c102c555b8fb8715b2bfc5b8e Mon Sep 17 00:00:00 2001 From: tisalon Date: Fri, 17 Jan 2025 12:51:30 +0100 Subject: [PATCH 13/67] Create restormer with Downsampler/Upsampler using monai implementation --- monai/networks/nets/restormer.py | 359 +++++++++++++++++++++++++++++++ 1 file changed, 359 insertions(+) create mode 100644 monai/networks/nets/restormer.py diff --git a/monai/networks/nets/restormer.py b/monai/networks/nets/restormer.py new file mode 100644 index 0000000000..9563dd5958 --- /dev/null +++ b/monai/networks/nets/restormer.py @@ -0,0 +1,359 @@ +import os +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../'))) + + + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from monai.networks.blocks.upsample import UpSample, UpsampleMode +from monai.networks.blocks.downsample import DownSample, DownsampleMode +from monai.networks.layers.factories import Norm +from einops import rearrange + +class FeedForward(nn.Module): + """Gated-DConv Feed-Forward Network (GDFN) that controls feature flow using gating mechanism. + Uses depth-wise convolutions for local context mixing and GELU-activated gating for refined feature selection.""" + def __init__(self, dim: int, ffn_expansion_factor: float, bias: bool): + super().__init__() + hidden_features = int(dim * ffn_expansion_factor) + self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) + self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, + stride=1, padding=1, groups=hidden_features*2, bias=bias) + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + return self.project_out(F.gelu(x1) * x2) + +class Attention(nn.Module): + """Multi-DConv Head Transposed Self-Attention (MDTA) Differs from standard self-attention + by operating on feature channels instead of spatial dimensions. Incorporates depth-wise + convolutions for local mixing before attention, achieving linear complexity vs quadratic + in vanilla attention.""" + def __init__(self, dim: int, num_heads: int, bias: bool, flash_attention: bool = False): + super().__init__() + if flash_attention and not hasattr(F, 'scaled_dot_product_attention'): + raise ValueError("Flash attention not available") + + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + self.flash_attention = flash_attention + self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, + padding=1, groups=dim*3, bias=bias) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + self._attention_fn = self._get_attention_fn() + + def _get_attention_fn(self): + if self.flash_attention: + return self._flash_attention + return self._normal_attention + def _flash_attention(self, q, k, v): + """Flash attention implementation using scaled dot-product attention.""" + scale = float(self.temperature.mean()) + out = F.scaled_dot_product_attention( + q, + k, + v, + scale=scale, + dropout_p=0.0, + is_causal=False + ) + return out + + def _normal_attention(self, q, k, v): + """Attention matrix multiplication with depth-wise convolutions.""" + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + return attn @ v + def forward(self, x): + """Forward pass for MDTA attention. + 1. Apply depth-wise convolutions to Q, K, V + 2. Reshape Q, K, V for multi-head attention + 3. Compute attention matrix using flash or normal attention + 4. Reshape and project out attention output""" + b,c,h,w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + q,k,v = qkv.chunk(3, dim=1) + + q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + out = self._attention_fn(q, k, v) + out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) + out = self.project_out(out) + return out + + +class TransformerBlock(nn.Module): + """Basic transformer unit combining MDTA and GDFN with skip connections. + Unlike standard transformers that use LayerNorm, this block uses Instance Norm + for better adaptation to image restoration tasks.""" + + def __init__(self, dim: int, num_heads: int, ffn_expansion_factor: float, + bias: bool, LayerNorm_type: str, flash_attention: bool = False): + super().__init__() + use_bias = LayerNorm_type != 'BiasFree' + self.norm1 = Norm[Norm.INSTANCE, 2](dim, affine=use_bias) + self.attn = Attention(dim, num_heads, bias, flash_attention) + self.norm2 = Norm[Norm.INSTANCE, 2](dim, affine=use_bias) + self.ffn = FeedForward(dim, ffn_expansion_factor, bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + #print(f'x shape in transformer block: {x.shape}') + + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + return x + + +class OverlapPatchEmbed(nn.Module): + """Initial feature extraction using overlapped convolutions. + Unlike standard patch embeddings that use non-overlapping patches, + this approach maintains spatial continuity through 3x3 convolutions.""" + + def __init__(self, in_c: int = 3, embed_dim: int = 48, bias: bool = False): + super().__init__() + self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, + stride=1, padding=1, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.proj(x) + + + +class DownSample_local(nn.Module): + """Downsampling module that halves spatial dimensions while doubling channels. + Uses PixelUnshuffle for efficient feature map manipulation.""" + + def __init__(self, n_feat: int): + super().__init__() + self.body = nn.Sequential( + nn.Conv2d(n_feat, n_feat//2, kernel_size=3, + stride=1, padding=1, bias=False), + nn.PixelUnshuffle(2) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.body(x) + + + +class Restormer_new(nn.Module): + """Restormer: Efficient Transformer for High-Resolution Image Restoration. + + Implements a U-Net style architecture with transformer blocks, combining: + - Multi-scale feature processing through progressive down/upsampling + - Efficient attention via MDTA blocks + - Local feature mixing through GDFN + - Skip connections for preserving spatial details + + Architecture: + - Encoder: Progressive feature downsampling with increasing channels + - Latent: Deep feature processing at lowest resolution + - Decoder: Progressive upsampling with skip connections + - Refinement: Final feature enhancement + """ + def __init__(self, + inp_channels=3, + out_channels=3, + dim=48, + num_blocks=[1, 1, 1, 1], + heads=[1, 1, 1, 1], + num_refinement_blocks=4, + ffn_expansion_factor=2.66, + bias=False, + LayerNorm_type='WithBias', + dual_pixel_task=False, + flash_attention=False): + super().__init__() + """Initialize Restormer model. + + Args: + inp_channels: Number of input image channels + out_channels: Number of output image channels + dim: Base feature dimension + num_blocks: Number of transformer blocks at each scale + num_refinement_blocks: Number of final refinement blocks + heads: Number of attention heads at each scale + ffn_expansion_factor: Expansion factor for feed-forward network + bias: Whether to use bias in convolutions + LayerNorm_type: Type of normalization ('WithBias' or 'BiasFree') + dual_pixel_task: Enable dual-pixel specific processing + flash_attention: Use flash attention if available + """ + # Check input parameters + assert len(num_blocks) > 1, "Number of blocks must be greater than 1" + assert len(num_blocks) == len(heads), "Number of blocks and heads must be equal" + assert all([n > 0 for n in num_blocks]), "Number of blocks must be greater than 0" + + # Initial feature extraction + self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + self.encoder_levels = nn.ModuleList() + self.downsamples = nn.ModuleList() + self.decoder_levels = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.reduce_channels = nn.ModuleList() + num_steps = len(num_blocks) - 1 + self.num_steps = num_steps + + # Define encoder levels + for n in range(num_steps): + current_dim = dim * 2**n + self.encoder_levels.append( + nn.Sequential(*[ + TransformerBlock( + dim=current_dim, + num_heads=heads[n], + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + LayerNorm_type=LayerNorm_type, + flash_attention=flash_attention + ) for _ in range(num_blocks[n]) + ]) + ) + print(f' Encoder layer {n}') + print(f'input channels to the downsampler: {current_dim}') + print(f'output channels from the downsampler: {current_dim//2}') + self.downsamples.append( + #DownSample_local(current_dim) + DownSample( + spatial_dims=2, + in_channels=current_dim, + out_channels=current_dim//2, + mode=DownsampleMode.PIXELUNSHUFFLE, + scale_factor=2, + bias=bias, + ) + ) + + # Define latent space + latent_dim = dim * 2**num_steps + self.latent = nn.Sequential(*[ + TransformerBlock( + dim=latent_dim, + num_heads=heads[num_steps], + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + LayerNorm_type=LayerNorm_type, + flash_attention=flash_attention + ) for _ in range(num_blocks[num_steps]) + ]) + + # Define decoder levels + for n in reversed(range(num_steps)): + current_dim = dim * 2**n + next_dim = dim * 2**(n+1) + self.upsamples.append( + UpSample( + spatial_dims=2, + in_channels=next_dim, + out_channels=(next_dim//2), + mode=UpsampleMode.PIXELSHUFFLE, + scale_factor=2, + bias=bias, + apply_pad_pool=False + ) + ) + + # Reduce channel layers to deal with skip connections + if n != 0: + self.reduce_channels.append( + nn.Conv2d(next_dim, current_dim, kernel_size=1, bias=bias) + ) + decoder_dim = current_dim + else: + decoder_dim = next_dim + + self.decoder_levels.append( + nn.Sequential(*[ + TransformerBlock( + dim=decoder_dim, + num_heads=heads[n], + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + LayerNorm_type=LayerNorm_type, + flash_attention=flash_attention + ) for _ in range(num_blocks[n]) + ]) + ) + + # Final refinement and output + self.refinement = nn.Sequential(*[ + TransformerBlock( + dim=decoder_dim, + num_heads=heads[0], + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + LayerNorm_type=LayerNorm_type, + flash_attention=flash_attention + ) for _ in range(num_refinement_blocks) + ]) + self.dual_pixel_task = dual_pixel_task + if self.dual_pixel_task: + self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias) + + self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) + + #print(f'downsample layer in new model: {self.downsamples}') + print(f'======================') + print(f'======================') + #print(f'upsamples layer in new model: {self.upsamples}') + + def forward(self, x): + """Forward pass of Restormer. + Processes input through encoder-decoder architecture with skip connections. + Args: + inp_img: Input image tensor of shape (B, C, H, W) + + Returns: + Restored image tensor of shape (B, C, H, W) + """ + assert x.shape[-1] > 2 ** self.num_steps and x.shape[-2] > 2 ** self.num_steps, "Input dimensions should be larger than 2^number_of_step" + + # Patch embedding + x = self.patch_embed(x) + skip_connections = [] + + # Encoding path + for idx, (encoder, downsample) in enumerate(zip(self.encoder_levels, self.downsamples)): + print(f'image shape at input: {x.shape}') + x = encoder(x) + skip_connections.append(x) + print(f'x shape in new model encoder: {x.shape}') + x = downsample(x) + print(f'x shape in new model downsample: {x.shape}') + + # Latent space + x = self.latent(x) + + # Decoding path + for idx in range(len(self.decoder_levels)): + x = self.upsamples[idx](x) + x = torch.concat([x, skip_connections[-(idx + 1)]], 1) + if idx < len(self.decoder_levels) - 1: + x = self.reduce_channels[idx](x) + x = self.decoder_levels[idx](x) + + # Final refinement + x = self.refinement(x) + + if self.dual_pixel_task: + x = x + self.skip_conv(skip_connections[0]) + x = self.output(x) + else: + x = self.output(x) + + return x + + + From 2b74270f3a04944be7c3af999e20ff123d92c0b2 Mon Sep 17 00:00:00 2001 From: tisalon Date: Fri, 17 Jan 2025 15:19:58 +0100 Subject: [PATCH 14/67] Add channel attention block --- monai/networks/blocks/cablock.py | 167 +++++++++++++++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 monai/networks/blocks/cablock.py diff --git a/monai/networks/blocks/cablock.py b/monai/networks/blocks/cablock.py new file mode 100644 index 0000000000..adcdb33911 --- /dev/null +++ b/monai/networks/blocks/cablock.py @@ -0,0 +1,167 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + + +import torch +import torch.nn as nn +import torch.nn.functional as F +from monai.networks.blocks.convolutions import Convolution +from einops import rearrange + +__all__ = ["FeedForward"] + + +class FeedForward(nn.Module): + """Gated-DConv Feed-Forward Network (GDFN) that controls feature flow using gating mechanism. + Uses depth-wise convolutions for local context mixing and GELU-activated gating for refined feature selection.""" + + def __init__(self, spatial_dims: int, dim: int, ffn_expansion_factor: float, bias: bool): + super().__init__() + hidden_features = int(dim * ffn_expansion_factor) + + self.project_in = Convolution( + spatial_dims=spatial_dims, + in_channels=dim, + out_channels=hidden_features*2, + kernel_size=1, + bias=bias, + conv_only=True + ) + + self.dwconv = Convolution( + spatial_dims=spatial_dims, + in_channels=hidden_features*2, + out_channels=hidden_features*2, + kernel_size=3, + strides=1, + padding=1, + groups=hidden_features*2, + bias=bias, + conv_only=True + ) + + self.project_out = Convolution( + spatial_dims=spatial_dims, + in_channels=hidden_features, + out_channels=dim, + kernel_size=1, + bias=bias, + conv_only=True + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + return self.project_out(F.gelu(x1) * x2) + + +class CABlock(nn.Module): + """Multi-DConv Head Transposed Self-Attention (MDTA) Differs from standard self-attention + by operating on feature channels instead of spatial dimensions. Incorporates depth-wise + convolutions for local mixing before attention, achieving linear complexity vs quadratic + in vanilla attention.""" + def __init__(self, spatial_dims, dim: int, num_heads: int, bias: bool, flash_attention: bool = False): + super().__init__() + if flash_attention and not hasattr(F, 'scaled_dot_product_attention'): + raise ValueError("Flash attention not available") + if spatial_dims > 3: + raise ValueError(f"Only 2D and 3D inputs are supported. Got spatial_dims={spatial_dims}") + self.spatial_dims = spatial_dims + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + self.flash_attention = flash_attention + + self.qkv = Convolution( + spatial_dims=spatial_dims, + in_channels=dim, + out_channels=dim*3, + kernel_size=1, + bias=bias, + conv_only=True + ) + + self.qkv_dwconv = Convolution( + spatial_dims=spatial_dims, + in_channels=dim*3, + out_channels=dim*3, + kernel_size=3, + strides=1, + padding=1, + groups=dim*3, + bias=bias, + conv_only=True + ) + + self.project_out = Convolution( + spatial_dims=spatial_dims, + in_channels=dim, + out_channels=dim, + kernel_size=1, + bias=bias, + conv_only=True + ) + + self._attention_fn = self._get_attention_fn() + def _get_attention_fn(self): + if self.flash_attention: + return self._flash_attention + return self._normal_attention + def _flash_attention(self, q, k, v): + """Flash attention implementation using scaled dot-product attention.""" + scale = float(self.temperature.mean()) + out = F.scaled_dot_product_attention( + q, + k, + v, + scale=scale, + dropout_p=0.0, + is_causal=False + ) + return out + + def _normal_attention(self, q, k, v): + """Attention matrix multiplication with depth-wise convolutions.""" + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + return attn @ v + def forward(self, x): + """Forward pass for MDTA attention. + 1. Apply depth-wise convolutions to Q, K, V + 2. Reshape Q, K, V for multi-head attention + 3. Compute attention matrix using flash or normal attention + 4. Reshape and project out attention output""" + spatial_dims = x.shape[2:] + + # Project and mix + qkv = self.qkv_dwconv(self.qkv(x)) + q, k, v = qkv.chunk(3, dim=1) + + # Select attention + if self.spatial_dims == 2: + qkv_to_multihead = 'b (head c) h w -> b head c (h w)' + multihead_to_qkv = 'b head c (h w) -> b (head c) h w' + else: # dims == 3 + qkv_to_multihead = 'b (head c) d h w -> b head c (d h w)' + multihead_to_qkv = 'b head c (d h w) -> b (head c) d h w' + + # Reconstruct and project feature map + q = rearrange(q, qkv_to_multihead, head=self.num_heads) + k = rearrange(k, qkv_to_multihead, head=self.num_heads) + v = rearrange(v, qkv_to_multihead, head=self.num_heads) + + q = torch.nn.functional.normalize(q, dim=-1) + k = torch.nn.functional.normalize(k, dim=-1) + + out = self._attention_fn(q, k, v) + out = rearrange(out, multihead_to_qkv, head=self.num_heads, **dict(zip(['h','w'] if self.spatial_dims==2 else ['d','h','w'], spatial_dims))) + + return self.project_out(out) \ No newline at end of file From 9b745338a583f1523f9c291c906d46bf12008131 Mon Sep 17 00:00:00 2001 From: tisalon Date: Fri, 17 Jan 2025 15:21:30 +0100 Subject: [PATCH 15/67] add assembled restormer with MONAI convs for 3D --- monai/networks/nets/restormer.py | 210 +++++++++++-------------------- 1 file changed, 70 insertions(+), 140 deletions(-) diff --git a/monai/networks/nets/restormer.py b/monai/networks/nets/restormer.py index 9563dd5958..d7e91ed5fc 100644 --- a/monai/networks/nets/restormer.py +++ b/monai/networks/nets/restormer.py @@ -1,116 +1,41 @@ -import os -import sys - -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../'))) - - - +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations import torch import torch.nn as nn -import torch.nn.functional as F + + from monai.networks.blocks.upsample import UpSample, UpsampleMode from monai.networks.blocks.downsample import DownSample, DownsampleMode from monai.networks.layers.factories import Norm -from einops import rearrange +from monai.networks.blocks.cablock import FeedForward, CABlock +from monai.networks.blocks.convolutions import Convolution -class FeedForward(nn.Module): - """Gated-DConv Feed-Forward Network (GDFN) that controls feature flow using gating mechanism. - Uses depth-wise convolutions for local context mixing and GELU-activated gating for refined feature selection.""" - def __init__(self, dim: int, ffn_expansion_factor: float, bias: bool): - super().__init__() - hidden_features = int(dim * ffn_expansion_factor) - self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) - self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, - stride=1, padding=1, groups=hidden_features*2, bias=bias) - self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.project_in(x) - x1, x2 = self.dwconv(x).chunk(2, dim=1) - return self.project_out(F.gelu(x1) * x2) -class Attention(nn.Module): - """Multi-DConv Head Transposed Self-Attention (MDTA) Differs from standard self-attention - by operating on feature channels instead of spatial dimensions. Incorporates depth-wise - convolutions for local mixing before attention, achieving linear complexity vs quadratic - in vanilla attention.""" - def __init__(self, dim: int, num_heads: int, bias: bool, flash_attention: bool = False): - super().__init__() - if flash_attention and not hasattr(F, 'scaled_dot_product_attention'): - raise ValueError("Flash attention not available") - - self.num_heads = num_heads - self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) - self.flash_attention = flash_attention - self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) - self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, - padding=1, groups=dim*3, bias=bias) - self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) - self._attention_fn = self._get_attention_fn() - - def _get_attention_fn(self): - if self.flash_attention: - return self._flash_attention - return self._normal_attention - def _flash_attention(self, q, k, v): - """Flash attention implementation using scaled dot-product attention.""" - scale = float(self.temperature.mean()) - out = F.scaled_dot_product_attention( - q, - k, - v, - scale=scale, - dropout_p=0.0, - is_causal=False - ) - return out - - def _normal_attention(self, q, k, v): - """Attention matrix multiplication with depth-wise convolutions.""" - attn = (q @ k.transpose(-2, -1)) * self.temperature - attn = attn.softmax(dim=-1) - return attn @ v - def forward(self, x): - """Forward pass for MDTA attention. - 1. Apply depth-wise convolutions to Q, K, V - 2. Reshape Q, K, V for multi-head attention - 3. Compute attention matrix using flash or normal attention - 4. Reshape and project out attention output""" - b,c,h,w = x.shape - - qkv = self.qkv_dwconv(self.qkv(x)) - q,k,v = qkv.chunk(3, dim=1) - - q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) - k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) - v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) - q = torch.nn.functional.normalize(q, dim=-1) - k = torch.nn.functional.normalize(k, dim=-1) - - out = self._attention_fn(q, k, v) - out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) - out = self.project_out(out) - return out - - -class TransformerBlock(nn.Module): +class MDTATransformerBlock(nn.Module): """Basic transformer unit combining MDTA and GDFN with skip connections. Unlike standard transformers that use LayerNorm, this block uses Instance Norm for better adaptation to image restoration tasks.""" - def __init__(self, dim: int, num_heads: int, ffn_expansion_factor: float, + def __init__(self, spatial_dims: int, dim: int, num_heads: int, ffn_expansion_factor: float, bias: bool, LayerNorm_type: str, flash_attention: bool = False): super().__init__() use_bias = LayerNorm_type != 'BiasFree' self.norm1 = Norm[Norm.INSTANCE, 2](dim, affine=use_bias) - self.attn = Attention(dim, num_heads, bias, flash_attention) + self.attn = CABlock(spatial_dims, dim, num_heads, bias, flash_attention) self.norm2 = Norm[Norm.INSTANCE, 2](dim, affine=use_bias) - self.ffn = FeedForward(dim, ffn_expansion_factor, bias) + self.ffn = FeedForward(spatial_dims, dim, ffn_expansion_factor, bias) def forward(self, x: torch.Tensor) -> torch.Tensor: - #print(f'x shape in transformer block: {x.shape}') - x = x + self.attn(self.norm1(x)) x = x + self.ffn(self.norm2(x)) return x @@ -121,31 +46,21 @@ class OverlapPatchEmbed(nn.Module): Unlike standard patch embeddings that use non-overlapping patches, this approach maintains spatial continuity through 3x3 convolutions.""" - def __init__(self, in_c: int = 3, embed_dim: int = 48, bias: bool = False): + def __init__(self, spatial_dims: int, in_c: int = 3, embed_dim: int = 48, bias: bool = False): super().__init__() - self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, - stride=1, padding=1, bias=bias) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.proj(x) - - - -class DownSample_local(nn.Module): - """Downsampling module that halves spatial dimensions while doubling channels. - Uses PixelUnshuffle for efficient feature map manipulation.""" - - def __init__(self, n_feat: int): - super().__init__() - self.body = nn.Sequential( - nn.Conv2d(n_feat, n_feat//2, kernel_size=3, - stride=1, padding=1, bias=False), - nn.PixelUnshuffle(2) + self.proj = Convolution( + spatial_dims=spatial_dims, + in_channels=in_c, + out_channels=embed_dim, + kernel_size=3, + strides=1, + padding=1, + bias=bias, + conv_only=True ) def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.body(x) - + return self.proj(x) class Restormer_new(nn.Module): @@ -164,6 +79,7 @@ class Restormer_new(nn.Module): - Refinement: Final feature enhancement """ def __init__(self, + spatial_dims=2, inp_channels=3, out_channels=3, dim=48, @@ -197,7 +113,7 @@ def __init__(self, assert all([n > 0 for n in num_blocks]), "Number of blocks must be greater than 0" # Initial feature extraction - self.patch_embed = OverlapPatchEmbed(inp_channels, dim) + self.patch_embed = OverlapPatchEmbed(spatial_dims, inp_channels, dim) self.encoder_levels = nn.ModuleList() self.downsamples = nn.ModuleList() self.decoder_levels = nn.ModuleList() @@ -205,13 +121,15 @@ def __init__(self, self.reduce_channels = nn.ModuleList() num_steps = len(num_blocks) - 1 self.num_steps = num_steps + self.spatial_dims=spatial_dims # Define encoder levels for n in range(num_steps): current_dim = dim * 2**n self.encoder_levels.append( nn.Sequential(*[ - TransformerBlock( + MDTATransformerBlock( + spatial_dims=spatial_dims, dim=current_dim, num_heads=heads[n], ffn_expansion_factor=ffn_expansion_factor, @@ -221,13 +139,10 @@ def __init__(self, ) for _ in range(num_blocks[n]) ]) ) - print(f' Encoder layer {n}') - print(f'input channels to the downsampler: {current_dim}') - print(f'output channels from the downsampler: {current_dim//2}') + self.downsamples.append( - #DownSample_local(current_dim) DownSample( - spatial_dims=2, + spatial_dims=self.spatial_dims, in_channels=current_dim, out_channels=current_dim//2, mode=DownsampleMode.PIXELUNSHUFFLE, @@ -239,7 +154,8 @@ def __init__(self, # Define latent space latent_dim = dim * 2**num_steps self.latent = nn.Sequential(*[ - TransformerBlock( + MDTATransformerBlock( + spatial_dims=spatial_dims, dim=latent_dim, num_heads=heads[num_steps], ffn_expansion_factor=ffn_expansion_factor, @@ -255,7 +171,7 @@ def __init__(self, next_dim = dim * 2**(n+1) self.upsamples.append( UpSample( - spatial_dims=2, + spatial_dims=self.spatial_dims, in_channels=next_dim, out_channels=(next_dim//2), mode=UpsampleMode.PIXELSHUFFLE, @@ -268,15 +184,23 @@ def __init__(self, # Reduce channel layers to deal with skip connections if n != 0: self.reduce_channels.append( - nn.Conv2d(next_dim, current_dim, kernel_size=1, bias=bias) + Convolution( + spatial_dims=self.spatial_dims, + in_channels=next_dim, + out_channels=current_dim, + kernel_size=1, + bias=bias, + conv_only=True ) + ) decoder_dim = current_dim else: decoder_dim = next_dim self.decoder_levels.append( nn.Sequential(*[ - TransformerBlock( + MDTATransformerBlock( + spatial_dims=spatial_dims, dim=decoder_dim, num_heads=heads[n], ffn_expansion_factor=ffn_expansion_factor, @@ -289,7 +213,8 @@ def __init__(self, # Final refinement and output self.refinement = nn.Sequential(*[ - TransformerBlock( + MDTATransformerBlock( + spatial_dims=spatial_dims, dim=decoder_dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, @@ -300,14 +225,25 @@ def __init__(self, ]) self.dual_pixel_task = dual_pixel_task if self.dual_pixel_task: - self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias) + self.skip_conv = Convolution( + spatial_dims=self.spatial_dims, + in_channels=dim, + out_channels=int(dim*2**1), + kernel_size=1, + bias=bias, + conv_only=True + ) - self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) - - #print(f'downsample layer in new model: {self.downsamples}') - print(f'======================') - print(f'======================') - #print(f'upsamples layer in new model: {self.upsamples}') + self.output = Convolution( + spatial_dims=self.spatial_dims, + in_channels=int(dim*2**1), + out_channels=out_channels, + kernel_size=3, + strides=1, + padding=1, + bias=bias, + conv_only=True + ) def forward(self, x): """Forward pass of Restormer. @@ -326,12 +262,9 @@ def forward(self, x): # Encoding path for idx, (encoder, downsample) in enumerate(zip(self.encoder_levels, self.downsamples)): - print(f'image shape at input: {x.shape}') x = encoder(x) skip_connections.append(x) - print(f'x shape in new model encoder: {x.shape}') x = downsample(x) - print(f'x shape in new model downsample: {x.shape}') # Latent space x = self.latent(x) @@ -354,6 +287,3 @@ def forward(self, x): x = self.output(x) return x - - - From 1ab34f66d248e4ae5cb63e2b1031bce92b86e0ef Mon Sep 17 00:00:00 2001 From: tisalon Date: Mon, 20 Jan 2025 15:42:59 +0100 Subject: [PATCH 16/67] restormer adapted for 2D/3D --- monai/networks/nets/restormer.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/monai/networks/nets/restormer.py b/monai/networks/nets/restormer.py index d7e91ed5fc..1f6ad9be27 100644 --- a/monai/networks/nets/restormer.py +++ b/monai/networks/nets/restormer.py @@ -30,9 +30,9 @@ def __init__(self, spatial_dims: int, dim: int, num_heads: int, ffn_expansion_fa bias: bool, LayerNorm_type: str, flash_attention: bool = False): super().__init__() use_bias = LayerNorm_type != 'BiasFree' - self.norm1 = Norm[Norm.INSTANCE, 2](dim, affine=use_bias) + self.norm1 = Norm[Norm.INSTANCE, spatial_dims](dim, affine=use_bias) self.attn = CABlock(spatial_dims, dim, num_heads, bias, flash_attention) - self.norm2 = Norm[Norm.INSTANCE, 2](dim, affine=use_bias) + self.norm2 = Norm[Norm.INSTANCE, spatial_dims](dim, affine=use_bias) self.ffn = FeedForward(spatial_dims, dim, ffn_expansion_factor, bias) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -122,10 +122,12 @@ def __init__(self, num_steps = len(num_blocks) - 1 self.num_steps = num_steps self.spatial_dims=spatial_dims - + spatial_multiplier = 2**(spatial_dims - 1) + # Define encoder levels for n in range(num_steps): - current_dim = dim * 2**n + current_dim = dim * (2)**(n) + next_dim=current_dim//spatial_multiplier self.encoder_levels.append( nn.Sequential(*[ MDTATransformerBlock( @@ -144,7 +146,7 @@ def __init__(self, DownSample( spatial_dims=self.spatial_dims, in_channels=current_dim, - out_channels=current_dim//2, + out_channels=next_dim, mode=DownsampleMode.PIXELUNSHUFFLE, scale_factor=2, bias=bias, @@ -152,7 +154,7 @@ def __init__(self, ) # Define latent space - latent_dim = dim * 2**num_steps + latent_dim = dim * (2)**(num_steps) self.latent = nn.Sequential(*[ MDTATransformerBlock( spatial_dims=spatial_dims, @@ -167,13 +169,13 @@ def __init__(self, # Define decoder levels for n in reversed(range(num_steps)): - current_dim = dim * 2**n - next_dim = dim * 2**(n+1) + current_dim = dim * (2)**(n) + next_dim = dim * (2)**(n+1) self.upsamples.append( UpSample( spatial_dims=self.spatial_dims, in_channels=next_dim, - out_channels=(next_dim//2), + out_channels=(current_dim), mode=UpsampleMode.PIXELSHUFFLE, scale_factor=2, bias=bias, @@ -228,15 +230,14 @@ def __init__(self, self.skip_conv = Convolution( spatial_dims=self.spatial_dims, in_channels=dim, - out_channels=int(dim*2**1), + out_channels=dim*2, kernel_size=1, bias=bias, conv_only=True ) - self.output = Convolution( spatial_dims=self.spatial_dims, - in_channels=int(dim*2**1), + in_channels=dim*2, out_channels=out_channels, kernel_size=3, strides=1, @@ -244,7 +245,6 @@ def __init__(self, bias=bias, conv_only=True ) - def forward(self, x): """Forward pass of Restormer. Processes input through encoder-decoder architecture with skip connections. From 4f4c62cfd2d71386bbf4eca8c3ba239aa4dd5a3b Mon Sep 17 00:00:00 2001 From: tisalon Date: Mon, 20 Jan 2025 16:19:03 +0100 Subject: [PATCH 17/67] Add unit test for CABlock and the FeedForward layers --- tests/test_CABlock.py | 154 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) create mode 100644 tests/test_CABlock.py diff --git a/tests/test_CABlock.py b/tests/test_CABlock.py new file mode 100644 index 0000000000..7d1997e044 --- /dev/null +++ b/tests/test_CABlock.py @@ -0,0 +1,154 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import sys + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../'))) + + +import unittest + + +import unittest +from unittest import skipUnless +import torch +import numpy as np +from parameterized import parameterized +from monai.networks import eval_mode +from monai.networks.blocks.cablock import CABlock, FeedForward +from tests.utils import assert_allclose, SkipIfBeforePyTorchVersion +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") + + +TEST_CASES_CAB = [] +for spatial_dims in [2, 3]: + for dim in [32, 64, 128]: + for num_heads in [2, 4, 8]: + for bias in [True, False]: + test_case = [ + { + "spatial_dims": spatial_dims, + "dim": dim, + "num_heads": num_heads, + "bias": bias, + "flash_attention": False + }, + (2, dim, *([16] * spatial_dims)), + (2, dim, *([16] * spatial_dims)) + ] + TEST_CASES_CAB.append(test_case) + + +TEST_CASES_FEEDFORWARD = [ + # Test different spatial dims, dimensions and expansion factors + [{"spatial_dims": 2, "dim": 64, "ffn_expansion_factor": 2.0, "bias": True}, (2, 64, 32, 32)], + [{"spatial_dims": 3, "dim": 128, "ffn_expansion_factor": 1.5, "bias": False}, (2, 128, 16, 16, 16)], + [{"spatial_dims": 2, "dim": 256, "ffn_expansion_factor": 1.0, "bias": True}, (1, 256, 64, 64)], +] + + +class TestFeedForward(unittest.TestCase): + + @parameterized.expand(TEST_CASES_FEEDFORWARD) + def test_shape(self, input_param, input_shape): + net = FeedForward(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, input_shape) + + def test_gating_mechanism(self): + net = FeedForward(spatial_dims=2, dim=32, ffn_expansion_factor=2.0, bias=True) + x = torch.ones(1, 32, 16, 16) + out = net(x) + self.assertNotEqual(torch.sum(out), torch.sum(x)) + + + +class TestCABlock(unittest.TestCase): + + @parameterized.expand(TEST_CASES_CAB) + def test_shape(self, input_param, input_shape, expected_shape): + net = CABlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_invalid_spatial_dims(self): + with self.assertRaises(ValueError): + CABlock(spatial_dims=4, dim=64, num_heads=4, bias=True) + + @SkipIfBeforePyTorchVersion((2, 0)) + def test_flash_attention(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=True).to(device) + x = torch.randn(2, 64, 32, 32).to(device) + output = block(x) + self.assertEqual(output.shape, x.shape) + + def test_temperature_parameter(self): + block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True) + self.assertTrue(isinstance(block.temperature, torch.nn.Parameter)) + self.assertEqual(block.temperature.shape, (4, 1, 1)) + + def test_qkv_transformation_2d(self): + block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True) + x = torch.randn(2, 64, 32, 32) + qkv = block.qkv(x) + self.assertEqual(qkv.shape, (2, 192, 32, 32)) + + def test_qkv_transformation_3d(self): + block = CABlock(spatial_dims=3, dim=64, num_heads=4, bias=True) + x = torch.randn(2, 64, 16, 16, 16) + qkv = block.qkv(x) + self.assertEqual(qkv.shape, (2, 192, 16, 16, 16)) + + @SkipIfBeforePyTorchVersion((2, 0)) + def test_flash_vs_normal_attention(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + block_flash = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=True).to(device) + block_normal = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=False).to(device) + + block_normal.load_state_dict(block_flash.state_dict()) + + x = torch.randn(2, 64, 32, 32).to(device) + with torch.no_grad(): + out_flash = block_flash(x) + out_normal = block_normal(x) + + assert_allclose(out_flash, out_normal, atol=1e-4) + + def test_deterministic_small_input(self): + block = CABlock(spatial_dims=2, dim=2, num_heads=1, bias=False) + with torch.no_grad(): + block.qkv.conv.weight.data.fill_(1.0) + block.qkv_dwconv.conv.weight.data.fill_(1.0) + block.temperature.data.fill_(1.0) + block.project_out.conv.weight.data.fill_(1.0) + + x = torch.tensor([ + [[[1.0, 2.0], + [3.0, 4.0]], + [[5.0, 6.0], + [7.0, 8.0]]]], + dtype=torch.float32) + + output = block(x) + # Channel attention: sum([1..8]) * (qkv_conv=1) * (dwconv=1) * (attn_weights=1) * (proj=1) = 36 * 2 = 72 + expected = torch.full_like(x, 72.0) + + assert_allclose(output, expected, atol=1e-6) +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 068688f50bb7e11a61f8411539cc368bdf1b6937 Mon Sep 17 00:00:00 2001 From: tisalon Date: Mon, 20 Jan 2025 16:29:15 +0100 Subject: [PATCH 18/67] remove relative imports --- tests/test_CABlock.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/test_CABlock.py b/tests/test_CABlock.py index 7d1997e044..a3b172fd4e 100644 --- a/tests/test_CABlock.py +++ b/tests/test_CABlock.py @@ -11,14 +11,6 @@ from __future__ import annotations -import os -import sys - -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../'))) - - -import unittest - import unittest from unittest import skipUnless From e2e1070a602ae1f5452f8c1b3ae51bb640600b7b Mon Sep 17 00:00:00 2001 From: tisalon Date: Mon, 20 Jan 2025 16:33:48 +0100 Subject: [PATCH 19/67] rename restormer --- monai/networks/nets/restormer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/restormer.py b/monai/networks/nets/restormer.py index 1f6ad9be27..540068e5e6 100644 --- a/monai/networks/nets/restormer.py +++ b/monai/networks/nets/restormer.py @@ -63,7 +63,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.proj(x) -class Restormer_new(nn.Module): +class Restormer(nn.Module): """Restormer: Efficient Transformer for High-Resolution Image Restoration. Implements a U-Net style architecture with transformer blocks, combining: From 35c7ee48e0c71f0979e181ce36a654c2683880c3 Mon Sep 17 00:00:00 2001 From: tisalon Date: Mon, 20 Jan 2025 16:39:08 +0100 Subject: [PATCH 20/67] add unit test restormer --- tests/test_restormer.py | 135 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 tests/test_restormer.py diff --git a/tests/test_restormer.py b/tests/test_restormer.py new file mode 100644 index 0000000000..1e9caab477 --- /dev/null +++ b/tests/test_restormer.py @@ -0,0 +1,135 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + + +import unittest +import torch +from parameterized import parameterized +from monai.networks.nets.restormer import MDTATransformerBlock, OverlapPatchEmbed, Restormer +from monai.networks import eval_mode + +TEST_CASES_TRANSFORMER = [ + # [spatial_dims, dim, num_heads, ffn_factor, bias, norm_type, flash_attn, input_shape] + [2, 48, 8, 2.66, True, "WithBias", False, (2, 48, 64, 64)], + [2, 96, 8, 2.66, False, "BiasFree", False, (2, 96, 32, 32)], + [3, 48, 4, 2.66, True, "WithBias", False, (2, 48, 32, 32, 32)], + [3, 96, 8, 2.66, False, "BiasFree", True, (2, 96, 16, 16, 16)], +] + +TEST_CASES_PATCHEMBED = [ + # spatial_dims, in_c, embed_dim, input_shape, expected_shape + [2, 1, 48, (2, 1, 64, 64), (2, 48, 64, 64)], + [2, 3, 96, (2, 3, 32, 32), (2, 96, 32, 32)], + [3, 1, 48, (2, 1, 32, 32, 32), (2, 48, 32, 32, 32)], + [3, 4, 64, (2, 4, 16, 16, 16), (2, 64, 16, 16, 16)], +] + +RESTORMER_CONFIGS = [ + # 2-level architecture + {"num_blocks": [1, 1], "heads": [1, 1]}, + {"num_blocks": [2, 1], "heads": [2, 1]}, + # 3-level architecture + {"num_blocks": [1, 1, 1], "heads": [1, 1, 1]}, + {"num_blocks": [2, 1, 1], "heads": [2, 1, 1]}, +] + +TEST_CASES_RESTORMER = [] +for config in RESTORMER_CONFIGS: + # 2D cases + TEST_CASES_RESTORMER.extend([ + [ + { + "spatial_dims": 2, + "inp_channels": 1, + "out_channels": 1, + "dim": 48, + "num_blocks": config["num_blocks"], + "heads": config["heads"], + "num_refinement_blocks": 2, + "ffn_expansion_factor": 1.5 + }, + (2, 1, 64, 64), + (2, 1, 64, 64) + ], + # 3D cases + [ + { + "spatial_dims": 3, + "inp_channels": 1, + "out_channels": 1, + "dim": 48, + "num_blocks": config["num_blocks"], + "heads": config["heads"], + "num_refinement_blocks": 2, + "ffn_expansion_factor": 1.5 + }, + (2, 1, 32, 32, 32), + (2, 1, 32, 32, 32) + ] + ]) + + +class TestMDTATransformerBlock(unittest.TestCase): + + @parameterized.expand(TEST_CASES_TRANSFORMER) + def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, norm_type, flash, shape): + block = MDTATransformerBlock( + spatial_dims=spatial_dims, + dim=dim, + num_heads=heads, + ffn_expansion_factor=ffn_factor, + bias=bias, + LayerNorm_type=norm_type, + flash_attention=flash + ) + with eval_mode(block): + x = torch.randn(shape) + output = block(x) + self.assertEqual(output.shape, x.shape) + + +class TestOverlapPatchEmbed(unittest.TestCase): + + @parameterized.expand(TEST_CASES_PATCHEMBED) + def test_shape(self, spatial_dims, in_c, embed_dim, input_shape, expected_shape): + net = OverlapPatchEmbed( + spatial_dims=spatial_dims, + in_c=in_c, + embed_dim=embed_dim + ) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + +class TestRestormer(unittest.TestCase): + + @parameterized.expand(TEST_CASES_RESTORMER) + def test_shape(self, input_param, input_shape, expected_shape): + net = Restormer(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_small_input_error_2d(self): + net = Restormer(spatial_dims=2, inp_channels=1, out_channels=1) + with self.assertRaises(AssertionError): + net(torch.randn(1, 1, 8, 8)) + + def test_small_input_error_3d(self): + net = Restormer(spatial_dims=3, inp_channels=1, out_channels=1) + with self.assertRaises(AssertionError): + net(torch.randn(1, 1, 8, 8, 8)) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From d8cb6c13b59f86b5fd65e255b0cc5e55cb298279 Mon Sep 17 00:00:00 2001 From: tisalon Date: Thu, 23 Jan 2025 15:32:31 +0100 Subject: [PATCH 21/67] Update documentation and imports for CABlock and FeedForward; add Downsample class alias --- docs/source/networks.rst | 25 +++++++++++++++++++++++++ monai/networks/blocks/__init__.py | 10 +++++++++- monai/networks/blocks/cablock.py | 6 +++--- monai/networks/blocks/downsample.py | 7 +++++-- 4 files changed, 42 insertions(+), 6 deletions(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index e2e509a99b..05825c3c18 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -109,6 +109,16 @@ Blocks .. autoclass:: SABlock :members: +`CABlock Block` +~~~~~~~~~~~~~~~ +.. autoclass:: CABlock + :members: + +`FeedForward Block` +~~~~~~~~~~~~~~~~~~~ +.. autoclass:: FeedForward + :members: + `Squeeze-and-Excitation` ~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: ChannelSELayer @@ -173,6 +183,16 @@ Blocks .. autoclass:: Subpixelupsample .. autoclass:: SubpixelUpSample +`Downsampling` +~~~~~~~~~~~~~~ +.. autoclass:: DownSample + :members: +.. autoclass:: DownSample +.. autoclass:: SubpixelDownSample + :members: +.. autoclass:: Subpixeldownsample +.. autoclass:: SubpixelDownSample + `Registration Residual Conv Block` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: RegistrationResidualConvBlock @@ -625,6 +645,11 @@ Nets .. autoclass:: ViT :members: +`Restormer` +~~~~~~~~~~~ +.. autoclass:: restormer + :members: + `ViTAutoEnc` ~~~~~~~~~~~~ .. autoclass:: ViTAutoEnc diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index dd1c7a256a..3390067cb5 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -20,7 +20,14 @@ from .crossattention import CrossAttentionBlock from .denseblock import ConvDenseBlock, DenseBlock from .dints_block import ActiConvNormBlock, FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock -from .downsample import MaxAvgPool, DownSample, SubpixelDownsample +from .downsample import ( + MaxAvgPool, + DownSample, + Downsample, + SubpixelDownsample, + SubpixelDownSample, + Subpixeldownsample +) from .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding from .encoder import BaseEncoder from .fcn import FCN, GCN, MCFCN, Refine @@ -32,6 +39,7 @@ from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock from .segresnet_block import ResBlock from .selfattention import SABlock +from .cablock import CABlock, FeedForward from .spade_norm import SPADE from .spatialattention import SpatialAttentionBlock from .squeeze_and_excitation import ( diff --git a/monai/networks/blocks/cablock.py b/monai/networks/blocks/cablock.py index adcdb33911..a8f8e0b243 100644 --- a/monai/networks/blocks/cablock.py +++ b/monai/networks/blocks/cablock.py @@ -17,7 +17,7 @@ from monai.networks.blocks.convolutions import Convolution from einops import rearrange -__all__ = ["FeedForward"] +__all__ = ["FeedForward", "CABlock"] class FeedForward(nn.Module): @@ -65,10 +65,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class CABlock(nn.Module): - """Multi-DConv Head Transposed Self-Attention (MDTA) Differs from standard self-attention + """Multi-DConv Head Transposed Self-Attention (MDTA): Differs from standard self-attention by operating on feature channels instead of spatial dimensions. Incorporates depth-wise convolutions for local mixing before attention, achieving linear complexity vs quadratic - in vanilla attention.""" + in vanilla attention. Based on SW Zamir, et al., 2022 """ def __init__(self, spatial_dims, dim: int, num_heads: int, bias: bool, flash_attention: bool = False): super().__init__() if flash_attention and not hasattr(F, 'scaled_dot_product_attention'): diff --git a/monai/networks/blocks/downsample.py b/monai/networks/blocks/downsample.py index bda3aeb961..517fcffa85 100644 --- a/monai/networks/blocks/downsample.py +++ b/monai/networks/blocks/downsample.py @@ -20,7 +20,7 @@ from monai.networks.utils import pixelunshuffle from monai.utils import InterpolateMode, DownsampleMode, ensure_tuple_rep, look_up_option -__all__ = ["MaxAvgPool", "DownSample", "SubpixelDownsample"] +__all__ = ["MaxAvgPool", "DownSample", "Downsample", "SubpixelDownsample", "SubpixelDownSample", "Subpixeldownsample"] class MaxAvgPool(nn.Module): """ @@ -311,4 +311,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: f"divisible by scale_factor {self.scale_factor}" ) x = pixelunshuffle(x, self.dimensions, self.scale_factor) - return x \ No newline at end of file + return x + +Downsample = DownSample +SubpixelDownSample = Subpixeldownsample = SubpixelDownsample \ No newline at end of file From 6d96816d3936b4faac6d559bbc290ab0a5589397 Mon Sep 17 00:00:00 2001 From: tisalon Date: Thu, 23 Jan 2025 15:39:28 +0100 Subject: [PATCH 22/67] Add licence to pixel_unshuffle test --- tests/test_pixelunshuffle.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_pixelunshuffle.py b/tests/test_pixelunshuffle.py index 9450aaf273..48deaa9ee9 100644 --- a/tests/test_pixelunshuffle.py +++ b/tests/test_pixelunshuffle.py @@ -1,4 +1,18 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + import unittest + import torch from monai.networks.utils import pixelunshuffle, pixelshuffle From 8a688fb57926232f73240fcb1f12805e2de72ded Mon Sep 17 00:00:00 2001 From: tisalon Date: Thu, 23 Jan 2025 15:49:44 +0100 Subject: [PATCH 23/67] Refactor imports and clean up whitespace in utils and test files and pass ./runtests.sh -f -u --net --coverage --- monai/networks/blocks/__init__.py | 11 +- monai/networks/blocks/cablock.py | 100 ++++++------ monai/networks/blocks/downsample.py | 48 ++---- monai/networks/nets/restormer.py | 239 +++++++++++++++------------- monai/networks/utils.py | 12 +- monai/utils/__init__.py | 2 +- monai/utils/enums.py | 4 +- tests/test_CABlock.py | 39 +++-- tests/test_downsample_block.py | 62 +++----- tests/test_pixelunshuffle.py | 7 +- tests/test_restormer.py | 88 +++++----- 11 files changed, 292 insertions(+), 320 deletions(-) diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py index 3390067cb5..22af82d316 100644 --- a/monai/networks/blocks/__init__.py +++ b/monai/networks/blocks/__init__.py @@ -15,19 +15,13 @@ from .activation import GEGLU, MemoryEfficientSwish, Mish, Swish from .aspp import SimpleASPP from .backbone_fpn_utils import BackboneWithFPN +from .cablock import CABlock, FeedForward from .convolutions import Convolution, ResidualUnit from .crf import CRF from .crossattention import CrossAttentionBlock from .denseblock import ConvDenseBlock, DenseBlock from .dints_block import ActiConvNormBlock, FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock -from .downsample import ( - MaxAvgPool, - DownSample, - Downsample, - SubpixelDownsample, - SubpixelDownSample, - Subpixeldownsample -) +from .downsample import DownSample, Downsample, MaxAvgPool, SubpixelDownsample, SubpixelDownSample, Subpixeldownsample from .dynunet_block import UnetBasicBlock, UnetOutBlock, UnetResBlock, UnetUpBlock, get_output_padding, get_padding from .encoder import BaseEncoder from .fcn import FCN, GCN, MCFCN, Refine @@ -39,7 +33,6 @@ from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock from .segresnet_block import ResBlock from .selfattention import SABlock -from .cablock import CABlock, FeedForward from .spade_norm import SPADE from .spatialattention import SpatialAttentionBlock from .squeeze_and_excitation import ( diff --git a/monai/networks/blocks/cablock.py b/monai/networks/blocks/cablock.py index a8f8e0b243..6f5eddc780 100644 --- a/monai/networks/blocks/cablock.py +++ b/monai/networks/blocks/cablock.py @@ -10,13 +10,13 @@ # limitations under the License. from __future__ import annotations - import torch import torch.nn as nn import torch.nn.functional as F -from monai.networks.blocks.convolutions import Convolution from einops import rearrange +from monai.networks.blocks.convolutions import Convolution + __all__ = ["FeedForward", "CABlock"] @@ -27,26 +27,26 @@ class FeedForward(nn.Module): def __init__(self, spatial_dims: int, dim: int, ffn_expansion_factor: float, bias: bool): super().__init__() hidden_features = int(dim * ffn_expansion_factor) - + self.project_in = Convolution( spatial_dims=spatial_dims, in_channels=dim, - out_channels=hidden_features*2, + out_channels=hidden_features * 2, kernel_size=1, bias=bias, - conv_only=True + conv_only=True, ) self.dwconv = Convolution( spatial_dims=spatial_dims, - in_channels=hidden_features*2, - out_channels=hidden_features*2, + in_channels=hidden_features * 2, + out_channels=hidden_features * 2, kernel_size=3, strides=1, padding=1, - groups=hidden_features*2, + groups=hidden_features * 2, bias=bias, - conv_only=True + conv_only=True, ) self.project_out = Convolution( @@ -55,7 +55,7 @@ def __init__(self, spatial_dims: int, dim: int, ffn_expansion_factor: float, bia out_channels=dim, kernel_size=1, bias=bias, - conv_only=True + conv_only=True, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -69,9 +69,10 @@ class CABlock(nn.Module): by operating on feature channels instead of spatial dimensions. Incorporates depth-wise convolutions for local mixing before attention, achieving linear complexity vs quadratic in vanilla attention. Based on SW Zamir, et al., 2022 """ + def __init__(self, spatial_dims, dim: int, num_heads: int, bias: bool, flash_attention: bool = False): super().__init__() - if flash_attention and not hasattr(F, 'scaled_dot_product_attention'): + if flash_attention and not hasattr(F, "scaled_dot_product_attention"): raise ValueError("Flash attention not available") if spatial_dims > 3: raise ValueError(f"Only 2D and 3D inputs are supported. Got spatial_dims={spatial_dims}") @@ -79,53 +80,38 @@ def __init__(self, spatial_dims, dim: int, num_heads: int, bias: bool, flash_att self.num_heads = num_heads self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) self.flash_attention = flash_attention - + self.qkv = Convolution( - spatial_dims=spatial_dims, - in_channels=dim, - out_channels=dim*3, - kernel_size=1, - bias=bias, - conv_only=True + spatial_dims=spatial_dims, in_channels=dim, out_channels=dim * 3, kernel_size=1, bias=bias, conv_only=True ) - + self.qkv_dwconv = Convolution( spatial_dims=spatial_dims, - in_channels=dim*3, - out_channels=dim*3, + in_channels=dim * 3, + out_channels=dim * 3, kernel_size=3, strides=1, padding=1, - groups=dim*3, + groups=dim * 3, bias=bias, - conv_only=True + conv_only=True, ) - + self.project_out = Convolution( - spatial_dims=spatial_dims, - in_channels=dim, - out_channels=dim, - kernel_size=1, - bias=bias, - conv_only=True + spatial_dims=spatial_dims, in_channels=dim, out_channels=dim, kernel_size=1, bias=bias, conv_only=True ) - + self._attention_fn = self._get_attention_fn() + def _get_attention_fn(self): if self.flash_attention: return self._flash_attention return self._normal_attention + def _flash_attention(self, q, k, v): """Flash attention implementation using scaled dot-product attention.""" - scale = float(self.temperature.mean()) - out = F.scaled_dot_product_attention( - q, - k, - v, - scale=scale, - dropout_p=0.0, - is_causal=False - ) + scale = float(self.temperature.mean()) + out = F.scaled_dot_product_attention(q, k, v, scale=scale, dropout_p=0.0, is_causal=False) return out def _normal_attention(self, q, k, v): @@ -133,35 +119,41 @@ def _normal_attention(self, q, k, v): attn = (q @ k.transpose(-2, -1)) * self.temperature attn = attn.softmax(dim=-1) return attn @ v + def forward(self, x): - """Forward pass for MDTA attention. + """Forward pass for MDTA attention. 1. Apply depth-wise convolutions to Q, K, V 2. Reshape Q, K, V for multi-head attention 3. Compute attention matrix using flash or normal attention 4. Reshape and project out attention output""" spatial_dims = x.shape[2:] - - # Project and mix + + # Project and mix qkv = self.qkv_dwconv(self.qkv(x)) q, k, v = qkv.chunk(3, dim=1) - + # Select attention if self.spatial_dims == 2: - qkv_to_multihead = 'b (head c) h w -> b head c (h w)' - multihead_to_qkv = 'b head c (h w) -> b (head c) h w' + qkv_to_multihead = "b (head c) h w -> b head c (h w)" + multihead_to_qkv = "b head c (h w) -> b (head c) h w" else: # dims == 3 - qkv_to_multihead = 'b (head c) d h w -> b head c (d h w)' - multihead_to_qkv = 'b head c (d h w) -> b (head c) d h w' - + qkv_to_multihead = "b (head c) d h w -> b head c (d h w)" + multihead_to_qkv = "b head c (d h w) -> b (head c) d h w" + # Reconstruct and project feature map q = rearrange(q, qkv_to_multihead, head=self.num_heads) k = rearrange(k, qkv_to_multihead, head=self.num_heads) v = rearrange(v, qkv_to_multihead, head=self.num_heads) - + q = torch.nn.functional.normalize(q, dim=-1) k = torch.nn.functional.normalize(k, dim=-1) - + out = self._attention_fn(q, k, v) - out = rearrange(out, multihead_to_qkv, head=self.num_heads, **dict(zip(['h','w'] if self.spatial_dims==2 else ['d','h','w'], spatial_dims))) - - return self.project_out(out) \ No newline at end of file + out = rearrange( + out, + multihead_to_qkv, + head=self.num_heads, + **dict(zip(["h", "w"] if self.spatial_dims == 2 else ["d", "h", "w"], spatial_dims)), + ) + + return self.project_out(out) diff --git a/monai/networks/blocks/downsample.py b/monai/networks/blocks/downsample.py index 517fcffa85..9e33f7fdce 100644 --- a/monai/networks/blocks/downsample.py +++ b/monai/networks/blocks/downsample.py @@ -18,10 +18,11 @@ from monai.networks.layers.factories import Conv, Pool from monai.networks.utils import pixelunshuffle -from monai.utils import InterpolateMode, DownsampleMode, ensure_tuple_rep, look_up_option +from monai.utils import DownsampleMode, InterpolateMode, ensure_tuple_rep, look_up_option __all__ = ["MaxAvgPool", "DownSample", "Downsample", "SubpixelDownsample", "SubpixelDownSample", "Subpixeldownsample"] + class MaxAvgPool(nn.Module): """ Downsample with both maxpooling and avgpooling, @@ -127,7 +128,7 @@ def __init__( bias: whether to have a bias term in the default preconv and conv layers. Defaults to True. """ super().__init__() - + scale_factor_ = ensure_tuple_rep(scale_factor, spatial_dims) down_mode = look_up_option(mode, DownsampleMode) @@ -177,19 +178,11 @@ def __init__( self.add_module( "preconv", Conv[Conv.CONV, spatial_dims]( - in_channels=in_channels, - out_channels=out_channels or in_channels, - kernel_size=1, - bias=bias + in_channels=in_channels, out_channels=out_channels or in_channels, kernel_size=1, bias=bias ), ) self.add_module( - "maxpool", - Pool[Pool.MAX, spatial_dims]( - kernel_size=kernel_size_, - stride=scale_factor_, - padding=padding - ) + "maxpool", Pool[Pool.MAX, spatial_dims](kernel_size=kernel_size_, stride=scale_factor_, padding=padding) ) if post_conv: self.add_module("postconv", post_conv) @@ -201,19 +194,11 @@ def __init__( self.add_module( "preconv", Conv[Conv.CONV, spatial_dims]( - in_channels=in_channels, - out_channels=out_channels or in_channels, - kernel_size=1, - bias=bias + in_channels=in_channels, out_channels=out_channels or in_channels, kernel_size=1, bias=bias ), ) self.add_module( - "avgpool", - Pool[Pool.AVG, spatial_dims]( - kernel_size=kernel_size_, - stride=scale_factor_, - padding=padding - ) + "avgpool", Pool[Pool.AVG, spatial_dims](kernel_size=kernel_size_, stride=scale_factor_, padding=padding) ) if post_conv: self.add_module("postconv", post_conv) @@ -239,9 +224,9 @@ class SubpixelDownsample(nn.Module): to adjust the number of channels. Secondly, a pixel unshuffle manipulation rearranges the spatial information into channel space, effectively reducing spatial dimensions while increasing channel depth. - - The pixel unshuffle operation is the inverse of pixel shuffle, rearranging dimensions - from (B, C, H*r, W*r) to (B, C*r², H, W). + + The pixel unshuffle operation is the inverse of pixel shuffle, rearranging dimensions + from (B, C, H*r, W*r) to (B, C*r², H, W). Example: (1, 1, 4, 4) with r=2 becomes (1, 4, 2, 2). See: Shi et al., 2016, "Real-Time Single Image and Video Super-Resolution @@ -285,12 +270,7 @@ def __init__( raise ValueError("in_channels need to be specified.") out_channels = out_channels or in_channels self.conv_block = Conv[Conv.CONV, self.dimensions]( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - stride=1, - padding=1, - bias=bias + in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=bias ) elif conv_block is None: self.conv_block = nn.Identity() @@ -307,11 +287,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv_block(x) if not all(d % self.scale_factor == 0 for d in x.shape[2:]): raise ValueError( - f"All spatial dimensions {x.shape[2:]} must be evenly " - f"divisible by scale_factor {self.scale_factor}" + f"All spatial dimensions {x.shape[2:]} must be evenly " f"divisible by scale_factor {self.scale_factor}" ) x = pixelunshuffle(x, self.dimensions, self.scale_factor) return x + Downsample = DownSample -SubpixelDownSample = Subpixeldownsample = SubpixelDownsample \ No newline at end of file +SubpixelDownSample = Subpixeldownsample = SubpixelDownsample diff --git a/monai/networks/nets/restormer.py b/monai/networks/nets/restormer.py index 540068e5e6..3d3d9565d4 100644 --- a/monai/networks/nets/restormer.py +++ b/monai/networks/nets/restormer.py @@ -13,23 +13,30 @@ import torch import torch.nn as nn - -from monai.networks.blocks.upsample import UpSample, UpsampleMode +from monai.networks.blocks.cablock import CABlock, FeedForward +from monai.networks.blocks.convolutions import Convolution from monai.networks.blocks.downsample import DownSample, DownsampleMode +from monai.networks.blocks.upsample import UpSample, UpsampleMode from monai.networks.layers.factories import Norm -from monai.networks.blocks.cablock import FeedForward, CABlock -from monai.networks.blocks.convolutions import Convolution class MDTATransformerBlock(nn.Module): """Basic transformer unit combining MDTA and GDFN with skip connections. Unlike standard transformers that use LayerNorm, this block uses Instance Norm for better adaptation to image restoration tasks.""" - - def __init__(self, spatial_dims: int, dim: int, num_heads: int, ffn_expansion_factor: float, - bias: bool, LayerNorm_type: str, flash_attention: bool = False): + + def __init__( + self, + spatial_dims: int, + dim: int, + num_heads: int, + ffn_expansion_factor: float, + bias: bool, + LayerNorm_type: str, + flash_attention: bool = False, + ): super().__init__() - use_bias = LayerNorm_type != 'BiasFree' + use_bias = LayerNorm_type != "BiasFree" self.norm1 = Norm[Norm.INSTANCE, spatial_dims](dim, affine=use_bias) self.attn = CABlock(spatial_dims, dim, num_heads, bias, flash_attention) self.norm2 = Norm[Norm.INSTANCE, spatial_dims](dim, affine=use_bias) @@ -45,7 +52,7 @@ class OverlapPatchEmbed(nn.Module): """Initial feature extraction using overlapped convolutions. Unlike standard patch embeddings that use non-overlapping patches, this approach maintains spatial continuity through 3x3 convolutions.""" - + def __init__(self, spatial_dims: int, in_c: int = 3, embed_dim: int = 48, bias: bool = False): super().__init__() self.proj = Convolution( @@ -56,7 +63,7 @@ def __init__(self, spatial_dims: int, in_c: int = 3, embed_dim: int = 48, bias: strides=1, padding=1, bias=bias, - conv_only=True + conv_only=True, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -65,35 +72,38 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Restormer(nn.Module): """Restormer: Efficient Transformer for High-Resolution Image Restoration. - + Implements a U-Net style architecture with transformer blocks, combining: - Multi-scale feature processing through progressive down/upsampling - Efficient attention via MDTA blocks - Local feature mixing through GDFN - Skip connections for preserving spatial details - + Architecture: - Encoder: Progressive feature downsampling with increasing channels - Latent: Deep feature processing at lowest resolution - Decoder: Progressive upsampling with skip connections - Refinement: Final feature enhancement """ - def __init__(self, - spatial_dims=2, - inp_channels=3, - out_channels=3, - dim=48, - num_blocks=[1, 1, 1, 1], - heads=[1, 1, 1, 1], - num_refinement_blocks=4, - ffn_expansion_factor=2.66, - bias=False, - LayerNorm_type='WithBias', - dual_pixel_task=False, - flash_attention=False): + + def __init__( + self, + spatial_dims=2, + inp_channels=3, + out_channels=3, + dim=48, + num_blocks=[1, 1, 1, 1], + heads=[1, 1, 1, 1], + num_refinement_blocks=4, + ffn_expansion_factor=2.66, + bias=False, + LayerNorm_type="WithBias", + dual_pixel_task=False, + flash_attention=False, + ): super().__init__() """Initialize Restormer model. - + Args: inp_channels: Number of input image channels out_channels: Number of output image channels @@ -111,7 +121,7 @@ def __init__(self, assert len(num_blocks) > 1, "Number of blocks must be greater than 1" assert len(num_blocks) == len(heads), "Number of blocks and heads must be equal" assert all([n > 0 for n in num_blocks]), "Number of blocks must be greater than 0" - + # Initial feature extraction self.patch_embed = OverlapPatchEmbed(spatial_dims, inp_channels, dim) self.encoder_levels = nn.ModuleList() @@ -119,70 +129,76 @@ def __init__(self, self.decoder_levels = nn.ModuleList() self.upsamples = nn.ModuleList() self.reduce_channels = nn.ModuleList() - num_steps = len(num_blocks) - 1 + num_steps = len(num_blocks) - 1 self.num_steps = num_steps - self.spatial_dims=spatial_dims - spatial_multiplier = 2**(spatial_dims - 1) - + self.spatial_dims = spatial_dims + spatial_multiplier = 2 ** (spatial_dims - 1) + # Define encoder levels for n in range(num_steps): - current_dim = dim * (2)**(n) - next_dim=current_dim//spatial_multiplier + current_dim = dim * (2) ** (n) + next_dim = current_dim // spatial_multiplier self.encoder_levels.append( - nn.Sequential(*[ - MDTATransformerBlock( - spatial_dims=spatial_dims, - dim=current_dim, - num_heads=heads[n], - ffn_expansion_factor=ffn_expansion_factor, - bias=bias, - LayerNorm_type=LayerNorm_type, - flash_attention=flash_attention - ) for _ in range(num_blocks[n]) - ]) + nn.Sequential( + *[ + MDTATransformerBlock( + spatial_dims=spatial_dims, + dim=current_dim, + num_heads=heads[n], + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + LayerNorm_type=LayerNorm_type, + flash_attention=flash_attention, + ) + for _ in range(num_blocks[n]) + ] + ) ) self.downsamples.append( DownSample( - spatial_dims=self.spatial_dims, - in_channels=current_dim, - out_channels=next_dim, - mode=DownsampleMode.PIXELUNSHUFFLE, - scale_factor=2, - bias=bias, + spatial_dims=self.spatial_dims, + in_channels=current_dim, + out_channels=next_dim, + mode=DownsampleMode.PIXELUNSHUFFLE, + scale_factor=2, + bias=bias, ) ) # Define latent space - latent_dim = dim * (2)**(num_steps) - self.latent = nn.Sequential(*[ - MDTATransformerBlock( - spatial_dims=spatial_dims, - dim=latent_dim, - num_heads=heads[num_steps], - ffn_expansion_factor=ffn_expansion_factor, - bias=bias, - LayerNorm_type=LayerNorm_type, - flash_attention=flash_attention - ) for _ in range(num_blocks[num_steps]) - ]) + latent_dim = dim * (2) ** (num_steps) + self.latent = nn.Sequential( + *[ + MDTATransformerBlock( + spatial_dims=spatial_dims, + dim=latent_dim, + num_heads=heads[num_steps], + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + LayerNorm_type=LayerNorm_type, + flash_attention=flash_attention, + ) + for _ in range(num_blocks[num_steps]) + ] + ) # Define decoder levels for n in reversed(range(num_steps)): - current_dim = dim * (2)**(n) - next_dim = dim * (2)**(n+1) + current_dim = dim * (2) ** (n) + next_dim = dim * (2) ** (n + 1) self.upsamples.append( UpSample( - spatial_dims=self.spatial_dims, - in_channels=next_dim, - out_channels=(current_dim), - mode=UpsampleMode.PIXELSHUFFLE, - scale_factor=2, - bias=bias, - apply_pad_pool=False - ) + spatial_dims=self.spatial_dims, + in_channels=next_dim, + out_channels=(current_dim), + mode=UpsampleMode.PIXELSHUFFLE, + scale_factor=2, + bias=bias, + apply_pad_pool=False, ) - + ) + # Reduce channel layers to deal with skip connections if n != 0: self.reduce_channels.append( @@ -192,69 +208,78 @@ def __init__(self, out_channels=current_dim, kernel_size=1, bias=bias, - conv_only=True + conv_only=True, ) ) decoder_dim = current_dim else: decoder_dim = next_dim - + self.decoder_levels.append( - nn.Sequential(*[ - MDTATransformerBlock( - spatial_dims=spatial_dims, - dim=decoder_dim, - num_heads=heads[n], - ffn_expansion_factor=ffn_expansion_factor, - bias=bias, - LayerNorm_type=LayerNorm_type, - flash_attention=flash_attention - ) for _ in range(num_blocks[n]) - ]) + nn.Sequential( + *[ + MDTATransformerBlock( + spatial_dims=spatial_dims, + dim=decoder_dim, + num_heads=heads[n], + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + LayerNorm_type=LayerNorm_type, + flash_attention=flash_attention, + ) + for _ in range(num_blocks[n]) + ] + ) ) # Final refinement and output - self.refinement = nn.Sequential(*[ - MDTATransformerBlock( - spatial_dims=spatial_dims, - dim=decoder_dim, - num_heads=heads[0], - ffn_expansion_factor=ffn_expansion_factor, - bias=bias, - LayerNorm_type=LayerNorm_type, - flash_attention=flash_attention - ) for _ in range(num_refinement_blocks) - ]) + self.refinement = nn.Sequential( + *[ + MDTATransformerBlock( + spatial_dims=spatial_dims, + dim=decoder_dim, + num_heads=heads[0], + ffn_expansion_factor=ffn_expansion_factor, + bias=bias, + LayerNorm_type=LayerNorm_type, + flash_attention=flash_attention, + ) + for _ in range(num_refinement_blocks) + ] + ) self.dual_pixel_task = dual_pixel_task if self.dual_pixel_task: self.skip_conv = Convolution( spatial_dims=self.spatial_dims, in_channels=dim, - out_channels=dim*2, + out_channels=dim * 2, kernel_size=1, bias=bias, - conv_only=True + conv_only=True, ) self.output = Convolution( spatial_dims=self.spatial_dims, - in_channels=dim*2, + in_channels=dim * 2, out_channels=out_channels, kernel_size=3, strides=1, padding=1, bias=bias, - conv_only=True + conv_only=True, ) + def forward(self, x): """Forward pass of Restormer. Processes input through encoder-decoder architecture with skip connections. Args: inp_img: Input image tensor of shape (B, C, H, W) - + Returns: Restored image tensor of shape (B, C, H, W) """ - assert x.shape[-1] > 2 ** self.num_steps and x.shape[-2] > 2 ** self.num_steps, "Input dimensions should be larger than 2^number_of_step" + assert ( + x.shape[-1] > 2 ** self.num_steps and x.shape[-2] > 2**self.num_steps + ), "Input dimensions should be larger than 2^number_of_step" # Patch embedding x = self.patch_embed(x) @@ -268,15 +293,15 @@ def forward(self, x): # Latent space x = self.latent(x) - + # Decoding path for idx in range(len(self.decoder_levels)): - x = self.upsamples[idx](x) + x = self.upsamples[idx](x) x = torch.concat([x, skip_connections[-(idx + 1)]], 1) if idx < len(self.decoder_levels) - 1: - x = self.reduce_channels[idx](x) + x = self.reduce_channels[idx](x) x = self.decoder_levels[idx](x) - + # Final refinement x = self.refinement(x) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 946bb6b824..46d6fc0825 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -439,21 +439,21 @@ def pixelunshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> tor input_size = list(x.size()) batch_size, channels = input_size[:2] scale_factor_mult = factor**dim - new_channels = channels * scale_factor_mult + new_channels = channels * scale_factor_mult if any(d % factor != 0 for d in input_size[2:]): raise ValueError( - f"All spatial dimensions must be divisible by factor {factor}. " - f"Got spatial dimensions: {input_size[2:]}" + f"All spatial dimensions must be divisible by factor {factor}. " f"Got spatial dimensions: {input_size[2:]}" ) output_size = [batch_size, new_channels] + [d // factor for d in input_size[2:]] reshaped_size = [batch_size, channels] + sum([[d // factor, factor] for d in input_size[2:]], []) - + permute_indices = [0, 1] + [(2 * i + 3) for i in range(spatial_dims)] + [(2 * i + 2) for i in range(spatial_dims)] - x=x.reshape(reshaped_size).permute(permute_indices) - x=x.reshape(output_size) + x = x.reshape(reshaped_size).permute(permute_indices) + x = x.reshape(output_size) return x + @contextmanager def eval_mode(*nets: nn.Module): """ diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index eb8bd451f8..3efc9b5e7f 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -29,6 +29,7 @@ CommonKeys, CompInitMode, DiceCEReduction, + DownsampleMode, EngineStatsKeys, FastMRIKeys, ForwardMode, @@ -62,7 +63,6 @@ TraceStatusKeys, TransformBackends, UpsampleMode, - DownsampleMode, Weight, WSIPatchKeys, ) diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 74ae829afd..50d1b28302 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -190,8 +190,8 @@ class DownsampleMode(StrEnum): CONV = "conv" # e.g. using strided convolution CONVGROUP = "convgroup" # e.g. using grouped strided convolution PIXELUNSHUFFLE = "pixelunshuffle" - MAXPOOL = "maxpool" - AVGPOOL = "avgpool" + MAXPOOL = "maxpool" + AVGPOOL = "avgpool" class BlendMode(StrEnum): diff --git a/tests/test_CABlock.py b/tests/test_CABlock.py index a3b172fd4e..82d813b1ad 100644 --- a/tests/test_CABlock.py +++ b/tests/test_CABlock.py @@ -11,16 +11,17 @@ from __future__ import annotations - import unittest from unittest import skipUnless -import torch + import numpy as np +import torch from parameterized import parameterized + from monai.networks import eval_mode from monai.networks.blocks.cablock import CABlock, FeedForward -from tests.utils import assert_allclose, SkipIfBeforePyTorchVersion from monai.utils import optional_import +from tests.utils import SkipIfBeforePyTorchVersion, assert_allclose einops, has_einops = optional_import("einops") @@ -36,10 +37,10 @@ "dim": dim, "num_heads": num_heads, "bias": bias, - "flash_attention": False + "flash_attention": False, }, - (2, dim, *([16] * spatial_dims)), - (2, dim, *([16] * spatial_dims)) + (2, dim, *([16] * spatial_dims)), + (2, dim, *([16] * spatial_dims)), ] TEST_CASES_CAB.append(test_case) @@ -53,31 +54,30 @@ class TestFeedForward(unittest.TestCase): - + @parameterized.expand(TEST_CASES_FEEDFORWARD) def test_shape(self, input_param, input_shape): net = FeedForward(**input_param) with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, input_shape) - + def test_gating_mechanism(self): net = FeedForward(spatial_dims=2, dim=32, ffn_expansion_factor=2.0, bias=True) x = torch.ones(1, 32, 16, 16) out = net(x) self.assertNotEqual(torch.sum(out), torch.sum(x)) - class TestCABlock(unittest.TestCase): - + @parameterized.expand(TEST_CASES_CAB) def test_shape(self, input_param, input_shape, expected_shape): net = CABlock(**input_param) with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) - + def test_invalid_spatial_dims(self): with self.assertRaises(ValueError): CABlock(spatial_dims=4, dim=64, num_heads=4, bias=True) @@ -112,14 +112,14 @@ def test_flash_vs_normal_attention(self): device = "cuda" if torch.cuda.is_available() else "cpu" block_flash = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=True).to(device) block_normal = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=False).to(device) - + block_normal.load_state_dict(block_flash.state_dict()) - + x = torch.randn(2, 64, 32, 32).to(device) with torch.no_grad(): out_flash = block_flash(x) out_normal = block_normal(x) - + assert_allclose(out_flash, out_normal, atol=1e-4) def test_deterministic_small_input(self): @@ -130,17 +130,14 @@ def test_deterministic_small_input(self): block.temperature.data.fill_(1.0) block.project_out.conv.weight.data.fill_(1.0) - x = torch.tensor([ - [[[1.0, 2.0], - [3.0, 4.0]], - [[5.0, 6.0], - [7.0, 8.0]]]], - dtype=torch.float32) + x = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]], dtype=torch.float32) output = block(x) # Channel attention: sum([1..8]) * (qkv_conv=1) * (dwconv=1) * (attn_weights=1) * (proj=1) = 36 * 2 = 72 expected = torch.full_like(x, 72.0) assert_allclose(output, expected, atol=1e-6) + + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/test_downsample_block.py b/tests/test_downsample_block.py index 1775e5609f..d6ce33ca5e 100644 --- a/tests/test_downsample_block.py +++ b/tests/test_downsample_block.py @@ -11,13 +11,15 @@ from __future__ import annotations +import unittest from collections.abc import Sequence import torch import torch.nn as nn +from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.blocks import MaxAvgPool, SubpixelDownsample, SubpixelUpsample, DownSample +from monai.networks.blocks import DownSample, MaxAvgPool, SubpixelDownsample, SubpixelUpsample TEST_CASES = [ [{"spatial_dims": 2, "kernel_size": 2}, (7, 4, 64, 48), (7, 8, 32, 24)], # 4-channel 2D, batch 7 @@ -42,31 +44,11 @@ ] TEST_CASES_DOWNSAMPLE = [ - [ - {"spatial_dims": 2, "in_channels": 4, "mode": "conv"}, - (1, 4, 16, 16), - (1, 4, 8, 8), - ], - [ - {"spatial_dims": 2, "in_channels": 4, "out_channels": 8, "mode": "convgroup"}, - (1, 4, 16, 16), - (1, 8, 8, 8), - ], - [ - {"spatial_dims": 3, "in_channels": 2, "mode": "maxpool"}, - (1, 2, 16, 16, 16), - (1, 2, 8, 8, 8), - ], - [ - {"spatial_dims": 2, "in_channels": 4, "mode": "avgpool"}, - (1, 4, 16, 16), - (1, 4, 8, 8), - ], - [ - {"spatial_dims": 2, "in_channels": 1, "mode": "pixelunshuffle"}, - (1, 1, 16, 16), - (1, 4, 8, 8), - ], + [{"spatial_dims": 2, "in_channels": 4, "mode": "conv"}, (1, 4, 16, 16), (1, 4, 8, 8)], + [{"spatial_dims": 2, "in_channels": 4, "out_channels": 8, "mode": "convgroup"}, (1, 4, 16, 16), (1, 8, 8, 8)], + [{"spatial_dims": 3, "in_channels": 2, "mode": "maxpool"}, (1, 2, 16, 16, 16), (1, 2, 8, 8, 8)], + [{"spatial_dims": 2, "in_channels": 4, "mode": "avgpool"}, (1, 4, 16, 16), (1, 4, 8, 8)], + [{"spatial_dims": 2, "in_channels": 1, "mode": "pixelunshuffle"}, (1, 1, 16, 16), (1, 4, 8, 8)], ] @@ -81,14 +63,14 @@ def test_shape(self, input_param, input_shape, expected_shape): class TestSubpixelDownsample(unittest.TestCase): - + @parameterized.expand(TEST_CASES_SUBPIXEL) def test_shape(self, input_param, input_shape, expected_shape): downsampler = SubpixelDownsample(**input_param) with eval_mode(downsampler): result = downsampler(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) - + def test_predefined_tensor(self): test_tensor = torch.arange(4).view(4, 1, 1).repeat(1, 4, 4) test_tensor = test_tensor.unsqueeze(0) @@ -120,7 +102,6 @@ def test_reconstruction_3D(self): reconstructed = up(downsampled) self.assertTrue(torch.allclose(input_tensor, reconstructed, rtol=1e-5)) - def test_invalid_spatial_size(self): downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2) with self.assertRaises(ValueError): @@ -141,7 +122,7 @@ def test_shape(self, input_param, input_shape, expected_shape): with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) - + def test_pre_post_conv(self): net = DownSample( spatial_dims=2, @@ -159,32 +140,33 @@ def test_pixelunshuffle_equivalence(self): class DownSample_local(torch.nn.Module): def __init__(self, n_feat: int): super().__init__() - self.conv = torch.nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False) + self.conv = torch.nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False) self.pixelunshuffle = torch.nn.PixelUnshuffle(2) - + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) return self.pixelunshuffle(x) + n_feat = 2 x = torch.randn(1, n_feat, 64, 64) - - fix_weight_conv = torch.nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False) - + + fix_weight_conv = torch.nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False) + monai_down = DownSample( spatial_dims=2, in_channels=n_feat, - out_channels=n_feat//2, + out_channels=n_feat // 2, mode="pixelunshuffle", - pre_conv=fix_weight_conv + pre_conv=fix_weight_conv, ) - + local_down = DownSample_local(n_feat) local_down.conv.weight.data = fix_weight_conv.weight.data.clone() - + with eval_mode(monai_down), eval_mode(local_down): out_monai = monai_down(x) out_local = local_down(x) - + self.assertTrue(torch.allclose(out_monai, out_local, rtol=1e-5)) def test_invalid_mode(self): diff --git a/tests/test_pixelunshuffle.py b/tests/test_pixelunshuffle.py index 48deaa9ee9..106dbe4d03 100644 --- a/tests/test_pixelunshuffle.py +++ b/tests/test_pixelunshuffle.py @@ -14,7 +14,9 @@ import unittest import torch -from monai.networks.utils import pixelunshuffle, pixelshuffle + +from monai.networks.utils import pixelshuffle, pixelunshuffle + class TestPixelUnshuffle(unittest.TestCase): @@ -24,7 +26,7 @@ def test_2d_basic(self): self.assertEqual(out.shape, (2, 16, 8, 8)) def test_3d_basic(self): - x = torch.randn(2, 4, 16, 16, 16) + x = torch.randn(2, 4, 16, 16, 16) out = pixelunshuffle(x, spatial_dims=3, scale_factor=2) self.assertEqual(out.shape, (2, 32, 8, 8, 8)) @@ -49,5 +51,6 @@ def test_invalid_scale(self): with self.assertRaises(RuntimeError): pixelunshuffle(x, spatial_dims=2, scale_factor=2) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_restormer.py b/tests/test_restormer.py index 1e9caab477..bf1fa83eac 100644 --- a/tests/test_restormer.py +++ b/tests/test_restormer.py @@ -11,12 +11,13 @@ from __future__ import annotations - import unittest + import torch from parameterized import parameterized -from monai.networks.nets.restormer import MDTATransformerBlock, OverlapPatchEmbed, Restormer + from monai.networks import eval_mode +from monai.networks.nets.restormer import MDTATransformerBlock, OverlapPatchEmbed, Restormer TEST_CASES_TRANSFORMER = [ # [spatial_dims, dim, num_heads, ffn_factor, bias, norm_type, flash_attn, input_shape] @@ -46,41 +47,43 @@ TEST_CASES_RESTORMER = [] for config in RESTORMER_CONFIGS: # 2D cases - TEST_CASES_RESTORMER.extend([ - [ - { - "spatial_dims": 2, - "inp_channels": 1, - "out_channels": 1, - "dim": 48, - "num_blocks": config["num_blocks"], - "heads": config["heads"], - "num_refinement_blocks": 2, - "ffn_expansion_factor": 1.5 - }, - (2, 1, 64, 64), - (2, 1, 64, 64) - ], - # 3D cases + TEST_CASES_RESTORMER.extend( [ - { - "spatial_dims": 3, - "inp_channels": 1, - "out_channels": 1, - "dim": 48, - "num_blocks": config["num_blocks"], - "heads": config["heads"], - "num_refinement_blocks": 2, - "ffn_expansion_factor": 1.5 - }, - (2, 1, 32, 32, 32), - (2, 1, 32, 32, 32) + [ + { + "spatial_dims": 2, + "inp_channels": 1, + "out_channels": 1, + "dim": 48, + "num_blocks": config["num_blocks"], + "heads": config["heads"], + "num_refinement_blocks": 2, + "ffn_expansion_factor": 1.5, + }, + (2, 1, 64, 64), + (2, 1, 64, 64), + ], + # 3D cases + [ + { + "spatial_dims": 3, + "inp_channels": 1, + "out_channels": 1, + "dim": 48, + "num_blocks": config["num_blocks"], + "heads": config["heads"], + "num_refinement_blocks": 2, + "ffn_expansion_factor": 1.5, + }, + (2, 1, 32, 32, 32), + (2, 1, 32, 32, 32), + ], ] - ]) + ) class TestMDTATransformerBlock(unittest.TestCase): - + @parameterized.expand(TEST_CASES_TRANSFORMER) def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, norm_type, flash, shape): block = MDTATransformerBlock( @@ -90,7 +93,7 @@ def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, norm_type, flas ffn_expansion_factor=ffn_factor, bias=bias, LayerNorm_type=norm_type, - flash_attention=flash + flash_attention=flash, ) with eval_mode(block): x = torch.randn(shape) @@ -99,37 +102,34 @@ def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, norm_type, flas class TestOverlapPatchEmbed(unittest.TestCase): - + @parameterized.expand(TEST_CASES_PATCHEMBED) def test_shape(self, spatial_dims, in_c, embed_dim, input_shape, expected_shape): - net = OverlapPatchEmbed( - spatial_dims=spatial_dims, - in_c=in_c, - embed_dim=embed_dim - ) + net = OverlapPatchEmbed(spatial_dims=spatial_dims, in_c=in_c, embed_dim=embed_dim) with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) - + class TestRestormer(unittest.TestCase): - + @parameterized.expand(TEST_CASES_RESTORMER) def test_shape(self, input_param, input_shape, expected_shape): net = Restormer(**input_param) with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) - + def test_small_input_error_2d(self): net = Restormer(spatial_dims=2, inp_channels=1, out_channels=1) with self.assertRaises(AssertionError): net(torch.randn(1, 1, 8, 8)) - + def test_small_input_error_3d(self): net = Restormer(spatial_dims=3, inp_channels=1, out_channels=1) with self.assertRaises(AssertionError): net(torch.randn(1, 1, 8, 8, 8)) + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From acb818d37fcdbd0383df164acfe75218055d2185 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 23 Jan 2025 14:51:47 +0000 Subject: [PATCH 24/67] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/blocks/downsample.py | 2 +- tests/test_CABlock.py | 2 -- tests/test_downsample_block.py | 2 -- 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/monai/networks/blocks/downsample.py b/monai/networks/blocks/downsample.py index 9e33f7fdce..721738eedb 100644 --- a/monai/networks/blocks/downsample.py +++ b/monai/networks/blocks/downsample.py @@ -18,7 +18,7 @@ from monai.networks.layers.factories import Conv, Pool from monai.networks.utils import pixelunshuffle -from monai.utils import DownsampleMode, InterpolateMode, ensure_tuple_rep, look_up_option +from monai.utils import DownsampleMode, ensure_tuple_rep, look_up_option __all__ = ["MaxAvgPool", "DownSample", "Downsample", "SubpixelDownsample", "SubpixelDownSample", "Subpixeldownsample"] diff --git a/tests/test_CABlock.py b/tests/test_CABlock.py index 82d813b1ad..fb981ad829 100644 --- a/tests/test_CABlock.py +++ b/tests/test_CABlock.py @@ -12,9 +12,7 @@ from __future__ import annotations import unittest -from unittest import skipUnless -import numpy as np import torch from parameterized import parameterized diff --git a/tests/test_downsample_block.py b/tests/test_downsample_block.py index d6ce33ca5e..27baa0ff08 100644 --- a/tests/test_downsample_block.py +++ b/tests/test_downsample_block.py @@ -12,10 +12,8 @@ from __future__ import annotations import unittest -from collections.abc import Sequence import torch -import torch.nn as nn from parameterized import parameterized from monai.networks import eval_mode From 6352ba94e98478bf708ec200616d51b037cec715 Mon Sep 17 00:00:00 2001 From: tisalon Date: Thu, 23 Jan 2025 16:32:03 +0100 Subject: [PATCH 25/67] DCO Remediation Commit for tisalon I, tisalon , hereby add my Signed-off-by to this commit: 3db93ce8e151f298609061a3e0589cf485baa8aa I, tisalon , hereby add my Signed-off-by to this commit: 9693e04eff628416061606a332c2f849f05316ea I, tisalon , hereby add my Signed-off-by to this commit: a89f2995f600952167014f7bafceecbd243d18fb I, tisalon , hereby add my Signed-off-by to this commit: 450691f225d1860cd331f055daee506f398b0df6 I, tisalon , hereby add my Signed-off-by to this commit: d0920d85a6cb334cc8434f2affe9d2f4bc293bbb I, tisalon , hereby add my Signed-off-by to this commit: 1a48d4de61fbe239b8f9f2321c25e988a1ec550b I, tisalon , hereby add my Signed-off-by to this commit: fe47807da520b1dd0f93dc7fc719eb0e092296ff I, tisalon , hereby add my Signed-off-by to this commit: 86155cd967560bb2883655f944a00be44dd38b9c I, tisalon , hereby add my Signed-off-by to this commit: 137a7f21e6125fb9f76b7799f7104ba62466d89b I, tisalon , hereby add my Signed-off-by to this commit: fb17baf6a4370af0b99a098b2cca9dd35010e628 I, tisalon , hereby add my Signed-off-by to this commit: 5ff0baa783de08354b87ef37a487526452636103 I, tisalon , hereby add my Signed-off-by to this commit: 2566db179645bd447723bd2ed1813b7d6a6c4fdf I, tisalon , hereby add my Signed-off-by to this commit: ac4047b26234803c102c555b8fb8715b2bfc5b8e I, tisalon , hereby add my Signed-off-by to this commit: 2b74270f3a04944be7c3af999e20ff123d92c0b2 I, tisalon , hereby add my Signed-off-by to this commit: 9b745338a583f1523f9c291c906d46bf12008131 I, tisalon , hereby add my Signed-off-by to this commit: 1ab34f66d248e4ae5cb63e2b1031bce92b86e0ef I, tisalon , hereby add my Signed-off-by to this commit: 4f4c62cfd2d71386bbf4eca8c3ba239aa4dd5a3b I, tisalon , hereby add my Signed-off-by to this commit: 068688f50bb7e11a61f8411539cc368bdf1b6937 I, tisalon , hereby add my Signed-off-by to this commit: e2e1070a602ae1f5452f8c1b3ae51bb640600b7b I, tisalon , hereby add my Signed-off-by to this commit: 35c7ee48e0c71f0979e181ce36a654c2683880c3 I, tisalon , hereby add my Signed-off-by to this commit: d8cb6c13b59f86b5fd65e255b0cc5e55cb298279 I, tisalon , hereby add my Signed-off-by to this commit: 6d96816d3936b4faac6d559bbc290ab0a5589397 I, tisalon , hereby add my Signed-off-by to this commit: 8a688fb57926232f73240fcb1f12805e2de72ded Signed-off-by: tisalon --- monai/networks/blocks/cablock.py | 2 +- monai/networks/nets/restormer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/blocks/cablock.py b/monai/networks/blocks/cablock.py index 6f5eddc780..b92e56c747 100644 --- a/monai/networks/blocks/cablock.py +++ b/monai/networks/blocks/cablock.py @@ -120,7 +120,7 @@ def _normal_attention(self, q, k, v): attn = attn.softmax(dim=-1) return attn @ v - def forward(self, x): + def forward(self, x) -> torch.Tensor: """Forward pass for MDTA attention. 1. Apply depth-wise convolutions to Q, K, V 2. Reshape Q, K, V for multi-head attention diff --git a/monai/networks/nets/restormer.py b/monai/networks/nets/restormer.py index 3d3d9565d4..c21fa3db76 100644 --- a/monai/networks/nets/restormer.py +++ b/monai/networks/nets/restormer.py @@ -268,7 +268,7 @@ def __init__( conv_only=True, ) - def forward(self, x): + def forward(self, x) -> torch.Tensor: """Forward pass of Restormer. Processes input through encoder-decoder architecture with skip connections. Args: From c7b1af479867f274fd67cec9c8337fe4adea89f4 Mon Sep 17 00:00:00 2001 From: tisalon Date: Fri, 24 Jan 2025 09:19:17 +0100 Subject: [PATCH 26/67] add optional_import to downsample block test Signed-off-by: tisalon --- tests/test_downsample_block.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/test_downsample_block.py b/tests/test_downsample_block.py index 27baa0ff08..5e660510d4 100644 --- a/tests/test_downsample_block.py +++ b/tests/test_downsample_block.py @@ -18,6 +18,9 @@ from monai.networks import eval_mode from monai.networks.blocks import DownSample, MaxAvgPool, SubpixelDownsample, SubpixelUpsample +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") TEST_CASES = [ [{"spatial_dims": 2, "kernel_size": 2}, (7, 4, 64, 48), (7, 8, 32, 24)], # 4-channel 2D, batch 7 @@ -82,7 +85,7 @@ def test_predefined_tensor(self): self.assertTrue(torch.all(result[0, 8:11] == 2)) self.assertTrue(torch.all(result[0, 12:15] == 3)) - def test_reconstruction_2D(self): + def test_reconstruction_2d(self): input_tensor = torch.randn(1, 1, 4, 4) down = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=None) up = SubpixelUpsample(spatial_dims=2, in_channels=4, scale_factor=2, conv_block=None, apply_pad_pool=False) @@ -91,7 +94,7 @@ def test_reconstruction_2D(self): reconstructed = up(downsampled) self.assertTrue(torch.allclose(input_tensor, reconstructed, rtol=1e-5)) - def test_reconstruction_3D(self): + def test_reconstruction_3d(self): input_tensor = torch.randn(1, 1, 4, 4, 4) down = SubpixelDownsample(spatial_dims=3, in_channels=1, scale_factor=2, conv_block=None) up = SubpixelUpsample(spatial_dims=3, in_channels=4, scale_factor=2, conv_block=None, apply_pad_pool=False) @@ -135,7 +138,7 @@ def test_pre_post_conv(self): self.assertEqual(result.shape, (1, 16, 8, 8)) def test_pixelunshuffle_equivalence(self): - class DownSample_local(torch.nn.Module): + class DownSampleLocal(torch.nn.Module): def __init__(self, n_feat: int): super().__init__() self.conv = torch.nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False) @@ -158,7 +161,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: pre_conv=fix_weight_conv, ) - local_down = DownSample_local(n_feat) + local_down = DownSampleLocal(n_feat) local_down.conv.weight.data = fix_weight_conv.weight.data.clone() with eval_mode(monai_down), eval_mode(local_down): From 8faa5da5fa5ac514e93d7c2f313927971e0e9b9d Mon Sep 17 00:00:00 2001 From: tisalon Date: Fri, 7 Feb 2025 13:21:06 +0100 Subject: [PATCH 27/67] rename args and fix imports --- monai/networks/nets/restormer.py | 44 ++++++++++++++++---------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/monai/networks/nets/restormer.py b/monai/networks/nets/restormer.py index c21fa3db76..489ef1984b 100644 --- a/monai/networks/nets/restormer.py +++ b/monai/networks/nets/restormer.py @@ -32,11 +32,11 @@ def __init__( num_heads: int, ffn_expansion_factor: float, bias: bool, - LayerNorm_type: str, + layer_norm_type: str = "BiasFree", flash_attention: bool = False, ): super().__init__() - use_bias = LayerNorm_type != "BiasFree" + use_bias = layer_norm_type != "BiasFree" self.norm1 = Norm[Norm.INSTANCE, spatial_dims](dim, affine=use_bias) self.attn = CABlock(spatial_dims, dim, num_heads, bias, flash_attention) self.norm2 = Norm[Norm.INSTANCE, spatial_dims](dim, affine=use_bias) @@ -88,19 +88,19 @@ class Restormer(nn.Module): def __init__( self, - spatial_dims=2, - inp_channels=3, - out_channels=3, - dim=48, - num_blocks=[1, 1, 1, 1], - heads=[1, 1, 1, 1], - num_refinement_blocks=4, - ffn_expansion_factor=2.66, - bias=False, - LayerNorm_type="WithBias", - dual_pixel_task=False, - flash_attention=False, - ): + spatial_dims: int = 2, + inp_channels: int = 3, + out_channels: int = 3, + dim: int = 48, + num_blocks: tuple[int, ...] = (1, 1, 1, 1), + heads: tuple[int, ...] = (1, 1, 1, 1), + num_refinement_blocks: int = 4, + ffn_expansion_factor: float = 2.66, + bias: bool = False, + layer_norm_type: str = "WithBias", + dual_pixel_task: bool = False, + flash_attention: bool = False, + ) -> None: super().__init__() """Initialize Restormer model. @@ -113,14 +113,14 @@ def __init__( heads: Number of attention heads at each scale ffn_expansion_factor: Expansion factor for feed-forward network bias: Whether to use bias in convolutions - LayerNorm_type: Type of normalization ('WithBias' or 'BiasFree') + layer_norm_type: Type of normalization ('WithBias' or 'BiasFree') dual_pixel_task: Enable dual-pixel specific processing flash_attention: Use flash attention if available """ # Check input parameters assert len(num_blocks) > 1, "Number of blocks must be greater than 1" assert len(num_blocks) == len(heads), "Number of blocks and heads must be equal" - assert all([n > 0 for n in num_blocks]), "Number of blocks must be greater than 0" + assert all(n > 0 for n in num_blocks), "Number of blocks must be greater than 0" # Initial feature extraction self.patch_embed = OverlapPatchEmbed(spatial_dims, inp_channels, dim) @@ -147,7 +147,7 @@ def __init__( num_heads=heads[n], ffn_expansion_factor=ffn_expansion_factor, bias=bias, - LayerNorm_type=LayerNorm_type, + layer_norm_type=layer_norm_type, flash_attention=flash_attention, ) for _ in range(num_blocks[n]) @@ -176,7 +176,7 @@ def __init__( num_heads=heads[num_steps], ffn_expansion_factor=ffn_expansion_factor, bias=bias, - LayerNorm_type=LayerNorm_type, + layer_norm_type=layer_norm_type, flash_attention=flash_attention, ) for _ in range(num_blocks[num_steps]) @@ -224,7 +224,7 @@ def __init__( num_heads=heads[n], ffn_expansion_factor=ffn_expansion_factor, bias=bias, - LayerNorm_type=LayerNorm_type, + layer_norm_type=layer_norm_type, flash_attention=flash_attention, ) for _ in range(num_blocks[n]) @@ -241,7 +241,7 @@ def __init__( num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, - LayerNorm_type=LayerNorm_type, + layer_norm_type=layer_norm_type, flash_attention=flash_attention, ) for _ in range(num_refinement_blocks) @@ -286,7 +286,7 @@ def forward(self, x) -> torch.Tensor: skip_connections = [] # Encoding path - for idx, (encoder, downsample) in enumerate(zip(self.encoder_levels, self.downsamples)): + for _idx, (encoder, downsample) in enumerate(zip(self.encoder_levels, self.downsamples)): x = encoder(x) skip_connections.append(x) x = downsample(x) From be899588719d11b795715a305c7f1eaaf378d024 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Wed, 15 Jan 2025 15:08:47 +0800 Subject: [PATCH 28/67] Using LocalStore in Zarr v3 (#8299) Fixes #8298 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- tests/test_zarr_avg_merger.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/test_zarr_avg_merger.py b/tests/test_zarr_avg_merger.py index de7fad48da..a52dbceb4c 100644 --- a/tests/test_zarr_avg_merger.py +++ b/tests/test_zarr_avg_merger.py @@ -19,11 +19,18 @@ from torch.nn.functional import pad from monai.inferers import ZarrAvgMerger -from monai.utils import optional_import +from monai.utils import get_package_version, optional_import, version_geq from tests.utils import assert_allclose np.seterr(divide="ignore", invalid="ignore") zarr, has_zarr = optional_import("zarr") +if has_zarr: + if version_geq(get_package_version("zarr"), "3.0.0"): + directory_store = zarr.storage.LocalStore("test.zarr") + else: + directory_store = zarr.storage.DirectoryStore("test.zarr") +else: + directory_store = None numcodecs, has_numcodecs = optional_import("numcodecs") TENSOR_4x4 = torch.randint(low=0, high=255, size=(2, 3, 4, 4), dtype=torch.float32) @@ -154,7 +161,7 @@ # explicit directory store TEST_CASE_10_DIRECTORY_STORE = [ - dict(merged_shape=TENSOR_4x4.shape, store=zarr.storage.DirectoryStore("test.zarr")), + dict(merged_shape=TENSOR_4x4.shape, store=directory_store), [ (TENSOR_4x4[..., :2, :2], (0, 0)), (TENSOR_4x4[..., :2, 2:], (0, 2)), From c17938bed7678ad53b3447c50671943868c5db57 Mon Sep 17 00:00:00 2001 From: advcu <65158236+advcu987@users.noreply.github.com> Date: Mon, 20 Jan 2025 07:26:06 +0100 Subject: [PATCH 29/67] 8267 fix normalize intensity (#8286) Fixes #8267 . ### Description Fix channel-wise intensity normalization for integer type inputs. ### Types of changes - [ ] Non-breaking change (fix or new feature that would not break existing functionality). - [x] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: advcu987 Signed-off-by: advcu <65158236+advcu987@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/transforms/intensity/array.py | 4 ++++ tests/test_normalize_intensity.py | 21 +++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 20000c52c4..8fe658ad3e 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -821,6 +821,7 @@ class NormalizeIntensity(Transform): mean and std on each channel separately. When `channel_wise` is True, the first dimension of `subtrahend` and `divisor` should be the number of image channels if they are not None. + If the input is not of floating point type, it will be converted to float32 Args: subtrahend: the amount to subtract by (usually the mean). @@ -907,6 +908,9 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: if self.divisor is not None and len(self.divisor) != len(img): raise ValueError(f"img has {len(img)} channels, but divisor has {len(self.divisor)} components.") + if not img.dtype.is_floating_point: + img, *_ = convert_data_type(img, dtype=torch.float32) + for i, d in enumerate(img): img[i] = self._normalize( # type: ignore d, diff --git a/tests/test_normalize_intensity.py b/tests/test_normalize_intensity.py index 72ebf579e1..7efd0d83e5 100644 --- a/tests/test_normalize_intensity.py +++ b/tests/test_normalize_intensity.py @@ -108,6 +108,27 @@ def test_channel_wise(self, im_type): normalized = normalizer(input_data) assert_allclose(normalized, im_type(expected), type_test="tensor") + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_channel_wise_int(self, im_type): + normalizer = NormalizeIntensity(nonzero=True, channel_wise=True) + input_data = im_type(torch.arange(1, 25).reshape(2, 3, 4)) + expected = np.array( + [ + [ + [-1.593255, -1.3035723, -1.0138896, -0.7242068], + [-0.4345241, -0.1448414, 0.1448414, 0.4345241], + [0.7242068, 1.0138896, 1.3035723, 1.593255], + ], + [ + [-1.593255, -1.3035723, -1.0138896, -0.7242068], + [-0.4345241, -0.1448414, 0.1448414, 0.4345241], + [0.7242068, 1.0138896, 1.3035723, 1.593255], + ], + ] + ) + normalized = normalizer(input_data) + assert_allclose(normalized, im_type(expected), type_test="tensor", rtol=1e-7, atol=1e-7) # tolerance + @parameterized.expand([[p] for p in TEST_NDARRAYS]) def test_value_errors(self, im_type): input_data = im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]])) From 64613a7ec9ebba567da7ebcae210a166a81d4495 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Tue, 21 Jan 2025 08:25:39 +0800 Subject: [PATCH 30/67] Fix bundle download error from ngc source (#8307) Fixes #8306 This previous api has been deprecated, update based on: https://docs.ngc.nvidia.com/api/?urls.primaryName=Private%20Artifacts%20(Models)%20API#/artifact-file-controller/downloadAllArtifactFiles ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/bundle/scripts.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 131c78008b..5089f0c045 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -174,7 +174,7 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam def _get_ngc_bundle_url(model_name: str, version: str) -> str: - return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/zip" + return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/files" def _get_ngc_private_base_url(repo: str) -> str: @@ -218,6 +218,21 @@ def _remove_ngc_prefix(name: str, prefix: str = "monai_") -> str: return name +def _get_all_download_files(request_url: str, headers: dict | None = None) -> list[str]: + if not has_requests: + raise ValueError("requests package is required, please install it.") + headers = {} if headers is None else headers + response = requests_get(request_url, headers=headers) + response.raise_for_status() + model_info = json.loads(response.text) + + if not isinstance(model_info, dict) or "modelFiles" not in model_info: + raise ValueError("The data is not a dictionary or it does not have the key 'modelFiles'.") + + model_files = model_info["modelFiles"] + return [f["path"] for f in model_files] + + def _download_from_ngc( download_path: Path, filename: str, @@ -229,12 +244,12 @@ def _download_from_ngc( # ensure prefix is contained filename = _add_ngc_prefix(filename, prefix=prefix) url = _get_ngc_bundle_url(model_name=filename, version=version) - filepath = download_path / f"{filename}_v{version}.zip" if remove_prefix: filename = _remove_ngc_prefix(filename, prefix=remove_prefix) - extract_path = download_path / f"{filename}" - download_url(url=url, filepath=filepath, hash_val=None, progress=progress) - extractall(filepath=filepath, output_dir=extract_path, has_base=True) + filepath = download_path / filename + filepath.mkdir(parents=True, exist_ok=True) + for file in _get_all_download_files(url): + download_url(url=f"{url}/{file}", filepath=f"{filepath}/{file}", hash_val=None, progress=progress) def _download_from_ngc_private( From 5643d4aecc4961afaec016bcac67d387015067e7 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 24 Jan 2025 23:00:48 +0800 Subject: [PATCH 31/67] Fix deprecated usage in zarr (#8313) Fixes #8298 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/inferers/merger.py | 23 +++++++++++++++++++---- tests/test_zarr_avg_merger.py | 7 ++++--- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/monai/inferers/merger.py b/monai/inferers/merger.py index d01d334142..1344207e18 100644 --- a/monai/inferers/merger.py +++ b/monai/inferers/merger.py @@ -15,12 +15,13 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from contextlib import nullcontext +from tempfile import TemporaryDirectory from typing import TYPE_CHECKING, Any import numpy as np import torch -from monai.utils import ensure_tuple_size, optional_import, require_pkg +from monai.utils import ensure_tuple_size, get_package_version, optional_import, require_pkg, version_geq if TYPE_CHECKING: import zarr @@ -233,7 +234,7 @@ def __init__( store: zarr.storage.Store | str = "merged.zarr", value_store: zarr.storage.Store | str | None = None, count_store: zarr.storage.Store | str | None = None, - compressor: str = "default", + compressor: str | None = None, value_compressor: str | None = None, count_compressor: str | None = None, chunks: Sequence[int] | bool = True, @@ -246,8 +247,22 @@ def __init__( self.value_dtype = value_dtype self.count_dtype = count_dtype self.store = store - self.value_store = zarr.storage.TempStore() if value_store is None else value_store - self.count_store = zarr.storage.TempStore() if count_store is None else count_store + self.tmpdir: TemporaryDirectory | None + if version_geq(get_package_version("zarr"), "3.0.0"): + if value_store is None: + self.tmpdir = TemporaryDirectory() + self.value_store = zarr.storage.LocalStore(self.tmpdir.name) + else: + self.value_store = value_store + if count_store is None: + self.tmpdir = TemporaryDirectory() + self.count_store = zarr.storage.LocalStore(self.tmpdir.name) + else: + self.count_store = count_store + else: + self.tmpdir = None + self.value_store = zarr.storage.TempStore() if value_store is None else value_store + self.count_store = zarr.storage.TempStore() if count_store is None else count_store self.chunks = chunks self.compressor = compressor self.value_compressor = value_compressor diff --git a/tests/test_zarr_avg_merger.py b/tests/test_zarr_avg_merger.py index a52dbceb4c..3c89e4fb03 100644 --- a/tests/test_zarr_avg_merger.py +++ b/tests/test_zarr_avg_merger.py @@ -287,15 +287,16 @@ class ZarrAvgMergerTests(unittest.TestCase): ] ) def test_zarr_avg_merger_patches(self, arguments, patch_locations, expected): + codec_reg = numcodecs.registry.codec_registry if "compressor" in arguments: if arguments["compressor"] != "default": - arguments["compressor"] = zarr.codec_registry[arguments["compressor"].lower()]() + arguments["compressor"] = codec_reg[arguments["compressor"].lower()]() if "value_compressor" in arguments: if arguments["value_compressor"] != "default": - arguments["value_compressor"] = zarr.codec_registry[arguments["value_compressor"].lower()]() + arguments["value_compressor"] = codec_reg[arguments["value_compressor"].lower()]() if "count_compressor" in arguments: if arguments["count_compressor"] != "default": - arguments["count_compressor"] = zarr.codec_registry[arguments["count_compressor"].lower()]() + arguments["count_compressor"] = codec_reg[arguments["count_compressor"].lower()]() merger = ZarrAvgMerger(**arguments) for pl in patch_locations: merger.aggregate(pl[0], pl[1]) From 595674aa2d2a4a9ecbfd5201ac44735404640100 Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Tue, 28 Jan 2025 07:54:49 +0800 Subject: [PATCH 32/67] update pydicom reader to enable gpu load (#8283) Related to #8241 . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yiheng Wang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- monai/data/image_reader.py | 219 ++++++++++++++++++++++++++++--------- tests/test_load_image.py | 58 +++++++++- 2 files changed, 222 insertions(+), 55 deletions(-) diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 5bc38f69ea..003ec2cf0b 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -418,6 +418,10 @@ class PydicomReader(ImageReader): If provided, only the matched files will be included. For example, to include the file name "image_0001.dcm", the regular expression could be `".*image_(\\d+).dcm"`. Default to `""`. Set it to `None` to use `pydicom.misc.is_dicom` to match valid files. + to_gpu: If True, load the image into GPU memory using CuPy and Kvikio. This can accelerate data loading. + Default is False. CuPy and Kvikio are required for this option. + In practical use, it's recommended to add a warm up call before the actual loading. + A related tutorial will be prepared in the future, and the document will be updated accordingly. kwargs: additional args for `pydicom.dcmread` API. more details about available args: https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html If the `get_data` function will be called @@ -434,6 +438,7 @@ def __init__( prune_metadata: bool = True, label_dict: dict | None = None, fname_regex: str = "", + to_gpu: bool = False, **kwargs, ): super().__init__() @@ -444,6 +449,33 @@ def __init__( self.prune_metadata = prune_metadata self.label_dict = label_dict self.fname_regex = fname_regex + if to_gpu and (not has_cp or not has_kvikio): + warnings.warn( + "PydicomReader: CuPy and/or Kvikio not installed for GPU loading, falling back to CPU loading." + ) + to_gpu = False + + if to_gpu: + self.warmup_kvikio() + + self.to_gpu = to_gpu + + def warmup_kvikio(self): + """ + Warm up the Kvikio library to initialize the internal buffers, cuFile, GDS, etc. + This can accelerate the data loading process when `to_gpu` is set to True. + """ + if has_cp and has_kvikio: + a = cp.arange(100) + with tempfile.NamedTemporaryFile() as tmp_file: + tmp_file_name = tmp_file.name + f = kvikio.CuFile(tmp_file_name, "w") + f.write(a) + f.close() + + b = cp.empty_like(a) + f = kvikio.CuFile(tmp_file_name, "r") + f.read(b) def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ @@ -475,12 +507,15 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): img_ = [] filenames: Sequence[PathLike] = ensure_tuple(data) + self.filenames = list(filenames) kwargs_ = self.kwargs.copy() + if self.to_gpu: + kwargs["defer_size"] = "100 KB" kwargs_.update(kwargs) self.has_series = False - for name in filenames: + for i, name in enumerate(filenames): name = f"{name}" if Path(name).is_dir(): # read DICOM series @@ -489,20 +524,28 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): else: series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if pydicom.misc.is_dicom(slc)] slices = [] + loaded_slc_names = [] for slc in series_slcs: try: slices.append(pydicom.dcmread(fp=slc, **kwargs_)) + loaded_slc_names.append(slc) except pydicom.errors.InvalidDicomError as e: warnings.warn(f"Failed to read {slc} with exception: \n{e}.", stacklevel=2) - img_.append(slices if len(slices) > 1 else slices[0]) if len(slices) > 1: self.has_series = True + img_.append(slices) + self.filenames[i] = loaded_slc_names # type: ignore + else: + img_.append(slices[0]) # type: ignore + self.filenames[i] = loaded_slc_names[0] # type: ignore else: ds = pydicom.dcmread(fp=name, **kwargs_) - img_.append(ds) - return img_ if len(filenames) > 1 else img_[0] + img_.append(ds) # type: ignore + if len(filenames) == 1: + return img_[0] + return img_ - def _combine_dicom_series(self, data: Iterable): + def _combine_dicom_series(self, data: Iterable, filenames: Sequence[PathLike]): """ Combine dicom series (a list of pydicom dataset objects). Their data arrays will be stacked together at a new dimension as the last dimension. @@ -522,28 +565,27 @@ def _combine_dicom_series(self, data: Iterable): """ slices: list = [] # for a dicom series - for slc_ds in data: + for slc_ds, filename in zip(data, filenames): if hasattr(slc_ds, "InstanceNumber"): - slices.append(slc_ds) + slices.append((slc_ds, filename)) else: - warnings.warn(f"slice: {slc_ds.filename} does not have InstanceNumber tag, skip it.") - slices = sorted(slices, key=lambda s: s.InstanceNumber) - + warnings.warn(f"slice: {filename} does not have InstanceNumber tag, skip it.") + slices = sorted(slices, key=lambda s: s[0].InstanceNumber) if len(slices) == 0: raise ValueError("the input does not have valid slices.") - first_slice = slices[0] + first_slice, first_filename = slices[0] average_distance = 0.0 - first_array = self._get_array_data(first_slice) + first_array = self._get_array_data(first_slice, first_filename) shape = first_array.shape - spacing = getattr(first_slice, "PixelSpacing", [1.0, 1.0, 1.0]) + spacing = getattr(first_slice, "PixelSpacing", [1.0] * len(shape)) prev_pos = getattr(first_slice, "ImagePositionPatient", (0.0, 0.0, 0.0))[2] stack_array = [first_array] for idx in range(1, len(slices)): - slc_array = self._get_array_data(slices[idx]) + slc_array = self._get_array_data(slices[idx][0], slices[idx][1]) slc_shape = slc_array.shape - slc_spacing = getattr(slices[idx], "PixelSpacing", (1.0, 1.0, 1.0)) - slc_pos = getattr(slices[idx], "ImagePositionPatient", (0.0, 0.0, float(idx)))[2] + slc_spacing = getattr(slices[idx][0], "PixelSpacing", [1.0] * len(shape)) + slc_pos = getattr(slices[idx][0], "ImagePositionPatient", (0.0, 0.0, float(idx)))[2] if not np.allclose(slc_spacing, spacing): warnings.warn(f"the list contains slices that have different spacings {spacing} and {slc_spacing}.") if shape != slc_shape: @@ -555,11 +597,14 @@ def _combine_dicom_series(self, data: Iterable): if len(slices) > 1: average_distance /= len(slices) - 1 spacing.append(average_distance) - stack_array = np.stack(stack_array, axis=-1) + if self.to_gpu: + stack_array = cp.stack(stack_array, axis=-1) + else: + stack_array = np.stack(stack_array, axis=-1) stack_metadata = self._get_meta_dict(first_slice) stack_metadata["spacing"] = np.asarray(spacing) - if hasattr(slices[-1], "ImagePositionPatient"): - stack_metadata["lastImagePositionPatient"] = np.asarray(slices[-1].ImagePositionPatient) + if hasattr(slices[-1][0], "ImagePositionPatient"): + stack_metadata["lastImagePositionPatient"] = np.asarray(slices[-1][0].ImagePositionPatient) stack_metadata[MetaKeys.SPATIAL_SHAPE] = shape + (len(slices),) else: stack_array = stack_array[0] @@ -597,29 +642,35 @@ def get_data(self, data) -> tuple[np.ndarray, dict]: if self.has_series is True: # a list, all objects within a list belong to one dicom series if not isinstance(data[0], list): - dicom_data.append(self._combine_dicom_series(data)) + # input is a dir, self.filenames is a list of list of filenames + dicom_data.append(self._combine_dicom_series(data, self.filenames[0])) # type: ignore # a list of list, each inner list represents a dicom series else: - for series in data: - dicom_data.append(self._combine_dicom_series(series)) + for i, series in enumerate(data): + dicom_data.append(self._combine_dicom_series(series, self.filenames[i])) # type: ignore else: # a single pydicom dataset object if not isinstance(data, list): data = [data] - for d in data: + for i, d in enumerate(data): if hasattr(d, "SegmentSequence"): - data_array, metadata = self._get_seg_data(d) + data_array, metadata = self._get_seg_data(d, self.filenames[i]) else: - data_array = self._get_array_data(d) + data_array = self._get_array_data(d, self.filenames[i]) metadata = self._get_meta_dict(d) metadata[MetaKeys.SPATIAL_SHAPE] = data_array.shape dicom_data.append((data_array, metadata)) + # TODO: the actual type is list[np.ndarray | cp.ndarray] + # should figure out how to define correct types without having cupy not found error + # https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918 img_array: list[np.ndarray] = [] compatible_meta: dict = {} for data_array, metadata in ensure_tuple(dicom_data): - img_array.append(np.ascontiguousarray(np.swapaxes(data_array, 0, 1) if self.swap_ij else data_array)) + if self.swap_ij: + data_array = cp.swapaxes(data_array, 0, 1) if self.to_gpu else np.swapaxes(data_array, 0, 1) + img_array.append(cp.ascontiguousarray(data_array) if self.to_gpu else np.ascontiguousarray(data_array)) affine = self._get_affine(metadata, self.affine_lps_to_ras) metadata[MetaKeys.SPACE] = SpaceKeys.RAS if self.affine_lps_to_ras else SpaceKeys.LPS if self.swap_ij: @@ -641,7 +692,7 @@ def get_data(self, data) -> tuple[np.ndarray, dict]: _copy_compatible_dict(metadata, compatible_meta) - return _stack_images(img_array, compatible_meta), compatible_meta + return _stack_images(img_array, compatible_meta, to_cupy=self.to_gpu), compatible_meta def _get_meta_dict(self, img) -> dict: """ @@ -713,7 +764,7 @@ def _get_affine(self, metadata: dict, lps_to_ras: bool = True): affine = orientation_ras_lps(affine) return affine - def _get_frame_data(self, img) -> Iterator: + def _get_frame_data(self, img, filename, array_data) -> Iterator: """ yield frames and description from the segmentation image. This function is adapted from Highdicom: @@ -751,48 +802,54 @@ def _get_frame_data(self, img) -> Iterator: """ if not hasattr(img, "PerFrameFunctionalGroupsSequence"): - raise NotImplementedError( - f"To read dicom seg: {img.filename}, 'PerFrameFunctionalGroupsSequence' is required." - ) + raise NotImplementedError(f"To read dicom seg: {filename}, 'PerFrameFunctionalGroupsSequence' is required.") frame_seg_nums = [] for f in img.PerFrameFunctionalGroupsSequence: if not hasattr(f, "SegmentIdentificationSequence"): raise NotImplementedError( - f"To read dicom seg: {img.filename}, 'SegmentIdentificationSequence' is required for each frame." + f"To read dicom seg: {filename}, 'SegmentIdentificationSequence' is required for each frame." ) frame_seg_nums.append(int(f.SegmentIdentificationSequence[0].ReferencedSegmentNumber)) - frame_seg_nums_arr = np.array(frame_seg_nums) + frame_seg_nums_arr = cp.array(frame_seg_nums) if self.to_gpu else np.array(frame_seg_nums) seg_descriptions = {int(f.SegmentNumber): f for f in img.SegmentSequence} - for i in np.unique(frame_seg_nums_arr): - indices = np.where(frame_seg_nums_arr == i)[0] - yield (img.pixel_array[indices, ...], seg_descriptions[i]) + for i in np.unique(frame_seg_nums_arr) if not self.to_gpu else cp.unique(frame_seg_nums_arr): + indices = np.where(frame_seg_nums_arr == i)[0] if not self.to_gpu else cp.where(frame_seg_nums_arr == i)[0] + yield (array_data[indices, ...], seg_descriptions[i]) - def _get_seg_data(self, img): + def _get_seg_data(self, img, filename): """ Get the array data and metadata of the segmentation image. Aegs: img: a Pydicom dataset object that has attribute "SegmentSequence". + filename: the file path of the image. """ metadata = self._get_meta_dict(img) n_classes = len(img.SegmentSequence) - spatial_shape = list(img.pixel_array.shape) + array_data = self._get_array_data(img, filename) + spatial_shape = list(array_data.shape) spatial_shape[0] = spatial_shape[0] // n_classes if self.label_dict is not None: metadata["labels"] = self.label_dict - all_segs = np.zeros([*spatial_shape, len(self.label_dict)]) + if self.to_gpu: + all_segs = cp.zeros([*spatial_shape, len(self.label_dict)], dtype=array_data.dtype) + else: + all_segs = np.zeros([*spatial_shape, len(self.label_dict)], dtype=array_data.dtype) else: metadata["labels"] = {} - all_segs = np.zeros([*spatial_shape, n_classes]) + if self.to_gpu: + all_segs = cp.zeros([*spatial_shape, n_classes], dtype=array_data.dtype) + else: + all_segs = np.zeros([*spatial_shape, n_classes], dtype=array_data.dtype) - for i, (frames, description) in enumerate(self._get_frame_data(img)): + for i, (frames, description) in enumerate(self._get_frame_data(img, filename, array_data)): segment_label = getattr(description, "SegmentLabel", f"label_{i}") class_name = getattr(description, "SegmentDescription", segment_label) if class_name not in metadata["labels"].keys(): @@ -840,19 +897,79 @@ def _get_seg_data(self, img): return all_segs, metadata - def _get_array_data(self, img): + def _get_array_data_from_gpu(self, img, filename): + """ + Get the raw array data of the image. This function is used when `to_gpu` is set to True. + + Args: + img: a Pydicom dataset object. + filename: the file path of the image. + + """ + rows = getattr(img, "Rows", None) + columns = getattr(img, "Columns", None) + bits_allocated = getattr(img, "BitsAllocated", None) + samples_per_pixel = getattr(img, "SamplesPerPixel", 1) + number_of_frames = getattr(img, "NumberOfFrames", 1) + pixel_representation = getattr(img, "PixelRepresentation", 1) + + if rows is None or columns is None or bits_allocated is None: + warnings.warn( + f"dicom data: {filename} does not have Rows, Columns or BitsAllocated, falling back to CPU loading." + ) + + if not hasattr(img, "pixel_array"): + raise ValueError(f"dicom data: {filename} does not have pixel_array.") + data = img.pixel_array + + return data + + if bits_allocated == 8: + dtype = cp.int8 if pixel_representation == 1 else cp.uint8 + elif bits_allocated == 16: + dtype = cp.int16 if pixel_representation == 1 else cp.uint16 + elif bits_allocated == 32: + dtype = cp.int32 if pixel_representation == 1 else cp.uint32 + else: + raise ValueError("Unsupported BitsAllocated value") + + bytes_per_pixel = bits_allocated // 8 + total_pixels = rows * columns * samples_per_pixel * number_of_frames + expected_pixel_data_length = total_pixels * bytes_per_pixel + + pixel_data_tag = pydicom.tag.Tag(0x7FE0, 0x0010) + if pixel_data_tag not in img: + raise ValueError(f"dicom data: {filename} does not have pixel data.") + + offset = img.get_item(pixel_data_tag, keep_deferred=True).value_tell + + with kvikio.CuFile(filename, "r") as f: + buffer = cp.empty(expected_pixel_data_length, dtype=cp.int8) + f.read(buffer, expected_pixel_data_length, offset) + + new_shape = (number_of_frames, rows, columns) if number_of_frames > 1 else (rows, columns) + data = buffer.view(dtype).reshape(new_shape) + + return data + + def _get_array_data(self, img, filename): """ Get the array data of the image. If `RescaleSlope` and `RescaleIntercept` are available, the raw array data - will be rescaled. The output data has the dtype np.float32 if the rescaling is applied. + will be rescaled. The output data has the dtype float32 if the rescaling is applied. Args: img: a Pydicom dataset object. + filename: the file path of the image. """ # process Dicom series - if not hasattr(img, "pixel_array"): - raise ValueError(f"dicom data: {img.filename} does not have pixel_array.") - data = img.pixel_array + + if self.to_gpu: + data = self._get_array_data_from_gpu(img, filename) + else: + if not hasattr(img, "pixel_array"): + raise ValueError(f"dicom data: {filename} does not have pixel_array.") + data = img.pixel_array slope, offset = 1.0, 0.0 rescale_flag = False @@ -862,8 +979,14 @@ def _get_array_data(self, img): if hasattr(img, "RescaleIntercept"): offset = img.RescaleIntercept rescale_flag = True + if rescale_flag: - data = data.astype(np.float32) * slope + offset + if self.to_gpu: + slope = cp.asarray(slope, dtype=cp.float32) + offset = cp.asarray(offset, dtype=cp.float32) + data = data.astype(cp.float32) * slope + offset + else: + data = data.astype(np.float32) * slope + offset return data @@ -884,8 +1007,6 @@ class NibabelReader(ImageReader): Default is False. CuPy and Kvikio are required for this option. Note: For compressed NIfTI files, some operations may still be performed on CPU memory, and the acceleration may not be significant. In some cases, it may be slower than loading on CPU. - In practical use, it's recommended to add a warm up call before the actual loading. - A related tutorial will be prepared in the future, and the document will be updated accordingly. kwargs: additional args for `nibabel.load` API. more details about available args: https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py diff --git a/tests/test_load_image.py b/tests/test_load_image.py index 498b9972b4..07acf7c179 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -168,6 +168,16 @@ def get_data(self, _obj): # test reader consistency between PydicomReader and ITKReader on dicom data TEST_CASE_22 = ["tests/testing_data/CT_DICOM"] +# test pydicom gpu reader +TEST_CASE_GPU_5 = [{"reader": "PydicomReader", "to_gpu": True}, "tests/testing_data/CT_DICOM", (16, 16, 4), (16, 16, 4)] + +TEST_CASE_GPU_6 = [ + {"reader": "PydicomReader", "ensure_channel_first": True, "force": True, "to_gpu": True}, + "tests/testing_data/CT_DICOM", + (16, 16, 4), + (1, 16, 16, 4), +] + TESTS_META = [] for track_meta in (False, True): TESTS_META.append([{}, (128, 128, 128), track_meta]) @@ -242,16 +252,17 @@ def test_nibabel_reader_gpu(self, input_param, filenames, expected_shape): @parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_8_1, TEST_CASE_9]) def test_itk_reader(self, input_param, filenames, expected_shape): - test_image = np.random.rand(128, 128, 128) + test_image = torch.randint(0, 256, (128, 128, 128), dtype=torch.uint8).numpy() + print("Test image value range:", test_image.min(), test_image.max()) with tempfile.TemporaryDirectory() as tempdir: for i, name in enumerate(filenames): filenames[i] = os.path.join(tempdir, name) - itk_np_view = itk.image_view_from_array(test_image) - itk.imwrite(itk_np_view, filenames[i]) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i]) result = LoadImage(image_only=True, **input_param)(filenames) - self.assertEqual(result.meta["filename_or_obj"], os.path.join(tempdir, "test_image.nii.gz")) - diag = torch.as_tensor(np.diag([-1, -1, 1, 1])) - np.testing.assert_allclose(result.affine, diag) + ext = "".join(Path(name).suffixes) + self.assertEqual(result.meta["filename_or_obj"], os.path.join(tempdir, "test_image" + ext)) + self.assertEqual(result.meta["space"], "RAS") + assert_allclose(result.affine, torch.eye(4)) self.assertTupleEqual(result.shape, expected_shape) @parameterized.expand([TEST_CASE_10, TEST_CASE_11, TEST_CASE_12, TEST_CASE_19, TEST_CASE_20, TEST_CASE_21]) @@ -271,6 +282,26 @@ def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape, e ) self.assertTupleEqual(result.shape, expected_np_shape) + @SkipIfNoModule("pydicom") + @SkipIfNoModule("cupy") + @SkipIfNoModule("kvikio") + @parameterized.expand([TEST_CASE_GPU_5, TEST_CASE_GPU_6]) + def test_pydicom_gpu_reader(self, input_param, filenames, expected_shape, expected_np_shape): + result = LoadImage(image_only=True, **input_param)(filenames) + self.assertEqual(result.meta["filename_or_obj"], f"{Path(filenames)}") + assert_allclose( + result.affine, + torch.tensor( + [ + [-0.488281, 0.0, 0.0, 125.0], + [0.0, -0.488281, 0.0, 128.100006], + [0.0, 0.0, 68.33333333, -99.480003], + [0.0, 0.0, 0.0, 1.0], + ] + ), + ) + self.assertTupleEqual(result.shape, expected_np_shape) + def test_no_files(self): with self.assertRaisesRegex(RuntimeError, "list index out of range"): # fname_regex excludes everything LoadImage(image_only=True, reader="PydicomReader", fname_regex=r"^(?!.*).*")("tests/testing_data/CT_DICOM") @@ -317,6 +348,21 @@ def test_dicom_reader_consistency(self, filenames): np.testing.assert_allclose(pydicom_result, itk_result) np.testing.assert_allclose(pydicom_result.affine, itk_result.affine) + @SkipIfNoModule("pydicom") + @SkipIfNoModule("cupy") + @SkipIfNoModule("kvikio") + @parameterized.expand([TEST_CASE_22]) + def test_pydicom_reader_gpu_cpu_consistency(self, filenames): + gpu_param = {"reader": "PydicomReader", "to_gpu": True} + cpu_param = {"reader": "PydicomReader", "to_gpu": False} + for affine_flag in [True, False]: + gpu_param["affine_lps_to_ras"] = affine_flag + cpu_param["affine_lps_to_ras"] = affine_flag + gpu_result = LoadImage(image_only=True, **gpu_param)(filenames) + cpu_result = LoadImage(image_only=True, **cpu_param)(filenames) + np.testing.assert_allclose(gpu_result.cpu(), cpu_result) + np.testing.assert_allclose(gpu_result.affine.cpu(), cpu_result.affine) + def test_dicom_reader_consistency_single(self): itk_param = {"reader": "ITKReader"} pydicom_param = {"reader": "PydicomReader"} From c775393da53c52a9a05d094ebf000f4e798d3617 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Mon, 3 Feb 2025 05:03:17 +0000 Subject: [PATCH 33/67] Zarr compression tests only with versions before 3.0 (#8319) Fixes #8298. ### Description This includes the tests for the `compressor` argument when testing with Zarr before version 3.0 when this argument was deprecated. A fix to upgrade the version of `pycln` used is also included. The version of PyTorch is also fixed to below 2.6 to avoid issues with misuse of `torch.load` which must be addressed later. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Eric Kerfoot --- .pre-commit-config.yaml | 2 +- monai/data/meta_tensor.py | 5 ++++ monai/utils/jupyter_utils.py | 2 +- monai/visualize/img2tensorboard.py | 4 +-- requirements-dev.txt | 2 +- requirements.txt | 2 +- tests/test_zarr_avg_merger.py | 45 +++++++++++++++--------------- 7 files changed, 34 insertions(+), 28 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2a57fbf31a..9621a1fe95 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -66,7 +66,7 @@ repos: )$ - repo: https://github.com/hadialqattan/pycln - rev: v2.4.0 + rev: v2.5.0 hooks: - id: pycln args: [--config=pyproject.toml] diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index c4c491e1b9..6425bc0a4f 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -607,3 +607,8 @@ def print_verbose(self) -> None: print(self) if self.meta is not None: print(self.meta.__repr__()) + + +# needed in later versions of Pytorch to indicate the class is safe for serialisation +if hasattr(torch.serialization, "add_safe_globals"): + torch.serialization.add_safe_globals([MetaTensor]) diff --git a/monai/utils/jupyter_utils.py b/monai/utils/jupyter_utils.py index b1b43a6767..c93e93dcb9 100644 --- a/monai/utils/jupyter_utils.py +++ b/monai/utils/jupyter_utils.py @@ -234,7 +234,7 @@ def plot_engine_status( def _get_loss_from_output( - output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor + output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor, ) -> torch.Tensor: """Returns a single value from the network output, which is a dict or tensor.""" diff --git a/monai/visualize/img2tensorboard.py b/monai/visualize/img2tensorboard.py index 677640bd04..fd328f2c7a 100644 --- a/monai/visualize/img2tensorboard.py +++ b/monai/visualize/img2tensorboard.py @@ -65,11 +65,11 @@ def _image3_animated_gif( img_str = b"" for b_data in PIL.GifImagePlugin.getheader(ims[0])[0]: img_str += b_data - img_str += b"\x21\xFF\x0B\x4E\x45\x54\x53\x43\x41\x50" b"\x45\x32\x2E\x30\x03\x01\x00\x00\x00" + img_str += b"\x21\xff\x0b\x4e\x45\x54\x53\x43\x41\x50" b"\x45\x32\x2e\x30\x03\x01\x00\x00\x00" for i in ims: for b_data in PIL.GifImagePlugin.getdata(i): img_str += b_data - img_str += b"\x3B" + img_str += b"\x3b" summary = SummaryX if has_tensorboardx and isinstance(writer, SummaryWriterX) else Summary summary_image_str = summary.Image(height=10, width=10, colorspace=1, encoded_image_string=img_str) diff --git a/requirements-dev.txt b/requirements-dev.txt index bffe304df4..c9730ee651 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -18,7 +18,7 @@ pep8-naming pycodestyle pyflakes black>=22.12 -isort>=5.1 +isort>=5.1, <6.0 ruff pytype>=2020.6.1; platform_system != "Windows" types-setuptools diff --git a/requirements.txt b/requirements.txt index e184322c13..85e7312f5d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -torch>=1.9 +torch>=1.9,<2.6 numpy>=1.24,<2.0 diff --git a/tests/test_zarr_avg_merger.py b/tests/test_zarr_avg_merger.py index 3c89e4fb03..64e8fbde71 100644 --- a/tests/test_zarr_avg_merger.py +++ b/tests/test_zarr_avg_merger.py @@ -260,32 +260,33 @@ TENSOR_4x4, ] +ALL_TESTS = [ + TEST_CASE_0_DEFAULT_DTYPE, + TEST_CASE_1_DEFAULT_DTYPE, + TEST_CASE_2_DEFAULT_DTYPE, + TEST_CASE_3_DEFAULT_DTYPE, + TEST_CASE_4_DEFAULT_DTYPE, + TEST_CASE_5_VALUE_DTYPE, + TEST_CASE_6_COUNT_DTYPE, + TEST_CASE_7_COUNT_VALUE_DTYPE, + TEST_CASE_8_DTYPE, + TEST_CASE_9_LARGER_SHAPE, + TEST_CASE_10_DIRECTORY_STORE, + TEST_CASE_11_MEMORY_STORE, + TEST_CASE_12_CHUNKS, + TEST_CASE_16_WITH_LOCK, + TEST_CASE_17_WITHOUT_LOCK, +] + +# add compression tests only when using Zarr version before 3.0 +if not version_geq(get_package_version("zarr"), "3.0.0"): + ALL_TESTS += [TEST_CASE_13_COMPRESSOR_LZ4, TEST_CASE_14_COMPRESSOR_PICKLE, TEST_CASE_15_COMPRESSOR_LZMA] + @unittest.skipUnless(has_zarr and has_numcodecs, "Requires zarr (and numcodecs) packages.)") class ZarrAvgMergerTests(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_0_DEFAULT_DTYPE, - TEST_CASE_1_DEFAULT_DTYPE, - TEST_CASE_2_DEFAULT_DTYPE, - TEST_CASE_3_DEFAULT_DTYPE, - TEST_CASE_4_DEFAULT_DTYPE, - TEST_CASE_5_VALUE_DTYPE, - TEST_CASE_6_COUNT_DTYPE, - TEST_CASE_7_COUNT_VALUE_DTYPE, - TEST_CASE_8_DTYPE, - TEST_CASE_9_LARGER_SHAPE, - TEST_CASE_10_DIRECTORY_STORE, - TEST_CASE_11_MEMORY_STORE, - TEST_CASE_12_CHUNKS, - TEST_CASE_13_COMPRESSOR_LZ4, - TEST_CASE_14_COMPRESSOR_PICKLE, - TEST_CASE_15_COMPRESSOR_LZMA, - TEST_CASE_16_WITH_LOCK, - TEST_CASE_17_WITHOUT_LOCK, - ] - ) + @parameterized.expand(ALL_TESTS) def test_zarr_avg_merger_patches(self, arguments, patch_locations, expected): codec_reg = numcodecs.registry.codec_registry if "compressor" in arguments: From 091887b7aabc0490065e5d0053f22e2475cf0c42 Mon Sep 17 00:00:00 2001 From: tisalon Date: Fri, 7 Feb 2025 13:44:45 +0100 Subject: [PATCH 34/67] Clarify input tensor shape in pixelshuffle and pixelunshuffle functions and simplify ValueError message in pixelunshuffle --- monai/networks/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 46d6fc0825..378a713b98 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -377,7 +377,7 @@ def pixelshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> torch See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution". Args: - x: Input tensor + x: Input tensor with shape BCHW[D] spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D scale_factor: factor to rescale the spatial dimensions by, must be >=1 @@ -423,7 +423,7 @@ def pixelunshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> tor See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution". Args: - x: Input tensor + x: Input tensor with shape BCHW[D] spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D scale_factor: factor to reduce the spatial dimensions by, must be >=1 @@ -443,7 +443,7 @@ def pixelunshuffle(x: torch.Tensor, spatial_dims: int, scale_factor: int) -> tor if any(d % factor != 0 for d in input_size[2:]): raise ValueError( - f"All spatial dimensions must be divisible by factor {factor}. " f"Got spatial dimensions: {input_size[2:]}" + f"All spatial dimensions must be divisible by factor {factor}. " f", spatial shape is: {input_size[2:]}" ) output_size = [batch_size, new_channels] + [d // factor for d in input_size[2:]] reshaped_size = [batch_size, channels] + sum([[d // factor, factor] for d in input_size[2:]], []) From 5d162d02236331a70dae1a74bfff49a6ddeab655 Mon Sep 17 00:00:00 2001 From: tisalon Date: Fri, 7 Feb 2025 13:56:48 +0100 Subject: [PATCH 35/67] Refactor downsample mode checks to use enum values for clarity --- monai/networks/blocks/downsample.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/networks/blocks/downsample.py b/monai/networks/blocks/downsample.py index 721738eedb..db9867e15f 100644 --- a/monai/networks/blocks/downsample.py +++ b/monai/networks/blocks/downsample.py @@ -139,7 +139,7 @@ def __init__( kernel_size_ = ensure_tuple_rep(kernel_size, spatial_dims) padding = tuple((k - 1) // 2 for k in kernel_size_) - if down_mode == "conv": + if down_mode == DownsampleMode.CONV: if not in_channels: raise ValueError("in_channels needs to be specified in conv mode") self.add_module( @@ -153,7 +153,7 @@ def __init__( bias=bias, ), ) - elif down_mode == "convgroup": + elif down_mode == DownsampleMode.CONVGROUP: if not in_channels: raise ValueError("in_channels needs to be specified") if out_channels is None: @@ -203,7 +203,7 @@ def __init__( if post_conv: self.add_module("postconv", post_conv) - elif down_mode == "pixelunshuffle": + elif down_mode == DownsampleMode.PIXELUNSHUFFLE: self.add_module( "pixelunshuffle", SubpixelDownsample( From f520e99c3a3533c7c848caf0208e70dfefd4dfc0 Mon Sep 17 00:00:00 2001 From: tisalon Date: Fri, 7 Feb 2025 14:13:59 +0100 Subject: [PATCH 36/67] fix optiona import --- monai/networks/blocks/cablock.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/networks/blocks/cablock.py b/monai/networks/blocks/cablock.py index b92e56c747..b38e89722a 100644 --- a/monai/networks/blocks/cablock.py +++ b/monai/networks/blocks/cablock.py @@ -13,7 +13,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange +from monai.utils import optional_import + +rearrange, _ = optional_import("einops", name="rearrange") from monai.networks.blocks.convolutions import Convolution From 39d1edfb2adece2f703cd173ab206d4edf253f15 Mon Sep 17 00:00:00 2001 From: tisalon Date: Fri, 7 Feb 2025 14:41:52 +0100 Subject: [PATCH 37/67] Refactor layer normalization parameters for consistency and clarity in Restormer model and update assert in forward layer to support 3D images --- monai/networks/nets/restormer.py | 39 ++++++++++++++++---------------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/monai/networks/nets/restormer.py b/monai/networks/nets/restormer.py index 489ef1984b..7d7c133756 100644 --- a/monai/networks/nets/restormer.py +++ b/monai/networks/nets/restormer.py @@ -32,14 +32,13 @@ def __init__( num_heads: int, ffn_expansion_factor: float, bias: bool, - layer_norm_type: str = "BiasFree", + layer_norm_use_bias: bool = False, flash_attention: bool = False, ): super().__init__() - use_bias = layer_norm_type != "BiasFree" - self.norm1 = Norm[Norm.INSTANCE, spatial_dims](dim, affine=use_bias) + self.norm1 = Norm[Norm.INSTANCE, spatial_dims](dim, affine=layer_norm_use_bias) self.attn = CABlock(spatial_dims, dim, num_heads, bias, flash_attention) - self.norm2 = Norm[Norm.INSTANCE, spatial_dims](dim, affine=use_bias) + self.norm2 = Norm[Norm.INSTANCE, spatial_dims](dim, affine=layer_norm_use_bias) self.ffn = FeedForward(spatial_dims, dim, ffn_expansion_factor, bias) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -53,11 +52,11 @@ class OverlapPatchEmbed(nn.Module): Unlike standard patch embeddings that use non-overlapping patches, this approach maintains spatial continuity through 3x3 convolutions.""" - def __init__(self, spatial_dims: int, in_c: int = 3, embed_dim: int = 48, bias: bool = False): + def __init__(self, spatial_dims: int, in_channels: int = 3, embed_dim: int = 48, bias: bool = False): super().__init__() self.proj = Convolution( spatial_dims=spatial_dims, - in_channels=in_c, + in_channels=in_channels, out_channels=embed_dim, kernel_size=3, strides=1, @@ -89,7 +88,7 @@ class Restormer(nn.Module): def __init__( self, spatial_dims: int = 2, - inp_channels: int = 3, + in_channels: int = 3, out_channels: int = 3, dim: int = 48, num_blocks: tuple[int, ...] = (1, 1, 1, 1), @@ -97,7 +96,7 @@ def __init__( num_refinement_blocks: int = 4, ffn_expansion_factor: float = 2.66, bias: bool = False, - layer_norm_type: str = "WithBias", + layer_norm_use_bias: str = True, dual_pixel_task: bool = False, flash_attention: bool = False, ) -> None: @@ -105,7 +104,7 @@ def __init__( """Initialize Restormer model. Args: - inp_channels: Number of input image channels + in_channels: Number of input image channels out_channels: Number of output image channels dim: Base feature dimension num_blocks: Number of transformer blocks at each scale @@ -113,7 +112,7 @@ def __init__( heads: Number of attention heads at each scale ffn_expansion_factor: Expansion factor for feed-forward network bias: Whether to use bias in convolutions - layer_norm_type: Type of normalization ('WithBias' or 'BiasFree') + layer_norm_use_bias: Whether to use bias in layer normalization. Default is True. dual_pixel_task: Enable dual-pixel specific processing flash_attention: Use flash attention if available """ @@ -123,7 +122,7 @@ def __init__( assert all(n > 0 for n in num_blocks), "Number of blocks must be greater than 0" # Initial feature extraction - self.patch_embed = OverlapPatchEmbed(spatial_dims, inp_channels, dim) + self.patch_embed = OverlapPatchEmbed(spatial_dims, in_channels, dim) self.encoder_levels = nn.ModuleList() self.downsamples = nn.ModuleList() self.decoder_levels = nn.ModuleList() @@ -147,7 +146,7 @@ def __init__( num_heads=heads[n], ffn_expansion_factor=ffn_expansion_factor, bias=bias, - layer_norm_type=layer_norm_type, + layer_norm_use_bias=layer_norm_use_bias, flash_attention=flash_attention, ) for _ in range(num_blocks[n]) @@ -176,7 +175,7 @@ def __init__( num_heads=heads[num_steps], ffn_expansion_factor=ffn_expansion_factor, bias=bias, - layer_norm_type=layer_norm_type, + layer_norm_use_bias=layer_norm_use_bias, flash_attention=flash_attention, ) for _ in range(num_blocks[num_steps]) @@ -224,7 +223,7 @@ def __init__( num_heads=heads[n], ffn_expansion_factor=ffn_expansion_factor, bias=bias, - layer_norm_type=layer_norm_type, + layer_norm_use_bias=layer_norm_use_bias, flash_attention=flash_attention, ) for _ in range(num_blocks[n]) @@ -241,7 +240,7 @@ def __init__( num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, - layer_norm_type=layer_norm_type, + layer_norm_use_bias=layer_norm_use_bias, flash_attention=flash_attention, ) for _ in range(num_refinement_blocks) @@ -272,14 +271,14 @@ def forward(self, x) -> torch.Tensor: """Forward pass of Restormer. Processes input through encoder-decoder architecture with skip connections. Args: - inp_img: Input image tensor of shape (B, C, H, W) + inp_img: Input image tensor of shape (B, C, H, W, [D]) Returns: - Restored image tensor of shape (B, C, H, W) + Restored image tensor of shape (B, C, H, W, [D]) """ - assert ( - x.shape[-1] > 2 ** self.num_steps and x.shape[-2] > 2**self.num_steps - ), "Input dimensions should be larger than 2^number_of_step" + assert all( + x.shape[-i] > 2 ** self.num_steps for i in range(1, self.spatial_dims + 1) + ), "All spatial dimensions should be larger than 2^number_of_step" # Patch embedding x = self.patch_embed(x) From 5b3d4e1176e901cd46de9fbde2c672d0f69ba9bd Mon Sep 17 00:00:00 2001 From: tisalon Date: Fri, 7 Feb 2025 14:50:27 +0100 Subject: [PATCH 38/67] Enhance documentation for MDTATransformerBlock, OverlapPatchEmbed and Restormer class. --- monai/networks/nets/restormer.py | 45 ++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/monai/networks/nets/restormer.py b/monai/networks/nets/restormer.py index 7d7c133756..2dab8d2e04 100644 --- a/monai/networks/nets/restormer.py +++ b/monai/networks/nets/restormer.py @@ -23,7 +23,17 @@ class MDTATransformerBlock(nn.Module): """Basic transformer unit combining MDTA and GDFN with skip connections. Unlike standard transformers that use LayerNorm, this block uses Instance Norm - for better adaptation to image restoration tasks.""" + for better adaptation to image restoration tasks. + + Args: + spatial_dims: Number of spatial dimensions (2D or 3D) + dim: Number of input channels + num_heads: Number of attention heads + ffn_expansion_factor: Expansion factor for feed-forward network + bias: Whether to use bias in attention layers + layer_norm_use_bias: Whether to use bias in layer normalization. Defaults to False. + flash_attention: Whether to use flash attention optimization. Defaults to False. + """ def __init__( self, @@ -50,7 +60,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class OverlapPatchEmbed(nn.Module): """Initial feature extraction using overlapped convolutions. Unlike standard patch embeddings that use non-overlapping patches, - this approach maintains spatial continuity through 3x3 convolutions.""" + this approach maintains spatial continuity through 3x3 convolutions. + + Args: + spatial_dims: Number of spatial dimensions (2D or 3D) + in_channels: Number of input channels + embed_dim: Dimension of embedded features. Defaults to 48. + bias: Whether to use bias in convolution layer. Defaults to False. + """ def __init__(self, spatial_dims: int, in_channels: int = 3, embed_dim: int = 48, bias: bool = False): super().__init__() @@ -104,17 +121,23 @@ def __init__( """Initialize Restormer model. Args: + spatial_dims: Number of spatial dimensions (2D or 3D) in_channels: Number of input image channels out_channels: Number of output image channels - dim: Base feature dimension - num_blocks: Number of transformer blocks at each scale - num_refinement_blocks: Number of final refinement blocks - heads: Number of attention heads at each scale - ffn_expansion_factor: Expansion factor for feed-forward network - bias: Whether to use bias in convolutions - layer_norm_use_bias: Whether to use bias in layer normalization. Default is True. - dual_pixel_task: Enable dual-pixel specific processing - flash_attention: Use flash attention if available + dim: Base feature dimension. Defaults to 48. + num_blocks: Number of transformer blocks at each scale. Defaults to (1,1,1,1). + heads: Number of attention heads at each scale. Defaults to (1,1,1,1). + num_refinement_blocks: Number of final refinement blocks. Defaults to 4. + ffn_expansion_factor: Expansion factor for feed-forward network. Defaults to 2.66. + bias: Whether to use bias in convolutions. Defaults to False. + layer_norm_use_bias: Whether to use bias in layer normalization. Defaults to True. + dual_pixel_task: Enable dual-pixel specific processing. Defaults to False. + flash_attention: Use flash attention if available. Defaults to False. + + Note: + The number of blocks must be greater than 1 + The length of num_blocks and heads must be equal + All values in num_blocks must be greater than 0 """ # Check input parameters assert len(num_blocks) > 1, "Number of blocks must be greater than 1" From 1683b14dcd8e3a2fa47dafaff41de1b2c6b4d0ee Mon Sep 17 00:00:00 2001 From: tisalon Date: Fri, 7 Feb 2025 14:57:44 +0100 Subject: [PATCH 39/67] run ./runtests.sh --autofix to check formatting --- monai/networks/blocks/cablock.py | 4 ++-- monai/networks/nets/restormer.py | 4 ++-- monai/utils/jupyter_utils.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/networks/blocks/cablock.py b/monai/networks/blocks/cablock.py index b38e89722a..9d50724fa9 100644 --- a/monai/networks/blocks/cablock.py +++ b/monai/networks/blocks/cablock.py @@ -13,12 +13,12 @@ import torch import torch.nn as nn import torch.nn.functional as F + +from monai.networks.blocks.convolutions import Convolution from monai.utils import optional_import rearrange, _ = optional_import("einops", name="rearrange") -from monai.networks.blocks.convolutions import Convolution - __all__ = ["FeedForward", "CABlock"] diff --git a/monai/networks/nets/restormer.py b/monai/networks/nets/restormer.py index 2dab8d2e04..fbe677f69f 100644 --- a/monai/networks/nets/restormer.py +++ b/monai/networks/nets/restormer.py @@ -61,7 +61,7 @@ class OverlapPatchEmbed(nn.Module): """Initial feature extraction using overlapped convolutions. Unlike standard patch embeddings that use non-overlapping patches, this approach maintains spatial continuity through 3x3 convolutions. - + Args: spatial_dims: Number of spatial dimensions (2D or 3D) in_channels: Number of input channels @@ -300,7 +300,7 @@ def forward(self, x) -> torch.Tensor: Restored image tensor of shape (B, C, H, W, [D]) """ assert all( - x.shape[-i] > 2 ** self.num_steps for i in range(1, self.spatial_dims + 1) + x.shape[-i] > 2**self.num_steps for i in range(1, self.spatial_dims + 1) ), "All spatial dimensions should be larger than 2^number_of_step" # Patch embedding diff --git a/monai/utils/jupyter_utils.py b/monai/utils/jupyter_utils.py index c93e93dcb9..b1b43a6767 100644 --- a/monai/utils/jupyter_utils.py +++ b/monai/utils/jupyter_utils.py @@ -234,7 +234,7 @@ def plot_engine_status( def _get_loss_from_output( - output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor, + output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor ) -> torch.Tensor: """Returns a single value from the network output, which is a dict or tensor.""" From 232be1cf8e26c0c5826e38113281d952bfc23003 Mon Sep 17 00:00:00 2001 From: tisalon Date: Fri, 7 Feb 2025 15:03:36 +0100 Subject: [PATCH 40/67] Refactor OverlapPatchEmbed to inherit from Convolution and streamline forward method --- monai/networks/nets/restormer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/monai/networks/nets/restormer.py b/monai/networks/nets/restormer.py index fbe677f69f..e48d7edd68 100644 --- a/monai/networks/nets/restormer.py +++ b/monai/networks/nets/restormer.py @@ -57,7 +57,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class OverlapPatchEmbed(nn.Module): +class OverlapPatchEmbed(Convolution): """Initial feature extraction using overlapped convolutions. Unlike standard patch embeddings that use non-overlapping patches, this approach maintains spatial continuity through 3x3 convolutions. @@ -70,8 +70,7 @@ class OverlapPatchEmbed(nn.Module): """ def __init__(self, spatial_dims: int, in_channels: int = 3, embed_dim: int = 48, bias: bool = False): - super().__init__() - self.proj = Convolution( + super().__init__( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=embed_dim, @@ -82,8 +81,8 @@ def __init__(self, spatial_dims: int, in_channels: int = 3, embed_dim: int = 48, conv_only=True, ) - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.proj(x) +def forward(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x) class Restormer(nn.Module): From d1df8e631389a3a66f79af4bf76da78c499e7a9d Mon Sep 17 00:00:00 2001 From: tisalon Date: Fri, 7 Feb 2025 15:05:52 +0100 Subject: [PATCH 41/67] Enhance documentation for FeedForward and CABlock classes, adding argument descriptions and error handling details. --- monai/networks/blocks/cablock.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/monai/networks/blocks/cablock.py b/monai/networks/blocks/cablock.py index 9d50724fa9..63f533e626 100644 --- a/monai/networks/blocks/cablock.py +++ b/monai/networks/blocks/cablock.py @@ -24,7 +24,14 @@ class FeedForward(nn.Module): """Gated-DConv Feed-Forward Network (GDFN) that controls feature flow using gating mechanism. - Uses depth-wise convolutions for local context mixing and GELU-activated gating for refined feature selection.""" + Uses depth-wise convolutions for local context mixing and GELU-activated gating for refined feature selection. + + Args: + spatial_dims: Number of spatial dimensions (2D or 3D) + dim: Number of input channels + ffn_expansion_factor: Factor to expand hidden features dimension + bias: Whether to use bias in convolution layers + """ def __init__(self, spatial_dims: int, dim: int, ffn_expansion_factor: float, bias: bool): super().__init__() @@ -70,7 +77,19 @@ class CABlock(nn.Module): """Multi-DConv Head Transposed Self-Attention (MDTA): Differs from standard self-attention by operating on feature channels instead of spatial dimensions. Incorporates depth-wise convolutions for local mixing before attention, achieving linear complexity vs quadratic - in vanilla attention. Based on SW Zamir, et al., 2022 """ + in vanilla attention. Based on SW Zamir, et al., 2022 + + Args: + spatial_dims: Number of spatial dimensions (2D or 3D) + dim: Number of input channels + num_heads: Number of attention heads + bias: Whether to use bias in convolution layers + flash_attention: Whether to use flash attention optimization. Defaults to False. + + Raises: + ValueError: If flash attention is not available in current PyTorch version + ValueError: If spatial_dims is greater than 3 + """ def __init__(self, spatial_dims, dim: int, num_heads: int, bias: bool, flash_attention: bool = False): super().__init__() From 78ce56b11bfca78db73566317c5eab145a5c9475 Mon Sep 17 00:00:00 2001 From: tisalon Date: Fri, 7 Feb 2025 15:07:36 +0100 Subject: [PATCH 42/67] code formatting --- monai/networks/blocks/cablock.py | 4 ++-- monai/networks/nets/restormer.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/monai/networks/blocks/cablock.py b/monai/networks/blocks/cablock.py index 63f533e626..72e4cc68d0 100644 --- a/monai/networks/blocks/cablock.py +++ b/monai/networks/blocks/cablock.py @@ -25,7 +25,7 @@ class FeedForward(nn.Module): """Gated-DConv Feed-Forward Network (GDFN) that controls feature flow using gating mechanism. Uses depth-wise convolutions for local context mixing and GELU-activated gating for refined feature selection. - + Args: spatial_dims: Number of spatial dimensions (2D or 3D) dim: Number of input channels @@ -78,7 +78,7 @@ class CABlock(nn.Module): by operating on feature channels instead of spatial dimensions. Incorporates depth-wise convolutions for local mixing before attention, achieving linear complexity vs quadratic in vanilla attention. Based on SW Zamir, et al., 2022 - + Args: spatial_dims: Number of spatial dimensions (2D or 3D) dim: Number of input channels diff --git a/monai/networks/nets/restormer.py b/monai/networks/nets/restormer.py index e48d7edd68..72ef0f4edc 100644 --- a/monai/networks/nets/restormer.py +++ b/monai/networks/nets/restormer.py @@ -81,6 +81,7 @@ def __init__(self, spatial_dims: int, in_channels: int = 3, embed_dim: int = 48, conv_only=True, ) + def forward(self, x: torch.Tensor) -> torch.Tensor: return super().forward(x) From 64b203debdb15ca6f2ec03531ac254e3f5cc1601 Mon Sep 17 00:00:00 2001 From: tisalon Date: Fri, 7 Feb 2025 15:37:09 +0100 Subject: [PATCH 43/67] Update args naming in unit restormer test for consistency with suggested changes Signed-off-by: tisalon --- tests/test_restormer.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/tests/test_restormer.py b/tests/test_restormer.py index bf1fa83eac..f34b8c053c 100644 --- a/tests/test_restormer.py +++ b/tests/test_restormer.py @@ -10,7 +10,10 @@ # limitations under the License. from __future__ import annotations +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "../")) import unittest import torch @@ -20,15 +23,15 @@ from monai.networks.nets.restormer import MDTATransformerBlock, OverlapPatchEmbed, Restormer TEST_CASES_TRANSFORMER = [ - # [spatial_dims, dim, num_heads, ffn_factor, bias, norm_type, flash_attn, input_shape] - [2, 48, 8, 2.66, True, "WithBias", False, (2, 48, 64, 64)], - [2, 96, 8, 2.66, False, "BiasFree", False, (2, 96, 32, 32)], - [3, 48, 4, 2.66, True, "WithBias", False, (2, 48, 32, 32, 32)], - [3, 96, 8, 2.66, False, "BiasFree", True, (2, 96, 16, 16, 16)], + # [spatial_dims, dim, num_heads, ffn_factor, bias, layer_norm_use_bias, flash_attn, input_shape] + [2, 48, 8, 2.66, True, True, False, (2, 48, 64, 64)], + [2, 96, 8, 2.66, False, False, False, (2, 96, 32, 32)], + [3, 48, 4, 2.66, True, True, False, (2, 48, 32, 32, 32)], + [3, 96, 8, 2.66, False, False, True, (2, 96, 16, 16, 16)], ] TEST_CASES_PATCHEMBED = [ - # spatial_dims, in_c, embed_dim, input_shape, expected_shape + # spatial_dims, in_channels, embed_dim, input_shape, expected_shape [2, 1, 48, (2, 1, 64, 64), (2, 48, 64, 64)], [2, 3, 96, (2, 3, 32, 32), (2, 96, 32, 32)], [3, 1, 48, (2, 1, 32, 32, 32), (2, 48, 32, 32, 32)], @@ -52,7 +55,7 @@ [ { "spatial_dims": 2, - "inp_channels": 1, + "in_channels": 1, "out_channels": 1, "dim": 48, "num_blocks": config["num_blocks"], @@ -67,9 +70,9 @@ [ { "spatial_dims": 3, - "inp_channels": 1, + "in_channels": 1, "out_channels": 1, - "dim": 48, + "dim": 16, "num_blocks": config["num_blocks"], "heads": config["heads"], "num_refinement_blocks": 2, @@ -85,14 +88,14 @@ class TestMDTATransformerBlock(unittest.TestCase): @parameterized.expand(TEST_CASES_TRANSFORMER) - def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, norm_type, flash, shape): + def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_bias, flash, shape): block = MDTATransformerBlock( spatial_dims=spatial_dims, dim=dim, num_heads=heads, ffn_expansion_factor=ffn_factor, bias=bias, - LayerNorm_type=norm_type, + layer_norm_use_bias=layer_norm_use_bias, flash_attention=flash, ) with eval_mode(block): @@ -104,8 +107,8 @@ def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, norm_type, flas class TestOverlapPatchEmbed(unittest.TestCase): @parameterized.expand(TEST_CASES_PATCHEMBED) - def test_shape(self, spatial_dims, in_c, embed_dim, input_shape, expected_shape): - net = OverlapPatchEmbed(spatial_dims=spatial_dims, in_c=in_c, embed_dim=embed_dim) + def test_shape(self, spatial_dims, in_channels, embed_dim, input_shape, expected_shape): + net = OverlapPatchEmbed(spatial_dims=spatial_dims, in_channels=in_channels, embed_dim=embed_dim) with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) @@ -121,12 +124,12 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) def test_small_input_error_2d(self): - net = Restormer(spatial_dims=2, inp_channels=1, out_channels=1) + net = Restormer(spatial_dims=2, in_channels=1, out_channels=1) with self.assertRaises(AssertionError): net(torch.randn(1, 1, 8, 8)) def test_small_input_error_3d(self): - net = Restormer(spatial_dims=3, inp_channels=1, out_channels=1) + net = Restormer(spatial_dims=3, in_channels=1, out_channels=1) with self.assertRaises(AssertionError): net(torch.randn(1, 1, 8, 8, 8)) From ce158865f10f7cb7de7c1b244a2021ee9025dcf8 Mon Sep 17 00:00:00 2001 From: tisalon Date: Fri, 7 Feb 2025 15:44:11 +0100 Subject: [PATCH 44/67] Fix optional import --- tests/test_CABlock.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_CABlock.py b/tests/test_CABlock.py index fb981ad829..1ac7a2927f 100644 --- a/tests/test_CABlock.py +++ b/tests/test_CABlock.py @@ -21,7 +21,7 @@ from monai.utils import optional_import from tests.utils import SkipIfBeforePyTorchVersion, assert_allclose -einops, has_einops = optional_import("einops") +rearrange, _ = optional_import("einops", name="rearrange") TEST_CASES_CAB = [] From 30fad17d5868b3ace5dc321102ae20450ad6f827 Mon Sep 17 00:00:00 2001 From: tisalon Date: Fri, 7 Feb 2025 15:52:55 +0100 Subject: [PATCH 45/67] require einops for all tests --- tests/test_CABlock.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/test_CABlock.py b/tests/test_CABlock.py index 1ac7a2927f..4b36f2e72c 100644 --- a/tests/test_CABlock.py +++ b/tests/test_CABlock.py @@ -12,6 +12,7 @@ from __future__ import annotations import unittest +from unittest import skipUnless import torch from parameterized import parameterized @@ -21,7 +22,7 @@ from monai.utils import optional_import from tests.utils import SkipIfBeforePyTorchVersion, assert_allclose -rearrange, _ = optional_import("einops", name="rearrange") +einops, has_einops = optional_import("einops") TEST_CASES_CAB = [] @@ -70,17 +71,20 @@ def test_gating_mechanism(self): class TestCABlock(unittest.TestCase): @parameterized.expand(TEST_CASES_CAB) + @skipUnless(has_einops, "Requires einops") def test_shape(self, input_param, input_shape, expected_shape): net = CABlock(**input_param) with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) + @skipUnless(has_einops, "Requires einops") def test_invalid_spatial_dims(self): with self.assertRaises(ValueError): CABlock(spatial_dims=4, dim=64, num_heads=4, bias=True) @SkipIfBeforePyTorchVersion((2, 0)) + @skipUnless(has_einops, "Requires einops") def test_flash_attention(self): device = "cuda" if torch.cuda.is_available() else "cpu" block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=True).to(device) @@ -88,17 +92,20 @@ def test_flash_attention(self): output = block(x) self.assertEqual(output.shape, x.shape) + @skipUnless(has_einops, "Requires einops") def test_temperature_parameter(self): block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True) self.assertTrue(isinstance(block.temperature, torch.nn.Parameter)) self.assertEqual(block.temperature.shape, (4, 1, 1)) + @skipUnless(has_einops, "Requires einops") def test_qkv_transformation_2d(self): block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True) x = torch.randn(2, 64, 32, 32) qkv = block.qkv(x) self.assertEqual(qkv.shape, (2, 192, 32, 32)) + @skipUnless(has_einops, "Requires einops") def test_qkv_transformation_3d(self): block = CABlock(spatial_dims=3, dim=64, num_heads=4, bias=True) x = torch.randn(2, 64, 16, 16, 16) @@ -106,6 +113,7 @@ def test_qkv_transformation_3d(self): self.assertEqual(qkv.shape, (2, 192, 16, 16, 16)) @SkipIfBeforePyTorchVersion((2, 0)) + @skipUnless(has_einops, "Requires einops") def test_flash_vs_normal_attention(self): device = "cuda" if torch.cuda.is_available() else "cpu" block_flash = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=True).to(device) @@ -120,6 +128,7 @@ def test_flash_vs_normal_attention(self): assert_allclose(out_flash, out_normal, atol=1e-4) + @skipUnless(has_einops, "Requires einops") def test_deterministic_small_input(self): block = CABlock(spatial_dims=2, dim=2, num_heads=1, bias=False) with torch.no_grad(): From 1079d8c69c31776126b500d9477743891f1df87d Mon Sep 17 00:00:00 2001 From: tisalon Date: Fri, 7 Feb 2025 16:04:21 +0100 Subject: [PATCH 46/67] require einops also for test_restormer Signed-off-by: tisalon --- tests/test_restormer.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/test_restormer.py b/tests/test_restormer.py index f34b8c053c..3724564b0e 100644 --- a/tests/test_restormer.py +++ b/tests/test_restormer.py @@ -15,12 +15,17 @@ sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "../")) import unittest +from unittest import skipUnless + import torch from parameterized import parameterized from monai.networks import eval_mode from monai.networks.nets.restormer import MDTATransformerBlock, OverlapPatchEmbed, Restormer +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") TEST_CASES_TRANSFORMER = [ # [spatial_dims, dim, num_heads, ffn_factor, bias, layer_norm_use_bias, flash_attn, input_shape] @@ -86,7 +91,8 @@ class TestMDTATransformerBlock(unittest.TestCase): - + + @skipUnless(has_einops, "Requires einops") @parameterized.expand(TEST_CASES_TRANSFORMER) def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_bias, flash, shape): block = MDTATransformerBlock( @@ -116,6 +122,7 @@ def test_shape(self, spatial_dims, in_channels, embed_dim, input_shape, expected class TestRestormer(unittest.TestCase): + @skipUnless(has_einops, "Requires einops") @parameterized.expand(TEST_CASES_RESTORMER) def test_shape(self, input_param, input_shape, expected_shape): net = Restormer(**input_param) @@ -123,11 +130,13 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) + @skipUnless(has_einops, "Requires einops") def test_small_input_error_2d(self): net = Restormer(spatial_dims=2, in_channels=1, out_channels=1) with self.assertRaises(AssertionError): net(torch.randn(1, 1, 8, 8)) + @skipUnless(has_einops, "Requires einops") def test_small_input_error_3d(self): net = Restormer(spatial_dims=3, in_channels=1, out_channels=1) with self.assertRaises(AssertionError): From b2b3ddf6a953fb8aa81a0d806586b37e81e282ee Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 Feb 2025 15:05:00 +0000 Subject: [PATCH 47/67] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_restormer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_restormer.py b/tests/test_restormer.py index 3724564b0e..d9029a1a51 100644 --- a/tests/test_restormer.py +++ b/tests/test_restormer.py @@ -91,7 +91,7 @@ class TestMDTATransformerBlock(unittest.TestCase): - + @skipUnless(has_einops, "Requires einops") @parameterized.expand(TEST_CASES_TRANSFORMER) def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_bias, flash, shape): From 174e968f192426a5c02c1f8949518b0921410ab2 Mon Sep 17 00:00:00 2001 From: tisalon Date: Fri, 7 Feb 2025 16:05:17 +0100 Subject: [PATCH 48/67] remove relative impots Signed-off-by: tisalon --- tests/test_restormer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_restormer.py b/tests/test_restormer.py index d9029a1a51..e10642f247 100644 --- a/tests/test_restormer.py +++ b/tests/test_restormer.py @@ -10,10 +10,7 @@ # limitations under the License. from __future__ import annotations -import os -import sys -sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "../")) import unittest from unittest import skipUnless From e15a8156c960b8d06b510bb8494e8a777ef84773 Mon Sep 17 00:00:00 2001 From: tisalon Date: Fri, 7 Feb 2025 16:22:13 +0100 Subject: [PATCH 49/67] fix capitalisation in DownSample documentation networks.rts Signed-off-by: tisalon --- docs/source/networks.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 05825c3c18..d76a9f72e3 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -187,7 +187,7 @@ Blocks ~~~~~~~~~~~~~~ .. autoclass:: DownSample :members: -.. autoclass:: DownSample +.. autoclass:: Downsample .. autoclass:: SubpixelDownSample :members: .. autoclass:: Subpixeldownsample From d53d97d8d679d9454950e01fb02cd9f492a95a9f Mon Sep 17 00:00:00 2001 From: tisalon Date: Fri, 7 Feb 2025 16:26:05 +0100 Subject: [PATCH 50/67] fix capitalisation in SubpixelDownsample documentation Signed-off-by: tisalon --- docs/source/networks.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/networks.rst b/docs/source/networks.rst index d76a9f72e3..3c8ea725a9 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -188,7 +188,7 @@ Blocks .. autoclass:: DownSample :members: .. autoclass:: Downsample -.. autoclass:: SubpixelDownSample +.. autoclass:: SubpixelDownsample :members: .. autoclass:: Subpixeldownsample .. autoclass:: SubpixelDownSample From cae7d96d688c1129b720698de72417b417edd637 Mon Sep 17 00:00:00 2001 From: tisalon Date: Fri, 7 Feb 2025 16:29:36 +0100 Subject: [PATCH 51/67] formatting Signed-off-by: tisalon --- tests/test_restormer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_restormer.py b/tests/test_restormer.py index e10642f247..5b0ad86f6e 100644 --- a/tests/test_restormer.py +++ b/tests/test_restormer.py @@ -14,7 +14,6 @@ import unittest from unittest import skipUnless - import torch from parameterized import parameterized From a0afee5d95c8d41ac6b654696334cf3c79d53842 Mon Sep 17 00:00:00 2001 From: tisalon Date: Fri, 7 Feb 2025 16:53:10 +0100 Subject: [PATCH 52/67] update docstring to mention 2D and 3D cases Signed-off-by: tisalon --- monai/networks/blocks/downsample.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/monai/networks/blocks/downsample.py b/monai/networks/blocks/downsample.py index db9867e15f..5ddaadf653 100644 --- a/monai/networks/blocks/downsample.py +++ b/monai/networks/blocks/downsample.py @@ -226,7 +226,8 @@ class SubpixelDownsample(nn.Module): spatial dimensions while increasing channel depth. The pixel unshuffle operation is the inverse of pixel shuffle, rearranging dimensions - from (B, C, H*r, W*r) to (B, C*r², H, W). + from (B, C, H*r, W*r) to (B, C*r², H, W) for 2D images or from (B, C, H*r, W*r, D*r) to (B, C*r³, H, W, D) in 3D case. + Example: (1, 1, 4, 4) with r=2 becomes (1, 4, 2, 2). See: Shi et al., 2016, "Real-Time Single Image and Video Super-Resolution @@ -246,15 +247,18 @@ def __init__( bias: bool = True, ) -> None: """ + Downsamples data by rearranging spatial information into channel space. + This reduces spatial dimensions while increasing channel depth. + Args: spatial_dims: number of spatial dimensions of the input image. in_channels: number of channels of the input image. out_channels: optional number of channels of the output image. scale_factor: factor to reduce the spatial dimensions by. Defaults to 2. conv_block: a conv block to adjust channels before downsampling. Defaults to None. - - When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized. - - When ``conv_block`` is an ``nn.module``, - please ensure the input number of channels matches requirements. + When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized. + When ``conv_block`` is an ``nn.module``, + please ensure the input number of channels matches requirements. bias: whether to have a bias term in the default conv_block. Defaults to True. """ super().__init__() From 529e90bcdb57534c76f4cea77f0cd2007a89c87f Mon Sep 17 00:00:00 2001 From: tisalon Date: Sun, 9 Feb 2025 11:52:46 +0100 Subject: [PATCH 53/67] Update type annotations and doctring --- monai/networks/blocks/downsample.py | 2 +- monai/networks/nets/restormer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/blocks/downsample.py b/monai/networks/blocks/downsample.py index 5ddaadf653..ae962287a9 100644 --- a/monai/networks/blocks/downsample.py +++ b/monai/networks/blocks/downsample.py @@ -249,7 +249,7 @@ def __init__( """ Downsamples data by rearranging spatial information into channel space. This reduces spatial dimensions while increasing channel depth. - + Args: spatial_dims: number of spatial dimensions of the input image. in_channels: number of channels of the input image. diff --git a/monai/networks/nets/restormer.py b/monai/networks/nets/restormer.py index 72ef0f4edc..b59150ad4d 100644 --- a/monai/networks/nets/restormer.py +++ b/monai/networks/nets/restormer.py @@ -113,7 +113,7 @@ def __init__( num_refinement_blocks: int = 4, ffn_expansion_factor: float = 2.66, bias: bool = False, - layer_norm_use_bias: str = True, + layer_norm_use_bias: bool = True, dual_pixel_task: bool = False, flash_attention: bool = False, ) -> None: From c109029ad4f73bcada08cc2cb50c7d45cf1d9566 Mon Sep 17 00:00:00 2001 From: tisalon Date: Sun, 9 Feb 2025 11:53:05 +0100 Subject: [PATCH 54/67] remove problematic unit test --- tests/test_pixelunshuffle.py | 6 ------ tests/test_restormer.py | 4 ++++ 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/test_pixelunshuffle.py b/tests/test_pixelunshuffle.py index 106dbe4d03..922c25e37b 100644 --- a/tests/test_pixelunshuffle.py +++ b/tests/test_pixelunshuffle.py @@ -46,11 +46,5 @@ def test_inverse_operation(self): unshuffled = pixelunshuffle(shuffled, spatial_dims=3, scale_factor=2) torch.testing.assert_close(x, unshuffled) - def test_invalid_scale(self): - x = torch.randn(2, 4, 15, 15) - with self.assertRaises(RuntimeError): - pixelunshuffle(x, spatial_dims=2, scale_factor=2) - - if __name__ == "__main__": unittest.main() diff --git a/tests/test_restormer.py b/tests/test_restormer.py index 5b0ad86f6e..90adfa09e1 100644 --- a/tests/test_restormer.py +++ b/tests/test_restormer.py @@ -91,6 +91,8 @@ class TestMDTATransformerBlock(unittest.TestCase): @skipUnless(has_einops, "Requires einops") @parameterized.expand(TEST_CASES_TRANSFORMER) def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_bias, flash, shape): + if flash and not torch.cuda.is_available(): + self.skipTest("Flash attention requires CUDA") block = MDTATransformerBlock( spatial_dims=spatial_dims, dim=dim, @@ -121,6 +123,8 @@ class TestRestormer(unittest.TestCase): @skipUnless(has_einops, "Requires einops") @parameterized.expand(TEST_CASES_RESTORMER) def test_shape(self, input_param, input_shape, expected_shape): + if input_param.get('flash_attention', False) and not torch.cuda.is_available(): + self.skipTest("Flash attention requires CUDA") net = Restormer(**input_param) with eval_mode(net): result = net(torch.randn(input_shape)) From 19c30f7f9b89611393860d9180a8a394ce2866ff Mon Sep 17 00:00:00 2001 From: tisalon Date: Sun, 9 Feb 2025 11:56:36 +0100 Subject: [PATCH 55/67] remove problematic unit test --- tests/test_pixelunshuffle.py | 1 + tests/test_restormer.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_pixelunshuffle.py b/tests/test_pixelunshuffle.py index 922c25e37b..49b61440e5 100644 --- a/tests/test_pixelunshuffle.py +++ b/tests/test_pixelunshuffle.py @@ -46,5 +46,6 @@ def test_inverse_operation(self): unshuffled = pixelunshuffle(shuffled, spatial_dims=3, scale_factor=2) torch.testing.assert_close(x, unshuffled) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_restormer.py b/tests/test_restormer.py index 90adfa09e1..ab08d84390 100644 --- a/tests/test_restormer.py +++ b/tests/test_restormer.py @@ -123,7 +123,7 @@ class TestRestormer(unittest.TestCase): @skipUnless(has_einops, "Requires einops") @parameterized.expand(TEST_CASES_RESTORMER) def test_shape(self, input_param, input_shape, expected_shape): - if input_param.get('flash_attention', False) and not torch.cuda.is_available(): + if input_param.get("flash_attention", False) and not torch.cuda.is_available(): self.skipTest("Flash attention requires CUDA") net = Restormer(**input_param) with eval_mode(net): From 55da640d69c6a1db4e3a9e22e024adb1f6e3512f Mon Sep 17 00:00:00 2001 From: "Cano-Muniz, Santiago" Date: Sat, 1 Mar 2025 11:00:26 +0100 Subject: [PATCH 56/67] relocate test in the correct place --- tests/networks/blocks/test_CABlock.py | 150 ++++++++++++++++++ .../networks/blocks/test_downsample_block.py | 50 ++++++ tests/networks/nets/test_restormer.py | 147 +++++++++++++++++ tests/networks/utils/test_pixelunshuffle.py | 51 ++++++ 4 files changed, 398 insertions(+) create mode 100644 tests/networks/blocks/test_CABlock.py create mode 100644 tests/networks/blocks/test_downsample_block.py create mode 100644 tests/networks/nets/test_restormer.py create mode 100644 tests/networks/utils/test_pixelunshuffle.py diff --git a/tests/networks/blocks/test_CABlock.py b/tests/networks/blocks/test_CABlock.py new file mode 100644 index 0000000000..42531131c5 --- /dev/null +++ b/tests/networks/blocks/test_CABlock.py @@ -0,0 +1,150 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.blocks.cablock import CABlock, FeedForward +from monai.utils import optional_import +from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose + +einops, has_einops = optional_import("einops") + + +TEST_CASES_CAB = [] +for spatial_dims in [2, 3]: + for dim in [32, 64, 128]: + for num_heads in [2, 4, 8]: + for bias in [True, False]: + test_case = [ + { + "spatial_dims": spatial_dims, + "dim": dim, + "num_heads": num_heads, + "bias": bias, + "flash_attention": False, + }, + (2, dim, *([16] * spatial_dims)), + (2, dim, *([16] * spatial_dims)), + ] + TEST_CASES_CAB.append(test_case) + + +TEST_CASES_FEEDFORWARD = [ + # Test different spatial dims, dimensions and expansion factors + [{"spatial_dims": 2, "dim": 64, "ffn_expansion_factor": 2.0, "bias": True}, (2, 64, 32, 32)], + [{"spatial_dims": 3, "dim": 128, "ffn_expansion_factor": 1.5, "bias": False}, (2, 128, 16, 16, 16)], + [{"spatial_dims": 2, "dim": 256, "ffn_expansion_factor": 1.0, "bias": True}, (1, 256, 64, 64)], +] + + +class TestFeedForward(unittest.TestCase): + + @parameterized.expand(TEST_CASES_FEEDFORWARD) + def test_shape(self, input_param, input_shape): + net = FeedForward(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, input_shape) + + def test_gating_mechanism(self): + net = FeedForward(spatial_dims=2, dim=32, ffn_expansion_factor=2.0, bias=True) + x = torch.ones(1, 32, 16, 16) + out = net(x) + self.assertNotEqual(torch.sum(out), torch.sum(x)) + + +class TestCABlock(unittest.TestCase): + + @parameterized.expand(TEST_CASES_CAB) + @skipUnless(has_einops, "Requires einops") + def test_shape(self, input_param, input_shape, expected_shape): + net = CABlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + @skipUnless(has_einops, "Requires einops") + def test_invalid_spatial_dims(self): + with self.assertRaises(ValueError): + CABlock(spatial_dims=4, dim=64, num_heads=4, bias=True) + + @SkipIfBeforePyTorchVersion((2, 0)) + @skipUnless(has_einops, "Requires einops") + def test_flash_attention(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=True).to(device) + x = torch.randn(2, 64, 32, 32).to(device) + output = block(x) + self.assertEqual(output.shape, x.shape) + + @skipUnless(has_einops, "Requires einops") + def test_temperature_parameter(self): + block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True) + self.assertTrue(isinstance(block.temperature, torch.nn.Parameter)) + self.assertEqual(block.temperature.shape, (4, 1, 1)) + + @skipUnless(has_einops, "Requires einops") + def test_qkv_transformation_2d(self): + block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True) + x = torch.randn(2, 64, 32, 32) + qkv = block.qkv(x) + self.assertEqual(qkv.shape, (2, 192, 32, 32)) + + @skipUnless(has_einops, "Requires einops") + def test_qkv_transformation_3d(self): + block = CABlock(spatial_dims=3, dim=64, num_heads=4, bias=True) + x = torch.randn(2, 64, 16, 16, 16) + qkv = block.qkv(x) + self.assertEqual(qkv.shape, (2, 192, 16, 16, 16)) + + @SkipIfBeforePyTorchVersion((2, 0)) + @skipUnless(has_einops, "Requires einops") + def test_flash_vs_normal_attention(self): + device = "cuda" if torch.cuda.is_available() else "cpu" + block_flash = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=True).to(device) + block_normal = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=False).to(device) + + block_normal.load_state_dict(block_flash.state_dict()) + + x = torch.randn(2, 64, 32, 32).to(device) + with torch.no_grad(): + out_flash = block_flash(x) + out_normal = block_normal(x) + + assert_allclose(out_flash, out_normal, atol=1e-4) + + @skipUnless(has_einops, "Requires einops") + def test_deterministic_small_input(self): + block = CABlock(spatial_dims=2, dim=2, num_heads=1, bias=False) + with torch.no_grad(): + block.qkv.conv.weight.data.fill_(1.0) + block.qkv_dwconv.conv.weight.data.fill_(1.0) + block.temperature.data.fill_(1.0) + block.project_out.conv.weight.data.fill_(1.0) + + x = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]], dtype=torch.float32) + + output = block(x) + # Channel attention: sum([1..8]) * (qkv_conv=1) * (dwconv=1) * (attn_weights=1) * (proj=1) = 36 * 2 = 72 + expected = torch.full_like(x, 72.0) + + assert_allclose(output, expected, atol=1e-6) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/networks/blocks/test_downsample_block.py b/tests/networks/blocks/test_downsample_block.py new file mode 100644 index 0000000000..34afa248ad --- /dev/null +++ b/tests/networks/blocks/test_downsample_block.py @@ -0,0 +1,50 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.blocks import MaxAvgPool + +TEST_CASES = [ + [{"spatial_dims": 2, "kernel_size": 2}, (7, 4, 64, 48), (7, 8, 32, 24)], # 4-channel 2D, batch 7 + [{"spatial_dims": 1, "kernel_size": 4}, (16, 4, 63), (16, 8, 15)], # 4-channel 1D, batch 16 + [{"spatial_dims": 1, "kernel_size": 4, "padding": 1}, (16, 4, 63), (16, 8, 16)], # 4-channel 1D, batch 16 + [ # 4-channel 3D, batch 16 + {"spatial_dims": 3, "kernel_size": 3, "ceil_mode": True}, + (16, 4, 32, 24, 48), + (16, 8, 11, 8, 16), + ], + [ # 1-channel 3D, batch 16 + {"spatial_dims": 3, "kernel_size": 3, "ceil_mode": False}, + (16, 1, 32, 24, 48), + (16, 2, 10, 8, 16), + ], +] + + +class TestMaxAvgPool(unittest.TestCase): + + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_shape, expected_shape): + net = MaxAvgPool(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/networks/nets/test_restormer.py b/tests/networks/nets/test_restormer.py new file mode 100644 index 0000000000..ab08d84390 --- /dev/null +++ b/tests/networks/nets/test_restormer.py @@ -0,0 +1,147 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.restormer import MDTATransformerBlock, OverlapPatchEmbed, Restormer +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") + +TEST_CASES_TRANSFORMER = [ + # [spatial_dims, dim, num_heads, ffn_factor, bias, layer_norm_use_bias, flash_attn, input_shape] + [2, 48, 8, 2.66, True, True, False, (2, 48, 64, 64)], + [2, 96, 8, 2.66, False, False, False, (2, 96, 32, 32)], + [3, 48, 4, 2.66, True, True, False, (2, 48, 32, 32, 32)], + [3, 96, 8, 2.66, False, False, True, (2, 96, 16, 16, 16)], +] + +TEST_CASES_PATCHEMBED = [ + # spatial_dims, in_channels, embed_dim, input_shape, expected_shape + [2, 1, 48, (2, 1, 64, 64), (2, 48, 64, 64)], + [2, 3, 96, (2, 3, 32, 32), (2, 96, 32, 32)], + [3, 1, 48, (2, 1, 32, 32, 32), (2, 48, 32, 32, 32)], + [3, 4, 64, (2, 4, 16, 16, 16), (2, 64, 16, 16, 16)], +] + +RESTORMER_CONFIGS = [ + # 2-level architecture + {"num_blocks": [1, 1], "heads": [1, 1]}, + {"num_blocks": [2, 1], "heads": [2, 1]}, + # 3-level architecture + {"num_blocks": [1, 1, 1], "heads": [1, 1, 1]}, + {"num_blocks": [2, 1, 1], "heads": [2, 1, 1]}, +] + +TEST_CASES_RESTORMER = [] +for config in RESTORMER_CONFIGS: + # 2D cases + TEST_CASES_RESTORMER.extend( + [ + [ + { + "spatial_dims": 2, + "in_channels": 1, + "out_channels": 1, + "dim": 48, + "num_blocks": config["num_blocks"], + "heads": config["heads"], + "num_refinement_blocks": 2, + "ffn_expansion_factor": 1.5, + }, + (2, 1, 64, 64), + (2, 1, 64, 64), + ], + # 3D cases + [ + { + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 1, + "dim": 16, + "num_blocks": config["num_blocks"], + "heads": config["heads"], + "num_refinement_blocks": 2, + "ffn_expansion_factor": 1.5, + }, + (2, 1, 32, 32, 32), + (2, 1, 32, 32, 32), + ], + ] + ) + + +class TestMDTATransformerBlock(unittest.TestCase): + + @skipUnless(has_einops, "Requires einops") + @parameterized.expand(TEST_CASES_TRANSFORMER) + def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_bias, flash, shape): + if flash and not torch.cuda.is_available(): + self.skipTest("Flash attention requires CUDA") + block = MDTATransformerBlock( + spatial_dims=spatial_dims, + dim=dim, + num_heads=heads, + ffn_expansion_factor=ffn_factor, + bias=bias, + layer_norm_use_bias=layer_norm_use_bias, + flash_attention=flash, + ) + with eval_mode(block): + x = torch.randn(shape) + output = block(x) + self.assertEqual(output.shape, x.shape) + + +class TestOverlapPatchEmbed(unittest.TestCase): + + @parameterized.expand(TEST_CASES_PATCHEMBED) + def test_shape(self, spatial_dims, in_channels, embed_dim, input_shape, expected_shape): + net = OverlapPatchEmbed(spatial_dims=spatial_dims, in_channels=in_channels, embed_dim=embed_dim) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + +class TestRestormer(unittest.TestCase): + + @skipUnless(has_einops, "Requires einops") + @parameterized.expand(TEST_CASES_RESTORMER) + def test_shape(self, input_param, input_shape, expected_shape): + if input_param.get("flash_attention", False) and not torch.cuda.is_available(): + self.skipTest("Flash attention requires CUDA") + net = Restormer(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + @skipUnless(has_einops, "Requires einops") + def test_small_input_error_2d(self): + net = Restormer(spatial_dims=2, in_channels=1, out_channels=1) + with self.assertRaises(AssertionError): + net(torch.randn(1, 1, 8, 8)) + + @skipUnless(has_einops, "Requires einops") + def test_small_input_error_3d(self): + net = Restormer(spatial_dims=3, in_channels=1, out_channels=1) + with self.assertRaises(AssertionError): + net(torch.randn(1, 1, 8, 8, 8)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/networks/utils/test_pixelunshuffle.py b/tests/networks/utils/test_pixelunshuffle.py new file mode 100644 index 0000000000..49b61440e5 --- /dev/null +++ b/tests/networks/utils/test_pixelunshuffle.py @@ -0,0 +1,51 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.networks.utils import pixelshuffle, pixelunshuffle + + +class TestPixelUnshuffle(unittest.TestCase): + + def test_2d_basic(self): + x = torch.randn(2, 4, 16, 16) + out = pixelunshuffle(x, spatial_dims=2, scale_factor=2) + self.assertEqual(out.shape, (2, 16, 8, 8)) + + def test_3d_basic(self): + x = torch.randn(2, 4, 16, 16, 16) + out = pixelunshuffle(x, spatial_dims=3, scale_factor=2) + self.assertEqual(out.shape, (2, 32, 8, 8, 8)) + + def test_non_square_input(self): + x = torch.arange(192).reshape(1, 2, 12, 8) + out = pixelunshuffle(x, spatial_dims=2, scale_factor=2) + torch.testing.assert_close(out, torch.pixel_unshuffle(x, 2)) + + def test_different_scale_factor(self): + x = torch.arange(360).reshape(1, 2, 12, 15) + out = pixelunshuffle(x, spatial_dims=2, scale_factor=3) + torch.testing.assert_close(out, torch.pixel_unshuffle(x, 3)) + + def test_inverse_operation(self): + x = torch.arange(4096).reshape(1, 8, 8, 8, 8) + shuffled = pixelshuffle(x, spatial_dims=3, scale_factor=2) + unshuffled = pixelunshuffle(shuffled, spatial_dims=3, scale_factor=2) + torch.testing.assert_close(x, unshuffled) + + +if __name__ == "__main__": + unittest.main() From 3c2dbc61f26c16d411558f9adb0fd92c5a65569b Mon Sep 17 00:00:00 2001 From: "Cano-Muniz, Santiago" Date: Sat, 1 Mar 2025 17:37:32 +0100 Subject: [PATCH 57/67] Add DownSampleBlock missing tests, Signed-off-by: Santiago Cano-Muniz , I, Cano-Muniz, Santiago , hereby add my Signed-off-by to this commit: 55da640d69c6a1db4e3a9e22e024adb1f6e3512f --- .../networks/blocks/test_downsample_block.py | 135 +++++++++++++++++- 1 file changed, 134 insertions(+), 1 deletion(-) diff --git a/tests/networks/blocks/test_downsample_block.py b/tests/networks/blocks/test_downsample_block.py index 34afa248ad..5e660510d4 100644 --- a/tests/networks/blocks/test_downsample_block.py +++ b/tests/networks/blocks/test_downsample_block.py @@ -17,7 +17,10 @@ from parameterized import parameterized from monai.networks import eval_mode -from monai.networks.blocks import MaxAvgPool +from monai.networks.blocks import DownSample, MaxAvgPool, SubpixelDownsample, SubpixelUpsample +from monai.utils import optional_import + +einops, has_einops = optional_import("einops") TEST_CASES = [ [{"spatial_dims": 2, "kernel_size": 2}, (7, 4, 64, 48), (7, 8, 32, 24)], # 4-channel 2D, batch 7 @@ -35,6 +38,20 @@ ], ] +TEST_CASES_SUBPIXEL = [ + [{"spatial_dims": 2, "in_channels": 1, "scale_factor": 2}, (1, 1, 8, 8), (1, 4, 4, 4)], + [{"spatial_dims": 3, "in_channels": 2, "scale_factor": 2}, (1, 2, 8, 8, 8), (1, 16, 4, 4, 4)], + [{"spatial_dims": 1, "in_channels": 3, "scale_factor": 2}, (1, 3, 8), (1, 6, 4)], +] + +TEST_CASES_DOWNSAMPLE = [ + [{"spatial_dims": 2, "in_channels": 4, "mode": "conv"}, (1, 4, 16, 16), (1, 4, 8, 8)], + [{"spatial_dims": 2, "in_channels": 4, "out_channels": 8, "mode": "convgroup"}, (1, 4, 16, 16), (1, 8, 8, 8)], + [{"spatial_dims": 3, "in_channels": 2, "mode": "maxpool"}, (1, 2, 16, 16, 16), (1, 2, 8, 8, 8)], + [{"spatial_dims": 2, "in_channels": 4, "mode": "avgpool"}, (1, 4, 16, 16), (1, 4, 8, 8)], + [{"spatial_dims": 2, "in_channels": 1, "mode": "pixelunshuffle"}, (1, 1, 16, 16), (1, 4, 8, 8)], +] + class TestMaxAvgPool(unittest.TestCase): @@ -46,5 +63,121 @@ def test_shape(self, input_param, input_shape, expected_shape): self.assertEqual(result.shape, expected_shape) +class TestSubpixelDownsample(unittest.TestCase): + + @parameterized.expand(TEST_CASES_SUBPIXEL) + def test_shape(self, input_param, input_shape, expected_shape): + downsampler = SubpixelDownsample(**input_param) + with eval_mode(downsampler): + result = downsampler(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_predefined_tensor(self): + test_tensor = torch.arange(4).view(4, 1, 1).repeat(1, 4, 4) + test_tensor = test_tensor.unsqueeze(0) + + downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=None) + with eval_mode(downsampler): + result = downsampler(test_tensor) + self.assertEqual(result.shape, (1, 16, 2, 2)) + self.assertTrue(torch.all(result[0, 0:3] == 0)) + self.assertTrue(torch.all(result[0, 4:7] == 1)) + self.assertTrue(torch.all(result[0, 8:11] == 2)) + self.assertTrue(torch.all(result[0, 12:15] == 3)) + + def test_reconstruction_2d(self): + input_tensor = torch.randn(1, 1, 4, 4) + down = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=None) + up = SubpixelUpsample(spatial_dims=2, in_channels=4, scale_factor=2, conv_block=None, apply_pad_pool=False) + with eval_mode(down), eval_mode(up): + downsampled = down(input_tensor) + reconstructed = up(downsampled) + self.assertTrue(torch.allclose(input_tensor, reconstructed, rtol=1e-5)) + + def test_reconstruction_3d(self): + input_tensor = torch.randn(1, 1, 4, 4, 4) + down = SubpixelDownsample(spatial_dims=3, in_channels=1, scale_factor=2, conv_block=None) + up = SubpixelUpsample(spatial_dims=3, in_channels=4, scale_factor=2, conv_block=None, apply_pad_pool=False) + with eval_mode(down), eval_mode(up): + downsampled = down(input_tensor) + reconstructed = up(downsampled) + self.assertTrue(torch.allclose(input_tensor, reconstructed, rtol=1e-5)) + + def test_invalid_spatial_size(self): + downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2) + with self.assertRaises(ValueError): + downsampler(torch.randn(1, 1, 3, 4)) + + def test_custom_conv_block(self): + custom_conv = torch.nn.Conv2d(1, 2, kernel_size=3, padding=1) + downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=custom_conv) + with eval_mode(downsampler): + result = downsampler(torch.randn(1, 1, 4, 4)) + self.assertEqual(result.shape, (1, 8, 2, 2)) + + +class TestDownSample(unittest.TestCase): + @parameterized.expand(TEST_CASES_DOWNSAMPLE) + def test_shape(self, input_param, input_shape, expected_shape): + net = DownSample(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_pre_post_conv(self): + net = DownSample( + spatial_dims=2, + in_channels=4, + out_channels=8, + mode="maxpool", + pre_conv="default", + post_conv=torch.nn.Conv2d(8, 16, 1), + ) + with eval_mode(net): + result = net(torch.randn(1, 4, 16, 16)) + self.assertEqual(result.shape, (1, 16, 8, 8)) + + def test_pixelunshuffle_equivalence(self): + class DownSampleLocal(torch.nn.Module): + def __init__(self, n_feat: int): + super().__init__() + self.conv = torch.nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False) + self.pixelunshuffle = torch.nn.PixelUnshuffle(2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) + return self.pixelunshuffle(x) + + n_feat = 2 + x = torch.randn(1, n_feat, 64, 64) + + fix_weight_conv = torch.nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False) + + monai_down = DownSample( + spatial_dims=2, + in_channels=n_feat, + out_channels=n_feat // 2, + mode="pixelunshuffle", + pre_conv=fix_weight_conv, + ) + + local_down = DownSampleLocal(n_feat) + local_down.conv.weight.data = fix_weight_conv.weight.data.clone() + + with eval_mode(monai_down), eval_mode(local_down): + out_monai = monai_down(x) + out_local = local_down(x) + + self.assertTrue(torch.allclose(out_monai, out_local, rtol=1e-5)) + + def test_invalid_mode(self): + with self.assertRaises(ValueError): + DownSample(spatial_dims=2, in_channels=4, mode="invalid") + + def test_missing_channels(self): + with self.assertRaises(ValueError): + DownSample(spatial_dims=2, mode="conv") + + if __name__ == "__main__": unittest.main() From f17e06e1522597650c2368ddf0edf219c9c00cd9 Mon Sep 17 00:00:00 2001 From: "Cano-Muniz, Santiago" Date: Sat, 8 Mar 2025 20:55:25 +0100 Subject: [PATCH 58/67] Re-order skipUnless in test_restormer.py, Signed-off-by: Cano-Muniz, Santiago --- tests/networks/nets/test_restormer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/networks/nets/test_restormer.py b/tests/networks/nets/test_restormer.py index ab08d84390..9b54b7a765 100644 --- a/tests/networks/nets/test_restormer.py +++ b/tests/networks/nets/test_restormer.py @@ -88,8 +88,8 @@ class TestMDTATransformerBlock(unittest.TestCase): - @skipUnless(has_einops, "Requires einops") @parameterized.expand(TEST_CASES_TRANSFORMER) + @skipUnless(has_einops, "Requires einops") def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_bias, flash, shape): if flash and not torch.cuda.is_available(): self.skipTest("Flash attention requires CUDA") @@ -111,6 +111,7 @@ def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_ class TestOverlapPatchEmbed(unittest.TestCase): @parameterized.expand(TEST_CASES_PATCHEMBED) + @skipUnless(has_einops, "Requires einops") def test_shape(self, spatial_dims, in_channels, embed_dim, input_shape, expected_shape): net = OverlapPatchEmbed(spatial_dims=spatial_dims, in_channels=in_channels, embed_dim=embed_dim) with eval_mode(net): @@ -120,8 +121,8 @@ def test_shape(self, spatial_dims, in_channels, embed_dim, input_shape, expected class TestRestormer(unittest.TestCase): - @skipUnless(has_einops, "Requires einops") @parameterized.expand(TEST_CASES_RESTORMER) + @skipUnless(has_einops, "Requires einops") def test_shape(self, input_param, input_shape, expected_shape): if input_param.get("flash_attention", False) and not torch.cuda.is_available(): self.skipTest("Flash attention requires CUDA") From 4573ec9247497564383db79363b4ab0b88de0de7 Mon Sep 17 00:00:00 2001 From: "Cano-Muniz, Santiago" Date: Sat, 8 Mar 2025 20:57:36 +0100 Subject: [PATCH 59/67] Clarify comments for RESTORMER_CONFIGS in test_restormer.py, I, Cano-Muniz, Santiago , hereby add my Signed-off-by to this commit: 3c2dbc61f26c16d411558f9adb0fd92c5a65569b Signed-off-by: Cano-Muniz, Santiago --- tests/networks/nets/test_restormer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/networks/nets/test_restormer.py b/tests/networks/nets/test_restormer.py index 9b54b7a765..f0466caa2d 100644 --- a/tests/networks/nets/test_restormer.py +++ b/tests/networks/nets/test_restormer.py @@ -40,10 +40,10 @@ ] RESTORMER_CONFIGS = [ - # 2-level architecture + # 2-level architecture test {"num_blocks": [1, 1], "heads": [1, 1]}, {"num_blocks": [2, 1], "heads": [2, 1]}, - # 3-level architecture + # 3-level architecture test {"num_blocks": [1, 1, 1], "heads": [1, 1, 1]}, {"num_blocks": [2, 1, 1], "heads": [2, 1, 1]}, ] From 8c564aaabcd6e3a25e7f5fc8cf886e3909c49a1c Mon Sep 17 00:00:00 2001 From: tisalon Date: Sat, 8 Mar 2025 21:33:32 +0100 Subject: [PATCH 60/67] Remove duplicated test_CABlock.py as part of codebase cleanup. In addition, solve DCO: DCO Remediation Commit for tisalon I, tisalon , hereby add my Signed-off-by to this commit: 8faa5da5fa5ac514e93d7c2f313927971e0e9b9d I, tisalon , hereby add my Signed-off-by to this commit: 091887b7aabc0490065e5d0053f22e2475cf0c42 I, tisalon , hereby add my Signed-off-by to this commit: 5d162d02236331a70dae1a74bfff49a6ddeab655 I, tisalon , hereby add my Signed-off-by to this commit: f520e99c3a3533c7c848caf0208e70dfefd4dfc0 I, tisalon , hereby add my Signed-off-by to this commit: 39d1edfb2adece2f703cd173ab206d4edf253f15 I, tisalon , hereby add my Signed-off-by to this commit: 5b3d4e1176e901cd46de9fbde2c672d0f69ba9bd I, tisalon , hereby add my Signed-off-by to this commit: 1683b14dcd8e3a2fa47dafaff41de1b2c6b4d0ee I, tisalon , hereby add my Signed-off-by to this commit: 232be1cf8e26c0c5826e38113281d952bfc23003 I, tisalon , hereby add my Signed-off-by to this commit: d1df8e631389a3a66f79af4bf76da78c499e7a9d I, tisalon , hereby add my Signed-off-by to this commit: 78ce56b11bfca78db73566317c5eab145a5c9475 I, tisalon , hereby add my Signed-off-by to this commit: ce158865f10f7cb7de7c1b244a2021ee9025dcf8 I, tisalon , hereby add my Signed-off-by to this commit: 30fad17d5868b3ace5dc321102ae20450ad6f827 I, tisalon , hereby add my Signed-off-by to this commit: 529e90bcdb57534c76f4cea77f0cd2007a89c87f I, tisalon , hereby add my Signed-off-by to this commit: c109029ad4f73bcada08cc2cb50c7d45cf1d9566 I, tisalon , hereby add my Signed-off-by to this commit: 19c30f7f9b89611393860d9180a8a394ce2866ff Signed-off-by: tisalon --- tests/test_CABlock.py | 150 ------------------------------------------ 1 file changed, 150 deletions(-) delete mode 100644 tests/test_CABlock.py diff --git a/tests/test_CABlock.py b/tests/test_CABlock.py deleted file mode 100644 index 4b36f2e72c..0000000000 --- a/tests/test_CABlock.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest -from unittest import skipUnless - -import torch -from parameterized import parameterized - -from monai.networks import eval_mode -from monai.networks.blocks.cablock import CABlock, FeedForward -from monai.utils import optional_import -from tests.utils import SkipIfBeforePyTorchVersion, assert_allclose - -einops, has_einops = optional_import("einops") - - -TEST_CASES_CAB = [] -for spatial_dims in [2, 3]: - for dim in [32, 64, 128]: - for num_heads in [2, 4, 8]: - for bias in [True, False]: - test_case = [ - { - "spatial_dims": spatial_dims, - "dim": dim, - "num_heads": num_heads, - "bias": bias, - "flash_attention": False, - }, - (2, dim, *([16] * spatial_dims)), - (2, dim, *([16] * spatial_dims)), - ] - TEST_CASES_CAB.append(test_case) - - -TEST_CASES_FEEDFORWARD = [ - # Test different spatial dims, dimensions and expansion factors - [{"spatial_dims": 2, "dim": 64, "ffn_expansion_factor": 2.0, "bias": True}, (2, 64, 32, 32)], - [{"spatial_dims": 3, "dim": 128, "ffn_expansion_factor": 1.5, "bias": False}, (2, 128, 16, 16, 16)], - [{"spatial_dims": 2, "dim": 256, "ffn_expansion_factor": 1.0, "bias": True}, (1, 256, 64, 64)], -] - - -class TestFeedForward(unittest.TestCase): - - @parameterized.expand(TEST_CASES_FEEDFORWARD) - def test_shape(self, input_param, input_shape): - net = FeedForward(**input_param) - with eval_mode(net): - result = net(torch.randn(input_shape)) - self.assertEqual(result.shape, input_shape) - - def test_gating_mechanism(self): - net = FeedForward(spatial_dims=2, dim=32, ffn_expansion_factor=2.0, bias=True) - x = torch.ones(1, 32, 16, 16) - out = net(x) - self.assertNotEqual(torch.sum(out), torch.sum(x)) - - -class TestCABlock(unittest.TestCase): - - @parameterized.expand(TEST_CASES_CAB) - @skipUnless(has_einops, "Requires einops") - def test_shape(self, input_param, input_shape, expected_shape): - net = CABlock(**input_param) - with eval_mode(net): - result = net(torch.randn(input_shape)) - self.assertEqual(result.shape, expected_shape) - - @skipUnless(has_einops, "Requires einops") - def test_invalid_spatial_dims(self): - with self.assertRaises(ValueError): - CABlock(spatial_dims=4, dim=64, num_heads=4, bias=True) - - @SkipIfBeforePyTorchVersion((2, 0)) - @skipUnless(has_einops, "Requires einops") - def test_flash_attention(self): - device = "cuda" if torch.cuda.is_available() else "cpu" - block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=True).to(device) - x = torch.randn(2, 64, 32, 32).to(device) - output = block(x) - self.assertEqual(output.shape, x.shape) - - @skipUnless(has_einops, "Requires einops") - def test_temperature_parameter(self): - block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True) - self.assertTrue(isinstance(block.temperature, torch.nn.Parameter)) - self.assertEqual(block.temperature.shape, (4, 1, 1)) - - @skipUnless(has_einops, "Requires einops") - def test_qkv_transformation_2d(self): - block = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True) - x = torch.randn(2, 64, 32, 32) - qkv = block.qkv(x) - self.assertEqual(qkv.shape, (2, 192, 32, 32)) - - @skipUnless(has_einops, "Requires einops") - def test_qkv_transformation_3d(self): - block = CABlock(spatial_dims=3, dim=64, num_heads=4, bias=True) - x = torch.randn(2, 64, 16, 16, 16) - qkv = block.qkv(x) - self.assertEqual(qkv.shape, (2, 192, 16, 16, 16)) - - @SkipIfBeforePyTorchVersion((2, 0)) - @skipUnless(has_einops, "Requires einops") - def test_flash_vs_normal_attention(self): - device = "cuda" if torch.cuda.is_available() else "cpu" - block_flash = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=True).to(device) - block_normal = CABlock(spatial_dims=2, dim=64, num_heads=4, bias=True, flash_attention=False).to(device) - - block_normal.load_state_dict(block_flash.state_dict()) - - x = torch.randn(2, 64, 32, 32).to(device) - with torch.no_grad(): - out_flash = block_flash(x) - out_normal = block_normal(x) - - assert_allclose(out_flash, out_normal, atol=1e-4) - - @skipUnless(has_einops, "Requires einops") - def test_deterministic_small_input(self): - block = CABlock(spatial_dims=2, dim=2, num_heads=1, bias=False) - with torch.no_grad(): - block.qkv.conv.weight.data.fill_(1.0) - block.qkv_dwconv.conv.weight.data.fill_(1.0) - block.temperature.data.fill_(1.0) - block.project_out.conv.weight.data.fill_(1.0) - - x = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]], dtype=torch.float32) - - output = block(x) - # Channel attention: sum([1..8]) * (qkv_conv=1) * (dwconv=1) * (attn_weights=1) * (proj=1) = 36 * 2 = 72 - expected = torch.full_like(x, 72.0) - - assert_allclose(output, expected, atol=1e-6) - - -if __name__ == "__main__": - unittest.main() From 3e013fe45e7ee1e2a046b81fee7e9344fc591e18 Mon Sep 17 00:00:00 2001 From: "Cano-Muniz, Santiago" Date: Sat, 8 Mar 2025 22:04:12 +0100 Subject: [PATCH 61/67] Refactor test cases in test_restormer.py to conditionally define classes based on einops availability, and solve last DCO issue: I, Cano-Muniz, Santiago , hereby add my Signed-off-by to this commit: f17e06e1522597650c2368ddf0edf219c9c00cd9 Signed-off-by: Cano-Muniz, Santiago --- tests/networks/nets/test_restormer.py | 98 +++++++++++++++------------ 1 file changed, 53 insertions(+), 45 deletions(-) diff --git a/tests/networks/nets/test_restormer.py b/tests/networks/nets/test_restormer.py index f0466caa2d..adad6e1f9a 100644 --- a/tests/networks/nets/test_restormer.py +++ b/tests/networks/nets/test_restormer.py @@ -86,26 +86,31 @@ ) -class TestMDTATransformerBlock(unittest.TestCase): - - @parameterized.expand(TEST_CASES_TRANSFORMER) - @skipUnless(has_einops, "Requires einops") - def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_bias, flash, shape): - if flash and not torch.cuda.is_available(): - self.skipTest("Flash attention requires CUDA") - block = MDTATransformerBlock( - spatial_dims=spatial_dims, - dim=dim, - num_heads=heads, - ffn_expansion_factor=ffn_factor, - bias=bias, - layer_norm_use_bias=layer_norm_use_bias, - flash_attention=flash, - ) - with eval_mode(block): - x = torch.randn(shape) - output = block(x) - self.assertEqual(output.shape, x.shape) +if has_einops: + class TestMDTATransformerBlock(unittest.TestCase): + + @parameterized.expand(TEST_CASES_TRANSFORMER) + @skipUnless(has_einops, "Requires einops") + def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_bias, flash, shape): + if flash and not torch.cuda.is_available(): + self.skipTest("Flash attention requires CUDA") + block = MDTATransformerBlock( + spatial_dims=spatial_dims, + dim=dim, + num_heads=heads, + ffn_expansion_factor=ffn_factor, + bias=bias, + layer_norm_use_bias=layer_norm_use_bias, + flash_attention=flash, + ) + with eval_mode(block): + x = torch.randn(shape) + output = block(x) + self.assertEqual(output.shape, x.shape) +else: + class TestMDTATransformerBlock(unittest.TestCase): + def test_placeholder(self): + self.skipTest("Einops module not available") class TestOverlapPatchEmbed(unittest.TestCase): @@ -118,31 +123,34 @@ def test_shape(self, spatial_dims, in_channels, embed_dim, input_shape, expected result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) - -class TestRestormer(unittest.TestCase): - - @parameterized.expand(TEST_CASES_RESTORMER) - @skipUnless(has_einops, "Requires einops") - def test_shape(self, input_param, input_shape, expected_shape): - if input_param.get("flash_attention", False) and not torch.cuda.is_available(): - self.skipTest("Flash attention requires CUDA") - net = Restormer(**input_param) - with eval_mode(net): - result = net(torch.randn(input_shape)) - self.assertEqual(result.shape, expected_shape) - - @skipUnless(has_einops, "Requires einops") - def test_small_input_error_2d(self): - net = Restormer(spatial_dims=2, in_channels=1, out_channels=1) - with self.assertRaises(AssertionError): - net(torch.randn(1, 1, 8, 8)) - - @skipUnless(has_einops, "Requires einops") - def test_small_input_error_3d(self): - net = Restormer(spatial_dims=3, in_channels=1, out_channels=1) - with self.assertRaises(AssertionError): - net(torch.randn(1, 1, 8, 8, 8)) - +if has_einops: + class TestRestormer(unittest.TestCase): + + @parameterized.expand(TEST_CASES_RESTORMER) + @skipUnless(has_einops, "Requires einops") + def test_shape(self, input_param, input_shape, expected_shape): + if input_param.get("flash_attention", False) and not torch.cuda.is_available(): + self.skipTest("Flash attention requires CUDA") + net = Restormer(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + @skipUnless(has_einops, "Requires einops") + def test_small_input_error_2d(self): + net = Restormer(spatial_dims=2, in_channels=1, out_channels=1) + with self.assertRaises(AssertionError): + net(torch.randn(1, 1, 8, 8)) + + @skipUnless(has_einops, "Requires einops") + def test_small_input_error_3d(self): + net = Restormer(spatial_dims=3, in_channels=1, out_channels=1) + with self.assertRaises(AssertionError): + net(torch.randn(1, 1, 8, 8, 8)) +else: + class TestRestormer(unittest.TestCase): + def test_placeholder(self): + self.skipTest("Einops module not available") if __name__ == "__main__": unittest.main() From 06be2ef54f0c250c9920fe977dc4e312fb09e252 Mon Sep 17 00:00:00 2001 From: "Cano-Muniz, Santiago" Date: Sat, 8 Mar 2025 22:20:43 +0100 Subject: [PATCH 62/67] formatting error in line 237. Solved by updating black from 24.10.0 to black 25.1.0, Signed-off-by: Cano-Muniz, Santiago --- monai/utils/jupyter_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/utils/jupyter_utils.py b/monai/utils/jupyter_utils.py index b1b43a6767..c93e93dcb9 100644 --- a/monai/utils/jupyter_utils.py +++ b/monai/utils/jupyter_utils.py @@ -234,7 +234,7 @@ def plot_engine_status( def _get_loss_from_output( - output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor + output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor, ) -> torch.Tensor: """Returns a single value from the network output, which is a dict or tensor.""" From c02d794795b4a252946be904535e2024d155ee9f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 8 Mar 2025 21:21:32 +0000 Subject: [PATCH 63/67] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/networks/nets/test_restormer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/networks/nets/test_restormer.py b/tests/networks/nets/test_restormer.py index adad6e1f9a..7259766bd0 100644 --- a/tests/networks/nets/test_restormer.py +++ b/tests/networks/nets/test_restormer.py @@ -150,7 +150,7 @@ def test_small_input_error_3d(self): else: class TestRestormer(unittest.TestCase): def test_placeholder(self): - self.skipTest("Einops module not available") + self.skipTest("Einops module not available") if __name__ == "__main__": unittest.main() From 7f46ac5feefa1940a0afe95fedc5b3dea1f42ea4 Mon Sep 17 00:00:00 2001 From: "Cano-Muniz, Santiago" Date: Tue, 11 Mar 2025 20:51:49 +0100 Subject: [PATCH 64/67] Remove duplicated tests and place the order of the decorators (skipUnless lower so that it interpreted first). Signed-off-by: Cano-Muniz, Santiago --- tests/integration/test_downsample_block.py | 183 --------------------- tests/networks/nets/test_restormer.py | 106 ++++++------ tests/test_pixelunshuffle.py | 51 ------ tests/test_restormer.py | 147 ----------------- 4 files changed, 49 insertions(+), 438 deletions(-) delete mode 100644 tests/integration/test_downsample_block.py delete mode 100644 tests/test_pixelunshuffle.py delete mode 100644 tests/test_restormer.py diff --git a/tests/integration/test_downsample_block.py b/tests/integration/test_downsample_block.py deleted file mode 100644 index 5e660510d4..0000000000 --- a/tests/integration/test_downsample_block.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest - -import torch -from parameterized import parameterized - -from monai.networks import eval_mode -from monai.networks.blocks import DownSample, MaxAvgPool, SubpixelDownsample, SubpixelUpsample -from monai.utils import optional_import - -einops, has_einops = optional_import("einops") - -TEST_CASES = [ - [{"spatial_dims": 2, "kernel_size": 2}, (7, 4, 64, 48), (7, 8, 32, 24)], # 4-channel 2D, batch 7 - [{"spatial_dims": 1, "kernel_size": 4}, (16, 4, 63), (16, 8, 15)], # 4-channel 1D, batch 16 - [{"spatial_dims": 1, "kernel_size": 4, "padding": 1}, (16, 4, 63), (16, 8, 16)], # 4-channel 1D, batch 16 - [ # 4-channel 3D, batch 16 - {"spatial_dims": 3, "kernel_size": 3, "ceil_mode": True}, - (16, 4, 32, 24, 48), - (16, 8, 11, 8, 16), - ], - [ # 1-channel 3D, batch 16 - {"spatial_dims": 3, "kernel_size": 3, "ceil_mode": False}, - (16, 1, 32, 24, 48), - (16, 2, 10, 8, 16), - ], -] - -TEST_CASES_SUBPIXEL = [ - [{"spatial_dims": 2, "in_channels": 1, "scale_factor": 2}, (1, 1, 8, 8), (1, 4, 4, 4)], - [{"spatial_dims": 3, "in_channels": 2, "scale_factor": 2}, (1, 2, 8, 8, 8), (1, 16, 4, 4, 4)], - [{"spatial_dims": 1, "in_channels": 3, "scale_factor": 2}, (1, 3, 8), (1, 6, 4)], -] - -TEST_CASES_DOWNSAMPLE = [ - [{"spatial_dims": 2, "in_channels": 4, "mode": "conv"}, (1, 4, 16, 16), (1, 4, 8, 8)], - [{"spatial_dims": 2, "in_channels": 4, "out_channels": 8, "mode": "convgroup"}, (1, 4, 16, 16), (1, 8, 8, 8)], - [{"spatial_dims": 3, "in_channels": 2, "mode": "maxpool"}, (1, 2, 16, 16, 16), (1, 2, 8, 8, 8)], - [{"spatial_dims": 2, "in_channels": 4, "mode": "avgpool"}, (1, 4, 16, 16), (1, 4, 8, 8)], - [{"spatial_dims": 2, "in_channels": 1, "mode": "pixelunshuffle"}, (1, 1, 16, 16), (1, 4, 8, 8)], -] - - -class TestMaxAvgPool(unittest.TestCase): - - @parameterized.expand(TEST_CASES) - def test_shape(self, input_param, input_shape, expected_shape): - net = MaxAvgPool(**input_param) - with eval_mode(net): - result = net(torch.randn(input_shape)) - self.assertEqual(result.shape, expected_shape) - - -class TestSubpixelDownsample(unittest.TestCase): - - @parameterized.expand(TEST_CASES_SUBPIXEL) - def test_shape(self, input_param, input_shape, expected_shape): - downsampler = SubpixelDownsample(**input_param) - with eval_mode(downsampler): - result = downsampler(torch.randn(input_shape)) - self.assertEqual(result.shape, expected_shape) - - def test_predefined_tensor(self): - test_tensor = torch.arange(4).view(4, 1, 1).repeat(1, 4, 4) - test_tensor = test_tensor.unsqueeze(0) - - downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=None) - with eval_mode(downsampler): - result = downsampler(test_tensor) - self.assertEqual(result.shape, (1, 16, 2, 2)) - self.assertTrue(torch.all(result[0, 0:3] == 0)) - self.assertTrue(torch.all(result[0, 4:7] == 1)) - self.assertTrue(torch.all(result[0, 8:11] == 2)) - self.assertTrue(torch.all(result[0, 12:15] == 3)) - - def test_reconstruction_2d(self): - input_tensor = torch.randn(1, 1, 4, 4) - down = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=None) - up = SubpixelUpsample(spatial_dims=2, in_channels=4, scale_factor=2, conv_block=None, apply_pad_pool=False) - with eval_mode(down), eval_mode(up): - downsampled = down(input_tensor) - reconstructed = up(downsampled) - self.assertTrue(torch.allclose(input_tensor, reconstructed, rtol=1e-5)) - - def test_reconstruction_3d(self): - input_tensor = torch.randn(1, 1, 4, 4, 4) - down = SubpixelDownsample(spatial_dims=3, in_channels=1, scale_factor=2, conv_block=None) - up = SubpixelUpsample(spatial_dims=3, in_channels=4, scale_factor=2, conv_block=None, apply_pad_pool=False) - with eval_mode(down), eval_mode(up): - downsampled = down(input_tensor) - reconstructed = up(downsampled) - self.assertTrue(torch.allclose(input_tensor, reconstructed, rtol=1e-5)) - - def test_invalid_spatial_size(self): - downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2) - with self.assertRaises(ValueError): - downsampler(torch.randn(1, 1, 3, 4)) - - def test_custom_conv_block(self): - custom_conv = torch.nn.Conv2d(1, 2, kernel_size=3, padding=1) - downsampler = SubpixelDownsample(spatial_dims=2, in_channels=1, scale_factor=2, conv_block=custom_conv) - with eval_mode(downsampler): - result = downsampler(torch.randn(1, 1, 4, 4)) - self.assertEqual(result.shape, (1, 8, 2, 2)) - - -class TestDownSample(unittest.TestCase): - @parameterized.expand(TEST_CASES_DOWNSAMPLE) - def test_shape(self, input_param, input_shape, expected_shape): - net = DownSample(**input_param) - with eval_mode(net): - result = net(torch.randn(input_shape)) - self.assertEqual(result.shape, expected_shape) - - def test_pre_post_conv(self): - net = DownSample( - spatial_dims=2, - in_channels=4, - out_channels=8, - mode="maxpool", - pre_conv="default", - post_conv=torch.nn.Conv2d(8, 16, 1), - ) - with eval_mode(net): - result = net(torch.randn(1, 4, 16, 16)) - self.assertEqual(result.shape, (1, 16, 8, 8)) - - def test_pixelunshuffle_equivalence(self): - class DownSampleLocal(torch.nn.Module): - def __init__(self, n_feat: int): - super().__init__() - self.conv = torch.nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False) - self.pixelunshuffle = torch.nn.PixelUnshuffle(2) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.conv(x) - return self.pixelunshuffle(x) - - n_feat = 2 - x = torch.randn(1, n_feat, 64, 64) - - fix_weight_conv = torch.nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False) - - monai_down = DownSample( - spatial_dims=2, - in_channels=n_feat, - out_channels=n_feat // 2, - mode="pixelunshuffle", - pre_conv=fix_weight_conv, - ) - - local_down = DownSampleLocal(n_feat) - local_down.conv.weight.data = fix_weight_conv.weight.data.clone() - - with eval_mode(monai_down), eval_mode(local_down): - out_monai = monai_down(x) - out_local = local_down(x) - - self.assertTrue(torch.allclose(out_monai, out_local, rtol=1e-5)) - - def test_invalid_mode(self): - with self.assertRaises(ValueError): - DownSample(spatial_dims=2, in_channels=4, mode="invalid") - - def test_missing_channels(self): - with self.assertRaises(ValueError): - DownSample(spatial_dims=2, mode="conv") - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/networks/nets/test_restormer.py b/tests/networks/nets/test_restormer.py index 7259766bd0..77a5e47b4c 100644 --- a/tests/networks/nets/test_restormer.py +++ b/tests/networks/nets/test_restormer.py @@ -40,10 +40,10 @@ ] RESTORMER_CONFIGS = [ - # 2-level architecture test + # 2-level architecture {"num_blocks": [1, 1], "heads": [1, 1]}, {"num_blocks": [2, 1], "heads": [2, 1]}, - # 3-level architecture test + # 3-level architecture {"num_blocks": [1, 1, 1], "heads": [1, 1, 1]}, {"num_blocks": [2, 1, 1], "heads": [2, 1, 1]}, ] @@ -86,71 +86,63 @@ ) -if has_einops: - class TestMDTATransformerBlock(unittest.TestCase): - - @parameterized.expand(TEST_CASES_TRANSFORMER) - @skipUnless(has_einops, "Requires einops") - def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_bias, flash, shape): - if flash and not torch.cuda.is_available(): - self.skipTest("Flash attention requires CUDA") - block = MDTATransformerBlock( - spatial_dims=spatial_dims, - dim=dim, - num_heads=heads, - ffn_expansion_factor=ffn_factor, - bias=bias, - layer_norm_use_bias=layer_norm_use_bias, - flash_attention=flash, - ) - with eval_mode(block): - x = torch.randn(shape) - output = block(x) - self.assertEqual(output.shape, x.shape) -else: - class TestMDTATransformerBlock(unittest.TestCase): - def test_placeholder(self): - self.skipTest("Einops module not available") +class TestMDTATransformerBlock(unittest.TestCase): + + @parameterized.expand(TEST_CASES_TRANSFORMER) + @skipUnless(has_einops, "Requires einops") + def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_bias, flash, shape): + if flash and not torch.cuda.is_available(): + self.skipTest("Flash attention requires CUDA") + block = MDTATransformerBlock( + spatial_dims=spatial_dims, + dim=dim, + num_heads=heads, + ffn_expansion_factor=ffn_factor, + bias=bias, + layer_norm_use_bias=layer_norm_use_bias, + flash_attention=flash, + ) + with eval_mode(block): + x = torch.randn(shape) + output = block(x) + self.assertEqual(output.shape, x.shape) class TestOverlapPatchEmbed(unittest.TestCase): @parameterized.expand(TEST_CASES_PATCHEMBED) - @skipUnless(has_einops, "Requires einops") def test_shape(self, spatial_dims, in_channels, embed_dim, input_shape, expected_shape): net = OverlapPatchEmbed(spatial_dims=spatial_dims, in_channels=in_channels, embed_dim=embed_dim) with eval_mode(net): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) -if has_einops: - class TestRestormer(unittest.TestCase): - - @parameterized.expand(TEST_CASES_RESTORMER) - @skipUnless(has_einops, "Requires einops") - def test_shape(self, input_param, input_shape, expected_shape): - if input_param.get("flash_attention", False) and not torch.cuda.is_available(): - self.skipTest("Flash attention requires CUDA") - net = Restormer(**input_param) - with eval_mode(net): - result = net(torch.randn(input_shape)) - self.assertEqual(result.shape, expected_shape) - - @skipUnless(has_einops, "Requires einops") - def test_small_input_error_2d(self): - net = Restormer(spatial_dims=2, in_channels=1, out_channels=1) - with self.assertRaises(AssertionError): - net(torch.randn(1, 1, 8, 8)) - - @skipUnless(has_einops, "Requires einops") - def test_small_input_error_3d(self): - net = Restormer(spatial_dims=3, in_channels=1, out_channels=1) - with self.assertRaises(AssertionError): - net(torch.randn(1, 1, 8, 8, 8)) -else: - class TestRestormer(unittest.TestCase): - def test_placeholder(self): - self.skipTest("Einops module not available") + +class TestRestormer(unittest.TestCase): + + @parameterized.expand(TEST_CASES_RESTORMER) + @skipUnless(has_einops, "Requires einops") + def test_shape(self, input_param, input_shape, expected_shape): + if input_param.get("flash_attention", False) and not torch.cuda.is_available(): + self.skipTest("Flash attention requires CUDA") + net = Restormer(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + @skipUnless(has_einops, "Requires einops") + def test_small_input_error_2d(self): + net = Restormer(spatial_dims=2, in_channels=1, out_channels=1) + with self.assertRaises(AssertionError): + net(torch.randn(1, 1, 8, 8)) + + @skipUnless(has_einops, "Requires einops") + def test_small_input_error_3d(self): + net = Restormer(spatial_dims=3, in_channels=1, out_channels=1) + with self.assertRaises(AssertionError): + net(torch.randn(1, 1, 8, 8, 8)) + if __name__ == "__main__": - unittest.main() + print(f'has_einops: {has_einops}') + unittest.main() \ No newline at end of file diff --git a/tests/test_pixelunshuffle.py b/tests/test_pixelunshuffle.py deleted file mode 100644 index 49b61440e5..0000000000 --- a/tests/test_pixelunshuffle.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest - -import torch - -from monai.networks.utils import pixelshuffle, pixelunshuffle - - -class TestPixelUnshuffle(unittest.TestCase): - - def test_2d_basic(self): - x = torch.randn(2, 4, 16, 16) - out = pixelunshuffle(x, spatial_dims=2, scale_factor=2) - self.assertEqual(out.shape, (2, 16, 8, 8)) - - def test_3d_basic(self): - x = torch.randn(2, 4, 16, 16, 16) - out = pixelunshuffle(x, spatial_dims=3, scale_factor=2) - self.assertEqual(out.shape, (2, 32, 8, 8, 8)) - - def test_non_square_input(self): - x = torch.arange(192).reshape(1, 2, 12, 8) - out = pixelunshuffle(x, spatial_dims=2, scale_factor=2) - torch.testing.assert_close(out, torch.pixel_unshuffle(x, 2)) - - def test_different_scale_factor(self): - x = torch.arange(360).reshape(1, 2, 12, 15) - out = pixelunshuffle(x, spatial_dims=2, scale_factor=3) - torch.testing.assert_close(out, torch.pixel_unshuffle(x, 3)) - - def test_inverse_operation(self): - x = torch.arange(4096).reshape(1, 8, 8, 8, 8) - shuffled = pixelshuffle(x, spatial_dims=3, scale_factor=2) - unshuffled = pixelunshuffle(shuffled, spatial_dims=3, scale_factor=2) - torch.testing.assert_close(x, unshuffled) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_restormer.py b/tests/test_restormer.py deleted file mode 100644 index ab08d84390..0000000000 --- a/tests/test_restormer.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest -from unittest import skipUnless - -import torch -from parameterized import parameterized - -from monai.networks import eval_mode -from monai.networks.nets.restormer import MDTATransformerBlock, OverlapPatchEmbed, Restormer -from monai.utils import optional_import - -einops, has_einops = optional_import("einops") - -TEST_CASES_TRANSFORMER = [ - # [spatial_dims, dim, num_heads, ffn_factor, bias, layer_norm_use_bias, flash_attn, input_shape] - [2, 48, 8, 2.66, True, True, False, (2, 48, 64, 64)], - [2, 96, 8, 2.66, False, False, False, (2, 96, 32, 32)], - [3, 48, 4, 2.66, True, True, False, (2, 48, 32, 32, 32)], - [3, 96, 8, 2.66, False, False, True, (2, 96, 16, 16, 16)], -] - -TEST_CASES_PATCHEMBED = [ - # spatial_dims, in_channels, embed_dim, input_shape, expected_shape - [2, 1, 48, (2, 1, 64, 64), (2, 48, 64, 64)], - [2, 3, 96, (2, 3, 32, 32), (2, 96, 32, 32)], - [3, 1, 48, (2, 1, 32, 32, 32), (2, 48, 32, 32, 32)], - [3, 4, 64, (2, 4, 16, 16, 16), (2, 64, 16, 16, 16)], -] - -RESTORMER_CONFIGS = [ - # 2-level architecture - {"num_blocks": [1, 1], "heads": [1, 1]}, - {"num_blocks": [2, 1], "heads": [2, 1]}, - # 3-level architecture - {"num_blocks": [1, 1, 1], "heads": [1, 1, 1]}, - {"num_blocks": [2, 1, 1], "heads": [2, 1, 1]}, -] - -TEST_CASES_RESTORMER = [] -for config in RESTORMER_CONFIGS: - # 2D cases - TEST_CASES_RESTORMER.extend( - [ - [ - { - "spatial_dims": 2, - "in_channels": 1, - "out_channels": 1, - "dim": 48, - "num_blocks": config["num_blocks"], - "heads": config["heads"], - "num_refinement_blocks": 2, - "ffn_expansion_factor": 1.5, - }, - (2, 1, 64, 64), - (2, 1, 64, 64), - ], - # 3D cases - [ - { - "spatial_dims": 3, - "in_channels": 1, - "out_channels": 1, - "dim": 16, - "num_blocks": config["num_blocks"], - "heads": config["heads"], - "num_refinement_blocks": 2, - "ffn_expansion_factor": 1.5, - }, - (2, 1, 32, 32, 32), - (2, 1, 32, 32, 32), - ], - ] - ) - - -class TestMDTATransformerBlock(unittest.TestCase): - - @skipUnless(has_einops, "Requires einops") - @parameterized.expand(TEST_CASES_TRANSFORMER) - def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_bias, flash, shape): - if flash and not torch.cuda.is_available(): - self.skipTest("Flash attention requires CUDA") - block = MDTATransformerBlock( - spatial_dims=spatial_dims, - dim=dim, - num_heads=heads, - ffn_expansion_factor=ffn_factor, - bias=bias, - layer_norm_use_bias=layer_norm_use_bias, - flash_attention=flash, - ) - with eval_mode(block): - x = torch.randn(shape) - output = block(x) - self.assertEqual(output.shape, x.shape) - - -class TestOverlapPatchEmbed(unittest.TestCase): - - @parameterized.expand(TEST_CASES_PATCHEMBED) - def test_shape(self, spatial_dims, in_channels, embed_dim, input_shape, expected_shape): - net = OverlapPatchEmbed(spatial_dims=spatial_dims, in_channels=in_channels, embed_dim=embed_dim) - with eval_mode(net): - result = net(torch.randn(input_shape)) - self.assertEqual(result.shape, expected_shape) - - -class TestRestormer(unittest.TestCase): - - @skipUnless(has_einops, "Requires einops") - @parameterized.expand(TEST_CASES_RESTORMER) - def test_shape(self, input_param, input_shape, expected_shape): - if input_param.get("flash_attention", False) and not torch.cuda.is_available(): - self.skipTest("Flash attention requires CUDA") - net = Restormer(**input_param) - with eval_mode(net): - result = net(torch.randn(input_shape)) - self.assertEqual(result.shape, expected_shape) - - @skipUnless(has_einops, "Requires einops") - def test_small_input_error_2d(self): - net = Restormer(spatial_dims=2, in_channels=1, out_channels=1) - with self.assertRaises(AssertionError): - net(torch.randn(1, 1, 8, 8)) - - @skipUnless(has_einops, "Requires einops") - def test_small_input_error_3d(self): - net = Restormer(spatial_dims=3, in_channels=1, out_channels=1) - with self.assertRaises(AssertionError): - net(torch.randn(1, 1, 8, 8, 8)) - - -if __name__ == "__main__": - unittest.main() From baf75415577695761a40a9e0b2eefb8fd138de1e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Mar 2025 19:52:24 +0000 Subject: [PATCH 65/67] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/networks/nets/test_restormer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/networks/nets/test_restormer.py b/tests/networks/nets/test_restormer.py index 77a5e47b4c..2a611346d6 100644 --- a/tests/networks/nets/test_restormer.py +++ b/tests/networks/nets/test_restormer.py @@ -145,4 +145,4 @@ def test_small_input_error_3d(self): if __name__ == "__main__": print(f'has_einops: {has_einops}') - unittest.main() \ No newline at end of file + unittest.main() From aeebc8902be36fe12f339caa90c8899dd6597bd3 Mon Sep 17 00:00:00 2001 From: "Cano-Muniz, Santiago" Date: Tue, 11 Mar 2025 21:00:11 +0100 Subject: [PATCH 66/67] Remove debug print statement for einops availability in test_restormer.py. Signed-off-by: Cano-Muniz, Santiago --- tests/networks/nets/test_restormer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/networks/nets/test_restormer.py b/tests/networks/nets/test_restormer.py index 2a611346d6..c4ef3fc9ca 100644 --- a/tests/networks/nets/test_restormer.py +++ b/tests/networks/nets/test_restormer.py @@ -144,5 +144,4 @@ def test_small_input_error_3d(self): if __name__ == "__main__": - print(f'has_einops: {has_einops}') unittest.main() From 7342b84ef6f40ee797993145d6c46e9bf2c8afcd Mon Sep 17 00:00:00 2001 From: "Cano-Muniz, Santiago" Date: Tue, 11 Mar 2025 22:24:05 +0100 Subject: [PATCH 67/67] Address mypy suggestions for type annotations in cablock.py, downsample.py, restormer.py and test_downsample_block.py. Signed-off-by: Cano-Muniz, Santiago --- monai/networks/blocks/cablock.py | 8 +++++--- monai/networks/blocks/downsample.py | 17 +++++++++-------- monai/networks/nets/restormer.py | 13 +++++++------ tests/networks/blocks/test_downsample_block.py | 3 ++- 4 files changed, 23 insertions(+), 18 deletions(-) diff --git a/monai/networks/blocks/cablock.py b/monai/networks/blocks/cablock.py index 72e4cc68d0..afaf96805d 100644 --- a/monai/networks/blocks/cablock.py +++ b/monai/networks/blocks/cablock.py @@ -10,6 +10,8 @@ # limitations under the License. from __future__ import annotations +from typing import cast + import torch import torch.nn as nn import torch.nn.functional as F @@ -70,7 +72,7 @@ def __init__(self, spatial_dims: int, dim: int, ffn_expansion_factor: float, bia def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.project_in(x) x1, x2 = self.dwconv(x).chunk(2, dim=1) - return self.project_out(F.gelu(x1) * x2) + return cast(torch.Tensor, self.project_out(F.gelu(x1) * x2)) class CABlock(nn.Module): @@ -141,7 +143,7 @@ def _normal_attention(self, q, k, v): attn = attn.softmax(dim=-1) return attn @ v - def forward(self, x) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass for MDTA attention. 1. Apply depth-wise convolutions to Q, K, V 2. Reshape Q, K, V for multi-head attention @@ -177,4 +179,4 @@ def forward(self, x) -> torch.Tensor: **dict(zip(["h", "w"] if self.spatial_dims == 2 else ["d", "h", "w"], spatial_dims)), ) - return self.project_out(out) + return cast(torch.Tensor, self.project_out(out)) diff --git a/monai/networks/blocks/downsample.py b/monai/networks/blocks/downsample.py index ae962287a9..3291ef0f0e 100644 --- a/monai/networks/blocks/downsample.py +++ b/monai/networks/blocks/downsample.py @@ -92,7 +92,7 @@ def __init__( out_channels: int | None = None, scale_factor: Sequence[float] | float = 2, kernel_size: Sequence[float] | float | None = None, - mode: str = "conv", # conv, convgroup, nontrainable, pixelunshuffle + mode: DownsampleMode | str = DownsampleMode.CONV, pre_conv: nn.Module | str | None = "default", post_conv: nn.Module | None = None, bias: bool = True, @@ -101,11 +101,11 @@ def __init__( Downsamples data by `scale_factor`. Supported modes are: - - "conv": uses a strided convolution for learnable downsampling. - - "convgroup": uses a grouped strided convolution for efficient feature reduction. - - "maxpool": uses maxpooling for non-learnable downsampling. - - "avgpool": uses average pooling for non-learnable downsampling. - - "pixelunshuffle": uses :py:class:`monai.networks.blocks.SubpixelDownsample`. + - DownsampleMode.CONV: uses a strided convolution for learnable downsampling. + - DownsampleMode.CONVGROUP: uses a grouped strided convolution for efficient feature reduction. + - DownsampleMode.MAXPOOL: uses maxpooling for non-learnable downsampling. + - DownsampleMode.AVGPOOL: uses average pooling for non-learnable downsampling. + - DownsampleMode.PIXELUNSHUFFLE: uses :py:class:`monai.networks.blocks.SubpixelDownsample`. This operation will cause non-deterministic behavior when ``mode`` is ``DownsampleMode.NONTRAINABLE``. Please check the link below for more details: @@ -120,7 +120,8 @@ def __init__( out_channels: number of channels of the output image. Defaults to `in_channels`. scale_factor: multiplier for spatial size reduction. Has to match input size if it is a tuple. Defaults to 2. kernel_size: kernel size used during convolutions. Defaults to `scale_factor`. - mode: {``"conv"``, ``"convgroup"``, ``"maxpool"``, ``"avgpool"``, ``"pixelunshuffle"``}. Defaults to ``"conv"``. + mode: {``DownsampleMode.CONV``, ``DownsampleMode.CONVGROUP``, ``DownsampleMode.MAXPOOL``, ``DownsampleMode.AVGPOOL``, + ``DownsampleMode.PIXELUNSHUFFLE``}. Defaults to ``DownsampleMode.CONV``. pre_conv: a conv block applied before downsampling. Defaults to "default". When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized. Only used in the "maxpool", "avgpool" or "pixelunshuffle" modes. @@ -134,7 +135,7 @@ def __init__( if not kernel_size: kernel_size_ = scale_factor_ - padding = 0 + padding = ensure_tuple_rep(0, spatial_dims) else: kernel_size_ = ensure_tuple_rep(kernel_size, spatial_dims) padding = tuple((k - 1) // 2 for k in kernel_size_) diff --git a/monai/networks/nets/restormer.py b/monai/networks/nets/restormer.py index b59150ad4d..02b4b6f28b 100644 --- a/monai/networks/nets/restormer.py +++ b/monai/networks/nets/restormer.py @@ -15,9 +15,10 @@ from monai.networks.blocks.cablock import CABlock, FeedForward from monai.networks.blocks.convolutions import Convolution -from monai.networks.blocks.downsample import DownSample, DownsampleMode -from monai.networks.blocks.upsample import UpSample, UpsampleMode +from monai.networks.blocks.downsample import DownSample +from monai.networks.blocks.upsample import UpSample from monai.networks.layers.factories import Norm +from monai.utils.enums import DownsampleMode, UpsampleMode class MDTATransformerBlock(nn.Module): @@ -81,9 +82,9 @@ def __init__(self, spatial_dims: int, in_channels: int = 3, embed_dim: int = 48, conv_only=True, ) - -def forward(self, x: torch.Tensor) -> torch.Tensor: - return super().forward(x) + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = super().forward(x) + return x class Restormer(nn.Module): @@ -290,7 +291,7 @@ def __init__( conv_only=True, ) - def forward(self, x) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass of Restormer. Processes input through encoder-decoder architecture with skip connections. Args: diff --git a/tests/networks/blocks/test_downsample_block.py b/tests/networks/blocks/test_downsample_block.py index 5e660510d4..993c2865d8 100644 --- a/tests/networks/blocks/test_downsample_block.py +++ b/tests/networks/blocks/test_downsample_block.py @@ -146,7 +146,8 @@ def __init__(self, n_feat: int): def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) - return self.pixelunshuffle(x) + x = self.pixelunshuffle(x) + return x n_feat = 2 x = torch.randn(1, n_feat, 64, 64)