@@ -2140,24 +2140,10 @@ defmodule EMLX.Backend do
21402140 end
21412141
21422142 @ impl true
2143- def block ( % Nx.Block.LinAlg.SVD { full_matrices?: full? } , { out_u , out_s , out_v } , [ tensor ] , _fun ) do
2143+ def block ( % Nx.Block.LinAlg.SVD { full_matrices?: true } , { out_u , out_s , out_v } , [ tensor ] , _fun ) do
21442144 t = to_typed_ref ( from_nx ( tensor ) , tensor . type , { :f , 32 } )
21452145 [ u , s , vt ] = EMLX . linalg_svd ( t , true )
2146-
2147- # MLX always returns full matrices; truncate for full_matrices?: false
2148- { u_final , vt_final } =
2149- if full? do
2150- { u , vt }
2151- else
2152- # Truncate U to {m, k} and Vt to {k, n} where k = min(m, n)
2153- { _m , k } = out_u . shape
2154- u_sliced = EMLX . slice ( u , [ 0 , 0 ] , [ elem ( EMLX . shape ( u ) , 0 ) , k ] , [ 1 , 1 ] )
2155- { _k2 , n } = out_v . shape
2156- vt_sliced = EMLX . slice ( vt , [ 0 , 0 ] , [ k , n ] , [ 1 , 1 ] )
2157- { u_sliced , vt_sliced }
2158- end
2159-
2160- { to_nx ( u_final , out_u ) , to_nx ( s , out_s ) , to_nx ( vt_final , out_v ) }
2146+ { to_nx ( u , out_u ) , to_nx ( s , out_s ) , to_nx ( vt , out_v ) }
21612147 end
21622148
21632149 @ impl true
0 commit comments