diff --git a/docs/source/dns_aid.md b/docs/source/dns_aid.md new file mode 100644 index 000000000..c49666ec1 --- /dev/null +++ b/docs/source/dns_aid.md @@ -0,0 +1,113 @@ +# DNS-AID Service Discovery + +Forge services can optionally register DNS-AID SVCB records on startup, enabling +peer discovery via DNS rather than hard-coded coordinator addresses. + +## Installation + +```bash +pip install forge[dns-aid] +``` + +## Configuration + +DNS-AID requires **both** the `DNS_AID_ENABLED` environment variable and +the per-service `DnsAidConfig.enabled` flag to be true. This dual-guard +means the environment variable acts as a global kill switch. + +### Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `DNS_AID_ENABLED` | `false` | Global toggle. Must be `true` for any DNS-AID operations. | +| `DNS_AID_ZONE` | — | DNS zone suffix (e.g. `_agents.svc.cluster.local`) | +| `DNS_AID_SERVER` | — | DNS server address (e.g. `10.0.0.53`) | +| `DNS_AID_PORT` | `853` | DNS server port | +| `DNS_AID_BACKEND` | — | DNS backend (`route53`, `cloudflare`, `ddns`, `mock`, etc.) | + +### Per-Service Configuration + +Add `DnsAidConfig` to your actor options: + +```python +from forge.controller import ForgeActor +from forge.types import DnsAidConfig + +dns_cfg = DnsAidConfig( + enabled=True, + name="generator", # DNS service name (default: class name) + domain="forge.internal", # DNS domain + port=8080, # Externally reachable port (required) + ttl=30, # Record TTL in seconds + capabilities=["gpu:8"], # Extra capabilities to advertise + category="rl-training", # Discovery category +) + +service = await MyGenerator.options( + num_replicas=4, + procs=2, + with_gpus=True, + dns_aid=dns_cfg, +).as_service(model_path="...") +``` + +The `port` field is required when `enabled` is True. It should be set to +the port that external systems use to reach this service (e.g. a load +balancer, gateway, or sidecar proxy port). Monarch services communicate +via actor RPC internally, so there is no auto-detected listener port. + +### OmegaConf YAML + +```yaml +# Requires DNS_AID_ENABLED=true in the environment +generator: + procs: 2 + num_replicas: 4 + with_gpus: true + dns_aid: + enabled: true + name: generator + domain: forge.internal + port: 8080 + ttl: 30 + capabilities: + - "gpu:8" + - "shard_count:4" +``` + +## How It Works + +1. **Startup**: After the service is fully initialized, `publish_service()` creates + a DNS-AID SVCB record advertising the service's hostname, port, role, and + capabilities. + +2. **Discovery**: Other services can call `discover_peers()` to find registered + peers by name. Discovery retries with exponential backoff (max 5 attempts) + to handle race conditions during cluster startup. Pass `retry_on_empty=False` + if you want to return immediately when no peers are found. + +3. **Shutdown**: `unpublish_service()` removes the DNS record. This is best-effort; + if the process crashes, the record expires after the configured TTL (default 30s). + +## Peer Discovery Example + +```python +from forge.controller.dns_aid import discover_peers +from forge.types import DnsAidConfig + +cfg = DnsAidConfig(enabled=True, domain="forge.internal") + +# Find all trainer services (retries if not yet registered) +trainers = await discover_peers("trainer", cfg) +for agent in trainers: + print(f"Found trainer at {agent.target_host}:{agent.port}") + +# Check once without retrying +trainers = await discover_peers("trainer", cfg, retry_on_empty=False) +``` + +## Backward Compatibility + +DNS-AID is fully opt-in. When `DNS_AID_ENABLED` is unset or `false` (the default), +no DNS operations are performed and the `dns-aid` package does not need to be +installed. Existing deployments are completely unaffected. diff --git a/pyproject.toml b/pyproject.toml index 0a5439fdc..7f2c2321c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ dev = [ "pytest-asyncio", "multiprocess", ] +dns-aid = ["dns-aid>=0.12.0"] docs = [ "sphinx==7.2.6", "pytorch-sphinx-theme2==0.1.0", diff --git a/src/forge/controller/actor.py b/src/forge/controller/actor.py index 796677b22..d35622207 100644 --- a/src/forge/controller/actor.py +++ b/src/forge/controller/actor.py @@ -173,6 +173,27 @@ async def as_service( service_interface = ServiceInterface(service, cls) # Register this service with the provisioner so it can cleanly shut this down await register_service(service_interface) + + # DNS-AID registration (best-effort, after service is fully initialized) + service_interface._dns_aid_cfg = cfg.dns_aid + if cfg.dns_aid is not None: + from forge.controller.dns_aid import is_dns_aid_enabled, publish_service + + if is_dns_aid_enabled(cfg.dns_aid): + if cfg.dns_aid.port is None: + logger.warning( + "DNS-AID: dns_aid.port is not set, skipping registration. " + "Set DnsAidConfig(port=...) to the externally reachable port." + ) + else: + import socket as _socket + + _hostname = _socket.gethostname() + _dns_name = cfg.dns_aid.name or cls.__name__.lower() + await publish_service( + _dns_name, _hostname, cfg.dns_aid.port, cfg.dns_aid + ) + return service_interface @endpoint diff --git a/src/forge/controller/dns_aid.py b/src/forge/controller/dns_aid.py new file mode 100644 index 000000000..6b13a689b --- /dev/null +++ b/src/forge/controller/dns_aid.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""DNS-AID service discovery helpers for Forge services. + +Provides publish, unpublish, and discover wrappers around the dns_aid library +with torchforge-specific defaults, dual enable guards, and retry logic. + +All operations are best-effort: failures are logged but never raised, so +service startup and shutdown are not blocked by DNS issues. +""" + +import asyncio +import logging +from typing import TYPE_CHECKING + +from forge.env import DNS_AID_ENABLED + +if TYPE_CHECKING: + from forge.types import DnsAidConfig + +logger = logging.getLogger(__name__) + +# Cached dns_aid import — None means not yet attempted, False means import failed. +_dns_aid_module = None +_dns_aid_import_attempted = False + + +def _get_forge_version() -> str: + """Return the installed forge package version, or 'unknown'.""" + try: + from importlib.metadata import version + + return version("forge") + except Exception: + return "unknown" + + +def _fqdn(service_name: str) -> str: + """Build the canonical DNS-AID name for a forge service.""" + return f"torchforge-{service_name}" + + +def is_dns_aid_enabled(cfg: "DnsAidConfig | None") -> bool: + """Check whether DNS-AID is enabled via both env var and config. + + Both ``DNS_AID_ENABLED`` environment variable and ``cfg.enabled`` must + be true for DNS-AID operations to proceed. + """ + if cfg is None: + return False + return bool(DNS_AID_ENABLED.get_value()) and cfg.enabled + + +def _try_import_dns_aid(): + """Lazily import dns_aid, returning the module or None. + + The result is cached so the warning is only logged once. + """ + global _dns_aid_module, _dns_aid_import_attempted + if _dns_aid_import_attempted: + return _dns_aid_module + + _dns_aid_import_attempted = True + try: + import dns_aid + + _dns_aid_module = dns_aid + return dns_aid + except ImportError: + logger.warning( + "dns-aid package is not installed. " + "Install with: pip install forge[dns-aid]" + ) + _dns_aid_module = None + return None + + +async def publish_service( + service_name: str, + hostname: str, + port: int, + cfg: "DnsAidConfig", +) -> bool: + """Publish a Forge service as a DNS-AID SVCB record. + + Args: + service_name: Logical name for the service (e.g. "generator"). + hostname: Host where the service is reachable. + port: Port where the service is reachable. + cfg: DNS-AID configuration. + + Returns: + True if publish succeeded, False otherwise. + """ + if not is_dns_aid_enabled(cfg): + return False + + dns_aid = _try_import_dns_aid() + if dns_aid is None: + return False + + capabilities = [ + "framework:torchforge", + f"role:{service_name}", + *cfg.capabilities, + ] + dns_name = _fqdn(service_name) + + try: + await dns_aid.publish( + name=dns_name, + domain=cfg.domain, + protocol=cfg.protocol, + endpoint=hostname, + port=port, + capabilities=capabilities, + version=_get_forge_version(), + description=f"Torchforge {service_name} service", + category=cfg.category, + ttl=cfg.ttl, + ) + logger.info( + f"DNS-AID: published {dns_name} " + f"at {hostname}:{port} (domain={cfg.domain}, ttl={cfg.ttl}s)" + ) + return True + except Exception: + logger.warning( + f"DNS-AID: failed to publish {dns_name}", + exc_info=True, + ) + return False + + +async def unpublish_service( + service_name: str, + cfg: "DnsAidConfig", +) -> bool: + """Remove a Forge service's DNS-AID record. Best-effort. + + Args: + service_name: Logical name for the service. + cfg: DNS-AID configuration. + + Returns: + True if unpublish succeeded, False otherwise. + """ + if not is_dns_aid_enabled(cfg): + return False + + dns_aid = _try_import_dns_aid() + if dns_aid is None: + return False + + dns_name = _fqdn(service_name) + try: + await dns_aid.unpublish( + name=dns_name, + domain=cfg.domain, + protocol=cfg.protocol, + ) + logger.info(f"DNS-AID: unpublished {dns_name}") + return True + except Exception: + logger.warning( + f"DNS-AID: failed to unpublish {dns_name}", + exc_info=True, + ) + return False + + +async def discover_peers( + service_name: str, + cfg: "DnsAidConfig", + max_attempts: int = 5, + initial_delay: float = 0.5, + backoff_factor: float = 2.0, + max_delay: float = 8.0, + retry_on_empty: bool = True, +) -> list: + """Discover peer Forge services via DNS-AID with exponential backoff. + + Args: + service_name: Name of the service to discover (e.g. "trainer"). + cfg: DNS-AID configuration. + max_attempts: Maximum number of discovery attempts. + initial_delay: Initial retry delay in seconds. + backoff_factor: Multiplier for each subsequent retry delay. + max_delay: Maximum retry delay in seconds. + retry_on_empty: If True (default), retry when discovery succeeds but + returns no agents. Set to False to return immediately on a + successful-but-empty response. + + Returns: + List of discovered AgentRecord objects, or empty list on failure. + """ + if not is_dns_aid_enabled(cfg): + return [] + + dns_aid = _try_import_dns_aid() + if dns_aid is None: + return [] + + dns_name = _fqdn(service_name) + delay = initial_delay + for attempt in range(1, max_attempts + 1): + try: + result = await dns_aid.discover( + domain=cfg.domain, + protocol=cfg.protocol, + name=dns_name, + ) + if result.agents: + logger.info( + f"DNS-AID: discovered {len(result.agents)} peer(s) " + f"for {dns_name} (attempt {attempt})" + ) + return result.agents + + if not retry_on_empty: + logger.debug(f"DNS-AID: no peers found for {dns_name}") + return [] + + except Exception: + logger.debug( + f"DNS-AID: discover attempt {attempt}/{max_attempts} " + f"for {dns_name} failed", + exc_info=True, + ) + + if attempt < max_attempts: + logger.debug( + f"DNS-AID: retrying discovery in {delay:.1f}s " + f"(attempt {attempt}/{max_attempts})" + ) + await asyncio.sleep(delay) + delay = min(delay * backoff_factor, max_delay) + + logger.warning( + f"DNS-AID: failed to discover {dns_name} after {max_attempts} attempts" + ) + return [] diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 235835a15..4b90f335f 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -538,6 +538,16 @@ async def shutdown_all_allocations(self): logger.info( f"Shutting down {len(self._registered_services)} service(s) and {len(self._registered_actors)} actor(s)..." ) + # --- DNS-AID deregistration (best-effort, before service shutdown) --- + from forge.controller.dns_aid import is_dns_aid_enabled, unpublish_service + + for service in self._registered_services: + dns_cfg = service._dns_aid_cfg + if dns_cfg is not None and is_dns_aid_enabled(dns_cfg): + dns_name = dns_cfg.name or service.actor_def.__name__.lower() + # unpublish_service is already best-effort (never raises) + await unpublish_service(dns_name, dns_cfg) + # --- ServiceInterface --- for service in reversed(self._registered_services): try: diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index c64d5c3f3..24a5a6e15 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -177,6 +177,9 @@ class ServiceInterface: def __init__(self, _service, actor_def): self._service = _service self.actor_def = actor_def + self._dns_aid_cfg = ( + None # Set by ForgeActor.as_service() if DNS-AID is configured + ) # Dynamically create ServiceEndpoint objects for user's actor endpoints # Inspect the actor_def directly to find endpoints diff --git a/src/forge/env.py b/src/forge/env.py index b698b8013..781bec73f 100644 --- a/src/forge/env.py +++ b/src/forge/env.py @@ -106,6 +106,13 @@ def get_value(self) -> Any: ) +DNS_AID_ENABLED = EnvVar( + name="DNS_AID_ENABLED", + default=False, + description="Enable DNS-AID service discovery for forge services.", +) + + def all_env_vars() -> list[EnvVar]: """Retrieves all registered environment variable names.""" env_vars = [] diff --git a/src/forge/types.py b/src/forge/types.py index 5c8059d86..a930a3f02 100644 --- a/src/forge/types.py +++ b/src/forge/types.py @@ -60,6 +60,34 @@ class ProcessConfig: mesh_name: str | None = None +@dataclass +class DnsAidConfig: + """Configuration for DNS-AID service discovery. + + Args: + enabled: Whether DNS-AID registration is enabled for this service. + name: Override DNS service name. Defaults to the actor class name (lowercased). + domain: DNS domain for registration. + protocol: DNS-AID protocol identifier. + port: Port to advertise in the DNS record. This should be the port that + external systems use to reach this service (e.g. a load balancer or + gateway port). Required when enabled is True — Monarch services + communicate via actor RPC, so there is no auto-detected listener port. + ttl: Time-to-live in seconds for DNS records. Dead workers expire after this. + capabilities: Additional capabilities to advertise. + category: Service category for discovery filtering. + """ + + enabled: bool = False + name: str | None = None + domain: str = "forge.internal" + protocol: str = "mcp" + port: int | None = None + ttl: int = 30 + capabilities: list[str] = field(default_factory=list) + category: str = "rl-training" + + @dataclass class ServiceConfig: """The configuration for a Forge service. @@ -84,6 +112,7 @@ class ServiceConfig: replica_max_concurrent_requests: int = 10 return_first_rank_result: bool = True mesh_name: str | None = None + dns_aid: DnsAidConfig | None = None def to_process_config(self) -> ProcessConfig: """Extract ProcessConfig from this ServiceConfig. diff --git a/tests/unit_tests/test_dns_aid.py b/tests/unit_tests/test_dns_aid.py new file mode 100644 index 000000000..d703af148 --- /dev/null +++ b/tests/unit_tests/test_dns_aid.py @@ -0,0 +1,440 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from forge.controller.dns_aid import ( + _fqdn, + _try_import_dns_aid, + discover_peers, + is_dns_aid_enabled, + publish_service, + unpublish_service, +) +from forge.types import DnsAidConfig + + +@pytest.fixture(autouse=True) +def _reset_dns_aid_import_cache(): + """Reset the cached import state between tests.""" + import forge.controller.dns_aid as mod + + mod._dns_aid_import_attempted = False + mod._dns_aid_module = None + yield + mod._dns_aid_import_attempted = False + mod._dns_aid_module = None + + +# --- _fqdn --- + + +def test_fqdn(): + assert _fqdn("generator") == "torchforge-generator" + assert _fqdn("replay-buffer") == "torchforge-replay-buffer" + + +# --- is_dns_aid_enabled --- + + +def test_is_dns_aid_enabled_both_true(monkeypatch): + monkeypatch.setenv("DNS_AID_ENABLED", "true") + cfg = DnsAidConfig(enabled=True) + assert is_dns_aid_enabled(cfg) is True + + +def test_is_dns_aid_enabled_env_false(monkeypatch): + monkeypatch.setenv("DNS_AID_ENABLED", "false") + cfg = DnsAidConfig(enabled=True) + assert is_dns_aid_enabled(cfg) is False + + +def test_is_dns_aid_enabled_config_false(monkeypatch): + monkeypatch.setenv("DNS_AID_ENABLED", "true") + cfg = DnsAidConfig(enabled=False) + assert is_dns_aid_enabled(cfg) is False + + +def test_is_dns_aid_enabled_none_config(): + assert is_dns_aid_enabled(None) is False + + +# --- _try_import_dns_aid caching --- + + +def test_import_warning_only_once(monkeypatch): + """The missing-package warning should fire once, not on every call.""" + import forge.controller.dns_aid as mod + + with patch.dict("sys.modules", {"dns_aid": None}): + # Simulate ImportError by patching builtins + original_import = ( + __builtins__.__import__ + if hasattr(__builtins__, "__import__") + else __import__ + ) + + def fake_import(name, *args, **kwargs): + if name == "dns_aid": + raise ImportError("no dns_aid") + return original_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=fake_import): + result1 = _try_import_dns_aid() + result2 = _try_import_dns_aid() + + assert result1 is None + assert result2 is None + # Second call should have used cache, not re-imported + assert mod._dns_aid_import_attempted is True + + +# --- publish_service --- + + +@pytest.mark.asyncio +async def test_publish_service_success(monkeypatch): + monkeypatch.setenv("DNS_AID_ENABLED", "true") + cfg = DnsAidConfig(enabled=True, domain="test.internal", port=7860, ttl=60) + + mock_dns_aid = MagicMock() + mock_dns_aid.publish = AsyncMock() + + with patch( + "forge.controller.dns_aid._try_import_dns_aid", return_value=mock_dns_aid + ): + result = await publish_service("generator", "host1", 7860, cfg) + + assert result is True + mock_dns_aid.publish.assert_called_once() + call_kwargs = mock_dns_aid.publish.call_args.kwargs + assert call_kwargs["name"] == "torchforge-generator" + assert call_kwargs["domain"] == "test.internal" + assert call_kwargs["endpoint"] == "host1" + assert call_kwargs["port"] == 7860 + assert call_kwargs["ttl"] == 60 + assert "framework:torchforge" in call_kwargs["capabilities"] + assert "role:generator" in call_kwargs["capabilities"] + + +@pytest.mark.asyncio +async def test_publish_service_with_extra_capabilities(monkeypatch): + monkeypatch.setenv("DNS_AID_ENABLED", "true") + cfg = DnsAidConfig(enabled=True, port=7861, capabilities=["gpu:8", "shard_count:4"]) + + mock_dns_aid = MagicMock() + mock_dns_aid.publish = AsyncMock() + + with patch( + "forge.controller.dns_aid._try_import_dns_aid", return_value=mock_dns_aid + ): + await publish_service("trainer", "host2", 7861, cfg) + + call_kwargs = mock_dns_aid.publish.call_args.kwargs + caps = call_kwargs["capabilities"] + assert caps == ["framework:torchforge", "role:trainer", "gpu:8", "shard_count:4"] + + +@pytest.mark.asyncio +async def test_publish_service_failure_no_raise(monkeypatch): + monkeypatch.setenv("DNS_AID_ENABLED", "true") + cfg = DnsAidConfig(enabled=True, port=7860) + + mock_dns_aid = MagicMock() + mock_dns_aid.publish = AsyncMock(side_effect=ConnectionError("DNS unreachable")) + + with patch( + "forge.controller.dns_aid._try_import_dns_aid", return_value=mock_dns_aid + ): + result = await publish_service("generator", "host1", 7860, cfg) + + assert result is False + + +@pytest.mark.asyncio +async def test_publish_skipped_when_disabled(monkeypatch): + monkeypatch.setenv("DNS_AID_ENABLED", "false") + cfg = DnsAidConfig(enabled=True, port=7860) + + mock_dns_aid = MagicMock() + mock_dns_aid.publish = AsyncMock() + + with patch( + "forge.controller.dns_aid._try_import_dns_aid", return_value=mock_dns_aid + ): + result = await publish_service("generator", "host1", 7860, cfg) + + assert result is False + mock_dns_aid.publish.assert_not_called() + + +@pytest.mark.asyncio +async def test_publish_uses_forge_version(monkeypatch): + monkeypatch.setenv("DNS_AID_ENABLED", "true") + cfg = DnsAidConfig(enabled=True, port=8080) + + mock_dns_aid = MagicMock() + mock_dns_aid.publish = AsyncMock() + + with patch( + "forge.controller.dns_aid._try_import_dns_aid", return_value=mock_dns_aid + ): + with patch("forge.controller.dns_aid._get_forge_version", return_value="0.5.0"): + await publish_service("gen", "host", 8080, cfg) + + assert mock_dns_aid.publish.call_args.kwargs["version"] == "0.5.0" + + +# --- unpublish_service --- + + +@pytest.mark.asyncio +async def test_unpublish_service_success(monkeypatch): + monkeypatch.setenv("DNS_AID_ENABLED", "true") + cfg = DnsAidConfig(enabled=True, domain="test.internal") + + mock_dns_aid = MagicMock() + mock_dns_aid.unpublish = AsyncMock(return_value=True) + + with patch( + "forge.controller.dns_aid._try_import_dns_aid", return_value=mock_dns_aid + ): + result = await unpublish_service("generator", cfg) + + assert result is True + mock_dns_aid.unpublish.assert_called_once_with( + name="torchforge-generator", + domain="test.internal", + protocol="mcp", + ) + + +@pytest.mark.asyncio +async def test_unpublish_service_best_effort(monkeypatch): + monkeypatch.setenv("DNS_AID_ENABLED", "true") + cfg = DnsAidConfig(enabled=True) + + mock_dns_aid = MagicMock() + mock_dns_aid.unpublish = AsyncMock(side_effect=RuntimeError("DNS timeout")) + + with patch( + "forge.controller.dns_aid._try_import_dns_aid", return_value=mock_dns_aid + ): + result = await unpublish_service("generator", cfg) + + assert result is False + + +# --- discover_peers --- + + +@pytest.mark.asyncio +async def test_discover_peers_success(monkeypatch): + monkeypatch.setenv("DNS_AID_ENABLED", "true") + cfg = DnsAidConfig(enabled=True, domain="test.internal") + + mock_agent = MagicMock() + mock_result = MagicMock() + mock_result.agents = [mock_agent] + + mock_dns_aid = MagicMock() + mock_dns_aid.discover = AsyncMock(return_value=mock_result) + + with patch( + "forge.controller.dns_aid._try_import_dns_aid", return_value=mock_dns_aid + ): + agents = await discover_peers("trainer", cfg) + + assert len(agents) == 1 + assert agents[0] is mock_agent + + +@pytest.mark.asyncio +async def test_discover_peers_retry_with_backoff(monkeypatch): + """Verify exponential backoff delays between retry attempts.""" + monkeypatch.setenv("DNS_AID_ENABLED", "true") + cfg = DnsAidConfig(enabled=True) + + success_result = MagicMock() + success_result.agents = [MagicMock()] + + mock_dns_aid = MagicMock() + mock_dns_aid.discover = AsyncMock( + side_effect=[ + ConnectionError("fail 1"), + ConnectionError("fail 2"), + success_result, + ] + ) + + mock_sleep = AsyncMock() + with patch( + "forge.controller.dns_aid._try_import_dns_aid", return_value=mock_dns_aid + ): + with patch("forge.controller.dns_aid.asyncio.sleep", mock_sleep): + agents = await discover_peers( + "trainer", cfg, initial_delay=1.0, backoff_factor=2.0, max_delay=10.0 + ) + + assert len(agents) == 1 + assert mock_dns_aid.discover.call_count == 3 + # Verify exponential backoff: 1.0s after first fail, 2.0s after second + assert mock_sleep.call_count == 2 + mock_sleep.assert_any_call(1.0) + mock_sleep.assert_any_call(2.0) + + +@pytest.mark.asyncio +async def test_discover_peers_all_retries_fail(monkeypatch): + monkeypatch.setenv("DNS_AID_ENABLED", "true") + cfg = DnsAidConfig(enabled=True) + + mock_dns_aid = MagicMock() + mock_dns_aid.discover = AsyncMock(side_effect=ConnectionError("always fails")) + + with patch( + "forge.controller.dns_aid._try_import_dns_aid", return_value=mock_dns_aid + ): + with patch("forge.controller.dns_aid.asyncio.sleep", new_callable=AsyncMock): + agents = await discover_peers( + "trainer", cfg, max_attempts=3, initial_delay=0.01 + ) + + assert agents == [] + assert mock_dns_aid.discover.call_count == 3 + + +@pytest.mark.asyncio +async def test_discover_peers_retry_on_empty_true(monkeypatch): + """With retry_on_empty=True (default), empty results trigger retries.""" + monkeypatch.setenv("DNS_AID_ENABLED", "true") + cfg = DnsAidConfig(enabled=True) + + empty_result = MagicMock() + empty_result.agents = [] + success_result = MagicMock() + success_result.agents = [MagicMock()] + + mock_dns_aid = MagicMock() + mock_dns_aid.discover = AsyncMock(side_effect=[empty_result, success_result]) + + with patch( + "forge.controller.dns_aid._try_import_dns_aid", return_value=mock_dns_aid + ): + with patch("forge.controller.dns_aid.asyncio.sleep", new_callable=AsyncMock): + agents = await discover_peers("trainer", cfg, retry_on_empty=True) + + assert len(agents) == 1 + assert mock_dns_aid.discover.call_count == 2 + + +@pytest.mark.asyncio +async def test_discover_peers_retry_on_empty_false(monkeypatch): + """With retry_on_empty=False, empty results return immediately.""" + monkeypatch.setenv("DNS_AID_ENABLED", "true") + cfg = DnsAidConfig(enabled=True) + + empty_result = MagicMock() + empty_result.agents = [] + + mock_dns_aid = MagicMock() + mock_dns_aid.discover = AsyncMock(return_value=empty_result) + + with patch( + "forge.controller.dns_aid._try_import_dns_aid", return_value=mock_dns_aid + ): + agents = await discover_peers("trainer", cfg, retry_on_empty=False) + + assert agents == [] + mock_dns_aid.discover.assert_called_once() + + +# --- Import guard --- + + +@pytest.mark.asyncio +async def test_dns_aid_import_missing(monkeypatch): + monkeypatch.setenv("DNS_AID_ENABLED", "true") + cfg = DnsAidConfig(enabled=True) + + with patch("forge.controller.dns_aid._try_import_dns_aid", return_value=None): + publish_result = await publish_service("gen", "host", 8080, cfg) + unpublish_result = await unpublish_service("gen", cfg) + discover_result = await discover_peers("gen", cfg) + + assert publish_result is False + assert unpublish_result is False + assert discover_result == [] + + +# --- Provisioner shutdown integration --- + + +@pytest.mark.asyncio +async def test_provisioner_shutdown_calls_unpublish(monkeypatch): + """Verify that shutdown_all_allocations unpublishes DNS-AID services.""" + monkeypatch.setenv("DNS_AID_ENABLED", "true") + + dns_cfg = DnsAidConfig(enabled=True, domain="test.internal", port=7860) + + # Build a minimal ServiceInterface-like object with the expected attributes + mock_service = MagicMock() + mock_service._dns_aid_cfg = dns_cfg + mock_service.actor_def.__name__ = "MyGenerator" + mock_service.shutdown = AsyncMock() + + mock_unpublish = AsyncMock(return_value=True) + + with patch("forge.controller.provisioner.unpublish_service", mock_unpublish): + with patch( + "forge.controller.provisioner.is_dns_aid_enabled", return_value=True + ): + with patch( + "forge.controller.provisioner.shutdown_context", new_callable=AsyncMock + ): + from forge.controller.provisioner import Provisioner + + provisioner = Provisioner.__new__(Provisioner) + provisioner._lock = __import__("asyncio").Lock() + provisioner._registered_services = [mock_service] + provisioner._registered_actors = [] + provisioner.launcher = None + + await provisioner.shutdown_all_allocations() + + mock_unpublish.assert_called_once_with("mygenerator", dns_cfg) + mock_service.shutdown.assert_called_once() + + +@pytest.mark.asyncio +async def test_provisioner_shutdown_skips_when_no_dns_cfg(monkeypatch): + """Services without DNS-AID config should not trigger unpublish.""" + mock_service = MagicMock() + mock_service._dns_aid_cfg = None + mock_service.shutdown = AsyncMock() + + mock_unpublish = AsyncMock() + + with patch("forge.controller.provisioner.unpublish_service", mock_unpublish): + with patch( + "forge.controller.provisioner.is_dns_aid_enabled", return_value=False + ): + with patch( + "forge.controller.provisioner.shutdown_context", new_callable=AsyncMock + ): + from forge.controller.provisioner import Provisioner + + provisioner = Provisioner.__new__(Provisioner) + provisioner._lock = __import__("asyncio").Lock() + provisioner._registered_services = [mock_service] + provisioner._registered_actors = [] + provisioner.launcher = None + + await provisioner.shutdown_all_allocations() + + mock_unpublish.assert_not_called() + mock_service.shutdown.assert_called_once()