Skip to content

Commit 470e10d

Browse files
jomitchellnvJonathan Mitchell
andauthored
ESM2 NVFP4 and MXFP8 support and documentation update. (#1484)
Layer-wise MXFP8/NVFP4 precision for ESM-2 TransformerEngine training Adds support for per-layer quantization precision control, enabling mixed FP8/FP4/BF16 configurations across transformer layers during training. This allows users to assign different quantization formats to different layers via Hydra config (1-indexed fp8_layers and fp4_layers lists), enabling convergence/performance tradeoff exploration. Key changes: - Per-layer quantization context in encoder forward: NVEsmEncoder now maintains a layer_number_quantized_recipe_map that selects the appropriate TE autocast context per layer (nullcontext for FP8 to respect outer autocast, explicit autocast for FP4, or autocast(enabled=False) for BF16). - quantization.py: New utilities for resolving layer-wise quantization assignments (resolve_quantization_layers), generating debug API regex patterns (generate_layer_regex), and initializing nvdlfw_inspect quant stats logging (initialize_quant_stats_logging). Handles 0-indexed (model internals) and 1-indexed (user-facing) layer numbering. - train_ddp.py / train_fsdp2.py: Integrated layer-wise quantization setup -- resolves layer assignments from config, builds recipe map, assigns to encoder, and optionally initializes quant stats logging. - train_fsdp2_cp.py: Switched from AutoConfig/AutoModelForMaskedLM to local NVEsmConfig/NVEsmForMaskedLM for consistency and to avoid remote code trust issues. - Hydra config (defaults.yaml): Added fp4_config, quant_stats_config, fp8_layers, fp4_layers, and use_fp32_master_weights settings. - Model files: Updated esm_nv.py across all checkpoint directories (native_te, accelerate, peft) and the models package with layer-wise quantization support, NVTX annotations per encoder layer, and FP8_RECIPES/FP4_RECIPES type constants. - Tests: Added comprehensive tests for resolve_quantization_layers, generate_layer_regex, and update_quant_stats_config covering defaults, explicit layers, mixed assignments, overlap validation, and edge cases.### Description #### Usage <!--- How does a user interact with the changed code --> ```python TODO: Add code snippet ``` ### Type of changes <!-- Mark the relevant option with an [x] --> - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Refactor - [ ] Documentation update - [ ] Other (please describe): ### CI Pipeline Configuration Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run. - [ciflow:skip](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:skip) - Skip all CI tests for this PR - [ciflow:notebooks](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:notebooks) - Run Jupyter notebooks execution tests for bionemo2 - [ciflow:slow](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:slow) - Run slow single GPU integration tests marked as @pytest.mark.slow for bionemo2 - [ciflow:all](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all) - Run all tests (unit tests, slow tests, and notebooks) for bionemo2. This label can be used to enforce running tests for all bionemo2. - [ciflow:all-recipes](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all-recipes) - Run tests for all recipes (under bionemo-recipes). This label can be used to enforce running tests for all recipes. Unit tests marked as `@pytest.mark.multi_gpu` or `@pytest.mark.distributed` are not run in the PR pipeline. For more details, see [CONTRIBUTING](CONTRIBUTING.md) > [!NOTE] > By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage. #### Authorizing CI Runs We use [copy-pr-bot](https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/#automation) to manage authorization of CI runs on NVIDIA's compute resources. - If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123) - If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an `/ok to test` comment on the pull request to trigger CI. This will need to be done for each new commit. #### Triggering Code Rabbit AI Review To trigger a code review from code rabbit, comment on a pull request with one of these commands: - @coderabbitai review - Triggers a standard review - @coderabbitai full review - Triggers a comprehensive review See https://docs.coderabbit.ai/reference/review-commands for a full list of commands. ### Pre-submit Checklist <!--- Ensure all items are completed before submitting --> - [ ] I have tested these changes locally - [ ] I have updated the documentation accordingly - [ ] I have added/updated tests as needed - [ ] All existing tests pass successfully <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Layer-wise FP8/FP4 quantization control and runtime recipe attachment for ESM2 * TransformerEngine-optimized ESM2 model variants with per-layer autocast and NVTX profiling * **Documentation** * Expanded low-precision training guide and quantization debugging examples; added benchmarks and convergence notes * **Configuration** * New FP4 settings block and unified quantization statistics config; support for specifying per-layer precision * **Tests** * Added comprehensive tests for quantization utilities and per-layer behavior <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com> Co-authored-by: Jonathan Mitchell <jomitchell@umb-b300-dp-218.cl1u1.colossus.nvidia.com>
1 parent b2ddae1 commit 470e10d

33 files changed

Lines changed: 2493 additions & 257 deletions

.vscode/settings.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,6 @@
2626
"editor.rulers": [
2727
120
2828
],
29-
"autoDocstring.docstringFormat": "google-notypes"
29+
"autoDocstring.docstringFormat": "google-notypes",
30+
"search.exclude": { "**/logs/**": true },
3031
}

bionemo-recipes/models/esm2/modeling_esm_te.py

Lines changed: 72 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@
2323
"""
2424

2525
import warnings
26+
from contextlib import nullcontext
2627
from typing import ClassVar, Literal, Optional, Unpack
2728

2829
# TODO: put import guard around transformer_engine here, with an informative error message around
2930
# installation and the nvidia docker container.
3031
import torch
32+
import transformer_engine.common.recipe
3133
import transformer_engine.pytorch
3234
from torch import nn
3335
from torch.nn import CrossEntropyLoss
@@ -71,6 +73,7 @@ def __init__(
7173
max_seq_length: Optional[int] = None,
7274
padded_vocab_size: Optional[int] = 64,
7375
attn_mask_type: str = "padding",
76+
layer_precision: list[str | None] | None = None,
7477
**kwargs,
7578
):
7679
"""Initialize the NVEsmConfig with additional TE-related config options.
@@ -100,6 +103,9 @@ def __init__(
100103
padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults
101104
to vocab_size. Must be greater than or equal to vocab_size.
102105
attn_mask_type: The type of attention mask to use.
106+
layer_precision: Per-layer quantization precision, a list of length ``num_hidden_layers``
107+
where each element is ``"fp8"``, ``"fp4"``, or ``None`` (BF16 fallback). ``None``
108+
(the default) means no quantization is configured.
103109
**kwargs: Additional config options to pass to EsmConfig.
104110
"""
105111
super().__init__(**kwargs)
@@ -111,6 +117,7 @@ def __init__(
111117
self.micro_batch_size = micro_batch_size
112118
self.max_seq_length = max_seq_length
113119
self.attn_mask_type = attn_mask_type
120+
self.layer_precision = layer_precision
114121

115122
# Set padded_vocab_size with default fallback to vocab_size
116123
self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size
@@ -165,6 +172,8 @@ def _init_method(x):
165172
for i in range(config.num_hidden_layers)
166173
]
167174
)
175+
self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = None
176+
self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = None
168177
self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm(
169178
config.hidden_size,
170179
eps=config.layer_norm_eps,
@@ -174,6 +183,49 @@ def _init_method(x):
174183
if config.position_embedding_type == "rotary":
175184
self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
176185

186+
def set_recipes(
187+
self,
188+
fp8_recipe: transformer_engine.common.recipe.Recipe | None = None,
189+
fp4_recipe: transformer_engine.common.recipe.Recipe | None = None,
190+
) -> None:
191+
"""Attach quantization recipe objects for per-layer autocast.
192+
193+
Recipes are not serializable and must be set at runtime after model creation
194+
and sharding (FSDP/DDP/mFSDP) but before training. The per-layer precision
195+
assignments are read from ``self.config.layer_precision``.
196+
197+
These recipes are also hardware specific, so we should not store them as
198+
attributes of the model and attach them at runtime.
199+
200+
Args:
201+
fp8_recipe: The FP8 recipe instance (e.g., MXFP8BlockScaling), or None.
202+
fp4_recipe: The FP4 recipe instance (e.g., NVFP4BlockScaling), or None.
203+
"""
204+
self._fp8_recipe = fp8_recipe
205+
self._fp4_recipe = fp4_recipe
206+
207+
def get_layer_autocast(self, layer_number: int):
208+
"""Return the appropriate TE autocast context manager for a given layer.
209+
210+
The context interacts with the outer FP8 autocast in the training script:
211+
- FP8 layer: nullcontext() -- lets the outer FP8 autocast take effect.
212+
- FP4 layer: te.pytorch.autocast(enabled=True, recipe=fp4_recipe) -- overrides to FP4.
213+
- BF16 layer: te.pytorch.autocast(enabled=False) -- disables quantized compute.
214+
215+
Args:
216+
layer_number: The 0-indexed layer number.
217+
218+
Returns:
219+
A context manager for the layer's quantization mode.
220+
"""
221+
precision = self.config.layer_precision[layer_number] if self.config.layer_precision is not None else None
222+
if precision == "fp8":
223+
return nullcontext()
224+
elif precision == "fp4":
225+
return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp4_recipe)
226+
else:
227+
return transformer_engine.pytorch.autocast(enabled=False)
228+
177229
def forward(
178230
self,
179231
hidden_states: torch.Tensor,
@@ -201,22 +253,26 @@ def forward(
201253
if te_rope_emb.dtype == torch.float32:
202254
warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning)
203255

204-
for layer_module in self.layers:
205-
if kwargs.get("output_hidden_states", False):
206-
all_hidden_states = (*all_hidden_states, hidden_states)
207-
208-
hidden_states = layer_module(
209-
hidden_states,
210-
attention_mask,
211-
rotary_pos_emb=te_rope_emb,
212-
cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
213-
cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
214-
cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None),
215-
cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None),
216-
max_seqlen_q=kwargs.get("max_length_q", None),
217-
max_seqlen_kv=kwargs.get("max_length_k", None),
218-
pad_between_seqs=kwargs.get("pad_between_seqs", None),
219-
)
256+
# Outer FP8 autocast enables FP8 compute for the encoder stack. Per-layer overrides (FP4, BF16) are handled
257+
# by get_layer_autocast(), which nests inside this context.
258+
with transformer_engine.pytorch.autocast(enabled=self._fp8_recipe is not None, recipe=self._fp8_recipe):
259+
for layer_number, layer_module in enumerate(self.layers):
260+
if kwargs.get("output_hidden_states", False):
261+
all_hidden_states = (*all_hidden_states, hidden_states)
262+
263+
with self.get_layer_autocast(layer_number):
264+
hidden_states = layer_module(
265+
hidden_states,
266+
attention_mask,
267+
rotary_pos_emb=te_rope_emb,
268+
cu_seqlens_q=kwargs.get("cu_seq_lens_q", None),
269+
cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None),
270+
cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None),
271+
cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None),
272+
max_seqlen_q=kwargs.get("max_length_q", None),
273+
max_seqlen_kv=kwargs.get("max_length_k", None),
274+
pad_between_seqs=kwargs.get("pad_between_seqs", None),
275+
)
220276

221277
hidden_states = self.emb_layer_norm_after(hidden_states)
222278

bionemo-recipes/models/esm2/tests/test_distributed_fp8.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,10 @@ def is_main_process(self) -> bool:
161161
)
162162
device = f"cuda:{dist_config.local_rank}"
163163

164+
fp8_recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_compute_algo="max", amax_history_len=10)
165+
164166
config = NVEsmConfig.from_pretrained("facebook/esm2_t6_8M_UR50D", dtype=torch.bfloat16, revision="c731040f")
167+
config.layer_precision = ["fp8"] * config.num_hidden_layers
165168
model = NVEsmForMaskedLM(config)
166169

167170
if args.strategy is Strategy.FSDP2:
@@ -195,13 +198,15 @@ def is_main_process(self) -> bool:
195198
tp_dim="tp",
196199
)
197200

201+
# Attach FP8 recipes to the encoder (layer precision is already on config).
202+
encoder = model.module.esm.encoder if args.strategy is Strategy.DDP else model.esm.encoder
203+
encoder.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=None)
204+
198205
model.train()
199206

200207
generator = torch.Generator()
201208
generator.manual_seed(torch.distributed.get_rank())
202209

203-
fp8_recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_compute_algo="max", amax_history_len=10)
204-
205210
for _ in range(3):
206211
input_data = {
207212
"input_ids": torch.randint(0, config.vocab_size, (1, 32), generator=generator),
@@ -211,8 +216,7 @@ def is_main_process(self) -> bool:
211216
input_data = {k: v.to(torch.cuda.current_device()) for k, v in input_data.items()}
212217

213218
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
214-
with transformer_engine.pytorch.autocast(enabled=True, recipe=fp8_recipe):
215-
outputs = model(**input_data)
219+
outputs = model(**input_data)
216220

217221
outputs.loss.backward()
218222

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Unit tests for NVEsmEncoder.set_recipes and get_layer_autocast."""
17+
18+
from contextlib import nullcontext
19+
from unittest.mock import patch
20+
21+
import pytest
22+
import transformer_engine.common.recipe
23+
import transformer_engine.pytorch
24+
25+
from modeling_esm_te import NVEsmConfig, NVEsmEncoder
26+
27+
28+
@pytest.fixture
29+
def encoder():
30+
"""Create a small NVEsmEncoder on CUDA for testing."""
31+
config = NVEsmConfig(
32+
hidden_size=320,
33+
intermediate_size=1280,
34+
num_hidden_layers=6,
35+
num_attention_heads=20,
36+
max_position_embeddings=1026,
37+
)
38+
return NVEsmEncoder(config)
39+
40+
41+
# -- set_recipes --
42+
43+
44+
def test_all_fp8(encoder):
45+
encoder.config.layer_precision = ["fp8"] * 6
46+
fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
47+
encoder.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=None)
48+
assert encoder._fp8_recipe is fp8_recipe
49+
assert encoder._fp4_recipe is None
50+
assert all(p == "fp8" for p in encoder.config.layer_precision)
51+
52+
53+
def test_all_fp4(encoder):
54+
encoder.config.layer_precision = ["fp4"] * 6
55+
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
56+
encoder.set_recipes(fp8_recipe=None, fp4_recipe=fp4_recipe)
57+
assert encoder._fp8_recipe is None
58+
assert encoder._fp4_recipe is fp4_recipe
59+
assert all(p == "fp4" for p in encoder.config.layer_precision)
60+
61+
62+
def test_all_bf16(encoder):
63+
encoder.config.layer_precision = [None] * 6
64+
encoder.set_recipes(fp8_recipe=None, fp4_recipe=None)
65+
assert all(p is None for p in encoder.config.layer_precision)
66+
67+
68+
def test_mixed_fp8_fp4(encoder):
69+
encoder.config.layer_precision = ["fp8", "fp8", "fp8", "fp4", "fp4", "fp4"]
70+
fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
71+
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
72+
encoder.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
73+
assert encoder.config.layer_precision == ["fp8", "fp8", "fp8", "fp4", "fp4", "fp4"]
74+
75+
76+
def test_mixed_fp8_bf16(encoder):
77+
encoder.config.layer_precision = ["fp8", None, "fp8", None, "fp8", None]
78+
fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
79+
encoder.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=None)
80+
assert encoder.config.layer_precision == ["fp8", None, "fp8", None, "fp8", None]
81+
82+
83+
def test_mixed_all_three(encoder):
84+
encoder.config.layer_precision = ["fp8", "fp8", None, None, "fp4", "fp4"]
85+
fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
86+
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
87+
encoder.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
88+
assert encoder.config.layer_precision == ["fp8", "fp8", None, None, "fp4", "fp4"]
89+
90+
91+
def test_covers_all_layers(encoder):
92+
encoder.config.layer_precision = ["fp8"] + [None] * 5
93+
encoder.set_recipes(fp8_recipe=transformer_engine.common.recipe.DelayedScaling(), fp4_recipe=None)
94+
assert len(encoder.config.layer_precision) == 6
95+
96+
97+
def test_recipes_stored_as_attributes(encoder):
98+
encoder.config.layer_precision = ["fp8", "fp4", None, None, None, None]
99+
fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
100+
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
101+
encoder.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
102+
assert encoder._fp8_recipe is fp8_recipe
103+
assert encoder._fp4_recipe is fp4_recipe
104+
# The precision list only contains strings/None, not recipe objects.
105+
for v in encoder.config.layer_precision:
106+
assert v is None or isinstance(v, str)
107+
108+
109+
# -- get_layer_autocast --
110+
111+
112+
def test_fp8_layer_returns_nullcontext(encoder):
113+
encoder.config.layer_precision = ["fp8"] + [None] * 5
114+
encoder.set_recipes(fp8_recipe=transformer_engine.common.recipe.DelayedScaling(), fp4_recipe=None)
115+
ctx = encoder.get_layer_autocast(0)
116+
assert isinstance(ctx, nullcontext)
117+
118+
119+
def test_fp4_layer_returns_te_autocast(encoder):
120+
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
121+
encoder.config.layer_precision = ["fp4"] + [None] * 5
122+
encoder.set_recipes(fp8_recipe=None, fp4_recipe=fp4_recipe)
123+
with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
124+
mock_autocast.return_value = "fp4_context"
125+
ctx = encoder.get_layer_autocast(0)
126+
mock_autocast.assert_called_once_with(enabled=True, recipe=fp4_recipe)
127+
assert ctx == "fp4_context"
128+
129+
130+
def test_bf16_layer_returns_te_autocast_disabled(encoder):
131+
encoder.config.layer_precision = [None] * 6
132+
encoder.set_recipes(fp8_recipe=None, fp4_recipe=None)
133+
with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
134+
mock_autocast.return_value = "bf16_context"
135+
ctx = encoder.get_layer_autocast(0)
136+
mock_autocast.assert_called_once_with(enabled=False)
137+
assert ctx == "bf16_context"
138+
139+
140+
def test_uninitialized_defaults_to_bf16(encoder):
141+
"""When layer_precision is None (default), all layers default to BF16."""
142+
assert encoder.config.layer_precision is None
143+
with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
144+
mock_autocast.return_value = "bf16_context"
145+
ctx = encoder.get_layer_autocast(0)
146+
mock_autocast.assert_called_once_with(enabled=False)
147+
assert ctx == "bf16_context"
148+
149+
150+
def test_mixed_layers_return_correct_contexts(encoder):
151+
fp8_recipe = transformer_engine.common.recipe.DelayedScaling()
152+
fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling()
153+
encoder.config.layer_precision = ["fp8", "fp8", "fp4", "fp4", None, None]
154+
encoder.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe)
155+
156+
# FP8 layers -> nullcontext
157+
assert isinstance(encoder.get_layer_autocast(0), nullcontext)
158+
assert isinstance(encoder.get_layer_autocast(1), nullcontext)
159+
160+
# FP4 layers -> te.pytorch.autocast
161+
with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
162+
mock_autocast.return_value = "fp4_context"
163+
encoder.get_layer_autocast(2)
164+
mock_autocast.assert_called_with(enabled=True, recipe=fp4_recipe)
165+
166+
# BF16 layers -> te.pytorch.autocast(enabled=False)
167+
with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast:
168+
mock_autocast.return_value = "bf16_context"
169+
encoder.get_layer_autocast(4)
170+
mock_autocast.assert_called_with(enabled=False)
171+
172+
173+
def test_layer_precision_is_pickleable(encoder):
174+
"""The config.layer_precision list should be trivially pickleable."""
175+
import pickle
176+
177+
encoder.config.layer_precision = ["fp8", "fp8", "fp4", "fp4", None, None]
178+
roundtripped = pickle.loads(pickle.dumps(encoder.config.layer_precision))
179+
assert roundtripped == encoder.config.layer_precision

0 commit comments

Comments
 (0)