diff --git a/pyproject.toml b/pyproject.toml index 52b35e3..f3557c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ classifiers = [ requires-python = ">=3.10" dependencies = [ "aiohttp[speedups]~=3.8", + "aiohttp-retry~=2.9", "gql[aiohttp,requests]>=4,<5", "pyjwt[crypto]~=2.8", "requests~=2.31", diff --git a/src/simple_github/client.py b/src/simple_github/client.py index 1ae5c9e..f49edd8 100644 --- a/src/simple_github/client.py +++ b/src/simple_github/client.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any from aiohttp import ClientResponse, ClientSession +from aiohttp_retry import ExponentialRetry, RetryClient from gql import Client as GqlClient from gql import gql from gql.client import ReconnectingAsyncClientSession, SyncClientSession @@ -12,6 +13,8 @@ from gql.transport.requests import RequestsHTTPTransport from requests import Response as RequestsResponse from requests import Session +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry if TYPE_CHECKING: from simple_github.auth import Auth @@ -131,6 +134,13 @@ def _get_requests_session(self) -> Session: assert session.transport.session return session.transport.session + def _get_retry_session(self) -> Session: + session = self._get_requests_session() + retry = Retry(total=5, backoff_factor=1, status_forcelist={500, 502, 503, 504}) + adapter = HTTPAdapter(max_retries=retry) + session.mount("https://", adapter) + return session + def request(self, method: str, query: str, **kwargs) -> RequestsResponse: """Make a request to Github's REST API. @@ -144,7 +154,7 @@ def request(self, method: str, query: str, **kwargs) -> RequestsResponse: Dict: The JSON result of the request. """ url = f"{GITHUB_API_ENDPOINT}/{query.lstrip('/')}" - session = self._get_requests_session() + session = self._get_retry_session() with session.request(method, url, **kwargs) as resp: return resp @@ -281,6 +291,13 @@ async def _get_aiohttp_session(self) -> ClientSession: assert session.transport.session return session.transport.session + async def _get_retry_client(self) -> RetryClient: + session = await self._get_aiohttp_session() + return RetryClient( + client_session=session, + retry_options=ExponentialRetry(attempts=5), + ) + async def request(self, method: str, query: str, **kwargs: Any) -> ClientResponse: """Make a request to Github's REST API. @@ -294,8 +311,8 @@ async def request(self, method: str, query: str, **kwargs: Any) -> ClientRespons Dict: The JSON result of the request. """ url = f"{GITHUB_API_ENDPOINT}/{query.lstrip('/')}" - session = await self._get_aiohttp_session() - return await session.request(method, url, **kwargs) + client = await self._get_retry_client() + return await client.request(method, url, **kwargs) async def get(self, query: str, **kwargs: Any) -> ClientResponse: """Make a GET request to Github's REST API. diff --git a/test/test_client.py b/test/test_client.py index 3c372f3..b071509 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -1,3 +1,5 @@ +from unittest import mock + import pytest import pytest_asyncio from aiohttp import ClientResponseError @@ -142,6 +144,8 @@ async def test_async_client_rest(aioresponses, async_client): }, data="null", allow_redirects=True, + # internal aiohttp-retry tracking data + trace_request_ctx=mock.ANY, ) aioresponses.get(url, status=401) @@ -150,6 +154,19 @@ async def test_async_client_rest(aioresponses, async_client): resp.raise_for_status() +@pytest.mark.asyncio +async def test_async_client_retries_on_5xx(aioresponses, async_client): + client = async_client + url = f"{GITHUB_API_ENDPOINT}/octocat" + + aioresponses.get(url, status=502) + aioresponses.get(url, status=200, payload={"answer": 42}) + + resp = await client.get("/octocat") + result = await resp.json() + assert result == {"answer": 42} + + def test_sync_client_rest(responses, sync_client): client = sync_client url = f"{GITHUB_API_ENDPOINT}/octocat" @@ -187,6 +204,18 @@ def test_sync_client_rest(responses, sync_client): resp.raise_for_status() +def test_sync_client_retries_on_5xx(responses, sync_client): + client = sync_client + url = f"{GITHUB_API_ENDPOINT}/octocat" + + responses.get(url, status=502) + responses.get(url, status=200, json={"answer": 42}) + + resp = client.get("/octocat") + result = resp.json() + assert result == {"answer": 42} + + @pytest.mark.asyncio async def test_async_client_rest_with_text(aioresponses, async_client): client = async_client diff --git a/uv.lock b/uv.lock index 34d03d3..c543872 100644 --- a/uv.lock +++ b/uv.lock @@ -156,6 +156,18 @@ speedups = [ { name = "brotlicffi", marker = "platform_python_implementation != 'CPython'" }, ] +[[package]] +name = "aiohttp-retry" +version = "2.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9d/61/ebda4d8e3d8cfa1fd3db0fb428db2dd7461d5742cea35178277ad180b033/aiohttp_retry-2.9.1.tar.gz", hash = "sha256:8eb75e904ed4ee5c2ec242fefe85bf04240f685391c4879d8f541d6028ff01f1", size = 13608, upload-time = "2024-11-06T10:44:54.574Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1a/99/84ba7273339d0f3dfa57901b846489d2e5c2cd731470167757f1935fffbd/aiohttp_retry-2.9.1-py3-none-any.whl", hash = "sha256:66d2759d1921838256a05a3f80ad7e724936f083e35be5abb5e16eed6be6dc54", size = 9981, upload-time = "2024-11-06T10:44:52.917Z" }, +] + [[package]] name = "aioresponses" version = "0.7.8" @@ -1615,6 +1627,7 @@ version = "3.0.0" source = { editable = "." } dependencies = [ { name = "aiohttp", extra = ["speedups"] }, + { name = "aiohttp-retry" }, { name = "gql", extra = ["aiohttp", "requests"] }, { name = "pyjwt", extra = ["crypto"] }, { name = "requests" }, @@ -1639,6 +1652,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "aiohttp", extras = ["speedups"], specifier = "~=3.8" }, + { name = "aiohttp-retry", specifier = "~=2.9" }, { name = "gql", extras = ["aiohttp", "requests"], specifier = ">=4,<5" }, { name = "pyjwt", extras = ["crypto"], specifier = "~=2.8" }, { name = "requests", specifier = "~=2.31" },