Skip to content

Commit e4b78c7

Browse files
Fix UB in float-to-unsigned scalar type conversions
Signed-off-by: shivansh023023 <singhshivansh023@gmail.com>
1 parent e6391e9 commit e4b78c7

4 files changed

Lines changed: 126 additions & 3 deletions

File tree

repro_uint8_bug.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""Minimal reproduction of the uint8 type closure conversion bug."""
2+
import warp as wp
3+
4+
wp.init()
5+
6+
def create_type_closure_scalar(scalar_type):
7+
@wp.kernel
8+
def k(input: float, expected: float):
9+
x = scalar_type(input)
10+
wp.expect_eq(float(x), expected)
11+
return k
12+
13+
# These work fine (int, float closures)
14+
type_closure_kernel_int = create_type_closure_scalar(int)
15+
type_closure_kernel_float = create_type_closure_scalar(float)
16+
17+
# This is the broken one
18+
type_closure_kernel_uint8 = create_type_closure_scalar(wp.uint8)
19+
20+
print("Testing int closure...")
21+
wp.launch(type_closure_kernel_int, dim=1, inputs=[-1.5, -1.0], device="cpu")
22+
wp.synchronize()
23+
print(" PASSED")
24+
25+
print("Testing float closure...")
26+
wp.launch(type_closure_kernel_float, dim=1, inputs=[-1.5, -1.5], device="cpu")
27+
wp.synchronize()
28+
print(" PASSED")
29+
30+
print("Testing uint8 closure...")
31+
try:
32+
wp.launch(type_closure_kernel_uint8, dim=1, inputs=[-1.5, 255.0], device="cpu")
33+
wp.synchronize()
34+
print(" PASSED")
35+
except Exception as e:
36+
print(f" FAILED with exception: {type(e).__name__}: {e}")

warp/_src/builtins.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1007,8 +1007,17 @@ def get_diag_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str,
10071007

10081008
# scalar type constructors between all storage / compute types
10091009
scalar_types_all = [*scalar_types, bool, int, float]
1010+
1011+
unsigned_int_types = (uint8, uint16, uint32, uint64)
1012+
float_src_types = {float16: "float16", float32: "float32", float64: "float64", float: "float32"}
1013+
10101014
for t in scalar_types_all:
10111015
for u in scalar_types_all:
1016+
# Use safe cast for float -> unsigned to avoid C++ UB
1017+
safe_native = None
1018+
if t in unsigned_int_types and u in float_src_types:
1019+
safe_native = f"{float_src_types[u]}_to_{t.__name__}"
1020+
10121021
add_builtin(
10131022
t.__name__,
10141023
input_types={"a": u},
@@ -1017,7 +1026,8 @@ def get_diag_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str,
10171026
hidden=True,
10181027
group="Scalar Math",
10191028
export=False,
1020-
namespace="wp::" if t is not bool else "",
1029+
namespace="wp::" if t is not bool and not safe_native else "",
1030+
native_func=safe_native if safe_native else t.__name__,
10211031
)
10221032

10231033

warp/native/builtin.h

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,46 @@ typedef uint64_t uint64;
105105
typedef const char* str;
106106

107107

108+
// Float-to-unsigned conversions: cast through int64 to avoid C++ UB
109+
// (C++ 7.3.11: float -> unsigned is UB when truncated value is negative)
110+
template <typename F> CUDA_CALLABLE inline int64 safe_float_to_int64(F x)
111+
{
112+
if (!(x == x))
113+
return 0;
114+
constexpr F min_int64 = static_cast<F>(-9223372036854775808.0); // -2^63
115+
constexpr F max_overflow = static_cast<F>(9223372036854775808.0); // 2^63
116+
if (x < min_int64)
117+
return -9223372036854775807LL - 1LL;
118+
if (x >= max_overflow)
119+
return 9223372036854775807LL;
120+
return static_cast<int64>(x);
121+
}
122+
123+
template <typename F> CUDA_CALLABLE inline uint64 safe_float_to_uint64(F x)
124+
{
125+
if (!(x == x))
126+
return 0;
127+
if (x <= 0.0)
128+
return static_cast<uint64>(safe_float_to_int64(x));
129+
constexpr F pow2_63 = static_cast<F>(9223372036854775808.0); // 2^63
130+
constexpr F overflow_uint64 = static_cast<F>(18446744073709551616.0); // 2^64
131+
if (x >= overflow_uint64)
132+
return 18446744073709551615ULL;
133+
if (x >= pow2_63)
134+
return static_cast<uint64>(safe_float_to_int64(x - pow2_63)) + 9223372036854775808ULL;
135+
return static_cast<uint64>(safe_float_to_int64(x));
136+
}
137+
138+
CUDA_CALLABLE inline uint8 float32_to_uint8(float32 x) { return static_cast<uint8>(safe_float_to_int64(x)); }
139+
CUDA_CALLABLE inline uint8 float64_to_uint8(float64 x) { return static_cast<uint8>(safe_float_to_int64(x)); }
140+
CUDA_CALLABLE inline uint16 float32_to_uint16(float32 x) { return static_cast<uint16>(safe_float_to_int64(x)); }
141+
CUDA_CALLABLE inline uint16 float64_to_uint16(float64 x) { return static_cast<uint16>(safe_float_to_int64(x)); }
142+
CUDA_CALLABLE inline uint32 float32_to_uint32(float32 x) { return static_cast<uint32>(safe_float_to_int64(x)); }
143+
CUDA_CALLABLE inline uint32 float64_to_uint32(float64 x) { return static_cast<uint32>(safe_float_to_int64(x)); }
144+
CUDA_CALLABLE inline uint64 float32_to_uint64(float32 x) { return safe_float_to_uint64(x); }
145+
CUDA_CALLABLE inline uint64 float64_to_uint64(float64 x) { return safe_float_to_uint64(x); }
146+
147+
108148
struct half;
109149

110150
CUDA_CALLABLE half float_to_half(float x);
@@ -182,6 +222,12 @@ static_assert(sizeof(half) == 2, "Size of half / float16 type must be 2-bytes");
182222

183223
typedef half float16;
184224

225+
// Handle float16 source
226+
CUDA_CALLABLE inline uint8 float16_to_uint8(float16 x) { return float32_to_uint8(float32(x)); }
227+
CUDA_CALLABLE inline uint16 float16_to_uint16(float16 x) { return float32_to_uint16(float32(x)); }
228+
CUDA_CALLABLE inline uint32 float16_to_uint32(float16 x) { return float32_to_uint32(float32(x)); }
229+
CUDA_CALLABLE inline uint64 float16_to_uint64(float16 x) { return float32_to_uint64(float32(x)); }
230+
185231
// Approximate division/reciprocal intrinsics
186232
#if defined(__CUDA_ARCH__)
187233

@@ -337,6 +383,19 @@ template <typename T> CUDA_CALLABLE inline void adj_float16(T x, T& adj_x, float
337383
template <typename T> CUDA_CALLABLE inline void adj_float32(T x, T& adj_x, float32 adj_ret) { adj_x += T(adj_ret); }
338384
template <typename T> CUDA_CALLABLE inline void adj_float64(T x, T& adj_x, float64 adj_ret) { adj_x += T(adj_ret); }
339385

386+
// Adjoint stubs for safe float-to-unsigned casts (no-op since they are cast functions)
387+
template <typename T> CUDA_CALLABLE inline void adj_float32_to_uint8(T, T&, uint8) { }
388+
template <typename T> CUDA_CALLABLE inline void adj_float64_to_uint8(T, T&, uint8) { }
389+
template <typename T> CUDA_CALLABLE inline void adj_float16_to_uint8(T, T&, uint8) { }
390+
template <typename T> CUDA_CALLABLE inline void adj_float32_to_uint16(T, T&, uint16) { }
391+
template <typename T> CUDA_CALLABLE inline void adj_float64_to_uint16(T, T&, uint16) { }
392+
template <typename T> CUDA_CALLABLE inline void adj_float16_to_uint16(T, T&, uint16) { }
393+
template <typename T> CUDA_CALLABLE inline void adj_float32_to_uint32(T, T&, uint32) { }
394+
template <typename T> CUDA_CALLABLE inline void adj_float64_to_uint32(T, T&, uint32) { }
395+
template <typename T> CUDA_CALLABLE inline void adj_float16_to_uint32(T, T&, uint32) { }
396+
template <typename T> CUDA_CALLABLE inline void adj_float32_to_uint64(T, T&, uint64) { }
397+
template <typename T> CUDA_CALLABLE inline void adj_float64_to_uint64(T, T&, uint64) { }
398+
template <typename T> CUDA_CALLABLE inline void adj_float16_to_uint64(T, T&, uint64) { }
340399

341400
#define kEps 0.0f
342401

warp/tests/test_codegen_instancing.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,13 +1091,31 @@ def k(input: float, expected: float):
10911091
type_closure_kernel_uint8 = create_type_closure_scalar(wp.uint8)
10921092

10931093

1094+
def create_type_closure_scalar_f64(scalar_type):
1095+
@wp.kernel
1096+
def k(input: wp.float64, expected: wp.float64):
1097+
x = scalar_type(input)
1098+
wp.expect_eq(wp.float64(x), expected)
1099+
1100+
return k
1101+
1102+
1103+
type_closure_kernel_uint64_f64 = create_type_closure_scalar_f64(wp.uint64)
1104+
1105+
10941106
def test_type_closure_scalar(test, device):
10951107
with wp.ScopedDevice(device):
10961108
wp.launch(type_closure_kernel_int, dim=1, inputs=[-1.5, -1.0])
10971109
wp.launch(type_closure_kernel_float, dim=1, inputs=[-1.5, -1.5])
10981110

1099-
# FIXME: a problem with type conversions breaks this case
1100-
# wp.launch(type_closure_kernel_uint8, dim=1, inputs=[-1.5, 255.0])
1111+
wp.launch(type_closure_kernel_uint8, dim=1, inputs=[-1.5, 255.0])
1112+
wp.launch(type_closure_kernel_uint8, dim=1, inputs=[-0.1, 0.0])
1113+
wp.launch(type_closure_kernel_uint8, dim=1, inputs=[255.1, 255.0])
1114+
wp.launch(type_closure_kernel_uint8, dim=1, inputs=[128.0, 128.0])
1115+
wp.launch(type_closure_kernel_uint8, dim=1, inputs=[-100.0, 156.0])
1116+
1117+
# Test boundary cases for uint64 truncation safety with float64 precision
1118+
wp.launch(type_closure_kernel_uint64_f64, dim=1, inputs=[9223372036854774784.0, 9223372036854774784.0])
11011119

11021120

11031121
# =======================================================================

0 commit comments

Comments
 (0)