Skip to content

Commit 270e7d0

Browse files
Enhance basic_solver and solver_core: Add atom creation and matching functionalities
1 parent 8cf849b commit 270e7d0

4 files changed

Lines changed: 252 additions & 3 deletions

File tree

include/basic_solver.hpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ namespace ratio
9999

100100
void solve() override;
101101

102+
private:
103+
[[nodiscard]] riddle::atom_expr create_atom(bool is_fact, riddle::predicate &pred, std::map<std::string, riddle::expr, std::less<>> &&args) override;
104+
105+
[[nodiscard]] bool execute(const riddle::bool_expr &expr) noexcept;
106+
102107
#ifdef ORATIO_ENABLE_LISTENERS
103108
private:
104109
/**
@@ -144,7 +149,7 @@ namespace ratio
144149
class clause_flaw final : public flaw
145150
{
146151
public:
147-
clause_flaw(basic_solver &slv, std::vector<std::reference_wrapper<resolver>> &&causes, std::vector<riddle::bool_expr> &&clause, const bool &exclusive = false) noexcept;
152+
clause_flaw(basic_solver &slv, std::vector<std::reference_wrapper<resolver>> &&causes, std::vector<riddle::bool_expr> &&clause) noexcept;
148153

149154
[[nodiscard]] const std::vector<riddle::bool_expr> &get_clause() const noexcept { return clause; }
150155

@@ -168,4 +173,18 @@ namespace ratio
168173
private:
169174
std::vector<std::unique_ptr<riddle::conjunction>> disjuncts;
170175
};
176+
177+
class atom_flaw final : public flaw
178+
{
179+
public:
180+
atom_flaw(basic_solver &slv, std::vector<std::reference_wrapper<resolver>> &&causes, bool is_fact, riddle::predicate &pred, std::map<std::string, riddle::expr, std::less<>> &&args, utils::lit &&sigma) noexcept;
181+
182+
[[nodiscard]] const riddle::atom_expr &get_atom() const noexcept { return atm; }
183+
184+
private:
185+
void compute_resolvers() override;
186+
187+
private:
188+
riddle::atom_expr atm;
189+
};
171190
} // namespace ratio

include/solver_core.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,9 @@ namespace ratio
4545

4646
virtual void solve() = 0;
4747

48+
[[nodiscard]] bool match(riddle::term &lhs, riddle::term &rhs) const;
49+
4850
private:
49-
[[nodiscard]] riddle::atom_expr create_atom(bool is_fact, riddle::predicate &pred, std::map<std::string, riddle::expr, std::less<>> &&args) override;
5051
[[nodiscard]] riddle::atom_state get_atom_state(const riddle::atom_term &atom) const noexcept override;
5152

5253
protected:

src/basic_solver.cpp

Lines changed: 194 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "basic_solver.hpp"
22
#include "items.hpp"
3+
#include "conjunction.hpp"
34
#include "logging.hpp"
45
#include <cassert>
56

@@ -72,15 +73,207 @@ namespace ratio
7273

7374
void basic_solver::solve() {}
7475

76+
riddle::atom_expr basic_solver::create_atom(bool is_fact, riddle::predicate &pred, std::map<std::string, riddle::expr, std::less<>> &&args)
77+
{
78+
std::vector<std::reference_wrapper<resolver>> causes;
79+
if (c_res)
80+
causes.push_back(c_res.value());
81+
82+
auto &af = new_flaw<atom_flaw>(*this, std::move(causes), is_fact, pred, std::move(args), ac_slv.new_sat());
83+
return af.get_atom();
84+
}
85+
86+
bool basic_solver::execute(const riddle::bool_expr &expr) noexcept
87+
{
88+
if (auto n_xpr = dynamic_cast<riddle::bool_not *>(expr.get()))
89+
{
90+
if (auto b_xpr = dynamic_cast<riddle::bool_item *>(n_xpr->get_arg().get()))
91+
{
92+
auto &a_cnstr = ac_slv.new_assign(utils::variable(b_xpr->get_lit()), utils::sign(b_xpr->get_lit()) ? arc_consistency::solver::False : arc_consistency::solver::True);
93+
ac_slv.add_constraint(a_cnstr);
94+
if (c_res)
95+
c_res->get().ac_cnsts.push_back(a_cnstr);
96+
return true;
97+
}
98+
else if (auto lt_xpr = dynamic_cast<riddle::lt_term *>(n_xpr->get_arg().get()))
99+
return lin_slv.new_gt(static_cast<riddle::arith_item *>(lt_xpr->get_lhs().get())->get_lin(), static_cast<riddle::arith_item *>(lt_xpr->get_rhs().get())->get_lin(), false, c_res ? std::make_optional(std::ref(c_res->get().cnst)) : std::nullopt);
100+
else if (auto le_xpr = dynamic_cast<riddle::le_term *>(n_xpr->get_arg().get()))
101+
return lin_slv.new_gt(static_cast<riddle::arith_item *>(le_xpr->get_lhs().get())->get_lin(), static_cast<riddle::arith_item *>(le_xpr->get_rhs().get())->get_lin(), true, c_res ? std::make_optional(std::ref(c_res->get().cnst)) : std::nullopt);
102+
else if (auto eq_xpr = dynamic_cast<riddle::eq_term *>(n_xpr->get_arg().get()))
103+
{
104+
if (&*eq_xpr->get_lhs() == &*eq_xpr->get_rhs()) // the terms are the same, so they are equal..
105+
return false;
106+
else if (&eq_xpr->get_lhs()->get_type() != &eq_xpr->get_lhs()->get_type()) // the types are different, so the constraint is always false..
107+
return true;
108+
else if (auto lhs_xpr = std::dynamic_pointer_cast<riddle::arith_item>(eq_xpr->get_lhs())) // we are dealing with an arithmetic constraint..
109+
{
110+
new_clause({std::make_shared<riddle::lt_term>(static_cast<riddle::bool_type &>(get_type(riddle::bool_kw)), lhs_xpr, std::static_pointer_cast<riddle::arith_item>(eq_xpr->get_rhs())), std::make_shared<riddle::gt_term>(static_cast<riddle::bool_type &>(get_type(riddle::bool_kw)), lhs_xpr, std::static_pointer_cast<riddle::arith_item>(eq_xpr->get_rhs()))});
111+
return true;
112+
}
113+
else if (auto lhs_sxpr = dynamic_cast<riddle::string_item *>(eq_xpr->get_lhs().get())) // we are dealing with a string constraint..
114+
return lhs_sxpr->get_string() != static_cast<riddle::string_item &>(*eq_xpr->get_rhs()).get_string();
115+
else if (auto lhs_bxpr = dynamic_cast<riddle::bool_item *>(eq_xpr->get_lhs().get())) // we are dealing with a boolean constraint..
116+
{
117+
auto &neq_cnstr = ac_slv.new_distinct(utils::variable(lhs_bxpr->get_lit()), utils::variable(static_cast<riddle::bool_item &>(*eq_xpr->get_rhs()).get_lit()));
118+
ac_slv.add_constraint(neq_cnstr);
119+
if (c_res) // if there is a current resolver, add the expression to it..
120+
c_res->get().ac_cnsts.push_back(neq_cnstr);
121+
return true;
122+
}
123+
else if (auto lhs_enum_xpr = dynamic_cast<riddle::enum_item *>(eq_xpr->get_lhs().get())) // we are dealing with an enum constraint..
124+
{
125+
if (auto rhs_enum_xpr = dynamic_cast<riddle::enum_item *>(eq_xpr->get_rhs().get()))
126+
{ // both sides are enum items..
127+
auto &neq_cnstr = ac_slv.new_distinct(lhs_enum_xpr->get_var(), rhs_enum_xpr->get_var());
128+
ac_slv.add_constraint(neq_cnstr);
129+
if (c_res) // if there is a current resolver, add the expression to it..
130+
c_res->get().ac_cnsts.push_back(neq_cnstr);
131+
return true;
132+
}
133+
else
134+
{
135+
auto &neq_cnstr = ac_slv.new_forbid(lhs_enum_xpr->get_var(), *eq_xpr->get_rhs());
136+
ac_slv.add_constraint(neq_cnstr);
137+
if (c_res) // if there is a current resolver, add the expression to it..
138+
c_res->get().ac_cnsts.push_back(neq_cnstr);
139+
return true;
140+
}
141+
}
142+
else if (auto rhs_enum_xpr = dynamic_cast<riddle::enum_item *>(eq_xpr->get_rhs().get()))
143+
{
144+
auto &neq_cnstr = ac_slv.new_forbid(rhs_enum_xpr->get_var(), *eq_xpr->get_lhs());
145+
ac_slv.add_constraint(neq_cnstr);
146+
if (c_res) // if there is a current resolver, add the expression to it..
147+
c_res->get().ac_cnsts.push_back(neq_cnstr);
148+
return true;
149+
}
150+
else if (auto lhs_atm = dynamic_cast<riddle::atom_term *>(eq_xpr->get_lhs().get()))
151+
{
152+
auto rhs_atm = static_cast<riddle::atom_term *>(eq_xpr->get_rhs().get());
153+
std::vector<riddle::bool_expr> clause_exprs;
154+
std::queue<riddle::predicate *> q;
155+
q.push(static_cast<riddle::predicate *>(&lhs_xpr->get_type()));
156+
while (!q.empty())
157+
{
158+
for (const auto &[f_name, f] : q.front()->get_fields())
159+
clause_exprs.push_back(std::make_shared<riddle::bool_not>(static_cast<riddle::bool_type &>(get_type(riddle::bool_kw)), std::make_shared<riddle::eq_term>(static_cast<riddle::bool_type &>(get_type(riddle::bool_kw)), lhs_atm->get(f_name), rhs_atm->get(f_name))));
160+
for (const auto &pp : q.front()->get_parents())
161+
q.push(&pp.get());
162+
q.pop();
163+
}
164+
new_clause(std::move(clause_exprs));
165+
return true;
166+
}
167+
else
168+
return true;
169+
}
170+
else if (auto ge_xpr = dynamic_cast<riddle::ge_term *>(n_xpr->get_arg().get()))
171+
return lin_slv.new_lt(static_cast<riddle::arith_item *>(ge_xpr->get_lhs().get())->get_lin(), static_cast<riddle::arith_item *>(ge_xpr->get_rhs().get())->get_lin(), false, c_res ? std::make_optional(std::ref(c_res->get().cnst)) : std::nullopt);
172+
else if (auto gt_xpr = dynamic_cast<riddle::gt_term *>(n_xpr->get_arg().get()))
173+
return lin_slv.new_lt(static_cast<riddle::arith_item *>(gt_xpr->get_lhs().get())->get_lin(), static_cast<riddle::arith_item *>(gt_xpr->get_rhs().get())->get_lin(), true, c_res ? std::make_optional(std::ref(c_res->get().cnst)) : std::nullopt);
174+
else
175+
return false; // unknown expression inside negation..
176+
}
177+
else
178+
{
179+
if (auto b_xpr = dynamic_cast<riddle::bool_item *>(expr.get()))
180+
{
181+
auto &a_cnstr = ac_slv.new_assign(utils::variable(b_xpr->get_lit()), utils::sign(b_xpr->get_lit()) ? arc_consistency::solver::True : arc_consistency::solver::False);
182+
ac_slv.add_constraint(a_cnstr);
183+
if (c_res)
184+
c_res->get().ac_cnsts.push_back(a_cnstr);
185+
return true;
186+
}
187+
else if (auto lt_xpr = dynamic_cast<riddle::lt_term *>(expr.get()))
188+
return lin_slv.new_lt(static_cast<riddle::arith_item *>(lt_xpr->get_lhs().get())->get_lin(), static_cast<riddle::arith_item *>(lt_xpr->get_rhs().get())->get_lin(), true, c_res ? std::make_optional(std::ref(c_res->get().cnst)) : std::nullopt);
189+
else if (auto le_xpr = dynamic_cast<riddle::le_term *>(expr.get()))
190+
return lin_slv.new_lt(static_cast<riddle::arith_item *>(le_xpr->get_lhs().get())->get_lin(), static_cast<riddle::arith_item *>(le_xpr->get_rhs().get())->get_lin(), false, c_res ? std::make_optional(std::ref(c_res->get().cnst)) : std::nullopt);
191+
else if (auto eq_xpr = dynamic_cast<riddle::eq_term *>(expr.get()))
192+
{
193+
if (&*eq_xpr->get_lhs() == &*eq_xpr->get_rhs()) // the terms are the same, so they are equal..
194+
return true;
195+
else if (&eq_xpr->get_lhs()->get_type() != &eq_xpr->get_lhs()->get_type()) // the types are different, so the constraint is always false..
196+
return false;
197+
else if (auto lhs_xpr = dynamic_cast<riddle::arith_item *>(eq_xpr->get_lhs().get())) // we are dealing with an arithmetic constraint..
198+
return lin_slv.new_eq(lhs_xpr->get_lin(), static_cast<riddle::arith_item *>(eq_xpr->get_rhs().get())->get_lin(), c_res ? std::make_optional(std::ref(c_res->get().cnst)) : std::nullopt);
199+
else if (auto lhs_sxpr = dynamic_cast<riddle::string_item *>(eq_xpr->get_lhs().get())) // we are dealing with a string constraint..
200+
return lhs_sxpr->get_string() == static_cast<riddle::string_item &>(*eq_xpr->get_rhs()).get_string();
201+
else if (auto lhs_bxpr = dynamic_cast<riddle::bool_item *>(eq_xpr->get_lhs().get())) // we are dealing with a boolean constraint..
202+
{
203+
auto &eq_cnstr = ac_slv.new_equal(utils::variable(lhs_bxpr->get_lit()), utils::variable(static_cast<riddle::bool_item &>(*eq_xpr->get_rhs()).get_lit()));
204+
ac_slv.add_constraint(eq_cnstr);
205+
if (c_res) // if there is a current resolver, add the expression to it..
206+
c_res->get().ac_cnsts.push_back(eq_cnstr);
207+
return true;
208+
}
209+
else if (auto lhs_enum_xpr = dynamic_cast<riddle::enum_item *>(eq_xpr->get_lhs().get())) // we are dealing with an enum constraint..
210+
{
211+
if (auto rhs_enum_xpr = dynamic_cast<riddle::enum_item *>(eq_xpr->get_rhs().get()))
212+
{ // both sides are enum items..
213+
auto &eq_cnstr = ac_slv.new_equal(lhs_enum_xpr->get_var(), rhs_enum_xpr->get_var());
214+
ac_slv.add_constraint(eq_cnstr);
215+
if (c_res) // if there is a current resolver, add the expression to it..
216+
c_res->get().ac_cnsts.push_back(eq_cnstr);
217+
return true;
218+
}
219+
else
220+
{
221+
auto &eq_cnstr = ac_slv.new_assign(lhs_enum_xpr->get_var(), *eq_xpr->get_rhs());
222+
ac_slv.add_constraint(eq_cnstr);
223+
if (c_res) // if there is a current resolver, add the expression to it..
224+
c_res->get().ac_cnsts.push_back(eq_cnstr);
225+
return true;
226+
}
227+
}
228+
else if (auto rhs_enum_xpr = dynamic_cast<riddle::enum_item *>(eq_xpr->get_rhs().get()))
229+
{
230+
auto &eq_cnstr = ac_slv.new_assign(rhs_enum_xpr->get_var(), *eq_xpr->get_lhs());
231+
ac_slv.add_constraint(eq_cnstr);
232+
if (c_res) // if there is a current resolver, add the expression to it..
233+
c_res->get().ac_cnsts.push_back(eq_cnstr);
234+
return true;
235+
}
236+
else if (auto lhs_atm = dynamic_cast<riddle::atom_term *>(eq_xpr->get_lhs().get())) // we are dealing with an atom constraint..
237+
{
238+
auto rhs_atm = static_cast<riddle::atom_term *>(eq_xpr->get_rhs().get());
239+
std::queue<riddle::predicate *> q;
240+
q.push(static_cast<riddle::predicate *>(&rhs_atm->get_type()));
241+
while (!q.empty())
242+
{
243+
for (const auto &[f_name, f] : q.front()->get_fields())
244+
if (!execute(std::make_shared<riddle::eq_term>(static_cast<riddle::bool_type &>(get_type(riddle::bool_kw)), lhs_atm->get(f_name), rhs_atm->get(f_name))))
245+
return false;
246+
for (const auto &pp : q.front()->get_parents())
247+
q.push(&pp.get());
248+
q.pop();
249+
}
250+
return true;
251+
}
252+
else
253+
return false;
254+
}
255+
else if (auto ge_xpr = dynamic_cast<riddle::ge_term *>(expr.get()))
256+
return lin_slv.new_gt(static_cast<riddle::arith_item *>(ge_xpr->get_lhs().get())->get_lin(), static_cast<riddle::arith_item *>(ge_xpr->get_rhs().get())->get_lin(), false, c_res ? std::make_optional(std::ref(c_res->get().cnst)) : std::nullopt);
257+
else if (auto gt_xpr = dynamic_cast<riddle::gt_term *>(expr.get()))
258+
return lin_slv.new_gt(static_cast<riddle::arith_item *>(gt_xpr->get_lhs().get())->get_lin(), static_cast<riddle::arith_item *>(gt_xpr->get_rhs().get())->get_lin(), true, c_res ? std::make_optional(std::ref(c_res->get().cnst)) : std::nullopt);
259+
else
260+
return false; // unsupported expression, just return false..
261+
}
262+
}
263+
75264
enum_flaw::enum_flaw(basic_solver &slv, std::vector<std::reference_wrapper<resolver>> &&causes, riddle::enum_expr var) noexcept : flaw(slv, std::move(causes)), var(std::move(var)) {}
76265

77266
void enum_flaw::compute_resolvers() {}
78267

79-
clause_flaw::clause_flaw(basic_solver &slv, std::vector<std::reference_wrapper<resolver>> &&causes, std::vector<riddle::bool_expr> &&clause, const bool &exclusive) noexcept : flaw(slv, std::move(causes), exclusive), clause(std::move(clause)) {}
268+
clause_flaw::clause_flaw(basic_solver &slv, std::vector<std::reference_wrapper<resolver>> &&causes, std::vector<riddle::bool_expr> &&clause) noexcept : flaw(slv, std::move(causes)), clause(std::move(clause)) {}
80269

81270
void clause_flaw::compute_resolvers() {}
82271

83272
disjunction_flaw::disjunction_flaw(basic_solver &slv, std::vector<std::reference_wrapper<resolver>> &&causes, std::vector<std::unique_ptr<riddle::conjunction>> &&disjuncts) noexcept : flaw(slv, std::move(causes)), disjuncts(std::move(disjuncts)) {}
84273

85274
void disjunction_flaw::compute_resolvers() {}
275+
276+
atom_flaw::atom_flaw(basic_solver &slv, std::vector<std::reference_wrapper<resolver>> &&causes, bool is_fact, riddle::predicate &pred, std::map<std::string, riddle::expr, std::less<>> &&args, utils::lit &&sigma) noexcept : flaw(slv, std::move(causes)), atm(std::make_shared<riddle::atom>(pred, is_fact, std::move(args), std::move(sigma))) {}
277+
278+
void atom_flaw::compute_resolvers() {}
86279
} // namespace ratio

src/solver_core.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,40 @@ namespace ratio
116116
else
117117
throw std::runtime_error("Invalid type");
118118
}
119+
120+
bool solver_core::match(riddle::term &lhs, riddle::term &rhs) const
121+
{
122+
if (&lhs == &rhs) // the terms are the same, so they match..
123+
return true;
124+
else if (&lhs.get_type() != &rhs.get_type()) // the types are different, so the terms cannot match..
125+
return false;
126+
else if (auto lhs_xpr = dynamic_cast<riddle::arith_item *>(&lhs)) // we are dealing with arithmetic terms..
127+
return lin_slv.match(lhs_xpr->get_lin(), static_cast<riddle::arith_item &>(rhs).get_lin());
128+
else if (auto lhs_bxpr = dynamic_cast<riddle::bool_item *>(&lhs)) // we are dealing with boolean terms..
129+
return ac_slv.match(lhs_bxpr->get_lit(), static_cast<riddle::bool_item &>(rhs).get_lit());
130+
else if (auto lhs_enum_xpr = dynamic_cast<riddle::enum_item *>(&lhs)) // we are dealing with enum terms..
131+
{
132+
if (auto rhs_enum_xpr = dynamic_cast<riddle::enum_item *>(&rhs))
133+
return ac_slv.match(lhs_enum_xpr->get_var(), rhs_enum_xpr->get_var());
134+
else
135+
return ac_slv.allows(lhs_enum_xpr->get_var(), rhs);
136+
}
137+
else if (auto rhs_enum_xpr = dynamic_cast<riddle::enum_item *>(&rhs)) // we are dealing with enum terms..
138+
return ac_slv.allows(rhs_enum_xpr->get_var(), lhs);
139+
else
140+
throw std::runtime_error("Matching not supported for this term type");
141+
}
142+
143+
riddle::atom_state solver_core::get_atom_state(const riddle::atom_term &atm) const noexcept
144+
{
145+
switch (ac_slv.sat_val(static_cast<const riddle::atom &>(atm).get_sigma()))
146+
{
147+
case utils::True:
148+
return riddle::active;
149+
case utils::False:
150+
return riddle::unified;
151+
default:
152+
return riddle::inactive;
153+
}
154+
}
119155
} // namespace ratio

0 commit comments

Comments
 (0)