diff --git a/.dep-versions b/.dep-versions index 4ca08f1d5c..eff2c832ec 100644 --- a/.dep-versions +++ b/.dep-versions @@ -8,7 +8,7 @@ enzyme=v0.0.238 # For a custom PL version, update the package version here and at # 'doc/requirements.txt' -pennylane=0.46.0.dev24 +pennylane=0.46.0.dev38 # For a custom LQ/LK version, update the package version here and at # 'doc/requirements.txt' diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index c990c2ad88..f5b0045e82 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -5,6 +5,10 @@

Improvements 🛠

+* The new `pennylane.core.Operator2` can now be lowered to MLIR with program capture for operators + without non-lowerable arguments. + [(#2969)](https://github.com/PennyLaneAI/catalyst/pull/2969/) + * The `ResourceAnalysis` pass now reports each loop body and each subroutine as its own entry instead of folding their gate counts into the caller. Loops with constant bounds appear as `for_loop_` with their trip count. Loops with dynamic bounds appear as `dyn_for_loop_` with a stable diff --git a/doc/requirements.txt b/doc/requirements.txt index c6721a731a..3459767b3b 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -34,4 +34,4 @@ lxml_html_clean --extra-index-url https://test.pypi.org/simple/ pennylane-lightning-kokkos==0.46.0-dev10 pennylane-lightning==0.46.0-dev10 -pennylane==0.46.0.dev24 +pennylane==0.46.0.dev38 diff --git a/frontend/catalyst/from_plxpr/qfunc_interpreter.py b/frontend/catalyst/from_plxpr/qfunc_interpreter.py index a1f8a8c098..0d51ee2705 100644 --- a/frontend/catalyst/from_plxpr/qfunc_interpreter.py +++ b/frontend/catalyst/from_plxpr/qfunc_interpreter.py @@ -30,6 +30,7 @@ from pennylane.capture.primitives import measure_prim as plxpr_measure_prim from pennylane.capture.primitives import pauli_measure_prim as plxpr_pauli_measure_prim from pennylane.capture.primitives import quantum_subroutine_prim, transform_prim +from pennylane.core.operator.operator2 import operator_p from pennylane.ftqc.primitives import measure_in_basis_prim as plxpr_measure_in_basis_prim from pennylane.measurements import CountsMP from pennylane.wires import AbstractQubit, is_abstract_qubit @@ -45,6 +46,7 @@ qref_measure_in_basis_p, qref_measure_p, qref_namedobs_p, + qref_operator_p, qref_pauli_measure_p, qref_pauli_rot_p, qref_qinst_p, @@ -290,6 +292,29 @@ def __call__(self, jaxpr, *args): return self.eval(jaxpr.jaxpr, jaxpr.consts, *args) +@PLxPRToQuantumJaxprInterpreter.register_primitive(operator_p) +def _handle_operator(self, *args, op_cls, hybrid_lens, hybrid_trees, **kwargs): + + if hybrid_lens or hybrid_trees or op_cls.static_argnames: + # only support compilable_argnames for the moment + raise NotImplementedError + + wire_inputs = args[len(op_cls.dynamic_argnames) :] + new_wires = [ + w if is_abstract_qubit(w) else qref_get_p.bind(self.init_qreg, w) for w in wire_inputs + ] + + qref_operator_p.bind( + *args[: len(op_cls.dynamic_argnames)], + *new_wires, + op_cls=op_cls, + hybrid_lens=hybrid_lens, + hybrid_trees=hybrid_trees, + **kwargs, + ) + return [] + + # pylint: disable=unused-argument, too-many-arguments def _qubit_unitary_bind_call( *invals, op, qubits_len, params_len, ctrl_len, adjoint, hyperparameters diff --git a/frontend/catalyst/from_plxpr/qref_jax_primitives.py b/frontend/catalyst/from_plxpr/qref_jax_primitives.py index 4e3e94d9a4..c47cdd1d99 100644 --- a/frontend/catalyst/from_plxpr/qref_jax_primitives.py +++ b/frontend/catalyst/from_plxpr/qref_jax_primitives.py @@ -29,8 +29,11 @@ ) from jaxlib.mlir.dialects.stablehlo import ConvertOp as StableHLOConvertOp from pennylane.capture.primitives import adjoint_transform_prim as plxpr_adjoint_transform_prim +from pennylane.pytrees import unflatten from pennylane.wires import AbstractQubit +from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval + # TODO: remove after jax v0.7.2 upgrade # Mock _ods_cext.globals.register_traceback_file_exclusion due to API conflicts between # Catalyst's MLIR version and the MLIR version used by JAX. The current JAX version has not @@ -72,6 +75,7 @@ MeasureOp, MultiRZOp, NamedObsOp, + OperatorOp, PauliRotOp, PCPhaseOp, QubitUnitaryOp, @@ -164,6 +168,7 @@ class MeasurementPlane(Enum): qref_compbasis_p = Primitive("qref_compbasis") qref_namedobs_p = Primitive("qref_namedobs") qref_hermitian_p = Primitive("qref_hermitian") +qref_operator_p = Primitive("qref_operator") # @@ -798,6 +803,62 @@ def _qref_named_obs_lowering(jax_ctx: mlir.LoweringRuleContext, qubit: ir.Value, return NamedObsOp(result_type, qubit, obsId).results +qref_operator_p.multiple_results = True + + +@qref_operator_p.def_abstract_eval +def _qref_operator_p_abstract_eval(*args, **kwargs): + return [] + + +def _qref_operator_p_lowering( + jax_ctx: mlir.LoweringRuleContext, + *args, + op_cls, + hybrid_lens, + hybrid_trees, + wire_lens, + **static_data, +): + params = args[: len(op_cls.dynamic_argnames)] + qubits = args[len(op_cls.dynamic_argnames) :] + + name_attr = get_mlir_attribute_from_pyval(op_cls.__name__) + + repack_static_data = {k: unflatten(*v) for k, v in static_data.items()} + processed_static_data = get_mlir_attribute_from_pyval(repack_static_data) + + param_map = { + name: ir.DenseI64ArrayAttr.get([ind]) for ind, name in enumerate(op_cls.dynamic_argnames) + } + processed_param_map = get_mlir_attribute_from_pyval(param_map) + + qubit_map = {} + ind = 0 + for name, size in zip(op_cls.wire_argnames, wire_lens): + qubit_map[name] = ir.DenseI64ArrayAttr.get(list(range(ind, ind + size))) + ind += size + + processed_qubit_map = get_mlir_attribute_from_pyval(qubit_map) + + OperatorOp( + op_name=name_attr, + params=params, + qubits=qubits, + qreg=None, + forward_args=[], + ctrl_qubits=[], + ctrl_values=[], + adjoint=False, + UID=None, + arr_qubit_indices=[], + param_map=processed_param_map, + static_data=processed_static_data, + qubit_map=processed_qubit_map, + ) + return [] + + # # hermitian observable # @@ -821,6 +882,7 @@ def _qref_hermitian_lowering(jax_ctx: mlir.LoweringRuleContext, matrix: ir.Value CUSTOM_LOWERING_RULES = ( + (qref_operator_p, _qref_operator_p_lowering), (qref_alloc_p, _qref_alloc_lowering), (qref_dealloc_p, _qref_dealloc_lowering), (qref_get_p, _qref_get_lowering), diff --git a/frontend/catalyst/jax_extras/lowering.py b/frontend/catalyst/jax_extras/lowering.py index d372890229..3812dc97a0 100644 --- a/frontend/catalyst/jax_extras/lowering.py +++ b/frontend/catalyst/jax_extras/lowering.py @@ -202,6 +202,9 @@ def get_mlir_attribute_from_pyval(value): attr = None match value: + case ir.Attribute(): + attr = value + case bool(): attr = ir.BoolAttr.get(value) diff --git a/frontend/test/lit/test_operator.py b/frontend/test/lit/test_operator.py new file mode 100644 index 0000000000..12bafa1f1a --- /dev/null +++ b/frontend/test/lit/test_operator.py @@ -0,0 +1,156 @@ +# Copyright 2022-2023 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for operator in Catalyst.""" + +# pylint: disable = useless-parent-delegation, missing-function-docstring, missing-class-docstring + +# RUN: %PYTHON %s | FileCheck %s + +import numpy as np +import pennylane as qp + + +class NoParams(qp.core.Operator2): + + def __init__(self, wires): + super().__init__(wires=wires) + + +@qp.qjit(target="mlir", capture=True) +@qp.qnode(qp.device("null.qubit", wires=2)) +def c_no_params(): + # CHECK: [[q0:%.+]] = qref.get {{%.+}} + # CHECK: qref.operator "NoParams"() qubits([[q0]]) + # CHECK: static_data = {} + # CHECK: param_map = {} qubit_map = {wires = [0]} + NoParams(wires=0) + + # CHECK: [[q1:%.+]] = qref.get {{%.+}} + # CHECK: [[q2:%.+]] = qref.get {{%.+}} + # CHECK: qref.operator "NoParams"() qubits([[q1]], [[q2]]) + # CHECK: static_data = {} + # CHECK: param_map = {} qubit_map = {wires = [0, 1]} + NoParams(wires=(0, 1)) + return qp.state() + + +print(c_no_params.mlir) + + +class SingleParam(qp.core.Operator2): + + dynamic_argnames = ("x",) + + def __init__(self, x, wires): + super().__init__(x, wires=wires) + + +@qp.qjit(target="mlir", capture=True) +@qp.qnode(qp.device("null.qubit", wires=3)) +def c_single_param(x: float): + + # CHECK: [[q0:%.+]] = qref.get {{%.+}} + + # CHECK: qref.operator "SingleParam"({{%.+}}: tensor) qubits([[q0]]) + # CHECK: static_data = {} + # CHECK: param_map = {x = [0]} qubit_map = {wires = [0]} + SingleParam(x, 0) + + # CHECK: [[q1:%.+]] = qref.get {{%.+}} + # CHECK: [[q2:%.+]] = qref.get {{%.+}} + # CHECK: qref.operator "SingleParam"({{%.+}}: tensor<4x4xf64>) qubits([[q1]], [[q2]]) + # CHECK: static_data = {} + # CHECK: param_map = {x = [0]} qubit_map = {wires = [0, 1]} + SingleParam(np.eye(4), (1, 2)) + + return qp.state() + + +print(c_single_param.mlir) + + +class CompilableData(qp.core.Operator2): + + compilable_argnames = ("a", "b", "thing") + + def __init__(self, a, b, thing, wires): + super().__init__(a=a, b=b, thing=thing, wires=wires) + + +@qp.qjit(capture=True, target="mlir") +@qp.qnode(qp.device("null.qubit", wires=3)) +def c_compilable(): + # CHECK: [[q1:%.+]] = qref.get {{%.+}} + # CHECK: [[q2:%.+]] = qref.get {{%.+}} + # CHECK: qref.operator "CompilableData"() qubits([[q1]], [[q2]]) + # CHECK: static_data = {a = true, b = "some string", thing = [1, true, "string"]} + # CHECK: param_map = {} qubit_map = {wires = [0, 1]} + + CompilableData(True, "some string", (1, True, "string"), wires=(0, 1)) + + return qp.state() + + +print(c_compilable.mlir) + + +class MultipleRegisters(qp.core.Operator2): + + wire_argnames = ("reg1", "reg2") + + def __init__(self, reg1, reg2): + super().__init__(reg1=reg1, reg2=reg2) + + +@qp.qjit(capture=True, target="mlir") +@qp.qnode(qp.device("null.qubit", wires=5)) +def c_multiple_registers(): + # CHECK: [[q0:%.+]] = qref.get {{%.+}} + # CHECK: [[q2:%.+]] = qref.get {{%.+}} + # CHECK: [[q3:%.+]] = qref.get {{%.+}} + # CHECK: [[q4:%.+]] = qref.get {{%.+}} + + # CHECK: qref.operator "MultipleRegisters"() qubits([[q0]], [[q2]], [[q3]], [[q4]]) + # CHECK: static_data = {} + # CHECK: param_map = {} qubit_map = {reg1 = [0], reg2 = [1, 2, 3]} + MultipleRegisters(0, (2, 3, 4)) + return qp.state() + + +print(c_multiple_registers.mlir) + + +class MultiParams(qp.core.Operator2): + + dynamic_argnames = ("a", "b", "c") + + # note also having non-standard order with dynamic inputs after wires + def __init__(self, wires, a, b, c): + super().__init__(wires, a, b, c) + + +@qp.qjit(capture=True, target="mlir") +@qp.qnode(qp.device("null.qubit", wires=1)) +def c_multi_params(): + # CHECK: [[q0:%.+]] = qref.get {{%.+}} + + # pylint: disable=line-too-long + # CHECK: qref.operator "MultiParams"({{%.+}}: tensor, {{%.+}}: tensor<4x2x1xf64>, {{%.+}}: tensor<3xi64>) qubits([[q0]]) + # CHECK: static_data = {} + # CHECK: param_map = {a = [0], b = [1], c = [2]} qubit_map = {wires = [0]} + MultiParams(0, 0.5, c=np.array([1, 2, 3]), b=np.zeros((4, 2, 1))) + return qp.state() + + +print(c_multi_params.mlir) diff --git a/frontend/test/pytest/test_operator.py b/frontend/test/pytest/test_operator.py new file mode 100644 index 0000000000..6558f5972d --- /dev/null +++ b/frontend/test/pytest/test_operator.py @@ -0,0 +1,65 @@ +# Copyright 2022-2026 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for the new Operator2 class. +""" + +# pylint: disable = useless-parent-delegation, missing-function-docstring, missing-class-docstring +import pennylane as qp +import pytest + + +class DummyOp(qp.core.Operator2): + + def __init__(self, wires): + super().__init__(wires=wires) + + +def test_hybrid_not_supported_yet(): + """Test that hybrid arguments are not yet supported.""" + + class OperatorArgument(qp.core.Operator2): + + hybrid_argnames = ("op",) + wire_argnames = () + + def __init__(self, op): + super().__init__(op) + + with pytest.raises(NotImplementedError): + + @qp.qjit(capture=True) + @qp.qnode(qp.device("null.qubit", wires=3)) + def c(): + OperatorArgument(DummyOp(0)) + return qp.state() + + +def test_static_argnames(): + """Test that static arguments are not yet supported.""" + + class StaticArgsOp(qp.core.Operator2): + + static_argnames = ("thing",) + + def __init__(self, thing, wires): + super().__init__(thing, wires) + + with pytest.raises(NotImplementedError): + + @qp.qjit(capture=True) + @qp.qnode(qp.device("null.qubit", wires=2)) + def c(): + StaticArgsOp("hello", 0) + return qp.state() diff --git a/mlir/include/QRef/IR/QRefOps.td b/mlir/include/QRef/IR/QRefOps.td index 075f871e2c..028916927f 100644 --- a/mlir/include/QRef/IR/QRefOps.td +++ b/mlir/include/QRef/IR/QRefOps.td @@ -412,7 +412,7 @@ def OperatorOp : UnitaryGate_Op<"operator", [ParametrizedGate, AttrSizedOperandS }]; let arguments = (ins - StringProp:$op_name, + StrAttr:$op_name, Variadic:$params, Variadic:$forward_args, diff --git a/mlir/include/Quantum/IR/QuantumOps.td b/mlir/include/Quantum/IR/QuantumOps.td index 5b7db9bebf..8623922262 100644 --- a/mlir/include/Quantum/IR/QuantumOps.td +++ b/mlir/include/Quantum/IR/QuantumOps.td @@ -613,7 +613,7 @@ def OperatorOp : UnitaryGate_Op<"operator", [ParametrizedGate, NoMemoryEffect, }]; let arguments = (ins - StringProp:$op_name, + StrAttr:$op_name, Variadic:$params, Variadic:$forward_args, diff --git a/mlir/lib/QRef/IR/QRefOps.cpp b/mlir/lib/QRef/IR/QRefOps.cpp index f53e3be8cc..3e754fc03b 100644 --- a/mlir/lib/QRef/IR/QRefOps.cpp +++ b/mlir/lib/QRef/IR/QRefOps.cpp @@ -482,7 +482,7 @@ void OperatorOp::print(OpAsmPrinter &p) // 5. Attribute Dictionary SmallVector elidedAttrs = {"static_data", "param_map", "qubit_map", - "operandSegmentSizes"}; + "operandSegmentSizes", "op_name"}; p.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs); p.increaseIndent(); @@ -594,8 +594,8 @@ ParseResult OperatorOp::parse(OpAsmParser &parser, OperationState &result) if (parser.parseString(&opName)) { return failure(); } + result.addAttribute("op_name", builder.getStringAttr(opName)); auto &opProperties = result.getOrAddProperties(); - opProperties.setOpName(opName); // 2. Parse variadic params: (%arg0: type, ...) SmallVector params; diff --git a/mlir/lib/QRef/Transforms/value_semantics_conversion.cpp b/mlir/lib/QRef/Transforms/value_semantics_conversion.cpp index d3475f46c2..689af5bb9c 100644 --- a/mlir/lib/QRef/Transforms/value_semantics_conversion.cpp +++ b/mlir/lib/QRef/Transforms/value_semantics_conversion.cpp @@ -1159,7 +1159,6 @@ void handleGate(IRRewriter &builder, qref::QuantumOperation rGateOp, QubitValueT builder.getDenseI32ArrayAttr({nTargets, nCtrls, nQreg})); // Properties are not handled via the generic attribute fields, so we set them separately. - vGateOp.setOpName(rOperatorOp.getOpName()); vGateOp.setAdjoint(rOperatorOp.getAdjoint()); vGateOp.setUID(rOperatorOp.getUID()); } diff --git a/mlir/lib/Quantum/IR/QuantumOps.cpp b/mlir/lib/Quantum/IR/QuantumOps.cpp index 736d9f2d4b..62977102c4 100644 --- a/mlir/lib/Quantum/IR/QuantumOps.cpp +++ b/mlir/lib/Quantum/IR/QuantumOps.cpp @@ -785,8 +785,8 @@ void OperatorOp::print(OpAsmPrinter &p) } // 5. Attribute Dictionary - SmallVector elidedAttrs = {"static_data", "param_map", "qubit_map", - "operandSegmentSizes", "resultSegmentSizes"}; + SmallVector elidedAttrs = {"static_data", "param_map", "qubit_map", + "operandSegmentSizes", "resultSegmentSizes", "op_name"}; p.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs); p.increaseIndent(); @@ -899,7 +899,7 @@ ParseResult OperatorOp::parse(OpAsmParser &parser, OperationState &result) return failure(); } auto &opProperties = result.getOrAddProperties(); - opProperties.setOpName(opName); + result.addAttribute("op_name", builder.getStringAttr(opName)); // 2. Parse variadic params: (%arg0: type, ...) SmallVector params; diff --git a/mlir/lib/Quantum/Transforms/reference_semantics_conversion.cpp b/mlir/lib/Quantum/Transforms/reference_semantics_conversion.cpp index 3dbbf0baf1..5c35f70782 100644 --- a/mlir/lib/Quantum/Transforms/reference_semantics_conversion.cpp +++ b/mlir/lib/Quantum/Transforms/reference_semantics_conversion.cpp @@ -368,7 +368,6 @@ void handleGate(IRRewriter &builder, quantum::QuantumOperation vGateOp, QubitVal rGateOp->removeAttr("resultSegmentSizes"); // Properties are not handled via the generic attribute fields, so we set them separately. - rGateOp.setOpName(vOperatorOp.getOpName()); rGateOp.setAdjoint(vOperatorOp.getAdjoint()); rGateOp.setUID(vOperatorOp.getUID()); }