Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/CodeGen_D3D12Compute_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,8 +684,9 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Store *op) {
bool shared_promotion_required = false;
string promotion_str = "";
if (groupshared_allocations.contains(op->name)) {
internal_assert(allocations.contains(op->name));
Type promoted_type = allocations.get(op->name).type;
const auto *alloc = allocations.find(op->name);
internal_assert(alloc);
Type promoted_type = alloc->type;
if (promoted_type != op->value.type()) {
shared_promotion_required = true;
// NOTE(marcos): might need to resort to StoragePackUnpack::pack_store() here
Expand Down
10 changes: 6 additions & 4 deletions src/CompilerLogger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,13 @@ std::ostream &JSONCompilerLogger::emit_to_stream(std::ostream &o) {
}

// If these are present, emit them, even if value is zero
if (compilation_time.count(Phase::HalideLowering)) {
emit_key_value(o, indent, "compilation_time_halide_lowering", compilation_time[Phase::HalideLowering]);
if (auto it = compilation_time.find(Phase::HalideLowering);
it != compilation_time.end()) {
emit_key_value(o, indent, "compilation_time_halide_lowering", it->second);
}
if (compilation_time.count(Phase::LLVM)) {
emit_key_value(o, indent, "compilation_time_llvm", compilation_time[Phase::LLVM]);
if (auto it = compilation_time.find(Phase::LLVM);
it != compilation_time.end()) {
emit_key_value(o, indent, "compilation_time_llvm", it->second);
}

if (!matched_simplifier_rules.empty()) {
Expand Down
35 changes: 20 additions & 15 deletions src/HexagonOptimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1521,8 +1521,8 @@ class EliminateInterleaves : public IRMutator {
}

if (const Load *load = x.as<Load>()) {
if (buffers.contains(load->name)) {
BufferState &state = buffers.ref(load->name);
if (auto *state_ptr = buffers.shallow_find(load->name)) {
BufferState &state = *state_ptr;
if (state != BufferState::NotInterleaved) {
state = BufferState::Interleaved;
return x;
Expand Down Expand Up @@ -1814,23 +1814,27 @@ class EliminateInterleaves : public IRMutator {
op->func, op->value_index, op->image, op->param);
// Add the interleave back to the result of the call.
return native_interleave(expr);
} else if (deinterleaving_alts.find(op->name) != deinterleaving_alts.end() && hvx_target >= deinterleaving_alts[op->name].first &&

} else if (auto it = deinterleaving_alts.find(op->name);
it != deinterleaving_alts.end() &&
hvx_target >= it->second.first &&
yields_removable_interleave(args)) {
// This call has a deinterleaving alternative, and the
// arguments are interleaved, so we should use the
// alternative instead.
for (Expr &i : args) {
i = remove_interleave(i);
}
return Call::make(op->type, deinterleaving_alts[op->name].second, args, op->call_type);
} else if (interleaving_alts.count(op->name) && hvx_target >= interleaving_alts[op->name].first && is_native_deinterleave(args[0])) {
return Call::make(op->type, it->second.second, args, op->call_type);
} else if (auto it = interleaving_alts.find(op->name);
it != interleaving_alts.end() &&
hvx_target >= it->second.first &&
is_native_deinterleave(args[0])) {
// This is an interleaving alternative with a
// deinterleave, which can be generated when we
// deinterleave storage. Revert back to the interleaving
// op so we can remove the deinterleave.
Expr arg = args[0].as<Call>()->args[0];
return Call::make(op->type, interleaving_alts[op->name].second, {arg}, op->call_type,
return Call::make(op->type, it->second.second, {arg}, op->call_type,
op->func, op->value_index, op->image, op->param);
} else if (changed) {
return Call::make(op->type, op->name, args, op->call_type,
Expand Down Expand Up @@ -1937,7 +1941,7 @@ class EliminateInterleaves : public IRMutator {
}

Expr visit(const Load *op) override {
if (buffers.contains(op->name)) {
if (auto *buf_state = buffers.shallow_find(op->name)) {
if ((op->type.lanes() * op->type.bits()) % (native_vector_bits * 2) == 0) {
// This is a double vector load, we might be able to
// deinterleave the storage of this buffer.
Expand All @@ -1959,8 +1963,7 @@ class EliminateInterleaves : public IRMutator {
} else {
// This is not a double vector load, so we can't
// deinterleave the storage of this buffer.
BufferState &state = buffers.ref(op->name);
state = BufferState::NotInterleaved;
*buf_state = BufferState::NotInterleaved;
}
}
Expr expr = IRMutator::visit(op);
Expand Down Expand Up @@ -2224,16 +2227,18 @@ class SyncronizationBarriers : public IRMutator {
// Creates entry in sync map for the stmt requiring a
// scatter-release instruction before it.
void check_hazard(const string &name) {
if (in_flight.find(name) == in_flight.end()) {
auto it = in_flight.find(name);
if (it == in_flight.end()) {
return;
}
// Sync Needed. Add the scatter-release before the first different For
// loop lock between the curr_path and the hazard src location.
size_t min_size = std::min(in_flight[name].size(), curr_path.size());
const auto &flight_path = it->second;
size_t min_size = std::min(flight_path.size(), curr_path.size());
size_t i = 0;
// Find the first different For loop block.
for (; i < min_size; i++) {
if (in_flight[name][i] != curr_path[i]) {
if (flight_path[i] != curr_path[i]) {
break;
}
}
Expand Down Expand Up @@ -2266,9 +2271,9 @@ class SyncronizationBarriers : public IRMutator {
curr = &s;
Stmt new_s = IRMutator::mutate(s);
// Wrap the stmt with scatter-release if any hazard was detected.
if (sync.find(&s) != sync.end()) {
if (auto it = sync.find(&s); it != sync.end()) {
Stmt scatter_sync =
Evaluate::make(Call::make(Int(32), Call::hvx_scatter_release, {sync[&s]}, Call::Intrinsic));
Evaluate::make(Call::make(Int(32), Call::hvx_scatter_release, {it->second}, Call::Intrinsic));
return Block::make(scatter_sync, new_s);
}
return new_s;
Expand Down
8 changes: 5 additions & 3 deletions src/SlidingWindow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,11 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator {
for (int i = 0; i < func.dimensions(); i++) {
// Look up the region required of this function's last stage
string var = prefix + func_args[i];
internal_assert(scope.contains(var + ".min") && scope.contains(var + ".max"));
Expr min_req = scope.get(var + ".min");
Expr max_req = scope.get(var + ".max");
const auto *min_val = scope.find(var + ".min");
const auto *max_val = scope.find(var + ".max");
internal_assert(min_val && max_val);
Expr min_req = *min_val;
Expr max_req = *max_val;
min_req = expand_expr(min_req, scope);
max_req = expand_expr(max_req, scope);

Expand Down
5 changes: 3 additions & 2 deletions src/SplitTuples.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ class SplitTuples : public IRMutator {
if (op->types.size() > 1) {
// If there is a corresponding HoistedStorage node record the new number of
// realizes.
if (hoisted_tuple_count.count(op->name)) {
hoisted_tuple_count[op->name] = op->types.size();
if (auto it = hoisted_tuple_count.find(op->name);
it != hoisted_tuple_count.end()) {
it->second = op->types.size();
}
// Make a nested set of realize nodes for each tuple element
Stmt body = mutate(op->body);
Expand Down
5 changes: 3 additions & 2 deletions src/StorageFlattening.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,9 @@ class HoistStorage : public IRMutator {
}

Stmt visit(const Allocate *op) override {
if (hoisted_storages_map.count(op->name) > 0) {
HoistedStorageData &hoisted_storage_data = hoisted_storages[hoisted_storages_map[op->name]];
if (auto it = hoisted_storages_map.find(op->name);
it != hoisted_storages_map.end()) {
HoistedStorageData &hoisted_storage_data = hoisted_storages[it->second];

auto expand_and_bound = [&](Expr e) {
// Iterate from innermost outwards
Expand Down
9 changes: 5 additions & 4 deletions src/VectorizeLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -541,8 +541,8 @@ class VectorSubs : public IRMutator {
}

Expr visit(const Variable *op) override {
if (replacements.count(op->name) > 0) {
return replacements[op->name];
if (auto it = replacements.find(op->name); it != replacements.end()) {
return it->second;
} else if (scope.contains(op->name)) {
string widened_name = get_widened_var_name(op->name);
return Variable::make(vector_scope.get(widened_name).type(), widened_name);
Expand Down Expand Up @@ -994,11 +994,12 @@ class VectorSubs : public IRMutator {
// them according to the current loop level.
for (const auto &[var, val] : containing_lets) {
// Skip if this var wasn't vectorized.
if (!scope.contains(var)) {
const auto *scope_val = scope.find(var);
if (!scope_val) {
continue;
}
string vectorized_name = get_widened_var_name(var);
Expr vectorized_value = mutate(scope.get(var));
Expr vectorized_value = mutate(*scope_val);
vector_scope.push(vectorized_name, vectorized_value);
}

Expand Down
Loading