Skip to content

Commit df4fa13

Browse files
Add retry utility for handling function retries (#278)
* add retry function * make timeout optional * review docstring * add pytest * remove logging_callback and add optional keyword arguments for on_failed * allow for return values * restructure retry * break pytest into multiple short tests, use `pytest.raises` --------- Co-authored-by: Falko Schindler <[email protected]>
1 parent 02d2c14 commit df4fa13

File tree

2 files changed

+90
-1
lines changed

2 files changed

+90
-1
lines changed

rosys/run.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
import uuid
88
from collections.abc import Callable, Generator
99
from contextlib import contextmanager
10+
from dataclasses import dataclass
1011
from functools import wraps
12+
from inspect import signature
1113
from pathlib import Path
12-
from typing import ParamSpec, TypeVar
14+
from typing import Any, ParamSpec, TypeVar
1315

1416
from nicegui import run
1517

@@ -143,3 +145,37 @@ def tear_down() -> None:
143145
_kill(process)
144146
running_sh_processes.clear()
145147
log.info('teardown complete.')
148+
149+
150+
@dataclass(slots=True, kw_only=True, frozen=True)
151+
class OnFailedArguments:
152+
attempt: int
153+
max_attempts: int
154+
155+
156+
async def retry(func: Callable, *,
157+
max_attempts: int = 3,
158+
max_timeout: float | None = None,
159+
on_failed: Callable | None = None) -> Any:
160+
"""Call a function repeatedly until it succeeds or reaches the maximum number of attempts.
161+
162+
:param func: A function to retry
163+
:param max_attempts: Maximum number of attempts
164+
:param max_timeout: Optional maximum time in seconds to wait per attempt
165+
:param on_failed: Optional callback to execute after each failed attempt (optional argument of type ``OnFailedArguments``)
166+
:return: Result of the called function
167+
:raises RuntimeError: If all attempts fail
168+
"""
169+
for attempt in range(max_attempts):
170+
try:
171+
return await asyncio.wait_for(func(), timeout=max_timeout)
172+
except Exception:
173+
if on_failed is None:
174+
continue
175+
if signature(on_failed).parameters:
176+
result = on_failed(OnFailedArguments(attempt=attempt, max_attempts=max_attempts))
177+
else:
178+
result = on_failed()
179+
if asyncio.iscoroutinefunction(on_failed):
180+
await result
181+
raise RuntimeError(f'Running {func.__name__} failed.')

tests/test_run.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import pytest
2+
3+
from rosys import run
4+
5+
6+
@pytest.mark.asyncio
7+
async def test_retry():
8+
async def func() -> str:
9+
events.append('call')
10+
return 'success'
11+
12+
events: list[str] = []
13+
result = await run.retry(func)
14+
assert result == 'success'
15+
assert events == ['call']
16+
17+
18+
@pytest.mark.asyncio
19+
async def test_retry_failed():
20+
async def func() -> None:
21+
events.append('call')
22+
raise ValueError()
23+
24+
events: list[str] = []
25+
with pytest.raises(RuntimeError):
26+
await run.retry(func)
27+
assert events == ['call'] * 3
28+
29+
30+
@pytest.mark.asyncio
31+
async def test_retry_failed_with_on_failed():
32+
async def func() -> None:
33+
events.append('call')
34+
raise ValueError()
35+
36+
events: list[str] = []
37+
with pytest.raises(RuntimeError):
38+
await run.retry(func, on_failed=lambda: events.append('failed'))
39+
assert events == ['call', 'failed'] * 3
40+
41+
42+
@pytest.mark.asyncio
43+
async def test_retry_failed_with_on_failed_and_args():
44+
async def func() -> None:
45+
events.append('call')
46+
raise ValueError()
47+
48+
events: list[str] = []
49+
with pytest.raises(RuntimeError):
50+
await run.retry(func, on_failed=lambda args: events.append(f'failed {args.attempt}/{args.max_attempts}'))
51+
assert events == ['call', 'failed 0/3',
52+
'call', 'failed 1/3',
53+
'call', 'failed 2/3']

0 commit comments

Comments
 (0)