diff --git a/docs/usage/configuration.md b/docs/usage/configuration.md index 3e06e5a42..0e95cc016 100644 --- a/docs/usage/configuration.md +++ b/docs/usage/configuration.md @@ -84,6 +84,23 @@ If `trust_env` is set to `True`, githubkit (httpx) will look for the environment If you want to set a proxy for client programmatically, you can pass a proxy URL to the `proxy` option. See [httpx's proxies documentation](https://www.python-httpx.org/advanced/proxies/) for more information. +### `transport`, `async_transport` + +These two options let you provide a custom [HTTPX transport](https://www.python-httpx.org/advanced/transports/) for the underlying HTTP client. + +They accept instances of the following types: + +- `httpx.BaseTransport` (sync transport) — pass via the `transport` option. +- `httpx.AsyncBaseTransport` (async transport) — pass via the `async_transport` option. + +When provided, githubkit will forward the transport to create the client. This is useful for: + +- providing a custom network implementation; +- injecting test-only transports (for example `httpx.MockTransport`) to stub responses in unit tests; +- using alternative transports provided by HTTPX or third parties. + +Note that if you pass `None` to the option, the default transport will be created by HTTPX. + ### `cache_strategy` The `cache_strategy` option defines how to cache the tokens or http responses. You can provide a githubkit built-in cache strategy or a custom one that implements the `BaseCacheStrategy` interface. By default, githubkit uses the `MemCacheStrategy` to cache the data in memory. diff --git a/docs/usage/unit-test.md b/docs/usage/unit-test.md index 71c26e3c0..7ef4fa6d2 100644 --- a/docs/usage/unit-test.md +++ b/docs/usage/unit-test.md @@ -1,6 +1,10 @@ # Unit Test -If you are using githubkit in your business logic, you may want to mock the github API in your unit tests. You can custom the response by mocking the `request`/`arequest` method of the `GitHub` class. Here is an example of how to mock githubkit's API calls: +If you are using githubkit in your business logic, you may want to mock the github API in your unit tests. There are two ways to reach this. + +## Mocking the API Calls + +If you can't provide a githubkit test client to your business logic, you can mock the `request`/`arequest` method of the `GitHub` class to custom the response. Here is an example of how to mock githubkit's API calls: === "Sync" @@ -106,3 +110,66 @@ If you are using githubkit in your business logic, you may want to mock the gith 1. Example function you want to test, which calls the GitHub API. 2. other request parameters including headers, json, etc. 3. When the request is made, return a fake response + +## Using a Test Transport + +You can also create a test client with mock transport and provide it to your business logic. Here is an example: + +=== "Sync" + + ```python + import json + from pathlib import Path + + import httpx + import pytest + + from githubkit import GitHub + from githubkit.versions.latest.models import FullRepository + + FAKE_RESPONSE = json.loads(Path("fake_response.json").read_text()) + + def target_sync_func(github: GitHub): + resp = github.rest.repos.get("owner", "repo") + return resp.parsed_data + + def mock_transport_handler(request: httpx.Request) -> httpx.Response: + if request.method == "GET" and request.url.path == "/repos/owner/repo": + return httpx.Response(status_code=200, json=FAKE_RESPONSE) + raise RuntimeError(f"Unexpected request: {request.method} {request.url.path}") + + def test_sync_mock(): + g = GitHub("xxxxx", transport=httpx.MockTransport(mock_transport_handler)) + repo = target_sync_func(g) + assert isinstance(repo, FullRepository) + ``` + +=== "Async" + + ```python + import json + from pathlib import Path + + import httpx + import pytest + + from githubkit import GitHub + from githubkit.versions.latest.models import FullRepository + + FAKE_RESPONSE = json.loads(Path("fake_response.json").read_text()) + + async def target_async_func(github: GitHub): + resp = await github.rest.repos.async_get("owner", "repo") + return resp.parsed_data + + def mock_transport_handler(request: httpx.Request) -> httpx.Response: + if request.method == "GET" and request.url.path == "/repos/owner/repo": + return httpx.Response(status_code=200, json=FAKE_RESPONSE) + raise RuntimeError(f"Unexpected request: {request.method} {request.url.path}") + + @pytest.mark.anyio + async def test_async_mock(): + g = GitHub("xxxxx", async_transport=httpx.MockTransport(mock_transport_handler)) + repo = await target_async_func(g) + assert isinstance(repo, FullRepository) + ``` diff --git a/githubkit/config.py b/githubkit/config.py index 63d25b01a..0560d67cf 100644 --- a/githubkit/config.py +++ b/githubkit/config.py @@ -24,6 +24,8 @@ class Config: ssl_verify: Union[bool, "ssl.SSLContext"] trust_env: bool # effects the `httpx` proxy and ssl cert proxy: Optional[ProxyTypes] + transport: Optional[httpx.BaseTransport] + async_transport: Optional[httpx.AsyncBaseTransport] cache_strategy: BaseCacheStrategy http_cache: bool throttler: BaseThrottler @@ -113,6 +115,8 @@ def get_config( ssl_verify: Union[bool, "ssl.SSLContext"] = True, trust_env: bool = True, proxy: Optional[ProxyTypes] = None, + transport: Optional[httpx.BaseTransport] = None, + async_transport: Optional[httpx.AsyncBaseTransport] = None, cache_strategy: Optional[BaseCacheStrategy] = None, http_cache: bool = True, throttler: Optional[BaseThrottler] = None, @@ -129,6 +133,8 @@ def get_config( ssl_verify, trust_env, proxy, + transport, + async_transport, build_cache_strategy(cache_strategy), http_cache, build_throttler(throttler), diff --git a/githubkit/core.py b/githubkit/core.py index 84f609a64..b9b4f5067 100644 --- a/githubkit/core.py +++ b/githubkit/core.py @@ -88,6 +88,8 @@ def __init__( ssl_verify: Union[bool, "ssl.SSLContext"] = ..., trust_env: bool = True, proxy: Optional[ProxyTypes] = None, + transport: Optional[httpx.BaseTransport] = None, + async_transport: Optional[httpx.AsyncBaseTransport] = None, cache_strategy: Optional[BaseCacheStrategy] = None, http_cache: bool = True, throttler: Optional[BaseThrottler] = None, @@ -110,6 +112,8 @@ def __init__( ssl_verify: Union[bool, "ssl.SSLContext"] = ..., trust_env: bool = True, proxy: Optional[ProxyTypes] = None, + transport: Optional[httpx.BaseTransport] = None, + async_transport: Optional[httpx.AsyncBaseTransport] = None, cache_strategy: Optional[BaseCacheStrategy] = None, http_cache: bool = True, throttler: Optional[BaseThrottler] = None, @@ -132,6 +136,8 @@ def __init__( ssl_verify: Union[bool, "ssl.SSLContext"] = ..., trust_env: bool = True, proxy: Optional[ProxyTypes] = None, + transport: Optional[httpx.BaseTransport] = None, + async_transport: Optional[httpx.AsyncBaseTransport] = None, cache_strategy: Optional[BaseCacheStrategy] = None, http_cache: bool = True, throttler: Optional[BaseThrottler] = None, @@ -153,6 +159,8 @@ def __init__( ssl_verify: Union[bool, "ssl.SSLContext"] = True, trust_env: bool = True, proxy: Optional[ProxyTypes] = None, + transport: Optional[httpx.BaseTransport] = None, + async_transport: Optional[httpx.AsyncBaseTransport] = None, cache_strategy: Optional[BaseCacheStrategy] = None, http_cache: bool = True, throttler: Optional[BaseThrottler] = None, @@ -174,6 +182,8 @@ def __init__( ssl_verify=ssl_verify, trust_env=trust_env, proxy=proxy, + transport=transport, + async_transport=async_transport, cache_strategy=cache_strategy, http_cache=http_cache, throttler=throttler, @@ -241,11 +251,14 @@ def _create_sync_client(self) -> httpx.Client: if self.config.http_cache: return hishel.CacheClient( **self._get_client_defaults(), + transport=self.config.transport, storage=self.config.cache_strategy.get_hishel_storage(), controller=self.config.cache_strategy.get_hishel_controller(), ) - return httpx.Client(**self._get_client_defaults()) + return httpx.Client( + **self._get_client_defaults(), transport=self.config.transport + ) # get or create sync client @contextmanager @@ -263,11 +276,14 @@ def _create_async_client(self) -> httpx.AsyncClient: if self.config.http_cache: return hishel.AsyncCacheClient( **self._get_client_defaults(), + transport=self.config.async_transport, storage=self.config.cache_strategy.get_async_hishel_storage(), controller=self.config.cache_strategy.get_hishel_controller(), ) - return httpx.AsyncClient(**self._get_client_defaults()) + return httpx.AsyncClient( + **self._get_client_defaults(), transport=self.config.async_transport + ) # get or create async client @asynccontextmanager diff --git a/githubkit/github.py b/githubkit/github.py index 0335483e2..6779afce0 100644 --- a/githubkit/github.py +++ b/githubkit/github.py @@ -79,6 +79,8 @@ def __init__( ssl_verify: Union[bool, "ssl.SSLContext"] = ..., trust_env: bool = True, proxy: Optional[ProxyTypes] = None, + transport: Optional[httpx.BaseTransport] = None, + async_transport: Optional[httpx.AsyncBaseTransport] = None, cache_strategy: Optional["BaseCacheStrategy"] = None, http_cache: bool = True, throttler: Optional["BaseThrottler"] = None, @@ -101,6 +103,8 @@ def __init__( ssl_verify: Union[bool, "ssl.SSLContext"] = ..., trust_env: bool = True, proxy: Optional[ProxyTypes] = None, + transport: Optional[httpx.BaseTransport] = None, + async_transport: Optional[httpx.AsyncBaseTransport] = None, cache_strategy: Optional["BaseCacheStrategy"] = None, http_cache: bool = True, throttler: Optional["BaseThrottler"] = None, @@ -123,6 +127,8 @@ def __init__( ssl_verify: Union[bool, "ssl.SSLContext"] = ..., trust_env: bool = True, proxy: Optional[ProxyTypes] = None, + transport: Optional[httpx.BaseTransport] = None, + async_transport: Optional[httpx.AsyncBaseTransport] = None, cache_strategy: Optional["BaseCacheStrategy"] = None, http_cache: bool = True, throttler: Optional["BaseThrottler"] = None, diff --git a/tests/test_unit_test/test_mock_transport.py b/tests/test_unit_test/test_mock_transport.py new file mode 100644 index 000000000..d25da6282 --- /dev/null +++ b/tests/test_unit_test/test_mock_transport.py @@ -0,0 +1,39 @@ +import json +from pathlib import Path + +import httpx +import pytest + +from githubkit import GitHub +from githubkit.versions.latest.models import FullRepository + +FAKE_RESPONSE = json.loads((Path(__file__).parent / "fake_response.json").read_text()) + + +def target_sync_func(github: GitHub): + resp = github.rest.repos.get("owner", "repo") + return resp.parsed_data + + +def mock_transport_handler(request: httpx.Request) -> httpx.Response: + if request.method == "GET" and request.url.path == "/repos/owner/repo": + return httpx.Response(status_code=200, json=FAKE_RESPONSE) + raise RuntimeError(f"Unexpected request: {request.method} {request.url.path}") + + +def test_sync_mock(): + g = GitHub("xxxxx", transport=httpx.MockTransport(mock_transport_handler)) + repo = target_sync_func(g) + assert isinstance(repo, FullRepository) + + +async def target_async_func(github: GitHub): + resp = await github.rest.repos.async_get("owner", "repo") + return resp.parsed_data + + +@pytest.mark.anyio +async def test_async_mock(): + g = GitHub("xxxxx", async_transport=httpx.MockTransport(mock_transport_handler)) + repo = await target_async_func(g) + assert isinstance(repo, FullRepository)