|
26 | 26 | from dreadnode.api.client import ApiClient |
27 | 27 | from dreadnode.config import UserConfig |
28 | 28 | from dreadnode.constants import ( |
| 29 | + DEFAULT_FS_CREDENTIAL_DURATION, |
29 | 30 | DEFAULT_SERVER_URL, |
30 | 31 | ENV_API_KEY, |
31 | 32 | ENV_API_TOKEN, |
|
35 | 36 | ENV_PROJECT, |
36 | 37 | ENV_SERVER, |
37 | 38 | ENV_SERVER_URL, |
| 39 | + FS_CREDENTIAL_REFRESH_BUFFER, |
38 | 40 | ) |
39 | 41 | from dreadnode.metric import ( |
40 | 42 | Metric, |
|
64 | 66 | Inherited, |
65 | 67 | JsonValue, |
66 | 68 | ) |
67 | | -from dreadnode.util import clean_str, handle_internal_errors, resolve_endpoint |
| 69 | +from dreadnode.util import clean_str, handle_internal_errors, logger, resolve_endpoint |
68 | 70 | from dreadnode.version import VERSION |
69 | 71 |
|
70 | 72 | if t.TYPE_CHECKING: |
|
73 | 75 | from opentelemetry.sdk.trace import SpanProcessor |
74 | 76 | from opentelemetry.trace import Tracer |
75 | 77 |
|
| 78 | + from dreadnode.api.models import UserDataCredentials |
| 79 | + |
76 | 80 |
|
77 | 81 | ToObject = t.Literal["task-or-run", "run"] |
78 | 82 |
|
@@ -137,6 +141,8 @@ def __init__( |
137 | 141 | self._fs_prefix: str = ".dreadnode/storage/" |
138 | 142 |
|
139 | 143 | self._initialized = False |
| 144 | + self._credentials: UserDataCredentials | None = None |
| 145 | + self._credentials_expiry: datetime | None = None |
140 | 146 |
|
141 | 147 | def _get_profile_server(self, profile: str | None = None) -> str | None: |
142 | 148 | with contextlib.suppress(Exception): |
@@ -347,19 +353,21 @@ def initialize(self) -> None: |
347 | 353 | # ) |
348 | 354 | # ) |
349 | 355 | # ) |
350 | | - |
351 | | - credentials = self._api.get_user_data_credentials() |
352 | | - resolved_endpoint = resolve_endpoint(credentials.endpoint) |
| 356 | + self._credentials = self._api.get_user_data_credentials( |
| 357 | + duration=DEFAULT_FS_CREDENTIAL_DURATION |
| 358 | + ) |
| 359 | + self._credentials_expiry = self._credentials.expiration |
| 360 | + resolved_endpoint = resolve_endpoint(self._credentials.endpoint) |
353 | 361 | self._fs = S3FileSystem( |
354 | | - key=credentials.access_key_id, |
355 | | - secret=credentials.secret_access_key, |
356 | | - token=credentials.session_token, |
| 362 | + key=self._credentials.access_key_id, |
| 363 | + secret=self._credentials.secret_access_key, |
| 364 | + token=self._credentials.session_token, |
357 | 365 | client_kwargs={ |
358 | 366 | "endpoint_url": resolved_endpoint, |
359 | | - "region_name": credentials.region, |
| 367 | + "region_name": self._credentials.region, |
360 | 368 | }, |
361 | 369 | ) |
362 | | - self._fs_prefix = f"{credentials.bucket}/{credentials.prefix}/" |
| 370 | + self._fs_prefix = f"{self._credentials.bucket}/{self._credentials.prefix}/" |
363 | 371 |
|
364 | 372 | self._logfire = logfire.configure( |
365 | 373 | local=not self.is_default, |
@@ -406,6 +414,45 @@ def api(self, *, server: str | None = None, token: str | None = None) -> ApiClie |
406 | 414 |
|
407 | 415 | return self._api |
408 | 416 |
|
| 417 | + def _refresh_storage_credentials(self) -> bool: |
| 418 | + """Refresh storage credentials if they are about to expire.""" |
| 419 | + if not self._api or not self._credentials: |
| 420 | + return False |
| 421 | + |
| 422 | + now = datetime.now(timezone.utc) |
| 423 | + |
| 424 | + if ( |
| 425 | + self._credentials_expiry is None |
| 426 | + or (self._credentials_expiry - now).total_seconds() < FS_CREDENTIAL_REFRESH_BUFFER |
| 427 | + ): |
| 428 | + try: |
| 429 | + logger.info("Refreshing storage credentials") |
| 430 | + self._credentials = self._api.get_user_data_credentials( |
| 431 | + duration=DEFAULT_FS_CREDENTIAL_DURATION |
| 432 | + ) |
| 433 | + self._credentials_expiry = self._credentials.expiration |
| 434 | + |
| 435 | + resolved_endpoint = resolve_endpoint(self._credentials.endpoint) |
| 436 | + self._fs = S3FileSystem( |
| 437 | + key=self._credentials.access_key_id, |
| 438 | + secret=self._credentials.secret_access_key, |
| 439 | + token=self._credentials.session_token, |
| 440 | + client_kwargs={ |
| 441 | + "endpoint_url": resolved_endpoint, |
| 442 | + "region_name": self._credentials.region, |
| 443 | + }, |
| 444 | + ) |
| 445 | + logger.info( |
| 446 | + f"Storage credentials refreshed, valid until {self._credentials_expiry}" |
| 447 | + ) |
| 448 | + return True # noqa: TRY300 |
| 449 | + |
| 450 | + except Exception as e: # noqa: BLE001 |
| 451 | + logger.error(f"Failed to refresh storage credentials: {e}") |
| 452 | + return False |
| 453 | + |
| 454 | + return True |
| 455 | + |
409 | 456 | def _get_tracer(self, *, is_span_tracer: bool = True) -> "Tracer": |
410 | 457 | return self._logfire._tracer_provider.get_tracer( # noqa: SLF001 |
411 | 458 | self.otel_scope, |
@@ -778,6 +825,7 @@ def run( |
778 | 825 | file_system=self._fs, |
779 | 826 | prefix_path=self._fs_prefix, |
780 | 827 | autolog=autolog, |
| 828 | + credential_refresher=self._refresh_storage_credentials if self._credentials else None, |
781 | 829 | ) |
782 | 830 |
|
783 | 831 | def get_run_context(self) -> RunContext: |
@@ -824,6 +872,7 @@ def continue_run(self, run_context: RunContext) -> RunSpan: |
824 | 872 | tracer=self._get_tracer(), |
825 | 873 | file_system=self._fs, |
826 | 874 | prefix_path=self._fs_prefix, |
| 875 | + credential_refresher=self._refresh_storage_credentials if self._credentials else None, |
827 | 876 | ) |
828 | 877 |
|
829 | 878 | def tag(self, *tag: str, to: ToObject = "task-or-run") -> None: |
|
0 commit comments