Skip to content

Commit 5e316dd

Browse files
committed
Allow ListAnnotation to accept Annotated element types
Signed-off-by: Qiqi Xiao <qiqix@nvidia.com>
1 parent 00e5933 commit 5e316dd

2 files changed

Lines changed: 20 additions & 7 deletions

File tree

src/cuda/tile/_stub.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import inspect
88
import textwrap
99
from dataclasses import dataclass
10-
from typing import Annotated, Any, TypeVar, Union, Literal, Optional, Protocol
10+
from typing import Annotated, Any, TypeVar, Union, Literal, Optional, Protocol, get_origin
1111

1212
from cuda.tile._memory_model import MemoryOrder, MemoryScope
1313
from cuda.tile._execution import function, stub
@@ -1028,16 +1028,30 @@ class ListAnnotation:
10281028
"""A ``typing.Annotated`` metadata class for list parameters.
10291029
10301030
Attributes:
1031-
element: Annotation for the list's element type. Currently must be an
1032-
:class:`ArrayAnnotation`.
1031+
element: Annotation for the list's element type. Must be an
1032+
:class:`ArrayAnnotation` or an ``Annotated`` type whose metadata
1033+
contains an :class:`ArrayAnnotation` (e.g. :data:`IndexedWithInt64`).
10331034
"""
10341035
element: Any
10351036

10361037
def __post_init__(self):
1037-
if not isinstance(self.element, ArrayAnnotation):
1038+
element = self.element
1039+
if get_origin(element) is Annotated:
1040+
array_anns = [m for m in element.__metadata__ if isinstance(m, ArrayAnnotation)]
1041+
if not array_anns:
1042+
raise TypeError(
1043+
f"`element` must contain an ArrayAnnotation,"
1044+
f" but no ArrayAnnotation found in {element!r}")
1045+
if len(array_anns) > 1:
1046+
raise TypeError(
1047+
f"`element` must contain exactly one ArrayAnnotation,"
1048+
f" but found multiple in {element!r}: {array_anns!r}")
1049+
element = array_anns[0]
1050+
object.__setattr__(self, 'element', element)
1051+
if not isinstance(element, ArrayAnnotation):
10381052
raise TypeError(
10391053
f"`element` must be an ArrayAnnotation,"
1040-
f" got {type(self.element).__name__}")
1054+
f" got {type(element).__name__}")
10411055

10421056

10431057
IndexedWithInt64 = Annotated[T, ArrayAnnotation(index_dtype=int64)]

test/test_list.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import cuda.tile as ct
1010
from cuda.tile._bytecode import BytecodeVersion
11-
from cuda.tile._datatype import int64
1211
from typing import Annotated
1312
from util import assert_equal
1413
from conftest import requires_tileiras
@@ -57,7 +56,7 @@ def test_add_list_of_arrays(kernel):
5756

5857

5958
ListOfArrayIndexedWithInt64 = Annotated[
60-
list, ct.ListAnnotation(element=ct.ArrayAnnotation(index_dtype=int64))
59+
list, ct.ListAnnotation(element=ct.IndexedWithInt64)
6160
]
6261

6362

0 commit comments

Comments
 (0)