Skip to content

Commit c2e9195

Browse files
committed
Add CAD-based sample points
1 parent 2520e64 commit c2e9195

4 files changed

Lines changed: 223 additions & 0 deletions

File tree

src/AlgebraicSolving.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ include("algorithms/dimension.jl")
1616
include("algorithms/decomposition.jl")
1717
include("algorithms/param-curve.jl")
1818
include("algorithms/hilbert.jl")
19+
include("algorithms/sampling.jl")
1920
#= siggb =#
2021
include("siggb/siggb.jl")
2122
#= progress =#

src/algorithms/sampling.jl

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
algebraic_solving_lock = ReentrantLock()
2+
3+
function _real_roots(f::Vector{QQMPolyRingElem})::Vector{Vector{Vector{QQFieldElem}}}
4+
rs = nothing
5+
lock(algebraic_solving_lock) do
6+
rs = real_solutions(AlgebraicSolving.Ideal(f), interval=true)
7+
end
8+
return rs
9+
end
10+
11+
function _from_univariate(R::QQMPolyRing, f::QQPolyRingElem)::QQMPolyRingElem
12+
@assert length(gens(R)) == 1
13+
M = MPolyBuildCtx(R)
14+
for i in 0:length(f)
15+
push_term!(M, coeff(f, i), [i])
16+
end
17+
finish(M)
18+
end
19+
20+
function _sample_points(f::Vector{QQPolyRingElem})::Vector{QQFieldElem}
21+
R, _ = polynomial_ring(QQ, [:x])
22+
fₓ = map(p -> _from_univariate(R, p), f)
23+
factors = unique([p for fₓ′ in fₓ for (p, _) in factor(fₓ′)])
24+
# We map each factor to its roots, and each root to its factor
25+
roots_by_factor = Dict{QQMPolyRingElem,Vector{Vector{Vector{QQFieldElem}}}}()
26+
factors_by_root = Dict{Vector{Vector{QQFieldElem}},QQMPolyRingElem}()
27+
for factor in factors
28+
roots = _real_roots([factor])
29+
roots_by_factor[factor] = roots
30+
for r in roots
31+
factors_by_root[r] = factor
32+
end
33+
end
34+
# We order intervals by their left endpoint
35+
roots = sort(collect(keys(factors_by_root)), by=x -> x[1][1])
36+
# If intervals are not ordered by their right endpoint,
37+
# we merge the offending factors
38+
while (!issorted(roots, by=x -> x[1][2]))
39+
i = findfirst(i -> roots[i][1][2] > roots[i+1][1][2], 1:length(roots)-1)
40+
bad_factors = [factors_by_root[roots[i]], factors_by_root[roots[i+1]]]
41+
# Remove old factors and roots
42+
for bad_factor in bad_factors
43+
for r in roots_by_factor[bad_factor]
44+
delete!(factors_by_root, r)
45+
end
46+
delete!(roots_by_factor, bad_factor)
47+
end
48+
# Replace with merged factor and its roots
49+
merged_factor = prod(bad_factors)
50+
merged_roots = _real_roots([merged_factor])
51+
roots_by_factor[merged_factor] = merged_roots
52+
for r in merged_roots
53+
factors_by_root[r] = merged_factor
54+
end
55+
roots = sort(collect(keys(factors_by_root)), by=r -> r[1][1])
56+
end
57+
# Now the intervals are properly ordered, we can sample points
58+
points = Vector{QQFieldElem}()
59+
if length(roots) == 0
60+
push!(points, QQ(0))
61+
return points
62+
end
63+
push!(points, floor(roots[1][1][1]) - 1)
64+
for i in 1:length(roots)-1
65+
push!(points, (roots[i][1][2] + roots[i+1][1][1]) // 2)
66+
end
67+
push!(points, ceil(roots[end][1][2]) + 1)
68+
return points
69+
end
70+
71+
function _sample_points_0(f::Vector{QQMPolyRingElem})::Vector{Vector{QQFieldElem}}
72+
@assert all(map(is_constant, f))
73+
return [Vector{QQFieldElem}()]
74+
end
75+
76+
function _sample_points_1(f::Vector{QQMPolyRingElem})::Vector{Vector{QQFieldElem}}
77+
@assert all(map(is_univariate, f))
78+
R, _ = polynomial_ring(QQ, :x)
79+
return [[p] for p in _sample_points(map(p -> to_univariate(R, p), f))]
80+
end
81+
82+
function _sample_points_desc(n::Int)::String
83+
if n == 0
84+
return "Constant"
85+
elseif n == 1
86+
return "Univariate"
87+
elseif n == 2
88+
return "Bivariate"
89+
elseif n == 3
90+
return "Trivariate"
91+
else
92+
return "Multivariate"
93+
end
94+
end
95+
96+
function _sample_points_2(
97+
f::Vector{QQMPolyRingElem},
98+
xs::Vector{QQMPolyRingElem};
99+
nr_thrds::Int=1,
100+
show_progress::Bool=false,
101+
desc::String="$(_sample_points_desc(length(xs))) sample points"
102+
)::Vector{Vector{QQFieldElem}}
103+
@assert length(xs) >= 2
104+
x₁ = xs[1:end-1]
105+
x₂ = xs[end]
106+
factors = unique([p for fₓ in f for (p, _) in factor(fₓ)])
107+
v = Vector{QQMPolyRingElem}()
108+
for i in eachindex(factors)
109+
if !isempty(intersect(x₁, vars(factors[i])))
110+
push!(v, leading_coefficient(factors[i], length(xs)))
111+
end
112+
if x₂ in vars(factors[i])
113+
push!(v, Interpolation.discriminant(factors[i], length(xs); nr_thrds=nr_thrds))
114+
end
115+
end
116+
for i in 1:length(factors)-1
117+
for j in i+1:length(factors)
118+
if x₂ in vars(factors[i]) || x₂ in vars(factors[j])
119+
push!(v, Interpolation.resultant(factors[i], factors[j], length(xs); nr_thrds=nr_thrds))
120+
end
121+
end
122+
end
123+
p₁ = length(xs) == 2 ? _sample_points_1(v) : _sample_points_2(v, x₁; show_progress=show_progress)
124+
prog = Progress.ProgressBar(total=length(p₁); desc=desc, enabled=show_progress)
125+
Progress.update!(prog, 0)
126+
function _points_chunk(p::Vector{Vector{QQFieldElem}})::Vector{Vector{QQFieldElem}}
127+
res_chunk = Vector{Vector{QQFieldElem}}()
128+
for i in eachindex(p)
129+
p₂ = _sample_points_1(map(p′ -> evaluate(p′, x₁, p[i]), f))
130+
for j in eachindex(p₂)
131+
push!(res_chunk, [p[i]; p₂[j][1]])
132+
end
133+
Progress.next!(prog)
134+
end
135+
res_chunk
136+
end
137+
chunk_size = ceil(Int, length(p₁) / nr_thrds)
138+
chunks = [p₁[i:min(i + chunk_size - 1, end)] for i in 1:chunk_size:length(p₁)]
139+
tasks = [Threads.@spawn _points_chunk(chunk) for chunk in chunks]
140+
res = sort(vcat(fetch.(tasks)...))
141+
Progress.finish!(prog)
142+
res
143+
end
144+
145+
function _unique_by_sign(p::Vector{Vector{QQFieldElem}}, f::Vector{QQMPolyRingElem})::Vector{Vector{QQFieldElem}}
146+
signs = Set{Vector{QQFieldElem}}()
147+
p′ = Vector{Vector{QQFieldElem}}()
148+
for pᵢ in p
149+
s = [sign(evaluate(q, pᵢ)) for q in f]
150+
if !(s in signs)
151+
push!(signs, s)
152+
push!(p′, pᵢ)
153+
end
154+
end
155+
p′
156+
end
157+
158+
function _sample_points(
159+
f::Vector{QQMPolyRingElem};
160+
distinct_signs::Bool=false,
161+
nr_thrds::Int=1,
162+
show_progress::Bool=false
163+
)::Vector{Vector{QQFieldElem}}
164+
if any(is_zero, f)
165+
error("Cannot sample points for polynomials containing zero polynomial")
166+
end
167+
p = Vector{Vector{QQFieldElem}}([])
168+
if length(gens(parent(f[1]))) == 0
169+
p = _sample_points_0(f)
170+
elseif length(gens(parent(f[1]))) == 1
171+
p = _sample_points_1(f)
172+
else
173+
p = _sample_points_2(f, gens(parent(f[1])); nr_thrds=nr_thrds, show_progress=show_progress)
174+
end
175+
if distinct_signs
176+
p = _unique_by_sign(p, f)
177+
end
178+
return p
179+
end

test/algorithms/sampling.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import AlgebraicSolving: _real_roots, _sample_points
2+
3+
@testset "Algorithms -> Real roots" begin
4+
R, (x, y, z) = polynomial_ring(QQ, [:x, :y, :z])
5+
f = [x + 2 * y + 2 * z - 1, x^2 + 2 * y^2 + 2 * z^2 - x, 2 * x * y + 2 * y * z - y]
6+
rs = _real_roots(f)
7+
@test length(rs) == 4
8+
end
9+
10+
@testset "Algorithms -> Univariate sample points" begin
11+
R, (x,) = polynomial_ring(QQ, ["x"])
12+
f = [one(R)]
13+
@test length(_sample_points(f)) == 1
14+
f = [x + 1, x + 4294967295 // 4294967296, x + 4294967297 // 4294967296]
15+
@test length(_sample_points(f)) == 4
16+
end
17+
18+
@testset "Algorithms -> Bivariate sample points" begin
19+
R, (x, y) = polynomial_ring(QQ, ["x", "y"])
20+
f = [one(R)]
21+
@test length(_sample_points(f)) == 1
22+
f = [x, x + 1]
23+
@test length(_sample_points(f)) == 3
24+
f = [y, y + 1]
25+
@test length(_sample_points(f)) == 3
26+
f = [x, x + 1, y - x, y + x - 1]
27+
@test length(_sample_points(f)) >= 9
28+
end
29+
30+
@testset "Algorithms -> Trivariate sample points" begin
31+
R, (x, y, z) = polynomial_ring(QQ, ["x", "y", "z"])
32+
f = [one(R)]
33+
@test length(_sample_points(f)) == 1
34+
f = [x, x + 1]
35+
@test length(_sample_points(f)) == 3
36+
f = [y, y + 1]
37+
@test length(_sample_points(f)) == 3
38+
f = [x, x + 1, y - x, y + x - 1]
39+
@test length(_sample_points(f)) >= 9
40+
f = [x, y, z]
41+
@test length(_sample_points(f)) == 8
42+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ include("algorithms/dimension.jl")
1010
include("algorithms/hilbert.jl")
1111
include("algorithms/decomposition.jl")
1212
include("algorithms/param-curves.jl")
13+
include("algorithms/sampling.jl")
1314
include("examples/katsura.jl")
1415
include("interp/thiele.jl")
1516
include("interp/newton.jl")

0 commit comments

Comments
 (0)