Skip to content

Commit 47386be

Browse files
authored
Merge pull request #122 from dreadnode/users/raja/add-refresh-s3-token
feat: Add automatic credential refresh for S3 storage operations
2 parents 777abc2 + d68cc4b commit 47386be

12 files changed

Lines changed: 190 additions & 705 deletions

File tree

docs/sdk/api.mdx

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -728,26 +728,41 @@ def get_user(self) -> UserResponse:
728728
### get\_user\_data\_credentials
729729

730730
```python
731-
get_user_data_credentials() -> UserDataCredentials
731+
get_user_data_credentials(
732+
duration: int = DEFAULT_FS_CREDENTIAL_DURATION,
733+
) -> UserDataCredentials
732734
```
733735

734736
Retrieves user data credentials for secondary storage access.
735737

738+
**Parameters:**
739+
740+
* **`duration`**
741+
(`int`, default:
742+
`DEFAULT_FS_CREDENTIAL_DURATION`
743+
)
744+
–Credential lifetime in seconds (default: 4 hours)
745+
736746
**Returns:**
737747

738748
* `UserDataCredentials`
739749
–The user data credentials object.
740750

741751
<Accordion title="Source code in dreadnode/api/client.py" icon="code">
742752
```python
743-
def get_user_data_credentials(self) -> UserDataCredentials:
753+
def get_user_data_credentials(
754+
self, duration: int = DEFAULT_FS_CREDENTIAL_DURATION
755+
) -> UserDataCredentials:
744756
"""
745757
Retrieves user data credentials for secondary storage access.
746758
759+
Args:
760+
duration: Credential lifetime in seconds (default: 4 hours)
761+
747762
Returns:
748763
The user data credentials object.
749764
"""
750-
response = self.request("GET", "/user-data/credentials")
765+
response = self._request("GET", "/user-data/credentials", params={"duration": duration})
751766
return UserDataCredentials(**response.json())
752767
```
753768

docs/sdk/artifact.mdx

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,10 @@ ArtifactStorage
244244
---------------
245245

246246
```python
247-
ArtifactStorage(file_system: AbstractFileSystem)
247+
ArtifactStorage(
248+
file_system: AbstractFileSystem,
249+
credential_refresher: Callable[[], bool] | None = None,
250+
)
248251
```
249252

250253
Storage for artifacts with efficient handling of large files and directories.
@@ -260,17 +263,28 @@ Initialize artifact storage with a file system and prefix path.
260263
* **`file_system`**
261264
(`AbstractFileSystem`)
262265
–FSSpec-compatible file system
266+
* **`credential_refresher`**
267+
(`Callable[[], bool] | None`, default:
268+
`None`
269+
)
270+
–Optional function to refresh credentials when it's about to expire
263271

264272
<Accordion title="Source code in dreadnode/artifact/storage.py" icon="code">
265273
```python
266-
def __init__(self, file_system: fsspec.AbstractFileSystem):
274+
def __init__(
275+
self,
276+
file_system: fsspec.AbstractFileSystem,
277+
credential_refresher: t.Callable[[], bool] | None = None,
278+
):
267279
"""
268280
Initialize artifact storage with a file system and prefix path.
269281
270282
Args:
271283
file_system: FSSpec-compatible file system
284+
credential_refresher: Optional function to refresh credentials when it's about to expire
272285
"""
273286
self._file_system = file_system
287+
self._credential_refresher = credential_refresher
274288
```
275289

276290

@@ -464,6 +478,7 @@ Store a file in the storage system, using multipart upload for large files.
464478

465479
<Accordion title="Source code in dreadnode/artifact/storage.py" icon="code">
466480
```python
481+
@with_credential_refresh
467482
def store_file(self, file_path: Path, target_key: str) -> str:
468483
"""
469484
Store a file in the storage system, using multipart upload for large files.

docs/sdk/main.mdx

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def __init__(
6565
self._fs_prefix: str = ".dreadnode/storage/"
6666

6767
self._initialized = False
68+
self._credentials: UserDataCredentials | None = None
69+
self._credentials_expiry: datetime | None = None
6870
```
6971

7072

@@ -380,6 +382,7 @@ def continue_run(self, run_context: RunContext) -> RunSpan:
380382
tracer=self._get_tracer(),
381383
file_system=self._fs,
382384
prefix_path=self._fs_prefix,
385+
credential_refresher=self._refresh_storage_credentials if self._credentials else None,
383386
)
384387
```
385388

@@ -523,19 +526,21 @@ def initialize(self) -> None:
523526
# )
524527
# )
525528
# )
526-
527-
credentials = self._api.get_user_data_credentials()
528-
resolved_endpoint = resolve_endpoint(credentials.endpoint)
529+
self._credentials = self._api.get_user_data_credentials(
530+
duration=DEFAULT_FS_CREDENTIAL_DURATION
531+
)
532+
self._credentials_expiry = self._credentials.expiration
533+
resolved_endpoint = resolve_endpoint(self._credentials.endpoint)
529534
self._fs = S3FileSystem(
530-
key=credentials.access_key_id,
531-
secret=credentials.secret_access_key,
532-
token=credentials.session_token,
535+
key=self._credentials.access_key_id,
536+
secret=self._credentials.secret_access_key,
537+
token=self._credentials.session_token,
533538
client_kwargs={
534539
"endpoint_url": resolved_endpoint,
535-
"region_name": credentials.region,
540+
"region_name": self._credentials.region,
536541
},
537542
)
538-
self._fs_prefix = f"{credentials.bucket}/{credentials.prefix}/"
543+
self._fs_prefix = f"{self._credentials.bucket}/{self._credentials.prefix}/"
539544

540545
self._logfire = logfire.configure(
541546
local=not self.is_default,
@@ -1723,6 +1728,7 @@ def run(
17231728
file_system=self._fs,
17241729
prefix_path=self._fs_prefix,
17251730
autolog=autolog,
1731+
credential_refresher=self._refresh_storage_credentials if self._credentials else None,
17261732
)
17271733
```
17281734

docs/sdk/metric.mdx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ Metric
3131
Metric(
3232
value: float,
3333
step: int = 0,
34-
timestamp: datetime = lambda: datetime.now(
35-
timezone.utc
34+
timestamp: datetime = (
35+
lambda: datetime.now(timezone.utc)
3636
)(),
3737
attributes: JsonDict = dict(),
3838
)

dreadnode/api/client.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@
3636
process_run,
3737
process_task,
3838
)
39-
from dreadnode.constants import DEFAULT_MAX_POLL_TIME, DEFAULT_POLL_INTERVAL
39+
from dreadnode.constants import (
40+
DEFAULT_FS_CREDENTIAL_DURATION,
41+
DEFAULT_MAX_POLL_TIME,
42+
DEFAULT_POLL_INTERVAL,
43+
)
4044
from dreadnode.util import logger
4145
from dreadnode.version import VERSION
4246

@@ -517,12 +521,17 @@ def export_timeseries(
517521

518522
# User data access
519523

520-
def get_user_data_credentials(self) -> UserDataCredentials:
524+
def get_user_data_credentials(
525+
self, duration: int = DEFAULT_FS_CREDENTIAL_DURATION
526+
) -> UserDataCredentials:
521527
"""
522528
Retrieves user data credentials for secondary storage access.
523529
530+
Args:
531+
duration: Credential lifetime in seconds (default: 4 hours)
532+
524533
Returns:
525534
The user data credentials object.
526535
"""
527-
response = self.request("GET", "/user-data/credentials")
536+
response = self._request("GET", "/user-data/credentials", params={"duration": duration})
528537
return UserDataCredentials(**response.json())

dreadnode/artifact/storage.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
"""
55

66
import hashlib
7+
import typing as t
78
from pathlib import Path
89

910
import fsspec # type: ignore[import-untyped]
1011

12+
from dreadnode.storage_utils import with_credential_refresh
1113
from dreadnode.util import logger
1214

1315
CHUNK_SIZE = 8 * 1024 * 1024 # 8MB
@@ -22,15 +24,27 @@ class ArtifactStorage:
2224
- Batch uploads for directories handled by fsspec
2325
"""
2426

25-
def __init__(self, file_system: fsspec.AbstractFileSystem):
27+
def __init__(
28+
self,
29+
file_system: fsspec.AbstractFileSystem,
30+
credential_refresher: t.Callable[[], bool] | None = None,
31+
):
2632
"""
2733
Initialize artifact storage with a file system and prefix path.
2834
2935
Args:
3036
file_system: FSSpec-compatible file system
37+
credential_refresher: Optional function to refresh credentials when it's about to expire
3138
"""
3239
self._file_system = file_system
40+
self._credential_refresher = credential_refresher
3341

42+
def _refresh_credentials_if_needed(self) -> None:
43+
"""Refresh credentials if refresher is available."""
44+
if self._credential_refresher:
45+
self._credential_refresher()
46+
47+
@with_credential_refresh
3448
def store_file(self, file_path: Path, target_key: str) -> str:
3549
"""
3650
Store a file in the storage system, using multipart upload for large files.

dreadnode/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,7 @@
5656
# allow overriding the user config file via env variable
5757
os.getenv("DREADNODE_USER_CONFIG_FILE") or pathlib.Path.home() / ".dreadnode" / "config"
5858
)
59+
60+
# Default values for the file system credential management
61+
DEFAULT_FS_CREDENTIAL_DURATION = 14400 # 4 hours in seconds
62+
FS_CREDENTIAL_REFRESH_BUFFER = 300 # 5 minutes in seconds

dreadnode/main.py

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from dreadnode.api.client import ApiClient
2727
from dreadnode.config import UserConfig
2828
from dreadnode.constants import (
29+
DEFAULT_FS_CREDENTIAL_DURATION,
2930
DEFAULT_SERVER_URL,
3031
ENV_API_KEY,
3132
ENV_API_TOKEN,
@@ -35,6 +36,7 @@
3536
ENV_PROJECT,
3637
ENV_SERVER,
3738
ENV_SERVER_URL,
39+
FS_CREDENTIAL_REFRESH_BUFFER,
3840
)
3941
from dreadnode.metric import (
4042
Metric,
@@ -64,7 +66,7 @@
6466
Inherited,
6567
JsonValue,
6668
)
67-
from dreadnode.util import clean_str, handle_internal_errors, resolve_endpoint
69+
from dreadnode.util import clean_str, handle_internal_errors, logger, resolve_endpoint
6870
from dreadnode.version import VERSION
6971

7072
if t.TYPE_CHECKING:
@@ -73,6 +75,8 @@
7375
from opentelemetry.sdk.trace import SpanProcessor
7476
from opentelemetry.trace import Tracer
7577

78+
from dreadnode.api.models import UserDataCredentials
79+
7680

7781
ToObject = t.Literal["task-or-run", "run"]
7882

@@ -137,6 +141,8 @@ def __init__(
137141
self._fs_prefix: str = ".dreadnode/storage/"
138142

139143
self._initialized = False
144+
self._credentials: UserDataCredentials | None = None
145+
self._credentials_expiry: datetime | None = None
140146

141147
def _get_profile_server(self, profile: str | None = None) -> str | None:
142148
with contextlib.suppress(Exception):
@@ -347,19 +353,21 @@ def initialize(self) -> None:
347353
# )
348354
# )
349355
# )
350-
351-
credentials = self._api.get_user_data_credentials()
352-
resolved_endpoint = resolve_endpoint(credentials.endpoint)
356+
self._credentials = self._api.get_user_data_credentials(
357+
duration=DEFAULT_FS_CREDENTIAL_DURATION
358+
)
359+
self._credentials_expiry = self._credentials.expiration
360+
resolved_endpoint = resolve_endpoint(self._credentials.endpoint)
353361
self._fs = S3FileSystem(
354-
key=credentials.access_key_id,
355-
secret=credentials.secret_access_key,
356-
token=credentials.session_token,
362+
key=self._credentials.access_key_id,
363+
secret=self._credentials.secret_access_key,
364+
token=self._credentials.session_token,
357365
client_kwargs={
358366
"endpoint_url": resolved_endpoint,
359-
"region_name": credentials.region,
367+
"region_name": self._credentials.region,
360368
},
361369
)
362-
self._fs_prefix = f"{credentials.bucket}/{credentials.prefix}/"
370+
self._fs_prefix = f"{self._credentials.bucket}/{self._credentials.prefix}/"
363371

364372
self._logfire = logfire.configure(
365373
local=not self.is_default,
@@ -406,6 +414,45 @@ def api(self, *, server: str | None = None, token: str | None = None) -> ApiClie
406414

407415
return self._api
408416

417+
def _refresh_storage_credentials(self) -> bool:
418+
"""Refresh storage credentials if they are about to expire."""
419+
if not self._api or not self._credentials:
420+
return False
421+
422+
now = datetime.now(timezone.utc)
423+
424+
if (
425+
self._credentials_expiry is None
426+
or (self._credentials_expiry - now).total_seconds() < FS_CREDENTIAL_REFRESH_BUFFER
427+
):
428+
try:
429+
logger.info("Refreshing storage credentials")
430+
self._credentials = self._api.get_user_data_credentials(
431+
duration=DEFAULT_FS_CREDENTIAL_DURATION
432+
)
433+
self._credentials_expiry = self._credentials.expiration
434+
435+
resolved_endpoint = resolve_endpoint(self._credentials.endpoint)
436+
self._fs = S3FileSystem(
437+
key=self._credentials.access_key_id,
438+
secret=self._credentials.secret_access_key,
439+
token=self._credentials.session_token,
440+
client_kwargs={
441+
"endpoint_url": resolved_endpoint,
442+
"region_name": self._credentials.region,
443+
},
444+
)
445+
logger.info(
446+
f"Storage credentials refreshed, valid until {self._credentials_expiry}"
447+
)
448+
return True # noqa: TRY300
449+
450+
except Exception as e: # noqa: BLE001
451+
logger.error(f"Failed to refresh storage credentials: {e}")
452+
return False
453+
454+
return True
455+
409456
def _get_tracer(self, *, is_span_tracer: bool = True) -> "Tracer":
410457
return self._logfire._tracer_provider.get_tracer( # noqa: SLF001
411458
self.otel_scope,
@@ -778,6 +825,7 @@ def run(
778825
file_system=self._fs,
779826
prefix_path=self._fs_prefix,
780827
autolog=autolog,
828+
credential_refresher=self._refresh_storage_credentials if self._credentials else None,
781829
)
782830

783831
def get_run_context(self) -> RunContext:
@@ -824,6 +872,7 @@ def continue_run(self, run_context: RunContext) -> RunSpan:
824872
tracer=self._get_tracer(),
825873
file_system=self._fs,
826874
prefix_path=self._fs_prefix,
875+
credential_refresher=self._refresh_storage_credentials if self._credentials else None,
827876
)
828877

829878
def tag(self, *tag: str, to: ToObject = "task-or-run") -> None:

0 commit comments

Comments
 (0)