Skip to content

Commit f862433

Browse files
abadamsclaude
andauthored
cast<t>(immediate) should be make_const(t, immediate) (#9061)
Doing a Halide cast of a c++ int constructs an immediate Expr (e.g. IntImm) and then eagerly folds it to a different type of immediate in the Cast constructor. It's better to just construct the immediate using the desired type to begin with. Co-authored-by: Claude <noreply@anthropic.com>
1 parent c75d258 commit f862433

10 files changed

Lines changed: 33 additions & 33 deletions

src/AddAtomicMutex.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ class AddAtomicMutex : public IRMutator {
326326
}
327327

328328
if (const std::string *mutex_name = needs_mutex_allocation.find(producer_name)) {
329-
Expr extent = cast<uint64_t>(1); // uint64_t to handle LargeBuffers
329+
Expr extent = make_one(UInt(64)); // uint64_t to handle LargeBuffers
330330
for (const Expr &e : op->extents) {
331331
extent = extent * e;
332332
}
@@ -378,7 +378,7 @@ class AddAtomicMutex : public IRMutator {
378378
if (const std::string *mutex_name = needs_mutex_allocation.find(it->first)) {
379379
// All output buffers in a Tuple have the same extent.
380380
OutputImageParam output_buffer = Func(f).output_buffers()[0];
381-
Expr extent = cast<uint64_t>(1); // uint64_t to handle LargeBuffers
381+
Expr extent = make_one(UInt(64)); // uint64_t to handle LargeBuffers
382382
for (int i = 0; i < output_buffer.dimensions(); i++) {
383383
extent *= output_buffer.dim(i).extent();
384384
}

src/Associativity.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -537,8 +537,8 @@ void associativity_test() {
537537
Expr f_call_0 = Call::make(t, "f", {x_idx}, Call::CallType::Halide, FunctionPtr(), 0);
538538

539539
for (const Expr &e : {cast<uint8_t>(min(cast<uint16_t>(x) + y, 255)),
540-
select(x > 255 - y, cast<uint8_t>(255), y),
541-
select(x < -y, y, cast<uint8_t>(255)),
540+
select(x > 255 - y, make_const(UInt(8), 255), y),
541+
select(x < -y, y, make_const(UInt(8), 255)),
542542
saturating_add(x, y),
543543
saturating_add(y, x),
544544
saturating_cast<uint8_t>(widening_add(x, y))}) {

src/BoundaryConditions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ Func constant_exterior(const Func &source, const Tuple &value,
4545
<< ") than dimensions (" << args.size()
4646
<< ") Func " << source.name() << " has.\n";
4747

48-
Expr out_of_bounds = cast<bool>(false);
48+
Expr out_of_bounds = Halide::Internal::make_zero(Bool());
4949
for (size_t i = 0; i < bounds.size(); i++) {
5050
const Var &arg_var = args[i];
5151
Expr min = bounds[i].min;

src/Bounds.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3623,22 +3623,22 @@ void bounds_test() {
36233623
scope.pop("x");
36243624

36253625
// Check some bitwise ops.
3626-
check(scope, (cast<uint8_t>(x) & cast<uint8_t>(7)), u8(0), u8(7));
3627-
check(scope, (cast<uint8_t>(3) & cast<uint8_t>(2)), u8(2), u8(2));
3628-
check(scope, (cast<uint8_t>(1) | cast<uint8_t>(2)), u8(3), u8(3));
3629-
check(scope, (cast<uint8_t>(3) ^ cast<uint8_t>(2)), u8(1), u8(1));
3630-
check(scope, (~cast<uint8_t>(3)), u8(0xfc), u8(0xfc));
3626+
check(scope, (cast<uint8_t>(x) & make_const(UInt(8), 7)), u8(0), u8(7));
3627+
check(scope, (make_const(UInt(8), 3) & make_const(UInt(8), 2)), u8(2), u8(2));
3628+
check(scope, (make_one(UInt(8)) | make_const(UInt(8), 2)), u8(3), u8(3));
3629+
check(scope, (make_const(UInt(8), 3) ^ make_const(UInt(8), 2)), u8(1), u8(1));
3630+
check(scope, (~make_const(UInt(8), 3)), u8(0xfc), u8(0xfc));
36313631
check(scope, cast<uint8_t>(x + 5) & cast<uint8_t>(x + 3), u8(0), u8(13));
36323632
check(scope, cast<int8_t>(x - 5) & cast<int8_t>(x + 3), i8(0), i8(13));
36333633
check(scope, cast<int8_t>(2 * x - 5) & cast<int8_t>(x - 3), i8(-128), i8(15));
36343634
check(scope, cast<uint8_t>(x + 5) | cast<uint8_t>(x + 3), u8(5), u8(255));
36353635
check(scope, cast<int8_t>(x + 5) | cast<int8_t>(x + 3), i8(3), i8(127));
36363636
check(scope, ~cast<uint8_t>(x), u8(-11), u8(-1));
3637-
check(scope, (cast<uint8_t>(x) >> cast<uint8_t>(1)), u8(0), u8(5));
3638-
check(scope, (cast<uint8_t>(10) >> cast<uint8_t>(1)), u8(5), u8(5));
3639-
check(scope, (cast<uint8_t>(x + 3) << cast<uint8_t>(1)), u8(6), u8(26));
3640-
check(scope, (cast<uint8_t>(x + 3) << cast<uint8_t>(7)), u8(0), u8(255)); // Overflows
3641-
check(scope, (cast<uint8_t>(5) << cast<uint8_t>(1)), u8(10), u8(10));
3637+
check(scope, (cast<uint8_t>(x) >> make_one(UInt(8))), u8(0), u8(5));
3638+
check(scope, (make_const(UInt(8), 10) >> make_one(UInt(8))), u8(5), u8(5));
3639+
check(scope, (cast<uint8_t>(x + 3) << make_one(UInt(8))), u8(6), u8(26));
3640+
check(scope, (cast<uint8_t>(x + 3) << make_const(UInt(8), 7)), u8(0), u8(255)); // Overflows
3641+
check(scope, (make_const(UInt(8), 5) << make_one(UInt(8))), u8(10), u8(10));
36423642
check(scope, (x << 12), 0, 10 << 12);
36433643
check(scope, x & 4095, 0, 10); // LHS known to be positive
36443644
check(scope, x & 123, 0, 10); // Doesn't have to be a precise bitmask
@@ -3712,27 +3712,27 @@ void bounds_test() {
37123712
u16(0), u16(4095));
37133713

37143714
check(scope,
3715-
cast<uint8_t>(clamp(cast<uint16_t>(x ^ y), cast<uint16_t>(0), cast<uint16_t>(128))),
3715+
cast<uint8_t>(clamp(cast<uint16_t>(x ^ y), make_zero(UInt(16)), make_const(UInt(16), 128))),
37163716
u8(0), u8(128));
37173717

37183718
Expr u8_1 = cast<uint8_t>(Load::make(Int(8), "buf", x, Buffer<>(), Parameter(), const_true(), ModulusRemainder()));
37193719
Expr u8_2 = cast<uint8_t>(Load::make(Int(8), "buf", x + 17, Buffer<>(), Parameter(), const_true(), ModulusRemainder()));
37203720
check(scope, cast<uint16_t>(u8_1) + cast<uint16_t>(u8_2),
37213721
u16(0), u16(255 * 2));
37223722

3723-
check(scope, saturating_cast<uint8_t>(clamp(x, 5, 10)), cast<uint8_t>(5), cast<uint8_t>(10));
3723+
check(scope, saturating_cast<uint8_t>(clamp(x, 5, 10)), make_const(UInt(8), 5), make_const(UInt(8), 10));
37243724
{
37253725
scope.push("x", Interval(UInt(32).min(), UInt(32).max()));
3726-
check(scope, saturating_cast<int32_t>(max(cast<uint32_t>(x), cast<uint32_t>(5))), cast<int32_t>(5), Int(32).max());
3726+
check(scope, saturating_cast<int32_t>(max(cast<uint32_t>(x), make_const(UInt(32), 5))), make_const(Int(32), 5), Int(32).max());
37273727
scope.pop("x");
37283728
}
37293729
{
37303730
Expr z = Variable::make(Float(32), "z");
3731-
scope.push("z", Interval(cast<float>(-1), cast<float>(1)));
3732-
check(scope, saturating_cast<int32_t>(z), cast<int32_t>(-1), cast<int32_t>(1));
3733-
check(scope, saturating_cast<double>(z), cast<double>(-1), cast<double>(1));
3734-
check(scope, saturating_cast<float16_t>(z), cast<float16_t>(-1), cast<float16_t>(1));
3735-
check(scope, saturating_cast<uint8_t>(z), cast<uint8_t>(0), cast<uint8_t>(1));
3731+
scope.push("z", Interval(make_const(Float(32), -1), make_one(Float(32))));
3732+
check(scope, saturating_cast<int32_t>(z), make_const(Int(32), -1), make_one(Int(32)));
3733+
check(scope, saturating_cast<double>(z), make_const(Float(64), -1), make_one(Float(64)));
3734+
check(scope, saturating_cast<float16_t>(z), make_const(Float(16), -1), make_one(Float(16)));
3735+
check(scope, saturating_cast<uint8_t>(z), make_zero(UInt(8)), make_one(UInt(8)));
37363736
scope.pop("z");
37373737
}
37383738
{

src/CodeGen_Hexagon.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ Stmt acquire_hvx_context(Stmt stmt, const Target &target) {
170170
// Modify the stmt to add a call to halide_qurt_hvx_lock, and
171171
// register a destructor to call halide_qurt_hvx_unlock.
172172
Stmt check_hvx_lock = call_halide_qurt_hvx_lock(target);
173-
Expr dummy_obj = reinterpret(Handle(), cast<uint64_t>(1));
173+
Expr dummy_obj = reinterpret(Handle(), make_one(UInt(64)));
174174
Expr hvx_unlock =
175175
Call::make(Handle(), Call::register_destructor,
176176
{Expr("halide_qurt_hvx_unlock_as_destructor"), dummy_obj},

src/CodeGen_LLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3989,7 +3989,7 @@ void CodeGen_LLVM::codegen_asserts(const vector<const AssertStmt *> &asserts) {
39893989

39903990
// Mix all the conditions together into a bitmask
39913991

3992-
Expr bitmask = cast<uint64_t>(1) << 63;
3992+
Expr bitmask = make_const(UInt(64), ((uint64_t)1) << 63);
39933993
for (size_t i = 0; i < asserts.size(); i++) {
39943994
bitmask = bitmask | (cast<uint64_t>(!asserts[i]->condition) << i);
39953995
}

src/Generator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2167,7 +2167,7 @@ void generator_test() {
21672167
const std::vector<uint64_t> a = {1, 2, 3, 4};
21682168
Var x;
21692169
Func fn_typed, fn_untyped;
2170-
fn_typed(x) = cast<int16_t>(38);
2170+
fn_typed(x) = make_const(Int(16), 38);
21712171
fn_untyped(x) = 32.f;
21722172
const std::vector<Func> fn_array = {fn_untyped, fn_untyped};
21732173

src/Generator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2065,7 +2065,7 @@ class GeneratorInput_Scalar : public GeneratorInputImpl<T, Expr> {
20652065
void set_estimate(const TBase &value) {
20662066
this->check_gio_access();
20672067
user_assert(value == nullptr) << "nullptr is the only valid estimate for Input<PointerType>";
2068-
Expr e = reinterpret(type_of<T2>(), cast<uint64_t>(0));
2068+
Expr e = reinterpret(type_of<T2>(), make_zero(UInt(64)));
20692069
for (Parameter &p : this->parameters_) {
20702070
p.set_estimate(e);
20712071
}

src/OffloadGPULoops.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,18 +200,18 @@ class InjectGpuOffload : public IRMutator {
200200
arg_types_or_sizes.emplace_back(cast(target_size_t_type, i.is_buffer ? 8 : i.type.bytes()));
201201
}
202202

203-
arg_is_buffer.emplace_back(cast<uint8_t>(i.is_buffer));
203+
arg_is_buffer.emplace_back(make_const(UInt(8), (int)i.is_buffer));
204204
}
205205

206206
// nullptr-terminate the lists
207-
args.emplace_back(reinterpret(Handle(), cast<uint64_t>(0)));
207+
args.emplace_back(reinterpret(Handle(), make_zero(UInt(64))));
208208
if (runtime_run_takes_types) {
209209
internal_assert(sizeof(halide_type_t) == sizeof(uint32_t));
210-
arg_types_or_sizes.emplace_back(cast<uint32_t>(0));
210+
arg_types_or_sizes.emplace_back(make_zero(UInt(32)));
211211
} else {
212212
arg_types_or_sizes.emplace_back(cast(target_size_t_type, 0));
213213
}
214-
arg_is_buffer.emplace_back(cast<uint8_t>(0));
214+
arg_is_buffer.emplace_back(make_zero(UInt(8)));
215215

216216
debug(3) << "bounds.num_blocks[0] = " << bounds.num_blocks[0] << "\n";
217217
debug(3) << "bounds.num_blocks[1] = " << bounds.num_blocks[1] << "\n";

src/Profiling.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ class InjectProfiling : public IRMutator {
201201

202202
Stmt unconditionally_set_current_func(int id) {
203203
Stmt s = Evaluate::make(Call::make(Int(32), "halide_profiler_set_current_func",
204-
{profiler_instance, id, reinterpret(Handle(), cast<uint64_t>(0))}, Call::Extern));
204+
{profiler_instance, id, reinterpret(Handle(), make_zero(UInt(64)))}, Call::Extern));
205205
return s;
206206
}
207207

@@ -210,7 +210,7 @@ class InjectProfiling : public IRMutator {
210210
return Evaluate::make(0);
211211
}
212212
most_recently_set_func = id;
213-
Expr last_arg = in_leaf_task ? profiler_local_sampling_token : reinterpret(Handle(), cast<uint64_t>(0));
213+
Expr last_arg = in_leaf_task ? profiler_local_sampling_token : reinterpret(Handle(), make_zero(UInt(64)));
214214
// This call gets inlined and becomes a single store instruction.
215215
Stmt s = Evaluate::make(Call::make(Int(32), "halide_profiler_set_current_func",
216216
{profiler_instance, id, last_arg}, Call::Extern));

0 commit comments

Comments
 (0)