Skip to content

Commit f1abf92

Browse files
committed
Update imports and type annotations in vision.py and contrastive_pretraining.py
1 parent c8e6afb commit f1abf92

2 files changed

Lines changed: 6 additions & 6 deletions

File tree

mmlearn/modules/encoders/vision.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from transformers.modeling_outputs import BaseModelOutput
1414

1515
from mmlearn import hf_utils
16-
from mmlearn.datasets.core.modalities import Modalities, Modality
16+
from mmlearn.datasets.core.modalities import Modalities
1717
from mmlearn.datasets.processors.masking import apply_masks
1818
from mmlearn.datasets.processors.transforms import (
1919
repeat_interleave_batch,
@@ -137,13 +137,13 @@ def forward(self, inputs: Dict[str, Any]) -> BaseModelOutput:
137137
)
138138

139139
def get_intermediate_layers(
140-
self, inputs: Dict[Union[str, Modality], Any], n: int = 1
140+
self, inputs: Dict[str, Any], n: int = 1
141141
) -> List[torch.Tensor]:
142142
"""Get the output of the intermediate layers.
143143
144144
Parameters
145145
----------
146-
inputs : Dict[Union[str, Modality], Any]
146+
inputs : Dict[str, Any]
147147
The input data. The `image` will be expected under the `Modalities.RGB` key.
148148
n : int, default=1
149149
The number of intermediate layers to return.

mmlearn/tasks/contrastive_pretraining.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ def validation_step(
497497
498498
Parameters
499499
----------
500-
batch : Dict[Union[str, Modality], Any]
500+
batch : Dict[str, torch.Tensor]
501501
The batch of data to process.
502502
batch_idx : int
503503
The index of the batch.
@@ -524,7 +524,7 @@ def test_step(
524524
525525
Parameters
526526
----------
527-
batch : Dict[Union[str, Modality], Any]
527+
batch : Dict[str, torch.Tensor]
528528
The batch of data to process.
529529
batch_idx : int
530530
The index of the batch.
@@ -644,7 +644,7 @@ def _shared_eval_step(
644644
645645
Parameters
646646
----------
647-
batch : Dict[Union[str, Modality], Any]
647+
batch : Dict[str, torch.Tensor]
648648
The batch of data to process.
649649
batch_idx : int
650650
The index of the batch.

0 commit comments

Comments
 (0)