Skip to content

Commit 88298b3

Browse files
committed
[Feat] Asu: use host-pinned transport buffers
1 parent d0b4363 commit 88298b3

10 files changed

Lines changed: 437 additions & 99 deletions

File tree

ucm/transport/kv/asu/test/case/buffer_manager_test.cc

Lines changed: 107 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,20 @@ TEST_F(BufferManagerTest, InitWithZeroSlotNum)
7171
ASSERT_EQ(status.code, StatusCode::INVALID_ARGUMENT);
7272
}
7373

74+
TEST_F(BufferManagerTest, InitHostWithUnalignedSlotCapacity)
75+
{
76+
BufferManager mgr;
77+
auto status = mgr.Init("test_buffer", MemoryType::HOST, 1000, 10);
78+
ASSERT_TRUE(status.ok()) << status.message;
79+
}
80+
81+
TEST_F(BufferManagerTest, InitDeviceWithUnalignedSlotCapacity)
82+
{
83+
BufferManager mgr;
84+
auto status = mgr.Init("test_buffer", MemoryType::ASCEND_DEVICE, 1000, 10);
85+
ASSERT_TRUE(status.ok()) << status.message;
86+
}
87+
7488
TEST_F(BufferManagerTest, DoubleInit)
7589
{
7690
BufferManager mgr;
@@ -123,12 +137,12 @@ TEST_F(BufferManagerTest, SingleAllocateAndFree)
123137
ScatterGatherEntry sge;
124138
status = mgr.Allocate(64, sge);
125139
ASSERT_TRUE(status.ok()) << status.message;
126-
ASSERT_NE(sge.addr, 0);
140+
ASSERT_NE(sge.local_addr, 0);
127141
ASSERT_EQ(sge.length, 64);
128142
ASSERT_EQ(sge.tokenId, 0);
129143
ASSERT_NE(sge.slot_index, UINT32_MAX);
130144

131-
auto* ptr = reinterpret_cast<void*>(sge.addr);
145+
auto* ptr = reinterpret_cast<void*>(sge.local_addr);
132146
std::memset(ptr, 0xAB, 64);
133147

134148
status = mgr.Free(sge.slot_index);
@@ -147,12 +161,12 @@ TEST_F(BufferManagerTest, MultipleAllocatesAndFrees)
147161
for (int i = 0; i < kCount; ++i) {
148162
status = mgr.Allocate(128, sges[i]);
149163
ASSERT_TRUE(status.ok()) << "Failed at i=" << i << ": " << status.message;
150-
ASSERT_NE(sges[i].addr, 0);
151-
std::memset(reinterpret_cast<void*>(sges[i].addr), i, 128);
164+
ASSERT_NE(sges[i].local_addr, 0);
165+
std::memset(reinterpret_cast<void*>(sges[i].local_addr), i, 128);
152166
}
153167

154168
for (int i = 0; i < kCount; ++i) {
155-
auto* data = reinterpret_cast<unsigned char*>(sges[i].addr);
169+
auto* data = reinterpret_cast<unsigned char*>(sges[i].local_addr);
156170
for (int j = 0; j < 128; ++j) { ASSERT_EQ(data[j], static_cast<unsigned char>(i)); }
157171
}
158172

@@ -191,10 +205,61 @@ TEST_F(BufferManagerTest, AllocateFullSlotSize)
191205
status = mgr.Allocate(1024, sge);
192206
ASSERT_TRUE(status.ok()) << status.message;
193207
ASSERT_EQ(sge.length, 1024);
208+
}
209+
210+
TEST_F(BufferManagerTest, AllocateFull4160ByteSlotCapacity)
211+
{
212+
BufferManager mgr;
213+
auto status = mgr.Init("test_buffer", MemoryType::HOST, 4160, 10);
214+
ASSERT_TRUE(status.ok());
194215

195-
std::memset(reinterpret_cast<void*>(sge.addr), 0xFF, 1024);
216+
ScatterGatherEntry sge;
217+
status = mgr.Allocate(4160, sge);
218+
ASSERT_TRUE(status.ok()) << status.message;
219+
ASSERT_EQ(sge.length, 4160);
220+
}
196221

197-
mgr.Free(sge.slot_index);
222+
TEST_F(BufferManagerTest, AllocateExceeds4160ByteSlotCapacity)
223+
{
224+
BufferManager mgr;
225+
auto status = mgr.Init("test_buffer", MemoryType::HOST, 4160, 10);
226+
ASSERT_TRUE(status.ok());
227+
228+
ScatterGatherEntry sge;
229+
status = mgr.Allocate(4161, sge);
230+
ASSERT_FALSE(status.ok());
231+
ASSERT_EQ(status.code, StatusCode::INVALID_ARGUMENT);
232+
}
233+
234+
TEST_F(BufferManagerTest, AllMemoryTypesUseAlignedSlotStride)
235+
{
236+
for (const auto type : {MemoryType::HOST, MemoryType::HOST_PINNED, MemoryType::ASCEND_DEVICE}) {
237+
BufferManager mgr;
238+
auto status = mgr.Init("test_buffer", type, 4160, 2);
239+
ASSERT_TRUE(status.ok()) << status.message;
240+
241+
ScatterGatherEntry first;
242+
ScatterGatherEntry second;
243+
ASSERT_TRUE(mgr.Allocate(4160, first).ok());
244+
ASSERT_TRUE(mgr.Allocate(4160, second).ok());
245+
ASSERT_EQ(second.local_addr - first.local_addr, 4224);
246+
ASSERT_EQ(second.device_addr - first.device_addr, 4224);
247+
}
248+
}
249+
250+
TEST_F(BufferManagerTest, FlagBufferCapacity71Uses128ByteStride)
251+
{
252+
BufferManager mgr;
253+
auto status = mgr.Init("flag_buffer", MemoryType::HOST_PINNED, 71, 2);
254+
ASSERT_TRUE(status.ok()) << status.message;
255+
256+
ScatterGatherEntry first;
257+
ScatterGatherEntry second;
258+
ASSERT_TRUE(mgr.Allocate(71, first).ok());
259+
ASSERT_TRUE(mgr.Allocate(71, second).ok());
260+
ASSERT_EQ(first.length, 71);
261+
ASSERT_EQ(second.local_addr - first.local_addr, 128);
262+
ASSERT_EQ(second.device_addr - first.device_addr, 128);
198263
}
199264

200265
TEST_F(BufferManagerTest, ReuseAfterFree)
@@ -212,7 +277,7 @@ TEST_F(BufferManagerTest, ReuseAfterFree)
212277
ScatterGatherEntry sge2;
213278
status = mgr.Allocate(64, sge2);
214279
ASSERT_TRUE(status.ok());
215-
ASSERT_EQ(sge2.addr, sge1.addr);
280+
ASSERT_EQ(sge2.local_addr, sge1.local_addr);
216281
ASSERT_EQ(sge2.slot_index, sge1.slot_index);
217282

218283
mgr.Free(sge2.slot_index);
@@ -233,7 +298,7 @@ TEST_F(BufferManagerTest, ConcurrentAllocateAndFree)
233298
auto s = mgr.Allocate(64, sge);
234299
ASSERT_TRUE(s.ok()) << "Thread " << thread_id << " op " << i << ": " << s.message;
235300

236-
std::memset(reinterpret_cast<void*>(sge.addr), thread_id, 64);
301+
std::memset(reinterpret_cast<void*>(sge.local_addr), thread_id, 64);
237302

238303
s = mgr.Free(sge.slot_index);
239304
ASSERT_TRUE(s.ok()) << s.message;
@@ -260,10 +325,10 @@ TEST_F(BufferManagerTest, ConcurrentStressTest)
260325
auto s = mgr.Allocate(128, sge);
261326
ASSERT_TRUE(s.ok());
262327

263-
std::memset(reinterpret_cast<void*>(sge.addr), thread_id, 128);
328+
std::memset(reinterpret_cast<void*>(sge.local_addr), thread_id, 128);
264329

265330
for (int j = 0; j < 128; ++j) {
266-
ASSERT_EQ(reinterpret_cast<unsigned char*>(sge.addr)[j], thread_id);
331+
ASSERT_EQ(reinterpret_cast<unsigned char*>(sge.local_addr)[j], thread_id);
267332
}
268333

269334
s = mgr.Free(sge.slot_index);
@@ -286,7 +351,7 @@ TEST_F(BufferManagerTest, FreeZeroesMemory)
286351
status = mgr.Allocate(64, sge1);
287352
ASSERT_TRUE(status.ok());
288353

289-
auto* ptr = reinterpret_cast<uint8_t*>(sge1.addr);
354+
auto* ptr = reinterpret_cast<uint8_t*>(sge1.local_addr);
290355
std::memset(ptr, 0xAB, 1024);
291356

292357
status = mgr.Free(sge1.slot_index);
@@ -295,10 +360,10 @@ TEST_F(BufferManagerTest, FreeZeroesMemory)
295360
ScatterGatherEntry sge2;
296361
status = mgr.Allocate(64, sge2);
297362
ASSERT_TRUE(status.ok());
298-
ASSERT_EQ(sge2.addr, sge1.addr);
363+
ASSERT_EQ(sge2.local_addr, sge1.local_addr);
299364
ASSERT_EQ(sge2.slot_index, sge1.slot_index);
300365

301-
auto* ptr2 = reinterpret_cast<uint8_t*>(sge2.addr);
366+
auto* ptr2 = reinterpret_cast<uint8_t*>(sge2.local_addr);
302367
for (size_t i = 0; i < 1024; ++i) {
303368
ASSERT_EQ(ptr2[i], 0) << "byte " << i << " not zeroed after free";
304369
}
@@ -393,8 +458,35 @@ TEST_F(BufferManagerTest, InitWithProviderRegistersMemory)
393458
ASSERT_EQ(provider.registerCount, 1);
394459
ASSERT_EQ(provider.lastMemType, TransProvider::MemType::MEM_HOST);
395460
ASSERT_NE(provider.lastAddr, 0);
396-
ASSERT_EQ(provider.lastSize, 1024 * 10);
461+
ASSERT_EQ(provider.lastSize, 1088 * 10);
397462
ASSERT_EQ(mgr.GetTokenId(), 42);
463+
464+
ScatterGatherEntry sge;
465+
ASSERT_TRUE(mgr.Allocate(64, sge).ok());
466+
ASSERT_EQ(sge.local_addr, sge.device_addr);
467+
}
468+
469+
TEST_F(BufferManagerTest, HostPinnedRegistersDeviceAddress)
470+
{
471+
StubTransProvider provider;
472+
473+
BufferManager mgr;
474+
auto status = mgr.Init("test_rdma_pinned", MemoryType::HOST_PINNED, 4096, 1, &provider);
475+
ASSERT_TRUE(status.ok()) << status.message;
476+
ASSERT_EQ(provider.registerCount, 1);
477+
ASSERT_EQ(provider.lastMemType, TransProvider::MemType::MEM_DEVICE);
478+
479+
ScatterGatherEntry sge;
480+
ASSERT_TRUE(mgr.Allocate(64, sge).ok());
481+
ASSERT_NE(sge.local_addr, 0);
482+
ASSERT_NE(sge.device_addr, 0);
483+
ASSERT_NE(sge.local_addr, sge.device_addr);
484+
ASSERT_EQ(sge.local_addr % 4096, 0);
485+
ASSERT_EQ(provider.lastAddr, sge.device_addr);
486+
487+
// The CPU writes through addr while HCOMM and remote RDMA use device_addr.
488+
std::memset(reinterpret_cast<void*>(sge.local_addr), 0x5A, sge.length);
489+
ASSERT_EQ(*reinterpret_cast<unsigned char*>(sge.local_addr), 0x5A);
398490
}
399491

400492
TEST_F(BufferManagerTest, InitWithProviderAllocateReturnsTokenId)

ucm/transport/kv/asu/trans/include/asu_transport/asu_transport.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,12 @@ struct TransportConfig {
7272
bool preconnect{true};
7373
bool bindCqPoller{true};
7474

75+
// Slot sizes are caller-visible capacities; BufferManager computes the
76+
// aligned physical stride used for allocation and memory registration.
7577
std::size_t sendBufferSlotSize{4160};
7678
std::size_t sendBufferSlotNum{128};
77-
std::size_t flagBufferSlotSize{128};
79+
// Maximum memory required by a batch store/retrieve response flag buffer.
80+
std::size_t flagBufferSlotSize{71};
7881
std::size_t flagBufferSlotNum{4096};
7982
std::size_t asuBatchLoadIoNum{110};
8083
std::size_t asuBatchStoreIoNum{110};

ucm/transport/kv/asu/trans/src/asu_submit_flow.cpp

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ void SetSubBatchSendFailed(TransportSubBatchContext& subBatchContext, const Stat
4747
subBatchContext.status = status;
4848
}
4949

50+
bool IsTransportBufferReady(const ScatterGatherEntry& sge)
51+
{
52+
return sge.local_addr != 0 && sge.device_addr != 0 && sge.length != 0 &&
53+
sge.slot_index != UINT32_MAX;
54+
}
55+
5056
} // namespace
5157

5258
Status AsuTransportImpl::SubmitTaskRequests(const TransportTaskContext& ctx,
@@ -130,13 +136,20 @@ Status AsuTransportImpl::BuildSubBatchSendBuffers(
130136
continue;
131137
}
132138

133-
if (subBatchContext.flagBuffer.addr == 0 || subBatchContext.flagBuffer.length == 0) {
134-
const auto subBatchStatus =
135-
Status::Error(StatusCode::NOT_INITIALIZED, "sub-batch flag buffer is not ready");
139+
if (subBatchContext.channel == nullptr ||
140+
!IsTransportBufferReady(subBatchContext.sendSge) ||
141+
!IsTransportBufferReady(subBatchContext.flagBuffer)) {
142+
const auto subBatchStatus = Status::Error(
143+
StatusCode::NOT_INITIALIZED, "sub-batch transport buffers are not ready");
136144
UC_ERROR(
137-
"Sub-batch flag buffer is not ready index={} cid={} flag_addr={} flag_length={}",
138-
index, subBatchContext.cid, subBatchContext.flagBuffer.addr,
139-
subBatchContext.flagBuffer.length);
145+
"Sub-batch transport buffers are not ready index={} cid={} channel={} "
146+
"send_local_addr={} send_device_addr={} send_length={} send_slot={} "
147+
"flag_local_addr={} flag_device_addr={} flag_length={} flag_slot={}",
148+
index, subBatchContext.cid, subBatchContext.channel != nullptr,
149+
subBatchContext.sendSge.local_addr, subBatchContext.sendSge.device_addr,
150+
subBatchContext.sendSge.length, subBatchContext.sendSge.slot_index,
151+
subBatchContext.flagBuffer.local_addr, subBatchContext.flagBuffer.device_addr,
152+
subBatchContext.flagBuffer.length, subBatchContext.flagBuffer.slot_index);
140153
SetSubBatchSendFailed(subBatchContext, subBatchStatus);
141154
if (status.ok()) { status = subBatchStatus; }
142155
ReleaseSubBatchResources(subBatchContext);
@@ -145,8 +158,8 @@ Status AsuTransportImpl::BuildSubBatchSendBuffers(
145158

146159
ioBatches.push_back(
147160
TransProvider::SendIoBatch{subBatchContext.channel->GetConnection(),
148-
reinterpret_cast<void*>(subBatchContext.sendSge.addr),
149-
reinterpret_cast<void*>(subBatchContext.flagBuffer.addr),
161+
reinterpret_cast<void*>(subBatchContext.sendSge.device_addr),
162+
reinterpret_cast<void*>(subBatchContext.flagBuffer.device_addr),
150163
subBatchContext.sendSge.length});
151164
subBatchIndexes.emplace_back(index);
152165
}

ucm/transport/kv/asu/trans/src/asu_transport_impl.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,14 @@ Status AsuTransportImpl::Init(const TransportConfig& config)
9595

9696
connManager_->StartRecoverLoop();
9797

98-
auto status =
99-
sendBufferManager_.Init("asu send buffer", MemoryType::HOST, config_.sendBufferSlotSize,
100-
config_.sendBufferSlotNum, transProvider_.get());
98+
auto status = sendBufferManager_.Init("asu send buffer", MemoryType::HOST_PINNED,
99+
config_.sendBufferSlotSize, config_.sendBufferSlotNum,
100+
transProvider_.get());
101101
if (!status.ok()) { return status; }
102102

103-
status =
104-
flagBufferManager_.Init("asu flag buffer", MemoryType::HOST, config_.flagBufferSlotSize,
105-
config_.flagBufferSlotNum, transProvider_.get());
103+
status = flagBufferManager_.Init("asu flag buffer", MemoryType::HOST_PINNED,
104+
config_.flagBufferSlotSize, config_.flagBufferSlotNum,
105+
transProvider_.get());
106106
if (!status.ok()) { return status; }
107107
protocolManager_ = std::make_unique<ProtocolManager>();
108108

@@ -427,7 +427,7 @@ void AsuTransportImpl::PollTaskCompletions(const TransportTaskContextPtr& ctx)
427427

428428
std::uint16_t completedCid = 0;
429429
if (const auto status = protocolManager_->PollResponseCid(
430-
reinterpret_cast<void*>(subBatchContext.flagBuffer.addr), completedCid);
430+
reinterpret_cast<void*>(subBatchContext.flagBuffer.local_addr), completedCid);
431431
!status.ok()) {
432432
continue;
433433
}
@@ -436,7 +436,7 @@ void AsuTransportImpl::PollTaskCompletions(const TransportTaskContextPtr& ctx)
436436
KvResponse response;
437437
const auto batchNumber = static_cast<std::uint16_t>(subBatchContext.entryStatus.size());
438438
if (const auto status = protocolManager_->UnpackResponse(
439-
reinterpret_cast<void*>(subBatchContext.flagBuffer.addr),
439+
reinterpret_cast<void*>(subBatchContext.flagBuffer.local_addr),
440440
ToKvOpcode(subBatchContext.opType), batchNumber, response);
441441
!status.ok()) {
442442
std::fill(subBatchContext.entryStatus.begin(), subBatchContext.entryStatus.end(),

0 commit comments

Comments
 (0)