Skip to content

Commit 9e9615b

Browse files
authored
Merge pull request #75 from dreadnode/feat/handle-docker-networks
feat: add automatic Docker service endpoint resolution
2 parents 34bd4e7 + 2b7badd commit 9e9615b

2 files changed

Lines changed: 109 additions & 5 deletions

File tree

dreadnode/main.py

Lines changed: 108 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
import inspect
33
import os
44
import random
5+
import socket
56
import typing as t
67
from dataclasses import dataclass
78
from datetime import datetime, timezone
89
from pathlib import Path
9-
from urllib.parse import urljoin, urlparse, urlunparse
10+
from urllib.parse import ParseResult, urljoin, urlparse, urlunparse
1011

1112
import coolname # type: ignore [import-untyped]
1213
import logfire
@@ -32,7 +33,14 @@
3233
ENV_SERVER,
3334
ENV_SERVER_URL,
3435
)
35-
from dreadnode.metric import Metric, MetricAggMode, MetricDict, Scorer, ScorerCallable, T
36+
from dreadnode.metric import (
37+
Metric,
38+
MetricAggMode,
39+
MetricDict,
40+
Scorer,
41+
ScorerCallable,
42+
T,
43+
)
3644
from dreadnode.task import P, R, Task
3745
from dreadnode.tracing.exporters import (
3846
FileExportConfig,
@@ -54,7 +62,7 @@
5462
JsonDict,
5563
JsonValue,
5664
)
57-
from dreadnode.util import clean_str, handle_internal_errors
65+
from dreadnode.util import clean_str, handle_internal_errors, logger
5866
from dreadnode.version import VERSION
5967

6068
if t.TYPE_CHECKING:
@@ -128,6 +136,97 @@ def __init__(
128136

129137
self._initialized = False
130138

139+
@staticmethod
140+
def _resolve_endpoint(endpoint: str | None) -> str | None:
141+
"""Automatically resolve endpoints based on environment
142+
143+
Args:
144+
endpoint: The endpoint URL to resolve.
145+
146+
Returns:
147+
str: The resolved endpoint URL.
148+
149+
Raises:
150+
ValueError: If the endpoint URL is invalid.
151+
"""
152+
if not endpoint:
153+
return None
154+
parsed = urlparse(endpoint)
155+
156+
# If it's a real domain (has dots), use as-is
157+
if not parsed.hostname:
158+
raise ValueError(f"Invalid endpoint URL: {endpoint}")
159+
160+
if "." in parsed.hostname:
161+
return endpoint
162+
163+
# If it's a service name, try to resolve it
164+
if Dreadnode._is_docker_service_name(parsed.hostname):
165+
return Dreadnode._resolve_docker_service(endpoint, parsed)
166+
167+
return endpoint
168+
169+
@staticmethod
170+
def _is_docker_service_name(hostname: str) -> bool:
171+
"""Check if this looks like a Docker service name
172+
173+
Args:
174+
hostname: The hostname to check.
175+
176+
Returns:
177+
bool: True if the hostname looks like a Docker service name, False otherwise.
178+
"""
179+
return bool(hostname and "." not in hostname and hostname != "localhost")
180+
181+
@staticmethod
182+
def _resolve_docker_service(original_endpoint: str, parsed: ParseResult) -> str:
183+
"""Try different resolution strategies for Docker services
184+
185+
Args:
186+
original_endpoint: The original endpoint URL.
187+
parsed: The parsed URL object.
188+
189+
Returns:
190+
str: The resolved endpoint URL.
191+
192+
Raises:
193+
RuntimeError: If no valid endpoint is found.
194+
"""
195+
strategies = [
196+
original_endpoint, # Try original first (works if running in same network)
197+
f"{parsed.scheme}://localhost:{parsed.port}", # Try localhost
198+
f"{parsed.scheme}://host.docker.internal:{parsed.port}", # Docker Desktop
199+
f"{parsed.scheme}://172.17.0.1:{parsed.port}", # Docker bridge IP
200+
]
201+
202+
for endpoint in strategies:
203+
if Dreadnode._test_connection(endpoint):
204+
logger.warning(
205+
f"Resolved Docker service for s3 connection '{parsed.hostname}' to '{endpoint}'."
206+
)
207+
return str(endpoint)
208+
209+
# If nothing works, return original and let it fail with a helpful error
210+
raise RuntimeError(f"Failed to connect to the Dreadnode Artifact storage at {endpoint}.")
211+
212+
@staticmethod
213+
def _test_connection(endpoint: str) -> bool:
214+
"""Quick connectivity test
215+
216+
Args:
217+
endpoint: The endpoint URL to test.
218+
219+
Returns:
220+
bool: True if the connection is successful, False otherwise.
221+
"""
222+
try:
223+
parsed = urlparse(endpoint)
224+
socket.create_connection((parsed.hostname, parsed.port or 443), timeout=1)
225+
except Exception: # noqa: BLE001
226+
return False
227+
228+
return True
229+
131230
def configure(
132231
self,
133232
*,
@@ -261,12 +360,13 @@ def initialize(self) -> None:
261360
# )
262361

263362
credentials = self._api.get_user_data_credentials()
363+
resolved_endpoint = self._resolve_endpoint(credentials.endpoint)
264364
self._fs = S3FileSystem(
265365
key=credentials.access_key_id,
266366
secret=credentials.secret_access_key,
267367
token=credentials.session_token,
268368
client_kwargs={
269-
"endpoint_url": credentials.endpoint,
369+
"endpoint_url": resolved_endpoint,
270370
"region_name": credentials.region,
271371
},
272372
)
@@ -1002,7 +1102,10 @@ def log_metric(
10021102
value
10031103
if isinstance(value, Metric)
10041104
else Metric(
1005-
float(value), step, timestamp or datetime.now(timezone.utc), attributes or {}
1105+
float(value),
1106+
step,
1107+
timestamp or datetime.now(timezone.utc),
1108+
attributes or {},
10061109
)
10071110
)
10081111
return target.log_metric(name, metric, origin=origin, mode=mode)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,4 +122,5 @@ skip-magic-trailing-comma = false
122122
"tests/**/*.py" = [
123123
"INP001", # namespace not required for pytest
124124
"S101", # asserts allowed in tests...
125+
"SLF001", # allow access to private members
125126
]

0 commit comments

Comments
 (0)