Skip to content

Commit d925f83

Browse files
committed
[Feat] Asu: use host-pinned memory with dual CPU/device addresses for transport buffers
## Purpose Switch ASU send/flag buffers from plain host memory to host-pinned memory so that CPU code packs SQEs through the local mapping while HCOMM/RDMA uses the device-visible mapping of the same allocation. ## Modifications 1. BufferManager: allocate host-pinned memory via aclrtMallocHost + aclrtHostRegisterV2, obtain device pointer via aclrtHostGetDevicePointer. ScatterGatherEntry gains device_addr. RegisterMemory uses device address for host-pinned regions. 2. AsuTransportImpl: send/flag buffers use HOST_PINNED instead of HOST. 3. asu_submit_flow: pass device_addr to SendIoBatch. 4. sqe_request: use flagBuffer.device_addr for response_buffer_addr. 5. Move IsTransportBufferReady from asu_submit_flow to buffer_manager. 6. Tests: added host-pinned dual-address and device_addr assertions. ## Test - buffer_manager_test: HostPinnedRegistersDeviceAddress. - asu_submit_flow_test: BuildSubBatchSendBuffersUsesHostPinnedDeviceAddresses. - sqe_request_test: packed response address matches device_addr.
1 parent 85c5094 commit d925f83

14 files changed

Lines changed: 624 additions & 103 deletions

ucm/shared/trans/ascend/ascend_buffer.cc

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,39 @@
2323
* */
2424
#include "ascend_buffer.h"
2525
#include <acl/acl.h>
26+
#include <limits>
2627
#include <sys/mman.h>
2728
#include "logger/logger.h"
2829

2930
namespace UC::Trans {
3031

32+
namespace {
33+
34+
constexpr std::uintptr_t HOST_REGISTER_PAGE_SIZE = 4096;
35+
36+
void FreeHostMemory(void* host)
37+
{
38+
auto ret = aclrtFreeHost(host);
39+
if (ret != ACL_SUCCESS) { UC_ERROR("Failed to free host memory addr={} ret={}", host, ret); }
40+
}
41+
42+
void* AlignUp(void* ptr, std::uintptr_t alignment)
43+
{
44+
const auto addr = reinterpret_cast<std::uintptr_t>(ptr);
45+
return reinterpret_cast<void*>((addr + alignment - 1) / alignment * alignment);
46+
}
47+
48+
void ReleaseHostPinnedMemory(void* registeredHost, void* allocatedHost)
49+
{
50+
auto ret = aclrtHostUnregister(registeredHost);
51+
if (ret != ACL_SUCCESS) {
52+
UC_ERROR("Failed to unregister host-pinned memory addr={} ret={}", registeredHost, ret);
53+
}
54+
FreeHostMemory(allocatedHost);
55+
}
56+
57+
} // namespace
58+
3159
class HostHugePages : public std::enable_shared_from_this<HostHugePages> {
3260
struct ConstructorKey {};
3361
static constexpr auto HUGE_PAGE_SIZE = 2UL << 20;
@@ -124,6 +152,35 @@ std::shared_ptr<void> Trans::AscendBuffer::MakeHostBuffer(size_t size)
124152
return nullptr;
125153
}
126154

155+
std::shared_ptr<void> Trans::AscendBuffer::MakeHostPinnedBuffer(size_t size, void** pDevice)
156+
{
157+
if (pDevice) { *pDevice = nullptr; }
158+
159+
constexpr auto kMaxSize = std::numeric_limits<size_t>::max();
160+
if (size > kMaxSize - (HOST_REGISTER_PAGE_SIZE - 1)) { return nullptr; }
161+
162+
void* allocatedHost = nullptr;
163+
const auto allocationSize = size + HOST_REGISTER_PAGE_SIZE - 1;
164+
auto ret = aclrtMallocHost(&allocatedHost, allocationSize);
165+
if (ret != ACL_SUCCESS) { return nullptr; }
166+
167+
void* host = AlignUp(allocatedHost, HOST_REGISTER_PAGE_SIZE);
168+
169+
void* device = nullptr;
170+
auto status = Buffer::RegisterHostBuffer(host, size, &device);
171+
if (status.Failure()) {
172+
UC_ERROR("Failed to register host-pinned memory addr={} size={} status={}", host, size,
173+
status);
174+
FreeHostMemory(allocatedHost);
175+
return nullptr;
176+
}
177+
178+
if (pDevice) { *pDevice = device; }
179+
return std::shared_ptr<void>(host, [allocatedHost](void* registeredHost) {
180+
ReleaseHostPinnedMemory(registeredHost, allocatedHost);
181+
});
182+
}
183+
127184
std::shared_ptr<void> Trans::AscendBuffer::MakeHostBuffer4DirectIo(size_t size)
128185
{
129186
try {

ucm/shared/trans/ascend/ascend_buffer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class AscendBuffer : public ReservedBuffer {
3232
public:
3333
std::shared_ptr<void> MakeDeviceBuffer(size_t size) override;
3434
std::shared_ptr<void> MakeHostBuffer(size_t size) override;
35+
std::shared_ptr<void> MakeHostPinnedBuffer(size_t size, void** pDevice = nullptr) override;
3536
std::shared_ptr<void> MakeHostBuffer4DirectIo(size_t size) override;
3637
};
3738

ucm/shared/trans/buffer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class Buffer {
3939

4040
virtual std::shared_ptr<void> MakeHostBuffer(size_t size) = 0;
4141
virtual std::shared_ptr<void> MakeHostBuffer4DirectIo(size_t size) = 0;
42+
virtual std::shared_ptr<void> MakeHostPinnedBuffer(size_t size, void** pDevice = nullptr) = 0;
4243
virtual Status MakeHostBuffers(size_t size, size_t number) = 0;
4344
virtual std::shared_ptr<void> GetHostBuffer(size_t size) = 0;
4445

ucm/shared/trans/detail/reserved_buffer.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ class ReservedBuffer : public Buffer {
7878
return this->MakeHostBuffer(size);
7979
}
8080

81+
std::shared_ptr<void> MakeHostPinnedBuffer(size_t size, void** pDevice = nullptr) override
82+
{
83+
if (pDevice) { *pDevice = nullptr; }
84+
return this->MakeHostBuffer(size);
85+
}
86+
8187
Status MakeHostBuffers(size_t size, size_t number) override
8288
{
8389
auto totalSize = size * number;

ucm/transport/kv/asu/test/case/buffer_manager_test.cc

Lines changed: 107 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,20 @@ TEST_F(BufferManagerTest, InitWithZeroSlotNum)
7171
ASSERT_EQ(status.code, StatusCode::INVALID_ARGUMENT);
7272
}
7373

74+
TEST_F(BufferManagerTest, InitHostWithUnalignedSlotCapacity)
75+
{
76+
BufferManager mgr;
77+
auto status = mgr.Init("test_buffer", MemoryType::HOST, 1000, 10);
78+
ASSERT_TRUE(status.ok()) << status.message;
79+
}
80+
81+
TEST_F(BufferManagerTest, InitDeviceWithUnalignedSlotCapacity)
82+
{
83+
BufferManager mgr;
84+
auto status = mgr.Init("test_buffer", MemoryType::ASCEND_DEVICE, 1000, 10);
85+
ASSERT_TRUE(status.ok()) << status.message;
86+
}
87+
7488
TEST_F(BufferManagerTest, DoubleInit)
7589
{
7690
BufferManager mgr;
@@ -123,12 +137,13 @@ TEST_F(BufferManagerTest, SingleAllocateAndFree)
123137
ScatterGatherEntry sge;
124138
status = mgr.Allocate(64, sge);
125139
ASSERT_TRUE(status.ok()) << status.message;
126-
ASSERT_NE(sge.addr, 0);
140+
ASSERT_NE(sge.local_addr, 0);
127141
ASSERT_EQ(sge.length, 64);
128142
ASSERT_EQ(sge.tokenId, 0);
129143
ASSERT_NE(sge.slot_index, UINT32_MAX);
144+
ASSERT_EQ(sge.memory_type, MemoryType::HOST);
130145

131-
auto* ptr = reinterpret_cast<void*>(sge.addr);
146+
auto* ptr = reinterpret_cast<void*>(sge.local_addr);
132147
std::memset(ptr, 0xAB, 64);
133148

134149
status = mgr.Free(sge.slot_index);
@@ -147,12 +162,12 @@ TEST_F(BufferManagerTest, MultipleAllocatesAndFrees)
147162
for (int i = 0; i < kCount; ++i) {
148163
status = mgr.Allocate(128, sges[i]);
149164
ASSERT_TRUE(status.ok()) << "Failed at i=" << i << ": " << status.message;
150-
ASSERT_NE(sges[i].addr, 0);
151-
std::memset(reinterpret_cast<void*>(sges[i].addr), i, 128);
165+
ASSERT_NE(sges[i].local_addr, 0);
166+
std::memset(reinterpret_cast<void*>(sges[i].local_addr), i, 128);
152167
}
153168

154169
for (int i = 0; i < kCount; ++i) {
155-
auto* data = reinterpret_cast<unsigned char*>(sges[i].addr);
170+
auto* data = reinterpret_cast<unsigned char*>(sges[i].local_addr);
156171
for (int j = 0; j < 128; ++j) { ASSERT_EQ(data[j], static_cast<unsigned char>(i)); }
157172
}
158173

@@ -191,10 +206,61 @@ TEST_F(BufferManagerTest, AllocateFullSlotSize)
191206
status = mgr.Allocate(1024, sge);
192207
ASSERT_TRUE(status.ok()) << status.message;
193208
ASSERT_EQ(sge.length, 1024);
209+
}
210+
211+
TEST_F(BufferManagerTest, AllocateFull4160ByteSlotCapacity)
212+
{
213+
BufferManager mgr;
214+
auto status = mgr.Init("test_buffer", MemoryType::HOST, 4160, 10);
215+
ASSERT_TRUE(status.ok());
194216

195-
std::memset(reinterpret_cast<void*>(sge.addr), 0xFF, 1024);
217+
ScatterGatherEntry sge;
218+
status = mgr.Allocate(4160, sge);
219+
ASSERT_TRUE(status.ok()) << status.message;
220+
ASSERT_EQ(sge.length, 4160);
221+
}
196222

197-
mgr.Free(sge.slot_index);
223+
TEST_F(BufferManagerTest, AllocateExceeds4160ByteSlotCapacity)
224+
{
225+
BufferManager mgr;
226+
auto status = mgr.Init("test_buffer", MemoryType::HOST, 4160, 10);
227+
ASSERT_TRUE(status.ok());
228+
229+
ScatterGatherEntry sge;
230+
status = mgr.Allocate(4161, sge);
231+
ASSERT_FALSE(status.ok());
232+
ASSERT_EQ(status.code, StatusCode::INVALID_ARGUMENT);
233+
}
234+
235+
TEST_F(BufferManagerTest, AllMemoryTypesUseAlignedSlotStride)
236+
{
237+
for (const auto type : {MemoryType::HOST, MemoryType::HOST_PINNED, MemoryType::ASCEND_DEVICE}) {
238+
BufferManager mgr;
239+
auto status = mgr.Init("test_buffer", type, 4160, 2);
240+
ASSERT_TRUE(status.ok()) << status.message;
241+
242+
ScatterGatherEntry first;
243+
ScatterGatherEntry second;
244+
ASSERT_TRUE(mgr.Allocate(4160, first).ok());
245+
ASSERT_TRUE(mgr.Allocate(4160, second).ok());
246+
ASSERT_EQ(second.local_addr - first.local_addr, 4160);
247+
ASSERT_EQ(second.device_addr - first.device_addr, 4160);
248+
}
249+
}
250+
251+
TEST_F(BufferManagerTest, FlagBufferCapacity71Uses128ByteStride)
252+
{
253+
BufferManager mgr;
254+
auto status = mgr.Init("flag_buffer", MemoryType::HOST_PINNED, 71, 2);
255+
ASSERT_TRUE(status.ok()) << status.message;
256+
257+
ScatterGatherEntry first;
258+
ScatterGatherEntry second;
259+
ASSERT_TRUE(mgr.Allocate(71, first).ok());
260+
ASSERT_TRUE(mgr.Allocate(71, second).ok());
261+
ASSERT_EQ(first.length, 71);
262+
ASSERT_EQ(second.local_addr - first.local_addr, 128);
263+
ASSERT_EQ(second.device_addr - first.device_addr, 128);
198264
}
199265

200266
TEST_F(BufferManagerTest, ReuseAfterFree)
@@ -212,7 +278,7 @@ TEST_F(BufferManagerTest, ReuseAfterFree)
212278
ScatterGatherEntry sge2;
213279
status = mgr.Allocate(64, sge2);
214280
ASSERT_TRUE(status.ok());
215-
ASSERT_EQ(sge2.addr, sge1.addr);
281+
ASSERT_EQ(sge2.local_addr, sge1.local_addr);
216282
ASSERT_EQ(sge2.slot_index, sge1.slot_index);
217283

218284
mgr.Free(sge2.slot_index);
@@ -233,7 +299,7 @@ TEST_F(BufferManagerTest, ConcurrentAllocateAndFree)
233299
auto s = mgr.Allocate(64, sge);
234300
ASSERT_TRUE(s.ok()) << "Thread " << thread_id << " op " << i << ": " << s.message;
235301

236-
std::memset(reinterpret_cast<void*>(sge.addr), thread_id, 64);
302+
std::memset(reinterpret_cast<void*>(sge.local_addr), thread_id, 64);
237303

238304
s = mgr.Free(sge.slot_index);
239305
ASSERT_TRUE(s.ok()) << s.message;
@@ -260,10 +326,10 @@ TEST_F(BufferManagerTest, ConcurrentStressTest)
260326
auto s = mgr.Allocate(128, sge);
261327
ASSERT_TRUE(s.ok());
262328

263-
std::memset(reinterpret_cast<void*>(sge.addr), thread_id, 128);
329+
std::memset(reinterpret_cast<void*>(sge.local_addr), thread_id, 128);
264330

265331
for (int j = 0; j < 128; ++j) {
266-
ASSERT_EQ(reinterpret_cast<unsigned char*>(sge.addr)[j], thread_id);
332+
ASSERT_EQ(reinterpret_cast<unsigned char*>(sge.local_addr)[j], thread_id);
267333
}
268334

269335
s = mgr.Free(sge.slot_index);
@@ -286,7 +352,7 @@ TEST_F(BufferManagerTest, FreeZeroesMemory)
286352
status = mgr.Allocate(64, sge1);
287353
ASSERT_TRUE(status.ok());
288354

289-
auto* ptr = reinterpret_cast<uint8_t*>(sge1.addr);
355+
auto* ptr = reinterpret_cast<uint8_t*>(sge1.local_addr);
290356
std::memset(ptr, 0xAB, 1024);
291357

292358
status = mgr.Free(sge1.slot_index);
@@ -295,10 +361,10 @@ TEST_F(BufferManagerTest, FreeZeroesMemory)
295361
ScatterGatherEntry sge2;
296362
status = mgr.Allocate(64, sge2);
297363
ASSERT_TRUE(status.ok());
298-
ASSERT_EQ(sge2.addr, sge1.addr);
364+
ASSERT_EQ(sge2.local_addr, sge1.local_addr);
299365
ASSERT_EQ(sge2.slot_index, sge1.slot_index);
300366

301-
auto* ptr2 = reinterpret_cast<uint8_t*>(sge2.addr);
367+
auto* ptr2 = reinterpret_cast<uint8_t*>(sge2.local_addr);
302368
for (size_t i = 0; i < 1024; ++i) {
303369
ASSERT_EQ(ptr2[i], 0) << "byte " << i << " not zeroed after free";
304370
}
@@ -395,6 +461,33 @@ TEST_F(BufferManagerTest, InitWithProviderRegistersMemory)
395461
ASSERT_NE(provider.lastAddr, 0);
396462
ASSERT_EQ(provider.lastSize, 1024 * 10);
397463
ASSERT_EQ(mgr.GetTokenId(), 42);
464+
465+
ScatterGatherEntry sge;
466+
ASSERT_TRUE(mgr.Allocate(64, sge).ok());
467+
ASSERT_EQ(sge.local_addr, sge.device_addr);
468+
}
469+
470+
TEST_F(BufferManagerTest, HostPinnedRegistersDeviceAddress)
471+
{
472+
StubTransProvider provider;
473+
474+
BufferManager mgr;
475+
auto status = mgr.Init("test_rdma_pinned", MemoryType::HOST_PINNED, 4096, 1, &provider);
476+
ASSERT_TRUE(status.ok()) << status.message;
477+
ASSERT_EQ(provider.registerCount, 1);
478+
ASSERT_EQ(provider.lastMemType, TransProvider::MemType::MEM_DEVICE);
479+
480+
ScatterGatherEntry sge;
481+
ASSERT_TRUE(mgr.Allocate(64, sge).ok());
482+
ASSERT_NE(sge.local_addr, 0);
483+
ASSERT_NE(sge.device_addr, 0);
484+
ASSERT_NE(sge.local_addr, sge.device_addr);
485+
ASSERT_EQ(sge.local_addr % 4096, 0);
486+
ASSERT_EQ(provider.lastAddr, sge.device_addr);
487+
488+
// The CPU writes through addr while HCOMM and remote RDMA use device_addr.
489+
std::memset(reinterpret_cast<void*>(sge.local_addr), 0x5A, sge.length);
490+
ASSERT_EQ(*reinterpret_cast<unsigned char*>(sge.local_addr), 0x5A);
398491
}
399492

400493
TEST_F(BufferManagerTest, InitWithProviderAllocateReturnsTokenId)

ucm/transport/kv/asu/trans/include/asu_transport/asu_transport.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,12 @@ struct TransportConfig {
7272
bool preconnect{true};
7373
bool bindCqPoller{true};
7474

75+
// Slot sizes are caller-visible capacities; BufferManager computes the
76+
// aligned physical stride used for allocation and memory registration.
7577
std::size_t sendBufferSlotSize{4160};
7678
std::size_t sendBufferSlotNum{128};
77-
std::size_t flagBufferSlotSize{128};
79+
// Maximum memory required by a batch store/retrieve response flag buffer.
80+
std::size_t flagBufferSlotSize{71};
7881
std::size_t flagBufferSlotNum{4096};
7982
std::size_t asuBatchLoadIoNum{110};
8083
std::size_t asuBatchStoreIoNum{110};

ucm/transport/kv/asu/trans/src/asu_submit_flow.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -130,24 +130,31 @@ Status AsuTransportImpl::BuildSubBatchSendBuffers(
130130
continue;
131131
}
132132

133-
if (subBatchContext.flagBuffer.addr == 0 || subBatchContext.flagBuffer.length == 0) {
134-
const auto subBatchStatus =
135-
Status::Error(StatusCode::NOT_INITIALIZED, "sub-batch flag buffer is not ready");
133+
if (subBatchContext.channel == nullptr ||
134+
!IsTransportBufferReady(subBatchContext.sendSge) ||
135+
!IsTransportBufferReady(subBatchContext.flagBuffer)) {
136+
const auto subBatchStatus = Status::Error(StatusCode::NOT_INITIALIZED,
137+
"sub-batch transport buffers are not ready");
136138
UC_ERROR(
137-
"Sub-batch flag buffer is not ready index={} cid={} flag_addr={} flag_length={}",
138-
index, subBatchContext.cid, subBatchContext.flagBuffer.addr,
139-
subBatchContext.flagBuffer.length);
139+
"Sub-batch transport buffers are not ready index={} cid={} channel={} "
140+
"send_local_addr={} send_device_addr={} send_length={} send_slot={} "
141+
"flag_local_addr={} flag_device_addr={} flag_length={} flag_slot={}",
142+
index, subBatchContext.cid, subBatchContext.channel != nullptr,
143+
subBatchContext.sendSge.local_addr, subBatchContext.sendSge.device_addr,
144+
subBatchContext.sendSge.length, subBatchContext.sendSge.slot_index,
145+
subBatchContext.flagBuffer.local_addr, subBatchContext.flagBuffer.device_addr,
146+
subBatchContext.flagBuffer.length, subBatchContext.flagBuffer.slot_index);
140147
SetSubBatchSendFailed(subBatchContext, subBatchStatus);
141148
if (status.ok()) { status = subBatchStatus; }
142149
ReleaseSubBatchResources(subBatchContext);
143150
continue;
144151
}
145152

146-
ioBatches.push_back(
147-
TransProvider::SendIoBatch{subBatchContext.channel->GetConnection(),
148-
reinterpret_cast<void*>(subBatchContext.sendSge.addr),
149-
reinterpret_cast<void*>(subBatchContext.flagBuffer.addr),
150-
subBatchContext.sendSge.length});
153+
ioBatches.push_back(TransProvider::SendIoBatch{
154+
subBatchContext.channel->GetConnection(),
155+
reinterpret_cast<void*>(subBatchContext.sendSge.device_addr),
156+
reinterpret_cast<void*>(subBatchContext.flagBuffer.device_addr),
157+
subBatchContext.sendSge.length});
151158
subBatchIndexes.emplace_back(index);
152159
}
153160

0 commit comments

Comments
 (0)