Skip to content

Commit e13bb38

Browse files
committed
test: add dtype_dispatch tests
1 parent fda7ccb commit e13bb38

File tree

3 files changed

+180
-0
lines changed

3 files changed

+180
-0
lines changed

CMakeLists.txt

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,35 @@ link_infini_train_exe(test_precision_check)
204204
add_executable(test_lora test/lora/test_lora.cc)
205205
link_infini_train_exe(test_lora)
206206

207+
add_executable(test_dtype_dispatch test/dispatch/test_dtype_dispatch.cc)
208+
link_infini_train_exe(test_dtype_dispatch)
209+
210+
# Negative compile test: missing dtype registration must fail at compile time.
211+
set(DTYPE_DISPATCH_COMPILE_FAIL_SOURCE
212+
${PROJECT_SOURCE_DIR}/test/dispatch/test_dtype_dispatch_compile_fail.cc)
213+
214+
try_compile(DTYPE_DISPATCH_COMPILE_UNEXPECTEDLY_SUCCEEDED
215+
${CMAKE_BINARY_DIR}/CMakeFiles/try_compile_dtype_dispatch_missing_map
216+
SOURCES ${DTYPE_DISPATCH_COMPILE_FAIL_SOURCE}
217+
CMAKE_FLAGS
218+
"-DCMAKE_CXX_STANDARD=${CMAKE_CXX_STANDARD}"
219+
"-DCMAKE_CXX_STANDARD_REQUIRED=ON"
220+
"-DCMAKE_CXX_EXTENSIONS=OFF"
221+
"-DCMAKE_CXX_FLAGS=-I${PROJECT_SOURCE_DIR}"
222+
OUTPUT_VARIABLE DTYPE_DISPATCH_TRY_COMPILE_OUTPUT
223+
)
224+
225+
if(DTYPE_DISPATCH_COMPILE_UNEXPECTEDLY_SUCCEEDED)
226+
message(FATAL_ERROR
227+
"dtype dispatch compile-fail test unexpectedly succeeded.\n"
228+
"Source: ${DTYPE_DISPATCH_COMPILE_FAIL_SOURCE}\n"
229+
"Output:\n${DTYPE_DISPATCH_TRY_COMPILE_OUTPUT}")
230+
endif()
231+
232+
add_custom_target(test_dtype_dispatch_compile_fail
233+
COMMAND ${CMAKE_COMMAND} -E echo
234+
"dtype dispatch compile-fail check passed (missing dtype registration correctly fails to compile)."
235+
VERBATIM
236+
)
237+
238+
add_dependencies(test_dtype_dispatch test_dtype_dispatch_compile_fail)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#include <cstdlib>
2+
#include <iostream>
3+
#include <string>
4+
#include <type_traits>
5+
6+
#include "glog/logging.h"
7+
8+
#include "infini_train/include/datatype.h"
9+
#include "infini_train/include/dtype_dispatch.h"
10+
11+
#include "infini_train/src/core/runtime/cpu/cpu_dispatch.h"
12+
13+
using namespace infini_train;
14+
15+
// ============================================================================
16+
// Test 1: HasMappedType_v intercepts backends missing FP16 / BF16
17+
// ============================================================================
18+
19+
// A backend TypeMap that only registers kFLOAT32 — FP16 / BF16 are absent.
20+
template <DataType DType> struct LowPrecisionAbsentTypeMap;
21+
22+
template <> struct LowPrecisionAbsentTypeMap<DataType::kFLOAT32> {
23+
using type = float;
24+
};
25+
26+
static_assert(HasMappedType_v<LowPrecisionAbsentTypeMap, DataType::kFLOAT32>,
27+
"sanity: registered dtype must be detected as present");
28+
static_assert(!HasMappedType_v<LowPrecisionAbsentTypeMap, DataType::kFLOAT16>,
29+
"unregistered kFLOAT16 must be intercepted by HasMappedType_v");
30+
static_assert(!HasMappedType_v<LowPrecisionAbsentTypeMap, DataType::kBFLOAT16>,
31+
"unregistered kBFLOAT16 must be intercepted by HasMappedType_v");
32+
33+
// ============================================================================
34+
// Test 2: CpuTypeMap resolves FP16 / BF16 to framework scalar types
35+
// ============================================================================
36+
37+
static_assert(std::is_same_v<MappedType_t<core::cpu::CpuTypeMap, DataType::kFLOAT16>, FP16>,
38+
"CpuTypeMap<kFLOAT16> must resolve to framework FP16");
39+
static_assert(std::is_same_v<MappedType_t<core::cpu::CpuTypeMap, DataType::kBFLOAT16>, BF16>,
40+
"CpuTypeMap<kBFLOAT16> must resolve to framework BF16");
41+
42+
// ============================================================================
43+
// Test 3: Runtime dispatch of kFLOAT16 / kBFLOAT16
44+
// ============================================================================
45+
46+
void TestRuntimeDispatchLowPrecision() {
47+
std::cout << "\n=== Test 3: Runtime dispatch of kFLOAT16 / kBFLOAT16 ===" << std::endl;
48+
49+
// kFLOAT16 must dispatch to framework FP16
50+
bool called_fp16 = false;
51+
core::cpu::DispatchCpuFunc<DataType::kFLOAT16, DataType::kBFLOAT16>(
52+
DataType::kFLOAT16,
53+
[&called_fp16]<typename T>() {
54+
if constexpr (std::is_same_v<T, FP16>) {
55+
called_fp16 = true;
56+
}
57+
},
58+
"dispatch kFLOAT16");
59+
CHECK(called_fp16) << "DispatchCpuFunc did not invoke functor for kFLOAT16";
60+
61+
// kBFLOAT16 must dispatch to framework BF16
62+
bool called_bf16 = false;
63+
core::cpu::DispatchCpuFunc<DataType::kFLOAT16, DataType::kBFLOAT16>(
64+
DataType::kBFLOAT16,
65+
[&called_bf16]<typename T>() {
66+
if constexpr (std::is_same_v<T, BF16>) {
67+
called_bf16 = true;
68+
}
69+
},
70+
"dispatch kBFLOAT16");
71+
CHECK(called_bf16) << "DispatchCpuFunc did not invoke functor for kBFLOAT16";
72+
73+
std::cout << "Low-precision dispatch OK." << std::endl;
74+
}
75+
76+
// ============================================================================
77+
// Test 4: Runtime dispatch of a low-precision dtype outside AllowedDTypes
78+
// must fatal
79+
// ============================================================================
80+
81+
// Sub-process entry: tries to dispatch kFLOAT16 with only kFLOAT32 allowed.
82+
void TriggerRuntimeUnsupportedLowPrecisionFatal() {
83+
core::cpu::DispatchCpuFunc<DataType::kFLOAT32>(
84+
DataType::kFLOAT16,
85+
[]<typename T>() { (void)sizeof(T); },
86+
"intercept kFLOAT16 when only kFLOAT32 is allowed");
87+
}
88+
89+
void TestRuntimeInterceptLowPrecision(const char *argv0) {
90+
std::cout << "\n=== Test 4: Runtime intercept of kFLOAT16 outside AllowedDTypes ===" << std::endl;
91+
const std::string cmd = std::string(argv0) + " --expect-runtime-fatal > /dev/null 2>&1";
92+
const int status = std::system(cmd.c_str());
93+
CHECK_NE(status, 0) << "Expected non-zero exit when dispatching an unallowed low-precision dtype";
94+
std::cout << "Low-precision runtime intercept OK." << std::endl;
95+
}
96+
97+
// ============================================================================
98+
// Main
99+
// ============================================================================
100+
101+
int main(int argc, char *argv[]) {
102+
google::InitGoogleLogging(argv[0]);
103+
104+
if (argc > 1 && std::string(argv[1]) == "--expect-runtime-fatal") {
105+
TriggerRuntimeUnsupportedLowPrecisionFatal();
106+
return 0;
107+
}
108+
109+
std::cout << "========================================" << std::endl;
110+
std::cout << " Low-precision Dtype Dispatch Test Suite" << std::endl;
111+
std::cout << "========================================" << std::endl;
112+
113+
std::cout << "Compile-time checks: PASSED" << std::endl;
114+
115+
TestRuntimeDispatchLowPrecision();
116+
TestRuntimeInterceptLowPrecision(argv[0]);
117+
118+
std::cout << "\nAll low-precision dtype dispatch tests passed." << std::endl;
119+
return 0;
120+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#include "infini_train/include/datatype.h"
2+
#include "infini_train/include/dtype_dispatch.h"
3+
4+
using namespace infini_train;
5+
6+
// ============================================================================
7+
// Compile-fail: dispatching an unregistered low-precision dtype must be
8+
// intercepted at compile time
9+
// ============================================================================
10+
11+
// Models a backend that has registered standard floating types but has NOT
12+
// yet provided a mapping for the low-precision dtypes FP16 / BF16.
13+
template <DataType DType> struct LowPrecisionMissingTypeMap;
14+
15+
template <> struct LowPrecisionMissingTypeMap<DataType::kFLOAT32> {
16+
using type = float;
17+
};
18+
19+
int main() {
20+
// Dispatching kFLOAT16 through LowPrecisionMissingTypeMap must trigger the
21+
// static_assert inside DispatchByTypeMap, failing this translation unit
22+
// before MappedType_t<TypeMap, kFLOAT16> is ever instantiated.
23+
DispatchByTypeMap<LowPrecisionMissingTypeMap, DataType::kFLOAT16>(
24+
DataType::kFLOAT16,
25+
[]<typename T>() { (void)sizeof(T); },
26+
"compile-fail: unregistered low-precision dtype");
27+
return 0;
28+
}

0 commit comments

Comments
 (0)