Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .dep-versions
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ enzyme=v0.0.238

# For a custom PL version, update the package version here and at
# 'doc/requirements.txt'
pennylane=0.46.0.dev24
pennylane=0.46.0.dev38
Comment thread
mudit2812 marked this conversation as resolved.

# For a custom LQ/LK version, update the package version here and at
# 'doc/requirements.txt'
Expand Down
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

<h3>Improvements 🛠</h3>

* The new `pennylane.core.Operator2` can now be lowered to MLIR with program capture for operators
without non-lowerable arguments.
Comment thread
albi3ro marked this conversation as resolved.
[(#2969)](https://github.com/PennyLaneAI/catalyst/pull/2969/)

* The `ResourceAnalysis` pass now reports each loop body and each subroutine as its own entry
instead of folding their gate counts into the caller. Loops with constant bounds appear as `for_loop_<N>`
with their trip count. Loops with dynamic bounds appear as `dyn_for_loop_<N>` with a stable
Expand Down
2 changes: 1 addition & 1 deletion doc/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ lxml_html_clean
--extra-index-url https://test.pypi.org/simple/
pennylane-lightning-kokkos==0.46.0-dev10
pennylane-lightning==0.46.0-dev10
pennylane==0.46.0.dev24
pennylane==0.46.0.dev38
Comment thread
mudit2812 marked this conversation as resolved.
25 changes: 25 additions & 0 deletions frontend/catalyst/from_plxpr/qfunc_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from pennylane.capture.primitives import measure_prim as plxpr_measure_prim
from pennylane.capture.primitives import pauli_measure_prim as plxpr_pauli_measure_prim
from pennylane.capture.primitives import quantum_subroutine_prim, transform_prim
from pennylane.core.operator.operator2 import operator_p
Comment thread
albi3ro marked this conversation as resolved.
from pennylane.ftqc.primitives import measure_in_basis_prim as plxpr_measure_in_basis_prim
from pennylane.measurements import CountsMP
from pennylane.wires import AbstractQubit, is_abstract_qubit
Expand All @@ -45,6 +46,7 @@
qref_measure_in_basis_p,
qref_measure_p,
qref_namedobs_p,
qref_operator_p,
qref_pauli_measure_p,
qref_pauli_rot_p,
qref_qinst_p,
Expand Down Expand Up @@ -290,6 +292,29 @@ def __call__(self, jaxpr, *args):
return self.eval(jaxpr.jaxpr, jaxpr.consts, *args)


@PLxPRToQuantumJaxprInterpreter.register_primitive(operator_p)
def _handle_operator(self, *args, op_cls, hybrid_lens, hybrid_trees, **kwargs):

if hybrid_lens or hybrid_trees or op_cls.static_argnames:
# only support compilable_argnames for the moment
raise NotImplementedError

wire_inputs = args[len(op_cls.dynamic_argnames) :]
new_wires = [
w if is_abstract_qubit(w) else qref_get_p.bind(self.init_qreg, w) for w in wire_inputs
]

qref_operator_p.bind(
*args[: len(op_cls.dynamic_argnames)],
*new_wires,
op_cls=op_cls,
hybrid_lens=hybrid_lens,
hybrid_trees=hybrid_trees,
**kwargs,
)
return []


# pylint: disable=unused-argument, too-many-arguments
def _qubit_unitary_bind_call(
*invals, op, qubits_len, params_len, ctrl_len, adjoint, hyperparameters
Expand Down
62 changes: 62 additions & 0 deletions frontend/catalyst/from_plxpr/qref_jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@
)
from jaxlib.mlir.dialects.stablehlo import ConvertOp as StableHLOConvertOp
from pennylane.capture.primitives import adjoint_transform_prim as plxpr_adjoint_transform_prim
from pennylane.pytrees import unflatten
from pennylane.wires import AbstractQubit

from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval

# TODO: remove after jax v0.7.2 upgrade
# Mock _ods_cext.globals.register_traceback_file_exclusion due to API conflicts between
# Catalyst's MLIR version and the MLIR version used by JAX. The current JAX version has not
Expand Down Expand Up @@ -72,6 +75,7 @@
MeasureOp,
MultiRZOp,
NamedObsOp,
OperatorOp,
PauliRotOp,
PCPhaseOp,
QubitUnitaryOp,
Expand Down Expand Up @@ -164,6 +168,7 @@ class MeasurementPlane(Enum):
qref_compbasis_p = Primitive("qref_compbasis")
qref_namedobs_p = Primitive("qref_namedobs")
qref_hermitian_p = Primitive("qref_hermitian")
qref_operator_p = Primitive("qref_operator")


#
Expand Down Expand Up @@ -798,6 +803,62 @@ def _qref_named_obs_lowering(jax_ctx: mlir.LoweringRuleContext, qubit: ir.Value,
return NamedObsOp(result_type, qubit, obsId).results


qref_operator_p.multiple_results = True


@qref_operator_p.def_abstract_eval
def _qref_operator_p_abstract_eval(*args, **kwargs):
return []


def _qref_operator_p_lowering(
jax_ctx: mlir.LoweringRuleContext,
*args,
op_cls,
hybrid_lens,
hybrid_trees,
wire_lens,
**static_data,
):
params = args[: len(op_cls.dynamic_argnames)]
qubits = args[len(op_cls.dynamic_argnames) :]

name_attr = get_mlir_attribute_from_pyval(op_cls.__name__)

repack_static_data = {k: unflatten(*v) for k, v in static_data.items()}
processed_static_data = get_mlir_attribute_from_pyval(repack_static_data)
Comment thread
albi3ro marked this conversation as resolved.

param_map = {
name: ir.DenseI64ArrayAttr.get([ind]) for ind, name in enumerate(op_cls.dynamic_argnames)
Comment thread
dime10 marked this conversation as resolved.
}
processed_param_map = get_mlir_attribute_from_pyval(param_map)

qubit_map = {}
ind = 0
for name, size in zip(op_cls.wire_argnames, wire_lens):
qubit_map[name] = ir.DenseI64ArrayAttr.get(list(range(ind, ind + size)))
ind += size

processed_qubit_map = get_mlir_attribute_from_pyval(qubit_map)

OperatorOp(
op_name=name_attr,
params=params,
qubits=qubits,
qreg=None,
forward_args=[],
Comment thread
albi3ro marked this conversation as resolved.
ctrl_qubits=[],
ctrl_values=[],
adjoint=False,
UID=None,
arr_qubit_indices=[],
param_map=processed_param_map,
static_data=processed_static_data,
qubit_map=processed_qubit_map,
)
return []


#
# hermitian observable
#
Expand All @@ -821,6 +882,7 @@ def _qref_hermitian_lowering(jax_ctx: mlir.LoweringRuleContext, matrix: ir.Value


CUSTOM_LOWERING_RULES = (
(qref_operator_p, _qref_operator_p_lowering),
(qref_alloc_p, _qref_alloc_lowering),
(qref_dealloc_p, _qref_dealloc_lowering),
(qref_get_p, _qref_get_lowering),
Expand Down
3 changes: 3 additions & 0 deletions frontend/catalyst/jax_extras/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ def get_mlir_attribute_from_pyval(value):

attr = None
match value:
case ir.Attribute():
attr = value

case bool():
attr = ir.BoolAttr.get(value)

Expand Down
156 changes: 156 additions & 0 deletions frontend/test/lit/test_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright 2022-2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for operator in Catalyst."""

# pylint: disable = useless-parent-delegation, missing-function-docstring, missing-class-docstring

# RUN: %PYTHON %s | FileCheck %s

import numpy as np
import pennylane as qp


class NoParams(qp.core.Operator2):

def __init__(self, wires):
super().__init__(wires=wires)


@qp.qjit(target="mlir", capture=True)
@qp.qnode(qp.device("null.qubit", wires=2))
def c_no_params():
# CHECK: [[q0:%.+]] = qref.get {{%.+}}
# CHECK: qref.operator "NoParams"() qubits([[q0]])
# CHECK: static_data = {}
# CHECK: param_map = {} qubit_map = {wires = [0]}
NoParams(wires=0)

# CHECK: [[q1:%.+]] = qref.get {{%.+}}
# CHECK: [[q2:%.+]] = qref.get {{%.+}}
# CHECK: qref.operator "NoParams"() qubits([[q1]], [[q2]])
# CHECK: static_data = {}
# CHECK: param_map = {} qubit_map = {wires = [0, 1]}
NoParams(wires=(0, 1))
return qp.state()


print(c_no_params.mlir)


class SingleParam(qp.core.Operator2):

dynamic_argnames = ("x",)

def __init__(self, x, wires):
super().__init__(x, wires=wires)


@qp.qjit(target="mlir", capture=True)
@qp.qnode(qp.device("null.qubit", wires=3))
def c_single_param(x: float):

# CHECK: [[q0:%.+]] = qref.get {{%.+}}

# CHECK: qref.operator "SingleParam"({{%.+}}: tensor<f64>) qubits([[q0]])
# CHECK: static_data = {}
# CHECK: param_map = {x = [0]} qubit_map = {wires = [0]}
SingleParam(x, 0)
Comment thread
albi3ro marked this conversation as resolved.

# CHECK: [[q1:%.+]] = qref.get {{%.+}}
# CHECK: [[q2:%.+]] = qref.get {{%.+}}
# CHECK: qref.operator "SingleParam"({{%.+}}: tensor<4x4xf64>) qubits([[q1]], [[q2]])
# CHECK: static_data = {}
# CHECK: param_map = {x = [0]} qubit_map = {wires = [0, 1]}
SingleParam(np.eye(4), (1, 2))

return qp.state()


print(c_single_param.mlir)


class CompilableData(qp.core.Operator2):

compilable_argnames = ("a", "b", "thing")

def __init__(self, a, b, thing, wires):
super().__init__(a=a, b=b, thing=thing, wires=wires)


@qp.qjit(capture=True, target="mlir")
@qp.qnode(qp.device("null.qubit", wires=3))
def c_compilable():
# CHECK: [[q1:%.+]] = qref.get {{%.+}}
# CHECK: [[q2:%.+]] = qref.get {{%.+}}
# CHECK: qref.operator "CompilableData"() qubits([[q1]], [[q2]])
# CHECK: static_data = {a = true, b = "some string", thing = [1, true, "string"]}
# CHECK: param_map = {} qubit_map = {wires = [0, 1]}

CompilableData(True, "some string", (1, True, "string"), wires=(0, 1))

return qp.state()


print(c_compilable.mlir)


class MultipleRegisters(qp.core.Operator2):

wire_argnames = ("reg1", "reg2")

def __init__(self, reg1, reg2):
super().__init__(reg1=reg1, reg2=reg2)


@qp.qjit(capture=True, target="mlir")
@qp.qnode(qp.device("null.qubit", wires=5))
def c_multiple_registers():
# CHECK: [[q0:%.+]] = qref.get {{%.+}}
# CHECK: [[q2:%.+]] = qref.get {{%.+}}
# CHECK: [[q3:%.+]] = qref.get {{%.+}}
# CHECK: [[q4:%.+]] = qref.get {{%.+}}

# CHECK: qref.operator "MultipleRegisters"() qubits([[q0]], [[q2]], [[q3]], [[q4]])
# CHECK: static_data = {}
# CHECK: param_map = {} qubit_map = {reg1 = [0], reg2 = [1, 2, 3]}
MultipleRegisters(0, (2, 3, 4))
return qp.state()


print(c_multiple_registers.mlir)


class MultiParams(qp.core.Operator2):

dynamic_argnames = ("a", "b", "c")

# note also having non-standard order with dynamic inputs after wires
def __init__(self, wires, a, b, c):
super().__init__(wires, a, b, c)


@qp.qjit(capture=True, target="mlir")
@qp.qnode(qp.device("null.qubit", wires=1))
def c_multi_params():
# CHECK: [[q0:%.+]] = qref.get {{%.+}}

# pylint: disable=line-too-long
# CHECK: qref.operator "MultiParams"({{%.+}}: tensor<f64>, {{%.+}}: tensor<4x2x1xf64>, {{%.+}}: tensor<3xi64>) qubits([[q0]])
# CHECK: static_data = {}
# CHECK: param_map = {a = [0], b = [1], c = [2]} qubit_map = {wires = [0]}
MultiParams(0, 0.5, c=np.array([1, 2, 3]), b=np.zeros((4, 2, 1)))
return qp.state()


print(c_multi_params.mlir)
Loading
Loading