Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion diracx-cli/src/diracx/cli/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ async def logout():
# Revoke refresh token
try:
await api.auth.revoke_refresh_token_by_refresh_token(
client_id=api.client_id, refresh_token=credentials.refresh_token
client_id=api.client_id, token=credentials.refresh_token
)
except Exception as e:
print(f"Error revoking the refresh token {e!r}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
build_auth_initiate_authorization_flow_request,
build_auth_initiate_device_flow_request,
build_auth_revoke_refresh_token_by_jti_request,
build_auth_revoke_refresh_token_by_refresh_token_request,
build_auth_userinfo_request,
build_config_serve_config_request,
build_jobs_add_heartbeat_request,
Expand Down Expand Up @@ -293,6 +292,7 @@ def __init__(self, *args, **kwargs) -> None:
raise_if_not_implemented(
self.__class__,
[
"revoke_refresh_token_by_refresh_token",
"get_oidc_token",
],
)
Expand Down Expand Up @@ -583,59 +583,6 @@ async def get_refresh_tokens(self, **kwargs: Any) -> List[Any]:

return deserialized # type: ignore

@distributed_trace_async
async def revoke_refresh_token_by_refresh_token(self, *, refresh_token: str, client_id: str, **kwargs: Any) -> str:
"""Revoke Refresh Token By Refresh Token.

Revoke a refresh token.

:keyword refresh_token: Required.
:paramtype refresh_token: str
:keyword client_id: Required.
:paramtype client_id: str
:return: str
:rtype: str
:raises ~azure.core.exceptions.HttpResponseError:
"""
error_map: MutableMapping = {
401: ClientAuthenticationError,
404: ResourceNotFoundError,
409: ResourceExistsError,
304: ResourceNotModifiedError,
}
error_map.update(kwargs.pop("error_map", {}) or {})

_headers = kwargs.pop("headers", {}) or {}
_params = kwargs.pop("params", {}) or {}

cls: ClsType[str] = kwargs.pop("cls", None)

_request = build_auth_revoke_refresh_token_by_refresh_token_request(
refresh_token=refresh_token,
client_id=client_id,
headers=_headers,
params=_params,
)
_request.url = self._client.format_url(_request.url)

_stream = False
pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access
_request, stream=_stream, **kwargs
)

response = pipeline_response.http_response

if response.status_code not in [200]:
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

deserialized = self._deserialize("str", pipeline_response.http_response)

if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore

return deserialized # type: ignore

@distributed_trace_async
async def revoke_refresh_token_by_jti(self, jti: str, **kwargs: Any) -> str:
"""Revoke Refresh Token By Jti.
Expand Down
2 changes: 2 additions & 0 deletions diracx-client/src/diracx/client/_generated/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ._models import ( # type: ignore
BodyAuthGetOidcToken,
BodyAuthGetOidcTokenGrantType,
BodyAuthRevokeRefreshTokenByRefreshToken,
BodyJobsRescheduleJobs,
BodyJobsUnassignBulkJobsSandboxes,
GroupInfo,
Expand Down Expand Up @@ -65,6 +66,7 @@
__all__ = [
"BodyAuthGetOidcToken",
"BodyAuthGetOidcTokenGrantType",
"BodyAuthRevokeRefreshTokenByRefreshToken",
"BodyJobsRescheduleJobs",
"BodyJobsUnassignBulkJobsSandboxes",
"GroupInfo",
Expand Down
40 changes: 40 additions & 0 deletions diracx-client/src/diracx/client/_generated/models/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,46 @@ class BodyAuthGetOidcTokenGrantType(_serialization.Model):
"""OAuth2 Grant type."""


class BodyAuthRevokeRefreshTokenByRefreshToken(_serialization.Model):
"""Body_auth_revoke_refresh_token_by_refresh_token.

All required parameters must be populated in order to send to server.

:ivar token: The refresh token to revoke. Required.
:vartype token: str
:ivar token_type_hint: Hint for the type of token being revoked.
:vartype token_type_hint: str
:ivar client_id: The client ID of the application requesting the revocation.
:vartype client_id: str
"""

_validation = {
"token": {"required": True},
}

_attribute_map = {
"token": {"key": "token", "type": "str"},
"token_type_hint": {"key": "token_type_hint", "type": "str"},
"client_id": {"key": "client_id", "type": "str"},
}

def __init__(
self, *, token: str, token_type_hint: Optional[str] = None, client_id: str = "myDIRACClientID", **kwargs: Any
) -> None:
"""
:keyword token: The refresh token to revoke. Required.
:paramtype token: str
:keyword token_type_hint: Hint for the type of token being revoked.
:paramtype token_type_hint: str
:keyword client_id: The client ID of the application requesting the revocation.
:paramtype client_id: str
"""
super().__init__(**kwargs)
self.token = token
self.token_type_hint = token_type_hint
self.client_id = client_id


class BodyJobsRescheduleJobs(_serialization.Model):
"""Body_jobs_reschedule_jobs.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,27 +175,6 @@ def build_auth_get_refresh_tokens_request(**kwargs: Any) -> HttpRequest:
return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs)


def build_auth_revoke_refresh_token_by_refresh_token_request( # pylint: disable=name-too-long
*, refresh_token: str, client_id: str, **kwargs: Any
) -> HttpRequest:
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})
_params = case_insensitive_dict(kwargs.pop("params", {}) or {})

accept = _headers.pop("Accept", "application/json")

# Construct URL
_url = "/api/auth/revoke"

# Construct parameters
_params["refresh_token"] = _SERIALIZER.query("refresh_token", refresh_token, "str")
_params["client_id"] = _SERIALIZER.query("client_id", client_id, "str")

# Construct headers
_headers["Accept"] = _SERIALIZER.header("accept", accept, "str")

return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs)


def build_auth_revoke_refresh_token_by_jti_request( # pylint: disable=name-too-long
jti: str, **kwargs: Any
) -> HttpRequest:
Expand Down Expand Up @@ -816,6 +795,7 @@ def __init__(self, *args, **kwargs) -> None:
raise_if_not_implemented(
self.__class__,
[
"revoke_refresh_token_by_refresh_token",
"get_oidc_token",
],
)
Expand Down Expand Up @@ -1104,59 +1084,6 @@ def get_refresh_tokens(self, **kwargs: Any) -> List[Any]:

return deserialized # type: ignore

@distributed_trace
def revoke_refresh_token_by_refresh_token(self, *, refresh_token: str, client_id: str, **kwargs: Any) -> str:
"""Revoke Refresh Token By Refresh Token.

Revoke a refresh token.

:keyword refresh_token: Required.
:paramtype refresh_token: str
:keyword client_id: Required.
:paramtype client_id: str
:return: str
:rtype: str
:raises ~azure.core.exceptions.HttpResponseError:
"""
error_map: MutableMapping = {
401: ClientAuthenticationError,
404: ResourceNotFoundError,
409: ResourceExistsError,
304: ResourceNotModifiedError,
}
error_map.update(kwargs.pop("error_map", {}) or {})

_headers = kwargs.pop("headers", {}) or {}
_params = kwargs.pop("params", {}) or {}

cls: ClsType[str] = kwargs.pop("cls", None)

_request = build_auth_revoke_refresh_token_by_refresh_token_request(
refresh_token=refresh_token,
client_id=client_id,
headers=_headers,
params=_params,
)
_request.url = self._client.format_url(_request.url)

_stream = False
pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
_request, stream=_stream, **kwargs
)

response = pipeline_response.http_response

if response.status_code not in [200]:
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

deserialized = self._deserialize("str", pipeline_response.http_response)

if cls:
return cls(pipeline_response, deserialized, {}) # type: ignore

return deserialized # type: ignore

@distributed_trace
def revoke_refresh_token_by_jti(self, jti: str, **kwargs: Any) -> str:
"""Revoke Refresh Token By Jti.
Expand Down
30 changes: 27 additions & 3 deletions diracx-client/src/diracx/client/patches/auth/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

from __future__ import annotations
from ast import Dict

__all__ = [
"AuthOperations",
Expand All @@ -19,15 +20,15 @@
_models,
AuthOperations as _AuthOperations,
)
from .common import prepare_request, handle_response
from .common import handle_revoke_response, prepare_oidc_request, handle_oidc_response, prepare_revoke_request


class AuthOperations(_AuthOperations):
@distributed_trace_async
async def get_oidc_token(
self, device_code: str, client_id: str, **kwargs
) -> TokenResponse | _models.DeviceFlowErrorResponse:
request = prepare_request(
request = prepare_oidc_request(
device_code=device_code,
client_id=client_id,
format_url=self._client.format_url,
Expand All @@ -39,7 +40,30 @@ async def get_oidc_token(
)
)

response = handle_response(pipeline_response, self._deserialize)
response = handle_oidc_response(pipeline_response, self._deserialize)
if isinstance(response, _models.DeviceFlowErrorResponse):
return response
return TokenResponse.model_validate(response.as_dict())

@distributed_trace_async
async def revoke_refresh_token_by_refresh_token(
self,
*,
token: str,
client_id: str,
token_type_hint: str = "refresh_token",
**kwargs,
) -> str:
request = prepare_revoke_request(
token=token,
client_id=client_id,
token_type_hint=token_type_hint,
format_url=self._client.format_url,
)

pipeline_response: PipelineResponse = (
await self._client._pipeline.run( # pylint: disable=protected-access
request, stream=False, **kwargs
)
)
return handle_revoke_response(pipeline_response, self._deserialize)
45 changes: 37 additions & 8 deletions diracx-client/src/diracx/client/patches/auth/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from __future__ import annotations

__all__ = [
"prepare_request",
"handle_response",
"prepare_oidc_request",
"handle_oidc_response",
]

from typing import Any
Expand All @@ -18,20 +18,24 @@
from ..._generated.operations._operations import _SERIALIZER


def build_token_request(**kwargs: Any) -> HttpRequest:
def build_request(**kwargs: Any) -> HttpRequest:
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})

accept = _headers.pop("Accept", "application/json")

_url = "/api/auth/token"
_url = kwargs.pop("url")

_headers["Accept"] = _SERIALIZER.header("accept", accept, "str")

return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs)
_method = kwargs.pop("method")

return HttpRequest(method=_method, url=_url, headers=_headers, **kwargs)

def prepare_request(device_code, client_id, format_url) -> HttpRequest:
request = build_token_request(

def prepare_oidc_request(device_code, client_id, format_url) -> HttpRequest:
request = build_request(
method="POST",
url="/api/auth/token",
data={
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"device_code": device_code,
Expand All @@ -42,7 +46,21 @@ def prepare_request(device_code, client_id, format_url) -> HttpRequest:
return request


def handle_response(
def prepare_revoke_request(token, client_id, token_type_hint, format_url) -> HttpRequest:
request = build_request(
method="POST",
url="/api/auth/revoke",
data={
"token": token,
"client_id": client_id,
"token_type_hint": token_type_hint,
}
)
request.url = format_url(request.url)
return request


def handle_oidc_response(
pipeline_response: PipelineResponse, deserialize
) -> TokenResponse | DeviceFlowErrorResponse:
response = pipeline_response.http_response
Expand All @@ -54,3 +72,14 @@ def handle_response(
else:
map_error(status_code=response.status_code, response=response, error_map={})
raise HttpResponseError(response=response)


def handle_revoke_response(
pipeline_response: PipelineResponse, deserialize
) -> str:
response = pipeline_response.http_response

if response.status_code != 200:
map_error(status_code=response.status_code, response=response, error_map={})
raise HttpResponseError(response=response)
return deserialize("str", pipeline_response)
Loading
Loading