Skip to content

Commit d6e1b01

Browse files
AiterAsmKernel: add init-time sanity checks for .co registration (#3127)
Two checks added to AiterAsmKernelFast::init(), both running once per kernel variant at construction time (not on the launch hot path): 1. validate_hsaco_lds(): scans the raw .co ELF blob for group_segment_fixed_size via msgpack decode and compares against the device LDS limit (hipDeviceGetAttribute). Gives an actionable error before __hipRegisterFatBinary is called, e.g. on gfx942 (MI300X) the 64 KB limit would reject a .co built for gfx950 (MI355X, ~160 KB). 2. Registration probe: __hipRegisterFunction returns void, so a hipGetFuncBySymbol probe is used to detect silent rejection by the runtime (LDS limit exceeded, arch mismatch, corrupted binary, etc.).
1 parent 28fc36f commit d6e1b01

1 file changed

Lines changed: 93 additions & 2 deletions

File tree

csrc/include/aiter_hip_common.h

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#include <fstream>
2222
#include <mutex>
2323
#include <memory>
24+
#include <string_view>
25+
#include <algorithm>
2426
#ifdef AITER_EMBEDDED_HSA_HEADER
2527
#include AITER_EMBEDDED_HSA_HEADER
2628
#endif
@@ -194,6 +196,14 @@ class AiterAsmKernelFast
194196
nullptr,
195197
nullptr,
196198
nullptr);
199+
// Verify registration succeeded. __hipRegisterFunction returns void so
200+
// we probe via hipGetFuncBySymbol — if it returns null the runtime silently
201+
// rejected the kernel (e.g. resource limits, arch mismatch). This runs
202+
// once per kernel variant at init time, not on every launch.
203+
hipFunction_t probe = nullptr;
204+
(void)hipGetFuncBySymbol(&probe, reinterpret_cast<void*>(this));
205+
AITER_CHECK(probe != nullptr,
206+
"kernel registration failed for '", kernel_name, "'.");
197207
}
198208

199209
public:
@@ -218,8 +228,7 @@ class AiterAsmKernelFast
218228
HIP_LAUNCH_PARAM_END};
219229
hipFunction_t kernel_func = nullptr;
220230
// TODO Ask runtime folks to provide an API for hipLaunchKernel with extra arg
221-
// Don't error check here.
222-
// Failure to load the func would cause hipModuleLaunchKernel to fail anyways.
231+
// Don't error check here — registration is validated once in init().
223232
(void)hipGetFuncBySymbol(&kernel_func, reinterpret_cast<void*>(this));
224233

225234
HIP_CALL_LAUNCH(hipModuleLaunchKernel(kernel_func,
@@ -242,6 +251,85 @@ class AiterAsmKernel: private AiterAsmKernelFast
242251
private:
243252
std::unique_ptr<char[]> hsaco_data;
244253

254+
static void validate_hsaco_lds(const char* kernel_name,
255+
const std::string& path,
256+
const char* data,
257+
size_t size)
258+
{
259+
// The AMDGPU metadata is stored as msgpack in an ELF .note section — not as
260+
// raw ASCII. We scan for the amdhsa kernel descriptor's group_segment_fixed_size
261+
// field in the binary kernel descriptor block instead. In the AMDHSA kernel
262+
// descriptor (64 bytes at the start of the .text symbol), byte offset 0 is a
263+
// uint32_t group_segment_fixed_size. However the most reliable approach without
264+
// linking a msgpack parser is to search the raw ELF for the msgpack integer that
265+
// follows the "group_segment_fixed_size" msgpack string key.
266+
//
267+
// Msgpack format for the key: 0xd9 <len8> <bytes> or 0xda <len16> <bytes> or
268+
// fixstr (0xa0|len) <bytes>. The value follows immediately as a msgpack uint.
269+
// We do a byte-level search: find "group_segment_fixed_size" as raw bytes, then
270+
// decode the msgpack uint that follows.
271+
272+
static const char key[] = "group_segment_fixed_size";
273+
static const size_t key_len = sizeof(key) - 1;
274+
275+
const char* p = data;
276+
const char* end = data + size;
277+
278+
while(p < end)
279+
{
280+
// Find the key bytes in the blob
281+
const char* found = std::search(p, end, key, key + key_len);
282+
if(found == end)
283+
return;
284+
285+
// The msgpack value follows the key bytes. Skip any msgpack string prefix
286+
// byte (the byte before 'g' in "group_...") was already part of the key
287+
// detection; now p points right after the key.
288+
const char* vp = found + key_len;
289+
if(vp >= end)
290+
return;
291+
292+
// Decode the following msgpack integer
293+
uint64_t declared_lds = 0;
294+
uint8_t tag = static_cast<uint8_t>(*vp);
295+
if(tag <= 0x7f) // positive fixint
296+
declared_lds = tag;
297+
else if(tag == 0xcc && vp + 1 < end) // uint8
298+
declared_lds = static_cast<uint8_t>(vp[1]);
299+
else if(tag == 0xcd && vp + 2 < end) // uint16 big-endian
300+
declared_lds = (static_cast<uint64_t>(static_cast<uint8_t>(vp[1])) << 8) |
301+
static_cast<uint64_t>(static_cast<uint8_t>(vp[2]));
302+
else if(tag == 0xce && vp + 4 < end) // uint32 big-endian
303+
declared_lds = (static_cast<uint64_t>(static_cast<uint8_t>(vp[1])) << 24) |
304+
(static_cast<uint64_t>(static_cast<uint8_t>(vp[2])) << 16) |
305+
(static_cast<uint64_t>(static_cast<uint8_t>(vp[3])) << 8) |
306+
static_cast<uint64_t>(static_cast<uint8_t>(vp[4]));
307+
else
308+
{
309+
p = found + 1;
310+
continue; // not a uint we recognise; keep searching
311+
}
312+
313+
if(declared_lds == 0)
314+
{
315+
p = found + 1;
316+
continue;
317+
}
318+
319+
hipDevice_t dev;
320+
int max_lds = 0;
321+
HIP_CALL(hipGetDevice(&dev));
322+
HIP_CALL(hipDeviceGetAttribute(&max_lds, hipDeviceAttributeMaxSharedMemoryPerBlock, dev));
323+
324+
AITER_CHECK(static_cast<int64_t>(declared_lds) <= static_cast<int64_t>(max_lds),
325+
"kernel '", kernel_name, "' in ", path.c_str(),
326+
": group_segment_fixed_size=", static_cast<uint32_t>(declared_lds),
327+
" exceeds device LDS limit=", max_lds,
328+
" bytes. Rebuild the .co with a smaller LDS allocation.");
329+
return; // validated OK
330+
}
331+
}
332+
245333
const void* load_hsaco_file(const char* kernel_name, const char* hsaco_path)
246334
{
247335
const char* AITER_ASM_DIR = std::getenv("AITER_ASM_DIR");
@@ -261,6 +349,9 @@ class AiterAsmKernel: private AiterAsmKernelFast
261349
file.seekg(0, std::ios::beg);
262350
AITER_CHECK(
263351
file.read(hsaco_data.get(), file_size), "failed to read ", full_path.c_str());
352+
353+
validate_hsaco_lds(kernel_name, full_path, hsaco_data.get(), file_size);
354+
264355
return hsaco_data.get();
265356
}
266357
else

0 commit comments

Comments
 (0)