Skip to content

Commit 1c86b79

Browse files
committed
Fix some Mypy errors.
1 parent 95f9490 commit 1c86b79

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

src/torchjd/sparse/_aten_function_overrides/shape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor:
207207
# Such a physical dimension already exists. Note that an alternative implementation would be
208208
# to simply always add the physical dimension, and squash it if it ends up being not needed.
209209
physicals = [t.physical for t in tensors_]
210-
pdim = indices[0][0]
210+
pdim = cast(int, indices[0, 0].item())
211211
new_strides = ref_tensor.strides
212212

213213
new_physical = aten.cat.default(physicals, dim=pdim)

src/torchjd/sparse/_structured_sparse_tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import itertools
22
import operator
3+
from collections.abc import Callable
34
from functools import wraps
45
from itertools import accumulate
56
from math import prod
@@ -10,7 +11,7 @@
1011

1112

1213
class StructuredSparseTensor(Tensor):
13-
_HANDLED_FUNCTIONS = dict()
14+
_HANDLED_FUNCTIONS = dict[Callable, Callable]()
1415

1516
@staticmethod
1617
def __new__(cls, physical: Tensor, strides: Tensor):

0 commit comments

Comments
 (0)