|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: LicenseRef-Apache2 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Unit tests for NVEsmEncoder.set_recipes and get_layer_autocast.""" |
| 17 | + |
| 18 | +from contextlib import nullcontext |
| 19 | +from unittest.mock import patch |
| 20 | + |
| 21 | +import pytest |
| 22 | +import transformer_engine.common.recipe |
| 23 | +import transformer_engine.pytorch |
| 24 | + |
| 25 | +from modeling_esm_te import NVEsmConfig, NVEsmEncoder |
| 26 | + |
| 27 | + |
| 28 | +@pytest.fixture |
| 29 | +def encoder(): |
| 30 | + """Create a small NVEsmEncoder on CUDA for testing.""" |
| 31 | + config = NVEsmConfig( |
| 32 | + hidden_size=320, |
| 33 | + intermediate_size=1280, |
| 34 | + num_hidden_layers=6, |
| 35 | + num_attention_heads=20, |
| 36 | + max_position_embeddings=1026, |
| 37 | + ) |
| 38 | + return NVEsmEncoder(config) |
| 39 | + |
| 40 | + |
| 41 | +# -- set_recipes -- |
| 42 | + |
| 43 | + |
| 44 | +def test_all_fp8(encoder): |
| 45 | + encoder.config.layer_precision = ["fp8"] * 6 |
| 46 | + fp8_recipe = transformer_engine.common.recipe.DelayedScaling() |
| 47 | + encoder.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=None) |
| 48 | + assert encoder._fp8_recipe is fp8_recipe |
| 49 | + assert encoder._fp4_recipe is None |
| 50 | + assert all(p == "fp8" for p in encoder.config.layer_precision) |
| 51 | + |
| 52 | + |
| 53 | +def test_all_fp4(encoder): |
| 54 | + encoder.config.layer_precision = ["fp4"] * 6 |
| 55 | + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() |
| 56 | + encoder.set_recipes(fp8_recipe=None, fp4_recipe=fp4_recipe) |
| 57 | + assert encoder._fp8_recipe is None |
| 58 | + assert encoder._fp4_recipe is fp4_recipe |
| 59 | + assert all(p == "fp4" for p in encoder.config.layer_precision) |
| 60 | + |
| 61 | + |
| 62 | +def test_all_bf16(encoder): |
| 63 | + encoder.config.layer_precision = [None] * 6 |
| 64 | + encoder.set_recipes(fp8_recipe=None, fp4_recipe=None) |
| 65 | + assert all(p is None for p in encoder.config.layer_precision) |
| 66 | + |
| 67 | + |
| 68 | +def test_mixed_fp8_fp4(encoder): |
| 69 | + encoder.config.layer_precision = ["fp8", "fp8", "fp8", "fp4", "fp4", "fp4"] |
| 70 | + fp8_recipe = transformer_engine.common.recipe.DelayedScaling() |
| 71 | + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() |
| 72 | + encoder.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) |
| 73 | + assert encoder.config.layer_precision == ["fp8", "fp8", "fp8", "fp4", "fp4", "fp4"] |
| 74 | + |
| 75 | + |
| 76 | +def test_mixed_fp8_bf16(encoder): |
| 77 | + encoder.config.layer_precision = ["fp8", None, "fp8", None, "fp8", None] |
| 78 | + fp8_recipe = transformer_engine.common.recipe.DelayedScaling() |
| 79 | + encoder.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=None) |
| 80 | + assert encoder.config.layer_precision == ["fp8", None, "fp8", None, "fp8", None] |
| 81 | + |
| 82 | + |
| 83 | +def test_mixed_all_three(encoder): |
| 84 | + encoder.config.layer_precision = ["fp8", "fp8", None, None, "fp4", "fp4"] |
| 85 | + fp8_recipe = transformer_engine.common.recipe.DelayedScaling() |
| 86 | + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() |
| 87 | + encoder.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) |
| 88 | + assert encoder.config.layer_precision == ["fp8", "fp8", None, None, "fp4", "fp4"] |
| 89 | + |
| 90 | + |
| 91 | +def test_covers_all_layers(encoder): |
| 92 | + encoder.config.layer_precision = ["fp8"] + [None] * 5 |
| 93 | + encoder.set_recipes(fp8_recipe=transformer_engine.common.recipe.DelayedScaling(), fp4_recipe=None) |
| 94 | + assert len(encoder.config.layer_precision) == 6 |
| 95 | + |
| 96 | + |
| 97 | +def test_recipes_stored_as_attributes(encoder): |
| 98 | + encoder.config.layer_precision = ["fp8", "fp4", None, None, None, None] |
| 99 | + fp8_recipe = transformer_engine.common.recipe.DelayedScaling() |
| 100 | + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() |
| 101 | + encoder.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) |
| 102 | + assert encoder._fp8_recipe is fp8_recipe |
| 103 | + assert encoder._fp4_recipe is fp4_recipe |
| 104 | + # The precision list only contains strings/None, not recipe objects. |
| 105 | + for v in encoder.config.layer_precision: |
| 106 | + assert v is None or isinstance(v, str) |
| 107 | + |
| 108 | + |
| 109 | +# -- get_layer_autocast -- |
| 110 | + |
| 111 | + |
| 112 | +def test_fp8_layer_returns_nullcontext(encoder): |
| 113 | + encoder.config.layer_precision = ["fp8"] + [None] * 5 |
| 114 | + encoder.set_recipes(fp8_recipe=transformer_engine.common.recipe.DelayedScaling(), fp4_recipe=None) |
| 115 | + ctx = encoder.get_layer_autocast(0) |
| 116 | + assert isinstance(ctx, nullcontext) |
| 117 | + |
| 118 | + |
| 119 | +def test_fp4_layer_returns_te_autocast(encoder): |
| 120 | + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() |
| 121 | + encoder.config.layer_precision = ["fp4"] + [None] * 5 |
| 122 | + encoder.set_recipes(fp8_recipe=None, fp4_recipe=fp4_recipe) |
| 123 | + with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast: |
| 124 | + mock_autocast.return_value = "fp4_context" |
| 125 | + ctx = encoder.get_layer_autocast(0) |
| 126 | + mock_autocast.assert_called_once_with(enabled=True, recipe=fp4_recipe) |
| 127 | + assert ctx == "fp4_context" |
| 128 | + |
| 129 | + |
| 130 | +def test_bf16_layer_returns_te_autocast_disabled(encoder): |
| 131 | + encoder.config.layer_precision = [None] * 6 |
| 132 | + encoder.set_recipes(fp8_recipe=None, fp4_recipe=None) |
| 133 | + with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast: |
| 134 | + mock_autocast.return_value = "bf16_context" |
| 135 | + ctx = encoder.get_layer_autocast(0) |
| 136 | + mock_autocast.assert_called_once_with(enabled=False) |
| 137 | + assert ctx == "bf16_context" |
| 138 | + |
| 139 | + |
| 140 | +def test_uninitialized_defaults_to_bf16(encoder): |
| 141 | + """When layer_precision is None (default), all layers default to BF16.""" |
| 142 | + assert encoder.config.layer_precision is None |
| 143 | + with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast: |
| 144 | + mock_autocast.return_value = "bf16_context" |
| 145 | + ctx = encoder.get_layer_autocast(0) |
| 146 | + mock_autocast.assert_called_once_with(enabled=False) |
| 147 | + assert ctx == "bf16_context" |
| 148 | + |
| 149 | + |
| 150 | +def test_mixed_layers_return_correct_contexts(encoder): |
| 151 | + fp8_recipe = transformer_engine.common.recipe.DelayedScaling() |
| 152 | + fp4_recipe = transformer_engine.common.recipe.NVFP4BlockScaling() |
| 153 | + encoder.config.layer_precision = ["fp8", "fp8", "fp4", "fp4", None, None] |
| 154 | + encoder.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=fp4_recipe) |
| 155 | + |
| 156 | + # FP8 layers -> nullcontext |
| 157 | + assert isinstance(encoder.get_layer_autocast(0), nullcontext) |
| 158 | + assert isinstance(encoder.get_layer_autocast(1), nullcontext) |
| 159 | + |
| 160 | + # FP4 layers -> te.pytorch.autocast |
| 161 | + with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast: |
| 162 | + mock_autocast.return_value = "fp4_context" |
| 163 | + encoder.get_layer_autocast(2) |
| 164 | + mock_autocast.assert_called_with(enabled=True, recipe=fp4_recipe) |
| 165 | + |
| 166 | + # BF16 layers -> te.pytorch.autocast(enabled=False) |
| 167 | + with patch.object(transformer_engine.pytorch, "autocast") as mock_autocast: |
| 168 | + mock_autocast.return_value = "bf16_context" |
| 169 | + encoder.get_layer_autocast(4) |
| 170 | + mock_autocast.assert_called_with(enabled=False) |
| 171 | + |
| 172 | + |
| 173 | +def test_layer_precision_is_pickleable(encoder): |
| 174 | + """The config.layer_precision list should be trivially pickleable.""" |
| 175 | + import pickle |
| 176 | + |
| 177 | + encoder.config.layer_precision = ["fp8", "fp8", "fp4", "fp4", None, None] |
| 178 | + roundtripped = pickle.loads(pickle.dumps(encoder.config.layer_precision)) |
| 179 | + assert roundtripped == encoder.config.layer_precision |
0 commit comments