diff --git a/csrc/include/aiter_hip_common.h b/csrc/include/aiter_hip_common.h index 9afcaa4ae8..7790e3d978 100644 --- a/csrc/include/aiter_hip_common.h +++ b/csrc/include/aiter_hip_common.h @@ -21,6 +21,8 @@ #include #include #include +#include +#include #ifdef AITER_EMBEDDED_HSA_HEADER #include AITER_EMBEDDED_HSA_HEADER #endif @@ -194,6 +196,14 @@ class AiterAsmKernelFast nullptr, nullptr, nullptr); + // Verify registration succeeded. __hipRegisterFunction returns void so + // we probe via hipGetFuncBySymbol — if it returns null the runtime silently + // rejected the kernel (e.g. resource limits, arch mismatch). This runs + // once per kernel variant at init time, not on every launch. + hipFunction_t probe = nullptr; + (void)hipGetFuncBySymbol(&probe, reinterpret_cast(this)); + AITER_CHECK(probe != nullptr, + "kernel registration failed for '", kernel_name, "'."); } public: @@ -218,8 +228,7 @@ class AiterAsmKernelFast HIP_LAUNCH_PARAM_END}; hipFunction_t kernel_func = nullptr; // TODO Ask runtime folks to provide an API for hipLaunchKernel with extra arg - // Don't error check here. - // Failure to load the func would cause hipModuleLaunchKernel to fail anyways. + // Don't error check here — registration is validated once in init(). (void)hipGetFuncBySymbol(&kernel_func, reinterpret_cast(this)); HIP_CALL_LAUNCH(hipModuleLaunchKernel(kernel_func, @@ -242,6 +251,85 @@ class AiterAsmKernel: private AiterAsmKernelFast private: std::unique_ptr hsaco_data; + static void validate_hsaco_lds(const char* kernel_name, + const std::string& path, + const char* data, + size_t size) + { + // The AMDGPU metadata is stored as msgpack in an ELF .note section — not as + // raw ASCII. We scan for the amdhsa kernel descriptor's group_segment_fixed_size + // field in the binary kernel descriptor block instead. In the AMDHSA kernel + // descriptor (64 bytes at the start of the .text symbol), byte offset 0 is a + // uint32_t group_segment_fixed_size. However the most reliable approach without + // linking a msgpack parser is to search the raw ELF for the msgpack integer that + // follows the "group_segment_fixed_size" msgpack string key. + // + // Msgpack format for the key: 0xd9 or 0xda or + // fixstr (0xa0|len) . The value follows immediately as a msgpack uint. + // We do a byte-level search: find "group_segment_fixed_size" as raw bytes, then + // decode the msgpack uint that follows. + + static const char key[] = "group_segment_fixed_size"; + static const size_t key_len = sizeof(key) - 1; + + const char* p = data; + const char* end = data + size; + + while(p < end) + { + // Find the key bytes in the blob + const char* found = std::search(p, end, key, key + key_len); + if(found == end) + return; + + // The msgpack value follows the key bytes. Skip any msgpack string prefix + // byte (the byte before 'g' in "group_...") was already part of the key + // detection; now p points right after the key. + const char* vp = found + key_len; + if(vp >= end) + return; + + // Decode the following msgpack integer + uint64_t declared_lds = 0; + uint8_t tag = static_cast(*vp); + if(tag <= 0x7f) // positive fixint + declared_lds = tag; + else if(tag == 0xcc && vp + 1 < end) // uint8 + declared_lds = static_cast(vp[1]); + else if(tag == 0xcd && vp + 2 < end) // uint16 big-endian + declared_lds = (static_cast(static_cast(vp[1])) << 8) | + static_cast(static_cast(vp[2])); + else if(tag == 0xce && vp + 4 < end) // uint32 big-endian + declared_lds = (static_cast(static_cast(vp[1])) << 24) | + (static_cast(static_cast(vp[2])) << 16) | + (static_cast(static_cast(vp[3])) << 8) | + static_cast(static_cast(vp[4])); + else + { + p = found + 1; + continue; // not a uint we recognise; keep searching + } + + if(declared_lds == 0) + { + p = found + 1; + continue; + } + + hipDevice_t dev; + int max_lds = 0; + HIP_CALL(hipGetDevice(&dev)); + HIP_CALL(hipDeviceGetAttribute(&max_lds, hipDeviceAttributeMaxSharedMemoryPerBlock, dev)); + + AITER_CHECK(static_cast(declared_lds) <= static_cast(max_lds), + "kernel '", kernel_name, "' in ", path.c_str(), + ": group_segment_fixed_size=", static_cast(declared_lds), + " exceeds device LDS limit=", max_lds, + " bytes. Rebuild the .co with a smaller LDS allocation."); + return; // validated OK + } + } + const void* load_hsaco_file(const char* kernel_name, const char* hsaco_path) { const char* AITER_ASM_DIR = std::getenv("AITER_ASM_DIR"); @@ -261,6 +349,9 @@ class AiterAsmKernel: private AiterAsmKernelFast file.seekg(0, std::ios::beg); AITER_CHECK( file.read(hsaco_data.get(), file_size), "failed to read ", full_path.c_str()); + + validate_hsaco_lds(kernel_name, full_path, hsaco_data.get(), file_size); + return hsaco_data.get(); } else