Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 93 additions & 2 deletions csrc/include/aiter_hip_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include <fstream>
#include <mutex>
#include <memory>
#include <string_view>
#include <algorithm>
#ifdef AITER_EMBEDDED_HSA_HEADER
#include AITER_EMBEDDED_HSA_HEADER
#endif
Expand Down Expand Up @@ -194,6 +196,14 @@ class AiterAsmKernelFast
nullptr,
nullptr,
nullptr);
// Verify registration succeeded. __hipRegisterFunction returns void so
// we probe via hipGetFuncBySymbol — if it returns null the runtime silently
// rejected the kernel (e.g. resource limits, arch mismatch). This runs
// once per kernel variant at init time, not on every launch.
hipFunction_t probe = nullptr;
(void)hipGetFuncBySymbol(&probe, reinterpret_cast<void*>(this));
AITER_CHECK(probe != nullptr,
"kernel registration failed for '", kernel_name, "'.");
}

public:
Expand All @@ -218,8 +228,7 @@ class AiterAsmKernelFast
HIP_LAUNCH_PARAM_END};
hipFunction_t kernel_func = nullptr;
// TODO Ask runtime folks to provide an API for hipLaunchKernel with extra arg
// Don't error check here.
// Failure to load the func would cause hipModuleLaunchKernel to fail anyways.
// Don't error check here — registration is validated once in init().
(void)hipGetFuncBySymbol(&kernel_func, reinterpret_cast<void*>(this));

HIP_CALL_LAUNCH(hipModuleLaunchKernel(kernel_func,
Expand All @@ -242,6 +251,85 @@ class AiterAsmKernel: private AiterAsmKernelFast
private:
std::unique_ptr<char[]> hsaco_data;

static void validate_hsaco_lds(const char* kernel_name,
const std::string& path,
const char* data,
size_t size)
{
// The AMDGPU metadata is stored as msgpack in an ELF .note section — not as
// raw ASCII. We scan for the amdhsa kernel descriptor's group_segment_fixed_size
// field in the binary kernel descriptor block instead. In the AMDHSA kernel
// descriptor (64 bytes at the start of the .text symbol), byte offset 0 is a
// uint32_t group_segment_fixed_size. However the most reliable approach without
// linking a msgpack parser is to search the raw ELF for the msgpack integer that
// follows the "group_segment_fixed_size" msgpack string key.
//
// Msgpack format for the key: 0xd9 <len8> <bytes> or 0xda <len16> <bytes> or
// fixstr (0xa0|len) <bytes>. The value follows immediately as a msgpack uint.
// We do a byte-level search: find "group_segment_fixed_size" as raw bytes, then
// decode the msgpack uint that follows.

static const char key[] = "group_segment_fixed_size";
static const size_t key_len = sizeof(key) - 1;

const char* p = data;
const char* end = data + size;

while(p < end)
{
// Find the key bytes in the blob
const char* found = std::search(p, end, key, key + key_len);
if(found == end)
return;

// The msgpack value follows the key bytes. Skip any msgpack string prefix
// byte (the byte before 'g' in "group_...") was already part of the key
// detection; now p points right after the key.
const char* vp = found + key_len;
if(vp >= end)
return;

// Decode the following msgpack integer
uint64_t declared_lds = 0;
uint8_t tag = static_cast<uint8_t>(*vp);
if(tag <= 0x7f) // positive fixint
declared_lds = tag;
else if(tag == 0xcc && vp + 1 < end) // uint8
declared_lds = static_cast<uint8_t>(vp[1]);
else if(tag == 0xcd && vp + 2 < end) // uint16 big-endian
declared_lds = (static_cast<uint64_t>(static_cast<uint8_t>(vp[1])) << 8) |
static_cast<uint64_t>(static_cast<uint8_t>(vp[2]));
else if(tag == 0xce && vp + 4 < end) // uint32 big-endian
declared_lds = (static_cast<uint64_t>(static_cast<uint8_t>(vp[1])) << 24) |
(static_cast<uint64_t>(static_cast<uint8_t>(vp[2])) << 16) |
(static_cast<uint64_t>(static_cast<uint8_t>(vp[3])) << 8) |
static_cast<uint64_t>(static_cast<uint8_t>(vp[4]));
else
{
p = found + 1;
continue; // not a uint we recognise; keep searching
}

if(declared_lds == 0)
{
p = found + 1;
continue;
}

hipDevice_t dev;
int max_lds = 0;
HIP_CALL(hipGetDevice(&dev));
HIP_CALL(hipDeviceGetAttribute(&max_lds, hipDeviceAttributeMaxSharedMemoryPerBlock, dev));

AITER_CHECK(static_cast<int64_t>(declared_lds) <= static_cast<int64_t>(max_lds),
"kernel '", kernel_name, "' in ", path.c_str(),
": group_segment_fixed_size=", static_cast<uint32_t>(declared_lds),
" exceeds device LDS limit=", max_lds,
" bytes. Rebuild the .co with a smaller LDS allocation.");
return; // validated OK
}
}

const void* load_hsaco_file(const char* kernel_name, const char* hsaco_path)
{
const char* AITER_ASM_DIR = std::getenv("AITER_ASM_DIR");
Expand All @@ -261,6 +349,9 @@ class AiterAsmKernel: private AiterAsmKernelFast
file.seekg(0, std::ios::beg);
AITER_CHECK(
file.read(hsaco_data.get(), file_size), "failed to read ", full_path.c_str());

validate_hsaco_lds(kernel_name, full_path, hsaco_data.get(), file_size);

return hsaco_data.get();
}
else
Expand Down
Loading