This repository was archived by the owner on Sep 21, 2025. It is now read-only.
forked from Colin-b/httpx_auth
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcommon.py
More file actions
133 lines (106 loc) · 4.44 KB
/
common.py
File metadata and controls
133 lines (106 loc) · 4.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import abc
from collections.abc import Mapping
from typing import Callable, Generator, Optional, Union
from urllib.parse import parse_qs, urlsplit, urlunsplit, urlencode
import httpx
from httpx_auth._errors import GrantNotProvided, InvalidGrantRequest
from httpx_auth._oauth2.browser import DisplaySettings
from httpx_auth._oauth2.tokens import TokenMemoryCache
def _add_parameters(initial_url: str, extra_parameters: dict) -> str:
"""
Add parameters to a URL and return the new URL.
:param initial_url:
:param extra_parameters: dictionary of parameters names and value.
:return: the new URL containing parameters.
"""
scheme, netloc, path, query_string, fragment = urlsplit(initial_url)
query_params = parse_qs(query_string)
query_params.update(
{
parameter_name: [parameter_value]
for parameter_name, parameter_value in extra_parameters.items()
}
)
new_query_string = urlencode(query_params, doseq=True)
return urlunsplit((scheme, netloc, path, new_query_string, fragment))
def _pop_parameter(url: str, query_parameter_name: str) -> (str, Optional[str]):
"""
Remove and return parameter of an URL.
:param url: The URL containing (or not) the parameter.
:param query_parameter_name: The query parameter to pop.
:return: The new URL (without this parameter) and the parameter value (None if not found).
"""
scheme, netloc, path, query_string, fragment = urlsplit(url)
query_params = parse_qs(query_string)
parameter_value = query_params.pop(query_parameter_name, None)
new_query_string = urlencode(query_params, doseq=True)
return (
urlunsplit((scheme, netloc, path, new_query_string, fragment)),
parameter_value,
)
def _get_query_parameter(url: str, param_name: str) -> Optional[str]:
scheme, netloc, path, query_string, fragment = urlsplit(url)
query_params = parse_qs(query_string)
all_values = query_params.get(param_name)
return all_values[0] if all_values else None
def _content_from_response(response: httpx.Response) -> dict:
content_type = response.headers.get("content-type")
if content_type == "text/html; charset=utf-8":
return {
key_values[0]: key_values[1]
for key_value in response.text.split("&")
if (key_values := key_value.split("=")) and len(key_values) == 2
}
return response.json()
def request_new_grant_with_post(
url: str, data, grant_name: str, headers: Mapping[str, str]
) -> Generator[httpx.Request, httpx.Response, tuple[str, int, str]]:
response = yield httpx.Request("post", url, data=data, headers=headers)
if response.is_error:
# As described in https://tools.ietf.org/html/rfc6749#section-5.2
raise InvalidGrantRequest(response)
content = _content_from_response(response)
token = content.get(grant_name)
if not token:
raise GrantNotProvided(grant_name, content)
return token, content.get("expires_in"), content.get("refresh_token")
class OAuth2:
token_cache = TokenMemoryCache()
display = DisplaySettings()
class OAuth2BaseAuth(abc.ABC, httpx.Auth):
def __init__(
self,
state: str,
early_expiry: float,
header_name: str,
header_value: str,
refresh_token: Optional[Callable] = None,
) -> None:
if "{token}" not in header_value:
raise Exception("header_value parameter must contains {token}.")
self.state = state
self.early_expiry = early_expiry
self.header_name = header_name
self.header_value = header_value
self.refresh_token = refresh_token
self.requires_response_body = True
def auth_flow(
self, request: httpx.Request
) -> Generator[httpx.Request, httpx.Response, None]:
token = yield from OAuth2.token_cache.get_token(
self.state,
early_expiry=self.early_expiry,
on_missing_token=self.request_new_token,
on_expired_token=self.refresh_token,
)
self._update_user_request(request, token)
yield request
@abc.abstractmethod
def request_new_token(
self,
) -> Generator[
httpx.Request, httpx.Response, Union[tuple[str, str], tuple[str, str, int]]
]:
pass # pragma: no cover
def _update_user_request(self, request: httpx.Request, token: str) -> None:
request.headers[self.header_name] = self.header_value.format(token=token)