Skip to content

Commit bc0a03b

Browse files
Merge pull request #119762 from AxeZhan/PollUntilContextCancel
wait.PollUntilContextCancel immediately executes condition once Kubernetes-commit: 227d1b2357d93a6884addccb50122df16674ca95
2 parents 16d50e6 + 5916a9f commit bc0a03b

File tree

2 files changed

+97
-55
lines changed

2 files changed

+97
-55
lines changed

pkg/util/wait/loop.go

+18-20
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ func loopConditionUntilContext(ctx context.Context, t Timer, immediate, sliding
4040
var timeCh <-chan time.Time
4141
doneCh := ctx.Done()
4242

43+
if !sliding {
44+
timeCh = t.C()
45+
}
46+
4347
// if immediate is true the condition is
4448
// guaranteed to be executed at least once,
4549
// if we haven't requested immediate execution, delay once
@@ -50,17 +54,27 @@ func loopConditionUntilContext(ctx context.Context, t Timer, immediate, sliding
5054
}(); err != nil || ok {
5155
return err
5256
}
53-
} else {
57+
}
58+
59+
if sliding {
5460
timeCh = t.C()
61+
}
62+
63+
for {
64+
65+
// Wait for either the context to be cancelled or the next invocation be called
5566
select {
5667
case <-doneCh:
5768
return ctx.Err()
5869
case <-timeCh:
5970
}
60-
}
6171

62-
for {
63-
// checking ctx.Err() is slightly faster than checking a select
72+
// IMPORTANT: Because there is no channel priority selection in golang
73+
// it is possible for very short timers to "win" the race in the previous select
74+
// repeatedly even when the context has been canceled. We therefore must
75+
// explicitly check for context cancellation on every loop and exit if true to
76+
// guarantee that we don't invoke condition more than once after context has
77+
// been cancelled.
6478
if err := ctx.Err(); err != nil {
6579
return err
6680
}
@@ -77,21 +91,5 @@ func loopConditionUntilContext(ctx context.Context, t Timer, immediate, sliding
7791
if sliding {
7892
t.Next()
7993
}
80-
81-
if timeCh == nil {
82-
timeCh = t.C()
83-
}
84-
85-
// NOTE: b/c there is no priority selection in golang
86-
// it is possible for this to race, meaning we could
87-
// trigger t.C and doneCh, and t.C select falls through.
88-
// In order to mitigate we re-check doneCh at the beginning
89-
// of every loop to guarantee at-most one extra execution
90-
// of condition.
91-
select {
92-
case <-doneCh:
93-
return ctx.Err()
94-
case <-timeCh:
95-
}
9694
}
9795
}

pkg/util/wait/loop_test.go

+79-35
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ func Test_loopConditionUntilContext_semantic(t *testing.T) {
9999
cancelContextAfter int
100100
attemptsExpected int
101101
errExpected error
102+
timer Timer
102103
}{
103104
{
104105
name: "condition successful is only one attempt",
@@ -203,45 +204,88 @@ func Test_loopConditionUntilContext_semantic(t *testing.T) {
203204
attemptsExpected: 0,
204205
errExpected: context.DeadlineExceeded,
205206
},
207+
{
208+
name: "context canceled before the second execution and immediate",
209+
immediate: true,
210+
context: func() (context.Context, context.CancelFunc) {
211+
return context.WithTimeout(context.Background(), time.Second)
212+
},
213+
callback: func(attempts int) (bool, error) {
214+
return false, nil
215+
},
216+
attemptsExpected: 1,
217+
errExpected: context.DeadlineExceeded,
218+
timer: Backoff{Duration: 2 * time.Second}.Timer(),
219+
},
220+
{
221+
name: "immediate and long duration of condition and sliding false",
222+
immediate: true,
223+
sliding: false,
224+
context: func() (context.Context, context.CancelFunc) {
225+
return context.WithTimeout(context.Background(), time.Second)
226+
},
227+
callback: func(attempts int) (bool, error) {
228+
if attempts >= 4 {
229+
return true, nil
230+
}
231+
time.Sleep(time.Second / 5)
232+
return false, nil
233+
},
234+
attemptsExpected: 4,
235+
timer: Backoff{Duration: time.Second / 5, Jitter: 0.001}.Timer(),
236+
},
237+
{
238+
name: "immediate and long duration of condition and sliding true",
239+
immediate: true,
240+
sliding: true,
241+
context: func() (context.Context, context.CancelFunc) {
242+
return context.WithTimeout(context.Background(), time.Second)
243+
},
244+
callback: func(attempts int) (bool, error) {
245+
if attempts >= 4 {
246+
return true, nil
247+
}
248+
time.Sleep(time.Second / 5)
249+
return false, nil
250+
},
251+
errExpected: context.DeadlineExceeded,
252+
attemptsExpected: 3,
253+
timer: Backoff{Duration: time.Second / 5, Jitter: 0.001}.Timer(),
254+
},
206255
}
207256

208257
for _, test := range tests {
209-
for _, immediate := range []bool{true, false} {
210-
t.Run(fmt.Sprintf("immediate=%t", immediate), func(t *testing.T) {
211-
for _, sliding := range []bool{true, false} {
212-
t.Run(fmt.Sprintf("sliding=%t", sliding), func(t *testing.T) {
213-
t.Run(test.name, func(t *testing.T) {
214-
contextFn := test.context
215-
if contextFn == nil {
216-
contextFn = defaultContext
217-
}
218-
ctx, cancel := contextFn()
219-
defer cancel()
220-
221-
timer := Backoff{Duration: time.Microsecond}.Timer()
222-
attempts := 0
223-
err := loopConditionUntilContext(ctx, timer, test.immediate, test.sliding, func(_ context.Context) (bool, error) {
224-
attempts++
225-
defer func() {
226-
if test.cancelContextAfter > 0 && test.cancelContextAfter == attempts {
227-
cancel()
228-
}
229-
}()
230-
return test.callback(attempts)
231-
})
232-
233-
if test.errExpected != err {
234-
t.Errorf("expected error: %v but got: %v", test.errExpected, err)
235-
}
236-
237-
if test.attemptsExpected != attempts {
238-
t.Errorf("expected attempts count: %d but got: %d", test.attemptsExpected, attempts)
239-
}
240-
})
241-
})
242-
}
258+
t.Run(test.name, func(t *testing.T) {
259+
contextFn := test.context
260+
if contextFn == nil {
261+
contextFn = defaultContext
262+
}
263+
ctx, cancel := contextFn()
264+
defer cancel()
265+
266+
timer := test.timer
267+
if timer == nil {
268+
timer = Backoff{Duration: time.Microsecond}.Timer()
269+
}
270+
attempts := 0
271+
err := loopConditionUntilContext(ctx, timer, test.immediate, test.sliding, func(_ context.Context) (bool, error) {
272+
attempts++
273+
defer func() {
274+
if test.cancelContextAfter > 0 && test.cancelContextAfter == attempts {
275+
cancel()
276+
}
277+
}()
278+
return test.callback(attempts)
243279
})
244-
}
280+
281+
if test.errExpected != err {
282+
t.Errorf("expected error: %v but got: %v", test.errExpected, err)
283+
}
284+
285+
if test.attemptsExpected != attempts {
286+
t.Errorf("expected attempts count: %d but got: %d", test.attemptsExpected, attempts)
287+
}
288+
})
245289
}
246290
}
247291

0 commit comments

Comments
 (0)