Skip to content

Commit eba1d8c

Browse files
albi3rokipawaa
andauthored
add translation and lowering of OperatorOp (#2969)
**Context:** Now that `Operator2` has been added to Pennylane and can be captured into plxpr, it's time to lower it to MLIR. **Description of the Change:** **Benefits:** **Possible Drawbacks:** Had to switch op_name from being a `StrProp` to being a `StrAttr`. Seems to work now. **Related GitHub Issues:** [sc-121978] [sc-121484] --------- Co-authored-by: River McCubbin <river.mccubbin@xanadu.ai>
1 parent 4ed1e3f commit eba1d8c

14 files changed

Lines changed: 324 additions & 11 deletions

File tree

.dep-versions

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ enzyme=v0.0.238
88

99
# For a custom PL version, update the package version here and at
1010
# 'doc/requirements.txt'
11-
pennylane=0.46.0.dev24
11+
pennylane=0.46.0.dev38
1212

1313
# For a custom LQ/LK version, update the package version here and at
1414
# 'doc/requirements.txt'

doc/releases/changelog-dev.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55

66
<h3>Improvements 🛠</h3>
77

8+
* The new `pennylane.core.Operator2` can now be lowered to MLIR with program capture for operators
9+
without non-lowerable arguments.
10+
[(#2969)](https://github.com/PennyLaneAI/catalyst/pull/2969/)
11+
812
* The `ResourceAnalysis` pass now reports each loop body and each subroutine as its own entry
913
instead of folding their gate counts into the caller. Loops with constant bounds appear as `for_loop_<N>`
1014
with their trip count. Loops with dynamic bounds appear as `dyn_for_loop_<N>` with a stable

doc/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,4 @@ lxml_html_clean
3434
--extra-index-url https://test.pypi.org/simple/
3535
pennylane-lightning-kokkos==0.46.0-dev10
3636
pennylane-lightning==0.46.0-dev10
37-
pennylane==0.46.0.dev24
37+
pennylane==0.46.0.dev38

frontend/catalyst/from_plxpr/qfunc_interpreter.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pennylane.capture.primitives import measure_prim as plxpr_measure_prim
3131
from pennylane.capture.primitives import pauli_measure_prim as plxpr_pauli_measure_prim
3232
from pennylane.capture.primitives import quantum_subroutine_prim, transform_prim
33+
from pennylane.core.operator.operator2 import operator_p
3334
from pennylane.ftqc.primitives import measure_in_basis_prim as plxpr_measure_in_basis_prim
3435
from pennylane.measurements import CountsMP
3536
from pennylane.wires import AbstractQubit, is_abstract_qubit
@@ -45,6 +46,7 @@
4546
qref_measure_in_basis_p,
4647
qref_measure_p,
4748
qref_namedobs_p,
49+
qref_operator_p,
4850
qref_pauli_measure_p,
4951
qref_pauli_rot_p,
5052
qref_qinst_p,
@@ -290,6 +292,29 @@ def __call__(self, jaxpr, *args):
290292
return self.eval(jaxpr.jaxpr, jaxpr.consts, *args)
291293

292294

295+
@PLxPRToQuantumJaxprInterpreter.register_primitive(operator_p)
296+
def _handle_operator(self, *args, op_cls, hybrid_lens, hybrid_trees, **kwargs):
297+
298+
if hybrid_lens or hybrid_trees or op_cls.static_argnames:
299+
# only support compilable_argnames for the moment
300+
raise NotImplementedError
301+
302+
wire_inputs = args[len(op_cls.dynamic_argnames) :]
303+
new_wires = [
304+
w if is_abstract_qubit(w) else qref_get_p.bind(self.init_qreg, w) for w in wire_inputs
305+
]
306+
307+
qref_operator_p.bind(
308+
*args[: len(op_cls.dynamic_argnames)],
309+
*new_wires,
310+
op_cls=op_cls,
311+
hybrid_lens=hybrid_lens,
312+
hybrid_trees=hybrid_trees,
313+
**kwargs,
314+
)
315+
return []
316+
317+
293318
# pylint: disable=unused-argument, too-many-arguments
294319
def _qubit_unitary_bind_call(
295320
*invals, op, qubits_len, params_len, ctrl_len, adjoint, hyperparameters

frontend/catalyst/from_plxpr/qref_jax_primitives.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,11 @@
2929
)
3030
from jaxlib.mlir.dialects.stablehlo import ConvertOp as StableHLOConvertOp
3131
from pennylane.capture.primitives import adjoint_transform_prim as plxpr_adjoint_transform_prim
32+
from pennylane.pytrees import unflatten
3233
from pennylane.wires import AbstractQubit
3334

35+
from catalyst.jax_extras.lowering import get_mlir_attribute_from_pyval
36+
3437
# TODO: remove after jax v0.7.2 upgrade
3538
# Mock _ods_cext.globals.register_traceback_file_exclusion due to API conflicts between
3639
# Catalyst's MLIR version and the MLIR version used by JAX. The current JAX version has not
@@ -72,6 +75,7 @@
7275
MeasureOp,
7376
MultiRZOp,
7477
NamedObsOp,
78+
OperatorOp,
7579
PauliRotOp,
7680
PCPhaseOp,
7781
QubitUnitaryOp,
@@ -164,6 +168,7 @@ class MeasurementPlane(Enum):
164168
qref_compbasis_p = Primitive("qref_compbasis")
165169
qref_namedobs_p = Primitive("qref_namedobs")
166170
qref_hermitian_p = Primitive("qref_hermitian")
171+
qref_operator_p = Primitive("qref_operator")
167172

168173

169174
#
@@ -798,6 +803,62 @@ def _qref_named_obs_lowering(jax_ctx: mlir.LoweringRuleContext, qubit: ir.Value,
798803
return NamedObsOp(result_type, qubit, obsId).results
799804

800805

806+
qref_operator_p.multiple_results = True
807+
808+
809+
@qref_operator_p.def_abstract_eval
810+
def _qref_operator_p_abstract_eval(*args, **kwargs):
811+
return []
812+
813+
814+
def _qref_operator_p_lowering(
815+
jax_ctx: mlir.LoweringRuleContext,
816+
*args,
817+
op_cls,
818+
hybrid_lens,
819+
hybrid_trees,
820+
wire_lens,
821+
**static_data,
822+
):
823+
params = args[: len(op_cls.dynamic_argnames)]
824+
qubits = args[len(op_cls.dynamic_argnames) :]
825+
826+
name_attr = get_mlir_attribute_from_pyval(op_cls.__name__)
827+
828+
repack_static_data = {k: unflatten(*v) for k, v in static_data.items()}
829+
processed_static_data = get_mlir_attribute_from_pyval(repack_static_data)
830+
831+
param_map = {
832+
name: ir.DenseI64ArrayAttr.get([ind]) for ind, name in enumerate(op_cls.dynamic_argnames)
833+
}
834+
processed_param_map = get_mlir_attribute_from_pyval(param_map)
835+
836+
qubit_map = {}
837+
ind = 0
838+
for name, size in zip(op_cls.wire_argnames, wire_lens):
839+
qubit_map[name] = ir.DenseI64ArrayAttr.get(list(range(ind, ind + size)))
840+
ind += size
841+
842+
processed_qubit_map = get_mlir_attribute_from_pyval(qubit_map)
843+
844+
OperatorOp(
845+
op_name=name_attr,
846+
params=params,
847+
qubits=qubits,
848+
qreg=None,
849+
forward_args=[],
850+
ctrl_qubits=[],
851+
ctrl_values=[],
852+
adjoint=False,
853+
UID=None,
854+
arr_qubit_indices=[],
855+
param_map=processed_param_map,
856+
static_data=processed_static_data,
857+
qubit_map=processed_qubit_map,
858+
)
859+
return []
860+
861+
801862
#
802863
# hermitian observable
803864
#
@@ -821,6 +882,7 @@ def _qref_hermitian_lowering(jax_ctx: mlir.LoweringRuleContext, matrix: ir.Value
821882

822883

823884
CUSTOM_LOWERING_RULES = (
885+
(qref_operator_p, _qref_operator_p_lowering),
824886
(qref_alloc_p, _qref_alloc_lowering),
825887
(qref_dealloc_p, _qref_dealloc_lowering),
826888
(qref_get_p, _qref_get_lowering),

frontend/catalyst/jax_extras/lowering.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,9 @@ def get_mlir_attribute_from_pyval(value):
202202

203203
attr = None
204204
match value:
205+
case ir.Attribute():
206+
attr = value
207+
205208
case bool():
206209
attr = ir.BoolAttr.get(value)
207210

frontend/test/lit/test_operator.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# Copyright 2022-2023 Xanadu Quantum Technologies Inc.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for operator in Catalyst."""
15+
16+
# pylint: disable = useless-parent-delegation, missing-function-docstring, missing-class-docstring
17+
18+
# RUN: %PYTHON %s | FileCheck %s
19+
20+
import numpy as np
21+
import pennylane as qp
22+
23+
24+
class NoParams(qp.core.Operator2):
25+
26+
def __init__(self, wires):
27+
super().__init__(wires=wires)
28+
29+
30+
@qp.qjit(target="mlir", capture=True)
31+
@qp.qnode(qp.device("null.qubit", wires=2))
32+
def c_no_params():
33+
# CHECK: [[q0:%.+]] = qref.get {{%.+}}
34+
# CHECK: qref.operator "NoParams"() qubits([[q0]])
35+
# CHECK: static_data = {}
36+
# CHECK: param_map = {} qubit_map = {wires = [0]}
37+
NoParams(wires=0)
38+
39+
# CHECK: [[q1:%.+]] = qref.get {{%.+}}
40+
# CHECK: [[q2:%.+]] = qref.get {{%.+}}
41+
# CHECK: qref.operator "NoParams"() qubits([[q1]], [[q2]])
42+
# CHECK: static_data = {}
43+
# CHECK: param_map = {} qubit_map = {wires = [0, 1]}
44+
NoParams(wires=(0, 1))
45+
return qp.state()
46+
47+
48+
print(c_no_params.mlir)
49+
50+
51+
class SingleParam(qp.core.Operator2):
52+
53+
dynamic_argnames = ("x",)
54+
55+
def __init__(self, x, wires):
56+
super().__init__(x, wires=wires)
57+
58+
59+
@qp.qjit(target="mlir", capture=True)
60+
@qp.qnode(qp.device("null.qubit", wires=3))
61+
def c_single_param(x: float):
62+
63+
# CHECK: [[q0:%.+]] = qref.get {{%.+}}
64+
65+
# CHECK: qref.operator "SingleParam"({{%.+}}: tensor<f64>) qubits([[q0]])
66+
# CHECK: static_data = {}
67+
# CHECK: param_map = {x = [0]} qubit_map = {wires = [0]}
68+
SingleParam(x, 0)
69+
70+
# CHECK: [[q1:%.+]] = qref.get {{%.+}}
71+
# CHECK: [[q2:%.+]] = qref.get {{%.+}}
72+
# CHECK: qref.operator "SingleParam"({{%.+}}: tensor<4x4xf64>) qubits([[q1]], [[q2]])
73+
# CHECK: static_data = {}
74+
# CHECK: param_map = {x = [0]} qubit_map = {wires = [0, 1]}
75+
SingleParam(np.eye(4), (1, 2))
76+
77+
return qp.state()
78+
79+
80+
print(c_single_param.mlir)
81+
82+
83+
class CompilableData(qp.core.Operator2):
84+
85+
compilable_argnames = ("a", "b", "thing")
86+
87+
def __init__(self, a, b, thing, wires):
88+
super().__init__(a=a, b=b, thing=thing, wires=wires)
89+
90+
91+
@qp.qjit(capture=True, target="mlir")
92+
@qp.qnode(qp.device("null.qubit", wires=3))
93+
def c_compilable():
94+
# CHECK: [[q1:%.+]] = qref.get {{%.+}}
95+
# CHECK: [[q2:%.+]] = qref.get {{%.+}}
96+
# CHECK: qref.operator "CompilableData"() qubits([[q1]], [[q2]])
97+
# CHECK: static_data = {a = true, b = "some string", thing = [1, true, "string"]}
98+
# CHECK: param_map = {} qubit_map = {wires = [0, 1]}
99+
100+
CompilableData(True, "some string", (1, True, "string"), wires=(0, 1))
101+
102+
return qp.state()
103+
104+
105+
print(c_compilable.mlir)
106+
107+
108+
class MultipleRegisters(qp.core.Operator2):
109+
110+
wire_argnames = ("reg1", "reg2")
111+
112+
def __init__(self, reg1, reg2):
113+
super().__init__(reg1=reg1, reg2=reg2)
114+
115+
116+
@qp.qjit(capture=True, target="mlir")
117+
@qp.qnode(qp.device("null.qubit", wires=5))
118+
def c_multiple_registers():
119+
# CHECK: [[q0:%.+]] = qref.get {{%.+}}
120+
# CHECK: [[q2:%.+]] = qref.get {{%.+}}
121+
# CHECK: [[q3:%.+]] = qref.get {{%.+}}
122+
# CHECK: [[q4:%.+]] = qref.get {{%.+}}
123+
124+
# CHECK: qref.operator "MultipleRegisters"() qubits([[q0]], [[q2]], [[q3]], [[q4]])
125+
# CHECK: static_data = {}
126+
# CHECK: param_map = {} qubit_map = {reg1 = [0], reg2 = [1, 2, 3]}
127+
MultipleRegisters(0, (2, 3, 4))
128+
return qp.state()
129+
130+
131+
print(c_multiple_registers.mlir)
132+
133+
134+
class MultiParams(qp.core.Operator2):
135+
136+
dynamic_argnames = ("a", "b", "c")
137+
138+
# note also having non-standard order with dynamic inputs after wires
139+
def __init__(self, wires, a, b, c):
140+
super().__init__(wires, a, b, c)
141+
142+
143+
@qp.qjit(capture=True, target="mlir")
144+
@qp.qnode(qp.device("null.qubit", wires=1))
145+
def c_multi_params():
146+
# CHECK: [[q0:%.+]] = qref.get {{%.+}}
147+
148+
# pylint: disable=line-too-long
149+
# CHECK: qref.operator "MultiParams"({{%.+}}: tensor<f64>, {{%.+}}: tensor<4x2x1xf64>, {{%.+}}: tensor<3xi64>) qubits([[q0]])
150+
# CHECK: static_data = {}
151+
# CHECK: param_map = {a = [0], b = [1], c = [2]} qubit_map = {wires = [0]}
152+
MultiParams(0, 0.5, c=np.array([1, 2, 3]), b=np.zeros((4, 2, 1)))
153+
return qp.state()
154+
155+
156+
print(c_multi_params.mlir)

0 commit comments

Comments
 (0)