@@ -51,6 +51,7 @@ function decomposition_invariant_conditional_gradient(
51
51
traj_data= [],
52
52
timeout= Inf ,
53
53
lazy= false ,
54
+ use_strong_lazy = false ,
54
55
linesearch_workspace= nothing ,
55
56
lazy_tolerance= 2.0 ,
56
57
extra_vertex_storage= nothing ,
@@ -169,14 +170,26 @@ function decomposition_invariant_conditional_gradient(
169
170
170
171
if lazy
171
172
d, v, v_index, a, away_index, phi, step_type =
172
- lazy_dicg_step (x, gradient, lmo, pre_computed_set, phi, epsilon, d;)
173
+ lazy_standard_dicg_step (
174
+ x,
175
+ gradient,
176
+ lmo,
177
+ pre_computed_set,
178
+ phi,
179
+ epsilon,
180
+ d;
181
+ strong_lazification = use_strong_lazy,
182
+ lazy_tolerance = lazy_tolerance,
183
+ )
173
184
else # non-lazy, call the simple and modified
174
185
v = compute_extreme_point (lmo, gradient, lazy= lazy)
175
186
dual_gap = fast_dot (gradient, x) - fast_dot (gradient, v)
176
187
phi = dual_gap
177
188
a = compute_inface_extreme_point (lmo, NegatingArray (gradient), x; lazy= lazy)
189
+ d = muladd_memory_mode (memory_mode, d, a, v)
190
+ step_type = ST_PAIRWISE
178
191
end
179
- d = muladd_memory_mode (memory_mode, d, a, v)
192
+
180
193
gamma_max = dicg_maximum_step (lmo, d, x)
181
194
gamma = perform_line_search (
182
195
line_search,
@@ -190,7 +203,7 @@ function decomposition_invariant_conditional_gradient(
190
203
linesearch_workspace,
191
204
memory_mode,
192
205
)
193
-
206
+
194
207
if lazy
195
208
idx = findfirst (x -> x == v, pre_computed_set)
196
209
if idx != = nothing
@@ -282,6 +295,7 @@ function blended_decomposition_invariant_conditional_gradient(
282
295
lazy= false ,
283
296
linesearch_workspace= nothing ,
284
297
lazy_tolerance= 2.0 ,
298
+ extra_vertex_storage = nothing ,
285
299
)
286
300
287
301
if ! is_decomposition_invariant_oracle (lmo)
@@ -353,6 +367,15 @@ function blended_decomposition_invariant_conditional_gradient(
353
367
phi = primal
354
368
gamma = one (phi)
355
369
370
+ if lazy
371
+ if extra_vertex_storage === nothing
372
+ v = compute_extreme_point (lmo, gradient, lazy = lazy)
373
+ pre_computed_set = [v]
374
+ else
375
+ pre_computed_set = extra_vertex_storage
376
+ end
377
+ end
378
+
356
379
if linesearch_workspace === nothing
357
380
linesearch_workspace = build_linesearch_workspace (line_search, x, gradient)
358
381
end
@@ -385,7 +408,18 @@ function blended_decomposition_invariant_conditional_gradient(
385
408
end
386
409
387
410
if lazy
388
- error (" not implemented yet" )
411
+ d, v, v_index, a, away_index, phi, step_type =
412
+ lazy_blended_dicg_step (
413
+ x,
414
+ gradient,
415
+ lmo,
416
+ pre_computed_set,
417
+ phi,
418
+ epsilon,
419
+ d;
420
+ strong_lazification = use_strong_lazy,
421
+ lazy_tolerance = lazy_tolerance,
422
+ )
389
423
else # non-lazy, call the simple and modified
390
424
a = compute_inface_extreme_point (lmo, NegatingArray (gradient), x; lazy= lazy)
391
425
v_inface = compute_inface_extreme_point (lmo, gradient, x; lazy= lazy)
@@ -404,6 +438,11 @@ function blended_decomposition_invariant_conditional_gradient(
404
438
gamma_max = one (phi)
405
439
end
406
440
end
441
+ if step_type == ST_REGULAR
442
+ gamma_max = one (phi)
443
+ else
444
+ gamma_max = dicg_maximum_step (lmo, d, x)
445
+ end
407
446
gamma = perform_line_search (
408
447
line_search,
409
448
t,
@@ -475,60 +514,168 @@ function blended_decomposition_invariant_conditional_gradient(
475
514
return (x= x, v= v, primal= primal, dual_gap= dual_gap, traj_data= traj_data)
476
515
end
477
516
478
- function lazy_dicg_step (
517
+ """
518
+ Search for both lazified FW vertex and in-face vetex in strong version.
519
+ Otherwise, only search for the lazified FW vertex.
520
+ """
521
+ function lazy_standard_dicg_step (
479
522
x,
480
523
gradient,
481
524
lmo,
482
525
pre_computed_set,
483
526
phi,
484
527
epsilon,
485
528
d;
486
- use_extra_vertex_storage= false ,
487
- extra_vertex_storage= nothing ,
488
- lazy_tolerance= 2.0 ,
489
- memory_mode:: MemoryEmphasis = InplaceEmphasis (),
529
+ strong_lazification = false ,
530
+ lazy_tolerance = 2.0 ,
531
+ memory_mode:: MemoryEmphasis = InplaceEmphasis (),
490
532
)
491
533
v_local, v_local_loc, val, a_local, a_local_loc, valM =
492
- pre_computed_set_argminmax (pre_computed_set, gradient)
493
- step_type = ST_REGULAR
534
+ pre_computed_set_argminmax (lmo, pre_computed_set, gradient, x; strong_lazification = strong_lazification )
535
+ step_type = ST_PAIRWISE
494
536
away_index = nothing
495
537
fw_index = nothing
496
538
grad_dot_x = fast_dot (x, gradient)
497
539
grad_dot_a_local = valM
498
-
499
- # Do lazy pairwise step
500
540
grad_dot_lazy_fw_vertex = val
501
541
502
- if grad_dot_a_local - grad_dot_lazy_fw_vertex >= phi / lazy_tolerance &&
503
- grad_dot_a_local - grad_dot_lazy_fw_vertex >= epsilon
542
+ if strong_lazification
543
+ a_taken = a_local
544
+ grad_dot_a_taken = grad_dot_a_local
545
+ # in-face LMO is called directly
546
+ else
547
+ a_taken = compute_inface_extreme_point (lmo, NegatingArray (gradient), x)
548
+ grad_dot_a_taken = fast_dot (gradient, a_taken)
549
+ end
550
+
551
+ # Do lazy pairwise step
552
+ if grad_dot_a_taken - grad_dot_lazy_fw_vertex >= phi &&
553
+ grad_dot_a_taken - grad_dot_lazy_fw_vertex >= epsilon
504
554
step_type = ST_LAZY
505
555
v = v_local
506
- a = a_local
556
+ a = a_taken
507
557
d = muladd_memory_mode (memory_mode, d, a, v)
508
558
fw_index = v_local_loc
509
559
else
510
560
v = compute_extreme_point (lmo, gradient)
511
561
grad_dot_v = fast_dot (gradient, v)
512
- # Do lazy inface_point
513
- if grad_dot_a_local - grad_dot_v >= phi / lazy_tolerance &&
514
- grad_dot_a_local - grad_dot_v >= epsilon
515
- step_type = ST_LAZY
516
- a = a_local
517
- away_index = a_local_loc
562
+ dual_gap = grad_dot_x - grad_dot_v
563
+
564
+ if grad_dot_a_taken - grad_dot_v >= phi/ lazy_tolerance &&
565
+ grad_dot_a_taken - grad_dot_v >= epsilon
566
+ a = a_taken
567
+ d = muladd_memory_mode (memory_mode, d, a, v)
568
+ step_type = strong_lazification ? ST_LAZY : ST_PAIRWISE
569
+ away_index = strong_lazification ? a_local_loc : nothing
570
+ elseif dual_gap >= phi / lazy_tolerance
571
+ if strong_lazification
572
+ a = compute_inface_extreme_point (lmo, NegatingArray (gradient), x)
573
+ else
574
+ a = a_taken
575
+ end
576
+ d = muladd_memory_mode (memory_mode, d, a, v)
577
+ # lower our expectation
518
578
else
519
- a = compute_inface_extreme_point (lmo, NegatingArray (gradient), x)
579
+ step_type = ST_DUALSTEP
580
+ phi = min (dual_gap, phi / 2.0 )
581
+ a = a_taken
582
+ d = zeros (length (x))
520
583
end
584
+ end
521
585
522
- # Real dual gap promises enough progress.
523
- grad_dot_fw_vertex = fast_dot (v, gradient)
524
- dual_gap = grad_dot_x - grad_dot_fw_vertex
586
+ return d, v, fw_index, a, away_index, phi, step_type
587
+ end
525
588
526
- if dual_gap >= phi / lazy_tolerance
527
- d = muladd_memory_mode (memory_mode, d, a, v)
528
- # Lower our expectation for progress.
589
+ """
590
+ Lazification for Blended DICG.
591
+ Search for in-face vertex and local FW vertex only in strong version.
592
+ """
593
+ function lazy_blended_dicg_step (
594
+ x,
595
+ gradient,
596
+ lmo,
597
+ pre_computed_set,
598
+ phi,
599
+ epsilon,
600
+ d;
601
+ strong_lazification = false ,
602
+ lazy_tolerance = 2.0 ,
603
+ memory_mode:: MemoryEmphasis = InplaceEmphasis (),
604
+ )
605
+ v_local, v_local_loc, val, a_local, a_local_loc, valM =
606
+ pre_computed_set_argminmax (lmo, pre_computed_set, gradient, x; strong_lazification = strong_lazification)
607
+ step_type = ST_PAIRWISE
608
+ away_index = nothing
609
+ fw_index = nothing
610
+ grad_dot_x = fast_dot (x, gradient)
611
+ grad_dot_a_local = valM
612
+ grad_dot_lazy_fw_vertex = val
613
+
614
+ if strong_lazification
615
+ a_taken = a_local
616
+ v_taken = v_local
617
+ grad_dot_a_taken = grad_dot_a_local
618
+ grad_dot_v_taken = grad_dot_lazy_fw_vertex
619
+ else
620
+ a_taken = compute_inface_extreme_point (lmo, NegatingArray (gradient), x)
621
+ v_taken = compute_inface_extreme_point (lmo, gradient, x)
622
+ grad_dot_a_taken = fast_dot (gradient, a_taken)
623
+ grad_dot_v_taken = fast_dot (gradient, v_taken)
624
+ end
625
+
626
+ # Do lazy pairwise step
627
+ if grad_dot_a_taken - grad_dot_v_taken >= phi &&
628
+ grad_dot_a_taken - grad_dot_v_taken >= epsilon
629
+ step_type = ST_LAZY
630
+ v = v_taken
631
+ a = a_taken
632
+ d = muladd_memory_mode (memory_mode, d, a, v)
633
+ fw_index = v_local_loc
634
+ away_index = a_local_loc
635
+ else
636
+ if strong_lazification
637
+ v_inface = compute_inface_extreme_point (lmo, gradient)
638
+ grad_dot_v_inface = fast_dot (gradient, v_inface)
639
+
640
+ if grad_dot_a_taken - grad_dot_v_inface >= phi &&
641
+ grad_dot_a_taken - grad_dot_v_inface >= epsilon
642
+ step_type = ST_LAZY
643
+ v = v_inface
644
+ a = a_taken
645
+ d = muladd_memory_mode (memory_mode, d, a, v)
646
+ away_index = a_local_loc
647
+ end
529
648
else
530
- d = muladd_memory_mode (memory_mode, d, a, v)
531
- phi = min (dual_gap, phi / 2.0 )
649
+ v_inface = v_taken
650
+ grad_dot_v_inface = grad_dot_v_taken
651
+ end
652
+
653
+ if step_type != = ST_LAZY
654
+ v = compute_extreme_point (lmo, gradient)
655
+ grad_dot_v = fast_dot (gradient, v)
656
+ dual_gap = grad_dot_x - grad_dot_v
657
+ if dual_gap >= phi / lazy_tolerance
658
+
659
+ if strong_lazification
660
+ a_taken = compute_inface_extreme_point (lmo, NegatingArray (gradient), x)
661
+ grad_dot_a_taken = fast_dot (gradient, a_taken)
662
+ end
663
+
664
+ if grad_dot_a_taken - grad_dot_v_inface >= grad_dot_x - grad_dot_v / lazy_tolerance
665
+ step_type = ST_PAIRWISE
666
+ a = a_taken
667
+ d = muladd_memory_mode (memory_mode, d, a, v_inface)
668
+ else
669
+ step_type = ST_REGULAR
670
+ a = x
671
+ d = muladd_memory_mode (memory_mode, d, x, v)
672
+ end
673
+ else
674
+ step_type = ST_DUALSTEP
675
+ phi = min (dual_gap, phi / 2.0 )
676
+ a = a_taken
677
+ d = zeros (length (x))
678
+ end
532
679
end
533
680
end
534
681
return d, v, fw_index, a, away_index, phi, step_type
0 commit comments