Skip to content

Commit e2d772f

Browse files
Extend macros with rand to support custom samplers (#1210)
* Extend macros with rand to support custom samplers * Fix tests * Update Project.toml Co-authored-by: Yuto Horikawa <[email protected]> * Update test/arraymath.jl Co-authored-by: Yuto Horikawa <[email protected]> * Update src/SVector.jl Co-authored-by: Yuto Horikawa <[email protected]> * Update src/SMatrix.jl Co-authored-by: Yuto Horikawa <[email protected]> * Update `SArray` macro * fix reported issue; support rng in SArray and SMatrix * Code suggestions for #1210 (#1213) * move `ex.args[2] isa Integer` * split `if` block * simplify :zeros and :ones * refactor :rand * refactor :randn and :randexp * update comments * add _isnonnegvec * update with `_isnonnegvec` * add `_isnonnegvec(args, n)` method to check the size of `args` * fix `@SArray` for `@SArray rand(rng,T,dim)` etc. * update comments * update `@SVector` macro * update `@SMatrix` * update `@SVector` * update `@SArray` * introduce `fargs` variable * avoid `_isnonnegvec` in `static_matrix_gen` * avoid `_isnonnegvec` in `static_vector_gen` * remove unnecessary `_isnonnegvec` * add `_rng()` function * update tests on `@SVector` macro * update tests on `@MVector` macro * organize test/MMatrix.jl and test/SMatrix.jl * organize test/MMatrix.jl and test/SMatrix.jl * update with broken tests * organize test/MMatrix.jl and test/SMatrix.jl for `rand*` functions * fix around `broken` key for `@test` macro * fix zero-length tests * update `test/SArray.jl` to match `test/MArray.jl` * update tests for `@SArray ones` etc. * add supports for `@SArray ones(3-1,2)` etc. * move block for `fill` * update macro `@SArray rand(rng,2,3)` to use ordinary dispatches * update around `@SArray randn` etc. * remove unnecessary dollars * simplify `@SArray fill` * add `@testset "expand_error"` * update tests for `@SArray rand(...)` etc. * fix bug in `rand*_with_Val` * cleanup tests * update macro `@SMatrix rand(rng,2,3)` to use ordinary dispatches * update macro `@SVector rand(rng,3)` to use ordinary dispatches * move block for `fill` * simplify `_randexp_with_Val` --------- Co-authored-by: Yuto Horikawa <[email protected]>
1 parent 3fd8fb9 commit e2d772f

13 files changed

+826
-179
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "StaticArrays"
22
uuid = "90137ffa-7385-5640-81b9-e52037218182"
3-
version = "1.8.2"
3+
version = "1.9.0"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/SArray.jl

+94-19
Original file line numberDiff line numberDiff line change
@@ -142,22 +142,65 @@ function parse_cat_ast(ex::Expr)
142142
cat_any(Val(maxdim), Val(catdim), nargs)
143143
end
144144

145+
#=
146+
For example,
147+
* `@SArray rand(2, 3, 4)`
148+
* `@SArray rand(rng, 3, 4)`
149+
will be expanded to the following.
150+
* `_rand_with_Val(SArray, 2, 3, _int2val(2), _int2val(3), Val((4,)))`
151+
* `_rand_with_Val(SArray, 2, 3, _int2val(rng), _int2val(3), Val((4,)))`
152+
The function `_int2val` is required to avoid the following case.
153+
* `_rand_with_Val(SArray, 2, 3, Val(2), Val(3), Val((4,)))`
154+
* `_rand_with_Val(SArray, 2, 3, Val(rng), Val(3), Val((4,)))`
155+
Mutable object such as `rng` cannot be type parameter, and `Val(rng)` throws an error.
156+
=#
157+
_int2val(x::Int) = Val(x)
158+
_int2val(::Any) = nothing
159+
# @SArray zeros(...)
160+
_zeros_with_Val(::Type{SA}, ::Int, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = zeros(SA{Tuple{n1, ns...}})
161+
_zeros_with_Val(::Type{SA}, T::DataType, ::Val, ::Val{ns}) where {SA, ns} = zeros(SA{Tuple{ns...}, T})
162+
# @SArray ones(...)
163+
_ones_with_Val(::Type{SA}, ::Int, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = ones(SA{Tuple{n1, ns...}})
164+
_ones_with_Val(::Type{SA}, T::DataType, ::Val, ::Val{ns}) where {SA, ns} = ones(SA{Tuple{ns...}, T})
165+
# @SArray rand(...)
166+
_rand_with_Val(::Type{SA}, ::Int, ::Int, ::Val{n1}, ::Val{n2}, ::Val{ns}) where {SA, n1, n2, ns} = rand(SA{Tuple{n1,n2,ns...}})
167+
_rand_with_Val(::Type{SA}, T::DataType, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _rand(Random.GLOBAL_RNG, T, Size(n1, ns...), SA{Tuple{n1, ns...}, T})
168+
_rand_with_Val(::Type{SA}, sampler, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _rand(Random.GLOBAL_RNG, sampler, Size(n1, ns...), SA{Tuple{n1, ns...}, Random.gentype(sampler)})
169+
_rand_with_Val(::Type{SA}, rng::AbstractRNG, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _rand(rng, Float64, Size(n1, ns...), SA{Tuple{n1, ns...}, Float64})
170+
_rand_with_Val(::Type{SA}, rng::AbstractRNG, T::DataType, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _rand(rng, T, Size(ns...), SA{Tuple{ns...}, T})
171+
_rand_with_Val(::Type{SA}, rng::AbstractRNG, sampler, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _rand(rng, sampler, Size(ns...), SA{Tuple{ns...}, Random.gentype(sampler)})
172+
# @SArray randn(...)
173+
_randn_with_Val(::Type{SA}, ::Int, ::Int, ::Val{n1}, ::Val{n2}, ::Val{ns}) where {SA, n1, n2, ns} = randn(SA{Tuple{n1,n2,ns...}})
174+
_randn_with_Val(::Type{SA}, T::DataType, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randn(Random.GLOBAL_RNG, Size(n1, ns...), SA{Tuple{n1, ns...}, T})
175+
_randn_with_Val(::Type{SA}, rng::AbstractRNG, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randn(rng, Size(n1, ns...), SA{Tuple{n1, ns...}, Float64})
176+
_randn_with_Val(::Type{SA}, rng::AbstractRNG, T::DataType, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _randn(rng, Size(ns...), SA{Tuple{ns...}, T})
177+
# @SArray randexp(...)
178+
_randexp_with_Val(::Type{SA}, ::Int, ::Int, ::Val{n1}, ::Val{n2}, ::Val{ns}) where {SA, n1, n2, ns} = randexp(SA{Tuple{n1,n2,ns...}})
179+
_randexp_with_Val(::Type{SA}, T::DataType, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randexp(Random.GLOBAL_RNG, Size(n1, ns...), SA{Tuple{n1, ns...}, T})
180+
_randexp_with_Val(::Type{SA}, rng::AbstractRNG, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randexp(rng, Size(n1, ns...), SA{Tuple{n1, ns...}, Float64})
181+
_randexp_with_Val(::Type{SA}, rng::AbstractRNG, T::DataType, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _randexp(rng, Size(ns...), SA{Tuple{ns...}, T})
182+
145183
escall(args) = Iterators.map(esc, args)
184+
function _isnonnegvec(args)
185+
length(args) == 0 && return false
186+
all(isa.(args, Integer)) && return all(args .≥ 0)
187+
return false
188+
end
146189
function static_array_gen(::Type{SA}, @nospecialize(ex), mod::Module) where {SA}
147190
if !isa(ex, Expr)
148191
error("Bad input for @$SA")
149192
end
150193
head = ex.head
151194
if head === :vect # vector
152-
return :($SA{$Tuple{$(length(ex.args))}}($tuple($(escall(ex.args)...))))
195+
return :($SA{Tuple{$(length(ex.args))}}($tuple($(escall(ex.args)...))))
153196
elseif head === :ref # typed, vector
154-
return :($SA{$Tuple{$(length(ex.args)-1)},$(esc(ex.args[1]))}($tuple($(escall(ex.args[2:end])...))))
197+
return :($SA{Tuple{$(length(ex.args)-1)},$(esc(ex.args[1]))}($tuple($(escall(ex.args[2:end])...))))
155198
elseif head === :typed_vcat || head === :typed_hcat || head === :typed_ncat # typed, cat
156199
args = parse_cat_ast(ex)
157-
return :($SA{$Tuple{$(size(args)...)},$(esc(ex.args[1]))}($tuple($(escall(args)...))))
200+
return :($SA{Tuple{$(size(args)...)},$(esc(ex.args[1]))}($tuple($(escall(args)...))))
158201
elseif head === :vcat || head === :hcat || head === :ncat # untyped, cat
159202
args = parse_cat_ast(ex)
160-
return :($SA{$Tuple{$(size(args)...)}}($tuple($(escall(args)...))))
203+
return :($SA{Tuple{$(size(args)...)}}($tuple($(escall(args)...))))
161204
elseif head === :comprehension
162205
if length(ex.args) != 1
163206
error("Expected generator in comprehension, e.g. [f(i,j) for i = 1:3, j = 1:3]")
@@ -173,7 +216,7 @@ function static_array_gen(::Type{SA}, @nospecialize(ex), mod::Module) where {SA}
173216
return quote
174217
let
175218
f($(escall(rng_args)...)) = $(esc(ex.args[1]))
176-
$SA{$Tuple{$(size(exprs)...)}}($tuple($(exprs...)))
219+
$SA{Tuple{$(size(exprs)...)}}($tuple($(exprs...)))
177220
end
178221
end
179222
elseif head === :typed_comprehension
@@ -192,26 +235,58 @@ function static_array_gen(::Type{SA}, @nospecialize(ex), mod::Module) where {SA}
192235
return quote
193236
let
194237
f($(escall(rng_args)...)) = $(esc(ex.args[1]))
195-
$SA{$Tuple{$(size(exprs)...)},$T}($tuple($(exprs...)))
238+
$SA{Tuple{$(size(exprs)...)},$T}($tuple($(exprs...)))
196239
end
197240
end
198241
elseif head === :call
199242
f = ex.args[1]
200-
if f === :zeros || f === :ones || f === :rand || f === :randn || f === :randexp
201-
if length(ex.args) == 1
202-
f === :zeros || f === :ones || error("@$SA got bad expression: $(ex)")
203-
return :($f($SA{$Tuple{},$Float64}))
204-
end
205-
return quote
206-
if isa($(esc(ex.args[2])), DataType)
207-
$f($SA{$Tuple{$(escall(ex.args[3:end])...)},$(esc(ex.args[2]))})
208-
else
209-
$f($SA{$Tuple{$(escall(ex.args[2:end])...)}})
210-
end
243+
fargs = ex.args[2:end]
244+
if f === :zeros || f === :ones
245+
_f_with_Val = Symbol(:_, f, :_with_Val)
246+
if length(fargs) == 0
247+
# for calls like `zeros()`
248+
return :($f($SA{Tuple{},$Float64}))
249+
elseif _isnonnegvec(fargs)
250+
# for calls like `zeros(dims...)`
251+
return :($f($SA{Tuple{$(escall(fargs)...)}}))
252+
else
253+
# for calls like `zeros(type)`
254+
# for calls like `zeros(type, dims...)`
255+
return :($_f_with_Val($SA, $(esc(fargs[1])), Val($(esc(fargs[1]))), Val(tuple($(escall(fargs[2:end])...)))))
211256
end
212257
elseif f === :fill
213-
length(ex.args) == 1 && error("@$SA got bad expression: $(ex)")
214-
return :($f($(esc(ex.args[2])), $SA{$Tuple{$(escall(ex.args[3:end])...)}}))
258+
# for calls like `fill(value, dims...)`
259+
return :($f($(esc(fargs[1])), $SA{Tuple{$(escall(fargs[2:end])...)}}))
260+
elseif f === :rand || f === :randn || f === :randexp
261+
_f_with_Val = Symbol(:_, f, :_with_Val)
262+
if length(fargs) == 0
263+
# No support for `@SArray rand()`
264+
error("@$SA got bad expression: $(ex)")
265+
elseif _isnonnegvec(fargs)
266+
# for calls like `rand(dims...)`
267+
return :($f($SA{Tuple{$(escall(fargs)...)}}))
268+
elseif length(fargs) 2
269+
# for calls like `rand(dim1, dim2, dims...)`
270+
# for calls like `rand(type, dim1, dims...)`
271+
# for calls like `rand(sampler, dim1, dims...)`
272+
# for calls like `rand(rng, dim1, dims...)`
273+
# for calls like `rand(rng, type, dims...)`
274+
# for calls like `rand(rng, sampler, dims...)`
275+
# for calls like `randn(dim1, dim2, dims...)`
276+
# for calls like `randn(type, dim1, dims...)`
277+
# for calls like `randn(rng, dim1, dims...)`
278+
# for calls like `randn(rng, type, dims...)`
279+
# for calls like `randexp(dim1, dim2, dims...)`
280+
# for calls like `randexp(type, dim1, dims...)`
281+
# for calls like `randexp(rng, dim1, dims...)`
282+
# for calls like `randexp(rng, type, dims...)`
283+
return :($_f_with_Val($SA, $(esc(fargs[1])), $(esc(fargs[2])), _int2val($(esc(fargs[1]))), _int2val($(esc(fargs[2]))), Val(tuple($(escall(fargs[3:end])...)))))
284+
elseif length(fargs) == 1
285+
# for calls like `rand(dim)`
286+
return :($f($SA{Tuple{$(escall(fargs)...)}}))
287+
else
288+
error("@$SA got bad expression: $(ex)")
289+
end
215290
else
216291
error("@$SA only supports the zeros(), ones(), fill(), rand(), randn(), and randexp() functions.")
217292
end

src/SMatrix.jl

+54-10
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,21 @@ function check_matrix_size(x::Tuple, T = :S)
1515
x1, x2
1616
end
1717

18+
# @SMatrix rand(...)
19+
_rand_with_Val(::Type{SM}, rng::AbstractRNG, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = rand(rng, SM{n1, n2})
20+
_rand_with_Val(::Type{SM}, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = _rand(Random.GLOBAL_RNG, T, Size(n1, n2), SM{n1, n2, T})
21+
_rand_with_Val(::Type{SM}, sampler, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = _rand(Random.GLOBAL_RNG, sampler, Size(n1, n2), SM{n1, n2, Random.gentype(sampler)})
22+
_rand_with_Val(::Type{SM}, rng::AbstractRNG, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = rand(rng, SM{n1, n2, T})
23+
_rand_with_Val(::Type{SM}, rng::AbstractRNG, sampler, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = _rand(rng, sampler, Size(n1, n2), SM{n1, n2, Random.gentype(sampler)})
24+
# @SMatrix randn(...)
25+
_randn_with_Val(::Type{SM}, rng::AbstractRNG, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = randn(rng, SM{n1, n2})
26+
_randn_with_Val(::Type{SM}, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = _randn(Random.GLOBAL_RNG, Size(n1, n2), SM{n1, n2, T})
27+
_randn_with_Val(::Type{SM}, rng::AbstractRNG, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = randn(rng, SM{n1, n2, T})
28+
# @SMatrix randexp(...)
29+
_randexp_with_Val(::Type{SM}, rng::AbstractRNG, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = randexp(rng, SM{n1, n2})
30+
_randexp_with_Val(::Type{SM}, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = _randexp(Random.GLOBAL_RNG, Size(n1, n2), SM{n1, n2, T})
31+
_randexp_with_Val(::Type{SM}, rng::AbstractRNG, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = randexp(rng, SM{n1, n2, T})
32+
1833
function static_matrix_gen(::Type{SM}, @nospecialize(ex), mod::Module) where {SM}
1934
if !isa(ex, Expr)
2035
error("Bad input for @$SM")
@@ -69,22 +84,51 @@ function static_matrix_gen(::Type{SM}, @nospecialize(ex), mod::Module) where {SM
6984
end
7085
elseif head === :call
7186
f = ex.args[1]
72-
if f === :zeros || f === :ones || f === :rand || f === :randn || f === :randexp
73-
if length(ex.args) == 3
74-
return :($f($SM{$(escall(ex.args[2:3])...), Float64})) # default to Float64 like Base
75-
elseif length(ex.args) == 4
76-
return :($f($SM{$(escall(ex.args[[3,4,2]])...)}))
87+
fargs = ex.args[2:end]
88+
if f === :zeros || f === :ones
89+
if length(fargs) == 2
90+
# for calls like `zeros(dim1, dim2)`
91+
return :($f($SM{$(escall(fargs)...)}))
92+
elseif length(fargs[2:end]) == 2
93+
# for calls like `zeros(type, dim1, dim2)`
94+
return :($f($SM{$(escall(fargs[2:end])...), $(esc(fargs[1]))}))
7795
else
78-
error("@$SM expected a 2-dimensional array expression")
96+
error("@$SM got bad expression: $(ex)")
7997
end
80-
elseif ex.args[1] === :fill
81-
if length(ex.args) == 4
82-
return :($f($(esc(ex.args[2])), $SM{$(escall(ex.args[3:4])...)}))
98+
elseif f === :fill
99+
# for calls like `fill(value, dim1, dim2)`
100+
if length(fargs[2:end]) == 2
101+
return :($f($(esc(fargs[1])), $SM{$(escall(fargs[2:end])...)}))
83102
else
84103
error("@$SM expected a 2-dimensional array expression")
85104
end
105+
elseif f === :rand || f === :randn || f === :randexp
106+
_f_with_Val = Symbol(:_, f, :_with_Val)
107+
if length(fargs) == 2
108+
# for calls like `rand(dim1, dim2)`
109+
# for calls like `randn(dim1, dim2)`
110+
# for calls like `randexp(dim1, dim2)`
111+
return :($f($SM{$(escall(fargs)...)}))
112+
elseif length(fargs) == 3
113+
# for calls like `rand(rng, dim1, dim2)`
114+
# for calls like `rand(type, dim1, dim2)`
115+
# for calls like `rand(sampler, dim1, dim2)`
116+
# for calls like `randn(rng, dim1, dim2)`
117+
# for calls like `randn(type, dim1, dim2)`
118+
# for calls like `randexp(rng, dim1, dim2)`
119+
# for calls like `randexp(type, dim1, dim2)`
120+
return :($_f_with_Val($SM, $(esc(fargs[1])), Val($(esc(fargs[2]))), Val($(esc(fargs[3])))))
121+
elseif length(fargs) == 4
122+
# for calls like `rand(rng, type, dim1, dim2)`
123+
# for calls like `rand(rng, sampler, dim1, dim2)`
124+
# for calls like `randn(rng, type, dim1, dim2)`
125+
# for calls like `randexp(rng, type, dim1, dim2)`
126+
return :($_f_with_Val($SM, $(esc(fargs[1])), $(esc(fargs[2])), Val($(esc(fargs[3]))), Val($(esc(fargs[4])))))
127+
else
128+
error("@$SM got bad expression: $(ex)")
129+
end
86130
else
87-
error("@$SM only supports the zeros(), ones(), rand(), randn(), and randexp() functions.")
131+
error("@$SM only supports the zeros(), ones(), fill(), rand(), randn(), and randexp() functions.")
88132
end
89133
else
90134
error("Bad input for @$SM")

src/SVector.jl

+54-10
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,21 @@ function check_vector_length(x::Tuple, T = :S)
1616
length(x) >= 1 ? x[1] : 1
1717
end
1818

19+
# @SVector rand(...)
20+
_rand_with_Val(::Type{SV}, rng::AbstractRNG, ::Val{n}) where {SV, n} = rand(rng, SV{n})
21+
_rand_with_Val(::Type{SV}, T::DataType, ::Val{n}) where {SV, n} = _rand(Random.GLOBAL_RNG, T, Size(n), SV{n, T})
22+
_rand_with_Val(::Type{SV}, sampler, ::Val{n}) where {SV, n} = _rand(Random.GLOBAL_RNG, sampler, Size(n), SV{n, Random.gentype(sampler)})
23+
_rand_with_Val(::Type{SV}, rng::AbstractRNG, T::DataType, ::Val{n}) where {SV, n} = rand(rng, SV{n, T})
24+
_rand_with_Val(::Type{SV}, rng::AbstractRNG, sampler, ::Val{n}) where {SV, n} = _rand(rng, sampler, Size(n), SV{n, Random.gentype(sampler)})
25+
# @SVector randn(...)
26+
_randn_with_Val(::Type{SV}, rng::AbstractRNG, ::Val{n}) where {SV, n} = randn(rng, SV{n})
27+
_randn_with_Val(::Type{SV}, T::DataType, ::Val{n}) where {SV, n} = _randn(Random.GLOBAL_RNG, Size(n), SV{n, T})
28+
_randn_with_Val(::Type{SV}, rng::AbstractRNG, T::DataType, ::Val{n}) where {SV, n} = randn(rng, SV{n, T})
29+
# @SVector randexp(...)
30+
_randexp_with_Val(::Type{SV}, rng::AbstractRNG, ::Val{n}) where {SV, n} = randexp(rng, SV{n})
31+
_randexp_with_Val(::Type{SV}, T::DataType, ::Val{n}) where {SV, n} = _randexp(Random.GLOBAL_RNG, Size(n), SV{n, T})
32+
_randexp_with_Val(::Type{SV}, rng::AbstractRNG, T::DataType, ::Val{n}) where {SV, n} = randexp(rng, SV{n, T})
33+
1934
function static_vector_gen(::Type{SV}, @nospecialize(ex), mod::Module) where {SV}
2035
if !isa(ex, Expr)
2136
error("Bad input for @$SV")
@@ -74,22 +89,51 @@ function static_vector_gen(::Type{SV}, @nospecialize(ex), mod::Module) where {SV
7489
end
7590
elseif head === :call
7691
f = ex.args[1]
77-
if f === :zeros || f === :ones || f === :rand || f === :randn || f === :randexp
78-
if length(ex.args) == 2
79-
return :($f($SV{$(esc(ex.args[2])), Float64})) # default to Float64 like Base
80-
elseif length(ex.args) == 3
81-
return :($f($SV{$(escall(ex.args[3:-1:2])...)}))
92+
fargs = ex.args[2:end]
93+
if f === :zeros || f === :ones
94+
if length(fargs) == 1
95+
# for calls like `zeros(dim)`
96+
return :($f($SV{$(esc(fargs[1]))}))
97+
elseif length(fargs) == 2
98+
# for calls like `zeros(type, dim)`
99+
return :($f($SV{$(esc(fargs[2])), $(esc(fargs[1]))}))
82100
else
83-
error("@$SV expected a 1-dimensional array expression")
101+
error("@$SV got bad expression: $(ex)")
84102
end
85-
elseif ex.args[1] === :fill
86-
if length(ex.args) == 3
87-
return :($f($(esc(ex.args[2])), $SV{$(esc(ex.args[3]))}))
103+
elseif f === :fill
104+
# for calls like `fill(value, dim)`
105+
if length(fargs) == 2
106+
return :($f($(esc(fargs[1])), $SV{$(esc(fargs[2]))}))
88107
else
89108
error("@$SV expected a 1-dimensional array expression")
90109
end
110+
elseif f === :rand || f === :randn || f === :randexp
111+
_f_with_Val = Symbol(:_, f, :_with_Val)
112+
if length(fargs) == 1
113+
# for calls like `rand(dim)`
114+
# for calls like `randn(dim)`
115+
# for calls like `randexp(dim)`
116+
return :($f($SV{$(escall(fargs)...)}))
117+
elseif length(fargs) == 2
118+
# for calls like `rand(rng, dim)`
119+
# for calls like `rand(type, dim)`
120+
# for calls like `rand(sampler, dim)`
121+
# for calls like `randn(rng, dim)`
122+
# for calls like `randn(type, dim)`
123+
# for calls like `randexp(rng, dim)`
124+
# for calls like `randexp(type, dim)`
125+
return :($_f_with_Val($SV, $(esc(fargs[1])), Val($(esc(fargs[2])))))
126+
elseif length(fargs) == 3
127+
# for calls like `rand(rng, type, dim)`
128+
# for calls like `rand(rng, sampler, dim)`
129+
# for calls like `randn(rng, type, dim)`
130+
# for calls like `randexp(rng, type, dim)`
131+
return :($_f_with_Val($SV, $(esc(fargs[1])), $(esc(fargs[2])), Val($(esc(fargs[3])))))
132+
else
133+
error("@$SV got bad expression: $(ex)")
134+
end
91135
else
92-
error("@$SV only supports the zeros(), ones(), rand(), randn() and randexp() functions.")
136+
error("@$SV only supports the zeros(), ones(), fill(), rand(), randn(), and randexp() functions.")
93137
end
94138
else
95139
error("Use @$SV [a,b,c], @$SV Type[a,b,c] or a comprehension like @$SV [f(i) for i = i_min:i_max]")

src/arraymath.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ end
8080

8181
@inline rand(rng::AbstractRNG, range::AbstractArray, ::Type{SA}) where {SA <: StaticArray} = _rand(rng, range, Size(SA), SA)
8282
@inline rand(range::AbstractArray, ::Type{SA}) where {SA <: StaticArray} = _rand(Random.GLOBAL_RNG, range, Size(SA), SA)
83-
@generated function _rand(rng::AbstractRNG, range::AbstractArray, ::Size{s}, ::Type{SA}) where {s, SA <: StaticArray}
84-
v = [:(rand(rng, range)) for i = 1:prod(s)]
83+
@generated function _rand(rng::AbstractRNG, X, ::Size{s}, ::Type{SA}) where {s, SA <: StaticArray}
84+
v = [:(rand(rng, X)) for i = 1:prod(s)]
8585
return quote
8686
@_inline_meta
8787
$SA(tuple($(v...)))

0 commit comments

Comments
 (0)