|
2 | 2 | import inspect |
3 | 3 | import os |
4 | 4 | import random |
| 5 | +import socket |
5 | 6 | import typing as t |
6 | 7 | from dataclasses import dataclass |
7 | 8 | from datetime import datetime, timezone |
8 | 9 | from pathlib import Path |
9 | | -from urllib.parse import urljoin, urlparse, urlunparse |
| 10 | +from urllib.parse import ParseResult, urljoin, urlparse, urlunparse |
10 | 11 |
|
11 | 12 | import coolname # type: ignore [import-untyped] |
12 | 13 | import logfire |
|
32 | 33 | ENV_SERVER, |
33 | 34 | ENV_SERVER_URL, |
34 | 35 | ) |
35 | | -from dreadnode.metric import Metric, MetricAggMode, MetricDict, Scorer, ScorerCallable, T |
| 36 | +from dreadnode.metric import ( |
| 37 | + Metric, |
| 38 | + MetricAggMode, |
| 39 | + MetricDict, |
| 40 | + Scorer, |
| 41 | + ScorerCallable, |
| 42 | + T, |
| 43 | +) |
36 | 44 | from dreadnode.task import P, R, Task |
37 | 45 | from dreadnode.tracing.exporters import ( |
38 | 46 | FileExportConfig, |
|
54 | 62 | JsonDict, |
55 | 63 | JsonValue, |
56 | 64 | ) |
57 | | -from dreadnode.util import clean_str, handle_internal_errors |
| 65 | +from dreadnode.util import clean_str, handle_internal_errors, logger |
58 | 66 | from dreadnode.version import VERSION |
59 | 67 |
|
60 | 68 | if t.TYPE_CHECKING: |
@@ -128,6 +136,97 @@ def __init__( |
128 | 136 |
|
129 | 137 | self._initialized = False |
130 | 138 |
|
| 139 | + @staticmethod |
| 140 | + def _resolve_endpoint(endpoint: str | None) -> str | None: |
| 141 | + """Automatically resolve endpoints based on environment |
| 142 | +
|
| 143 | + Args: |
| 144 | + endpoint: The endpoint URL to resolve. |
| 145 | +
|
| 146 | + Returns: |
| 147 | + str: The resolved endpoint URL. |
| 148 | +
|
| 149 | + Raises: |
| 150 | + ValueError: If the endpoint URL is invalid. |
| 151 | + """ |
| 152 | + if not endpoint: |
| 153 | + return None |
| 154 | + parsed = urlparse(endpoint) |
| 155 | + |
| 156 | + # If it's a real domain (has dots), use as-is |
| 157 | + if not parsed.hostname: |
| 158 | + raise ValueError(f"Invalid endpoint URL: {endpoint}") |
| 159 | + |
| 160 | + if "." in parsed.hostname: |
| 161 | + return endpoint |
| 162 | + |
| 163 | + # If it's a service name, try to resolve it |
| 164 | + if Dreadnode._is_docker_service_name(parsed.hostname): |
| 165 | + return Dreadnode._resolve_docker_service(endpoint, parsed) |
| 166 | + |
| 167 | + return endpoint |
| 168 | + |
| 169 | + @staticmethod |
| 170 | + def _is_docker_service_name(hostname: str) -> bool: |
| 171 | + """Check if this looks like a Docker service name |
| 172 | +
|
| 173 | + Args: |
| 174 | + hostname: The hostname to check. |
| 175 | +
|
| 176 | + Returns: |
| 177 | + bool: True if the hostname looks like a Docker service name, False otherwise. |
| 178 | + """ |
| 179 | + return bool(hostname and "." not in hostname and hostname != "localhost") |
| 180 | + |
| 181 | + @staticmethod |
| 182 | + def _resolve_docker_service(original_endpoint: str, parsed: ParseResult) -> str: |
| 183 | + """Try different resolution strategies for Docker services |
| 184 | +
|
| 185 | + Args: |
| 186 | + original_endpoint: The original endpoint URL. |
| 187 | + parsed: The parsed URL object. |
| 188 | +
|
| 189 | + Returns: |
| 190 | + str: The resolved endpoint URL. |
| 191 | +
|
| 192 | + Raises: |
| 193 | + RuntimeError: If no valid endpoint is found. |
| 194 | + """ |
| 195 | + strategies = [ |
| 196 | + original_endpoint, # Try original first (works if running in same network) |
| 197 | + f"{parsed.scheme}://localhost:{parsed.port}", # Try localhost |
| 198 | + f"{parsed.scheme}://host.docker.internal:{parsed.port}", # Docker Desktop |
| 199 | + f"{parsed.scheme}://172.17.0.1:{parsed.port}", # Docker bridge IP |
| 200 | + ] |
| 201 | + |
| 202 | + for endpoint in strategies: |
| 203 | + if Dreadnode._test_connection(endpoint): |
| 204 | + logger.warning( |
| 205 | + f"Resolved Docker service for s3 connection '{parsed.hostname}' to '{endpoint}'." |
| 206 | + ) |
| 207 | + return str(endpoint) |
| 208 | + |
| 209 | + # If nothing works, return original and let it fail with a helpful error |
| 210 | + raise RuntimeError(f"Failed to connect to the Dreadnode Artifact storage at {endpoint}.") |
| 211 | + |
| 212 | + @staticmethod |
| 213 | + def _test_connection(endpoint: str) -> bool: |
| 214 | + """Quick connectivity test |
| 215 | +
|
| 216 | + Args: |
| 217 | + endpoint: The endpoint URL to test. |
| 218 | +
|
| 219 | + Returns: |
| 220 | + bool: True if the connection is successful, False otherwise. |
| 221 | + """ |
| 222 | + try: |
| 223 | + parsed = urlparse(endpoint) |
| 224 | + socket.create_connection((parsed.hostname, parsed.port or 443), timeout=1) |
| 225 | + except Exception: # noqa: BLE001 |
| 226 | + return False |
| 227 | + |
| 228 | + return True |
| 229 | + |
131 | 230 | def configure( |
132 | 231 | self, |
133 | 232 | *, |
@@ -261,12 +360,13 @@ def initialize(self) -> None: |
261 | 360 | # ) |
262 | 361 |
|
263 | 362 | credentials = self._api.get_user_data_credentials() |
| 363 | + resolved_endpoint = self._resolve_endpoint(credentials.endpoint) |
264 | 364 | self._fs = S3FileSystem( |
265 | 365 | key=credentials.access_key_id, |
266 | 366 | secret=credentials.secret_access_key, |
267 | 367 | token=credentials.session_token, |
268 | 368 | client_kwargs={ |
269 | | - "endpoint_url": credentials.endpoint, |
| 369 | + "endpoint_url": resolved_endpoint, |
270 | 370 | "region_name": credentials.region, |
271 | 371 | }, |
272 | 372 | ) |
@@ -1002,7 +1102,10 @@ def log_metric( |
1002 | 1102 | value |
1003 | 1103 | if isinstance(value, Metric) |
1004 | 1104 | else Metric( |
1005 | | - float(value), step, timestamp or datetime.now(timezone.utc), attributes or {} |
| 1105 | + float(value), |
| 1106 | + step, |
| 1107 | + timestamp or datetime.now(timezone.utc), |
| 1108 | + attributes or {}, |
1006 | 1109 | ) |
1007 | 1110 | ) |
1008 | 1111 | return target.log_metric(name, metric, origin=origin, mode=mode) |
|
0 commit comments