Skip to content

Commit 150856c

Browse files
committed
Add CAD-based sample points
1 parent a3713b0 commit 150856c

5 files changed

Lines changed: 219 additions & 1 deletion

File tree

src/AlgebraicSolving.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ include("algorithms/dimension.jl")
1616
include("algorithms/decomposition.jl")
1717
include("algorithms/param-curve.jl")
1818
include("algorithms/hilbert.jl")
19+
include("algorithms/sampling.jl")
1920
#= siggb =#
2021
include("siggb/siggb.jl")
2122
#= progress =#

src/algorithms/sampling.jl

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
function _from_univariate(R::QQMPolyRing, f::QQPolyRingElem)::QQMPolyRingElem
2+
@assert length(gens(R)) == 1
3+
M = MPolyBuildCtx(R)
4+
for i in 0:length(f)
5+
push_term!(M, coeff(f, i), [i])
6+
end
7+
finish(M)
8+
end
9+
10+
function _sample_points(f::Vector{QQPolyRingElem}; worker_pool::AbstractWorkerPool=default_worker_pool())::Vector{QQFieldElem}
11+
R, _ = polynomial_ring(QQ, [:x])
12+
fₓ = map(p -> _from_univariate(R, p), f)
13+
factors = unique([p for fₓ′ in fₓ for (p, _) in factor(fₓ′)])
14+
# We map each factor to its roots, and each root to its factor
15+
roots_by_factor = Dict{QQMPolyRingElem,Vector{Vector{Vector{QQFieldElem}}}}()
16+
factors_by_root = Dict{Vector{Vector{QQFieldElem}},QQMPolyRingElem}()
17+
for factor in factors
18+
roots = real_solutions(Ideal([factor]); interval=true, worker_pool=worker_pool)
19+
roots_by_factor[factor] = roots
20+
for r in roots
21+
factors_by_root[r] = factor
22+
end
23+
end
24+
# We order intervals by their left endpoint
25+
roots = sort(collect(keys(factors_by_root)), by=x -> x[1][1])
26+
# If intervals are not ordered by their right endpoint,
27+
# we merge the offending factors
28+
while (!issorted(roots, by=x -> x[1][2]))
29+
i = findfirst(i -> roots[i][1][2] > roots[i+1][1][2], 1:length(roots)-1)
30+
bad_factors = [factors_by_root[roots[i]], factors_by_root[roots[i+1]]]
31+
# Remove old factors and roots
32+
for bad_factor in bad_factors
33+
for r in roots_by_factor[bad_factor]
34+
delete!(factors_by_root, r)
35+
end
36+
delete!(roots_by_factor, bad_factor)
37+
end
38+
# Replace with merged factor and its roots
39+
merged_factor = prod(bad_factors)
40+
merged_roots = real_solutions(Ideal([merged_factor]); interval=true, worker_pool=worker_pool)
41+
roots_by_factor[merged_factor] = merged_roots
42+
for r in merged_roots
43+
factors_by_root[r] = merged_factor
44+
end
45+
roots = sort(collect(keys(factors_by_root)), by=r -> r[1][1])
46+
end
47+
# Now the intervals are properly ordered, we can sample points
48+
points = Vector{QQFieldElem}()
49+
if length(roots) == 0
50+
push!(points, QQ(0))
51+
return points
52+
end
53+
push!(points, floor(roots[1][1][1]) - 1)
54+
for i in 1:length(roots)-1
55+
push!(points, (roots[i][1][2] + roots[i+1][1][1]) // 2)
56+
end
57+
push!(points, ceil(roots[end][1][2]) + 1)
58+
return points
59+
end
60+
61+
function _sample_points_0(f::Vector{QQMPolyRingElem})::Vector{Vector{QQFieldElem}}
62+
@assert all(map(is_constant, f))
63+
return [Vector{QQFieldElem}()]
64+
end
65+
66+
function _sample_points_1(f::Vector{QQMPolyRingElem}; worker_pool::AbstractWorkerPool=default_worker_pool())::Vector{Vector{QQFieldElem}}
67+
@assert all(map(is_univariate, f))
68+
R, _ = polynomial_ring(QQ, :x)
69+
return [[p] for p in _sample_points(map(p -> to_univariate(R, p), f); worker_pool=worker_pool)]
70+
end
71+
72+
function _sample_points_desc(n::Int)::String
73+
if n == 0
74+
return "Constant"
75+
elseif n == 1
76+
return "Univariate"
77+
elseif n == 2
78+
return "Bivariate"
79+
elseif n == 3
80+
return "Trivariate"
81+
else
82+
return "Multivariate"
83+
end
84+
end
85+
86+
function _sample_points_2(
87+
f::Vector{QQMPolyRingElem},
88+
xs::Vector{QQMPolyRingElem};
89+
nr_thrds::Int=1,
90+
worker_pool::AbstractWorkerPool=default_worker_pool(),
91+
show_progress::Bool=false,
92+
desc::String="$(_sample_points_desc(length(xs))) sample points"
93+
)::Vector{Vector{QQFieldElem}}
94+
@assert length(xs) >= 2
95+
x₁ = xs[1:end-1]
96+
x₂ = xs[end]
97+
factors = unique([p for fₓ in f for (p, _) in factor(fₓ)])
98+
v = Vector{QQMPolyRingElem}()
99+
for i in eachindex(factors)
100+
if !isempty(intersect(x₁, vars(factors[i])))
101+
push!(v, leading_coefficient(factors[i], length(xs)))
102+
end
103+
if x₂ in vars(factors[i])
104+
push!(v, Interpolation.discriminant(factors[i], length(xs); nr_thrds=nr_thrds))
105+
end
106+
end
107+
for i in 1:length(factors)-1
108+
for j in i+1:length(factors)
109+
if x₂ in vars(factors[i]) || x₂ in vars(factors[j])
110+
push!(v, Interpolation.resultant(factors[i], factors[j], length(xs); nr_thrds=nr_thrds))
111+
end
112+
end
113+
end
114+
p₁ = if length(xs) == 2
115+
_sample_points_1(v; worker_pool=worker_pool)
116+
else
117+
_sample_points_2(v, x₁; show_progress=show_progress, worker_pool=worker_pool)
118+
end
119+
prog = Progress.ProgressBar(total=length(p₁); desc=desc, enabled=show_progress)
120+
Progress.update!(prog, 0)
121+
function _points_chunk(p::Vector{Vector{QQFieldElem}})::Vector{Vector{QQFieldElem}}
122+
res_chunk = Vector{Vector{QQFieldElem}}()
123+
for i in eachindex(p)
124+
p₂ = _sample_points_1(map(p′ -> evaluate(p′, x₁, p[i]), f); worker_pool=worker_pool)
125+
for j in eachindex(p₂)
126+
push!(res_chunk, [p[i]; p₂[j][1]])
127+
end
128+
Progress.next!(prog)
129+
end
130+
res_chunk
131+
end
132+
chunk_size = ceil(Int, length(p₁) / nr_thrds)
133+
chunks = [p₁[i:min(i + chunk_size - 1, end)] for i in 1:chunk_size:length(p₁)]
134+
tasks = [Threads.@spawn _points_chunk(chunk) for chunk in chunks]
135+
res = sort(vcat(fetch.(tasks)...))
136+
Progress.finish!(prog)
137+
res
138+
end
139+
140+
function _sample_points(
141+
f::Vector{QQMPolyRingElem};
142+
nr_thrds::Int=1,
143+
worker_pool::AbstractWorkerPool=default_worker_pool(),
144+
show_progress::Bool=false
145+
)::Vector{Vector{QQFieldElem}}
146+
if any(is_zero, f)
147+
error("Cannot sample points for polynomials containing zero polynomial")
148+
end
149+
p = Vector{Vector{QQFieldElem}}([])
150+
if length(gens(parent(f[1]))) == 0
151+
p = _sample_points_0(f)
152+
elseif length(gens(parent(f[1]))) == 1
153+
p = _sample_points_1(f; worker_pool=worker_pool)
154+
else
155+
p = _sample_points_2(f, gens(parent(f[1])); nr_thrds=nr_thrds, worker_pool=worker_pool, show_progress=show_progress)
156+
end
157+
return p
158+
end
159+
160+
function points_per_components(
161+
eqs::Vector{QQMPolyRingElem},
162+
pos::Vector{QQMPolyRingElem},
163+
ineqs::Vector{QQMPolyRingElem};
164+
nr_thrds::Int=1,
165+
worker_pool::AbstractWorkerPool=default_worker_pool(),
166+
info_level::Int=0
167+
)::Vector{Vector{Vector{QQFieldElem}}}
168+
if !isempty(eqs)
169+
throw(ArgumentError("Sampling points per components is not implemented for systems involving equations"))
170+
end
171+
points = _sample_points([pos; ineqs]; nr_thrds=nr_thrds, worker_pool=worker_pool, show_progress=info_level > 0)
172+
res = Vector{Vector{Vector{QQFieldElem}}}()
173+
for point in points
174+
if all(evaluate(constraint, point) > 0 for constraint in pos)
175+
push!(res, map(x -> [x, x], point))
176+
end
177+
end
178+
return res
179+
end

src/imports.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using StaticArrays
77
import Random: MersenneTwister
88
import Logging: ConsoleLogger, with_logger, Warn, Info
99
import Printf: @sprintf, @printf
10-
import Distributed: AbstractWorkerPool, remotecall_fetch
10+
import Distributed: AbstractWorkerPool, default_worker_pool, remotecall_fetch
1111

1212
import Nemo:
1313
bell,

test/algorithms/sampling.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
function points_per_components(ineqs::Vector{QQMPolyRingElem})::Vector{Vector{Vector{QQFieldElem}}}
2+
AlgebraicSolving.points_per_components(QQMPolyRingElem[], QQMPolyRingElem[], ineqs)
3+
end
4+
5+
@testset "Algorithms -> Univariate sample points" begin
6+
R, (x,) = polynomial_ring(QQ, ["x"])
7+
f = [one(R)]
8+
@test length(points_per_components(f)) == 1
9+
f = [x + 1, x + 4294967295 // 4294967296, x + 4294967297 // 4294967296]
10+
@test length(points_per_components(f)) == 4
11+
end
12+
13+
@testset "Algorithms -> Bivariate sample points" begin
14+
R, (x, y) = polynomial_ring(QQ, ["x", "y"])
15+
f = [one(R)]
16+
@test length(points_per_components(f)) == 1
17+
f = [x, x + 1]
18+
@test length(points_per_components(f)) == 3
19+
f = [y, y + 1]
20+
@test length(points_per_components(f)) == 3
21+
f = [x, x + 1, y - x, y + x - 1]
22+
@test length(points_per_components(f)) >= 9
23+
end
24+
25+
@testset "Algorithms -> Trivariate sample points" begin
26+
R, (x, y, z) = polynomial_ring(QQ, ["x", "y", "z"])
27+
f = [one(R)]
28+
@test length(points_per_components(f)) == 1
29+
f = [x, x + 1]
30+
@test length(points_per_components(f)) == 3
31+
f = [y, y + 1]
32+
@test length(points_per_components(f)) == 3
33+
f = [x, x + 1, y - x, y + x - 1]
34+
@test length(points_per_components(f)) >= 9
35+
f = [x, y, z]
36+
@test length(points_per_components(f)) == 8
37+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ include("algorithms/dimension.jl")
1010
include("algorithms/hilbert.jl")
1111
include("algorithms/decomposition.jl")
1212
include("algorithms/param-curves.jl")
13+
include("algorithms/sampling.jl")
1314
include("examples/katsura.jl")
1415
include("interp/thiele.jl")
1516
include("interp/newton.jl")

0 commit comments

Comments
 (0)