49
49
from ..casting import (floor_log2 , type_info , OK_FLOATS , shared_range )
50
50
51
51
from ..deprecator import ExpiredDeprecationError
52
+ from ..optpkg import optional_package
52
53
53
54
from numpy .testing import (assert_array_almost_equal ,
54
55
assert_array_equal )
55
56
import pytest
56
57
57
58
from nibabel .testing import nullcontext , assert_dt_equal , assert_allclose_safely , suppress_warnings
58
59
60
+ pyzstd , HAVE_ZSTD , _ = optional_package ("pyzstd" )
61
+
59
62
#: convenience variables for numpy types
60
63
FLOAT_TYPES = np .sctypes ['float' ]
61
64
COMPLEX_TYPES = np .sctypes ['complex' ]
68
71
def test__is_compressed_fobj ():
69
72
# _is_compressed helper function
70
73
with InTemporaryDirectory ():
71
- for ext , opener , compressed in (('' , open , False ),
72
- ('.gz' , gzip .open , True ),
73
- ('.bz2' , BZ2File , True )):
74
+ file_openers = [('' , open , False ),
75
+ ('.gz' , gzip .open , True ),
76
+ ('.bz2' , BZ2File , True )]
77
+ if HAVE_ZSTD :
78
+ file_openers += [('.zst' , pyzstd .ZstdFile , True )]
79
+ for ext , opener , compressed in file_openers :
74
80
fname = 'test.bin' + ext
75
81
for mode in ('wb' , 'rb' ):
76
82
fobj = opener (fname , mode )
@@ -88,12 +94,15 @@ def make_array(n, bytes):
88
94
arr .flags .writeable = True
89
95
return arr
90
96
91
- # Check whether file, gzip file, bz2 file reread memory from cache
97
+ # Check whether file, gzip file, bz2, zst file reread memory from cache
92
98
fname = 'test.bin'
93
99
with InTemporaryDirectory ():
100
+ openers = [open , gzip .open , BZ2File ]
101
+ if HAVE_ZSTD :
102
+ openers += [pyzstd .ZstdFile ]
94
103
for n , opener in itertools .product (
95
104
(256 , 1024 , 2560 , 25600 ),
96
- ( open , gzip . open , BZ2File ) ):
105
+ openers ):
97
106
in_arr = np .arange (n , dtype = dtype )
98
107
# Write array to file
99
108
fobj_w = opener (fname , 'wb' )
@@ -230,7 +239,10 @@ def test_array_from_file_openers():
230
239
dtype = np .dtype (np .float32 )
231
240
in_arr = np .arange (24 , dtype = dtype ).reshape (shape )
232
241
with InTemporaryDirectory ():
233
- for ext , offset in itertools .product (('' , '.gz' , '.bz2' ),
242
+ extensions = ['' , '.gz' , '.bz2' ]
243
+ if HAVE_ZSTD :
244
+ extensions += ['.zst' ]
245
+ for ext , offset in itertools .product (extensions ,
234
246
(0 , 5 , 10 )):
235
247
fname = 'test.bin' + ext
236
248
with Opener (fname , 'wb' ) as out_buf :
@@ -251,9 +263,12 @@ def test_array_from_file_reread():
251
263
offset = 9
252
264
fname = 'test.bin'
253
265
with InTemporaryDirectory ():
266
+ openers = [open , gzip .open , bz2 .BZ2File , BytesIO ]
267
+ if HAVE_ZSTD :
268
+ openers += [pyzstd .ZstdFile ]
254
269
for shape , opener , dtt , order in itertools .product (
255
270
((64 ,), (64 , 65 ), (64 , 65 , 66 )),
256
- ( open , gzip . open , bz2 . BZ2File , BytesIO ) ,
271
+ openers ,
257
272
(np .int16 , np .float32 ),
258
273
('F' , 'C' )):
259
274
n_els = np .prod (shape )
@@ -901,7 +916,9 @@ def test_write_zeros():
901
916
def test_seek_tell ():
902
917
# Test seek tell routine
903
918
bio = BytesIO ()
904
- in_files = bio , 'test.bin' , 'test.gz' , 'test.bz2'
919
+ in_files = [bio , 'test.bin' , 'test.gz' , 'test.bz2' ]
920
+ if HAVE_ZSTD :
921
+ in_files += ['test.zst' ]
905
922
start = 10
906
923
end = 100
907
924
diff = end - start
@@ -920,9 +937,12 @@ def test_seek_tell():
920
937
fobj .write (b'\x01 ' * start )
921
938
assert fobj .tell () == start
922
939
# Files other than BZ2Files can seek forward on write, leaving
923
- # zeros in their wake. BZ2Files can't seek when writing, unless
924
- # we enable the write0 flag to seek_tell
925
- if not write0 and in_file == 'test.bz2' : # Can't seek write in bz2
940
+ # zeros in their wake. BZ2Files can't seek when writing,
941
+ # unless we enable the write0 flag to seek_tell
942
+ # ZstdFiles also does not support seek forward on write
943
+ if (not write0 and
944
+ (in_file == 'test.bz2' or
945
+ in_file == 'test.zst' )): # Can't seek write in bz2, zst
926
946
# write the zeros by hand for the read test below
927
947
fobj .write (b'\x00 ' * diff )
928
948
else :
@@ -946,7 +966,10 @@ def test_seek_tell():
946
966
# Check we have the expected written output
947
967
with ImageOpener (in_file , 'rb' ) as fobj :
948
968
assert fobj .read () == b'\x01 ' * start + b'\x00 ' * diff + b'\x02 ' * tail
949
- for in_file in ('test2.gz' , 'test2.bz2' ):
969
+ input_files = ['test2.gz' , 'test2.bz2' ]
970
+ if HAVE_ZSTD :
971
+ input_files += ['test2.zst' ]
972
+ for in_file in input_files :
950
973
# Check failure of write seek backwards
951
974
with ImageOpener (in_file , 'wb' ) as fobj :
952
975
fobj .write (b'g' * 10 )
0 commit comments