Skip to content

Commit 3f81a9f

Browse files
committed
Add multivariate rational interpolation
1 parent 46c3dd4 commit 3f81a9f

12 files changed

Lines changed: 699 additions & 0 deletions

File tree

src/AlgebraicSolving.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ include("algorithms/param-curve.jl")
1818
include("algorithms/hilbert.jl")
1919
#= siggb =#
2020
include("siggb/siggb.jl")
21+
#= progress =#
22+
include("progress/main.jl")
23+
#= interp =#
24+
include("interp/main.jl")
2125
#= examples =#
2226
include("examples/katsura.jl")
2327
include("examples/cyclic.jl")

src/interp/cuyt_lee.jl

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
mutable struct Atomic{T}
2+
@atomic x::T
3+
end
4+
5+
struct CuytLeeError <: Exception
6+
msg::String
7+
end
8+
9+
Base.show(io::IO, e::CuytLeeError) = print(io, "CuytLeeError: ", e.msg)
10+
11+
function _random_point(n::Int)::Vector{Int}
12+
map(a -> 1 + abs(a) % 99, rand(Int, n))
13+
end
14+
15+
function _estimate_total_degree(
16+
R::QQMPolyRing,
17+
bb::Function;
18+
samples::Int=5,
19+
show_progress::Bool=false
20+
)::Tuple{Int,Int}
21+
t = length(gens(R))
22+
@assert t > 0
23+
R_z, _ = polynomial_ring(QQ, :z)
24+
total_degree_counts = Dict{Tuple{Int,Int},Int}()
25+
total = 0
26+
while total < samples
27+
x = _random_point(t)
28+
try
29+
f = thiele(R_z, k -> bb(k .* x); show_progress=show_progress)
30+
total_degree = (degree(denominator(f)), degree(numerator(f)))
31+
total_degree_counts[total_degree] = get(total_degree_counts, total_degree, 0) + 1
32+
total += 1
33+
if findmax(total_degree_counts)[1] > samples ÷ 2
34+
break
35+
end
36+
catch e
37+
if isa(e, ThieleError)
38+
continue
39+
else
40+
rethrow(e)
41+
end
42+
end
43+
end
44+
return findmax(total_degree_counts)[2]
45+
end
46+
47+
function _homogenize(
48+
f::QQMPolyRingElem,
49+
d::Int
50+
)::QQMPolyRingElem
51+
R = parent(f)
52+
C = MPolyBuildCtx(R)
53+
for i in 1:f.length
54+
exp = exponent_vector(f, i)
55+
@assert exp[1] == 0
56+
total_deg = sum(exp)
57+
exp[1] += d - total_deg
58+
@assert exp[1] >= 0
59+
push_term!(C, coeff(f, i), exp)
60+
end
61+
return finish(C)
62+
end
63+
64+
function cuyt_lee_shifted(
65+
R::QQMPolyRing,
66+
bb::Function;
67+
retry::Int=10,
68+
nr_thrds::Int=1,
69+
show_progress::Bool=false,
70+
desc::String="Multivariate rational interpolation"
71+
)::FracFieldElem{QQMPolyRingElem}
72+
# https://arxiv.org/pdf/1608.01902
73+
t = length(gens(R))
74+
if t == 0
75+
return R(bb(Vector{QQFieldElem}())) // one(R)
76+
end
77+
R_z, _ = polynomial_ring(QQ, :z)
78+
d_den, d_num = _estimate_total_degree(R, bb; show_progress=show_progress)
79+
d = [0; fill(max(d_den, d_num), t - 1)...]
80+
x = Vector{Vector{ZZRingElem}}()
81+
coeffs_den = []
82+
coeffs_num = []
83+
data_lock = ReentrantLock()
84+
prog = ProgressBar(total=prod(d .+ 1); desc=desc, enabled=show_progress)
85+
update!(prog, 0)
86+
function populate(cur::Vector{ZZRingElem}, dim::Int; num_threads::Int=1, offset=1)::Bool
87+
if dim > t
88+
f = nothing
89+
try
90+
f = thiele(R_z, k -> bb(k .* cur); retry=retry, show_progress=show_progress, offset=offset)
91+
catch e
92+
if isa(e, BoundsError) || isa(e, ThieleError)
93+
return false
94+
else
95+
rethrow(e)
96+
end
97+
end
98+
if degree(denominator(f)) != d_den || degree(numerator(f)) != d_num
99+
return false
100+
end
101+
c = constant_coefficient(denominator(f))
102+
if c == 0
103+
return false
104+
end
105+
lock(data_lock) do
106+
push!(x, copy(cur))
107+
push!(coeffs_den, collect(coefficients(denominator(f))) ./ c)
108+
push!(coeffs_num, collect(coefficients(numerator(f))) ./ c)
109+
update!(prog, length(x))
110+
end
111+
return true
112+
end
113+
total = Atomic(0)
114+
failures = Atomic(0)
115+
i = 1
116+
while total.x < d[dim] + 1
117+
num_threads_chunk = min(num_threads, d[dim] + 1 - total.x)
118+
Threads.@threads for j in 0:num_threads_chunk-1
119+
if populate([cur; ZZ(i + j)], dim + 1; num_threads=1, offset=max(offset, j + 1))
120+
@atomic total.x += 1
121+
@atomic failures.x = 0
122+
else
123+
@atomic failures.x += 1
124+
end
125+
end
126+
i += num_threads_chunk
127+
if failures.x >= retry
128+
return false
129+
end
130+
end
131+
return true
132+
end
133+
res = populate([ZZ(1)], 2; num_threads=nr_thrds)
134+
if !res
135+
finish!(prog)
136+
throw(CuytLeeError("Failed to collect enough data points for interpolation. This could happen if the black box function is singular at zero, or if the expected total degree is incorrect."))
137+
end
138+
perm = sortperm(x)
139+
x = x[perm]
140+
coeffs_den = coeffs_den[perm]
141+
coeffs_num = coeffs_num[perm]
142+
# We interpolate the denominator and numerator separately
143+
den = zero(R)
144+
num = zero(R)
145+
for i in 0:d_den
146+
y = [coeffs_den[j][i+1] for j in 1:length(x)]
147+
den += _homogenize(newton(R, x, y, d), i)
148+
end
149+
for i in 0:d_num
150+
y = [coeffs_num[j][i+1] for j in 1:length(x)]
151+
num += _homogenize(newton(R, x, y, d), i)
152+
end
153+
finish!(prog)
154+
return num // den
155+
end
156+
157+
function cuyt_lee_with_shift(
158+
R::QQMPolyRing,
159+
bb::Function,
160+
shift::Vector{Int};
161+
retry::Int=10,
162+
nr_thrds::Int=1,
163+
show_progress::Bool=false,
164+
desc::String="Multivariate rational interpolation"
165+
)::FracFieldElem{QQMPolyRingElem}
166+
t = length(gens(R))
167+
if t == 0
168+
return R(bb(Vector{QQFieldElem}())) // one(R)
169+
end
170+
f_shifted = cuyt_lee_shifted(R, z -> bb(z .+ shift); retry=retry, nr_thrds=nr_thrds, show_progress=show_progress, desc=desc)
171+
x = gens(R) .- shift
172+
num = evaluate(numerator(f_shifted), x)
173+
den = evaluate(denominator(f_shifted), x)
174+
return num // den
175+
end
176+
177+
function cuyt_lee(
178+
R::QQMPolyRing,
179+
bb::Function;
180+
initial_shift=_random_point(length(gens(R))),
181+
retry::Int=10,
182+
nr_thrds::Int=1,
183+
show_progress::Bool=false,
184+
desc::String="Multivariate rational interpolation"
185+
)::FracFieldElem{QQMPolyRingElem}
186+
t = length(gens(R))
187+
if t == 0
188+
return R(bb(Vector{QQFieldElem}())) // one(R)
189+
end
190+
shift = initial_shift
191+
for i in 1:retry
192+
try
193+
return cuyt_lee_with_shift(R, bb, shift; retry=retry, nr_thrds=nr_thrds, show_progress=show_progress, desc=desc)
194+
catch e
195+
if isa(e, CuytLeeError)
196+
if show_progress
197+
@warn "Interpolation failed, retrying with a different shift... Retries left: $(retry - i)"
198+
end
199+
shift = _random_point(t)
200+
else
201+
rethrow(e)
202+
end
203+
end
204+
end
205+
throw(CuytLeeError("Interpolation failed after maximum number of retries."))
206+
end

src/interp/main.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
module Interpolation
2+
3+
using Nemo
4+
import Nemo.Generic: FracFieldElem
5+
using ..Progress
6+
7+
export thiele, newton, cuyt_lee
8+
9+
# Interpolation algorithms
10+
include("thiele.jl")
11+
include("newton.jl")
12+
include("cuyt_lee.jl")
13+
14+
# Applications
15+
include("resultant.jl")
16+
17+
end # module Interpolation

src/interp/newton.jl

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
struct Fmpq
2+
num::Int
3+
den::Int
4+
end
5+
6+
function _flint_lagrange(
7+
R::QQPolyRing,
8+
x::Vector{ZZRingElem},
9+
y::Vector{QQFieldElem}
10+
)::QQPolyRingElem
11+
@assert length(x) == length(y)
12+
n = length(x)
13+
z = zero(R)
14+
ax = Vector{Int}(map(i -> Int(x[i].d), 1:n))
15+
ay = Vector{Fmpq}(map(i -> Fmpq(Int(y[i].num), Int(y[i].den)), 1:n))
16+
@ccall Nemo.libflint.fmpq_poly_interpolate_fmpz_fmpq_vec(z::Ref{QQPolyRingElem}, ax::Ptr{Int}, ay::Ptr{Fmpq}, n::Int)::Nothing
17+
@assert evaluate(z, x[n]) == y[n]
18+
return z
19+
end
20+
21+
function _from_univariate(
22+
x::QQMPolyRingElem,
23+
f::QQPolyRingElem
24+
)::QQMPolyRingElem
25+
R = parent(x)
26+
j = findfirst(u -> u == x, gens(R))
27+
t = length(gens(R))
28+
C = MPolyBuildCtx(R)
29+
for i in 0:length(f)-1
30+
c = coeff(f, i)
31+
if c != 0
32+
exp = fill(0, t)
33+
exp[j] = i
34+
push_term!(C, c, exp)
35+
end
36+
end
37+
finish(C)
38+
end
39+
40+
function newton(
41+
R::QQMPolyRing,
42+
x::Vector{Vector{ZZRingElem}},
43+
y::Vector{QQFieldElem},
44+
d::Vector{Int}
45+
)::QQMPolyRingElem
46+
n = length(x)
47+
@assert length(y) == prod(d .+ 1) == n
48+
step = prod(d[2:end] .+ 1)
49+
x_index = length(gens(R)) - length(d) + 1
50+
z = gens(R)[x_index]
51+
x_trunc = [x[i][x_index] for i in 1:step:n]
52+
if length(d) == 1
53+
R′, _ = polynomial_ring(QQ, :z)
54+
y_trunc = [y[i] for i in 1:(d[1]+1)]
55+
f = _flint_lagrange(R′, x_trunc, y_trunc)
56+
return _from_univariate(z, f)
57+
end
58+
y_trunc = [newton(R, x[i:i+step-1], y[i:i+step-1], d[2:end]) for i in 1:step:n]
59+
if d[1] == 0
60+
return y_trunc[1]
61+
end
62+
dd = [copy(y_trunc)]
63+
for j in 1:d[1]
64+
dd_j = Vector{QQMPolyRingElem}()
65+
for i in 1:(d[1]-j+1)
66+
push!(dd_j, (dd[j][i+1] - dd[j][i]) / (R(x_trunc[i+j]) - R(x_trunc[i])))
67+
end
68+
push!(dd, dd_j)
69+
end
70+
f = zero(R)
71+
term = one(R)
72+
for j in 0:d[1]
73+
f += dd[j+1][1] * term
74+
term *= (z - x_trunc[j+1])
75+
end
76+
return f
77+
end
78+
79+
function newton(
80+
R::QQMPolyRing,
81+
bb::Function,
82+
d::Vector{Int};
83+
non_vanishing_poly::QQMPolyRingElem=one(R),
84+
nr_thrds::Int=1,
85+
show_progress::Bool=false,
86+
desc::String="Multivariate polynomial interpolation"
87+
)::QQMPolyRingElem
88+
@assert !iszero(non_vanishing_poly)
89+
t = length(gens(R))
90+
@assert length(d) == t
91+
x = Vector{Vector{ZZRingElem}}()
92+
y = Vector{QQFieldElem}()
93+
data_lock = ReentrantLock()
94+
prog = ProgressBar(total=prod(d .+ 1); desc=desc, enabled=show_progress)
95+
update!(prog, 0)
96+
function populate(cur::Vector{ZZRingElem}, dim::Int; num_threads::Int=1)
97+
if dim > t
98+
lock(data_lock) do
99+
push!(x, cur)
100+
push!(y, bb(QQ.(cur)))
101+
next!(prog)
102+
end
103+
return
104+
end
105+
total = Atomic(0)
106+
i = 1
107+
while total.x < d[dim] + 1
108+
num_threads_chunk = min(num_threads, d[dim] + 1 - total.x)
109+
Threads.@threads for j in 0:num_threads_chunk-1
110+
if !iszero(evaluate(non_vanishing_poly, gens(R)[1:length(cur)+1], [cur; ZZ(i + j)]))
111+
populate([cur; ZZ(i + j)], dim + 1)
112+
@atomic total.x += 1
113+
end
114+
end
115+
i += num_threads_chunk
116+
end
117+
end
118+
populate(Vector{ZZRingElem}(), 1; num_threads=nr_thrds)
119+
f = newton(R, x, y, d)
120+
finish!(prog)
121+
return f
122+
end

0 commit comments

Comments
 (0)