Skip to content

Commit 8a288a1

Browse files
authored
Merge pull request #2177 from JohnAAbbott/JAA/PATCH-solve_triu
Patched _solve_triu and _solve_tril; beefed up the tests
2 parents b22e086 + 95f986c commit 8a288a1

2 files changed

Lines changed: 43 additions & 37 deletions

File tree

src/flint/fmpz_mat.jl

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1673,7 +1673,7 @@ function _solve_dixon(a::ZZMatrix, b::ZZMatrix)
16731673
return z, d
16741674
end
16751675

1676-
#XU = B. only the upper triangular part of U is used
1676+
# Solve XU = B for X given U & B. Only the upper triangular part of U is used.
16771677
function AbstractAlgebra._solve_triu_left(U::ZZMatrix, b::ZZMatrix; unipotent::Bool = false)
16781678
n = ncols(U)
16791679
m = nrows(b)
@@ -1719,9 +1719,9 @@ function AbstractAlgebra._solve_triu_left(U::ZZMatrix, b::ZZMatrix; unipotent::B
17191719
return X
17201720
end
17211721

1722-
#UX = B, U has to be upper triangular
1723-
#I think due to the Strassen calling path, where Strasse.solve(side = :left)
1724-
#call directly AA.solve_left, this has to be in AA and cannot be independent.
1722+
# Solve UX = B for X given U & B: U has to be upper triangular.
1723+
# I think due to the Strassen calling path, where Strasse.solve(side = :left)
1724+
# call directly AA.solve_left, this has to be in AA and cannot be independent.
17251725
function AbstractAlgebra._solve_triu(U::ZZMatrix, b::ZZMatrix; side::Symbol=:left, unipotent::Bool = false)
17261726
if side == :left
17271727
return AbstractAlgebra._solve_triu_left(U, b; unipotent)
@@ -1732,20 +1732,22 @@ function AbstractAlgebra._solve_triu(U::ZZMatrix, b::ZZMatrix; side::Symbol=:lef
17321732
X = zero(b)
17331733
tmp = zero_matrix(ZZ, 1, n)
17341734
s = ZZ()
1735+
# We build up the solution column by column
17351736
GC.@preserve U b X tmp begin
1736-
for i = 1:m
1737+
for i = 1:m # i indexes the columns
17371738
tmp_ptr = mat_entry_ptr(tmp, 1, 1)
17381739
for j = 1:n
17391740
X_ptr = mat_entry_ptr(X, j, i)
17401741
set!(tmp_ptr, X_ptr)
17411742
tmp_ptr += sizeof(ZZRingElem)
17421743
end
1743-
for j = n:-1:1
1744+
# At this point tmp is full of zeroes
1745+
for j = n:-1:1 # j indexes the rows (in i-th column)
17441746
zero!(s)
17451747
tmp_ptr = mat_entry_ptr(tmp, 1, j+1)
17461748
for k = j + 1:n
17471749
U_ptr = mat_entry_ptr(U, j, k)
1748-
mul!(s, U_ptr, tmp_ptr)
1750+
addmul!(s, U_ptr, tmp_ptr)
17491751
tmp_ptr += sizeof(ZZRingElem)
17501752
# s = addmul!(s, U[j, k], tmp[k])
17511753
end
@@ -1772,35 +1774,35 @@ function AbstractAlgebra._solve_triu(U::ZZMatrix, b::ZZMatrix; side::Symbol=:lef
17721774
return X
17731775
end
17741776

1775-
#solves Ax = B for A lower triagular. if f != 0 (f is true), the diagonal
1776-
#is assumed to be 1 and not actually used.
1777-
#the upper part of A is not used/ touched.
1778-
#one cannot assert is_lower_triangular as this is used for the inplace
1779-
#lu decomposition where the matrix is full, encoding an upper triangular
1780-
#using the diagonal and a lower triangular with trivial diagonal
1781-
function AbstractAlgebra._solve_tril!(A::ZZMatrix, B::ZZMatrix, C::ZZMatrix, f::Int = 0)
1777+
# Solves Lx = B for L lower triangular. If unipotent is true, the diagonal
1778+
# is assumed to be 1 and not actually used.
1779+
# The upper part of L is not used/ touched.
1780+
# One cannot assert is_lower_triangular as this is used for the inplace
1781+
# lu decomposition where the matrix is full, encoding an upper triangular
1782+
# using the diagonal and a lower triangular with trivial diagonal
1783+
function AbstractAlgebra._solve_tril!(X::ZZMatrix, L::ZZMatrix, B::ZZMatrix; unipotent::Bool = false)
17821784

17831785
# a x u ax = u
17841786
# b c * y = v bx + cy = v
17851787
# d e f z w ....
17861788

1787-
@assert ncols(A) == ncols(C)
1789+
@assert ncols(X) == ncols(B)
17881790
s = ZZ(0)
1789-
GC.@preserve A B C begin
1790-
for i=1:ncols(A)
1791-
for j = 1:nrows(A)
1792-
t = C[j, i]
1793-
B_ptr = mat_entry_ptr(B, j, 1)
1791+
GC.@preserve X L B begin
1792+
for i=1:ncols(X)
1793+
for j = 1:nrows(X)
1794+
t = B[j, i]
1795+
L_ptr = mat_entry_ptr(L, j, 1)
17941796
for k = 1:j-1
1795-
A_ptr = mat_entry_ptr(B, k, i)
1796-
mul!(s, A_ptr, B_ptr)
1797-
B_ptr += sizeof(ZZRingElem)
1797+
X_ptr = mat_entry_ptr(X, k, i)
1798+
mul!(s, X_ptr, L_ptr)
1799+
L_ptr += sizeof(ZZRingElem)
17981800
sub!(t, t, s)
17991801
end
1800-
if f == 1
1801-
A[j,i] = t
1802+
if unipotent
1803+
X[j,i] = t
18021804
else
1803-
A[j,i] = divexact(t, B[j, j])
1805+
X[j,i] = divexact(t, L[j, j])
18041806
end
18051807
end
18061808
end

test/flint/fmpz_mat-test.jl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -759,19 +759,23 @@ end
759759
end
760760

761761
@testset "ZZMatrix.solve" begin
762-
A = matrix(ZZ, 2, 2, [1,2,3,4])
762+
# Test matrices have size (at least) 3x3 -- smaller failed to detect a bug.
763+
U_triang = matrix(ZZ, 3, 3, [1,2,3, 0,4,5, 0,0,6])
764+
L_triang = matrix(ZZ, 3, 3, [1,0,0, 2,3,0, 4,5,6])
763765

764766
@test AbstractAlgebra.Solve.matrix_normal_form_type(ZZ) === AbstractAlgebra.Solve.HermiteFormTrait()
765-
@test AbstractAlgebra.Solve.matrix_normal_form_type(A) === AbstractAlgebra.Solve.HermiteFormTrait()
766-
767-
b = matrix(ZZ, 1, 2, [1, 6])
768-
@test AbstractAlgebra._solve_triu_left(A, b) == matrix(ZZ, 1, 2, [1, 1])
769-
b = matrix(ZZ, 2, 1, [3, 4])
770-
@test AbstractAlgebra._solve_triu(A, b; side = :right) == matrix(ZZ, 2, 1, [1, 1])
771-
b = matrix(ZZ, 2, 1, [1, 7])
772-
c = similar(b)
773-
AbstractAlgebra._solve_tril!(c, A, b)
774-
@test c == matrix(ZZ, 2, 1, [1, 1])
767+
@test AbstractAlgebra.Solve.matrix_normal_form_type(U_triang) === AbstractAlgebra.Solve.HermiteFormTrait()
768+
@test AbstractAlgebra.Solve.matrix_normal_form_type(L_triang) === AbstractAlgebra.Solve.HermiteFormTrait()
769+
770+
X = matrix(ZZ, 3, 2, [3,1, 4,1, 5,9])
771+
trX = transpose(X)
772+
@test AbstractAlgebra._solve_triu_left(U_triang, trX*U_triang) == trX
773+
@test AbstractAlgebra._solve_triu(U_triang, trX*U_triang; side = :left) == trX
774+
@test AbstractAlgebra._solve_triu(U_triang, U_triang*X; side = :right) == X
775+
776+
c = similar(X)
777+
AbstractAlgebra._solve_tril!(c, L_triang, L_triang*X)
778+
@test c == X
775779

776780
S = matrix_space(ZZ, 3, 3)
777781

0 commit comments

Comments
 (0)