Skip to content

Commit 02f205c

Browse files
committed
[Feat] Asu: use host-pinned memory with dual CPU/device addresses for transport buffers
## Purpose Switch ASU send/flag buffers from plain host memory to host-pinned memory so that CPU code packs SQEs through the local mapping while HCOMM/RDMA uses the device-visible mapping of the same allocation. ## Modifications 1. BufferManager: allocate host-pinned memory via aclrtMallocHost + aclrtHostRegisterV2, obtain device pointer via aclrtHostGetDevicePointer. ScatterGatherEntry gains device_addr. RegisterMemory uses device address for host-pinned regions. 2. AsuTransportImpl: send/flag buffers use HOST_PINNED instead of HOST. 3. asu_submit_flow: pass device_addr to SendIoBatch. 4. sqe_request: use flagBuffer.device_addr for response_buffer_addr. 5. Move IsTransportBufferReady from asu_submit_flow to buffer_manager. 6. Tests: added host-pinned dual-address and device_addr assertions. ## Test - buffer_manager_test: HostPinnedRegistersDeviceAddress. - asu_submit_flow_test: BuildSubBatchSendBuffersUsesHostPinnedDeviceAddresses. - sqe_request_test: packed response address matches device_addr.
1 parent 85c5094 commit 02f205c

12 files changed

Lines changed: 608 additions & 103 deletions

File tree

ucm/shared/trans/ascend/ascend_buffer.cc

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,27 @@
2828

2929
namespace UC::Trans {
3030

31+
namespace {
32+
33+
constexpr std::uintptr_t HOST_REGISTER_PAGE_SIZE = 4096;
34+
35+
void FreeHostMemory(void* host)
36+
{
37+
auto ret = aclrtFreeHost(host);
38+
if (ret != ACL_SUCCESS) { UC_ERROR("Failed to free host memory addr={} ret={}", host, ret); }
39+
}
40+
41+
void ReleaseHostPinnedMemory(void* host)
42+
{
43+
auto ret = aclrtHostUnregister(host);
44+
if (ret != ACL_SUCCESS) {
45+
UC_ERROR("Failed to unregister host-pinned memory addr={} ret={}", host, ret);
46+
}
47+
FreeHostMemory(host);
48+
}
49+
50+
} // namespace
51+
3152
class HostHugePages : public std::enable_shared_from_this<HostHugePages> {
3253
struct ConstructorKey {};
3354
static constexpr auto HUGE_PAGE_SIZE = 2UL << 20;
@@ -124,6 +145,33 @@ std::shared_ptr<void> Trans::AscendBuffer::MakeHostBuffer(size_t size)
124145
return nullptr;
125146
}
126147

148+
std::shared_ptr<void> Trans::AscendBuffer::MakeHostPinnedBuffer(size_t size, void** pDevice)
149+
{
150+
if (pDevice) { *pDevice = nullptr; }
151+
152+
void* host = nullptr;
153+
auto ret = aclrtMallocHost(&host, size);
154+
if (ret != ACL_SUCCESS) { return nullptr; }
155+
156+
if (reinterpret_cast<std::uintptr_t>(host) % HOST_REGISTER_PAGE_SIZE != 0) {
157+
UC_ERROR("Host-pinned memory is not 4K page aligned addr={} size={}", host, size);
158+
FreeHostMemory(host);
159+
return nullptr;
160+
}
161+
162+
void* device = nullptr;
163+
auto status = Buffer::RegisterHostBuffer(host, size, &device);
164+
if (status.Failure()) {
165+
UC_ERROR("Failed to register host-pinned memory addr={} size={} status={}", host, size,
166+
status);
167+
FreeHostMemory(host);
168+
return nullptr;
169+
}
170+
171+
if (pDevice) { *pDevice = device; }
172+
return std::shared_ptr<void>(host, ReleaseHostPinnedMemory);
173+
}
174+
127175
std::shared_ptr<void> Trans::AscendBuffer::MakeHostBuffer4DirectIo(size_t size)
128176
{
129177
try {

ucm/shared/trans/ascend/ascend_buffer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class AscendBuffer : public ReservedBuffer {
3232
public:
3333
std::shared_ptr<void> MakeDeviceBuffer(size_t size) override;
3434
std::shared_ptr<void> MakeHostBuffer(size_t size) override;
35+
std::shared_ptr<void> MakeHostPinnedBuffer(size_t size, void** pDevice = nullptr);
3536
std::shared_ptr<void> MakeHostBuffer4DirectIo(size_t size) override;
3637
};
3738

ucm/transport/kv/asu/test/case/buffer_manager_test.cc

Lines changed: 107 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,20 @@ TEST_F(BufferManagerTest, InitWithZeroSlotNum)
7171
ASSERT_EQ(status.code, StatusCode::INVALID_ARGUMENT);
7272
}
7373

74+
TEST_F(BufferManagerTest, InitHostWithUnalignedSlotCapacity)
75+
{
76+
BufferManager mgr;
77+
auto status = mgr.Init("test_buffer", MemoryType::HOST, 1000, 10);
78+
ASSERT_TRUE(status.ok()) << status.message;
79+
}
80+
81+
TEST_F(BufferManagerTest, InitDeviceWithUnalignedSlotCapacity)
82+
{
83+
BufferManager mgr;
84+
auto status = mgr.Init("test_buffer", MemoryType::ASCEND_DEVICE, 1000, 10);
85+
ASSERT_TRUE(status.ok()) << status.message;
86+
}
87+
7488
TEST_F(BufferManagerTest, DoubleInit)
7589
{
7690
BufferManager mgr;
@@ -123,12 +137,13 @@ TEST_F(BufferManagerTest, SingleAllocateAndFree)
123137
ScatterGatherEntry sge;
124138
status = mgr.Allocate(64, sge);
125139
ASSERT_TRUE(status.ok()) << status.message;
126-
ASSERT_NE(sge.addr, 0);
140+
ASSERT_NE(sge.local_addr, 0);
127141
ASSERT_EQ(sge.length, 64);
128142
ASSERT_EQ(sge.tokenId, 0);
129143
ASSERT_NE(sge.slot_index, UINT32_MAX);
144+
ASSERT_EQ(sge.memory_type, MemoryType::HOST);
130145

131-
auto* ptr = reinterpret_cast<void*>(sge.addr);
146+
auto* ptr = reinterpret_cast<void*>(sge.local_addr);
132147
std::memset(ptr, 0xAB, 64);
133148

134149
status = mgr.Free(sge.slot_index);
@@ -147,12 +162,12 @@ TEST_F(BufferManagerTest, MultipleAllocatesAndFrees)
147162
for (int i = 0; i < kCount; ++i) {
148163
status = mgr.Allocate(128, sges[i]);
149164
ASSERT_TRUE(status.ok()) << "Failed at i=" << i << ": " << status.message;
150-
ASSERT_NE(sges[i].addr, 0);
151-
std::memset(reinterpret_cast<void*>(sges[i].addr), i, 128);
165+
ASSERT_NE(sges[i].local_addr, 0);
166+
std::memset(reinterpret_cast<void*>(sges[i].local_addr), i, 128);
152167
}
153168

154169
for (int i = 0; i < kCount; ++i) {
155-
auto* data = reinterpret_cast<unsigned char*>(sges[i].addr);
170+
auto* data = reinterpret_cast<unsigned char*>(sges[i].local_addr);
156171
for (int j = 0; j < 128; ++j) { ASSERT_EQ(data[j], static_cast<unsigned char>(i)); }
157172
}
158173

@@ -191,10 +206,61 @@ TEST_F(BufferManagerTest, AllocateFullSlotSize)
191206
status = mgr.Allocate(1024, sge);
192207
ASSERT_TRUE(status.ok()) << status.message;
193208
ASSERT_EQ(sge.length, 1024);
209+
}
210+
211+
TEST_F(BufferManagerTest, AllocateFull4160ByteSlotCapacity)
212+
{
213+
BufferManager mgr;
214+
auto status = mgr.Init("test_buffer", MemoryType::HOST, 4160, 10);
215+
ASSERT_TRUE(status.ok());
194216

195-
std::memset(reinterpret_cast<void*>(sge.addr), 0xFF, 1024);
217+
ScatterGatherEntry sge;
218+
status = mgr.Allocate(4160, sge);
219+
ASSERT_TRUE(status.ok()) << status.message;
220+
ASSERT_EQ(sge.length, 4160);
221+
}
196222

197-
mgr.Free(sge.slot_index);
223+
TEST_F(BufferManagerTest, AllocateExceeds4160ByteSlotCapacity)
224+
{
225+
BufferManager mgr;
226+
auto status = mgr.Init("test_buffer", MemoryType::HOST, 4160, 10);
227+
ASSERT_TRUE(status.ok());
228+
229+
ScatterGatherEntry sge;
230+
status = mgr.Allocate(4161, sge);
231+
ASSERT_FALSE(status.ok());
232+
ASSERT_EQ(status.code, StatusCode::INVALID_ARGUMENT);
233+
}
234+
235+
TEST_F(BufferManagerTest, AllMemoryTypesUseAlignedSlotStride)
236+
{
237+
for (const auto type : {MemoryType::HOST, MemoryType::HOST_PINNED, MemoryType::ASCEND_DEVICE}) {
238+
BufferManager mgr;
239+
auto status = mgr.Init("test_buffer", type, 4160, 2);
240+
ASSERT_TRUE(status.ok()) << status.message;
241+
242+
ScatterGatherEntry first;
243+
ScatterGatherEntry second;
244+
ASSERT_TRUE(mgr.Allocate(4160, first).ok());
245+
ASSERT_TRUE(mgr.Allocate(4160, second).ok());
246+
ASSERT_EQ(second.local_addr - first.local_addr, 4160);
247+
ASSERT_EQ(second.device_addr - first.device_addr, 4160);
248+
}
249+
}
250+
251+
TEST_F(BufferManagerTest, FlagBufferCapacity71Uses128ByteStride)
252+
{
253+
BufferManager mgr;
254+
auto status = mgr.Init("flag_buffer", MemoryType::HOST_PINNED, 71, 2);
255+
ASSERT_TRUE(status.ok()) << status.message;
256+
257+
ScatterGatherEntry first;
258+
ScatterGatherEntry second;
259+
ASSERT_TRUE(mgr.Allocate(71, first).ok());
260+
ASSERT_TRUE(mgr.Allocate(71, second).ok());
261+
ASSERT_EQ(first.length, 71);
262+
ASSERT_EQ(second.local_addr - first.local_addr, 128);
263+
ASSERT_EQ(second.device_addr - first.device_addr, 128);
198264
}
199265

200266
TEST_F(BufferManagerTest, ReuseAfterFree)
@@ -212,7 +278,7 @@ TEST_F(BufferManagerTest, ReuseAfterFree)
212278
ScatterGatherEntry sge2;
213279
status = mgr.Allocate(64, sge2);
214280
ASSERT_TRUE(status.ok());
215-
ASSERT_EQ(sge2.addr, sge1.addr);
281+
ASSERT_EQ(sge2.local_addr, sge1.local_addr);
216282
ASSERT_EQ(sge2.slot_index, sge1.slot_index);
217283

218284
mgr.Free(sge2.slot_index);
@@ -233,7 +299,7 @@ TEST_F(BufferManagerTest, ConcurrentAllocateAndFree)
233299
auto s = mgr.Allocate(64, sge);
234300
ASSERT_TRUE(s.ok()) << "Thread " << thread_id << " op " << i << ": " << s.message;
235301

236-
std::memset(reinterpret_cast<void*>(sge.addr), thread_id, 64);
302+
std::memset(reinterpret_cast<void*>(sge.local_addr), thread_id, 64);
237303

238304
s = mgr.Free(sge.slot_index);
239305
ASSERT_TRUE(s.ok()) << s.message;
@@ -260,10 +326,10 @@ TEST_F(BufferManagerTest, ConcurrentStressTest)
260326
auto s = mgr.Allocate(128, sge);
261327
ASSERT_TRUE(s.ok());
262328

263-
std::memset(reinterpret_cast<void*>(sge.addr), thread_id, 128);
329+
std::memset(reinterpret_cast<void*>(sge.local_addr), thread_id, 128);
264330

265331
for (int j = 0; j < 128; ++j) {
266-
ASSERT_EQ(reinterpret_cast<unsigned char*>(sge.addr)[j], thread_id);
332+
ASSERT_EQ(reinterpret_cast<unsigned char*>(sge.local_addr)[j], thread_id);
267333
}
268334

269335
s = mgr.Free(sge.slot_index);
@@ -286,7 +352,7 @@ TEST_F(BufferManagerTest, FreeZeroesMemory)
286352
status = mgr.Allocate(64, sge1);
287353
ASSERT_TRUE(status.ok());
288354

289-
auto* ptr = reinterpret_cast<uint8_t*>(sge1.addr);
355+
auto* ptr = reinterpret_cast<uint8_t*>(sge1.local_addr);
290356
std::memset(ptr, 0xAB, 1024);
291357

292358
status = mgr.Free(sge1.slot_index);
@@ -295,10 +361,10 @@ TEST_F(BufferManagerTest, FreeZeroesMemory)
295361
ScatterGatherEntry sge2;
296362
status = mgr.Allocate(64, sge2);
297363
ASSERT_TRUE(status.ok());
298-
ASSERT_EQ(sge2.addr, sge1.addr);
364+
ASSERT_EQ(sge2.local_addr, sge1.local_addr);
299365
ASSERT_EQ(sge2.slot_index, sge1.slot_index);
300366

301-
auto* ptr2 = reinterpret_cast<uint8_t*>(sge2.addr);
367+
auto* ptr2 = reinterpret_cast<uint8_t*>(sge2.local_addr);
302368
for (size_t i = 0; i < 1024; ++i) {
303369
ASSERT_EQ(ptr2[i], 0) << "byte " << i << " not zeroed after free";
304370
}
@@ -395,6 +461,33 @@ TEST_F(BufferManagerTest, InitWithProviderRegistersMemory)
395461
ASSERT_NE(provider.lastAddr, 0);
396462
ASSERT_EQ(provider.lastSize, 1024 * 10);
397463
ASSERT_EQ(mgr.GetTokenId(), 42);
464+
465+
ScatterGatherEntry sge;
466+
ASSERT_TRUE(mgr.Allocate(64, sge).ok());
467+
ASSERT_EQ(sge.local_addr, sge.device_addr);
468+
}
469+
470+
TEST_F(BufferManagerTest, HostPinnedRegistersDeviceAddress)
471+
{
472+
StubTransProvider provider;
473+
474+
BufferManager mgr;
475+
auto status = mgr.Init("test_rdma_pinned", MemoryType::HOST_PINNED, 4096, 1, &provider);
476+
ASSERT_TRUE(status.ok()) << status.message;
477+
ASSERT_EQ(provider.registerCount, 1);
478+
ASSERT_EQ(provider.lastMemType, TransProvider::MemType::MEM_DEVICE);
479+
480+
ScatterGatherEntry sge;
481+
ASSERT_TRUE(mgr.Allocate(64, sge).ok());
482+
ASSERT_NE(sge.local_addr, 0);
483+
ASSERT_NE(sge.device_addr, 0);
484+
ASSERT_NE(sge.local_addr, sge.device_addr);
485+
ASSERT_EQ(sge.local_addr % 4096, 0);
486+
ASSERT_EQ(provider.lastAddr, sge.device_addr);
487+
488+
// The CPU writes through addr while HCOMM and remote RDMA use device_addr.
489+
std::memset(reinterpret_cast<void*>(sge.local_addr), 0x5A, sge.length);
490+
ASSERT_EQ(*reinterpret_cast<unsigned char*>(sge.local_addr), 0x5A);
398491
}
399492

400493
TEST_F(BufferManagerTest, InitWithProviderAllocateReturnsTokenId)

ucm/transport/kv/asu/trans/include/asu_transport/asu_transport.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,12 @@ struct TransportConfig {
7272
bool preconnect{true};
7373
bool bindCqPoller{true};
7474

75+
// Slot sizes are caller-visible capacities; BufferManager computes the
76+
// aligned physical stride used for allocation and memory registration.
7577
std::size_t sendBufferSlotSize{4160};
7678
std::size_t sendBufferSlotNum{128};
77-
std::size_t flagBufferSlotSize{128};
79+
// Maximum memory required by a batch store/retrieve response flag buffer.
80+
std::size_t flagBufferSlotSize{71};
7881
std::size_t flagBufferSlotNum{4096};
7982
std::size_t asuBatchLoadIoNum{110};
8083
std::size_t asuBatchStoreIoNum{110};

ucm/transport/kv/asu/trans/src/asu_submit_flow.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -130,24 +130,31 @@ Status AsuTransportImpl::BuildSubBatchSendBuffers(
130130
continue;
131131
}
132132

133-
if (subBatchContext.flagBuffer.addr == 0 || subBatchContext.flagBuffer.length == 0) {
134-
const auto subBatchStatus =
135-
Status::Error(StatusCode::NOT_INITIALIZED, "sub-batch flag buffer is not ready");
133+
if (subBatchContext.channel == nullptr ||
134+
!IsTransportBufferReady(subBatchContext.sendSge) ||
135+
!IsTransportBufferReady(subBatchContext.flagBuffer)) {
136+
const auto subBatchStatus = Status::Error(StatusCode::NOT_INITIALIZED,
137+
"sub-batch transport buffers are not ready");
136138
UC_ERROR(
137-
"Sub-batch flag buffer is not ready index={} cid={} flag_addr={} flag_length={}",
138-
index, subBatchContext.cid, subBatchContext.flagBuffer.addr,
139-
subBatchContext.flagBuffer.length);
139+
"Sub-batch transport buffers are not ready index={} cid={} channel={} "
140+
"send_local_addr={} send_device_addr={} send_length={} send_slot={} "
141+
"flag_local_addr={} flag_device_addr={} flag_length={} flag_slot={}",
142+
index, subBatchContext.cid, subBatchContext.channel != nullptr,
143+
subBatchContext.sendSge.local_addr, subBatchContext.sendSge.device_addr,
144+
subBatchContext.sendSge.length, subBatchContext.sendSge.slot_index,
145+
subBatchContext.flagBuffer.local_addr, subBatchContext.flagBuffer.device_addr,
146+
subBatchContext.flagBuffer.length, subBatchContext.flagBuffer.slot_index);
140147
SetSubBatchSendFailed(subBatchContext, subBatchStatus);
141148
if (status.ok()) { status = subBatchStatus; }
142149
ReleaseSubBatchResources(subBatchContext);
143150
continue;
144151
}
145152

146-
ioBatches.push_back(
147-
TransProvider::SendIoBatch{subBatchContext.channel->GetConnection(),
148-
reinterpret_cast<void*>(subBatchContext.sendSge.addr),
149-
reinterpret_cast<void*>(subBatchContext.flagBuffer.addr),
150-
subBatchContext.sendSge.length});
153+
ioBatches.push_back(TransProvider::SendIoBatch{
154+
subBatchContext.channel->GetConnection(),
155+
reinterpret_cast<void*>(subBatchContext.sendSge.device_addr),
156+
reinterpret_cast<void*>(subBatchContext.flagBuffer.device_addr),
157+
subBatchContext.sendSge.length});
151158
subBatchIndexes.emplace_back(index);
152159
}
153160

0 commit comments

Comments
 (0)