Skip to content

Commit 80ffb14

Browse files
committed
Improve hnf_decomposition and add a test for it (failing)
1 parent c6b19c7 commit 80ffb14

2 files changed

Lines changed: 66 additions & 69 deletions

File tree

src/torchjd/sparse/_linalg.py

Lines changed: 35 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -57,103 +57,69 @@ def hnf_decomposition(A: Tensor) -> tuple[Tensor, Tensor, Tensor]:
5757
V: (r x n) Right inverse Unimodular transform (H @ V = A)
5858
"""
5959

60-
H = A.clone().to(dtype=torch.long)
60+
H = A.clone()
6161
m, n = H.shape
6262

63-
U = torch.eye(n, dtype=torch.long)
64-
V = torch.eye(n, dtype=torch.long)
63+
U = torch.eye(n, dtype=A.dtype)
64+
V = torch.eye(n, dtype=A.dtype)
6565

66-
row = 0
6766
col = 0
6867

69-
while row < m and col < n:
70-
# --- 1. Pivot Selection ---
71-
# Find first non-zero entry in current row from col onwards
72-
pivot_idx = -1
68+
for row in range(m):
69+
if n <= col:
70+
break
71+
row_slice = H[row, col:n]
72+
nonzero_indices = torch.nonzero(row_slice)
7373

74-
# We extract the row slice to CPU for faster scalar checks if on GPU
75-
# or just iterate. For HNF, strictly sequential loop is often easiest.
76-
for j in range(col, n):
77-
if H[row, j] != 0:
78-
pivot_idx = j
79-
break
80-
81-
if pivot_idx == -1:
82-
row += 1
74+
if nonzero_indices.numel() > 0:
75+
relative_pivot_idx = nonzero_indices[0][0].item()
76+
pivot_idx = col + relative_pivot_idx
77+
else:
8378
continue
8479

85-
# Swap to current column
8680
if pivot_idx != col:
87-
# Swap Columns in H and U
8881
H[:, [col, pivot_idx]] = H[:, [pivot_idx, col]]
8982
U[:, [col, pivot_idx]] = U[:, [pivot_idx, col]]
90-
# Swap ROWS in V
9183
V[[col, pivot_idx], :] = V[[pivot_idx, col], :]
9284

93-
# --- 2. Gaussian Elimination via GCD ---
9485
for j in range(col + 1, n):
9586
if H[row, j] != 0:
96-
# Extract values as python ints for GCD logic
9787
a_val = H[row, col].item()
9888
b_val = H[row, j].item()
9989

10090
g, x, y = extended_gcd(a_val, b_val)
10191

102-
# Bezout: a*x + b*y = g
103-
# c1 = -b // g, c2 = a // g
10492
c1 = -b_val // g
10593
c2 = a_val // g
10694

107-
# --- Update H (Column Ops) ---
108-
# Important: Clone columns to avoid in-place modification issues during calc
109-
col_c = H[:, col].clone()
110-
col_j = H[:, j].clone()
111-
112-
H[:, col] = col_c * x + col_j * y
113-
H[:, j] = col_c * c1 + col_j * c2
114-
115-
# --- Update U (Column Ops) ---
116-
u_c = U[:, col].clone()
117-
u_j = U[:, j].clone()
118-
U[:, col] = u_c * x + u_j * y
119-
U[:, j] = u_c * c1 + u_j * c2
120-
121-
# --- Update V (Inverse Row Ops) ---
122-
# Inverse of [[x, c1], [y, c2]] is [[c2, -c1], [-y, x]]
123-
v_r_c = V[col, :].clone()
124-
v_r_j = V[j, :].clone()
125-
V[col, :] = v_r_c * c2 - v_r_j * c1
126-
V[j, :] = v_r_c * (-y) + v_r_j * x
127-
128-
# --- 3. Enforce Positive Diagonal ---
129-
if H[row, col] < 0:
130-
H[:, col] *= -1
131-
U[:, col] *= -1
132-
V[col, :] *= -1
133-
134-
# --- 4. Canonical Reduction (Modulo) ---
135-
# Ensure 0 <= H[row, k] < H[row, col] for k < col
136-
pivot_val = H[row, col].clone()
137-
if pivot_val != 0:
138-
for j in range(col):
139-
# floor division
140-
factor = torch.div(H[row, j], pivot_val, rounding_mode="floor")
95+
H_col = H[:, col]
96+
H_j = H[:, j]
97+
98+
H[:, [col, j]] = torch.stack([H_col * x + H_j * y, H_col * c1 + H_j * c2], dim=1)
99+
100+
U_col = U[:, col]
101+
U_j = U[:, j]
102+
U[:, [col, j]] = torch.stack([U_col * x + U_j * y, U_col * c1 + U_j * c2], dim=1)
141103

142-
if factor != 0:
143-
H[:, j] -= factor * H[:, col]
144-
U[:, j] -= factor * U[:, col]
145-
V[col, :] += factor * V[j, :]
104+
V_row_c = V[col, :]
105+
V_row_j = V[j, :]
106+
V[[col, j], :] = torch.stack(
107+
[V_row_c * c2 - V_row_j * c1, V_row_c * (-y) + V_row_j * x], dim=0
108+
)
109+
110+
pivot_val = H[row, col]
111+
112+
if pivot_val != 0:
113+
H_row_prefix = H[row, 0:col]
114+
factors = torch.div(H_row_prefix, pivot_val, rounding_mode="floor")
115+
H[:, 0:col] -= factors.unsqueeze(0) * H[:, col].unsqueeze(1)
116+
U[:, 0:col] -= factors.unsqueeze(0) * U[:, col].unsqueeze(1)
117+
V[col, :] += factors @ V[0:col, :]
146118

147-
row += 1
148119
col += 1
149120

150121
col_magnitudes = torch.sum(torch.abs(H), dim=0)
151-
non_zero_indices = torch.nonzero(col_magnitudes, as_tuple=True)[0]
152-
153-
if len(non_zero_indices) == 0:
154-
rank = 0
155-
else:
156-
rank = non_zero_indices.max().item() + 1
122+
rank = torch.count_nonzero(col_magnitudes).item()
157123

158124
reduced_H = H[:, :rank]
159125
reduced_U = U[:, :rank]

tests/unit/sparse/test_linalg.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import torch
2+
from pytest import mark
3+
4+
from torchjd.sparse._linalg import hnf_decomposition
5+
6+
7+
@mark.parametrize(
8+
["shape", "max_rank"],
9+
[
10+
([5, 7], 3),
11+
([1, 7], 1),
12+
([5, 1], 1),
13+
([7, 5], 2),
14+
([5, 7], 5),
15+
([7, 5], 5),
16+
],
17+
)
18+
def test_hnf_decomposition(shape: tuple[int, int], max_rank: int):
19+
# Generate a matrix A of desired shape and rank max_rank with high probability and lower
20+
# otherwise.
21+
U = torch.randint(-50, 51, [shape[0], max_rank], dtype=torch.int64)
22+
V = torch.randint(-50, 51, [max_rank, shape[1]], dtype=torch.int64)
23+
A = U @ V
24+
H, U, V = hnf_decomposition(A)
25+
26+
rank = H.shape[1]
27+
28+
assert rank <= max_rank
29+
assert torch.equal(V @ U, torch.eye(rank, dtype=torch.int64))
30+
assert torch.equal(H @ V, A)
31+
assert torch.equal(A @ U, H)

0 commit comments

Comments
 (0)