Skip to content

Commit f861120

Browse files
committed
Fix TCP PeerAgent advertised host
1 parent 20ee6bf commit f861120

4 files changed

Lines changed: 335 additions & 4 deletions

File tree

dlslime/dlslime/peer_agent/_agent.py

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,16 @@
88
from __future__ import annotations
99

1010
import inspect
11+
import ipaddress
1112
import json
1213
import os
14+
import socket
1315
import threading
1416
import time
1517
from concurrent.futures import ThreadPoolExecutor
1618
from dataclasses import dataclass
1719
from typing import Any, Dict, List, Optional, Set, Tuple, Union
20+
from urllib.parse import urlparse
1821

1922
try:
2023
import httpx
@@ -407,6 +410,45 @@ def _local_address(self) -> str:
407410
return str(address)
408411
return ""
409412

413+
def _ctrl_host(self) -> str:
414+
ctrl_url = self.ctrl_url
415+
if not ctrl_url.startswith(("http://", "https://")):
416+
ctrl_url = f"http://{ctrl_url}"
417+
return urlparse(ctrl_url).hostname or ""
418+
419+
@staticmethod
420+
def _is_loopback_host(host: str) -> bool:
421+
if host in {"localhost", "::1"}:
422+
return True
423+
try:
424+
return ipaddress.ip_address(host).is_loopback
425+
except ValueError:
426+
return False
427+
428+
@staticmethod
429+
def _local_ip_for_remote(remote_host: str) -> str:
430+
if not remote_host:
431+
return ""
432+
try:
433+
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
434+
sock.connect((remote_host, 80))
435+
return sock.getsockname()[0]
436+
except OSError:
437+
return ""
438+
439+
def _resolve_tcp_local_host(self, local_host: Optional[str]) -> str:
440+
if not local_host:
441+
discovered = self._local_address()
442+
if discovered and not self._is_loopback_host(discovered):
443+
return discovered
444+
ctrl_host = self._ctrl_host()
445+
if ctrl_host and not self._is_loopback_host(ctrl_host):
446+
routed = self._local_ip_for_remote(ctrl_host)
447+
if routed:
448+
return routed
449+
return discovered or "0.0.0.0"
450+
return str(local_host)
451+
410452
def _normalize_link_type(self, link_type: Optional[str]) -> str:
411453
if not link_type:
412454
return "UNKNOWN"
@@ -960,7 +1002,17 @@ def _ensure_connection_from_meta(
9601002
) -> DirectedConnection:
9611003
"""Create local connection state from a peer's directed request."""
9621004
if str(meta.get("transport") or "rdma").lower() == "tcp":
963-
local_key: ResourceKey = TcpResourceKey()
1005+
with self._connections_lock:
1006+
peer_connections = [
1007+
conn
1008+
for conn in self._connections.values()
1009+
if conn.peer_alias == peer_alias
1010+
]
1011+
if len(peer_connections) == 1:
1012+
return peer_connections[0]
1013+
local_key: ResourceKey = TcpResourceKey(
1014+
host=self._resolve_tcp_local_host(None), port=0
1015+
)
9641016
peer_key: ResourceKey = TcpResourceKey()
9651017
return self._get_or_create_connection(
9661018
peer_alias,
@@ -1156,6 +1208,14 @@ def _mark_connection_connected(self, conn_id: str) -> None:
11561208
self._connected_peers.add(conn_id)
11571209
self._connected_peers_cond.notify_all()
11581210

1211+
def _mark_connection_failed(self, conn_id: str) -> None:
1212+
with self._connections_lock:
1213+
conn = self._connections.get(conn_id)
1214+
if conn is not None:
1215+
conn.mark_failed()
1216+
with self._connected_peers_lock:
1217+
self._connected_peers_cond.notify_all()
1218+
11591219
def _has_notified_peer(self, conn_id: str) -> bool:
11601220
"""True iff we've already sent qp_ready for this connection."""
11611221
with self._notified_peers_lock:
@@ -1183,16 +1243,18 @@ def connect_to(
11831243
qp_num: Optional[int] = 1,
11841244
min_bw: Optional[str] = None,
11851245
transport: str = "rdma",
1186-
local_host: str = "0.0.0.0",
1246+
local_host: Optional[str] = None,
11871247
local_port: int = 0,
11881248
) -> PeerConnection:
11891249
"""Start connecting to a peer and return a connection handle.
11901250
11911251
``transport='rdma'`` (default) selects the RDMA path, picking a NIC
11921252
from the discovered local topology. ``transport='tcp'`` selects the
11931253
TCP path: ``local_host``/``local_port`` configure the local bind
1194-
(port 0 = OS-assigned), and ``ib_port``/``qp_num``/``local_device``/
1195-
``peer_device`` are ignored.
1254+
(port 0 = OS-assigned). If ``local_host`` is left as the default
1255+
``None``, PeerAgent publishes its discovered local host address instead
1256+
so remote peers do not receive ``0.0.0.0`` in endpoint_info.
1257+
``ib_port``/``qp_num``/``local_device``/``peer_device`` are ignored.
11961258
"""
11971259
if not isinstance(peer_alias, str) or not peer_alias:
11981260
raise TypeError("connect_to() requires a non-empty peer alias string")
@@ -1213,6 +1275,7 @@ def connect_to(
12131275
t0 = time.perf_counter()
12141276

12151277
if transport_norm == "tcp":
1278+
local_host = self._resolve_tcp_local_host(local_host)
12161279
local_key: ResourceKey = TcpResourceKey(
12171280
host=local_host, port=int(local_port)
12181281
)
@@ -1306,6 +1369,11 @@ def _wait_connected(self, conn_id: str, timeout_sec: float = 60.0) -> None:
13061369
f"+{(time.perf_counter() - t0) * 1000:.3f}ms"
13071370
)
13081371
return
1372+
with self._connections_lock:
1373+
conn = self._connections.get(conn_id)
1374+
state = conn.state if conn is not None else "unknown"
1375+
if state == "failed":
1376+
raise RuntimeError(f"Connection {conn_id!r} failed")
13091377
remaining = deadline - time.monotonic()
13101378
if remaining <= 0:
13111379
raise TimeoutError(

dlslime/dlslime/peer_agent/_mailbox.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,15 @@ def _try_connect_peer_inner(
373373
# D. Complete RDMA handshake
374374
t_d = time.perf_counter()
375375
endpoint.connect(peer_qp_info)
376+
if hasattr(endpoint, "is_connected") and not endpoint.is_connected():
377+
logger.warning(
378+
"StreamMailbox %s: endpoint.connect(%s) failed for endpoint_info=%s",
379+
self._agent.alias,
380+
peer,
381+
peer_qp_info,
382+
)
383+
self._agent._mark_connection_failed(conn_id)
384+
return
376385
# Stash for one-sided ops on transports (TCP) where remote MR info
377386
# rides on the endpoint_info JSON instead of a separate Redis record.
378387
conn.peer_endpoint_info = peer_qp_info
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
#!/usr/bin/env python3
2+
"""TCP one-sided read through two PeerAgent processes.
3+
4+
The initiator reads bytes directly out of a memory region published by the
5+
target, then notifies the target so both processes can shut down cleanly.
6+
7+
Prerequisites:
8+
1. Start NanoCtrl: cd NanoCtrl && cargo run --release
9+
2. Redis must be reachable.
10+
3. dlslime built with BUILD_TCP=ON (default).
11+
12+
Usage:
13+
python p2p_tcp_rc_read_peer_agent_two_process.py --role target
14+
python p2p_tcp_rc_read_peer_agent_two_process.py --role initiator
15+
python p2p_tcp_rc_read_peer_agent_two_process.py --role target --ctrl_address http://host:4479
16+
python p2p_tcp_rc_read_peer_agent_two_process.py --role initiator --ctrl_address http://host:4479
17+
python p2p_tcp_rc_read_peer_agent_two_process.py --role target --local_host 10.0.0.2
18+
python p2p_tcp_rc_read_peer_agent_two_process.py --role initiator --local_host 10.0.0.3
19+
"""
20+
21+
import argparse
22+
import ctypes
23+
from typing import Optional
24+
25+
from dlslime import PeerAgent
26+
27+
28+
INITIATOR_ALIAS = "tcp_read_initiator"
29+
TARGET_ALIAS = "tcp_read_target"
30+
PAYLOAD = b"one-sided-read-via-peer-agent"
31+
DONE = b"done"
32+
33+
34+
def _buffer_from_bytes(data: bytes) -> ctypes.Array:
35+
buf = ctypes.create_string_buffer(len(data))
36+
ctypes.memmove(ctypes.addressof(buf), data, len(data))
37+
return buf
38+
39+
40+
def run_initiator(
41+
ctrl_url: str, local_host: Optional[str], local_port: int, connect_timeout: float
42+
) -> None:
43+
agent = PeerAgent(ctrl_url=ctrl_url, alias=INITIATOR_ALIAS)
44+
try:
45+
buf_a = ctypes.create_string_buffer(64)
46+
addr_a = ctypes.addressof(buf_a)
47+
48+
# Register the local MR before connect_to so endpoint_info() carries it
49+
# when the rendezvous fires. Required for one-sided ops on TCP.
50+
agent.register_memory_region("buf_a", addr_a, 0, 64)
51+
52+
if TARGET_ALIAS not in agent.list_agents():
53+
raise RuntimeError(f"target agent {TARGET_ALIAS!r} is not running")
54+
55+
resolved_host = agent._resolve_tcp_local_host(local_host)
56+
print(f"[initiator] TCP local host: {resolved_host}:{local_port}")
57+
conn = agent.connect_to(
58+
TARGET_ALIAS,
59+
transport="tcp",
60+
local_host=local_host,
61+
local_port=local_port,
62+
)
63+
conn.wait(timeout=connect_timeout)
64+
local_info = conn.endpoint.endpoint_info()
65+
print(
66+
f"[initiator] Connected over TCP: {INITIATOR_ALIAS} -> {TARGET_ALIAS} "
67+
f"(local={local_info.get('host')}:{local_info.get('port')})"
68+
)
69+
70+
done_buf = _buffer_from_bytes(DONE)
71+
72+
ep = conn.endpoint
73+
if not ep.is_connected():
74+
raise RuntimeError(
75+
"TCP endpoint is not connected; make sure both processes publish "
76+
"a peer-reachable --local_host"
77+
)
78+
peer_info = conn.peer_endpoint_info
79+
assert peer_info is not None
80+
remote_mr_info = peer_info.get("mr_info", {}).get("buf_b")
81+
if remote_mr_info is None:
82+
raise RuntimeError(
83+
f"{TARGET_ALIAS!r} is connected but did not publish buf_b; "
84+
"start the target process first"
85+
)
86+
87+
h_local = ep.register_memory_region("buf_a_loc", addr_a, 0, 64)
88+
h_remote = ep.register_remote_memory_region("buf_b_rem", remote_mr_info)
89+
90+
st = ep.read([(h_local, h_remote, 0, 0, len(PAYLOAD))]).wait()
91+
assert st == 0, f"read failed: {st}"
92+
assert (
93+
bytes(buf_a[: len(PAYLOAD)]) == PAYLOAD
94+
), "initiator did not read bytes"
95+
print(
96+
"[initiator] target->initiator one-sided read = "
97+
f"{bytes(buf_a[: len(PAYLOAD)])!r} ok"
98+
)
99+
100+
st = agent.send(
101+
TARGET_ALIAS, (ctypes.addressof(done_buf), 0, len(DONE))
102+
).wait()
103+
assert st == 0, f"done send failed: {st}"
104+
print("[initiator] Sent completion notice.")
105+
finally:
106+
agent.shutdown()
107+
108+
109+
def run_target(
110+
ctrl_url: str, local_host: Optional[str], local_port: int, connect_timeout: float
111+
) -> None:
112+
agent = PeerAgent(ctrl_url=ctrl_url, alias=TARGET_ALIAS)
113+
try:
114+
buf_b = ctypes.create_string_buffer(64)
115+
addr_b = ctypes.addressof(buf_b)
116+
117+
# Pre-fill the target buffer with the payload the initiator reads out.
118+
ctypes.memmove(addr_b, PAYLOAD, len(PAYLOAD))
119+
120+
# Register MRs before connect_to so endpoint_info() carries them when
121+
# the rendezvous fires. Required for one-sided ops on TCP.
122+
agent.register_memory_region("buf_b", addr_b, 0, 64)
123+
124+
resolved_host = agent._resolve_tcp_local_host(local_host)
125+
print(f"[target] TCP local host: {resolved_host}:{local_port}")
126+
conn = agent.connect_to(
127+
INITIATOR_ALIAS,
128+
transport="tcp",
129+
local_host=local_host,
130+
local_port=local_port,
131+
)
132+
conn.wait(timeout=connect_timeout)
133+
local_info = conn.endpoint.endpoint_info()
134+
print(
135+
f"[target] Connected over TCP: {TARGET_ALIAS} -> {INITIATOR_ALIAS} "
136+
f"(local={local_info.get('host')}:{local_info.get('port')})"
137+
)
138+
139+
done_buf = ctypes.create_string_buffer(len(DONE))
140+
141+
if not conn.endpoint.is_connected():
142+
raise RuntimeError(
143+
"TCP endpoint is not connected; make sure both processes publish "
144+
"a peer-reachable --local_host"
145+
)
146+
print("[target] Published buffer; waiting for initiator completion.")
147+
148+
st = agent.recv(
149+
INITIATOR_ALIAS, (ctypes.addressof(done_buf), 0, len(DONE))
150+
).wait()
151+
assert st == 0, f"done recv failed: {st}"
152+
assert bytes(done_buf[: len(DONE)]) == DONE, "initiator did not send done"
153+
print("[target] Initiator completed read; shutting down.")
154+
finally:
155+
agent.shutdown()
156+
157+
158+
if __name__ == "__main__":
159+
parser = argparse.ArgumentParser(
160+
description="TCP one-sided read through two PeerAgent processes"
161+
)
162+
parser.add_argument(
163+
"--role",
164+
choices=("initiator", "target"),
165+
required=True,
166+
help="Process role to run",
167+
)
168+
parser.add_argument(
169+
"--ctrl_address",
170+
"--ctrl",
171+
default="http://127.0.0.1:4479",
172+
help="NanoCtrl URL",
173+
)
174+
parser.add_argument(
175+
"--local_host",
176+
default=None,
177+
help="Local TCP host/IP to publish to the peer; defaults to discovered host",
178+
)
179+
parser.add_argument(
180+
"--local_port",
181+
type=int,
182+
default=0,
183+
help="Local TCP port to bind; 0 lets the OS choose",
184+
)
185+
parser.add_argument(
186+
"--connect_timeout",
187+
type=float,
188+
default=60.0,
189+
help="Seconds to wait for the peer-agent TCP connection",
190+
)
191+
args = parser.parse_args()
192+
if args.role == "initiator":
193+
run_initiator(
194+
args.ctrl_address, args.local_host, args.local_port, args.connect_timeout
195+
)
196+
else:
197+
run_target(
198+
args.ctrl_address, args.local_host, args.local_port, args.connect_timeout
199+
)

0 commit comments

Comments
 (0)