[wwb] move to dtype auto#3793
Conversation
There was a problem hiding this comment.
Pull request overview
Updates WWB’s default HuggingFace model loading dtype behavior to rely on Transformers’ automatic dtype selection rather than forcing fp32.
Changes:
- Switch
PYTORCH_MODEL_DTYPE_KWARGfromtorch.float32to"auto"for HF model loading in WWB.
| PYTORCH_MODEL_DTYPE_KWARG = {"torch_dtype": "auto"} | ||
|
|
||
|
|
There was a problem hiding this comment.
Setting torch_dtype to "auto" for all HuggingFace loads can break CPU execution: many models advertise fp16/bf16 in config, and loading them on CPU frequently leads to runtime errors (e.g., Half ops not implemented) or unexpected slowdowns. Consider keeping torch.float32 when device is CPU / CUDA is unavailable, and only using "auto" on GPU (or make dtype configurable via CLI).
| PYTORCH_MODEL_DTYPE_KWARG = {"torch_dtype": "auto"} | |
| def _get_pytorch_model_dtype_kwarg(): | |
| if torch.cuda.is_available(): | |
| return {"torch_dtype": "auto"} | |
| return {"torch_dtype": torch.float32} | |
| PYTORCH_MODEL_DTYPE_KWARG = _get_pytorch_model_dtype_kwarg() |
| logger = logging.getLogger(__name__) | ||
|
|
||
| PYTORCH_MODEL_DTYPE_KWARG = {"torch_dtype": torch.float32} | ||
| PYTORCH_MODEL_DTYPE_KWARG = {"torch_dtype": "auto"} |
There was a problem hiding this comment.
PR description still contains placeholders (e.g., CVS-###, Fixes #(issue)) and the checklist is not filled out. Please update the PR description to match the repository template before merging so reviewers can confirm scope, tests, and docs impact.
| logger = logging.getLogger(__name__) | ||
|
|
||
| PYTORCH_MODEL_DTYPE_KWARG = {"torch_dtype": torch.float32} | ||
| PYTORCH_MODEL_DTYPE_KWARG = {"torch_dtype": "auto"} |
There was a problem hiding this comment.
This change alters the default dtype selection for HF models, but the WWB test suite doesn’t appear to cover the new behavior (e.g., loading a model whose config defaults to fp16/bf16 on CPU). Please add/update a WWB test to exercise HF loading with a non-fp32 default dtype (ideally using a tiny-random model) so regressions are caught.
rkazants
left a comment
There was a problem hiding this comment.
Before moving to auto type, please make sure that there will be no problems with CPU.
That is because I expect multiple JIRA tickets after this assigned to optimum-intel. Hovewer, the problem can be with CPU plugin.
So I would ask you to have validation runs and determine all existing issue in CPU. So we will know of it in advance.
yes, this pr was created for validation purpose |
Description
CVS-###
Fixes #(issue)
Checklist: