Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/olmo_core/nn/hf/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _get_flex_olmo_config(model: MoETransformer) -> PretrainedConfig:
rms_norm_eps=block.feed_forward_norm.eps,
num_experts_per_tok=block.feed_forward_moe.router.top_k,
num_experts=block.feed_forward_moe.router.num_experts,
tie_word_embeddings=False,
tie_word_embeddings=model.tie_word_embeddings,
)


Expand Down Expand Up @@ -133,7 +133,7 @@ def get_hf_config(model: Transformer) -> PretrainedConfig:
"bos_token_id": None,
"eos_token_id": None,
"rms_norm_eps": first_block.feed_forward_norm.eps,
"tie_word_embeddings": False,
"tie_word_embeddings": model.tie_word_embeddings,
}

# The OLMo 3 model family is identical to the OLMo 2 model family, except:
Expand Down Expand Up @@ -387,7 +387,7 @@ def get_hybrid_hf_config(
"attention_bias": attn.w_out.bias is not None,
"attention_dropout": 0.0,
"rms_norm_eps": attn_block.feed_forward_norm.eps, # todo: revisit
"tie_word_embeddings": False,
"tie_word_embeddings": model.tie_word_embeddings,
# Hybrid layer configuration
"layer_types": layer_types,
# GDN (linear attention) parameters
Expand Down
8 changes: 5 additions & 3 deletions src/olmo_core/nn/hf/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,9 @@ def convert_state_from_hf(
if model_type == "gemma3_text":
converted_state = _apply_gemma3_norm_transform(converted_state)

if config.tie_word_embeddings:
converted_state["lm_head.w_out.weight"] = converted_state["embeddings.weight"]

return converted_state


Expand Down Expand Up @@ -537,12 +540,11 @@ def convert_state_to_hf(
:class:`DTensor` or :class:`ShardedTensor`
"""

model_type = getattr(config, "model_type", None)
converter = _get_converter_to_hf(model_type)
converter = _get_converter_to_hf(config.model_type)

converted_state = _convert_state(config, olmo_core_state, converter)

if model_type == "gemma3_text":
if config.model_type == "gemma3_text":
converted_state = _apply_gemma3_norm_inverse_transform(converted_state)

return converted_state
Expand Down
18 changes: 18 additions & 0 deletions src/olmo_core/nn/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,13 @@ class TransformerConfig(ModelConfig):
block_pattern: Optional[List[str]] = None
block_overrides: Optional[Dict[int, TransformerBlockConfig]] = None
embed_scale: Optional[float] = None
tie_word_embeddings: bool = False

def __post_init__(self):
if self.tie_word_embeddings and self.name == TransformerType.normalized:
raise OLMoConfigurationError(
"Tying word embeddings is not supported with the normalized transformer"
)
validate_block_resolution_config(
n_layers=self.n_layers,
block=self.block,
Expand Down Expand Up @@ -380,6 +385,7 @@ def build(
block_overrides=self.block_overrides,
block_pattern=self.block_pattern,
embed_scale=self.embed_scale,
tie_word_embeddings=self.tie_word_embeddings,
)
elif self.name == TransformerType.normalized:
assert self.embedding_norm is None
Expand Down Expand Up @@ -414,6 +420,7 @@ def build(
embedding_init_std=self.embedding_init_std,
block_overrides=self.block_overrides,
block_pattern=self.block_pattern,
tie_word_embeddings=self.tie_word_embeddings,
)
else:
raise NotImplementedError(self.name)
Expand Down Expand Up @@ -466,6 +473,10 @@ def num_params(self) -> int:
# LM head.
num_params += self.lm_head.num_params(self.d_model, self.vocab_size)

# The LM head weight is shared with the embeddings when tied.
if self.tie_word_embeddings:
num_params -= self.d_model * self.vocab_size

return num_params

@property
Expand All @@ -487,6 +498,10 @@ def num_active_params(self) -> int:
# LM head.
num_active_params += self.lm_head.num_params(self.d_model, self.vocab_size)

# The LM head weight is shared with the embeddings when tied.
if self.tie_word_embeddings:
num_active_params -= self.d_model * self.vocab_size

return num_active_params

@property
Expand Down Expand Up @@ -1301,6 +1316,7 @@ def qwen3_0_6B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
feed_forward=FeedForwardConfig(
hidden_size=3072, bias=False, dtype=kwargs.get("dtype", DType.float32)
),
tie_word_embeddings=kwargs.pop("tie_word_embeddings", True),
**kwargs,
)

Expand Down Expand Up @@ -1337,6 +1353,7 @@ def qwen3_1_7B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
feed_forward=FeedForwardConfig(
hidden_size=6144, bias=False, dtype=kwargs.get("dtype", DType.float32)
),
tie_word_embeddings=kwargs.pop("tie_word_embeddings", True),
**kwargs,
)

Expand Down Expand Up @@ -1373,6 +1390,7 @@ def qwen3_4B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
feed_forward=FeedForwardConfig(
hidden_size=9728, bias=False, dtype=kwargs.get("dtype", DType.float32)
),
tie_word_embeddings=kwargs.pop("tie_word_embeddings", True),
**kwargs,
)

Expand Down
44 changes: 40 additions & 4 deletions src/olmo_core/nn/transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from ..buffer_cache import BufferCache
from ..functional import l2_normalize
from ..layer_norm import LayerNormConfig
from ..lm_head import LMHeadConfig, LMOutputWithLoss
from ..lm_head import LMHeadConfig, LMLossImplementation, LMOutputWithLoss
from ..moe import MoEBase
from ..rope import RoPEBuffers, RotaryEmbeddingBase
from ..utils import selective_checkpointing_context_fn
Expand Down Expand Up @@ -117,6 +117,7 @@ def __init__(
block_overrides: Optional[Dict[int, TransformerBlockConfig]] = None,
block_pattern: Optional[List[str]] = None,
embed_scale: Optional[float] = None,
tie_word_embeddings: bool = False,
):
super().__init__()

Expand Down Expand Up @@ -160,6 +161,10 @@ def __init__(
d_model=d_model, vocab_size=vocab_size, init_device=init_device
)

self.tie_word_embeddings = tie_word_embeddings
if tie_word_embeddings:
self._tie_weights()

self.init_device = init_device
self.init_method = InitMethod(init_method)
self.init_seed = init_seed
Expand All @@ -183,6 +188,15 @@ def __init__(
self.num_params
self.num_non_embedding_params

def _tie_weights(self) -> None:
if self.embeddings is None or self.lm_head is None:
raise OLMoConfigurationError(
"Cannot tie word embeddings without both embeddings and an LM head"
)
if self.lm_head.w_out.bias is not None:
raise OLMoConfigurationError("Cannot tie word embeddings when the LM head uses a bias")
self.lm_head.w_out.weight = self.embeddings.weight

def _validate_block(self, block: TransformerBlockBase) -> TransformerBlockBase:
return block

Expand Down Expand Up @@ -295,6 +309,10 @@ def init_weights(
generator=generator,
)

# Re-establish weight tying since `to_empty` above allocates fresh storage.
if self.tie_word_embeddings:
self._tie_weights()

for block in self.blocks.values():
# This might fail if it's wrapped.
# assert isinstance(block, TransformerBlock)
Expand Down Expand Up @@ -345,7 +363,7 @@ def init_weights(
if max_seq_len is not None and att.rope is not None:
att.rope.warmup_cache(max_seq_len, device)

if self.lm_head is not None:
if self.lm_head is not None and not self.tie_word_embeddings:
Comment thread
finbarrtimbers marked this conversation as resolved.
self.init_method.init_final_w_out(
self.lm_head.w_out,
d_model=self.d_model,
Expand Down Expand Up @@ -616,6 +634,16 @@ def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: Optional[bool] = None):
:param loss_parallel: Set to ``True`` if parallelizing the loss function as well.
:param float8_enabled: Set this to ``True`` if training with float8 linear layers.
"""
if self.tie_word_embeddings and (
self.lm_head is None
or self.lm_head.loss_implementation == LMLossImplementation.fused_linear
):
raise NotImplementedError(
"Tensor parallelism with tied word embeddings requires the default loss "
"implementation; the fused-linear loss replicates the LM head weight, which is "
"incompatible with the vocab-sharded embedding."
)

if float8_enabled is None:
float8_enabled = self.fp8_enabled
elif not float8_enabled and self.fp8_enabled:
Expand Down Expand Up @@ -646,6 +674,12 @@ def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: Optional[bool] = None):
if self.lm_head is not None:
self.lm_head.apply_tp(tp_mesh, input_layouts=(Shard(1), Replicate()))

# The embedding (RowwiseParallel) and the LM head (ColwiseParallel) both shard their
# weight along the vocab dimension, so re-point the head at the embedding's sharded
# parameter to restore the tie that `parallelize_module` broke.
if self.tie_word_embeddings and self.embeddings is not None and self.lm_head is not None:
self._tie_weights()

self._tp_enabled = True
self._tp_mesh = tp_mesh

Expand Down Expand Up @@ -831,7 +865,9 @@ def apply_fsdp(
mp_policy=mp_policy,
)

if self.embeddings is not None:
# When weights are tied the embeddings and LM head share a parameter, so they must
# stay in the same FSDP group (the root) rather than being sharded separately.
if self.embeddings is not None and not self.tie_word_embeddings:
fully_shard(
self.embeddings,
reshard_after_forward=reshard_after_forward,
Expand All @@ -843,7 +879,7 @@ def apply_fsdp(
if wrapping_strategy != TransformerDataParallelWrappingStrategy.blocks:
if self.embedding_norm is not None:
fully_shard(self.embedding_norm, **fsdp_config)
if self.lm_head is not None:
if self.lm_head is not None and not self.tie_word_embeddings:
fully_shard(self.lm_head, reshard_after_forward=False, **fsdp_config)

fully_shard(self, reshard_after_forward=reshard_after_forward, **fsdp_config)
Expand Down
10 changes: 9 additions & 1 deletion src/olmo_core/train/train_module/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ def split_model(
self, model: Transformer, *, pp_mesh: DeviceMesh, device: torch.device
) -> Tuple[List[PipelineStage], List[Transformer]]:
split_points = self.get_split_points(model.n_layers)
num_stages = len(split_points) + 1

if num_stages > 1 and model.tie_word_embeddings:
raise NotImplementedError(
"Pipeline parallelism with tied word embeddings is not supported: the input "
"embeddings and LM head are placed on different pipeline stages, so they cannot "
"share a weight."
)

pp_rank = pp_mesh.get_local_rank()

def build_stage(
Expand Down Expand Up @@ -128,7 +137,6 @@ def build_stage(
)
return stage, model_chunk

num_stages = len(split_points) + 1
stage_idx = pp_rank

stages = []
Expand Down
85 changes: 34 additions & 51 deletions src/test/nn/hf/convert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,32 @@ def test_convert_state_from_hf_and_flatten():
)


def test_convert_state_from_hf_ties_word_embeddings():
hf_config = AutoConfig.for_model(
"qwen3",
vocab_size=64,
hidden_size=16,
intermediate_size=32,
num_hidden_layers=2,
num_attention_heads=2,
num_key_value_heads=2,
max_position_embeddings=64,
tie_word_embeddings=True,
)

# A tied HF checkpoint omits `lm_head.weight`.
hf_state = {
"model.embed_tokens.weight": torch.randn(hf_config.vocab_size, hf_config.hidden_size)
}

converted_state = convert_state_from_hf(hf_config, hf_state, model_type="qwen3")

assert "lm_head.w_out.weight" in converted_state
torch.testing.assert_close(
converted_state["lm_head.w_out.weight"], converted_state["embeddings.weight"]
)


def test_convert_state_to_hf():
hf_config = _get_olmo2_config()

Expand Down Expand Up @@ -228,35 +254,6 @@ def test_convert_state_to_flex_olmo_hf():
)


def _roundtrip_norm_keys(model_id, model_type, norm_suffixes, forbidden=()):
hf_config = AutoConfig.from_pretrained(model_id)
hf_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
hf_state = {k: v.detach().cpu() for k, v in hf_model.state_dict().items()}
del hf_model

oc_state = convert_state_from_hf(hf_config, hf_state, model_type=model_type)
hf_roundtrip = convert_state_to_hf(hf_config, oc_state)

for i in range(hf_config.num_hidden_layers):
for suffix in norm_suffixes:
key = f"model.layers.{i}.{suffix}"
assert key in hf_roundtrip, f"missing {key} in round-tripped state"
torch.testing.assert_close(hf_roundtrip[key], hf_state[key])
for suffix in forbidden:
assert (
f"model.layers.{i}.{suffix}" not in hf_roundtrip
), f"unexpected key model.layers.{i}.{suffix}"


def test_qwen3_0_6b_roundtrip_pre_norm():
_roundtrip_norm_keys(
"Qwen/Qwen3-0.6B",
model_type="qwen3",
norm_suffixes=("input_layernorm.weight", "post_attention_layernorm.weight"),
forbidden=("post_feedforward_layernorm.weight",),
)


def test_llama_tiny_roundtrip_pre_norm():
hf_config = AutoConfig.for_model(
"llama",
Expand Down Expand Up @@ -287,20 +284,14 @@ def test_llama_tiny_roundtrip_pre_norm():
assert f"model.layers.{i}.post_feedforward_layernorm.weight" not in hf_roundtrip


def test_gemma3_270m_roundtrip_pre_norm():
_roundtrip_norm_keys(
"google/gemma-3-270m",
model_type="gemma3_text",
norm_suffixes=(
"input_layernorm.weight",
"post_attention_layernorm.weight",
"pre_feedforward_layernorm.weight",
"post_feedforward_layernorm.weight",
),
)


def _assert_logprobs_match_after_roundtrip(model_id: str, model_type: str):
@pytest.mark.parametrize(
"model_id, model_type",
[
pytest.param("Qwen/Qwen3-0.6B", "qwen3", id="qwen3"),
pytest.param("google/gemma-3-270m", "gemma3_text", id="gemma3"),
],
)
def test_logprobs_match_after_roundtrip(model_id: str, model_type: str):
hf_config = AutoConfig.from_pretrained(model_id)
hf_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
hf_model.eval()
Expand All @@ -323,11 +314,3 @@ def _assert_logprobs_match_after_roundtrip(model_id: str, model_type: str):
rt_logprobs = torch.log_softmax(rt_logits, dim=-1)

torch.testing.assert_close(rt_logprobs, ref_logprobs, rtol=1e-5, atol=1e-5)


def test_qwen3_0_6b_logprobs_roundtrip():
_assert_logprobs_match_after_roundtrip("Qwen/Qwen3-0.6B", model_type="qwen3")


def test_gemma3_270m_logprobs_roundtrip():
_assert_logprobs_match_after_roundtrip("google/gemma-3-270m", model_type="gemma3_text")
Loading