diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml old mode 100644 new mode 100755 index c85ba504..cc8c9dab --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -55,7 +55,8 @@ jobs: cmake \ ninja-build \ libnuma-dev \ - libibverbs-dev + libibverbs-dev \ + libasio-dev - name: Install build dependencies run: | @@ -66,14 +67,17 @@ jobs: env: CMAKE_ARGS: >- -DBUILD_RDMA=ON + -DBUILD_TCP=ON -DBUILD_PYTHON=ON -DBUILD_NVLINK=OFF -DBUILD_TORCH_PLUGIN=OFF -DBUILD_ASCEND_DIRECT=OFF -DBUILD_TEST=OFF + -DUSE_CUDA=OFF # The native wheel lives under dlslime/ (dlslime-ctrl/ is a separate Rust crate). run: python -m build --wheel --outdir dist dlslime + - name: Install wheel smoke test run: | python -m pip install dist/dlslime-*.whl --no-deps @@ -128,6 +132,8 @@ jobs: docker exec "${container_name}" bash -lc ' set -euxo pipefail + apt install -y libasio-dev + cd /workspace export PIP_CONFIG_FILE=/dev/null export PIP_INDEX_URL=https://mirrors.aliyun.com/pypi/simple/ @@ -138,6 +144,7 @@ jobs: export HTTPS_PROXY=http://127.0.0.1:7897 export SLIME_VISIBLE_DEVICES=mlx5_0,mlx5_1,mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7 unset PIP_EXTRA_INDEX_URL + export DLSLIME_TCP_TEST_CUDA=ON python -m pip install -U pip pytest python -m pip show dlslime || true python -m pip uninstall -y dlslime || true diff --git a/dlslime/CMakeLists.txt b/dlslime/CMakeLists.txt index 6d987c9e..78de0cbd 100644 --- a/dlslime/CMakeLists.txt +++ b/dlslime/CMakeLists.txt @@ -13,6 +13,7 @@ slime_option(USE_MACA "USE in MACA Platform" OFF) slime_option(BUILD_NVLINK "Build NVLINK" OFF) slime_option(BUILD_ASCEND_DIRECT "Build Ascend direct transport" OFF) +slime_option(BUILD_TCP "Build TCP transport" ON) # Slime options for custom python wrapper slime_option(BUILD_PYTHON "Build python wrapper" OFF) diff --git a/dlslime/bench/python/tcp_bench_spmd.py b/dlslime/bench/python/tcp_bench_spmd.py new file mode 100755 index 00000000..63c695ff --- /dev/null +++ b/dlslime/bench/python/tcp_bench_spmd.py @@ -0,0 +1,358 @@ +"""# Remote Read Benchmark + +## Node 0 +torchrun --master-addr 10.130.8.145 --master-port 6006 \ + --nnodes 2 --nproc-per-node 1 --node-rank 0 bench/python/tcp_spmd.py \ + --transfer-engine dlslime --batch-size 94 --num-iteration 10 --num-concurrency 8 + +## Node 1 +torchrun --master-addr 10.130.8.145 --master-port 6006 \ + --nnodes 2 --nproc-per-node 1 --node-rank 1 bench/python/tcp_spmd.py \ + --transfer-engine dlslime --batch-size 94 --num-iteration 10 --num-concurrency 8 +""" + +import argparse +import csv +import os +import socket + +import torch +import torch.distributed as dist +from tabulate import tabulate +from torch.distributed import distributed_c10d + +parser = argparse.ArgumentParser() +parser.add_argument("--batch-size", type=int, default=1) +parser.add_argument("--size", nargs="+", type=int, default=[n for n in range(8, 20)]) +parser.add_argument("--num-concurrency", type=int, default=16) +parser.add_argument("--num-iteration", type=int, default=100) +parser.add_argument("--num-warmup-iteration", type=int, default=10) +parser.add_argument("--opcode", type=str, choices=["read", "write"], default="write") +parser.add_argument( + "--save-csv", action="store_true", help="Save benchmark results to CSV file" +) +parser.add_argument( + "--csv-filename", type=str, default="./output.csv", help="Filename for CSV output" +) +parser.add_argument( + "--transfer-engine", + choices=["dlslime", "mooncake"], + type=str, + default="dlslime", +) + +args = parser.parse_args() + + +def set_env_when_no_default(env_name, value): + env_value = os.environ.get(env_name, value) or value + os.environ[env_name] = env_value + return env_value + + +def get_local_ip(): + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(("8.8.8.8", 80)) + local_ip = s.getsockname()[0] + s.close() + return local_ip + + +local_ip = get_local_ip() + +# Get SPMD Info +rank = int(os.environ["RANK"]) +local_rank = int(os.environ["LOCAL_RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +local_world_size = nnodes = int(os.environ["LOCAL_WORLD_SIZE"]) +master_addr = os.environ["MASTER_ADDR"] +master_port = os.environ["MASTER_PORT"] +npros_per_rank = local_world_size +assert world_size % 2 == 0 +num_channels = world_size // 2 +# target rank for initiator rank +peer_rank = (rank + num_channels) % world_size + +if rank < num_channels: + role = "initiator" +else: + role = "target" + + +def rank_0_print(*args): + if rank == 0: + print(*args) + + +rank_0_print( + f"rank [{0}, {world_size // 2}) for initiator, " + f"rank [{world_size // 2}, {world_size}) for target." +) +rank_0_print( + f"{rank=}, {peer_rank=}, {world_size=}, {npros_per_rank=}, {master_addr=}, {master_port=}" +) +rank_0_print(f"Local_ip: {local_ip}") +rank_0_print(f"mode: {args.opcode}") +rank_0_print(f"batch size: {args.batch_size}") +rank_0_print(f"num concurrency: {args.num_concurrency}") + +rank_0_print(f"benchmarking transfer engine: {args.transfer_engine}") +# import Python Package +if args.transfer_engine == "dlslime": + import dlslime + from dlslime import TcpEndpoint + +elif args.transfer_engine == "mooncake": + from mooncake.engine import TransferEngine as MooncakeTransferEngine + +dist.init_process_group("cpu:gloo,cuda:nccl") +transfer_group = dist.new_group(list(range(world_size)), backend="cuda:nccl") + +# TODO: for AFD Benchmark +initiator_group = dist.new_group(list(range(num_channels)), backend="cpu:gloo") +target_group = dist.new_group(list(range(num_channels, world_size)), backend="cpu:gloo") + +# Setting Info +if args.transfer_engine == "mooncake": + mooncake_endpoint_info = {"local_ip": local_ip, "kv_table": {}, "endpoint": []} + +if args.opcode != "write": + raise ValueError("Immediate data can only be used with write operations.") + +if args.transfer_engine == "dlslime": + tcp_endpoint = TcpEndpoint(f"{local_ip}", 22500 + local_rank) +elif args.transfer_engine == "mooncake": + engine = MooncakeTransferEngine() + result = engine.initialize( + f"{local_ip}:{22500+local_rank}", "P2PHANDSHAKE", "tcp", None + ) + mooncake_endpoint_info = { + "local_ip": local_ip, + "kv_table": {}, + "endpoint": engine.get_rpc_port(), + } + tcp_endpoint = engine + +torch.cuda.set_device(local_rank) + +max_numel = 2 << max(args.size) +max_ttensor = torch.ones([max_numel], device=f"cuda:{local_rank}") +ttensors = [max_ttensor[: 2 << rawsize] for rawsize in args.size] +max_mr_key = 0 +dlslime_mr_name = "bench_mr" # named MR for the DLSlime handle model +dlslime_local_handle = None +dlslime_remote_handle = None +print(local_rank) +torch.cuda.synchronize() + +if args.transfer_engine == "dlslime": + dlslime_local_handle = tcp_endpoint.register_memory_region( + dlslime_mr_name, + max_ttensor.data_ptr(), + int(max_ttensor.storage_offset()), + max_ttensor.numel() * max_ttensor.itemsize, + ) +elif args.transfer_engine == "mooncake": + result = tcp_endpoint.register_memory( + max_ttensor.data_ptr() + max_ttensor.storage_offset(), + max_ttensor.numel() * max_ttensor.itemsize, + ) + if result != 0: + raise RuntimeError(f"Failed to register memory region: {result}") + mooncake_endpoint_info["kv_table"][max_mr_key] = ( + max_ttensor.data_ptr() + max_ttensor.storage_offset(), + max_ttensor.numel() * max_ttensor.itemsize, + ) + + +if rank == 0: + print("exchanging endpoint info... ") +all_endpoint_info = [{} for _ in range(world_size)] +if args.transfer_engine == "dlslime": + dist.all_gather_object(all_endpoint_info, tcp_endpoint.endpoint_info()) +elif args.transfer_engine == "mooncake": + dist.all_gather_object(all_endpoint_info, mooncake_endpoint_info) + +if rank == 0: + print("endpoint exchanged") + +if args.transfer_engine == "dlslime": + # endpoint connect + tcp_endpoint.connect(all_endpoint_info[(rank + num_channels) % world_size]) + # Resolve the peer's published MR to a local remote-handle. Both sides + # register so either can initiate read/write; only the initiator role + # actually uses the handle in this benchmark. + peer_mr_info = all_endpoint_info[peer_rank]["mr_info"][dlslime_mr_name] + dlslime_remote_handle = tcp_endpoint.register_remote_memory_region( + dlslime_mr_name, peer_mr_info + ) +elif args.transfer_engine == "mooncake": + # construct connect lazily + pass + +start_event = torch.cuda.Event(enable_timing=True) +end_event = torch.cuda.Event(enable_timing=True) + + +def transfer_batch_concurrency_dlslime( + role, opcode, local_handle, remote_handle, tensor, batch_size, num_concurrency +): + fn = tcp_endpoint.async_read if opcode == "read" else tcp_endpoint.async_write + if role == "initiator": + slots = [] + for concurrent_id in range(num_concurrency): + assign = [ + fn( + [ + ( + local_handle, + remote_handle, + 0, + 0, + tensor.numel() * tensor.itemsize, + ) + for _ in range(batch_size) + ], + ) + ] + slots.extend(assign) + + for slot in slots: + slot.wait() + + +def transfer_batch_concurrency_mooncake( + role, opcode, mr_key, tensor, batch_size, num_concurrency +): + print(f"tcp_endpoint:{tcp_endpoint}") + # assert opcode == 'read' + if role == "initiator": + all_batch_ids_to_wait = [] + for concurrent_id in range(num_concurrency): + batch_id = tcp_endpoint.batch_transfer_async_write( + f"{all_endpoint_info[peer_rank]['local_ip']}:{all_endpoint_info[peer_rank]['endpoint']}", + [ + all_endpoint_info[rank]["kv_table"][mr_key][0] + for _ in range(batch_size) + ], + [ + all_endpoint_info[peer_rank]["kv_table"][mr_key][0] + for _ in range(batch_size) + ], + [tensor.numel() * tensor.itemsize for _ in range(batch_size)], + ) + if batch_id == 0: + print("error for transport") + all_batch_ids_to_wait.append(batch_id) + result = tcp_endpoint.get_batch_transfer_status(all_batch_ids_to_wait) + if result != 0: + print(f"transport failure, batch IDs: {all_batch_ids_to_wait}") + + +n_runs = args.num_concurrency +benchmark_data = [] +for idx, (rawsize, ttensor) in enumerate(zip(args.size, ttensors)): + rank_0_print(f"benchmark s={ttensor.numel() * ttensor.itemsize / 1024}K") + size = 2 << rawsize + total_time = 0.0 + + def _run_one_iteration(): + if args.transfer_engine == "dlslime": + transfer_batch_concurrency_dlslime( + role, + args.opcode, + dlslime_local_handle, + dlslime_remote_handle, + ttensor, + args.batch_size, + args.num_concurrency, + ) + elif args.transfer_engine == "mooncake": + transfer_batch_concurrency_mooncake( + role, + args.opcode, + max_mr_key, + ttensor, + args.batch_size, + args.num_concurrency, + ) + torch.cuda.synchronize() + + # Warmup (excluded from timing). Barrier after so all ranks start the + # timed region together and the first measured iteration isn't paying + # cold-start cost. + for _ in range(args.num_warmup_iteration): + _run_one_iteration() + torch.cuda.synchronize() + dist.barrier() + + start_event.record() + for iter_id in range(args.num_iteration): + _run_one_iteration() + end_event.record() + torch.cuda.synchronize() + dist.barrier() + elapsed_time = start_event.elapsed_time(end_event) + total_time += elapsed_time + + if rank < num_channels: + size_bytes = ttensor.numel() * ttensor.itemsize + total_transport = ( + n_runs * size * ttensor.itemsize * args.num_iteration * args.batch_size + ) + avg_latency = total_time / args.num_iteration / n_runs + + bandwidth = torch.tensor(total_transport / total_time / 1e3) + dist.all_reduce(bandwidth, group=initiator_group) + bandwidth = int(bandwidth) + + benchmark_data.append( + [ + args.transfer_engine, + num_channels, + f"{size_bytes:,}", # noqa: E231 + f"{args.batch_size}", # noqa: E231 + f"{args.num_concurrency}", # noqa: E231 + f"{total_transport:,}", # noqa: E231 + f"{avg_latency:.3f}", # noqa: E231 + f"{bandwidth:.3f}", # noqa: E231 + ] + ) + + rank_0_print( + [ + args.transfer_engine, + num_channels, + f"{size_bytes:,}", # noqa: E231 + f"{args.batch_size}", # noqa: E231 + f"{args.num_concurrency}", # noqa: E231 + f"{total_transport:,}", # noqa: E231 + f"{avg_latency:.3f}", # noqa: E231 + f"{bandwidth:.3f}", # noqa: E231 + ] + ) + +dist.barrier() + +if rank == 0: + headers = [ + "Transfer Engine", + "#Channels", + "Message Size (bytes)", + "Batch Size", + "Num Concurrency", + "Total Transport (bytes)", + "Avg Latency(ms)", + "Bandwidth(MB/s)", + ] + print("\nBenchmark Results:") + print(tabulate(benchmark_data, headers=headers, tablefmt="github")) + if args.save_csv: + with open(args.csv_filename, "w", newline="") as f: + writer = csv.writer(f) + if f.tell() == 0: + writer.writerow(headers) + writer.writerows(benchmark_data) + print(f"CSV saved to {args.csv_filename}") + +dist.destroy_process_group() diff --git a/dlslime/dlslime/csrc/CMakeLists.txt b/dlslime/dlslime/csrc/CMakeLists.txt index 0947f3c1..045f74ad 100644 --- a/dlslime/dlslime/csrc/CMakeLists.txt +++ b/dlslime/dlslime/csrc/CMakeLists.txt @@ -27,6 +27,10 @@ if(BUILD_RDMA) target_link_libraries(dlslime INTERFACE _slime_rdma) endif() +if(BUILD_TCP) + target_link_libraries(dlslime INTERFACE _slime_tcp) +endif() + # rpc/ has no independent C++ consumers (the session API is all # pybind11). It is compiled straight into _slime_c.so by the python # subdirectory below. Keep the source files here as a marker so future diff --git a/dlslime/dlslime/csrc/engine/CMakeLists.txt b/dlslime/dlslime/csrc/engine/CMakeLists.txt index c03c88c7..c9bffdf5 100755 --- a/dlslime/dlslime/csrc/engine/CMakeLists.txt +++ b/dlslime/dlslime/csrc/engine/CMakeLists.txt @@ -34,3 +34,7 @@ endif() if (BUILD_RDMA) add_subdirectory(rdma) endif() + +if (BUILD_TCP) + add_subdirectory(tcp) +endif() diff --git a/dlslime/dlslime/csrc/engine/tcp/CMakeLists.txt b/dlslime/dlslime/csrc/engine/tcp/CMakeLists.txt new file mode 100644 index 00000000..a65f16ce --- /dev/null +++ b/dlslime/dlslime/csrc/engine/tcp/CMakeLists.txt @@ -0,0 +1,50 @@ +# asio is header-only. Try find_package, fall back to manual detection. +find_package(asio QUIET) +if(NOT asio_FOUND) + if(EXISTS /usr/include/asio.hpp) + add_library(asio::asio INTERFACE IMPORTED) + target_include_directories(asio::asio INTERFACE /usr/include) + elseif(EXISTS /usr/include/boost/asio.hpp) + add_library(asio::asio INTERFACE IMPORTED) + target_include_directories(asio::asio INTERFACE /usr/include/boost) + target_compile_definitions(asio::asio INTERFACE ASIO_STANDALONE) + else() + message(FATAL_ERROR "asio not found. Install libasio-dev or boost.") + endif() +endif() + +add_library(_slime_tcp SHARED + tcp_memory_pool.cpp + tcp_connection_pool.cpp + tcp_context.cpp + tcp_session.cpp + tcp_endpoint.cpp +) + +target_compile_definitions(_slime_tcp PRIVATE ASIO_STANDALONE) + +# Workaround: asio awaitable.hpp uses std::exchange without including +target_compile_options(_slime_tcp PRIVATE -include utility) + +if (USE_CUDA) + find_package(CUDAToolkit REQUIRED) + target_compile_definitions(_slime_tcp PRIVATE USE_CUDA) + target_include_directories(_slime_tcp PRIVATE ${CUDAToolkit_INCLUDE_DIRS}) + target_link_libraries(_slime_tcp PUBLIC CUDA::cudart) +endif() + +target_link_libraries(_slime_tcp PUBLIC + asio::asio + _slime_device + _slime_engine +) + +set_target_properties(_slime_tcp PROPERTIES + BUILD_WITH_INSTALL_RPATH TRUE + INSTALL_RPATH "${ORIGIN}" +) + +install(TARGETS _slime_tcp + EXPORT dlslimeTargets + LIBRARY DESTINATION ${DLSLIME_INSTALL_PATH} +) diff --git a/dlslime/dlslime/csrc/engine/tcp/tcp_connection_pool.cpp b/dlslime/dlslime/csrc/engine/tcp/tcp_connection_pool.cpp new file mode 100644 index 00000000..9aec5a8e --- /dev/null +++ b/dlslime/dlslime/csrc/engine/tcp/tcp_connection_pool.cpp @@ -0,0 +1,148 @@ +#include "tcp_connection_pool.h" + +#include + +#include "dlslime/csrc/logging.h" + +namespace dlslime { +namespace tcp { + +using tcp = asio::ip::tcp; + +std::shared_ptr TcpConnectionPool::getConnection(const std::string& host, uint16_t port) +{ + ConnKey key{host, port}; + + { + std::lock_guard lk(mu_); + auto it = pool_.find(key); + if (it != pool_.end()) { + for (auto& c : it->second) { + if (!c->in_use && c->socket.is_open()) { + c->in_use = true; + c->last_used = std::chrono::steady_clock::now(); + return c; + } + } + } + } + + tcp::resolver resolver(io_ctx_); + auto endpoints = resolver.resolve(host, std::to_string(port)); + tcp::socket sock(io_ctx_); + asio::error_code ec; + asio::connect(sock, endpoints, ec); + if (ec) { + SLIME_LOG_WARN("TcpConnectionPool: connect to ", host, ":", port, " failed: ", ec.message()); + return nullptr; + } + sock.set_option(tcp::no_delay(true)); + + auto conn = std::make_shared(std::move(sock), host, port); + { + std::lock_guard lk(mu_); + // Remove idle connection + cleanupIdleConnections(false); + + auto& q = pool_[key]; + for (auto q_i = q.begin(); q_i != q.end();) { + auto& c = *q_i; + if (!c->in_use) { + if (c->socket.is_open()) { + c->in_use = true; + c->last_used = std::chrono::steady_clock::now(); + asio::error_code ign; + conn->socket.close(ign); + return c; + } + else { + // Remove dead connection + q_i = q.erase(q_i); + continue; + } + } + q_i++; + } + q.push_back(conn); + } + return conn; +} + +void TcpConnectionPool::returnConnection(std::shared_ptr conn) +{ + if (!conn) + return; + ConnKey key{conn->host, conn->port}; + + std::lock_guard lk(mu_); + auto it = pool_.find(key); + if (it != pool_.end()) { + auto& q = it->second; + for (auto qi = q.begin(); qi != q.end(); ++qi) + if (*qi == conn) { + if (conn->socket.is_open()) { + conn->in_use = false; + conn->last_used = std::chrono::steady_clock::now(); + } + else { + q.erase(qi); + } + break; + } + if (q.empty()) + pool_.erase(it); + return; + } + + // Connection not found in pool (temporary), close it. + if (conn->socket.is_open()) { + asio::error_code ec; + conn->socket.close(ec); + if (ec) + SLIME_LOG_WARN( + "TcpConnectionPool: close temp conn ", conn->host, ":", conn->port, " failed: ", ec.message()); + } +} + +void TcpConnectionPool::cleanupIdleConnections(bool lock) +{ + auto now = std::chrono::steady_clock::now(); + if (lock) + std::lock_guard lk(mu_); + for (auto it = pool_.begin(); it != pool_.end();) { + auto& q = it->second; + while (!q.empty()) { + auto& c = q.back(); + if (!c->in_use) { + auto idle = std::chrono::duration_cast(now - c->last_used).count(); + if (idle > kIdleTimeout.count()) { + asio::error_code ign; + c->socket.close(ign); + q.pop_back(); + continue; + } + } + break; + } + if (q.empty()) + it = pool_.erase(it); + else + ++it; + } +} + +void TcpConnectionPool::clear() +{ + std::lock_guard lk(mu_); + for (auto& [_, q] : pool_) + // force close + for (auto& c : q) { + c->in_use = false; + asio::error_code ign; + c->socket.close(ign); + } + pool_.clear(); +} + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/dlslime/csrc/engine/tcp/tcp_connection_pool.h b/dlslime/dlslime/csrc/engine/tcp/tcp_connection_pool.h new file mode 100644 index 00000000..878314f1 --- /dev/null +++ b/dlslime/dlslime/csrc/engine/tcp/tcp_connection_pool.h @@ -0,0 +1,65 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace dlslime { +namespace tcp { + +struct PooledConnection { + asio::ip::tcp::socket socket; + std::string host; + uint16_t port{0}; + std::chrono::steady_clock::time_point last_used; + bool in_use{true}; + + PooledConnection(asio::ip::tcp::socket s, std::string h, uint16_t p): + socket(std::move(s)), host(std::move(h)), port(p), last_used(std::chrono::steady_clock::now()) + { + } +}; + +// Keyed by (host, port). Thread-safe. +// States: IDLE (in deque, in_use=false) / ACTIVE (checked out) / RESERVED +class TcpConnectionPool { +public: + static constexpr std::chrono::seconds kIdleTimeout{300}; + + explicit TcpConnectionPool(asio::io_context& io_ctx): io_ctx_(io_ctx) {} + + std::shared_ptr getConnection(const std::string& host, uint16_t port); + + void returnConnection(std::shared_ptr conn); + + void cleanupIdleConnections(bool lock = true); + void clear(); + +private: + struct ConnKey { + std::string host; + uint16_t port; + bool operator==(const ConnKey& o) const + { + return host == o.host && port == o.port; + } + }; + struct ConnKeyHash { + size_t operator()(const ConnKey& k) const + { + return std::hash{}(k.host) ^ (std::hash{}(k.port) << 1); + } + }; + + asio::io_context& io_ctx_; + std::mutex mu_; + std::unordered_map>, ConnKeyHash> pool_; +}; + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/dlslime/csrc/engine/tcp/tcp_context.cpp b/dlslime/dlslime/csrc/engine/tcp/tcp_context.cpp new file mode 100644 index 00000000..f1d61e3d --- /dev/null +++ b/dlslime/dlslime/csrc/engine/tcp/tcp_context.cpp @@ -0,0 +1,30 @@ +#include "tcp_context.h" + +namespace dlslime { +namespace tcp { + +TcpContext::TcpContext() +{ + // Keep io_context alive even when there's no work posted yet. + auto work = asio::make_work_guard(io_ctx_); + io_thread_ = std::thread([this, w = std::move(work)]() { io_ctx_.run(); }); +} + +TcpContext::~TcpContext() +{ + shutdown(); +} + +void TcpContext::shutdown() +{ + if (!running_) + return; + running_ = false; + io_ctx_.stop(); + if (io_thread_.joinable()) + io_thread_.join(); + conn_pool_.clear(); +} + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/dlslime/csrc/engine/tcp/tcp_context.h b/dlslime/dlslime/csrc/engine/tcp/tcp_context.h new file mode 100644 index 00000000..3fd44308 --- /dev/null +++ b/dlslime/dlslime/csrc/engine/tcp/tcp_context.h @@ -0,0 +1,45 @@ +#pragma once + +#include +#include +#include + +#include "tcp_connection_pool.h" + +namespace dlslime { +namespace tcp { + +// Process-level shared resource holder. +// Multiple TcpEndpoints can share one TcpContext to run on a single +// io_context thread, reducing thread count. +// +// For sync wrappers: sync_send() = async_send() + future.wait() +// — the io_context drives the I/O, the caller thread just blocks. +class TcpContext { +public: + TcpContext(); + ~TcpContext(); + + TcpContext(const TcpContext&) = delete; + TcpContext& operator=(const TcpContext&) = delete; + + asio::io_context& io_context() + { + return io_ctx_; + } + TcpConnectionPool& conn_pool() + { + return conn_pool_; + } + + void shutdown(); + +private: + asio::io_context io_ctx_; + std::thread io_thread_; + TcpConnectionPool conn_pool_{io_ctx_}; + bool running_{true}; +}; + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/dlslime/csrc/engine/tcp/tcp_endpoint.cpp b/dlslime/dlslime/csrc/engine/tcp/tcp_endpoint.cpp new file mode 100644 index 00000000..b91136f5 --- /dev/null +++ b/dlslime/dlslime/csrc/engine/tcp/tcp_endpoint.cpp @@ -0,0 +1,447 @@ +#include "tcp_endpoint.h" + +#include +#include + +#include +#include + +#include "dlslime/csrc/logging.h" + +#ifdef USE_CUDA +#include +#endif + +namespace dlslime { +namespace tcp { + +using tcp = asio::ip::tcp; + +// ── helpers ───────────────────────────────────────────── + +static void hdr_hton(SessionHeader& h) +{ + h.size = htole64(h.size); + h.addr = htole64(h.addr); +} + +#ifdef USE_CUDA +static bool is_cuda_memory(const void* addr) +{ + cudaPointerAttributes attr; + auto st = cudaPointerGetAttributes(&attr, addr); + return (st == cudaSuccess && attr.type == cudaMemoryTypeDevice); +} +#endif + +// ── RecvMatcher factory ──────────────────────────────── + +ServerSession::RecvMatcher TcpEndpoint::make_recv_matcher() +{ + std::weak_ptr weak = shared_from_this(); + return [weak]() -> RecvSlot { + auto self = weak.lock(); + if (!self) + return {}; + std::lock_guard lk(self->recv_mu_); + if (self->pending_recvs_.empty()) + return {}; + auto pr = std::move(self->pending_recvs_.front()); + self->pending_recvs_.pop_front(); + + RecvSlot slot{pr.op_state->user_buffer, pr.op_state->user_length, pr.op_state, {}, pr.exact_size}; +#ifdef USE_CUDA + if (pr.cuda_dst) { + slot.buffer = reinterpret_cast(pr.staging_buf.get()); + slot.post_read = [buf = std::shared_ptr(std::move(pr.staging_buf)), + dst = pr.cuda_dst, + len = pr.op_state->user_length]() { + auto cu_err = cudaMemcpy(reinterpret_cast(dst), buf.get(), len, cudaMemcpyHostToDevice); + if (cu_err != cudaSuccess) + SLIME_LOG_ERROR("cudaMemcpy H2D (recv): ", cudaGetErrorString(cu_err)); + }; + } +#endif + return slot; + }; +} + +// ── Constructor ──────────────────────────────────────── + +TcpEndpoint::TcpEndpoint(const std::string& ip, uint16_t port): + own_ctx_(std::make_unique()), + acceptor_(own_ctx_->io_context()), + local_pool_(std::make_shared()), + remote_pool_(std::make_shared()), + local_host_(ip) +{ + ctx_ = own_ctx_.get(); + local_port_ = port; + start_io(); +} + +TcpEndpoint::~TcpEndpoint() +{ + shutdown(); +} + +void TcpEndpoint::start_io() +{ + asio::error_code ec; + auto addr = asio::ip::make_address(local_host_); + auto ep = tcp::endpoint(addr, local_port_); + acceptor_.open(ep.protocol()); + acceptor_.set_option(tcp::acceptor::reuse_address(true)); + acceptor_.bind(ep, ec); + if (ec) { + SLIME_LOG_ERROR("acceptor_.bind failed ", local_host_, ":", local_port_, " ERROR:", ec.message()); + } + acceptor_.listen(64); + if (local_port_ == 0) { + asio::error_code ec; + local_port_ = acceptor_.local_endpoint(ec).port(); + } + + do_accept(); +} + +// ── do_accept ─────────────────────────────────────────── + +void TcpEndpoint::do_accept() +{ + if (!running_.load(std::memory_order_acquire)) + return; + acceptor_.async_accept([this](asio::error_code ec, tcp::socket sock) { + if (ec) { + if (ec != asio::error::operation_aborted) + SLIME_LOG_WARN("TcpEndpoint accept: ", ec.message()); + return; + } + sock.set_option(tcp::no_delay(true)); + auto session = std::make_shared(std::move(sock), local_pool_.get(), make_recv_matcher()); + session->start(); + do_accept(); + }); +} + +// ── endpoint_info / connect ───────────────────────────── + +json TcpEndpoint::endpoint_info() const +{ + return {{"host", local_host_}, {"port", local_port_}, {"mr_info", local_pool_->mr_info()}}; +} + +json TcpEndpoint::mr_info() const +{ + return local_pool_->mr_info(); +} + +void TcpEndpoint::connect(const json& remote_endpoint_info) +{ + auto host = remote_endpoint_info.value("host", ""); + auto port = static_cast(remote_endpoint_info.value("port", 0)); + + // Verify reachability before accepting the peer identity. + auto conn = ctx_->conn_pool().getConnection(host, port); + if (!conn) { + SLIME_LOG_WARN("TcpEndpoint::connect: cannot reach ", host, ":", port); + return; + } + + peer_host_ = host; + peer_port_ = port; + if (remote_endpoint_info.contains("mr_info")) { + for (const auto& [name, info] : remote_endpoint_info["mr_info"].items()) + remote_pool_->register_remote_memory_region(info, name); + } + connected_.store(true, std::memory_order_release); + ctx_->conn_pool().returnConnection(std::move(conn)); +} + +// ── memory registration ───────────────────────────────── + +int32_t TcpEndpoint::register_memory_region(const std::string& name, uintptr_t ptr, uintptr_t offset, size_t length) +{ + return local_pool_->register_memory_region(ptr + offset, length, name); +} + +int32_t TcpEndpoint::register_remote_memory_region(const std::string& name, const json& mr_info) +{ + return remote_pool_->register_remote_memory_region(mr_info, name); +} + +// ── async_send ────────────────────────────────────────── +// chunk_tuple_t = (src_ptr, offset, length) — raw pointers, no MR lookup. + +std::shared_ptr TcpEndpoint::async_send(const chunk_tuple_t& chunk, int64_t /*timeout_ms*/) +{ + uintptr_t src = std::get<0>(chunk) + std::get<1>(chunk); + size_t len = std::get<2>(chunk); + + auto conn = ctx_->conn_pool().getConnection(peer_host_, peer_port_); + auto op = TcpOpState::create(); + op->signal->reset_all(); + + if (!conn) { + op->completion_status.store(TCP_FAILED, std::memory_order_release); + op->signal->force_complete(); + return std::make_shared(op); + } + + SessionHeader hdr{len, 0, OP_SEND}; + auto& pool = ctx_->conn_pool(); + + auto* send_ptr = reinterpret_cast(src); + bool is_cuda = false; +#ifdef USE_CUDA + if (is_cuda_memory(send_ptr)) { + auto* buf = new char[len]; + auto cu_err = cudaMemcpy(buf, send_ptr, len, cudaMemcpyDeviceToHost); + if (cu_err != cudaSuccess) { + SLIME_LOG_ERROR("async_send cudaMemcpy D2H: ", cudaGetErrorString(cu_err)); + delete[] buf; + op->completion_status.store(TCP_FAILED, std::memory_order_release); + op->signal->force_complete(); + pool.returnConnection(conn); + return std::make_shared(op); + } + send_ptr = buf; + is_cuda = true; + } +#endif + + auto session = std::make_shared( + std::move(conn->socket), [op, conn, &pool, send_ptr, is_cuda](asio::error_code ec) { + if (ec) + SLIME_LOG_WARN("async_send: ", ec.message()); + op->completion_status.store(ec ? TCP_FAILED : TCP_SUCCESS, std::memory_order_release); + if (op->signal) + op->signal->set_comm_done(0); + pool.returnConnection(conn); +#ifdef USE_CUDA + if (is_cuda) + delete[] send_ptr; +#endif + }); + session->start_write(hdr, send_ptr); + + return std::make_shared(op); +} + +// ── async_recv ────────────────────────────────────────── +// chunk_tuple_t = (dst_ptr, offset, length) — raw pointers, no MR lookup. + +std::shared_ptr TcpEndpoint::async_recv(const chunk_tuple_t& chunk, bool exact_size) +{ + auto op = TcpOpState::create(); + op->signal->reset_all(); + uintptr_t dst = std::get<0>(chunk) + std::get<1>(chunk); + size_t length = std::get<2>(chunk); + op->user_buffer = dst; + op->user_length = length; + + PendingRecv pr{op, nullptr, 0, exact_size}; +#ifdef USE_CUDA + if (is_cuda_memory(reinterpret_cast(dst))) { + auto* buf = new char[length]; + pr.staging_buf.reset(buf); + pr.cuda_dst = dst; + op->user_buffer = reinterpret_cast(buf); + } +#endif + + { + std::lock_guard lk(recv_mu_); + pending_recvs_.push_back(std::move(pr)); + } + + return std::make_shared(op); +} + +// ── async_read ────────────────────────────────────────── +// Each assign creates an independent ClientSession; all share one OpState. +// Future.wait() blocks until every session has signalled its bit. + +std::shared_ptr TcpEndpoint::async_read(const std::vector& assign, + int64_t /*timeout_ms*/) +{ + if (assign.empty()) + throw std::runtime_error("TcpEndpoint::async_read: empty assignment"); + + size_t N = assign.size(); + auto op = TcpOpState::create(); + op->signal->reset_all(); + op->expected_mask = (N < 32) ? (1u << N) - 1 : 0xFFFFFFFFu; + op->completion_status.store(TCP_SUCCESS, std::memory_order_release); + op->completion_mask.store(0, std::memory_order_release); + + auto& pool = ctx_->conn_pool(); + + for (size_t i = 0; i < N; i++) { + const auto& a = assign[i]; + int32_t local_h = static_cast(std::get<0>(a)); + int32_t remote_h = static_cast(std::get<1>(a)); + uint64_t remote_off = std::get<2>(a); + uint64_t local_off = std::get<3>(a); + size_t length = std::get<4>(a); + + auto local = local_pool_->get_mr_fast(local_h); + auto remote = remote_pool_->get_remote_mr_fast(remote_h); + if (local.length == 0 || remote.length == 0) + throw std::runtime_error("TcpEndpoint::async_read: invalid MR handle"); + + uintptr_t local_dst = local.addr + local_off; + SessionHeader hdr{length, remote.addr + remote_off, OP_READ}; + + auto conn = pool.getConnection(peer_host_, peer_port_); + if (!conn) { + op->completion_status.store(TCP_FAILED, std::memory_order_release); + op->signal->set_comm_done(i); + continue; + } + + auto* read_dst = reinterpret_cast(local_dst); + bool is_cuda = false; +#ifdef USE_CUDA + if (is_cuda_memory(read_dst)) { + read_dst = new char[length]; + is_cuda = true; + } +#endif + + auto session = std::make_shared( + std::move(conn->socket), + [op, conn, i, &pool, read_dst, is_cuda, real_dst = local_dst, len = length](asio::error_code ec) { + if (ec) { + SLIME_LOG_WARN("async_read session ", i, ": ", ec.message()); + op->completion_status.store(TCP_FAILED, std::memory_order_release); + } +#ifdef USE_CUDA + if (!ec && is_cuda) { + auto cu_err = cudaMemcpy(reinterpret_cast(real_dst), read_dst, len, cudaMemcpyHostToDevice); + if (cu_err != cudaSuccess) { + SLIME_LOG_ERROR("async_read cudaMemcpy H2D: ", cudaGetErrorString(cu_err)); + op->completion_status.store(TCP_FAILED, std::memory_order_release); + } + } + if (is_cuda) + delete[] read_dst; +#endif + if (op->signal) + op->signal->set_comm_done(i); + pool.returnConnection(conn); + }); + session->start_read(hdr, read_dst); + } + + return std::make_shared(op); +} + +// ── async_write ───────────────────────────────────────── +// Each assign creates an independent ClientSession; all share one OpState. + +std::shared_ptr TcpEndpoint::async_write(const std::vector& assign, + int64_t /*timeout_ms*/) +{ + if (assign.empty()) + throw std::runtime_error("TcpEndpoint::async_write: empty assignment"); + + size_t N = assign.size(); + auto op = TcpOpState::create(); + op->signal->reset_all(); + op->expected_mask = (N < 32) ? (1u << N) - 1 : 0xFFFFFFFFu; + op->completion_status.store(TCP_SUCCESS, std::memory_order_release); + op->completion_mask.store(0, std::memory_order_release); + + auto& pool = ctx_->conn_pool(); + + for (size_t i = 0; i < N; i++) { + const auto& a = assign[i]; + int32_t local_h = static_cast(std::get<0>(a)); + int32_t remote_h = static_cast(std::get<1>(a)); + uint64_t remote_off = std::get<2>(a); + uint64_t local_off = std::get<3>(a); + size_t length = std::get<4>(a); + + auto local = local_pool_->get_mr_fast(local_h); + auto remote = remote_pool_->get_remote_mr_fast(remote_h); + if (local.length == 0 || remote.length == 0) + throw std::runtime_error("TcpEndpoint::async_write: invalid MR handle"); + + uintptr_t src = local.addr + local_off; + SessionHeader hdr{length, remote.addr + remote_off, OP_WRITE}; + + auto conn = pool.getConnection(peer_host_, peer_port_); + if (!conn) { + op->completion_status.store(TCP_FAILED, std::memory_order_release); + op->signal->set_comm_done(i); + continue; + } + + auto* send_ptr = reinterpret_cast(src); + bool is_cuda = false; +#ifdef USE_CUDA + if (is_cuda_memory(send_ptr)) { + auto* buf = new char[length]; + auto cu_err = cudaMemcpy(buf, send_ptr, length, cudaMemcpyDeviceToHost); + if (cu_err != cudaSuccess) { + SLIME_LOG_ERROR("async_write cudaMemcpy D2H: ", cudaGetErrorString(cu_err)); + delete[] buf; + op->completion_status.store(TCP_FAILED, std::memory_order_release); + op->signal->force_complete(); + pool.returnConnection(conn); + return std::make_shared(op); + } + send_ptr = buf; + is_cuda = true; + } +#endif + + auto session = std::make_shared( + std::move(conn->socket), [op, conn, i, &pool, send_ptr, is_cuda](asio::error_code ec) { + if (ec) { + SLIME_LOG_WARN("async_write session ", i, ": ", ec.message()); + op->completion_status.store(TCP_FAILED, std::memory_order_release); + } + if (op->signal) + op->signal->set_comm_done(i); + pool.returnConnection(conn); +#ifdef USE_CUDA + if (is_cuda) + delete[] send_ptr; +#endif + }); + session->start_write(hdr, send_ptr); + } + + return std::make_shared(op); +} + +// ── shutdown ──────────────────────────────────────────── + +void TcpEndpoint::shutdown() +{ + bool expected = true; + if (!running_.compare_exchange_strong(expected, false)) + return; + + connected_.store(false, std::memory_order_release); + acceptor_.close(); + + { + std::lock_guard lk(recv_mu_); + for (auto& pr : pending_recvs_) { + if (pr.op_state && pr.op_state->signal) { + pr.op_state->completion_status.store(TCP_CLOSED, std::memory_order_release); + pr.op_state->signal->force_complete(); + } + } + pending_recvs_.clear(); + } + + if (own_ctx_) + own_ctx_->shutdown(); +} + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/dlslime/csrc/engine/tcp/tcp_endpoint.h b/dlslime/dlslime/csrc/engine/tcp/tcp_endpoint.h new file mode 100644 index 00000000..928e5601 --- /dev/null +++ b/dlslime/dlslime/csrc/engine/tcp/tcp_endpoint.h @@ -0,0 +1,111 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "dlslime/csrc/common/json.hpp" +#include "dlslime/csrc/engine/assignment.h" +#include "tcp_connection_pool.h" +#include "tcp_context.h" +#include "tcp_future.h" +#include "tcp_header.h" +#include "tcp_memory_pool.h" +#include "tcp_op_state.h" +#include "tcp_session.h" + +namespace dlslime { +namespace tcp { + +using json = nlohmann::json; + +class TcpEndpoint: public std::enable_shared_from_this { +public: + static constexpr int64_t kDefaultTimeoutMs = 30000; + + explicit TcpEndpoint(const std::string& ip = "0.0.0.0", uint16_t port = 0); + + TcpEndpoint(TcpContext& ctx, const std::string& ip = "0.0.0.0", uint16_t port = 0) = delete; + + ~TcpEndpoint(); + + TcpEndpoint(const TcpEndpoint&) = delete; + TcpEndpoint& operator=(const TcpEndpoint&) = delete; + + // ── Connection ────────────────────────────────────── + json endpoint_info() const; + void connect(const json& remote_endpoint_info); + void shutdown(); + + // ── Memory ────────────────────────────────────────── + int32_t register_memory_region(const std::string& name, uintptr_t ptr, uintptr_t offset, size_t length); + int32_t register_remote_memory_region(const std::string& name, const json& mr_info); + json mr_info() const; + + // ── Async I/O (all return Future immediately; I/O runs on io_context thread) ── + + std::shared_ptr async_send(const chunk_tuple_t& chunk, int64_t timeout_ms = kDefaultTimeoutMs); + + std::shared_ptr async_recv(const chunk_tuple_t& chunk, bool exact_size = false); + + std::shared_ptr async_read(const std::vector& assign, + int64_t timeout_ms = kDefaultTimeoutMs); + + std::shared_ptr async_write(const std::vector& assign, + int64_t timeout_ms = kDefaultTimeoutMs); + + // ── Accessors ─────────────────────────────────────── + void setId(int64_t id) + { + id_.store(id, std::memory_order_relaxed); + } + int64_t getId() const + { + return id_.load(std::memory_order_relaxed); + } + bool is_connected() const + { + return connected_.load(std::memory_order_acquire); + } + +private: + void start_io(); + void do_accept(); + ServerSession::RecvMatcher make_recv_matcher(); + + // ── identity ──────────────────────────────────────── + std::atomic id_{-1}; + std::string local_host_{"0.0.0.0"}; + uint16_t local_port_{0}; + std::string peer_host_; + uint16_t peer_port_{0}; + std::atomic connected_{false}; + + // ── asio core ─────────────────────────────────────── + TcpContext* ctx_{nullptr}; + std::unique_ptr own_ctx_; + asio::ip::tcp::acceptor acceptor_; + std::atomic running_{true}; + + // ── memory ────────────────────────────────────────── + std::shared_ptr local_pool_; + std::shared_ptr remote_pool_; + + // ── recv matching ─────────────────────────────────── + struct PendingRecv { + std::shared_ptr op_state; + std::unique_ptr staging_buf; + uintptr_t cuda_dst{0}; + bool exact_size{false}; + }; + std::mutex recv_mu_; + std::deque pending_recvs_; +}; + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/dlslime/csrc/engine/tcp/tcp_future.h b/dlslime/dlslime/csrc/engine/tcp/tcp_future.h new file mode 100644 index 00000000..947f71f2 --- /dev/null +++ b/dlslime/dlslime/csrc/engine/tcp/tcp_future.h @@ -0,0 +1,77 @@ +#pragma once + +#include +#include +#include +#include + +#include "dlslime/csrc/common/pause.h" +#include "dlslime/csrc/device/device_future.h" +#include "tcp_op_state.h" + +namespace dlslime { +namespace tcp { + +class TcpFuture: public DeviceFuture { +public: + explicit TcpFuture(std::shared_ptr op): op_state_(std::move(op)) + { + if (!op_state_) + throw std::runtime_error("TcpFuture: null op_state"); + } + + int32_t wait() const override + { + if (op_state_->signal) + op_state_->signal->wait_comm_done_cpu(op_state_->expected_mask); + return op_state_->completion_status.load(std::memory_order_acquire); + } + + // Block up to timeout_ms milliseconds. Returns true iff completed; + // writes status to *out. On timeout the operation is still in-flight. + bool wait_for(int64_t timeout_ms, int32_t* out) const + { + auto deadline = std::chrono::steady_clock::now() + std::chrono::milliseconds(timeout_ms); + while (true) { + if (op_state_->signal) { + uint32_t m = op_state_->signal->get_comm_done_mask(); + if ((m & op_state_->expected_mask) == op_state_->expected_mask) { + if (out) + *out = op_state_->completion_status.load(std::memory_order_acquire); + return true; + } + } + if (std::chrono::steady_clock::now() >= deadline) { + if (op_state_->signal) { + uint32_t m = op_state_->signal->get_comm_done_mask(); + if ((m & op_state_->expected_mask) == op_state_->expected_mask) { + if (out) + *out = op_state_->completion_status.load(std::memory_order_acquire); + return true; + } + } + return false; + } + machnet_pause(); + } + } + +protected: + std::shared_ptr op_state_; +}; + +class TcpSendFuture: public TcpFuture { +public: + using TcpFuture::TcpFuture; +}; +class TcpRecvFuture: public TcpFuture { +public: + using TcpFuture::TcpFuture; +}; +class TcpReadWriteFuture: public TcpFuture { +public: + using TcpFuture::TcpFuture; +}; + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/dlslime/csrc/engine/tcp/tcp_header.h b/dlslime/dlslime/csrc/engine/tcp/tcp_header.h new file mode 100644 index 00000000..3c09d395 --- /dev/null +++ b/dlslime/dlslime/csrc/engine/tcp/tcp_header.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include + +namespace dlslime { +namespace tcp { + +// 17-byte wire header, referenced from Mooncake SessionHeader. +// offset size field +// 0 8 size payload byte count (htole64 / le64toh) +// 8 8 addr remote buffer virtual address +// 16 1 opcode SEND=0x00 READ=0x01 WRITE=0x02 + +#pragma pack(push, 1) +struct SessionHeader { + uint64_t size; + uint64_t addr; + uint8_t opcode; +}; +#pragma pack(pop) + +static_assert(sizeof(SessionHeader) == 17, "SessionHeader must be 17 bytes"); + +enum OpCode : uint8_t { + OP_SEND = 0x00, // header + payload → peer recv matches + OP_READ = 0x01, // header only → peer reads local memory → sends data back + OP_WRITE = 0x02, // header + payload → peer writes to local memory +}; + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/dlslime/csrc/engine/tcp/tcp_memory_pool.cpp b/dlslime/dlslime/csrc/engine/tcp/tcp_memory_pool.cpp new file mode 100644 index 00000000..8e7d1f7b --- /dev/null +++ b/dlslime/dlslime/csrc/engine/tcp/tcp_memory_pool.cpp @@ -0,0 +1,146 @@ +#include "tcp_memory_pool.h" + +#include "dlslime/csrc/logging.h" + +namespace dlslime { +namespace tcp { + +// ── local MR ──────────────────────────────────────────── + +int32_t TcpMemoryPool::register_memory_region(uintptr_t addr, size_t length, const std::string& name) +{ + + if (name.empty()) { + SLIME_LOG_WARN("TcpMemoryPool: empty name rejected"); + return -1; + } + if (name_to_handle_.find(name) != name_to_handle_.end()) { + SLIME_LOG_WARN("TcpMemoryPool: duplicate name '", name, "' rejected"); + return -1; + } + + auto pit = ptr_to_handle_.find(addr); + if (pit != ptr_to_handle_.end()) { + int32_t h = pit->second; + if (h >= 0 && static_cast(h) < handle_to_mr_.size() && handle_to_mr_[h].addr == addr + && handle_to_mr_[h].length >= length) { + name_to_handle_[name] = h; + return h; + } + } + + int32_t h = static_cast(handle_to_mr_.size()); + handle_to_mr_.push_back({addr, length}); + ptr_to_handle_[addr] = h; + name_to_handle_[name] = h; + return h; +} + +int32_t TcpMemoryPool::unregister_memory_region(int32_t handle) +{ + if (handle < 0 || static_cast(handle) >= handle_to_mr_.size()) + return -1; + + auto& mr = handle_to_mr_[handle]; + ptr_to_handle_.erase(mr.addr); + + // Remove all name→handle entries pointing to this handle. + for (auto it = name_to_handle_.begin(); it != name_to_handle_.end();) { + if (it->second == handle) + it = name_to_handle_.erase(it); + else + ++it; + } + + mr = {}; + return 0; +} + +// ── remote MR ─────────────────────────────────────────── + +int32_t TcpMemoryPool::register_remote_memory_region(const json& mr_info, std::optional name) +{ + + std::string mr_name = name.value_or(mr_info.value("name", "")); + + if (!mr_name.empty()) { + auto it = remote_name_to_handle_.find(mr_name); + if (it != remote_name_to_handle_.end()) { + int32_t h = it->second; + auto& rm = remote_handle_to_mr_[h]; + rm.addr = mr_info.value("addr", 0UL); + rm.length = mr_info.value("length", 0UL); + return h; + } + } + + int32_t h = static_cast(remote_handle_to_mr_.size()); + remote_handle_to_mr_.push_back({mr_info.value("addr", 0UL), mr_info.value("length", 0UL)}); + remote_handle_to_name_.push_back(mr_name); + if (!mr_name.empty()) + remote_name_to_handle_[mr_name] = h; + return h; +} + +int32_t TcpMemoryPool::unregister_remote_memory_region(int32_t handle) +{ + if (handle < 0 || static_cast(handle) >= remote_handle_to_mr_.size()) + return -1; + auto& s = remote_handle_to_name_[handle]; + if (!s.empty()) + remote_name_to_handle_.erase(s); + remote_handle_to_mr_[handle] = {}; + s.clear(); + return 0; +} + +// ── fast lookup ───────────────────────────────────────── + +TcpMr TcpMemoryPool::get_mr_fast(int32_t handle) const +{ + if (handle < 0 || static_cast(handle) >= handle_to_mr_.size()) + return {}; + return handle_to_mr_[handle]; +} + +TcpMr TcpMemoryPool::get_remote_mr_fast(int32_t handle) const +{ + if (handle < 0 || static_cast(handle) >= remote_handle_to_mr_.size()) + return {}; + return remote_handle_to_mr_[handle]; +} + +int32_t TcpMemoryPool::get_mr_handle(const std::string& name) const +{ + auto it = name_to_handle_.find(name); + return it != name_to_handle_.end() ? it->second : -1; +} + +int32_t TcpMemoryPool::get_remote_mr_handle(const std::string& name) const +{ + auto it = remote_name_to_handle_.find(name); + return it != remote_name_to_handle_.end() ? it->second : -1; +} + +// ── serialization ─────────────────────────────────────── + +json TcpMemoryPool::mr_info() const +{ + json j = json::object(); + for (const auto& [name, h] : name_to_handle_) + if (h >= 0 && static_cast(h) < handle_to_mr_.size() && handle_to_mr_[h].length > 0) + j[name] = handle_to_mr_[h].json_info(name); + return j; +} + +json TcpMemoryPool::remote_mr_info() const +{ + json j = json::object(); + for (const auto& [name, h] : remote_name_to_handle_) + if (h >= 0 && static_cast(h) < remote_handle_to_mr_.size() && remote_handle_to_mr_[h].length > 0) + j[name] = remote_handle_to_mr_[h].json_info(name); + return j; +} + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/dlslime/csrc/engine/tcp/tcp_memory_pool.h b/dlslime/dlslime/csrc/engine/tcp/tcp_memory_pool.h new file mode 100644 index 00000000..279fe207 --- /dev/null +++ b/dlslime/dlslime/csrc/engine/tcp/tcp_memory_pool.h @@ -0,0 +1,60 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "dlslime/csrc/common/json.hpp" + +namespace dlslime { +namespace tcp { + +using json = nlohmann::json; + +struct TcpMr { + uintptr_t addr{0}; + size_t length{0}; + + json json_info(const std::string& name) const + { + return {{"name", name}, {"addr", addr}, {"length", length}}; + } +}; + +// Pure-bookkeeping pool. No hardware registration needed for TCP. +class TcpMemoryPool { +public: + TcpMemoryPool() = default; + + // name must be non-empty and unique; returns -1 on violation. + int32_t register_memory_region(uintptr_t addr, size_t length, const std::string& name); + int32_t unregister_memory_region(int32_t handle); + + // remote MR — name is optional (may come from peer's mr_info). + int32_t register_remote_memory_region(const json& mr_info, std::optional name = std::nullopt); + int32_t unregister_remote_memory_region(int32_t handle); + + TcpMr get_mr_fast(int32_t handle) const; + TcpMr get_remote_mr_fast(int32_t handle) const; + int32_t get_mr_handle(const std::string& name) const; + int32_t get_remote_mr_handle(const std::string& name) const; + + json mr_info() const; + json remote_mr_info() const; + +private: + // local MRs + std::unordered_map name_to_handle_; + std::unordered_map ptr_to_handle_; + std::vector handle_to_mr_; + + // remote MRs + std::unordered_map remote_name_to_handle_; + std::vector remote_handle_to_mr_; + std::vector remote_handle_to_name_; +}; + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/dlslime/csrc/engine/tcp/tcp_op_state.h b/dlslime/dlslime/csrc/engine/tcp/tcp_op_state.h new file mode 100644 index 00000000..b10eb0b3 --- /dev/null +++ b/dlslime/dlslime/csrc/engine/tcp/tcp_op_state.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include +#include + +#include "dlslime/csrc/device/device_api.h" +#include "dlslime/csrc/device/signal.h" + +namespace dlslime { +namespace tcp { + +enum Status : int32_t { + TCP_SUCCESS = 0, + TCP_FAILED = -1, + TCP_TIMEOUT = -2, + TCP_CLOSED = -3, +}; + +// One per in-flight operation. The io_context thread (or caller for sync +// ops) signals completion via DeviceSignal; the future's wait() spins on +// wait_comm_done_cpu(). +struct TcpOpState { + std::shared_ptr signal; + uint32_t expected_mask{1}; + std::atomic completion_mask{0}; + std::atomic completion_status{TCP_SUCCESS}; + + uintptr_t user_buffer{0}; + size_t user_length{0}; + size_t bytes_copied{0}; + + static std::shared_ptr create() + { + auto s = std::make_shared(); + s->signal = dlslime::device::createSignal(false); + return s; + } +}; + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/dlslime/csrc/engine/tcp/tcp_session.cpp b/dlslime/dlslime/csrc/engine/tcp/tcp_session.cpp new file mode 100644 index 00000000..88c12968 --- /dev/null +++ b/dlslime/dlslime/csrc/engine/tcp/tcp_session.cpp @@ -0,0 +1,271 @@ +#include "tcp_session.h" + +#include + +#include +#include +#include +#include + +#include "dlslime/csrc/logging.h" + +#ifdef USE_CUDA +#include +#endif + +namespace dlslime { +namespace tcp { + +// ── helpers ───────────────────────────────────────────── + +static void hdr_to_net(SessionHeader& hdr) +{ + hdr.size = htole64(hdr.size); + hdr.addr = htole64(hdr.addr); +} + +static void hdr_to_host(SessionHeader& hdr) +{ + hdr.size = le64toh(hdr.size); + hdr.addr = le64toh(hdr.addr); +} + +static bool is_fatal(asio::error_code ec) +{ + return ec && ec != asio::error::eof; +} + +#ifdef USE_CUDA +static bool is_cuda_memory(const void* addr) +{ + cudaPointerAttributes attr; + auto st = cudaPointerGetAttributes(&attr, addr); + return (st == cudaSuccess && attr.type == cudaMemoryTypeDevice); +} +#endif + +// ── ServerSession ─────────────────────────────────────── + +ServerSession::ServerSession(asio::ip::tcp::socket socket, TcpMemoryPool* local_pool, RecvMatcher recv_matcher): + socket_(std::move(socket)), local_pool_(local_pool), recv_matcher_(std::move(recv_matcher)) +{ +} + +void ServerSession::start() +{ + readHeader(); +} + +void ServerSession::readHeader() +{ + auto self = shared_from_this(); + header_ = {}; + asio::async_read(socket_, asio::buffer(&header_, sizeof(header_)), [this, self](asio::error_code ec, size_t /*n*/) { + if (ec) { + if (is_fatal(ec)) + SLIME_LOG_WARN("ServerSession::readHeader ", ec.message()); + return; + } + hdr_to_host(header_); + dispatch(); + }); +} + +void ServerSession::dispatch() +{ + switch (header_.opcode) { + + case OP_SEND: { + if (header_.size == 0) { + readHeader(); + return; + } + auto slot = recv_matcher_(); + if (!slot.buffer || slot.length == 0) { + SLIME_LOG_WARN("ServerSession: OP_SEND with no pending recv"); + readHeader(); + return; + } + if (slot.exact_size && header_.size != slot.length) { + SLIME_LOG_WARN("ServerSession: size mismatch, send ", header_.size, " != recv ", slot.length); + if (slot.op_state) { + slot.op_state->completion_status.store(TCP_FAILED, std::memory_order_release); + if (slot.op_state->signal) + slot.op_state->signal->set_comm_done(0); + } + readHeader(); + return; + } + + // Always drain the full send payload from the wire. If recv buffer + // is smaller, read into a temp buffer then copy what fits. + size_t n_read = static_cast(header_.size); + size_t n_copy = std::min(n_read, slot.length); + auto* dst = reinterpret_cast(slot.buffer); + bool overflow = false; + + if (header_.size > slot.length) { + dst = new char[n_read]; + overflow = true; + } + + auto self = shared_from_this(); + asio::async_read(socket_, + asio::buffer(dst, n_read), + [this, self, slot, n_copy, dst, overflow](asio::error_code ec, size_t /*rn*/) { + if (ec) { + if (is_fatal(ec)) + SLIME_LOG_WARN("ServerSession SEND read: ", ec.message()); + if (overflow) + delete[] dst; + return; + } + if (overflow) { + std::memcpy(reinterpret_cast(slot.buffer), dst, n_copy); + delete[] dst; + } + if (slot.post_read) + slot.post_read(); + if (slot.op_state) { + slot.op_state->bytes_copied = n_copy; + slot.op_state->completion_status.store(TCP_SUCCESS, std::memory_order_release); + if (slot.op_state->signal) + slot.op_state->signal->set_comm_done(0); + } + readHeader(); + }); + break; + } + + case OP_WRITE: + if (header_.size == 0) { + readHeader(); + return; + } + readBody(reinterpret_cast(header_.addr), header_.size); + break; + + case OP_READ: { + uintptr_t addr = static_cast(header_.addr); + size_t sz = static_cast(header_.size); + if (sz == 0) { + readHeader(); + return; + } + writeBody(reinterpret_cast(addr), sz); + break; + } + + default: + SLIME_LOG_WARN("ServerSession: unknown opcode ", static_cast(header_.opcode)); + readHeader(); + break; + } +} + +void ServerSession::readBody(void* dst, size_t len) +{ + auto* ptr = static_cast(dst); + bool is_cuda = false; +#ifdef USE_CUDA + if (is_cuda_memory(dst)) { + ptr = new char[len]; + is_cuda = true; + } +#endif + + auto self = shared_from_this(); + asio::async_read(socket_, + asio::buffer(ptr, len), + [this, self, real_addr = reinterpret_cast(dst), len, is_cuda, ptr](asio::error_code ec, + size_t /*n*/) { + if (ec) { + if (is_fatal(ec)) + SLIME_LOG_WARN("ServerSession::readBody ", ec.message()); + if (is_cuda) + delete[] ptr; + return; + } +#ifdef USE_CUDA + if (is_cuda) { + auto cu_err = + cudaMemcpy(reinterpret_cast(real_addr), ptr, len, cudaMemcpyHostToDevice); + if (cu_err != cudaSuccess) + SLIME_LOG_ERROR("readBody cudaMemcpy H2D: ", cudaGetErrorString(cu_err)); + delete[] ptr; + } +#endif + readHeader(); + }); +} + +void ServerSession::writeBody(const void* src, size_t len) +{ + auto* ptr = static_cast(src); + bool is_cuda = false; +#ifdef USE_CUDA + if (is_cuda_memory(src)) { + auto* buf = new char[len]; + auto cu_err = cudaMemcpy(buf, src, len, cudaMemcpyDeviceToHost); + if (cu_err != cudaSuccess) { + SLIME_LOG_ERROR("writeBody cudaMemcpy D2H: ", cudaGetErrorString(cu_err)); + delete[] buf; + ptr = static_cast(src); + } + else { + ptr = buf; + is_cuda = true; + } + } +#endif + + auto self = shared_from_this(); + asio::async_write(socket_, asio::buffer(ptr, len), [this, self, is_cuda, ptr](asio::error_code ec, size_t /*n*/) { + if (is_cuda) + delete[] ptr; + if (ec && is_fatal(ec)) + SLIME_LOG_WARN("ServerSession::writeBody ", ec.message()); + readHeader(); + }); +} + +// ── ClientSession ─────────────────────────────────────── + +ClientSession::ClientSession(asio::ip::tcp::socket sock, DoneCallback on_done): + socket_(std::move(sock)), on_done_(std::move(on_done)) +{ +} + +void ClientSession::start_write(const SessionHeader& hdr, const void* payload) +{ + auto self = shared_from_this(); + SessionHeader net = hdr; + hdr_to_net(net); + std::array bufs = {asio::buffer(&net, sizeof(net)), asio::buffer(payload, hdr.size)}; + asio::async_write(socket_, bufs, [this, self](asio::error_code ec, size_t) { + if (on_done_) + on_done_(ec); + }); +} + +void ClientSession::start_read(const SessionHeader& hdr, void* dst) +{ + auto self = shared_from_this(); + hdr_ = hdr; + SessionHeader net = hdr; + hdr_to_net(net); + asio::async_write(socket_, asio::buffer(&net, sizeof(net)), [this, self, dst](asio::error_code ec, size_t) { + if (ec) { + if (on_done_) + on_done_(ec); + return; + } + asio::async_read(socket_, asio::buffer(dst, hdr_.size), [this, self](asio::error_code ec, size_t) { + if (on_done_) + on_done_(ec); + }); + }); +} + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/dlslime/csrc/engine/tcp/tcp_session.h b/dlslime/dlslime/csrc/engine/tcp/tcp_session.h new file mode 100644 index 00000000..ec8ecaf7 --- /dev/null +++ b/dlslime/dlslime/csrc/engine/tcp/tcp_session.h @@ -0,0 +1,71 @@ +#pragma once + +#include +#include +#include +#include + +#include "tcp_header.h" +#include "tcp_memory_pool.h" +#include "tcp_op_state.h" + +namespace dlslime { +namespace tcp { + +class TcpConnectionPool; + +struct RecvSlot { + uintptr_t buffer{0}; + size_t length{0}; + std::shared_ptr op_state; + std::function post_read; + bool exact_size{false}; // reject send size != recv size +}; + +// ── ServerSession: handles incoming requests on one persistent connection ── +// +// Lifecycle: start() → readHeader → dispatch → readBody/writeBody → readHeader ↻ +class ServerSession: public std::enable_shared_from_this { +public: + using RecvMatcher = std::function; + + ServerSession(asio::ip::tcp::socket socket, TcpMemoryPool* local_pool, RecvMatcher recv_matcher); + + void start(); + +private: + void readHeader(); + void dispatch(); + void readBody(void* dst, size_t len); + void writeBody(const void* src, size_t len); + + asio::ip::tcp::socket socket_; + TcpMemoryPool* local_pool_; + RecvMatcher recv_matcher_; + SessionHeader header_{}; +}; + +// ── ClientSession: drives one outbound I/O operation ───── +// +// Lifecycle: construct → start_write/start_read → on_done → self-destruct +// Does NOT own OpState or PooledConnection — only drives the I/O and reports ec. +class ClientSession: public std::enable_shared_from_this { +public: + using DoneCallback = std::function; + + ClientSession(asio::ip::tcp::socket sock, DoneCallback on_done); + + // Write header + payload to socket (gather async_write). + void start_write(const SessionHeader& hdr, const void* payload); + + // Write OP_READ header → read raw response into dst. + void start_read(const SessionHeader& hdr, void* dst); + +private: + asio::ip::tcp::socket socket_; + DoneCallback on_done_; + SessionHeader hdr_{}; +}; + +} // namespace tcp +} // namespace dlslime diff --git a/dlslime/dlslime/csrc/python/CMakeLists.txt b/dlslime/dlslime/csrc/python/CMakeLists.txt index 389be03a..1584babc 100755 --- a/dlslime/dlslime/csrc/python/CMakeLists.txt +++ b/dlslime/dlslime/csrc/python/CMakeLists.txt @@ -67,6 +67,11 @@ if (BUILD_ASCEND_DIRECT) ) endif() +if (BUILD_TCP) + target_compile_definitions(_slime_c PRIVATE BUILD_TCP) + list(APPEND _slime_c_link_libraries _slime_tcp) +endif() + # Ops moved to NanoCCL - link to NanoCCL if needed # if (BUILD_INTRA_OPS OR BUILD_INTER_OPS) # if (BUILD_INTRA_OPS) diff --git a/dlslime/dlslime/csrc/python/bind.cpp b/dlslime/dlslime/csrc/python/bind.cpp index c0abf314..9c52988f 100644 --- a/dlslime/dlslime/csrc/python/bind.cpp +++ b/dlslime/dlslime/csrc/python/bind.cpp @@ -27,6 +27,12 @@ #include "dlslime/csrc/engine/ascend_direct/ascend_remote_memory_pool.h" #endif +#ifdef BUILD_TCP +#include "dlslime/csrc/engine/tcp/tcp_endpoint.h" +#include "dlslime/csrc/engine/tcp/tcp_future.h" +#include "dlslime/csrc/engine/tcp/tcp_memory_pool.h" +#endif + #include "dlslime/csrc/device/signal.h" #ifdef BUILD_RDMA @@ -89,6 +95,12 @@ namespace py = pybind11; #define BUILD_RPC_ENABLED false #endif +#ifdef BUILD_TCP +#define BUILD_TCP_ENABLED true +#else +#define BUILD_TCP_ENABLED false +#endif + // Ops moved to NanoCCL #define BUILD_INTRA_OPS_ENABLED false #define BUILD_INTER_OPS_ENABLED false @@ -102,6 +114,7 @@ PYBIND11_MODULE(_slime_c, m) EXPOSE_BUILD_FLAG(m, BUILD_INTRA_OPS); EXPOSE_BUILD_FLAG(m, BUILD_INTER_OPS); EXPOSE_BUILD_FLAG(m, BUILD_RPC); + EXPOSE_BUILD_FLAG(m, BUILD_TCP); m.def("discover_topology", &dlslime::topology::discoverTopology, @@ -512,6 +525,122 @@ PYBIND11_MODULE(_slime_c, m) py::call_guard()); #endif +#ifdef BUILD_TCP + // ========================================================================= + // TCP Transport + // ========================================================================= + py::class_>(m, "SlimeTcpSendFuture") + .def("wait", &dlslime::tcp::TcpSendFuture::wait, py::call_guard()) + .def( + "wait_for", + [](const dlslime::tcp::TcpSendFuture& self, double sec) -> py::object { + int32_t st = 0; + int64_t ms = static_cast(sec * 1000.0); + if (ms < 0) + ms = 0; + if (self.wait_for(ms, &st)) + return py::cast(st); + return py::none(); + }, + py::arg("timeout_seconds")); + + py::class_>(m, "SlimeTcpRecvFuture") + .def("wait", &dlslime::tcp::TcpRecvFuture::wait, py::call_guard()) + .def( + "wait_for", + [](const dlslime::tcp::TcpRecvFuture& self, double sec) -> py::object { + int32_t st = 0; + int64_t ms = static_cast(sec * 1000.0); + if (ms < 0) + ms = 0; + if (self.wait_for(ms, &st)) + return py::cast(st); + return py::none(); + }, + py::arg("timeout_seconds")); + + py::class_>( + m, "SlimeTcpReadWriteFuture") + .def("wait", &dlslime::tcp::TcpReadWriteFuture::wait, py::call_guard()) + .def( + "wait_for", + [](const dlslime::tcp::TcpReadWriteFuture& self, double sec) -> py::object { + int32_t st = 0; + int64_t ms = static_cast(sec * 1000.0); + if (ms < 0) + ms = 0; + if (self.wait_for(ms, &st)) + return py::cast(st); + return py::none(); + }, + py::arg("timeout_seconds")); + + py::class_>(m, "TcpMemoryPool") + .def(py::init<>()) + .def("register_memory_region", + &dlslime::tcp::TcpMemoryPool::register_memory_region, + py::arg("addr"), + py::arg("length"), + py::arg("name")) + .def( + "register_remote_memory_region", + [](dlslime::tcp::TcpMemoryPool& self, const json& mr_info, py::object name_obj) { + std::optional name; + if (!name_obj.is_none()) + name = name_obj.cast(); + return self.register_remote_memory_region(mr_info, name); + }, + py::arg("mr_info"), + py::arg("name") = py::none()) + .def("get_mr_handle", &dlslime::tcp::TcpMemoryPool::get_mr_handle, py::arg("name")) + .def("mr_info", &dlslime::tcp::TcpMemoryPool::mr_info); + + py::class_>(m, "TcpEndpoint") + .def(py::init(), py::arg("ip") = "0.0.0.0", py::arg("port") = 0) + .def("connect", + &dlslime::tcp::TcpEndpoint::connect, + py::arg("remote_info"), + py::call_guard()) + .def("endpoint_info", &dlslime::tcp::TcpEndpoint::endpoint_info) + .def("mr_info", &dlslime::tcp::TcpEndpoint::mr_info) + .def("shutdown", &dlslime::tcp::TcpEndpoint::shutdown, py::call_guard()) + .def("is_connected", &dlslime::tcp::TcpEndpoint::is_connected) + .def("register_memory_region", + &dlslime::tcp::TcpEndpoint::register_memory_region, + py::arg("name"), + py::arg("data_ptr"), + py::arg("offset"), + py::arg("length"), + py::call_guard()) + .def("register_remote_memory_region", + &dlslime::tcp::TcpEndpoint::register_remote_memory_region, + py::arg("name"), + py::arg("mr_info"), + py::call_guard()) + .def("async_send", + py::overload_cast(&dlslime::tcp::TcpEndpoint::async_send), + py::arg("chunk"), + py::arg("timeout_ms") = dlslime::tcp::TcpEndpoint::kDefaultTimeoutMs, + py::call_guard()) + .def("async_recv", + &dlslime::tcp::TcpEndpoint::async_recv, + py::arg("chunk"), + py::arg("exact_size") = false, + py::call_guard()) + .def("async_read", + py::overload_cast&, int64_t>( + &dlslime::tcp::TcpEndpoint::async_read), + py::arg("assign"), + py::arg("timeout_ms") = dlslime::tcp::TcpEndpoint::kDefaultTimeoutMs, + py::call_guard()) + .def("async_write", + py::overload_cast&, int64_t>( + &dlslime::tcp::TcpEndpoint::async_write), + py::arg("assign"), + py::arg("timeout_ms") = dlslime::tcp::TcpEndpoint::kDefaultTimeoutMs, + py::call_guard()); +#endif // BUILD_TCP + // ========================================================================= // Observability (always available, independent of backend) // ========================================================================= diff --git a/dlslime/tests/python/test_tcp.py b/dlslime/tests/python/test_tcp.py new file mode 100755 index 00000000..7d6874d2 --- /dev/null +++ b/dlslime/tests/python/test_tcp.py @@ -0,0 +1,722 @@ +import ctypes +import inspect +import os +import socket +import threading +import time + +from dlslime import TcpEndpoint, TcpMemoryPool + +# ── optional torch / CUDA support ──────────────────────── + +_HAS_TORCH = False +_HAS_CUDA = False + +try: + import torch + + _HAS_TORCH = True + _HAS_CUDA = torch.cuda.is_available() +except Exception: + pass + +_CUDA_FORCE_OFF = os.environ.get("DLSLIME_TCP_TEST_CUDA", "").lower() in ( + "0", + "false", + "no", + "off", +) + + +def _torch_skip(): + return not _HAS_TORCH + + +def _cuda_skip(): + if _CUDA_FORCE_OFF: + return True + return not _HAS_CUDA + + +# ── test harness ───────────────────────────────────────── + + +def _sync_run(name, fn_a, fn_b, timeout=120): + err = [] + b = threading.Barrier(2) + + def wrap(fn): + try: + b.wait(10) + fn() + except Exception as e: + err.append(e) + + ta = threading.Thread(target=wrap, args=(fn_a,), daemon=False) + tb = threading.Thread(target=wrap, args=(fn_b,), daemon=False) + ta.start() + tb.start() + ta.join(timeout) + tb.join(timeout) + if ta.is_alive() or tb.is_alive(): + raise RuntimeError(f"{name} FAIL!{timeout}s timeout!") + if len(err) > 0: + print(f"{name} FAIL {err}", flush=True) + return False + else: + print(f"{name} SUCC ", flush=True) + return True + + +# ── ctypes-based tests ─────────────────────────────────── + + +def test_async_send_recv( + port_a: int = 0, port_b: int = 0, ip_a: str = "0.0.0.0", ip_b: str = "0.0.0.0" +): + buf_a = ctypes.create_string_buffer(128) + buf_b = ctypes.create_string_buffer(128) + + ep_a = TcpEndpoint(ip=ip_a, port=port_a) + ep_b = TcpEndpoint(ip=ip_b, port=port_b) + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + + def run_a(): + ep_a.connect(info_b) + ctypes.memmove(ctypes.addressof(buf_a), b"hello", 5) + time.sleep(5) + st = ep_a.async_send((ctypes.addressof(buf_a), 0, 5)).wait() + if st != 0: + raise RuntimeError(f"send: {st}") + st = ep_a.async_recv((ctypes.addressof(buf_a), 5, 5)).wait() + if st != 0: + raise RuntimeError(f"recv: {st}") + if bytes(buf_a[5:10]) != b"world": + raise RuntimeError(f"data: {bytes(buf_a[5:10])}") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + st = ep_b.async_recv((ctypes.addressof(buf_b), 0, 5)).wait() + if st != 0: + raise RuntimeError(f"recv: {st}") + if bytes(buf_b[:5]) != b"hello": + raise RuntimeError(f"data: {bytes(buf_b[:5])}") + ctypes.memmove(ctypes.addressof(buf_b), b"world", 5) + time.sleep(5) + st = ep_b.async_send((ctypes.addressof(buf_b), 0, 5)).wait() + if st != 0: + raise RuntimeError(f"send: {st}") + ep_b.shutdown() + + _sync_run("test_async_send_recv", run_a, run_b, timeout=240) + + +def test_async_send2recv( + port_a: int = 0, port_b: int = 0, ip_a: str = "0.0.0.0", ip_b: str = "0.0.0.0" +): + buf_a = ctypes.create_string_buffer(32) + buf_b = ctypes.create_string_buffer(32) + + ep_a = TcpEndpoint(ip=ip_a, port=port_a) + ep_b = TcpEndpoint(ip=ip_b, port=port_b) + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + + def run_a(): + ep_a.connect(info_b) + ctypes.memmove(ctypes.addressof(buf_a), b"one", 3) + time.sleep(5) + st = ep_a.async_send((ctypes.addressof(buf_a), 0, 3)).wait() + if st != 0: + raise RuntimeError(f"send: {st}") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + st = ep_b.async_recv((ctypes.addressof(buf_b), 0, 3)).wait() + if st != 0: + raise RuntimeError(f"recv: {st}") + if bytes(buf_b[:3]) != b"one": + raise RuntimeError(f"data: {bytes(buf_b[:3])}") + ep_b.shutdown() + + _sync_run("test_async_send_recv_one", run_a, run_b) + + +def test_async_write( + port_a: int = 0, port_b: int = 0, ip_a: str = "0.0.0.0", ip_b: str = "0.0.0.0" +): + buf_a = ctypes.create_string_buffer(256) + buf_b = ctypes.create_string_buffer(256) + addr_a = ctypes.addressof(buf_a) + + ep_a = TcpEndpoint(ip=ip_a, port=port_a) + ep_b = TcpEndpoint(ip=ip_b, port=port_b) + h_a = ep_a.register_memory_region("a", addr_a, 0, 256) + h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 0, 256) + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + h_br = ep_a.register_remote_memory_region("rb", info_b["mr_info"]["b"]) + + test_data = b"hello_from_a" + + def run_a(): + ep_a.connect(info_b) + ctypes.memmove(addr_a, test_data, len(test_data)) + st = ep_a.async_write([(h_a, h_br, 0, 0, len(test_data))]).wait() + if st != 0: + raise RuntimeError(f"write: {st}") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + for _ in range(50): + if bytes(buf_b[: len(test_data)]) == test_data: + break + time.sleep(0.5) + if bytes(buf_b[: len(test_data)]) != test_data: + raise RuntimeError(f"B write not received in {50 * 0.5}s") + ep_b.shutdown() + + _sync_run("test_async_write", run_a, run_b) + + +def test_async_read( + port_a: int = 0, port_b: int = 0, ip_a: str = "0.0.0.0", ip_b: str = "0.0.0.0" +): + buf_a = ctypes.create_string_buffer(256) + buf_b = ctypes.create_string_buffer(256) + addr_a = ctypes.addressof(buf_a) + + ep_a = TcpEndpoint(ip=ip_a, port=port_a) + ep_b = TcpEndpoint(ip=ip_b, port=port_b) + h_a = ep_a.register_memory_region("a", addr_a, 0, 256) + h_b = ep_b.register_memory_region("b", ctypes.addressof(buf_b), 0, 256) + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + h_br = ep_a.register_remote_memory_region("rb", info_b["mr_info"]["b"]) + + test_data = b"hello_from_b" + ctypes.memmove(ctypes.addressof(buf_b), test_data, 12) + + def run_a(): + ep_a.connect(info_b) + st = ep_a.async_read([(h_a, h_br, 0, 0, len(test_data))]).wait() + if st != 0: + raise RuntimeError(f"read: {st}") + if bytes(buf_a[: len(test_data)]) != test_data: + raise RuntimeError("read data mismatch") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + time.sleep(25) + ep_b.shutdown() + + _sync_run("test_async_read", run_a, run_b) + + +# ── skip test ── + + +def test_recv_timeout( + port_a: int = 0, port_b: int = 0, ip_a: str = "0.0.0.0", ip_b: str = "0.0.0.0" +): + buf_a = ctypes.create_string_buffer(32) + + ep_a = TcpEndpoint(ip=ip_a, port=port_a) + ep_b = TcpEndpoint(ip=ip_b, port=port_b) + + def run_b(): + ep_b.connect(ep_a.endpoint_info()) + time.sleep(1.0) + ep_b.shutdown() + + def run_a(): + ep_a.connect(ep_b.endpoint_info()) + fut = ep_a.async_recv((ctypes.addressof(buf_a), 0, 5)) + result = fut.wait_for(0.3) + if result is not None: + raise RuntimeError(f"expected None, got {result}") + ep_a.shutdown() + + _sync_run("test_recv_timeout", run_a, run_b) + + +def test_send_timeout_ms( + port_a: int = 0, port_b: int = 0, ip_a: str = "0.0.0.0", ip_b: str = "0.0.0.0" +): + buf_a = ctypes.create_string_buffer(64) + buf_b = ctypes.create_string_buffer(64) + + ep_a = TcpEndpoint(ip=ip_a, port=port_a) + ep_b = TcpEndpoint(ip=ip_b, port=port_b) + + def run_b(): + ep_b.connect(ep_a.endpoint_info()) + st = ep_b.async_recv((ctypes.addressof(buf_b), 0, 5)).wait() + if st != 0: + raise RuntimeError(f"recv: {st}") + ep_b.shutdown() + + def run_a(): + ep_a.connect(ep_b.endpoint_info()) + ctypes.memmove(ctypes.addressof(buf_a), b"world", 5) + st = ep_a.async_send((ctypes.addressof(buf_a), 0, 5), timeout_ms=10000).wait() + if st != 0: + raise RuntimeError(f"send: {st}") + ep_a.shutdown() + + _sync_run("test_send_timeout_ms", run_a, run_b) + + +def test_default_timeout( + port_a: int = 0, port_b: int = 0, ip_a: str = "0.0.0.0", ip_b: str = "0.0.0.0" +): + buf_a = ctypes.create_string_buffer(32) + buf_b = ctypes.create_string_buffer(32) + + ep_a = TcpEndpoint(ip=ip_a, port=port_a) + ep_b = TcpEndpoint(ip=ip_b, port=port_b) + + def run_b(): + ep_b.connect(ep_a.endpoint_info()) + st = ep_b.async_recv((ctypes.addressof(buf_b), 0, 5)).wait() + if st != 0: + raise RuntimeError(f"recv: {st}") + ep_b.shutdown() + + def run_a(): + ep_a.connect(ep_b.endpoint_info()) + ctypes.memmove(ctypes.addressof(buf_a), b"test!", 5) + st = ep_a.async_send((ctypes.addressof(buf_a), 0, 5)).wait() + if st != 0: + raise RuntimeError(f"send: {st}") + ep_a.shutdown() + + _sync_run("test_default_timeout", run_a, run_b) + + +def test_exact_size_mismatch( + port_a: int = 0, port_b: int = 0, ip_a: str = "0.0.0.0", ip_b: str = "0.0.0.0" +): + buf_a = ctypes.create_string_buffer(32) + buf_b = ctypes.create_string_buffer(32) + + ep_a = TcpEndpoint(ip=ip_a, port=port_a) + ep_b = TcpEndpoint(ip=ip_b, port=port_b) + + def run_b(): + ep_b.connect(ep_a.endpoint_info()) + st = ep_b.async_recv((ctypes.addressof(buf_b), 0, 4), exact_size=True).wait() + if st != -1: + raise RuntimeError(f"expected TCP_FAILED(-1), got {st}") + ep_b.shutdown() + + def run_a(): + ep_a.connect(ep_b.endpoint_info()) + ctypes.memmove(ctypes.addressof(buf_a), b"overflow", 8) + st = ep_a.async_send((ctypes.addressof(buf_a), 0, 8)).wait() + if st != 0: + raise RuntimeError(f"send: {st}") + ep_a.shutdown() + + _sync_run("test_exact_size_mismatch", run_a, run_b) + + +def test_overflow_truncate( + port_a: int = 0, port_b: int = 0, ip_a: str = "0.0.0.0", ip_b: str = "0.0.0.0" +): + buf_a = ctypes.create_string_buffer(64) + buf_b = ctypes.create_string_buffer(64) + + ep_a = TcpEndpoint(ip=ip_a, port=port_a) + ep_b = TcpEndpoint(ip=ip_b, port=port_b) + + def run_b(): + ep_b.connect(ep_a.endpoint_info()) + st = ep_b.async_recv((ctypes.addressof(buf_b), 0, 4)).wait() + if st != 0: + raise RuntimeError(f"recv1: {st}") + if bytes(buf_b[:4]) != b"LONG": + raise RuntimeError(f"truncated: {bytes(buf_b[:4])}") + st = ep_b.async_recv((ctypes.addressof(buf_b), 4, 5)).wait() + if st != 0: + raise RuntimeError(f"recv2: {st}") + if bytes(buf_b[4:9]) != b"HELLO": + raise RuntimeError(f"follow-up: {bytes(buf_b[4:9])}") + ep_b.shutdown() + + def run_a(): + ep_a.connect(ep_b.endpoint_info()) + ctypes.memmove(ctypes.addressof(buf_a), b"LONGDATA", 8) + st = ep_a.async_send((ctypes.addressof(buf_a), 0, 8)).wait() + if st != 0: + raise RuntimeError(f"send1: {st}") + ctypes.memmove(ctypes.addressof(buf_a), b"HELLO", 5) + st = ep_a.async_send((ctypes.addressof(buf_a), 0, 5)).wait() + if st != 0: + raise RuntimeError(f"send2: {st}") + ep_a.shutdown() + + _sync_run("test_overflow_truncate", run_a, run_b) + + +def test_mr_name_validation(): + ep = TcpEndpoint(port=0) + buf = ctypes.create_string_buffer(32) + + h = ep.register_memory_region("valid", ctypes.addressof(buf), 0, 32) + if h < 0: + raise RuntimeError(f"valid name: {h}") + + h = ep.register_memory_region("", ctypes.addressof(buf), 0, 32) + if h != -1: + raise RuntimeError(f"empty name should return -1, got {h}") + + h = ep.register_memory_region("valid", ctypes.addressof(buf), 0, 32) + if h != -1: + raise RuntimeError(f"duplicate name should return -1, got {h}") + + ep.shutdown() + + +def test_connect_unreachable(): + ep = TcpEndpoint(port=10015) + unreachable = {"host": "127.0.0.1", "port": 65535, "mr_info": {}} + ep.connect(unreachable) + if ep.is_connected(): + raise RuntimeError("should not be connected") + ep.shutdown() + + +# ── parameterized torch tests (device="cpu" or "cuda") ── + + +def _make_tensor(shape, device, dtype, **kw): + """Create a tensor on the given device. CPU tensor for recv on cuda path + uses ctypes buffer so data_ptr() gives host pointer (needed for cudaMemcpy).""" + return torch.randn( + shape, + dtype=dtype, + device=device if isinstance(device, torch.device) else torch.device(device), + **kw, + ) + + +def test_torch_send_recv( + port_a: int = 0, + port_b: int = 0, + device="cpu", + dtype=torch.float32, + ip_a: str = "0.0.0.0", + ip_b: str = "0.0.0.0", +): + """Round-trip: A send full → B recv → B send slice → A recv.""" + SZ, SL = 32, 5 # elements + t_a = _make_tensor(SZ, device, dtype) + t_b = _make_tensor(SZ, device, dtype) + expected = t_a.clone() + n_bytes = SZ * 4 + sl_bytes = SL * 4 + + ep_a = TcpEndpoint(ip=ip_a, port=port_a) + ep_b = TcpEndpoint(ip=ip_b, port=port_b) + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + + def run_a(): + ep_a.connect(info_b) + time.sleep(5) + st = ep_a.async_send((t_a.data_ptr(), 0, n_bytes)).wait() + if st != 0: + raise RuntimeError(f"send: {st}") + st = ep_a.async_recv((t_a.data_ptr(), 10 * 4, sl_bytes)).wait() + if st != 0: + raise RuntimeError(f"recv: {st}") + if not torch.equal(t_a[10:15], t_b[20:25]): + raise RuntimeError("slice mismatch") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + st = ep_b.async_recv((t_b.data_ptr(), 0, n_bytes)).wait() + if st != 0: + raise RuntimeError(f"recv: {st}") + if not torch.equal(expected, t_b): + raise RuntimeError("full tensor mismatch") + time.sleep(5) + st = ep_b.async_send((t_b.data_ptr(), 20 * 4, sl_bytes)).wait() + if st != 0: + raise RuntimeError(f"send: {st}") + ep_b.shutdown() + + _sync_run(f"test_torch_send_recv_{device}", run_a, run_b, 120) + + +def test_torch_write( + port_a: int = 0, + port_b: int = 0, + device="cpu", + dtype=torch.float32, + ip_a: str = "0.0.0.0", + ip_b: str = "0.0.0.0", +): + """One-sided write: A async_write → B verifies data received.""" + SZ = 64 + t_a = _make_tensor(SZ, device, dtype) + t_b = _make_tensor(SZ, device, dtype) + expected = t_a.clone() + + n_bytes = SZ * 4 + + ep_a = TcpEndpoint(ip=ip_a, port=port_a) + ep_b = TcpEndpoint(ip=ip_b, port=port_b) + h_a = ep_a.register_memory_region("a", t_a.data_ptr(), 0, n_bytes) + h_b = ep_b.register_memory_region("b", t_b.data_ptr(), 0, n_bytes) + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + h_br = ep_a.register_remote_memory_region("rb", info_b["mr_info"]["b"]) + + def run_a(): + ep_a.connect(info_b) + st = ep_a.async_write([(h_a, h_br, 0, 0, n_bytes)]).wait() + if st != 0: + raise RuntimeError(f"write: {st}") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + for _ in range(40): + if torch.equal(expected, t_b): + break + time.sleep(0.5) + if not torch.equal(expected, t_b): + raise RuntimeError("write data not received") + ep_b.shutdown() + + _sync_run(f"test_torch_write_{device}", run_a, run_b) + + +def test_torch_read( + port_a: int = 0, + port_b: int = 0, + device="cpu", + dtype=torch.float32, + ip_a: str = "0.0.0.0", + ip_b: str = "0.0.0.0", +): + """One-sided read: B buffer pre-filled, A async_read and verifies.""" + dsize = 4 + SZ = 64 + t_a = _make_tensor(SZ, device, dtype) + t_b = _make_tensor(SZ, device, dtype) + expected = t_b.clone() + + n_bytes = SZ * dsize + + ep_a = TcpEndpoint(ip=ip_a, port=port_a) + ep_b = TcpEndpoint(ip=ip_b, port=port_b) + h_a = ep_a.register_memory_region("a", t_a.data_ptr(), 0, n_bytes) + h_b = ep_b.register_memory_region("b", t_b.data_ptr(), 0, n_bytes) + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + h_br = ep_a.register_remote_memory_region("rb", info_b["mr_info"]["b"]) + + def run_a(): + ep_a.connect(info_b) + st = ep_a.async_read([(h_a, h_br, 0, 0, n_bytes)]).wait() + if st != 0: + raise RuntimeError(f"read: {st}") + if not torch.equal(t_a, expected): + raise RuntimeError("read data mismatch") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + time.sleep(20) + ep_b.shutdown() + + _sync_run(f"test_torch_read_{device}", run_a, run_b) + + +def test_torch_write_batch( + port_a: int = 0, + port_b: int = 0, + device="cpu", + dtype=torch.float32, + n_batch=4, + ip_a: str = "0.0.0.0", + ip_b: str = "0.0.0.0", +): + """One async_write with multiple assignments.""" + dsize = 4 + SZ = 64 + t_a_batch = [_make_tensor(SZ, device, dtype) for i in range(n_batch)] + t_b_batch = [_make_tensor(SZ, device, dtype) for i in range(n_batch)] + expected_batch = [i.clone() for i in t_a_batch] + + n_bytes = SZ * dsize + + ep_a = TcpEndpoint(ip=ip_a, port=port_a) + ep_b = TcpEndpoint(ip=ip_b, port=port_b) + h_a_batch = [ + ep_a.register_memory_region(f"a_{i}", t_a_batch[i].data_ptr(), 0, n_bytes) + for i in range(n_batch) + ] + h_b_batch = [ + ep_b.register_memory_region(f"b_{i}", t_b_batch[i].data_ptr(), 0, n_bytes) + for i in range(n_batch) + ] + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + h_br_batch = [ + ep_a.register_remote_memory_region(f"rb_{i}", info_b["mr_info"][f"b_{i}"]) + for i in range(n_batch) + ] + + def run_a(): + ep_a.connect(info_b) + assigns = [ + (h_a_batch[i], h_br_batch[i], i * dsize, i * dsize, dsize) + for i in range(n_batch) + ] + st = ep_a.async_write(assigns).wait() + if st != 0: + raise RuntimeError(f"write batch: {st}") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + time.sleep(3) + for i in range(n_batch): + time.sleep(2) + if not torch.equal(t_b_batch[i][i], expected_batch[i][i]): + raise RuntimeError(f"batch {i}: mismatch") + ep_b.shutdown() + + _sync_run(f"test_torch_write_batch_{device}", run_a, run_b) + + +def test_torch_read_batch( + port_a: int = 0, + port_b: int = 0, + device="cpu", + dtype=torch.float32, + n_batch=4, + ip_a: str = "0.0.0.0", + ip_b: str = "0.0.0.0", +): + """One async_read with multiple assignments.""" + dsize = 4 + SZ = 64 + t_a_batch = [_make_tensor(SZ, device, dtype) for i in range(n_batch)] + t_b_batch = [_make_tensor(SZ, device, dtype) for i in range(n_batch)] + expected_batch = [i.clone() for i in t_b_batch] + + n_bytes = SZ * dsize + + ep_a = TcpEndpoint(ip=ip_a, port=port_a) + ep_b = TcpEndpoint(ip=ip_b, port=port_b) + h_a_batch = [ + ep_a.register_memory_region(f"a_{i}", t_a_batch[i].data_ptr(), 0, n_bytes) + for i in range(n_batch) + ] + h_b_batch = [ + ep_b.register_memory_region(f"b_{i}", t_b_batch[i].data_ptr(), 0, n_bytes) + for i in range(n_batch) + ] + info_a = ep_a.endpoint_info() + info_b = ep_b.endpoint_info() + h_br_batch = [ + ep_a.register_remote_memory_region(f"rb_{i}", info_b["mr_info"][f"b_{i}"]) + for i in range(n_batch) + ] + + def run_a(): + ep_a.connect(info_b) + assigns = [ + (h_a_batch[i], h_br_batch[i], i * dsize, i * dsize, dsize) + for i in range(n_batch) + ] + st = ep_a.async_read(assigns).wait() + if st != 0: + raise RuntimeError(f"read batch: {st}") + ep_a.shutdown() + + def run_b(): + ep_b.connect(info_a) + time.sleep(3) + for i in range(n_batch): + time.sleep(2) + if not torch.equal(t_a_batch[i][i], expected_batch[i][i]): + raise RuntimeError(f"batch {i}: mismatch") + ep_b.shutdown() + + _sync_run(f"test_torch_read_batch_{device}", run_a, run_b) + + +# ── main ───────────────────────────────────────────────── + + +def _alloc_port_kwargs(fn, **overrides): + n = _count_port_params(fn) + if n == 0: + return {} + + result = {} + pending = [] # (port_key, ip) pairs needing dynamic allocation + + for c in ["a", "b"][:n]: + port_key = f"port_{c}" + ip_key = f"ip_{c}" + ip = overrides.get(ip_key, _get_ip_default(fn, ip_key)) + + if port_key in overrides: + if not _port_free(ip, overrides[port_key]): + raise RuntimeError( + f"Port {overrides[port_key]} on {ip} is occupied " + f"({fn.__name__}, {port_key}={overrides[port_key]})" + ) + result[port_key] = overrides[port_key] + else: + pending.append((port_key, ip)) + + if pending: + ports = _find_free_ports(len(pending), [ip for _, ip in pending]) + for (port_key, _), port in zip(pending, ports): + result[port_key] = port + + return result + + +if __name__ == "__main__": + _ctypes_tests = [ + test_async_send_recv, + test_async_send2recv, + test_async_write, + test_async_read, + ] + for fn in _ctypes_tests: + fn(port_a=0, port_b=0) + + if not _torch_skip(): + device_list = ["cpu", "cuda"] + if _cuda_skip(): + print("No Cuda, Cpu Only", flush=True) + device_list = [ + "cpu", + ] + + _torch_tests = [ + test_torch_send_recv, + test_torch_write, + test_torch_read, + test_torch_write_batch, + test_torch_read_batch, + ] + for dev in device_list: + for fn in _torch_tests: + fn(device=dev, port_a=0, port_b=0)