Skip to content

Commit 48125d3

Browse files
chore(internal): refactor authentication internals
1 parent e2553be commit 48125d3

4 files changed

Lines changed: 51 additions & 11 deletions

File tree

src/rollin/_base_client.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
)
6464
from ._utils import is_dict, is_list, asyncify, is_given, lru_cache, is_mapping
6565
from ._compat import PYDANTIC_V1, model_copy, model_dump
66-
from ._models import GenericModel, FinalRequestOptions, validate_type, construct_type
66+
from ._models import GenericModel, SecurityOptions, FinalRequestOptions, validate_type, construct_type
6767
from ._response import (
6868
APIResponse,
6969
BaseAPIResponse,
@@ -432,9 +432,27 @@ def _make_status_error(
432432
) -> _exceptions.APIStatusError:
433433
raise NotImplementedError()
434434

435+
def _auth_headers(
436+
self,
437+
security: SecurityOptions, # noqa: ARG002
438+
) -> dict[str, str]:
439+
return {}
440+
441+
def _auth_query(
442+
self,
443+
security: SecurityOptions, # noqa: ARG002
444+
) -> dict[str, str]:
445+
return {}
446+
447+
def _custom_auth(
448+
self,
449+
security: SecurityOptions, # noqa: ARG002
450+
) -> httpx.Auth | None:
451+
return None
452+
435453
def _build_headers(self, options: FinalRequestOptions, *, retries_taken: int = 0) -> httpx.Headers:
436454
custom_headers = options.headers or {}
437-
headers_dict = _merge_mappings(self.default_headers, custom_headers)
455+
headers_dict = _merge_mappings({**self._auth_headers(options.security), **self.default_headers}, custom_headers)
438456
self._validate_headers(headers_dict, custom_headers)
439457

440458
# headers are case-insensitive while dictionaries are not.
@@ -506,7 +524,7 @@ def _build_request(
506524
raise RuntimeError(f"Unexpected JSON data type, {type(json_data)}, cannot merge with `extra_body`")
507525

508526
headers = self._build_headers(options, retries_taken=retries_taken)
509-
params = _merge_mappings(self.default_query, options.params)
527+
params = _merge_mappings({**self._auth_query(options.security), **self.default_query}, options.params)
510528
content_type = headers.get("Content-Type")
511529
files = options.files
512530

@@ -671,7 +689,6 @@ def default_headers(self) -> dict[str, str | Omit]:
671689
"Content-Type": "application/json",
672690
"User-Agent": self.user_agent,
673691
**self.platform_headers(),
674-
**self.auth_headers,
675692
**self._custom_headers,
676693
}
677694

@@ -990,8 +1007,9 @@ def request(
9901007
self._prepare_request(request)
9911008

9921009
kwargs: HttpxSendArgs = {}
993-
if self.custom_auth is not None:
994-
kwargs["auth"] = self.custom_auth
1010+
custom_auth = self._custom_auth(options.security)
1011+
if custom_auth is not None:
1012+
kwargs["auth"] = custom_auth
9951013

9961014
if options.follow_redirects is not None:
9971015
kwargs["follow_redirects"] = options.follow_redirects
@@ -1952,6 +1970,7 @@ def make_request_options(
19521970
idempotency_key: str | None = None,
19531971
timeout: float | httpx.Timeout | None | NotGiven = not_given,
19541972
post_parser: PostParser | NotGiven = not_given,
1973+
security: SecurityOptions | None = None,
19551974
) -> RequestOptions:
19561975
"""Create a dict of type RequestOptions without keys of NotGiven values."""
19571976
options: RequestOptions = {}
@@ -1977,6 +1996,9 @@ def make_request_options(
19771996
# internal
19781997
options["post_parser"] = post_parser # type: ignore
19791998

1999+
if security is not None:
2000+
options["security"] = security
2001+
19802002
return options
19812003

19822004

src/rollin/_client.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from ._utils import is_given, get_async_library
2323
from ._compat import cached_property
24+
from ._models import SecurityOptions
2425
from ._version import __version__
2526
from ._streaming import Stream as Stream, AsyncStream as AsyncStream
2627
from ._exceptions import RollinError, APIStatusError
@@ -132,9 +133,14 @@ def with_streaming_response(self) -> RollinWithStreamedResponse:
132133
def qs(self) -> Querystring:
133134
return Querystring(array_format="comma")
134135

135-
@property
136136
@override
137-
def auth_headers(self) -> dict[str, str]:
137+
def _auth_headers(self, security: SecurityOptions) -> dict[str, str]:
138+
return {
139+
**(self._api_key_header if security.get("api_key_header", False) else {}),
140+
}
141+
142+
@property
143+
def _api_key_header(self) -> dict[str, str]:
138144
api_key = self.api_key
139145
return {"X-Api-Key": api_key}
140146

@@ -324,9 +330,14 @@ def with_streaming_response(self) -> AsyncRollinWithStreamedResponse:
324330
def qs(self) -> Querystring:
325331
return Querystring(array_format="comma")
326332

327-
@property
328333
@override
329-
def auth_headers(self) -> dict[str, str]:
334+
def _auth_headers(self, security: SecurityOptions) -> dict[str, str]:
335+
return {
336+
**(self._api_key_header if security.get("api_key_header", False) else {}),
337+
}
338+
339+
@property
340+
def _api_key_header(self) -> dict[str, str]:
330341
api_key = self.api_key
331342
return {"X-Api-Key": api_key}
332343

src/rollin/_models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,10 @@ def _create_pydantic_model(type_: _T) -> Type[RootModel[_T]]:
791791
return RootModel[type_] # type: ignore
792792

793793

794+
class SecurityOptions(TypedDict, total=False):
795+
api_key_header: bool
796+
797+
794798
class FinalRequestOptionsInput(TypedDict, total=False):
795799
method: Required[str]
796800
url: Required[str]
@@ -804,6 +808,7 @@ class FinalRequestOptionsInput(TypedDict, total=False):
804808
json_data: Body
805809
extra_json: AnyMapping
806810
follow_redirects: bool
811+
security: SecurityOptions
807812

808813

809814
@final
@@ -818,6 +823,7 @@ class FinalRequestOptions(pydantic.BaseModel):
818823
idempotency_key: Union[str, None] = None
819824
post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven()
820825
follow_redirects: Union[bool, None] = None
826+
security: SecurityOptions = {"api_key_header": True}
821827

822828
content: Union[bytes, bytearray, IO[bytes], Iterable[bytes], AsyncIterable[bytes], None] = None
823829
# It should be noted that we cannot use `json` here as that would override

src/rollin/_types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from httpx import URL, Proxy, Timeout, Response, BaseTransport, AsyncBaseTransport
3737

3838
if TYPE_CHECKING:
39-
from ._models import BaseModel
39+
from ._models import BaseModel, SecurityOptions
4040
from ._response import APIResponse, AsyncAPIResponse
4141

4242
Transport = BaseTransport
@@ -121,6 +121,7 @@ class RequestOptions(TypedDict, total=False):
121121
extra_json: AnyMapping
122122
idempotency_key: str
123123
follow_redirects: bool
124+
security: SecurityOptions
124125

125126

126127
# Sentinel class used until PEP 0661 is accepted

0 commit comments

Comments
 (0)