Skip to content

Commit 4cccc45

Browse files
committed
use a new mode for kernel import: 'usePrecompiledAndBakedKernel'
1 parent 58db9c1 commit 4cccc45

5 files changed

Lines changed: 115 additions & 15 deletions

File tree

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,7 @@ build/
2525
result.xml
2626
UnitTest/bitcodes/*.fatbin
2727
Test/SimpleD3D12/cache/**
28+
29+
ParallelPrimitives/cache/KernelArgs.h
30+
ParallelPrimitives/cache/Kernels.h
31+
ParallelPrimitives/cache/oro_compiled_kernels.h

Orochi/OrochiUtils.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,41 @@ oroFunction OrochiUtils::getFunctionFromString( oroDevice device, const char* so
558558
return f;
559559
}
560560

561+
oroFunction OrochiUtils::getFunctionFromPrecompiledBinary_asData( const unsigned char* precompData, size_t dataSizeInBytes, const std::string& funcName )
562+
{
563+
std::lock_guard<std::recursive_mutex> lock( m_mutex );
564+
565+
const std::string cacheName = OrochiUtilsImpl::getCacheName( "___BAKED_BIN___", funcName );
566+
if( m_kernelMap.find( cacheName.c_str() ) != m_kernelMap.end() )
567+
{
568+
return m_kernelMap[cacheName].function;
569+
}
570+
571+
oroModule module = nullptr;
572+
oroError e = oroModuleLoadData( &module, precompData );
573+
if ( e != oroSuccess )
574+
{
575+
// add some verbose info to help debugging missing data
576+
printf("oroModuleLoadData FAILED (error = %d) loading baked precomp data: %s\n", e, funcName.c_str());
577+
return nullptr;
578+
}
579+
580+
oroFunction functionOut{};
581+
e = oroModuleGetFunction( &functionOut, module, funcName.c_str() );
582+
if ( e != oroSuccess )
583+
{
584+
// add some verbose info to help debugging missing data
585+
printf("oroModuleGetFunction FAILED (error = %d) loading baked precomp data: %s\n", e, funcName.c_str());
586+
return nullptr;
587+
}
588+
OROASSERT( e == oroSuccess, 0 );
589+
590+
m_kernelMap[cacheName].function = functionOut;
591+
m_kernelMap[cacheName].module = module;
592+
593+
return functionOut;
594+
}
595+
561596
oroFunction OrochiUtils::getFunctionFromPrecompiledBinary( const std::string& path, const std::string& funcName )
562597
{
563598
std::lock_guard<std::recursive_mutex> lock( m_mutex );

Orochi/OrochiUtils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ class OrochiUtils
6969

7070
oroFunction getFunctionFromPrecompiledBinary( const std::string& path, const std::string& funcName );
7171

72+
// this function is like 'getFunctionFromPrecompiledBinary' but instead of giving a path to a file, we give the data directly.
73+
// ( use the script convert_binary_to_array.py to convert the .hipfb to a C-array. )
74+
oroFunction getFunctionFromPrecompiledBinary_asData( const unsigned char* data, size_t dataSizeInBytes, const std::string& funcName );
75+
7276
oroFunction getFunctionFromFile( oroDevice device, const char* path, const char* funcName, std::vector<const char*>* opts );
7377
oroFunction getFunctionFromString( oroDevice device, const char* source, const char* path, const char* funcName, std::vector<const char*>* opts, int numHeaders, const char** headers, const char** includeNames );
7478
oroFunction getFunction( oroDevice device, const char* code, const char* path, const char* funcName, std::vector<const char*>* opts, int numHeaders = 0, const char** headers = 0, const char** includeNames = 0, oroModule* loadedModule = 0 );

ParallelPrimitives/RadixSort.cpp

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,24 +40,47 @@
4040
#include <dlfcn.h>
4141
#endif
4242

43-
namespace
44-
{
45-
#if defined( ORO_PRECOMPILED )
46-
constexpr auto useBitCode = true;
43+
#if defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING )
44+
#include <ParallelPrimitives/cache/oro_compiled_kernels.h> // generate this header with 'convert_binary_to_array.py'
4745
#else
48-
constexpr auto useBitCode = false;
46+
const unsigned char oro_compiled_kernels_h[] = "";
47+
const size_t oro_compiled_kernels_h_size = 0;
4948
#endif
5049

51-
#if defined( ORO_PP_LOAD_FROM_STRING )
52-
constexpr auto useBakeKernel = true;
53-
#else
54-
constexpr auto useBakeKernel = false;
55-
static const char* hip_RadixSortKernels = nullptr;
56-
namespace hip
50+
namespace
5751
{
58-
static const char** RadixSortKernelsArgs = nullptr;
59-
static const char** RadixSortKernelsIncludes = nullptr;
60-
} // namespace hip
52+
53+
// if those 2 preprocessors are enabled, this activates the 'usePrecompiledAndBakedKernel' mode.
54+
#if defined( ORO_PRECOMPILED ) && defined( ORO_PP_LOAD_FROM_STRING )
55+
56+
// this flag means that we bake the precompiled kernels
57+
constexpr auto usePrecompiledAndBakedKernel = true;
58+
59+
constexpr auto useBitCode = false;
60+
constexpr auto useBakeKernel = false;
61+
62+
#else
63+
64+
constexpr auto usePrecompiledAndBakedKernel = false;
65+
66+
#if defined( ORO_PRECOMPILED )
67+
constexpr auto useBitCode = true; // this flag means we use the bitcode file
68+
#else
69+
constexpr auto useBitCode = false;
70+
#endif
71+
72+
#if defined( ORO_PP_LOAD_FROM_STRING )
73+
constexpr auto useBakeKernel = true; // this flag means we use the HIP source code embeded in the binary ( as a string )
74+
#else
75+
constexpr auto useBakeKernel = false;
76+
static const char* hip_RadixSortKernels = nullptr;
77+
namespace hip
78+
{
79+
static const char** RadixSortKernelsArgs = nullptr;
80+
static const char** RadixSortKernelsIncludes = nullptr;
81+
} // namespace hip
82+
#endif
83+
6184
#endif
6285

6386
static_assert( !( useBitCode && useBakeKernel ), "useBitCode and useBakeKernel cannot coexist" );
@@ -211,9 +234,14 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string
211234
opts.push_back( sort_block_size_param.c_str() );
212235
opts.push_back( sort_num_warps_param.c_str() );
213236

237+
214238
for( const auto& record : records )
215239
{
216-
if constexpr( useBakeKernel )
240+
if constexpr( usePrecompiledAndBakedKernel )
241+
{
242+
oroFunctions[record.kernelType] = m_oroutils.getFunctionFromPrecompiledBinary_asData(oro_compiled_kernels_h, oro_compiled_kernels_h_size, record.kernelName.c_str() );
243+
}
244+
else if constexpr( useBakeKernel )
217245
{
218246
oroFunctions[record.kernelType] = m_oroutils.getFunctionFromString( m_device, hip_RadixSortKernels, currentKernelPath.c_str(), record.kernelName.c_str(), &opts, 1, hip::RadixSortKernelsArgs, hip::RadixSortKernelsIncludes );
219247
}
@@ -231,6 +259,8 @@ void RadixSort::compileKernels( const std::string& kernelPath, const std::string
231259
printKernelInfo( record.kernelName, oroFunctions[record.kernelType] );
232260
}
233261
}
262+
263+
return;
234264
}
235265

236266
int RadixSort::calculateWGsToExecute( const int blockSize ) const noexcept

scripts/convert_binary_to_array.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# convert_binary_to_header.py
2+
import sys
3+
from pathlib import Path
4+
5+
def binary_to_c_array(bin_file, array_name):
6+
with open(bin_file, 'rb') as f:
7+
binary_data = f.read()
8+
9+
hex_array = ', '.join(f'0x{b:02x}' for b in binary_data)
10+
c_array = f'const unsigned char {array_name}[] = {{\n {hex_array}\n}};\n'
11+
c_array += f'const size_t {array_name}_size = sizeof({array_name});\n'
12+
return c_array
13+
14+
if __name__ == "__main__":
15+
if len(sys.argv) != 3:
16+
print(f"Usage: {sys.argv[0]} <input_binary_file> <output_header_file>")
17+
sys.exit(1)
18+
19+
bin_file = sys.argv[1]
20+
header_file_path = sys.argv[2]
21+
header_file = Path(header_file_path).name
22+
array_name = header_file.replace('.', '_')
23+
24+
c_array = binary_to_c_array(bin_file, array_name)
25+
with open(header_file_path, 'w') as f:
26+
f.write("// generated by convert_binary_to_header.py\n")
27+
f.write(c_array)

0 commit comments

Comments
 (0)