|
7 | 7 | import inspect |
8 | 8 | import textwrap |
9 | 9 | from dataclasses import dataclass |
10 | | -from typing import Annotated, Any, TypeVar, Union, Literal, Optional, Protocol |
| 10 | +from typing import Annotated, Any, TypeVar, Union, Literal, Optional, Protocol, get_origin |
11 | 11 |
|
12 | 12 | from cuda.tile._memory_model import MemoryOrder, MemoryScope |
13 | 13 | from cuda.tile._execution import function, stub |
@@ -1028,16 +1028,30 @@ class ListAnnotation: |
1028 | 1028 | """A ``typing.Annotated`` metadata class for list parameters. |
1029 | 1029 |
|
1030 | 1030 | Attributes: |
1031 | | - element: Annotation for the list's element type. Currently must be an |
1032 | | - :class:`ArrayAnnotation`. |
| 1031 | + element: Annotation for the list's element type. Must be an |
| 1032 | + :class:`ArrayAnnotation` or an ``Annotated`` type whose metadata |
| 1033 | + contains an :class:`ArrayAnnotation` (e.g. :data:`IndexedWithInt64`). |
1033 | 1034 | """ |
1034 | 1035 | element: Any |
1035 | 1036 |
|
1036 | 1037 | def __post_init__(self): |
1037 | | - if not isinstance(self.element, ArrayAnnotation): |
| 1038 | + element = self.element |
| 1039 | + if get_origin(element) is Annotated: |
| 1040 | + array_anns = [m for m in element.__metadata__ if isinstance(m, ArrayAnnotation)] |
| 1041 | + if not array_anns: |
| 1042 | + raise TypeError( |
| 1043 | + f"`element` must contain an ArrayAnnotation," |
| 1044 | + f" but no ArrayAnnotation found in {element!r}") |
| 1045 | + if len(array_anns) > 1: |
| 1046 | + raise TypeError( |
| 1047 | + f"`element` must contain exactly one ArrayAnnotation," |
| 1048 | + f" but found multiple in {element!r}: {array_anns!r}") |
| 1049 | + element = array_anns[0] |
| 1050 | + object.__setattr__(self, 'element', element) |
| 1051 | + if not isinstance(element, ArrayAnnotation): |
1038 | 1052 | raise TypeError( |
1039 | 1053 | f"`element` must be an ArrayAnnotation," |
1040 | | - f" got {type(self.element).__name__}") |
| 1054 | + f" got {type(element).__name__}") |
1041 | 1055 |
|
1042 | 1056 |
|
1043 | 1057 | IndexedWithInt64 = Annotated[T, ArrayAnnotation(index_dtype=int64)] |
|
0 commit comments