Skip to content

Commit c67eea6

Browse files
ruff
1 parent da547b6 commit c67eea6

2 files changed

Lines changed: 37 additions & 25 deletions

File tree

loopy/statistics.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,30 @@
3131
from dataclasses import dataclass, replace
3232
from enum import Enum, auto as enum_auto
3333
from functools import cached_property, partial
34+
from typing import (
35+
TYPE_CHECKING,
36+
Any,
37+
Callable,
38+
ClassVar,
39+
Generic,
40+
Iterable,
41+
TypeVar,
42+
Union,
43+
cast,
44+
)
3445

3546
from immutabledict import immutabledict
36-
from typing import TYPE_CHECKING, ClassVar
3747

3848
import islpy as isl
39-
import pymbolic.primitives as p
4049
from islpy import PwQPolynomial, dim_type
4150
from pymbolic.mapper import CombineMapper
42-
from pymbolic.typing import ArithmeticExpressionT
4351
from pytools import memoize_method
44-
from pytools.tag import Tag
4552

4653
import loopy as lp
4754
from loopy.diagnostic import LoopyError, warn_with_kernel
4855
from loopy.kernel import LoopKernel
49-
from loopy.kernel.array import ArrayBase
5056
from loopy.kernel.data import AddressSpace, MultiAssignmentBase
5157
from loopy.kernel.function_interface import CallableKernel
52-
from loopy.kernel.instruction import InstructionBase
5358
from loopy.symbolic import (
5459
CoefficientCollector,
5560
Reduction,
@@ -58,12 +63,19 @@
5863
flatten,
5964
)
6065
from loopy.translation_unit import ConcreteCallablesTable, TranslationUnit
61-
from loopy.types import LoopyType
62-
from loopy.typing import Expression, ExpressionT, auto
6366

6467

6568
if TYPE_CHECKING:
66-
from collections.abc import Sequence
69+
from collections.abc import Mapping, Sequence
70+
71+
import pymbolic.primitives as p
72+
from pymbolic.typing import ArithmeticExpressionT
73+
from pytools.tag import Tag
74+
75+
from loopy.kernel.array import ArrayBase
76+
from loopy.kernel.instruction import InstructionBase
77+
from loopy.types import LoopyType
78+
from loopy.typing import Expression, ExpressionT, auto
6779

6880

6981
__doc__ = """
@@ -245,7 +257,7 @@ def __add__(self, other: ToCountMap[CountT]) -> ToCountMap[CountT]:
245257
result[k] = self.count_map.get(k, 0) + v
246258
return self.copy(count_map=result)
247259

248-
def __radd__(self, other: Union[int, ToCountMap[CountT]]) -> ToCountMap[CountT]:
260+
def __radd__(self, other: int | ToCountMap[CountT]) -> ToCountMap[CountT]:
249261
if other != 0:
250262
raise ValueError("ToCountMap: Attempted to add ToCountMap "
251263
"to {} {}. ToCountMap may only be added to "
@@ -487,7 +499,7 @@ def to_bytes(self) -> ToCountMap[CountT]:
487499
new_count_map = {}
488500

489501
for key, val in self.count_map.items():
490-
new_count_map[key] = int(key.dtype.itemsize) * val # type: ignore[union-attr] # noqa: E501
502+
new_count_map[key] = int(key.dtype.itemsize) * val # type: ignore[union-attr]
491503

492504
return self.copy(new_count_map)
493505

@@ -821,7 +833,7 @@ class MemAccess:
821833
A :class:`frozenset` of tags to the operation.
822834
"""
823835

824-
address_space: AddressSpace | Type[auto] | None = None
836+
address_space: AddressSpace | type[auto] | None = None
825837
dtype: LoopyType | None = None
826838
lid_strides: Mapping[int, Expression] | None = None
827839
gid_strides: Mapping[int, Expression] | None = None
@@ -1127,7 +1139,7 @@ def map_product(
11271139
kernel_name=self.knl.name): self.one})
11281140
+ self.rec(child, tags)
11291141
for child in expr.children
1130-
if not is_zero(cast(ArithmeticExpressionT, child) + 1)) + \
1142+
if not is_zero(cast("ArithmeticExpressionT", child) + 1)) + \
11311143
self.new_poly_map({Op(dtype=self.type_inf(expr),
11321144
op_type=OpType.MUL,
11331145
tags=tags,
@@ -1159,7 +1171,7 @@ def map_power(self, expr: p.Power, tags: frozenset[Tag]) -> ToCountPolynomialMap
11591171
+ self.rec(expr.exponent, tags)
11601172

11611173
def map_left_shift(
1162-
self, expr: Union[p.LeftShift, p.RightShift], tags: frozenset[Tag]
1174+
self, expr: p.LeftShift | p.RightShift, tags: frozenset[Tag]
11631175
) -> ToCountPolynomialMap:
11641176
return self.new_poly_map({Op(dtype=self.type_inf(expr),
11651177
op_type=OpType.SHIFT,
@@ -1181,7 +1193,7 @@ def map_bitwise_not(
11811193
+ self.rec(expr.child, tags)
11821194

11831195
def map_bitwise_or(
1184-
self, expr: Union[p.BitwiseOr, p.BitwiseAnd, p.BitwiseXor],
1196+
self, expr: p.BitwiseOr | p.BitwiseAnd | p.BitwiseXor,
11851197
tags: frozenset[Tag]) -> ToCountPolynomialMap:
11861198
return self.new_poly_map({Op(dtype=self.type_inf(expr),
11871199
op_type=OpType.BITWISE,
@@ -1202,7 +1214,7 @@ def map_if(self, expr: p.If, tags: frozenset[Tag]) -> ToCountPolynomialMap:
12021214
+ self.rec(expr.else_, tags)
12031215

12041216
def map_min(
1205-
self, expr: Union[p. Min, p.Max], tags: frozenset[Tag]
1217+
self, expr: p.Min | p.Max, tags: frozenset[Tag]
12061218
) -> ToCountPolynomialMap:
12071219
return self.new_poly_map({Op(dtype=self.type_inf(expr),
12081220
op_type=OpType.MAXMIN,
@@ -1847,7 +1859,7 @@ def _get_op_map_for_single_kernel(
18471859
if isinstance(insn, (CallInstruction, Assignment)):
18481860
ops = op_counter(insn.assignees) + op_counter(insn.expression)
18491861
for key, val in ops.count_map.items():
1850-
key = cast(Op, key)
1862+
key = cast("Op", key)
18511863
count = _get_insn_count(knl, callables_table, insn.id,
18521864
subgroup_size, count_redundant_work,
18531865
key.count_granularity)
@@ -1931,7 +1943,7 @@ def get_op_map(
19311943
if len(t_unit.entrypoints) > 1:
19321944
raise LoopyError("Must provide entrypoint")
19331945

1934-
entrypoint = next(iter(program.entrypoints))
1946+
entrypoint = next(iter(t_unit.entrypoints))
19351947

19361948
assert entrypoint in t_unit.entrypoints
19371949

@@ -2052,7 +2064,7 @@ def _get_mem_access_map_for_single_kernel(
20522064
).with_set_attributes(read_write=AccessDirection.WRITE)
20532065

20542066
for key, val in insn_access_map.count_map.items():
2055-
key = cast(MemAccess, key)
2067+
key = cast("MemAccess", key)
20562068
count = _get_insn_count(knl, callables_table, insn.id,
20572069
subgroup_size, count_redundant_work,
20582070
key.count_granularity)
@@ -2162,7 +2174,7 @@ def get_mem_access_map(
21622174
if len(t_unit.entrypoints) > 1:
21632175
raise LoopyError("Must provide entrypoint")
21642176

2165-
entrypoint = next(iter(program.entrypoints))
2177+
entrypoint = next(iter(t_unit.entrypoints))
21662178

21672179
assert entrypoint in t_unit.entrypoints
21682180

@@ -2295,7 +2307,7 @@ def get_synchronization_map(
22952307
if len(t_unit.entrypoints) > 1:
22962308
raise LoopyError("Must provide entrypoint")
22972309

2298-
entrypoint = next(iter(program.entrypoints))
2310+
entrypoint = next(iter(t_unit.entrypoints))
22992311

23002312
assert entrypoint in t_unit.entrypoints
23012313
from loopy.preprocess import infer_unknown_types, preprocess_program
@@ -2360,7 +2372,7 @@ def gather_access_footprints(
23602372
if len(t_unit.entrypoints) > 1:
23612373
raise LoopyError("Must provide entrypoint")
23622374

2363-
entrypoint = next(iter(program.entrypoints))
2375+
entrypoint = next(iter(t_unit.entrypoints))
23642376

23652377
assert entrypoint in t_unit.entrypoints
23662378

loopy/symbolic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
LoopyError,
8888
UnableToDetermineAccessRangeError,
8989
)
90-
from loopy.typing import Expression, not_none
90+
from loopy.typing import Expression, ExpressionT, not_none
9191

9292

9393
if TYPE_CHECKING:
@@ -283,7 +283,7 @@ def map_tagged_expression(self, expr, *args, **kwargs):
283283
return
284284

285285
self.rec(expr.expr, *args, **kwargs)
286-
286+
287287
def map_literal(self, expr, *args: P.args, **kwargs: P.kwargs) -> None:
288288
self.visit(expr, *args, **kwargs)
289289

@@ -363,7 +363,7 @@ class CallbackMapper(IdentityMapperMixin, CallbackMapperBase):
363363
class CombineMapper(CombineMapperBase[ResultT, P]):
364364
def map_tagged_expression(self, expr, *args, **kwargs):
365365
return self.rec(expr.expr, *args, **kwargs)
366-
366+
367367
def map_reduction(self, expr, *args: P.args, **kwargs: P.kwargs):
368368
return self.rec(expr.expr, *args, **kwargs)
369369

0 commit comments

Comments
 (0)