Skip to content

Commit c44ec59

Browse files
authored
Merge pull request #262 from Modalities/type-annotations
chore: use built-in types
2 parents a267b0f + ed78497 commit c44ec59

63 files changed

Lines changed: 302 additions & 330 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/modalities/__main__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import shutil
66
from datetime import datetime
77
from pathlib import Path
8-
from typing import List, Tuple, Type
8+
from typing import Type
99

1010
import click
1111
import click_pathlib
@@ -198,7 +198,7 @@ def entry_point_pack_encoded_data(config_path: FilePath):
198198
@data.command(name="merge_packed_data")
199199
@click.argument("src_paths", type=click.types.Path(exists=True, path_type=Path), nargs=-1, required=True)
200200
@click.argument("target_path", type=click.types.Path(file_okay=False, dir_okay=False, path_type=Path))
201-
def entry_point_merge_packed_data(src_paths: List[Path], target_path: Path):
201+
def entry_point_merge_packed_data(src_paths: list[Path], target_path: Path):
202202
"""Utility for merging different pbin-files into one.
203203
This is especially useful, if different datasets were at different points in time or if one encoding takes so long,
204204
that the overall process was done in chunks.
@@ -207,7 +207,7 @@ def entry_point_merge_packed_data(src_paths: List[Path], target_path: Path):
207207
Specify an arbitrary amount of pbin-files and/or directory containing such as input.
208208
209209
Args:
210-
src_paths (List[Path]): List of paths to the pbin-files or directories containing such.
210+
src_paths (list[Path]): List of paths to the pbin-files or directories containing such.
211211
target_path (Path): The path to the merged pbin-file, that will be created.
212212
"""
213213
input_files = []
@@ -364,7 +364,7 @@ def get_logging_publishers(
364364
results_subscriber: MessageSubscriberIF[EvaluationResultBatch],
365365
global_rank: int,
366366
local_rank: int,
367-
) -> Tuple[MessagePublisher[EvaluationResultBatch], MessagePublisher[ProgressUpdate]]:
367+
) -> tuple[MessagePublisher[EvaluationResultBatch], MessagePublisher[ProgressUpdate]]:
368368
"""Returns the logging publishers for the training.
369369
370370
These publishers are used to pass the evaluation results and the progress updates to the message broker.
@@ -377,7 +377,7 @@ def get_logging_publishers(
377377
local_rank (int): The local rank of the current process on the current node
378378
379379
Returns:
380-
Tuple[MessagePublisher[EvaluationResultBatch], MessagePublisher[ProgressUpdate]]: The evaluation
380+
tuple[MessagePublisher[EvaluationResultBatch], MessagePublisher[ProgressUpdate]]: The evaluation
381381
result publisher and the progress publisher
382382
"""
383383
message_broker = MessageBroker()

src/modalities/batch.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
22
from dataclasses import dataclass, field
3-
from typing import Dict, Optional
3+
from typing import Optional
44

55
import torch
66

@@ -32,8 +32,8 @@ class Batch(ABC):
3232
class DatasetBatch(Batch, TorchDeviceMixin):
3333
"""A batch of samples and its targets. Used to batch train a model."""
3434

35-
samples: Dict[str, torch.Tensor]
36-
targets: Dict[str, torch.Tensor]
35+
samples: dict[str, torch.Tensor]
36+
targets: dict[str, torch.Tensor]
3737
batch_dim: int = 0
3838

3939
def to(self, device: torch.device):
@@ -58,8 +58,8 @@ def __len__(self) -> int:
5858
class InferenceResultBatch(Batch, TorchDeviceMixin):
5959
"""Stores targets and predictions of an entire batch."""
6060

61-
targets: Dict[str, torch.Tensor]
62-
predictions: Dict[str, torch.Tensor]
61+
targets: dict[str, torch.Tensor]
62+
predictions: dict[str, torch.Tensor]
6363
batch_dim: int = 0
6464

6565
def to_cpu(self):
@@ -106,12 +106,12 @@ class EvaluationResultBatch(Batch):
106106

107107
dataloader_tag: str
108108
num_train_steps_done: int
109-
losses: Dict[str, ResultItem] = field(default_factory=dict)
110-
metrics: Dict[str, ResultItem] = field(default_factory=dict)
111-
throughput_metrics: Dict[str, ResultItem] = field(default_factory=dict)
109+
losses: dict[str, ResultItem] = field(default_factory=dict)
110+
metrics: dict[str, ResultItem] = field(default_factory=dict)
111+
throughput_metrics: dict[str, ResultItem] = field(default_factory=dict)
112112

113113
def __str__(self) -> str:
114-
def _round_result_item_dict(result_item_dict: Dict[str, ResultItem]) -> Dict[str, ResultItem]:
114+
def _round_result_item_dict(result_item_dict: dict[str, ResultItem]) -> dict[str, ResultItem]:
115115
rounded_result_item_dict = {}
116116
for k, item in result_item_dict.items():
117117
if item.decimal_places is not None:

src/modalities/checkpointing/checkpoint_saving.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from enum import Enum
2-
from typing import Dict
32

43
import torch.nn as nn
54
from torch.optim import Optimizer
@@ -43,7 +42,7 @@ def __init__(
4342
def save_checkpoint(
4443
self,
4544
training_progress: TrainingProgress,
46-
evaluation_result: Dict[str, EvaluationResultBatch],
45+
evaluation_result: dict[str, EvaluationResultBatch],
4746
model: nn.Module,
4847
optimizer: Optimizer,
4948
early_stoppping_criterion_fulfilled: bool = False,
@@ -53,7 +52,7 @@ def save_checkpoint(
5352
5453
Args:
5554
training_progress (TrainingProgress): The training progress.
56-
evaluation_result (Dict[str, EvaluationResultBatch]): The evaluation result.
55+
evaluation_result (dict[str, EvaluationResultBatch]): The evaluation result.
5756
model (nn.Module): The model to be saved.
5857
optimizer (Optimizer): The optimizer to be saved.
5958
early_stoppping_criterion_fulfilled (bool, optional):

src/modalities/checkpointing/checkpoint_saving_instruction.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from dataclasses import dataclass, field
2-
from typing import List
32

43
from modalities.training.training_progress import TrainingProgress
54

@@ -11,8 +10,8 @@ class CheckpointingInstruction:
1110
1211
Attributes:
1312
save_current (bool): Indicates whether to save the current checkpoint.
14-
checkpoints_to_delete (List[TrainingProgress]): List of checkpoint IDs to delete.
13+
checkpoints_to_delete (list[TrainingProgress]): List of checkpoint IDs to delete.
1514
"""
1615

1716
save_current: bool = False
18-
checkpoints_to_delete: List[TrainingProgress] = field(default_factory=list)
17+
checkpoints_to_delete: list[TrainingProgress] = field(default_factory=list)

src/modalities/checkpointing/checkpoint_saving_strategies.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import dataclasses
22
from abc import ABC, abstractmethod
3-
from typing import Dict, List, Optional
3+
from typing import Optional
44

55
from modalities.batch import EvaluationResultBatch
66
from modalities.checkpointing.checkpoint_saving_instruction import CheckpointingInstruction
@@ -14,15 +14,15 @@ class CheckpointSavingStrategyIF(ABC):
1414
def get_checkpoint_instruction(
1515
self,
1616
training_progress: TrainingProgress,
17-
evaluation_result: Optional[Dict[str, EvaluationResultBatch]] = None,
17+
evaluation_result: Optional[dict[str, EvaluationResultBatch]] = None,
1818
early_stoppping_criterion_fulfilled: bool = False,
1919
) -> CheckpointingInstruction:
2020
"""
2121
Returns the checkpointing instruction.
2222
2323
Parameters:
2424
training_progress (TrainingProgress): The training progress.
25-
evaluation_result (Dict[str, EvaluationResultBatch] | None, optional):
25+
evaluation_result (dict[str, EvaluationResultBatch] | None, optional):
2626
The evaluation result. Defaults to None.
2727
early_stoppping_criterion_fulfilled (bool, optional):
2828
Whether the early stopping criterion is fulfilled. Defaults to False.
@@ -46,29 +46,29 @@ def __init__(self, k: int = -1):
4646
Set to a positive integer to save the specified number of
4747
checkpointsStrategy for saving the k most recent checkpoints only.
4848
"""
49-
self.saved_step_checkpoints: List[TrainingProgress] = []
49+
self.saved_step_checkpoints: list[TrainingProgress] = []
5050
self.k = k
5151

5252
def get_checkpoint_instruction(
5353
self,
5454
training_progress: TrainingProgress,
55-
evaluation_result: Dict[str, EvaluationResultBatch] | None = None,
55+
evaluation_result: dict[str, EvaluationResultBatch] | None = None,
5656
early_stoppping_criterion_fulfilled: bool = False,
5757
) -> CheckpointingInstruction:
5858
"""
5959
Generates a checkpointing instruction based on the given parameters.
6060
6161
Args:
6262
training_progress (TrainingProgress): The training progress.
63-
evaluation_result (Dict[str, EvaluationResultBatch] | None, optional):
63+
evaluation_result (dict[str, EvaluationResultBatch] | None, optional):
6464
The evaluation result. Defaults to None.
6565
early_stoppping_criterion_fulfilled (bool, optional):
6666
Whether the early stopping criterion is fulfilled. Defaults to False.
6767
6868
Returns:
6969
CheckpointingInstruction: The generated checkpointing instruction.
7070
"""
71-
checkpoints_to_delete: List[TrainingProgress] = []
71+
checkpoints_to_delete: list[TrainingProgress] = []
7272
save_current = True
7373

7474
if self.k > 0:
@@ -101,15 +101,15 @@ def __init__(self, k: int):
101101
def get_checkpoint_instruction(
102102
self,
103103
training_progress: TrainingProgress,
104-
evaluation_result: Dict[str, EvaluationResultBatch] | None = None,
104+
evaluation_result: dict[str, EvaluationResultBatch] | None = None,
105105
early_stoppping_criterion_fulfilled: bool = False,
106106
) -> CheckpointingInstruction:
107107
"""
108108
Returns a CheckpointingInstruction object.
109109
110110
Args:
111111
training_progress (TrainingProgress): The training progress.
112-
evaluation_result (Dict[str, EvaluationResultBatch] | None, optional):
112+
evaluation_result (dict[str, EvaluationResultBatch] | None, optional):
113113
The evaluation result. Defaults to None.
114114
early_stoppping_criterion_fulfilled (bool, optional):
115115
Whether the early stopping criterion is fulfilled. Defaults to False.

src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from pathlib import Path
2-
from typing import List
32

43
import torch
54
import torch.nn as nn
@@ -17,7 +16,7 @@ class FSDPCheckpointLoading(CheckpointLoadingIF):
1716
def __init__(
1817
self,
1918
global_rank: int,
20-
block_names: List[str],
19+
block_names: list[str],
2120
mixed_precision_settings: MixedPrecisionSettings,
2221
sharding_strategy: ShardingStrategy,
2322
):
@@ -26,7 +25,7 @@ def __init__(
2625
2726
Args:
2827
global_rank (int): The global rank of the process.
29-
block_names (List[str]): The names of the blocks.
28+
block_names (list[str]): The names of the blocks.
3029
mixed_precision_settings (MixedPrecisionSettings): The settings for mixed precision.
3130
sharding_strategy (ShardingStrategy): The sharding strategy.
3231

src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from enum import Enum
22
from pathlib import Path
3-
from typing import List
43

54
import torch
65
import torch.distributed as dist
@@ -124,7 +123,7 @@ def _save_checkpoint(self, model: FSDP, optimizer: Optimizer, training_progress:
124123
# leading to wrong throughput measurements.
125124
dist.barrier()
126125

127-
def _get_paths_to_delete(self, training_progress: TrainingProgress) -> List[Path]:
126+
def _get_paths_to_delete(self, training_progress: TrainingProgress) -> list[Path]:
128127
return [
129128
self._get_checkpointing_path(
130129
experiment_id=self.experiment_id,

src/modalities/config/component_factory.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Type, TypeVar, Union
1+
from typing import Any, Type, TypeVar
22

33
from pydantic import BaseModel
44

@@ -19,12 +19,12 @@ def __init__(self, registry: Registry) -> None:
1919
"""
2020
self.registry = registry
2121

22-
def build_components(self, config_dict: Dict, components_model_type: Type[BaseModelChild]) -> BaseModelChild:
22+
def build_components(self, config_dict: dict, components_model_type: Type[BaseModelChild]) -> BaseModelChild:
2323
"""Builds the components from a config dictionary. All components specified in `components_model_type`
2424
are built from the config dictionary in a recursive manner.
2525
2626
Args:
27-
config_dict (Dict): Dictionary with the configuration of the components.
27+
config_dict (dict): Dictionary with the configuration of the components.
2828
components_model_type (Type[BaseModelChild]): Base model type defining the components to be build.
2929
3030
Returns:
@@ -35,7 +35,7 @@ def build_components(self, config_dict: Dict, components_model_type: Type[BaseMo
3535
components = components_model_type(**component_dict)
3636
return components
3737

38-
def _build_config(self, config_dict: Dict, component_names: List[str]) -> Dict[str, Any]:
38+
def _build_config(self, config_dict: dict, component_names: list[str]) -> dict[str, Any]:
3939
component_dict_filtered = {name: config_dict[name] for name in component_names}
4040
components, _ = self._build_component(
4141
current_component_config=component_dict_filtered,
@@ -47,10 +47,10 @@ def _build_config(self, config_dict: Dict, component_names: List[str]) -> Dict[s
4747

4848
def _build_component(
4949
self,
50-
current_component_config: Union[Dict, List, Any],
51-
component_config: Union[Dict, List, Any],
52-
top_level_components: Dict[str, Any],
53-
traversal_path: List,
50+
current_component_config: dict | list | Any,
51+
component_config: dict | list | Any,
52+
top_level_components: dict[str, Any],
53+
traversal_path: list,
5454
) -> Any:
5555
# build sub components first via recursion
5656
if isinstance(current_component_config, dict):
@@ -130,16 +130,16 @@ def _build_component(
130130
return current_component_config, top_level_components
131131

132132
@staticmethod
133-
def _is_component_config(config_dict: Dict) -> bool:
133+
def _is_component_config(config_dict: dict) -> bool:
134134
# TODO instead of field checks, we should introduce an enum for the config type.
135135
return "component_key" in config_dict.keys()
136136

137137
@staticmethod
138-
def _is_reference_config(config_dict: Dict) -> bool:
138+
def _is_reference_config(config_dict: dict) -> bool:
139139
# TODO instead of field checks, we should introduce an enum for the config type.
140140
return {"instance_key", "pass_type"} == config_dict.keys()
141141

142-
def _instantiate_component_config(self, component_key: str, variant_key: str, config_dict: Dict) -> BaseModel:
142+
def _instantiate_component_config(self, component_key: str, variant_key: str, config_dict: dict) -> BaseModel:
143143
component_config_type: Type[BaseModel] = self.registry.get_config(component_key, variant_key)
144144
self._assert_valid_config_keys(
145145
component_key=component_key,
@@ -151,7 +151,7 @@ def _instantiate_component_config(self, component_key: str, variant_key: str, co
151151
return comp_config
152152

153153
def _assert_valid_config_keys(
154-
self, component_key: str, variant_key: str, config_dict: Dict, component_config_type: Type[BaseModelChild]
154+
self, component_key: str, variant_key: str, config_dict: dict, component_config_type: Type[BaseModelChild]
155155
) -> None:
156156
required_keys = []
157157
optional_keys = []
@@ -178,7 +178,7 @@ def _instantiate_component(self, component_key: str, variant_key: str, component
178178
return component
179179

180180
@staticmethod
181-
def _base_model_to_dict(base_model: BaseModel) -> Dict:
181+
def _base_model_to_dict(base_model: BaseModel) -> dict:
182182
# converts top level structure of base_model into dictionary while maintaining substructure
183183
output = {}
184184
for name, _ in base_model.model_fields.items():

0 commit comments

Comments
 (0)