diff --git a/.dep-versions b/.dep-versions
index 4ca08f1d5c..5a91175c2e 100644
--- a/.dep-versions
+++ b/.dep-versions
@@ -8,7 +8,7 @@ enzyme=v0.0.238
# For a custom PL version, update the package version here and at
# 'doc/requirements.txt'
-pennylane=0.46.0.dev24
+pennylane=0.46.0.dev32
# For a custom LQ/LK version, update the package version here and at
# 'doc/requirements.txt'
diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index fb61df48bf..84786c78ca 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -14,6 +14,10 @@
Improvements ðŸ›
+* Adds the ability to use `pennylane.typing.AbstractArray` and `pennylane.wires.AbstractWires` as type hints for
+ AOT compilation and as arguments to `pennylane.specs` calculations.
+ [(#2953)](https://github.com/PennyLaneAI/catalyst/pull/2953)
+
* The `ResourceAnalysis` pass now reports each loop body and each subroutine as its own entry
instead of folding their gate counts into the caller. Loops with constant bounds appear as `for_loop_`
with their trip count. Loops with dynamic bounds appear as `dyn_for_loop_` with a stable
diff --git a/doc/requirements.txt b/doc/requirements.txt
index c6721a731a..ecc30dd747 100644
--- a/doc/requirements.txt
+++ b/doc/requirements.txt
@@ -34,4 +34,4 @@ lxml_html_clean
--extra-index-url https://test.pypi.org/simple/
pennylane-lightning-kokkos==0.46.0-dev10
pennylane-lightning==0.46.0-dev10
-pennylane==0.46.0.dev24
+pennylane==0.46.0.dev32
diff --git a/frontend/catalyst/jax_extras/patches.py b/frontend/catalyst/jax_extras/patches.py
index 42c72693ae..1732bfbcbe 100644
--- a/frontend/catalyst/jax_extras/patches.py
+++ b/frontend/catalyst/jax_extras/patches.py
@@ -46,6 +46,8 @@
from jax._src.pjit import _out_type, _pjit_forwarding, jit_p
from jax._src.sharding_impls import UnspecifiedValue
from jax.core import AbstractValue, Tracer
+from pennylane.typing import AbstractArray
+from pennylane.wires import AbstractWires
from catalyst.utils.patching import DictPatchWrapper
@@ -99,6 +101,8 @@ def _drop_unused_vars2(
def get_aval2(x):
"""An extended version of `jax.core.get_aval` which also accepts AbstractValues."""
# TODO: remove this patch when https://github.com/google/jax/pull/18579 is merged
+ if isinstance(x, (AbstractArray, AbstractWires)):
+ x = jax.core.ShapedArray(x.shape, x.dtype)
if isinstance(x, AbstractValue):
return x
elif isinstance(x, Tracer):
diff --git a/frontend/catalyst/tracing/type_signatures.py b/frontend/catalyst/tracing/type_signatures.py
index d0ef9903a9..3baf493b60 100644
--- a/frontend/catalyst/tracing/type_signatures.py
+++ b/frontend/catalyst/tracing/type_signatures.py
@@ -27,6 +27,8 @@
from jax._src.pjit import _flat_axes_specs
from jax.core import AbstractValue
from jax.tree_util import tree_flatten, tree_unflatten
+from pennylane.typing import AbstractArray
+from pennylane.wires import AbstractWires
from catalyst.jax_extras import get_aval2
from catalyst.utils.exceptions import CompileError
@@ -57,7 +59,10 @@ def params_are_annotated(fn: Callable):
are_annotated = all(annotation is not inspect.Parameter.empty for annotation in annotations)
if not are_annotated:
return False
- return all(isinstance(annotation, (type, AbstractValue)) for annotation in annotations)
+ return all(
+ isinstance(annotation, (type, AbstractValue, AbstractArray, AbstractWires))
+ for annotation in annotations
+ )
def get_type_annotations(fn: Callable):
diff --git a/frontend/test/pytest/test_jit_behaviour.py b/frontend/test/pytest/test_jit_behaviour.py
index 399c6cec53..f654be79d9 100644
--- a/frontend/test/pytest/test_jit_behaviour.py
+++ b/frontend/test/pytest/test_jit_behaviour.py
@@ -863,6 +863,28 @@ def workflow(phi: float): # pylint: disable=function-redefined
class TestDefaultAvailableIR:
+
+ def test_AbstractArray_AbstractWires_AOT(self):
+ """Test that AbstractArray and AbstractWires can be used to specify the input
+ shapes for AOT compilation."""
+
+ @qp.qjit(capture=True)
+ @qp.qnode(qp.device("lightning.qubit", wires=4))
+ def c(x: qp.typing.AbstractArray((3,), float), wires: qp.wires.Wires[4]):
+ @qp.for_loop(x.shape[0])
+ def loop(i):
+ qp.RX(x[i], wires[i])
+
+ @qp.for_loop(wires.shape[0])
+ def loop2(i):
+ qp.X(i)
+
+ loop()
+ loop2()
+ return qp.expval(qp.Z(0))
+
+ assert c.mlir
+
def test_mlir(self):
"""Test mlir."""
diff --git a/frontend/test/pytest/test_specs.py b/frontend/test/pytest/test_specs.py
index ef60955956..8b59ca0dfd 100644
--- a/frontend/test/pytest/test_specs.py
+++ b/frontend/test/pytest/test_specs.py
@@ -1441,5 +1441,28 @@ def test_marker(self, simple_circuit, capture_mode):
check_specs_same(actual, expected)
+def test_abstract_array_inputs():
+ """Test that AbstractArray and AbstractWires can be used with specs when level!= device."""
+
+ @qp.qjit(capture=True)
+ @qp.qnode(qp.device("lightning.qubit", wires=4))
+ def c(x, wires):
+ @qp.for_loop(x.shape[0])
+ def loop(i):
+ qp.RX(x[i], wires[i])
+
+ @qp.for_loop(wires.shape[0])
+ def loop2(i):
+ qp.X(i)
+
+ loop()
+ loop2()
+ return qp.expval(qp.Z(0))
+
+ s = qp.specs(c, level=0)(qp.typing.AbstractArray((3,), float), qp.wires.Wires[3])
+ assert s.resources.gate_types["PauliX"] == 3
+ assert s.resources.gate_types["RX"] == 3
+
+
if __name__ == "__main__":
pytest.main(["-x", __file__])