Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ env:
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }}
GOOGLE_CREDENTIALS: ${{ secrets.GOOGLE_CREDENTIALS }}
HF_TOKEN: ${{ secrets.HF_TOKEN }}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep Hub token out of the default test job

When the Test matrix runs on same-repo PRs or pushes where secrets are available, setting HF_TOKEN globally makes src/test/nn/hf/golden_tests.py's @skipif(not HF_TOKEN) false, so the ordinary 15-minute CPU test job now downloads and runs the Qwen3-0.6B and Gemma3-270m generation golden tests. That couples normal unit CI to Hub availability and large-model runtime; scope this token to a dedicated/gated job or keep those tests skipped unless explicitly requested.

Useful? React with 👍 / 👎.


jobs:
checks:
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Distributed checkpoint writes now clone each tensor before serialization to avoid accidentally writing the full backing storage of a view/shared tensor, with a guard that raises `OLMoCheckpointError` if a written tensor is unexpectedly larger than its `nbytes`.
- Fixed LM in-loop evaluator data-order drift across repeated runs by resetting loader bookkeeping before each pass and making deterministic reshuffling the default.
- Fixed Qwen3 implementation to match HuggingFace by applying RoPE in the input dtype (bf16) rather than upcasting to fp32.
- Fixed HF model conversion for Llama, Qwen3, and Gemma so that converted checkpoints roundtrip correctly.
- Fixed Beaker secret existence check to use the case-insensitive HTTP endpoint, avoiding spurious "secret not found" errors when secret names differ only in case.
- Fixed `Transformer.init_weights` so that under interleaved pipeline parallelism (e.g. `Interleaved1F1B`, `InterleavedZeroBubble`) the multiple model chunks owned by a single rank no longer initialize to identical parameters. Adds a `model_part_idx` kwarg incorporated into the seed as `model_part_idx * pp_size`.
- Disabled `torch.compile` tracing through `TEAttentionBackend.forward`, whose Python/pybind setup is not Dynamo-safe.
Expand Down
6 changes: 3 additions & 3 deletions src/olmo_core/nn/hf/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _get_flex_olmo_config(model: MoETransformer) -> PretrainedConfig:
rms_norm_eps=block.feed_forward_norm.eps,
num_experts_per_tok=block.feed_forward_moe.router.top_k,
num_experts=block.feed_forward_moe.router.num_experts,
tie_word_embeddings=False,
tie_word_embeddings=model.tie_word_embeddings,
)


Expand Down Expand Up @@ -133,7 +133,7 @@ def get_hf_config(model: Transformer) -> PretrainedConfig:
"bos_token_id": None,
"eos_token_id": None,
"rms_norm_eps": first_block.feed_forward_norm.eps,
"tie_word_embeddings": False,
"tie_word_embeddings": model.tie_word_embeddings,
}

# The OLMo 3 model family is identical to the OLMo 2 model family, except:
Expand Down Expand Up @@ -387,7 +387,7 @@ def get_hybrid_hf_config(
"attention_bias": attn.w_out.bias is not None,
"attention_dropout": 0.0,
"rms_norm_eps": attn_block.feed_forward_norm.eps, # todo: revisit
"tie_word_embeddings": False,
"tie_word_embeddings": model.tie_word_embeddings,
# Hybrid layer configuration
"layer_types": layer_types,
# GDN (linear attention) parameters
Expand Down
83 changes: 80 additions & 3 deletions src/olmo_core/nn/hf/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,53 @@
f"model.layers.{LAYER}.mlp.gate.weight",
unflatten_dim=(0, (TemplatePlaceholder.EXPERT, -1)),
),
}
},
"llama": {
f"blocks.{LAYER}.attention_norm.weight": StateMappingTemplate(
f"blocks.{LAYER}.attention_norm.weight",
f"model.layers.{LAYER}.input_layernorm.weight",
state_type=StateType.weight,
),
f"blocks.{LAYER}.feed_forward_norm.weight": StateMappingTemplate(
f"blocks.{LAYER}.feed_forward_norm.weight",
f"model.layers.{LAYER}.post_attention_layernorm.weight",
state_type=StateType.weight,
),
},
"qwen3": {
f"blocks.{LAYER}.attention_norm.weight": StateMappingTemplate(
f"blocks.{LAYER}.attention_norm.weight",
f"model.layers.{LAYER}.input_layernorm.weight",
state_type=StateType.weight,
),
f"blocks.{LAYER}.feed_forward_norm.weight": StateMappingTemplate(
f"blocks.{LAYER}.feed_forward_norm.weight",
f"model.layers.{LAYER}.post_attention_layernorm.weight",
state_type=StateType.weight,
),
},
"gemma3_text": {
f"blocks.{LAYER}.attention_norm.weight": StateMappingTemplate(
f"blocks.{LAYER}.attention_norm.weight",
f"model.layers.{LAYER}.input_layernorm.weight",
state_type=StateType.weight,
),
f"blocks.{LAYER}.post_attention_norm.weight": StateMappingTemplate(
f"blocks.{LAYER}.post_attention_norm.weight",
f"model.layers.{LAYER}.post_attention_layernorm.weight",
state_type=StateType.weight,
),
f"blocks.{LAYER}.feed_forward_norm.weight": StateMappingTemplate(
f"blocks.{LAYER}.feed_forward_norm.weight",
f"model.layers.{LAYER}.pre_feedforward_layernorm.weight",
state_type=StateType.weight,
),
f"blocks.{LAYER}.post_feed_forward_norm.weight": StateMappingTemplate(
f"blocks.{LAYER}.post_feed_forward_norm.weight",
f"model.layers.{LAYER}.post_feedforward_layernorm.weight",
state_type=StateType.weight,
),
},
}


Expand Down Expand Up @@ -427,6 +473,9 @@ def convert_state_from_hf(
if model_type == "gemma3_text":
converted_state = _apply_gemma3_norm_transform(converted_state)

if config.tie_word_embeddings:
converted_state["lm_head.w_out.weight"] = converted_state["embeddings.weight"]

return converted_state


Expand Down Expand Up @@ -456,6 +505,29 @@ def get_converter_to_hf(model_type: str | None = None) -> StateConverter:
return _get_converter_to_hf(model_type)


def _apply_gemma3_norm_inverse_transform(state: Dict[str, Any]) -> Dict[str, Any]:
"""
Inverse of :func:`_apply_gemma3_norm_transform`: subtracts 1 from norm weights
so that an OLMo-core checkpoint can be exported back into HF Gemma 3 format.
"""
norm_patterns = [
"input_layernorm.weight",
"post_attention_layernorm.weight",
"pre_feedforward_layernorm.weight",
"post_feedforward_layernorm.weight",
"model.norm.weight",
"q_norm.weight",
"k_norm.weight",
]

for key, value in state.items():
if any(pattern in key for pattern in norm_patterns):
if isinstance(value, torch.Tensor):
state[key] = value - 1.0

return state


@beta_feature
def convert_state_to_hf(
config: PretrainedConfig, olmo_core_state: Dict[str, Any]
Expand All @@ -468,9 +540,14 @@ def convert_state_to_hf(
:class:`DTensor` or :class:`ShardedTensor`
"""

converter = _get_converter_to_hf(getattr(config, "model_type", None))
converter = _get_converter_to_hf(config.model_type)

converted_state = _convert_state(config, olmo_core_state, converter)

return _convert_state(config, olmo_core_state, converter)
if config.model_type == "gemma3_text":
converted_state = _apply_gemma3_norm_inverse_transform(converted_state)

return converted_state


# ---------------------------------------------------------------------------
Expand Down
18 changes: 18 additions & 0 deletions src/olmo_core/nn/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,13 @@ class TransformerConfig(ModelConfig):
block_pattern: Optional[List[str]] = None
block_overrides: Optional[Dict[int, TransformerBlockConfig]] = None
embed_scale: Optional[float] = None
tie_word_embeddings: bool = False

def __post_init__(self):
if self.tie_word_embeddings and self.name == TransformerType.normalized:
raise OLMoConfigurationError(
"Tying word embeddings is not supported with the normalized transformer"
)
validate_block_resolution_config(
n_layers=self.n_layers,
block=self.block,
Expand Down Expand Up @@ -380,6 +385,7 @@ def build(
block_overrides=self.block_overrides,
block_pattern=self.block_pattern,
embed_scale=self.embed_scale,
tie_word_embeddings=self.tie_word_embeddings,
)
elif self.name == TransformerType.normalized:
assert self.embedding_norm is None
Expand Down Expand Up @@ -414,6 +420,7 @@ def build(
embedding_init_std=self.embedding_init_std,
block_overrides=self.block_overrides,
block_pattern=self.block_pattern,
tie_word_embeddings=self.tie_word_embeddings,
)
else:
raise NotImplementedError(self.name)
Expand Down Expand Up @@ -466,6 +473,10 @@ def num_params(self) -> int:
# LM head.
num_params += self.lm_head.num_params(self.d_model, self.vocab_size)

# The LM head weight is shared with the embeddings when tied.
if self.tie_word_embeddings:
num_params -= self.d_model * self.vocab_size

return num_params

@property
Expand All @@ -487,6 +498,10 @@ def num_active_params(self) -> int:
# LM head.
num_active_params += self.lm_head.num_params(self.d_model, self.vocab_size)

# The LM head weight is shared with the embeddings when tied.
if self.tie_word_embeddings:
num_active_params -= self.d_model * self.vocab_size

return num_active_params

@property
Expand Down Expand Up @@ -1301,6 +1316,7 @@ def qwen3_0_6B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
feed_forward=FeedForwardConfig(
hidden_size=3072, bias=False, dtype=kwargs.get("dtype", DType.float32)
),
tie_word_embeddings=kwargs.pop("tie_word_embeddings", True),
**kwargs,
)

Expand Down Expand Up @@ -1337,6 +1353,7 @@ def qwen3_1_7B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
feed_forward=FeedForwardConfig(
hidden_size=6144, bias=False, dtype=kwargs.get("dtype", DType.float32)
),
tie_word_embeddings=kwargs.pop("tie_word_embeddings", True),
**kwargs,
)

Expand Down Expand Up @@ -1373,6 +1390,7 @@ def qwen3_4B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
feed_forward=FeedForwardConfig(
hidden_size=9728, bias=False, dtype=kwargs.get("dtype", DType.float32)
),
tie_word_embeddings=kwargs.pop("tie_word_embeddings", True),
**kwargs,
)

Expand Down
44 changes: 40 additions & 4 deletions src/olmo_core/nn/transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from ..buffer_cache import BufferCache
from ..functional import l2_normalize
from ..layer_norm import LayerNormConfig
from ..lm_head import LMHeadConfig, LMOutputWithLoss
from ..lm_head import LMHeadConfig, LMLossImplementation, LMOutputWithLoss
from ..moe import MoEBase
from ..rope import RoPEBuffers, RotaryEmbeddingBase
from ..utils import selective_checkpointing_context_fn
Expand Down Expand Up @@ -117,6 +117,7 @@ def __init__(
block_overrides: Optional[Dict[int, TransformerBlockConfig]] = None,
block_pattern: Optional[List[str]] = None,
embed_scale: Optional[float] = None,
tie_word_embeddings: bool = False,
):
super().__init__()

Expand Down Expand Up @@ -160,6 +161,10 @@ def __init__(
d_model=d_model, vocab_size=vocab_size, init_device=init_device
)

self.tie_word_embeddings = tie_word_embeddings
if tie_word_embeddings:
self._tie_weights()

self.init_device = init_device
self.init_method = InitMethod(init_method)
self.init_seed = init_seed
Expand All @@ -183,6 +188,15 @@ def __init__(
self.num_params
self.num_non_embedding_params

def _tie_weights(self) -> None:
if self.embeddings is None or self.lm_head is None:
raise OLMoConfigurationError(
"Cannot tie word embeddings without both embeddings and an LM head"
)
if self.lm_head.w_out.bias is not None:
raise OLMoConfigurationError("Cannot tie word embeddings when the LM head uses a bias")
self.lm_head.w_out.weight = self.embeddings.weight

def _validate_block(self, block: TransformerBlockBase) -> TransformerBlockBase:
return block

Expand Down Expand Up @@ -295,6 +309,10 @@ def init_weights(
generator=generator,
)

# Re-establish weight tying since `to_empty` above allocates fresh storage.
if self.tie_word_embeddings:
self._tie_weights()

for block in self.blocks.values():
# This might fail if it's wrapped.
# assert isinstance(block, TransformerBlock)
Expand Down Expand Up @@ -345,7 +363,7 @@ def init_weights(
if max_seq_len is not None and att.rope is not None:
att.rope.warmup_cache(max_seq_len, device)

if self.lm_head is not None:
if self.lm_head is not None and not self.tie_word_embeddings:
self.init_method.init_final_w_out(
self.lm_head.w_out,
d_model=self.d_model,
Expand Down Expand Up @@ -616,6 +634,16 @@ def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: Optional[bool] = None):
:param loss_parallel: Set to ``True`` if parallelizing the loss function as well.
:param float8_enabled: Set this to ``True`` if training with float8 linear layers.
"""
if self.tie_word_embeddings and (
self.lm_head is None
or self.lm_head.loss_implementation == LMLossImplementation.fused_linear
):
raise NotImplementedError(
"Tensor parallelism with tied word embeddings requires the default loss "
"implementation; the fused-linear loss replicates the LM head weight, which is "
"incompatible with the vocab-sharded embedding."
)

if float8_enabled is None:
float8_enabled = self.fp8_enabled
elif not float8_enabled and self.fp8_enabled:
Expand Down Expand Up @@ -646,6 +674,12 @@ def apply_tp(self, tp_mesh: DeviceMesh, float8_enabled: Optional[bool] = None):
if self.lm_head is not None:
self.lm_head.apply_tp(tp_mesh, input_layouts=(Shard(1), Replicate()))

# The embedding (RowwiseParallel) and the LM head (ColwiseParallel) both shard their
# weight along the vocab dimension, so re-point the head at the embedding's sharded
# parameter to restore the tie that `parallelize_module` broke.
if self.tie_word_embeddings and self.embeddings is not None and self.lm_head is not None:
self._tie_weights()

self._tp_enabled = True
self._tp_mesh = tp_mesh

Expand Down Expand Up @@ -831,7 +865,9 @@ def apply_fsdp(
mp_policy=mp_policy,
)

if self.embeddings is not None:
# When weights are tied the embeddings and LM head share a parameter, so they must
# stay in the same FSDP group (the root) rather than being sharded separately.
if self.embeddings is not None and not self.tie_word_embeddings:
fully_shard(
self.embeddings,
reshard_after_forward=reshard_after_forward,
Expand All @@ -843,7 +879,7 @@ def apply_fsdp(
if wrapping_strategy != TransformerDataParallelWrappingStrategy.blocks:
if self.embedding_norm is not None:
fully_shard(self.embedding_norm, **fsdp_config)
if self.lm_head is not None:
if self.lm_head is not None and not self.tie_word_embeddings:
fully_shard(self.lm_head, reshard_after_forward=False, **fsdp_config)

fully_shard(self, reshard_after_forward=reshard_after_forward, **fsdp_config)
Expand Down
10 changes: 9 additions & 1 deletion src/olmo_core/train/train_module/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ def split_model(
self, model: Transformer, *, pp_mesh: DeviceMesh, device: torch.device
) -> Tuple[List[PipelineStage], List[Transformer]]:
split_points = self.get_split_points(model.n_layers)
num_stages = len(split_points) + 1

if num_stages > 1 and model.tie_word_embeddings:
raise NotImplementedError(
"Pipeline parallelism with tied word embeddings is not supported: the input "
"embeddings and LM head are placed on different pipeline stages, so they cannot "
"share a weight."
)

pp_rank = pp_mesh.get_local_rank()

def build_stage(
Expand Down Expand Up @@ -128,7 +137,6 @@ def build_stage(
)
return stage, model_chunk

num_stages = len(split_points) + 1
stage_idx = pp_rank

stages = []
Expand Down
Loading
Loading