Skip to content

Commit a15fa64

Browse files
authored
fix: Avoid overwriting local contexts with retry decorator (#479)
* Avoid overwriting local contexts with retry decorator * Add reno release note
1 parent ee6a8f7 commit a15fa64

File tree

4 files changed

+139
-6
lines changed

4 files changed

+139
-6
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
fixes:
3+
- |
4+
Avoid overwriting local contexts when applying the retry decorator.

tenacity/__init__.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -329,13 +329,19 @@ def wraps(self, f: WrappedFn) -> WrappedFn:
329329
f, functools.WRAPPER_ASSIGNMENTS + ("__defaults__", "__kwdefaults__")
330330
)
331331
def wrapped_f(*args: t.Any, **kw: t.Any) -> t.Any:
332-
return self(f, *args, **kw)
332+
# Always create a copy to prevent overwriting the local contexts when
333+
# calling the same wrapped functions multiple times in the same stack
334+
copy = self.copy()
335+
wrapped_f.statistics = copy.statistics # type: ignore[attr-defined]
336+
return copy(f, *args, **kw)
333337

334338
def retry_with(*args: t.Any, **kwargs: t.Any) -> WrappedFn:
335339
return self.copy(*args, **kwargs).wraps(f)
336340

337-
wrapped_f.retry = self # type: ignore[attr-defined]
341+
# Preserve attributes
342+
wrapped_f.retry = wrapped_f # type: ignore[attr-defined]
338343
wrapped_f.retry_with = retry_with # type: ignore[attr-defined]
344+
wrapped_f.statistics = {} # type: ignore[attr-defined]
339345

340346
return wrapped_f # type: ignore[return-value]
341347

tenacity/asyncio/__init__.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -175,18 +175,23 @@ async def __anext__(self) -> AttemptManager:
175175
raise StopAsyncIteration
176176

177177
def wraps(self, fn: WrappedFn) -> WrappedFn:
178-
fn = super().wraps(fn)
178+
wrapped = super().wraps(fn)
179179
# Ensure wrapper is recognized as a coroutine function.
180180

181181
@functools.wraps(
182182
fn, functools.WRAPPER_ASSIGNMENTS + ("__defaults__", "__kwdefaults__")
183183
)
184184
async def async_wrapped(*args: t.Any, **kwargs: t.Any) -> t.Any:
185-
return await fn(*args, **kwargs)
185+
# Always create a copy to prevent overwriting the local contexts when
186+
# calling the same wrapped functions multiple times in the same stack
187+
copy = self.copy()
188+
async_wrapped.statistics = copy.statistics # type: ignore[attr-defined]
189+
return await copy(fn, *args, **kwargs)
186190

187191
# Preserve attributes
188-
async_wrapped.retry = fn.retry # type: ignore[attr-defined]
189-
async_wrapped.retry_with = fn.retry_with # type: ignore[attr-defined]
192+
async_wrapped.retry = async_wrapped # type: ignore[attr-defined]
193+
async_wrapped.retry_with = wrapped.retry_with # type: ignore[attr-defined]
194+
async_wrapped.statistics = {} # type: ignore[attr-defined]
190195

191196
return async_wrapped # type: ignore[return-value]
192197

tests/test_issue_478.py

+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import asyncio
2+
import typing
3+
import unittest
4+
5+
from functools import wraps
6+
7+
from tenacity import RetryCallState, retry
8+
9+
10+
def asynctest(
11+
callable_: typing.Callable[..., typing.Any],
12+
) -> typing.Callable[..., typing.Any]:
13+
@wraps(callable_)
14+
def wrapper(*a: typing.Any, **kw: typing.Any) -> typing.Any:
15+
loop = asyncio.get_event_loop()
16+
return loop.run_until_complete(callable_(*a, **kw))
17+
18+
return wrapper
19+
20+
21+
MAX_RETRY_FIX_ATTEMPTS = 2
22+
23+
24+
class TestIssue478(unittest.TestCase):
25+
def test_issue(self) -> None:
26+
results = []
27+
28+
def do_retry(retry_state: RetryCallState) -> bool:
29+
outcome = retry_state.outcome
30+
assert outcome
31+
ex = outcome.exception()
32+
_subject_: str = retry_state.args[0]
33+
34+
if _subject_ == "Fix": # no retry on fix failure
35+
return False
36+
37+
if retry_state.attempt_number >= MAX_RETRY_FIX_ATTEMPTS:
38+
return False
39+
40+
if ex:
41+
do_fix_work()
42+
return True
43+
44+
return False
45+
46+
@retry(reraise=True, retry=do_retry)
47+
def _do_work(subject: str) -> None:
48+
if subject == "Error":
49+
results.append(f"{subject} is not working")
50+
raise Exception(f"{subject} is not working")
51+
results.append(f"{subject} is working")
52+
53+
def do_any_work(subject: str) -> None:
54+
_do_work(subject)
55+
56+
def do_fix_work() -> None:
57+
_do_work("Fix")
58+
59+
try:
60+
do_any_work("Error")
61+
except Exception as exc:
62+
assert str(exc) == "Error is not working"
63+
else:
64+
assert False, "No exception caught"
65+
66+
assert results == [
67+
"Error is not working",
68+
"Fix is working",
69+
"Error is not working",
70+
]
71+
72+
@asynctest
73+
async def test_async(self) -> None:
74+
results = []
75+
76+
async def do_retry(retry_state: RetryCallState) -> bool:
77+
outcome = retry_state.outcome
78+
assert outcome
79+
ex = outcome.exception()
80+
_subject_: str = retry_state.args[0]
81+
82+
if _subject_ == "Fix": # no retry on fix failure
83+
return False
84+
85+
if retry_state.attempt_number >= MAX_RETRY_FIX_ATTEMPTS:
86+
return False
87+
88+
if ex:
89+
await do_fix_work()
90+
return True
91+
92+
return False
93+
94+
@retry(reraise=True, retry=do_retry)
95+
async def _do_work(subject: str) -> None:
96+
if subject == "Error":
97+
results.append(f"{subject} is not working")
98+
raise Exception(f"{subject} is not working")
99+
results.append(f"{subject} is working")
100+
101+
async def do_any_work(subject: str) -> None:
102+
await _do_work(subject)
103+
104+
async def do_fix_work() -> None:
105+
await _do_work("Fix")
106+
107+
try:
108+
await do_any_work("Error")
109+
except Exception as exc:
110+
assert str(exc) == "Error is not working"
111+
else:
112+
assert False, "No exception caught"
113+
114+
assert results == [
115+
"Error is not working",
116+
"Fix is working",
117+
"Error is not working",
118+
]

0 commit comments

Comments
 (0)