diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index b60c3dddc5..55c6a56bb7 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -5,6 +5,12 @@

Improvements 🛠

+* `quantum.extract` canonicalization now looks through a `quantum.insert` at a distinct + static index, rewriting the extract to read from the register feeding the insert. This + removes the false data dependency between wires that act on different qubits of the same + register. + [(#2965)](https://github.com/PennyLaneAI/catalyst/pull/2965) + * The `ResourceAnalysis` pass now reports each loop body and each subroutine as its own entry instead of folding their gate counts into the caller. Loops with constant bounds appear as `for_loop_` with their trip count. Loops with dynamic bounds appear as `dyn_for_loop_` with a stable @@ -358,6 +364,7 @@ Lillian Frederiksen, Sengthai Heng, David Ittah, Christina Lee, +Rylan Malarchick, Mehrdad Malekmohammadi, River McCubbin, Shuli Shu, diff --git a/mlir/lib/Quantum/IR/QuantumOps.cpp b/mlir/lib/Quantum/IR/QuantumOps.cpp index 736d9f2d4b..f6f74fd422 100644 --- a/mlir/lib/Quantum/IR/QuantumOps.cpp +++ b/mlir/lib/Quantum/IR/QuantumOps.cpp @@ -162,6 +162,7 @@ LogicalResult ExtractOp::canonicalize(ExtractOp extract, mlir::PatternRewriter & bool staticallyEqual = bothStatic && extract.getIdxAttrAttr() == insert.getIdxAttrAttr(); bool dynamicallyEqual = bothDynamic && extract.getIdx() == insert.getIdx(); + bool staticallyDistinct = bothStatic && extract.getIdxAttrAttr() != insert.getIdxAttrAttr(); bool inSameBlock = extract->getBlock() == insert->getBlock(); if ((staticallyEqual || dynamicallyEqual) && inSameBlock) { @@ -169,6 +170,14 @@ LogicalResult ExtractOp::canonicalize(ExtractOp extract, mlir::PatternRewriter & rewriter.replaceOp(insert, insert.getInQreg()); return success(); } + + bool insertHasNonExtractUser = llvm::any_of( + insert.getResult().getUsers(), [](Operation *user) { return !isa(user); }); + if (staticallyDistinct && inSameBlock && insertHasNonExtractUser) { + rewriter.modifyOpInPlace(extract, + [&] { extract.getQregMutable().assign(insert.getInQreg()); }); + return success(); + } } return failure(); } diff --git a/mlir/test/QRef/SemanticConversion/TestFlatCircuits.mlir b/mlir/test/QRef/SemanticConversion/TestFlatCircuits.mlir index aaced1b2c6..b80b61a40a 100644 --- a/mlir/test/QRef/SemanticConversion/TestFlatCircuits.mlir +++ b/mlir/test/QRef/SemanticConversion/TestFlatCircuits.mlir @@ -280,23 +280,15 @@ func.func @test_namedobs_op() -> (!quantum.obs, !quantum.obs) attributes {quantu // CHECK: [[CNOT:%.+]]:2 = quantum.custom "CNOT"() [[q0]], [[q1]] : !quantum.bit, !quantum.bit qref.custom "CNOT"() %q0, %q1 : !qref.bit, !qref.bit - // COM: TODO: improve canonicalization patterns to recognize inverse extract-insert pairs where - // COM: inserts are delayed past guaranteed distinct extracts (or vice versa), via statically - // COM: different indices + // CHECK: [[obs_x:%.+]] = quantum.namedobs [[CNOT]]#0[ PauliX] : !quantum.obs // CHECK: [[insert0:%.+]] = quantum.insert [[qreg]][ 0], [[CNOT]]#0 : !quantum.reg, !quantum.bit - // CHECK: [[insert1:%.+]] = quantum.insert [[insert0]][ 1], [[CNOT]]#1 : !quantum.reg, !quantum.bit - - // CHECK: [[extract:%.+]] = quantum.extract [[insert1]][ 0] : !quantum.reg -> !quantum.bit - // CHECK: [[obs_x:%.+]] = quantum.namedobs [[extract]][ PauliX] : !quantum.obs - // CHECK: [[insertX:%.+]] = quantum.insert [[insert1]][ 0], [[extract]] : !quantum.reg, !quantum.bit %obs_x = qref.namedobs %q0 [ PauliX] : !quantum.obs - // CHECK: [[extract:%.+]] = quantum.extract [[insertX]][ 1] : !quantum.reg -> !quantum.bit - // CHECK: [[obs_z:%.+]] = quantum.namedobs [[extract]][ PauliZ] : !quantum.obs - // CHECK: [[insertZ:%.+]] = quantum.insert [[insertX]][ 1], [[extract]] : !quantum.reg, !quantum.bit + // CHECK: [[obs_z:%.+]] = quantum.namedobs [[CNOT]]#1[ PauliZ] : !quantum.obs + // CHECK: [[insert1:%.+]] = quantum.insert [[insert0]][ 1], [[CNOT]]#1 : !quantum.reg, !quantum.bit %obs_z = qref.namedobs %q1 [ PauliZ] : !quantum.obs - // CHECK: quantum.dealloc [[insertZ]] : !quantum.reg + // CHECK: quantum.dealloc [[insert1]] : !quantum.reg qref.dealloc %a : !qref.reg<2> // CHECK: return [[obs_x]], [[obs_z]] diff --git a/mlir/test/QRef/SemanticConversion/TestPBC.mlir b/mlir/test/QRef/SemanticConversion/TestPBC.mlir index 6f95c6378d..5cfd604d82 100644 --- a/mlir/test/QRef/SemanticConversion/TestPBC.mlir +++ b/mlir/test/QRef/SemanticConversion/TestPBC.mlir @@ -34,11 +34,9 @@ func.func @test_PPM_op(%angle: f64) -> (i1, i1, i1) attributes {quantum.node} { // CHECK: [[q1:%.+]] = quantum.extract [[qreg]][ 1] : !quantum.reg -> !quantum.bit // CHECK: [[m1:%.+]], [[m1_out_qubits:%.+]]:2 = pbc.ppm ["Z", "Y"] [[m0_out_qubit]], [[q1]] : i1, !quantum.bit, !quantum.bit %m1 = pbc.ref.ppm ["Z", "Y"] %q0, %q1 : i1 - // CHECK: [[insert0:%.+]] = quantum.insert [[qreg]][ 0], [[m1_out_qubits]]#0 : !quantum.reg, !quantum.bit - // CHECK: [[insert1:%.+]] = quantum.insert [[insert0]][ 1], [[m1_out_qubits]]#1 : !quantum.reg, !quantum.bit + // CHECK: [[insert1:%.+]] = quantum.insert [[qreg]][ 1], [[m1_out_qubits]]#1 : !quantum.reg, !quantum.bit - // CHECK: [[q0:%.+]] = quantum.extract [[insert1]][ 0] : !quantum.reg -> !quantum.bit - // CHECK: [[m2:%.+]], [[m2_out_qubits:%.+]]:2 = pbc.ppm ["X", "Z"] [[q0]], [[qb]] : i1, !quantum.bit, !quantum.bit + // CHECK: [[m2:%.+]], [[m2_out_qubits:%.+]]:2 = pbc.ppm ["X", "Z"] [[m1_out_qubits]]#0, [[qb]] : i1, !quantum.bit, !quantum.bit %m2 = pbc.ref.ppm ["X", "Z"] %q0, %qb : i1 // CHECK: [[insert2:%.+]] = quantum.insert [[insert1]][ 0], [[m2_out_qubits]]#0 : !quantum.reg, !quantum.bit diff --git a/mlir/test/QRef/SemanticConversion/TestSubroutines.mlir b/mlir/test/QRef/SemanticConversion/TestSubroutines.mlir index 55ae6f8777..d3c2b6139c 100644 --- a/mlir/test/QRef/SemanticConversion/TestSubroutines.mlir +++ b/mlir/test/QRef/SemanticConversion/TestSubroutines.mlir @@ -54,16 +54,13 @@ func.func @main(%arg0: i64, %arg1: f64) -> (!quantum.obs, !quantum.obs) attribut // CHECK-SAME: (f64, !quantum.bit, !quantum.bit, !quantum.bit) -> (!quantum.bit, !quantum.bit, !quantum.bit) func.call @test_extract_before_call(%r2, %r_dyn, %arg1) : (!qref.reg<2>, !qref.reg, f64) -> () func.call @test_extract_before_call(%r2, %r_dyn, %arg1) : (!qref.reg<2>, !qref.reg, f64) -> () - // CHECK: [[insert_20:%.+]] = quantum.insert [[r2]][ 0], [[second_call]]#0 : !quantum.reg, !quantum.bit - // CHECK: [[insert_21:%.+]] = quantum.insert [[insert_20]][ 1], [[second_call]]#1 : !quantum.reg, !quantum.bit + // CHECK: [[insert_21:%.+]] = quantum.insert [[r2]][ 1], [[second_call]]#1 : !quantum.reg, !quantum.bit // CHECK: [[insert_dyn1:%.+]] = quantum.insert [[r_dyn]][ 1], [[second_call]]#2 : !quantum.reg, !quantum.bit - - // CHECK: [[q20:%.+]] = quantum.extract [[insert_21]][ 0] : !quantum.reg -> !quantum.bit - // CHECK: [[obs_q:%.+]] = quantum.compbasis qubits [[q20]] : !quantum.obs + // CHECK: [[obs_q:%.+]] = quantum.compbasis qubits [[second_call]]#0 : !quantum.obs %q20 = qref.get %r2[0] : !qref.reg<2> -> !qref.bit %obs_q = qref.compbasis qubits %q20 : !quantum.obs - // CHECK: [[insert_r2:%.+]] = quantum.insert [[insert_21]][ 0], [[q20]] : !quantum.reg, !quantum.bit + // CHECK: [[insert_r2:%.+]] = quantum.insert [[insert_21]][ 0], [[second_call]]#0 : !quantum.reg, !quantum.bit // CHECK: [[obs_r:%.+]] = quantum.compbasis qreg [[insert_dyn1]] : !quantum.obs %obs_r = qref.compbasis (qreg %r_dyn : !qref.reg) : !quantum.obs diff --git a/mlir/test/Quantum/CanonicalizationTest.mlir b/mlir/test/Quantum/CanonicalizationTest.mlir index eb5e7f2378..25a8f72af3 100644 --- a/mlir/test/Quantum/CanonicalizationTest.mlir +++ b/mlir/test/Quantum/CanonicalizationTest.mlir @@ -71,15 +71,14 @@ func.func @test_extract_insert_fold(%r1: !quantum.reg, %i: i64) -> !quantum.reg return %r3 : !quantum.reg } -// CHECK-LABEL: test_extract_insert_no_fold_static -func.func @test_extract_insert_no_fold_static(%r1: !quantum.reg, %i1: i64, %i2: i64) -> !quantum.reg { - // CHECK: quantum.extract - // CHECK: quantum.insert +// CHECK-LABEL: test_extract_insert_distinct_static_folds +func.func @test_extract_insert_distinct_static_folds(%r1: !quantum.reg, %i1: i64, %i2: i64) -> !quantum.reg { + // CHECK: [[Q:%.+]] = quantum.extract %arg0[ 0] + // CHECK: [[R2:%.+]] = quantum.insert %arg0[ 1], [[Q]] + // CHECK: quantum.insert [[R2]][%arg2], [[Q]] %q1 = quantum.extract %r1[0] : !quantum.reg -> !quantum.bit %r2 = quantum.insert %r1[1], %q1 : !quantum.reg, !quantum.bit - // CHECK: quantum.extract - // CHECK: quantum.insert %q2 = quantum.extract %r2[0] : !quantum.reg -> !quantum.bit %r3 = quantum.insert %r2[%i1], %q2 : !quantum.reg, !quantum.bit @@ -90,6 +89,38 @@ func.func @test_extract_insert_no_fold_static(%r1: !quantum.reg, %i1: i64, %i2: return %r4 : !quantum.reg } +// CHECK-LABEL: test_extract_no_redirect_when_insert_is_leaf +func.func @test_extract_no_redirect_when_insert_is_leaf(%r0: !quantum.reg) -> !quantum.bit { + // CHECK: [[Q0:%.+]] = quantum.extract %arg0[ 0] + // CHECK: [[X:%.+]] = quantum.custom "PauliX"() [[Q0]] + // CHECK: [[INS:%.+]] = quantum.insert %arg0[ 0], [[X]] + // CHECK: [[Q1:%.+]] = quantum.extract [[INS]][ 1] + // CHECK: return [[Q1]] + %q0 = quantum.extract %r0[0] : !quantum.reg -> !quantum.bit + %x = quantum.custom "PauliX"() %q0 : !quantum.bit + %r1 = quantum.insert %r0[0], %x : !quantum.reg, !quantum.bit + %q1 = quantum.extract %r1[1] : !quantum.reg -> !quantum.bit + return %q1 : !quantum.bit +} + +// CHECK-LABEL: test_extract_through_insert_distinct_index +func.func @test_extract_through_insert_distinct_index(%r0: !quantum.reg) -> !quantum.reg { + // CHECK: [[Q0:%.+]] = quantum.extract %arg0[ 0] + // CHECK: [[X:%.+]] = quantum.custom "PauliX"() [[Q0]] + // CHECK: [[INS0:%.+]] = quantum.insert %arg0[ 0], [[X]] + // CHECK: [[Q1:%.+]] = quantum.extract %arg0[ 1] + // CHECK: [[Z:%.+]] = quantum.custom "PauliZ"() [[Q1]] + // CHECK: [[INS1:%.+]] = quantum.insert [[INS0]][ 1], [[Z]] + // CHECK: return [[INS1]] + %q0 = quantum.extract %r0[0] : !quantum.reg -> !quantum.bit + %x = quantum.custom "PauliX"() %q0 : !quantum.bit + %r1 = quantum.insert %r0[0], %x : !quantum.reg, !quantum.bit + %q1 = quantum.extract %r1[1] : !quantum.reg -> !quantum.bit + %z = quantum.custom "PauliZ"() %q1 : !quantum.bit + %r2 = quantum.insert %r1[1], %z : !quantum.reg, !quantum.bit + return %r2 : !quantum.reg +} + // CHECK-LABEL: test_extract_insert_constant func.func @test_extract_insert_constant(%r1: !quantum.reg) -> !quantum.reg { // CHECK-NOT: arith.constant diff --git a/mlir/test/Quantum/DecomposeLoweringTest.mlir b/mlir/test/Quantum/DecomposeLoweringTest.mlir index b7abe6f203..7cae100802 100644 --- a/mlir/test/Quantum/DecomposeLoweringTest.mlir +++ b/mlir/test/Quantum/DecomposeLoweringTest.mlir @@ -749,7 +749,9 @@ module @different_qreg_values{ // CHECK: [[q0:%.+]] = quantum.extract [[reg]][ 0] : !quantum.reg -> !quantum.bit // CHECK: [[H:%.+]] = quantum.custom "Hadamard"() [[q0]] : !quantum.bit // CHECK: [[H_insert:%.+]] = quantum.insert [[reg]][ 0], [[H]] : !quantum.reg, !quantum.bit - // CHECK: [[full_insert:%.+]] = quantum.insert [[H_insert]][ 1], [[q1]] : !quantum.reg, !quantum.bit + // CHECK: [[q2_pre:%.+]] = quantum.extract [[reg]][ 2] : !quantum.reg -> !quantum.bit + // CHECK: [[insert_2:%.+]] = quantum.insert [[H_insert]][ 2], [[q2_pre]] : !quantum.reg, !quantum.bit + // CHECK: [[full_insert:%.+]] = quantum.insert [[insert_2]][ 1], [[q1]] : !quantum.reg, !quantum.bit %0 = quantum.alloc( 3) : !quantum.reg %1 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit %2 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit