@@ -342,6 +342,13 @@ enum Merge<L> {
342342 } ,
343343}
344344
345+ /// Relationship between loop body condition and loop-continue condition.
346+ #[ derive( Copy , Clone , PartialEq , Eq ) ]
347+ enum CondRelation {
348+ Same ,
349+ Not ,
350+ }
351+
345352impl < ' a > NeedsIdsCollector < ' a > {
346353 fn alloc_ids < E > (
347354 self ,
@@ -508,6 +515,326 @@ impl FuncAt<'_, Node> {
508515}
509516
510517impl < ' a > FuncLifting < ' a > {
518+ /// Recompute incoming edge counts for each CFG point in `blocks`.
519+ ///
520+ /// This has to run after control-flow rewrites and before dead-block
521+ /// pruning, to avoid stale predecessor counts.
522+ fn recompute_use_counts (
523+ blocks : & FxIndexMap < CfgPoint , BlockLifting < ' a > > ,
524+ use_counts : & mut FxHashMap < CfgPoint , usize > ,
525+ ) {
526+ use_counts. clear ( ) ;
527+ use_counts. reserve ( blocks. len ( ) ) ;
528+ let all_edges = blocks. first ( ) . map ( |( & entry_point, _) | entry_point) . into_iter ( ) . chain (
529+ blocks. values ( ) . flat_map ( |block| {
530+ block
531+ . terminator
532+ . merge
533+ . iter ( )
534+ . flat_map ( |merge| {
535+ let ( a, b) = match merge {
536+ Merge :: Selection ( a) => ( a, None ) ,
537+ Merge :: Loop { loop_merge : a, loop_continue : b } => ( a, Some ( b) ) ,
538+ } ;
539+ [ a] . into_iter ( ) . chain ( b)
540+ } )
541+ . chain ( & block. terminator . targets )
542+ . copied ( )
543+ } ) ,
544+ ) ;
545+ for target in all_edges {
546+ * use_counts. entry ( target) . or_default ( ) += 1 ;
547+ }
548+ }
549+
550+ /// Return `true` iff `point` is an empty pass-through block branching only
551+ /// to `target`.
552+ fn is_passthrough_branch_to (
553+ blocks : & FxIndexMap < CfgPoint , BlockLifting < ' a > > ,
554+ point : CfgPoint ,
555+ target : CfgPoint ,
556+ ) -> bool {
557+ let Some ( block) = blocks. get ( & point) else {
558+ return false ;
559+ } ;
560+ block. phis . is_empty ( )
561+ && block. insts . is_empty ( )
562+ && block. terminator . attrs == AttrSet :: default ( )
563+ && matches ! ( & * block. terminator. kind, cfg:: ControlInstKind :: Branch )
564+ && block. terminator . inputs . is_empty ( )
565+ && block. terminator . targets . as_slice ( ) == [ target]
566+ && block. terminator . target_phi_values . keys ( ) . all ( |& phi_target| phi_target == target)
567+ && block. terminator . merge . is_none ( )
568+ }
569+
570+ fn is_const_opcode ( cx : & Context , v : Value , opcode : spec:: Opcode ) -> bool {
571+ match v {
572+ Value :: Const ( c) => match & cx[ c] . kind {
573+ ConstKind :: SpvInst { spv_inst_and_const_inputs } => {
574+ spv_inst_and_const_inputs. 0 . opcode == opcode
575+ }
576+ _ => false ,
577+ } ,
578+ _ => false ,
579+ }
580+ }
581+
582+ /// Determine whether `continue_cond` is equal to `body_cond` or its
583+ /// logical negation.
584+ fn continue_cond_relation (
585+ cx : & Context ,
586+ func_def_body : & ' a crate :: FuncDefBody ,
587+ continue_cond : Value ,
588+ body_cond : Value ,
589+ ) -> Option < CondRelation > {
590+ if continue_cond == body_cond {
591+ return Some ( CondRelation :: Same ) ;
592+ }
593+
594+ let wk = & spec:: Spec :: get ( ) . well_known ;
595+ match continue_cond {
596+ Value :: DataInstOutput ( _) => None ,
597+
598+ Value :: ControlNodeOutput { control_node, output_idx } => {
599+ let control_node_def = func_def_body. at ( control_node) . def ( ) ;
600+ let ControlNodeKind :: Select { kind : SelectionKind :: BoolCond , scrutinee, cases } =
601+ & control_node_def. kind
602+ else {
603+ return None ;
604+ } ;
605+ if * scrutinee != body_cond || cases. len ( ) != 2 {
606+ return None ;
607+ }
608+
609+ let output_idx = output_idx as usize ;
610+ let true_case_outputs = & func_def_body. at ( cases[ 0 ] ) . def ( ) . outputs ;
611+ let false_case_outputs = & func_def_body. at ( cases[ 1 ] ) . def ( ) . outputs ;
612+ if output_idx >= true_case_outputs. len ( ) || output_idx >= false_case_outputs. len ( ) {
613+ return None ;
614+ }
615+
616+ let on_true = true_case_outputs[ output_idx] ;
617+ let on_false = false_case_outputs[ output_idx] ;
618+ if Self :: is_const_opcode ( cx, on_true, wk. OpConstantTrue )
619+ && Self :: is_const_opcode ( cx, on_false, wk. OpConstantFalse )
620+ {
621+ Some ( CondRelation :: Same )
622+ } else if Self :: is_const_opcode ( cx, on_true, wk. OpConstantFalse )
623+ && Self :: is_const_opcode ( cx, on_false, wk. OpConstantTrue )
624+ {
625+ Some ( CondRelation :: Not )
626+ } else {
627+ None
628+ }
629+ }
630+
631+ _ => None ,
632+ }
633+ }
634+
635+ /// Rewrite `loop_continue` into an unconditional backedge while preserving
636+ /// only phi payloads for the loop header edge.
637+ fn rewrite_continue_as_unconditional_backedge (
638+ blocks : & mut FxIndexMap < CfgPoint , BlockLifting < ' a > > ,
639+ loop_continue : CfgPoint ,
640+ ) {
641+ let continue_block = blocks. get_mut ( & loop_continue) . unwrap ( ) ;
642+ continue_block. terminator . kind = Cow :: Owned ( cfg:: ControlInstKind :: Branch ) ;
643+ continue_block. terminator . inputs = [ ] . into_iter ( ) . collect ( ) ;
644+ let header_point = continue_block. terminator . targets [ 0 ] ;
645+ continue_block. terminator . targets = [ header_point] . into_iter ( ) . collect ( ) ;
646+ continue_block. terminator . target_phi_values . retain ( |& target, _| target == header_point) ;
647+ }
648+
649+ /// Canonicalize strict loop shortcut patterns:
650+ /// * loop header branches to a body select,
651+ /// * one body arm is an empty pass-through to body merge,
652+ /// * body merge branches to `loop_continue`,
653+ /// * `loop_continue` conditionally branches to header/merge.
654+ ///
655+ /// Rewriting the pass-through arm directly to `loop_merge` avoids
656+ /// preserving this fragile shape in lifted CFG.
657+ fn canonicalize_loop_continue_shortcuts (
658+ cx : & Context ,
659+ func_def_body : & ' a crate :: FuncDefBody ,
660+ blocks : & mut FxIndexMap < CfgPoint , BlockLifting < ' a > > ,
661+ ) {
662+ let mut loop_continue_shortcuts =
663+ SmallVec :: < [ ( CfgPoint , CfgPoint , CfgPoint , CfgPoint , bool ) ; 4 ] > :: new ( ) ;
664+ for ( & header_point, header_block) in & * blocks {
665+ let header_term = & header_block. terminator ;
666+ let Some ( Merge :: Loop { loop_merge, loop_continue } ) = header_term. merge else {
667+ continue ;
668+ } ;
669+ if header_term. attrs != AttrSet :: default ( )
670+ || !matches ! ( & * header_term. kind, cfg:: ControlInstKind :: Branch )
671+ || !header_term. inputs . is_empty ( )
672+ || header_term. targets . len ( ) != 1
673+ || !header_term. target_phi_values . is_empty ( )
674+ {
675+ continue ;
676+ }
677+
678+ let body_point = header_term. targets [ 0 ] ;
679+ let Some ( body_block) = blocks. get ( & body_point) else {
680+ continue ;
681+ } ;
682+ let body_term = & body_block. terminator ;
683+ let Some ( Merge :: Selection ( body_merge) ) = body_term. merge else {
684+ continue ;
685+ } ;
686+ if body_term. attrs != AttrSet :: default ( )
687+ || !matches ! (
688+ & * body_term. kind,
689+ cfg:: ControlInstKind :: SelectBranch ( SelectionKind :: BoolCond )
690+ )
691+ || body_term. inputs . len ( ) != 1
692+ || body_term. targets . len ( ) != 2
693+ || !body_term. target_phi_values . is_empty ( )
694+ {
695+ continue ;
696+ }
697+
698+ let Some ( continue_block) = blocks. get ( & loop_continue) else {
699+ continue ;
700+ } ;
701+ if !continue_block. phis . is_empty ( ) || !continue_block. insts . is_empty ( ) {
702+ continue ;
703+ }
704+ let continue_term = & continue_block. terminator ;
705+ if continue_term. attrs != AttrSet :: default ( )
706+ || !matches ! (
707+ & * continue_term. kind,
708+ cfg:: ControlInstKind :: SelectBranch ( SelectionKind :: BoolCond )
709+ )
710+ || continue_term. inputs . len ( ) != 1
711+ || continue_term. targets . as_slice ( ) != [ header_point, loop_merge]
712+ || continue_term. target_phi_values . keys ( ) . any ( |& target| target != header_point)
713+ || continue_term. merge . is_some ( )
714+ {
715+ continue ;
716+ }
717+
718+ let t0 = body_term. targets [ 0 ] ;
719+ let t1 = body_term. targets [ 1 ] ;
720+ let ( _work_target, pass_target) =
721+ if Self :: is_passthrough_branch_to ( blocks, t0, body_merge) {
722+ ( t1, t0)
723+ } else if Self :: is_passthrough_branch_to ( blocks, t1, body_merge) {
724+ ( t0, t1)
725+ } else {
726+ continue ;
727+ } ;
728+ let Some ( cond_relation) = Self :: continue_cond_relation (
729+ cx,
730+ func_def_body,
731+ continue_term. inputs [ 0 ] ,
732+ body_term. inputs [ 0 ] ,
733+ ) else {
734+ continue ;
735+ } ;
736+ let continue_routes_work_to_header = match cond_relation {
737+ CondRelation :: Same => pass_target == t1,
738+ CondRelation :: Not => pass_target == t0,
739+ } ;
740+ if !continue_routes_work_to_header {
741+ continue ;
742+ }
743+
744+ let body_merge_preds: SmallVec < [ CfgPoint ; 4 ] > = blocks
745+ . iter ( )
746+ . filter_map ( |( & point, block) | {
747+ block. terminator . targets . contains ( & body_merge) . then_some ( point)
748+ } )
749+ . collect ( ) ;
750+ if body_merge_preds. len ( ) != 2 || !body_merge_preds. contains ( & pass_target) {
751+ continue ;
752+ }
753+ let Some ( other_body_merge_pred) =
754+ body_merge_preds. into_iter ( ) . find ( |& point| point != pass_target)
755+ else {
756+ continue ;
757+ } ;
758+
759+ let continue_pred_count = blocks
760+ . values ( )
761+ . filter ( |block| block. terminator . targets . contains ( & loop_continue) )
762+ . count ( ) ;
763+ if continue_pred_count != 1 {
764+ continue ;
765+ }
766+
767+ let Some ( loop_merge_block) = blocks. get ( & loop_merge) else {
768+ continue ;
769+ } ;
770+ if !loop_merge_block. phis . is_empty ( ) {
771+ continue ;
772+ }
773+
774+ let Some ( body_merge_block) = blocks. get ( & body_merge) else {
775+ continue ;
776+ } ;
777+ let merge_term = & body_merge_block. terminator ;
778+ if merge_term. attrs != AttrSet :: default ( )
779+ || !matches ! ( & * merge_term. kind, cfg:: ControlInstKind :: Branch )
780+ || !merge_term. inputs . is_empty ( )
781+ || merge_term. targets . as_slice ( ) != [ loop_continue]
782+ || !merge_term. target_phi_values . is_empty ( )
783+ || merge_term. merge . is_some ( )
784+ {
785+ continue ;
786+ }
787+
788+ let body_merge_phi_count = body_merge_block. phis . len ( ) ;
789+ let payload_arity_to = |source : CfgPoint , target : CfgPoint | {
790+ blocks
791+ . get ( & source)
792+ . and_then ( |block| block. terminator . target_phi_values . get ( & target) )
793+ . map_or ( 0 , |values| values. len ( ) )
794+ } ;
795+ if payload_arity_to ( pass_target, body_merge) != body_merge_phi_count
796+ || payload_arity_to ( other_body_merge_pred, body_merge) != body_merge_phi_count
797+ {
798+ continue ;
799+ }
800+
801+ let header_phi_count = header_block. phis . len ( ) ;
802+ let continue_payload_arity =
803+ continue_term. target_phi_values . get ( & header_point) . map_or ( 0 , |values| values. len ( ) ) ;
804+ if continue_payload_arity != header_phi_count {
805+ continue ;
806+ }
807+
808+ loop_continue_shortcuts. push ( (
809+ body_point,
810+ pass_target,
811+ loop_merge,
812+ loop_continue,
813+ continue_routes_work_to_header,
814+ ) ) ;
815+ }
816+ for (
817+ body_point,
818+ pass_target,
819+ loop_merge,
820+ loop_continue,
821+ rewrite_continue_to_unconditional_backedge,
822+ ) in loop_continue_shortcuts
823+ {
824+ let body_block = blocks. get_mut ( & body_point) . unwrap ( ) ;
825+ body_block. terminator . merge = None ;
826+ for target in & mut body_block. terminator . targets {
827+ if * target == pass_target {
828+ * target = loop_merge;
829+ }
830+ }
831+
832+ if rewrite_continue_to_unconditional_backedge {
833+ Self :: rewrite_continue_as_unconditional_backedge ( blocks, loop_continue) ;
834+ }
835+ }
836+ }
837+
511838 fn from_func_decl < E > (
512839 cx : & Context ,
513840 func_decl : & ' a FuncDecl ,
@@ -913,6 +1240,9 @@ impl<'a> FuncLifting<'a> {
9131240 }
9141241 }
9151242
1243+ Self :: canonicalize_loop_continue_shortcuts ( cx, func_def_body, & mut blocks) ;
1244+ Self :: recompute_use_counts ( & blocks, & mut use_counts) ;
1245+
9161246 // Remove now-unused blocks.
9171247 blocks. retain ( |point, _| use_counts. get ( point) . is_some_and ( |& count| count > 0 ) ) ;
9181248
0 commit comments