Skip to content

Commit 4dea496

Browse files
committed
feat: implemented llama3 weight init tests
1 parent 5f5e616 commit 4dea496

2 files changed

Lines changed: 95 additions & 4 deletions

File tree

tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ app_state_raw:
177177
component_key: app_state
178178
variant_key: raw
179179
config:
180-
model:
180+
model:
181181
instance_key: initialized_model
182182
pass_type: BY_REFERENCE
183183
optimizer:
@@ -288,7 +288,7 @@ optimizer:
288288
eps: 1e-8
289289
weight_decay: 1e-1
290290
weight_decay_groups_excluded: [embedding, layernorm]
291-
wrapped_model:
291+
wrapped_model:
292292
instance_key: initialized_model
293293
pass_type: BY_REFERENCE
294294

tests/test_initialization_fsdpx.py

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,16 @@
1818
from torch.distributed.fsdp import StateDictType
1919

2020
from modalities.__main__ import Main
21-
from modalities.config.config import ProcessGroupBackendType
22-
from modalities.config.pydantic_if_types import PydanticFSDP1ModuleType, PydanticFSDP2ModuleType
21+
from modalities.config.component_factory import ComponentFactory
22+
from modalities.config.config import ProcessGroupBackendType, load_app_config_dict
23+
from modalities.config.pydantic_if_types import (
24+
PydanticFSDP1ModuleType,
25+
PydanticFSDP2ModuleType,
26+
PydanticPytorchModuleType,
27+
)
28+
from modalities.models.gpt2.gpt2_model import GPT2LLM, GPT2Block
29+
from modalities.registry.components import COMPONENTS
30+
from modalities.registry.registry import Registry
2331
from tests.end2end_tests.custom_components import MultiProcessingCudaEnv
2432

2533

@@ -493,3 +501,86 @@ def _get_fdsp2_state_dict(model: FSDP2) -> dict[str, Any]:
493501
model=model, optimizers=[], options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True)
494502
)[0]
495503
return model_state
504+
505+
506+
class TestLlama3LikeInitialization:
507+
@pytest.mark.parametrize("has_bias", [True, False])
508+
def test_llama3_like_initialization(self, has_bias: bool):
509+
config_file_path = Path(__file__).parent / "test_yaml_configs/llama3_config_initalization.yaml"
510+
n_layer = 4
511+
model = self._get_components(config_file_path=config_file_path, has_bias=has_bias)
512+
self._test_wte(model=model)
513+
self._test_lm_head(model=model)
514+
515+
for block in model.transformer.h:
516+
self._test_qkv_proj(gpt2_block=block, has_bias=has_bias)
517+
self._test_c_proj(gpt2_block=block, has_bias=has_bias, n_layer=n_layer)
518+
self._test_swiglu_proj(gpt2_block=block, has_bias=has_bias, n_layer=n_layer)
519+
520+
def _get_components(self, config_file_path: Path, has_bias: bool) -> GPT2LLM:
521+
config_dict = load_app_config_dict(
522+
config_file_path=config_file_path,
523+
)
524+
config_dict["model_raw"]["config"]["bias"] = has_bias
525+
registry = Registry(COMPONENTS)
526+
component_factory = ComponentFactory(registry=registry)
527+
528+
class ComponentsInstantiationModel(BaseModel):
529+
initialized_model: PydanticPytorchModuleType
530+
531+
components: ComponentsInstantiationModel = component_factory.build_components(
532+
config_dict=config_dict, components_model_type=ComponentsInstantiationModel
533+
)
534+
return components.initialized_model
535+
536+
def _test_wte(self, model: GPT2LLM):
537+
assert model.transformer.wte.weight.std().detach().cpu() == pytest.approx(1, abs=1e-3)
538+
assert model.transformer.wte.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3)
539+
540+
def _test_lm_head(self, model: GPT2LLM, n_emb: int):
541+
assert model.transformer.lm_head.weight.std().detach().cpu() == pytest.approx(1 / math.sqrt(n_emb), abs=1e-3)
542+
assert model.transformer.lm_head.weight.max().detach().cpu() <= 3 / math.sqrt(n_emb)
543+
assert model.transformer.lm_head.weight.min().detach().cpu() >= -3 / math.sqrt(n_emb)
544+
assert model.transformer.lm_head.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3)
545+
546+
def _test_qkv_proj(self, gpt2_block: GPT2Block, has_bias: bool):
547+
layers = (gpt2_block.attn.q_attn, gpt2_block.attn.k_attn, gpt2_block.attn.v_attn)
548+
for layer in layers:
549+
assert layer.weight.std().detach().cpu() == pytest.approx(0.02, abs=1e-3)
550+
assert layer.weight.max().detach().cpu() <= 2
551+
assert layer.weight.min().detach().cpu() >= -2
552+
if has_bias:
553+
assert layer.bias is not None
554+
assert layer.bias.std().detach().cpu() == pytest.approx(0.02, abs=1e-3)
555+
assert layer.bias.max().detach().cpu() <= 2
556+
assert layer.bias.min().detach().cpu() >= -2
557+
558+
def _test_c_proj(self, gpt2_block: GPT2Block, has_bias: bool, n_layer: int):
559+
layer = gpt2_block.attn.c_proj
560+
assert layer.weight.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * n_layer), abs=1e-3)
561+
assert layer.weight.max().detach().cpu() <= 2
562+
assert layer.weight.min().detach().cpu() >= -2
563+
564+
if has_bias:
565+
assert layer.bias is not None
566+
assert layer.bias.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * n_layer), abs=1e-3)
567+
assert layer.bias.max().detach().cpu() <= 2
568+
assert layer.bias.min().detach().cpu() >= -2
569+
570+
def _test_swiglu_proj(self, gpt2_block: GPT2Block, has_bias: bool, n_layer: int):
571+
layers = (gpt2_block.mlp.V, gpt2_block.mlp.W_2)
572+
for layer in layers:
573+
assert layer.weight.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * n_layer), abs=1e-3)
574+
assert layer.weight.max().detach().cpu() <= 2
575+
assert layer.weight.min().detach().cpu() >= -2
576+
577+
if has_bias:
578+
# all zero bias
579+
assert layer.bias is not None and torch.all(layer.bias == 0)
580+
581+
layer = gpt2_block.mlp.W
582+
assert layer.weight.std().detach().cpu() == pytest.approx(0.02, abs=1e-3)
583+
assert layer.weight.max().detach().cpu() <= 2
584+
assert layer.weight.min().detach().cpu() >= -2
585+
if has_bias:
586+
assert layer.bias is not None and torch.all(layer.bias == 0)

0 commit comments

Comments
 (0)