Skip to content

Commit d032934

Browse files
committed
add runtime version of macro gendict so that @defvar(m,x[1:3]) and T=1:3;@defvar(m,x[T]) have the same semantics
1 parent 6b179d2 commit d032934

File tree

4 files changed

+168
-4
lines changed

4 files changed

+168
-4
lines changed

src/JuMPContainer.jl

+82-2
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ Base.isempty(d::JuMPContainer) = isempty(_innercontainer(d))
6969
# 0:K -- range with compile-time starting index
7070
# S -- general iterable set
7171
export @gendict
72-
macro gendict(instancename,T,idxpairs,idxsets...)
72+
macro gendict(instancename,T,idxsets...)
7373
N = length(idxsets)
7474
allranges = all(s -> (isexpr(s,:(:)) && length(s.args) == 2), idxsets)
7575
truearray = allranges && all(s -> s.args[1] == 1, idxsets)
@@ -165,12 +165,92 @@ macro gendict(instancename,T,idxpairs,idxsets...)
165165
=#
166166
else
167167
# JuMPDict
168+
escidxs = [esc(idxset) for idxset in idxsets]
168169
return :(
169-
$(esc(instancename)) = JuMPDict{$T,$N}()
170+
$(esc(instancename)) = (if is_unit_ranges($(escidxs...))
171+
runtime_gendict($T, $(escidxs...))
172+
else
173+
JuMPDict{$T,$N}()
174+
end)
170175
)
171176
end
172177
end
173178

179+
function runtime_gendict(T,idxsets...)
180+
N = length(idxsets)
181+
allranges = all(s -> (typeof(s) <: Range), idxsets)
182+
truearray = allranges && all(s -> (first(s) == 1), idxsets)
183+
if allranges
184+
if truearray
185+
return Array(T, [last(rng) for rng in idxsets]...)
186+
else
187+
typename = symbol(string("JuMPArray",gensym()))
188+
dictnames = Array(Symbol,N)
189+
# JuMPArray
190+
offset = Array(Int,N)
191+
for i in 1:N
192+
offset[i] = 1 - first(idxsets[i])
193+
end
194+
typecode = quote
195+
type $(typename){T} <: JuMPArray{T,$N}
196+
innerArray::Array{T,$N}
197+
meta::Dict{Symbol,Any}
198+
end
199+
end
200+
constrlhs = :($(typename)(innerArray::Array))
201+
constrrhs = :($(typename)(innerArray, Dict{Symbol,Any}()))
202+
getidxlhs = :(Base.getindex(d::$(typename)))
203+
setidxlhs = :(Base.setindex!(d::$(typename),val))
204+
getidxrhs = :(Base.getindex(d.innerArray))
205+
setidxrhs = :(Base.setindex!(d.innerArray,val))
206+
maplhs = :(Base.map(f::Function,d::$(typename)))
207+
maprhs = :($(typename)(map(f,d.innerArray),d.meta))
208+
wraplhs = :(JuMPContainer_from(d::$(typename),inner)) # helper function that wraps array into JuMPArray of similar type
209+
wraprhs = :($(typename)(inner))
210+
211+
nextidxlhs = :(_next_index(d::$(typename), k))
212+
# build up exprs for _next_index
213+
lidxsets = [ii => symbol(string("locidxset",ii)) for ii in 1:N]
214+
nextidxrhs = quote
215+
subidx = ind2sub(size(d), k)
216+
$(Expr(:tuple, [:(subidx[$ii] - $(offset[ii])) for ii in 1:N]...))
217+
end
218+
for i in 1:N
219+
varname = symbol(string("x",i))
220+
221+
push!(getidxlhs.args,:($varname))
222+
push!(setidxlhs.args,:($varname))
223+
224+
push!(getidxrhs.args,:(isa($varname, Int) ? $varname+$(offset[i]) : $varname ))
225+
push!(setidxrhs.args,:($varname+$(offset[i])))
226+
227+
end
228+
229+
badgetidxlhs = :(Base.getindex(d::$(typename),wrong...))
230+
badgetidxrhs = :(data = printdata(d);
231+
error("Wrong number of indices for ",data.name, ", expected ",length(data.indexsets)))
232+
233+
funcs = quote
234+
$constrlhs = $constrrhs
235+
$getidxlhs = $getidxrhs
236+
$setidxlhs = $setidxrhs
237+
$maplhs = $maprhs
238+
$badgetidxlhs = $badgetidxrhs
239+
$wraplhs = $wraprhs
240+
$nextidxlhs = $nextidxrhs
241+
end
242+
243+
eval(Expr(:toplevel, typecode))
244+
eval(Expr(:toplevel, funcs))
245+
246+
return eval(:($(typename)(Array($T, [length(idxset) for idxset in $idxsets]...))))
247+
end
248+
else
249+
error("Should not reach this point")
250+
end
251+
end
252+
253+
@generated is_unit_ranges(idxsets...) = :($(all(s -> s <: UnitRange{Int}, idxsets)))
174254
pushmeta!(x::JuMPContainer, sym::Symbol, val) = (x.meta[sym] = val)
175255
getmeta(x::JuMPContainer, sym::Symbol) = x.meta[sym]
176256

src/macros.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ function getloopedcode(c::Expr, code, condition, idxvars, idxsets, idxpairs, sym
119119
N = length(idxsets)
120120
mac = :($(esc(varname)) = JuMPDict{$(sym),$N}())
121121
else
122-
mac = Expr(:macrocall,symbol("@gendict"),esc(varname),sym,idxpairs,idxsets...)
122+
mac = Expr(:macrocall,symbol("@gendict"),esc(varname),sym,idxsets...)
123123
end
124124
return quote
125125
$mac
@@ -877,7 +877,7 @@ macro defConstrRef(var)
877877
idxsets = var.args[2:end]
878878
idxpairs = IndexPair[]
879879

880-
mac = Expr(:macrocall,symbol("@gendict"),varname,:ConstraintRef,idxpairs, idxsets...)
880+
mac = Expr(:macrocall,symbol("@gendict"), varname, :ConstraintRef, idxsets...)
881881
code = quote
882882
$(esc(mac))
883883
nothing

test/perf/JuMPArray-iteration.jl

+20
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,23 @@ bench(1)
1818
bench(100)
1919
bench(1000)
2020
bench(2000)
21+
22+
function bench_runtime(n)
23+
t1 = @elapsed begin
24+
m = Model()
25+
I, J = 1:n, 2:n
26+
@defVar(m, x[I,J])
27+
end
28+
t2 = @elapsed begin
29+
cntr = 0
30+
for (ii,jj,v) in x
31+
cntr += ii + jj + v.col
32+
end
33+
end
34+
t1, t2
35+
end
36+
37+
bench_runtime(1)
38+
bench_runtime(100)
39+
bench_runtime(1000)
40+
bench_runtime(2000)

test/perf/macro.jl

+64
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,67 @@ for N in [20,50,100]
6363
println(" N=$(N) min $(minimum(N2_times))")
6464
end
6565

66+
function test_linear_runtime(N)
67+
m = Model()
68+
I,J = 1:10N, 1:5N
69+
@defVar(m, x[I,J])
70+
K = 1:N
71+
@defVar(m, y[K,K,K])
72+
73+
for z in 1:10
74+
@addConstraint(m,
75+
9*y[1,1,1] - 5*y[N,N,N] -
76+
2*sum{ z*x[j,i*N], j=((z-1)*N+1):z*N, i=3:4} +
77+
sum{ i*(9*x[i,j] + 3*x[j,i]), i=N:2N, j=N:2N} +
78+
x[1,1] + x[10N,5N] + x[2N,1] +
79+
1*y[1,1,N] + 2*y[1,N,1] + 3*y[N,1,1] +
80+
y[N,N,N] - 2*y[N,N,N] + 3*y[N,N,N]
81+
<=
82+
sum{sum{sum{N*i*j*k*y[i,j,k] + x[i,j],k=1:N; i!=j && j!=k},j=1:N},i=1:N} +
83+
sum{sum{x[i,j], j=1:5N; j % i == 3}, i=1:10N; i <= N*z}
84+
)
85+
end
86+
end
87+
88+
function test_quad_runtime(N)
89+
m = Model()
90+
I,J = 1:10N, 1:5N
91+
@defVar(m, x[I,J])
92+
K = 1:N
93+
@defVar(m, y[K,K,K])
94+
95+
for z in 1:10
96+
@addConstraint(m,
97+
9*y[1,1,1] - 5*y[N,N,N] -
98+
2*sum{ z*x[j,i*N], j=((z-1)*N+1):z*N, i=3:4} +
99+
sum{ i*(9*x[i,j] + 3*x[j,i]), i=N:2N, j=N:2N} +
100+
x[1,1] + x[10N,5N] * x[2N,1] +
101+
1*y[1,1,N] * 2*y[1,N,1] + 3*y[N,1,1] +
102+
y[N,N,N] - 2*y[N,N,N] * 3*y[N,N,N]
103+
<=
104+
sum{sum{sum{N*i*j*k*y[i,j,k] * x[i,j],k=1:N; i!=j && j!=k},j=1:N},i=1:N} +
105+
sum{sum{x[i,j], j=1:5N; j % i == 3}, i=1:10N; i <= N*z}
106+
)
107+
end
108+
end
109+
110+
111+
# Warmup
112+
println("Test 2 (runtime)")
113+
test_linear_runtime(1)
114+
test_quad_runtime(1)
115+
for N in [20,50,100]
116+
println(" Running N=$(N)...")
117+
N1_times = {}
118+
N2_times = {}
119+
for iter in 1:10
120+
tic()
121+
test_linear_runtime(N)
122+
push!(N1_times, toq())
123+
tic()
124+
test_quad_runtime(N)
125+
push!(N2_times, toq())
126+
end
127+
println(" N=$(N) min $(minimum(N1_times))")
128+
println(" N=$(N) min $(minimum(N2_times))")
129+
end

0 commit comments

Comments
 (0)