Skip to content

Commit f7bef3f

Browse files
authored
adds comment to explain why we always use bf16 heads (#1431)
1 parent e5e58c8 commit f7bef3f

4 files changed

Lines changed: 8 additions & 0 deletions

File tree

bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,8 @@ def forward(self, features, **kwargs):
484484
features (torch.Tensor): The features.
485485
**kwargs: Additional arguments.
486486
"""
487+
# Keep the last layers of the network in higher precision to avoid numerical instability.
488+
# Please see recipes/fp8_analysis/README.md for more details.
487489
with transformer_engine.pytorch.fp8_autocast(enabled=False):
488490
x = self.dense(features)
489491
x = torch.nn.functional.gelu(x)

bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,8 @@ def forward(self, features, **kwargs):
484484
features (torch.Tensor): The features.
485485
**kwargs: Additional arguments.
486486
"""
487+
# Keep the last layers of the network in higher precision to avoid numerical instability.
488+
# Please see recipes/fp8_analysis/README.md for more details.
487489
with transformer_engine.pytorch.fp8_autocast(enabled=False):
488490
x = self.dense(features)
489491
x = torch.nn.functional.gelu(x)

bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,8 @@ def forward(self, features, **kwargs):
484484
features (torch.Tensor): The features.
485485
**kwargs: Additional arguments.
486486
"""
487+
# Keep the last layers of the network in higher precision to avoid numerical instability.
488+
# Please see recipes/fp8_analysis/README.md for more details.
487489
with transformer_engine.pytorch.fp8_autocast(enabled=False):
488490
x = self.dense(features)
489491
x = torch.nn.functional.gelu(x)

bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,8 @@ def forward(self, features, **kwargs):
484484
features (torch.Tensor): The features.
485485
**kwargs: Additional arguments.
486486
"""
487+
# Keep the last layers of the network in higher precision to avoid numerical instability.
488+
# Please see recipes/fp8_analysis/README.md for more details.
487489
with transformer_engine.pytorch.fp8_autocast(enabled=False):
488490
x = self.dense(features)
489491
x = torch.nn.functional.gelu(x)

0 commit comments

Comments
 (0)