4646from ..buffer_cache import BufferCache
4747from ..functional import l2_normalize
4848from ..layer_norm import LayerNormConfig
49- from ..lm_head import LMHeadConfig , LMOutputWithLoss
49+ from ..lm_head import LMHeadConfig , LMLossImplementation , LMOutputWithLoss
5050from ..moe import MoEBase
5151from ..rope import RoPEBuffers , RotaryEmbeddingBase
5252from ..utils import selective_checkpointing_context_fn
@@ -117,6 +117,7 @@ def __init__(
117117 block_overrides : Optional [Dict [int , TransformerBlockConfig ]] = None ,
118118 block_pattern : Optional [List [str ]] = None ,
119119 embed_scale : Optional [float ] = None ,
120+ tie_word_embeddings : bool = False ,
120121 ):
121122 super ().__init__ ()
122123
@@ -160,6 +161,10 @@ def __init__(
160161 d_model = d_model , vocab_size = vocab_size , init_device = init_device
161162 )
162163
164+ self .tie_word_embeddings = tie_word_embeddings
165+ if tie_word_embeddings :
166+ self ._tie_weights ()
167+
163168 self .init_device = init_device
164169 self .init_method = InitMethod (init_method )
165170 self .init_seed = init_seed
@@ -183,6 +188,15 @@ def __init__(
183188 self .num_params
184189 self .num_non_embedding_params
185190
191+ def _tie_weights (self ) -> None :
192+ if self .embeddings is None or self .lm_head is None :
193+ raise OLMoConfigurationError (
194+ "Cannot tie word embeddings without both embeddings and an LM head"
195+ )
196+ if self .lm_head .w_out .bias is not None :
197+ raise OLMoConfigurationError ("Cannot tie word embeddings when the LM head uses a bias" )
198+ self .lm_head .w_out .weight = self .embeddings .weight
199+
186200 def _validate_block (self , block : TransformerBlockBase ) -> TransformerBlockBase :
187201 return block
188202
@@ -295,6 +309,10 @@ def init_weights(
295309 generator = generator ,
296310 )
297311
312+ # Re-establish weight tying since `to_empty` above allocates fresh storage.
313+ if self .tie_word_embeddings :
314+ self ._tie_weights ()
315+
298316 for block in self .blocks .values ():
299317 # This might fail if it's wrapped.
300318 # assert isinstance(block, TransformerBlock)
@@ -345,7 +363,7 @@ def init_weights(
345363 if max_seq_len is not None and att .rope is not None :
346364 att .rope .warmup_cache (max_seq_len , device )
347365
348- if self .lm_head is not None :
366+ if self .lm_head is not None and not self . tie_word_embeddings :
349367 self .init_method .init_final_w_out (
350368 self .lm_head .w_out ,
351369 d_model = self .d_model ,
@@ -616,6 +634,16 @@ def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: Optional[bool] = None):
616634 :param loss_parallel: Set to ``True`` if parallelizing the loss function as well.
617635 :param float8_enabled: Set this to ``True`` if training with float8 linear layers.
618636 """
637+ if self .tie_word_embeddings and (
638+ self .lm_head is None
639+ or self .lm_head .loss_implementation == LMLossImplementation .fused_linear
640+ ):
641+ raise NotImplementedError (
642+ "Tensor parallelism with tied word embeddings requires the default loss "
643+ "implementation; the fused-linear loss replicates the LM head weight, which is "
644+ "incompatible with the vocab-sharded embedding."
645+ )
646+
619647 if float8_enabled is None :
620648 float8_enabled = self .fp8_enabled
621649 elif not float8_enabled and self .fp8_enabled :
@@ -646,6 +674,12 @@ def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: Optional[bool] = None):
646674 if self .lm_head is not None :
647675 self .lm_head .apply_tp (tp_mesh , input_layouts = (Shard (1 ), Replicate ()))
648676
677+ # The embedding (RowwiseParallel) and the LM head (ColwiseParallel) both shard their
678+ # weight along the vocab dimension, so re-point the head at the embedding's sharded
679+ # parameter to restore the tie that `parallelize_module` broke.
680+ if self .tie_word_embeddings and self .embeddings is not None and self .lm_head is not None :
681+ self ._tie_weights ()
682+
649683 self ._tp_enabled = True
650684 self ._tp_mesh = tp_mesh
651685
@@ -831,7 +865,9 @@ def apply_fsdp(
831865 mp_policy = mp_policy ,
832866 )
833867
834- if self .embeddings is not None :
868+ # When weights are tied the embeddings and LM head share a parameter, so they must
869+ # stay in the same FSDP group (the root) rather than being sharded separately.
870+ if self .embeddings is not None and not self .tie_word_embeddings :
835871 fully_shard (
836872 self .embeddings ,
837873 reshard_after_forward = reshard_after_forward ,
@@ -843,7 +879,7 @@ def apply_fsdp(
843879 if wrapping_strategy != TransformerDataParallelWrappingStrategy .blocks :
844880 if self .embedding_norm is not None :
845881 fully_shard (self .embedding_norm , ** fsdp_config )
846- if self .lm_head is not None :
882+ if self .lm_head is not None and not self . tie_word_embeddings :
847883 fully_shard (self .lm_head , reshard_after_forward = False , ** fsdp_config )
848884
849885 fully_shard (self , reshard_after_forward = reshard_after_forward , ** fsdp_config )
0 commit comments