@@ -142,22 +142,65 @@ function parse_cat_ast(ex::Expr)
142
142
cat_any (Val (maxdim), Val (catdim), nargs)
143
143
end
144
144
145
+ #=
146
+ For example,
147
+ * `@SArray rand(2, 3, 4)`
148
+ * `@SArray rand(rng, 3, 4)`
149
+ will be expanded to the following.
150
+ * `_rand_with_Val(SArray, 2, 3, _int2val(2), _int2val(3), Val((4,)))`
151
+ * `_rand_with_Val(SArray, 2, 3, _int2val(rng), _int2val(3), Val((4,)))`
152
+ The function `_int2val` is required to avoid the following case.
153
+ * `_rand_with_Val(SArray, 2, 3, Val(2), Val(3), Val((4,)))`
154
+ * `_rand_with_Val(SArray, 2, 3, Val(rng), Val(3), Val((4,)))`
155
+ Mutable object such as `rng` cannot be type parameter, and `Val(rng)` throws an error.
156
+ =#
157
+ _int2val (x:: Int ) = Val (x)
158
+ _int2val (:: Any ) = nothing
159
+ # @SArray zeros(...)
160
+ _zeros_with_Val (:: Type{SA} , :: Int , :: Val{n1} , :: Val{ns} ) where {SA, n1, ns} = zeros (SA{Tuple{n1, ns... }})
161
+ _zeros_with_Val (:: Type{SA} , T:: DataType , :: Val , :: Val{ns} ) where {SA, ns} = zeros (SA{Tuple{ns... }, T})
162
+ # @SArray ones(...)
163
+ _ones_with_Val (:: Type{SA} , :: Int , :: Val{n1} , :: Val{ns} ) where {SA, n1, ns} = ones (SA{Tuple{n1, ns... }})
164
+ _ones_with_Val (:: Type{SA} , T:: DataType , :: Val , :: Val{ns} ) where {SA, ns} = ones (SA{Tuple{ns... }, T})
165
+ # @SArray rand(...)
166
+ _rand_with_Val (:: Type{SA} , :: Int , :: Int , :: Val{n1} , :: Val{n2} , :: Val{ns} ) where {SA, n1, n2, ns} = rand (SA{Tuple{n1,n2,ns... }})
167
+ _rand_with_Val (:: Type{SA} , T:: DataType , :: Int , :: Nothing , :: Val{n1} , :: Val{ns} ) where {SA, n1, ns} = _rand (Random. GLOBAL_RNG, T, Size (n1, ns... ), SA{Tuple{n1, ns... }, T})
168
+ _rand_with_Val (:: Type{SA} , sampler, :: Int , :: Nothing , :: Val{n1} , :: Val{ns} ) where {SA, n1, ns} = _rand (Random. GLOBAL_RNG, sampler, Size (n1, ns... ), SA{Tuple{n1, ns... }, Random. gentype (sampler)})
169
+ _rand_with_Val (:: Type{SA} , rng:: AbstractRNG , :: Int , :: Nothing , :: Val{n1} , :: Val{ns} ) where {SA, n1, ns} = _rand (rng, Float64, Size (n1, ns... ), SA{Tuple{n1, ns... }, Float64})
170
+ _rand_with_Val (:: Type{SA} , rng:: AbstractRNG , T:: DataType , :: Nothing , :: Nothing , :: Val{ns} ) where {SA, ns} = _rand (rng, T, Size (ns... ), SA{Tuple{ns... }, T})
171
+ _rand_with_Val (:: Type{SA} , rng:: AbstractRNG , sampler, :: Nothing , :: Nothing , :: Val{ns} ) where {SA, ns} = _rand (rng, sampler, Size (ns... ), SA{Tuple{ns... }, Random. gentype (sampler)})
172
+ # @SArray randn(...)
173
+ _randn_with_Val (:: Type{SA} , :: Int , :: Int , :: Val{n1} , :: Val{n2} , :: Val{ns} ) where {SA, n1, n2, ns} = randn (SA{Tuple{n1,n2,ns... }})
174
+ _randn_with_Val (:: Type{SA} , T:: DataType , :: Int , :: Nothing , :: Val{n1} , :: Val{ns} ) where {SA, n1, ns} = _randn (Random. GLOBAL_RNG, Size (n1, ns... ), SA{Tuple{n1, ns... }, T})
175
+ _randn_with_Val (:: Type{SA} , rng:: AbstractRNG , :: Int , :: Nothing , :: Val{n1} , :: Val{ns} ) where {SA, n1, ns} = _randn (rng, Size (n1, ns... ), SA{Tuple{n1, ns... }, Float64})
176
+ _randn_with_Val (:: Type{SA} , rng:: AbstractRNG , T:: DataType , :: Nothing , :: Nothing , :: Val{ns} ) where {SA, ns} = _randn (rng, Size (ns... ), SA{Tuple{ns... }, T})
177
+ # @SArray randexp(...)
178
+ _randexp_with_Val (:: Type{SA} , :: Int , :: Int , :: Val{n1} , :: Val{n2} , :: Val{ns} ) where {SA, n1, n2, ns} = randexp (SA{Tuple{n1,n2,ns... }})
179
+ _randexp_with_Val (:: Type{SA} , T:: DataType , :: Int , :: Nothing , :: Val{n1} , :: Val{ns} ) where {SA, n1, ns} = _randexp (Random. GLOBAL_RNG, Size (n1, ns... ), SA{Tuple{n1, ns... }, T})
180
+ _randexp_with_Val (:: Type{SA} , rng:: AbstractRNG , :: Int , :: Nothing , :: Val{n1} , :: Val{ns} ) where {SA, n1, ns} = _randexp (rng, Size (n1, ns... ), SA{Tuple{n1, ns... }, Float64})
181
+ _randexp_with_Val (:: Type{SA} , rng:: AbstractRNG , T:: DataType , :: Nothing , :: Nothing , :: Val{ns} ) where {SA, ns} = _randexp (rng, Size (ns... ), SA{Tuple{ns... }, T})
182
+
145
183
escall (args) = Iterators. map (esc, args)
184
+ function _isnonnegvec (args)
185
+ length (args) == 0 && return false
186
+ all (isa .(args, Integer)) && return all (args .≥ 0 )
187
+ return false
188
+ end
146
189
function static_array_gen (:: Type{SA} , @nospecialize (ex), mod:: Module ) where {SA}
147
190
if ! isa (ex, Expr)
148
191
error (" Bad input for @$SA " )
149
192
end
150
193
head = ex. head
151
194
if head === :vect # vector
152
- return :($ SA {$ Tuple{$(length(ex.args))}} ($ tuple ($ (escall (ex. args)... ))))
195
+ return :($ SA {Tuple{$(length(ex.args))}} ($ tuple ($ (escall (ex. args)... ))))
153
196
elseif head === :ref # typed, vector
154
- return :($ SA {$ Tuple{$(length(ex.args)-1)},$(esc(ex.args[1]))} ($ tuple ($ (escall (ex. args[2 : end ])... ))))
197
+ return :($ SA {Tuple{$(length(ex.args)-1)},$(esc(ex.args[1]))} ($ tuple ($ (escall (ex. args[2 : end ])... ))))
155
198
elseif head === :typed_vcat || head === :typed_hcat || head === :typed_ncat # typed, cat
156
199
args = parse_cat_ast (ex)
157
- return :($ SA {$ Tuple{$(size(args)...)},$(esc(ex.args[1]))} ($ tuple ($ (escall (args)... ))))
200
+ return :($ SA {Tuple{$(size(args)...)},$(esc(ex.args[1]))} ($ tuple ($ (escall (args)... ))))
158
201
elseif head === :vcat || head === :hcat || head === :ncat # untyped, cat
159
202
args = parse_cat_ast (ex)
160
- return :($ SA {$ Tuple{$(size(args)...)}} ($ tuple ($ (escall (args)... ))))
203
+ return :($ SA {Tuple{$(size(args)...)}} ($ tuple ($ (escall (args)... ))))
161
204
elseif head === :comprehension
162
205
if length (ex. args) != 1
163
206
error (" Expected generator in comprehension, e.g. [f(i,j) for i = 1:3, j = 1:3]" )
@@ -173,7 +216,7 @@ function static_array_gen(::Type{SA}, @nospecialize(ex), mod::Module) where {SA}
173
216
return quote
174
217
let
175
218
f ($ (escall (rng_args). .. )) = $ (esc (ex. args[1 ]))
176
- $ SA {$ Tuple{$(size(exprs)...)}} ($ tuple ($ (exprs... )))
219
+ $ SA {Tuple{$(size(exprs)...)}} ($ tuple ($ (exprs... )))
177
220
end
178
221
end
179
222
elseif head === :typed_comprehension
@@ -192,26 +235,58 @@ function static_array_gen(::Type{SA}, @nospecialize(ex), mod::Module) where {SA}
192
235
return quote
193
236
let
194
237
f ($ (escall (rng_args). .. )) = $ (esc (ex. args[1 ]))
195
- $ SA {$ Tuple{$(size(exprs)...)},$T} ($ tuple ($ (exprs... )))
238
+ $ SA {Tuple{$(size(exprs)...)},$T} ($ tuple ($ (exprs... )))
196
239
end
197
240
end
198
241
elseif head === :call
199
242
f = ex. args[1 ]
200
- if f === :zeros || f === :ones || f === :rand || f === :randn || f === :randexp
201
- if length (ex. args) == 1
202
- f === :zeros || f === :ones || error (" @$SA got bad expression: $(ex) " )
203
- return :($ f ($ SA{$ Tuple{},$ Float64}))
204
- end
205
- return quote
206
- if isa ($ (esc (ex. args[2 ])), DataType)
207
- $ f ($ SA{$ Tuple{$ (escall (ex. args[3 : end ])... )},$ (esc (ex. args[2 ]))})
208
- else
209
- $ f ($ SA{$ Tuple{$ (escall (ex. args[2 : end ])... )}})
210
- end
243
+ fargs = ex. args[2 : end ]
244
+ if f === :zeros || f === :ones
245
+ _f_with_Val = Symbol (:_ , f, :_with_Val )
246
+ if length (fargs) == 0
247
+ # for calls like `zeros()`
248
+ return :($ f ($ SA{Tuple{},$ Float64}))
249
+ elseif _isnonnegvec (fargs)
250
+ # for calls like `zeros(dims...)`
251
+ return :($ f ($ SA{Tuple{$ (escall (fargs)... )}}))
252
+ else
253
+ # for calls like `zeros(type)`
254
+ # for calls like `zeros(type, dims...)`
255
+ return :($ _f_with_Val ($ SA, $ (esc (fargs[1 ])), Val ($ (esc (fargs[1 ]))), Val (tuple ($ (escall (fargs[2 : end ])... )))))
211
256
end
212
257
elseif f === :fill
213
- length (ex. args) == 1 && error (" @$SA got bad expression: $(ex) " )
214
- return :($ f ($ (esc (ex. args[2 ])), $ SA{$ Tuple{$ (escall (ex. args[3 : end ])... )}}))
258
+ # for calls like `fill(value, dims...)`
259
+ return :($ f ($ (esc (fargs[1 ])), $ SA{Tuple{$ (escall (fargs[2 : end ])... )}}))
260
+ elseif f === :rand || f === :randn || f === :randexp
261
+ _f_with_Val = Symbol (:_ , f, :_with_Val )
262
+ if length (fargs) == 0
263
+ # No support for `@SArray rand()`
264
+ error (" @$SA got bad expression: $(ex) " )
265
+ elseif _isnonnegvec (fargs)
266
+ # for calls like `rand(dims...)`
267
+ return :($ f ($ SA{Tuple{$ (escall (fargs)... )}}))
268
+ elseif length (fargs) ≥ 2
269
+ # for calls like `rand(dim1, dim2, dims...)`
270
+ # for calls like `rand(type, dim1, dims...)`
271
+ # for calls like `rand(sampler, dim1, dims...)`
272
+ # for calls like `rand(rng, dim1, dims...)`
273
+ # for calls like `rand(rng, type, dims...)`
274
+ # for calls like `rand(rng, sampler, dims...)`
275
+ # for calls like `randn(dim1, dim2, dims...)`
276
+ # for calls like `randn(type, dim1, dims...)`
277
+ # for calls like `randn(rng, dim1, dims...)`
278
+ # for calls like `randn(rng, type, dims...)`
279
+ # for calls like `randexp(dim1, dim2, dims...)`
280
+ # for calls like `randexp(type, dim1, dims...)`
281
+ # for calls like `randexp(rng, dim1, dims...)`
282
+ # for calls like `randexp(rng, type, dims...)`
283
+ return :($ _f_with_Val ($ SA, $ (esc (fargs[1 ])), $ (esc (fargs[2 ])), _int2val ($ (esc (fargs[1 ]))), _int2val ($ (esc (fargs[2 ]))), Val (tuple ($ (escall (fargs[3 : end ])... )))))
284
+ elseif length (fargs) == 1
285
+ # for calls like `rand(dim)`
286
+ return :($ f ($ SA{Tuple{$ (escall (fargs)... )}}))
287
+ else
288
+ error (" @$SA got bad expression: $(ex) " )
289
+ end
215
290
else
216
291
error (" @$SA only supports the zeros(), ones(), fill(), rand(), randn(), and randexp() functions." )
217
292
end
0 commit comments