Skip to content

Commit 839d5a1

Browse files
committed
[Refactor] Move host-pinned allocation to AscendBuffer, Support HBM transport staging buffers
1 parent c5cfef8 commit 839d5a1

7 files changed

Lines changed: 159 additions & 73 deletions

File tree

ucm/shared/trans/ascend/ascend_buffer.cc

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,27 @@
2828

2929
namespace UC::Trans {
3030

31+
namespace {
32+
33+
constexpr std::uintptr_t HOST_REGISTER_PAGE_SIZE = 4096;
34+
35+
void FreeHostMemory(void* host)
36+
{
37+
auto ret = aclrtFreeHost(host);
38+
if (ret != ACL_SUCCESS) { UC_ERROR("Failed to free host memory addr={} ret={}", host, ret); }
39+
}
40+
41+
void ReleaseHostPinnedMemory(void* host)
42+
{
43+
auto ret = aclrtHostUnregister(host);
44+
if (ret != ACL_SUCCESS) {
45+
UC_ERROR("Failed to unregister host-pinned memory addr={} ret={}", host, ret);
46+
}
47+
FreeHostMemory(host);
48+
}
49+
50+
} // namespace
51+
3152
class HostHugePages : public std::enable_shared_from_this<HostHugePages> {
3253
struct ConstructorKey {};
3354
static constexpr auto HUGE_PAGE_SIZE = 2UL << 20;
@@ -124,6 +145,33 @@ std::shared_ptr<void> Trans::AscendBuffer::MakeHostBuffer(size_t size)
124145
return nullptr;
125146
}
126147

148+
std::shared_ptr<void> Trans::AscendBuffer::MakeHostPinnedBuffer(size_t size, void** pDevice)
149+
{
150+
if (pDevice) { *pDevice = nullptr; }
151+
152+
void* host = nullptr;
153+
auto ret = aclrtMallocHost(&host, size);
154+
if (ret != ACL_SUCCESS) { return nullptr; }
155+
156+
if (reinterpret_cast<std::uintptr_t>(host) % HOST_REGISTER_PAGE_SIZE != 0) {
157+
UC_ERROR("Host-pinned memory is not 4K page aligned addr={} size={}", host, size);
158+
FreeHostMemory(host);
159+
return nullptr;
160+
}
161+
162+
void* device = nullptr;
163+
auto status = Buffer::RegisterHostBuffer(host, size, &device);
164+
if (status.Failure()) {
165+
UC_ERROR("Failed to register host-pinned memory addr={} size={} status={}", host, size,
166+
status);
167+
FreeHostMemory(host);
168+
return nullptr;
169+
}
170+
171+
if (pDevice) { *pDevice = device; }
172+
return std::shared_ptr<void>(host, ReleaseHostPinnedMemory);
173+
}
174+
127175
std::shared_ptr<void> Trans::AscendBuffer::MakeHostBuffer4DirectIo(size_t size)
128176
{
129177
try {

ucm/shared/trans/ascend/ascend_buffer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class AscendBuffer : public ReservedBuffer {
3232
public:
3333
std::shared_ptr<void> MakeDeviceBuffer(size_t size) override;
3434
std::shared_ptr<void> MakeHostBuffer(size_t size) override;
35+
std::shared_ptr<void> MakeHostPinnedBuffer(size_t size, void** pDevice = nullptr);
3536
std::shared_ptr<void> MakeHostBuffer4DirectIo(size_t size) override;
3637
};
3738

ucm/transport/kv/asu/test/case/buffer_manager_test.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ TEST_F(BufferManagerTest, SingleAllocateAndFree)
141141
ASSERT_EQ(sge.length, 64);
142142
ASSERT_EQ(sge.tokenId, 0);
143143
ASSERT_NE(sge.slot_index, UINT32_MAX);
144+
ASSERT_EQ(sge.memory_type, MemoryType::HOST);
144145

145146
auto* ptr = reinterpret_cast<void*>(sge.local_addr);
146147
std::memset(ptr, 0xAB, 64);

ucm/transport/kv/asu/trans/src/asu_transport_impl.cpp

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,13 @@
2222
* SOFTWARE.
2323
* */
2424
#include "asu_transport_impl.h"
25+
#include <acl/acl.h>
2526
#include <algorithm>
27+
#include <array>
2628
#include <chrono>
29+
#include <cstdint>
2730
#include <memory>
31+
#include <string>
2832
#include <thread>
2933
#include <utility>
3034
#include "aicpu_trans_provider.h"
@@ -37,6 +41,29 @@
3741

3842
namespace UC::ASU {
3943

44+
namespace {
45+
46+
constexpr std::size_t kFlagBufferHeaderCopySize = kCqeDwordCount * sizeof(std::uint32_t);
47+
48+
Status CopyDeviceToHost(const ScatterGatherEntry& sge, void* host, std::size_t size,
49+
const char* name)
50+
{
51+
if (size > sge.length) {
52+
return Status::Error(StatusCode::INVALID_ARGUMENT,
53+
std::string(name) + ": copy size exceeds buffer length");
54+
}
55+
const auto ret = aclrtMemcpy(host, size, reinterpret_cast<void*>(sge.device_addr), size,
56+
ACL_MEMCPY_DEVICE_TO_HOST);
57+
if (ret != ACL_SUCCESS) {
58+
return Status::Error(StatusCode::INTERNAL_ERROR,
59+
std::string(name) + ": copy device memory to host failed ret=" +
60+
std::to_string(ret));
61+
}
62+
return Status::OK();
63+
}
64+
65+
} // namespace
66+
4067
AsuTransportImpl::~AsuTransportImpl() { Shutdown(); }
4168

4269
Status AsuTransportImpl::Init(const std::string& configPath)
@@ -425,23 +452,54 @@ void AsuTransportImpl::PollTaskCompletions(const TransportTaskContextPtr& ctx)
425452
for (auto& subBatchContext : ctx->subBatchContexts) {
426453
if (subBatchContext.state != TransportSubBatchState::PENDING) { continue; }
427454

455+
auto completeWithError = [this, &ctx, &subBatchContext](const Status& status) {
456+
std::fill(subBatchContext.entryStatus.begin(), subBatchContext.entryStatus.end(),
457+
status);
458+
CompleteSubBatch(*ctx, subBatchContext, status);
459+
};
460+
428461
std::uint16_t completedCid = 0;
429-
if (const auto status = protocolManager_->PollResponseCid(
430-
reinterpret_cast<void*>(subBatchContext.flagBuffer.local_addr), completedCid);
462+
const void* responseData = nullptr;
463+
std::array<std::uint8_t, kFlagBufferHeaderCopySize> flagHeader{};
464+
std::vector<std::uint8_t> flagBuffer;
465+
if (IsCpuAccessible(subBatchContext.flagBuffer.memory_type)) {
466+
responseData = reinterpret_cast<void*>(subBatchContext.flagBuffer.local_addr);
467+
} else {
468+
auto status = CopyDeviceToHost(subBatchContext.flagBuffer, flagHeader.data(),
469+
flagHeader.size(), "flag buffer header");
470+
if (!status.ok()) {
471+
// Without a readable header, this sub-batch cannot be polled or unpacked.
472+
completeWithError(status);
473+
continue;
474+
}
475+
responseData = flagHeader.data();
476+
}
477+
478+
if (const auto status = protocolManager_->PollResponseCid(responseData, completedCid);
431479
!status.ok()) {
432480
continue;
433481
}
434482
if (completedCid == 0 || completedCid != subBatchContext.cid) { continue; }
435483

484+
if (!IsCpuAccessible(subBatchContext.flagBuffer.memory_type)) {
485+
// The header matched; copy the full CQE before unpacking entry status.
486+
flagBuffer.resize(subBatchContext.flagBuffer.length);
487+
auto status = CopyDeviceToHost(subBatchContext.flagBuffer, flagBuffer.data(),
488+
flagBuffer.size(), "flag buffer");
489+
if (!status.ok()) {
490+
// The matched CQE cannot be decoded without the complete flag buffer.
491+
completeWithError(status);
492+
continue;
493+
}
494+
responseData = flagBuffer.data();
495+
}
496+
436497
KvResponse response;
437498
const auto batchNumber = static_cast<std::uint16_t>(subBatchContext.entryStatus.size());
438499
if (const auto status = protocolManager_->UnpackResponse(
439-
reinterpret_cast<void*>(subBatchContext.flagBuffer.local_addr),
440-
ToKvOpcode(subBatchContext.opType), batchNumber, response);
500+
responseData, ToKvOpcode(subBatchContext.opType), batchNumber, response);
441501
!status.ok()) {
442-
std::fill(subBatchContext.entryStatus.begin(), subBatchContext.entryStatus.end(),
443-
status);
444-
CompleteSubBatch(*ctx, subBatchContext, status);
502+
completeWithError(status);
445503
continue;
446504
}
447505

ucm/transport/kv/asu/trans/src/buffer_manager.cpp

Lines changed: 23 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -26,32 +26,14 @@
2626
#include <cstdlib>
2727
#include <cstring>
2828
#include <limits>
29-
#include "logger.h"
3029
#include "trans/ascend/ascend_buffer.h"
3130

3231
namespace UC::ASU {
3332

34-
constexpr std::uintptr_t kHostRegisterPageSize = 4096;
3533
constexpr std::size_t kSlotSizeAlignment = 32;
3634
constexpr std::size_t kSlotPadding = 32;
3735
constexpr std::size_t kSlotAddressAlignment = 64;
3836

39-
void FreeHostMemory(void* addr)
40-
{
41-
auto ret = aclrtFreeHost(addr);
42-
if (ret != ACL_SUCCESS) { UC_ERROR("Failed to free host memory addr={} ret={}", addr, ret); }
43-
}
44-
45-
void ReleaseHostPinnedMemory(void* addr)
46-
{
47-
auto ret = aclrtHostUnregister(addr);
48-
if (ret != ACL_SUCCESS) {
49-
UC_ERROR("Failed to unregister host-pinned memory addr={} ret={}", addr, ret);
50-
return;
51-
}
52-
FreeHostMemory(addr);
53-
}
54-
5537
bool GetSlotStride(std::size_t capacity, std::size_t& stride)
5638
{
5739
// Keep one layout for every memory type: reserve ALIGN_UP(capacity, 32) + 32
@@ -77,16 +59,30 @@ Status BufferManager::BufferRegion::Create(MemoryType type, std::size_t size, Bu
7759
switch (type) {
7860
case MemoryType::HOST: {
7961
auto owner = ascendBuffer.MakeHostBuffer(size);
80-
if (!owner) { return AllocationFailed("host"); }
62+
if (!owner) {
63+
return Status::Error(StatusCode::INTERNAL_ERROR, "failed to allocate host memory");
64+
}
8165
// HOST has one CPU-visible address, which is also passed to the
8266
// provider when it registers the region as MEM_HOST.
8367
region = {owner, owner.get(), owner.get(), TransProvider::MemType::MEM_HOST};
8468
return Status::OK();
8569
}
86-
case MemoryType::HOST_PINNED: return MakeHostPinned(size, region);
70+
case MemoryType::HOST_PINNED: {
71+
void* deviceAddr = nullptr;
72+
auto owner = ascendBuffer.MakeHostPinnedBuffer(size, &deviceAddr);
73+
if (!owner) {
74+
return Status::Error(StatusCode::INTERNAL_ERROR,
75+
"failed to allocate host-pinned memory");
76+
}
77+
region = {owner, owner.get(), deviceAddr, TransProvider::MemType::MEM_DEVICE};
78+
return Status::OK();
79+
}
8780
case MemoryType::ASCEND_DEVICE: {
8881
auto owner = ascendBuffer.MakeDeviceBuffer(size);
89-
if (!owner) { return AllocationFailed("device"); }
82+
if (!owner) {
83+
return Status::Error(StatusCode::INTERNAL_ERROR,
84+
"failed to allocate device memory");
85+
}
9086
region = {owner, owner.get(), owner.get(), TransProvider::MemType::MEM_DEVICE};
9187
return Status::OK();
9288
}
@@ -102,51 +98,17 @@ void BufferManager::BufferRegion::Reset()
10298
providerMemType = TransProvider::MemType::MEM_HOST;
10399
}
104100

105-
Status BufferManager::BufferRegion::AllocationFailed(const char* type)
106-
{
107-
return Status::Error(StatusCode::INTERNAL_ERROR,
108-
std::string("failed to allocate ") + type + " memory");
109-
}
110-
111-
Status BufferManager::BufferRegion::MakeHostPinned(std::size_t size, BufferRegion& region)
112-
{
113-
void* hostAddr = nullptr;
114-
auto ret = aclrtMallocHost(&hostAddr, size);
115-
if (ret != ACL_SUCCESS) { return AllocationFailed("host-pinned"); }
116-
if (reinterpret_cast<std::uintptr_t>(hostAddr) % kHostRegisterPageSize != 0) {
117-
FreeHostMemory(hostAddr);
118-
return Status::Error(StatusCode::INTERNAL_ERROR,
119-
"host-pinned memory is not 4K page aligned");
120-
}
121-
122-
ret = aclrtHostRegisterV2(hostAddr, size, ACL_HOST_REG_MAPPED | ACL_HOST_REG_PINNED);
123-
if (ret != ACL_SUCCESS) {
124-
FreeHostMemory(hostAddr);
125-
return Status::Error(StatusCode::INTERNAL_ERROR,
126-
"failed to register host-pinned memory with ACL");
127-
}
128-
129-
void* deviceAddr = nullptr;
130-
ret = aclrtHostGetDevicePointer(hostAddr, &deviceAddr, 0);
131-
if (ret != ACL_SUCCESS) {
132-
ReleaseHostPinnedMemory(hostAddr);
133-
return Status::Error(StatusCode::INTERNAL_ERROR,
134-
"failed to get host-pinned device address");
135-
}
136-
137-
// The owner keeps the ACL registration alive until after HCOMM has
138-
// unregistered the region in BufferManager's destructor.
139-
auto owner = std::shared_ptr<void>(hostAddr, ReleaseHostPinnedMemory);
140-
region = {owner, hostAddr, deviceAddr, TransProvider::MemType::MEM_DEVICE};
141-
return Status::OK();
142-
}
143-
144101
bool IsTransportBufferReady(const ScatterGatherEntry& sge)
145102
{
146103
return sge.local_addr != 0 && sge.device_addr != 0 && sge.length != 0 &&
147104
sge.slot_index != UINT32_MAX;
148105
}
149106

107+
bool IsCpuAccessible(MemoryType type)
108+
{
109+
return type == MemoryType::HOST || type == MemoryType::HOST_PINNED;
110+
}
111+
150112
BufferManager::~BufferManager()
151113
{
152114
if (provider_ && memHandle_) {
@@ -262,6 +224,7 @@ Status BufferManager::Allocate(std::size_t size, ScatterGatherEntry& sge)
262224
sge.length = static_cast<std::uint32_t>(size);
263225
sge.tokenId = tokenId_;
264226
sge.slot_index = idx;
227+
sge.memory_type = memory_type_;
265228
return Status::OK();
266229
}
267230

ucm/transport/kv/asu/trans/src/buffer_manager.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,11 @@ struct ScatterGatherEntry {
4242
std::uint32_t length{0};
4343
std::uint32_t tokenId{0};
4444
std::uint32_t slot_index{UINT32_MAX};
45+
MemoryType memory_type{MemoryType::HOST};
4546
};
4647

4748
bool IsTransportBufferReady(const ScatterGatherEntry& sge);
49+
bool IsCpuAccessible(MemoryType type);
4850

4951
class BufferManager {
5052
public:
@@ -75,10 +77,6 @@ class BufferManager {
7577
void* localAddr{nullptr};
7678
void* deviceAddr{nullptr};
7779
TransProvider::MemType providerMemType{TransProvider::MemType::MEM_HOST};
78-
79-
private:
80-
static Status AllocationFailed(const char* type);
81-
static Status MakeHostPinned(std::size_t size, BufferRegion& region);
8280
};
8381

8482
Status RegisterMemory();

ucm/transport/kv/asu/trans/src/sqe_request.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
* SOFTWARE.
2323
* */
2424
#include <algorithm>
25+
#include <acl/acl.h>
2526
#include <cctype>
27+
#include <cstdint>
2628
#include <memory>
2729
#include <string>
2830
#include <type_traits>
@@ -148,8 +150,23 @@ Status PackSubBatchRequest(ProtocolManager& protocolManager, BufferManager& send
148150
return SetSubBatchBuildFailed(subBatchContext, status);
149151
}
150152

151-
status = protocolManager.PackRequest(
152-
reinterpret_cast<void*>(subBatchContext.sendSge.local_addr), opcode, request);
153+
if (IsCpuAccessible(subBatchContext.sendSge.memory_type)) {
154+
status = protocolManager.PackRequest(
155+
reinterpret_cast<void*>(subBatchContext.sendSge.local_addr), opcode, request);
156+
} else {
157+
std::vector<std::uint8_t> staging(packedSize, 0);
158+
status = protocolManager.PackRequest(staging.data(), opcode, request);
159+
if (status.ok()) {
160+
const auto ret = aclrtMemcpy(reinterpret_cast<void*>(subBatchContext.sendSge.device_addr),
161+
packedSize, staging.data(), packedSize,
162+
ACL_MEMCPY_HOST_TO_DEVICE);
163+
if (ret != ACL_SUCCESS) {
164+
status = Status::Error(StatusCode::INTERNAL_ERROR,
165+
"copy packed SQE to device memory failed ret=" +
166+
std::to_string(ret));
167+
}
168+
}
169+
}
153170
if (!status.ok()) {
154171
UC_ERROR("Pack sub-batch request failed opcode={} cid={} code={} message={}",
155172
static_cast<int>(opcode), subBatchContext.cid, static_cast<int>(status.code),

0 commit comments

Comments
 (0)