Skip to content

Commit 0021abd

Browse files
Implement tied LM head & word embeddings for Qwen3 (#686)
Implements tied LM head & word embeddings for Qwen3. The three sizes that Qwen ships tied (0.6B, 1.7B, 4B) now default to tying; 8B/14B/32B stay untied. The HF import path is tie-aware.
1 parent b7c58c3 commit 0021abd

7 files changed

Lines changed: 229 additions & 62 deletions

File tree

src/olmo_core/nn/hf/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def _get_flex_olmo_config(model: MoETransformer) -> PretrainedConfig:
7979
rms_norm_eps=block.feed_forward_norm.eps,
8080
num_experts_per_tok=block.feed_forward_moe.router.top_k,
8181
num_experts=block.feed_forward_moe.router.num_experts,
82-
tie_word_embeddings=False,
82+
tie_word_embeddings=model.tie_word_embeddings,
8383
)
8484

8585

@@ -133,7 +133,7 @@ def get_hf_config(model: Transformer) -> PretrainedConfig:
133133
"bos_token_id": None,
134134
"eos_token_id": None,
135135
"rms_norm_eps": first_block.feed_forward_norm.eps,
136-
"tie_word_embeddings": False,
136+
"tie_word_embeddings": model.tie_word_embeddings,
137137
}
138138

139139
# The OLMo 3 model family is identical to the OLMo 2 model family, except:
@@ -387,7 +387,7 @@ def get_hybrid_hf_config(
387387
"attention_bias": attn.w_out.bias is not None,
388388
"attention_dropout": 0.0,
389389
"rms_norm_eps": attn_block.feed_forward_norm.eps, # todo: revisit
390-
"tie_word_embeddings": False,
390+
"tie_word_embeddings": model.tie_word_embeddings,
391391
# Hybrid layer configuration
392392
"layer_types": layer_types,
393393
# GDN (linear attention) parameters

src/olmo_core/nn/hf/convert.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,9 @@ def convert_state_from_hf(
473473
if model_type == "gemma3_text":
474474
converted_state = _apply_gemma3_norm_transform(converted_state)
475475

476+
if config.tie_word_embeddings:
477+
converted_state["lm_head.w_out.weight"] = converted_state["embeddings.weight"]
478+
476479
return converted_state
477480

478481

@@ -537,12 +540,11 @@ def convert_state_to_hf(
537540
:class:`DTensor` or :class:`ShardedTensor`
538541
"""
539542

540-
model_type = getattr(config, "model_type", None)
541-
converter = _get_converter_to_hf(model_type)
543+
converter = _get_converter_to_hf(config.model_type)
542544

543545
converted_state = _convert_state(config, olmo_core_state, converter)
544546

545-
if model_type == "gemma3_text":
547+
if config.model_type == "gemma3_text":
546548
converted_state = _apply_gemma3_norm_inverse_transform(converted_state)
547549

548550
return converted_state

src/olmo_core/nn/transformer/config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,13 @@ class TransformerConfig(ModelConfig):
328328
block_pattern: Optional[List[str]] = None
329329
block_overrides: Optional[Dict[int, TransformerBlockConfig]] = None
330330
embed_scale: Optional[float] = None
331+
tie_word_embeddings: bool = False
331332

332333
def __post_init__(self):
334+
if self.tie_word_embeddings and self.name == TransformerType.normalized:
335+
raise OLMoConfigurationError(
336+
"Tying word embeddings is not supported with the normalized transformer"
337+
)
333338
validate_block_resolution_config(
334339
n_layers=self.n_layers,
335340
block=self.block,
@@ -380,6 +385,7 @@ def build(
380385
block_overrides=self.block_overrides,
381386
block_pattern=self.block_pattern,
382387
embed_scale=self.embed_scale,
388+
tie_word_embeddings=self.tie_word_embeddings,
383389
)
384390
elif self.name == TransformerType.normalized:
385391
assert self.embedding_norm is None
@@ -414,6 +420,7 @@ def build(
414420
embedding_init_std=self.embedding_init_std,
415421
block_overrides=self.block_overrides,
416422
block_pattern=self.block_pattern,
423+
tie_word_embeddings=self.tie_word_embeddings,
417424
)
418425
else:
419426
raise NotImplementedError(self.name)
@@ -466,6 +473,10 @@ def num_params(self) -> int:
466473
# LM head.
467474
num_params += self.lm_head.num_params(self.d_model, self.vocab_size)
468475

476+
# The LM head weight is shared with the embeddings when tied.
477+
if self.tie_word_embeddings:
478+
num_params -= self.d_model * self.vocab_size
479+
469480
return num_params
470481

471482
@property
@@ -487,6 +498,10 @@ def num_active_params(self) -> int:
487498
# LM head.
488499
num_active_params += self.lm_head.num_params(self.d_model, self.vocab_size)
489500

501+
# The LM head weight is shared with the embeddings when tied.
502+
if self.tie_word_embeddings:
503+
num_active_params -= self.d_model * self.vocab_size
504+
490505
return num_active_params
491506

492507
@property
@@ -1301,6 +1316,7 @@ def qwen3_0_6B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
13011316
feed_forward=FeedForwardConfig(
13021317
hidden_size=3072, bias=False, dtype=kwargs.get("dtype", DType.float32)
13031318
),
1319+
tie_word_embeddings=kwargs.pop("tie_word_embeddings", True),
13041320
**kwargs,
13051321
)
13061322

@@ -1337,6 +1353,7 @@ def qwen3_1_7B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
13371353
feed_forward=FeedForwardConfig(
13381354
hidden_size=6144, bias=False, dtype=kwargs.get("dtype", DType.float32)
13391355
),
1356+
tie_word_embeddings=kwargs.pop("tie_word_embeddings", True),
13401357
**kwargs,
13411358
)
13421359

@@ -1373,6 +1390,7 @@ def qwen3_4B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
13731390
feed_forward=FeedForwardConfig(
13741391
hidden_size=9728, bias=False, dtype=kwargs.get("dtype", DType.float32)
13751392
),
1393+
tie_word_embeddings=kwargs.pop("tie_word_embeddings", True),
13761394
**kwargs,
13771395
)
13781396

src/olmo_core/nn/transformer/model.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from ..buffer_cache import BufferCache
4747
from ..functional import l2_normalize
4848
from ..layer_norm import LayerNormConfig
49-
from ..lm_head import LMHeadConfig, LMOutputWithLoss
49+
from ..lm_head import LMHeadConfig, LMLossImplementation, LMOutputWithLoss
5050
from ..moe import MoEBase
5151
from ..rope import RoPEBuffers, RotaryEmbeddingBase
5252
from ..utils import selective_checkpointing_context_fn
@@ -117,6 +117,7 @@ def __init__(
117117
block_overrides: Optional[Dict[int, TransformerBlockConfig]] = None,
118118
block_pattern: Optional[List[str]] = None,
119119
embed_scale: Optional[float] = None,
120+
tie_word_embeddings: bool = False,
120121
):
121122
super().__init__()
122123

@@ -160,6 +161,10 @@ def __init__(
160161
d_model=d_model, vocab_size=vocab_size, init_device=init_device
161162
)
162163

164+
self.tie_word_embeddings = tie_word_embeddings
165+
if tie_word_embeddings:
166+
self._tie_weights()
167+
163168
self.init_device = init_device
164169
self.init_method = InitMethod(init_method)
165170
self.init_seed = init_seed
@@ -183,6 +188,15 @@ def __init__(
183188
self.num_params
184189
self.num_non_embedding_params
185190

191+
def _tie_weights(self) -> None:
192+
if self.embeddings is None or self.lm_head is None:
193+
raise OLMoConfigurationError(
194+
"Cannot tie word embeddings without both embeddings and an LM head"
195+
)
196+
if self.lm_head.w_out.bias is not None:
197+
raise OLMoConfigurationError("Cannot tie word embeddings when the LM head uses a bias")
198+
self.lm_head.w_out.weight = self.embeddings.weight
199+
186200
def _validate_block(self, block: TransformerBlockBase) -> TransformerBlockBase:
187201
return block
188202

@@ -295,6 +309,10 @@ def init_weights(
295309
generator=generator,
296310
)
297311

312+
# Re-establish weight tying since `to_empty` above allocates fresh storage.
313+
if self.tie_word_embeddings:
314+
self._tie_weights()
315+
298316
for block in self.blocks.values():
299317
# This might fail if it's wrapped.
300318
# assert isinstance(block, TransformerBlock)
@@ -345,7 +363,7 @@ def init_weights(
345363
if max_seq_len is not None and att.rope is not None:
346364
att.rope.warmup_cache(max_seq_len, device)
347365

348-
if self.lm_head is not None:
366+
if self.lm_head is not None and not self.tie_word_embeddings:
349367
self.init_method.init_final_w_out(
350368
self.lm_head.w_out,
351369
d_model=self.d_model,
@@ -616,6 +634,16 @@ def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: Optional[bool] = None):
616634
:param loss_parallel: Set to ``True`` if parallelizing the loss function as well.
617635
:param float8_enabled: Set this to ``True`` if training with float8 linear layers.
618636
"""
637+
if self.tie_word_embeddings and (
638+
self.lm_head is None
639+
or self.lm_head.loss_implementation == LMLossImplementation.fused_linear
640+
):
641+
raise NotImplementedError(
642+
"Tensor parallelism with tied word embeddings requires the default loss "
643+
"implementation; the fused-linear loss replicates the LM head weight, which is "
644+
"incompatible with the vocab-sharded embedding."
645+
)
646+
619647
if float8_enabled is None:
620648
float8_enabled = self.fp8_enabled
621649
elif not float8_enabled and self.fp8_enabled:
@@ -646,6 +674,12 @@ def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: Optional[bool] = None):
646674
if self.lm_head is not None:
647675
self.lm_head.apply_tp(tp_mesh, input_layouts=(Shard(1), Replicate()))
648676

677+
# The embedding (RowwiseParallel) and the LM head (ColwiseParallel) both shard their
678+
# weight along the vocab dimension, so re-point the head at the embedding's sharded
679+
# parameter to restore the tie that `parallelize_module` broke.
680+
if self.tie_word_embeddings and self.embeddings is not None and self.lm_head is not None:
681+
self._tie_weights()
682+
649683
self._tp_enabled = True
650684
self._tp_mesh = tp_mesh
651685

@@ -831,7 +865,9 @@ def apply_fsdp(
831865
mp_policy=mp_policy,
832866
)
833867

834-
if self.embeddings is not None:
868+
# When weights are tied the embeddings and LM head share a parameter, so they must
869+
# stay in the same FSDP group (the root) rather than being sharded separately.
870+
if self.embeddings is not None and not self.tie_word_embeddings:
835871
fully_shard(
836872
self.embeddings,
837873
reshard_after_forward=reshard_after_forward,
@@ -843,7 +879,7 @@ def apply_fsdp(
843879
if wrapping_strategy != TransformerDataParallelWrappingStrategy.blocks:
844880
if self.embedding_norm is not None:
845881
fully_shard(self.embedding_norm, **fsdp_config)
846-
if self.lm_head is not None:
882+
if self.lm_head is not None and not self.tie_word_embeddings:
847883
fully_shard(self.lm_head, reshard_after_forward=False, **fsdp_config)
848884

849885
fully_shard(self, reshard_after_forward=reshard_after_forward, **fsdp_config)

src/olmo_core/train/train_module/transformer/config.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,15 @@ def split_model(
9393
self, model: Transformer, *, pp_mesh: DeviceMesh, device: torch.device
9494
) -> Tuple[List[PipelineStage], List[Transformer]]:
9595
split_points = self.get_split_points(model.n_layers)
96+
num_stages = len(split_points) + 1
97+
98+
if num_stages > 1 and model.tie_word_embeddings:
99+
raise NotImplementedError(
100+
"Pipeline parallelism with tied word embeddings is not supported: the input "
101+
"embeddings and LM head are placed on different pipeline stages, so they cannot "
102+
"share a weight."
103+
)
104+
96105
pp_rank = pp_mesh.get_local_rank()
97106

98107
def build_stage(
@@ -128,7 +137,6 @@ def build_stage(
128137
)
129138
return stage, model_chunk
130139

131-
num_stages = len(split_points) + 1
132140
stage_idx = pp_rank
133141

134142
stages = []

src/test/nn/hf/convert_test.py

Lines changed: 34 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,32 @@ def test_convert_state_from_hf_and_flatten():
118118
)
119119

120120

121+
def test_convert_state_from_hf_ties_word_embeddings():
122+
hf_config = AutoConfig.for_model(
123+
"qwen3",
124+
vocab_size=64,
125+
hidden_size=16,
126+
intermediate_size=32,
127+
num_hidden_layers=2,
128+
num_attention_heads=2,
129+
num_key_value_heads=2,
130+
max_position_embeddings=64,
131+
tie_word_embeddings=True,
132+
)
133+
134+
# A tied HF checkpoint omits `lm_head.weight`.
135+
hf_state = {
136+
"model.embed_tokens.weight": torch.randn(hf_config.vocab_size, hf_config.hidden_size)
137+
}
138+
139+
converted_state = convert_state_from_hf(hf_config, hf_state, model_type="qwen3")
140+
141+
assert "lm_head.w_out.weight" in converted_state
142+
torch.testing.assert_close(
143+
converted_state["lm_head.w_out.weight"], converted_state["embeddings.weight"]
144+
)
145+
146+
121147
def test_convert_state_to_hf():
122148
hf_config = _get_olmo2_config()
123149

@@ -228,35 +254,6 @@ def test_convert_state_to_flex_olmo_hf():
228254
)
229255

230256

231-
def _roundtrip_norm_keys(model_id, model_type, norm_suffixes, forbidden=()):
232-
hf_config = AutoConfig.from_pretrained(model_id)
233-
hf_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
234-
hf_state = {k: v.detach().cpu() for k, v in hf_model.state_dict().items()}
235-
del hf_model
236-
237-
oc_state = convert_state_from_hf(hf_config, hf_state, model_type=model_type)
238-
hf_roundtrip = convert_state_to_hf(hf_config, oc_state)
239-
240-
for i in range(hf_config.num_hidden_layers):
241-
for suffix in norm_suffixes:
242-
key = f"model.layers.{i}.{suffix}"
243-
assert key in hf_roundtrip, f"missing {key} in round-tripped state"
244-
torch.testing.assert_close(hf_roundtrip[key], hf_state[key])
245-
for suffix in forbidden:
246-
assert (
247-
f"model.layers.{i}.{suffix}" not in hf_roundtrip
248-
), f"unexpected key model.layers.{i}.{suffix}"
249-
250-
251-
def test_qwen3_0_6b_roundtrip_pre_norm():
252-
_roundtrip_norm_keys(
253-
"Qwen/Qwen3-0.6B",
254-
model_type="qwen3",
255-
norm_suffixes=("input_layernorm.weight", "post_attention_layernorm.weight"),
256-
forbidden=("post_feedforward_layernorm.weight",),
257-
)
258-
259-
260257
def test_llama_tiny_roundtrip_pre_norm():
261258
hf_config = AutoConfig.for_model(
262259
"llama",
@@ -287,20 +284,14 @@ def test_llama_tiny_roundtrip_pre_norm():
287284
assert f"model.layers.{i}.post_feedforward_layernorm.weight" not in hf_roundtrip
288285

289286

290-
def test_gemma3_270m_roundtrip_pre_norm():
291-
_roundtrip_norm_keys(
292-
"google/gemma-3-270m",
293-
model_type="gemma3_text",
294-
norm_suffixes=(
295-
"input_layernorm.weight",
296-
"post_attention_layernorm.weight",
297-
"pre_feedforward_layernorm.weight",
298-
"post_feedforward_layernorm.weight",
299-
),
300-
)
301-
302-
303-
def _assert_logprobs_match_after_roundtrip(model_id: str, model_type: str):
287+
@pytest.mark.parametrize(
288+
"model_id, model_type",
289+
[
290+
pytest.param("Qwen/Qwen3-0.6B", "qwen3", id="qwen3"),
291+
pytest.param("google/gemma-3-270m", "gemma3_text", id="gemma3"),
292+
],
293+
)
294+
def test_logprobs_match_after_roundtrip(model_id: str, model_type: str):
304295
hf_config = AutoConfig.from_pretrained(model_id)
305296
hf_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
306297
hf_model.eval()
@@ -323,11 +314,3 @@ def _assert_logprobs_match_after_roundtrip(model_id: str, model_type: str):
323314
rt_logprobs = torch.log_softmax(rt_logits, dim=-1)
324315

325316
torch.testing.assert_close(rt_logprobs, ref_logprobs, rtol=1e-5, atol=1e-5)
326-
327-
328-
def test_qwen3_0_6b_logprobs_roundtrip():
329-
_assert_logprobs_match_after_roundtrip("Qwen/Qwen3-0.6B", model_type="qwen3")
330-
331-
332-
def test_gemma3_270m_logprobs_roundtrip():
333-
_assert_logprobs_match_after_roundtrip("google/gemma-3-270m", model_type="gemma3_text")

0 commit comments

Comments
 (0)