diff --git a/README.md b/README.md index 8fb5895..46e7409 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,17 @@ pip install fastapi_oauth20 [fastapi oauth20](https://fastapi-practices.github.io/fastapi-oauth20/) +## Demo + +查看完整的示例项目:[fastapi-oauth20-demo](https://github.com/fastapi-practices/fastapi-oauth20-demo) + +该示例项目展示了如何在实际应用中使用 fastapi-oauth20,包括: + +- 多个 OAuth2 提供商的集成示例 +- 完整的授权流程实现 +- 用户信息获取和处理 +- 错误处理最佳实践 + ## Sponsor 如果这个项目对你有帮助,欢迎[请作者喝杯咖啡](https://wu-clan.github.io/sponsor/) ☕ diff --git a/docs/status.md b/docs/status.md index 1a63e75..2fd6f66 100644 --- a/docs/status.md +++ b/docs/status.md @@ -1,18 +1,16 @@ -下面展示了我们的计划,如果你有更多需求,请在仓库内创建 Issues,我们将尽力完成所有目标 +如果你有更多需求,请在仓库内创建 [Issues](https://github.com/fastapi-practices/fastapi-oauth20/issues) -## FINISHED - -- [x] [LinuxDo](clients/linuxdo.md) - [x] [GitHub](clients/github.md) +- [x] [Google](clients/google.md) +- [x] [LinuxDo](clients/linuxdo.md) - [x] [Gitee](clients/gitee.md) - [x] [开源中国](clients/oschina.md) - [x] [飞书](clients/feishu.md) -- [x] [Google](clients/google.md) +- [x] [微信小程序](clients/wechat_open.md) +- [x] [微信开放平台](clients/wechat_mp.md) ## TODO -- [ ] [微信小程序](clients/wechat_open.md) -- [ ] [微信开放平台](clients/wechat_mp.md) - [ ] [企业微信二维码登录](clients/wechat_work.md) - [ ] [钉钉](clients/dingtalk.md) - [ ] [QQ](clients/qq.md) diff --git a/docs/usage.md b/docs/usage.md index 997aa5a..13ed422 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -4,6 +4,22 @@ from fastapi_oauth20 import FastAPIOAuth20 本指南介绍如何将 FastAPI OAuth2.0 库与各种 OAuth2 提供程序一起使用。 +## 演示项目 + +在开始之前,强烈推荐查看我们的完整演示项目 + +**[fastapi-oauth20-demo](https://github.com/fastapi-practices/fastapi-oauth20-demo)** + +该演示项目包含: + +- 多个 OAuth2 提供商的集成示例 +- 详细的代码注释和实现说明 +- 生产环境的最佳实践 +- 错误处理示例 +- 可直接运行的完整应用 + +通过演示项目,你可以快速了解如何在真实应用中使用 fastapi-oauth20 + ## 基本用法 ### 1. 选择 OAuth2 提供商并初始化客户端 diff --git a/fastapi_oauth20/__init__.py b/fastapi_oauth20/__init__.py index 3847c3a..6ead953 100644 --- a/fastapi_oauth20/__init__.py +++ b/fastapi_oauth20/__init__.py @@ -6,5 +6,7 @@ from .clients.google import GoogleOAuth20 as GoogleOAuth20 from .clients.linuxdo import LinuxDoOAuth20 as LinuxDoOAuth20 from .clients.oschina import OSChinaOAuth20 as OSChinaOAuth20 +from .clients.weixin_mp import WeChatMpOAuth20 as WeChatMpOAuth20 +from .clients.weixin_open import WeChatOpenOAuth20 as WeChatOpenOAuth20 __version__ = '0.0.2' diff --git a/fastapi_oauth20/callback.py b/fastapi_oauth20/callback.py index efa5a0b..50043a4 100644 --- a/fastapi_oauth20/callback.py +++ b/fastapi_oauth20/callback.py @@ -1,3 +1,5 @@ +import inspect + from typing import Annotated, Any import httpx @@ -25,6 +27,7 @@ def __init__( :param detail: Error detail message describing what went wrong. :param headers: Additional HTTP headers to include in the error response. :param response: The original HTTP response that caused the error (if any). + :return: """ self.response = response super().__init__(status_code=status_code, detail=detail, headers=headers) @@ -44,6 +47,7 @@ def __init__( :param client: An OAuth2 client instance that inherits from OAuth20Base. :param redirect_uri: The full callback URL where the OAuth2 provider redirects after authorization. Must match the URL registered with the OAuth2 provider. + :return: """ self.client = client self.redirect_uri = redirect_uri @@ -64,6 +68,7 @@ async def __call__( :param state: The state parameter for CSRF protection (extracted from query parameters). :param code_verifier: PKCE code verifier if PKCE was used in the authorization request. :param error: Error parameter from OAuth2 provider if authorization was denied or failed. + :return: """ if code is None or error is not None: raise OAuth20AuthorizeCallbackError( @@ -71,12 +76,19 @@ async def __call__( detail=error if error is not None else None, ) + kwargs = {'code': code} + try: - access_token = await self.client.get_access_token( - code=code, - redirect_uri=self.redirect_uri, - code_verifier=code_verifier, - ) + sig = inspect.signature(self.client.get_access_token) + params = sig.parameters + + if 'redirect_uri' in params: + kwargs['redirect_uri'] = self.redirect_uri + + if 'code_verifier' in params: + kwargs['code_verifier'] = code_verifier + + access_token = await self.client.get_access_token(**kwargs) except OAuth20RequestError as e: raise OAuth20AuthorizeCallbackError( status_code=500, diff --git a/fastapi_oauth20/clients/feishu.py b/fastapi_oauth20/clients/feishu.py index a446241..e56194a 100644 --- a/fastapi_oauth20/clients/feishu.py +++ b/fastapi_oauth20/clients/feishu.py @@ -10,6 +10,7 @@ def __init__(self, client_id: str, client_secret: str): :param client_id: FeiShu app client ID from the FeiShu developer console. :param client_secret: FeiShu app client secret from the FeiShu developer console. + :return: """ super().__init__( client_id=client_id, diff --git a/fastapi_oauth20/clients/gitee.py b/fastapi_oauth20/clients/gitee.py index 69a69bb..45eb182 100644 --- a/fastapi_oauth20/clients/gitee.py +++ b/fastapi_oauth20/clients/gitee.py @@ -10,6 +10,7 @@ def __init__(self, client_id: str, client_secret: str): :param client_id: Gitee OAuth application client ID. :param client_secret: Gitee OAuth application client secret. + :return: """ super().__init__( client_id=client_id, diff --git a/fastapi_oauth20/clients/github.py b/fastapi_oauth20/clients/github.py index bcaf3ff..d723d4b 100644 --- a/fastapi_oauth20/clients/github.py +++ b/fastapi_oauth20/clients/github.py @@ -15,6 +15,7 @@ def __init__(self, client_id: str, client_secret: str): :param client_id: GitHub OAuth App client ID. :param client_secret: GitHub OAuth App client secret. + :return: """ super().__init__( client_id=client_id, diff --git a/fastapi_oauth20/clients/google.py b/fastapi_oauth20/clients/google.py index 97a5468..992b88c 100644 --- a/fastapi_oauth20/clients/google.py +++ b/fastapi_oauth20/clients/google.py @@ -10,6 +10,7 @@ def __init__(self, client_id: str, client_secret: str): :param client_id: Google OAuth 2.0 client ID from Google Cloud Console. :param client_secret: Google OAuth 2.0 client secret from Google Cloud Console. + :return: """ super().__init__( client_id=client_id, diff --git a/fastapi_oauth20/clients/linuxdo.py b/fastapi_oauth20/clients/linuxdo.py index 5b9e808..9e0d005 100644 --- a/fastapi_oauth20/clients/linuxdo.py +++ b/fastapi_oauth20/clients/linuxdo.py @@ -10,6 +10,7 @@ def __init__(self, client_id: str, client_secret: str): :param client_id: Linux.do OAuth application client ID. :param client_secret: Linux.do OAuth application client secret. + :return: """ super().__init__( client_id=client_id, diff --git a/fastapi_oauth20/clients/oschina.py b/fastapi_oauth20/clients/oschina.py index 9ad1189..c13edcd 100644 --- a/fastapi_oauth20/clients/oschina.py +++ b/fastapi_oauth20/clients/oschina.py @@ -10,6 +10,7 @@ def __init__(self, client_id: str, client_secret: str): :param client_id: OSChina OAuth application client ID. :param client_secret: OSChina OAuth application client secret. + :return: """ super().__init__( client_id=client_id, diff --git a/fastapi_oauth20/clients/weixin_mp.py b/fastapi_oauth20/clients/weixin_mp.py new file mode 100644 index 0000000..6f3bc71 --- /dev/null +++ b/fastapi_oauth20/clients/weixin_mp.py @@ -0,0 +1,139 @@ +from typing import Any +from urllib.parse import urlencode + +import httpx + +from fastapi_oauth20.errors import AccessTokenError, GetUserInfoError, RefreshTokenError +from fastapi_oauth20.oauth20 import OAuth20Base + + +class WeChatMpOAuth20(OAuth20Base): + """WeChat public platform OAuth2 client implementation.""" + + def __init__(self, client_id: str, client_secret: str): + """ + Initialize WeChat public platform OAuth2 client. + + :param client_id: AppID from the WeChat public platform developer console. + :param client_secret: AppSecret from the WeChat public platform developer console. + :return: + """ + super().__init__( + client_id=client_id, + client_secret=client_secret, + authorize_endpoint='https://open.weixin.qq.com/connect/oauth2/authorize', + access_token_endpoint='https://api.weixin.qq.com/sns/oauth2/access_token', + refresh_token_endpoint='https://api.weixin.qq.com/sns/oauth2/refresh_token', + userinfo_endpoint='https://api.weixin.qq.com/sns/userinfo', + default_scopes=['snsapi_userinfo'], + ) + + async def get_authorization_url( + self, + redirect_uri: str, + state: str | None = None, + scope: list[str] | None = None, + **kwargs, + ) -> str: + """ + Generate WeChat OAuth2 authorization URL. + + :param redirect_uri: The URL where WeChat will redirect after authorization. + :param state: An opaque value used to maintain state between request and callback. + :param scope: The list of OAuth scopes to request. Default is ['snsapi_userinfo']. + :param kwargs: Additional query parameters. + :return: + """ + params = { + 'appid': self.client_id, + 'redirect_uri': redirect_uri, + 'response_type': 'code', + } + + if state is not None: + params['state'] = state + + _scope = scope or self.default_scopes + if _scope is not None: + params['scope'] = ','.join(_scope) + + if kwargs: + params.update(kwargs) + + return f'{self.authorize_endpoint}?{urlencode(params)}#wechat_redirect' + + async def get_access_token(self, code: str) -> dict[str, Any]: + """ + Exchange authorization code for access token using WeChat's GET method. + + :param code: The authorization code received from WeChat callback. + :return: + """ + params = { + 'appid': self.client_id, + 'secret': self.client_secret, + 'code': code, + 'grant_type': 'authorization_code', + } + + async with httpx.AsyncClient() as client: + response = await client.get( + self.access_token_endpoint, + params=params, + headers=self.request_headers, + ) + self.raise_httpx_oauth20_errors(response) + result = self.get_json_result(response, err_class=AccessTokenError) + return result + + async def refresh_token(self, refresh_token: str) -> dict[str, Any]: + """ + Refresh access token using WeChat's GET method. + + :param refresh_token: The refresh token received from initial token exchange. + :return: + """ + if self.refresh_token_endpoint is None: + raise RefreshTokenError('The refresh token address is missing') + + params = { + 'appid': self.client_id, + 'grant_type': 'refresh_token', + 'refresh_token': refresh_token, + } + + async with httpx.AsyncClient() as client: + response = await client.get( + self.refresh_token_endpoint, + params=params, + headers=self.request_headers, + ) + self.raise_httpx_oauth20_errors(response) + result = self.get_json_result(response, err_class=RefreshTokenError) + return result + + async def get_userinfo(self, access_token: str, openid: str | None = None) -> dict[str, Any]: + """ + Retrieve user information from WeChat API. + + :param access_token: Valid WeChat access token. + :param openid: User's OpenID. + :return: + """ + if openid is None: + raise GetUserInfoError('openid is required') + + params = { + 'access_token': access_token, + 'openid': openid, + 'lang': 'zh_CN', + } + + async with httpx.AsyncClient() as client: + response = await client.get( + self.userinfo_endpoint, + params=params, + ) + self.raise_httpx_oauth20_errors(response) + result = self.get_json_result(response, err_class=GetUserInfoError) + return result diff --git a/fastapi_oauth20/clients/weixin_open.py b/fastapi_oauth20/clients/weixin_open.py new file mode 100644 index 0000000..2e8d130 --- /dev/null +++ b/fastapi_oauth20/clients/weixin_open.py @@ -0,0 +1,135 @@ +from typing import Any +from urllib.parse import urlencode + +import httpx + +from fastapi_oauth20.errors import AccessTokenError, GetUserInfoError, RefreshTokenError +from fastapi_oauth20.oauth20 import OAuth20Base + + +class WeChatOpenOAuth20(OAuth20Base): + """WeChat open platform OAuth2 client implementation.""" + + def __init__(self, client_id: str, client_secret: str): + """ + Initialize WeChat open platform OAuth2 client. + + :param client_id: AppID from the WeChat open platform developer console. + :param client_secret: AppSecret from the WeChat open platform developer console. + :return: + """ + super().__init__( + client_id=client_id, + client_secret=client_secret, + authorize_endpoint='https://open.weixin.qq.com/connect/qrconnect', + access_token_endpoint='https://api.weixin.qq.com/sns/oauth2/access_token', + refresh_token_endpoint='https://api.weixin.qq.com/sns/oauth2/refresh_token', + userinfo_endpoint='https://api.weixin.qq.com/sns/userinfo', + default_scopes=['snsapi_login'], + ) + + async def get_authorization_url( + self, + redirect_uri: str, + state: str | None = None, + scope: list[str] | None = None, + **kwargs, + ) -> str: + """ + Generate WeChat Open Platform OAuth2 authorization URL. + + :param redirect_uri: The URL where WeChat will redirect after authorization. + :param state: An opaque value used to maintain state between request and callback. + :param scope: The list of OAuth scopes to request. Default is ['snsapi_login']. + :param kwargs: Additional query parameters. + :return: + """ + params = {'appid': self.client_id, 'redirect_uri': redirect_uri, 'response_type': 'code', 'lang': 'cn'} + + if state is not None: + params['state'] = state + + _scope = scope or self.default_scopes + if _scope is not None: + params['scope'] = ','.join(_scope) + + if kwargs: + params.update(kwargs) + + return f'{self.authorize_endpoint}?{urlencode(params)}#wechat_redirect' + + async def get_access_token(self, code: str) -> dict[str, Any]: + """ + Exchange authorization code for access token using WeChat's GET method. + + :param code: The authorization code received from WeChat callback. + :return: + """ + params = { + 'appid': self.client_id, + 'secret': self.client_secret, + 'code': code, + 'grant_type': 'authorization_code', + } + + async with httpx.AsyncClient() as client: + response = await client.get( + self.access_token_endpoint, + params=params, + headers=self.request_headers, + ) + self.raise_httpx_oauth20_errors(response) + result = self.get_json_result(response, err_class=AccessTokenError) + return result + + async def refresh_token(self, refresh_token: str) -> dict[str, Any]: + """ + Refresh access token using WeChat's GET method. + + :param refresh_token: The refresh token received from initial token exchange. + :return: + """ + if self.refresh_token_endpoint is None: + raise RefreshTokenError('The refresh token address is missing') + + params = { + 'appid': self.client_id, + 'grant_type': 'refresh_token', + 'refresh_token': refresh_token, + } + + async with httpx.AsyncClient() as client: + response = await client.get( + self.refresh_token_endpoint, + params=params, + headers=self.request_headers, + ) + self.raise_httpx_oauth20_errors(response) + result = self.get_json_result(response, err_class=RefreshTokenError) + return result + + async def get_userinfo(self, access_token: str, openid: str | None = None) -> dict[str, Any]: + """ + Retrieve user information from WeChat Open Platform API. + + :param access_token: Valid WeChat access token. + :param openid: User's OpenID. If not provided, will attempt to extract from previous token response. + :return: + """ + if openid is None: + raise GetUserInfoError('openid is required') + + params = { + 'access_token': access_token, + 'openid': openid, + 'lang': 'zh_CN', + } + + async with httpx.AsyncClient() as client: + response = await client.get( + self.userinfo_endpoint, + params=params, + ) + self.raise_httpx_oauth20_errors(response) + result = self.get_json_result(response, err_class=GetUserInfoError) + return result diff --git a/fastapi_oauth20/errors.py b/fastapi_oauth20/errors.py index b5fb889..a46e52d 100644 --- a/fastapi_oauth20/errors.py +++ b/fastapi_oauth20/errors.py @@ -11,6 +11,7 @@ def __init__(self, msg: str) -> None: Initialize base OAuth2 error. :param msg: Human-readable error message describing the OAuth2 error. + :return: """ self.msg = msg super().__init__(msg) @@ -25,6 +26,7 @@ def __init__(self, msg: str, response: httpx.Response | None = None) -> None: :param msg: Human-readable error message describing the request error. :param response: The HTTP response object that caused the error (if available). + :return: """ self.response = response super().__init__(msg) diff --git a/fastapi_oauth20/oauth20.py b/fastapi_oauth20/oauth20.py index 2f62495..1614f5f 100644 --- a/fastapi_oauth20/oauth20.py +++ b/fastapi_oauth20/oauth20.py @@ -43,6 +43,7 @@ def __init__( :param default_scopes: Default list of OAuth scopes to request if none are specified. :param token_endpoint_basic_auth: Whether to use HTTP Basic Authentication for token endpoint requests. :param revoke_token_endpoint_basic_auth: Whether to use HTTP Basic Authentication for revoke endpoint requests. + :return: """ self.client_id = client_id self.client_secret = client_secret diff --git a/mkdocs.yml b/mkdocs.yml index 820254b..b0af919 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -12,8 +12,9 @@ nav: - 用法: usage.md - 客户端状态: status.md - 客户端申请: - - LinuxDo: clients/linuxdo.md - GitHub: clients/github.md + - Google: clients/google.md + - LinuxDo: clients/linuxdo.md - Gitee: clients/gitee.md - 开源中国: clients/oschina.md - 微信小程序: clients/wechat_mp.md @@ -22,7 +23,6 @@ nav: - 飞书: clients/feishu.md - 钉钉: clients/dingtalk.md - QQ: clients/qq.md - - Google: clients/google.md - 变更日志: changelog.md theme: name: material diff --git a/tests/clients/test_feishu.py b/tests/clients/test_feishu.py index 15f9e38..c263202 100644 --- a/tests/clients/test_feishu.py +++ b/tests/clients/test_feishu.py @@ -19,15 +19,11 @@ @pytest.fixture def feishu_client(): - """Create FeiShu OAuth2 client instance for testing.""" return FeiShuOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) class TestFeiShuOAuth20: - """Test FeiShu OAuth2 client functionality.""" - - def test_feishu_client_initialization(self, feishu_client): - """Test FeiShu client initialization with correct parameters.""" + def test_client_initialization(self, feishu_client): assert feishu_client.client_id == TEST_CLIENT_ID assert feishu_client.client_secret == TEST_CLIENT_SECRET assert feishu_client.authorize_endpoint == 'https://passport.feishu.cn/suite/passport/oauth/authorize' @@ -39,94 +35,69 @@ def test_feishu_client_initialization(self, feishu_client): 'contact:user.email:readonly', ] - def test_feishu_client_initialization_with_custom_credentials(self): - """Test FeiShu client initialization with custom credentials.""" - client = FeiShuOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) - assert client.client_id == TEST_CLIENT_ID - assert client.client_secret == TEST_CLIENT_SECRET - - def test_feishu_client_inheritance(self, feishu_client): - """Test that FeiShu client properly inherits from OAuth20Base.""" + def test_client_inheritance(self, feishu_client): assert isinstance(feishu_client, OAuth20Base) - def test_feishu_client_scopes_are_lists(self, feishu_client): - """Test that default scopes are properly configured as lists.""" + def test_client_scopes_are_lists(self, feishu_client): assert isinstance(feishu_client.default_scopes, list) assert len(feishu_client.default_scopes) == 3 assert all(isinstance(scope, str) for scope in feishu_client.default_scopes) - def test_feishu_client_endpoint_urls(self): - """Test that FeiShu client uses correct endpoint URLs.""" - client = FeiShuOAuth20(TEST_CLIENT_ID, TEST_CLIENT_SECRET) - - # Test that endpoints are correctly set without hardcoding them in tests - assert client.authorize_endpoint.endswith('/suite/passport/oauth/authorize') - assert client.access_token_endpoint.endswith('/suite/passport/oauth/token') - assert client.refresh_token_endpoint.endswith('/suite/passport/oauth/authorize') - - # Test that all endpoints use the correct domain - for endpoint in [client.authorize_endpoint, client.access_token_endpoint, client.refresh_token_endpoint]: + def test_client_endpoint_urls(self, feishu_client): + assert feishu_client.authorize_endpoint.endswith('/suite/passport/oauth/authorize') + assert feishu_client.access_token_endpoint.endswith('/suite/passport/oauth/token') + assert feishu_client.refresh_token_endpoint.endswith('/suite/passport/oauth/authorize') + for endpoint in [ + feishu_client.authorize_endpoint, + feishu_client.access_token_endpoint, + feishu_client.refresh_token_endpoint, + ]: assert 'passport.feishu.cn' in endpoint - def test_feishu_client_multiple_instances(self): - """Test that multiple FeiShu client instances work independently.""" + def test_client_multiple_instances(self): client1 = FeiShuOAuth20('client1', 'secret1') client2 = FeiShuOAuth20('client2', 'secret2') - assert client1.client_id != client2.client_id assert client1.client_secret != client2.client_secret assert client1.authorize_endpoint == client2.authorize_endpoint - assert client1.access_token_endpoint == client2.access_token_endpoint @pytest.mark.asyncio @respx.mock async def test_get_userinfo_success(self, feishu_client): - """Test successful user info retrieval from FeiShu API.""" mock_user_data = create_mock_user_data('feishu') mock_user_info_response(respx, FEISHU_USER_INFO_URL, mock_user_data) - result = await feishu_client.get_userinfo(TEST_ACCESS_TOKEN) assert result == mock_user_data @pytest.mark.asyncio @respx.mock async def test_get_userinfo_with_different_access_token(self, feishu_client): - """Test user info retrieval with different access tokens.""" mock_user_data = create_mock_user_data('feishu', user_id='user_789', name='Another User') mock_user_info_response(respx, FEISHU_USER_INFO_URL, mock_user_data) - result = await feishu_client.get_userinfo('different_token') assert result == mock_user_data @pytest.mark.asyncio @respx.mock async def test_get_userinfo_empty_response(self, feishu_client): - """Test handling of empty user info response.""" mock_user_info_response(respx, FEISHU_USER_INFO_URL, {}) - result = await feishu_client.get_userinfo(TEST_ACCESS_TOKEN) assert result == {} @pytest.mark.asyncio @respx.mock async def test_get_userinfo_partial_data(self, feishu_client): - """Test handling of partial user info response.""" partial_data = {'user_id': 'test_user', 'name': 'Test User'} mock_user_info_response(respx, FEISHU_USER_INFO_URL, partial_data) - result = await feishu_client.get_userinfo(TEST_ACCESS_TOKEN) assert result == partial_data @pytest.mark.asyncio @respx.mock async def test_get_userinfo_authorization_header(self, feishu_client): - """Test that authorization header is correctly formatted.""" mock_user_data = {'user_id': 'test_user'} route = mock_user_info_response(respx, FEISHU_USER_INFO_URL, mock_user_data) - await feishu_client.get_userinfo(TEST_ACCESS_TOKEN) - - # Verify the request was made with correct authorization header assert route.called request = route.calls[0].request assert request.headers['authorization'] == f'Bearer {TEST_ACCESS_TOKEN}' @@ -134,35 +105,27 @@ async def test_get_userinfo_authorization_header(self, feishu_client): @pytest.mark.asyncio @respx.mock async def test_get_userinfo_http_error_401(self, feishu_client): - """Test handling of 401 HTTP error when getting user info.""" respx.get(FEISHU_USER_INFO_URL).mock(return_value=httpx.Response(401, text='Unauthorized')) - with pytest.raises(HTTPXOAuth20Error): await feishu_client.get_userinfo(INVALID_TOKEN) @pytest.mark.asyncio @respx.mock async def test_get_userinfo_http_error_403(self, feishu_client): - """Test handling of 403 HTTP error when getting user info.""" respx.get(FEISHU_USER_INFO_URL).mock(return_value=httpx.Response(403, text='Forbidden')) - with pytest.raises(HTTPXOAuth20Error): await feishu_client.get_userinfo(TEST_ACCESS_TOKEN) @pytest.mark.asyncio @respx.mock async def test_get_userinfo_http_error_500(self, feishu_client): - """Test handling of 500 HTTP error when getting user info.""" respx.get(FEISHU_USER_INFO_URL).mock(return_value=httpx.Response(500, text='Internal Server Error')) - with pytest.raises(HTTPXOAuth20Error): await feishu_client.get_userinfo(TEST_ACCESS_TOKEN) @pytest.mark.asyncio @respx.mock async def test_get_userinfo_invalid_json(self, feishu_client): - """Test handling of invalid JSON response.""" respx.get(FEISHU_USER_INFO_URL).mock(return_value=httpx.Response(200, text='invalid json')) - with pytest.raises(GetUserInfoError): await feishu_client.get_userinfo(TEST_ACCESS_TOKEN) diff --git a/tests/clients/test_gitee.py b/tests/clients/test_gitee.py index bb03534..09a1fb5 100644 --- a/tests/clients/test_gitee.py +++ b/tests/clients/test_gitee.py @@ -19,15 +19,11 @@ @pytest.fixture def gitee_client(): - """Create Gitee OAuth2 client instance for testing.""" return GiteeOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) class TestGiteeOAuth20: - """Test Gitee OAuth2 client functionality.""" - - def test_gitee_client_initialization(self, gitee_client): - """Test Gitee client initialization with correct parameters.""" + def test_client_initialization(self, gitee_client): assert gitee_client.client_id == TEST_CLIENT_ID assert gitee_client.client_secret == TEST_CLIENT_SECRET assert gitee_client.authorize_endpoint == 'https://gitee.com/oauth/authorize' @@ -35,92 +31,68 @@ def test_gitee_client_initialization(self, gitee_client): assert gitee_client.refresh_token_endpoint == 'https://gitee.com/oauth/token' assert gitee_client.default_scopes == ['user_info'] - def test_gitee_client_initialization_with_custom_credentials(self): - """Test Gitee client initialization with custom credentials.""" - client = GiteeOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) - assert client.client_id == TEST_CLIENT_ID - assert client.client_secret == TEST_CLIENT_SECRET - - def test_gitee_client_inheritance(self, gitee_client): - """Test that Gitee client properly inherits from OAuth20Base.""" + def test_client_inheritance(self, gitee_client): assert isinstance(gitee_client, OAuth20Base) - def test_gitee_client_scopes_are_lists(self, gitee_client): - """Test that default scopes are properly configured as lists.""" + def test_client_scopes_are_lists(self, gitee_client): assert isinstance(gitee_client.default_scopes, list) assert len(gitee_client.default_scopes) == 1 assert all(isinstance(scope, str) for scope in gitee_client.default_scopes) - def test_gitee_client_endpoint_urls(self): - """Test that Gitee client uses correct endpoint URLs.""" - client = GiteeOAuth20(TEST_CLIENT_ID, TEST_CLIENT_SECRET) - - # Test that endpoints are correctly set without hardcoding them in tests - assert client.authorize_endpoint.endswith('/oauth/authorize') - assert client.access_token_endpoint.endswith('/oauth/token') - assert client.refresh_token_endpoint.endswith('/oauth/token') - - # Test that all endpoints use the correct domain - for endpoint in [client.authorize_endpoint, client.access_token_endpoint, client.refresh_token_endpoint]: + def test_client_endpoint_urls(self, gitee_client): + assert gitee_client.authorize_endpoint.endswith('/oauth/authorize') + assert gitee_client.access_token_endpoint.endswith('/oauth/token') + assert gitee_client.refresh_token_endpoint.endswith('/oauth/token') + for endpoint in [ + gitee_client.authorize_endpoint, + gitee_client.access_token_endpoint, + gitee_client.refresh_token_endpoint, + ]: assert 'gitee.com' in endpoint @pytest.mark.asyncio @respx.mock async def test_get_userinfo_success(self, gitee_client): - """Test successful user info retrieval from Gitee API.""" mock_user_data = create_mock_user_data('gitee') mock_user_info_response(respx, GITEE_USER_INFO_URL, mock_user_data) - result = await gitee_client.get_userinfo(TEST_ACCESS_TOKEN) assert result == mock_user_data @pytest.mark.asyncio @respx.mock async def test_get_userinfo_authorization_header(self, gitee_client): - """Test that authorization header is correctly formatted.""" mock_user_data = {'id': 'test_user'} route = mock_user_info_response(respx, GITEE_USER_INFO_URL, mock_user_data) - await gitee_client.get_userinfo(TEST_ACCESS_TOKEN) - - # Verify the request was made with correct authorization header assert route.called request = route.calls[0].request assert request.headers['authorization'] == f'Bearer {TEST_ACCESS_TOKEN}' - @pytest.mark.asyncio - @respx.mock - async def test_get_userinfo_http_error(self, gitee_client): - """Test handling of HTTP errors when getting user info.""" - respx.get(GITEE_USER_INFO_URL).mock(return_value=httpx.Response(401, text='Unauthorized')) - - with pytest.raises(HTTPXOAuth20Error): - await gitee_client.get_userinfo(INVALID_TOKEN) - @pytest.mark.asyncio @respx.mock async def test_get_userinfo_empty_response(self, gitee_client): - """Test handling of empty user info response.""" mock_user_info_response(respx, GITEE_USER_INFO_URL, {}) - result = await gitee_client.get_userinfo(TEST_ACCESS_TOKEN) assert result == {} @pytest.mark.asyncio @respx.mock async def test_get_userinfo_partial_data(self, gitee_client): - """Test handling of partial user info response.""" partial_data = {'id': 123456, 'login': 'testuser'} mock_user_info_response(respx, GITEE_USER_INFO_URL, partial_data) - result = await gitee_client.get_userinfo(TEST_ACCESS_TOKEN) assert result == partial_data + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_http_error(self, gitee_client): + respx.get(GITEE_USER_INFO_URL).mock(return_value=httpx.Response(401, text='Unauthorized')) + with pytest.raises(HTTPXOAuth20Error): + await gitee_client.get_userinfo(INVALID_TOKEN) + @pytest.mark.asyncio @respx.mock async def test_get_userinfo_invalid_json(self, gitee_client): - """Test handling of invalid JSON response.""" respx.get(GITEE_USER_INFO_URL).mock(return_value=httpx.Response(200, text='invalid json')) - with pytest.raises(GetUserInfoError): await gitee_client.get_userinfo(TEST_ACCESS_TOKEN) diff --git a/tests/clients/test_github.py b/tests/clients/test_github.py index 134a11a..ee2befa 100644 --- a/tests/clients/test_github.py +++ b/tests/clients/test_github.py @@ -14,86 +14,59 @@ mock_user_info_response, ) -GITHUB_TOKEN_URL = 'https://github.com/login/oauth/access_token' GITHUB_USER_INFO_URL = 'https://api.github.com/user' GITHUB_EMAILS_URL = 'https://api.github.com/user/emails' @pytest.fixture def github_client(): - """Create GitHub OAuth2 client instance for testing.""" return GitHubOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) class TestGitHubOAuth20: - """Test GitHub OAuth2 client functionality.""" - - def test_github_client_initialization(self, github_client): - """Test GitHub client initialization with correct parameters.""" + def test_client_initialization(self, github_client): assert github_client.client_id == TEST_CLIENT_ID assert github_client.client_secret == TEST_CLIENT_SECRET assert github_client.authorize_endpoint == 'https://github.com/login/oauth/authorize' assert github_client.access_token_endpoint == 'https://github.com/login/oauth/access_token' assert github_client.default_scopes == ['user', 'user:email'] - def test_github_client_initialization_with_custom_credentials(self): - """Test GitHub client initialization with custom credentials.""" - client = GitHubOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) - assert client.client_id == TEST_CLIENT_ID - assert client.client_secret == TEST_CLIENT_SECRET - - def test_github_client_inheritance(self, github_client): - """Test that GitHub client properly inherits from OAuth20Base.""" + def test_client_inheritance(self, github_client): assert isinstance(github_client, OAuth20Base) - def test_github_client_scopes_are_lists(self, github_client): - """Test that default scopes are properly configured as lists.""" + def test_client_scopes_are_lists(self, github_client): assert isinstance(github_client.default_scopes, list) assert len(github_client.default_scopes) == 2 assert all(isinstance(scope, str) for scope in github_client.default_scopes) - def test_github_client_endpoint_urls(self): - """Test that GitHub client uses correct endpoint URLs.""" - client = GitHubOAuth20(TEST_CLIENT_ID, TEST_CLIENT_SECRET) - - # Test that endpoints are correctly set without hardcoding them in tests - assert client.authorize_endpoint.endswith('/login/oauth/authorize') - assert client.access_token_endpoint.endswith('/login/oauth/access_token') - - # Test that all endpoints use the correct domain - for endpoint in [client.authorize_endpoint, client.access_token_endpoint]: + def test_client_endpoint_urls(self, github_client): + assert github_client.authorize_endpoint.endswith('/login/oauth/authorize') + assert github_client.access_token_endpoint.endswith('/login/oauth/access_token') + for endpoint in [github_client.authorize_endpoint, github_client.access_token_endpoint]: assert 'github.com' in endpoint - def test_github_client_multiple_instances(self): - """Test that multiple GitHub client instances work independently.""" + def test_client_multiple_instances(self): client1 = GitHubOAuth20('client1', 'secret1') client2 = GitHubOAuth20('client2', 'secret2') - assert client1.client_id != client2.client_id assert client1.client_secret != client2.client_secret assert client1.authorize_endpoint == client2.authorize_endpoint - assert client1.access_token_endpoint == client2.access_token_endpoint @pytest.mark.asyncio @respx.mock async def test_get_userinfo_success_with_email(self, github_client): - """Test successful user info retrieval from GitHub API with email included.""" mock_user_data = create_mock_user_data('github') mock_user_info_response(respx, GITHUB_USER_INFO_URL, mock_user_data) - result = await github_client.get_userinfo(TEST_ACCESS_TOKEN) assert result == mock_user_data @pytest.mark.asyncio @respx.mock async def test_get_userinfo_success_without_email(self, github_client): - """Test successful user info retrieval from GitHub API without email.""" mock_user_data = create_mock_user_data('github', email=None) mock_user_info_response(respx, GITHUB_USER_INFO_URL, mock_user_data) - # Mock emails endpoint emails_data = [{'email': 'test@example.com', 'primary': True}] respx.get(GITHUB_EMAILS_URL).mock(return_value=httpx.Response(200, json=emails_data)) - result = await github_client.get_userinfo(TEST_ACCESS_TOKEN) assert result['login'] == mock_user_data['login'] assert result['email'] == 'test@example.com' @@ -101,100 +74,73 @@ async def test_get_userinfo_success_without_email(self, github_client): @pytest.mark.asyncio @respx.mock async def test_get_userinfo_with_different_access_token(self, github_client): - """Test user info retrieval with different access tokens.""" mock_user_data = create_mock_user_data('github', id=789, login='different_user') mock_user_info_response(respx, GITHUB_USER_INFO_URL, mock_user_data) - result = await github_client.get_userinfo('different_token') assert result == mock_user_data @pytest.mark.asyncio @respx.mock async def test_get_userinfo_authorization_header(self, github_client): - """Test that authorization header is correctly formatted.""" - mock_user_data = {'id': 'test_user', 'email': 'test@example.com'} # Include email to avoid emails endpoint call + mock_user_data = {'id': 'test_user', 'email': 'test@example.com'} route = mock_user_info_response(respx, GITHUB_USER_INFO_URL, mock_user_data) - await github_client.get_userinfo(TEST_ACCESS_TOKEN) - - # Verify the request was made with correct authorization header assert route.called request = route.calls[0].request assert request.headers['authorization'] == f'Bearer {TEST_ACCESS_TOKEN}' + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_empty_response(self, github_client): + mock_user_info_response(respx, GITHUB_USER_INFO_URL, {}) + emails_data = [{'email': 'test@example.com', 'primary': True}] + respx.get(GITHUB_EMAILS_URL).mock(return_value=httpx.Response(200, json=emails_data)) + result = await github_client.get_userinfo(TEST_ACCESS_TOKEN) + assert result['email'] == 'test@example.com' + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_partial_data(self, github_client): + partial_data = {'id': 123456, 'login': 'testuser', 'email': 'test@example.com'} + mock_user_info_response(respx, GITHUB_USER_INFO_URL, partial_data) + result = await github_client.get_userinfo(TEST_ACCESS_TOKEN) + assert result == partial_data + @pytest.mark.asyncio @respx.mock async def test_get_userinfo_http_error_401(self, github_client): - """Test handling of 401 HTTP error when getting user info.""" respx.get(GITHUB_USER_INFO_URL).mock(return_value=httpx.Response(401, text='Unauthorized')) - with pytest.raises(HTTPXOAuth20Error): await github_client.get_userinfo(INVALID_TOKEN) @pytest.mark.asyncio @respx.mock async def test_get_userinfo_http_error_403(self, github_client): - """Test handling of 403 HTTP error when getting user info.""" respx.get(GITHUB_USER_INFO_URL).mock(return_value=httpx.Response(403, text='Forbidden')) - with pytest.raises(HTTPXOAuth20Error): await github_client.get_userinfo(TEST_ACCESS_TOKEN) @pytest.mark.asyncio @respx.mock async def test_get_userinfo_http_error_500(self, github_client): - """Test handling of 500 HTTP error when getting user info.""" respx.get(GITHUB_USER_INFO_URL).mock(return_value=httpx.Response(500, text='Internal Server Error')) - with pytest.raises(HTTPXOAuth20Error): await github_client.get_userinfo(TEST_ACCESS_TOKEN) @pytest.mark.asyncio @respx.mock async def test_get_userinfo_invalid_json(self, github_client): - """Test handling of invalid JSON response.""" respx.get(GITHUB_USER_INFO_URL).mock(return_value=httpx.Response(200, text='invalid json')) - with pytest.raises(GetUserInfoError): await github_client.get_userinfo(TEST_ACCESS_TOKEN) - @pytest.mark.asyncio - @respx.mock - async def test_get_userinfo_empty_response(self, github_client): - """Test handling of empty user info response.""" - mock_user_info_response(respx, GITHUB_USER_INFO_URL, {}) - # Mock emails endpoint since empty response will trigger email lookup - emails_data = [{'email': 'test@example.com', 'primary': True}] - respx.get(GITHUB_EMAILS_URL).mock(return_value=httpx.Response(200, json=emails_data)) - - result = await github_client.get_userinfo(TEST_ACCESS_TOKEN) - assert result['email'] == 'test@example.com' - - @pytest.mark.asyncio - @respx.mock - async def test_get_userinfo_partial_data(self, github_client): - """Test handling of partial user info response.""" - partial_data = { - 'id': 123456, - 'login': 'testuser', - 'email': 'test@example.com', - } # Add email to avoid emails endpoint call - mock_user_info_response(respx, GITHUB_USER_INFO_URL, partial_data) - - result = await github_client.get_userinfo(TEST_ACCESS_TOKEN) - assert result == partial_data - @pytest.mark.asyncio @respx.mock async def test_get_userinfo_rate_limit(self, github_client): - """Test handling of GitHub API rate limit.""" - # GitHub rate limit response rate_limit_response = { 'message': 'API rate limit exceeded for xxx.xxx.xxx.xxx.', 'documentation_url': 'https://docs.github.com/rest/overview/rate-limits-for-the-rest-api', } - respx.get(GITHUB_USER_INFO_URL).mock(return_value=httpx.Response(403, json=rate_limit_response)) - with pytest.raises(HTTPXOAuth20Error): await github_client.get_userinfo(TEST_ACCESS_TOKEN) diff --git a/tests/clients/test_google.py b/tests/clients/test_google.py index 576765e..6a8d836 100644 --- a/tests/clients/test_google.py +++ b/tests/clients/test_google.py @@ -19,15 +19,11 @@ @pytest.fixture def google_client(): - """Create Google OAuth2 client instance for testing.""" return GoogleOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) class TestGoogleOAuth20: - """Test Google OAuth2 client functionality.""" - - def test_google_client_initialization(self, google_client): - """Test Google client initialization with correct parameters.""" + def test_client_initialization(self, google_client): assert google_client.client_id == TEST_CLIENT_ID assert google_client.client_secret == TEST_CLIENT_SECRET assert google_client.authorize_endpoint == 'https://accounts.google.com/o/oauth2/v2/auth' @@ -36,105 +32,74 @@ def test_google_client_initialization(self, google_client): assert google_client.revoke_token_endpoint == 'https://accounts.google.com/o/oauth2/revoke' assert google_client.default_scopes == ['email', 'openid', 'profile'] - def test_google_client_initialization_with_custom_credentials(self): - """Test Google client initialization with custom credentials.""" - client = GoogleOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) - assert client.client_id == TEST_CLIENT_ID - assert client.client_secret == TEST_CLIENT_SECRET - - def test_google_client_inheritance(self, google_client): - """Test that Google client properly inherits from OAuth20Base.""" + def test_client_inheritance(self, google_client): assert isinstance(google_client, OAuth20Base) - def test_google_client_scopes_are_lists(self, google_client): - """Test that default scopes are properly configured as lists.""" + def test_client_scopes_are_lists(self, google_client): assert isinstance(google_client.default_scopes, list) assert len(google_client.default_scopes) == 3 assert all(isinstance(scope, str) for scope in google_client.default_scopes) - def test_google_client_endpoint_urls(self): - """Test that Google client uses correct endpoint URLs.""" - client = GoogleOAuth20(TEST_CLIENT_ID, TEST_CLIENT_SECRET) - - # Test that endpoints are correctly set without hardcoding them in tests - assert client.authorize_endpoint.endswith('/o/oauth2/v2/auth') - assert client.access_token_endpoint.endswith('/token') - assert client.refresh_token_endpoint.endswith('/token') - assert client.revoke_token_endpoint.endswith('/o/oauth2/revoke') - - # Test that all endpoints use the correct domains - assert 'accounts.google.com' in client.authorize_endpoint - assert 'accounts.google.com' in client.revoke_token_endpoint - assert 'oauth2.googleapis.com' in client.access_token_endpoint - assert 'oauth2.googleapis.com' in client.refresh_token_endpoint + def test_client_endpoint_urls(self, google_client): + assert google_client.authorize_endpoint.endswith('/o/oauth2/v2/auth') + assert google_client.access_token_endpoint.endswith('/token') + assert google_client.refresh_token_endpoint.endswith('/token') + assert google_client.revoke_token_endpoint.endswith('/o/oauth2/revoke') + assert 'accounts.google.com' in google_client.authorize_endpoint + assert 'accounts.google.com' in google_client.revoke_token_endpoint + assert 'oauth2.googleapis.com' in google_client.access_token_endpoint + assert 'oauth2.googleapis.com' in google_client.refresh_token_endpoint + + def test_client_multiple_instances(self): + client1 = GoogleOAuth20('client1', 'secret1') + client2 = GoogleOAuth20('client2', 'secret2') + assert client1.client_id != client2.client_id + assert client1.client_secret != client2.client_secret + assert client1.authorize_endpoint == client2.authorize_endpoint @pytest.mark.asyncio @respx.mock async def test_get_userinfo_success(self, google_client): - """Test successful user info retrieval from Google OAuth2 API.""" mock_user_data = create_mock_user_data('google') mock_user_info_response(respx, GOOGLE_USER_INFO_URL, mock_user_data) - result = await google_client.get_userinfo(TEST_ACCESS_TOKEN) assert result == mock_user_data @pytest.mark.asyncio @respx.mock async def test_get_userinfo_authorization_header(self, google_client): - """Test that authorization header is correctly formatted.""" mock_user_data = {'id': 'test_user'} route = mock_user_info_response(respx, GOOGLE_USER_INFO_URL, mock_user_data) - await google_client.get_userinfo(TEST_ACCESS_TOKEN) - - # Verify the request was made with correct authorization header assert route.called request = route.calls[0].request assert request.headers['authorization'] == f'Bearer {TEST_ACCESS_TOKEN}' - @pytest.mark.asyncio - @respx.mock - async def test_get_userinfo_http_error(self, google_client): - """Test handling of HTTP errors when getting user info.""" - respx.get(GOOGLE_USER_INFO_URL).mock(return_value=httpx.Response(401, text='Unauthorized')) - - with pytest.raises(HTTPXOAuth20Error): - await google_client.get_userinfo(INVALID_TOKEN) - @pytest.mark.asyncio @respx.mock async def test_get_userinfo_empty_response(self, google_client): - """Test handling of empty user info response.""" mock_user_info_response(respx, GOOGLE_USER_INFO_URL, {}) - result = await google_client.get_userinfo(TEST_ACCESS_TOKEN) assert result == {} @pytest.mark.asyncio @respx.mock async def test_get_userinfo_partial_data(self, google_client): - """Test handling of partial user info response.""" partial_data = {'id': '123456789', 'email': 'test@example.com'} mock_user_info_response(respx, GOOGLE_USER_INFO_URL, partial_data) - result = await google_client.get_userinfo(TEST_ACCESS_TOKEN) assert result == partial_data + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_http_error(self, google_client): + respx.get(GOOGLE_USER_INFO_URL).mock(return_value=httpx.Response(401, text='Unauthorized')) + with pytest.raises(HTTPXOAuth20Error): + await google_client.get_userinfo(INVALID_TOKEN) + @pytest.mark.asyncio @respx.mock async def test_get_userinfo_invalid_json(self, google_client): - """Test handling of invalid JSON response.""" respx.get(GOOGLE_USER_INFO_URL).mock(return_value=httpx.Response(200, text='invalid json')) - with pytest.raises(GetUserInfoError): await google_client.get_userinfo(TEST_ACCESS_TOKEN) - - def test_google_client_multiple_instances(self): - """Test that multiple Google client instances work independently.""" - client1 = GoogleOAuth20('client1', 'secret1') - client2 = GoogleOAuth20('client2', 'secret2') - - assert client1.client_id != client2.client_id - assert client1.client_secret != client2.client_secret - assert client1.authorize_endpoint == client2.authorize_endpoint - assert client1.access_token_endpoint == client2.access_token_endpoint diff --git a/tests/clients/test_linuxdo.py b/tests/clients/test_linuxdo.py index b944ff3..b419489 100644 --- a/tests/clients/test_linuxdo.py +++ b/tests/clients/test_linuxdo.py @@ -19,15 +19,11 @@ @pytest.fixture def linuxdo_client(): - """Create LinuxDo OAuth2 client instance for testing.""" return LinuxDoOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) class TestLinuxDoOAuth20: - """Test LinuxDo OAuth2 client functionality.""" - - def test_linuxdo_client_initialization(self, linuxdo_client): - """Test LinuxDo client initialization with correct parameters.""" + def test_client_initialization(self, linuxdo_client): assert linuxdo_client.client_id == TEST_CLIENT_ID assert linuxdo_client.client_secret == TEST_CLIENT_SECRET assert linuxdo_client.authorize_endpoint == 'https://connect.linux.do/oauth2/authorize' @@ -36,53 +32,37 @@ def test_linuxdo_client_initialization(self, linuxdo_client): assert linuxdo_client.default_scopes is None assert linuxdo_client.token_endpoint_basic_auth is True - def test_linuxdo_client_initialization_with_custom_credentials(self): - """Test LinuxDo client initialization with custom credentials.""" - client = LinuxDoOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) - assert client.client_id == TEST_CLIENT_ID - assert client.client_secret == TEST_CLIENT_SECRET - - def test_linuxdo_client_inheritance(self, linuxdo_client): - """Test that LinuxDo client properly inherits from OAuth20Base.""" + def test_client_inheritance(self, linuxdo_client): assert isinstance(linuxdo_client, OAuth20Base) - def test_linuxdo_client_basic_auth_enabled(self, linuxdo_client): - """Test that LinuxDo client has basic authentication enabled for token endpoint.""" + def test_client_basic_auth_enabled(self, linuxdo_client): assert linuxdo_client.token_endpoint_basic_auth is True - def test_linuxdo_client_endpoint_urls(self): - """Test that LinuxDo client uses correct endpoint URLs.""" - client = LinuxDoOAuth20(TEST_CLIENT_ID, TEST_CLIENT_SECRET) - - # Test that endpoints are correctly set without hardcoding them in tests - assert client.authorize_endpoint.endswith('/oauth2/authorize') - assert client.access_token_endpoint.endswith('/oauth2/token') - assert client.refresh_token_endpoint.endswith('/oauth2/token') - - # Test that all endpoints use the correct domain - for endpoint in [client.authorize_endpoint, client.access_token_endpoint, client.refresh_token_endpoint]: + def test_client_endpoint_urls(self, linuxdo_client): + assert linuxdo_client.authorize_endpoint.endswith('/oauth2/authorize') + assert linuxdo_client.access_token_endpoint.endswith('/oauth2/token') + assert linuxdo_client.refresh_token_endpoint.endswith('/oauth2/token') + for endpoint in [ + linuxdo_client.authorize_endpoint, + linuxdo_client.access_token_endpoint, + linuxdo_client.refresh_token_endpoint, + ]: assert 'connect.linux.do' in endpoint @pytest.mark.asyncio @respx.mock async def test_get_userinfo_success(self, linuxdo_client): - """Test successful user info retrieval from LinuxDo API.""" mock_user_data = create_mock_user_data('linuxdo') mock_user_info_response(respx, LINUXDO_USER_INFO_URL, mock_user_data) - result = await linuxdo_client.get_userinfo(TEST_ACCESS_TOKEN) assert result == mock_user_data @pytest.mark.asyncio @respx.mock async def test_get_userinfo_authorization_header(self, linuxdo_client): - """Test that authorization header is correctly formatted.""" mock_user_data = {'id': 'test_user'} route = mock_user_info_response(respx, LINUXDO_USER_INFO_URL, mock_user_data) - await linuxdo_client.get_userinfo(TEST_ACCESS_TOKEN) - - # Verify the request was made with correct authorization header assert route.called request = route.calls[0].request assert request.headers['authorization'] == f'Bearer {TEST_ACCESS_TOKEN}' @@ -90,28 +70,19 @@ async def test_get_userinfo_authorization_header(self, linuxdo_client): @pytest.mark.asyncio @respx.mock async def test_get_userinfo_http_error(self, linuxdo_client): - """Test handling of HTTP errors when getting user info.""" respx.get(LINUXDO_USER_INFO_URL).mock(return_value=httpx.Response(401, text='Unauthorized')) - with pytest.raises(HTTPXOAuth20Error): await linuxdo_client.get_userinfo(INVALID_TOKEN) @pytest.mark.asyncio @respx.mock async def test_get_access_token_uses_basic_auth(self, linuxdo_client): - """Test that access token requests use HTTP Basic Authentication.""" mock_token_data = {'access_token': 'new_access_token'} - - # Mock the token endpoint and capture the request route = respx.post('https://connect.linux.do/oauth2/token').mock( return_value=httpx.Response(200, json=mock_token_data) ) - await linuxdo_client.get_access_token(code='auth_code_123', redirect_uri='https://example.com/callback') - - # Verify BasicAuth was used assert route.called request = route.calls[0].request assert 'authorization' in request.headers - # Basic auth should be present assert request.headers['authorization'].startswith('Basic ') diff --git a/tests/clients/test_oschina.py b/tests/clients/test_oschina.py index ca91b14..cdd36ea 100644 --- a/tests/clients/test_oschina.py +++ b/tests/clients/test_oschina.py @@ -19,15 +19,11 @@ @pytest.fixture def oschina_client(): - """Create OSChina OAuth2 client instance for testing.""" return OSChinaOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) class TestOSChinaOAuth20: - """Test OSChina OAuth2 client functionality.""" - - def test_oschina_client_initialization(self, oschina_client): - """Test OSChina client initialization with correct parameters.""" + def test_client_initialization(self, oschina_client): assert oschina_client.client_id == TEST_CLIENT_ID assert oschina_client.client_secret == TEST_CLIENT_SECRET assert oschina_client.authorize_endpoint == 'https://www.oschina.net/action/oauth2/authorize' @@ -35,53 +31,37 @@ def test_oschina_client_initialization(self, oschina_client): assert oschina_client.refresh_token_endpoint == 'https://www.oschina.net/action/openapi/token' assert oschina_client.default_scopes is None - def test_oschina_client_initialization_with_custom_credentials(self): - """Test OSChina client initialization with custom credentials.""" - client = OSChinaOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) - assert client.client_id == TEST_CLIENT_ID - assert client.client_secret == TEST_CLIENT_SECRET - - def test_oschina_client_inheritance(self, oschina_client): - """Test that OSChina client properly inherits from OAuth20Base.""" + def test_client_inheritance(self, oschina_client): assert isinstance(oschina_client, OAuth20Base) - def test_oschina_client_no_default_scopes(self, oschina_client): - """Test that OSChina client has no default scopes configured.""" + def test_client_no_default_scopes(self, oschina_client): assert oschina_client.default_scopes is None - def test_oschina_client_endpoint_urls(self): - """Test that OSChina client uses correct endpoint URLs.""" - client = OSChinaOAuth20(TEST_CLIENT_ID, TEST_CLIENT_SECRET) - - # Test that endpoints are correctly set without hardcoding them in tests - assert client.authorize_endpoint.endswith('/action/oauth2/authorize') - assert client.access_token_endpoint.endswith('/action/openapi/token') - assert client.refresh_token_endpoint.endswith('/action/openapi/token') - - # Test that all endpoints use the correct domain - for endpoint in [client.authorize_endpoint, client.access_token_endpoint, client.refresh_token_endpoint]: + def test_client_endpoint_urls(self, oschina_client): + assert oschina_client.authorize_endpoint.endswith('/action/oauth2/authorize') + assert oschina_client.access_token_endpoint.endswith('/action/openapi/token') + assert oschina_client.refresh_token_endpoint.endswith('/action/openapi/token') + for endpoint in [ + oschina_client.authorize_endpoint, + oschina_client.access_token_endpoint, + oschina_client.refresh_token_endpoint, + ]: assert 'oschina.net' in endpoint @pytest.mark.asyncio @respx.mock async def test_get_userinfo_success(self, oschina_client): - """Test successful user info retrieval from OSChina API.""" mock_user_data = create_mock_user_data('oschina') mock_user_info_response(respx, OSCHINA_USER_INFO_URL, mock_user_data) - result = await oschina_client.get_userinfo(TEST_ACCESS_TOKEN) assert result == mock_user_data @pytest.mark.asyncio @respx.mock async def test_get_userinfo_authorization_header(self, oschina_client): - """Test that authorization header is correctly formatted.""" mock_user_data = {'id': 'test_user'} route = mock_user_info_response(respx, OSCHINA_USER_INFO_URL, mock_user_data) - await oschina_client.get_userinfo(TEST_ACCESS_TOKEN) - - # Verify the request was made with correct authorization header assert route.called request = route.calls[0].request assert request.headers['authorization'] == f'Bearer {TEST_ACCESS_TOKEN}' @@ -89,8 +69,6 @@ async def test_get_userinfo_authorization_header(self, oschina_client): @pytest.mark.asyncio @respx.mock async def test_get_userinfo_http_error(self, oschina_client): - """Test handling of HTTP errors when getting user info.""" respx.get(OSCHINA_USER_INFO_URL).mock(return_value=httpx.Response(401, text='Unauthorized')) - with pytest.raises(HTTPXOAuth20Error): await oschina_client.get_userinfo(INVALID_TOKEN) diff --git a/tests/clients/test_weixin_mp.py b/tests/clients/test_weixin_mp.py new file mode 100644 index 0000000..a7822af --- /dev/null +++ b/tests/clients/test_weixin_mp.py @@ -0,0 +1,207 @@ +import httpx +import pytest +import respx + +from fastapi_oauth20 import WeChatMpOAuth20 +from fastapi_oauth20.errors import GetUserInfoError, HTTPXOAuth20Error, RefreshTokenError +from fastapi_oauth20.oauth20 import OAuth20Base +from tests.conftest import ( + INVALID_TOKEN, + TEST_ACCESS_TOKEN, + TEST_CLIENT_ID, + TEST_CLIENT_SECRET, + create_mock_user_data, +) + +WECHAT_MP_USER_INFO_URL = 'https://api.weixin.qq.com/sns/userinfo' + + +@pytest.fixture +def wechat_mp_client(): + return WeChatMpOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + + +class TestWeChatMpOAuth20: + def test_client_initialization(self, wechat_mp_client): + assert wechat_mp_client.client_id == TEST_CLIENT_ID + assert wechat_mp_client.client_secret == TEST_CLIENT_SECRET + assert wechat_mp_client.authorize_endpoint == 'https://open.weixin.qq.com/connect/oauth2/authorize' + assert wechat_mp_client.access_token_endpoint == 'https://api.weixin.qq.com/sns/oauth2/access_token' + assert wechat_mp_client.refresh_token_endpoint == 'https://api.weixin.qq.com/sns/oauth2/refresh_token' + assert wechat_mp_client.userinfo_endpoint == 'https://api.weixin.qq.com/sns/userinfo' + assert wechat_mp_client.default_scopes == ['snsapi_userinfo'] + + def test_client_inheritance(self, wechat_mp_client): + assert isinstance(wechat_mp_client, OAuth20Base) + + def test_client_scopes_are_lists(self, wechat_mp_client): + assert isinstance(wechat_mp_client.default_scopes, list) + assert len(wechat_mp_client.default_scopes) == 1 + assert all(isinstance(scope, str) for scope in wechat_mp_client.default_scopes) + + def test_client_endpoint_urls(self, wechat_mp_client): + assert wechat_mp_client.authorize_endpoint.endswith('/connect/oauth2/authorize') + assert wechat_mp_client.access_token_endpoint.endswith('/sns/oauth2/access_token') + assert wechat_mp_client.refresh_token_endpoint.endswith('/sns/oauth2/refresh_token') + assert wechat_mp_client.userinfo_endpoint.endswith('/sns/userinfo') + assert 'open.weixin.qq.com' in wechat_mp_client.authorize_endpoint + assert 'api.weixin.qq.com' in wechat_mp_client.access_token_endpoint + + def test_client_multiple_instances(self): + client1 = WeChatMpOAuth20('client1', 'secret1') + client2 = WeChatMpOAuth20('client2', 'secret2') + assert client1.client_id != client2.client_id + assert client1.client_secret != client2.client_secret + assert client1.authorize_endpoint == client2.authorize_endpoint + + @pytest.mark.asyncio + async def test_get_authorization_url(self, wechat_mp_client): + redirect_uri = 'https://example.com/callback' + state = 'test_state' + url = await wechat_mp_client.get_authorization_url(redirect_uri=redirect_uri, state=state) + assert 'open.weixin.qq.com/connect/oauth2/authorize' in url + assert f'appid={TEST_CLIENT_ID}' in url + assert 'redirect_uri=https%3A%2F%2Fexample.com%2Fcallback' in url + assert f'state={state}' in url + assert 'response_type=code' in url + assert 'scope=snsapi_userinfo' in url + assert url.endswith('#wechat_redirect') + + @pytest.mark.asyncio + async def test_get_authorization_url_with_custom_scope(self, wechat_mp_client): + redirect_uri = 'https://example.com/callback' + url = await wechat_mp_client.get_authorization_url(redirect_uri=redirect_uri, scope=['snsapi_base']) + assert 'scope=snsapi_base' in url + + @pytest.mark.asyncio + async def test_get_authorization_url_query_parameters(self, wechat_mp_client): + url = await wechat_mp_client.get_authorization_url( + redirect_uri='https://example.com/callback', state='test_state' + ) + assert 'appid=' in url + assert 'redirect_uri=' in url + assert 'response_type=code' in url + assert 'scope=' in url + assert 'state=' in url + + @pytest.mark.asyncio + async def test_get_authorization_url_with_kwargs(self, wechat_mp_client): + url = await wechat_mp_client.get_authorization_url( + redirect_uri='https://example.com/callback', state='test_state', extra_param='extra_value' + ) + assert 'open.weixin.qq.com/connect/oauth2/authorize' in url + assert f'appid={TEST_CLIENT_ID}' in url + assert 'extra_param=extra_value' in url + + @pytest.mark.asyncio + @respx.mock + async def test_get_access_token_success(self, wechat_mp_client): + mock_token_data = { + 'access_token': TEST_ACCESS_TOKEN, + 'expires_in': 7200, + 'refresh_token': 'test_refresh_token', + 'openid': 'test_openid', + 'scope': 'snsapi_userinfo', + } + respx.get('https://api.weixin.qq.com/sns/oauth2/access_token').mock( + return_value=httpx.Response(200, json=mock_token_data) + ) + result = await wechat_mp_client.get_access_token(code='test_code') + assert result == mock_token_data + assert result['access_token'] == TEST_ACCESS_TOKEN + assert result['openid'] == 'test_openid' + + @pytest.mark.asyncio + @respx.mock + async def test_get_access_token_wechat_error_response(self, wechat_mp_client): + error_response = {'errcode': 40029, 'errmsg': 'invalid code'} + respx.get('https://api.weixin.qq.com/sns/oauth2/access_token').mock( + return_value=httpx.Response(200, json=error_response) + ) + result = await wechat_mp_client.get_access_token(code='invalid_code') + assert result == error_response + assert result['errcode'] == 40029 + + @pytest.mark.asyncio + @respx.mock + async def test_refresh_token_success(self, wechat_mp_client): + mock_token_data = { + 'access_token': 'new_access_token', + 'expires_in': 7200, + 'refresh_token': 'new_refresh_token', + 'openid': 'test_openid', + 'scope': 'snsapi_userinfo', + } + respx.get('https://api.weixin.qq.com/sns/oauth2/refresh_token').mock( + return_value=httpx.Response(200, json=mock_token_data) + ) + result = await wechat_mp_client.refresh_token(refresh_token='test_refresh_token') + assert result == mock_token_data + assert result['access_token'] == 'new_access_token' + + @pytest.mark.asyncio + async def test_refresh_token_without_endpoint(self): + client = WeChatMpOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + client.refresh_token_endpoint = None + with pytest.raises(RefreshTokenError, match='The refresh token address is missing'): + await client.refresh_token(refresh_token='test_refresh_token') + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_success(self, wechat_mp_client): + mock_user_data = create_mock_user_data('wechat_mp') + respx.get(WECHAT_MP_USER_INFO_URL).mock(return_value=httpx.Response(200, json=mock_user_data)) + result = await wechat_mp_client.get_userinfo(TEST_ACCESS_TOKEN, openid='test_openid') + assert result == mock_user_data + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_with_lang_parameter(self, wechat_mp_client): + mock_user_data = create_mock_user_data('wechat_mp') + route = respx.get(WECHAT_MP_USER_INFO_URL).mock(return_value=httpx.Response(200, json=mock_user_data)) + await wechat_mp_client.get_userinfo(TEST_ACCESS_TOKEN, openid='test_openid') + assert route.called + request = route.calls[0].request + assert 'lang=zh_CN' in str(request.url) + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_wechat_error_response(self, wechat_mp_client): + error_response = {'errcode': 40001, 'errmsg': 'invalid credential'} + respx.get(WECHAT_MP_USER_INFO_URL).mock(return_value=httpx.Response(200, json=error_response)) + result = await wechat_mp_client.get_userinfo(TEST_ACCESS_TOKEN, openid='test_openid') + assert result == error_response + assert result['errcode'] == 40001 + + @pytest.mark.asyncio + async def test_get_userinfo_without_openid(self, wechat_mp_client): + with pytest.raises(GetUserInfoError, match='openid is required'): + await wechat_mp_client.get_userinfo(TEST_ACCESS_TOKEN) + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_invalid_json(self, wechat_mp_client): + respx.get(WECHAT_MP_USER_INFO_URL).mock(return_value=httpx.Response(200, text='invalid json')) + with pytest.raises(GetUserInfoError): + await wechat_mp_client.get_userinfo(TEST_ACCESS_TOKEN, openid='test_openid') + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_http_error_401(self, wechat_mp_client): + respx.get(WECHAT_MP_USER_INFO_URL).mock(return_value=httpx.Response(401, text='Unauthorized')) + with pytest.raises(HTTPXOAuth20Error): + await wechat_mp_client.get_userinfo(INVALID_TOKEN, openid='test_openid') + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_http_error_403(self, wechat_mp_client): + respx.get(WECHAT_MP_USER_INFO_URL).mock(return_value=httpx.Response(403, text='Forbidden')) + with pytest.raises(HTTPXOAuth20Error): + await wechat_mp_client.get_userinfo(TEST_ACCESS_TOKEN, openid='test_openid') + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_http_error_500(self, wechat_mp_client): + respx.get(WECHAT_MP_USER_INFO_URL).mock(return_value=httpx.Response(500, text='Internal Server Error')) + with pytest.raises(HTTPXOAuth20Error): + await wechat_mp_client.get_userinfo(TEST_ACCESS_TOKEN, openid='test_openid') diff --git a/tests/clients/test_weixin_open.py b/tests/clients/test_weixin_open.py new file mode 100644 index 0000000..f7e3097 --- /dev/null +++ b/tests/clients/test_weixin_open.py @@ -0,0 +1,217 @@ +import httpx +import pytest +import respx + +from fastapi_oauth20 import WeChatOpenOAuth20 +from fastapi_oauth20.errors import GetUserInfoError, HTTPXOAuth20Error, RefreshTokenError +from fastapi_oauth20.oauth20 import OAuth20Base +from tests.conftest import ( + INVALID_TOKEN, + TEST_ACCESS_TOKEN, + TEST_CLIENT_ID, + TEST_CLIENT_SECRET, + create_mock_user_data, +) + +WECHAT_OPEN_USER_INFO_URL = 'https://api.weixin.qq.com/sns/userinfo' + + +@pytest.fixture +def wechat_open_client(): + return WeChatOpenOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + + +class TestWeChatOpenOAuth20: + def test_client_initialization(self, wechat_open_client): + assert wechat_open_client.client_id == TEST_CLIENT_ID + assert wechat_open_client.client_secret == TEST_CLIENT_SECRET + assert wechat_open_client.authorize_endpoint == 'https://open.weixin.qq.com/connect/qrconnect' + assert wechat_open_client.access_token_endpoint == 'https://api.weixin.qq.com/sns/oauth2/access_token' + assert wechat_open_client.refresh_token_endpoint == 'https://api.weixin.qq.com/sns/oauth2/refresh_token' + assert wechat_open_client.userinfo_endpoint == 'https://api.weixin.qq.com/sns/userinfo' + assert wechat_open_client.default_scopes == ['snsapi_login'] + + def test_client_inheritance(self, wechat_open_client): + assert isinstance(wechat_open_client, OAuth20Base) + + def test_client_scopes_are_lists(self, wechat_open_client): + assert isinstance(wechat_open_client.default_scopes, list) + assert len(wechat_open_client.default_scopes) == 1 + assert all(isinstance(scope, str) for scope in wechat_open_client.default_scopes) + + def test_client_endpoint_urls(self, wechat_open_client): + assert wechat_open_client.authorize_endpoint.endswith('/connect/qrconnect') + assert wechat_open_client.access_token_endpoint.endswith('/sns/oauth2/access_token') + assert wechat_open_client.refresh_token_endpoint.endswith('/sns/oauth2/refresh_token') + assert wechat_open_client.userinfo_endpoint.endswith('/sns/userinfo') + assert 'open.weixin.qq.com' in wechat_open_client.authorize_endpoint + assert 'api.weixin.qq.com' in wechat_open_client.access_token_endpoint + + def test_client_multiple_instances(self): + client1 = WeChatOpenOAuth20('client1', 'secret1') + client2 = WeChatOpenOAuth20('client2', 'secret2') + assert client1.client_id != client2.client_id + assert client1.client_secret != client2.client_secret + assert client1.authorize_endpoint == client2.authorize_endpoint + + @pytest.mark.asyncio + async def test_get_authorization_url(self, wechat_open_client): + redirect_uri = 'https://example.com/callback' + state = 'test_state' + url = await wechat_open_client.get_authorization_url(redirect_uri=redirect_uri, state=state) + assert 'open.weixin.qq.com/connect/qrconnect' in url + assert f'appid={TEST_CLIENT_ID}' in url + assert 'redirect_uri=https%3A%2F%2Fexample.com%2Fcallback' in url + assert f'state={state}' in url + assert 'response_type=code' in url + assert 'scope=snsapi_login' in url + assert url.endswith('#wechat_redirect') + + @pytest.mark.asyncio + async def test_get_authorization_url_with_custom_scope(self, wechat_open_client): + redirect_uri = 'https://example.com/callback' + url = await wechat_open_client.get_authorization_url(redirect_uri=redirect_uri, scope=['snsapi_login']) + assert 'scope=snsapi_login' in url + + @pytest.mark.asyncio + async def test_get_authorization_url_with_lang_parameter(self, wechat_open_client): + url = await wechat_open_client.get_authorization_url( + redirect_uri='https://example.com/callback', state='test_state' + ) + assert 'lang=cn' in url + + @pytest.mark.asyncio + async def test_get_authorization_url_query_parameters(self, wechat_open_client): + url = await wechat_open_client.get_authorization_url( + redirect_uri='https://example.com/callback', state='test_state' + ) + assert 'appid=' in url + assert 'redirect_uri=' in url + assert 'response_type=code' in url + assert 'scope=' in url + assert 'state=' in url + assert 'lang=' in url + + @pytest.mark.asyncio + async def test_get_authorization_url_with_kwargs(self, wechat_open_client): + url = await wechat_open_client.get_authorization_url( + redirect_uri='https://example.com/callback', state='test_state', extra_param='extra_value' + ) + assert 'open.weixin.qq.com/connect/qrconnect' in url + assert f'appid={TEST_CLIENT_ID}' in url + assert 'extra_param=extra_value' in url + + @pytest.mark.asyncio + @respx.mock + async def test_get_access_token_success(self, wechat_open_client): + mock_token_data = { + 'access_token': TEST_ACCESS_TOKEN, + 'expires_in': 7200, + 'refresh_token': 'test_refresh_token', + 'openid': 'test_openid', + 'scope': 'snsapi_login', + 'unionid': 'test_unionid', + } + respx.get('https://api.weixin.qq.com/sns/oauth2/access_token').mock( + return_value=httpx.Response(200, json=mock_token_data) + ) + result = await wechat_open_client.get_access_token(code='test_code') + assert result == mock_token_data + assert result['access_token'] == TEST_ACCESS_TOKEN + assert result['openid'] == 'test_openid' + assert result['unionid'] == 'test_unionid' + + @pytest.mark.asyncio + @respx.mock + async def test_get_access_token_wechat_error_response(self, wechat_open_client): + error_response = {'errcode': 40029, 'errmsg': 'invalid code'} + respx.get('https://api.weixin.qq.com/sns/oauth2/access_token').mock( + return_value=httpx.Response(200, json=error_response) + ) + result = await wechat_open_client.get_access_token(code='invalid_code') + assert result == error_response + assert result['errcode'] == 40029 + + @pytest.mark.asyncio + @respx.mock + async def test_refresh_token_success(self, wechat_open_client): + mock_token_data = { + 'access_token': 'new_access_token', + 'expires_in': 7200, + 'refresh_token': 'new_refresh_token', + 'openid': 'test_openid', + 'scope': 'snsapi_login', + } + respx.get('https://api.weixin.qq.com/sns/oauth2/refresh_token').mock( + return_value=httpx.Response(200, json=mock_token_data) + ) + result = await wechat_open_client.refresh_token(refresh_token='test_refresh_token') + assert result == mock_token_data + assert result['access_token'] == 'new_access_token' + + @pytest.mark.asyncio + async def test_refresh_token_without_endpoint(self): + client = WeChatOpenOAuth20(client_id=TEST_CLIENT_ID, client_secret=TEST_CLIENT_SECRET) + client.refresh_token_endpoint = None + with pytest.raises(RefreshTokenError, match='The refresh token address is missing'): + await client.refresh_token(refresh_token='test_refresh_token') + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_success(self, wechat_open_client): + mock_user_data = create_mock_user_data('wechat_open') + respx.get(WECHAT_OPEN_USER_INFO_URL).mock(return_value=httpx.Response(200, json=mock_user_data)) + result = await wechat_open_client.get_userinfo(TEST_ACCESS_TOKEN, openid='test_openid') + assert result == mock_user_data + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_with_lang_parameter(self, wechat_open_client): + mock_user_data = create_mock_user_data('wechat_open') + route = respx.get(WECHAT_OPEN_USER_INFO_URL).mock(return_value=httpx.Response(200, json=mock_user_data)) + await wechat_open_client.get_userinfo(TEST_ACCESS_TOKEN, openid='test_openid') + assert route.called + request = route.calls[0].request + assert 'lang=zh_CN' in str(request.url) + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_wechat_error_response(self, wechat_open_client): + error_response = {'errcode': 40001, 'errmsg': 'invalid credential'} + respx.get(WECHAT_OPEN_USER_INFO_URL).mock(return_value=httpx.Response(200, json=error_response)) + result = await wechat_open_client.get_userinfo(TEST_ACCESS_TOKEN, openid='test_openid') + assert result == error_response + assert result['errcode'] == 40001 + + @pytest.mark.asyncio + async def test_get_userinfo_without_openid(self, wechat_open_client): + with pytest.raises(GetUserInfoError, match='openid is required'): + await wechat_open_client.get_userinfo(TEST_ACCESS_TOKEN) + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_invalid_json(self, wechat_open_client): + respx.get(WECHAT_OPEN_USER_INFO_URL).mock(return_value=httpx.Response(200, text='invalid json')) + with pytest.raises(GetUserInfoError): + await wechat_open_client.get_userinfo(TEST_ACCESS_TOKEN, openid='test_openid') + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_http_error_401(self, wechat_open_client): + respx.get(WECHAT_OPEN_USER_INFO_URL).mock(return_value=httpx.Response(401, text='Unauthorized')) + with pytest.raises(HTTPXOAuth20Error): + await wechat_open_client.get_userinfo(INVALID_TOKEN, openid='test_openid') + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_http_error_403(self, wechat_open_client): + respx.get(WECHAT_OPEN_USER_INFO_URL).mock(return_value=httpx.Response(403, text='Forbidden')) + with pytest.raises(HTTPXOAuth20Error): + await wechat_open_client.get_userinfo(TEST_ACCESS_TOKEN, openid='test_openid') + + @pytest.mark.asyncio + @respx.mock + async def test_get_userinfo_http_error_500(self, wechat_open_client): + respx.get(WECHAT_OPEN_USER_INFO_URL).mock(return_value=httpx.Response(500, text='Internal Server Error')) + with pytest.raises(HTTPXOAuth20Error): + await wechat_open_client.get_userinfo(TEST_ACCESS_TOKEN, openid='test_openid') diff --git a/tests/conftest.py b/tests/conftest.py index 4065f00..9e00fe8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,58 +6,77 @@ INVALID_TOKEN = 'invalid_token' TEST_STATE = 'test_state' +MOCK_USER_DATA = { + 'github': { + 'id': 123456, + 'login': 'testuser', + 'name': 'Test User', + 'email': 'test@example.com', + 'bio': 'Test bio', + 'location': 'Test Location', + }, + 'google': { + 'id': '123456789', + 'email': 'test@gmail.com', + 'name': 'Test User', + 'picture': 'https://lh3.googleusercontent.com/test.jpg', + }, + 'gitee': { + 'id': 123456, + 'login': 'testuser', + 'name': 'Test User', + 'email': 'test@example.com', + 'avatar_url': 'https://avatar.example.com/testuser.png', + }, + 'feishu': { + 'user_id': 'test_user_123', + 'employee_id': 'emp_456', + 'name': 'Test User', + 'email': 'test@example.com', + 'mobile': '13800000000', + }, + 'linuxdo': { + 'id': 123456, + 'username': 'testuser', + 'name': 'Test User', + 'email': 'test@example.com', + 'avatar_url': 'https://linux.do/avatar/testuser.png', + }, + 'oschina': { + 'id': 123456, + 'name': 'Test User', + 'email': 'test@example.com', + 'avatar': 'https://oschina.net/img/test.jpg', + }, + 'wechat_mp': { + 'openid': 'test_openid_mp', + 'nickname': 'Test User', + 'sex': 1, + 'province': 'Guangdong', + 'city': 'Shenzhen', + 'country': 'China', + 'headimgurl': 'https://thirdwx.qlogo.cn/test.jpg', + 'privilege': [], + }, + 'wechat_open': { + 'openid': 'test_openid_open', + 'nickname': 'Test User', + 'sex': 1, + 'province': 'Guangdong', + 'city': 'Shenzhen', + 'country': 'China', + 'headimgurl': 'https://thirdwx.qlogo.cn/test.jpg', + 'privilege': [], + 'unionid': 'test_unionid', + }, +} -def create_mock_user_data(provider_name: str, **overrides): - """Create mock user data for a specific provider with optional overrides.""" - MOCK_USER_DATA = { - 'github': { - 'id': 123456, - 'login': 'testuser', - 'name': 'Test User', - 'email': 'test@example.com', - 'bio': 'Test bio', - 'location': 'Test Location', - }, - 'google': { - 'id': '123456789', - 'email': 'test@gmail.com', - 'name': 'Test User', - 'picture': 'https://lh3.googleusercontent.com/test.jpg', - }, - 'gitee': { - 'id': 123456, - 'login': 'testuser', - 'name': 'Test User', - 'email': 'test@example.com', - 'avatar_url': 'https://avatar.example.com/testuser.png', - }, - 'feishu': { - 'user_id': 'test_user_123', - 'employee_id': 'emp_456', - 'name': 'Test User', - 'email': 'test@example.com', - 'mobile': '13800000000', - }, - 'linuxdo': { - 'id': 123456, - 'username': 'testuser', - 'name': 'Test User', - 'email': 'test@example.com', - 'avatar_url': 'https://linux.do/avatar/testuser.png', - }, - 'oschina': { - 'id': 123456, - 'name': 'Test User', - 'email': 'test@example.com', - 'avatar': 'https://oschina.net/img/test.jpg', - }, - } +def create_mock_user_data(provider_name: str, **overrides): base_data = MOCK_USER_DATA.get(provider_name, {}).copy() base_data.update(overrides) return base_data def mock_user_info_response(respx_mock, user_info_url: str, user_data: dict, status_code: int = 200): - """Mock user info endpoint response.""" return respx_mock.get(user_info_url).mock(return_value=httpx.Response(status_code, json=user_data)) diff --git a/tests/test_errors.py b/tests/test_errors.py index cb344f4..b40147a 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -11,14 +11,12 @@ def test_oauth20_request_error_basic(): - """Test basic OAuth20RequestError creation.""" error = OAuth20RequestError('Test error') assert str(error) == 'Test error' assert error.msg == 'Test error' def test_oauth20_request_error_with_response(): - """Test OAuth20RequestError with HTTP response.""" mock_response = httpx.Response(400) error = OAuth20RequestError('Bad request', mock_response) assert str(error) == 'Bad request' @@ -27,14 +25,12 @@ def test_oauth20_request_error_with_response(): def test_httpx_oauth20_error_basic(): - """Test basic HTTPXOAuth20Error creation.""" error = HTTPXOAuth20Error('HTTP error') assert str(error) == 'HTTP error' assert error.msg == 'HTTP error' def test_httpx_oauth20_error_with_response(): - """Test HTTPXOAuth20Error with HTTP response.""" mock_response = httpx.Response(404) error = HTTPXOAuth20Error('Not found', mock_response) assert str(error) == 'Not found' @@ -43,58 +39,47 @@ def test_httpx_oauth20_error_with_response(): def test_access_token_error(): - """Test AccessTokenError creation and inheritance.""" mock_response = httpx.Response(401) error = AccessTokenError('Invalid token', mock_response) - assert str(error) == 'Invalid token' assert isinstance(error, OAuth20RequestError) assert error.response is mock_response def test_refresh_token_error(): - """Test RefreshTokenError creation and inheritance.""" mock_response = httpx.Response(401) error = RefreshTokenError('Invalid refresh token', mock_response) - assert str(error) == 'Invalid refresh token' assert isinstance(error, OAuth20RequestError) assert error.response is mock_response def test_revoke_token_error(): - """Test RevokeTokenError creation and inheritance.""" mock_response = httpx.Response(400) error = RevokeTokenError('Revocation failed', mock_response) - assert str(error) == 'Revocation failed' assert isinstance(error, OAuth20RequestError) assert error.response is mock_response def test_get_userinfo_error(): - """Test GetUserInfoError creation and inheritance.""" mock_response = httpx.Response(403) error = GetUserInfoError('Access denied', mock_response) - assert str(error) == 'Access denied' assert isinstance(error, OAuth20RequestError) assert error.response is mock_response def test_error_inheritance_chain(): - """Test that all OAuth2 errors have proper inheritance.""" assert issubclass(AccessTokenError, OAuth20RequestError) assert issubclass(RefreshTokenError, OAuth20RequestError) assert issubclass(RevokeTokenError, OAuth20RequestError) assert issubclass(GetUserInfoError, OAuth20RequestError) - assert issubclass(HTTPXOAuth20Error, OAuth20RequestError) assert issubclass(OAuth20RequestError, Exception) def test_error_without_response(): - """Test error creation without HTTP response.""" error = AccessTokenError('Simple error') assert str(error) == 'Simple error' assert error.msg == 'Simple error' @@ -102,22 +87,18 @@ def test_error_without_response(): def test_error_catch_hierarchy(): - """Test that errors can be caught at different levels of hierarchy.""" mock_response = httpx.Response(400) - # Specific error type try: raise AccessTokenError('Access token error', mock_response) except AccessTokenError as e: assert str(e) == 'Access token error' - # Parent OAuth20RequestError type try: raise RefreshTokenError('Refresh token error', mock_response) except OAuth20RequestError as e: assert str(e) == 'Refresh token error' - # HTTPXOAuth20Error type try: raise HTTPXOAuth20Error('HTTPX error', mock_response) except HTTPXOAuth20Error as e: @@ -125,9 +106,7 @@ def test_error_catch_hierarchy(): def test_error_properties(): - """Test that error objects have expected properties.""" mock_response = httpx.Response(500) - error = RevokeTokenError('Server error', mock_response) assert hasattr(error, 'msg') assert hasattr(error, 'response') @@ -136,19 +115,15 @@ def test_error_properties(): def test_error_str_representation(): - """Test string representation of errors.""" - # Error without response error1 = AccessTokenError('Simple message') assert str(error1) == 'Simple message' - # Error with response mock_response = httpx.Response(404) error2 = GetUserInfoError('User not found', mock_response) assert str(error2) == 'User not found' def test_error_with_complex_message(): - """Test errors with complex or multi-line messages.""" complex_message = "Error: Invalid request\nDetails: Missing required parameter 'code'" error = OAuth20RequestError(complex_message) assert str(error) == complex_message diff --git a/tests/test_oauth20.py b/tests/test_oauth20.py index ed3a9cd..a2a8f3d 100644 --- a/tests/test_oauth20.py +++ b/tests/test_oauth20.py @@ -1,31 +1,23 @@ import json +from typing import Any from unittest.mock import Mock import httpx import pytest import respx -from fastapi_oauth20.errors import ( - AccessTokenError, - HTTPXOAuth20Error, - RefreshTokenError, - RevokeTokenError, -) +from fastapi_oauth20.errors import AccessTokenError, HTTPXOAuth20Error, RefreshTokenError, RevokeTokenError from fastapi_oauth20.oauth20 import OAuth20Base class MockOAuth20Client(OAuth20Base): - """Test implementation of OAuth20Base for testing purposes.""" - - async def get_userinfo(self, access_token: str) -> dict[str, any]: - """Mock implementation for testing.""" + async def get_userinfo(self, access_token: str) -> dict[str, Any]: return {'user_id': 'test_user', 'access_token': access_token} @pytest.fixture def oauth_client(): - """Create OAuth20Base client instance for testing.""" return MockOAuth20Client( client_id='test_client_id', client_secret='test_client_secret', @@ -39,7 +31,6 @@ def oauth_client(): def test_oauth_base_initialization(oauth_client): - """Test OAuth20Base initialization with all parameters.""" assert oauth_client.client_id == 'test_client_id' assert oauth_client.client_secret == 'test_client_secret' assert oauth_client.authorize_endpoint == 'https://example.com/oauth/authorize' @@ -53,7 +44,6 @@ def test_oauth_base_initialization(oauth_client): def test_oauth_base_initialization_minimal(): - """Test OAuth20Base initialization with minimal required parameters.""" client = MockOAuth20Client( client_id='test_id', client_secret='test_secret', @@ -61,18 +51,14 @@ def test_oauth_base_initialization_minimal(): access_token_endpoint='https://example.com/token', userinfo_endpoint='https://example.com/userinfo', ) - assert client.client_id == 'test_id' assert client.client_secret == 'test_secret' - assert client.authorize_endpoint == 'https://example.com/auth' - assert client.access_token_endpoint == 'https://example.com/token' assert client.refresh_token_endpoint is None assert client.revoke_token_endpoint is None assert client.default_scopes is None def test_oauth_base_initialization_with_basic_auth(): - """Test OAuth20Base initialization with basic authentication enabled.""" client = MockOAuth20Client( client_id='test_id', client_secret='test_secret', @@ -82,16 +68,38 @@ def test_oauth_base_initialization_with_basic_auth(): token_endpoint_basic_auth=True, revoke_token_endpoint_basic_auth=True, ) - assert client.token_endpoint_basic_auth is True assert client.revoke_token_endpoint_basic_auth is True +def test_concrete_implementation(): + client = OAuth20Base( + client_id='test', + client_secret='test', + authorize_endpoint='https://example.com/auth', + access_token_endpoint='https://example.com/token', + userinfo_endpoint='https://example.com/userinfo', + ) + assert client.client_id == 'test' + assert client.userinfo_endpoint == 'https://example.com/userinfo' + + +@pytest.mark.asyncio +async def test_get_userinfo_implementation(): + client = MockOAuth20Client( + client_id='test', + client_secret='test', + authorize_endpoint='https://example.com/auth', + access_token_endpoint='https://example.com/token', + userinfo_endpoint='https://example.com/userinfo', + ) + result = await client.get_userinfo('test_token') + assert result == {'user_id': 'test_user', 'access_token': 'test_token'} + + @pytest.mark.asyncio async def test_get_authorization_url_basic(oauth_client): - """Test basic authorization URL generation.""" url = await oauth_client.get_authorization_url(redirect_uri='https://example.com/callback') - assert 'https://example.com/oauth/authorize' in url assert 'client_id=test_client_id' in url assert 'redirect_uri=https%3A%2F%2Fexample.com%2Fcallback' in url @@ -101,43 +109,35 @@ async def test_get_authorization_url_basic(oauth_client): @pytest.mark.asyncio async def test_get_authorization_url_with_state(oauth_client): - """Test authorization URL generation with state parameter.""" url = await oauth_client.get_authorization_url( redirect_uri='https://example.com/callback', state='random_state_123' ) - assert 'state=random_state_123' in url @pytest.mark.asyncio async def test_get_authorization_url_with_custom_scope(oauth_client): - """Test authorization URL generation with custom scope.""" url = await oauth_client.get_authorization_url( redirect_uri='https://example.com/callback', scope=['read', 'delete'] ) - assert 'scope=read+delete' in url assert 'write' not in url @pytest.mark.asyncio async def test_get_authorization_url_with_pkce(oauth_client): - """Test authorization URL generation with PKCE parameters.""" url = await oauth_client.get_authorization_url( redirect_uri='https://example.com/callback', code_challenge='challenge_123', code_challenge_method='S256' ) - assert 'code_challenge=challenge_123' in url assert 'code_challenge_method=S256' in url @pytest.mark.asyncio async def test_get_authorization_url_with_extra_params(oauth_client): - """Test authorization URL generation with additional parameters.""" url = await oauth_client.get_authorization_url( redirect_uri='https://example.com/callback', access_type='offline', prompt='consent' ) - assert 'access_type=offline' in url assert 'prompt=consent' in url @@ -145,17 +145,13 @@ async def test_get_authorization_url_with_extra_params(oauth_client): @pytest.mark.asyncio @respx.mock async def test_get_access_token_success(oauth_client): - """Test successful access token exchange.""" mock_token_data = { 'access_token': 'new_access_token', 'token_type': 'Bearer', 'expires_in': 3600, 'refresh_token': 'refresh_token_123', } - - # Mock the token endpoint respx.post('https://example.com/oauth/token').mock(return_value=httpx.Response(200, json=mock_token_data)) - result = await oauth_client.get_access_token(code='auth_code_123', redirect_uri='https://example.com/callback') assert result == mock_token_data @@ -163,17 +159,11 @@ async def test_get_access_token_success(oauth_client): @pytest.mark.asyncio @respx.mock async def test_get_access_token_with_code_verifier(oauth_client): - """Test access token exchange with PKCE code verifier.""" mock_token_data = {'access_token': 'new_access_token'} - - # Mock the token endpoint and capture the request route = respx.post('https://example.com/oauth/token').mock(return_value=httpx.Response(200, json=mock_token_data)) - await oauth_client.get_access_token( code='auth_code_123', redirect_uri='https://example.com/callback', code_verifier='verifier_123' ) - - # Verify the request was made with code_verifier assert route.called request_data = route.calls[0].request.content.decode() assert 'code_verifier=verifier_123' in request_data @@ -182,7 +172,6 @@ async def test_get_access_token_with_code_verifier(oauth_client): @pytest.mark.asyncio @respx.mock async def test_get_access_token_with_basic_auth(): - """Test access token exchange with HTTP Basic Authentication.""" client = MockOAuth20Client( client_id='test_id', client_secret='test_secret', @@ -191,29 +180,19 @@ async def test_get_access_token_with_basic_auth(): userinfo_endpoint='https://example.com/userinfo', token_endpoint_basic_auth=True, ) - mock_token_data = {'access_token': 'new_access_token'} - - # Mock the token endpoint route = respx.post('https://example.com/token').mock(return_value=httpx.Response(200, json=mock_token_data)) - await client.get_access_token(code='auth_code_123', redirect_uri='https://example.com/callback') - - # Verify BasicAuth was used assert route.called request = route.calls[0].request assert 'authorization' in request.headers - # Basic auth should be present assert request.headers['authorization'].startswith('Basic ') @pytest.mark.asyncio @respx.mock async def test_get_access_token_http_error(oauth_client): - """Test handling of HTTP errors during access token exchange.""" - # Mock HTTP error response respx.post('https://example.com/oauth/token').mock(return_value=httpx.Response(400, text='Bad Request')) - with pytest.raises(HTTPXOAuth20Error): await oauth_client.get_access_token(code='invalid_code', redirect_uri='https://example.com/callback') @@ -221,27 +200,42 @@ async def test_get_access_token_http_error(oauth_client): @pytest.mark.asyncio @respx.mock async def test_refresh_token_success(oauth_client): - """Test successful token refresh.""" mock_token_data = {'access_token': 'refreshed_access_token', 'token_type': 'Bearer', 'expires_in': 3600} - - # Mock the refresh endpoint respx.post('https://example.com/oauth/refresh').mock(return_value=httpx.Response(200, json=mock_token_data)) - result = await oauth_client.refresh_token('refresh_token_123') assert result == mock_token_data @pytest.mark.asyncio -async def test_refresh_token_missing_endpoint(): - """Test refresh token when refresh endpoint is not configured.""" +@respx.mock +async def test_refresh_token_with_basic_auth(): client = MockOAuth20Client( client_id='test_id', client_secret='test_secret', authorize_endpoint='https://example.com/auth', access_token_endpoint='https://example.com/token', userinfo_endpoint='https://example.com/userinfo', + refresh_token_endpoint='https://example.com/oauth/refresh', + token_endpoint_basic_auth=True, ) + mock_token_data = {'access_token': 'refreshed_access_token', 'token_type': 'Bearer', 'expires_in': 3600} + route = respx.post('https://example.com/oauth/refresh').mock(return_value=httpx.Response(200, json=mock_token_data)) + await client.refresh_token('refresh_token_123') + assert route.called + request = route.calls[0].request + assert 'authorization' in request.headers + assert request.headers['authorization'].startswith('Basic ') + +@pytest.mark.asyncio +async def test_refresh_token_missing_endpoint(): + client = MockOAuth20Client( + client_id='test_id', + client_secret='test_secret', + authorize_endpoint='https://example.com/auth', + access_token_endpoint='https://example.com/token', + userinfo_endpoint='https://example.com/userinfo', + ) with pytest.raises(RefreshTokenError, match='refresh token address is missing'): await client.refresh_token('refresh_token_123') @@ -249,10 +243,7 @@ async def test_refresh_token_missing_endpoint(): @pytest.mark.asyncio @respx.mock async def test_refresh_token_http_error(oauth_client): - """Test handling of HTTP errors during token refresh.""" - # Mock HTTP error response respx.post('https://example.com/oauth/refresh').mock(return_value=httpx.Response(401, text='Unauthorized')) - with pytest.raises(HTTPXOAuth20Error): await oauth_client.refresh_token('invalid_refresh_token') @@ -260,40 +251,49 @@ async def test_refresh_token_http_error(oauth_client): @pytest.mark.asyncio @respx.mock async def test_revoke_token_success(oauth_client): - """Test successful token revocation.""" - # Mock successful revocation response respx.post('https://example.com/oauth/revoke').mock(return_value=httpx.Response(200, text='OK')) - - # Should not raise any exception for successful revocation await oauth_client.revoke_token('access_token_123') @pytest.mark.asyncio @respx.mock async def test_revoke_token_with_type_hint(oauth_client): - """Test token revocation with token type hint.""" - # Mock the revoke endpoint and capture the request route = respx.post('https://example.com/oauth/revoke').mock(return_value=httpx.Response(200, text='OK')) - await oauth_client.revoke_token(token='refresh_token_123', token_type_hint='refresh_token') - - # Verify token_type_hint was included in the request assert route.called request_data = route.calls[0].request.content.decode() assert 'token_type_hint=refresh_token' in request_data @pytest.mark.asyncio -async def test_revoke_token_missing_endpoint(): - """Test token revocation when revoke endpoint is not configured.""" +@respx.mock +async def test_revoke_token_with_basic_auth(): client = MockOAuth20Client( client_id='test_id', client_secret='test_secret', authorize_endpoint='https://example.com/auth', access_token_endpoint='https://example.com/token', userinfo_endpoint='https://example.com/userinfo', + revoke_token_endpoint='https://example.com/oauth/revoke', + revoke_token_endpoint_basic_auth=True, ) + route = respx.post('https://example.com/oauth/revoke').mock(return_value=httpx.Response(200, text='OK')) + await client.revoke_token('access_token_123') + assert route.called + request = route.calls[0].request + assert 'authorization' in request.headers + assert request.headers['authorization'].startswith('Basic ') + +@pytest.mark.asyncio +async def test_revoke_token_missing_endpoint(): + client = MockOAuth20Client( + client_id='test_id', + client_secret='test_secret', + authorize_endpoint='https://example.com/auth', + access_token_endpoint='https://example.com/token', + userinfo_endpoint='https://example.com/userinfo', + ) with pytest.raises(RevokeTokenError, match='revoke token address is missing'): await client.revoke_token('access_token_123') @@ -301,87 +301,42 @@ async def test_revoke_token_missing_endpoint(): @pytest.mark.asyncio @respx.mock async def test_revoke_token_http_error(oauth_client): - """Test handling of HTTP errors during token revocation.""" - # Mock HTTP error response respx.post('https://example.com/oauth/revoke').mock(return_value=httpx.Response(400, text='Bad Request')) - with pytest.raises(HTTPXOAuth20Error): await oauth_client.revoke_token('invalid_token') def test_raise_httpx_oauth20_errors_success(): - """Test successful HTTP response validation.""" mock_response = Mock() mock_response.raise_for_status.return_value = None - - # Should not raise any exception OAuth20Base.raise_httpx_oauth20_errors(mock_response) def test_raise_httpx_oauth20_errors_http_status_error(): - """Test handling of HTTP status errors.""" mock_response = Mock() mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( 'Not Found', request=None, response=mock_response ) - with pytest.raises(HTTPXOAuth20Error): OAuth20Base.raise_httpx_oauth20_errors(mock_response) def test_raise_httpx_oauth20_errors_network_error(): - """Test handling of network errors.""" - # Test with a mock response that will raise RequestError when raise_for_status is called mock_response = Mock() mock_response.raise_for_status.side_effect = httpx.RequestError('Network error') - with pytest.raises(HTTPXOAuth20Error): OAuth20Base.raise_httpx_oauth20_errors(mock_response) def test_get_json_result_success(): - """Test successful JSON result parsing.""" mock_response = Mock() mock_response.json.return_value = {'key': 'value'} - result = OAuth20Base.get_json_result(mock_response, err_class=AccessTokenError) assert result == {'key': 'value'} def test_get_json_result_invalid_json(): - """Test handling of invalid JSON response.""" mock_response = Mock() mock_response.json.side_effect = json.JSONDecodeError('Invalid JSON', '', 0) - with pytest.raises(AccessTokenError, match='Result serialization failed'): OAuth20Base.get_json_result(mock_response, err_class=AccessTokenError) - - -def test_concrete_implementation(): - """Test that OAuth20Base can be instantiated directly.""" - client = OAuth20Base( - client_id='test', - client_secret='test', - authorize_endpoint='https://example.com/auth', - access_token_endpoint='https://example.com/token', - userinfo_endpoint='https://example.com/userinfo', - ) - - assert client.client_id == 'test' - assert client.client_secret == 'test' - assert client.userinfo_endpoint == 'https://example.com/userinfo' - - -@pytest.mark.asyncio -async def test_get_userinfo_implementation(): - """Test that concrete implementation of get_userinfo works.""" - client = MockOAuth20Client( - client_id='test', - client_secret='test', - authorize_endpoint='https://example.com/auth', - access_token_endpoint='https://example.com/token', - userinfo_endpoint='https://example.com/userinfo', - ) - - result = await client.get_userinfo('test_token') - assert result == {'user_id': 'test_user', 'access_token': 'test_token'}