forked from ZIB-IOL/FrankWolfe.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.jl
487 lines (425 loc) · 14.2 KB
/
utils.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
##############################
### memory_mode macro
##############################
macro memory_mode(memory_mode, ex)
return esc(quote
if $memory_mode isa InplaceEmphasis
@. $ex
else
$ex
end
end)
end
"""
muladd_memory_mode(memory_mode::MemoryEmphasis, d, x, v)
Performs `d = x - v` in-place or not depending on MemoryEmphasis
"""
function muladd_memory_mode(memory_mode::MemoryEmphasis, d, x, v)
@memory_mode(memory_mode, d = x - v)
end
"""
(memory_mode::MemoryEmphasis, x, gamma::Real, d)
Performs `x = x - gamma * d` in-place or not depending on MemoryEmphasis
"""
function muladd_memory_mode(memory_mode::MemoryEmphasis, x, gamma::Real, d)
@memory_mode(memory_mode, x = x - gamma * d)
end
"""
(memory_mode::MemoryEmphasis, storage, x, gamma::Real, d)
Performs `storage = x - gamma * d` in-place or not depending on MemoryEmphasis
"""
function muladd_memory_mode(memory_mode::MemoryEmphasis, storage, x, gamma::Real, d)
@memory_mode(memory_mode, storage = x - gamma * d)
end
##############################################################
# simple benchmark of elementary costs of oracles and
# critical components
##############################################################
function benchmark_oracles(f, grad!, x_gen, lmo; k=100, nocache=true)
x = x_gen()
sv = sizeof(x) / 1024^2
println("\nSize of single atom ($(eltype(x))): $sv MB\n")
to = TimerOutput()
@showprogress 1 "Testing f... " for i in 1:k
x = x_gen()
@timeit to "f" temp = f(x)
end
@showprogress 1 "Testing grad... " for i in 1:k
x = x_gen()
temp = similar(x)
@timeit to "grad" grad!(temp, x)
end
@showprogress 1 "Testing lmo... " for i in 1:k
x = x_gen()
@timeit to "lmo" temp = compute_extreme_point(lmo, x)
end
@showprogress 1 "Testing dual gap... " for i in 1:k
x = x_gen()
gradient = collect(x)
grad!(gradient, x)
v = compute_extreme_point(lmo, gradient)
@timeit to "dual gap" begin
dual_gap = fast_dot(x, gradient) - fast_dot(v, gradient)
end
end
@showprogress 1 "Testing update... (Emphasis: OutplaceEmphasis) " for i in 1:k
x = x_gen()
gradient = collect(x)
grad!(gradient, x)
v = compute_extreme_point(lmo, gradient)
gamma = 1 / 2
@timeit to "update (OutplaceEmphasis)" @memory_mode(
OutplaceEmphasis(),
x = (1 - gamma) * x + gamma * v
)
end
@showprogress 1 "Testing update... (Emphasis: InplaceEmphasis) " for i in 1:k
x = x_gen()
gradient = collect(x)
grad!(gradient, x)
v = compute_extreme_point(lmo, gradient)
gamma = 1 / 2
# TODO: to be updated to broadcast version once data structure ScaledHotVector allows for it
@timeit to "update (InplaceEmphasis)" @memory_mode(
InplaceEmphasis(),
x = (1 - gamma) * x + gamma * v
)
end
if !nocache
@showprogress 1 "Testing caching 100 points... " for i in 1:k
@timeit to "caching 100 points" begin
cache = [gen_x() for _ in 1:100]
x = gen_x()
gradient = collect(x)
grad!(gradient, x)
v = compute_extreme_point(lmo, gradient)
gamma = 1 / 2
test = (x -> fast_dot(x, gradient)).(cache)
v = cache[argmin(test)]
val = v in cache
end
end
end
print_timer(to)
return nothing
end
"""
_unsafe_equal(a, b)
Like `isequal` on arrays but without the checks. Assumes a and b have the same axes.
"""
function _unsafe_equal(a::Array, b::Array)
if a === b
return true
end
@inbounds for idx in eachindex(a)
if a[idx] != b[idx]
return false
end
end
return true
end
_unsafe_equal(a, b) = isequal(a, b)
function _unsafe_equal(a::SparseArrays.AbstractSparseArray, b::SparseArrays.AbstractSparseArray)
return a == b
end
fast_dot(A, B) = dot(A, B)
fast_dot(B::SparseArrays.SparseMatrixCSC, A::Matrix) = conj(fast_dot(A, B))
function fast_dot(A::Matrix{T1}, B::SparseArrays.SparseMatrixCSC{T2}) where {T1,T2}
T = promote_type(T1, T2)
(m, n) = size(A)
if (m, n) != size(B)
throw(DimensionMismatch("Size mismatch"))
end
s = zero(T)
if m * n == 0
return s
end
rows = SparseArrays.rowvals(B)
vals = SparseArrays.nonzeros(B)
@inbounds for j in 1:n
for ridx in SparseArrays.nzrange(B, j)
i = rows[ridx]
v = vals[ridx]
s += v * conj(A[i, j])
end
end
return s
end
fast_dot(a, Q, b) = dot(a, Q, b)
function fast_dot(a::SparseArrays.AbstractSparseVector{<:Real}, Q::Diagonal{<:Real}, b::AbstractVector{<:Real})
if a === b
return _fast_quadratic_form_symmetric(a, Q)
end
d = Q.diag
nzvals = SparseArrays.nonzeros(a)
nzinds = SparseArrays.nonzeroinds(a)
return sum(eachindex(nzvals); init=zero(eltype(a))) do nzidx
nzvals[nzidx] * d[nzinds[nzidx]] * b[nzinds[nzidx]]
end
end
function fast_dot(a::SparseArrays.AbstractSparseVector{<:Real}, Q::Diagonal{<:Real}, b::SparseArrays.AbstractSparseVector{<:Real})
if a === b
return _fast_quadratic_form_symmetric(a, Q)
end
n = length(a)
if length(b) != n
throw(
DimensionMismatch("Vector a has a length $n but b has a length $(length(b))")
)
end
anzind = SparseArrays.nonzeroinds(a)
bnzind = SparseArrays.nonzeroinds(b)
anzval = SparseArrays.nonzeros(a)
bnzval = SparseArrays.nonzeros(b)
s = zero(Base.promote_eltype(a, Q, b))
if isempty(anzind) || isempty(bnzind)
return s
end
a_idx = 1
b_idx = 1
a_idx_last = length(anzind)
b_idx_last = length(bnzind)
# go through the nonzero indices of a and b simultaneously
@inbounds while a_idx <= a_idx_last && b_idx <= b_idx_last
ia = anzind[a_idx]
ib = bnzind[b_idx]
if ia == ib
s += dot(anzval[a_idx], Q.diag[ia], bnzval[b_idx])
a_idx += 1
b_idx += 1
elseif ia < ib
a_idx += 1
else
b_idx += 1
end
end
return s
end
function _fast_quadratic_form_symmetric(a, Q)
d = Q.diag
if length(d) != length(a)
throw(DimensionMismatch())
end
nzvals = SparseArrays.nonzeros(a)
nzinds = SparseArrays.nonzeroinds(a)
s = zero(Base.promote_eltype(a, Q))
@inbounds for nzidx in eachindex(nzvals)
s += nzvals[nzidx]^2 * d[nzinds[nzidx]]
end
return s
end
"""
trajectory_callback(storage)
Callback pushing the state at each iteration to the passed storage.
The state data is only the 5 first fields, usually:
`(t,primal,dual,dual_gap,time)`
"""
function trajectory_callback(storage)
return function push_trajectory!(data, args...)
return push!(storage, callback_state(data))
end
end
"""
momentum_iterate(iter::MomentumIterator) -> ρ
Method to implement for a type `MomentumIterator`.
Returns the next momentum value `ρ` and updates the iterator internal state.
"""
function momentum_iterate end
"""
ExpMomentumIterator{T}
Iterator for the momentum used in the variant of Stochastic Frank-Wolfe.
Momentum coefficients are the values of the iterator:
`ρ_t = 1 - num / (offset + t)^exp`
The state corresponds to the iteration count.
Source:
Stochastic Conditional Gradient Methods: From Convex Minimization to Submodular Maximization
Aryan Mokhtari, Hamed Hassani, Amin Karbasi, JMLR 2020.
"""
mutable struct ExpMomentumIterator{T}
exp::T
num::T
offset::T
iter::Int
end
ExpMomentumIterator() = ExpMomentumIterator(2 / 3, 4.0, 8.0, 0)
function momentum_iterate(em::ExpMomentumIterator)
em.iter += 1
return 1 - em.num / (em.offset + em.iter)^(em.exp)
end
"""
ConstantMomentumIterator{T}
Iterator for momentum with a fixed damping value, always return the value and a dummy state.
"""
struct ConstantMomentumIterator{T}
v::T
end
momentum_iterate(em::ConstantMomentumIterator) = em.v
# batch sizes
"""
batchsize_iterate(iter::BatchSizeIterator) -> b
Method to implement for a batch size iterator of type `BatchSizeIterator`.
Calling `batchsize_iterate` returns the next batch size and typically update the internal state of `iter`.
"""
function batchsize_iterate end
"""
ConstantBatchIterator(batch_size)
Batch iterator always returning a constant batch size.
"""
struct ConstantBatchIterator
batch_size::Int
end
batchsize_iterate(cbi::ConstantBatchIterator) = cbi.batch_size
"""
IncrementBatchIterator(starting_batch_size, max_batch_size, [increment = 1])
Batch size starting at starting_batch_size and incrementing by `increment` at every iteration.
"""
mutable struct IncrementBatchIterator
starting_batch_size::Int
max_batch_size::Int
increment::Int
iter::Int
maxreached::Bool
end
function IncrementBatchIterator(starting_batch_size::Int, max_batch_size::Int, increment::Int)
return IncrementBatchIterator(starting_batch_size, max_batch_size, increment, 0, false)
end
function IncrementBatchIterator(starting_batch_size::Int, max_batch_size::Int)
return IncrementBatchIterator(starting_batch_size, max_batch_size, 1, 0, false)
end
function batchsize_iterate(ibi::IncrementBatchIterator)
if ibi.maxreached
return ibi.max_batch_size
end
new_size = ibi.starting_batch_size + ibi.iter * ibi.increment
ibi.iter += 1
if new_size > ibi.max_batch_size
ibi.maxreached = true
return ibi.max_batch_size
end
return new_size
end
"""
Vertex storage to store dropped vertices or find a suitable direction in lazy settings.
The algorithm will look for at most `return_kth` suitable atoms before returning the best.
See [Extra-lazification with a vertex storage](@ref) for usage.
A vertex storage can be any type that implements two operations:
1. `Base.push!(storage, atom)` to add an atom to the storage.
Note that it is the storage type responsibility to ensure uniqueness of the atoms present.
2. `storage_find_argmin_vertex(storage, direction, lazy_threshold) -> (found, vertex)`
returning whether a vertex with sufficient progress was found and the vertex.
It is up to the storage to remove vertices (or not) when they have been picked up.
"""
struct DeletedVertexStorage{AT}
storage::Vector{AT}
return_kth::Int
end
DeletedVertexStorage(storage::Vector) = DeletedVertexStorage(storage, 1)
DeletedVertexStorage{AT}() where {AT} = DeletedVertexStorage(AT[])
function Base.push!(vertex_storage::DeletedVertexStorage{AT}, atom::AT) where {AT}
# do not push duplicates
if !any(v -> _unsafe_equal(atom, v), vertex_storage.storage)
push!(vertex_storage.storage, atom)
end
return vertex_storage
end
Base.length(storage::DeletedVertexStorage) = length(storage.storage)
"""
Computes the linear minimizer in the direction on the precomputed_set.
Precomputed_set stores the vertices computed as extreme points v in each iteration.
"""
function pre_computed_set_argminmax(lmo, pre_computed_set, direction, x; strong_lazification = false)
val = convert(eltype(direction), Inf)
valM = convert(eltype(direction), -Inf)
idx = -1
idxM = -1
for i in eachindex(pre_computed_set)
temp_val = fast_dot(pre_computed_set[i], direction)
if temp_val < val
val = temp_val
idx = i
end
if strong_lazification
if is_inface_feasible(lmo, pre_computed_set[i], x) && temp_val > valM
valM = temp_val
idxM = i
end
end
end
if idx == -1
error("Infinite minimum $val in the precomputed set. Does the gradient contain invalid (NaN / Inf) entries?")
end
v_local = pre_computed_set[idx]
a_local = idxM != -1 ? pre_computed_set[idxM] : nothing
return (v_local, idx, val, a_local, idxM, valM)
end
"""
Give the vertex `v` in the storage that minimizes `s = direction ⋅ v` and whether `s` achieves
`s ≤ lazy_threshold`.
"""
function storage_find_argmin_vertex(vertex_storage::DeletedVertexStorage, direction, lazy_threshold)
if isempty(vertex_storage.storage)
return (false, nothing)
end
best_idx = 1
best_val = lazy_threshold
found_good = false
counter = 0
for (idx, atom) in enumerate(vertex_storage.storage)
s = dot(direction, atom)
if s < best_val
counter += 1
best_val = s
found_good = true
best_idx = idx
if counter ≥ vertex_storage.return_kth
return (found_good, vertex_storage.storage[best_idx])
end
end
end
return (found_good, vertex_storage.storage[best_idx])
end
# temporary fix because argmin is broken on julia 1.8
argmin_(v) = argmin(v)
function argmin_(v::SparseArrays.SparseVector{T}) where {T}
if isempty(v.nzind)
return 1
end
idx = -1
val = T(Inf)
for s_idx in eachindex(v.nzind)
if v.nzval[s_idx] < val
val = v.nzval[s_idx]
idx = s_idx
end
end
# if min value is already negative or the indices were all checked
if val < 0 || length(v.nzind) == length(v)
return v.nzind[idx]
end
# otherwise, find the first zero
for idx in eachindex(v)
if idx ∉ v.nzind
return idx
end
end
error("unreachable")
end
"""
Given an array `array`, `NegatingArray` represents `-1 * array` lazily.
"""
struct NegatingArray{T, N, AT <: AbstractArray{T,N}} <: AbstractArray{T, N}
array::AT
function NegatingArray(array::AT) where {T, N, AT <: AbstractArray{T,N}}
return new{T, N, AT}(array)
end
end
Base.size(a::NegatingArray) = Base.size(a.array)
Base.getindex(a::NegatingArray, idxs...) = -Base.getindex(a.array, idxs...)
LinearAlgebra.dot(a1::NegatingArray, a2::NegatingArray) = dot(a1.array, a2.array)
LinearAlgebra.dot(a1::NegatingArray{T1, N}, a2::AbstractArray{T2, N}) where {T1, T2, N} = -dot(a1.array, a2)
LinearAlgebra.dot(a1::AbstractArray{T1, N}, a2::NegatingArray{T2, N}) where {T1, T2, N} = -dot(a1, a2.array)
Base.sum(a::NegatingArray) = -sum(a.array)
function weight_purge_threshold_default(::Type{T}) where {T<:AbstractFloat}
return sqrt(eps(T) * Base.rtoldefault(T)) # around 1e-12 for Float64
end
weight_purge_threshold_default(::Type{T}) where {T<:Number} = Base.rtoldefault(T)