From 034e5003c2084f46ed2e9699f1d6072a683c3e09 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Fri, 19 Jun 2026 16:10:15 -0400 Subject: [PATCH 01/16] add translation and lowering of OperatorOp --- .../catalyst/from_plxpr/qfunc_interpreter.py | 22 ++++++++++++++++ .../from_plxpr/qref_jax_primitives.py | 26 +++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/frontend/catalyst/from_plxpr/qfunc_interpreter.py b/frontend/catalyst/from_plxpr/qfunc_interpreter.py index a1f8a8c098..fca616acd2 100644 --- a/frontend/catalyst/from_plxpr/qfunc_interpreter.py +++ b/frontend/catalyst/from_plxpr/qfunc_interpreter.py @@ -25,6 +25,7 @@ import pennylane as qp from jax._src.sharding_impls import UNSPECIFIED from pennylane.capture import PlxprInterpreter, pause +from pennylane.core.operator.operator2 import operator_p from pennylane.capture.primitives import cond_prim as pl_cond_prim from pennylane.capture.primitives import ctrl_transform_prim as plxpr_ctrl_transform_prim from pennylane.capture.primitives import measure_prim as plxpr_measure_prim @@ -51,6 +52,7 @@ qref_set_basis_state_p, qref_set_state_p, qref_unitary_p, + qref_operator_op, ) from catalyst.jax_primitives import ( counts_p, @@ -290,6 +292,26 @@ def __call__(self, jaxpr, *args): return self.eval(jaxpr.jaxpr, jaxpr.consts, *args) +@PLxPRToQuantumJaxprInterpreter.register_primitive(operator_p) +def _handle_operator(self, *args, op_cls, hybrid_lens, hybrid_trees, **kwargs): + + if hybrid_lens or hybrid_trees: + raise NotImplementedError + + wire_inputs = args[len(op_cls.dynamic_argnames):] + new_wires = [w if is_abstract_qubit(w) else qref_get_p.bind(self.init_qreg, w) for w in wire_inputs] + + qref_operator_op.bind( + *args[:len(op_cls.dynamic_argnames)], + *new_wires, + op_cls=op_cls, + hybrid_lens=hybrid_lens, + hybrid_trees=hybrid_trees, + **kwargs + ) + return [] + + # pylint: disable=unused-argument, too-many-arguments def _qubit_unitary_bind_call( *invals, op, qubits_len, params_len, ctrl_len, adjoint, hyperparameters diff --git a/frontend/catalyst/from_plxpr/qref_jax_primitives.py b/frontend/catalyst/from_plxpr/qref_jax_primitives.py index 4e3e94d9a4..c31768a786 100644 --- a/frontend/catalyst/from_plxpr/qref_jax_primitives.py +++ b/frontend/catalyst/from_plxpr/qref_jax_primitives.py @@ -72,6 +72,7 @@ MeasureOp, MultiRZOp, NamedObsOp, + OperatorOp, PauliRotOp, PCPhaseOp, QubitUnitaryOp, @@ -164,6 +165,7 @@ class MeasurementPlane(Enum): qref_compbasis_p = Primitive("qref_compbasis") qref_namedobs_p = Primitive("qref_namedobs") qref_hermitian_p = Primitive("qref_hermitian") +qref_operator_op = Primitive("qref_operator") # @@ -797,6 +799,29 @@ def _qref_named_obs_lowering(jax_ctx: mlir.LoweringRuleContext, qubit: ir.Value, return NamedObsOp(result_type, qubit, obsId).results +qref_operator_op.multiple_results = True + +@qref_operator_op.def_abstract_eval +def _qref_operator_op_abstract_eval(*args, **kwargs): + return [] + +def _operator_op_lowering(jax_ctx: mlir.LoweringRuleContext, *args, op_cls, hybrid_lens, hybrid_trees, wire_lens, **static_args): + params = args[:len(op_cls.dynamic_argnames)] + qubits = args[len(op_cls.dynamic_argnames):] + + name_attr = ir.StringAttr.get(op_cls.__name__) + + OperatorOp(op_name=name_attr, + params = params, + qubits = qubits, + forward_args=[], + ctrl_qubits=[], + ctrl_values=[], + adjoint=False, + UID=None, + arr_qubit_indices=[] + ) + return [] # # hermitian observable @@ -821,6 +846,7 @@ def _qref_hermitian_lowering(jax_ctx: mlir.LoweringRuleContext, matrix: ir.Value CUSTOM_LOWERING_RULES = ( + (qref_operator_op, _operator_op_lowering), (qref_alloc_p, _qref_alloc_lowering), (qref_dealloc_p, _qref_dealloc_lowering), (qref_get_p, _qref_get_lowering), From a9308fc8df8ab63ca57dad127d79bbd2808e299b Mon Sep 17 00:00:00 2001 From: albi3ro Date: Fri, 19 Jun 2026 17:29:03 -0400 Subject: [PATCH 02/16] sonnet's AI recommendation to fix python bindings issue --- mlir/python/dialects/qref.py | 45 ++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/mlir/python/dialects/qref.py b/mlir/python/dialects/qref.py index 5522c08da8..951de35d86 100644 --- a/mlir/python/dialects/qref.py +++ b/mlir/python/dialects/qref.py @@ -16,3 +16,48 @@ # pylint: disable=relative-beyond-top-level from ._qref_ops_gen import * # noqa: F401 +from ._qref_ops_gen import OperatorOp as _OperatorOpGen +from ._ods_common import ( + get_default_loc_context as _ods_get_default_loc_context, + get_op_results_or_values as _get_op_results_or_values, +) +from ._ods_common import _cext as _ods_cext +_ods_ir = _ods_cext.ir + + +class OperatorOp(_OperatorOpGen): + def __init__(self, op_name, params, forward_args, qubits, ctrl_qubits, + ctrl_values, arr_qubit_indices=None, adjoint=False, UID=None, *, + qreg=None, arr_ctrl_indices=None, arr_ctrl_values=None, + static_data=None, param_map=None, qubit_map=None, loc=None, ip=None): + operands = [] + attributes = {} + # op_name is a StringProp — must go in attributes + attributes["op_name"] = (op_name if isinstance(op_name, _ods_ir.Attribute) + else _ods_ir.StringAttr.get(op_name)) + operands.append(_get_op_results_or_values(params)) + operands.append(_get_op_results_or_values(forward_args)) + operands.append(_get_op_results_or_values(qubits)) + operands.append(_get_op_results_or_values(ctrl_qubits)) + operands.append(ctrl_values if ctrl_values is not None else []) + operands.append(_get_op_results_or_values(qreg)) + operands.append(arr_qubit_indices if arr_qubit_indices is not None else []) + operands.append(arr_ctrl_indices) + operands.append(arr_ctrl_values) + _ods_context = _ods_get_default_loc_context(loc) + if bool(adjoint): + attributes["adjoint"] = _ods_ir.UnitAttr.get(_ods_context) + if UID is not None: + attributes["UID"] = _ods_ir.IntegerAttr.get( + _ods_ir.IntegerType.get_signless(64), UID) + if static_data is not None: + attributes["static_data"] = static_data + if param_map is not None: + attributes["param_map"] = param_map + if qubit_map is not None: + attributes["qubit_map"] = qubit_map + # bypass _OperatorOpGen.__init__ and go directly to OpView + super(_OperatorOpGen, self).__init__( + self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, + self._ODS_RESULT_SEGMENTS, attributes=attributes, results=[], + operands=operands, successors=None, regions=None, loc=loc, ip=ip) \ No newline at end of file From 8599ec6219b754d73edf9eb78805b1860203a7c9 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 24 Jun 2026 10:02:36 -0400 Subject: [PATCH 03/16] switch to StrAttr --- mlir/include/QRef/IR/QRefOps.td | 2 +- mlir/include/Quantum/IR/QuantumOps.td | 2 +- mlir/lib/QRef/IR/QRefOps.cpp | 3 ++- mlir/lib/Quantum/IR/QuantumOps.cpp | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/mlir/include/QRef/IR/QRefOps.td b/mlir/include/QRef/IR/QRefOps.td index 075f871e2c..028916927f 100644 --- a/mlir/include/QRef/IR/QRefOps.td +++ b/mlir/include/QRef/IR/QRefOps.td @@ -412,7 +412,7 @@ def OperatorOp : UnitaryGate_Op<"operator", [ParametrizedGate, AttrSizedOperandS }]; let arguments = (ins - StringProp:$op_name, + StrAttr:$op_name, Variadic:$params, Variadic:$forward_args, diff --git a/mlir/include/Quantum/IR/QuantumOps.td b/mlir/include/Quantum/IR/QuantumOps.td index 5b7db9bebf..8623922262 100644 --- a/mlir/include/Quantum/IR/QuantumOps.td +++ b/mlir/include/Quantum/IR/QuantumOps.td @@ -613,7 +613,7 @@ def OperatorOp : UnitaryGate_Op<"operator", [ParametrizedGate, NoMemoryEffect, }]; let arguments = (ins - StringProp:$op_name, + StrAttr:$op_name, Variadic:$params, Variadic:$forward_args, diff --git a/mlir/lib/QRef/IR/QRefOps.cpp b/mlir/lib/QRef/IR/QRefOps.cpp index f53e3be8cc..8f4e24cd9d 100644 --- a/mlir/lib/QRef/IR/QRefOps.cpp +++ b/mlir/lib/QRef/IR/QRefOps.cpp @@ -594,8 +594,9 @@ ParseResult OperatorOp::parse(OpAsmParser &parser, OperationState &result) if (parser.parseString(&opName)) { return failure(); } + result.addAttribute("op_name", builder.getStringAttr(opName)); auto &opProperties = result.getOrAddProperties(); - opProperties.setOpName(opName); + // 2. Parse variadic params: (%arg0: type, ...) SmallVector params; diff --git a/mlir/lib/Quantum/IR/QuantumOps.cpp b/mlir/lib/Quantum/IR/QuantumOps.cpp index 736d9f2d4b..821003d195 100644 --- a/mlir/lib/Quantum/IR/QuantumOps.cpp +++ b/mlir/lib/Quantum/IR/QuantumOps.cpp @@ -899,7 +899,7 @@ ParseResult OperatorOp::parse(OpAsmParser &parser, OperationState &result) return failure(); } auto &opProperties = result.getOrAddProperties(); - opProperties.setOpName(opName); + result.addAttribute("op_name", builder.getStringAttr(opName)); // 2. Parse variadic params: (%arg0: type, ...) SmallVector params; From 5c4504e4f19b04edae3a539444cd0499dcb2fa87 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 24 Jun 2026 10:23:51 -0400 Subject: [PATCH 04/16] bump pl version --- .dep-versions | 2 +- doc/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.dep-versions b/.dep-versions index 4ca08f1d5c..eff2c832ec 100644 --- a/.dep-versions +++ b/.dep-versions @@ -8,7 +8,7 @@ enzyme=v0.0.238 # For a custom PL version, update the package version here and at # 'doc/requirements.txt' -pennylane=0.46.0.dev24 +pennylane=0.46.0.dev38 # For a custom LQ/LK version, update the package version here and at # 'doc/requirements.txt' diff --git a/doc/requirements.txt b/doc/requirements.txt index c6721a731a..3459767b3b 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -34,4 +34,4 @@ lxml_html_clean --extra-index-url https://test.pypi.org/simple/ pennylane-lightning-kokkos==0.46.0-dev10 pennylane-lightning==0.46.0-dev10 -pennylane==0.46.0.dev24 +pennylane==0.46.0.dev38 From b23289ad514afd768691ae0b61233cee122b6858 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 24 Jun 2026 10:58:05 -0400 Subject: [PATCH 05/16] remove qref patch --- mlir/python/dialects/qref.py | 45 ------------------------------------ 1 file changed, 45 deletions(-) diff --git a/mlir/python/dialects/qref.py b/mlir/python/dialects/qref.py index 951de35d86..5522c08da8 100644 --- a/mlir/python/dialects/qref.py +++ b/mlir/python/dialects/qref.py @@ -16,48 +16,3 @@ # pylint: disable=relative-beyond-top-level from ._qref_ops_gen import * # noqa: F401 -from ._qref_ops_gen import OperatorOp as _OperatorOpGen -from ._ods_common import ( - get_default_loc_context as _ods_get_default_loc_context, - get_op_results_or_values as _get_op_results_or_values, -) -from ._ods_common import _cext as _ods_cext -_ods_ir = _ods_cext.ir - - -class OperatorOp(_OperatorOpGen): - def __init__(self, op_name, params, forward_args, qubits, ctrl_qubits, - ctrl_values, arr_qubit_indices=None, adjoint=False, UID=None, *, - qreg=None, arr_ctrl_indices=None, arr_ctrl_values=None, - static_data=None, param_map=None, qubit_map=None, loc=None, ip=None): - operands = [] - attributes = {} - # op_name is a StringProp — must go in attributes - attributes["op_name"] = (op_name if isinstance(op_name, _ods_ir.Attribute) - else _ods_ir.StringAttr.get(op_name)) - operands.append(_get_op_results_or_values(params)) - operands.append(_get_op_results_or_values(forward_args)) - operands.append(_get_op_results_or_values(qubits)) - operands.append(_get_op_results_or_values(ctrl_qubits)) - operands.append(ctrl_values if ctrl_values is not None else []) - operands.append(_get_op_results_or_values(qreg)) - operands.append(arr_qubit_indices if arr_qubit_indices is not None else []) - operands.append(arr_ctrl_indices) - operands.append(arr_ctrl_values) - _ods_context = _ods_get_default_loc_context(loc) - if bool(adjoint): - attributes["adjoint"] = _ods_ir.UnitAttr.get(_ods_context) - if UID is not None: - attributes["UID"] = _ods_ir.IntegerAttr.get( - _ods_ir.IntegerType.get_signless(64), UID) - if static_data is not None: - attributes["static_data"] = static_data - if param_map is not None: - attributes["param_map"] = param_map - if qubit_map is not None: - attributes["qubit_map"] = qubit_map - # bypass _OperatorOpGen.__init__ and go directly to OpView - super(_OperatorOpGen, self).__init__( - self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, - self._ODS_RESULT_SEGMENTS, attributes=attributes, results=[], - operands=operands, successors=None, regions=None, loc=loc, ip=ip) \ No newline at end of file From 6d4de2b446732617a011be7283b0326df4a6abc6 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 24 Jun 2026 12:54:03 -0400 Subject: [PATCH 06/16] formatting --- mlir/lib/QRef/IR/QRefOps.cpp | 3 +-- mlir/lib/Quantum/IR/QuantumOps.cpp | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/mlir/lib/QRef/IR/QRefOps.cpp b/mlir/lib/QRef/IR/QRefOps.cpp index 8f4e24cd9d..3e754fc03b 100644 --- a/mlir/lib/QRef/IR/QRefOps.cpp +++ b/mlir/lib/QRef/IR/QRefOps.cpp @@ -482,7 +482,7 @@ void OperatorOp::print(OpAsmPrinter &p) // 5. Attribute Dictionary SmallVector elidedAttrs = {"static_data", "param_map", "qubit_map", - "operandSegmentSizes"}; + "operandSegmentSizes", "op_name"}; p.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs); p.increaseIndent(); @@ -596,7 +596,6 @@ ParseResult OperatorOp::parse(OpAsmParser &parser, OperationState &result) } result.addAttribute("op_name", builder.getStringAttr(opName)); auto &opProperties = result.getOrAddProperties(); - // 2. Parse variadic params: (%arg0: type, ...) SmallVector params; diff --git a/mlir/lib/Quantum/IR/QuantumOps.cpp b/mlir/lib/Quantum/IR/QuantumOps.cpp index 821003d195..62977102c4 100644 --- a/mlir/lib/Quantum/IR/QuantumOps.cpp +++ b/mlir/lib/Quantum/IR/QuantumOps.cpp @@ -785,8 +785,8 @@ void OperatorOp::print(OpAsmPrinter &p) } // 5. Attribute Dictionary - SmallVector elidedAttrs = {"static_data", "param_map", "qubit_map", - "operandSegmentSizes", "resultSegmentSizes"}; + SmallVector elidedAttrs = {"static_data", "param_map", "qubit_map", + "operandSegmentSizes", "resultSegmentSizes", "op_name"}; p.printOptionalAttrDict(getOperation()->getAttrs(), elidedAttrs); p.increaseIndent(); From daca7ae39dc2b802878a8b71fa12784f49899db9 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 24 Jun 2026 12:54:36 -0400 Subject: [PATCH 07/16] black --- .../catalyst/from_plxpr/qfunc_interpreter.py | 10 +++-- .../from_plxpr/qref_jax_primitives.py | 43 ++++++++++++------- 2 files changed, 34 insertions(+), 19 deletions(-) diff --git a/frontend/catalyst/from_plxpr/qfunc_interpreter.py b/frontend/catalyst/from_plxpr/qfunc_interpreter.py index fca616acd2..c7e3cb45a2 100644 --- a/frontend/catalyst/from_plxpr/qfunc_interpreter.py +++ b/frontend/catalyst/from_plxpr/qfunc_interpreter.py @@ -298,16 +298,18 @@ def _handle_operator(self, *args, op_cls, hybrid_lens, hybrid_trees, **kwargs): if hybrid_lens or hybrid_trees: raise NotImplementedError - wire_inputs = args[len(op_cls.dynamic_argnames):] - new_wires = [w if is_abstract_qubit(w) else qref_get_p.bind(self.init_qreg, w) for w in wire_inputs] + wire_inputs = args[len(op_cls.dynamic_argnames) :] + new_wires = [ + w if is_abstract_qubit(w) else qref_get_p.bind(self.init_qreg, w) for w in wire_inputs + ] qref_operator_op.bind( - *args[:len(op_cls.dynamic_argnames)], + *args[: len(op_cls.dynamic_argnames)], *new_wires, op_cls=op_cls, hybrid_lens=hybrid_lens, hybrid_trees=hybrid_trees, - **kwargs + **kwargs, ) return [] diff --git a/frontend/catalyst/from_plxpr/qref_jax_primitives.py b/frontend/catalyst/from_plxpr/qref_jax_primitives.py index c31768a786..4dd957288a 100644 --- a/frontend/catalyst/from_plxpr/qref_jax_primitives.py +++ b/frontend/catalyst/from_plxpr/qref_jax_primitives.py @@ -799,30 +799,43 @@ def _qref_named_obs_lowering(jax_ctx: mlir.LoweringRuleContext, qubit: ir.Value, return NamedObsOp(result_type, qubit, obsId).results + qref_operator_op.multiple_results = True + @qref_operator_op.def_abstract_eval def _qref_operator_op_abstract_eval(*args, **kwargs): return [] -def _operator_op_lowering(jax_ctx: mlir.LoweringRuleContext, *args, op_cls, hybrid_lens, hybrid_trees, wire_lens, **static_args): - params = args[:len(op_cls.dynamic_argnames)] - qubits = args[len(op_cls.dynamic_argnames):] - + +def _operator_op_lowering( + jax_ctx: mlir.LoweringRuleContext, + *args, + op_cls, + hybrid_lens, + hybrid_trees, + wire_lens, + **static_args, +): + params = args[: len(op_cls.dynamic_argnames)] + qubits = args[len(op_cls.dynamic_argnames) :] + name_attr = ir.StringAttr.get(op_cls.__name__) - - OperatorOp(op_name=name_attr, - params = params, - qubits = qubits, - forward_args=[], - ctrl_qubits=[], - ctrl_values=[], - adjoint=False, - UID=None, - arr_qubit_indices=[] - ) + + OperatorOp( + op_name=name_attr, + params=params, + qubits=qubits, + forward_args=[], + ctrl_qubits=[], + ctrl_values=[], + adjoint=False, + UID=None, + arr_qubit_indices=[], + ) return [] + # # hermitian observable # From 79e060e3ac936908de17847ff8b530d9914d5f77 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 24 Jun 2026 16:58:46 -0400 Subject: [PATCH 08/16] starting to add some tests --- frontend/catalyst/from_plxpr/from_plxpr.py | 4 + .../catalyst/from_plxpr/qfunc_interpreter.py | 3 +- .../from_plxpr/qref_jax_primitives.py | 23 ++- frontend/catalyst/jax_extras/lowering.py | 3 + frontend/test/lit/test_operator.py | 146 ++++++++++++++++++ 5 files changed, 176 insertions(+), 3 deletions(-) create mode 100644 frontend/test/lit/test_operator.py diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index f3b0366f57..bafceb02b0 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -25,6 +25,7 @@ import jax import pennylane as qp from jax.extend.core import ClosedJaxpr, Jaxpr +from pennylane.core.operator.operator2 import operator_p from pennylane.capture import PlxprInterpreter, qnode_prim from pennylane.capture.primitives import transform_prim from pennylane.transforms import decompose as pl_decompose @@ -407,6 +408,9 @@ def _handle_decompose_transform(self, inner_jaxpr, consts, non_const_args, tkwar # return self.eval(final_jaxpr.jaxpr, consts, *non_const_args) return next_eval.eval(inner_jaxpr, consts, *non_const_args) +@WorkflowInterpreter.register_primitive(operator_p) +def _error_on_operator(self, *args, op_cls, **kwargs): + raise ValueError(f"Operator {op_cls} must occur inside a qnode.") # pylint: disable=too-many-arguments @WorkflowInterpreter.register_primitive(transform_prim) diff --git a/frontend/catalyst/from_plxpr/qfunc_interpreter.py b/frontend/catalyst/from_plxpr/qfunc_interpreter.py index c7e3cb45a2..918139323c 100644 --- a/frontend/catalyst/from_plxpr/qfunc_interpreter.py +++ b/frontend/catalyst/from_plxpr/qfunc_interpreter.py @@ -295,7 +295,8 @@ def __call__(self, jaxpr, *args): @PLxPRToQuantumJaxprInterpreter.register_primitive(operator_p) def _handle_operator(self, *args, op_cls, hybrid_lens, hybrid_trees, **kwargs): - if hybrid_lens or hybrid_trees: + if hybrid_lens or hybrid_trees or op_cls.static_argnames: + # only support compilable_argnames for the moment raise NotImplementedError wire_inputs = args[len(op_cls.dynamic_argnames) :] diff --git a/frontend/catalyst/from_plxpr/qref_jax_primitives.py b/frontend/catalyst/from_plxpr/qref_jax_primitives.py index 4dd957288a..315003a187 100644 --- a/frontend/catalyst/from_plxpr/qref_jax_primitives.py +++ b/frontend/catalyst/from_plxpr/qref_jax_primitives.py @@ -30,6 +30,7 @@ from jaxlib.mlir.dialects.stablehlo import ConvertOp as StableHLOConvertOp from pennylane.capture.primitives import adjoint_transform_prim as plxpr_adjoint_transform_prim from pennylane.wires import AbstractQubit +from pennylane.pytrees import unflatten # TODO: remove after jax v0.7.2 upgrade # Mock _ods_cext.globals.register_traceback_file_exclusion due to API conflicts between @@ -38,6 +39,7 @@ # once JAX updates to a compatible MLIR version # pylint: disable=ungrouped-imports from catalyst.jax_extras.patches import mock_attributes +from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval from catalyst.jax_primitives import ( AbstractObs, _named_obs_attribute, @@ -815,12 +817,26 @@ def _operator_op_lowering( hybrid_lens, hybrid_trees, wire_lens, - **static_args, + **static_data, ): params = args[: len(op_cls.dynamic_argnames)] qubits = args[len(op_cls.dynamic_argnames) :] - name_attr = ir.StringAttr.get(op_cls.__name__) + name_attr = get_mlir_attribute_from_pyval(op_cls.__name__) + + repack_static_data = {k: unflatten(*v) for k, v in static_data.items()} + processed_static_data = get_mlir_attribute_from_pyval(repack_static_data) + + param_map = {name: ir.DenseI64ArrayAttr.get([ind]) for ind, name in enumerate(op_cls.dynamic_argnames)} + processed_param_map = get_mlir_attribute_from_pyval(param_map) + + qubit_map = {} + ind = 0 + for name, size in zip(op_cls.wire_argnames, wire_lens): + qubit_map[name] = ir.DenseI64ArrayAttr.get(list(range(ind, ind+size))) + ind += size + + processed_qubit_map = get_mlir_attribute_from_pyval(qubit_map) OperatorOp( op_name=name_attr, @@ -832,6 +848,9 @@ def _operator_op_lowering( adjoint=False, UID=None, arr_qubit_indices=[], + param_map=processed_param_map, + static_data=processed_static_data, + qubit_map=processed_qubit_map, ) return [] diff --git a/frontend/catalyst/jax_extras/lowering.py b/frontend/catalyst/jax_extras/lowering.py index d372890229..3812dc97a0 100644 --- a/frontend/catalyst/jax_extras/lowering.py +++ b/frontend/catalyst/jax_extras/lowering.py @@ -202,6 +202,9 @@ def get_mlir_attribute_from_pyval(value): attr = None match value: + case ir.Attribute(): + attr = value + case bool(): attr = ir.BoolAttr.get(value) diff --git a/frontend/test/lit/test_operator.py b/frontend/test/lit/test_operator.py new file mode 100644 index 0000000000..9260ddcb21 --- /dev/null +++ b/frontend/test/lit/test_operator.py @@ -0,0 +1,146 @@ +# Copyright 2022-2023 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for operator in Catalyst.""" + + +# RUN: %PYTHON %s | FileCheck %s + +import numpy as np +import pennylane as qp + +class NoParams(qp.core.Operator2): + + def __init__(self, wires): + super().__init__(wires=wires) + + +@qp.qjit(target="mlir", capture=True) +@qp.qnode(qp.device('null.qubit', wires=2)) +def c_no_params(): + # CHECK: [[q0:%.+]] = qref.get {{%.+}} + # CHECK: qref.operator "NoParams"() qubits([[q0]]) + # CHECK static_data = {} + # CHECK param_map = {} qubit_map = {wires = [0]} + NoParams(wires=0) + + # CHECK: [[q1:%.+]] = qref.get {{%.+}} + # CHECK: [[q2:%.+]] = qref.get {{%.+}} + # CHECK: qref.operator "NoParams"() qubits([[q1]], [[q2]]) + # CHECK: static_data = {} + # CHECK: param_map = {} qubit_map = {wires = [0, 1]} + NoParams(wires=(0,1)) + return qp.state() + +print(c_no_params.mlir) + +class SingleParam(qp.core.Operator2): + + dynamic_argnames = ("x", ) + + def __init__(self, x, wires): + super().__init__(x, wires=wires) + +@qp.qjit(target="mlir", capture=True) +@qp.qnode(qp.device('null.qubit', wires=3)) +def c_single_param(x : float): + + # CHECK: [[q0:%.+]] = qref.get {{%.+}} + + # CHECK: qref.operator "SingleParam"({{%.+}}: tensor) qubits([[q0]]) + # CHECK: static_data = {} + # CHECK: param_map = {x = [0]} qubit_map = {wires = [0]} + SingleParam(x, 0) + + # CHECK: [[q1:%.+]] = qref.get {{%.+}} + # CHECK: [[q2:%.+]] = qref.get {{%.+}} + # CHECK: qref.operator "SingleParam"({{%.+}}: tensor<4x4xf64>) qubits([[q1]], [[q2]]) + # CHECK: static_data = {} + # CHECK: param_map = {x = [0]} qubit_map = {wires = [0, 1]} + SingleParam(np.eye(4), (1,2)) + + return qp.state() + +print(c_single_param.mlir) + + +class CompilableData(qp.core.Operator2): + + compilable_argnames = ("a", "b", "thing") + + def __init__(self, a, b, thing, wires): + super().__init__(a=a, b=b, thing=thing, wires=wires) + +@qp.qjit(capture=True, target="mlir") +@qp.qnode(qp.device('null.qubit', wires=3)) +def c_compilable(): + # CHECK: [[q1:%.+]] = qref.get {{%.+}} + # CHECK: [[q2:%.+]] = qref.get {{%.+}} + # CHECK: qref.operator "CompilableData"() qubits([[q1]], [[q2]]) + # CHECK: static_data = {a = true, b = "some string", thing = [1, true, "string"]} + # CHECK: param_map = {} qubit_map = {wires = [0, 1]} + + CompilableData(True, "some string", (1, True, "string"), wires=(0,1)) + + return qp.state() + + +print(c_compilable.mlir) + + +class MultipleRegisters(qp.core.Operator2): + + wire_argnames = ("reg1", "reg2") + + def __init__(self, reg1, reg2): + super().__init__(reg1=reg1, reg2=reg2) + +@qp.qjit(capture=True, target="mlir") +@qp.qnode(qp.device('null.qubit', wires=5)) +def c_multiple_registers(): + # CHECK: [[q0:%.+]] = qref.get {{%.+}} + # CHECK: [[q2:%.+]] = qref.get {{%.+}} + # CHECK: [[q3:%.+]] = qref.get {{%.+}} + # CHECK: [[q4:%.+]] = qref.get {{%.+}} + + # CHECK: qref.operator "MultipleRegisters"() qubits([[q0]], [[q2]], [[q3]], [[q4]]) + # CHECK: static_data = {} + # CHECK: param_map = {} qubit_map = {reg1 = [0], reg2 = [1, 2, 3]} + MultipleRegisters(0, (2,3,4)) + return qp.state() + + +print(c_multiple_registers.mlir) + + +class MultiParams(qp.core.Operator2): + + dynamic_argnames = ("a", "b", "c") + + # note also having non-standard order with dynamic inputs after wires + def __init__(self, wires, a, b, c): + super().__init__(wires, a, b, c) + + +@qp.qjit(capture=True, target="mlir") +@qp.qnode(qp.device('null.qubit', wires=1)) +def c_multi_params(): + # CHECK: [[q0:%.+]] = qref.get {{%.+}} + + # CHECK: qref.operator "MultiParams"({{%.+}}: tensor, {{%.+}}: tensor<4x2x1xf64>, {{%.+}}: tensor<3xi64>) qubits([[q0]]) + # CHECK: static_data = {} + # CHECK: param_map = {a = [0], b = [1], c = [2]} qubit_map = {wires = [0]} + MultiParams(0, 0.5, c=np.array([1, 2 ,3]), b=np.zeros((4,2,1))) + return qp.state() + +print(c_multi_params.mlir) \ No newline at end of file From 98c88c4216a24dfddfc56bb9e340c1cd1713b50e Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 25 Jun 2026 10:28:02 -0400 Subject: [PATCH 09/16] tests, changelog, formatting --- doc/releases/changelog-dev.md | 4 ++ frontend/catalyst/from_plxpr/from_plxpr.py | 4 +- .../catalyst/from_plxpr/qfunc_interpreter.py | 4 +- .../from_plxpr/qref_jax_primitives.py | 11 +-- frontend/test/lit/test_operator.py | 39 +++++----- frontend/test/pytest/test_operator.py | 72 +++++++++++++++++++ .../Transforms/value_semantics_conversion.cpp | 1 - .../reference_semantics_conversion.cpp | 1 - 8 files changed, 111 insertions(+), 25 deletions(-) create mode 100644 frontend/test/pytest/test_operator.py diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index c990c2ad88..f5b0045e82 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -5,6 +5,10 @@

Improvements 🛠

+* The new `pennylane.core.Operator2` can now be lowered to MLIR with program capture for operators + without non-lowerable arguments. + [(#2969)](https://github.com/PennyLaneAI/catalyst/pull/2969/) + * The `ResourceAnalysis` pass now reports each loop body and each subroutine as its own entry instead of folding their gate counts into the caller. Loops with constant bounds appear as `for_loop_` with their trip count. Loops with dynamic bounds appear as `dyn_for_loop_` with a stable diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index bafceb02b0..86bdc8ac88 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -25,9 +25,9 @@ import jax import pennylane as qp from jax.extend.core import ClosedJaxpr, Jaxpr -from pennylane.core.operator.operator2 import operator_p from pennylane.capture import PlxprInterpreter, qnode_prim from pennylane.capture.primitives import transform_prim +from pennylane.core.operator.operator2 import operator_p from pennylane.transforms import decompose as pl_decompose from catalyst.device import extract_backend_info @@ -408,10 +408,12 @@ def _handle_decompose_transform(self, inner_jaxpr, consts, non_const_args, tkwar # return self.eval(final_jaxpr.jaxpr, consts, *non_const_args) return next_eval.eval(inner_jaxpr, consts, *non_const_args) + @WorkflowInterpreter.register_primitive(operator_p) def _error_on_operator(self, *args, op_cls, **kwargs): raise ValueError(f"Operator {op_cls} must occur inside a qnode.") + # pylint: disable=too-many-arguments @WorkflowInterpreter.register_primitive(transform_prim) def handle_transform( diff --git a/frontend/catalyst/from_plxpr/qfunc_interpreter.py b/frontend/catalyst/from_plxpr/qfunc_interpreter.py index 918139323c..1f6b4af815 100644 --- a/frontend/catalyst/from_plxpr/qfunc_interpreter.py +++ b/frontend/catalyst/from_plxpr/qfunc_interpreter.py @@ -25,12 +25,12 @@ import pennylane as qp from jax._src.sharding_impls import UNSPECIFIED from pennylane.capture import PlxprInterpreter, pause -from pennylane.core.operator.operator2 import operator_p from pennylane.capture.primitives import cond_prim as pl_cond_prim from pennylane.capture.primitives import ctrl_transform_prim as plxpr_ctrl_transform_prim from pennylane.capture.primitives import measure_prim as plxpr_measure_prim from pennylane.capture.primitives import pauli_measure_prim as plxpr_pauli_measure_prim from pennylane.capture.primitives import quantum_subroutine_prim, transform_prim +from pennylane.core.operator.operator2 import operator_p from pennylane.ftqc.primitives import measure_in_basis_prim as plxpr_measure_in_basis_prim from pennylane.measurements import CountsMP from pennylane.wires import AbstractQubit, is_abstract_qubit @@ -46,13 +46,13 @@ qref_measure_in_basis_p, qref_measure_p, qref_namedobs_p, + qref_operator_op, qref_pauli_measure_p, qref_pauli_rot_p, qref_qinst_p, qref_set_basis_state_p, qref_set_state_p, qref_unitary_p, - qref_operator_op, ) from catalyst.jax_primitives import ( counts_p, diff --git a/frontend/catalyst/from_plxpr/qref_jax_primitives.py b/frontend/catalyst/from_plxpr/qref_jax_primitives.py index 315003a187..6535a85cfa 100644 --- a/frontend/catalyst/from_plxpr/qref_jax_primitives.py +++ b/frontend/catalyst/from_plxpr/qref_jax_primitives.py @@ -29,8 +29,10 @@ ) from jaxlib.mlir.dialects.stablehlo import ConvertOp as StableHLOConvertOp from pennylane.capture.primitives import adjoint_transform_prim as plxpr_adjoint_transform_prim -from pennylane.wires import AbstractQubit from pennylane.pytrees import unflatten +from pennylane.wires import AbstractQubit + +from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval # TODO: remove after jax v0.7.2 upgrade # Mock _ods_cext.globals.register_traceback_file_exclusion due to API conflicts between @@ -39,7 +41,6 @@ # once JAX updates to a compatible MLIR version # pylint: disable=ungrouped-imports from catalyst.jax_extras.patches import mock_attributes -from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval from catalyst.jax_primitives import ( AbstractObs, _named_obs_attribute, @@ -827,13 +828,15 @@ def _operator_op_lowering( repack_static_data = {k: unflatten(*v) for k, v in static_data.items()} processed_static_data = get_mlir_attribute_from_pyval(repack_static_data) - param_map = {name: ir.DenseI64ArrayAttr.get([ind]) for ind, name in enumerate(op_cls.dynamic_argnames)} + param_map = { + name: ir.DenseI64ArrayAttr.get([ind]) for ind, name in enumerate(op_cls.dynamic_argnames) + } processed_param_map = get_mlir_attribute_from_pyval(param_map) qubit_map = {} ind = 0 for name, size in zip(op_cls.wire_argnames, wire_lens): - qubit_map[name] = ir.DenseI64ArrayAttr.get(list(range(ind, ind+size))) + qubit_map[name] = ir.DenseI64ArrayAttr.get(list(range(ind, ind + size))) ind += size processed_qubit_map = get_mlir_attribute_from_pyval(qubit_map) diff --git a/frontend/test/lit/test_operator.py b/frontend/test/lit/test_operator.py index 9260ddcb21..0ae5fc2d25 100644 --- a/frontend/test/lit/test_operator.py +++ b/frontend/test/lit/test_operator.py @@ -13,12 +13,12 @@ # limitations under the License. """Tests for operator in Catalyst.""" - # RUN: %PYTHON %s | FileCheck %s import numpy as np import pennylane as qp + class NoParams(qp.core.Operator2): def __init__(self, wires): @@ -26,10 +26,10 @@ def __init__(self, wires): @qp.qjit(target="mlir", capture=True) -@qp.qnode(qp.device('null.qubit', wires=2)) +@qp.qnode(qp.device("null.qubit", wires=2)) def c_no_params(): # CHECK: [[q0:%.+]] = qref.get {{%.+}} - # CHECK: qref.operator "NoParams"() qubits([[q0]]) + # CHECK: qref.operator "NoParams"() qubits([[q0]]) # CHECK static_data = {} # CHECK param_map = {} qubit_map = {wires = [0]} NoParams(wires=0) @@ -39,21 +39,24 @@ def c_no_params(): # CHECK: qref.operator "NoParams"() qubits([[q1]], [[q2]]) # CHECK: static_data = {} # CHECK: param_map = {} qubit_map = {wires = [0, 1]} - NoParams(wires=(0,1)) + NoParams(wires=(0, 1)) return qp.state() + print(c_no_params.mlir) + class SingleParam(qp.core.Operator2): - dynamic_argnames = ("x", ) + dynamic_argnames = ("x",) def __init__(self, x, wires): super().__init__(x, wires=wires) + @qp.qjit(target="mlir", capture=True) -@qp.qnode(qp.device('null.qubit', wires=3)) -def c_single_param(x : float): +@qp.qnode(qp.device("null.qubit", wires=3)) +def c_single_param(x: float): # CHECK: [[q0:%.+]] = qref.get {{%.+}} @@ -67,10 +70,11 @@ def c_single_param(x : float): # CHECK: qref.operator "SingleParam"({{%.+}}: tensor<4x4xf64>) qubits([[q1]], [[q2]]) # CHECK: static_data = {} # CHECK: param_map = {x = [0]} qubit_map = {wires = [0, 1]} - SingleParam(np.eye(4), (1,2)) + SingleParam(np.eye(4), (1, 2)) return qp.state() + print(c_single_param.mlir) @@ -81,8 +85,9 @@ class CompilableData(qp.core.Operator2): def __init__(self, a, b, thing, wires): super().__init__(a=a, b=b, thing=thing, wires=wires) + @qp.qjit(capture=True, target="mlir") -@qp.qnode(qp.device('null.qubit', wires=3)) +@qp.qnode(qp.device("null.qubit", wires=3)) def c_compilable(): # CHECK: [[q1:%.+]] = qref.get {{%.+}} # CHECK: [[q2:%.+]] = qref.get {{%.+}} @@ -90,7 +95,7 @@ def c_compilable(): # CHECK: static_data = {a = true, b = "some string", thing = [1, true, "string"]} # CHECK: param_map = {} qubit_map = {wires = [0, 1]} - CompilableData(True, "some string", (1, True, "string"), wires=(0,1)) + CompilableData(True, "some string", (1, True, "string"), wires=(0, 1)) return qp.state() @@ -105,8 +110,9 @@ class MultipleRegisters(qp.core.Operator2): def __init__(self, reg1, reg2): super().__init__(reg1=reg1, reg2=reg2) + @qp.qjit(capture=True, target="mlir") -@qp.qnode(qp.device('null.qubit', wires=5)) +@qp.qnode(qp.device("null.qubit", wires=5)) def c_multiple_registers(): # CHECK: [[q0:%.+]] = qref.get {{%.+}} # CHECK: [[q2:%.+]] = qref.get {{%.+}} @@ -116,7 +122,7 @@ def c_multiple_registers(): # CHECK: qref.operator "MultipleRegisters"() qubits([[q0]], [[q2]], [[q3]], [[q4]]) # CHECK: static_data = {} # CHECK: param_map = {} qubit_map = {reg1 = [0], reg2 = [1, 2, 3]} - MultipleRegisters(0, (2,3,4)) + MultipleRegisters(0, (2, 3, 4)) return qp.state() @@ -130,17 +136,18 @@ class MultiParams(qp.core.Operator2): # note also having non-standard order with dynamic inputs after wires def __init__(self, wires, a, b, c): super().__init__(wires, a, b, c) - + @qp.qjit(capture=True, target="mlir") -@qp.qnode(qp.device('null.qubit', wires=1)) +@qp.qnode(qp.device("null.qubit", wires=1)) def c_multi_params(): # CHECK: [[q0:%.+]] = qref.get {{%.+}} # CHECK: qref.operator "MultiParams"({{%.+}}: tensor, {{%.+}}: tensor<4x2x1xf64>, {{%.+}}: tensor<3xi64>) qubits([[q0]]) # CHECK: static_data = {} # CHECK: param_map = {a = [0], b = [1], c = [2]} qubit_map = {wires = [0]} - MultiParams(0, 0.5, c=np.array([1, 2 ,3]), b=np.zeros((4,2,1))) + MultiParams(0, 0.5, c=np.array([1, 2, 3]), b=np.zeros((4, 2, 1))) return qp.state() -print(c_multi_params.mlir) \ No newline at end of file + +print(c_multi_params.mlir) diff --git a/frontend/test/pytest/test_operator.py b/frontend/test/pytest/test_operator.py new file mode 100644 index 0000000000..0ab4ac9668 --- /dev/null +++ b/frontend/test/pytest/test_operator.py @@ -0,0 +1,72 @@ +# Copyright 2022-2026 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pennylane as qp +import pytest + + +class DummyOp(qp.core.Operator2): + + def __init__(self, wires): + super().__init__(wires=wires) + + +def test_error_on_operator_outside_qnode(): + """Test that an error is raised for an Operator2 outside the qnode.""" + + with pytest.raises(ValueError, match="must occur inside a qnode."): + + @qp.qjit(capture=True, target="jaxpr") + def f(): + DummyOp(2) + return 2 + + +def test_hybrid_not_supported_yet(): + """Test that hybrid arguments are not yet supported.""" + + class OperatorArgument(qp.core.Operator2): + + hybrid_argnames = ("op",) + wire_argnames = () + + def __init__(self, op): + super().__init__(op) + + with pytest.raises(NotImplementedError): + + @qp.qjit(capture=True) + @qp.qnode(qp.device("null.qubit", wires=3)) + def c(): + OperatorArgument(DummyOp(0)) + return qp.state() + + +def test_static_argnames(): + """Test that static arguments are not yet supported.""" + + class StaticArgsOp(qp.core.Operator2): + + static_argnames = ("thing",) + + def __init__(self, thing, wires): + super().__init__(thing, wires) + + with pytest.raises(NotImplementedError): + + @qp.qjit(capture=True) + @qp.qnode(qp.device("null.qubit", wires=2)) + def c(): + StaticArgsOp("hello", 0) + return qp.state() diff --git a/mlir/lib/QRef/Transforms/value_semantics_conversion.cpp b/mlir/lib/QRef/Transforms/value_semantics_conversion.cpp index d3475f46c2..689af5bb9c 100644 --- a/mlir/lib/QRef/Transforms/value_semantics_conversion.cpp +++ b/mlir/lib/QRef/Transforms/value_semantics_conversion.cpp @@ -1159,7 +1159,6 @@ void handleGate(IRRewriter &builder, qref::QuantumOperation rGateOp, QubitValueT builder.getDenseI32ArrayAttr({nTargets, nCtrls, nQreg})); // Properties are not handled via the generic attribute fields, so we set them separately. - vGateOp.setOpName(rOperatorOp.getOpName()); vGateOp.setAdjoint(rOperatorOp.getAdjoint()); vGateOp.setUID(rOperatorOp.getUID()); } diff --git a/mlir/lib/Quantum/Transforms/reference_semantics_conversion.cpp b/mlir/lib/Quantum/Transforms/reference_semantics_conversion.cpp index 3dbbf0baf1..5c35f70782 100644 --- a/mlir/lib/Quantum/Transforms/reference_semantics_conversion.cpp +++ b/mlir/lib/Quantum/Transforms/reference_semantics_conversion.cpp @@ -368,7 +368,6 @@ void handleGate(IRRewriter &builder, quantum::QuantumOperation vGateOp, QubitVal rGateOp->removeAttr("resultSegmentSizes"); // Properties are not handled via the generic attribute fields, so we set them separately. - rGateOp.setOpName(vOperatorOp.getOpName()); rGateOp.setAdjoint(vOperatorOp.getAdjoint()); rGateOp.setUID(vOperatorOp.getUID()); } From 862375eddba2fd28fb2eaec3baa8cf7e7a3c2f0c Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 25 Jun 2026 10:29:48 -0400 Subject: [PATCH 10/16] pylint --- frontend/test/pytest/test_operator.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/frontend/test/pytest/test_operator.py b/frontend/test/pytest/test_operator.py index 0ab4ac9668..c637bb0f94 100644 --- a/frontend/test/pytest/test_operator.py +++ b/frontend/test/pytest/test_operator.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +""" +Tests for the new Operator2 class. +""" +# pylint: disable = useless-parent-delegation, missing-function-docstring, missing-class-docstring import pennylane as qp import pytest From 6e4eaf761c9dd6b5d92dc3d0e30e93039b68c13d Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 25 Jun 2026 10:35:01 -0400 Subject: [PATCH 11/16] why didn't black work --- frontend/test/pytest/test_operator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/frontend/test/pytest/test_operator.py b/frontend/test/pytest/test_operator.py index c637bb0f94..557df7658d 100644 --- a/frontend/test/pytest/test_operator.py +++ b/frontend/test/pytest/test_operator.py @@ -14,7 +14,8 @@ """ Tests for the new Operator2 class. """ -# pylint: disable = useless-parent-delegation, missing-function-docstring, missing-class-docstring + +# pylint: disable = useless-parent-delegation, missing-function-docstring, missing-class-docstring import pennylane as qp import pytest From 78b4f9ee59c4316ea9923dfe23d7017175e8f535 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 25 Jun 2026 13:52:05 -0400 Subject: [PATCH 12/16] more pylint --- frontend/test/lit/test_operator.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/frontend/test/lit/test_operator.py b/frontend/test/lit/test_operator.py index 0ae5fc2d25..abd2beb55d 100644 --- a/frontend/test/lit/test_operator.py +++ b/frontend/test/lit/test_operator.py @@ -13,6 +13,8 @@ # limitations under the License. """Tests for operator in Catalyst.""" +# pylint: disable = useless-parent-delegation, missing-function-docstring, missing-class-docstring + # RUN: %PYTHON %s | FileCheck %s import numpy as np @@ -143,6 +145,7 @@ def __init__(self, wires, a, b, c): def c_multi_params(): # CHECK: [[q0:%.+]] = qref.get {{%.+}} + # pylint: disable=line-too-long # CHECK: qref.operator "MultiParams"({{%.+}}: tensor, {{%.+}}: tensor<4x2x1xf64>, {{%.+}}: tensor<3xi64>) qubits([[q0]]) # CHECK: static_data = {} # CHECK: param_map = {a = [0], b = [1], c = [2]} qubit_map = {wires = [0]} From 99ec30bf5d33593657fe67b978fef3c603511639 Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Fri, 26 Jun 2026 13:52:16 -0400 Subject: [PATCH 13/16] Update frontend/test/lit/test_operator.py Co-authored-by: River McCubbin --- frontend/test/lit/test_operator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/test/lit/test_operator.py b/frontend/test/lit/test_operator.py index abd2beb55d..12bafa1f1a 100644 --- a/frontend/test/lit/test_operator.py +++ b/frontend/test/lit/test_operator.py @@ -32,8 +32,8 @@ def __init__(self, wires): def c_no_params(): # CHECK: [[q0:%.+]] = qref.get {{%.+}} # CHECK: qref.operator "NoParams"() qubits([[q0]]) - # CHECK static_data = {} - # CHECK param_map = {} qubit_map = {wires = [0]} + # CHECK: static_data = {} + # CHECK: param_map = {} qubit_map = {wires = [0]} NoParams(wires=0) # CHECK: [[q1:%.+]] = qref.get {{%.+}} From 1fafe4082e389e8971cae4e5dae669ba25f4ed67 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Mon, 29 Jun 2026 13:55:22 -0400 Subject: [PATCH 14/16] rename to qref_operator_p --- frontend/catalyst/from_plxpr/qfunc_interpreter.py | 4 ++-- frontend/catalyst/from_plxpr/qref_jax_primitives.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/frontend/catalyst/from_plxpr/qfunc_interpreter.py b/frontend/catalyst/from_plxpr/qfunc_interpreter.py index 1f6b4af815..0d51ee2705 100644 --- a/frontend/catalyst/from_plxpr/qfunc_interpreter.py +++ b/frontend/catalyst/from_plxpr/qfunc_interpreter.py @@ -46,7 +46,7 @@ qref_measure_in_basis_p, qref_measure_p, qref_namedobs_p, - qref_operator_op, + qref_operator_p, qref_pauli_measure_p, qref_pauli_rot_p, qref_qinst_p, @@ -304,7 +304,7 @@ def _handle_operator(self, *args, op_cls, hybrid_lens, hybrid_trees, **kwargs): w if is_abstract_qubit(w) else qref_get_p.bind(self.init_qreg, w) for w in wire_inputs ] - qref_operator_op.bind( + qref_operator_p.bind( *args[: len(op_cls.dynamic_argnames)], *new_wires, op_cls=op_cls, diff --git a/frontend/catalyst/from_plxpr/qref_jax_primitives.py b/frontend/catalyst/from_plxpr/qref_jax_primitives.py index 6535a85cfa..c756d0a291 100644 --- a/frontend/catalyst/from_plxpr/qref_jax_primitives.py +++ b/frontend/catalyst/from_plxpr/qref_jax_primitives.py @@ -168,7 +168,7 @@ class MeasurementPlane(Enum): qref_compbasis_p = Primitive("qref_compbasis") qref_namedobs_p = Primitive("qref_namedobs") qref_hermitian_p = Primitive("qref_hermitian") -qref_operator_op = Primitive("qref_operator") +qref_operator_p = Primitive("qref_operator") # @@ -803,15 +803,15 @@ def _qref_named_obs_lowering(jax_ctx: mlir.LoweringRuleContext, qubit: ir.Value, return NamedObsOp(result_type, qubit, obsId).results -qref_operator_op.multiple_results = True +qref_operator_p.multiple_results = True -@qref_operator_op.def_abstract_eval -def _qref_operator_op_abstract_eval(*args, **kwargs): +@qref_operator_p.def_abstract_eval +def _qref_operator_p_abstract_eval(*args, **kwargs): return [] -def _operator_op_lowering( +def _qref_operator_p_lowering( jax_ctx: mlir.LoweringRuleContext, *args, op_cls, @@ -881,7 +881,7 @@ def _qref_hermitian_lowering(jax_ctx: mlir.LoweringRuleContext, matrix: ir.Value CUSTOM_LOWERING_RULES = ( - (qref_operator_op, _operator_op_lowering), + (qref_operator_p, _qref_operator_p_lowering), (qref_alloc_p, _qref_alloc_lowering), (qref_dealloc_p, _qref_dealloc_lowering), (qref_get_p, _qref_get_lowering), From f2e0b73f11c7e1f4f15ac607549719b37b62710a Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 30 Jun 2026 10:25:38 -0400 Subject: [PATCH 15/16] responding to feedback --- frontend/catalyst/from_plxpr/from_plxpr.py | 6 ------ frontend/catalyst/from_plxpr/qref_jax_primitives.py | 1 + frontend/test/pytest/test_operator.py | 11 ----------- 3 files changed, 1 insertion(+), 17 deletions(-) diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 86bdc8ac88..5df064e7f3 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -408,12 +408,6 @@ def _handle_decompose_transform(self, inner_jaxpr, consts, non_const_args, tkwar # return self.eval(final_jaxpr.jaxpr, consts, *non_const_args) return next_eval.eval(inner_jaxpr, consts, *non_const_args) - -@WorkflowInterpreter.register_primitive(operator_p) -def _error_on_operator(self, *args, op_cls, **kwargs): - raise ValueError(f"Operator {op_cls} must occur inside a qnode.") - - # pylint: disable=too-many-arguments @WorkflowInterpreter.register_primitive(transform_prim) def handle_transform( diff --git a/frontend/catalyst/from_plxpr/qref_jax_primitives.py b/frontend/catalyst/from_plxpr/qref_jax_primitives.py index c756d0a291..c47cdd1d99 100644 --- a/frontend/catalyst/from_plxpr/qref_jax_primitives.py +++ b/frontend/catalyst/from_plxpr/qref_jax_primitives.py @@ -845,6 +845,7 @@ def _qref_operator_p_lowering( op_name=name_attr, params=params, qubits=qubits, + qreg=None, forward_args=[], ctrl_qubits=[], ctrl_values=[], diff --git a/frontend/test/pytest/test_operator.py b/frontend/test/pytest/test_operator.py index 557df7658d..6558f5972d 100644 --- a/frontend/test/pytest/test_operator.py +++ b/frontend/test/pytest/test_operator.py @@ -26,17 +26,6 @@ def __init__(self, wires): super().__init__(wires=wires) -def test_error_on_operator_outside_qnode(): - """Test that an error is raised for an Operator2 outside the qnode.""" - - with pytest.raises(ValueError, match="must occur inside a qnode."): - - @qp.qjit(capture=True, target="jaxpr") - def f(): - DummyOp(2) - return 2 - - def test_hybrid_not_supported_yet(): """Test that hybrid arguments are not yet supported.""" From dd34ff668503959338fd3424e5b8010ece4df9e6 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 30 Jun 2026 10:34:57 -0400 Subject: [PATCH 16/16] remove unused import --- frontend/catalyst/from_plxpr/from_plxpr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 5df064e7f3..f3b0366f57 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -27,7 +27,6 @@ from jax.extend.core import ClosedJaxpr, Jaxpr from pennylane.capture import PlxprInterpreter, qnode_prim from pennylane.capture.primitives import transform_prim -from pennylane.core.operator.operator2 import operator_p from pennylane.transforms import decompose as pl_decompose from catalyst.device import extract_backend_info @@ -408,6 +407,7 @@ def _handle_decompose_transform(self, inner_jaxpr, consts, non_const_args, tkwar # return self.eval(final_jaxpr.jaxpr, consts, *non_const_args) return next_eval.eval(inner_jaxpr, consts, *non_const_args) + # pylint: disable=too-many-arguments @WorkflowInterpreter.register_primitive(transform_prim) def handle_transform(