Skip to content

Commit 56e97c8

Browse files
authored
Bug fix 5875873 (#865)
## What does this PR do? **Type of change:** Bug fix <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** Newer version of trl uses dtype instead of torch_dtype. Modified code to set float32 as default for older versions of trl that you torch_dtype. ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Enhanced error handling in model training examples to safely manage missing dtype attributes, preventing crashes during initialization when torch_dtype is not configured. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>
1 parent 110a44c commit 56e97c8

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

examples/gpt-oss/sft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def main(script_args, training_args, model_args, quant_args):
7272
"revision": model_args.model_revision,
7373
"trust_remote_code": model_args.trust_remote_code,
7474
"attn_implementation": model_args.attn_implementation,
75-
"torch_dtype": model_args.torch_dtype,
75+
"torch_dtype": getattr(model_args, "dtype", "float32"),
7676
"use_cache": not training_args.gradient_checkpointing,
7777
}
7878

0 commit comments

Comments
 (0)