Skip to content

Commit 7569d0d

Browse files
authored
feat: convert output to float in distill (#320)
1 parent fa30df9 commit 7569d0d

1 file changed

Lines changed: 2 additions & 3 deletions

File tree

model2vec/distill/inference.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,9 @@ def _encode_with_model(
137137
# NOTE: If the dtype is bfloat 16, we convert to float32,
138138
# because numpy does not suport bfloat16
139139
# See here: https://github.com/numpy/numpy/issues/19808
140-
if hidden.dtype == torch.bfloat16:
141-
hidden = hidden.float()
140+
hidden = hidden.float()
142141
pooler = getattr(outputs, "pooler_output", None)
143-
if pooler is not None and pooler.dtype == torch.bfloat16:
142+
if pooler is not None:
144143
pooler = pooler.float()
145144
return hidden, pooler, encodings_on_device
146145

0 commit comments

Comments
 (0)