We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 6008052 commit 51f048dCopy full SHA for 51f048d
cebra/data/datasets.py
@@ -30,6 +30,7 @@
30
31
import cebra.data as cebra_data
32
import cebra.helper as cebra_helper
33
+from cebra.data.datatypes import Offset
34
35
36
class TensorDataset(cebra_data.SingleSessionDataset):
@@ -65,7 +66,7 @@ def __init__(self,
65
66
neural: Union[torch.Tensor, npt.NDArray],
67
continuous: Union[torch.Tensor, npt.NDArray] = None,
68
discrete: Union[torch.Tensor, npt.NDArray] = None,
- offset: int = 1,
69
+ offset: Offset = Offset(0, 1),
70
device: str = "cpu"):
71
super().__init__(device=device)
72
self.neural = self._to_tensor(neural, check_dtype="float").float()
0 commit comments