Skip to content

Commit bba71c1

Browse files
committed
Fix int32 overflow in elem_count for shapes with >INT32_MAX elements
1 parent 894a625 commit bba71c1

2 files changed

Lines changed: 20 additions & 1 deletion

File tree

emlx/c_src/emlx_nif.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ NIF(to_blob) {
218218
}
219219

220220
uint64_t elem_count(std::vector<int> shape) {
221-
return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>{});
221+
return std::accumulate(shape.begin(), shape.end(), uint64_t{1}, std::multiplies<uint64_t>{});
222222
}
223223

224224
NIF(from_blob) {

emlx/test/emlx_test.exs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,4 +170,23 @@ defmodule EMLXTest do
170170
end
171171
end
172172
end
173+
174+
describe "large tensors (element count > INT32_MAX)" do
175+
# Regression: elem_count overflowed int32 for shapes whose element count
176+
# exceeds INT32_MAX, causing "Binary size is too small" on valid binaries.
177+
@tag :large_tensor
178+
test "from_binary accepts shape whose element count exceeds INT32_MAX" do
179+
# Reshape on BinaryBackend first — Nx.from_binary creates a flat 1D tensor
180+
# whose single dimension would also exceed INT32_MAX.
181+
binary = :binary.copy(<<7>>, 2_147_483_648)
182+
183+
t =
184+
Nx.from_binary(binary, :u8, backend: Nx.BinaryBackend)
185+
|> Nx.reshape({2, 1_073_741_824})
186+
|> Nx.backend_transfer(EMLX.Backend)
187+
188+
assert Nx.shape(t) == {2, 1_073_741_824}
189+
assert Nx.to_number(t[0][0]) == 7
190+
end
191+
end
173192
end

0 commit comments

Comments
 (0)