Skip to content

Commit 5bb8fbc

Browse files
blasphemetheusclaudepolvalente
authored
fix: Handle wide matrices in orthogonal initializer (#629)
* fix: Handle wide matrices in orthogonal initializer QR decomposition of an {m, n} matrix produces Q of shape {m, m}, which fails when n > m (e.g. LSTM weights {hidden, 4*hidden}). Generate a {max(m,n), max(m,n)} square random matrix so QR always produces enough orthogonal columns, then slice to {m, n}. Adds tests for wide 2D and high-rank shapes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Apply suggestions from code review * Apply suggestion from @polvalente --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com>
1 parent d5ecacb commit 5bb8fbc

2 files changed

Lines changed: 34 additions & 2 deletions

File tree

lib/axon/initializers.ex

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -686,13 +686,20 @@ defmodule Axon.Initializers do
686686

687687
{m, n} = get_flat_shape(shape)
688688

689+
# Generate a square random matrix of size max(m, n) so QR
690+
# produces enough orthogonal columns for wide matrices
691+
# (e.g. LSTM weights shaped {hidden, 4*hidden})
692+
# This is because mode: :complete returns Q {m, m} for an {m, n} tensor.
693+
# mode: :reduced return Q {m, min(m, n)}
694+
k = max(m, n)
695+
689696
random_seed =
690697
case distribution do
691698
:uniform ->
692-
Nx.Random.uniform_split(key, 0.0, 1.0, shape: {m, n}, type: type)
699+
Nx.Random.uniform_split(key, 0.0, 1.0, shape: {k, k}, type: type)
693700

694701
:normal ->
695-
Nx.Random.normal_split(key, 0.0, 1.0, shape: {m, n}, type: type)
702+
Nx.Random.normal_split(key, 0.0, 1.0, shape: {k, k}, type: type)
696703

697704
dist ->
698705
raise ArgumentError,

test/axon/initializers_test.exs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,31 @@ defmodule Axon.InitializersTest do
164164
)
165165
end
166166

167+
test "works with wide matrices (n > m, e.g. LSTM/GRU weights)" do
168+
init_fn = Axon.Initializers.orthogonal()
169+
170+
# Wide matrix like LSTM kernel {hidden, 4*hidden}
171+
# Using small dims to keep QR fast on BinaryBackend
172+
t = init_fn.({8, 32}, {:f, 32}, Nx.Random.key(1))
173+
assert Nx.shape(t) == {8, 32}
174+
175+
# Rows should be orthonormal: t * t^T = I
176+
identity = Nx.dot(t, [1], t, [1])
177+
178+
assert_all_close(identity, Nx.eye(Nx.shape(identity)),
179+
atol: 1.0e-3,
180+
rtol: 1.0e-3
181+
)
182+
end
183+
184+
test "works with wide high-rank shapes" do
185+
init_fn = Axon.Initializers.orthogonal()
186+
187+
# Shape that flattens to wide: {2, 3} -> {2, 3} where n > m
188+
t = init_fn.({2, 8}, {:f, 32}, Nx.Random.key(1))
189+
assert Nx.shape(t) == {2, 8}
190+
end
191+
167192
test "raises on input rank less than 2" do
168193
assert_raise ArgumentError,
169194
~r/Axon.Initializers.orthogonal: expected input_shape shape to have at least rank 2/,

0 commit comments

Comments
 (0)