Skip to content

Commit 24aa859

Browse files
Potential fix for pull request finding
in plot_and_save_heatmap, turn everything to numpy before plotting Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
1 parent 9722282 commit 24aa859

1 file changed

Lines changed: 9 additions & 1 deletion

File tree

pyvisim/_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,15 @@ def plot_and_save_heatmap(
147147
:param show: whether to display the plot
148148
:param save_fig_path: Path to save the figure
149149
"""
150-
figsize = (len(matrix) * 0.7, len(matrix) * 0.7) if figsize is None else figsize
150+
if isinstance(matrix, list):
151+
matrix = np.array(matrix)
152+
elif isinstance(matrix, torch.Tensor):
153+
matrix = matrix.detach().cpu().numpy()
154+
155+
figsize = (
156+
matrix.shape[1] * 0.7,
157+
matrix.shape[0] * 0.7,
158+
) if figsize is None else figsize
151159
plt.figure(figsize=figsize)
152160
sns.heatmap(
153161
matrix,

0 commit comments

Comments
 (0)