@@ -418,6 +418,10 @@ class PydicomReader(ImageReader):
418
418
If provided, only the matched files will be included. For example, to include the file name
419
419
"image_0001.dcm", the regular expression could be `".*image_(\\ d+).dcm"`. Default to `""`.
420
420
Set it to `None` to use `pydicom.misc.is_dicom` to match valid files.
421
+ to_gpu: If True, load the image into GPU memory using CuPy and Kvikio. This can accelerate data loading.
422
+ Default is False. CuPy and Kvikio are required for this option.
423
+ In practical use, it's recommended to add a warm up call before the actual loading.
424
+ A related tutorial will be prepared in the future, and the document will be updated accordingly.
421
425
kwargs: additional args for `pydicom.dcmread` API. more details about available args:
422
426
https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html
423
427
If the `get_data` function will be called
@@ -434,6 +438,7 @@ def __init__(
434
438
prune_metadata : bool = True ,
435
439
label_dict : dict | None = None ,
436
440
fname_regex : str = "" ,
441
+ to_gpu : bool = False ,
437
442
** kwargs ,
438
443
):
439
444
super ().__init__ ()
@@ -444,6 +449,33 @@ def __init__(
444
449
self .prune_metadata = prune_metadata
445
450
self .label_dict = label_dict
446
451
self .fname_regex = fname_regex
452
+ if to_gpu and (not has_cp or not has_kvikio ):
453
+ warnings .warn (
454
+ "PydicomReader: CuPy and/or Kvikio not installed for GPU loading, falling back to CPU loading."
455
+ )
456
+ to_gpu = False
457
+
458
+ if to_gpu :
459
+ self .warmup_kvikio ()
460
+
461
+ self .to_gpu = to_gpu
462
+
463
+ def warmup_kvikio (self ):
464
+ """
465
+ Warm up the Kvikio library to initialize the internal buffers, cuFile, GDS, etc.
466
+ This can accelerate the data loading process when `to_gpu` is set to True.
467
+ """
468
+ if has_cp and has_kvikio :
469
+ a = cp .arange (100 )
470
+ with tempfile .NamedTemporaryFile () as tmp_file :
471
+ tmp_file_name = tmp_file .name
472
+ f = kvikio .CuFile (tmp_file_name , "w" )
473
+ f .write (a )
474
+ f .close ()
475
+
476
+ b = cp .empty_like (a )
477
+ f = kvikio .CuFile (tmp_file_name , "r" )
478
+ f .read (b )
447
479
448
480
def verify_suffix (self , filename : Sequence [PathLike ] | PathLike ) -> bool :
449
481
"""
@@ -475,12 +507,15 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
475
507
img_ = []
476
508
477
509
filenames : Sequence [PathLike ] = ensure_tuple (data )
510
+ self .filenames = list (filenames )
478
511
kwargs_ = self .kwargs .copy ()
512
+ if self .to_gpu :
513
+ kwargs ["defer_size" ] = "100 KB"
479
514
kwargs_ .update (kwargs )
480
515
481
516
self .has_series = False
482
517
483
- for name in filenames :
518
+ for i , name in enumerate ( filenames ) :
484
519
name = f"{ name } "
485
520
if Path (name ).is_dir ():
486
521
# read DICOM series
@@ -489,20 +524,28 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs):
489
524
else :
490
525
series_slcs = [slc for slc in glob .glob (os .path .join (name , "*" )) if pydicom .misc .is_dicom (slc )]
491
526
slices = []
527
+ loaded_slc_names = []
492
528
for slc in series_slcs :
493
529
try :
494
530
slices .append (pydicom .dcmread (fp = slc , ** kwargs_ ))
531
+ loaded_slc_names .append (slc )
495
532
except pydicom .errors .InvalidDicomError as e :
496
533
warnings .warn (f"Failed to read { slc } with exception: \n { e } ." , stacklevel = 2 )
497
- img_ .append (slices if len (slices ) > 1 else slices [0 ])
498
534
if len (slices ) > 1 :
499
535
self .has_series = True
536
+ img_ .append (slices )
537
+ self .filenames [i ] = loaded_slc_names # type: ignore
538
+ else :
539
+ img_ .append (slices [0 ]) # type: ignore
540
+ self .filenames [i ] = loaded_slc_names [0 ] # type: ignore
500
541
else :
501
542
ds = pydicom .dcmread (fp = name , ** kwargs_ )
502
- img_ .append (ds )
503
- return img_ if len (filenames ) > 1 else img_ [0 ]
543
+ img_ .append (ds ) # type: ignore
544
+ if len (filenames ) == 1 :
545
+ return img_ [0 ]
546
+ return img_
504
547
505
- def _combine_dicom_series (self , data : Iterable ):
548
+ def _combine_dicom_series (self , data : Iterable , filenames : Sequence [ PathLike ] ):
506
549
"""
507
550
Combine dicom series (a list of pydicom dataset objects). Their data arrays will be stacked together at a new
508
551
dimension as the last dimension.
@@ -522,28 +565,27 @@ def _combine_dicom_series(self, data: Iterable):
522
565
"""
523
566
slices : list = []
524
567
# for a dicom series
525
- for slc_ds in data :
568
+ for slc_ds , filename in zip ( data , filenames ) :
526
569
if hasattr (slc_ds , "InstanceNumber" ):
527
- slices .append (slc_ds )
570
+ slices .append (( slc_ds , filename ) )
528
571
else :
529
- warnings .warn (f"slice: { slc_ds .filename } does not have InstanceNumber tag, skip it." )
530
- slices = sorted (slices , key = lambda s : s .InstanceNumber )
531
-
572
+ warnings .warn (f"slice: { filename } does not have InstanceNumber tag, skip it." )
573
+ slices = sorted (slices , key = lambda s : s [0 ].InstanceNumber )
532
574
if len (slices ) == 0 :
533
575
raise ValueError ("the input does not have valid slices." )
534
576
535
- first_slice = slices [0 ]
577
+ first_slice , first_filename = slices [0 ]
536
578
average_distance = 0.0
537
- first_array = self ._get_array_data (first_slice )
579
+ first_array = self ._get_array_data (first_slice , first_filename )
538
580
shape = first_array .shape
539
- spacing = getattr (first_slice , "PixelSpacing" , [1.0 , 1.0 , 1.0 ] )
581
+ spacing = getattr (first_slice , "PixelSpacing" , [1.0 ] * len ( shape ) )
540
582
prev_pos = getattr (first_slice , "ImagePositionPatient" , (0.0 , 0.0 , 0.0 ))[2 ]
541
583
stack_array = [first_array ]
542
584
for idx in range (1 , len (slices )):
543
- slc_array = self ._get_array_data (slices [idx ])
585
+ slc_array = self ._get_array_data (slices [idx ][ 0 ], slices [ idx ][ 1 ] )
544
586
slc_shape = slc_array .shape
545
- slc_spacing = getattr (slices [idx ], "PixelSpacing" , ( 1.0 , 1.0 , 1.0 ))
546
- slc_pos = getattr (slices [idx ], "ImagePositionPatient" , (0.0 , 0.0 , float (idx )))[2 ]
587
+ slc_spacing = getattr (slices [idx ][ 0 ] , "PixelSpacing" , [ 1.0 ] * len ( shape ))
588
+ slc_pos = getattr (slices [idx ][ 0 ] , "ImagePositionPatient" , (0.0 , 0.0 , float (idx )))[2 ]
547
589
if not np .allclose (slc_spacing , spacing ):
548
590
warnings .warn (f"the list contains slices that have different spacings { spacing } and { slc_spacing } ." )
549
591
if shape != slc_shape :
@@ -555,11 +597,14 @@ def _combine_dicom_series(self, data: Iterable):
555
597
if len (slices ) > 1 :
556
598
average_distance /= len (slices ) - 1
557
599
spacing .append (average_distance )
558
- stack_array = np .stack (stack_array , axis = - 1 )
600
+ if self .to_gpu :
601
+ stack_array = cp .stack (stack_array , axis = - 1 )
602
+ else :
603
+ stack_array = np .stack (stack_array , axis = - 1 )
559
604
stack_metadata = self ._get_meta_dict (first_slice )
560
605
stack_metadata ["spacing" ] = np .asarray (spacing )
561
- if hasattr (slices [- 1 ], "ImagePositionPatient" ):
562
- stack_metadata ["lastImagePositionPatient" ] = np .asarray (slices [- 1 ].ImagePositionPatient )
606
+ if hasattr (slices [- 1 ][ 0 ] , "ImagePositionPatient" ):
607
+ stack_metadata ["lastImagePositionPatient" ] = np .asarray (slices [- 1 ][ 0 ] .ImagePositionPatient )
563
608
stack_metadata [MetaKeys .SPATIAL_SHAPE ] = shape + (len (slices ),)
564
609
else :
565
610
stack_array = stack_array [0 ]
@@ -597,29 +642,35 @@ def get_data(self, data) -> tuple[np.ndarray, dict]:
597
642
if self .has_series is True :
598
643
# a list, all objects within a list belong to one dicom series
599
644
if not isinstance (data [0 ], list ):
600
- dicom_data .append (self ._combine_dicom_series (data ))
645
+ # input is a dir, self.filenames is a list of list of filenames
646
+ dicom_data .append (self ._combine_dicom_series (data , self .filenames [0 ])) # type: ignore
601
647
# a list of list, each inner list represents a dicom series
602
648
else :
603
- for series in data :
604
- dicom_data .append (self ._combine_dicom_series (series ))
649
+ for i , series in enumerate ( data ) :
650
+ dicom_data .append (self ._combine_dicom_series (series , self . filenames [ i ])) # type: ignore
605
651
else :
606
652
# a single pydicom dataset object
607
653
if not isinstance (data , list ):
608
654
data = [data ]
609
- for d in data :
655
+ for i , d in enumerate ( data ) :
610
656
if hasattr (d , "SegmentSequence" ):
611
- data_array , metadata = self ._get_seg_data (d )
657
+ data_array , metadata = self ._get_seg_data (d , self . filenames [ i ] )
612
658
else :
613
- data_array = self ._get_array_data (d )
659
+ data_array = self ._get_array_data (d , self . filenames [ i ] )
614
660
metadata = self ._get_meta_dict (d )
615
661
metadata [MetaKeys .SPATIAL_SHAPE ] = data_array .shape
616
662
dicom_data .append ((data_array , metadata ))
617
663
664
+ # TODO: the actual type is list[np.ndarray | cp.ndarray]
665
+ # should figure out how to define correct types without having cupy not found error
666
+ # https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918
618
667
img_array : list [np .ndarray ] = []
619
668
compatible_meta : dict = {}
620
669
621
670
for data_array , metadata in ensure_tuple (dicom_data ):
622
- img_array .append (np .ascontiguousarray (np .swapaxes (data_array , 0 , 1 ) if self .swap_ij else data_array ))
671
+ if self .swap_ij :
672
+ data_array = cp .swapaxes (data_array , 0 , 1 ) if self .to_gpu else np .swapaxes (data_array , 0 , 1 )
673
+ img_array .append (cp .ascontiguousarray (data_array ) if self .to_gpu else np .ascontiguousarray (data_array ))
623
674
affine = self ._get_affine (metadata , self .affine_lps_to_ras )
624
675
metadata [MetaKeys .SPACE ] = SpaceKeys .RAS if self .affine_lps_to_ras else SpaceKeys .LPS
625
676
if self .swap_ij :
@@ -641,7 +692,7 @@ def get_data(self, data) -> tuple[np.ndarray, dict]:
641
692
642
693
_copy_compatible_dict (metadata , compatible_meta )
643
694
644
- return _stack_images (img_array , compatible_meta ), compatible_meta
695
+ return _stack_images (img_array , compatible_meta , to_cupy = self . to_gpu ), compatible_meta
645
696
646
697
def _get_meta_dict (self , img ) -> dict :
647
698
"""
@@ -713,7 +764,7 @@ def _get_affine(self, metadata: dict, lps_to_ras: bool = True):
713
764
affine = orientation_ras_lps (affine )
714
765
return affine
715
766
716
- def _get_frame_data (self , img ) -> Iterator :
767
+ def _get_frame_data (self , img , filename , array_data ) -> Iterator :
717
768
"""
718
769
yield frames and description from the segmentation image.
719
770
This function is adapted from Highdicom:
@@ -751,48 +802,54 @@ def _get_frame_data(self, img) -> Iterator:
751
802
"""
752
803
753
804
if not hasattr (img , "PerFrameFunctionalGroupsSequence" ):
754
- raise NotImplementedError (
755
- f"To read dicom seg: { img .filename } , 'PerFrameFunctionalGroupsSequence' is required."
756
- )
805
+ raise NotImplementedError (f"To read dicom seg: { filename } , 'PerFrameFunctionalGroupsSequence' is required." )
757
806
758
807
frame_seg_nums = []
759
808
for f in img .PerFrameFunctionalGroupsSequence :
760
809
if not hasattr (f , "SegmentIdentificationSequence" ):
761
810
raise NotImplementedError (
762
- f"To read dicom seg: { img . filename } , 'SegmentIdentificationSequence' is required for each frame."
811
+ f"To read dicom seg: { filename } , 'SegmentIdentificationSequence' is required for each frame."
763
812
)
764
813
frame_seg_nums .append (int (f .SegmentIdentificationSequence [0 ].ReferencedSegmentNumber ))
765
814
766
- frame_seg_nums_arr = np .array (frame_seg_nums )
815
+ frame_seg_nums_arr = cp . array ( frame_seg_nums ) if self . to_gpu else np .array (frame_seg_nums )
767
816
768
817
seg_descriptions = {int (f .SegmentNumber ): f for f in img .SegmentSequence }
769
818
770
- for i in np .unique (frame_seg_nums_arr ):
771
- indices = np .where (frame_seg_nums_arr == i )[0 ]
772
- yield (img . pixel_array [indices , ...], seg_descriptions [i ])
819
+ for i in np .unique (frame_seg_nums_arr ) if not self . to_gpu else cp . unique ( frame_seg_nums_arr ) :
820
+ indices = np .where (frame_seg_nums_arr == i )[0 ] if not self . to_gpu else cp . where ( frame_seg_nums_arr == i )[ 0 ]
821
+ yield (array_data [indices , ...], seg_descriptions [i ])
773
822
774
- def _get_seg_data (self , img ):
823
+ def _get_seg_data (self , img , filename ):
775
824
"""
776
825
Get the array data and metadata of the segmentation image.
777
826
778
827
Aegs:
779
828
img: a Pydicom dataset object that has attribute "SegmentSequence".
829
+ filename: the file path of the image.
780
830
781
831
"""
782
832
783
833
metadata = self ._get_meta_dict (img )
784
834
n_classes = len (img .SegmentSequence )
785
- spatial_shape = list (img .pixel_array .shape )
835
+ array_data = self ._get_array_data (img , filename )
836
+ spatial_shape = list (array_data .shape )
786
837
spatial_shape [0 ] = spatial_shape [0 ] // n_classes
787
838
788
839
if self .label_dict is not None :
789
840
metadata ["labels" ] = self .label_dict
790
- all_segs = np .zeros ([* spatial_shape , len (self .label_dict )])
841
+ if self .to_gpu :
842
+ all_segs = cp .zeros ([* spatial_shape , len (self .label_dict )], dtype = array_data .dtype )
843
+ else :
844
+ all_segs = np .zeros ([* spatial_shape , len (self .label_dict )], dtype = array_data .dtype )
791
845
else :
792
846
metadata ["labels" ] = {}
793
- all_segs = np .zeros ([* spatial_shape , n_classes ])
847
+ if self .to_gpu :
848
+ all_segs = cp .zeros ([* spatial_shape , n_classes ], dtype = array_data .dtype )
849
+ else :
850
+ all_segs = np .zeros ([* spatial_shape , n_classes ], dtype = array_data .dtype )
794
851
795
- for i , (frames , description ) in enumerate (self ._get_frame_data (img )):
852
+ for i , (frames , description ) in enumerate (self ._get_frame_data (img , filename , array_data )):
796
853
segment_label = getattr (description , "SegmentLabel" , f"label_{ i } " )
797
854
class_name = getattr (description , "SegmentDescription" , segment_label )
798
855
if class_name not in metadata ["labels" ].keys ():
@@ -840,19 +897,79 @@ def _get_seg_data(self, img):
840
897
841
898
return all_segs , metadata
842
899
843
- def _get_array_data (self , img ):
900
+ def _get_array_data_from_gpu (self , img , filename ):
901
+ """
902
+ Get the raw array data of the image. This function is used when `to_gpu` is set to True.
903
+
904
+ Args:
905
+ img: a Pydicom dataset object.
906
+ filename: the file path of the image.
907
+
908
+ """
909
+ rows = getattr (img , "Rows" , None )
910
+ columns = getattr (img , "Columns" , None )
911
+ bits_allocated = getattr (img , "BitsAllocated" , None )
912
+ samples_per_pixel = getattr (img , "SamplesPerPixel" , 1 )
913
+ number_of_frames = getattr (img , "NumberOfFrames" , 1 )
914
+ pixel_representation = getattr (img , "PixelRepresentation" , 1 )
915
+
916
+ if rows is None or columns is None or bits_allocated is None :
917
+ warnings .warn (
918
+ f"dicom data: { filename } does not have Rows, Columns or BitsAllocated, falling back to CPU loading."
919
+ )
920
+
921
+ if not hasattr (img , "pixel_array" ):
922
+ raise ValueError (f"dicom data: { filename } does not have pixel_array." )
923
+ data = img .pixel_array
924
+
925
+ return data
926
+
927
+ if bits_allocated == 8 :
928
+ dtype = cp .int8 if pixel_representation == 1 else cp .uint8
929
+ elif bits_allocated == 16 :
930
+ dtype = cp .int16 if pixel_representation == 1 else cp .uint16
931
+ elif bits_allocated == 32 :
932
+ dtype = cp .int32 if pixel_representation == 1 else cp .uint32
933
+ else :
934
+ raise ValueError ("Unsupported BitsAllocated value" )
935
+
936
+ bytes_per_pixel = bits_allocated // 8
937
+ total_pixels = rows * columns * samples_per_pixel * number_of_frames
938
+ expected_pixel_data_length = total_pixels * bytes_per_pixel
939
+
940
+ pixel_data_tag = pydicom .tag .Tag (0x7FE0 , 0x0010 )
941
+ if pixel_data_tag not in img :
942
+ raise ValueError (f"dicom data: { filename } does not have pixel data." )
943
+
944
+ offset = img .get_item (pixel_data_tag , keep_deferred = True ).value_tell
945
+
946
+ with kvikio .CuFile (filename , "r" ) as f :
947
+ buffer = cp .empty (expected_pixel_data_length , dtype = cp .int8 )
948
+ f .read (buffer , expected_pixel_data_length , offset )
949
+
950
+ new_shape = (number_of_frames , rows , columns ) if number_of_frames > 1 else (rows , columns )
951
+ data = buffer .view (dtype ).reshape (new_shape )
952
+
953
+ return data
954
+
955
+ def _get_array_data (self , img , filename ):
844
956
"""
845
957
Get the array data of the image. If `RescaleSlope` and `RescaleIntercept` are available, the raw array data
846
- will be rescaled. The output data has the dtype np. float32 if the rescaling is applied.
958
+ will be rescaled. The output data has the dtype float32 if the rescaling is applied.
847
959
848
960
Args:
849
961
img: a Pydicom dataset object.
962
+ filename: the file path of the image.
850
963
851
964
"""
852
965
# process Dicom series
853
- if not hasattr (img , "pixel_array" ):
854
- raise ValueError (f"dicom data: { img .filename } does not have pixel_array." )
855
- data = img .pixel_array
966
+
967
+ if self .to_gpu :
968
+ data = self ._get_array_data_from_gpu (img , filename )
969
+ else :
970
+ if not hasattr (img , "pixel_array" ):
971
+ raise ValueError (f"dicom data: { filename } does not have pixel_array." )
972
+ data = img .pixel_array
856
973
857
974
slope , offset = 1.0 , 0.0
858
975
rescale_flag = False
@@ -862,8 +979,14 @@ def _get_array_data(self, img):
862
979
if hasattr (img , "RescaleIntercept" ):
863
980
offset = img .RescaleIntercept
864
981
rescale_flag = True
982
+
865
983
if rescale_flag :
866
- data = data .astype (np .float32 ) * slope + offset
984
+ if self .to_gpu :
985
+ slope = cp .asarray (slope , dtype = cp .float32 )
986
+ offset = cp .asarray (offset , dtype = cp .float32 )
987
+ data = data .astype (cp .float32 ) * slope + offset
988
+ else :
989
+ data = data .astype (np .float32 ) * slope + offset
867
990
868
991
return data
869
992
@@ -884,8 +1007,6 @@ class NibabelReader(ImageReader):
884
1007
Default is False. CuPy and Kvikio are required for this option.
885
1008
Note: For compressed NIfTI files, some operations may still be performed on CPU memory,
886
1009
and the acceleration may not be significant. In some cases, it may be slower than loading on CPU.
887
- In practical use, it's recommended to add a warm up call before the actual loading.
888
- A related tutorial will be prepared in the future, and the document will be updated accordingly.
889
1010
kwargs: additional args for `nibabel.load` API. more details about available args:
890
1011
https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py
891
1012
0 commit comments