Skip to content

Commit a9c6789

Browse files
feat: Add pilot registration (secret-exchange)
1 parent 09cc45c commit a9c6789

32 files changed

Lines changed: 2213 additions & 76 deletions

File tree

diracx-client/src/diracx/client/patches/pilots/aio.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from azure.core.tracing.decorator_async import distributed_trace_async
1717

1818
from ..._generated.aio.operations._operations import PilotsOperations as _PilotsOperations
19+
from ..._generated.models._models import PilotCredentialsInfo
1920
from .common import (
2021
make_search_body,
2122
make_summary_body,
@@ -43,7 +44,7 @@ async def summary(self, **kwargs: Unpack[SummaryKwargs]) -> list[dict[str, Any]]
4344
return await super().summary(**make_summary_body(**kwargs))
4445

4546
@distributed_trace_async
46-
async def add_pilot_stamps(self, **kwargs: Unpack[AddPilotStampsKwargs]) -> None:
47+
async def add_pilot_stamps(self, **kwargs: Unpack[AddPilotStampsKwargs]) -> list[PilotCredentialsInfo] | None:
4748
"""TODO"""
4849
return await super().add_pilot_stamps(**make_add_pilot_stamps_body(**kwargs))
4950

diracx-client/src/diracx/client/patches/pilots/common.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ class AddPilotStampsBody(TypedDict, total=False):
9999
pilot_references: dict[str, str]
100100
pilot_status: PilotStatus
101101
vo: str
102+
generate_secrets: bool
103+
pilot_secret_use_count_max: int | None
102104

103105
class AddPilotStampsKwargs(AddPilotStampsBody, ResponseExtra): ...
104106

@@ -112,7 +114,7 @@ def make_add_pilot_stamps_body(**kwargs: Unpack[AddPilotStampsKwargs]) -> Underl
112114
for key in AddPilotStampsBody.__optional_keys__:
113115
if key not in kwargs:
114116
continue
115-
key = cast(Literal["pilot_stamps", "grid_type", "grid_site", "pilot_references", "pilot_status", "vo"], key)
117+
key = cast(Literal["pilot_stamps", "grid_type", "grid_site", "pilot_references", "pilot_status", "vo", "generate_secrets", "pilot_secret_use_count_max"], key)
116118
value = kwargs.pop(key)
117119
if value is not None:
118120
body[key] = value

diracx-client/src/diracx/client/patches/pilots/sync.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from azure.core.tracing.decorator import distributed_trace
1717

1818
from ..._generated.operations._operations import PilotsOperations as _PilotsOperations
19+
from ..._generated.models._models import PilotCredentialsInfo
1920
from .common import (
2021
make_search_body,
2122
make_summary_body,
@@ -43,7 +44,7 @@ def summary(self, **kwargs: Unpack[SummaryKwargs]) -> list[dict[str, Any]]:
4344
return super().summary(**make_summary_body(**kwargs))
4445

4546
@distributed_trace
46-
def add_pilot_stamps(self, **kwargs: Unpack[AddPilotStampsKwargs]) -> None:
47+
def add_pilot_stamps(self, **kwargs: Unpack[AddPilotStampsKwargs]) -> list[PilotCredentialsInfo] | None:
4748
"""TODO"""
4849
return super().add_pilot_stamps(**make_add_pilot_stamps_body(**kwargs))
4950

diracx-core/src/diracx/core/exceptions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ def __init__(self, job_id, detail: str = ""):
9999
)
100100

101101

102+
class BadTokenError(DiracError): ...
103+
104+
102105
class NotReadyError(DiracError):
103106
"""Tried to access a value which is asynchronously loaded but not yet available."""
104107

@@ -113,3 +116,15 @@ class PilotAlreadyExistsError(DiracError):
113116

114117
class PilotAlreadyAssociatedWithJobError(DiracError):
115118
"""We can't associate a pilot with the same job twice."""
119+
120+
121+
class SecretHasExpiredError(DiracError):
122+
"""If a secret expired."""
123+
124+
125+
class SecretNotFoundError(DiracError):
126+
"""If a secret not found."""
127+
128+
129+
class BadPilotCredentialsError(DiracError):
130+
"""If a pilot tries to auth with another pilot's credentials."""

diracx-core/src/diracx/core/models.py

Lines changed: 89 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55

66
from __future__ import annotations
77

8+
import uuid as std_uuid
89
from datetime import datetime
9-
from enum import StrEnum
10-
from typing import Literal, Optional
10+
from enum import StrEnum, auto
11+
from typing import Any, Literal, Optional
1112

12-
from pydantic import BaseModel, Field
13+
from pydantic import BaseModel, Field, GetCoreSchemaHandler, GetJsonSchemaHandler
14+
from pydantic_core import CoreSchema, core_schema
1315
from typing_extensions import TypedDict
16+
from uuid_utils import UUID as _UUID
1417

1518

1619
class ScalarSearchOperator(StrEnum):
@@ -37,7 +40,7 @@ class ScalarSearchSpec(TypedDict):
3740
class VectorSearchSpec(TypedDict):
3841
parameter: str
3942
operator: VectorSearchOperator
40-
values: list[str] | list[int]
43+
values: list[str] | list[int] | list[bytes]
4144

4245

4346
SearchSpec = ScalarSearchSpec | VectorSearchSpec
@@ -179,6 +182,29 @@ class SandboxUploadResponse(BaseModel):
179182
fields: dict[str, str] = {}
180183

181184

185+
class UUID(_UUID):
186+
"""Subclass of uuid_utils.UUID to add pydantic support."""
187+
188+
@classmethod
189+
def __get_pydantic_core_schema__(
190+
cls, source_type: Any, handler: GetCoreSchemaHandler
191+
) -> CoreSchema:
192+
"""Use the stdlib uuid.UUID schema for validation and serialization."""
193+
std_schema = handler(std_uuid.UUID)
194+
195+
def to_uuid_utils(u: std_uuid.UUID) -> UUID:
196+
return cls(str(u))
197+
198+
return core_schema.no_info_after_validator_function(to_uuid_utils, std_schema)
199+
200+
@classmethod
201+
def __get_pydantic_json_schema__(
202+
cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler
203+
) -> dict[str, Any]:
204+
"""Return the stdlib uuid.UUID schema for JSON serialization."""
205+
return handler(core_schema)
206+
207+
182208
class GrantType(StrEnum):
183209
"""Grant types for OAuth2."""
184210

@@ -213,9 +239,14 @@ class OpenIDConfiguration(TypedDict):
213239
code_challenge_methods_supported: list[str]
214240

215241

216-
class TokenPayload(TypedDict):
242+
class BaseTokenPayload(TypedDict):
243+
"""This class helps having pilot and user tokens without code duplication."""
244+
217245
jti: str
218246
exp: datetime
247+
248+
249+
class TokenPayload(BaseTokenPayload):
219250
dirac_policies: dict
220251

221252

@@ -308,3 +339,56 @@ class PilotStatus(StrEnum):
308339
ABORTED = "Aborted"
309340
#: Cannot get information about the pilot status:
310341
UNKNOWN = "Unknown"
342+
343+
344+
class PilotSecretConstraints(TypedDict, total=False):
345+
VOs: list[str] # Authorize only a list of VOs
346+
PilotStamps: list[str] # Authorize only a list of stamps
347+
Sites: list[str] # Authorize only a list of sites
348+
# ...
349+
# We can add constraints here
350+
351+
352+
class TokenType(StrEnum):
353+
# Pilot token
354+
PILOT_TOKEN = auto()
355+
# User token
356+
USER_TOKEN = auto()
357+
358+
359+
class PilotSecretsInfo(BaseModel):
360+
pilot_secret: str
361+
pilot_secret_expires_in: int
362+
363+
364+
class PilotAccessTokenPayload(BaseTokenPayload):
365+
sub: str
366+
vo: str
367+
iss: str
368+
pilot_stamp: str
369+
370+
371+
class PilotInfo(BaseModel):
372+
pilot_stamp: str
373+
vo: str
374+
sub: str
375+
376+
377+
class PilotRefreshTokenPayload(BaseTokenPayload):
378+
legacy_exchange: bool
379+
380+
381+
class PilotCredentialsInfo(PilotSecretsInfo):
382+
pilot_stamp: str
383+
384+
385+
class PilotAuthCredentials(TypedDict):
386+
pilot_stamp: str
387+
pilot_secret: str
388+
389+
390+
class VacuumPilotAuth(PilotAuthCredentials):
391+
vo: str
392+
grid_type: str
393+
grid_site: str
394+
status: str

diracx-core/src/diracx/core/settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ class AuthSettings(ServiceSettingsBase):
163163
token_allowed_algorithms: list[str] = ["RS256", "EdDSA"] # noqa: S105
164164
access_token_expire_minutes: int = 20
165165
refresh_token_expire_minutes: int = 60
166+
pilot_secret_expire_seconds: int = 3600
167+
pilot_refresh_token_expire_hours: int = 168
166168

167169
available_properties: set[SecurityProperty] = Field(
168170
default_factory=SecurityProperty.available_properties

diracx-core/src/diracx/core/utils.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from uuid import UUID
4+
35
__all__ = [
46
"dotenv_files_from_environment",
57
"serialize_credentials",
@@ -19,7 +21,7 @@
1921
from concurrent.futures import Future, ThreadPoolExecutor, wait
2022
from datetime import datetime, timedelta, timezone
2123
from pathlib import Path
22-
from typing import Any, AsyncIterable, TypeVar
24+
from typing import Any, AsyncIterable, Mapping, TypeVar, cast
2325

2426
from cachetools import Cache, TTLCache
2527

@@ -271,3 +273,31 @@ async def batched_async(
271273
if strict and len(batch) != n:
272274
raise ValueError("batched(): incomplete batch")
273275
yield tuple(batch)
276+
277+
278+
def extract_timestamp_from_uuid7(uuid_str: str) -> datetime:
279+
u = UUID(uuid_str)
280+
ts_bytes = u.bytes[0:6] # First 48 bits = timestamp in ms
281+
timestamp_ms = int.from_bytes(ts_bytes, byteorder="big")
282+
# Convert into seconds then to datetime
283+
return datetime.fromtimestamp(timestamp_ms / 1000, timezone.utc)
284+
285+
286+
T_DICTS = TypeVar("T_DICTS", bound=Mapping[str, Any])
287+
288+
289+
def recursive_dict_merge(x: T_DICTS, y: T_DICTS) -> T_DICTS:
290+
result: dict[str, Any] = dict(x)
291+
292+
for k, v in y.items():
293+
if k in result:
294+
if isinstance(result[k], dict) and isinstance(v, dict):
295+
result[k] = recursive_dict_merge(result[k], v)
296+
elif isinstance(result[k], list) and isinstance(v, list):
297+
result[k] = result[k] + v
298+
else:
299+
result[k] = v
300+
else:
301+
result[k] = v
302+
303+
return cast(T_DICTS, result)

0 commit comments

Comments
 (0)