Skip to content

Commit baeaff0

Browse files
Enhance flaw and resolver integration: Update flaw constructors to accept causes and add new clause/disjunction methods in basic_solver
1 parent f5679bf commit baeaff0

4 files changed

Lines changed: 162 additions & 11 deletions

File tree

include/basic_solver.hpp

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,16 @@ namespace ratio
1818
class flaw
1919
{
2020
public:
21-
flaw(basic_solver &slv) noexcept;
21+
flaw(basic_solver &slv, std::vector<std::reference_wrapper<resolver>> &&causes) noexcept;
2222
flaw(const flaw &) = delete;
2323
virtual ~flaw() = default;
2424

2525
private:
2626
virtual void compute_resolvers() = 0;
2727

2828
protected:
29-
basic_solver &slv; // The solver managing this flaw..
29+
basic_solver &slv; // The solver managing this flaw..
30+
std::vector<std::reference_wrapper<resolver>> causes; // The causes of this flaw..
3031
};
3132

3233
class resolver
@@ -51,6 +52,9 @@ namespace ratio
5152

5253
[[nodiscard]] riddle::expr new_enum(riddle::component_type &tp, std::vector<riddle::expr> &&values) override;
5354

55+
void new_clause(std::vector<riddle::bool_expr> &&exprs) override;
56+
void new_disjunction(std::vector<std::unique_ptr<riddle::conjunction>> &&disjuncts) override;
57+
5458
/**
5559
* @brief Creates a new flaw of the given type.
5660
*
@@ -113,14 +117,16 @@ namespace ratio
113117
#endif
114118

115119
private:
116-
std::vector<std::unique_ptr<flaw>> flaws; // The set of flaws
117-
std::vector<std::unique_ptr<resolver>> resolvers; // The set of resolvers
120+
std::vector<std::unique_ptr<flaw>> flaws; // The set of flaws
121+
std::vector<std::unique_ptr<resolver>> resolvers; // The set of resolvers
122+
std::optional<std::reference_wrapper<flaw>> c_flaw; // The current flaw..
123+
std::optional<std::reference_wrapper<resolver>> c_res; // The current resolver..
118124
};
119125

120126
class enum_flaw final : public flaw
121127
{
122128
public:
123-
enum_flaw(basic_solver &slv, riddle::enum_expr var) noexcept;
129+
enum_flaw(basic_solver &slv, std::vector<std::reference_wrapper<resolver>> &&causes, riddle::enum_expr var) noexcept;
124130

125131
[[nodiscard]] const riddle::enum_expr &get_var() const noexcept { return var; }
126132

@@ -130,4 +136,32 @@ namespace ratio
130136
private:
131137
riddle::enum_expr var;
132138
};
139+
140+
class clause_flaw final : public flaw
141+
{
142+
public:
143+
clause_flaw(basic_solver &slv, std::vector<std::reference_wrapper<resolver>> &&causes, std::vector<riddle::bool_expr> &&clause, const bool &exclusive = false) noexcept;
144+
145+
[[nodiscard]] const std::vector<riddle::bool_expr> &get_clause() const noexcept { return clause; }
146+
147+
private:
148+
void compute_resolvers() override;
149+
150+
private:
151+
std::vector<riddle::bool_expr> clause;
152+
};
153+
154+
class disjunction_flaw final : public flaw
155+
{
156+
public:
157+
disjunction_flaw(basic_solver &slv, std::vector<std::reference_wrapper<resolver>> &&causes, std::vector<std::unique_ptr<riddle::conjunction>> &&disjuncts) noexcept;
158+
159+
[[nodiscard]] const std::vector<std::unique_ptr<riddle::conjunction>> &get_disjuncts() const noexcept { return disjuncts; }
160+
161+
private:
162+
void compute_resolvers() override;
163+
164+
private:
165+
std::vector<std::unique_ptr<riddle::conjunction>> disjuncts;
166+
};
133167
} // namespace ratio

include/solver_core.hpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,6 @@ namespace ratio
4343
[[nodiscard]] riddle::arith_expr new_product(std::vector<riddle::arith_expr> &&xprs) override;
4444
[[nodiscard]] riddle::arith_expr new_division(std::vector<riddle::arith_expr> &&xprs) override;
4545

46-
void new_clause(std::vector<riddle::bool_expr> &&exprs) override;
47-
void new_disjunction(std::vector<std::unique_ptr<riddle::conjunction>> &&disjuncts) override;
48-
4946
virtual void solve() = 0;
5047

5148
private:

src/basic_solver.cpp

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
namespace ratio
77
{
8-
flaw::flaw(basic_solver &slv) noexcept : slv(slv) {}
8+
flaw::flaw(basic_solver &slv, std::vector<std::reference_wrapper<resolver>> &&causes) noexcept : slv(slv), causes(std::move(causes)) {}
99

1010
resolver::resolver(flaw &flw, utils::rational &&intrinsic_cost) noexcept : flw(flw), intrinsic_cost(std::move(intrinsic_cost)) {}
1111

@@ -21,19 +21,66 @@ namespace ratio
2121
}
2222
else
2323
{
24+
std::vector<std::reference_wrapper<resolver>> causes;
25+
if (c_res)
26+
causes.push_back(c_res.value());
2427
std::vector<std::reference_wrapper<const utils::enum_val>> ev_refs;
2528
for (auto &ev_ptr : values)
2629
ev_refs.emplace_back(*ev_ptr);
2730
auto ev = ac_slv.new_var(ev_refs);
2831
// .. and create a new enum flaw to manage the variable..
29-
auto &ef = new_flaw<enum_flaw>(*this, std::make_shared<riddle::enum_item>(tp, std::move(values), ev));
32+
auto &ef = new_flaw<enum_flaw>(*this, std::move(causes), std::make_shared<riddle::enum_item>(tp, std::move(values), ev));
3033
return ef.get_var();
3134
}
3235
}
3336

37+
void basic_solver::new_clause(std::vector<riddle::bool_expr> &&exprs)
38+
{
39+
assert(!exprs.empty());
40+
if (exprs.size() == 1)
41+
{ // if there is only one expression, just execute it..
42+
if (!execute(exprs[0]))
43+
throw std::runtime_error("Unsatisfiable constraints");
44+
}
45+
else
46+
{ // otherwise, create a new clause flaw..
47+
std::vector<std::reference_wrapper<resolver>> causes;
48+
if (c_res)
49+
causes.push_back(c_res.value());
50+
std::vector<utils::lit> clause;
51+
clause.reserve(exprs.size());
52+
for (const riddle::bool_expr &expr : exprs)
53+
clause.push_back(static_cast<const riddle::bool_item &>(*expr).get_lit());
54+
55+
auto &ac_cnstr = ac_slv.new_clause(std::move(clause));
56+
if (c_res) // if there is a current resolver, add the expression to it..
57+
c_res->get().ac_cnsts.push_back(ac_cnstr);
58+
else
59+
ac_slv.add_constraint(ac_cnstr);
60+
new_flaw<clause_flaw>(*this, std::move(causes), std::move(exprs));
61+
}
62+
}
63+
void basic_solver::new_disjunction(std::vector<std::unique_ptr<riddle::conjunction>> &&disjuncts)
64+
{
65+
assert(disjuncts.size() > 1);
66+
std::vector<std::reference_wrapper<resolver>> causes;
67+
if (c_res)
68+
causes.push_back(c_res.value());
69+
70+
new_flaw<disjunction_flaw>(*this, std::move(causes), std::move(disjuncts));
71+
}
72+
3473
void basic_solver::solve() {}
3574

36-
enum_flaw::enum_flaw(basic_solver &slv, riddle::enum_expr var) noexcept : flaw(slv), var(std::move(var)) {}
75+
enum_flaw::enum_flaw(basic_solver &slv, std::vector<std::reference_wrapper<resolver>> &&causes, riddle::enum_expr var) noexcept : flaw(slv, std::move(causes)), var(std::move(var)) {}
3776

3877
void enum_flaw::compute_resolvers() {}
78+
79+
clause_flaw::clause_flaw(basic_solver &slv, std::vector<std::reference_wrapper<resolver>> &&causes, std::vector<riddle::bool_expr> &&clause, const bool &exclusive) noexcept : flaw(slv, std::move(causes), exclusive), clause(std::move(clause)) {}
80+
81+
void clause_flaw::compute_resolvers() {}
82+
83+
disjunction_flaw::disjunction_flaw(basic_solver &slv, std::vector<std::reference_wrapper<resolver>> &&causes, std::vector<std::unique_ptr<riddle::conjunction>> &&disjuncts) noexcept : flaw(slv, std::move(causes)), disjuncts(std::move(disjuncts)) {}
84+
85+
void disjunction_flaw::compute_resolvers() {}
3986
} // namespace ratio

src/solver_core.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,77 @@ namespace ratio
4343
values.push_back(ev_ptr);
4444
return values;
4545
};
46+
47+
riddle::arith_expr solver_core::new_negation(riddle::arith_expr xpr)
48+
{
49+
if (xpr->get_type().get_name() == riddle::int_kw)
50+
return std::make_shared<riddle::arith_item>(static_cast<riddle::int_type &>(get_type(riddle::int_kw)), -static_cast<const riddle::arith_item &>(*xpr).get_lin());
51+
else if (xpr->get_type().get_name() == riddle::real_kw)
52+
return std::make_shared<riddle::arith_item>(static_cast<riddle::real_type &>(get_type(riddle::real_kw)), -static_cast<const riddle::arith_item &>(*xpr).get_lin());
53+
else
54+
throw std::runtime_error("Invalid type");
55+
}
56+
57+
riddle::arith_expr solver_core::new_sum(std::vector<riddle::arith_expr> &&xprs)
58+
{
59+
assert(xprs.size() > 1);
60+
utils::lin sum;
61+
for (const riddle::arith_expr &xpr : xprs)
62+
sum += static_cast<const riddle::arith_item &>(*xpr).get_lin();
63+
auto &tp = type_promotion(xprs);
64+
if (tp.get_name() == riddle::int_kw)
65+
return std::make_shared<riddle::arith_item>(static_cast<riddle::int_type &>(get_type(riddle::int_kw)), std::move(sum));
66+
else if (tp.get_name() == riddle::real_kw)
67+
return std::make_shared<riddle::arith_item>(static_cast<riddle::real_type &>(get_type(riddle::real_kw)), std::move(sum));
68+
else
69+
throw std::runtime_error("Invalid type");
70+
}
71+
riddle::arith_expr solver_core::new_subtraction(std::vector<riddle::arith_expr> &&xprs)
72+
{
73+
assert(xprs.size() > 1);
74+
utils::lin sub = static_cast<const riddle::arith_item &>(*xprs[0]).get_lin();
75+
for (size_t i = 1; i < xprs.size(); i++)
76+
sub -= static_cast<const riddle::arith_item &>(*xprs[i]).get_lin();
77+
auto &tp = type_promotion(xprs);
78+
if (tp.get_name() == riddle::int_kw)
79+
return std::make_shared<riddle::arith_item>(static_cast<riddle::int_type &>(get_type(riddle::int_kw)), std::move(sub));
80+
else if (tp.get_name() == riddle::real_kw)
81+
return std::make_shared<riddle::arith_item>(static_cast<riddle::real_type &>(get_type(riddle::real_kw)), std::move(sub));
82+
else
83+
throw std::runtime_error("Invalid type");
84+
}
85+
riddle::arith_expr solver_core::new_product(std::vector<riddle::arith_expr> &&xprs)
86+
{
87+
assert(xprs.size() > 1);
88+
utils::lin prod;
89+
for (const riddle::arith_expr &xpr : xprs)
90+
if (static_cast<const riddle::arith_item &>(*xpr).get_lin().vars.empty())
91+
prod *= static_cast<const riddle::arith_item &>(*xpr).get_lin().known_term;
92+
else
93+
throw std::runtime_error("Non-linear arithmetic not supported");
94+
auto &tp = type_promotion(xprs);
95+
if (tp.get_name() == riddle::int_kw)
96+
return std::make_shared<riddle::arith_item>(static_cast<riddle::int_type &>(get_type(riddle::int_kw)), std::move(prod));
97+
else if (tp.get_name() == riddle::real_kw)
98+
return std::make_shared<riddle::arith_item>(static_cast<riddle::real_type &>(get_type(riddle::real_kw)), std::move(prod));
99+
else
100+
throw std::runtime_error("Invalid type");
101+
}
102+
riddle::arith_expr solver_core::new_division(std::vector<riddle::arith_expr> &&xprs)
103+
{
104+
assert(xprs.size() > 1);
105+
utils::lin div = static_cast<const riddle::arith_item &>(*xprs[0]).get_lin();
106+
for (size_t i = 1; i < xprs.size(); i++)
107+
if (static_cast<const riddle::arith_item &>(*xprs[i]).get_lin().vars.empty())
108+
div /= static_cast<const riddle::arith_item &>(*xprs[i]).get_lin().known_term;
109+
else
110+
throw std::runtime_error("Non-linear arithmetic not supported");
111+
auto &tp = type_promotion(xprs);
112+
if (tp.get_name() == riddle::int_kw)
113+
return std::make_shared<riddle::arith_item>(static_cast<riddle::int_type &>(get_type(riddle::int_kw)), std::move(div));
114+
else if (tp.get_name() == riddle::real_kw)
115+
return std::make_shared<riddle::arith_item>(static_cast<riddle::real_type &>(get_type(riddle::real_kw)), std::move(div));
116+
else
117+
throw std::runtime_error("Invalid type");
118+
}
46119
} // namespace ratio

0 commit comments

Comments
 (0)