Skip to content

Commit ed78497

Browse files
committed
Merge remote-tracking branch 'origin/main' into type-annotations
2 parents b880e2e + a267b0f commit ed78497

6 files changed

Lines changed: 24 additions & 9 deletions

File tree

.github/workflows/tests_full.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
sudo apt-get update
2424
sudo apt-get install curl -y # required by coveralls
2525
sudo apt-get install git -y
26-
python -m pip install torch
26+
python -m pip install torch~=2.4.1
2727
python -m pip install --upgrade pip setuptools wheel
2828
export FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE
2929
python -m pip install -e .[tests]

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ Currently, the flash attention dependency cannot be installed without torch bein
5959
Until the flash attention developers fix this, we have to run
6060

6161
```sh
62-
pip install torch
62+
pip install torch~=2.4.1
6363
```
6464
beforehand.
6565

@@ -75,7 +75,7 @@ pip install -e .
7575
To install Modalities via pip, run
7676

7777
```sh
78-
pip install torch
78+
pip install torch~=2.4.1
7979
pip install modalities
8080
```
8181

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ description = "Modalities, a PyTorch-native framework for distributed and reprod
66
readme = "README.md"
77
dependencies = [
88
"numpy<2.0",
9-
"torch>=2.3",
9+
"torch~=2.4.1",
1010
"packaging",
1111
"tqdm",
1212
"pyyaml",

src/modalities/checkpointing/checkpoint_saving_strategies.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import dataclasses
12
from abc import ABC, abstractmethod
23
from typing import Optional
34

@@ -71,15 +72,15 @@ def get_checkpoint_instruction(
7172
save_current = True
7273

7374
if self.k > 0:
74-
self.saved_step_checkpoints = [training_progress] + self.saved_step_checkpoints
75+
self.saved_step_checkpoints = [dataclasses.replace(training_progress)] + self.saved_step_checkpoints
7576
if len(self.saved_step_checkpoints) > self.k:
7677
# Delete oldest checkpoint
7778
checkpoints_to_delete = [self.saved_step_checkpoints[-1]]
7879
self.saved_step_checkpoints = self.saved_step_checkpoints[:-1]
7980
elif self.k == 0:
8081
save_current = False
8182
elif self.k == -1:
82-
self.saved_step_checkpoints = [training_progress] + self.saved_step_checkpoints
83+
self.saved_step_checkpoints = [dataclasses.replace(training_progress)] + self.saved_step_checkpoints
8384

8485
return CheckpointingInstruction(save_current=save_current, checkpoints_to_delete=checkpoints_to_delete)
8586

src/modalities/dataloader/dataloader.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
from typing import Iterable, Optional
22

33
from torch.utils.data import Dataset, DistributedSampler, Sampler
4-
from torch.utils.data.dataloader import DataLoader, T_co, _collate_fn_t, _worker_init_fn_t
4+
from torch.utils.data.dataloader import DataLoader, _collate_fn_t, _worker_init_fn_t
5+
6+
try: # torch <= 2.4
7+
from torch.utils.data.dataloader import T_co
8+
except ImportError: # torch >= 2.5
9+
from torch.utils.data.dataloader import _T_co as T_co
510

611
from modalities.dataloader.samplers import ResumableBatchSampler
712

tests/checkpointing/test_checkpoint_strategies.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# k value is 0. No deletion of checkpoints.
1313
(0, [], [], False),
1414
# k value is 2, but there are currently only one checkpoint. Hence, no deletion.
15-
(2, [1], [], True),
15+
(2, [TrainingProgress(1, 1, 20, 20)], [], True),
1616
# k value is -1, therefore we want to keep all checkpoints without any deletion
1717
(
1818
-1,
@@ -25,12 +25,21 @@
2525
def test_checkpoint_strategy_k(
2626
k: int, saved_instances: list[TrainingProgress], checkpoints_to_delete: list[int], save_current: bool
2727
) -> None:
28+
num_seen_steps_current_run = 10
2829
training_progress = TrainingProgress(
29-
num_seen_steps_current_run=10, num_seen_tokens_current_run=10, num_target_steps=20, num_target_tokens=40
30+
num_seen_steps_current_run=num_seen_steps_current_run,
31+
num_seen_tokens_current_run=10,
32+
num_target_steps=20,
33+
num_target_tokens=40,
3034
)
3135
checkpoint_strategy = SaveKMostRecentCheckpointsStrategy(k=k)
3236
checkpoint_strategy.saved_step_checkpoints = saved_instances
3337
checkpoint_instruction = checkpoint_strategy.get_checkpoint_instruction(training_progress=training_progress)
3438

3539
assert checkpoint_instruction.checkpoints_to_delete == checkpoints_to_delete
3640
assert checkpoint_instruction.save_current == save_current
41+
42+
# make sure that modifying the training progress externally does not affect saved_step_checkpoints
43+
if k != 0 and save_current:
44+
training_progress.num_seen_steps_current_run = 100
45+
assert checkpoint_strategy.saved_step_checkpoints[0].num_seen_steps_current_run == num_seen_steps_current_run

0 commit comments

Comments
 (0)