diff --git a/.dep-versions b/.dep-versions
index 4ca08f1d5c..eff2c832ec 100644
--- a/.dep-versions
+++ b/.dep-versions
@@ -8,7 +8,7 @@ enzyme=v0.0.238
# For a custom PL version, update the package version here and at
# 'doc/requirements.txt'
-pennylane=0.46.0.dev24
+pennylane=0.46.0.dev38
# For a custom LQ/LK version, update the package version here and at
# 'doc/requirements.txt'
diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index c990c2ad88..f5b0045e82 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -5,6 +5,10 @@
Improvements ðŸ›
+* The new `pennylane.core.Operator2` can now be lowered to MLIR with program capture for operators
+ without non-lowerable arguments.
+ [(#2969)](https://github.com/PennyLaneAI/catalyst/pull/2969/)
+
* The `ResourceAnalysis` pass now reports each loop body and each subroutine as its own entry
instead of folding their gate counts into the caller. Loops with constant bounds appear as `for_loop_`
with their trip count. Loops with dynamic bounds appear as `dyn_for_loop_` with a stable
diff --git a/doc/requirements.txt b/doc/requirements.txt
index c6721a731a..3459767b3b 100644
--- a/doc/requirements.txt
+++ b/doc/requirements.txt
@@ -34,4 +34,4 @@ lxml_html_clean
--extra-index-url https://test.pypi.org/simple/
pennylane-lightning-kokkos==0.46.0-dev10
pennylane-lightning==0.46.0-dev10
-pennylane==0.46.0.dev24
+pennylane==0.46.0.dev38
diff --git a/frontend/catalyst/from_plxpr/qfunc_interpreter.py b/frontend/catalyst/from_plxpr/qfunc_interpreter.py
index a1f8a8c098..0d51ee2705 100644
--- a/frontend/catalyst/from_plxpr/qfunc_interpreter.py
+++ b/frontend/catalyst/from_plxpr/qfunc_interpreter.py
@@ -30,6 +30,7 @@
from pennylane.capture.primitives import measure_prim as plxpr_measure_prim
from pennylane.capture.primitives import pauli_measure_prim as plxpr_pauli_measure_prim
from pennylane.capture.primitives import quantum_subroutine_prim, transform_prim
+from pennylane.core.operator.operator2 import operator_p
from pennylane.ftqc.primitives import measure_in_basis_prim as plxpr_measure_in_basis_prim
from pennylane.measurements import CountsMP
from pennylane.wires import AbstractQubit, is_abstract_qubit
@@ -45,6 +46,7 @@
qref_measure_in_basis_p,
qref_measure_p,
qref_namedobs_p,
+ qref_operator_p,
qref_pauli_measure_p,
qref_pauli_rot_p,
qref_qinst_p,
@@ -290,6 +292,29 @@ def __call__(self, jaxpr, *args):
return self.eval(jaxpr.jaxpr, jaxpr.consts, *args)
+@PLxPRToQuantumJaxprInterpreter.register_primitive(operator_p)
+def _handle_operator(self, *args, op_cls, hybrid_lens, hybrid_trees, **kwargs):
+
+ if hybrid_lens or hybrid_trees or op_cls.static_argnames:
+ # only support compilable_argnames for the moment
+ raise NotImplementedError
+
+ wire_inputs = args[len(op_cls.dynamic_argnames) :]
+ new_wires = [
+ w if is_abstract_qubit(w) else qref_get_p.bind(self.init_qreg, w) for w in wire_inputs
+ ]
+
+ qref_operator_p.bind(
+ *args[: len(op_cls.dynamic_argnames)],
+ *new_wires,
+ op_cls=op_cls,
+ hybrid_lens=hybrid_lens,
+ hybrid_trees=hybrid_trees,
+ **kwargs,
+ )
+ return []
+
+
# pylint: disable=unused-argument, too-many-arguments
def _qubit_unitary_bind_call(
*invals, op, qubits_len, params_len, ctrl_len, adjoint, hyperparameters
diff --git a/frontend/catalyst/from_plxpr/qref_jax_primitives.py b/frontend/catalyst/from_plxpr/qref_jax_primitives.py
index 4e3e94d9a4..c47cdd1d99 100644
--- a/frontend/catalyst/from_plxpr/qref_jax_primitives.py
+++ b/frontend/catalyst/from_plxpr/qref_jax_primitives.py
@@ -29,8 +29,11 @@
)
from jaxlib.mlir.dialects.stablehlo import ConvertOp as StableHLOConvertOp
from pennylane.capture.primitives import adjoint_transform_prim as plxpr_adjoint_transform_prim
+from pennylane.pytrees import unflatten
from pennylane.wires import AbstractQubit
+from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval
+
# TODO: remove after jax v0.7.2 upgrade
# Mock _ods_cext.globals.register_traceback_file_exclusion due to API conflicts between
# Catalyst's MLIR version and the MLIR version used by JAX. The current JAX version has not
@@ -72,6 +75,7 @@
MeasureOp,
MultiRZOp,
NamedObsOp,
+ OperatorOp,
PauliRotOp,
PCPhaseOp,
QubitUnitaryOp,
@@ -164,6 +168,7 @@ class MeasurementPlane(Enum):
qref_compbasis_p = Primitive("qref_compbasis")
qref_namedobs_p = Primitive("qref_namedobs")
qref_hermitian_p = Primitive("qref_hermitian")
+qref_operator_p = Primitive("qref_operator")
#
@@ -798,6 +803,62 @@ def _qref_named_obs_lowering(jax_ctx: mlir.LoweringRuleContext, qubit: ir.Value,
return NamedObsOp(result_type, qubit, obsId).results
+qref_operator_p.multiple_results = True
+
+
+@qref_operator_p.def_abstract_eval
+def _qref_operator_p_abstract_eval(*args, **kwargs):
+ return []
+
+
+def _qref_operator_p_lowering(
+ jax_ctx: mlir.LoweringRuleContext,
+ *args,
+ op_cls,
+ hybrid_lens,
+ hybrid_trees,
+ wire_lens,
+ **static_data,
+):
+ params = args[: len(op_cls.dynamic_argnames)]
+ qubits = args[len(op_cls.dynamic_argnames) :]
+
+ name_attr = get_mlir_attribute_from_pyval(op_cls.__name__)
+
+ repack_static_data = {k: unflatten(*v) for k, v in static_data.items()}
+ processed_static_data = get_mlir_attribute_from_pyval(repack_static_data)
+
+ param_map = {
+ name: ir.DenseI64ArrayAttr.get([ind]) for ind, name in enumerate(op_cls.dynamic_argnames)
+ }
+ processed_param_map = get_mlir_attribute_from_pyval(param_map)
+
+ qubit_map = {}
+ ind = 0
+ for name, size in zip(op_cls.wire_argnames, wire_lens):
+ qubit_map[name] = ir.DenseI64ArrayAttr.get(list(range(ind, ind + size)))
+ ind += size
+
+ processed_qubit_map = get_mlir_attribute_from_pyval(qubit_map)
+
+ OperatorOp(
+ op_name=name_attr,
+ params=params,
+ qubits=qubits,
+ qreg=None,
+ forward_args=[],
+ ctrl_qubits=[],
+ ctrl_values=[],
+ adjoint=False,
+ UID=None,
+ arr_qubit_indices=[],
+ param_map=processed_param_map,
+ static_data=processed_static_data,
+ qubit_map=processed_qubit_map,
+ )
+ return []
+
+
#
# hermitian observable
#
@@ -821,6 +882,7 @@ def _qref_hermitian_lowering(jax_ctx: mlir.LoweringRuleContext, matrix: ir.Value
CUSTOM_LOWERING_RULES = (
+ (qref_operator_p, _qref_operator_p_lowering),
(qref_alloc_p, _qref_alloc_lowering),
(qref_dealloc_p, _qref_dealloc_lowering),
(qref_get_p, _qref_get_lowering),
diff --git a/frontend/catalyst/jax_extras/lowering.py b/frontend/catalyst/jax_extras/lowering.py
index d372890229..3812dc97a0 100644
--- a/frontend/catalyst/jax_extras/lowering.py
+++ b/frontend/catalyst/jax_extras/lowering.py
@@ -202,6 +202,9 @@ def get_mlir_attribute_from_pyval(value):
attr = None
match value:
+ case ir.Attribute():
+ attr = value
+
case bool():
attr = ir.BoolAttr.get(value)
diff --git a/frontend/test/lit/test_operator.py b/frontend/test/lit/test_operator.py
new file mode 100644
index 0000000000..12bafa1f1a
--- /dev/null
+++ b/frontend/test/lit/test_operator.py
@@ -0,0 +1,156 @@
+# Copyright 2022-2023 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for operator in Catalyst."""
+
+# pylint: disable = useless-parent-delegation, missing-function-docstring, missing-class-docstring
+
+# RUN: %PYTHON %s | FileCheck %s
+
+import numpy as np
+import pennylane as qp
+
+
+class NoParams(qp.core.Operator2):
+
+ def __init__(self, wires):
+ super().__init__(wires=wires)
+
+
+@qp.qjit(target="mlir", capture=True)
+@qp.qnode(qp.device("null.qubit", wires=2))
+def c_no_params():
+ # CHECK: [[q0:%.+]] = qref.get {{%.+}}
+ # CHECK: qref.operator "NoParams"() qubits([[q0]])
+ # CHECK: static_data = {}
+ # CHECK: param_map = {} qubit_map = {wires = [0]}
+ NoParams(wires=0)
+
+ # CHECK: [[q1:%.+]] = qref.get {{%.+}}
+ # CHECK: [[q2:%.+]] = qref.get {{%.+}}
+ # CHECK: qref.operator "NoParams"() qubits([[q1]], [[q2]])
+ # CHECK: static_data = {}
+ # CHECK: param_map = {} qubit_map = {wires = [0, 1]}
+ NoParams(wires=(0, 1))
+ return qp.state()
+
+
+print(c_no_params.mlir)
+
+
+class SingleParam(qp.core.Operator2):
+
+ dynamic_argnames = ("x",)
+
+ def __init__(self, x, wires):
+ super().__init__(x, wires=wires)
+
+
+@qp.qjit(target="mlir", capture=True)
+@qp.qnode(qp.device("null.qubit", wires=3))
+def c_single_param(x: float):
+
+ # CHECK: [[q0:%.+]] = qref.get {{%.+}}
+
+ # CHECK: qref.operator "SingleParam"({{%.+}}: tensor) qubits([[q0]])
+ # CHECK: static_data = {}
+ # CHECK: param_map = {x = [0]} qubit_map = {wires = [0]}
+ SingleParam(x, 0)
+
+ # CHECK: [[q1:%.+]] = qref.get {{%.+}}
+ # CHECK: [[q2:%.+]] = qref.get {{%.+}}
+ # CHECK: qref.operator "SingleParam"({{%.+}}: tensor<4x4xf64>) qubits([[q1]], [[q2]])
+ # CHECK: static_data = {}
+ # CHECK: param_map = {x = [0]} qubit_map = {wires = [0, 1]}
+ SingleParam(np.eye(4), (1, 2))
+
+ return qp.state()
+
+
+print(c_single_param.mlir)
+
+
+class CompilableData(qp.core.Operator2):
+
+ compilable_argnames = ("a", "b", "thing")
+
+ def __init__(self, a, b, thing, wires):
+ super().__init__(a=a, b=b, thing=thing, wires=wires)
+
+
+@qp.qjit(capture=True, target="mlir")
+@qp.qnode(qp.device("null.qubit", wires=3))
+def c_compilable():
+ # CHECK: [[q1:%.+]] = qref.get {{%.+}}
+ # CHECK: [[q2:%.+]] = qref.get {{%.+}}
+ # CHECK: qref.operator "CompilableData"() qubits([[q1]], [[q2]])
+ # CHECK: static_data = {a = true, b = "some string", thing = [1, true, "string"]}
+ # CHECK: param_map = {} qubit_map = {wires = [0, 1]}
+
+ CompilableData(True, "some string", (1, True, "string"), wires=(0, 1))
+
+ return qp.state()
+
+
+print(c_compilable.mlir)
+
+
+class MultipleRegisters(qp.core.Operator2):
+
+ wire_argnames = ("reg1", "reg2")
+
+ def __init__(self, reg1, reg2):
+ super().__init__(reg1=reg1, reg2=reg2)
+
+
+@qp.qjit(capture=True, target="mlir")
+@qp.qnode(qp.device("null.qubit", wires=5))
+def c_multiple_registers():
+ # CHECK: [[q0:%.+]] = qref.get {{%.+}}
+ # CHECK: [[q2:%.+]] = qref.get {{%.+}}
+ # CHECK: [[q3:%.+]] = qref.get {{%.+}}
+ # CHECK: [[q4:%.+]] = qref.get {{%.+}}
+
+ # CHECK: qref.operator "MultipleRegisters"() qubits([[q0]], [[q2]], [[q3]], [[q4]])
+ # CHECK: static_data = {}
+ # CHECK: param_map = {} qubit_map = {reg1 = [0], reg2 = [1, 2, 3]}
+ MultipleRegisters(0, (2, 3, 4))
+ return qp.state()
+
+
+print(c_multiple_registers.mlir)
+
+
+class MultiParams(qp.core.Operator2):
+
+ dynamic_argnames = ("a", "b", "c")
+
+ # note also having non-standard order with dynamic inputs after wires
+ def __init__(self, wires, a, b, c):
+ super().__init__(wires, a, b, c)
+
+
+@qp.qjit(capture=True, target="mlir")
+@qp.qnode(qp.device("null.qubit", wires=1))
+def c_multi_params():
+ # CHECK: [[q0:%.+]] = qref.get {{%.+}}
+
+ # pylint: disable=line-too-long
+ # CHECK: qref.operator "MultiParams"({{%.+}}: tensor, {{%.+}}: tensor<4x2x1xf64>, {{%.+}}: tensor<3xi64>) qubits([[q0]])
+ # CHECK: static_data = {}
+ # CHECK: param_map = {a = [0], b = [1], c = [2]} qubit_map = {wires = [0]}
+ MultiParams(0, 0.5, c=np.array([1, 2, 3]), b=np.zeros((4, 2, 1)))
+ return qp.state()
+
+
+print(c_multi_params.mlir)
diff --git a/frontend/test/pytest/test_operator.py b/frontend/test/pytest/test_operator.py
new file mode 100644
index 0000000000..6558f5972d
--- /dev/null
+++ b/frontend/test/pytest/test_operator.py
@@ -0,0 +1,65 @@
+# Copyright 2022-2026 Xanadu Quantum Technologies Inc.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Tests for the new Operator2 class.
+"""
+
+# pylint: disable = useless-parent-delegation, missing-function-docstring, missing-class-docstring
+import pennylane as qp
+import pytest
+
+
+class DummyOp(qp.core.Operator2):
+
+ def __init__(self, wires):
+ super().__init__(wires=wires)
+
+
+def test_hybrid_not_supported_yet():
+ """Test that hybrid arguments are not yet supported."""
+
+ class OperatorArgument(qp.core.Operator2):
+
+ hybrid_argnames = ("op",)
+ wire_argnames = ()
+
+ def __init__(self, op):
+ super().__init__(op)
+
+ with pytest.raises(NotImplementedError):
+
+ @qp.qjit(capture=True)
+ @qp.qnode(qp.device("null.qubit", wires=3))
+ def c():
+ OperatorArgument(DummyOp(0))
+ return qp.state()
+
+
+def test_static_argnames():
+ """Test that static arguments are not yet supported."""
+
+ class StaticArgsOp(qp.core.Operator2):
+
+ static_argnames = ("thing",)
+
+ def __init__(self, thing, wires):
+ super().__init__(thing, wires)
+
+ with pytest.raises(NotImplementedError):
+
+ @qp.qjit(capture=True)
+ @qp.qnode(qp.device("null.qubit", wires=2))
+ def c():
+ StaticArgsOp("hello", 0)
+ return qp.state()
diff --git a/mlir/include/QRef/IR/QRefOps.td b/mlir/include/QRef/IR/QRefOps.td
index 075f871e2c..028916927f 100644
--- a/mlir/include/QRef/IR/QRefOps.td
+++ b/mlir/include/QRef/IR/QRefOps.td
@@ -412,7 +412,7 @@ def OperatorOp : UnitaryGate_Op<"operator", [ParametrizedGate, AttrSizedOperandS
}];
let arguments = (ins
- StringProp:$op_name,
+ StrAttr:$op_name,
Variadic:$params,
Variadic:$forward_args,
diff --git a/mlir/include/Quantum/IR/QuantumOps.td b/mlir/include/Quantum/IR/QuantumOps.td
index 5b7db9bebf..8623922262 100644
--- a/mlir/include/Quantum/IR/QuantumOps.td
+++ b/mlir/include/Quantum/IR/QuantumOps.td
@@ -613,7 +613,7 @@ def OperatorOp : UnitaryGate_Op<"operator", [ParametrizedGate, NoMemoryEffect,
}];
let arguments = (ins
- StringProp:$op_name,
+ StrAttr:$op_name,
Variadic:$params,
Variadic:$forward_args,
diff --git a/mlir/lib/QRef/IR/QRefOps.cpp b/mlir/lib/QRef/IR/QRefOps.cpp
index f53e3be8cc..3e754fc03b 100644
--- a/mlir/lib/QRef/IR/QRefOps.cpp
+++ b/mlir/lib/QRef/IR/QRefOps.cpp
@@ -482,7 +482,7 @@ void OperatorOp::print(OpAsmPrinter &p)
// 5. Attribute Dictionary
SmallVector elidedAttrs = {"static_data", "param_map", "qubit_map",
- "operandSegmentSizes"};
+ "operandSegmentSizes", "op_name"};
p.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs);
p.increaseIndent();
@@ -594,8 +594,8 @@ ParseResult OperatorOp::parse(OpAsmParser &parser, OperationState &result)
if (parser.parseString(&opName)) {
return failure();
}
+ result.addAttribute("op_name", builder.getStringAttr(opName));
auto &opProperties = result.getOrAddProperties();
- opProperties.setOpName(opName);
// 2. Parse variadic params: (%arg0: type, ...)
SmallVector params;
diff --git a/mlir/lib/QRef/Transforms/value_semantics_conversion.cpp b/mlir/lib/QRef/Transforms/value_semantics_conversion.cpp
index d3475f46c2..689af5bb9c 100644
--- a/mlir/lib/QRef/Transforms/value_semantics_conversion.cpp
+++ b/mlir/lib/QRef/Transforms/value_semantics_conversion.cpp
@@ -1159,7 +1159,6 @@ void handleGate(IRRewriter &builder, qref::QuantumOperation rGateOp, QubitValueT
builder.getDenseI32ArrayAttr({nTargets, nCtrls, nQreg}));
// Properties are not handled via the generic attribute fields, so we set them separately.
- vGateOp.setOpName(rOperatorOp.getOpName());
vGateOp.setAdjoint(rOperatorOp.getAdjoint());
vGateOp.setUID(rOperatorOp.getUID());
}
diff --git a/mlir/lib/Quantum/IR/QuantumOps.cpp b/mlir/lib/Quantum/IR/QuantumOps.cpp
index 736d9f2d4b..62977102c4 100644
--- a/mlir/lib/Quantum/IR/QuantumOps.cpp
+++ b/mlir/lib/Quantum/IR/QuantumOps.cpp
@@ -785,8 +785,8 @@ void OperatorOp::print(OpAsmPrinter &p)
}
// 5. Attribute Dictionary
- SmallVector elidedAttrs = {"static_data", "param_map", "qubit_map",
- "operandSegmentSizes", "resultSegmentSizes"};
+ SmallVector elidedAttrs = {"static_data", "param_map", "qubit_map",
+ "operandSegmentSizes", "resultSegmentSizes", "op_name"};
p.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs);
p.increaseIndent();
@@ -899,7 +899,7 @@ ParseResult OperatorOp::parse(OpAsmParser &parser, OperationState &result)
return failure();
}
auto &opProperties = result.getOrAddProperties();
- opProperties.setOpName(opName);
+ result.addAttribute("op_name", builder.getStringAttr(opName));
// 2. Parse variadic params: (%arg0: type, ...)
SmallVector params;
diff --git a/mlir/lib/Quantum/Transforms/reference_semantics_conversion.cpp b/mlir/lib/Quantum/Transforms/reference_semantics_conversion.cpp
index 3dbbf0baf1..5c35f70782 100644
--- a/mlir/lib/Quantum/Transforms/reference_semantics_conversion.cpp
+++ b/mlir/lib/Quantum/Transforms/reference_semantics_conversion.cpp
@@ -368,7 +368,6 @@ void handleGate(IRRewriter &builder, quantum::QuantumOperation vGateOp, QubitVal
rGateOp->removeAttr("resultSegmentSizes");
// Properties are not handled via the generic attribute fields, so we set them separately.
- rGateOp.setOpName(vOperatorOp.getOpName());
rGateOp.setAdjoint(vOperatorOp.getAdjoint());
rGateOp.setUID(vOperatorOp.getUID());
}