@@ -57,103 +57,69 @@ def hnf_decomposition(A: Tensor) -> tuple[Tensor, Tensor, Tensor]:
5757 V: (r x n) Right inverse Unimodular transform (H @ V = A)
5858 """
5959
60- H = A .clone (). to ( dtype = torch . long )
60+ H = A .clone ()
6161 m , n = H .shape
6262
63- U = torch .eye (n , dtype = torch . long )
64- V = torch .eye (n , dtype = torch . long )
63+ U = torch .eye (n , dtype = A . dtype )
64+ V = torch .eye (n , dtype = A . dtype )
6565
66- row = 0
6766 col = 0
6867
69- while row < m and col < n :
70- # --- 1. Pivot Selection ---
71- # Find first non-zero entry in current row from col onwards
72- pivot_idx = - 1
68+ for row in range (m ):
69+ if n <= col :
70+ break
71+ row_slice = H [row , col :n ]
72+ nonzero_indices = torch .nonzero (row_slice )
7373
74- # We extract the row slice to CPU for faster scalar checks if on GPU
75- # or just iterate. For HNF, strictly sequential loop is often easiest.
76- for j in range (col , n ):
77- if H [row , j ] != 0 :
78- pivot_idx = j
79- break
80-
81- if pivot_idx == - 1 :
82- row += 1
74+ if nonzero_indices .numel () > 0 :
75+ relative_pivot_idx = nonzero_indices [0 ][0 ].item ()
76+ pivot_idx = col + relative_pivot_idx
77+ else :
8378 continue
8479
85- # Swap to current column
8680 if pivot_idx != col :
87- # Swap Columns in H and U
8881 H [:, [col , pivot_idx ]] = H [:, [pivot_idx , col ]]
8982 U [:, [col , pivot_idx ]] = U [:, [pivot_idx , col ]]
90- # Swap ROWS in V
9183 V [[col , pivot_idx ], :] = V [[pivot_idx , col ], :]
9284
93- # --- 2. Gaussian Elimination via GCD ---
9485 for j in range (col + 1 , n ):
9586 if H [row , j ] != 0 :
96- # Extract values as python ints for GCD logic
9787 a_val = H [row , col ].item ()
9888 b_val = H [row , j ].item ()
9989
10090 g , x , y = extended_gcd (a_val , b_val )
10191
102- # Bezout: a*x + b*y = g
103- # c1 = -b // g, c2 = a // g
10492 c1 = - b_val // g
10593 c2 = a_val // g
10694
107- # --- Update H (Column Ops) ---
108- # Important: Clone columns to avoid in-place modification issues during calc
109- col_c = H [:, col ].clone ()
110- col_j = H [:, j ].clone ()
111-
112- H [:, col ] = col_c * x + col_j * y
113- H [:, j ] = col_c * c1 + col_j * c2
114-
115- # --- Update U (Column Ops) ---
116- u_c = U [:, col ].clone ()
117- u_j = U [:, j ].clone ()
118- U [:, col ] = u_c * x + u_j * y
119- U [:, j ] = u_c * c1 + u_j * c2
120-
121- # --- Update V (Inverse Row Ops) ---
122- # Inverse of [[x, c1], [y, c2]] is [[c2, -c1], [-y, x]]
123- v_r_c = V [col , :].clone ()
124- v_r_j = V [j , :].clone ()
125- V [col , :] = v_r_c * c2 - v_r_j * c1
126- V [j , :] = v_r_c * (- y ) + v_r_j * x
127-
128- # --- 3. Enforce Positive Diagonal ---
129- if H [row , col ] < 0 :
130- H [:, col ] *= - 1
131- U [:, col ] *= - 1
132- V [col , :] *= - 1
133-
134- # --- 4. Canonical Reduction (Modulo) ---
135- # Ensure 0 <= H[row, k] < H[row, col] for k < col
136- pivot_val = H [row , col ].clone ()
137- if pivot_val != 0 :
138- for j in range (col ):
139- # floor division
140- factor = torch .div (H [row , j ], pivot_val , rounding_mode = "floor" )
95+ H_col = H [:, col ]
96+ H_j = H [:, j ]
97+
98+ H [:, [col , j ]] = torch .stack ([H_col * x + H_j * y , H_col * c1 + H_j * c2 ], dim = 1 )
99+
100+ U_col = U [:, col ]
101+ U_j = U [:, j ]
102+ U [:, [col , j ]] = torch .stack ([U_col * x + U_j * y , U_col * c1 + U_j * c2 ], dim = 1 )
141103
142- if factor != 0 :
143- H [:, j ] -= factor * H [:, col ]
144- U [:, j ] -= factor * U [:, col ]
145- V [col , :] += factor * V [j , :]
104+ V_row_c = V [col , :]
105+ V_row_j = V [j , :]
106+ V [[col , j ], :] = torch .stack (
107+ [V_row_c * c2 - V_row_j * c1 , V_row_c * (- y ) + V_row_j * x ], dim = 0
108+ )
109+
110+ pivot_val = H [row , col ]
111+
112+ if pivot_val != 0 :
113+ H_row_prefix = H [row , 0 :col ]
114+ factors = torch .div (H_row_prefix , pivot_val , rounding_mode = "floor" )
115+ H [:, 0 :col ] -= factors .unsqueeze (0 ) * H [:, col ].unsqueeze (1 )
116+ U [:, 0 :col ] -= factors .unsqueeze (0 ) * U [:, col ].unsqueeze (1 )
117+ V [col , :] += factors @ V [0 :col , :]
146118
147- row += 1
148119 col += 1
149120
150121 col_magnitudes = torch .sum (torch .abs (H ), dim = 0 )
151- non_zero_indices = torch .nonzero (col_magnitudes , as_tuple = True )[0 ]
152-
153- if len (non_zero_indices ) == 0 :
154- rank = 0
155- else :
156- rank = non_zero_indices .max ().item () + 1
122+ rank = torch .count_nonzero (col_magnitudes ).item ()
157123
158124 reduced_H = H [:, :rank ]
159125 reduced_U = U [:, :rank ]
0 commit comments