|
| 1 | +from typing import Dict, Optional, Tuple, Type |
| 2 | + |
1 | 3 | import cebra |
2 | | -import torch |
| 4 | +import matplotlib |
3 | 5 | import numpy as np |
4 | | -from ..utils_allen import decoding_frames |
5 | | -from ..utils_hpc import decoding_pos_dir |
6 | | -from ..activations import get_activations_model |
7 | | -from .base import _BaseMetric |
8 | | -import cebra_lens.matplotlib as cebra_lens_matplotlib |
9 | 6 | import numpy.typing as npt |
10 | | -from typing import Dict, Type, Tuple, Optional |
11 | | -import torch.nn as nn |
12 | 7 | import sklearn.metrics |
| 8 | +import torch |
13 | 9 | import torch as pt |
14 | | -import matplotlib |
| 10 | +import torch.nn as nn |
| 11 | + |
| 12 | +import cebra_lens.matplotlib as cebra_lens_matplotlib |
| 13 | + |
| 14 | +from ..activations import get_activations_model |
| 15 | +from ..utils_allen import decoding_frames |
| 16 | +from ..utils_hpc import decoding_pos_dir |
| 17 | +from .base import _BaseMetric |
15 | 18 |
|
16 | 19 |
|
17 | 20 | def decoding( |
@@ -370,8 +373,10 @@ def plot( |
370 | 373 | ) |
371 | 374 |
|
372 | 375 | if self.output_only: |
373 | | - return cebra_lens_matplotlib.plot_decoding(results_dict, palette, self.dataset_label, |
374 | | - label, plot_error, ax) |
| 376 | + return cebra_lens_matplotlib.plot_decoding(results_dict, palette, |
| 377 | + self.dataset_label, |
| 378 | + label, plot_error, ax) |
375 | 379 | else: |
376 | | - return cebra_lens_matplotlib.plot_layer_decoding(results_dict, title, self.dataset_label, |
377 | | - label, plot_error, figsize) |
| 380 | + return cebra_lens_matplotlib.plot_layer_decoding( |
| 381 | + results_dict, title, self.dataset_label, label, plot_error, |
| 382 | + figsize) |
0 commit comments