diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index 64a6b2333d..faf73465ff 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -50,11 +50,16 @@ def compress_templates( if remove_mean: templates_array -= templates_array.mean(axis=(1, 2))[:, None, None] - temporal, singular, spatial = np.linalg.svd(templates_array, full_matrices=False) - # Keep only the strongest components - temporal = temporal[:, :, :approx_rank].astype(np.float32) - singular = singular[:, :approx_rank].astype(np.float32) - spatial = spatial[:, :approx_rank, :].astype(np.float32) + num_templates, num_samples, num_channels = templates_array.shape + temporal = np.zeros((num_templates, num_samples, approx_rank), dtype=np.float32) + spatial = np.zeros((num_templates, approx_rank, num_channels), dtype=np.float32) + singular = np.zeros((num_templates, approx_rank), dtype=np.float32) + + for i in range(num_templates): + i_temporal, i_singular, i_spatial = np.linalg.svd(templates_array[i], full_matrices=False) + temporal[i, :, : min(approx_rank, num_channels)] = i_temporal[:, :approx_rank] + spatial[i, : min(approx_rank, num_channels), :] = i_spatial[:approx_rank, :] + singular[i, : min(approx_rank, num_channels)] = i_singular[:approx_rank] if return_new_templates: templates_array = np.matmul(temporal * singular[:, np.newaxis, :], spatial)