|
1 | 1 | #include "basic_solver.hpp" |
2 | 2 | #include "items.hpp" |
| 3 | +#include "conjunction.hpp" |
3 | 4 | #include "logging.hpp" |
4 | 5 | #include <cassert> |
5 | 6 |
|
@@ -72,15 +73,207 @@ namespace ratio |
72 | 73 |
|
73 | 74 | void basic_solver::solve() {} |
74 | 75 |
|
| 76 | + riddle::atom_expr basic_solver::create_atom(bool is_fact, riddle::predicate &pred, std::map<std::string, riddle::expr, std::less<>> &&args) |
| 77 | + { |
| 78 | + std::vector<std::reference_wrapper<resolver>> causes; |
| 79 | + if (c_res) |
| 80 | + causes.push_back(c_res.value()); |
| 81 | + |
| 82 | + auto &af = new_flaw<atom_flaw>(*this, std::move(causes), is_fact, pred, std::move(args), ac_slv.new_sat()); |
| 83 | + return af.get_atom(); |
| 84 | + } |
| 85 | + |
| 86 | + bool basic_solver::execute(const riddle::bool_expr &expr) noexcept |
| 87 | + { |
| 88 | + if (auto n_xpr = dynamic_cast<riddle::bool_not *>(expr.get())) |
| 89 | + { |
| 90 | + if (auto b_xpr = dynamic_cast<riddle::bool_item *>(n_xpr->get_arg().get())) |
| 91 | + { |
| 92 | + auto &a_cnstr = ac_slv.new_assign(utils::variable(b_xpr->get_lit()), utils::sign(b_xpr->get_lit()) ? arc_consistency::solver::False : arc_consistency::solver::True); |
| 93 | + ac_slv.add_constraint(a_cnstr); |
| 94 | + if (c_res) |
| 95 | + c_res->get().ac_cnsts.push_back(a_cnstr); |
| 96 | + return true; |
| 97 | + } |
| 98 | + else if (auto lt_xpr = dynamic_cast<riddle::lt_term *>(n_xpr->get_arg().get())) |
| 99 | + return lin_slv.new_gt(static_cast<riddle::arith_item *>(lt_xpr->get_lhs().get())->get_lin(), static_cast<riddle::arith_item *>(lt_xpr->get_rhs().get())->get_lin(), false, c_res ? std::make_optional(std::ref(c_res->get().cnst)) : std::nullopt); |
| 100 | + else if (auto le_xpr = dynamic_cast<riddle::le_term *>(n_xpr->get_arg().get())) |
| 101 | + return lin_slv.new_gt(static_cast<riddle::arith_item *>(le_xpr->get_lhs().get())->get_lin(), static_cast<riddle::arith_item *>(le_xpr->get_rhs().get())->get_lin(), true, c_res ? std::make_optional(std::ref(c_res->get().cnst)) : std::nullopt); |
| 102 | + else if (auto eq_xpr = dynamic_cast<riddle::eq_term *>(n_xpr->get_arg().get())) |
| 103 | + { |
| 104 | + if (&*eq_xpr->get_lhs() == &*eq_xpr->get_rhs()) // the terms are the same, so they are equal.. |
| 105 | + return false; |
| 106 | + else if (&eq_xpr->get_lhs()->get_type() != &eq_xpr->get_lhs()->get_type()) // the types are different, so the constraint is always false.. |
| 107 | + return true; |
| 108 | + else if (auto lhs_xpr = std::dynamic_pointer_cast<riddle::arith_item>(eq_xpr->get_lhs())) // we are dealing with an arithmetic constraint.. |
| 109 | + { |
| 110 | + new_clause({std::make_shared<riddle::lt_term>(static_cast<riddle::bool_type &>(get_type(riddle::bool_kw)), lhs_xpr, std::static_pointer_cast<riddle::arith_item>(eq_xpr->get_rhs())), std::make_shared<riddle::gt_term>(static_cast<riddle::bool_type &>(get_type(riddle::bool_kw)), lhs_xpr, std::static_pointer_cast<riddle::arith_item>(eq_xpr->get_rhs()))}); |
| 111 | + return true; |
| 112 | + } |
| 113 | + else if (auto lhs_sxpr = dynamic_cast<riddle::string_item *>(eq_xpr->get_lhs().get())) // we are dealing with a string constraint.. |
| 114 | + return lhs_sxpr->get_string() != static_cast<riddle::string_item &>(*eq_xpr->get_rhs()).get_string(); |
| 115 | + else if (auto lhs_bxpr = dynamic_cast<riddle::bool_item *>(eq_xpr->get_lhs().get())) // we are dealing with a boolean constraint.. |
| 116 | + { |
| 117 | + auto &neq_cnstr = ac_slv.new_distinct(utils::variable(lhs_bxpr->get_lit()), utils::variable(static_cast<riddle::bool_item &>(*eq_xpr->get_rhs()).get_lit())); |
| 118 | + ac_slv.add_constraint(neq_cnstr); |
| 119 | + if (c_res) // if there is a current resolver, add the expression to it.. |
| 120 | + c_res->get().ac_cnsts.push_back(neq_cnstr); |
| 121 | + return true; |
| 122 | + } |
| 123 | + else if (auto lhs_enum_xpr = dynamic_cast<riddle::enum_item *>(eq_xpr->get_lhs().get())) // we are dealing with an enum constraint.. |
| 124 | + { |
| 125 | + if (auto rhs_enum_xpr = dynamic_cast<riddle::enum_item *>(eq_xpr->get_rhs().get())) |
| 126 | + { // both sides are enum items.. |
| 127 | + auto &neq_cnstr = ac_slv.new_distinct(lhs_enum_xpr->get_var(), rhs_enum_xpr->get_var()); |
| 128 | + ac_slv.add_constraint(neq_cnstr); |
| 129 | + if (c_res) // if there is a current resolver, add the expression to it.. |
| 130 | + c_res->get().ac_cnsts.push_back(neq_cnstr); |
| 131 | + return true; |
| 132 | + } |
| 133 | + else |
| 134 | + { |
| 135 | + auto &neq_cnstr = ac_slv.new_forbid(lhs_enum_xpr->get_var(), *eq_xpr->get_rhs()); |
| 136 | + ac_slv.add_constraint(neq_cnstr); |
| 137 | + if (c_res) // if there is a current resolver, add the expression to it.. |
| 138 | + c_res->get().ac_cnsts.push_back(neq_cnstr); |
| 139 | + return true; |
| 140 | + } |
| 141 | + } |
| 142 | + else if (auto rhs_enum_xpr = dynamic_cast<riddle::enum_item *>(eq_xpr->get_rhs().get())) |
| 143 | + { |
| 144 | + auto &neq_cnstr = ac_slv.new_forbid(rhs_enum_xpr->get_var(), *eq_xpr->get_lhs()); |
| 145 | + ac_slv.add_constraint(neq_cnstr); |
| 146 | + if (c_res) // if there is a current resolver, add the expression to it.. |
| 147 | + c_res->get().ac_cnsts.push_back(neq_cnstr); |
| 148 | + return true; |
| 149 | + } |
| 150 | + else if (auto lhs_atm = dynamic_cast<riddle::atom_term *>(eq_xpr->get_lhs().get())) |
| 151 | + { |
| 152 | + auto rhs_atm = static_cast<riddle::atom_term *>(eq_xpr->get_rhs().get()); |
| 153 | + std::vector<riddle::bool_expr> clause_exprs; |
| 154 | + std::queue<riddle::predicate *> q; |
| 155 | + q.push(static_cast<riddle::predicate *>(&lhs_xpr->get_type())); |
| 156 | + while (!q.empty()) |
| 157 | + { |
| 158 | + for (const auto &[f_name, f] : q.front()->get_fields()) |
| 159 | + clause_exprs.push_back(std::make_shared<riddle::bool_not>(static_cast<riddle::bool_type &>(get_type(riddle::bool_kw)), std::make_shared<riddle::eq_term>(static_cast<riddle::bool_type &>(get_type(riddle::bool_kw)), lhs_atm->get(f_name), rhs_atm->get(f_name)))); |
| 160 | + for (const auto &pp : q.front()->get_parents()) |
| 161 | + q.push(&pp.get()); |
| 162 | + q.pop(); |
| 163 | + } |
| 164 | + new_clause(std::move(clause_exprs)); |
| 165 | + return true; |
| 166 | + } |
| 167 | + else |
| 168 | + return true; |
| 169 | + } |
| 170 | + else if (auto ge_xpr = dynamic_cast<riddle::ge_term *>(n_xpr->get_arg().get())) |
| 171 | + return lin_slv.new_lt(static_cast<riddle::arith_item *>(ge_xpr->get_lhs().get())->get_lin(), static_cast<riddle::arith_item *>(ge_xpr->get_rhs().get())->get_lin(), false, c_res ? std::make_optional(std::ref(c_res->get().cnst)) : std::nullopt); |
| 172 | + else if (auto gt_xpr = dynamic_cast<riddle::gt_term *>(n_xpr->get_arg().get())) |
| 173 | + return lin_slv.new_lt(static_cast<riddle::arith_item *>(gt_xpr->get_lhs().get())->get_lin(), static_cast<riddle::arith_item *>(gt_xpr->get_rhs().get())->get_lin(), true, c_res ? std::make_optional(std::ref(c_res->get().cnst)) : std::nullopt); |
| 174 | + else |
| 175 | + return false; // unknown expression inside negation.. |
| 176 | + } |
| 177 | + else |
| 178 | + { |
| 179 | + if (auto b_xpr = dynamic_cast<riddle::bool_item *>(expr.get())) |
| 180 | + { |
| 181 | + auto &a_cnstr = ac_slv.new_assign(utils::variable(b_xpr->get_lit()), utils::sign(b_xpr->get_lit()) ? arc_consistency::solver::True : arc_consistency::solver::False); |
| 182 | + ac_slv.add_constraint(a_cnstr); |
| 183 | + if (c_res) |
| 184 | + c_res->get().ac_cnsts.push_back(a_cnstr); |
| 185 | + return true; |
| 186 | + } |
| 187 | + else if (auto lt_xpr = dynamic_cast<riddle::lt_term *>(expr.get())) |
| 188 | + return lin_slv.new_lt(static_cast<riddle::arith_item *>(lt_xpr->get_lhs().get())->get_lin(), static_cast<riddle::arith_item *>(lt_xpr->get_rhs().get())->get_lin(), true, c_res ? std::make_optional(std::ref(c_res->get().cnst)) : std::nullopt); |
| 189 | + else if (auto le_xpr = dynamic_cast<riddle::le_term *>(expr.get())) |
| 190 | + return lin_slv.new_lt(static_cast<riddle::arith_item *>(le_xpr->get_lhs().get())->get_lin(), static_cast<riddle::arith_item *>(le_xpr->get_rhs().get())->get_lin(), false, c_res ? std::make_optional(std::ref(c_res->get().cnst)) : std::nullopt); |
| 191 | + else if (auto eq_xpr = dynamic_cast<riddle::eq_term *>(expr.get())) |
| 192 | + { |
| 193 | + if (&*eq_xpr->get_lhs() == &*eq_xpr->get_rhs()) // the terms are the same, so they are equal.. |
| 194 | + return true; |
| 195 | + else if (&eq_xpr->get_lhs()->get_type() != &eq_xpr->get_lhs()->get_type()) // the types are different, so the constraint is always false.. |
| 196 | + return false; |
| 197 | + else if (auto lhs_xpr = dynamic_cast<riddle::arith_item *>(eq_xpr->get_lhs().get())) // we are dealing with an arithmetic constraint.. |
| 198 | + return lin_slv.new_eq(lhs_xpr->get_lin(), static_cast<riddle::arith_item *>(eq_xpr->get_rhs().get())->get_lin(), c_res ? std::make_optional(std::ref(c_res->get().cnst)) : std::nullopt); |
| 199 | + else if (auto lhs_sxpr = dynamic_cast<riddle::string_item *>(eq_xpr->get_lhs().get())) // we are dealing with a string constraint.. |
| 200 | + return lhs_sxpr->get_string() == static_cast<riddle::string_item &>(*eq_xpr->get_rhs()).get_string(); |
| 201 | + else if (auto lhs_bxpr = dynamic_cast<riddle::bool_item *>(eq_xpr->get_lhs().get())) // we are dealing with a boolean constraint.. |
| 202 | + { |
| 203 | + auto &eq_cnstr = ac_slv.new_equal(utils::variable(lhs_bxpr->get_lit()), utils::variable(static_cast<riddle::bool_item &>(*eq_xpr->get_rhs()).get_lit())); |
| 204 | + ac_slv.add_constraint(eq_cnstr); |
| 205 | + if (c_res) // if there is a current resolver, add the expression to it.. |
| 206 | + c_res->get().ac_cnsts.push_back(eq_cnstr); |
| 207 | + return true; |
| 208 | + } |
| 209 | + else if (auto lhs_enum_xpr = dynamic_cast<riddle::enum_item *>(eq_xpr->get_lhs().get())) // we are dealing with an enum constraint.. |
| 210 | + { |
| 211 | + if (auto rhs_enum_xpr = dynamic_cast<riddle::enum_item *>(eq_xpr->get_rhs().get())) |
| 212 | + { // both sides are enum items.. |
| 213 | + auto &eq_cnstr = ac_slv.new_equal(lhs_enum_xpr->get_var(), rhs_enum_xpr->get_var()); |
| 214 | + ac_slv.add_constraint(eq_cnstr); |
| 215 | + if (c_res) // if there is a current resolver, add the expression to it.. |
| 216 | + c_res->get().ac_cnsts.push_back(eq_cnstr); |
| 217 | + return true; |
| 218 | + } |
| 219 | + else |
| 220 | + { |
| 221 | + auto &eq_cnstr = ac_slv.new_assign(lhs_enum_xpr->get_var(), *eq_xpr->get_rhs()); |
| 222 | + ac_slv.add_constraint(eq_cnstr); |
| 223 | + if (c_res) // if there is a current resolver, add the expression to it.. |
| 224 | + c_res->get().ac_cnsts.push_back(eq_cnstr); |
| 225 | + return true; |
| 226 | + } |
| 227 | + } |
| 228 | + else if (auto rhs_enum_xpr = dynamic_cast<riddle::enum_item *>(eq_xpr->get_rhs().get())) |
| 229 | + { |
| 230 | + auto &eq_cnstr = ac_slv.new_assign(rhs_enum_xpr->get_var(), *eq_xpr->get_lhs()); |
| 231 | + ac_slv.add_constraint(eq_cnstr); |
| 232 | + if (c_res) // if there is a current resolver, add the expression to it.. |
| 233 | + c_res->get().ac_cnsts.push_back(eq_cnstr); |
| 234 | + return true; |
| 235 | + } |
| 236 | + else if (auto lhs_atm = dynamic_cast<riddle::atom_term *>(eq_xpr->get_lhs().get())) // we are dealing with an atom constraint.. |
| 237 | + { |
| 238 | + auto rhs_atm = static_cast<riddle::atom_term *>(eq_xpr->get_rhs().get()); |
| 239 | + std::queue<riddle::predicate *> q; |
| 240 | + q.push(static_cast<riddle::predicate *>(&rhs_atm->get_type())); |
| 241 | + while (!q.empty()) |
| 242 | + { |
| 243 | + for (const auto &[f_name, f] : q.front()->get_fields()) |
| 244 | + if (!execute(std::make_shared<riddle::eq_term>(static_cast<riddle::bool_type &>(get_type(riddle::bool_kw)), lhs_atm->get(f_name), rhs_atm->get(f_name)))) |
| 245 | + return false; |
| 246 | + for (const auto &pp : q.front()->get_parents()) |
| 247 | + q.push(&pp.get()); |
| 248 | + q.pop(); |
| 249 | + } |
| 250 | + return true; |
| 251 | + } |
| 252 | + else |
| 253 | + return false; |
| 254 | + } |
| 255 | + else if (auto ge_xpr = dynamic_cast<riddle::ge_term *>(expr.get())) |
| 256 | + return lin_slv.new_gt(static_cast<riddle::arith_item *>(ge_xpr->get_lhs().get())->get_lin(), static_cast<riddle::arith_item *>(ge_xpr->get_rhs().get())->get_lin(), false, c_res ? std::make_optional(std::ref(c_res->get().cnst)) : std::nullopt); |
| 257 | + else if (auto gt_xpr = dynamic_cast<riddle::gt_term *>(expr.get())) |
| 258 | + return lin_slv.new_gt(static_cast<riddle::arith_item *>(gt_xpr->get_lhs().get())->get_lin(), static_cast<riddle::arith_item *>(gt_xpr->get_rhs().get())->get_lin(), true, c_res ? std::make_optional(std::ref(c_res->get().cnst)) : std::nullopt); |
| 259 | + else |
| 260 | + return false; // unsupported expression, just return false.. |
| 261 | + } |
| 262 | + } |
| 263 | + |
75 | 264 | enum_flaw::enum_flaw(basic_solver &slv, std::vector<std::reference_wrapper<resolver>> &&causes, riddle::enum_expr var) noexcept : flaw(slv, std::move(causes)), var(std::move(var)) {} |
76 | 265 |
|
77 | 266 | void enum_flaw::compute_resolvers() {} |
78 | 267 |
|
79 | | - clause_flaw::clause_flaw(basic_solver &slv, std::vector<std::reference_wrapper<resolver>> &&causes, std::vector<riddle::bool_expr> &&clause, const bool &exclusive) noexcept : flaw(slv, std::move(causes), exclusive), clause(std::move(clause)) {} |
| 268 | + clause_flaw::clause_flaw(basic_solver &slv, std::vector<std::reference_wrapper<resolver>> &&causes, std::vector<riddle::bool_expr> &&clause) noexcept : flaw(slv, std::move(causes)), clause(std::move(clause)) {} |
80 | 269 |
|
81 | 270 | void clause_flaw::compute_resolvers() {} |
82 | 271 |
|
83 | 272 | disjunction_flaw::disjunction_flaw(basic_solver &slv, std::vector<std::reference_wrapper<resolver>> &&causes, std::vector<std::unique_ptr<riddle::conjunction>> &&disjuncts) noexcept : flaw(slv, std::move(causes)), disjuncts(std::move(disjuncts)) {} |
84 | 273 |
|
85 | 274 | void disjunction_flaw::compute_resolvers() {} |
| 275 | + |
| 276 | + atom_flaw::atom_flaw(basic_solver &slv, std::vector<std::reference_wrapper<resolver>> &&causes, bool is_fact, riddle::predicate &pred, std::map<std::string, riddle::expr, std::less<>> &&args, utils::lit &&sigma) noexcept : flaw(slv, std::move(causes)), atm(std::make_shared<riddle::atom>(pred, is_fact, std::move(args), std::move(sigma))) {} |
| 277 | + |
| 278 | + void atom_flaw::compute_resolvers() {} |
86 | 279 | } // namespace ratio |
0 commit comments