Skip to content

Commit 3482b79

Browse files
committed
fix: use nx-defined implementation for non-full svd computation
1 parent f9894c0 commit 3482b79

1 file changed

Lines changed: 2 additions & 16 deletions

File tree

emlx/lib/emlx/backend.ex

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2140,24 +2140,10 @@ defmodule EMLX.Backend do
21402140
end
21412141

21422142
@impl true
2143-
def block(%Nx.Block.LinAlg.SVD{full_matrices?: full?}, {out_u, out_s, out_v}, [tensor], _fun) do
2143+
def block(%Nx.Block.LinAlg.SVD{full_matrices?: true}, {out_u, out_s, out_v}, [tensor], _fun) do
21442144
t = to_typed_ref(from_nx(tensor), tensor.type, {:f, 32})
21452145
[u, s, vt] = EMLX.linalg_svd(t, true)
2146-
2147-
# MLX always returns full matrices; truncate for full_matrices?: false
2148-
{u_final, vt_final} =
2149-
if full? do
2150-
{u, vt}
2151-
else
2152-
# Truncate U to {m, k} and Vt to {k, n} where k = min(m, n)
2153-
{_m, k} = out_u.shape
2154-
u_sliced = EMLX.slice(u, [0, 0], [elem(EMLX.shape(u), 0), k], [1, 1])
2155-
{_k2, n} = out_v.shape
2156-
vt_sliced = EMLX.slice(vt, [0, 0], [k, n], [1, 1])
2157-
{u_sliced, vt_sliced}
2158-
end
2159-
2160-
{to_nx(u_final, out_u), to_nx(s, out_s), to_nx(vt_final, out_v)}
2146+
{to_nx(u, out_u), to_nx(s, out_s), to_nx(vt, out_v)}
21612147
end
21622148

21632149
@impl true

0 commit comments

Comments
 (0)