3131from dataclasses import dataclass , replace
3232from enum import Enum , auto as enum_auto
3333from functools import cached_property , partial
34+ from typing import (
35+ TYPE_CHECKING ,
36+ Any ,
37+ Callable ,
38+ ClassVar ,
39+ Generic ,
40+ Iterable ,
41+ TypeVar ,
42+ Union ,
43+ cast ,
44+ )
3445
3546from immutabledict import immutabledict
36- from typing import TYPE_CHECKING , ClassVar
3747
3848import islpy as isl
39- import pymbolic .primitives as p
4049from islpy import PwQPolynomial , dim_type
4150from pymbolic .mapper import CombineMapper
42- from pymbolic .typing import ArithmeticExpressionT
4351from pytools import memoize_method
44- from pytools .tag import Tag
4552
4653import loopy as lp
4754from loopy .diagnostic import LoopyError , warn_with_kernel
4855from loopy .kernel import LoopKernel
49- from loopy .kernel .array import ArrayBase
5056from loopy .kernel .data import AddressSpace , MultiAssignmentBase
5157from loopy .kernel .function_interface import CallableKernel
52- from loopy .kernel .instruction import InstructionBase
5358from loopy .symbolic import (
5459 CoefficientCollector ,
5560 Reduction ,
5863 flatten ,
5964)
6065from loopy .translation_unit import ConcreteCallablesTable , TranslationUnit
61- from loopy .types import LoopyType
62- from loopy .typing import Expression , ExpressionT , auto
6366
6467
6568if TYPE_CHECKING :
66- from collections .abc import Sequence
69+ from collections .abc import Mapping , Sequence
70+
71+ import pymbolic .primitives as p
72+ from pymbolic .typing import ArithmeticExpressionT
73+ from pytools .tag import Tag
74+
75+ from loopy .kernel .array import ArrayBase
76+ from loopy .kernel .instruction import InstructionBase
77+ from loopy .types import LoopyType
78+ from loopy .typing import Expression , ExpressionT , auto
6779
6880
6981__doc__ = """
@@ -245,7 +257,7 @@ def __add__(self, other: ToCountMap[CountT]) -> ToCountMap[CountT]:
245257 result [k ] = self .count_map .get (k , 0 ) + v
246258 return self .copy (count_map = result )
247259
248- def __radd__ (self , other : Union [ int , ToCountMap [CountT ] ]) -> ToCountMap [CountT ]:
260+ def __radd__ (self , other : int | ToCountMap [CountT ]) -> ToCountMap [CountT ]:
249261 if other != 0 :
250262 raise ValueError ("ToCountMap: Attempted to add ToCountMap "
251263 "to {} {}. ToCountMap may only be added to "
@@ -487,7 +499,7 @@ def to_bytes(self) -> ToCountMap[CountT]:
487499 new_count_map = {}
488500
489501 for key , val in self .count_map .items ():
490- new_count_map [key ] = int (key .dtype .itemsize ) * val # type: ignore[union-attr] # noqa: E501
502+ new_count_map [key ] = int (key .dtype .itemsize ) * val # type: ignore[union-attr]
491503
492504 return self .copy (new_count_map )
493505
@@ -821,7 +833,7 @@ class MemAccess:
821833 A :class:`frozenset` of tags to the operation.
822834 """
823835
824- address_space : AddressSpace | Type [auto ] | None = None
836+ address_space : AddressSpace | type [auto ] | None = None
825837 dtype : LoopyType | None = None
826838 lid_strides : Mapping [int , Expression ] | None = None
827839 gid_strides : Mapping [int , Expression ] | None = None
@@ -1127,7 +1139,7 @@ def map_product(
11271139 kernel_name = self .knl .name ): self .one })
11281140 + self .rec (child , tags )
11291141 for child in expr .children
1130- if not is_zero (cast (ArithmeticExpressionT , child ) + 1 )) + \
1142+ if not is_zero (cast (" ArithmeticExpressionT" , child ) + 1 )) + \
11311143 self .new_poly_map ({Op (dtype = self .type_inf (expr ),
11321144 op_type = OpType .MUL ,
11331145 tags = tags ,
@@ -1159,7 +1171,7 @@ def map_power(self, expr: p.Power, tags: frozenset[Tag]) -> ToCountPolynomialMap
11591171 + self .rec (expr .exponent , tags )
11601172
11611173 def map_left_shift (
1162- self , expr : Union [ p .LeftShift , p .RightShift ] , tags : frozenset [Tag ]
1174+ self , expr : p .LeftShift | p .RightShift , tags : frozenset [Tag ]
11631175 ) -> ToCountPolynomialMap :
11641176 return self .new_poly_map ({Op (dtype = self .type_inf (expr ),
11651177 op_type = OpType .SHIFT ,
@@ -1181,7 +1193,7 @@ def map_bitwise_not(
11811193 + self .rec (expr .child , tags )
11821194
11831195 def map_bitwise_or (
1184- self , expr : Union [ p .BitwiseOr , p .BitwiseAnd , p .BitwiseXor ] ,
1196+ self , expr : p .BitwiseOr | p .BitwiseAnd | p .BitwiseXor ,
11851197 tags : frozenset [Tag ]) -> ToCountPolynomialMap :
11861198 return self .new_poly_map ({Op (dtype = self .type_inf (expr ),
11871199 op_type = OpType .BITWISE ,
@@ -1202,7 +1214,7 @@ def map_if(self, expr: p.If, tags: frozenset[Tag]) -> ToCountPolynomialMap:
12021214 + self .rec (expr .else_ , tags )
12031215
12041216 def map_min (
1205- self , expr : Union [ p . Min , p .Max ] , tags : frozenset [Tag ]
1217+ self , expr : p . Min | p .Max , tags : frozenset [Tag ]
12061218 ) -> ToCountPolynomialMap :
12071219 return self .new_poly_map ({Op (dtype = self .type_inf (expr ),
12081220 op_type = OpType .MAXMIN ,
@@ -1847,7 +1859,7 @@ def _get_op_map_for_single_kernel(
18471859 if isinstance (insn , (CallInstruction , Assignment )):
18481860 ops = op_counter (insn .assignees ) + op_counter (insn .expression )
18491861 for key , val in ops .count_map .items ():
1850- key = cast (Op , key )
1862+ key = cast ("Op" , key )
18511863 count = _get_insn_count (knl , callables_table , insn .id ,
18521864 subgroup_size , count_redundant_work ,
18531865 key .count_granularity )
@@ -1931,7 +1943,7 @@ def get_op_map(
19311943 if len (t_unit .entrypoints ) > 1 :
19321944 raise LoopyError ("Must provide entrypoint" )
19331945
1934- entrypoint = next (iter (program .entrypoints ))
1946+ entrypoint = next (iter (t_unit .entrypoints ))
19351947
19361948 assert entrypoint in t_unit .entrypoints
19371949
@@ -2052,7 +2064,7 @@ def _get_mem_access_map_for_single_kernel(
20522064 ).with_set_attributes (read_write = AccessDirection .WRITE )
20532065
20542066 for key , val in insn_access_map .count_map .items ():
2055- key = cast (MemAccess , key )
2067+ key = cast (" MemAccess" , key )
20562068 count = _get_insn_count (knl , callables_table , insn .id ,
20572069 subgroup_size , count_redundant_work ,
20582070 key .count_granularity )
@@ -2162,7 +2174,7 @@ def get_mem_access_map(
21622174 if len (t_unit .entrypoints ) > 1 :
21632175 raise LoopyError ("Must provide entrypoint" )
21642176
2165- entrypoint = next (iter (program .entrypoints ))
2177+ entrypoint = next (iter (t_unit .entrypoints ))
21662178
21672179 assert entrypoint in t_unit .entrypoints
21682180
@@ -2295,7 +2307,7 @@ def get_synchronization_map(
22952307 if len (t_unit .entrypoints ) > 1 :
22962308 raise LoopyError ("Must provide entrypoint" )
22972309
2298- entrypoint = next (iter (program .entrypoints ))
2310+ entrypoint = next (iter (t_unit .entrypoints ))
22992311
23002312 assert entrypoint in t_unit .entrypoints
23012313 from loopy .preprocess import infer_unknown_types , preprocess_program
@@ -2360,7 +2372,7 @@ def gather_access_footprints(
23602372 if len (t_unit .entrypoints ) > 1 :
23612373 raise LoopyError ("Must provide entrypoint" )
23622374
2363- entrypoint = next (iter (program .entrypoints ))
2375+ entrypoint = next (iter (t_unit .entrypoints ))
23642376
23652377 assert entrypoint in t_unit .entrypoints
23662378
0 commit comments