Skip to content

Commit a870372

Browse files
author
niklas
committed
spv/lift: canonicalize loop shortcut before phi materialization
Add a guarded CFG canonicalization pass in FuncLifting that rewrites a strict loop shortcut shape before dead-block pruning and OpPhi case collection. Matched shape: loop header branches to a body select with one empty pass-through arm into body merge, body merge branches to loop_continue, and loop_continue conditionally branches to loop header or loop merge on the same boolean condition (or its negation). Rewrite: retarget pass-through arm directly to loop merge, drop the body selection merge marker, and canonicalize loop_continue to an unconditional backedge while preserving only header-target phi payloads. Then recompute predecessor/use counts before block retention. This preserves CFG/phi invariants through explicit guards on terminator kinds, merge topology, predecessor counts, and payload arity, while eliminating a fragile structured-loop encoding that can mis-handle edge-carried values.
1 parent 710956c commit a870372

1 file changed

Lines changed: 330 additions & 0 deletions

File tree

src/spv/lift.rs

Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,13 @@ enum Merge<L> {
342342
},
343343
}
344344

345+
/// Relationship between loop body condition and loop-continue condition.
346+
#[derive(Copy, Clone, PartialEq, Eq)]
347+
enum CondRelation {
348+
Same,
349+
Not,
350+
}
351+
345352
impl<'a> NeedsIdsCollector<'a> {
346353
fn alloc_ids<E>(
347354
self,
@@ -508,6 +515,326 @@ impl FuncAt<'_, Node> {
508515
}
509516

510517
impl<'a> FuncLifting<'a> {
518+
/// Recompute incoming edge counts for each CFG point in `blocks`.
519+
///
520+
/// This has to run after control-flow rewrites and before dead-block
521+
/// pruning, to avoid stale predecessor counts.
522+
fn recompute_use_counts(
523+
blocks: &FxIndexMap<CfgPoint, BlockLifting<'a>>,
524+
use_counts: &mut FxHashMap<CfgPoint, usize>,
525+
) {
526+
use_counts.clear();
527+
use_counts.reserve(blocks.len());
528+
let all_edges = blocks.first().map(|(&entry_point, _)| entry_point).into_iter().chain(
529+
blocks.values().flat_map(|block| {
530+
block
531+
.terminator
532+
.merge
533+
.iter()
534+
.flat_map(|merge| {
535+
let (a, b) = match merge {
536+
Merge::Selection(a) => (a, None),
537+
Merge::Loop { loop_merge: a, loop_continue: b } => (a, Some(b)),
538+
};
539+
[a].into_iter().chain(b)
540+
})
541+
.chain(&block.terminator.targets)
542+
.copied()
543+
}),
544+
);
545+
for target in all_edges {
546+
*use_counts.entry(target).or_default() += 1;
547+
}
548+
}
549+
550+
/// Return `true` iff `point` is an empty pass-through block branching only
551+
/// to `target`.
552+
fn is_passthrough_branch_to(
553+
blocks: &FxIndexMap<CfgPoint, BlockLifting<'a>>,
554+
point: CfgPoint,
555+
target: CfgPoint,
556+
) -> bool {
557+
let Some(block) = blocks.get(&point) else {
558+
return false;
559+
};
560+
block.phis.is_empty()
561+
&& block.insts.is_empty()
562+
&& block.terminator.attrs == AttrSet::default()
563+
&& matches!(&*block.terminator.kind, cfg::ControlInstKind::Branch)
564+
&& block.terminator.inputs.is_empty()
565+
&& block.terminator.targets.as_slice() == [target]
566+
&& block.terminator.target_phi_values.keys().all(|&phi_target| phi_target == target)
567+
&& block.terminator.merge.is_none()
568+
}
569+
570+
fn is_const_opcode(cx: &Context, v: Value, opcode: spec::Opcode) -> bool {
571+
match v {
572+
Value::Const(c) => match &cx[c].kind {
573+
ConstKind::SpvInst { spv_inst_and_const_inputs } => {
574+
spv_inst_and_const_inputs.0.opcode == opcode
575+
}
576+
_ => false,
577+
},
578+
_ => false,
579+
}
580+
}
581+
582+
/// Determine whether `continue_cond` is equal to `body_cond` or its
583+
/// logical negation.
584+
fn continue_cond_relation(
585+
cx: &Context,
586+
func_def_body: &'a crate::FuncDefBody,
587+
continue_cond: Value,
588+
body_cond: Value,
589+
) -> Option<CondRelation> {
590+
if continue_cond == body_cond {
591+
return Some(CondRelation::Same);
592+
}
593+
594+
let wk = &spec::Spec::get().well_known;
595+
match continue_cond {
596+
Value::DataInstOutput(_) => None,
597+
598+
Value::ControlNodeOutput { control_node, output_idx } => {
599+
let control_node_def = func_def_body.at(control_node).def();
600+
let ControlNodeKind::Select { kind: SelectionKind::BoolCond, scrutinee, cases } =
601+
&control_node_def.kind
602+
else {
603+
return None;
604+
};
605+
if *scrutinee != body_cond || cases.len() != 2 {
606+
return None;
607+
}
608+
609+
let output_idx = output_idx as usize;
610+
let true_case_outputs = &func_def_body.at(cases[0]).def().outputs;
611+
let false_case_outputs = &func_def_body.at(cases[1]).def().outputs;
612+
if output_idx >= true_case_outputs.len() || output_idx >= false_case_outputs.len() {
613+
return None;
614+
}
615+
616+
let on_true = true_case_outputs[output_idx];
617+
let on_false = false_case_outputs[output_idx];
618+
if Self::is_const_opcode(cx, on_true, wk.OpConstantTrue)
619+
&& Self::is_const_opcode(cx, on_false, wk.OpConstantFalse)
620+
{
621+
Some(CondRelation::Same)
622+
} else if Self::is_const_opcode(cx, on_true, wk.OpConstantFalse)
623+
&& Self::is_const_opcode(cx, on_false, wk.OpConstantTrue)
624+
{
625+
Some(CondRelation::Not)
626+
} else {
627+
None
628+
}
629+
}
630+
631+
_ => None,
632+
}
633+
}
634+
635+
/// Rewrite `loop_continue` into an unconditional backedge while preserving
636+
/// only phi payloads for the loop header edge.
637+
fn rewrite_continue_as_unconditional_backedge(
638+
blocks: &mut FxIndexMap<CfgPoint, BlockLifting<'a>>,
639+
loop_continue: CfgPoint,
640+
) {
641+
let continue_block = blocks.get_mut(&loop_continue).unwrap();
642+
continue_block.terminator.kind = Cow::Owned(cfg::ControlInstKind::Branch);
643+
continue_block.terminator.inputs = [].into_iter().collect();
644+
let header_point = continue_block.terminator.targets[0];
645+
continue_block.terminator.targets = [header_point].into_iter().collect();
646+
continue_block.terminator.target_phi_values.retain(|&target, _| target == header_point);
647+
}
648+
649+
/// Canonicalize strict loop shortcut patterns:
650+
/// * loop header branches to a body select,
651+
/// * one body arm is an empty pass-through to body merge,
652+
/// * body merge branches to `loop_continue`,
653+
/// * `loop_continue` conditionally branches to header/merge.
654+
///
655+
/// Rewriting the pass-through arm directly to `loop_merge` avoids
656+
/// preserving this fragile shape in lifted CFG.
657+
fn canonicalize_loop_continue_shortcuts(
658+
cx: &Context,
659+
func_def_body: &'a crate::FuncDefBody,
660+
blocks: &mut FxIndexMap<CfgPoint, BlockLifting<'a>>,
661+
) {
662+
let mut loop_continue_shortcuts =
663+
SmallVec::<[(CfgPoint, CfgPoint, CfgPoint, CfgPoint, bool); 4]>::new();
664+
for (&header_point, header_block) in &*blocks {
665+
let header_term = &header_block.terminator;
666+
let Some(Merge::Loop { loop_merge, loop_continue }) = header_term.merge else {
667+
continue;
668+
};
669+
if header_term.attrs != AttrSet::default()
670+
|| !matches!(&*header_term.kind, cfg::ControlInstKind::Branch)
671+
|| !header_term.inputs.is_empty()
672+
|| header_term.targets.len() != 1
673+
|| !header_term.target_phi_values.is_empty()
674+
{
675+
continue;
676+
}
677+
678+
let body_point = header_term.targets[0];
679+
let Some(body_block) = blocks.get(&body_point) else {
680+
continue;
681+
};
682+
let body_term = &body_block.terminator;
683+
let Some(Merge::Selection(body_merge)) = body_term.merge else {
684+
continue;
685+
};
686+
if body_term.attrs != AttrSet::default()
687+
|| !matches!(
688+
&*body_term.kind,
689+
cfg::ControlInstKind::SelectBranch(SelectionKind::BoolCond)
690+
)
691+
|| body_term.inputs.len() != 1
692+
|| body_term.targets.len() != 2
693+
|| !body_term.target_phi_values.is_empty()
694+
{
695+
continue;
696+
}
697+
698+
let Some(continue_block) = blocks.get(&loop_continue) else {
699+
continue;
700+
};
701+
if !continue_block.phis.is_empty() || !continue_block.insts.is_empty() {
702+
continue;
703+
}
704+
let continue_term = &continue_block.terminator;
705+
if continue_term.attrs != AttrSet::default()
706+
|| !matches!(
707+
&*continue_term.kind,
708+
cfg::ControlInstKind::SelectBranch(SelectionKind::BoolCond)
709+
)
710+
|| continue_term.inputs.len() != 1
711+
|| continue_term.targets.as_slice() != [header_point, loop_merge]
712+
|| continue_term.target_phi_values.keys().any(|&target| target != header_point)
713+
|| continue_term.merge.is_some()
714+
{
715+
continue;
716+
}
717+
718+
let t0 = body_term.targets[0];
719+
let t1 = body_term.targets[1];
720+
let (_work_target, pass_target) =
721+
if Self::is_passthrough_branch_to(blocks, t0, body_merge) {
722+
(t1, t0)
723+
} else if Self::is_passthrough_branch_to(blocks, t1, body_merge) {
724+
(t0, t1)
725+
} else {
726+
continue;
727+
};
728+
let Some(cond_relation) = Self::continue_cond_relation(
729+
cx,
730+
func_def_body,
731+
continue_term.inputs[0],
732+
body_term.inputs[0],
733+
) else {
734+
continue;
735+
};
736+
let continue_routes_work_to_header = match cond_relation {
737+
CondRelation::Same => pass_target == t1,
738+
CondRelation::Not => pass_target == t0,
739+
};
740+
if !continue_routes_work_to_header {
741+
continue;
742+
}
743+
744+
let body_merge_preds: SmallVec<[CfgPoint; 4]> = blocks
745+
.iter()
746+
.filter_map(|(&point, block)| {
747+
block.terminator.targets.contains(&body_merge).then_some(point)
748+
})
749+
.collect();
750+
if body_merge_preds.len() != 2 || !body_merge_preds.contains(&pass_target) {
751+
continue;
752+
}
753+
let Some(other_body_merge_pred) =
754+
body_merge_preds.into_iter().find(|&point| point != pass_target)
755+
else {
756+
continue;
757+
};
758+
759+
let continue_pred_count = blocks
760+
.values()
761+
.filter(|block| block.terminator.targets.contains(&loop_continue))
762+
.count();
763+
if continue_pred_count != 1 {
764+
continue;
765+
}
766+
767+
let Some(loop_merge_block) = blocks.get(&loop_merge) else {
768+
continue;
769+
};
770+
if !loop_merge_block.phis.is_empty() {
771+
continue;
772+
}
773+
774+
let Some(body_merge_block) = blocks.get(&body_merge) else {
775+
continue;
776+
};
777+
let merge_term = &body_merge_block.terminator;
778+
if merge_term.attrs != AttrSet::default()
779+
|| !matches!(&*merge_term.kind, cfg::ControlInstKind::Branch)
780+
|| !merge_term.inputs.is_empty()
781+
|| merge_term.targets.as_slice() != [loop_continue]
782+
|| !merge_term.target_phi_values.is_empty()
783+
|| merge_term.merge.is_some()
784+
{
785+
continue;
786+
}
787+
788+
let body_merge_phi_count = body_merge_block.phis.len();
789+
let payload_arity_to = |source: CfgPoint, target: CfgPoint| {
790+
blocks
791+
.get(&source)
792+
.and_then(|block| block.terminator.target_phi_values.get(&target))
793+
.map_or(0, |values| values.len())
794+
};
795+
if payload_arity_to(pass_target, body_merge) != body_merge_phi_count
796+
|| payload_arity_to(other_body_merge_pred, body_merge) != body_merge_phi_count
797+
{
798+
continue;
799+
}
800+
801+
let header_phi_count = header_block.phis.len();
802+
let continue_payload_arity =
803+
continue_term.target_phi_values.get(&header_point).map_or(0, |values| values.len());
804+
if continue_payload_arity != header_phi_count {
805+
continue;
806+
}
807+
808+
loop_continue_shortcuts.push((
809+
body_point,
810+
pass_target,
811+
loop_merge,
812+
loop_continue,
813+
continue_routes_work_to_header,
814+
));
815+
}
816+
for (
817+
body_point,
818+
pass_target,
819+
loop_merge,
820+
loop_continue,
821+
rewrite_continue_to_unconditional_backedge,
822+
) in loop_continue_shortcuts
823+
{
824+
let body_block = blocks.get_mut(&body_point).unwrap();
825+
body_block.terminator.merge = None;
826+
for target in &mut body_block.terminator.targets {
827+
if *target == pass_target {
828+
*target = loop_merge;
829+
}
830+
}
831+
832+
if rewrite_continue_to_unconditional_backedge {
833+
Self::rewrite_continue_as_unconditional_backedge(blocks, loop_continue);
834+
}
835+
}
836+
}
837+
511838
fn from_func_decl<E>(
512839
cx: &Context,
513840
func_decl: &'a FuncDecl,
@@ -913,6 +1240,9 @@ impl<'a> FuncLifting<'a> {
9131240
}
9141241
}
9151242

1243+
Self::canonicalize_loop_continue_shortcuts(cx, func_def_body, &mut blocks);
1244+
Self::recompute_use_counts(&blocks, &mut use_counts);
1245+
9161246
// Remove now-unused blocks.
9171247
blocks.retain(|point, _| use_counts.get(point).is_some_and(|&count| count > 0));
9181248

0 commit comments

Comments
 (0)