Skip to content

Commit 8ac8e0d

Browse files
yiheng-wang-nvpre-commit-ci[bot]
andauthoredJan 27, 2025··
update pydicom reader to enable gpu load (#8283)
Related to #8241 . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yiheng Wang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent bc9352e commit 8ac8e0d

File tree

2 files changed

+222
-55
lines changed

2 files changed

+222
-55
lines changed
 

‎monai/data/image_reader.py

+170-49
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,10 @@ class PydicomReader(ImageReader):
418418
If provided, only the matched files will be included. For example, to include the file name
419419
"image_0001.dcm", the regular expression could be `".*image_(\\d+).dcm"`. Default to `""`.
420420
Set it to `None` to use `pydicom.misc.is_dicom` to match valid files.
421+
to_gpu: If True, load the image into GPU memory using CuPy and Kvikio. This can accelerate data loading.
422+
Default is False. CuPy and Kvikio are required for this option.
423+
In practical use, it's recommended to add a warm up call before the actual loading.
424+
A related tutorial will be prepared in the future, and the document will be updated accordingly.
421425
kwargs: additional args for `pydicom.dcmread` API. more details about available args:
422426
https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html
423427
If the `get_data` function will be called
@@ -434,6 +438,7 @@ def __init__(
434438
prune_metadata: bool = True,
435439
label_dict: dict | None = None,
436440
fname_regex: str = "",
441+
to_gpu: bool = False,
437442
**kwargs,
438443
):
439444
super().__init__()
@@ -444,6 +449,33 @@ def __init__(
444449
self.prune_metadata = prune_metadata
445450
self.label_dict = label_dict
446451
self.fname_regex = fname_regex
452+
if to_gpu and (not has_cp or not has_kvikio):
453+
warnings.warn(
454+
"PydicomReader: CuPy and/or Kvikio not installed for GPU loading, falling back to CPU loading."
455+
)
456+
to_gpu = False
457+
458+
if to_gpu:
459+
self.warmup_kvikio()
460+
461+
self.to_gpu = to_gpu
462+
463+
def warmup_kvikio(self):
464+
"""
465+
Warm up the Kvikio library to initialize the internal buffers, cuFile, GDS, etc.
466+
This can accelerate the data loading process when `to_gpu` is set to True.
467+
"""
468+
if has_cp and has_kvikio:
469+
a = cp.arange(100)
470+
with tempfile.NamedTemporaryFile() as tmp_file:
471+
tmp_file_name = tmp_file.name
472+
f = kvikio.CuFile(tmp_file_name, "w")
473+
f.write(a)
474+
f.close()
475+
476+
b = cp.empty_like(a)
477+
f = kvikio.CuFile(tmp_file_name, "r")
478+
f.read(b)
447479

448480
def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool:
449481
"""
@@ -475,12 +507,15 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
475507
img_ = []
476508

477509
filenames: Sequence[PathLike] = ensure_tuple(data)
510+
self.filenames = list(filenames)
478511
kwargs_ = self.kwargs.copy()
512+
if self.to_gpu:
513+
kwargs["defer_size"] = "100 KB"
479514
kwargs_.update(kwargs)
480515

481516
self.has_series = False
482517

483-
for name in filenames:
518+
for i, name in enumerate(filenames):
484519
name = f"{name}"
485520
if Path(name).is_dir():
486521
# read DICOM series
@@ -489,20 +524,28 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
489524
else:
490525
series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if pydicom.misc.is_dicom(slc)]
491526
slices = []
527+
loaded_slc_names = []
492528
for slc in series_slcs:
493529
try:
494530
slices.append(pydicom.dcmread(fp=slc, **kwargs_))
531+
loaded_slc_names.append(slc)
495532
except pydicom.errors.InvalidDicomError as e:
496533
warnings.warn(f"Failed to read {slc} with exception: \n{e}.", stacklevel=2)
497-
img_.append(slices if len(slices) > 1 else slices[0])
498534
if len(slices) > 1:
499535
self.has_series = True
536+
img_.append(slices)
537+
self.filenames[i] = loaded_slc_names # type: ignore
538+
else:
539+
img_.append(slices[0]) # type: ignore
540+
self.filenames[i] = loaded_slc_names[0] # type: ignore
500541
else:
501542
ds = pydicom.dcmread(fp=name, **kwargs_)
502-
img_.append(ds)
503-
return img_ if len(filenames) > 1 else img_[0]
543+
img_.append(ds) # type: ignore
544+
if len(filenames) == 1:
545+
return img_[0]
546+
return img_
504547

505-
def _combine_dicom_series(self, data: Iterable):
548+
def _combine_dicom_series(self, data: Iterable, filenames: Sequence[PathLike]):
506549
"""
507550
Combine dicom series (a list of pydicom dataset objects). Their data arrays will be stacked together at a new
508551
dimension as the last dimension.
@@ -522,28 +565,27 @@ def _combine_dicom_series(self, data: Iterable):
522565
"""
523566
slices: list = []
524567
# for a dicom series
525-
for slc_ds in data:
568+
for slc_ds, filename in zip(data, filenames):
526569
if hasattr(slc_ds, "InstanceNumber"):
527-
slices.append(slc_ds)
570+
slices.append((slc_ds, filename))
528571
else:
529-
warnings.warn(f"slice: {slc_ds.filename} does not have InstanceNumber tag, skip it.")
530-
slices = sorted(slices, key=lambda s: s.InstanceNumber)
531-
572+
warnings.warn(f"slice: {filename} does not have InstanceNumber tag, skip it.")
573+
slices = sorted(slices, key=lambda s: s[0].InstanceNumber)
532574
if len(slices) == 0:
533575
raise ValueError("the input does not have valid slices.")
534576

535-
first_slice = slices[0]
577+
first_slice, first_filename = slices[0]
536578
average_distance = 0.0
537-
first_array = self._get_array_data(first_slice)
579+
first_array = self._get_array_data(first_slice, first_filename)
538580
shape = first_array.shape
539-
spacing = getattr(first_slice, "PixelSpacing", [1.0, 1.0, 1.0])
581+
spacing = getattr(first_slice, "PixelSpacing", [1.0] * len(shape))
540582
prev_pos = getattr(first_slice, "ImagePositionPatient", (0.0, 0.0, 0.0))[2]
541583
stack_array = [first_array]
542584
for idx in range(1, len(slices)):
543-
slc_array = self._get_array_data(slices[idx])
585+
slc_array = self._get_array_data(slices[idx][0], slices[idx][1])
544586
slc_shape = slc_array.shape
545-
slc_spacing = getattr(slices[idx], "PixelSpacing", (1.0, 1.0, 1.0))
546-
slc_pos = getattr(slices[idx], "ImagePositionPatient", (0.0, 0.0, float(idx)))[2]
587+
slc_spacing = getattr(slices[idx][0], "PixelSpacing", [1.0] * len(shape))
588+
slc_pos = getattr(slices[idx][0], "ImagePositionPatient", (0.0, 0.0, float(idx)))[2]
547589
if not np.allclose(slc_spacing, spacing):
548590
warnings.warn(f"the list contains slices that have different spacings {spacing} and {slc_spacing}.")
549591
if shape != slc_shape:
@@ -555,11 +597,14 @@ def _combine_dicom_series(self, data: Iterable):
555597
if len(slices) > 1:
556598
average_distance /= len(slices) - 1
557599
spacing.append(average_distance)
558-
stack_array = np.stack(stack_array, axis=-1)
600+
if self.to_gpu:
601+
stack_array = cp.stack(stack_array, axis=-1)
602+
else:
603+
stack_array = np.stack(stack_array, axis=-1)
559604
stack_metadata = self._get_meta_dict(first_slice)
560605
stack_metadata["spacing"] = np.asarray(spacing)
561-
if hasattr(slices[-1], "ImagePositionPatient"):
562-
stack_metadata["lastImagePositionPatient"] = np.asarray(slices[-1].ImagePositionPatient)
606+
if hasattr(slices[-1][0], "ImagePositionPatient"):
607+
stack_metadata["lastImagePositionPatient"] = np.asarray(slices[-1][0].ImagePositionPatient)
563608
stack_metadata[MetaKeys.SPATIAL_SHAPE] = shape + (len(slices),)
564609
else:
565610
stack_array = stack_array[0]
@@ -597,29 +642,35 @@ def get_data(self, data) -> tuple[np.ndarray, dict]:
597642
if self.has_series is True:
598643
# a list, all objects within a list belong to one dicom series
599644
if not isinstance(data[0], list):
600-
dicom_data.append(self._combine_dicom_series(data))
645+
# input is a dir, self.filenames is a list of list of filenames
646+
dicom_data.append(self._combine_dicom_series(data, self.filenames[0])) # type: ignore
601647
# a list of list, each inner list represents a dicom series
602648
else:
603-
for series in data:
604-
dicom_data.append(self._combine_dicom_series(series))
649+
for i, series in enumerate(data):
650+
dicom_data.append(self._combine_dicom_series(series, self.filenames[i])) # type: ignore
605651
else:
606652
# a single pydicom dataset object
607653
if not isinstance(data, list):
608654
data = [data]
609-
for d in data:
655+
for i, d in enumerate(data):
610656
if hasattr(d, "SegmentSequence"):
611-
data_array, metadata = self._get_seg_data(d)
657+
data_array, metadata = self._get_seg_data(d, self.filenames[i])
612658
else:
613-
data_array = self._get_array_data(d)
659+
data_array = self._get_array_data(d, self.filenames[i])
614660
metadata = self._get_meta_dict(d)
615661
metadata[MetaKeys.SPATIAL_SHAPE] = data_array.shape
616662
dicom_data.append((data_array, metadata))
617663

664+
# TODO: the actual type is list[np.ndarray | cp.ndarray]
665+
# should figure out how to define correct types without having cupy not found error
666+
# https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918
618667
img_array: list[np.ndarray] = []
619668
compatible_meta: dict = {}
620669

621670
for data_array, metadata in ensure_tuple(dicom_data):
622-
img_array.append(np.ascontiguousarray(np.swapaxes(data_array, 0, 1) if self.swap_ij else data_array))
671+
if self.swap_ij:
672+
data_array = cp.swapaxes(data_array, 0, 1) if self.to_gpu else np.swapaxes(data_array, 0, 1)
673+
img_array.append(cp.ascontiguousarray(data_array) if self.to_gpu else np.ascontiguousarray(data_array))
623674
affine = self._get_affine(metadata, self.affine_lps_to_ras)
624675
metadata[MetaKeys.SPACE] = SpaceKeys.RAS if self.affine_lps_to_ras else SpaceKeys.LPS
625676
if self.swap_ij:
@@ -641,7 +692,7 @@ def get_data(self, data) -> tuple[np.ndarray, dict]:
641692

642693
_copy_compatible_dict(metadata, compatible_meta)
643694

644-
return _stack_images(img_array, compatible_meta), compatible_meta
695+
return _stack_images(img_array, compatible_meta, to_cupy=self.to_gpu), compatible_meta
645696

646697
def _get_meta_dict(self, img) -> dict:
647698
"""
@@ -713,7 +764,7 @@ def _get_affine(self, metadata: dict, lps_to_ras: bool = True):
713764
affine = orientation_ras_lps(affine)
714765
return affine
715766

716-
def _get_frame_data(self, img) -> Iterator:
767+
def _get_frame_data(self, img, filename, array_data) -> Iterator:
717768
"""
718769
yield frames and description from the segmentation image.
719770
This function is adapted from Highdicom:
@@ -751,48 +802,54 @@ def _get_frame_data(self, img) -> Iterator:
751802
"""
752803

753804
if not hasattr(img, "PerFrameFunctionalGroupsSequence"):
754-
raise NotImplementedError(
755-
f"To read dicom seg: {img.filename}, 'PerFrameFunctionalGroupsSequence' is required."
756-
)
805+
raise NotImplementedError(f"To read dicom seg: {filename}, 'PerFrameFunctionalGroupsSequence' is required.")
757806

758807
frame_seg_nums = []
759808
for f in img.PerFrameFunctionalGroupsSequence:
760809
if not hasattr(f, "SegmentIdentificationSequence"):
761810
raise NotImplementedError(
762-
f"To read dicom seg: {img.filename}, 'SegmentIdentificationSequence' is required for each frame."
811+
f"To read dicom seg: {filename}, 'SegmentIdentificationSequence' is required for each frame."
763812
)
764813
frame_seg_nums.append(int(f.SegmentIdentificationSequence[0].ReferencedSegmentNumber))
765814

766-
frame_seg_nums_arr = np.array(frame_seg_nums)
815+
frame_seg_nums_arr = cp.array(frame_seg_nums) if self.to_gpu else np.array(frame_seg_nums)
767816

768817
seg_descriptions = {int(f.SegmentNumber): f for f in img.SegmentSequence}
769818

770-
for i in np.unique(frame_seg_nums_arr):
771-
indices = np.where(frame_seg_nums_arr == i)[0]
772-
yield (img.pixel_array[indices, ...], seg_descriptions[i])
819+
for i in np.unique(frame_seg_nums_arr) if not self.to_gpu else cp.unique(frame_seg_nums_arr):
820+
indices = np.where(frame_seg_nums_arr == i)[0] if not self.to_gpu else cp.where(frame_seg_nums_arr == i)[0]
821+
yield (array_data[indices, ...], seg_descriptions[i])
773822

774-
def _get_seg_data(self, img):
823+
def _get_seg_data(self, img, filename):
775824
"""
776825
Get the array data and metadata of the segmentation image.
777826
778827
Aegs:
779828
img: a Pydicom dataset object that has attribute "SegmentSequence".
829+
filename: the file path of the image.
780830
781831
"""
782832

783833
metadata = self._get_meta_dict(img)
784834
n_classes = len(img.SegmentSequence)
785-
spatial_shape = list(img.pixel_array.shape)
835+
array_data = self._get_array_data(img, filename)
836+
spatial_shape = list(array_data.shape)
786837
spatial_shape[0] = spatial_shape[0] // n_classes
787838

788839
if self.label_dict is not None:
789840
metadata["labels"] = self.label_dict
790-
all_segs = np.zeros([*spatial_shape, len(self.label_dict)])
841+
if self.to_gpu:
842+
all_segs = cp.zeros([*spatial_shape, len(self.label_dict)], dtype=array_data.dtype)
843+
else:
844+
all_segs = np.zeros([*spatial_shape, len(self.label_dict)], dtype=array_data.dtype)
791845
else:
792846
metadata["labels"] = {}
793-
all_segs = np.zeros([*spatial_shape, n_classes])
847+
if self.to_gpu:
848+
all_segs = cp.zeros([*spatial_shape, n_classes], dtype=array_data.dtype)
849+
else:
850+
all_segs = np.zeros([*spatial_shape, n_classes], dtype=array_data.dtype)
794851

795-
for i, (frames, description) in enumerate(self._get_frame_data(img)):
852+
for i, (frames, description) in enumerate(self._get_frame_data(img, filename, array_data)):
796853
segment_label = getattr(description, "SegmentLabel", f"label_{i}")
797854
class_name = getattr(description, "SegmentDescription", segment_label)
798855
if class_name not in metadata["labels"].keys():
@@ -840,19 +897,79 @@ def _get_seg_data(self, img):
840897

841898
return all_segs, metadata
842899

843-
def _get_array_data(self, img):
900+
def _get_array_data_from_gpu(self, img, filename):
901+
"""
902+
Get the raw array data of the image. This function is used when `to_gpu` is set to True.
903+
904+
Args:
905+
img: a Pydicom dataset object.
906+
filename: the file path of the image.
907+
908+
"""
909+
rows = getattr(img, "Rows", None)
910+
columns = getattr(img, "Columns", None)
911+
bits_allocated = getattr(img, "BitsAllocated", None)
912+
samples_per_pixel = getattr(img, "SamplesPerPixel", 1)
913+
number_of_frames = getattr(img, "NumberOfFrames", 1)
914+
pixel_representation = getattr(img, "PixelRepresentation", 1)
915+
916+
if rows is None or columns is None or bits_allocated is None:
917+
warnings.warn(
918+
f"dicom data: {filename} does not have Rows, Columns or BitsAllocated, falling back to CPU loading."
919+
)
920+
921+
if not hasattr(img, "pixel_array"):
922+
raise ValueError(f"dicom data: {filename} does not have pixel_array.")
923+
data = img.pixel_array
924+
925+
return data
926+
927+
if bits_allocated == 8:
928+
dtype = cp.int8 if pixel_representation == 1 else cp.uint8
929+
elif bits_allocated == 16:
930+
dtype = cp.int16 if pixel_representation == 1 else cp.uint16
931+
elif bits_allocated == 32:
932+
dtype = cp.int32 if pixel_representation == 1 else cp.uint32
933+
else:
934+
raise ValueError("Unsupported BitsAllocated value")
935+
936+
bytes_per_pixel = bits_allocated // 8
937+
total_pixels = rows * columns * samples_per_pixel * number_of_frames
938+
expected_pixel_data_length = total_pixels * bytes_per_pixel
939+
940+
pixel_data_tag = pydicom.tag.Tag(0x7FE0, 0x0010)
941+
if pixel_data_tag not in img:
942+
raise ValueError(f"dicom data: {filename} does not have pixel data.")
943+
944+
offset = img.get_item(pixel_data_tag, keep_deferred=True).value_tell
945+
946+
with kvikio.CuFile(filename, "r") as f:
947+
buffer = cp.empty(expected_pixel_data_length, dtype=cp.int8)
948+
f.read(buffer, expected_pixel_data_length, offset)
949+
950+
new_shape = (number_of_frames, rows, columns) if number_of_frames > 1 else (rows, columns)
951+
data = buffer.view(dtype).reshape(new_shape)
952+
953+
return data
954+
955+
def _get_array_data(self, img, filename):
844956
"""
845957
Get the array data of the image. If `RescaleSlope` and `RescaleIntercept` are available, the raw array data
846-
will be rescaled. The output data has the dtype np.float32 if the rescaling is applied.
958+
will be rescaled. The output data has the dtype float32 if the rescaling is applied.
847959
848960
Args:
849961
img: a Pydicom dataset object.
962+
filename: the file path of the image.
850963
851964
"""
852965
# process Dicom series
853-
if not hasattr(img, "pixel_array"):
854-
raise ValueError(f"dicom data: {img.filename} does not have pixel_array.")
855-
data = img.pixel_array
966+
967+
if self.to_gpu:
968+
data = self._get_array_data_from_gpu(img, filename)
969+
else:
970+
if not hasattr(img, "pixel_array"):
971+
raise ValueError(f"dicom data: {filename} does not have pixel_array.")
972+
data = img.pixel_array
856973

857974
slope, offset = 1.0, 0.0
858975
rescale_flag = False
@@ -862,8 +979,14 @@ def _get_array_data(self, img):
862979
if hasattr(img, "RescaleIntercept"):
863980
offset = img.RescaleIntercept
864981
rescale_flag = True
982+
865983
if rescale_flag:
866-
data = data.astype(np.float32) * slope + offset
984+
if self.to_gpu:
985+
slope = cp.asarray(slope, dtype=cp.float32)
986+
offset = cp.asarray(offset, dtype=cp.float32)
987+
data = data.astype(cp.float32) * slope + offset
988+
else:
989+
data = data.astype(np.float32) * slope + offset
867990

868991
return data
869992

@@ -884,8 +1007,6 @@ class NibabelReader(ImageReader):
8841007
Default is False. CuPy and Kvikio are required for this option.
8851008
Note: For compressed NIfTI files, some operations may still be performed on CPU memory,
8861009
and the acceleration may not be significant. In some cases, it may be slower than loading on CPU.
887-
In practical use, it's recommended to add a warm up call before the actual loading.
888-
A related tutorial will be prepared in the future, and the document will be updated accordingly.
8891010
kwargs: additional args for `nibabel.load` API. more details about available args:
8901011
https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py
8911012

‎tests/test_load_image.py

+52-6
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,16 @@ def get_data(self, _obj):
168168
# test reader consistency between PydicomReader and ITKReader on dicom data
169169
TEST_CASE_22 = ["tests/testing_data/CT_DICOM"]
170170

171+
# test pydicom gpu reader
172+
TEST_CASE_GPU_5 = [{"reader": "PydicomReader", "to_gpu": True}, "tests/testing_data/CT_DICOM", (16, 16, 4), (16, 16, 4)]
173+
174+
TEST_CASE_GPU_6 = [
175+
{"reader": "PydicomReader", "ensure_channel_first": True, "force": True, "to_gpu": True},
176+
"tests/testing_data/CT_DICOM",
177+
(16, 16, 4),
178+
(1, 16, 16, 4),
179+
]
180+
171181
TESTS_META = []
172182
for track_meta in (False, True):
173183
TESTS_META.append([{}, (128, 128, 128), track_meta])
@@ -242,16 +252,17 @@ def test_nibabel_reader_gpu(self, input_param, filenames, expected_shape):
242252

243253
@parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_8_1, TEST_CASE_9])
244254
def test_itk_reader(self, input_param, filenames, expected_shape):
245-
test_image = np.random.rand(128, 128, 128)
255+
test_image = torch.randint(0, 256, (128, 128, 128), dtype=torch.uint8).numpy()
256+
print("Test image value range:", test_image.min(), test_image.max())
246257
with tempfile.TemporaryDirectory() as tempdir:
247258
for i, name in enumerate(filenames):
248259
filenames[i] = os.path.join(tempdir, name)
249-
itk_np_view = itk.image_view_from_array(test_image)
250-
itk.imwrite(itk_np_view, filenames[i])
260+
nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i])
251261
result = LoadImage(image_only=True, **input_param)(filenames)
252-
self.assertEqual(result.meta["filename_or_obj"], os.path.join(tempdir, "test_image.nii.gz"))
253-
diag = torch.as_tensor(np.diag([-1, -1, 1, 1]))
254-
np.testing.assert_allclose(result.affine, diag)
262+
ext = "".join(Path(name).suffixes)
263+
self.assertEqual(result.meta["filename_or_obj"], os.path.join(tempdir, "test_image" + ext))
264+
self.assertEqual(result.meta["space"], "RAS")
265+
assert_allclose(result.affine, torch.eye(4))
255266
self.assertTupleEqual(result.shape, expected_shape)
256267

257268
@parameterized.expand([TEST_CASE_10, TEST_CASE_11, TEST_CASE_12, TEST_CASE_19, TEST_CASE_20, TEST_CASE_21])
@@ -271,6 +282,26 @@ def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape, e
271282
)
272283
self.assertTupleEqual(result.shape, expected_np_shape)
273284

285+
@SkipIfNoModule("pydicom")
286+
@SkipIfNoModule("cupy")
287+
@SkipIfNoModule("kvikio")
288+
@parameterized.expand([TEST_CASE_GPU_5, TEST_CASE_GPU_6])
289+
def test_pydicom_gpu_reader(self, input_param, filenames, expected_shape, expected_np_shape):
290+
result = LoadImage(image_only=True, **input_param)(filenames)
291+
self.assertEqual(result.meta["filename_or_obj"], f"{Path(filenames)}")
292+
assert_allclose(
293+
result.affine,
294+
torch.tensor(
295+
[
296+
[-0.488281, 0.0, 0.0, 125.0],
297+
[0.0, -0.488281, 0.0, 128.100006],
298+
[0.0, 0.0, 68.33333333, -99.480003],
299+
[0.0, 0.0, 0.0, 1.0],
300+
]
301+
),
302+
)
303+
self.assertTupleEqual(result.shape, expected_np_shape)
304+
274305
def test_no_files(self):
275306
with self.assertRaisesRegex(RuntimeError, "list index out of range"): # fname_regex excludes everything
276307
LoadImage(image_only=True, reader="PydicomReader", fname_regex=r"^(?!.*).*")("tests/testing_data/CT_DICOM")
@@ -317,6 +348,21 @@ def test_dicom_reader_consistency(self, filenames):
317348
np.testing.assert_allclose(pydicom_result, itk_result)
318349
np.testing.assert_allclose(pydicom_result.affine, itk_result.affine)
319350

351+
@SkipIfNoModule("pydicom")
352+
@SkipIfNoModule("cupy")
353+
@SkipIfNoModule("kvikio")
354+
@parameterized.expand([TEST_CASE_22])
355+
def test_pydicom_reader_gpu_cpu_consistency(self, filenames):
356+
gpu_param = {"reader": "PydicomReader", "to_gpu": True}
357+
cpu_param = {"reader": "PydicomReader", "to_gpu": False}
358+
for affine_flag in [True, False]:
359+
gpu_param["affine_lps_to_ras"] = affine_flag
360+
cpu_param["affine_lps_to_ras"] = affine_flag
361+
gpu_result = LoadImage(image_only=True, **gpu_param)(filenames)
362+
cpu_result = LoadImage(image_only=True, **cpu_param)(filenames)
363+
np.testing.assert_allclose(gpu_result.cpu(), cpu_result)
364+
np.testing.assert_allclose(gpu_result.affine.cpu(), cpu_result.affine)
365+
320366
def test_dicom_reader_consistency_single(self):
321367
itk_param = {"reader": "ITKReader"}
322368
pydicom_param = {"reader": "PydicomReader"}

0 commit comments

Comments
 (0)
Please sign in to comment.