Skip to content

Commit 75e5dbe

Browse files
committed
use explicit isinstance() check instead of try/except around to_index_lambda()
1 parent 2e9fedd commit 75e5dbe

1 file changed

Lines changed: 13 additions & 6 deletions

File tree

pytato/analysis/__init__.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,19 @@
4141
Array,
4242
ArrayOrScalar,
4343
Concatenate,
44+
DataWrapper,
4445
DictOfNamedArrays,
4546
Einsum,
4647
IndexBase,
4748
IndexLambda,
4849
IndexRemappingBase,
4950
InputArgumentBase,
5051
NamedArray,
52+
Placeholder,
5153
ShapeType,
5254
Stack,
5355
)
54-
from pytato.diagnostic import CannotBeLoweredToIndexLambda
56+
from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder
5557
from pytato.function import Call, FunctionDefinition, NamedCallResult
5658
from pytato.scalar_expr import (
5759
FlopCounter as ScalarFlopCounter,
@@ -75,7 +77,6 @@
7577

7678
import pytools.tag
7779

78-
from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder
7980
from pytato.loopy import LoopyCall
8081

8182
__doc__ = """
@@ -776,10 +777,16 @@ def combine(self, *args: int) -> int:
776777
return sum(args)
777778

778779
def _get_own_flop_count(self, expr: Array) -> int:
779-
try:
780-
nflops = self.scalar_flop_counter(to_index_lambda(expr).expr)
781-
except CannotBeLoweredToIndexLambda:
782-
nflops = 0
780+
if isinstance(
781+
expr,
782+
(
783+
DataWrapper,
784+
Placeholder,
785+
NamedArray,
786+
DistributedRecv,
787+
DistributedSendRefHolder)):
788+
return 0
789+
nflops = self.scalar_flop_counter(to_index_lambda(expr).expr)
783790
if not isinstance(nflops, int):
784791
from pytato.scalar_expr import InputGatherer as ScalarInputGatherer
785792
var_names: set[str] = set(ScalarInputGatherer()(nflops))

0 commit comments

Comments
 (0)