Commit 56e97c8
authored
Bug fix 5875873 (#865)
## What does this PR do?
**Type of change:** Bug fix <!-- Use one of the following: Bug fix, new
feature, new example, new tests, documentation. -->
**Overview:** Newer version of trl uses dtype instead of torch_dtype.
Modified code to set float32 as default for older versions of trl that
you torch_dtype.
## Usage
<!-- You can potentially add a usage example below. -->
```python
# Add a code snippet demonstrating how to use this
```
## Testing
<!-- Mention how have you tested your change if applicable. -->
## Before your PR is "*Ready for review*"
<!-- If you haven't finished some of the above items you can still open
`Draft` PR. -->
- **Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)**
and your commits are signed.
- **Is this change backward compatible?**: Yes/No <!--- If No, explain
why. -->
- **Did you write any new necessary tests?**: Yes/No
- **Did you add or update any necessary documentation?**: Yes/No
- **Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**:
Yes/No <!--- Only for new features, API changes, critical bug fixes or
bw breaking changes. -->
## Additional Information
<!-- E.g. related issue. -->
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
* **Bug Fixes**
* Enhanced error handling in model training examples to safely manage
missing dtype attributes, preventing crashes during initialization when
torch_dtype is not configured.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
---------
Signed-off-by: Suguna Velury <178320438+sugunav14@users.noreply.github.com>1 parent 110a44c commit 56e97c8
1 file changed
+1
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
72 | 72 | | |
73 | 73 | | |
74 | 74 | | |
75 | | - | |
| 75 | + | |
76 | 76 | | |
77 | 77 | | |
78 | 78 | | |
| |||
0 commit comments