diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index c503ffb84..426d3a82b 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -22,6 +22,7 @@ env: AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }} GOOGLE_CREDENTIALS: ${{ secrets.GOOGLE_CREDENTIALS }} + HF_TOKEN: ${{ secrets.HF_TOKEN }} jobs: checks: diff --git a/CHANGELOG.md b/CHANGELOG.md index a8976a91e..4c4e44fcb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Distributed checkpoint writes now clone each tensor before serialization to avoid accidentally writing the full backing storage of a view/shared tensor, with a guard that raises `OLMoCheckpointError` if a written tensor is unexpectedly larger than its `nbytes`. - Fixed LM in-loop evaluator data-order drift across repeated runs by resetting loader bookkeeping before each pass and making deterministic reshuffling the default. - Fixed Qwen3 implementation to match HuggingFace by applying RoPE in the input dtype (bf16) rather than upcasting to fp32. +- Fixed HF model conversion for Llama, Qwen3, and Gemma so that converted checkpoints roundtrip correctly. - Fixed Beaker secret existence check to use the case-insensitive HTTP endpoint, avoiding spurious "secret not found" errors when secret names differ only in case. - Fixed `Transformer.init_weights` so that under interleaved pipeline parallelism (e.g. `Interleaved1F1B`, `InterleavedZeroBubble`) the multiple model chunks owned by a single rank no longer initialize to identical parameters. Adds a `model_part_idx` kwarg incorporated into the seed as `model_part_idx * pp_size`. - Disabled `torch.compile` tracing through `TEAttentionBackend.forward`, whose Python/pybind setup is not Dynamo-safe. diff --git a/src/olmo_core/nn/hf/config.py b/src/olmo_core/nn/hf/config.py index fd0f90df6..811543d79 100644 --- a/src/olmo_core/nn/hf/config.py +++ b/src/olmo_core/nn/hf/config.py @@ -79,7 +79,7 @@ def _get_flex_olmo_config(model: MoETransformer) -> PretrainedConfig: rms_norm_eps=block.feed_forward_norm.eps, num_experts_per_tok=block.feed_forward_moe.router.top_k, num_experts=block.feed_forward_moe.router.num_experts, - tie_word_embeddings=False, + tie_word_embeddings=model.tie_word_embeddings, ) @@ -133,7 +133,7 @@ def get_hf_config(model: Transformer) -> PretrainedConfig: "bos_token_id": None, "eos_token_id": None, "rms_norm_eps": first_block.feed_forward_norm.eps, - "tie_word_embeddings": False, + "tie_word_embeddings": model.tie_word_embeddings, } # The OLMo 3 model family is identical to the OLMo 2 model family, except: @@ -387,7 +387,7 @@ def get_hybrid_hf_config( "attention_bias": attn.w_out.bias is not None, "attention_dropout": 0.0, "rms_norm_eps": attn_block.feed_forward_norm.eps, # todo: revisit - "tie_word_embeddings": False, + "tie_word_embeddings": model.tie_word_embeddings, # Hybrid layer configuration "layer_types": layer_types, # GDN (linear attention) parameters diff --git a/src/olmo_core/nn/hf/convert.py b/src/olmo_core/nn/hf/convert.py index 18b1b9f8d..4b275413c 100644 --- a/src/olmo_core/nn/hf/convert.py +++ b/src/olmo_core/nn/hf/convert.py @@ -309,7 +309,53 @@ f"model.layers.{LAYER}.mlp.gate.weight", unflatten_dim=(0, (TemplatePlaceholder.EXPERT, -1)), ), - } + }, + "llama": { + f"blocks.{LAYER}.attention_norm.weight": StateMappingTemplate( + f"blocks.{LAYER}.attention_norm.weight", + f"model.layers.{LAYER}.input_layernorm.weight", + state_type=StateType.weight, + ), + f"blocks.{LAYER}.feed_forward_norm.weight": StateMappingTemplate( + f"blocks.{LAYER}.feed_forward_norm.weight", + f"model.layers.{LAYER}.post_attention_layernorm.weight", + state_type=StateType.weight, + ), + }, + "qwen3": { + f"blocks.{LAYER}.attention_norm.weight": StateMappingTemplate( + f"blocks.{LAYER}.attention_norm.weight", + f"model.layers.{LAYER}.input_layernorm.weight", + state_type=StateType.weight, + ), + f"blocks.{LAYER}.feed_forward_norm.weight": StateMappingTemplate( + f"blocks.{LAYER}.feed_forward_norm.weight", + f"model.layers.{LAYER}.post_attention_layernorm.weight", + state_type=StateType.weight, + ), + }, + "gemma3_text": { + f"blocks.{LAYER}.attention_norm.weight": StateMappingTemplate( + f"blocks.{LAYER}.attention_norm.weight", + f"model.layers.{LAYER}.input_layernorm.weight", + state_type=StateType.weight, + ), + f"blocks.{LAYER}.post_attention_norm.weight": StateMappingTemplate( + f"blocks.{LAYER}.post_attention_norm.weight", + f"model.layers.{LAYER}.post_attention_layernorm.weight", + state_type=StateType.weight, + ), + f"blocks.{LAYER}.feed_forward_norm.weight": StateMappingTemplate( + f"blocks.{LAYER}.feed_forward_norm.weight", + f"model.layers.{LAYER}.pre_feedforward_layernorm.weight", + state_type=StateType.weight, + ), + f"blocks.{LAYER}.post_feed_forward_norm.weight": StateMappingTemplate( + f"blocks.{LAYER}.post_feed_forward_norm.weight", + f"model.layers.{LAYER}.post_feedforward_layernorm.weight", + state_type=StateType.weight, + ), + }, } @@ -427,6 +473,9 @@ def convert_state_from_hf( if model_type == "gemma3_text": converted_state = _apply_gemma3_norm_transform(converted_state) + if config.tie_word_embeddings: + converted_state["lm_head.w_out.weight"] = converted_state["embeddings.weight"] + return converted_state @@ -456,6 +505,29 @@ def get_converter_to_hf(model_type: str | None = None) -> StateConverter: return _get_converter_to_hf(model_type) +def _apply_gemma3_norm_inverse_transform(state: Dict[str, Any]) -> Dict[str, Any]: + """ + Inverse of :func:`_apply_gemma3_norm_transform`: subtracts 1 from norm weights + so that an OLMo-core checkpoint can be exported back into HF Gemma 3 format. + """ + norm_patterns = [ + "input_layernorm.weight", + "post_attention_layernorm.weight", + "pre_feedforward_layernorm.weight", + "post_feedforward_layernorm.weight", + "model.norm.weight", + "q_norm.weight", + "k_norm.weight", + ] + + for key, value in state.items(): + if any(pattern in key for pattern in norm_patterns): + if isinstance(value, torch.Tensor): + state[key] = value - 1.0 + + return state + + @beta_feature def convert_state_to_hf( config: PretrainedConfig, olmo_core_state: Dict[str, Any] @@ -468,9 +540,14 @@ def convert_state_to_hf( :class:`DTensor` or :class:`ShardedTensor` """ - converter = _get_converter_to_hf(getattr(config, "model_type", None)) + converter = _get_converter_to_hf(config.model_type) + + converted_state = _convert_state(config, olmo_core_state, converter) - return _convert_state(config, olmo_core_state, converter) + if config.model_type == "gemma3_text": + converted_state = _apply_gemma3_norm_inverse_transform(converted_state) + + return converted_state # --------------------------------------------------------------------------- diff --git a/src/olmo_core/nn/transformer/config.py b/src/olmo_core/nn/transformer/config.py index 3626f3559..4ec9fbf97 100644 --- a/src/olmo_core/nn/transformer/config.py +++ b/src/olmo_core/nn/transformer/config.py @@ -328,8 +328,13 @@ class TransformerConfig(ModelConfig): block_pattern: Optional[List[str]] = None block_overrides: Optional[Dict[int, TransformerBlockConfig]] = None embed_scale: Optional[float] = None + tie_word_embeddings: bool = False def __post_init__(self): + if self.tie_word_embeddings and self.name == TransformerType.normalized: + raise OLMoConfigurationError( + "Tying word embeddings is not supported with the normalized transformer" + ) validate_block_resolution_config( n_layers=self.n_layers, block=self.block, @@ -380,6 +385,7 @@ def build( block_overrides=self.block_overrides, block_pattern=self.block_pattern, embed_scale=self.embed_scale, + tie_word_embeddings=self.tie_word_embeddings, ) elif self.name == TransformerType.normalized: assert self.embedding_norm is None @@ -414,6 +420,7 @@ def build( embedding_init_std=self.embedding_init_std, block_overrides=self.block_overrides, block_pattern=self.block_pattern, + tie_word_embeddings=self.tie_word_embeddings, ) else: raise NotImplementedError(self.name) @@ -466,6 +473,10 @@ def num_params(self) -> int: # LM head. num_params += self.lm_head.num_params(self.d_model, self.vocab_size) + # The LM head weight is shared with the embeddings when tied. + if self.tie_word_embeddings: + num_params -= self.d_model * self.vocab_size + return num_params @property @@ -487,6 +498,10 @@ def num_active_params(self) -> int: # LM head. num_active_params += self.lm_head.num_params(self.d_model, self.vocab_size) + # The LM head weight is shared with the embeddings when tied. + if self.tie_word_embeddings: + num_active_params -= self.d_model * self.vocab_size + return num_active_params @property @@ -1301,6 +1316,7 @@ def qwen3_0_6B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": feed_forward=FeedForwardConfig( hidden_size=3072, bias=False, dtype=kwargs.get("dtype", DType.float32) ), + tie_word_embeddings=kwargs.pop("tie_word_embeddings", True), **kwargs, ) @@ -1337,6 +1353,7 @@ def qwen3_1_7B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": feed_forward=FeedForwardConfig( hidden_size=6144, bias=False, dtype=kwargs.get("dtype", DType.float32) ), + tie_word_embeddings=kwargs.pop("tie_word_embeddings", True), **kwargs, ) @@ -1373,6 +1390,7 @@ def qwen3_4B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": feed_forward=FeedForwardConfig( hidden_size=9728, bias=False, dtype=kwargs.get("dtype", DType.float32) ), + tie_word_embeddings=kwargs.pop("tie_word_embeddings", True), **kwargs, ) diff --git a/src/olmo_core/nn/transformer/model.py b/src/olmo_core/nn/transformer/model.py index 81c95ed31..65b26f5c6 100644 --- a/src/olmo_core/nn/transformer/model.py +++ b/src/olmo_core/nn/transformer/model.py @@ -46,7 +46,7 @@ from ..buffer_cache import BufferCache from ..functional import l2_normalize from ..layer_norm import LayerNormConfig -from ..lm_head import LMHeadConfig, LMOutputWithLoss +from ..lm_head import LMHeadConfig, LMLossImplementation, LMOutputWithLoss from ..moe import MoEBase from ..rope import RoPEBuffers, RotaryEmbeddingBase from ..utils import selective_checkpointing_context_fn @@ -117,6 +117,7 @@ def __init__( block_overrides: Optional[Dict[int, TransformerBlockConfig]] = None, block_pattern: Optional[List[str]] = None, embed_scale: Optional[float] = None, + tie_word_embeddings: bool = False, ): super().__init__() @@ -160,6 +161,10 @@ def __init__( d_model=d_model, vocab_size=vocab_size, init_device=init_device ) + self.tie_word_embeddings = tie_word_embeddings + if tie_word_embeddings: + self._tie_weights() + self.init_device = init_device self.init_method = InitMethod(init_method) self.init_seed = init_seed @@ -183,6 +188,15 @@ def __init__( self.num_params self.num_non_embedding_params + def _tie_weights(self) -> None: + if self.embeddings is None or self.lm_head is None: + raise OLMoConfigurationError( + "Cannot tie word embeddings without both embeddings and an LM head" + ) + if self.lm_head.w_out.bias is not None: + raise OLMoConfigurationError("Cannot tie word embeddings when the LM head uses a bias") + self.lm_head.w_out.weight = self.embeddings.weight + def _validate_block(self, block: TransformerBlockBase) -> TransformerBlockBase: return block @@ -295,6 +309,10 @@ def init_weights( generator=generator, ) + # Re-establish weight tying since `to_empty` above allocates fresh storage. + if self.tie_word_embeddings: + self._tie_weights() + for block in self.blocks.values(): # This might fail if it's wrapped. # assert isinstance(block, TransformerBlock) @@ -345,7 +363,7 @@ def init_weights( if max_seq_len is not None and att.rope is not None: att.rope.warmup_cache(max_seq_len, device) - if self.lm_head is not None: + if self.lm_head is not None and not self.tie_word_embeddings: self.init_method.init_final_w_out( self.lm_head.w_out, d_model=self.d_model, @@ -616,6 +634,16 @@ def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: Optional[bool] = None): :param loss_parallel: Set to ``True`` if parallelizing the loss function as well. :param float8_enabled: Set this to ``True`` if training with float8 linear layers. """ + if self.tie_word_embeddings and ( + self.lm_head is None + or self.lm_head.loss_implementation == LMLossImplementation.fused_linear + ): + raise NotImplementedError( + "Tensor parallelism with tied word embeddings requires the default loss " + "implementation; the fused-linear loss replicates the LM head weight, which is " + "incompatible with the vocab-sharded embedding." + ) + if float8_enabled is None: float8_enabled = self.fp8_enabled elif not float8_enabled and self.fp8_enabled: @@ -646,6 +674,12 @@ def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: Optional[bool] = None): if self.lm_head is not None: self.lm_head.apply_tp(tp_mesh, input_layouts=(Shard(1), Replicate())) + # The embedding (RowwiseParallel) and the LM head (ColwiseParallel) both shard their + # weight along the vocab dimension, so re-point the head at the embedding's sharded + # parameter to restore the tie that `parallelize_module` broke. + if self.tie_word_embeddings and self.embeddings is not None and self.lm_head is not None: + self._tie_weights() + self._tp_enabled = True self._tp_mesh = tp_mesh @@ -831,7 +865,9 @@ def apply_fsdp( mp_policy=mp_policy, ) - if self.embeddings is not None: + # When weights are tied the embeddings and LM head share a parameter, so they must + # stay in the same FSDP group (the root) rather than being sharded separately. + if self.embeddings is not None and not self.tie_word_embeddings: fully_shard( self.embeddings, reshard_after_forward=reshard_after_forward, @@ -843,7 +879,7 @@ def apply_fsdp( if wrapping_strategy != TransformerDataParallelWrappingStrategy.blocks: if self.embedding_norm is not None: fully_shard(self.embedding_norm, **fsdp_config) - if self.lm_head is not None: + if self.lm_head is not None and not self.tie_word_embeddings: fully_shard(self.lm_head, reshard_after_forward=False, **fsdp_config) fully_shard(self, reshard_after_forward=reshard_after_forward, **fsdp_config) diff --git a/src/olmo_core/train/train_module/transformer/config.py b/src/olmo_core/train/train_module/transformer/config.py index ab60836e8..b76f7d3db 100644 --- a/src/olmo_core/train/train_module/transformer/config.py +++ b/src/olmo_core/train/train_module/transformer/config.py @@ -93,6 +93,15 @@ def split_model( self, model: Transformer, *, pp_mesh: DeviceMesh, device: torch.device ) -> Tuple[List[PipelineStage], List[Transformer]]: split_points = self.get_split_points(model.n_layers) + num_stages = len(split_points) + 1 + + if num_stages > 1 and model.tie_word_embeddings: + raise NotImplementedError( + "Pipeline parallelism with tied word embeddings is not supported: the input " + "embeddings and LM head are placed on different pipeline stages, so they cannot " + "share a weight." + ) + pp_rank = pp_mesh.get_local_rank() def build_stage( @@ -128,7 +137,6 @@ def build_stage( ) return stage, model_chunk - num_stages = len(split_points) + 1 stage_idx = pp_rank stages = [] diff --git a/src/test/nn/hf/convert_test.py b/src/test/nn/hf/convert_test.py index 66ae0e730..bfcfd3d4b 100644 --- a/src/test/nn/hf/convert_test.py +++ b/src/test/nn/hf/convert_test.py @@ -1,6 +1,6 @@ import pytest import torch -from transformers import Olmo2Config +from transformers import AutoConfig, AutoModelForCausalLM, Olmo2Config from olmo_core.nn.hf.convert import convert_state_from_hf, convert_state_to_hf @@ -118,6 +118,32 @@ def test_convert_state_from_hf_and_flatten(): ) +def test_convert_state_from_hf_ties_word_embeddings(): + hf_config = AutoConfig.for_model( + "qwen3", + vocab_size=64, + hidden_size=16, + intermediate_size=32, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + max_position_embeddings=64, + tie_word_embeddings=True, + ) + + # A tied HF checkpoint omits `lm_head.weight`. + hf_state = { + "model.embed_tokens.weight": torch.randn(hf_config.vocab_size, hf_config.hidden_size) + } + + converted_state = convert_state_from_hf(hf_config, hf_state, model_type="qwen3") + + assert "lm_head.w_out.weight" in converted_state + torch.testing.assert_close( + converted_state["lm_head.w_out.weight"], converted_state["embeddings.weight"] + ) + + def test_convert_state_to_hf(): hf_config = _get_olmo2_config() @@ -226,3 +252,65 @@ def test_convert_state_to_flex_olmo_hf(): converted_state[f"model.layers.{i}.mlp.gate.weight"].flatten(), olmo_core_state[f"blocks.{i}.feed_forward_moe.router.weight"], ) + + +def test_llama_tiny_roundtrip_pre_norm(): + hf_config = AutoConfig.for_model( + "llama", + vocab_size=64, + hidden_size=16, + intermediate_size=32, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + max_position_embeddings=64, + ) + n, h = hf_config.num_hidden_layers, hf_config.hidden_size + olmo_core_state = {} + for i in range(n): + olmo_core_state[f"blocks.{i}.attention_norm.weight"] = torch.full((h,), 1.0 + i) + olmo_core_state[f"blocks.{i}.feed_forward_norm.weight"] = torch.full((h,), 100.0 + i) + + hf_roundtrip = convert_state_to_hf(hf_config, olmo_core_state) + for i in range(n): + torch.testing.assert_close( + hf_roundtrip[f"model.layers.{i}.input_layernorm.weight"], + olmo_core_state[f"blocks.{i}.attention_norm.weight"], + ) + torch.testing.assert_close( + hf_roundtrip[f"model.layers.{i}.post_attention_layernorm.weight"], + olmo_core_state[f"blocks.{i}.feed_forward_norm.weight"], + ) + assert f"model.layers.{i}.post_feedforward_layernorm.weight" not in hf_roundtrip + + +@pytest.mark.parametrize( + "model_id, model_type", + [ + pytest.param("Qwen/Qwen3-0.6B", "qwen3", id="qwen3"), + pytest.param("google/gemma-3-270m", "gemma3_text", id="gemma3"), + ], +) +def test_logprobs_match_after_roundtrip(model_id: str, model_type: str): + hf_config = AutoConfig.from_pretrained(model_id) + hf_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32) + hf_model.eval() + + input_ids = torch.randint( + 0, hf_config.vocab_size, (1, 8), generator=torch.Generator().manual_seed(0) + ) + with torch.no_grad(): + ref_logits = hf_model(input_ids).logits + ref_logprobs = torch.log_softmax(ref_logits, dim=-1) + + hf_state = {k: v.detach().cpu().clone() for k, v in hf_model.state_dict().items()} + oc_state = convert_state_from_hf(hf_config, hf_state, model_type=model_type) + hf_roundtrip = convert_state_to_hf(hf_config, oc_state) + + hf_model.load_state_dict(hf_roundtrip, strict=True) + hf_model.eval() + with torch.no_grad(): + rt_logits = hf_model(input_ids).logits + rt_logprobs = torch.log_softmax(rt_logits, dim=-1) + + torch.testing.assert_close(rt_logprobs, ref_logprobs, rtol=1e-5, atol=1e-5) diff --git a/src/test/nn/transformer/model_test.py b/src/test/nn/transformer/model_test.py index b7e685e4b..d174eeff2 100644 --- a/src/test/nn/transformer/model_test.py +++ b/src/test/nn/transformer/model_test.py @@ -20,6 +20,7 @@ build_world_mesh, ) from olmo_core.distributed.utils import get_full_tensor, get_world_size +from olmo_core.exceptions import OLMoConfigurationError from olmo_core.nn.attention import ( AttentionBackendName, AttentionConfig, @@ -56,6 +57,9 @@ run_distributed_test, ) from olmo_core.testing.utils import FLA_MARKS, has_fla +from olmo_core.train.train_module.transformer.config import ( + TransformerPipelineParallelConfig, +) from olmo_core.utils import get_default_device, seed_all log = logging.getLogger(__name__) @@ -679,3 +683,119 @@ def test_qwen3_builder_configs(config_builder, expected_d_model): num_actual_params = sum(p.numel() for p in model.parameters()) assert config.num_params == num_actual_params assert model.num_params == num_actual_params + + +@pytest.mark.parametrize( + "config_builder, expected_tie", + [ + pytest.param(TransformerConfig.qwen3_0_6B, True, id="qwen3_0_6B"), + pytest.param(TransformerConfig.qwen3_1_7B, True, id="qwen3_1_7B"), + pytest.param(TransformerConfig.qwen3_4B, True, id="qwen3_4B"), + pytest.param(TransformerConfig.qwen3_8B, False, id="qwen3_8B"), + pytest.param(TransformerConfig.qwen3_14B, False, id="qwen3_14B"), + pytest.param(TransformerConfig.qwen3_32B, False, id="qwen3_32B"), + ], +) +def test_qwen3_small_sizes_tie_word_embeddings(config_builder, expected_tie): + assert config_builder(vocab_size=128, n_layers=2).tie_word_embeddings == expected_tie + + +def test_qwen3_tie_word_embeddings_can_be_overridden(): + config = TransformerConfig.qwen3_0_6B(vocab_size=128, n_layers=2, tie_word_embeddings=False) + assert not config.tie_word_embeddings + + +def test_tied_word_embeddings_share_weight_after_init(): + config = TransformerConfig.qwen3_0_6B(vocab_size=128, n_layers=2) + model = config.build(init_device="cpu") + model.init_weights(device=torch.device("cpu")) + + assert model.tie_word_embeddings + # The tie must survive `init_weights`, which calls `to_empty`. + assert model.lm_head.w_out.weight is model.embeddings.weight + + # The shared weight is only counted once. + num_actual_params = sum(p.numel() for p in model.parameters()) + assert config.num_params == num_actual_params + assert model.num_params == num_actual_params + + +def test_normalized_transformer_rejects_tied_word_embeddings(): + with pytest.raises(OLMoConfigurationError): + TransformerConfig.ngpt_271M(vocab_size=128, n_layers=2, tie_word_embeddings=True) + + +def test_pipeline_parallel_rejects_tied_word_embeddings(): + config = TransformerConfig.qwen3_0_6B(vocab_size=128, n_layers=2) + model = config.build(init_device="cpu") + pp_config = TransformerPipelineParallelConfig(degree=1, split_points=[1]) + + with pytest.raises(NotImplementedError, match="tied word embeddings"): + pp_config.split_model(model, pp_mesh=None, device=torch.device("cpu")) + + +def run_tensor_parallel_tied_word_embeddings(): + device = get_default_device() + config = TransformerConfig.llama2_271M( + vocab_size=16_000, n_layers=2, fused_ops=False, tie_word_embeddings=True + ) + mesh = init_device_mesh(device.type, (get_world_size(),), mesh_dim_names=("tp",)) + + model = config.build() + model.apply_tp(mesh["tp"]) + model.init_weights(device=device, max_seq_len=512) + + assert model.tie_word_embeddings + assert isinstance(model.embeddings.weight, DTensor) + # The tie survives `apply_tp` (which converts the weight to a sharded DTensor) and + # `init_weights` (which calls `to_empty`): both modules share one parameter. + assert model.lm_head.w_out.weight is model.embeddings.weight + + input_ids = get_transformer_inputs().to(device) + logits = model(input_ids=input_ids) + logits.sum().backward() + + # Gradients from the embedding lookup and the output projection accumulate into the single + # shared parameter. + assert model.embeddings.weight.grad is not None + + +@pytest.mark.parametrize("backend", BACKENDS) +def test_tensor_parallel_tied_word_embeddings(backend: str): + run_distributed_test( + run_tensor_parallel_tied_word_embeddings, + backend=backend, + start_method="spawn", + ) + + +def run_fsdp_tied_word_embeddings(): + device = get_default_device() + config = TransformerConfig.llama2_271M( + vocab_size=16_000, n_layers=2, fused_ops=False, tie_word_embeddings=True + ) + + model = config.build(init_device="meta") + model.apply_fsdp() + model.init_weights(device=device, max_seq_len=512) + + assert model.tie_word_embeddings + assert isinstance(model.embeddings.weight, DTensor) + # The embeddings and LM head are not sharded into separate FSDP groups when tied, so they + # stay in the root group and keep sharing one parameter. + assert model.lm_head.w_out.weight is model.embeddings.weight + + input_ids = get_transformer_inputs().to(device) + logits = model(input_ids=input_ids) + logits.sum().backward() + + assert model.embeddings.weight.grad is not None + + +@requires_multi_gpu +def test_fsdp_tied_word_embeddings(): + run_distributed_test( + run_fsdp_tied_word_embeddings, + backend="nccl", + start_method="spawn", + )