Skip to content

Commit 9253ad1

Browse files
committed
refactor: Applied formatters.
1 parent e3d75c7 commit 9253ad1

5 files changed

Lines changed: 15 additions & 11 deletions

File tree

src/modalities/evaluator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ def evaluate(
124124
cumulated_loss[0] += batch_loss.item() # sum up batch loss
125125
cumulated_loss[1] += 1
126126
batch_length_tensor = torch.tensor(len(batch)).to(device)
127-
throughput_aggregator.add_value(key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor)
127+
throughput_aggregator.add_value(
128+
key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor
129+
)
128130

129131
Evaluator._publish_progress(
130132
progress_publisher=self.progress_publisher,

src/modalities/utils/mfu.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,7 @@ def __init__(
156156
wrapped_model: FSDPX,
157157
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
158158
):
159-
self._num_params = get_total_number_of_trainable_parameters(
160-
model=wrapped_model,
161-
device_mesh=device_mesh
162-
)
159+
self._num_params = get_total_number_of_trainable_parameters(model=wrapped_model, device_mesh=device_mesh)
163160
self._n_layer = n_layer
164161
self._sequence_length = sequence_length
165162
self._n_embd = n_embd

tests/fsdp2_parallelization/test_full_and_hybrid_sharding.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import torch.multiprocessing as mp
88
import yaml
99
from pydantic import BaseModel
10-
from torch.distributed.fsdp import FSDPModule as FSDP2
1110

1211
from modalities.__main__ import Main
1312
from modalities.config.config import ProcessGroupBackendType

tests/test_torch_compile.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import copy
32

43
import pytest
@@ -68,8 +67,12 @@ def test_get_compiled_model_compiles_blocks(gpt2_model):
6867
result_model = ModelFactory.get_compiled_model(gpt2_model, block_names, fullgraph=True)
6968

7069
assert len(result_model.transformer.h) == 4, "Should still have four blocks"
71-
for i, (original_block_idx, new_block_idx) in enumerate(zip(original_model.transformer.h, result_model.transformer.h)):
72-
assert result_model.transformer.h[new_block_idx] is not original_model.transformer.h[original_block_idx], f"Block {i} should be a compiled version"
70+
for i, (original_block_idx, new_block_idx) in enumerate(
71+
zip(original_model.transformer.h, result_model.transformer.h)
72+
):
73+
assert (
74+
result_model.transformer.h[new_block_idx] is not original_model.transformer.h[original_block_idx]
75+
), f"Block {i} should be a compiled version"
7376
assert isinstance(result_model.transformer.h[new_block_idx], nn.Module), f"Block {i} should be an nn.Module"
7477
assert result_model.transformer.wte is original_wte, "Embedding layer should remain unchanged"
7578
assert result_model.transformer.lm_head is original_lm_head, "LM head should remain unchanged"

tests/utils/test_mfu.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,10 @@ def test_get_theoretical_flops_per_token(
313313
assert theoretical_flops_per_token == expected_theoretical_flops_per_token
314314

315315
@staticmethod
316-
@pytest.mark.skipif(torch.cuda.device_count() < 2 or not torch.cuda.get_device_name().startswith("NVIDIA A100"), reason="This test requires 2 A100 GPUs.")
316+
@pytest.mark.skipif(
317+
torch.cuda.device_count() < 2 or not torch.cuda.get_device_name().startswith("NVIDIA A100"),
318+
reason="This test requires 2 A100 GPUs.",
319+
)
317320
@pytest.mark.parametrize(
318321
"rdvz_port, relative_config_path, num_samples_per_second_per_gpu, expected_mfu",
319322
[
@@ -339,7 +342,7 @@ def test_compute_mfu(
339342
TestMFU._save_yaml_config(config_file_path=tmp_config_file_path, config=config_updated)
340343

341344
# run the test in a distributed environment
342-
world_size = 2 #torch.cuda.device_count()
345+
world_size = 2 # torch.cuda.device_count()
343346
num_samples_per_second = num_samples_per_second_per_gpu * world_size
344347
mp.spawn(
345348
TestMFU._test_compute_mfu_thread,

0 commit comments

Comments
 (0)