add translation and lowering of OperatorOp#2969
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #2969 +/- ##
=======================================
Coverage 96.97% 96.98%
=======================================
Files 166 166
Lines 19209 19247 +38
Branches 1788 1791 +3
=======================================
+ Hits 18628 18666 +38
Misses 429 429
Partials 152 152 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
| return 2 | ||
|
|
||
|
|
||
| def test_hybrid_not_supported_yet(): |
There was a problem hiding this comment.
A pattern I've seen is to write a test that we would like to pass in the future, and then xfail with a "not supported" type message. Not a requirement, but I think it's a nice way to indicate unsupported behaviour that will be added.
Co-authored-by: River McCubbin <river.mccubbin@xanadu.ai>
mudit2812
left a comment
There was a problem hiding this comment.
Looks great Christina, just one blocking comment regarding how we're lowering scalar parameters of qref.operator
|
|
||
| let arguments = (ins | ||
| StringProp:$op_name, | ||
| StrAttr:$op_name, |
There was a problem hiding this comment.
So only string properties are causing issue? All the other properties are fine? 🤔
There was a problem hiding this comment.
I guess we don't know yet because adjoint and UID are not part of the lowering yet
There was a problem hiding this comment.
We will also need to update hanlding of adjoint and UID when we start providing them, I think.
There was a problem hiding this comment.
Let me double check if there are any downstream effects from this since there were subtle differences in the APIs between attributes and props.
| @WorkflowInterpreter.register_primitive(operator_p) | ||
| def _error_on_operator(self, *args, op_cls, **kwargs): | ||
| raise ValueError(f"Operator {op_cls} must occur inside a qnode.") |
There was a problem hiding this comment.
Would be nice to support this in the future too. For example, in the case where we create an operator to get its matrix or to reuse it in multiple QNodes. That would be a new feature though, so not needed right now.
There was a problem hiding this comment.
Would such an operator A) have a primitive (i guess yes whenever capture is enabled)? and B) have its primitive evaluated (interpreted)?
To compute matrices and reuse across QNodes, we should already have the PL operator instance available anyway, right?
There was a problem hiding this comment.
@qp.qjit(capture=True)
def forgot_my_qnode():
qp.X(0)
return qp.state()
for example.
There was a problem hiding this comment.
But what do we do with such a function? I don't think that's one of the scenarios Mudit was describing 🤔
There was a problem hiding this comment.
I was describing something along the lines of:
@qp.qjit(capture=True)
def workflow():
op = qp.X(0)
@qp.qnode(dev)
def f(mat):
# Do something with mat
...
# Support for qp.apply with program capture is being added right now
qp.apply(op)
...
return qp.state()
res = f(op.matrix())
# Do something with res
return somethingThere was a problem hiding this comment.
Nice thanks! In this case, op is a PL Operator2 instance, so it can be freely used in PL functions. The generated equation should actually be removed upon use (qp.apply(op)) during the tracing process. When converting to catalyst / mlir, that equation would then already be gone right?
There was a problem hiding this comment.
@dime10 qp.apply would be the one case where we don't delete the old equation. apply is not a primitive that is consuming operators as data, it treats the operator as an instruction that is needs to insert in the program.
I think we can implement some blanket logic in from_plxpr that just doesn't bind a new primitive if there are operators outside qnodes. So we would still have an equation in plxpr, but not in catalyst jaxpr. I'm not sure how to do this more elegantly though if we want to get rid of from_plxpr eventually.
There was a problem hiding this comment.
apply is not a primitive that is consuming operators as data
That's true, but a good way to look at it imo is that it's just another function that generates an operator / instruction, like qp.adjoint, except that it is the "identity" function. This treats it the same as another PL operator functions and gives more predictable behaviour. Perhaps better revisited elsewhere though.
I think we can implement some blanket logic in from_plxpr that just doesn't bind a new primitive if there are operators outside qnodes. So we would still have an equation in plxpr, but not in catalyst jaxpr.
Makes sense, and is pretty much trivial 👍
I'm not sure how to do this more elegantly though if we want to get rid of from_plxpr eventually.
Hmm, regardless of what we do, at some point a lowering step will have to find a qubit value for it, and there won't be one (outside of a qnode/device/prior allocation), and we could ignore it there, but I definitely think it's better served via the first point.
| from pennylane.capture.primitives import measure_prim as plxpr_measure_prim | ||
| from pennylane.capture.primitives import pauli_measure_prim as plxpr_pauli_measure_prim | ||
| from pennylane.capture.primitives import quantum_subroutine_prim, transform_prim | ||
| from pennylane.core.operator.operator2 import operator_p |
There was a problem hiding this comment.
In the PR where I'm implementing, I'm also making operator_p visible from qp.capture.primitives, so we can update this after that change is made.
| from pennylane.core.operator.operator2 import operator_p | |
| # TODO: Change import path for operator_p after it is available in qp.capture.primitives | |
| from pennylane.core.operator.operator2 import operator_p |
There was a problem hiding this comment.
Generally skeptical of adding TODO's. I think it will be available soon enough I don't think we will forget.
There was a problem hiding this comment.
Sure. I added the change in my PR
| def c_single_param(x: float): | ||
|
|
||
| # CHECK: [[q0:%.+]] = qref.get {{%.+}} | ||
|
|
||
| # CHECK: qref.operator "SingleParam"({{%.+}}: tensor<f64>) qubits([[q0]]) | ||
| # CHECK: static_data = {} | ||
| # CHECK: param_map = {x = [0]} qubit_map = {wires = [0]} | ||
| SingleParam(x, 0) |
There was a problem hiding this comment.
Why is the input a tensor<f64> instead of just an f64?
There was a problem hiding this comment.
Oh it's because for CustomOp we explicitly add operations to extract a scalar from a 0-D tensor. It might be worth doing the same thing here. If any input operands are 0-D tensors, we extract the scalar from it and give that as an input.
There was a problem hiding this comment.
Had a chat offline. The plan is to discuss the necessity of this on Slack and handle it in a follow up if any changes are needed
There was a problem hiding this comment.
Yeah I think for the generic op leaving the 0-D tensors might be simpler. A bit of an annoying difference w.r.t. to JAX, but scalar tensors are a huge resource waste when lowered to CPU, so it's definitely better to have scalars eventually.
| def c_single_param(x: float): | ||
|
|
||
| # CHECK: [[q0:%.+]] = qref.get {{%.+}} | ||
|
|
||
| # CHECK: qref.operator "SingleParam"({{%.+}}: tensor<f64>) qubits([[q0]]) | ||
| # CHECK: static_data = {} | ||
| # CHECK: param_map = {x = [0]} qubit_map = {wires = [0]} | ||
| SingleParam(x, 0) |
There was a problem hiding this comment.
Had a chat offline. The plan is to discuss the necessity of this on Slack and handle it in a follow up if any changes are needed
|
I'd like to have a look over the changes before merging :) Is it correct that this is only part of the op that is supported? @albi3ro your goal is to a support in stages across multiple prs? |
So #2979 adds specialized lowerings for things that fit CustomOp and all the other specialized ops. In PennyLaneAI/pennylane#9729 and PennyLaneAI/pennylane#9730 Mudit is adding specialized capturing for adjoint and controlled. Following those PR's, we will be able to add lowerings for adjoint and ctrl. Then we will have a follow up for things with python-only data. |
| * The new `pennylane.core.Operator2` can now be lowered to MLIR with program capture for operators | ||
| without non-lowerable arguments. |
There was a problem hiding this comment.
These will probably all be pulled into single entry at sound point around operator2, but okay for now.
| @WorkflowInterpreter.register_primitive(operator_p) | ||
| def _error_on_operator(self, *args, op_cls, **kwargs): | ||
| raise ValueError(f"Operator {op_cls} must occur inside a qnode.") |
There was a problem hiding this comment.
Would such an operator A) have a primitive (i guess yes whenever capture is enabled)? and B) have its primitive evaluated (interpreted)?
To compute matrices and reuse across QNodes, we should already have the PL operator instance available anyway, right?
| name_attr = get_mlir_attribute_from_pyval(op_cls.__name__) | ||
|
|
||
| repack_static_data = {k: unflatten(*v) for k, v in static_data.items()} | ||
| processed_static_data = get_mlir_attribute_from_pyval(repack_static_data) |
There was a problem hiding this comment.
I wonder how we want to handle error handling / verification about what kind of data is supported to be lowered like this.
| op_name=name_attr, | ||
| params=params, | ||
| qubits=qubits, | ||
| forward_args=[], |
There was a problem hiding this comment.
We might want to add the register arg here for completeness (setting it to None so it's explicit we're not using it).
Context:
Now that
Operator2has been added to Pennylane and can be captured into plxpr, it's time to lower it to MLIR.Description of the Change:
Benefits:
Possible Drawbacks:
Had to switch op_name from being a
StrPropto being aStrAttr. Seems to work now.Related GitHub Issues:
[sc-121978] [sc-121484]