diff --git a/frontend/catalyst/from_plxpr/qref_jax_primitives.py b/frontend/catalyst/from_plxpr/qref_jax_primitives.py index c47cdd1d99..085c4d9bdb 100644 --- a/frontend/catalyst/from_plxpr/qref_jax_primitives.py +++ b/frontend/catalyst/from_plxpr/qref_jax_primitives.py @@ -29,7 +29,6 @@ ) from jaxlib.mlir.dialects.stablehlo import ConvertOp as StableHLOConvertOp from pennylane.capture.primitives import adjoint_transform_prim as plxpr_adjoint_transform_prim -from pennylane.pytrees import unflatten from pennylane.wires import AbstractQubit from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval @@ -50,6 +49,8 @@ from catalyst.utils.extra_bindings import FromElementsOp, TensorExtractOp from catalyst.utils.patching import Patcher +from .qref_operator2_primitives import qref_operator_p, _qref_operator_p_lowering + with Patcher( ( _ods_cext, @@ -75,7 +76,6 @@ MeasureOp, MultiRZOp, NamedObsOp, - OperatorOp, PauliRotOp, PCPhaseOp, QubitUnitaryOp, @@ -168,7 +168,7 @@ class MeasurementPlane(Enum): qref_compbasis_p = Primitive("qref_compbasis") qref_namedobs_p = Primitive("qref_namedobs") qref_hermitian_p = Primitive("qref_hermitian") -qref_operator_p = Primitive("qref_operator") + # @@ -803,61 +803,6 @@ def _qref_named_obs_lowering(jax_ctx: mlir.LoweringRuleContext, qubit: ir.Value, return NamedObsOp(result_type, qubit, obsId).results -qref_operator_p.multiple_results = True - - -@qref_operator_p.def_abstract_eval -def _qref_operator_p_abstract_eval(*args, **kwargs): - return [] - - -def _qref_operator_p_lowering( - jax_ctx: mlir.LoweringRuleContext, - *args, - op_cls, - hybrid_lens, - hybrid_trees, - wire_lens, - **static_data, -): - params = args[: len(op_cls.dynamic_argnames)] - qubits = args[len(op_cls.dynamic_argnames) :] - - name_attr = get_mlir_attribute_from_pyval(op_cls.__name__) - - repack_static_data = {k: unflatten(*v) for k, v in static_data.items()} - processed_static_data = get_mlir_attribute_from_pyval(repack_static_data) - - param_map = { - name: ir.DenseI64ArrayAttr.get([ind]) for ind, name in enumerate(op_cls.dynamic_argnames) - } - processed_param_map = get_mlir_attribute_from_pyval(param_map) - - qubit_map = {} - ind = 0 - for name, size in zip(op_cls.wire_argnames, wire_lens): - qubit_map[name] = ir.DenseI64ArrayAttr.get(list(range(ind, ind + size))) - ind += size - - processed_qubit_map = get_mlir_attribute_from_pyval(qubit_map) - - OperatorOp( - op_name=name_attr, - params=params, - qubits=qubits, - qreg=None, - forward_args=[], - ctrl_qubits=[], - ctrl_values=[], - adjoint=False, - UID=None, - arr_qubit_indices=[], - param_map=processed_param_map, - static_data=processed_static_data, - qubit_map=processed_qubit_map, - ) - return [] - # # hermitian observable diff --git a/frontend/catalyst/from_plxpr/qref_operator2_primitives.py b/frontend/catalyst/from_plxpr/qref_operator2_primitives.py new file mode 100644 index 0000000000..7f554f2ed4 --- /dev/null +++ b/frontend/catalyst/from_plxpr/qref_operator2_primitives.py @@ -0,0 +1,319 @@ +# Copyright 2026 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module contains JAX-compatible quantum primitives to support the lowering +of quantum operations, measurements, and observables to reference semantics JAXPR. +""" +# pylint: disable=unused-argument +from jax._src.lib.mlir import ir +from jax.extend.core import Primitive +from jax.interpreters import mlir +from jaxlib.mlir._mlir_libs import _mlir as _ods_cext +from jaxlib.mlir.dialects.stablehlo import ConvertOp as StableHLOConvertOp +from pennylane.pytrees import unflatten + +from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval + +# TODO: remove after jax v0.7.2 upgrade +# Mock _ods_cext.globals.register_traceback_file_exclusion due to API conflicts between +# Catalyst's MLIR version and the MLIR version used by JAX. The current JAX version has not +# yet updated to the latest MLIR, causing compatibility issues. This workaround will be removed +# once JAX updates to a compatible MLIR version +# pylint: disable=ungrouped-imports +from catalyst.jax_extras.patches import mock_attributes +from catalyst.jax_primitives import ( + extract_scalar, + safe_cast_to_f64, +) +from catalyst.utils.extra_bindings import TensorExtractOp +from catalyst.utils.patching import Patcher + +with Patcher( + ( + _ods_cext, + "globals", + mock_attributes( + # pylint: disable=c-extension-no-member + _ods_cext.globals, + {"register_traceback_file_exclusion": lambda x: None}, + ), + ), +): + from mlir_quantum.dialects.qref import ( + CustomOp, + GlobalPhaseOp, + MultiRZOp, + OperatorOp, + PauliRotOp, + PCPhaseOp, + QubitUnitaryOp, + ) + + +_SPECIAL_LOWERINGS = {} + + +def _register_special_lowering(op_name): + def decorator(f): + _SPECIAL_LOWERINGS[op_name] = f + return f + + return decorator + +qref_operator_p = Primitive("qref_operator") +qref_operator_p.multiple_results = True + + +@qref_operator_p.def_abstract_eval +def _qref_operator_p_abstract_eval(*args, **kwargs): + return [] + + +def _is_custom_op(op_cls, params): + if op_cls.static_argnames or op_cls.hybrid_argnames or op_cls.compilable_argnames: + return False + if op_cls.wire_argnames != ("wires",): + return False + return all(p.shape == () and "float" in p.dtype.name for p in params) + + +def _qref_operator_p_lowering( + jax_ctx: mlir.LoweringRuleContext, + *args, + op_cls, + **kwargs, +): + ctx = jax_ctx.module_context.context + ctx.allow_unregistered_dialects = True + if op_cls.__name__ in _SPECIAL_LOWERINGS: + return _SPECIAL_LOWERINGS[op_cls.__name__]( + jax_ctx, + *args, + op_cls=op_cls, + **kwargs + ) + hybrid_lens = kwargs.pop("hybrid_lens") + hybrid_trees = kwargs.pop("hybrid_trees") + wire_lens = kwargs.pop("wire_lens") + params = args[: len(op_cls.dynamic_argnames)] + qubits = args[len(op_cls.dynamic_argnames) :] + + name_attr = get_mlir_attribute_from_pyval(op_cls.__name__) + + repack_static_data = {k: unflatten(*v) for k, v in kwargs.items()} + processed_static_data = get_mlir_attribute_from_pyval(repack_static_data) + + param_map = { + name: ir.DenseI64ArrayAttr.get([ind]) for ind, name in enumerate(op_cls.dynamic_argnames) + } + processed_param_map = get_mlir_attribute_from_pyval(param_map) + + qubit_map = {} + ind = 0 + for name, size in zip(op_cls.wire_argnames, wire_lens): + qubit_map[name] = ir.DenseI64ArrayAttr.get(list(range(ind, ind + size))) + ind += size + + processed_qubit_map = get_mlir_attribute_from_pyval(qubit_map) + + ctrl_qubits = [] + ctrl_values = [] + adjoint = False + + if _is_custom_op(op_cls, params): + params = [extract_scalar(safe_cast_to_f64(p, op_cls), op_cls) for p in params] + CustomOp( + params=params, + qubits=qubits, + gate_name=name_attr, + ctrl_qubits=ctrl_qubits, + ctrl_values=ctrl_values, + adjoint=adjoint, + ) + else: + OperatorOp( + op_name=name_attr, + params=params, + qubits=qubits, + qreg=None, + forward_args=[], + ctrl_qubits=[], + ctrl_values=[], + adjoint=False, + UID=None, + arr_qubit_indices=[], + param_map=processed_param_map, + static_data=processed_static_data, + qubit_map=processed_qubit_map, + ) + return [] + + + + +@_register_special_lowering("MultiRZ") +def _multirz_lowering( + jax_ctx: mlir.LoweringRuleContext, + *args, + op_cls, + hybrid_lens, + hybrid_trees, + wire_lens +): + theta = (extract_scalar(safe_cast_to_f64(args[0], "MultiRZ"), "MultiRZ"),) + qubits = args[1:] + MultiRZOp( + theta=theta, + qubits=qubits, + ctrl_qubits=[], + ctrl_values=[], + adjoint=False, + ) + return [] + + +@_register_special_lowering("PCPhase") +def _pcphase_lowering( + jax_ctx: mlir.LoweringRuleContext, + *args, + op_cls, + hybrid_lens, + hybrid_trees, + wire_lens, +): + qubits = args[2:] + PCPhaseOp( + theta=extract_scalar(safe_cast_to_f64(args[0], "PCPhase"), "PCPhase"), + dim=extract_scalar(safe_cast_to_f64(args[0], "PCPhase"), "PCPhase"), + qubits=qubits, + ctrl_qubits=[], + ctrl_values=[], + adjoint=False, + ) + return () + + +@_register_special_lowering("GlobalPhase") +def _special_gphase_lowering( + jax_ctx: mlir.LoweringRuleContext, + *args, + op_cls, + hybrid_lens, + hybrid_trees, + wire_lens, +): + GlobalPhaseOp( + angle=extract_scalar(safe_cast_to_f64(args[0], "GlobalPhase"), "GlobalPhase"), + ctrl_qubits=[], + ctrl_values=[], + adjoint=False, + ) + return () + + +@_register_special_lowering("QubitUnitary") +def _special_unitary_lowering( + jax_ctx: mlir.LoweringRuleContext, + matrix, + *qubits, + op_cls, + hybrid_lens, + hybrid_trees, + wire_lens, +): + ctrl_qubits = [] + ctrl_values = [] + + for q in qubits: + assert ir.OpaqueType.isinstance(q.type) + assert ir.OpaqueType(q.type).dialect_namespace == "qref" + assert ir.OpaqueType(q.type).data == "bit" + + matrix_type = matrix.type + is_tensor = ir.RankedTensorType.isinstance(matrix_type) + shape = ir.RankedTensorType(matrix_type).shape if is_tensor else None + is_2d_tensor = len(shape) == 2 if is_tensor else False + if not is_2d_tensor: + raise TypeError("QubitUnitary must be a 2 dimensional tensor.") + + possibly_complex_type = ir.RankedTensorType(matrix_type).element_type + is_complex = ir.ComplexType.isinstance(possibly_complex_type) + is_f64_type = False + + if is_complex: + complex_type = ir.ComplexType(possibly_complex_type) + possibly_f64_type = complex_type.element_type + is_f64_type = ir.F64Type.isinstance(possibly_f64_type) + + is_complex_f64_type = is_complex and is_f64_type + if not is_complex_f64_type: + f64_type = ir.F64Type.get() + complex_f64_type = ir.ComplexType.get(f64_type) + tensor_complex_f64_type = ir.RankedTensorType.get(shape, complex_f64_type) + matrix = StableHLOConvertOp(tensor_complex_f64_type, matrix).result + + ctrl_values_i1 = [ + TensorExtractOp(ir.IntegerType.get_signless(1), v, []).result for v in ctrl_values + ] + + QubitUnitaryOp( + matrix=matrix, + qubits=qubits, + ctrl_qubits=ctrl_qubits, + ctrl_values=ctrl_values_i1, + adjoint=False, + ) + + return () + + +@_register_special_lowering("PauliRot") +def _special_paulirot_lowering( + jax_ctx: mlir.LoweringRuleContext, + angle, + *qubits, + op_cls, + hybrid_lens, + hybrid_trees, + wire_lens, + pauli_word, +): + pauli_word = unflatten(*pauli_word) + ctrl_qubits = [] + ctrl_values = [] + + for q in qubits: + assert ir.OpaqueType.isinstance(q.type) + assert ir.OpaqueType(q.type).dialect_namespace == "qref" + assert ir.OpaqueType(q.type).data == "bit" + + angle = safe_cast_to_f64(angle, "PauliRot") + angle = extract_scalar(angle, "PauliRot") + assert ir.F64Type.isinstance(angle.type) + + pauli_word = ir.ArrayAttr.get([ir.StringAttr.get(p) for p in pauli_word]) + + ctrl_values_i1 = [ + TensorExtractOp(ir.IntegerType.get_signless(1), v, []).result for v in ctrl_values + ] + + PauliRotOp( + angle=angle, + pauli_product=pauli_word, + qubits=qubits, + ctrl_qubits=ctrl_qubits, + ctrl_values=ctrl_values_i1, + adjoint=False, + ) + + return () diff --git a/frontend/test/lit/test_operator.py b/frontend/test/lit/test_operator.py index 12bafa1f1a..b2ce026243 100644 --- a/frontend/test/lit/test_operator.py +++ b/frontend/test/lit/test_operator.py @@ -23,8 +23,11 @@ class NoParams(qp.core.Operator2): - def __init__(self, wires): - super().__init__(wires=wires) + # have to use different wire argnames or will in up CustomOp + wire_argnames = ("reg",) + + def __init__(self, reg): + super().__init__(reg=reg) @qp.qjit(target="mlir", capture=True) @@ -33,27 +36,51 @@ def c_no_params(): # CHECK: [[q0:%.+]] = qref.get {{%.+}} # CHECK: qref.operator "NoParams"() qubits([[q0]]) # CHECK: static_data = {} - # CHECK: param_map = {} qubit_map = {wires = [0]} - NoParams(wires=0) + # CHECK: param_map = {} qubit_map = {reg = [0]} + NoParams(reg=0) # CHECK: [[q1:%.+]] = qref.get {{%.+}} # CHECK: [[q2:%.+]] = qref.get {{%.+}} # CHECK: qref.operator "NoParams"() qubits([[q1]], [[q2]]) # CHECK: static_data = {} - # CHECK: param_map = {} qubit_map = {wires = [0, 1]} - NoParams(wires=(0, 1)) + # CHECK: param_map = {} qubit_map = {reg = [0, 1]} + NoParams(reg=(0, 1)) return qp.state() print(c_no_params.mlir) +class NoParamsCustomOp(qp.core.Operator2): + + def __init__(self, wires): + super().__init__(wires=wires) + + +@qp.qjit(target="mlir", capture=True) +@qp.qnode(qp.device("null.qubit", wires=2)) +def c_no_params_custom(): + # CHECK: [[q0:%.+]] = qref.get {{%.+}} + # CHECK: qref.custom "NoParamsCustomOp"() [[q0]] : !qref.bit + NoParamsCustomOp(wires=0) + + # CHECK: [[q1:%.+]] = qref.get {{%.+}} + # CHECK: [[q2:%.+]] = qref.get {{%.+}} + # CHECK: qref.custom "NoParamsCustomOp"() [[q1]], [[q2]] : !qref.bit, !qref.bit + NoParamsCustomOp(wires=(0, 1)) + return qp.state() + + +print(c_no_params_custom.mlir) + + class SingleParam(qp.core.Operator2): dynamic_argnames = ("x",) + wire_argnames = ("reg",) - def __init__(self, x, wires): - super().__init__(x, wires=wires) + def __init__(self, x, reg): + super().__init__(x, reg=reg) @qp.qjit(target="mlir", capture=True) @@ -64,14 +91,14 @@ def c_single_param(x: float): # CHECK: qref.operator "SingleParam"({{%.+}}: tensor) qubits([[q0]]) # CHECK: static_data = {} - # CHECK: param_map = {x = [0]} qubit_map = {wires = [0]} + # CHECK: param_map = {x = [0]} qubit_map = {reg = [0]} SingleParam(x, 0) # CHECK: [[q1:%.+]] = qref.get {{%.+}} # CHECK: [[q2:%.+]] = qref.get {{%.+}} # CHECK: qref.operator "SingleParam"({{%.+}}: tensor<4x4xf64>) qubits([[q1]], [[q2]]) # CHECK: static_data = {} - # CHECK: param_map = {x = [0]} qubit_map = {wires = [0, 1]} + # CHECK: param_map = {x = [0]} qubit_map = {reg = [0, 1]} SingleParam(np.eye(4), (1, 2)) return qp.state() @@ -80,6 +107,33 @@ def c_single_param(x: float): print(c_single_param.mlir) +class SingleParamCustomOp(qp.core.Operator2): + + dynamic_argnames = ("x",) + + def __init__(self, x, wires): + super().__init__(x, wires=wires) + + +@qp.qjit(target="mlir", capture=True) +@qp.qnode(qp.device("null.qubit", wires=3)) +def c_single_param_custom(x: float): + + # CHECK: [[q0:%.+]] = qref.get {{%.+}} + # CHECK: qref.custom "SingleParamCustomOp"({{%.+}}: tensor) [[q0]] -> !qref.bit + SingleParamCustomOp(x, 0) + + # CHECK: [[q1:%.+]] = qref.get {{%.+}} + # CHECK: [[q2:%.+]] = qref.get {{%.+}} + # CHECK: qref.custom "SingleParamCustomOp"({{%.+}}: tensor) [[q1]], [[q2]]) -> !qref.bit !qref.bit + SingleParamCustomOp(0.5, (1, 2)) + + return qp.state() + + +print(c_single_param_custom.mlir) + + class CompilableData(qp.core.Operator2): compilable_argnames = ("a", "b", "thing") @@ -134,10 +188,11 @@ def c_multiple_registers(): class MultiParams(qp.core.Operator2): dynamic_argnames = ("a", "b", "c") + wire_argnames = ("reg", ) # note also having non-standard order with dynamic inputs after wires - def __init__(self, wires, a, b, c): - super().__init__(wires, a, b, c) + def __init__(self, reg, a, b, c): + super().__init__(reg, a, b, c) @qp.qjit(capture=True, target="mlir") @@ -148,7 +203,7 @@ def c_multi_params(): # pylint: disable=line-too-long # CHECK: qref.operator "MultiParams"({{%.+}}: tensor, {{%.+}}: tensor<4x2x1xf64>, {{%.+}}: tensor<3xi64>) qubits([[q0]]) # CHECK: static_data = {} - # CHECK: param_map = {a = [0], b = [1], c = [2]} qubit_map = {wires = [0]} + # CHECK: param_map = {a = [0], b = [1], c = [2]} qubit_map = {reg = [0]} MultiParams(0, 0.5, c=np.array([1, 2, 3]), b=np.zeros((4, 2, 1))) return qp.state()