From 5c75a16f1935a2beea16a343c90f163d52d2341b Mon Sep 17 00:00:00 2001 From: Amirreza Razmjoo <40331272+amirrazmjoo@users.noreply.github.com> Date: Wed, 10 Apr 2024 11:43:16 +0200 Subject: [PATCH] device mismatch solved for truncated SVD The output of truncated_svd should ideally reside in the same device as M. However, in certain cases where the tensor is low rank, the output has been observed to be on the 'cpu' device instead. --- tntorch/round.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tntorch/round.py b/tntorch/round.py index 8b273ac..aa5de68 100644 --- a/tntorch/round.py +++ b/tntorch/round.py @@ -131,10 +131,10 @@ def truncated_svd( # NOTE: Special case: M = zero -> rank is 1 if batch: if svd[1].max() < 1e-13: - return torch.zeros([batch_size, M.shape[1], 1]), torch.zeros([batch_size, 1, M.shape[2]]) + return torch.zeros([batch_size, M.shape[1], 1]).to(M.device), torch.zeros([batch_size, 1, M.shape[2]]).to(M.device) else: if svd[1][0] < 1e-13: - return torch.zeros([M.shape[0], 1]), torch.zeros([1, M.shape[1]]) + return torch.zeros([M.shape[0], 1]).to(M.device), torch.zeros([1, M.shape[1]]).to(M.device) S = svd[1]**2