-
Notifications
You must be signed in to change notification settings - Fork 184
/
Copy pathstdlib_io_npy_load.fypp
533 lines (461 loc) · 18.4 KB
/
stdlib_io_npy_load.fypp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
! SPDX-Identifier: MIT
#:include "common.fypp"
#:set RANKS = range(1, MAXRANK + 1)
#:set KINDS_TYPES = REAL_KINDS_TYPES + INT_KINDS_TYPES + CMPLX_KINDS_TYPES
!> Implementation of loading npy files into multidimensional arrays
submodule (stdlib_io_npy) stdlib_io_npy_load
use stdlib_error, only : error_stop
use stdlib_strings, only : to_string, starts_with
implicit none
contains
#:for k1, t1 in KINDS_TYPES
#:for rank in RANKS
!> Load a ${rank}$-dimensional array from a npy file
module subroutine load_npy_${t1[0]}$${k1}$_${rank}$(filename, array, iostat, iomsg)
!> Name of the npy file to load from
character(len=*), intent(in) :: filename
!> Array to be loaded from the npy file
${t1}$, allocatable, intent(out) :: array${ranksuffix(rank)}$
!> Error status of loading, zero on success
integer, intent(out), optional :: iostat
!> Associated error message in case of non-zero status code
character(len=:), allocatable, intent(out), optional :: iomsg
character(len=*), parameter :: vtype = type_${t1[0]}$${k1}$
integer, parameter :: rank = ${rank}$
integer :: io, stat
character(len=:), allocatable :: msg
open(newunit=io, file=filename, form="unformatted", access="stream", iostat=stat)
catch: block
character(len=:), allocatable :: this_type
integer, allocatable :: vshape(:)
call get_descriptor(io, filename, this_type, vshape, stat, msg)
if (stat /= 0) exit catch
if (this_type /= vtype) then
stat = 1
msg = "File '"//filename//"' contains data of type '"//this_type//"', "//&
& "but expected '"//vtype//"'"
exit catch
end if
if (size(vshape) /= rank) then
stat = 1
msg = "File '"//filename//"' contains data of rank "//&
& to_string(size(vshape))//", but expected "//&
& to_string(rank)
exit catch
end if
call allocator(array, vshape, stat)
if (stat /= 0) then
msg = "Failed to allocate array of type '"//vtype//"' "//&
& "with total size of "//to_string(product(vshape))
exit catch
end if
read(io, iostat=stat) array
end block catch
close(io)
if (present(iostat)) then
iostat = stat
else if (stat /= 0) then
if (allocated(msg)) then
call error_stop("Failed to read array from file '"//filename//"'"//nl//&
& msg)
else
call error_stop("Failed to read array from file '"//filename//"'")
end if
end if
if (present(iomsg).and.allocated(msg)) call move_alloc(msg, iomsg)
contains
!> Wrapped intrinsic allocate to create an allocation from a shape array
subroutine allocator(array, vshape, stat)
!> Instance of the array to be allocated
${t1}$, allocatable, intent(out) :: array${ranksuffix(rank)}$
!> Dimensions to allocate for
integer, intent(in) :: vshape(:)
!> Status of allocate
integer, intent(out) :: stat
allocate(array( &
#:for i in range(rank-1)
& vshape(${i+1}$), &
#:endfor
& vshape(${rank}$)), &
& stat=stat)
end subroutine allocator
end subroutine load_npy_${t1[0]}$${k1}$_${rank}$
#:endfor
#:endfor
!> Read the npy header from a binary file and retrieve the descriptor string.
subroutine get_descriptor(io, filename, vtype, vshape, stat, msg)
!> Unformatted, stream accessed unit
integer, intent(in) :: io
!> Filename for error reporting
character(len=*), intent(in) :: filename
!> Type of data saved in npy file
character(len=:), allocatable, intent(out) :: vtype
!> Shape descriptor of the
integer, allocatable, intent(out) :: vshape(:)
!> Status of operation
integer, intent(out) :: stat
!> Associated error message in case of non-zero status
character(len=:), allocatable, intent(out) :: msg
integer :: major, header_len, i
character(len=:), allocatable :: dict
character(len=8) :: header
character :: buf(4)
logical :: fortran_order
read(io, iostat=stat) header
if (stat /= 0) return
call parse_header(header, major, stat, msg)
if (stat /= 0) return
read(io, iostat=stat) buf(1:merge(4, 2, major > 1))
if (stat /= 0) return
if (major > 1) then
header_len = ichar(buf(1)) &
& + ichar(buf(2)) * 2**8 &
& + ichar(buf(3)) * 2**16 &
& + ichar(buf(4)) * 2**32
else
header_len = ichar(buf(1)) &
& + ichar(buf(2)) * 2**8
end if
allocate(character(header_len) :: dict, stat=stat)
if (stat /= 0) return
read(io, iostat=stat) dict
if (stat /= 0) return
if (dict(header_len:header_len) /= nl) then
stat = 1
msg = "Descriptor length does not match"
return
end if
if (scan(dict, achar(0)) > 0) then
stat = 1
msg = "Nul byte not allowed in descriptor string"
return
end if
call parse_descriptor(trim(dict(:len(dict)-1)), filename, &
& vtype, fortran_order, vshape, stat, msg)
if (stat /= 0) return
if (.not.fortran_order) then
vshape = [(vshape(i), i = size(vshape), 1, -1)]
end if
end subroutine get_descriptor
!> Parse the first eight bytes of the npy header to verify the data
subroutine parse_header(header, major, stat, msg)
!> Header of the binary file
character(len=*), intent(in) :: header
!> Major version of the npy format
integer, intent(out) :: major
!> Status of operation
integer, intent(out) :: stat
!> Associated error message in case of non-zero status
character(len=:), allocatable, intent(out) :: msg
integer :: minor
if (header(1:1) /= magic_number) then
stat = 1
msg = "Expected z'93' but got z'"//to_string(ichar(header(1:1)))//"' "//&
& "as first byte"
return
end if
if (header(2:6) /= magic_string) then
stat = 1
msg = "Expected identifier '"//magic_string//"'"
return
end if
major = ichar(header(7:7))
if (.not.any(major == [1, 2, 3])) then
stat = 1
msg = "Unsupported format major version number '"//to_string(major)//"'"
return
end if
minor = ichar(header(8:8))
if (minor /= 0) then
stat = 1
msg = "Unsupported format version "// &
& "'"//to_string(major)//"."//to_string(minor)//"'"
return
end if
end subroutine parse_header
!> Parse the descriptor in the npy header. This routine implements a minimal
!> non-recursive parser for serialized Python dictionaries.
subroutine parse_descriptor(input, filename, vtype, fortran_order, vshape, stat, msg)
!> Input string to parse as descriptor
character(len=*), intent(in) :: input
!> Filename for error reporting
character(len=*), intent(in) :: filename
!> Type of the data stored, retrieved from field `descr`
character(len=:), allocatable, intent(out) :: vtype
!> Whether the data is in left layout, retrieved from field `fortran_order`
logical, intent(out) :: fortran_order
!> Shape of the stored data, retrieved from field `shape`
integer, allocatable, intent(out) :: vshape(:)
!> Status of operation
integer, intent(out) :: stat
!> Associated error message in case of non-zero status
character(len=:), allocatable, intent(out) :: msg
enum, bind(c)
enumerator :: invalid, string, lbrace, rbrace, comma, colon, &
lparen, rparen, bool, literal, space
end enum
type :: token_type
integer :: first, last, kind
end type token_type
integer :: pos
character(len=:), allocatable :: key
type(token_type) :: token, last
logical :: has_descr, has_shape, has_fortran_order
has_descr = .false.
has_shape = .false.
has_fortran_order = .false.
pos = 0
call next_token(input, pos, token, [lbrace], stat, msg)
if (stat /= 0) return
last = token_type(pos, pos, comma)
do while (pos < len(input))
call get_token(input, pos, token)
select case(token%kind)
case(space)
continue
case(comma)
if (token%kind == last%kind) then
stat = 1
msg = make_message(filename, input, token%first, token%last, &
& "Comma cannot appear at this point")
return
end if
last = token
case(rbrace)
exit
case(string)
if (token%kind == last%kind) then
stat = 1
msg = make_message(filename, input, token%first, token%last, &
& "String cannot appear at this point")
return
end if
last = token
key = input(token%first+1:token%last-1)
call next_token(input, pos, token, [colon], stat, msg)
if (stat /= 0) return
if (key == "descr" .and. has_descr &
& .or. key == "fortran_order" .and. has_fortran_order &
& .or. key == "shape" .and. has_shape) then
stat = 1
msg = make_message(filename, input, last%first, last%last, &
& "Duplicate entry for '"//key//"' found")
return
end if
select case(key)
case("descr")
call next_token(input, pos, token, [string], stat, msg)
if (stat /= 0) return
vtype = input(token%first+1:token%last-1)
has_descr = .true.
case("fortran_order")
call next_token(input, pos, token, [bool], stat, msg)
if (stat /= 0) return
fortran_order = input(token%first:token%last) == "True"
has_fortran_order = .true.
case("shape")
call parse_tuple(input, pos, vshape, stat, msg)
has_shape = .true.
case default
stat = 1
msg = make_message(filename, input, last%first, last%last, &
& "Invalid entry '"//key//"' in dictionary encountered")
return
end select
case default
stat = 1
msg = make_message(filename, input, token%first, token%last, &
& "Invalid token encountered")
return
end select
end do
if (.not.has_descr) then
stat = 1
msg = make_message(filename, input, 1, pos, &
& "Dictionary does not contain required entry 'descr'")
end if
if (.not.has_shape) then
stat = 1
msg = make_message(filename, input, 1, pos, &
& "Dictionary does not contain required entry 'shape'")
end if
if (.not.has_fortran_order) then
stat = 1
msg = make_message(filename, input, 1, pos, &
& "Dictionary does not contain required entry 'fortran_order'")
end if
contains
function make_message(filename, input, first, last, message) result(str)
!> Filename for context
character(len=*), intent(in) :: filename
!> Input string to parse
character(len=*), intent(in) :: input
!> Offset in the input
integer, intent(in) :: first, last
!> Error message
character(len=*), intent(in) :: message
!> Final output message
character(len=:), allocatable :: str
character(len=*), parameter :: nl = new_line('a')
str = message // nl // &
& " --> " // filename // ":1:" // to_string(first) // "-" // to_string(last) // nl // &
& " |" // nl // &
& "1 | " // input // nl // &
& " |" // repeat(" ", first) // repeat("^", last - first + 1) // nl // &
& " |"
end function make_message
!> Parse a tuple of integers into an array of integers
subroutine parse_tuple(input, pos, tuple, stat, msg)
!> Input string to parse
character(len=*), intent(in) :: input
!> Offset in the input, will be advanced after reading
integer, intent(inout) :: pos
!> Array representing tuple of integers
integer, allocatable, intent(out) :: tuple(:)
!> Status of operation
integer, intent(out) :: stat
!> Associated error message in case of non-zero status
character(len=:), allocatable, intent(out) :: msg
type(token_type) :: token
integer :: last, itmp
allocate(tuple(0), stat=stat)
if (stat /= 0) return
call next_token(input, pos, token, [lparen], stat, msg)
if (stat /= 0) return
last = comma
do while (pos < len(input))
call get_token(input, pos, token)
select case(token%kind)
case(space)
continue
case(literal)
if (token%kind == last) then
stat = 1
msg = make_message(filename, input, token%first, token%last, &
& "Invalid token encountered")
return
end if
last = token%kind
read(input(token%first:token%last), *, iostat=stat) itmp
if (stat /= 0) then
return
end if
tuple = [tuple, itmp]
case(comma)
if (token%kind == last) then
stat = 1
msg = make_message(filename, input, token%first, token%last, &
& "Invalid token encountered")
return
end if
last = token%kind
case(rparen)
exit
case default
stat = 1
msg = make_message(filename, input, token%first, token%last, &
& "Invalid token encountered")
return
end select
end do
end subroutine parse_tuple
!> Get the next allowed token
subroutine next_token(input, pos, token, allowed_token, stat, msg)
!> Input string to parse
character(len=*), intent(in) :: input
!> Current offset in the input string
integer, intent(inout) :: pos
!> Last token parsed
type(token_type), intent(out) :: token
!> Tokens allowed in the current context
integer, intent(in) :: allowed_token(:)
!> Status of operation
integer, intent(out) :: stat
!> Associated error message in case of non-zero status
character(len=:), allocatable, intent(out) :: msg
stat = pos
do while (pos < len(input))
call get_token(input, pos, token)
if (token%kind == space) then
continue
else if (any(token%kind == allowed_token)) then
stat = 0
exit
else
stat = 1
msg = make_message(filename, input, token%first, token%last, &
& "Invalid token encountered")
exit
end if
end do
end subroutine next_token
!> Tokenize input string
subroutine get_token(input, pos, token)
!> Input strin to tokenize
character(len=*), intent(in) :: input
!> Offset in input string, will be advanced
integer, intent(inout) :: pos
!> Returned token from the next position
type(token_type), intent(out) :: token
character :: quote
pos = pos + 1
select case(input(pos:pos))
case("""", "'")
quote = input(pos:pos)
token%first = pos
pos = pos + 1
do while (pos <= len(input))
if (input(pos:pos) == quote) then
token%last = pos
exit
else
pos = pos + 1
end if
end do
token%kind = string
case("0", "1", "2", "3", "4", "5", "6", "7", "8", "9")
token%first = pos
do while (pos <= len(input))
if (.not.any(input(pos:pos) == ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"])) then
pos = pos - 1
token%last = pos
exit
else
pos = pos + 1
end if
end do
token%kind = literal
case("T")
if (starts_with(input(pos:), "True")) then
token = token_type(pos, pos+3, bool)
pos = pos + 3
else
token = token_type(pos, pos, invalid)
end if
case("F")
if (starts_with(input(pos:), "False")) then
token = token_type(pos, pos+4, bool)
pos = pos + 4
else
token = token_type(pos, pos, invalid)
end if
case("{")
token = token_type(pos, pos, lbrace)
case("}")
token = token_type(pos, pos, rbrace)
case(",")
token = token_type(pos, pos, comma)
case(":")
token = token_type(pos, pos, colon)
case("(")
token = token_type(pos, pos, lparen)
case(")")
token = token_type(pos, pos, rparen)
case(" ", nl)
token = token_type(pos, pos, space)
case default
token = token_type(pos, pos, invalid)
end select
end subroutine get_token
end subroutine parse_descriptor
end submodule stdlib_io_npy_load