Skip to content

Commit 075d46d

Browse files
committed
Add CAD-based sample points
1 parent a3713b0 commit 075d46d

5 files changed

Lines changed: 220 additions & 1 deletion

File tree

src/AlgebraicSolving.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ include("algorithms/dimension.jl")
1616
include("algorithms/decomposition.jl")
1717
include("algorithms/param-curve.jl")
1818
include("algorithms/hilbert.jl")
19+
include("algorithms/sampling.jl")
1920
#= siggb =#
2021
include("siggb/siggb.jl")
2122
#= progress =#

src/algorithms/sampling.jl

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
2+
function _from_univariate(R::QQMPolyRing, f::QQPolyRingElem)::QQMPolyRingElem
3+
@assert length(gens(R)) == 1
4+
M = MPolyBuildCtx(R)
5+
for i in 0:length(f)
6+
push_term!(M, coeff(f, i), [i])
7+
end
8+
finish(M)
9+
end
10+
11+
function _sample_points(f::Vector{QQPolyRingElem}; worker_pool::AbstractWorkerPool=default_worker_pool())::Vector{QQFieldElem}
12+
R, _ = polynomial_ring(QQ, [:x])
13+
fₓ = map(p -> _from_univariate(R, p), f)
14+
factors = unique([p for fₓ′ in fₓ for (p, _) in factor(fₓ′)])
15+
# We map each factor to its roots, and each root to its factor
16+
roots_by_factor = Dict{QQMPolyRingElem,Vector{Vector{Vector{QQFieldElem}}}}()
17+
factors_by_root = Dict{Vector{Vector{QQFieldElem}},QQMPolyRingElem}()
18+
for factor in factors
19+
roots = real_solutions(Ideal([factor]); interval=true, worker_pool=worker_pool)
20+
roots_by_factor[factor] = roots
21+
for r in roots
22+
factors_by_root[r] = factor
23+
end
24+
end
25+
# We order intervals by their left endpoint
26+
roots = sort(collect(keys(factors_by_root)), by=x -> x[1][1])
27+
# If intervals are not ordered by their right endpoint,
28+
# we merge the offending factors
29+
while (!issorted(roots, by=x -> x[1][2]))
30+
i = findfirst(i -> roots[i][1][2] > roots[i+1][1][2], 1:length(roots)-1)
31+
bad_factors = [factors_by_root[roots[i]], factors_by_root[roots[i+1]]]
32+
# Remove old factors and roots
33+
for bad_factor in bad_factors
34+
for r in roots_by_factor[bad_factor]
35+
delete!(factors_by_root, r)
36+
end
37+
delete!(roots_by_factor, bad_factor)
38+
end
39+
# Replace with merged factor and its roots
40+
merged_factor = prod(bad_factors)
41+
merged_roots = real_solutions(Ideal([merged_factor]); interval=true, worker_pool=worker_pool)
42+
roots_by_factor[merged_factor] = merged_roots
43+
for r in merged_roots
44+
factors_by_root[r] = merged_factor
45+
end
46+
roots = sort(collect(keys(factors_by_root)), by=r -> r[1][1])
47+
end
48+
# Now the intervals are properly ordered, we can sample points
49+
points = Vector{QQFieldElem}()
50+
if length(roots) == 0
51+
push!(points, QQ(0))
52+
return points
53+
end
54+
push!(points, floor(roots[1][1][1]) - 1)
55+
for i in 1:length(roots)-1
56+
push!(points, (roots[i][1][2] + roots[i+1][1][1]) // 2)
57+
end
58+
push!(points, ceil(roots[end][1][2]) + 1)
59+
return points
60+
end
61+
62+
function _sample_points_0(f::Vector{QQMPolyRingElem})::Vector{Vector{QQFieldElem}}
63+
@assert all(map(is_constant, f))
64+
return [Vector{QQFieldElem}()]
65+
end
66+
67+
function _sample_points_1(f::Vector{QQMPolyRingElem}; worker_pool::AbstractWorkerPool=default_worker_pool())::Vector{Vector{QQFieldElem}}
68+
@assert all(map(is_univariate, f))
69+
R, _ = polynomial_ring(QQ, :x)
70+
return [[p] for p in _sample_points(map(p -> to_univariate(R, p), f); worker_pool=worker_pool)]
71+
end
72+
73+
function _sample_points_desc(n::Int)::String
74+
if n == 0
75+
return "Constant"
76+
elseif n == 1
77+
return "Univariate"
78+
elseif n == 2
79+
return "Bivariate"
80+
elseif n == 3
81+
return "Trivariate"
82+
else
83+
return "Multivariate"
84+
end
85+
end
86+
87+
function _sample_points_2(
88+
f::Vector{QQMPolyRingElem},
89+
xs::Vector{QQMPolyRingElem};
90+
nr_thrds::Int=1,
91+
worker_pool::AbstractWorkerPool=default_worker_pool(),
92+
show_progress::Bool=false,
93+
desc::String="$(_sample_points_desc(length(xs))) sample points"
94+
)::Vector{Vector{QQFieldElem}}
95+
@assert length(xs) >= 2
96+
x₁ = xs[1:end-1]
97+
x₂ = xs[end]
98+
factors = unique([p for fₓ in f for (p, _) in factor(fₓ)])
99+
v = Vector{QQMPolyRingElem}()
100+
for i in eachindex(factors)
101+
if !isempty(intersect(x₁, vars(factors[i])))
102+
push!(v, leading_coefficient(factors[i], length(xs)))
103+
end
104+
if x₂ in vars(factors[i])
105+
push!(v, Interpolation.discriminant(factors[i], length(xs); nr_thrds=nr_thrds))
106+
end
107+
end
108+
for i in 1:length(factors)-1
109+
for j in i+1:length(factors)
110+
if x₂ in vars(factors[i]) || x₂ in vars(factors[j])
111+
push!(v, Interpolation.resultant(factors[i], factors[j], length(xs); nr_thrds=nr_thrds))
112+
end
113+
end
114+
end
115+
p₁ = if length(xs) == 2
116+
_sample_points_1(v; worker_pool=worker_pool)
117+
else
118+
_sample_points_2(v, x₁; show_progress=show_progress, worker_pool=worker_pool)
119+
end
120+
prog = Progress.ProgressBar(total=length(p₁); desc=desc, enabled=show_progress)
121+
Progress.update!(prog, 0)
122+
function _points_chunk(p::Vector{Vector{QQFieldElem}})::Vector{Vector{QQFieldElem}}
123+
res_chunk = Vector{Vector{QQFieldElem}}()
124+
for i in eachindex(p)
125+
p₂ = _sample_points_1(map(p′ -> evaluate(p′, x₁, p[i]), f); worker_pool=worker_pool)
126+
for j in eachindex(p₂)
127+
push!(res_chunk, [p[i]; p₂[j][1]])
128+
end
129+
Progress.next!(prog)
130+
end
131+
res_chunk
132+
end
133+
chunk_size = ceil(Int, length(p₁) / nr_thrds)
134+
chunks = [p₁[i:min(i + chunk_size - 1, end)] for i in 1:chunk_size:length(p₁)]
135+
tasks = [Threads.@spawn _points_chunk(chunk) for chunk in chunks]
136+
res = sort(vcat(fetch.(tasks)...))
137+
Progress.finish!(prog)
138+
res
139+
end
140+
141+
function _sample_points(
142+
f::Vector{QQMPolyRingElem};
143+
nr_thrds::Int=1,
144+
worker_pool::AbstractWorkerPool=default_worker_pool(),
145+
show_progress::Bool=false
146+
)::Vector{Vector{QQFieldElem}}
147+
if any(is_zero, f)
148+
error("Cannot sample points for polynomials containing zero polynomial")
149+
end
150+
p = Vector{Vector{QQFieldElem}}([])
151+
if length(gens(parent(f[1]))) == 0
152+
p = _sample_points_0(f)
153+
elseif length(gens(parent(f[1]))) == 1
154+
p = _sample_points_1(f; worker_pool=worker_pool)
155+
else
156+
p = _sample_points_2(f, gens(parent(f[1])); nr_thrds=nr_thrds, worker_pool=worker_pool, show_progress=show_progress)
157+
end
158+
return p
159+
end
160+
161+
function points_per_components(
162+
eqs::Vector{QQMPolyRingElem},
163+
pos::Vector{QQMPolyRingElem},
164+
ineqs::Vector{QQMPolyRingElem};
165+
nr_thrds::Int=1,
166+
worker_pool::AbstractWorkerPool=default_worker_pool(),
167+
info_level::Int=0
168+
)::Vector{Vector{Vector{QQFieldElem}}}
169+
if !isempty(eqs)
170+
throw(ArgumentError("Sampling points per components is not implemented for systems involving equations"))
171+
end
172+
points = _sample_points([pos; ineqs]; nr_thrds=nr_thrds, worker_pool=worker_pool, show_progress=info_level > 0)
173+
res = Vector{Vector{Vector{QQFieldElem}}}()
174+
for point in points
175+
if all(evaluate(constraint, point) > 0 for constraint in pos)
176+
push!(res, map(x -> [x, x], point))
177+
end
178+
end
179+
return res
180+
end

src/imports.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using StaticArrays
77
import Random: MersenneTwister
88
import Logging: ConsoleLogger, with_logger, Warn, Info
99
import Printf: @sprintf, @printf
10-
import Distributed: AbstractWorkerPool, remotecall_fetch
10+
import Distributed: AbstractWorkerPool, default_worker_pool, remotecall_fetch
1111

1212
import Nemo:
1313
bell,

test/algorithms/sampling.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
function points_per_components(ineqs::Vector{QQMPolyRingElem})::Vector{Vector{Vector{QQFieldElem}}}
2+
AlgebraicSolving.points_per_components(QQMPolyRingElem[], QQMPolyRingElem[], ineqs)
3+
end
4+
5+
@testset "Algorithms -> Univariate sample points" begin
6+
R, (x,) = polynomial_ring(QQ, ["x"])
7+
f = [one(R)]
8+
@test length(points_per_components(f)) == 1
9+
f = [x + 1, x + 4294967295 // 4294967296, x + 4294967297 // 4294967296]
10+
@test length(points_per_components(f)) == 4
11+
end
12+
13+
@testset "Algorithms -> Bivariate sample points" begin
14+
R, (x, y) = polynomial_ring(QQ, ["x", "y"])
15+
f = [one(R)]
16+
@test length(points_per_components(f)) == 1
17+
f = [x, x + 1]
18+
@test length(points_per_components(f)) == 3
19+
f = [y, y + 1]
20+
@test length(points_per_components(f)) == 3
21+
f = [x, x + 1, y - x, y + x - 1]
22+
@test length(points_per_components(f)) >= 9
23+
end
24+
25+
@testset "Algorithms -> Trivariate sample points" begin
26+
R, (x, y, z) = polynomial_ring(QQ, ["x", "y", "z"])
27+
f = [one(R)]
28+
@test length(points_per_components(f)) == 1
29+
f = [x, x + 1]
30+
@test length(points_per_components(f)) == 3
31+
f = [y, y + 1]
32+
@test length(points_per_components(f)) == 3
33+
f = [x, x + 1, y - x, y + x - 1]
34+
@test length(points_per_components(f)) >= 9
35+
f = [x, y, z]
36+
@test length(points_per_components(f)) == 8
37+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ include("algorithms/dimension.jl")
1010
include("algorithms/hilbert.jl")
1111
include("algorithms/decomposition.jl")
1212
include("algorithms/param-curves.jl")
13+
include("algorithms/sampling.jl")
1314
include("examples/katsura.jl")
1415
include("interp/thiele.jl")
1516
include("interp/newton.jl")

0 commit comments

Comments
 (0)