Skip to content

Commit dca438e

Browse files
authored
Merge pull request #110 from Balandat/fix_deprecated_sparse_tensor
Update deprecated sparse tensor construction
2 parents 2789d18 + baa44cf commit dca438e

1 file changed

Lines changed: 7 additions & 6 deletions

File tree

linear_operator/utils/interpolation.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,13 @@ def left_t_interp(interp_indices, interp_values, rhs, output_dim):
6363
device=interp_values.device,
6464
)
6565
size = torch.Size((batch_size, output_dim, num_data * num_interp))
66-
type_name = summing_matrix_values.type().split(".")[-1] # e.g. FloatTensor
67-
if interp_values.is_cuda:
68-
cls = getattr(torch.cuda.sparse, type_name)
69-
else:
70-
cls = getattr(torch.sparse, type_name)
71-
summing_matrix = cls(summing_matrix_indices, summing_matrix_values, size)
66+
summing_matrix = torch.sparse_coo_tensor(
67+
summing_matrix_indices,
68+
summing_matrix_values,
69+
size,
70+
dtype=summing_matrix_values.dtype,
71+
device=summing_matrix_values.device,
72+
)
7273

7374
# Sum up the values appropriately by performing sparse matrix multiplication
7475
values = values.reshape(batch_size, num_data * num_interp, num_cols)

0 commit comments

Comments
 (0)