# SPDX-FileCopyrightText: 2022 Hynek Schlawack <hs@ox.cx>
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
import datetime as dt
import sys
from collections.abc import Callable
from dataclasses import dataclass, replace
from functools import wraps
from inspect import iscoroutinefunction
from types import TracebackType
from typing import AsyncIterator, Awaitable, Iterator, TypedDict, TypeVar
import tenacity as _t
from ._config import CONFIG, _Config
from .instrumentation._data import RetryDetails, guess_name
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec
try:
from sniffio import current_async_library
except ImportError: # pragma: no cover -- we always have sniffio in tests
def current_async_library() -> str:
return "asyncio"
async def _smart_sleep(delay: float) -> None:
io = current_async_library()
if io == "asyncio":
import asyncio
await asyncio.sleep(delay)
elif io == "trio":
import trio
await trio.sleep(delay)
else: # pragma: no cover
msg = f"Unknown async library: {io!r}"
raise RuntimeError(msg)
T = TypeVar("T")
P = ParamSpec("P")
[docs]
def retry_context(
on: type[Exception] | tuple[type[Exception], ...],
attempts: int | None = 10,
timeout: float | dt.timedelta | None = 45.0,
wait_initial: float | dt.timedelta = 0.1,
wait_max: float | dt.timedelta = 5.0,
wait_jitter: float | dt.timedelta = 1.0,
wait_exp_base: float = 2.0,
) -> _RetryContextIterator:
"""
Iterator that yields context managers that can be used to retry code
blocks.
Arguments have the same meaning as for :func:`stamina.retry`.
.. versionadded:: 23.1.0
.. versionadded:: 23.3.0 `Trio <https://trio.readthedocs.io/>`_ support.
"""
return _RetryContextIterator.from_params(
on=on,
attempts=attempts,
timeout=timeout,
wait_initial=wait_initial,
wait_max=wait_max,
wait_jitter=wait_jitter,
wait_exp_base=wait_exp_base,
name="<context block>",
args=(),
kw={},
)
[docs]
class Attempt:
"""
A context manager that can be used to retry code blocks.
Instances are yielded by the :func:`stamina.retry_context` iterator.
.. versionadded:: 23.2.0
"""
__slots__ = ("_t_attempt",)
_t_attempt: _t.AttemptManager
def __init__(self, attempt: _t.AttemptManager):
self._t_attempt = attempt
def __repr__(self) -> str:
return f"<Attempt num={self.num}>"
@property
def num(self) -> int:
"""
The number of the current attempt.
"""
return self._t_attempt.retry_state.attempt_number # type: ignore[no-any-return]
def __enter__(self) -> None:
return self._t_attempt.__enter__() # type: ignore[no-any-return]
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> bool | None:
return self._t_attempt.__exit__( # type: ignore[no-any-return]
exc_type, exc_value, traceback
)
class RetryKWs(TypedDict):
attempts: int | None
timeout: float | dt.timedelta | None
wait_initial: float | dt.timedelta
wait_max: float | dt.timedelta
wait_jitter: float | dt.timedelta
wait_exp_base: float
class BaseRetryingCaller:
"""
Simple base class that transforms retry parameters into a dictionary that
can be `**`-passed into `retry_context`.
"""
__slots__ = ("_context_kws",)
_context_kws: RetryKWs
def __init__(
self,
attempts: int | None = 10,
timeout: float | dt.timedelta | None = 45.0,
wait_initial: float | dt.timedelta = 0.1,
wait_max: float | dt.timedelta = 5.0,
wait_jitter: float | dt.timedelta = 1.0,
wait_exp_base: float = 2.0,
):
self._context_kws = {
"attempts": attempts,
"timeout": timeout,
"wait_initial": wait_initial,
"wait_max": wait_max,
"wait_jitter": wait_jitter,
"wait_exp_base": wait_exp_base,
}
def __repr__(self) -> str:
kws = ", ".join(
f"{k}={self._context_kws[k]!r}" # type: ignore[literal-required]
for k in sorted(self._context_kws)
if k != "on"
)
return f"<{self.__class__.__name__}({kws})>"
[docs]
class RetryingCaller(BaseRetryingCaller):
"""
Call your callables with retries.
Arguments have the same meaning as for :func:`stamina.retry`.
Tip:
Instances of ``RetryingCaller`` may be reused because they internally
create a new :func:`retry_context` iterator on each call.
.. versionadded:: 24.2.0
"""
[docs]
def __call__(
self,
on: type[Exception] | tuple[type[Exception], ...],
callable_: Callable[P, T],
/,
*args: P.args,
**kwargs: P.kwargs,
) -> T:
r"""
Call ``callable_(*args, **kw)`` with retries if *on* is raised.
Args:
on: Exception(s) to retry on.
callable\_: Callable to call.
args: Positional arguments to pass to *callable_*.
kw: Keyword arguments to pass to *callable_*.
"""
for attempt in retry_context(on, **self._context_kws):
with attempt:
return callable_(*args, **kwargs)
raise SystemError("unreachable") # noqa: EM101
[docs]
def on(
self, on: type[Exception] | tuple[type[Exception], ...], /
) -> BoundRetryingCaller:
"""
Create a new instance of :class:`BoundRetryingCaller` with the same
parameters, but bound to a specific exception type.
.. versionadded:: 24.2.0
"""
# This should be a `functools.partial`, but unfortunately it's
# impossible to provide a nicely typed API with it, so we use a
# separate class.
return BoundRetryingCaller(self, on)
[docs]
class BoundRetryingCaller:
"""
Same as :class:`RetryingCaller`, but pre-bound to a specific exception
type.
Caution:
Returned by :meth:`RetryingCaller.on` -- do not instantiate directly.
.. versionadded:: 24.2.0
"""
__slots__ = ("_caller", "_on")
_caller: RetryingCaller
_on: type[Exception] | tuple[type[Exception], ...]
def __init__(
self,
caller: RetryingCaller,
on: type[Exception] | tuple[type[Exception], ...],
):
self._caller = caller
self._on = on
def __repr__(self) -> str:
return (
f"<BoundRetryingCaller({guess_name(self._on)}, {self._caller!r})>"
)
[docs]
def __call__(
self, callable_: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs
) -> T:
"""
Same as :func:`RetryingCaller.__call__`, except retry on the exception
that is bound to this instance.
"""
return self._caller(self._on, callable_, *args, **kwargs)
[docs]
class AsyncRetryingCaller(BaseRetryingCaller):
"""
Same as :class:`RetryingCaller`, but for async callables.
.. versionadded:: 24.2.0
"""
[docs]
async def __call__(
self,
on: type[Exception] | tuple[type[Exception], ...],
callable_: Callable[P, Awaitable[T]],
/,
*args: P.args,
**kwargs: P.kwargs,
) -> T:
"""
Same as :meth:`RetryingCaller.__call__`, but *callable_* is awaited.
"""
async for attempt in retry_context(on, **self._context_kws):
with attempt:
return await callable_(*args, **kwargs)
raise SystemError("unreachable") # noqa: EM101
[docs]
def on(
self, on: type[Exception] | tuple[type[Exception], ...], /
) -> BoundAsyncRetryingCaller:
"""
Create a new instance of :class:`BoundAsyncRetryingCaller` with the
same parameters, but bound to a specific exception type.
.. versionadded:: 24.2.0
"""
return BoundAsyncRetryingCaller(self, on)
[docs]
class BoundAsyncRetryingCaller:
"""
Same as :class:`BoundRetryingCaller`, but for async callables.
Caution:
Returned by :meth:`AsyncRetryingCaller.on` -- do not instantiate
directly.
.. versionadded:: 24.2.0
"""
__slots__ = ("_caller", "_on")
_caller: AsyncRetryingCaller
_on: type[Exception] | tuple[type[Exception], ...]
def __init__(
self,
caller: AsyncRetryingCaller,
on: type[Exception] | tuple[type[Exception], ...],
):
self._caller = caller
self._on = on
def __repr__(self) -> str:
return f"<BoundAsyncRetryingCaller({guess_name(self._on)}, {self._caller!r})>"
[docs]
async def __call__(
self,
callable_: Callable[P, Awaitable[T]],
/,
*args: P.args,
**kwargs: P.kwargs,
) -> T:
"""
Same as :func:`AsyncRetryingCaller.__call__`, except retry on the
exception that is bound to this instance.
"""
return await self._caller(self._on, callable_, *args, **kwargs)
_STOP_NO_RETRY = _t.stop_after_attempt(1)
class _LazyNoAsyncRetry:
"""
Allows us a free null object pattern using non-retries and avoid None.
"""
__slots__ = ()
def __aiter__(self) -> _t.AsyncRetrying:
return _t.AsyncRetrying(
reraise=True, stop=_STOP_NO_RETRY, sleep=_smart_sleep
).__aiter__()
_LAZY_NO_ASYNC_RETRY = _LazyNoAsyncRetry()
@dataclass
class _RetryContextIterator:
__slots__ = ("_t_kw", "_t_a_retrying", "_name", "_args", "_kw")
_t_kw: dict[str, object]
_t_a_retrying: _t.AsyncRetrying
_name: str
_args: tuple[object, ...]
_kw: dict[str, object]
@classmethod
def from_params(
cls,
on: type[Exception] | tuple[type[Exception], ...],
attempts: int | None,
timeout: float | dt.timedelta | None,
wait_initial: float | dt.timedelta,
wait_max: float | dt.timedelta,
wait_jitter: float | dt.timedelta,
wait_exp_base: float,
name: str,
args: tuple[object, ...],
kw: dict[str, object],
) -> _RetryContextIterator:
return cls(
_name=name,
_args=args,
_kw=kw,
_t_kw={
"retry": _t.retry_if_exception_type(on),
"wait": _t.wait_exponential_jitter(
initial=(
wait_initial.total_seconds()
if isinstance(wait_initial, dt.timedelta)
else wait_initial
),
max=(
wait_max.total_seconds()
if isinstance(wait_max, dt.timedelta)
else wait_max
),
exp_base=wait_exp_base,
jitter=(
wait_jitter.total_seconds()
if isinstance(wait_jitter, dt.timedelta)
else wait_jitter
),
),
"stop": _make_stop(
attempts=attempts,
timeout=(
timeout.total_seconds()
if isinstance(timeout, dt.timedelta)
else timeout
),
),
"reraise": True,
},
_t_a_retrying=_LAZY_NO_ASYNC_RETRY,
)
def with_name(
self, name: str, args: tuple[object, ...], kw: dict[str, object]
) -> _RetryContextIterator:
"""
Recreate ourselves with a new name and arguments.
"""
return replace(self, _name=name, _args=args, _kw=kw)
def __iter__(self) -> Iterator[Attempt]:
if not CONFIG.is_active:
for r in _t.Retrying(reraise=True, stop=_STOP_NO_RETRY):
yield Attempt(r)
return
for r in _t.Retrying(
before_sleep=_make_before_sleep(
self._name, CONFIG, self._args, self._kw
),
**self._t_kw,
):
yield Attempt(r)
def __aiter__(self) -> AsyncIterator[Attempt]:
if CONFIG.is_active:
self._t_a_retrying = _t.AsyncRetrying(
sleep=_smart_sleep,
before_sleep=_make_before_sleep(
self._name, CONFIG, self._args, self._kw
),
**self._t_kw,
)
self._t_a_retrying = self._t_a_retrying.__aiter__()
return self
async def __anext__(self) -> Attempt:
return Attempt(await self._t_a_retrying.__anext__())
def _make_before_sleep(
name: str,
config: _Config,
args: tuple[object, ...],
kw: dict[str, object],
) -> Callable[[_t.RetryCallState], None]:
"""
Create a `before_sleep` callback function that runs our `RetryHook`s with
the necessary arguments.
"""
last_idle_for = 0.0
def before_sleep(rcs: _t.RetryCallState) -> None:
nonlocal last_idle_for
wait_for = rcs.idle_for - last_idle_for
details = RetryDetails(
name=name,
retry_num=rcs.attempt_number,
wait_for=wait_for,
waited_so_far=rcs.idle_for - wait_for,
caused_by=rcs.outcome.exception(),
args=args,
kwargs=kw,
)
for hook in config.on_retry:
hook(details)
last_idle_for = rcs.idle_for
return before_sleep
def _make_stop(*, attempts: int | None, timeout: float | None) -> _t.stop_base:
"""
Combine *attempts* and *timeout* into one stop condition.
"""
stops = []
if attempts:
stops.append(_t.stop_after_attempt(attempts))
if timeout:
stops.append(_t.stop_after_delay(timeout))
if len(stops) > 1:
return _t.stop_any(*stops)
if not stops:
return _t.stop_never
return stops[0]
[docs]
def retry(
*,
on: type[Exception] | tuple[type[Exception], ...],
attempts: int | None = 10,
timeout: float | dt.timedelta | None = 45.0,
wait_initial: float | dt.timedelta = 0.1,
wait_max: float | dt.timedelta = 5.0,
wait_jitter: float | dt.timedelta = 1.0,
wait_exp_base: float = 2.0,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
r"""
Retry if one of configured exceptions are raised.
The backoff delays between retries grow exponentially plus a random jitter.
The backoff for retry attempt number *attempt* is computed as:
.. keep in-sync with docs/motivation.md
.. math::
min(wait\_max, wait\_initial * wait\_exp\_base^{attempt - 1} + random(0, wait\_jitter))
Since :math:`x^0` is always 1, the first backoff is within the interval
:math:`[wait\_initial,wait\_initial+wait\_jitter]`. Thus, with default
values between 0.1 and 1.1 seconds.
If all retries fail, the *last* exception is let through.
All float-based time parameters are in seconds.
Args:
on:
An Exception or a tuple of Exceptions on which the decorated
callable will be retried. There is no default -- you *must* pass
this explicitly.
attempts:
Maximum total number of attempts. Can be combined with *timeout*.
timeout:
Maximum total time for all retries. Can be combined with
*attempts*.
wait_initial: Minimum backoff before the *first* retry.
wait_max: Maximum backoff time between retries at any time.
wait_jitter:
Maximum *jitter* that is added to retry back-off delays (the actual
jitter added is a random number between 0 and *wait_jitter*)
wait_exp_base: The exponential base used to compute the retry backoff.
.. versionchanged:: 23.1.0
All time-related parameters can now be specified as a
:class:`datetime.timedelta`.
.. versionadded:: 23.3.0 `Trio <https://trio.readthedocs.io/>`_ support.
"""
retry_ctx = _RetryContextIterator.from_params(
on=on,
attempts=attempts,
timeout=timeout,
wait_initial=wait_initial,
wait_max=wait_max,
wait_jitter=wait_jitter,
wait_exp_base=wait_exp_base,
name="<unknown>",
args=(),
kw={},
)
def retry_decorator(wrapped: Callable[P, T]) -> Callable[P, T]:
name = guess_name(wrapped)
if not iscoroutinefunction(wrapped):
@wraps(wrapped)
def sync_inner(*args: P.args, **kw: P.kwargs) -> T: # type: ignore[return]
for attempt in retry_ctx.with_name( # noqa: RET503
name, args, kw
):
with attempt:
return wrapped(*args, **kw)
return sync_inner
@wraps(wrapped)
async def async_inner(*args: P.args, **kw: P.kwargs) -> T: # type: ignore[return]
async for attempt in retry_ctx.with_name( # noqa: RET503
name, args, kw
):
with attempt:
return await wrapped(*args, **kw) # type: ignore[no-any-return]
return async_inner # type: ignore[return-value]
return retry_decorator