Skip to content

Commit 8721270

Browse files
Fix TCP PeerAgent advertised host (#111)
* Fix TCP PeerAgent advertised host * lint * Update dlslime/dlslime/peer_agent/_agent.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update dlslime/dlslime/peer_agent/_agent.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update dlslime/examples/python/p2p_tcp_rc_read_peer_agent_two_process.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update dlslime/examples/python/p2p_tcp_rc_read_peer_agent_two_process.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Handle TCP endpoint connect exceptions * Update dlslime/dlslime/peer_agent/_agent.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 20ee6bf commit 8721270

4 files changed

Lines changed: 345 additions & 5 deletions

File tree

dlslime/dlslime/peer_agent/_agent.py

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,16 @@
88
from __future__ import annotations
99

1010
import inspect
11+
import ipaddress
1112
import json
1213
import os
14+
import socket
1315
import threading
1416
import time
1517
from concurrent.futures import ThreadPoolExecutor
1618
from dataclasses import dataclass
1719
from typing import Any, Dict, List, Optional, Set, Tuple, Union
20+
from urllib.parse import urlparse
1821

1922
try:
2023
import httpx
@@ -407,6 +410,50 @@ def _local_address(self) -> str:
407410
return str(address)
408411
return ""
409412

413+
def _ctrl_host(self) -> str:
414+
ctrl_url = self.ctrl_url
415+
if not ctrl_url.startswith(("http://", "https://")):
416+
ctrl_url = f"http://{ctrl_url}"
417+
return urlparse(ctrl_url).hostname or ""
418+
419+
@staticmethod
420+
def _is_loopback_host(host: str) -> bool:
421+
if host in {"localhost", "::1"}:
422+
return True
423+
try:
424+
return ipaddress.ip_address(host).is_loopback
425+
except ValueError:
426+
return False
427+
428+
@staticmethod
429+
def _local_ip_for_remote(remote_host: str) -> str:
430+
if not remote_host:
431+
return ""
432+
try:
433+
# Resolve the family dynamically to support both IPv4 and IPv6
434+
gai = socket.getaddrinfo(remote_host, 80, type=socket.SOCK_DGRAM)
435+
if not gai:
436+
return ""
437+
family = gai[0][0]
438+
with socket.socket(family, socket.SOCK_DGRAM) as sock:
439+
sock.connect((remote_host, 80))
440+
return sock.getsockname()[0]
441+
except OSError:
442+
return ""
443+
444+
def _resolve_tcp_local_host(self, local_host: Optional[str]) -> str:
445+
if not local_host:
446+
discovered = self._local_address()
447+
if discovered and not self._is_loopback_host(discovered):
448+
return discovered
449+
ctrl_host = self._ctrl_host()
450+
if ctrl_host and not self._is_loopback_host(ctrl_host):
451+
routed = self._local_ip_for_remote(ctrl_host)
452+
if routed:
453+
return routed
454+
return discovered or "0.0.0.0"
455+
return str(local_host)
456+
410457
def _normalize_link_type(self, link_type: Optional[str]) -> str:
411458
if not link_type:
412459
return "UNKNOWN"
@@ -960,7 +1007,17 @@ def _ensure_connection_from_meta(
9601007
) -> DirectedConnection:
9611008
"""Create local connection state from a peer's directed request."""
9621009
if str(meta.get("transport") or "rdma").lower() == "tcp":
963-
local_key: ResourceKey = TcpResourceKey()
1010+
with self._connections_lock:
1011+
peer_connections = [
1012+
conn
1013+
for conn in self._connections.values()
1014+
if conn.peer_alias == peer_alias and conn.transport == "tcp"
1015+
]
1016+
if len(peer_connections) == 1:
1017+
return peer_connections[0]
1018+
local_key: ResourceKey = TcpResourceKey(
1019+
host=self._resolve_tcp_local_host(None), port=0
1020+
)
9641021
peer_key: ResourceKey = TcpResourceKey()
9651022
return self._get_or_create_connection(
9661023
peer_alias,
@@ -1156,6 +1213,14 @@ def _mark_connection_connected(self, conn_id: str) -> None:
11561213
self._connected_peers.add(conn_id)
11571214
self._connected_peers_cond.notify_all()
11581215

1216+
def _mark_connection_failed(self, conn_id: str) -> None:
1217+
with self._connections_lock:
1218+
conn = self._connections.get(conn_id)
1219+
if conn is not None:
1220+
conn.mark_failed()
1221+
with self._connected_peers_lock:
1222+
self._connected_peers_cond.notify_all()
1223+
11591224
def _has_notified_peer(self, conn_id: str) -> bool:
11601225
"""True iff we've already sent qp_ready for this connection."""
11611226
with self._notified_peers_lock:
@@ -1183,16 +1248,18 @@ def connect_to(
11831248
qp_num: Optional[int] = 1,
11841249
min_bw: Optional[str] = None,
11851250
transport: str = "rdma",
1186-
local_host: str = "0.0.0.0",
1251+
local_host: Optional[str] = None,
11871252
local_port: int = 0,
11881253
) -> PeerConnection:
11891254
"""Start connecting to a peer and return a connection handle.
11901255
11911256
``transport='rdma'`` (default) selects the RDMA path, picking a NIC
11921257
from the discovered local topology. ``transport='tcp'`` selects the
11931258
TCP path: ``local_host``/``local_port`` configure the local bind
1194-
(port 0 = OS-assigned), and ``ib_port``/``qp_num``/``local_device``/
1195-
``peer_device`` are ignored.
1259+
(port 0 = OS-assigned). If ``local_host`` is left as the default
1260+
``None``, PeerAgent publishes its discovered local host address instead
1261+
so remote peers do not receive ``0.0.0.0`` in endpoint_info.
1262+
``ib_port``/``qp_num``/``local_device``/``peer_device`` are ignored.
11961263
"""
11971264
if not isinstance(peer_alias, str) or not peer_alias:
11981265
raise TypeError("connect_to() requires a non-empty peer alias string")
@@ -1213,6 +1280,7 @@ def connect_to(
12131280
t0 = time.perf_counter()
12141281

12151282
if transport_norm == "tcp":
1283+
local_host = self._resolve_tcp_local_host(local_host)
12161284
local_key: ResourceKey = TcpResourceKey(
12171285
host=local_host, port=int(local_port)
12181286
)
@@ -1306,6 +1374,11 @@ def _wait_connected(self, conn_id: str, timeout_sec: float = 60.0) -> None:
13061374
f"+{(time.perf_counter() - t0) * 1000:.3f}ms"
13071375
)
13081376
return
1377+
with self._connections_lock:
1378+
conn = self._connections.get(conn_id)
1379+
state = conn.state if conn is not None else "unknown"
1380+
if state == "failed":
1381+
raise RuntimeError(f"Connection {conn_id!r} failed")
13091382
remaining = deadline - time.monotonic()
13101383
if remaining <= 0:
13111384
raise TimeoutError(

dlslime/dlslime/peer_agent/_mailbox.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,28 @@ def _try_connect_peer_inner(
372372

373373
# D. Complete RDMA handshake
374374
t_d = time.perf_counter()
375-
endpoint.connect(peer_qp_info)
375+
try:
376+
endpoint.connect(peer_qp_info)
377+
except Exception as e:
378+
logger.warning(
379+
"StreamMailbox %s: endpoint.connect(%s) raised for "
380+
"endpoint_info=%s: %s",
381+
self._agent.alias,
382+
peer,
383+
peer_qp_info,
384+
e,
385+
)
386+
self._agent._mark_connection_failed(conn_id)
387+
return
388+
if hasattr(endpoint, "is_connected") and not endpoint.is_connected():
389+
logger.warning(
390+
"StreamMailbox %s: endpoint.connect(%s) failed for endpoint_info=%s",
391+
self._agent.alias,
392+
peer,
393+
peer_qp_info,
394+
)
395+
self._agent._mark_connection_failed(conn_id)
396+
return
376397
# Stash for one-sided ops on transports (TCP) where remote MR info
377398
# rides on the endpoint_info JSON instead of a separate Redis record.
378399
conn.peer_endpoint_info = peer_qp_info
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
#!/usr/bin/env python3
2+
"""TCP one-sided read through two PeerAgent processes.
3+
4+
The initiator reads bytes directly out of a memory region published by the
5+
target, then notifies the target so both processes can shut down cleanly.
6+
7+
Prerequisites:
8+
1. Start NanoCtrl: cd NanoCtrl && cargo run --release
9+
2. Redis must be reachable.
10+
3. dlslime built with BUILD_TCP=ON (default).
11+
12+
Usage:
13+
python p2p_tcp_rc_read_peer_agent_two_process.py --role target
14+
python p2p_tcp_rc_read_peer_agent_two_process.py --role initiator
15+
python p2p_tcp_rc_read_peer_agent_two_process.py --role target --ctrl_address http://host:4479
16+
python p2p_tcp_rc_read_peer_agent_two_process.py --role initiator --ctrl_address http://host:4479
17+
python p2p_tcp_rc_read_peer_agent_two_process.py --role target --local_host 10.0.0.2
18+
python p2p_tcp_rc_read_peer_agent_two_process.py --role initiator --local_host 10.0.0.3
19+
"""
20+
21+
import argparse
22+
import ctypes
23+
from typing import Optional
24+
25+
from dlslime import PeerAgent
26+
27+
28+
INITIATOR_ALIAS = "tcp_read_initiator"
29+
TARGET_ALIAS = "tcp_read_target"
30+
PAYLOAD = b"one-sided-read-via-peer-agent"
31+
DONE = b"done"
32+
33+
34+
def _buffer_from_bytes(data: bytes) -> ctypes.Array:
35+
buf = ctypes.create_string_buffer(len(data))
36+
ctypes.memmove(ctypes.addressof(buf), data, len(data))
37+
return buf
38+
39+
40+
def run_initiator(
41+
ctrl_url: str, local_host: Optional[str], local_port: int, connect_timeout: float
42+
) -> None:
43+
agent = PeerAgent(ctrl_url=ctrl_url, alias=INITIATOR_ALIAS)
44+
try:
45+
buf_a = ctypes.create_string_buffer(64)
46+
addr_a = ctypes.addressof(buf_a)
47+
48+
# Register the local MR before connect_to so endpoint_info() carries it
49+
# when the rendezvous fires. Required for one-sided ops on TCP.
50+
agent.register_memory_region("buf_a", addr_a, 0, 64)
51+
52+
if TARGET_ALIAS not in agent.list_agents():
53+
raise RuntimeError(f"target agent {TARGET_ALIAS!r} is not running")
54+
55+
print(f"[initiator] TCP local host: {local_host or 'default'}:{local_port}")
56+
conn = agent.connect_to(
57+
TARGET_ALIAS,
58+
transport="tcp",
59+
local_host=local_host,
60+
local_port=local_port,
61+
)
62+
conn.wait(timeout=connect_timeout)
63+
local_info = conn.endpoint.endpoint_info()
64+
print(
65+
f"[initiator] Connected over TCP: {INITIATOR_ALIAS} -> {TARGET_ALIAS} "
66+
f"(local={local_info.get('host')}:{local_info.get('port')})"
67+
)
68+
69+
done_buf = _buffer_from_bytes(DONE)
70+
71+
ep = conn.endpoint
72+
if not ep.is_connected():
73+
raise RuntimeError(
74+
"TCP endpoint is not connected; make sure both processes publish "
75+
"a peer-reachable --local_host"
76+
)
77+
peer_info = conn.peer_endpoint_info
78+
assert peer_info is not None
79+
remote_mr_info = peer_info.get("mr_info", {}).get("buf_b")
80+
if remote_mr_info is None:
81+
raise RuntimeError(
82+
f"{TARGET_ALIAS!r} is connected but did not publish buf_b; "
83+
"start the target process first"
84+
)
85+
86+
h_local = ep.register_memory_region("buf_a_loc", addr_a, 0, 64)
87+
h_remote = ep.register_remote_memory_region("buf_b_rem", remote_mr_info)
88+
89+
st = ep.read([(h_local, h_remote, 0, 0, len(PAYLOAD))]).wait()
90+
assert st == 0, f"read failed: {st}"
91+
assert bytes(buf_a[: len(PAYLOAD)]) == PAYLOAD, "initiator did not read bytes"
92+
print(
93+
"[initiator] target->initiator one-sided read = "
94+
f"{bytes(buf_a[: len(PAYLOAD)])!r} ok"
95+
)
96+
97+
st = agent.send(TARGET_ALIAS, (ctypes.addressof(done_buf), 0, len(DONE))).wait()
98+
assert st == 0, f"done send failed: {st}"
99+
print("[initiator] Sent completion notice.")
100+
finally:
101+
agent.shutdown()
102+
103+
104+
def run_target(
105+
ctrl_url: str, local_host: Optional[str], local_port: int, connect_timeout: float
106+
) -> None:
107+
agent = PeerAgent(ctrl_url=ctrl_url, alias=TARGET_ALIAS)
108+
try:
109+
buf_b = ctypes.create_string_buffer(64)
110+
addr_b = ctypes.addressof(buf_b)
111+
112+
# Pre-fill the target buffer with the payload the initiator reads out.
113+
ctypes.memmove(addr_b, PAYLOAD, len(PAYLOAD))
114+
115+
# Register MRs before connect_to so endpoint_info() carries them when
116+
# the rendezvous fires. Required for one-sided ops on TCP.
117+
agent.register_memory_region("buf_b", addr_b, 0, 64)
118+
119+
print(f"[target] TCP local host: {local_host or 'default'}:{local_port}")
120+
conn = agent.connect_to(
121+
INITIATOR_ALIAS,
122+
transport="tcp",
123+
local_host=local_host,
124+
local_port=local_port,
125+
)
126+
conn.wait(timeout=connect_timeout)
127+
local_info = conn.endpoint.endpoint_info()
128+
print(
129+
f"[target] Connected over TCP: {TARGET_ALIAS} -> {INITIATOR_ALIAS} "
130+
f"(local={local_info.get('host')}:{local_info.get('port')})"
131+
)
132+
133+
done_buf = ctypes.create_string_buffer(len(DONE))
134+
135+
if not conn.endpoint.is_connected():
136+
raise RuntimeError(
137+
"TCP endpoint is not connected; make sure both processes publish "
138+
"a peer-reachable --local_host"
139+
)
140+
print("[target] Published buffer; waiting for initiator completion.")
141+
142+
st = agent.recv(
143+
INITIATOR_ALIAS, (ctypes.addressof(done_buf), 0, len(DONE))
144+
).wait()
145+
assert st == 0, f"done recv failed: {st}"
146+
assert bytes(done_buf[: len(DONE)]) == DONE, "initiator did not send done"
147+
print("[target] Initiator completed read; shutting down.")
148+
finally:
149+
agent.shutdown()
150+
151+
152+
if __name__ == "__main__":
153+
parser = argparse.ArgumentParser(
154+
description="TCP one-sided read through two PeerAgent processes"
155+
)
156+
parser.add_argument(
157+
"--role",
158+
choices=("initiator", "target"),
159+
required=True,
160+
help="Process role to run",
161+
)
162+
parser.add_argument(
163+
"--ctrl_address",
164+
"--ctrl",
165+
default="http://127.0.0.1:4479",
166+
help="NanoCtrl URL",
167+
)
168+
parser.add_argument(
169+
"--local_host",
170+
default=None,
171+
help="Local TCP host/IP to publish to the peer; defaults to discovered host",
172+
)
173+
parser.add_argument(
174+
"--local_port",
175+
type=int,
176+
default=0,
177+
help="Local TCP port to bind; 0 lets the OS choose",
178+
)
179+
parser.add_argument(
180+
"--connect_timeout",
181+
type=float,
182+
default=60.0,
183+
help="Seconds to wait for the peer-agent TCP connection",
184+
)
185+
args = parser.parse_args()
186+
if args.role == "initiator":
187+
run_initiator(
188+
args.ctrl_address, args.local_host, args.local_port, args.connect_timeout
189+
)
190+
else:
191+
run_target(
192+
args.ctrl_address, args.local_host, args.local_port, args.connect_timeout
193+
)

0 commit comments

Comments
 (0)