Skip to content

Commit e55674e

Browse files
lazied_dicg (#532)
1 parent dddb6bf commit e55674e

File tree

5 files changed

+302
-131
lines changed

5 files changed

+302
-131
lines changed

src/dicg.jl

+178-31
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ function decomposition_invariant_conditional_gradient(
5151
traj_data=[],
5252
timeout=Inf,
5353
lazy=false,
54+
use_strong_lazy = false,
5455
linesearch_workspace=nothing,
5556
lazy_tolerance=2.0,
5657
extra_vertex_storage=nothing,
@@ -169,14 +170,26 @@ function decomposition_invariant_conditional_gradient(
169170

170171
if lazy
171172
d, v, v_index, a, away_index, phi, step_type =
172-
lazy_dicg_step(x, gradient, lmo, pre_computed_set, phi, epsilon, d;)
173+
lazy_standard_dicg_step(
174+
x,
175+
gradient,
176+
lmo,
177+
pre_computed_set,
178+
phi,
179+
epsilon,
180+
d;
181+
strong_lazification = use_strong_lazy,
182+
lazy_tolerance = lazy_tolerance,
183+
)
173184
else # non-lazy, call the simple and modified
174185
v = compute_extreme_point(lmo, gradient, lazy=lazy)
175186
dual_gap = fast_dot(gradient, x) - fast_dot(gradient, v)
176187
phi = dual_gap
177188
a = compute_inface_extreme_point(lmo, NegatingArray(gradient), x; lazy=lazy)
189+
d = muladd_memory_mode(memory_mode, d, a, v)
190+
step_type = ST_PAIRWISE
178191
end
179-
d = muladd_memory_mode(memory_mode, d, a, v)
192+
180193
gamma_max = dicg_maximum_step(lmo, d, x)
181194
gamma = perform_line_search(
182195
line_search,
@@ -190,7 +203,7 @@ function decomposition_invariant_conditional_gradient(
190203
linesearch_workspace,
191204
memory_mode,
192205
)
193-
206+
194207
if lazy
195208
idx = findfirst(x -> x == v, pre_computed_set)
196209
if idx !== nothing
@@ -282,6 +295,7 @@ function blended_decomposition_invariant_conditional_gradient(
282295
lazy=false,
283296
linesearch_workspace=nothing,
284297
lazy_tolerance=2.0,
298+
extra_vertex_storage = nothing,
285299
)
286300

287301
if !is_decomposition_invariant_oracle(lmo)
@@ -353,6 +367,15 @@ function blended_decomposition_invariant_conditional_gradient(
353367
phi = primal
354368
gamma = one(phi)
355369

370+
if lazy
371+
if extra_vertex_storage === nothing
372+
v = compute_extreme_point(lmo, gradient, lazy = lazy)
373+
pre_computed_set = [v]
374+
else
375+
pre_computed_set = extra_vertex_storage
376+
end
377+
end
378+
356379
if linesearch_workspace === nothing
357380
linesearch_workspace = build_linesearch_workspace(line_search, x, gradient)
358381
end
@@ -385,7 +408,18 @@ function blended_decomposition_invariant_conditional_gradient(
385408
end
386409

387410
if lazy
388-
error("not implemented yet")
411+
d, v, v_index, a, away_index, phi, step_type =
412+
lazy_blended_dicg_step(
413+
x,
414+
gradient,
415+
lmo,
416+
pre_computed_set,
417+
phi,
418+
epsilon,
419+
d;
420+
strong_lazification = use_strong_lazy,
421+
lazy_tolerance = lazy_tolerance,
422+
)
389423
else # non-lazy, call the simple and modified
390424
a = compute_inface_extreme_point(lmo, NegatingArray(gradient), x; lazy=lazy)
391425
v_inface = compute_inface_extreme_point(lmo, gradient, x; lazy=lazy)
@@ -404,6 +438,11 @@ function blended_decomposition_invariant_conditional_gradient(
404438
gamma_max = one(phi)
405439
end
406440
end
441+
if step_type == ST_REGULAR
442+
gamma_max = one(phi)
443+
else
444+
gamma_max = dicg_maximum_step(lmo, d, x)
445+
end
407446
gamma = perform_line_search(
408447
line_search,
409448
t,
@@ -475,60 +514,168 @@ function blended_decomposition_invariant_conditional_gradient(
475514
return (x=x, v=v, primal=primal, dual_gap=dual_gap, traj_data=traj_data)
476515
end
477516

478-
function lazy_dicg_step(
517+
"""
518+
Search for both lazified FW vertex and in-face vetex in strong version.
519+
Otherwise, only search for the lazified FW vertex.
520+
"""
521+
function lazy_standard_dicg_step(
479522
x,
480523
gradient,
481524
lmo,
482525
pre_computed_set,
483526
phi,
484527
epsilon,
485528
d;
486-
use_extra_vertex_storage=false,
487-
extra_vertex_storage=nothing,
488-
lazy_tolerance=2.0,
489-
memory_mode::MemoryEmphasis=InplaceEmphasis(),
529+
strong_lazification = false,
530+
lazy_tolerance = 2.0,
531+
memory_mode::MemoryEmphasis = InplaceEmphasis(),
490532
)
491533
v_local, v_local_loc, val, a_local, a_local_loc, valM =
492-
pre_computed_set_argminmax(pre_computed_set, gradient)
493-
step_type = ST_REGULAR
534+
pre_computed_set_argminmax(lmo, pre_computed_set, gradient, x; strong_lazification = strong_lazification)
535+
step_type = ST_PAIRWISE
494536
away_index = nothing
495537
fw_index = nothing
496538
grad_dot_x = fast_dot(x, gradient)
497539
grad_dot_a_local = valM
498-
499-
# Do lazy pairwise step
500540
grad_dot_lazy_fw_vertex = val
501541

502-
if grad_dot_a_local - grad_dot_lazy_fw_vertex >= phi / lazy_tolerance &&
503-
grad_dot_a_local - grad_dot_lazy_fw_vertex >= epsilon
542+
if strong_lazification
543+
a_taken = a_local
544+
grad_dot_a_taken = grad_dot_a_local
545+
# in-face LMO is called directly
546+
else
547+
a_taken = compute_inface_extreme_point(lmo, NegatingArray(gradient), x)
548+
grad_dot_a_taken = fast_dot(gradient, a_taken)
549+
end
550+
551+
# Do lazy pairwise step
552+
if grad_dot_a_taken - grad_dot_lazy_fw_vertex >= phi &&
553+
grad_dot_a_taken - grad_dot_lazy_fw_vertex >= epsilon
504554
step_type = ST_LAZY
505555
v = v_local
506-
a = a_local
556+
a = a_taken
507557
d = muladd_memory_mode(memory_mode, d, a, v)
508558
fw_index = v_local_loc
509559
else
510560
v = compute_extreme_point(lmo, gradient)
511561
grad_dot_v = fast_dot(gradient, v)
512-
# Do lazy inface_point
513-
if grad_dot_a_local - grad_dot_v >= phi / lazy_tolerance &&
514-
grad_dot_a_local - grad_dot_v >= epsilon
515-
step_type = ST_LAZY
516-
a = a_local
517-
away_index = a_local_loc
562+
dual_gap = grad_dot_x - grad_dot_v
563+
564+
if grad_dot_a_taken - grad_dot_v >= phi/lazy_tolerance &&
565+
grad_dot_a_taken - grad_dot_v >= epsilon
566+
a = a_taken
567+
d = muladd_memory_mode(memory_mode, d, a, v)
568+
step_type = strong_lazification ? ST_LAZY : ST_PAIRWISE
569+
away_index = strong_lazification ? a_local_loc : nothing
570+
elseif dual_gap >= phi / lazy_tolerance
571+
if strong_lazification
572+
a = compute_inface_extreme_point(lmo, NegatingArray(gradient), x)
573+
else
574+
a = a_taken
575+
end
576+
d = muladd_memory_mode(memory_mode, d, a, v)
577+
# lower our expectation
518578
else
519-
a = compute_inface_extreme_point(lmo, NegatingArray(gradient), x)
579+
step_type = ST_DUALSTEP
580+
phi = min(dual_gap, phi / 2.0)
581+
a = a_taken
582+
d = zeros(length(x))
520583
end
584+
end
521585

522-
# Real dual gap promises enough progress.
523-
grad_dot_fw_vertex = fast_dot(v, gradient)
524-
dual_gap = grad_dot_x - grad_dot_fw_vertex
586+
return d, v, fw_index, a, away_index, phi, step_type
587+
end
525588

526-
if dual_gap >= phi / lazy_tolerance
527-
d = muladd_memory_mode(memory_mode, d, a, v)
528-
#Lower our expectation for progress.
589+
"""
590+
Lazification for Blended DICG.
591+
Search for in-face vertex and local FW vertex only in strong version.
592+
"""
593+
function lazy_blended_dicg_step(
594+
x,
595+
gradient,
596+
lmo,
597+
pre_computed_set,
598+
phi,
599+
epsilon,
600+
d;
601+
strong_lazification = false,
602+
lazy_tolerance = 2.0,
603+
memory_mode::MemoryEmphasis = InplaceEmphasis(),
604+
)
605+
v_local, v_local_loc, val, a_local, a_local_loc, valM =
606+
pre_computed_set_argminmax(lmo, pre_computed_set, gradient, x; strong_lazification = strong_lazification)
607+
step_type = ST_PAIRWISE
608+
away_index = nothing
609+
fw_index = nothing
610+
grad_dot_x = fast_dot(x, gradient)
611+
grad_dot_a_local = valM
612+
grad_dot_lazy_fw_vertex = val
613+
614+
if strong_lazification
615+
a_taken = a_local
616+
v_taken = v_local
617+
grad_dot_a_taken = grad_dot_a_local
618+
grad_dot_v_taken = grad_dot_lazy_fw_vertex
619+
else
620+
a_taken = compute_inface_extreme_point(lmo, NegatingArray(gradient), x)
621+
v_taken = compute_inface_extreme_point(lmo, gradient, x)
622+
grad_dot_a_taken = fast_dot(gradient, a_taken)
623+
grad_dot_v_taken = fast_dot(gradient, v_taken)
624+
end
625+
626+
# Do lazy pairwise step
627+
if grad_dot_a_taken - grad_dot_v_taken >= phi &&
628+
grad_dot_a_taken - grad_dot_v_taken >= epsilon
629+
step_type = ST_LAZY
630+
v = v_taken
631+
a = a_taken
632+
d = muladd_memory_mode(memory_mode, d, a, v)
633+
fw_index = v_local_loc
634+
away_index = a_local_loc
635+
else
636+
if strong_lazification
637+
v_inface = compute_inface_extreme_point(lmo, gradient)
638+
grad_dot_v_inface = fast_dot(gradient, v_inface)
639+
640+
if grad_dot_a_taken - grad_dot_v_inface >= phi &&
641+
grad_dot_a_taken - grad_dot_v_inface >= epsilon
642+
step_type = ST_LAZY
643+
v = v_inface
644+
a = a_taken
645+
d = muladd_memory_mode(memory_mode, d, a, v)
646+
away_index = a_local_loc
647+
end
529648
else
530-
d = muladd_memory_mode(memory_mode, d, a, v)
531-
phi = min(dual_gap, phi / 2.0)
649+
v_inface = v_taken
650+
grad_dot_v_inface = grad_dot_v_taken
651+
end
652+
653+
if step_type !== ST_LAZY
654+
v = compute_extreme_point(lmo, gradient)
655+
grad_dot_v = fast_dot(gradient, v)
656+
dual_gap = grad_dot_x - grad_dot_v
657+
if dual_gap >= phi / lazy_tolerance
658+
659+
if strong_lazification
660+
a_taken = compute_inface_extreme_point(lmo, NegatingArray(gradient), x)
661+
grad_dot_a_taken = fast_dot(gradient, a_taken)
662+
end
663+
664+
if grad_dot_a_taken - grad_dot_v_inface >= grad_dot_x - grad_dot_v / lazy_tolerance
665+
step_type = ST_PAIRWISE
666+
a = a_taken
667+
d = muladd_memory_mode(memory_mode, d, a, v_inface)
668+
else
669+
step_type = ST_REGULAR
670+
a = x
671+
d = muladd_memory_mode(memory_mode, d, x, v)
672+
end
673+
else
674+
step_type = ST_DUALSTEP
675+
phi = min(dual_gap, phi / 2.0)
676+
a = a_taken
677+
d = zeros(length(x))
678+
end
532679
end
533680
end
534681
return d, v, fw_index, a, away_index, phi, step_type

0 commit comments

Comments
 (0)