From 2fe5535b2030ddea52fc9584b43bb24193b1961e Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 16 Jun 2026 14:14:36 -0400 Subject: [PATCH 1/6] allow using AbstractArray for AOT compilation and specs --- doc/releases/changelog-dev.md | 3 +++ frontend/catalyst/jax_extras/patches.py | 5 ++++ frontend/catalyst/tracing/type_signatures.py | 5 +++- frontend/test/pytest/test_jit_behaviour.py | 21 +++++++++++++++++ frontend/test/pytest/test_specs.py | 24 ++++++++++++++++++++ 5 files changed, 57 insertions(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index fb61df48bf..aa1688837f 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -14,6 +14,9 @@

Improvements 🛠

+* Adds the ability to use `pennylane.typing.AbstractArray` and `pennylane.wires.AbstractWires` as type hints for + AOT compilation and as arguments to `pennylane.specs` calculations. + * The `ResourceAnalysis` pass now reports each loop body and each subroutine as its own entry instead of folding their gate counts into the caller. Loops with constant bounds appear as `for_loop_` with their trip count. Loops with dynamic bounds appear as `dyn_for_loop_` with a stable diff --git a/frontend/catalyst/jax_extras/patches.py b/frontend/catalyst/jax_extras/patches.py index 42c72693ae..3c7519c255 100644 --- a/frontend/catalyst/jax_extras/patches.py +++ b/frontend/catalyst/jax_extras/patches.py @@ -47,6 +47,9 @@ from jax._src.sharding_impls import UnspecifiedValue from jax.core import AbstractValue, Tracer +from pennylane.typing import AbstractArray +from pennylane.wires import AbstractWires + from catalyst.utils.patching import DictPatchWrapper __all__ = ( @@ -99,6 +102,8 @@ def _drop_unused_vars2( def get_aval2(x): """An extended version of `jax.core.get_aval` which also accepts AbstractValues.""" # TODO: remove this patch when https://github.com/google/jax/pull/18579 is merged + if isinstance(x, (AbstractArray, AbstractWires)): + x = jax.core.ShapedArray(x.shape, x.dtype) if isinstance(x, AbstractValue): return x elif isinstance(x, Tracer): diff --git a/frontend/catalyst/tracing/type_signatures.py b/frontend/catalyst/tracing/type_signatures.py index d0ef9903a9..081e6c1ef4 100644 --- a/frontend/catalyst/tracing/type_signatures.py +++ b/frontend/catalyst/tracing/type_signatures.py @@ -28,6 +28,9 @@ from jax.core import AbstractValue from jax.tree_util import tree_flatten, tree_unflatten +from pennylane.typing import AbstractArray +from pennylane.wires import AbstractWires + from catalyst.jax_extras import get_aval2 from catalyst.utils.exceptions import CompileError from catalyst.utils.patching import Patcher @@ -57,7 +60,7 @@ def params_are_annotated(fn: Callable): are_annotated = all(annotation is not inspect.Parameter.empty for annotation in annotations) if not are_annotated: return False - return all(isinstance(annotation, (type, AbstractValue)) for annotation in annotations) + return all(isinstance(annotation, (type, AbstractValue, AbstractArray, AbstractWires)) for annotation in annotations) def get_type_annotations(fn: Callable): diff --git a/frontend/test/pytest/test_jit_behaviour.py b/frontend/test/pytest/test_jit_behaviour.py index 399c6cec53..96a86150e2 100644 --- a/frontend/test/pytest/test_jit_behaviour.py +++ b/frontend/test/pytest/test_jit_behaviour.py @@ -863,6 +863,27 @@ def workflow(phi: float): # pylint: disable=function-redefined class TestDefaultAvailableIR: + + def test_AbstractArray_AbstractWires_AOT(self): + """Test that AbstractArray and AbstractWires can be used to specify the input shapes for AOT compilation.""" + + @qp.qjit(capture=True) + @qp.qnode(qp.device('lightning.qubit', wires=4)) + def c(x : qp.typing.AbstractArray((3, ), float), wires : qp.wires.Wires[4]): + @qp.for_loop(x.shape[0]) + def loop(i): + qp.RX(x[i], wires[i]) + + @qp.for_loop(wires.shape[0]) + def loop2(i): + qp.X(i) + + loop() + loop2() + return qp.expval(qp.Z(0)) + + assert c.mlir + def test_mlir(self): """Test mlir.""" diff --git a/frontend/test/pytest/test_specs.py b/frontend/test/pytest/test_specs.py index ef60955956..b07f8e5356 100644 --- a/frontend/test/pytest/test_specs.py +++ b/frontend/test/pytest/test_specs.py @@ -1441,5 +1441,29 @@ def test_marker(self, simple_circuit, capture_mode): check_specs_same(actual, expected) + +def test_abstract_array_inputs(): + """Test that AbstractArray and AbstractWires can be used with specs when level!= device.""" + + @qp.qjit(capture=True) + @qp.qnode(qp.device('lightning.qubit', wires=4)) + def c(x, wires): + @qp.for_loop(x.shape[0]) + def loop(i): + qp.RX(x[i], wires[i]) + + @qp.for_loop(wires.shape[0]) + def loop2(i): + qp.X(i) + + loop() + loop2() + return qp.expval(qp.Z(0)) + + s = qp.specs(c, level=0)(qp.typing.AbstractArray((3,), float), qp.wires.Wires[3]) + assert s.resources.gate_types['PauliX'] == 3 + assert s.resources.gate_types['RX'] == 3 + + if __name__ == "__main__": pytest.main(["-x", __file__]) From 2c0dbbd84135bc2a82453a24103990aa7ee73c8d Mon Sep 17 00:00:00 2001 From: Christina Lee Date: Tue, 16 Jun 2026 14:18:41 -0400 Subject: [PATCH 2/6] Apply suggestion from @albi3ro --- doc/releases/changelog-dev.md | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index aa1688837f..84786c78ca 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -16,6 +16,7 @@ * Adds the ability to use `pennylane.typing.AbstractArray` and `pennylane.wires.AbstractWires` as type hints for AOT compilation and as arguments to `pennylane.specs` calculations. + [(#2953)](https://github.com/PennyLaneAI/catalyst/pull/2953) * The `ResourceAnalysis` pass now reports each loop body and each subroutine as its own entry instead of folding their gate counts into the caller. Loops with constant bounds appear as `for_loop_` From 72cff28bce657851060ed728ae78f1008d768879 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 16 Jun 2026 14:20:58 -0400 Subject: [PATCH 3/6] black --- frontend/catalyst/tracing/type_signatures.py | 5 ++++- frontend/test/pytest/test_jit_behaviour.py | 4 ++-- frontend/test/pytest/test_specs.py | 7 +++---- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/frontend/catalyst/tracing/type_signatures.py b/frontend/catalyst/tracing/type_signatures.py index 081e6c1ef4..0d35334569 100644 --- a/frontend/catalyst/tracing/type_signatures.py +++ b/frontend/catalyst/tracing/type_signatures.py @@ -60,7 +60,10 @@ def params_are_annotated(fn: Callable): are_annotated = all(annotation is not inspect.Parameter.empty for annotation in annotations) if not are_annotated: return False - return all(isinstance(annotation, (type, AbstractValue, AbstractArray, AbstractWires)) for annotation in annotations) + return all( + isinstance(annotation, (type, AbstractValue, AbstractArray, AbstractWires)) + for annotation in annotations + ) def get_type_annotations(fn: Callable): diff --git a/frontend/test/pytest/test_jit_behaviour.py b/frontend/test/pytest/test_jit_behaviour.py index 96a86150e2..a9030d65fb 100644 --- a/frontend/test/pytest/test_jit_behaviour.py +++ b/frontend/test/pytest/test_jit_behaviour.py @@ -868,8 +868,8 @@ def test_AbstractArray_AbstractWires_AOT(self): """Test that AbstractArray and AbstractWires can be used to specify the input shapes for AOT compilation.""" @qp.qjit(capture=True) - @qp.qnode(qp.device('lightning.qubit', wires=4)) - def c(x : qp.typing.AbstractArray((3, ), float), wires : qp.wires.Wires[4]): + @qp.qnode(qp.device("lightning.qubit", wires=4)) + def c(x: qp.typing.AbstractArray((3,), float), wires: qp.wires.Wires[4]): @qp.for_loop(x.shape[0]) def loop(i): qp.RX(x[i], wires[i]) diff --git a/frontend/test/pytest/test_specs.py b/frontend/test/pytest/test_specs.py index b07f8e5356..8b59ca0dfd 100644 --- a/frontend/test/pytest/test_specs.py +++ b/frontend/test/pytest/test_specs.py @@ -1441,12 +1441,11 @@ def test_marker(self, simple_circuit, capture_mode): check_specs_same(actual, expected) - def test_abstract_array_inputs(): """Test that AbstractArray and AbstractWires can be used with specs when level!= device.""" @qp.qjit(capture=True) - @qp.qnode(qp.device('lightning.qubit', wires=4)) + @qp.qnode(qp.device("lightning.qubit", wires=4)) def c(x, wires): @qp.for_loop(x.shape[0]) def loop(i): @@ -1461,8 +1460,8 @@ def loop2(i): return qp.expval(qp.Z(0)) s = qp.specs(c, level=0)(qp.typing.AbstractArray((3,), float), qp.wires.Wires[3]) - assert s.resources.gate_types['PauliX'] == 3 - assert s.resources.gate_types['RX'] == 3 + assert s.resources.gate_types["PauliX"] == 3 + assert s.resources.gate_types["RX"] == 3 if __name__ == "__main__": From 5defbee2de518d354e4c7c707aec0215fd57f76e Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 16 Jun 2026 14:23:37 -0400 Subject: [PATCH 4/6] isort --- frontend/catalyst/jax_extras/patches.py | 1 - frontend/catalyst/tracing/type_signatures.py | 1 - 2 files changed, 2 deletions(-) diff --git a/frontend/catalyst/jax_extras/patches.py b/frontend/catalyst/jax_extras/patches.py index 3c7519c255..1732bfbcbe 100644 --- a/frontend/catalyst/jax_extras/patches.py +++ b/frontend/catalyst/jax_extras/patches.py @@ -46,7 +46,6 @@ from jax._src.pjit import _out_type, _pjit_forwarding, jit_p from jax._src.sharding_impls import UnspecifiedValue from jax.core import AbstractValue, Tracer - from pennylane.typing import AbstractArray from pennylane.wires import AbstractWires diff --git a/frontend/catalyst/tracing/type_signatures.py b/frontend/catalyst/tracing/type_signatures.py index 0d35334569..3baf493b60 100644 --- a/frontend/catalyst/tracing/type_signatures.py +++ b/frontend/catalyst/tracing/type_signatures.py @@ -27,7 +27,6 @@ from jax._src.pjit import _flat_axes_specs from jax.core import AbstractValue from jax.tree_util import tree_flatten, tree_unflatten - from pennylane.typing import AbstractArray from pennylane.wires import AbstractWires From 0ca227ffe1f46f16f8eed5f8898049e4449dd3a1 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 16 Jun 2026 14:24:37 -0400 Subject: [PATCH 5/6] line too long --- frontend/test/pytest/test_jit_behaviour.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/frontend/test/pytest/test_jit_behaviour.py b/frontend/test/pytest/test_jit_behaviour.py index a9030d65fb..f654be79d9 100644 --- a/frontend/test/pytest/test_jit_behaviour.py +++ b/frontend/test/pytest/test_jit_behaviour.py @@ -865,7 +865,8 @@ def workflow(phi: float): # pylint: disable=function-redefined class TestDefaultAvailableIR: def test_AbstractArray_AbstractWires_AOT(self): - """Test that AbstractArray and AbstractWires can be used to specify the input shapes for AOT compilation.""" + """Test that AbstractArray and AbstractWires can be used to specify the input + shapes for AOT compilation.""" @qp.qjit(capture=True) @qp.qnode(qp.device("lightning.qubit", wires=4)) From cb31e741c92083cbf0349ce59f5d78d078319698 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 16 Jun 2026 15:06:34 -0400 Subject: [PATCH 6/6] bump pl version --- .dep-versions | 2 +- doc/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.dep-versions b/.dep-versions index 4ca08f1d5c..5a91175c2e 100644 --- a/.dep-versions +++ b/.dep-versions @@ -8,7 +8,7 @@ enzyme=v0.0.238 # For a custom PL version, update the package version here and at # 'doc/requirements.txt' -pennylane=0.46.0.dev24 +pennylane=0.46.0.dev32 # For a custom LQ/LK version, update the package version here and at # 'doc/requirements.txt' diff --git a/doc/requirements.txt b/doc/requirements.txt index c6721a731a..ecc30dd747 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -34,4 +34,4 @@ lxml_html_clean --extra-index-url https://test.pypi.org/simple/ pennylane-lightning-kokkos==0.46.0-dev10 pennylane-lightning==0.46.0-dev10 -pennylane==0.46.0.dev24 +pennylane==0.46.0.dev32