Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ repos:
- id: check-toml

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.14.0
rev: v0.15.11
hooks:
- id: ruff
args:
Expand All @@ -15,7 +15,7 @@ repos:
- id: ruff-format

- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.9.0
rev: 0.11.7
hooks:
- id: uv-lock
- id: uv-export
Expand Down
6 changes: 3 additions & 3 deletions fastapi_oauth20/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,16 @@ async def __call__(
detail=error if error is not None else None,
)

kwargs = {'code': code}
kwargs: dict[str, str] = {'code': code}

try:
sig = inspect.signature(self.client.get_access_token)
params = sig.parameters

if 'redirect_uri' in params:
if 'redirect_uri' in params and self.redirect_uri is not None:
kwargs['redirect_uri'] = self.redirect_uri

if 'code_verifier' in params:
if 'code_verifier' in params and code_verifier is not None:
kwargs['code_verifier'] = code_verifier

access_token = await self.client.get_access_token(**kwargs)
Expand Down
10 changes: 8 additions & 2 deletions fastapi_oauth20/clients/github.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any
import json

from typing import Any, cast

import httpx

Expand Down Expand Up @@ -43,7 +45,11 @@ async def get_userinfo(self, access_token: str) -> dict[str, Any]:
if email is None:
response = await client.get(f'{self.userinfo_endpoint}/emails')
self.raise_httpx_oauth20_errors(response)
emails = self.get_json_result(response, err_class=GetUserInfoError)
try:
emails = cast(list[dict[str, Any]], response.json())
except json.JSONDecodeError as e:
raise GetUserInfoError('Result serialization failed.', response) from e

email = next((email['email'] for email in emails if email.get('primary')), emails[0]['email'])
result['email'] = email

Expand Down
4 changes: 2 additions & 2 deletions fastapi_oauth20/clients/weixin_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async def get_authorization_url(
state: str | None = None,
scope: list[str] | None = None,
**kwargs,
) -> str:
) -> str: # ty:ignore[invalid-method-override]
"""
Generate WeChat OAuth2 authorization URL.

Expand Down Expand Up @@ -62,7 +62,7 @@ async def get_authorization_url(

return f'{self.authorize_endpoint}?{urlencode(params)}#wechat_redirect'

async def get_access_token(self, code: str) -> dict[str, Any]:
async def get_access_token(self, code: str) -> dict[str, Any]: # ty:ignore[invalid-method-override]
"""
Exchange authorization code for access token using WeChat's GET method.

Expand Down
4 changes: 2 additions & 2 deletions fastapi_oauth20/clients/weixin_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async def get_authorization_url(
state: str | None = None,
scope: list[str] | None = None,
**kwargs,
) -> str:
) -> str: # ty:ignore[invalid-method-override]
"""
Generate WeChat Open Platform OAuth2 authorization URL.

Expand All @@ -58,7 +58,7 @@ async def get_authorization_url(

return f'{self.authorize_endpoint}?{urlencode(params)}#wechat_redirect'

async def get_access_token(self, code: str) -> dict[str, Any]:
async def get_access_token(self, code: str) -> dict[str, Any]: # ty:ignore[invalid-method-override]
"""
Exchange authorization code for access token using WeChat's GET method.

Expand Down
12 changes: 6 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ repository = "https://github.com/fastapi-practices/fastapi_oauth20"
[dependency-groups]
dev = [
"click==8.2.1",
"fastapi>=0.119.0",
"pytest>=8.4.0",
"pytest-asyncio>=1.2.0",
"pytest-cov>=7.0.0",
"respx>=0.22.0",
"ty>=0.0.1a23",
"fastapi>=0.136.0",
"pytest>=9.0.3",
"pytest-asyncio>=1.3.0",
"pytest-cov>=7.1.0",
"respx>=0.23.1",
"ty>=0.0.31",
]
lint = [
"prek>=0.3.9",
Expand Down
42 changes: 22 additions & 20 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# This file was autogenerated by uv via the following command:
# uv export -o requirements.txt --no-hashes
-e .
annotated-doc==0.0.4
# via fastapi
annotated-types==0.7.0
# via pydantic
anyio==4.11.0
anyio==4.13.0
# via
# httpx
# starlette
backports-asyncio-runner==1.2.0 ; python_full_version < '3.11'
# via pytest-asyncio
certifi==2025.10.5
certifi==2026.2.25
# via
# httpcore
# httpx
Expand All @@ -18,13 +20,13 @@ colorama==0.4.6 ; sys_platform == 'win32'
# via
# click
# pytest
coverage==7.11.1
coverage==7.13.5
# via pytest-cov
exceptiongroup==1.3.0 ; python_full_version < '3.11'
exceptiongroup==1.3.1 ; python_full_version < '3.11'
# via
# anyio
# pytest
fastapi==0.119.0
fastapi==0.136.0
h11==0.16.0
# via httpcore
httpcore==1.0.9
Expand All @@ -37,37 +39,35 @@ idna==3.11
# via
# anyio
# httpx
iniconfig==2.1.0
iniconfig==2.3.0
# via pytest
packaging==25.0
packaging==26.1
# via pytest
pluggy==1.6.0
# via
# pytest
# pytest-cov
prek==0.3.9
pydantic==2.12.2
pydantic==2.13.2
# via fastapi
pydantic-core==2.41.4
pydantic-core==2.46.2
# via pydantic
pygments==2.19.2
pygments==2.20.0
# via pytest
pytest==8.4.2
pytest==9.0.3
# via
# pytest-asyncio
# pytest-cov
pytest-asyncio==1.2.0
pytest-cov==7.0.0
respx==0.22.0
sniffio==1.3.1
# via anyio
starlette==0.48.0
pytest-asyncio==1.3.0
pytest-cov==7.1.0
respx==0.23.1
starlette==1.0.0
# via fastapi
tomli==2.3.0 ; python_full_version <= '3.11'
tomli==2.4.1 ; python_full_version <= '3.11'
# via
# coverage
# pytest
ty==0.0.1a23
ty==0.0.31
typing-extensions==4.15.0
# via
# anyio
Expand All @@ -79,4 +79,6 @@ typing-extensions==4.15.0
# starlette
# typing-inspection
typing-inspection==0.4.2
# via pydantic
# via
# fastapi
# pydantic
22 changes: 22 additions & 0 deletions tests/clients/test_github.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,19 @@ async def test_get_userinfo_success_without_email(self, github_client):
assert result['login'] == mock_user_data['login']
assert result['email'] == 'test@example.com'

@pytest.mark.asyncio
@respx.mock
async def test_get_userinfo_without_primary_email_uses_first_email(self, github_client):
mock_user_data = create_mock_user_data('github', email=None)
mock_user_info_response(respx, GITHUB_USER_INFO_URL, mock_user_data)
emails_data = [
{'email': 'fallback@example.com', 'primary': False},
{'email': 'secondary@example.com', 'primary': False},
]
respx.get(GITHUB_EMAILS_URL).mock(return_value=httpx.Response(200, json=emails_data))
result = await github_client.get_userinfo(TEST_ACCESS_TOKEN)
assert result['email'] == 'fallback@example.com'

@pytest.mark.asyncio
@respx.mock
async def test_get_userinfo_with_different_access_token(self, github_client):
Expand Down Expand Up @@ -134,6 +147,15 @@ async def test_get_userinfo_invalid_json(self, github_client):
with pytest.raises(GetUserInfoError):
await github_client.get_userinfo(TEST_ACCESS_TOKEN)

@pytest.mark.asyncio
@respx.mock
async def test_get_userinfo_emails_invalid_json(self, github_client):
mock_user_data = create_mock_user_data('github', email=None)
mock_user_info_response(respx, GITHUB_USER_INFO_URL, mock_user_data)
respx.get(GITHUB_EMAILS_URL).mock(return_value=httpx.Response(200, text='invalid json'))
with pytest.raises(GetUserInfoError):
await github_client.get_userinfo(TEST_ACCESS_TOKEN)

@pytest.mark.asyncio
@respx.mock
async def test_get_userinfo_rate_limit(self, github_client):
Expand Down
Loading
Loading