Skip to content

Commit ba692ee

Browse files
committed
fix typing for CachedMapper.rec/rec_function_definition
1 parent 0fcc46c commit ba692ee

4 files changed

Lines changed: 28 additions & 178 deletions

File tree

.basedpyright/baseline.json

Lines changed: 0 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -1825,46 +1825,6 @@
18251825
"lineCount": 1
18261826
}
18271827
},
1828-
{
1829-
"code": "reportUnknownVariableType",
1830-
"range": {
1831-
"startColumn": 12,
1832-
"endColumn": 13,
1833-
"lineCount": 1
1834-
}
1835-
},
1836-
{
1837-
"code": "reportUnknownMemberType",
1838-
"range": {
1839-
"startColumn": 16,
1840-
"endColumn": 26,
1841-
"lineCount": 1
1842-
}
1843-
},
1844-
{
1845-
"code": "reportUnknownVariableType",
1846-
"range": {
1847-
"startColumn": 16,
1848-
"endColumn": 22,
1849-
"lineCount": 1
1850-
}
1851-
},
1852-
{
1853-
"code": "reportUnknownVariableType",
1854-
"range": {
1855-
"startColumn": 16,
1856-
"endColumn": 22,
1857-
"lineCount": 1
1858-
}
1859-
},
1860-
{
1861-
"code": "reportUnknownVariableType",
1862-
"range": {
1863-
"startColumn": 19,
1864-
"endColumn": 25,
1865-
"lineCount": 1
1866-
}
1867-
},
18681828
{
18691829
"code": "reportAny",
18701830
"range": {
@@ -6657,38 +6617,6 @@
66576617
"lineCount": 1
66586618
}
66596619
},
6660-
{
6661-
"code": "reportUnknownVariableType",
6662-
"range": {
6663-
"startColumn": 12,
6664-
"endColumn": 18,
6665-
"lineCount": 1
6666-
}
6667-
},
6668-
{
6669-
"code": "reportUnknownMemberType",
6670-
"range": {
6671-
"startColumn": 21,
6672-
"endColumn": 49,
6673-
"lineCount": 1
6674-
}
6675-
},
6676-
{
6677-
"code": "reportAttributeAccessIssue",
6678-
"range": {
6679-
"startColumn": 26,
6680-
"endColumn": 49,
6681-
"lineCount": 1
6682-
}
6683-
},
6684-
{
6685-
"code": "reportUnknownVariableType",
6686-
"range": {
6687-
"startColumn": 15,
6688-
"endColumn": 44,
6689-
"lineCount": 1
6690-
}
6691-
},
66926620
{
66936621
"code": "reportUnreachable",
66946622
"range": {
@@ -6713,22 +6641,6 @@
67136641
"lineCount": 1
67146642
}
67156643
},
6716-
{
6717-
"code": "reportUnknownMemberType",
6718-
"range": {
6719-
"startColumn": 43,
6720-
"endColumn": 53,
6721-
"lineCount": 1
6722-
}
6723-
},
6724-
{
6725-
"code": "reportUnknownArgumentType",
6726-
"range": {
6727-
"startColumn": 43,
6728-
"endColumn": 82,
6729-
"lineCount": 1
6730-
}
6731-
},
67326644
{
67336645
"code": "reportImplicitOverride",
67346646
"range": {
@@ -6737,22 +6649,6 @@
67376649
"lineCount": 1
67386650
}
67396651
},
6740-
{
6741-
"code": "reportUnknownMemberType",
6742-
"range": {
6743-
"startColumn": 24,
6744-
"endColumn": 54,
6745-
"lineCount": 1
6746-
}
6747-
},
6748-
{
6749-
"code": "reportUnknownArgumentType",
6750-
"range": {
6751-
"startColumn": 24,
6752-
"endColumn": 83,
6753-
"lineCount": 1
6754-
}
6755-
},
67566652
{
67576653
"code": "reportUnusedParameter",
67586654
"range": {
@@ -7353,22 +7249,6 @@
73537249
"lineCount": 1
73547250
}
73557251
},
7356-
{
7357-
"code": "reportUnknownMemberType",
7358-
"range": {
7359-
"startColumn": 43,
7360-
"endColumn": 53,
7361-
"lineCount": 1
7362-
}
7363-
},
7364-
{
7365-
"code": "reportUnknownArgumentType",
7366-
"range": {
7367-
"startColumn": 43,
7368-
"endColumn": 78,
7369-
"lineCount": 1
7370-
}
7371-
},
73727252
{
73737253
"code": "reportIncompatibleMethodOverride",
73747254
"range": {
@@ -8169,38 +8049,6 @@
81698049
"lineCount": 1
81708050
}
81718051
},
8172-
{
8173-
"code": "reportUnknownVariableType",
8174-
"range": {
8175-
"startColumn": 12,
8176-
"endColumn": 18,
8177-
"lineCount": 1
8178-
}
8179-
},
8180-
{
8181-
"code": "reportUnknownMemberType",
8182-
"range": {
8183-
"startColumn": 21,
8184-
"endColumn": 31,
8185-
"lineCount": 1
8186-
}
8187-
},
8188-
{
8189-
"code": "reportUnknownArgumentType",
8190-
"range": {
8191-
"startColumn": 49,
8192-
"endColumn": 55,
8193-
"lineCount": 1
8194-
}
8195-
},
8196-
{
8197-
"code": "reportUnknownArgumentType",
8198-
"range": {
8199-
"startColumn": 43,
8200-
"endColumn": 49,
8201-
"lineCount": 1
8202-
}
8203-
},
82048052
{
82058053
"code": "reportImplicitOverride",
82068054
"range": {

pytato/analysis/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from pytato.function import Call, FunctionDefinition, NamedCallResult
5353
from pytato.transform import (
5454
ArrayOrNames,
55+
CachedMapper,
5556
CachedWalkMapper,
5657
CombineMapper,
5758
Mapper,
@@ -720,10 +721,9 @@ def rec(self, expr: ArrayOrNames) -> int:
720721
try:
721722
return self._cache_retrieve(inputs)
722723
except KeyError:
723-
# Intentionally going to Mapper instead of super() to avoid
724-
# double caching when subclasses of CachedMapper override rec,
725-
# see https://github.com/inducer/pytato/pull/585
726-
s = Mapper.rec(self, expr)
724+
# Using super(CachedMapper, self) instead of super() to bypass
725+
# CachedMapper.rec and avoid double caching
726+
s = super(CachedMapper, self).rec(expr)
727727
if (
728728
isinstance(expr, Array)
729729
and (

pytato/transform/__init__.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -280,15 +280,15 @@ def rec_function_definition(
280280
self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs
281281
) -> FunctionResultT:
282282
"""Call the mapper method of *expr* and return the result."""
283-
method: Callable[..., FunctionResultT] | None
284-
285283
try:
286-
method = self.map_function_definition # type: ignore[attr-defined]
284+
method_name = "map_function_definition"
285+
method: Callable[..., FunctionResultT] = cast(
286+
"Callable[..., FunctionResultT]",
287+
getattr(self, method_name))
287288
except AttributeError:
288289
raise ValueError(
289290
f"{type(self).__name__} lacks a mapper method for functions.") from None
290291

291-
assert method is not None
292292
return method(expr, *args, **kwargs)
293293

294294
@overload
@@ -523,10 +523,11 @@ def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT:
523523
try:
524524
return self._cache_retrieve(inputs)
525525
except KeyError:
526-
# Intentionally going to Mapper instead of super() to avoid
527-
# double caching when subclasses of CachedMapper override rec,
528-
# see https://github.com/inducer/pytato/pull/585
529-
return self._cache_add(inputs, Mapper.rec(self, expr, *args, **kwargs))
526+
# Reminder: If overriding this in a subclass and reimplementing the cache
527+
# lookup logic there, must use super(CachedMapper, self) instead of
528+
# super() below to avoid double caching,
529+
# see https://github.com/inducer/pytato/pull/585.
530+
return self._cache_add(inputs, super().rec(expr, *args, **kwargs))
530531

531532
def rec_function_definition(
532533
self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs
@@ -535,11 +536,12 @@ def rec_function_definition(
535536
try:
536537
return self._function_cache_retrieve(inputs)
537538
except KeyError:
539+
# Reminder: If overriding this in a subclass and reimplementing the cache
540+
# lookup logic there, must use super(CachedMapper, self) instead of
541+
# super() below to avoid double caching,
542+
# see https://github.com/inducer/pytato/pull/585.
538543
return self._function_cache_add(
539-
# Intentionally going to Mapper instead of super() to avoid
540-
# double caching when subclasses of CachedMapper override rec,
541-
# see https://github.com/inducer/pytato/pull/585
542-
inputs, Mapper.rec_function_definition(self, expr, *args, **kwargs))
544+
inputs, super().rec_function_definition(expr, *args, **kwargs))
543545

544546
def clone_for_callee(
545547
self, function: FunctionDefinition) -> Self:
@@ -1970,10 +1972,10 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames:
19701972
try:
19711973
return self._cache_retrieve(inputs)
19721974
except KeyError:
1973-
# Intentionally going to Mapper instead of super() to avoid
1974-
# double caching when subclasses of CachedMapper override rec,
1975-
# see https://github.com/inducer/pytato/pull/585
1976-
return self._cache_add(inputs, Mapper.rec(self, self.map_fn(expr)))
1975+
# Using super(CachedMapper, self) instead of super() to bypass
1976+
# CachedMapper.rec and avoid double caching
1977+
return self._cache_add(inputs,
1978+
super(CachedMapper, self).rec(self.map_fn(expr)))
19771979

19781980
# }}}
19791981

pytato/transform/metadata.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
from pytato.transform import (
8888
ArrayOrNames,
8989
ArrayOrNamesOrFunctionDefTc,
90+
CachedMapper,
9091
CopyMapper,
9192
Mapper,
9293
TransformMapperCache,
@@ -517,15 +518,14 @@ def rec(self, expr: ArrayOrNames) -> ArrayOrNames:
517518
try:
518519
return self._cache_retrieve(inputs)
519520
except KeyError:
520-
# Intentionally going to Mapper instead of super() to avoid
521-
# double caching when subclasses of CachedMapper override rec,
522-
# see https://github.com/inducer/pytato/pull/585
523-
result = Mapper.rec(self, expr)
521+
# Using super(CachedMapper, self) instead of super() to bypass
522+
# CachedMapper.rec and avoid double caching
523+
result = super(CachedMapper, self).rec(expr)
524524
if not isinstance(
525525
expr, AbstractResultWithNamedArrays | DistributedSendRefHolder):
526526
assert isinstance(expr, Array)
527-
# type-ignore reason: passed "ArrayOrNames"; expected "Array"
528-
result = self._attach_tags(expr, result) # type: ignore[arg-type]
527+
assert isinstance(result, Array)
528+
result = self._attach_tags(expr, result)
529529
return self._cache_add(inputs, result)
530530

531531
def map_named_call_result(self, expr: NamedCallResult) -> Array:

0 commit comments

Comments
 (0)