Skip to content

Commit 1c11601

Browse files
authored
Fix MyPY (#19560)
### Summary From the pytorch code, create_symintnode will always return a symint when the sym input is symbolic. Mypy doesn't know this. Assert is an idiomatic way to express the narrowing to mypy. This is also in test code, so low risk. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani
1 parent a7fb87e commit 1c11601

4 files changed

Lines changed: 4 additions & 0 deletions

File tree

backends/arm/test/misc/test_tosa_dialect_resize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def _make_symint(
2222
shape_env: ShapeEnv, symbol: str, hint: int, min: int = 1, max: int = 64
2323
) -> torch.SymInt:
2424
symint = shape_env.create_symintnode(sympy.Symbol(symbol), hint=hint)
25+
assert isinstance(symint, torch.SymInt)
2526
shape_env.constrain_symbol_range(
2627
symint.node.expr, compiler_min=min, compiler_max=max
2728
)

backends/arm/test/misc/test_tosa_dialect_shape_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def _make_symint(
2323
) -> torch.SymInt:
2424
"""Create a symbolic dimension backed by the provided ShapeEnv."""
2525
symint = shape_env.create_symintnode(sympy.Symbol(symbol), hint=hint)
26+
assert isinstance(symint, torch.SymInt)
2627
symbol_expr = symint.node.expr
2728
shape_env.constrain_symbol_range(symbol_expr, compiler_min=min, compiler_max=max)
2829
return symint

backends/arm/test/passes/test_rewrite_upsample_pass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def _make_symint(
1414
shape_env: ShapeEnv, symbol: str, hint: int, min: int = 1, max: int = 64
1515
) -> torch.SymInt:
1616
symint = shape_env.create_symintnode(sympy.Symbol(symbol), hint=hint)
17+
assert isinstance(symint, torch.SymInt)
1718
shape_env.constrain_symbol_range(
1819
symint.node.expr, compiler_min=min, compiler_max=max
1920
)

backends/arm/test/passes/test_symbolic_value_range.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def _make_shape_env(
2020
) -> tuple[ShapeEnv, torch.SymInt]:
2121
shape_env = ShapeEnv()
2222
symint = shape_env.create_symintnode(sympy.Symbol(symbol_name), hint=hint)
23+
assert isinstance(symint, torch.SymInt)
2324
shape_env.constrain_symbol_range(
2425
symint.node.expr,
2526
compiler_min=compiler_min,

0 commit comments

Comments
 (0)