2121#include < fstream>
2222#include < mutex>
2323#include < memory>
24+ #include < string_view>
25+ #include < algorithm>
2426#ifdef AITER_EMBEDDED_HSA_HEADER
2527#include AITER_EMBEDDED_HSA_HEADER
2628#endif
@@ -194,6 +196,14 @@ class AiterAsmKernelFast
194196 nullptr ,
195197 nullptr ,
196198 nullptr );
199+ // Verify registration succeeded. __hipRegisterFunction returns void so
200+ // we probe via hipGetFuncBySymbol — if it returns null the runtime silently
201+ // rejected the kernel (e.g. resource limits, arch mismatch). This runs
202+ // once per kernel variant at init time, not on every launch.
203+ hipFunction_t probe = nullptr ;
204+ (void )hipGetFuncBySymbol (&probe, reinterpret_cast <void *>(this ));
205+ AITER_CHECK (probe != nullptr ,
206+ " kernel registration failed for '" , kernel_name, " '." );
197207 }
198208
199209 public:
@@ -218,8 +228,7 @@ class AiterAsmKernelFast
218228 HIP_LAUNCH_PARAM_END};
219229 hipFunction_t kernel_func = nullptr ;
220230 // TODO Ask runtime folks to provide an API for hipLaunchKernel with extra arg
221- // Don't error check here.
222- // Failure to load the func would cause hipModuleLaunchKernel to fail anyways.
231+ // Don't error check here — registration is validated once in init().
223232 (void )hipGetFuncBySymbol (&kernel_func, reinterpret_cast <void *>(this ));
224233
225234 HIP_CALL_LAUNCH (hipModuleLaunchKernel (kernel_func,
@@ -242,6 +251,85 @@ class AiterAsmKernel: private AiterAsmKernelFast
242251 private:
243252 std::unique_ptr<char []> hsaco_data;
244253
254+ static void validate_hsaco_lds (const char * kernel_name,
255+ const std::string& path,
256+ const char * data,
257+ size_t size)
258+ {
259+ // The AMDGPU metadata is stored as msgpack in an ELF .note section — not as
260+ // raw ASCII. We scan for the amdhsa kernel descriptor's group_segment_fixed_size
261+ // field in the binary kernel descriptor block instead. In the AMDHSA kernel
262+ // descriptor (64 bytes at the start of the .text symbol), byte offset 0 is a
263+ // uint32_t group_segment_fixed_size. However the most reliable approach without
264+ // linking a msgpack parser is to search the raw ELF for the msgpack integer that
265+ // follows the "group_segment_fixed_size" msgpack string key.
266+ //
267+ // Msgpack format for the key: 0xd9 <len8> <bytes> or 0xda <len16> <bytes> or
268+ // fixstr (0xa0|len) <bytes>. The value follows immediately as a msgpack uint.
269+ // We do a byte-level search: find "group_segment_fixed_size" as raw bytes, then
270+ // decode the msgpack uint that follows.
271+
272+ static const char key[] = " group_segment_fixed_size" ;
273+ static const size_t key_len = sizeof (key) - 1 ;
274+
275+ const char * p = data;
276+ const char * end = data + size;
277+
278+ while (p < end)
279+ {
280+ // Find the key bytes in the blob
281+ const char * found = std::search (p, end, key, key + key_len);
282+ if (found == end)
283+ return ;
284+
285+ // The msgpack value follows the key bytes. Skip any msgpack string prefix
286+ // byte (the byte before 'g' in "group_...") was already part of the key
287+ // detection; now p points right after the key.
288+ const char * vp = found + key_len;
289+ if (vp >= end)
290+ return ;
291+
292+ // Decode the following msgpack integer
293+ uint64_t declared_lds = 0 ;
294+ uint8_t tag = static_cast <uint8_t >(*vp);
295+ if (tag <= 0x7f ) // positive fixint
296+ declared_lds = tag;
297+ else if (tag == 0xcc && vp + 1 < end) // uint8
298+ declared_lds = static_cast <uint8_t >(vp[1 ]);
299+ else if (tag == 0xcd && vp + 2 < end) // uint16 big-endian
300+ declared_lds = (static_cast <uint64_t >(static_cast <uint8_t >(vp[1 ])) << 8 ) |
301+ static_cast <uint64_t >(static_cast <uint8_t >(vp[2 ]));
302+ else if (tag == 0xce && vp + 4 < end) // uint32 big-endian
303+ declared_lds = (static_cast <uint64_t >(static_cast <uint8_t >(vp[1 ])) << 24 ) |
304+ (static_cast <uint64_t >(static_cast <uint8_t >(vp[2 ])) << 16 ) |
305+ (static_cast <uint64_t >(static_cast <uint8_t >(vp[3 ])) << 8 ) |
306+ static_cast <uint64_t >(static_cast <uint8_t >(vp[4 ]));
307+ else
308+ {
309+ p = found + 1 ;
310+ continue ; // not a uint we recognise; keep searching
311+ }
312+
313+ if (declared_lds == 0 )
314+ {
315+ p = found + 1 ;
316+ continue ;
317+ }
318+
319+ hipDevice_t dev;
320+ int max_lds = 0 ;
321+ HIP_CALL (hipGetDevice (&dev));
322+ HIP_CALL (hipDeviceGetAttribute (&max_lds, hipDeviceAttributeMaxSharedMemoryPerBlock, dev));
323+
324+ AITER_CHECK (static_cast <int64_t >(declared_lds) <= static_cast <int64_t >(max_lds),
325+ " kernel '" , kernel_name, " ' in " , path.c_str (),
326+ " : group_segment_fixed_size=" , static_cast <uint32_t >(declared_lds),
327+ " exceeds device LDS limit=" , max_lds,
328+ " bytes. Rebuild the .co with a smaller LDS allocation." );
329+ return ; // validated OK
330+ }
331+ }
332+
245333 const void * load_hsaco_file (const char * kernel_name, const char * hsaco_path)
246334 {
247335 const char * AITER_ASM_DIR = std::getenv (" AITER_ASM_DIR" );
@@ -261,6 +349,9 @@ class AiterAsmKernel: private AiterAsmKernelFast
261349 file.seekg (0 , std::ios::beg);
262350 AITER_CHECK (
263351 file.read (hsaco_data.get (), file_size), " failed to read " , full_path.c_str ());
352+
353+ validate_hsaco_lds (kernel_name, full_path, hsaco_data.get (), file_size);
354+
264355 return hsaco_data.get ();
265356 }
266357 else
0 commit comments