Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions docs/source/dns_aid.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# DNS-AID Service Discovery

Forge services can optionally register DNS-AID SVCB records on startup, enabling
peer discovery via DNS rather than hard-coded coordinator addresses.

## Installation

```bash
pip install forge[dns-aid]
```

## Configuration

DNS-AID requires **both** the `DNS_AID_ENABLED` environment variable and
the per-service `DnsAidConfig.enabled` flag to be true. This dual-guard
means the environment variable acts as a global kill switch.

### Environment Variables

| Variable | Default | Description |
|----------|---------|-------------|
| `DNS_AID_ENABLED` | `false` | Global toggle. Must be `true` for any DNS-AID operations. |
| `DNS_AID_ZONE` | — | DNS zone suffix (e.g. `_agents.svc.cluster.local`) |
| `DNS_AID_SERVER` | — | DNS server address (e.g. `10.0.0.53`) |
| `DNS_AID_PORT` | `853` | DNS server port |
| `DNS_AID_BACKEND` | — | DNS backend (`route53`, `cloudflare`, `ddns`, `mock`, etc.) |

### Per-Service Configuration

Add `DnsAidConfig` to your actor options:

```python
from forge.controller import ForgeActor
from forge.types import DnsAidConfig

dns_cfg = DnsAidConfig(
enabled=True,
name="generator", # DNS service name (default: class name)
domain="forge.internal", # DNS domain
port=8080, # Externally reachable port (required)
ttl=30, # Record TTL in seconds
capabilities=["gpu:8"], # Extra capabilities to advertise
category="rl-training", # Discovery category
)

service = await MyGenerator.options(
num_replicas=4,
procs=2,
with_gpus=True,
dns_aid=dns_cfg,
).as_service(model_path="...")
```

The `port` field is required when `enabled` is True. It should be set to
the port that external systems use to reach this service (e.g. a load
balancer, gateway, or sidecar proxy port). Monarch services communicate
via actor RPC internally, so there is no auto-detected listener port.

### OmegaConf YAML

```yaml
# Requires DNS_AID_ENABLED=true in the environment
generator:
procs: 2
num_replicas: 4
with_gpus: true
dns_aid:
enabled: true
name: generator
domain: forge.internal
port: 8080
ttl: 30
capabilities:
- "gpu:8"
- "shard_count:4"
```

## How It Works

1. **Startup**: After the service is fully initialized, `publish_service()` creates
a DNS-AID SVCB record advertising the service's hostname, port, role, and
capabilities.

2. **Discovery**: Other services can call `discover_peers()` to find registered
peers by name. Discovery retries with exponential backoff (max 5 attempts)
to handle race conditions during cluster startup. Pass `retry_on_empty=False`
if you want to return immediately when no peers are found.

3. **Shutdown**: `unpublish_service()` removes the DNS record. This is best-effort;
if the process crashes, the record expires after the configured TTL (default 30s).

## Peer Discovery Example

```python
from forge.controller.dns_aid import discover_peers
from forge.types import DnsAidConfig

cfg = DnsAidConfig(enabled=True, domain="forge.internal")

# Find all trainer services (retries if not yet registered)
trainers = await discover_peers("trainer", cfg)
for agent in trainers:
print(f"Found trainer at {agent.target_host}:{agent.port}")

# Check once without retrying
trainers = await discover_peers("trainer", cfg, retry_on_empty=False)
```

## Backward Compatibility

DNS-AID is fully opt-in. When `DNS_AID_ENABLED` is unset or `false` (the default),
no DNS operations are performed and the `dns-aid` package does not need to be
installed. Existing deployments are completely unaffected.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ dev = [
"pytest-asyncio",
"multiprocess",
]
dns-aid = ["dns-aid>=0.12.0"]
docs = [
"sphinx==7.2.6",
"pytorch-sphinx-theme2==0.1.0",
Expand Down
21 changes: 21 additions & 0 deletions src/forge/controller/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,27 @@ async def as_service(
service_interface = ServiceInterface(service, cls)
# Register this service with the provisioner so it can cleanly shut this down
await register_service(service_interface)

# DNS-AID registration (best-effort, after service is fully initialized)
service_interface._dns_aid_cfg = cfg.dns_aid
if cfg.dns_aid is not None:
from forge.controller.dns_aid import is_dns_aid_enabled, publish_service

if is_dns_aid_enabled(cfg.dns_aid):
if cfg.dns_aid.port is None:
logger.warning(
"DNS-AID: dns_aid.port is not set, skipping registration. "
"Set DnsAidConfig(port=...) to the externally reachable port."
)
else:
import socket as _socket

_hostname = _socket.gethostname()
_dns_name = cfg.dns_aid.name or cls.__name__.lower()
await publish_service(
_dns_name, _hostname, cfg.dns_aid.port, cfg.dns_aid
)

return service_interface

@endpoint
Expand Down
246 changes: 246 additions & 0 deletions src/forge/controller/dns_aid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""DNS-AID service discovery helpers for Forge services.

Provides publish, unpublish, and discover wrappers around the dns_aid library
with torchforge-specific defaults, dual enable guards, and retry logic.

All operations are best-effort: failures are logged but never raised, so
service startup and shutdown are not blocked by DNS issues.
"""

import asyncio
import logging
from typing import TYPE_CHECKING

from forge.env import DNS_AID_ENABLED

if TYPE_CHECKING:
from forge.types import DnsAidConfig

logger = logging.getLogger(__name__)

# Cached dns_aid import — None means not yet attempted, False means import failed.
_dns_aid_module = None
_dns_aid_import_attempted = False


def _get_forge_version() -> str:
"""Return the installed forge package version, or 'unknown'."""
try:
from importlib.metadata import version

return version("forge")
except Exception:
return "unknown"


def _fqdn(service_name: str) -> str:
"""Build the canonical DNS-AID name for a forge service."""
return f"torchforge-{service_name}"


def is_dns_aid_enabled(cfg: "DnsAidConfig | None") -> bool:
"""Check whether DNS-AID is enabled via both env var and config.

Both ``DNS_AID_ENABLED`` environment variable and ``cfg.enabled`` must
be true for DNS-AID operations to proceed.
"""
if cfg is None:
return False
return bool(DNS_AID_ENABLED.get_value()) and cfg.enabled


def _try_import_dns_aid():
"""Lazily import dns_aid, returning the module or None.

The result is cached so the warning is only logged once.
"""
global _dns_aid_module, _dns_aid_import_attempted
if _dns_aid_import_attempted:
return _dns_aid_module

_dns_aid_import_attempted = True
try:
import dns_aid

_dns_aid_module = dns_aid
return dns_aid
except ImportError:
logger.warning(
"dns-aid package is not installed. "
"Install with: pip install forge[dns-aid]"
)
_dns_aid_module = None
return None


async def publish_service(
service_name: str,
hostname: str,
port: int,
cfg: "DnsAidConfig",
) -> bool:
"""Publish a Forge service as a DNS-AID SVCB record.

Args:
service_name: Logical name for the service (e.g. "generator").
hostname: Host where the service is reachable.
port: Port where the service is reachable.
cfg: DNS-AID configuration.

Returns:
True if publish succeeded, False otherwise.
"""
if not is_dns_aid_enabled(cfg):
return False

dns_aid = _try_import_dns_aid()
if dns_aid is None:
return False

capabilities = [
"framework:torchforge",
f"role:{service_name}",
*cfg.capabilities,
]
dns_name = _fqdn(service_name)

try:
await dns_aid.publish(
name=dns_name,
domain=cfg.domain,
protocol=cfg.protocol,
endpoint=hostname,
port=port,
capabilities=capabilities,
version=_get_forge_version(),
description=f"Torchforge {service_name} service",
category=cfg.category,
ttl=cfg.ttl,
)
logger.info(
f"DNS-AID: published {dns_name} "
f"at {hostname}:{port} (domain={cfg.domain}, ttl={cfg.ttl}s)"
)
return True
except Exception:
logger.warning(
f"DNS-AID: failed to publish {dns_name}",
exc_info=True,
)
return False


async def unpublish_service(
service_name: str,
cfg: "DnsAidConfig",
) -> bool:
"""Remove a Forge service's DNS-AID record. Best-effort.

Args:
service_name: Logical name for the service.
cfg: DNS-AID configuration.

Returns:
True if unpublish succeeded, False otherwise.
"""
if not is_dns_aid_enabled(cfg):
return False

dns_aid = _try_import_dns_aid()
if dns_aid is None:
return False

dns_name = _fqdn(service_name)
try:
await dns_aid.unpublish(
name=dns_name,
domain=cfg.domain,
protocol=cfg.protocol,
)
logger.info(f"DNS-AID: unpublished {dns_name}")
return True
except Exception:
logger.warning(
f"DNS-AID: failed to unpublish {dns_name}",
exc_info=True,
)
return False


async def discover_peers(
service_name: str,
cfg: "DnsAidConfig",
max_attempts: int = 5,
initial_delay: float = 0.5,
backoff_factor: float = 2.0,
max_delay: float = 8.0,
retry_on_empty: bool = True,
) -> list:
"""Discover peer Forge services via DNS-AID with exponential backoff.

Args:
service_name: Name of the service to discover (e.g. "trainer").
cfg: DNS-AID configuration.
max_attempts: Maximum number of discovery attempts.
initial_delay: Initial retry delay in seconds.
backoff_factor: Multiplier for each subsequent retry delay.
max_delay: Maximum retry delay in seconds.
retry_on_empty: If True (default), retry when discovery succeeds but
returns no agents. Set to False to return immediately on a
successful-but-empty response.

Returns:
List of discovered AgentRecord objects, or empty list on failure.
"""
if not is_dns_aid_enabled(cfg):
return []

dns_aid = _try_import_dns_aid()
if dns_aid is None:
return []

dns_name = _fqdn(service_name)
delay = initial_delay
for attempt in range(1, max_attempts + 1):
try:
result = await dns_aid.discover(
domain=cfg.domain,
protocol=cfg.protocol,
name=dns_name,
)
if result.agents:
logger.info(
f"DNS-AID: discovered {len(result.agents)} peer(s) "
f"for {dns_name} (attempt {attempt})"
)
return result.agents

if not retry_on_empty:
logger.debug(f"DNS-AID: no peers found for {dns_name}")
return []

except Exception:
logger.debug(
f"DNS-AID: discover attempt {attempt}/{max_attempts} "
f"for {dns_name} failed",
exc_info=True,
)

if attempt < max_attempts:
logger.debug(
f"DNS-AID: retrying discovery in {delay:.1f}s "
f"(attempt {attempt}/{max_attempts})"
)
await asyncio.sleep(delay)
delay = min(delay * backoff_factor, max_delay)

logger.warning(
f"DNS-AID: failed to discover {dns_name} after {max_attempts} attempts"
)
return []
Loading