Skip to content

Commit cc5424d

Browse files
committed
adds NVFP4 support
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
1 parent 6ef9dcd commit cc5424d

3 files changed

Lines changed: 42 additions & 9 deletions

File tree

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
example_fp4_tensor_stat_collection:
2+
enabled: True
3+
layers:
4+
# Match MiniFold te.Linear sublayers in FP4 blocks
5+
layer_types: [pi, gi, po, go, fc1, fc2]
6+
transformer_engine:
7+
LogNvfp4TensorStats:
8+
enabled: True
9+
tensors_struct:
10+
- tensor: activation
11+
stats: [underflows%, mse]
12+
freq: 100
13+
- tensor: gradient
14+
stats: [underflows%, mse]
15+
freq: 100
16+
17+
example_fp8_tensor_stat_collection:
18+
enabled: True
19+
layers:
20+
# Match MiniFold te.Linear sublayers in FP8 blocks
21+
layer_types: [pi, gi, po, go, fc1, fc2]
22+
transformer_engine:
23+
LogFp8TensorStats:
24+
enabled: True
25+
tensors_struct:
26+
- tensor: activation
27+
stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse]
28+
freq: 100
29+
- tensor: gradient
30+
stats: [mxfp8_underflows%, mxfp8_scale_inv_min, mxfp8_scale_inv_max, mxfp8_mse]
31+
freq: 100

bionemo-recipes/recipes/esm2_minifold_te/quantization.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -331,14 +331,12 @@ def update_quant_stats_config(
331331
config = yaml.safe_load(f)
332332

333333
if "example_fp4_tensor_stat_collection" in config:
334-
config["example_fp4_tensor_stat_collection"]["enabled"] = False
334+
fp4_regex = generate_layer_regex(fp4_layers, component_precision=component_precision)
335+
config["example_fp4_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"] = fp4_regex
335336
if fp4_layers:
336-
logger.warning(
337-
"NVFP4 quant stats logging is not yet supported (requires a future TransformerEngine release). "
338-
f"Disabling FP4 stats collection for blocks {fp4_layers}. FP8 stats will still be collected."
339-
)
337+
logger.info(f"Updated FP4 block regex to match blocks: {fp4_layers}")
340338
else:
341-
logger.info("FP4 stats section disabled (no FP4 blocks and feature not yet supported)")
339+
logger.info("FP4 blocks empty - regex set to match nothing")
342340

343341
if "example_fp8_tensor_stat_collection" in config:
344342
fp8_regex = generate_layer_regex(fp8_layers, component_precision=component_precision)

bionemo-recipes/recipes/esm2_minifold_te/tests/test_quantization.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,15 +256,19 @@ def test_none_layers_disables_matching(fp8_only_config):
256256
assert "DISABLED" in regex
257257

258258

259-
def test_fp4_section_disabled_fp8_still_updated(fp4_fp8_config):
259+
def test_fp4_and_fp8_both_updated(fp4_fp8_config):
260260
output_path = update_quant_stats_config(config_file=fp4_fp8_config, fp4_layers=[1, 2, 3], fp8_layers=[4, 5, 6])
261261
with open(output_path) as f:
262262
result = yaml.safe_load(f)
263263

264-
assert result["example_fp4_tensor_stat_collection"]["enabled"] is False
264+
# FP4 section should have regex for blocks 1-3 (0-indexed 0-2)
265+
fp4_regex = result["example_fp4_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"]
266+
assert re.search(fp4_regex, "fold.miniformer.blocks.0.transition.fc1")
267+
assert re.search(fp4_regex, "fold.miniformer.blocks.2.triangular.pi")
268+
assert not re.search(fp4_regex, "fold.miniformer.blocks.3.triangular.pi")
265269

270+
# FP8 section should have regex for blocks 4-6 (0-indexed 3-5)
266271
fp8_regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"]
267-
# 1-indexed [4,5,6] -> 0-indexed [3,4,5]
268272
assert re.search(fp8_regex, "fold.miniformer.blocks.4.triangular.pi")
269273
assert not re.search(fp8_regex, "fold.miniformer.blocks.1.triangular.pi")
270274

0 commit comments

Comments
 (0)