diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index b60c3dddc5..55c6a56bb7 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -5,6 +5,12 @@
Improvements ðŸ›
+* `quantum.extract` canonicalization now looks through a `quantum.insert` at a distinct
+ static index, rewriting the extract to read from the register feeding the insert. This
+ removes the false data dependency between wires that act on different qubits of the same
+ register.
+ [(#2965)](https://github.com/PennyLaneAI/catalyst/pull/2965)
+
* The `ResourceAnalysis` pass now reports each loop body and each subroutine as its own entry
instead of folding their gate counts into the caller. Loops with constant bounds appear as `for_loop_`
with their trip count. Loops with dynamic bounds appear as `dyn_for_loop_` with a stable
@@ -358,6 +364,7 @@ Lillian Frederiksen,
Sengthai Heng,
David Ittah,
Christina Lee,
+Rylan Malarchick,
Mehrdad Malekmohammadi,
River McCubbin,
Shuli Shu,
diff --git a/mlir/lib/Quantum/IR/QuantumOps.cpp b/mlir/lib/Quantum/IR/QuantumOps.cpp
index 736d9f2d4b..f6f74fd422 100644
--- a/mlir/lib/Quantum/IR/QuantumOps.cpp
+++ b/mlir/lib/Quantum/IR/QuantumOps.cpp
@@ -162,6 +162,7 @@ LogicalResult ExtractOp::canonicalize(ExtractOp extract, mlir::PatternRewriter &
bool staticallyEqual = bothStatic && extract.getIdxAttrAttr() == insert.getIdxAttrAttr();
bool dynamicallyEqual = bothDynamic && extract.getIdx() == insert.getIdx();
+ bool staticallyDistinct = bothStatic && extract.getIdxAttrAttr() != insert.getIdxAttrAttr();
bool inSameBlock = extract->getBlock() == insert->getBlock();
if ((staticallyEqual || dynamicallyEqual) && inSameBlock) {
@@ -169,6 +170,14 @@ LogicalResult ExtractOp::canonicalize(ExtractOp extract, mlir::PatternRewriter &
rewriter.replaceOp(insert, insert.getInQreg());
return success();
}
+
+ bool insertHasNonExtractUser = llvm::any_of(
+ insert.getResult().getUsers(), [](Operation *user) { return !isa(user); });
+ if (staticallyDistinct && inSameBlock && insertHasNonExtractUser) {
+ rewriter.modifyOpInPlace(extract,
+ [&] { extract.getQregMutable().assign(insert.getInQreg()); });
+ return success();
+ }
}
return failure();
}
diff --git a/mlir/test/QRef/SemanticConversion/TestFlatCircuits.mlir b/mlir/test/QRef/SemanticConversion/TestFlatCircuits.mlir
index aaced1b2c6..b80b61a40a 100644
--- a/mlir/test/QRef/SemanticConversion/TestFlatCircuits.mlir
+++ b/mlir/test/QRef/SemanticConversion/TestFlatCircuits.mlir
@@ -280,23 +280,15 @@ func.func @test_namedobs_op() -> (!quantum.obs, !quantum.obs) attributes {quantu
// CHECK: [[CNOT:%.+]]:2 = quantum.custom "CNOT"() [[q0]], [[q1]] : !quantum.bit, !quantum.bit
qref.custom "CNOT"() %q0, %q1 : !qref.bit, !qref.bit
- // COM: TODO: improve canonicalization patterns to recognize inverse extract-insert pairs where
- // COM: inserts are delayed past guaranteed distinct extracts (or vice versa), via statically
- // COM: different indices
+ // CHECK: [[obs_x:%.+]] = quantum.namedobs [[CNOT]]#0[ PauliX] : !quantum.obs
// CHECK: [[insert0:%.+]] = quantum.insert [[qreg]][ 0], [[CNOT]]#0 : !quantum.reg, !quantum.bit
- // CHECK: [[insert1:%.+]] = quantum.insert [[insert0]][ 1], [[CNOT]]#1 : !quantum.reg, !quantum.bit
-
- // CHECK: [[extract:%.+]] = quantum.extract [[insert1]][ 0] : !quantum.reg -> !quantum.bit
- // CHECK: [[obs_x:%.+]] = quantum.namedobs [[extract]][ PauliX] : !quantum.obs
- // CHECK: [[insertX:%.+]] = quantum.insert [[insert1]][ 0], [[extract]] : !quantum.reg, !quantum.bit
%obs_x = qref.namedobs %q0 [ PauliX] : !quantum.obs
- // CHECK: [[extract:%.+]] = quantum.extract [[insertX]][ 1] : !quantum.reg -> !quantum.bit
- // CHECK: [[obs_z:%.+]] = quantum.namedobs [[extract]][ PauliZ] : !quantum.obs
- // CHECK: [[insertZ:%.+]] = quantum.insert [[insertX]][ 1], [[extract]] : !quantum.reg, !quantum.bit
+ // CHECK: [[obs_z:%.+]] = quantum.namedobs [[CNOT]]#1[ PauliZ] : !quantum.obs
+ // CHECK: [[insert1:%.+]] = quantum.insert [[insert0]][ 1], [[CNOT]]#1 : !quantum.reg, !quantum.bit
%obs_z = qref.namedobs %q1 [ PauliZ] : !quantum.obs
- // CHECK: quantum.dealloc [[insertZ]] : !quantum.reg
+ // CHECK: quantum.dealloc [[insert1]] : !quantum.reg
qref.dealloc %a : !qref.reg<2>
// CHECK: return [[obs_x]], [[obs_z]]
diff --git a/mlir/test/QRef/SemanticConversion/TestPBC.mlir b/mlir/test/QRef/SemanticConversion/TestPBC.mlir
index 6f95c6378d..5cfd604d82 100644
--- a/mlir/test/QRef/SemanticConversion/TestPBC.mlir
+++ b/mlir/test/QRef/SemanticConversion/TestPBC.mlir
@@ -34,11 +34,9 @@ func.func @test_PPM_op(%angle: f64) -> (i1, i1, i1) attributes {quantum.node} {
// CHECK: [[q1:%.+]] = quantum.extract [[qreg]][ 1] : !quantum.reg -> !quantum.bit
// CHECK: [[m1:%.+]], [[m1_out_qubits:%.+]]:2 = pbc.ppm ["Z", "Y"] [[m0_out_qubit]], [[q1]] : i1, !quantum.bit, !quantum.bit
%m1 = pbc.ref.ppm ["Z", "Y"] %q0, %q1 : i1
- // CHECK: [[insert0:%.+]] = quantum.insert [[qreg]][ 0], [[m1_out_qubits]]#0 : !quantum.reg, !quantum.bit
- // CHECK: [[insert1:%.+]] = quantum.insert [[insert0]][ 1], [[m1_out_qubits]]#1 : !quantum.reg, !quantum.bit
+ // CHECK: [[insert1:%.+]] = quantum.insert [[qreg]][ 1], [[m1_out_qubits]]#1 : !quantum.reg, !quantum.bit
- // CHECK: [[q0:%.+]] = quantum.extract [[insert1]][ 0] : !quantum.reg -> !quantum.bit
- // CHECK: [[m2:%.+]], [[m2_out_qubits:%.+]]:2 = pbc.ppm ["X", "Z"] [[q0]], [[qb]] : i1, !quantum.bit, !quantum.bit
+ // CHECK: [[m2:%.+]], [[m2_out_qubits:%.+]]:2 = pbc.ppm ["X", "Z"] [[m1_out_qubits]]#0, [[qb]] : i1, !quantum.bit, !quantum.bit
%m2 = pbc.ref.ppm ["X", "Z"] %q0, %qb : i1
// CHECK: [[insert2:%.+]] = quantum.insert [[insert1]][ 0], [[m2_out_qubits]]#0 : !quantum.reg, !quantum.bit
diff --git a/mlir/test/QRef/SemanticConversion/TestSubroutines.mlir b/mlir/test/QRef/SemanticConversion/TestSubroutines.mlir
index 55ae6f8777..d3c2b6139c 100644
--- a/mlir/test/QRef/SemanticConversion/TestSubroutines.mlir
+++ b/mlir/test/QRef/SemanticConversion/TestSubroutines.mlir
@@ -54,16 +54,13 @@ func.func @main(%arg0: i64, %arg1: f64) -> (!quantum.obs, !quantum.obs) attribut
// CHECK-SAME: (f64, !quantum.bit, !quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit, !quantum.bit)
func.call @test_extract_before_call(%r2, %r_dyn, %arg1) : (!qref.reg<2>, !qref.reg>, f64) -> ()
func.call @test_extract_before_call(%r2, %r_dyn, %arg1) : (!qref.reg<2>, !qref.reg>, f64) -> ()
- // CHECK: [[insert_20:%.+]] = quantum.insert [[r2]][ 0], [[second_call]]#0 : !quantum.reg, !quantum.bit
- // CHECK: [[insert_21:%.+]] = quantum.insert [[insert_20]][ 1], [[second_call]]#1 : !quantum.reg, !quantum.bit
+ // CHECK: [[insert_21:%.+]] = quantum.insert [[r2]][ 1], [[second_call]]#1 : !quantum.reg, !quantum.bit
// CHECK: [[insert_dyn1:%.+]] = quantum.insert [[r_dyn]][ 1], [[second_call]]#2 : !quantum.reg, !quantum.bit
-
- // CHECK: [[q20:%.+]] = quantum.extract [[insert_21]][ 0] : !quantum.reg -> !quantum.bit
- // CHECK: [[obs_q:%.+]] = quantum.compbasis qubits [[q20]] : !quantum.obs
+ // CHECK: [[obs_q:%.+]] = quantum.compbasis qubits [[second_call]]#0 : !quantum.obs
%q20 = qref.get %r2[0] : !qref.reg<2> -> !qref.bit
%obs_q = qref.compbasis qubits %q20 : !quantum.obs
- // CHECK: [[insert_r2:%.+]] = quantum.insert [[insert_21]][ 0], [[q20]] : !quantum.reg, !quantum.bit
+ // CHECK: [[insert_r2:%.+]] = quantum.insert [[insert_21]][ 0], [[second_call]]#0 : !quantum.reg, !quantum.bit
// CHECK: [[obs_r:%.+]] = quantum.compbasis qreg [[insert_dyn1]] : !quantum.obs
%obs_r = qref.compbasis (qreg %r_dyn : !qref.reg>) : !quantum.obs
diff --git a/mlir/test/Quantum/CanonicalizationTest.mlir b/mlir/test/Quantum/CanonicalizationTest.mlir
index eb5e7f2378..25a8f72af3 100644
--- a/mlir/test/Quantum/CanonicalizationTest.mlir
+++ b/mlir/test/Quantum/CanonicalizationTest.mlir
@@ -71,15 +71,14 @@ func.func @test_extract_insert_fold(%r1: !quantum.reg, %i: i64) -> !quantum.reg
return %r3 : !quantum.reg
}
-// CHECK-LABEL: test_extract_insert_no_fold_static
-func.func @test_extract_insert_no_fold_static(%r1: !quantum.reg, %i1: i64, %i2: i64) -> !quantum.reg {
- // CHECK: quantum.extract
- // CHECK: quantum.insert
+// CHECK-LABEL: test_extract_insert_distinct_static_folds
+func.func @test_extract_insert_distinct_static_folds(%r1: !quantum.reg, %i1: i64, %i2: i64) -> !quantum.reg {
+ // CHECK: [[Q:%.+]] = quantum.extract %arg0[ 0]
+ // CHECK: [[R2:%.+]] = quantum.insert %arg0[ 1], [[Q]]
+ // CHECK: quantum.insert [[R2]][%arg2], [[Q]]
%q1 = quantum.extract %r1[0] : !quantum.reg -> !quantum.bit
%r2 = quantum.insert %r1[1], %q1 : !quantum.reg, !quantum.bit
- // CHECK: quantum.extract
- // CHECK: quantum.insert
%q2 = quantum.extract %r2[0] : !quantum.reg -> !quantum.bit
%r3 = quantum.insert %r2[%i1], %q2 : !quantum.reg, !quantum.bit
@@ -90,6 +89,38 @@ func.func @test_extract_insert_no_fold_static(%r1: !quantum.reg, %i1: i64, %i2:
return %r4 : !quantum.reg
}
+// CHECK-LABEL: test_extract_no_redirect_when_insert_is_leaf
+func.func @test_extract_no_redirect_when_insert_is_leaf(%r0: !quantum.reg) -> !quantum.bit {
+ // CHECK: [[Q0:%.+]] = quantum.extract %arg0[ 0]
+ // CHECK: [[X:%.+]] = quantum.custom "PauliX"() [[Q0]]
+ // CHECK: [[INS:%.+]] = quantum.insert %arg0[ 0], [[X]]
+ // CHECK: [[Q1:%.+]] = quantum.extract [[INS]][ 1]
+ // CHECK: return [[Q1]]
+ %q0 = quantum.extract %r0[0] : !quantum.reg -> !quantum.bit
+ %x = quantum.custom "PauliX"() %q0 : !quantum.bit
+ %r1 = quantum.insert %r0[0], %x : !quantum.reg, !quantum.bit
+ %q1 = quantum.extract %r1[1] : !quantum.reg -> !quantum.bit
+ return %q1 : !quantum.bit
+}
+
+// CHECK-LABEL: test_extract_through_insert_distinct_index
+func.func @test_extract_through_insert_distinct_index(%r0: !quantum.reg) -> !quantum.reg {
+ // CHECK: [[Q0:%.+]] = quantum.extract %arg0[ 0]
+ // CHECK: [[X:%.+]] = quantum.custom "PauliX"() [[Q0]]
+ // CHECK: [[INS0:%.+]] = quantum.insert %arg0[ 0], [[X]]
+ // CHECK: [[Q1:%.+]] = quantum.extract %arg0[ 1]
+ // CHECK: [[Z:%.+]] = quantum.custom "PauliZ"() [[Q1]]
+ // CHECK: [[INS1:%.+]] = quantum.insert [[INS0]][ 1], [[Z]]
+ // CHECK: return [[INS1]]
+ %q0 = quantum.extract %r0[0] : !quantum.reg -> !quantum.bit
+ %x = quantum.custom "PauliX"() %q0 : !quantum.bit
+ %r1 = quantum.insert %r0[0], %x : !quantum.reg, !quantum.bit
+ %q1 = quantum.extract %r1[1] : !quantum.reg -> !quantum.bit
+ %z = quantum.custom "PauliZ"() %q1 : !quantum.bit
+ %r2 = quantum.insert %r1[1], %z : !quantum.reg, !quantum.bit
+ return %r2 : !quantum.reg
+}
+
// CHECK-LABEL: test_extract_insert_constant
func.func @test_extract_insert_constant(%r1: !quantum.reg) -> !quantum.reg {
// CHECK-NOT: arith.constant
diff --git a/mlir/test/Quantum/DecomposeLoweringTest.mlir b/mlir/test/Quantum/DecomposeLoweringTest.mlir
index b7abe6f203..7cae100802 100644
--- a/mlir/test/Quantum/DecomposeLoweringTest.mlir
+++ b/mlir/test/Quantum/DecomposeLoweringTest.mlir
@@ -749,7 +749,9 @@ module @different_qreg_values{
// CHECK: [[q0:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit
// CHECK: [[H:%.+]] = quantum.custom "Hadamard"() [[q0]] : !quantum.bit
// CHECK: [[H_insert:%.+]] = quantum.insert [[reg]][ 0], [[H]] : !quantum.reg, !quantum.bit
- // CHECK: [[full_insert:%.+]] = quantum.insert [[H_insert]][ 1], [[q1]] : !quantum.reg, !quantum.bit
+ // CHECK: [[q2_pre:%.+]] = quantum.extract [[reg]][ 2] : !quantum.reg -> !quantum.bit
+ // CHECK: [[insert_2:%.+]] = quantum.insert [[H_insert]][ 2], [[q2_pre]] : !quantum.reg, !quantum.bit
+ // CHECK: [[full_insert:%.+]] = quantum.insert [[insert_2]][ 1], [[q1]] : !quantum.reg, !quantum.bit
%0 = quantum.alloc( 3) : !quantum.reg
%1 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit
%2 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit