We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent fa30df9 commit 7569d0dCopy full SHA for 7569d0d
1 file changed
model2vec/distill/inference.py
@@ -137,10 +137,9 @@ def _encode_with_model(
137
# NOTE: If the dtype is bfloat 16, we convert to float32,
138
# because numpy does not suport bfloat16
139
# See here: https://github.com/numpy/numpy/issues/19808
140
- if hidden.dtype == torch.bfloat16:
141
- hidden = hidden.float()
+ hidden = hidden.float()
142
pooler = getattr(outputs, "pooler_output", None)
143
- if pooler is not None and pooler.dtype == torch.bfloat16:
+ if pooler is not None:
144
pooler = pooler.float()
145
return hidden, pooler, encodings_on_device
146
0 commit comments