From ae85c1100b81436ca7e29c50cccc45f6d206fd7b Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 2 Jun 2026 08:24:10 -0600 Subject: [PATCH 1/2] Add Olmo Hybrid 7B config preset and bidirectional HF weight conversion Co-Authored-By: Claude Opus 4.8 --- CHANGELOG.md | 2 + src/olmo_core/nn/hf/__init__.py | 2 + src/olmo_core/nn/hf/convert.py | 88 ++++++++++++++++++++++++++ src/olmo_core/nn/transformer/config.py | 43 +++++++++++++ src/test/nn/hf/convert_test.py | 49 +++++++++++++- 5 files changed, 183 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 73feba949..e8c85df7f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `TransformerConfig.olmo3_hybrid_7B`, a preset for the Olmo Hybrid 7B model (Gated Delta Net + attention layers in a `[gdn, gdn, gdn, attn]` pattern). +- Added `convert_hybrid_state_from_hf` and wired hybrid (`olmo_hybrid`) model dispatch into both `convert_state_from_hf` and `convert_state_to_hf`, enabling bidirectional HF weight conversion for hybrid models. - Added `HFConverterCallback`, which can be used to convert models to huggingface format at the end of the training run. - Trainer now records checkpoint save and load durations as `train/checkpoint_save_duration_s` and `train/checkpoint_load_duration_s` metrics. - Added `PowerLR`, a power-law learning rate scheduler with linear warmup, power-decay phase (`lr = initial_lr * (current / warmup) ** b` for negative `b`, making the LR independent of the training horizon), and an optional linear decay tail. Registered as `"power_lr"`. diff --git a/src/olmo_core/nn/hf/__init__.py b/src/olmo_core/nn/hf/__init__.py index 9e209a1b4..91b642b3c 100644 --- a/src/olmo_core/nn/hf/__init__.py +++ b/src/olmo_core/nn/hf/__init__.py @@ -12,6 +12,7 @@ is_olmo_hybrid_model, ) from .convert import ( + convert_hybrid_state_from_hf, convert_hybrid_state_to_hf, convert_state_from_hf, convert_state_to_hf, @@ -26,6 +27,7 @@ __all__ = [ "convert_checkpoint_to_hf", + "convert_hybrid_state_from_hf", "convert_hybrid_state_to_hf", "convert_state_from_hf", "convert_state_to_hf", diff --git a/src/olmo_core/nn/hf/convert.py b/src/olmo_core/nn/hf/convert.py index a1f594ab2..a1d03834f 100644 --- a/src/olmo_core/nn/hf/convert.py +++ b/src/olmo_core/nn/hf/convert.py @@ -466,6 +466,12 @@ def convert_state_from_hf( :param model_type: The model type of the HF model. """ + if model_type is None: + model_type = getattr(config, "model_type", None) + + if model_type == "olmo_hybrid": + return convert_hybrid_state_from_hf(hf_state, _hybrid_layer_types_from_config(config)) + converter = _get_converter_from_hf(model_type=model_type) converted_state = _convert_state(config, hf_state, converter) @@ -538,6 +544,10 @@ def convert_state_to_hf( """ model_type = getattr(config, "model_type", None) + + if model_type == "olmo_hybrid": + return convert_hybrid_state_to_hf(olmo_core_state, _hybrid_layer_types_from_config(config)) + converter = _get_converter_to_hf(model_type) converted_state = _convert_state(config, olmo_core_state, converter) @@ -650,3 +660,81 @@ def convert_hybrid_state_to_hf( hf_state[hf_key] = value return hf_state + + +def _invert_hybrid_key_map(key_map: Dict[str, str]) -> Dict[str, str]: + inverse: Dict[str, str] = {} + for olmo_suffix, hf_suffix in key_map.items(): + if hf_suffix in inverse: + raise ValueError(f"Non-invertible hybrid key map: duplicate HF key {hf_suffix!r}") + inverse[hf_suffix] = olmo_suffix + return inverse + + +#: Inverse of the hybrid maps, for the HF -> OLMo-core direction. +_HF_TO_OLMO_HYBRID_SHARED_KEY_MAP: Dict[str, str] = _invert_hybrid_key_map(HYBRID_SHARED_KEY_MAP) +_HF_TO_OLMO_HYBRID_GDN_LAYER_KEY_MAP: Dict[str, str] = _invert_hybrid_key_map( + HYBRID_GDN_LAYER_KEY_MAP +) +_HF_TO_OLMO_HYBRID_ATTN_LAYER_KEY_MAP: Dict[str, str] = _invert_hybrid_key_map( + HYBRID_ATTN_LAYER_KEY_MAP +) + +_HF_HYBRID_BLOCK_KEY_RE = re.compile(r"^model\.layers\.(\d+)\.(.+)$") + + +@beta_feature +def convert_hybrid_state_from_hf( + state_dict: Dict[str, Any], + layer_types: List[str], +) -> Dict[str, Any]: + """ + Convert an HF ``olmo_hybrid`` state dict to OLMo-core format. + + Inverse of :func:`convert_hybrid_state_to_hf`: uses the inverse of + :data:`HYBRID_SHARED_KEY_MAP` for non-block keys, and the inverse of + :data:`HYBRID_GDN_LAYER_KEY_MAP` / :data:`HYBRID_ATTN_LAYER_KEY_MAP` + based on *layer_types*. + + :param state_dict: An unsharded HF ``olmo_hybrid`` model state dict. + :param layer_types: Per-layer type list (``"linear_attention"`` or ``"full_attention"``). + """ + olmo_state: Dict[str, Any] = {} + + for hf_key, value in state_dict.items(): + # Try shared (non-block) keys first. + if hf_key in _HF_TO_OLMO_HYBRID_SHARED_KEY_MAP: + olmo_state[_HF_TO_OLMO_HYBRID_SHARED_KEY_MAP[hf_key]] = value + continue + + m = _HF_HYBRID_BLOCK_KEY_RE.match(hf_key) + if m is None: + raise KeyError(f"Unmapped key: {hf_key}") + + layer_idx = int(m.group(1)) + suffix = m.group(2) + + key_map = ( + _HF_TO_OLMO_HYBRID_GDN_LAYER_KEY_MAP + if layer_types[layer_idx] == "linear_attention" + else _HF_TO_OLMO_HYBRID_ATTN_LAYER_KEY_MAP + ) + if suffix not in key_map: + raise KeyError( + f"Unmapped block suffix for layer {layer_idx} " + f"(type={layer_types[layer_idx]!r}): {hf_key}" + ) + + olmo_state[f"blocks.{layer_idx}.{key_map[suffix]}"] = value + + return olmo_state + + +def _hybrid_layer_types_from_config(config: PretrainedConfig) -> List[str]: + """Read the per-layer type list from an HF ``olmo_hybrid`` config.""" + layer_types = getattr(config, "layer_types", None) + if layer_types is None: + raise ValueError( + "olmo_hybrid HF config is missing the `layer_types` field required for conversion." + ) + return list(layer_types) diff --git a/src/olmo_core/nn/transformer/config.py b/src/olmo_core/nn/transformer/config.py index 3626f3559..d25373381 100644 --- a/src/olmo_core/nn/transformer/config.py +++ b/src/olmo_core/nn/transformer/config.py @@ -17,6 +17,7 @@ AttentionConfig, AttentionType, GateConfig, + GatedDeltaNetConfig, SlidingWindowAttentionConfig, ) from ..buffer_cache import BufferCache @@ -971,6 +972,48 @@ def olmo3_7B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": ) return config + @classmethod + def olmo3_hybrid_7B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": + """ + A 7B Olmo 3 hybrid model config that combines Gated Delta Net (GDN) recurrent + layers with standard attention layers in a ``[gdn, gdn, gdn, attn]`` pattern. + + Matches the recipe used to train ``allenai/Olmo-Hybrid-Instruct-SFT-7B``: it + starts from :meth:`olmo3_7B`, removes ``remove_heads`` attention heads (scaling + ``d_model`` down accordingly) to match the params/throughput of the dense 7B + model, disables RoPE (as in long-context extension), and replaces 3 of every 4 + layers with GDN layers. + """ + remove_heads = kwargs.pop("remove_heads", 2) + head_dim = kwargs.pop("head_dim", 128) + config = cls.olmo3_7B(vocab_size=vocab_size, **kwargs) + assert isinstance(config.block, TransformerBlockConfig) + assert isinstance(config.block.sequence_mixer, AttentionConfig) + + # Remove heads (and scale down d_model) to compensate for the extra GDN params. + config.d_model -= remove_heads * head_dim + num_heads = config.block.sequence_mixer.n_heads - remove_heads + config.block.sequence_mixer.n_heads = num_heads + assert config.d_model / num_heads == head_dim + + # RoPE was disabled at the start of long-context extension. + attn_block = config.block.replace( + sequence_mixer=config.block.sequence_mixer.replace(rope=None), + ) + gdn_block = attn_block.replace( + sequence_mixer=GatedDeltaNetConfig( + n_heads=num_heads, + head_dim=int(0.75 * config.d_model / num_heads), + allow_neg_eigval=True, + ), + ) + + # 3 GDN layers followed by 1 attention layer, repeating. + config.block = {"gdn": gdn_block, "attn": attn_block} + config.block_pattern = ["gdn", "gdn", "gdn", "attn"] + assert config.n_layers % len(config.block_pattern) == 0 + return config + @classmethod def olmo3_13B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": """ diff --git a/src/test/nn/hf/convert_test.py b/src/test/nn/hf/convert_test.py index afb157965..f42dc0310 100644 --- a/src/test/nn/hf/convert_test.py +++ b/src/test/nn/hf/convert_test.py @@ -1,8 +1,18 @@ +from types import SimpleNamespace + import pytest import torch from transformers import AutoConfig, AutoModelForCausalLM, Olmo2Config -from olmo_core.nn.hf.convert import convert_state_from_hf, convert_state_to_hf +from olmo_core.nn.hf.convert import ( + HYBRID_ATTN_LAYER_KEY_MAP, + HYBRID_GDN_LAYER_KEY_MAP, + HYBRID_SHARED_KEY_MAP, + convert_hybrid_state_from_hf, + convert_hybrid_state_to_hf, + convert_state_from_hf, + convert_state_to_hf, +) try: from transformers import FlexOlmoConfig # type: ignore @@ -331,3 +341,40 @@ def test_qwen3_0_6b_logprobs_roundtrip(): def test_gemma3_270m_logprobs_roundtrip(): _assert_logprobs_match_after_roundtrip("google/gemma-3-270m", model_type="gemma3_text") + + +# Layer 0 is GDN (linear_attention), layer 1 is standard attention (full_attention). +_HYBRID_LAYER_TYPES = ["linear_attention", "full_attention"] + + +def _make_hybrid_olmo_core_state() -> dict: + torch.manual_seed(0) + state = {olmo_key: torch.randn(3) for olmo_key in HYBRID_SHARED_KEY_MAP} + for olmo_suffix in HYBRID_GDN_LAYER_KEY_MAP: + state[f"blocks.0.{olmo_suffix}"] = torch.randn(3) + for olmo_suffix in HYBRID_ATTN_LAYER_KEY_MAP: + state[f"blocks.1.{olmo_suffix}"] = torch.randn(3) + return state + + +def test_convert_hybrid_state_roundtrip(): + olmo_core_state = _make_hybrid_olmo_core_state() + + hf_state = convert_hybrid_state_to_hf(olmo_core_state, _HYBRID_LAYER_TYPES) + roundtrip_state = convert_hybrid_state_from_hf(hf_state, _HYBRID_LAYER_TYPES) + + assert set(roundtrip_state.keys()) == set(olmo_core_state.keys()) + for key, value in olmo_core_state.items(): + torch.testing.assert_close(roundtrip_state[key], value) + + +def test_convert_state_hybrid_dispatch_roundtrip(): + config = SimpleNamespace(model_type="olmo_hybrid", layer_types=_HYBRID_LAYER_TYPES) + olmo_core_state = _make_hybrid_olmo_core_state() + + hf_state = convert_state_to_hf(config, olmo_core_state) + roundtrip_state = convert_state_from_hf(config, hf_state, model_type="olmo_hybrid") + + assert set(roundtrip_state.keys()) == set(olmo_core_state.keys()) + for key, value in olmo_core_state.items(): + torch.testing.assert_close(roundtrip_state[key], value) From 1b021a41c16e51b2f060bfd0f69306bea5b1e07b Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 2 Jun 2026 10:00:10 -0600 Subject: [PATCH 2/2] cleaned up PR --- src/olmo_core/nn/hf/convert.py | 123 ++++++++++----------------------- 1 file changed, 38 insertions(+), 85 deletions(-) diff --git a/src/olmo_core/nn/hf/convert.py b/src/olmo_core/nn/hf/convert.py index a1d03834f..2b0b79af9 100644 --- a/src/olmo_core/nn/hf/convert.py +++ b/src/olmo_core/nn/hf/convert.py @@ -1,5 +1,4 @@ import logging -import re from typing import Any, Dict, List import torch @@ -612,75 +611,58 @@ def convert_state_to_hf( "lm_head.w_out.weight": "lm_head.weight", } -_HYBRID_BLOCK_KEY_RE = re.compile(r"^blocks\.(\d+)\.(.+)$") - -@beta_feature -def convert_hybrid_state_to_hf( - state_dict: Dict[str, Any], - layer_types: List[str], -) -> Dict[str, Any]: +def _build_hybrid_mapping_templates( + layer_types: List[str], *, to_hf: bool +) -> List[StateMappingTemplate]: """ - Convert an OLMo-core hybrid state dict to HF ``olmo_hybrid`` format. + Build concrete (placeholder-free) :class:`StateMappingTemplate`s for a hybrid model. - Uses :data:`HYBRID_SHARED_KEY_MAP` for non-block keys, and per-layer - :data:`HYBRID_GDN_LAYER_KEY_MAP` / :data:`HYBRID_ATTN_LAYER_KEY_MAP` - based on *layer_types*. + For each layer, the GDN (``"linear_attention"``) or attention (``"full_attention"``) suffix + map is selected so the per-layer-type naming difference is baked into the concrete keys. The + resulting templates can be fed to a standard :class:`StateConverter`. - :param state_dict: An unsharded OLMo-core model state dict. :param layer_types: Per-layer type list (``"linear_attention"`` or ``"full_attention"``). + :param to_hf: If ``True``, map OLMo-core keys to HF keys, otherwise HF keys to OLMo-core keys. """ - hf_state: Dict[str, Any] = {} - - for olmo_key, value in state_dict.items(): - # Try shared (non-block) keys first. - if olmo_key in HYBRID_SHARED_KEY_MAP: - hf_state[HYBRID_SHARED_KEY_MAP[olmo_key]] = value - continue - m = _HYBRID_BLOCK_KEY_RE.match(olmo_key) - if m is None: - raise KeyError(f"Unmapped key: {olmo_key}") + def template(olmo_key: str, hf_key: str) -> StateMappingTemplate: + src, dst = (olmo_key, hf_key) if to_hf else (hf_key, olmo_key) + return StateMappingTemplate(src, dst) - layer_idx = int(m.group(1)) - suffix = m.group(2) + templates = [template(olmo_key, hf_key) for olmo_key, hf_key in HYBRID_SHARED_KEY_MAP.items()] + for layer_idx, layer_type in enumerate(layer_types): key_map = ( HYBRID_GDN_LAYER_KEY_MAP - if layer_types[layer_idx] == "linear_attention" + if layer_type == "linear_attention" else HYBRID_ATTN_LAYER_KEY_MAP ) - if suffix not in key_map: - raise KeyError( - f"Unmapped block suffix for layer {layer_idx} " - f"(type={layer_types[layer_idx]!r}): {olmo_key}" - ) - - hf_key = f"model.layers.{layer_idx}.{key_map[suffix]}" - hf_state[hf_key] = value - - return hf_state + templates.extend( + template(f"blocks.{layer_idx}.{olmo_suffix}", f"model.layers.{layer_idx}.{hf_suffix}") + for olmo_suffix, hf_suffix in key_map.items() + ) + return templates -def _invert_hybrid_key_map(key_map: Dict[str, str]) -> Dict[str, str]: - inverse: Dict[str, str] = {} - for olmo_suffix, hf_suffix in key_map.items(): - if hf_suffix in inverse: - raise ValueError(f"Non-invertible hybrid key map: duplicate HF key {hf_suffix!r}") - inverse[hf_suffix] = olmo_suffix - return inverse +@beta_feature +def convert_hybrid_state_to_hf( + state_dict: Dict[str, Any], + layer_types: List[str], +) -> Dict[str, Any]: + """ + Convert an OLMo-core hybrid state dict to HF ``olmo_hybrid`` format. -#: Inverse of the hybrid maps, for the HF -> OLMo-core direction. -_HF_TO_OLMO_HYBRID_SHARED_KEY_MAP: Dict[str, str] = _invert_hybrid_key_map(HYBRID_SHARED_KEY_MAP) -_HF_TO_OLMO_HYBRID_GDN_LAYER_KEY_MAP: Dict[str, str] = _invert_hybrid_key_map( - HYBRID_GDN_LAYER_KEY_MAP -) -_HF_TO_OLMO_HYBRID_ATTN_LAYER_KEY_MAP: Dict[str, str] = _invert_hybrid_key_map( - HYBRID_ATTN_LAYER_KEY_MAP -) + Uses :data:`HYBRID_SHARED_KEY_MAP` for non-block keys, and per-layer + :data:`HYBRID_GDN_LAYER_KEY_MAP` / :data:`HYBRID_ATTN_LAYER_KEY_MAP` + based on *layer_types*. -_HF_HYBRID_BLOCK_KEY_RE = re.compile(r"^model\.layers\.(\d+)\.(.+)$") + :param state_dict: An unsharded OLMo-core model state dict. + :param layer_types: Per-layer type list (``"linear_attention"`` or ``"full_attention"``). + """ + templates = _build_hybrid_mapping_templates(layer_types, to_hf=True) + return StateConverter(templates).convert(state_dict, placeholder_bounds={}) @beta_feature @@ -691,43 +673,14 @@ def convert_hybrid_state_from_hf( """ Convert an HF ``olmo_hybrid`` state dict to OLMo-core format. - Inverse of :func:`convert_hybrid_state_to_hf`: uses the inverse of - :data:`HYBRID_SHARED_KEY_MAP` for non-block keys, and the inverse of - :data:`HYBRID_GDN_LAYER_KEY_MAP` / :data:`HYBRID_ATTN_LAYER_KEY_MAP` - based on *layer_types*. + Inverse of :func:`convert_hybrid_state_to_hf`: uses the same suffix maps in the HF -> + OLMo-core direction, selecting per layer based on *layer_types*. :param state_dict: An unsharded HF ``olmo_hybrid`` model state dict. :param layer_types: Per-layer type list (``"linear_attention"`` or ``"full_attention"``). """ - olmo_state: Dict[str, Any] = {} - - for hf_key, value in state_dict.items(): - # Try shared (non-block) keys first. - if hf_key in _HF_TO_OLMO_HYBRID_SHARED_KEY_MAP: - olmo_state[_HF_TO_OLMO_HYBRID_SHARED_KEY_MAP[hf_key]] = value - continue - - m = _HF_HYBRID_BLOCK_KEY_RE.match(hf_key) - if m is None: - raise KeyError(f"Unmapped key: {hf_key}") - - layer_idx = int(m.group(1)) - suffix = m.group(2) - - key_map = ( - _HF_TO_OLMO_HYBRID_GDN_LAYER_KEY_MAP - if layer_types[layer_idx] == "linear_attention" - else _HF_TO_OLMO_HYBRID_ATTN_LAYER_KEY_MAP - ) - if suffix not in key_map: - raise KeyError( - f"Unmapped block suffix for layer {layer_idx} " - f"(type={layer_types[layer_idx]!r}): {hf_key}" - ) - - olmo_state[f"blocks.{layer_idx}.{key_map[suffix]}"] = value - - return olmo_state + templates = _build_hybrid_mapping_templates(layer_types, to_hf=False) + return StateConverter(templates).convert(state_dict, placeholder_bounds={}) def _hybrid_layer_types_from_config(config: PretrainedConfig) -> List[str]: