Skip to content

Commit 08c9ef9

Browse files
Fix to_numpy() to handle bfloat16 tensors (#1346)
NumPy has no bfloat16 dtype, so calling .numpy() on a bfloat16 tensor raises TypeError. Detach/move to CPU, upcast bfloat16 to float32, then convert. bfloat16 is common since many pretrained models load in reduced precision. Adds a TestToNumpy class covering bfloat16, float32/float16/int passthrough, numpy/list/tuple/scalar inputs, and the invalid-type error.
1 parent 34e6dc4 commit 08c9ef9

2 files changed

Lines changed: 80 additions & 2 deletions

File tree

tests/unit/utilities/test_tensors.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
This module tests the tensor utility functions, particularly the filter_dict_by_prefix function.
44
"""
55

6+
import numpy as np
7+
import pytest
68
import torch
79

8-
from transformer_lens.utilities.tensors import filter_dict_by_prefix
10+
from transformer_lens.utilities.tensors import filter_dict_by_prefix, to_numpy
911

1012

1113
class TestFilterDictByPrefix:
@@ -205,3 +207,73 @@ def test_filter_dict_multiple_prefixes_sequentially(self):
205207
assert "layer1.weight" in decoder_result
206208
# Ensure original dict is not modified
207209
assert len(test_dict) == 4
210+
211+
212+
class TestToNumpy:
213+
"""Test cases for the to_numpy function."""
214+
215+
def test_bfloat16_tensor(self):
216+
"""bfloat16 tensors should be upcast to float32 instead of raising a TypeError.
217+
218+
NumPy has no bfloat16 dtype, so calling .numpy() on a bfloat16 tensor raises
219+
``TypeError: Got unsupported ScalarType BFloat16``. bfloat16 is common in
220+
TransformerLens because many pretrained models load in reduced precision.
221+
"""
222+
tensor = torch.tensor([1.0, 2.0, -3.5], dtype=torch.bfloat16)
223+
result = to_numpy(tensor)
224+
assert isinstance(result, np.ndarray)
225+
assert result.dtype == np.float32
226+
# Values that are exactly representable in bfloat16 should round-trip exactly.
227+
np.testing.assert_array_equal(result, np.array([1.0, 2.0, -3.5], dtype=np.float32))
228+
229+
def test_float32_tensor_passthrough(self):
230+
"""float32 tensors should convert without dtype changes."""
231+
tensor = torch.tensor([1.5, 2.5], dtype=torch.float32)
232+
result = to_numpy(tensor)
233+
assert isinstance(result, np.ndarray)
234+
assert result.dtype == np.float32
235+
np.testing.assert_array_equal(result, np.array([1.5, 2.5], dtype=np.float32))
236+
237+
def test_float16_tensor(self):
238+
"""float16 tensors are representable in numpy and should be preserved."""
239+
tensor = torch.tensor([1.0, 2.0], dtype=torch.float16)
240+
result = to_numpy(tensor)
241+
assert isinstance(result, np.ndarray)
242+
assert result.dtype == np.float16
243+
244+
def test_int_tensor(self):
245+
"""Integer tensors should convert without modification."""
246+
tensor = torch.tensor([1, 2, 3], dtype=torch.int64)
247+
result = to_numpy(tensor)
248+
assert isinstance(result, np.ndarray)
249+
np.testing.assert_array_equal(result, np.array([1, 2, 3]))
250+
251+
def test_parameter_bfloat16(self):
252+
"""nn.Parameter wrapping a bfloat16 tensor should also be handled."""
253+
param = torch.nn.Parameter(torch.tensor([4.0, 5.0], dtype=torch.bfloat16))
254+
result = to_numpy(param)
255+
assert isinstance(result, np.ndarray)
256+
assert result.dtype == np.float32
257+
np.testing.assert_array_equal(result, np.array([4.0, 5.0], dtype=np.float32))
258+
259+
def test_numpy_array_passthrough(self):
260+
"""numpy arrays should be returned unchanged."""
261+
array = np.array([1.0, 2.0])
262+
result = to_numpy(array)
263+
assert result is array
264+
265+
def test_list_and_tuple(self):
266+
"""Lists and tuples should be converted to numpy arrays."""
267+
np.testing.assert_array_equal(to_numpy([1, 2, 3]), np.array([1, 2, 3]))
268+
np.testing.assert_array_equal(to_numpy((4, 5, 6)), np.array([4, 5, 6]))
269+
270+
def test_scalar(self):
271+
"""Python scalars should be converted to numpy arrays."""
272+
result = to_numpy(3.5)
273+
assert isinstance(result, np.ndarray)
274+
assert result.item() == 3.5
275+
276+
def test_invalid_type_raises(self):
277+
"""Unsupported types should raise a ValueError."""
278+
with pytest.raises(ValueError, match="invalid type"):
279+
to_numpy({"a": 1})

transformer_lens/utilities/tensors.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,13 @@ def to_numpy(tensor):
2323
array = np.array(tensor)
2424
return array
2525
elif isinstance(tensor, (torch.Tensor, torch.nn.parameter.Parameter)):
26-
return tensor.detach().cpu().numpy()
26+
tensor = tensor.detach().cpu()
27+
# NumPy has no bfloat16 dtype, so calling .numpy() directly on a bfloat16
28+
# tensor raises a TypeError. Upcast to float32 first (bfloat16 is common in
29+
# TransformerLens since many pretrained models are loaded in reduced precision).
30+
if tensor.dtype == torch.bfloat16:
31+
tensor = tensor.to(torch.float32)
32+
return tensor.numpy()
2733
elif isinstance(tensor, (int, float, bool, str)):
2834
return np.array(tensor)
2935
else:

0 commit comments

Comments
 (0)