Source code for restgdf.resilience._retry

"""Stamina-based retry wrapper implementing AsyncHTTPSession (BL-31)."""

from __future__ import annotations

import inspect
from typing import Any

import aiohttp
import stamina

from restgdf._config import ResilienceConfig
from restgdf._logging import get_logger
from restgdf.errors import (
    RateLimitError,
    RestgdfResponseError,
    RestgdfTimeoutError,
    TransportError,
)
from restgdf.resilience._errors import _parse_retry_after
from restgdf.resilience._limiter import CooldownRegistry, LimiterRegistry, _service_root


_log = get_logger("retry")

# Retryable HTTP status codes
_RETRYABLE_STATUS = frozenset({429, 500, 502, 503, 504})


class _ResponseCtx:
    """Thin async-context-manager wrapping an already-resolved response."""

    __slots__ = ("_resp",)

    def __init__(self, resp: Any) -> None:
        self._resp = resp

    async def __aenter__(self) -> Any:
        return self._resp

    async def __aexit__(self, *args: Any) -> None:
        pass

    def __getattr__(self, name: str) -> Any:
        return getattr(self._resp, name)


[docs] class ResilientSession: """Retry + rate-limit adapter wrapping an inner AsyncHTTPSession.""" def __init__( self, inner: Any, config: ResilienceConfig, ) -> None: self._inner = inner self._config = config self._cooldown = CooldownRegistry() self._limiter: LimiterRegistry | None = None if config.rate_per_service_root_per_second is not None: self._limiter = LimiterRegistry(config.rate_per_service_root_per_second) @property def closed(self) -> bool: return self._inner.closed
[docs] async def close(self) -> None: await self._inner.close()
[docs] def get(self, url: str, **kwargs: Any) -> Any: if not self._config.enabled: return self._inner.get(url, **kwargs) return self._retried_request("get", url, **kwargs)
[docs] def post(self, url: str, **kwargs: Any) -> Any: if not self._config.enabled: return self._inner.post(url, **kwargs) return self._retried_request("post", url, **kwargs)
def _retried_request(self, method: str, url: str, **kwargs: Any) -> Any: return _RetriedCtx(self, method, url, kwargs) def _reset_limiters(self) -> None: """Reset all limiter and cooldown state (for testing).""" self._cooldown = CooldownRegistry() if self._limiter is not None: self._limiter.reset()
class _RetriedCtx: """Dual-interface wrapper: works as ``await session.get(url)`` AND as ``async with session.get(url) as resp:``. Mirrors :class:`aiohttp.client._RequestContextManager` so :class:`ResilientSession` behaves identically to :class:`aiohttp.ClientSession` regardless of whether callers use the awaitable or async-context-manager pattern. :mod:`restgdf.utils._http` awaits the result of ``session.get`` / ``session.post`` directly, so this dual shape is required for the helper to work against a :class:`ResilientSession`-wrapped inner session. """ __slots__ = ("_session", "_method", "_url", "_kwargs", "_resp", "_resp_ctx") def __init__( self, session: ResilientSession, method: str, url: str, kwargs: dict[str, Any], ) -> None: self._session = session self._method = method self._url = url self._kwargs = kwargs self._resp: Any = None self._resp_ctx: Any = None async def _run(self) -> Any: self._resp_ctx, self._resp = await _do_retried_request( self._session._inner, self._session._config, self._method, self._url, self._kwargs, limiter=self._session._limiter, cooldown=self._session._cooldown, ) return self._resp async def __aenter__(self) -> Any: return await self._run() async def __aexit__(self, *args: Any) -> None: if self._resp_ctx is not None: await self._resp_ctx.__aexit__(*args) def __await__(self) -> Any: return self._run().__await__() class _RetryableHTTPError(Exception): """Internal sentinel for stamina retry loop.""" def __init__(self, status: int, headers: dict[str, str] | None = None) -> None: self.status = status self.headers = headers or {} async def _do_retried_request( inner: Any, config: ResilienceConfig, method: str, url: str, kwargs: dict[str, Any], *, limiter: LimiterRegistry | None = None, cooldown: CooldownRegistry | None = None, ) -> tuple[Any, Any]: """Execute request with stamina retry, token-bucket, and cooldown.""" svc_root = _service_root(url) retry_on = ( _RetryableHTTPError, aiohttp.ClientConnectorError, aiohttp.ServerTimeoutError, ) @stamina.retry( on=retry_on, attempts=5, timeout=60.0, wait_initial=0.5, wait_max=10.0, wait_jitter=1.0, ) async def _attempt() -> Any: # 429 cooldown: wait if a previous 429 set a deadline for this service if cooldown is not None: await cooldown.wait_if_cooling(svc_root) # Token-bucket rate limit if limiter is not None: await limiter.get(svc_root).acquire() try: dispatch = getattr(inner, method) ctx, resp = await _enter_request(dispatch(url, **kwargs)) except aiohttp.ClientConnectorError: raise # retryable except aiohttp.ServerTimeoutError: raise # retryable if resp.status in _RETRYABLE_STATUS: headers = dict(getattr(resp, "headers", {})) # Set cooldown on 429 so the next retry waits if resp.status == 429 and cooldown is not None: ra = _parse_retry_after(headers.get("Retry-After", "")) cd = ( min(ra, config.respect_retry_after_max_s) if ra else config.fallback_retry_after_seconds ) cooldown.set_cooldown(svc_root, cd) await ctx.__aexit__(None, None, None) raise _RetryableHTTPError(resp.status, headers) if 400 <= resp.status < 500: await ctx.__aexit__(None, None, None) raise RestgdfResponseError( f"Client error ({resp.status}) at {url}", model_name="", context=url, raw=None, url=url, status_code=resp.status, ) return ctx, resp try: return await _attempt() except _RetryableHTTPError as exc: if exc.status == 429: retry_after = _parse_retry_after(exc.headers.get("Retry-After", "")) raise RateLimitError( f"Rate limited (429) at {url}", retry_after=retry_after, url=url, status_code=429, ) from exc raise RestgdfResponseError( f"Server error ({exc.status}) at {url}", model_name="", context=url, raw=None, url=url, status_code=exc.status, ) from exc except aiohttp.ClientConnectorError as exc: raise TransportError( f"Connection failed for {url}", url=url, status_code=None, ) from exc except aiohttp.ServerTimeoutError as exc: raise RestgdfTimeoutError( f"Read timeout: {exc}", url=url, timeout_kind="read", ) from exc async def _enter_request(result: Any) -> tuple[Any, Any]: """Normalize a session dispatch result to an entered async context.""" if inspect.isawaitable(result): response = await result ctx = _ResponseCtx(response) return ctx, await ctx.__aenter__() if hasattr(result, "__aenter__") and hasattr(result, "__aexit__"): return result, await result.__aenter__() ctx = _ResponseCtx(result) return ctx, await ctx.__aenter__()