diff --git a/.dep-versions b/.dep-versions index 4ca08f1d5c..5a91175c2e 100644 --- a/.dep-versions +++ b/.dep-versions @@ -8,7 +8,7 @@ enzyme=v0.0.238 # For a custom PL version, update the package version here and at # 'doc/requirements.txt' -pennylane=0.46.0.dev24 +pennylane=0.46.0.dev32 # For a custom LQ/LK version, update the package version here and at # 'doc/requirements.txt' diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index fb61df48bf..84786c78ca 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -14,6 +14,10 @@

Improvements 🛠

+* Adds the ability to use `pennylane.typing.AbstractArray` and `pennylane.wires.AbstractWires` as type hints for + AOT compilation and as arguments to `pennylane.specs` calculations. + [(#2953)](https://github.com/PennyLaneAI/catalyst/pull/2953) + * The `ResourceAnalysis` pass now reports each loop body and each subroutine as its own entry instead of folding their gate counts into the caller. Loops with constant bounds appear as `for_loop_` with their trip count. Loops with dynamic bounds appear as `dyn_for_loop_` with a stable diff --git a/doc/requirements.txt b/doc/requirements.txt index c6721a731a..ecc30dd747 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -34,4 +34,4 @@ lxml_html_clean --extra-index-url https://test.pypi.org/simple/ pennylane-lightning-kokkos==0.46.0-dev10 pennylane-lightning==0.46.0-dev10 -pennylane==0.46.0.dev24 +pennylane==0.46.0.dev32 diff --git a/frontend/catalyst/jax_extras/patches.py b/frontend/catalyst/jax_extras/patches.py index 42c72693ae..1732bfbcbe 100644 --- a/frontend/catalyst/jax_extras/patches.py +++ b/frontend/catalyst/jax_extras/patches.py @@ -46,6 +46,8 @@ from jax._src.pjit import _out_type, _pjit_forwarding, jit_p from jax._src.sharding_impls import UnspecifiedValue from jax.core import AbstractValue, Tracer +from pennylane.typing import AbstractArray +from pennylane.wires import AbstractWires from catalyst.utils.patching import DictPatchWrapper @@ -99,6 +101,8 @@ def _drop_unused_vars2( def get_aval2(x): """An extended version of `jax.core.get_aval` which also accepts AbstractValues.""" # TODO: remove this patch when https://github.com/google/jax/pull/18579 is merged + if isinstance(x, (AbstractArray, AbstractWires)): + x = jax.core.ShapedArray(x.shape, x.dtype) if isinstance(x, AbstractValue): return x elif isinstance(x, Tracer): diff --git a/frontend/catalyst/tracing/type_signatures.py b/frontend/catalyst/tracing/type_signatures.py index d0ef9903a9..3baf493b60 100644 --- a/frontend/catalyst/tracing/type_signatures.py +++ b/frontend/catalyst/tracing/type_signatures.py @@ -27,6 +27,8 @@ from jax._src.pjit import _flat_axes_specs from jax.core import AbstractValue from jax.tree_util import tree_flatten, tree_unflatten +from pennylane.typing import AbstractArray +from pennylane.wires import AbstractWires from catalyst.jax_extras import get_aval2 from catalyst.utils.exceptions import CompileError @@ -57,7 +59,10 @@ def params_are_annotated(fn: Callable): are_annotated = all(annotation is not inspect.Parameter.empty for annotation in annotations) if not are_annotated: return False - return all(isinstance(annotation, (type, AbstractValue)) for annotation in annotations) + return all( + isinstance(annotation, (type, AbstractValue, AbstractArray, AbstractWires)) + for annotation in annotations + ) def get_type_annotations(fn: Callable): diff --git a/frontend/test/pytest/test_jit_behaviour.py b/frontend/test/pytest/test_jit_behaviour.py index 399c6cec53..f654be79d9 100644 --- a/frontend/test/pytest/test_jit_behaviour.py +++ b/frontend/test/pytest/test_jit_behaviour.py @@ -863,6 +863,28 @@ def workflow(phi: float): # pylint: disable=function-redefined class TestDefaultAvailableIR: + + def test_AbstractArray_AbstractWires_AOT(self): + """Test that AbstractArray and AbstractWires can be used to specify the input + shapes for AOT compilation.""" + + @qp.qjit(capture=True) + @qp.qnode(qp.device("lightning.qubit", wires=4)) + def c(x: qp.typing.AbstractArray((3,), float), wires: qp.wires.Wires[4]): + @qp.for_loop(x.shape[0]) + def loop(i): + qp.RX(x[i], wires[i]) + + @qp.for_loop(wires.shape[0]) + def loop2(i): + qp.X(i) + + loop() + loop2() + return qp.expval(qp.Z(0)) + + assert c.mlir + def test_mlir(self): """Test mlir.""" diff --git a/frontend/test/pytest/test_specs.py b/frontend/test/pytest/test_specs.py index ef60955956..8b59ca0dfd 100644 --- a/frontend/test/pytest/test_specs.py +++ b/frontend/test/pytest/test_specs.py @@ -1441,5 +1441,28 @@ def test_marker(self, simple_circuit, capture_mode): check_specs_same(actual, expected) +def test_abstract_array_inputs(): + """Test that AbstractArray and AbstractWires can be used with specs when level!= device.""" + + @qp.qjit(capture=True) + @qp.qnode(qp.device("lightning.qubit", wires=4)) + def c(x, wires): + @qp.for_loop(x.shape[0]) + def loop(i): + qp.RX(x[i], wires[i]) + + @qp.for_loop(wires.shape[0]) + def loop2(i): + qp.X(i) + + loop() + loop2() + return qp.expval(qp.Z(0)) + + s = qp.specs(c, level=0)(qp.typing.AbstractArray((3,), float), qp.wires.Wires[3]) + assert s.resources.gate_types["PauliX"] == 3 + assert s.resources.gate_types["RX"] == 3 + + if __name__ == "__main__": pytest.main(["-x", __file__])