Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added `TransformerConfig.olmo3_hybrid_7B`, a preset for the Olmo Hybrid 7B model (Gated Delta Net + attention layers in a `[gdn, gdn, gdn, attn]` pattern).
- Added `convert_hybrid_state_from_hf` and wired hybrid (`olmo_hybrid`) model dispatch into both `convert_state_from_hf` and `convert_state_to_hf`, enabling bidirectional HF weight conversion for hybrid models.
- Added `HFConverterCallback`, which can be used to convert models to huggingface format at the end of the training run.
- Trainer now records checkpoint save and load durations as `train/checkpoint_save_duration_s` and `train/checkpoint_load_duration_s` metrics.
- Added `PowerLR`, a power-law learning rate scheduler with linear warmup, power-decay phase (`lr = initial_lr * (current / warmup) ** b` for negative `b`, making the LR independent of the training horizon), and an optional linear decay tail. Registered as `"power_lr"`.
Expand Down
2 changes: 2 additions & 0 deletions src/olmo_core/nn/hf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
is_olmo_hybrid_model,
)
from .convert import (
convert_hybrid_state_from_hf,
convert_hybrid_state_to_hf,
convert_state_from_hf,
convert_state_to_hf,
Expand All @@ -26,6 +27,7 @@

__all__ = [
"convert_checkpoint_to_hf",
"convert_hybrid_state_from_hf",
"convert_hybrid_state_to_hf",
"convert_state_from_hf",
"convert_state_to_hf",
Expand Down
93 changes: 67 additions & 26 deletions src/olmo_core/nn/hf/convert.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import re
from typing import Any, Dict, List

import torch
Expand Down Expand Up @@ -466,6 +465,12 @@ def convert_state_from_hf(
:param model_type: The model type of the HF model.
"""

if model_type is None:
model_type = getattr(config, "model_type", None)

if model_type == "olmo_hybrid":
return convert_hybrid_state_from_hf(hf_state, _hybrid_layer_types_from_config(config))

converter = _get_converter_from_hf(model_type=model_type)

converted_state = _convert_state(config, hf_state, converter)
Expand Down Expand Up @@ -538,6 +543,10 @@ def convert_state_to_hf(
"""

model_type = getattr(config, "model_type", None)

if model_type == "olmo_hybrid":
return convert_hybrid_state_to_hf(olmo_core_state, _hybrid_layer_types_from_config(config))

converter = _get_converter_to_hf(model_type)

converted_state = _convert_state(config, olmo_core_state, converter)
Expand Down Expand Up @@ -602,7 +611,39 @@ def convert_state_to_hf(
"lm_head.w_out.weight": "lm_head.weight",
}

_HYBRID_BLOCK_KEY_RE = re.compile(r"^blocks\.(\d+)\.(.+)$")

def _build_hybrid_mapping_templates(
layer_types: List[str], *, to_hf: bool
) -> List[StateMappingTemplate]:
"""
Build concrete (placeholder-free) :class:`StateMappingTemplate`s for a hybrid model.

For each layer, the GDN (``"linear_attention"``) or attention (``"full_attention"``) suffix
map is selected so the per-layer-type naming difference is baked into the concrete keys. The
resulting templates can be fed to a standard :class:`StateConverter`.

:param layer_types: Per-layer type list (``"linear_attention"`` or ``"full_attention"``).
:param to_hf: If ``True``, map OLMo-core keys to HF keys, otherwise HF keys to OLMo-core keys.
"""

def template(olmo_key: str, hf_key: str) -> StateMappingTemplate:
src, dst = (olmo_key, hf_key) if to_hf else (hf_key, olmo_key)
return StateMappingTemplate(src, dst)

templates = [template(olmo_key, hf_key) for olmo_key, hf_key in HYBRID_SHARED_KEY_MAP.items()]

for layer_idx, layer_type in enumerate(layer_types):
key_map = (
HYBRID_GDN_LAYER_KEY_MAP
if layer_type == "linear_attention"
else HYBRID_ATTN_LAYER_KEY_MAP
)
templates.extend(
template(f"blocks.{layer_idx}.{olmo_suffix}", f"model.layers.{layer_idx}.{hf_suffix}")
for olmo_suffix, hf_suffix in key_map.items()
)

return templates


@beta_feature
Expand All @@ -620,33 +661,33 @@ def convert_hybrid_state_to_hf(
:param state_dict: An unsharded OLMo-core model state dict.
:param layer_types: Per-layer type list (``"linear_attention"`` or ``"full_attention"``).
"""
hf_state: Dict[str, Any] = {}
templates = _build_hybrid_mapping_templates(layer_types, to_hf=True)
return StateConverter(templates).convert(state_dict, placeholder_bounds={})

for olmo_key, value in state_dict.items():
# Try shared (non-block) keys first.
if olmo_key in HYBRID_SHARED_KEY_MAP:
hf_state[HYBRID_SHARED_KEY_MAP[olmo_key]] = value
continue

m = _HYBRID_BLOCK_KEY_RE.match(olmo_key)
if m is None:
raise KeyError(f"Unmapped key: {olmo_key}")
@beta_feature
def convert_hybrid_state_from_hf(
state_dict: Dict[str, Any],
layer_types: List[str],
) -> Dict[str, Any]:
"""
Convert an HF ``olmo_hybrid`` state dict to OLMo-core format.

layer_idx = int(m.group(1))
suffix = m.group(2)
Inverse of :func:`convert_hybrid_state_to_hf`: uses the same suffix maps in the HF ->
OLMo-core direction, selecting per layer based on *layer_types*.

key_map = (
HYBRID_GDN_LAYER_KEY_MAP
if layer_types[layer_idx] == "linear_attention"
else HYBRID_ATTN_LAYER_KEY_MAP
)
if suffix not in key_map:
raise KeyError(
f"Unmapped block suffix for layer {layer_idx} "
f"(type={layer_types[layer_idx]!r}): {olmo_key}"
)
:param state_dict: An unsharded HF ``olmo_hybrid`` model state dict.
:param layer_types: Per-layer type list (``"linear_attention"`` or ``"full_attention"``).
"""
templates = _build_hybrid_mapping_templates(layer_types, to_hf=False)
return StateConverter(templates).convert(state_dict, placeholder_bounds={})

hf_key = f"model.layers.{layer_idx}.{key_map[suffix]}"
hf_state[hf_key] = value

return hf_state
def _hybrid_layer_types_from_config(config: PretrainedConfig) -> List[str]:
"""Read the per-layer type list from an HF ``olmo_hybrid`` config."""
layer_types = getattr(config, "layer_types", None)
if layer_types is None:
raise ValueError(
"olmo_hybrid HF config is missing the `layer_types` field required for conversion."
)
return list(layer_types)
43 changes: 43 additions & 0 deletions src/olmo_core/nn/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
AttentionConfig,
AttentionType,
GateConfig,
GatedDeltaNetConfig,
SlidingWindowAttentionConfig,
)
from ..buffer_cache import BufferCache
Expand Down Expand Up @@ -971,6 +972,48 @@ def olmo3_7B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
)
return config

@classmethod
def olmo3_hybrid_7B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
"""
A 7B Olmo 3 hybrid model config that combines Gated Delta Net (GDN) recurrent
layers with standard attention layers in a ``[gdn, gdn, gdn, attn]`` pattern.

Matches the recipe used to train ``allenai/Olmo-Hybrid-Instruct-SFT-7B``: it
starts from :meth:`olmo3_7B`, removes ``remove_heads`` attention heads (scaling
``d_model`` down accordingly) to match the params/throughput of the dense 7B
model, disables RoPE (as in long-context extension), and replaces 3 of every 4
layers with GDN layers.
"""
remove_heads = kwargs.pop("remove_heads", 2)
head_dim = kwargs.pop("head_dim", 128)
config = cls.olmo3_7B(vocab_size=vocab_size, **kwargs)
assert isinstance(config.block, TransformerBlockConfig)
assert isinstance(config.block.sequence_mixer, AttentionConfig)

# Remove heads (and scale down d_model) to compensate for the extra GDN params.
config.d_model -= remove_heads * head_dim
num_heads = config.block.sequence_mixer.n_heads - remove_heads
config.block.sequence_mixer.n_heads = num_heads
assert config.d_model / num_heads == head_dim

# RoPE was disabled at the start of long-context extension.
attn_block = config.block.replace(
sequence_mixer=config.block.sequence_mixer.replace(rope=None),
)
gdn_block = attn_block.replace(
sequence_mixer=GatedDeltaNetConfig(
n_heads=num_heads,
head_dim=int(0.75 * config.d_model / num_heads),
allow_neg_eigval=True,
),
)

# 3 GDN layers followed by 1 attention layer, repeating.
config.block = {"gdn": gdn_block, "attn": attn_block}
config.block_pattern = ["gdn", "gdn", "gdn", "attn"]
assert config.n_layers % len(config.block_pattern) == 0
return config

@classmethod
def olmo3_13B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
"""
Expand Down
49 changes: 48 additions & 1 deletion src/test/nn/hf/convert_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
from types import SimpleNamespace

import pytest
import torch
from transformers import AutoConfig, AutoModelForCausalLM, Olmo2Config

from olmo_core.nn.hf.convert import convert_state_from_hf, convert_state_to_hf
from olmo_core.nn.hf.convert import (
HYBRID_ATTN_LAYER_KEY_MAP,
HYBRID_GDN_LAYER_KEY_MAP,
HYBRID_SHARED_KEY_MAP,
convert_hybrid_state_from_hf,
convert_hybrid_state_to_hf,
convert_state_from_hf,
convert_state_to_hf,
)

try:
from transformers import FlexOlmoConfig # type: ignore
Expand Down Expand Up @@ -331,3 +341,40 @@ def test_qwen3_0_6b_logprobs_roundtrip():

def test_gemma3_270m_logprobs_roundtrip():
_assert_logprobs_match_after_roundtrip("google/gemma-3-270m", model_type="gemma3_text")


# Layer 0 is GDN (linear_attention), layer 1 is standard attention (full_attention).
_HYBRID_LAYER_TYPES = ["linear_attention", "full_attention"]


def _make_hybrid_olmo_core_state() -> dict:
torch.manual_seed(0)
state = {olmo_key: torch.randn(3) for olmo_key in HYBRID_SHARED_KEY_MAP}
for olmo_suffix in HYBRID_GDN_LAYER_KEY_MAP:
state[f"blocks.0.{olmo_suffix}"] = torch.randn(3)
for olmo_suffix in HYBRID_ATTN_LAYER_KEY_MAP:
state[f"blocks.1.{olmo_suffix}"] = torch.randn(3)
return state


def test_convert_hybrid_state_roundtrip():
olmo_core_state = _make_hybrid_olmo_core_state()

hf_state = convert_hybrid_state_to_hf(olmo_core_state, _HYBRID_LAYER_TYPES)
roundtrip_state = convert_hybrid_state_from_hf(hf_state, _HYBRID_LAYER_TYPES)

assert set(roundtrip_state.keys()) == set(olmo_core_state.keys())
for key, value in olmo_core_state.items():
torch.testing.assert_close(roundtrip_state[key], value)


def test_convert_state_hybrid_dispatch_roundtrip():
config = SimpleNamespace(model_type="olmo_hybrid", layer_types=_HYBRID_LAYER_TYPES)
olmo_core_state = _make_hybrid_olmo_core_state()

hf_state = convert_state_to_hf(config, olmo_core_state)
roundtrip_state = convert_state_from_hf(config, hf_state, model_type="olmo_hybrid")

assert set(roundtrip_state.keys()) == set(olmo_core_state.keys())
for key, value in olmo_core_state.items():
torch.testing.assert_close(roundtrip_state[key], value)