Skip to content

Commit cba6542

Browse files
authored
Fix double-lookups into Scope and map (#9062)
In several places we were checking something was in a scope, and if so, looking it up again to actually do something with it. We should use ::find (or ::shallow_find for a mutable reference into a Scope) No functional impact.
1 parent f862433 commit cba6542

7 files changed

Lines changed: 45 additions & 32 deletions

src/CodeGen_D3D12Compute_Dev.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -684,8 +684,9 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Store *op) {
684684
bool shared_promotion_required = false;
685685
string promotion_str = "";
686686
if (groupshared_allocations.contains(op->name)) {
687-
internal_assert(allocations.contains(op->name));
688-
Type promoted_type = allocations.get(op->name).type;
687+
const auto *alloc = allocations.find(op->name);
688+
internal_assert(alloc);
689+
Type promoted_type = alloc->type;
689690
if (promoted_type != op->value.type()) {
690691
shared_promotion_required = true;
691692
// NOTE(marcos): might need to resort to StoragePackUnpack::pack_store() here

src/CompilerLogger.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -271,11 +271,13 @@ std::ostream &JSONCompilerLogger::emit_to_stream(std::ostream &o) {
271271
}
272272

273273
// If these are present, emit them, even if value is zero
274-
if (compilation_time.count(Phase::HalideLowering)) {
275-
emit_key_value(o, indent, "compilation_time_halide_lowering", compilation_time[Phase::HalideLowering]);
274+
if (auto it = compilation_time.find(Phase::HalideLowering);
275+
it != compilation_time.end()) {
276+
emit_key_value(o, indent, "compilation_time_halide_lowering", it->second);
276277
}
277-
if (compilation_time.count(Phase::LLVM)) {
278-
emit_key_value(o, indent, "compilation_time_llvm", compilation_time[Phase::LLVM]);
278+
if (auto it = compilation_time.find(Phase::LLVM);
279+
it != compilation_time.end()) {
280+
emit_key_value(o, indent, "compilation_time_llvm", it->second);
279281
}
280282

281283
if (!matched_simplifier_rules.empty()) {

src/HexagonOptimize.cpp

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,8 +1521,8 @@ class EliminateInterleaves : public IRMutator {
15211521
}
15221522

15231523
if (const Load *load = x.as<Load>()) {
1524-
if (buffers.contains(load->name)) {
1525-
BufferState &state = buffers.ref(load->name);
1524+
if (auto *state_ptr = buffers.shallow_find(load->name)) {
1525+
BufferState &state = *state_ptr;
15261526
if (state != BufferState::NotInterleaved) {
15271527
state = BufferState::Interleaved;
15281528
return x;
@@ -1814,23 +1814,27 @@ class EliminateInterleaves : public IRMutator {
18141814
op->func, op->value_index, op->image, op->param);
18151815
// Add the interleave back to the result of the call.
18161816
return native_interleave(expr);
1817-
} else if (deinterleaving_alts.find(op->name) != deinterleaving_alts.end() && hvx_target >= deinterleaving_alts[op->name].first &&
1818-
1817+
} else if (auto it = deinterleaving_alts.find(op->name);
1818+
it != deinterleaving_alts.end() &&
1819+
hvx_target >= it->second.first &&
18191820
yields_removable_interleave(args)) {
18201821
// This call has a deinterleaving alternative, and the
18211822
// arguments are interleaved, so we should use the
18221823
// alternative instead.
18231824
for (Expr &i : args) {
18241825
i = remove_interleave(i);
18251826
}
1826-
return Call::make(op->type, deinterleaving_alts[op->name].second, args, op->call_type);
1827-
} else if (interleaving_alts.count(op->name) && hvx_target >= interleaving_alts[op->name].first && is_native_deinterleave(args[0])) {
1827+
return Call::make(op->type, it->second.second, args, op->call_type);
1828+
} else if (auto it = interleaving_alts.find(op->name);
1829+
it != interleaving_alts.end() &&
1830+
hvx_target >= it->second.first &&
1831+
is_native_deinterleave(args[0])) {
18281832
// This is an interleaving alternative with a
18291833
// deinterleave, which can be generated when we
18301834
// deinterleave storage. Revert back to the interleaving
18311835
// op so we can remove the deinterleave.
18321836
Expr arg = args[0].as<Call>()->args[0];
1833-
return Call::make(op->type, interleaving_alts[op->name].second, {arg}, op->call_type,
1837+
return Call::make(op->type, it->second.second, {arg}, op->call_type,
18341838
op->func, op->value_index, op->image, op->param);
18351839
} else if (changed) {
18361840
return Call::make(op->type, op->name, args, op->call_type,
@@ -1937,7 +1941,7 @@ class EliminateInterleaves : public IRMutator {
19371941
}
19381942

19391943
Expr visit(const Load *op) override {
1940-
if (buffers.contains(op->name)) {
1944+
if (auto *buf_state = buffers.shallow_find(op->name)) {
19411945
if ((op->type.lanes() * op->type.bits()) % (native_vector_bits * 2) == 0) {
19421946
// This is a double vector load, we might be able to
19431947
// deinterleave the storage of this buffer.
@@ -1959,8 +1963,7 @@ class EliminateInterleaves : public IRMutator {
19591963
} else {
19601964
// This is not a double vector load, so we can't
19611965
// deinterleave the storage of this buffer.
1962-
BufferState &state = buffers.ref(op->name);
1963-
state = BufferState::NotInterleaved;
1966+
*buf_state = BufferState::NotInterleaved;
19641967
}
19651968
}
19661969
Expr expr = IRMutator::visit(op);
@@ -2224,16 +2227,18 @@ class SyncronizationBarriers : public IRMutator {
22242227
// Creates entry in sync map for the stmt requiring a
22252228
// scatter-release instruction before it.
22262229
void check_hazard(const string &name) {
2227-
if (in_flight.find(name) == in_flight.end()) {
2230+
auto it = in_flight.find(name);
2231+
if (it == in_flight.end()) {
22282232
return;
22292233
}
22302234
// Sync Needed. Add the scatter-release before the first different For
22312235
// loop lock between the curr_path and the hazard src location.
2232-
size_t min_size = std::min(in_flight[name].size(), curr_path.size());
2236+
const auto &flight_path = it->second;
2237+
size_t min_size = std::min(flight_path.size(), curr_path.size());
22332238
size_t i = 0;
22342239
// Find the first different For loop block.
22352240
for (; i < min_size; i++) {
2236-
if (in_flight[name][i] != curr_path[i]) {
2241+
if (flight_path[i] != curr_path[i]) {
22372242
break;
22382243
}
22392244
}
@@ -2266,9 +2271,9 @@ class SyncronizationBarriers : public IRMutator {
22662271
curr = &s;
22672272
Stmt new_s = IRMutator::mutate(s);
22682273
// Wrap the stmt with scatter-release if any hazard was detected.
2269-
if (sync.find(&s) != sync.end()) {
2274+
if (auto it = sync.find(&s); it != sync.end()) {
22702275
Stmt scatter_sync =
2271-
Evaluate::make(Call::make(Int(32), Call::hvx_scatter_release, {sync[&s]}, Call::Intrinsic));
2276+
Evaluate::make(Call::make(Int(32), Call::hvx_scatter_release, {it->second}, Call::Intrinsic));
22722277
return Block::make(scatter_sync, new_s);
22732278
}
22742279
return new_s;

src/SlidingWindow.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,11 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator {
270270
for (int i = 0; i < func.dimensions(); i++) {
271271
// Look up the region required of this function's last stage
272272
string var = prefix + func_args[i];
273-
internal_assert(scope.contains(var + ".min") && scope.contains(var + ".max"));
274-
Expr min_req = scope.get(var + ".min");
275-
Expr max_req = scope.get(var + ".max");
273+
const auto *min_val = scope.find(var + ".min");
274+
const auto *max_val = scope.find(var + ".max");
275+
internal_assert(min_val && max_val);
276+
Expr min_req = *min_val;
277+
Expr max_req = *max_val;
276278
min_req = expand_expr(min_req, scope);
277279
max_req = expand_expr(max_req, scope);
278280

src/SplitTuples.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,9 @@ class SplitTuples : public IRMutator {
6060
if (op->types.size() > 1) {
6161
// If there is a corresponding HoistedStorage node record the new number of
6262
// realizes.
63-
if (hoisted_tuple_count.count(op->name)) {
64-
hoisted_tuple_count[op->name] = op->types.size();
63+
if (auto it = hoisted_tuple_count.find(op->name);
64+
it != hoisted_tuple_count.end()) {
65+
it->second = op->types.size();
6566
}
6667
// Make a nested set of realize nodes for each tuple element
6768
Stmt body = mutate(op->body);

src/StorageFlattening.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -504,8 +504,9 @@ class HoistStorage : public IRMutator {
504504
}
505505

506506
Stmt visit(const Allocate *op) override {
507-
if (hoisted_storages_map.count(op->name) > 0) {
508-
HoistedStorageData &hoisted_storage_data = hoisted_storages[hoisted_storages_map[op->name]];
507+
if (auto it = hoisted_storages_map.find(op->name);
508+
it != hoisted_storages_map.end()) {
509+
HoistedStorageData &hoisted_storage_data = hoisted_storages[it->second];
509510

510511
auto expand_and_bound = [&](Expr e) {
511512
// Iterate from innermost outwards

src/VectorizeLoops.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -541,8 +541,8 @@ class VectorSubs : public IRMutator {
541541
}
542542

543543
Expr visit(const Variable *op) override {
544-
if (replacements.count(op->name) > 0) {
545-
return replacements[op->name];
544+
if (auto it = replacements.find(op->name); it != replacements.end()) {
545+
return it->second;
546546
} else if (scope.contains(op->name)) {
547547
string widened_name = get_widened_var_name(op->name);
548548
return Variable::make(vector_scope.get(widened_name).type(), widened_name);
@@ -994,11 +994,12 @@ class VectorSubs : public IRMutator {
994994
// them according to the current loop level.
995995
for (const auto &[var, val] : containing_lets) {
996996
// Skip if this var wasn't vectorized.
997-
if (!scope.contains(var)) {
997+
const auto *scope_val = scope.find(var);
998+
if (!scope_val) {
998999
continue;
9991000
}
10001001
string vectorized_name = get_widened_var_name(var);
1001-
Expr vectorized_value = mutate(scope.get(var));
1002+
Expr vectorized_value = mutate(*scope_val);
10021003
vector_scope.push(vectorized_name, vectorized_value);
10031004
}
10041005

0 commit comments

Comments
 (0)