Skip to content

Commit e3184b1

Browse files
Enhance flaw handling: Add compute_flaw_cost method, update flaw cost notifications, and improve cost estimation logic in solver_core and server classes.
1 parent a0db490 commit e3184b1

6 files changed

Lines changed: 107 additions & 19 deletions

File tree

include/basic_solver.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ namespace ratio
3535
private:
3636
[[nodiscard]] riddle::atom_expr create_atom(bool is_fact, riddle::predicate &pred, std::map<std::string, riddle::expr, std::less<>> &&args) override;
3737

38+
void compute_flaw_cost(flaw &f) noexcept;
39+
3840
struct Node
3941
{
4042
std::size_t id = 0; // The unique identifier of the node..

include/server/basic_solver_server.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,17 @@ namespace ratio
1919
void on_ws_error(network::ws_server_session_base &ws, const std::error_code &);
2020

2121
void state_changed() noexcept override;
22+
2223
void flaw_created(const flaw &f) noexcept override;
24+
void flaw_cost_changed(const flaw &f) noexcept override;
25+
2326
void resolver_created(const ratio::resolver &r) noexcept override;
2427

2528
void current_flaw(std::optional<std::reference_wrapper<ratio::flaw>> f) noexcept override;
2629
void current_resolver(std::optional<std::reference_wrapper<ratio::resolver>> r) noexcept override;
2730

2831
void causal_link_added(const flaw &f, const resolver &r) override;
29-
32+
3033
private:
3134
std::unordered_set<network::ws_server_session_base *> clients;
3235
};

include/solver_core.hpp

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ namespace ratio
2727
[[nodiscard]] const std::vector<std::reference_wrapper<resolver>> &get_resolvers() const noexcept { return resolvers; }
2828
[[nodiscard]] const std::vector<std::reference_wrapper<resolver>> &get_supports() const noexcept { return supports; }
2929

30+
[[nodiscard]] const utils::rational &get_estimated_cost() const noexcept { return est_cost; }
31+
3032
[[nodiscard]] virtual json::json to_json() const;
3133

3234
protected:
@@ -38,11 +40,12 @@ namespace ratio
3840
virtual void compute_resolvers() = 0;
3941

4042
private:
41-
solver_core &slv; // The solver managing this flaw..
42-
bool expanded = false; // Whether the flaw has been expanded..
43-
std::vector<std::reference_wrapper<resolver>> causes; // The causes of this flaw..
44-
std::vector<std::reference_wrapper<resolver>> resolvers; // The resolvers for this flaw..
45-
std::vector<std::reference_wrapper<resolver>> supports; // The resolvers supported by this flaw..
43+
solver_core &slv; // The solver managing this flaw..
44+
bool expanded = false; // Whether the flaw has been expanded..
45+
std::vector<std::reference_wrapper<resolver>> causes; // The causes of this flaw..
46+
std::vector<std::reference_wrapper<resolver>> resolvers; // The resolvers for this flaw..
47+
std::vector<std::reference_wrapper<resolver>> supports; // The resolvers supported by this flaw..
48+
utils::rational est_cost = utils::rational::positive_infinite; // The estimated cost to resolve this flaw..
4649
};
4750

4851
class resolver
@@ -61,6 +64,8 @@ namespace ratio
6164

6265
[[nodiscard]] const utils::rational &get_intrinsic_cost() const noexcept { return intrinsic_cost; }
6366

67+
[[nodiscard]] utils::rational get_estimated_cost() const noexcept;
68+
6469
[[nodiscard]] const std::vector<std::reference_wrapper<flaw>> &get_preconditions() const noexcept { return preconditions; }
6570

6671
[[nodiscard]] virtual json::json to_json() const;
@@ -187,6 +192,8 @@ namespace ratio
187192

188193
void retract_resolver(resolver &res) noexcept;
189194

195+
void set_flaw_cost(flaw &f, const utils::rational &cost) noexcept;
196+
190197
private:
191198
[[nodiscard]] riddle::atom_state get_atom_state(const riddle::atom_term &atom) const noexcept override;
192199

@@ -210,6 +217,15 @@ namespace ratio
210217
*/
211218
virtual void flaw_created([[maybe_unused]] const flaw &f) noexcept {}
212219

220+
/**
221+
* @brief Notifies when the cost of a flaw has changed.
222+
*
223+
* This function is called when the cost of a flaw has changed. It is a virtual function that can be overridden by derived classes to perform specific actions when a flaw's cost changes.
224+
*
225+
* @param f The flaw whose cost has changed.
226+
*/
227+
virtual void flaw_cost_changed([[maybe_unused]] const flaw &f) {}
228+
213229
/**
214230
* @brief Notifies that a new resolver has been created.
215231
*
@@ -224,28 +240,28 @@ namespace ratio
224240
*
225241
* This function is called to inform about the current flaw being processed in the solver.
226242
*
227-
* @param The current flaw being processed.
243+
* @param f The current flaw being processed.
228244
*/
229-
virtual void current_flaw([[maybe_unused]] std::optional<std::reference_wrapper<ratio::flaw>>) noexcept {}
245+
virtual void current_flaw([[maybe_unused]] std::optional<std::reference_wrapper<ratio::flaw>> f) noexcept {}
230246

231247
/**
232248
* @brief Notifies about the current resolver being applied.
233249
*
234250
* This function is called to inform about the current resolver being applied in the solver.
235251
*
236-
* @param The current resolver being applied.
252+
* @param r The current resolver being applied.
237253
*/
238-
virtual void current_resolver([[maybe_unused]] std::optional<std::reference_wrapper<ratio::resolver>>) noexcept {}
254+
virtual void current_resolver([[maybe_unused]] std::optional<std::reference_wrapper<ratio::resolver>> r) noexcept {}
239255

240256
/**
241257
* @brief Notifies when a causal link has been added.
242258
*
243259
* This function is called when a causal link has been added. It is a virtual function that can be overridden by derived classes to perform specific actions when a causal link is added.
244260
*
245-
* @param flaw The flaw that is the source of the causal link.
246-
* @param resolver The resolver that is the destination of the causal link.
261+
* @param f The flaw that is the source of the causal link.
262+
* @param r The resolver that is the destination of the causal link.
247263
*/
248-
virtual void causal_link_added([[maybe_unused]] const flaw &, [[maybe_unused]] const resolver &) {}
264+
virtual void causal_link_added([[maybe_unused]] const flaw &f, [[maybe_unused]] const resolver &r) {}
249265
#endif
250266

251267
protected:

src/basic_solver.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "basic_solver.hpp"
22
#include "conjunction.hpp"
33
#include "logging.hpp"
4+
#include <stack>
45
#include <cassert>
56

67
namespace ratio
@@ -93,7 +94,10 @@ namespace ratio
9394
LOG_DEBUG(flw.to_json().dump());
9495
// Compute the resolvers for the selected flaw..
9596
if (!flw.is_expanded())
97+
{
9698
compute_resolvers(flw);
99+
compute_flaw_cost(flw);
100+
}
97101
if (flw.get_resolvers().size() == 1)
98102
{ // If there is only one resolver and applying it does not lead to a conflict, continue from the current node..
99103
auto &res = flw.get_resolvers().front().get();
@@ -130,6 +134,33 @@ namespace ratio
130134
return af.get_atom();
131135
}
132136

137+
void basic_solver::compute_flaw_cost(flaw &f) noexcept
138+
{
139+
std::stack<std::pair<flaw *, std::unordered_set<flaw *>>> stk;
140+
stk.push({&f, {}}); // we push the flaw in the stack..
141+
142+
while (!stk.empty())
143+
{
144+
auto c_f = stk.top();
145+
stk.pop();
146+
147+
utils::rational c_cost = utils::rational::positive_infinite;
148+
for (const auto &res : c_f.first->get_resolvers())
149+
c_cost = std::min(c_cost, res.get().get_estimated_cost());
150+
151+
if (c_f.first->get_estimated_cost() != c_cost) // we update the cost of the flaw..
152+
{
153+
set_flaw_cost(*c_f.first, c_cost);
154+
// we propagate the cost to the causes..
155+
for (auto &cause : c_f.first->get_causes())
156+
stk.push({&cause.get().get_flaw(), c_f.second}); // we push the cause flaw in the stack..
157+
// we propagate the cost to the supported resolvers..
158+
for (auto &support : c_f.first->get_supports())
159+
stk.push({&support.get().get_flaw(), c_f.second}); // we push the supported flaw in the stack..
160+
}
161+
}
162+
}
163+
133164
std::shared_ptr<basic_solver::Node> basic_solver::find_common_ancestor(std::shared_ptr<Node> a, std::shared_ptr<Node> b) const
134165
{
135166
std::unordered_set<std::shared_ptr<Node>> ancestors;

src/server/basic_solver_server.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,15 @@ namespace ratio
5353
{
5454
auto j_msg = f.to_json();
5555
j_msg["id"] = static_cast<uint64_t>(f.get_id());
56-
j_msg["cost"] = linspire::to_json(utils::rational::zero);
5756
j_msg["state"] = "active";
5857
j_msg["msg_type"] = "flaw_created";
5958
auto msg = j_msg.dump();
6059
for (auto client : clients)
6160
client->send(msg);
6261
}
63-
void server::current_flaw(std::optional<std::reference_wrapper<ratio::flaw>> f) noexcept
62+
void server::flaw_cost_changed(const ratio::flaw &f) noexcept
6463
{
65-
auto j_msg = json::json{{"msg_type", "current_flaw"}};
66-
if (f)
67-
j_msg["id"] = f.value().get().get_id();
64+
auto j_msg = json::json{{"msg_type", "flaw_cost_changed"}, {"id", f.get_id()}, {"cost", linspire::to_json(f.get_estimated_cost())}};
6865
auto msg = j_msg.dump();
6966
for (auto client : clients)
7067
client->send(msg);
@@ -80,6 +77,16 @@ namespace ratio
8077
for (auto client : clients)
8178
client->send(msg);
8279
}
80+
81+
void server::current_flaw(std::optional<std::reference_wrapper<ratio::flaw>> f) noexcept
82+
{
83+
auto j_msg = json::json{{"msg_type", "current_flaw"}};
84+
if (f)
85+
j_msg["id"] = f.value().get().get_id();
86+
auto msg = j_msg.dump();
87+
for (auto client : clients)
88+
client->send(msg);
89+
}
8390
void server::current_resolver(std::optional<std::reference_wrapper<ratio::resolver>> r) noexcept
8491
{
8592
auto j_msg = json::json{{"msg_type", "current_resolver"}};

src/solver_core.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55

66
#ifdef ORATIO_ENABLE_LISTENERS
77
#define CURRENT_FLAW(f) current_flaw(f)
8+
#define FLAW_COST_CHANGED(f) flaw_cost_changed(f)
89
#define CURRENT_RESOLVER(r) current_resolver(r)
910
#define NEW_CAUSAL_LINK(f, r) causal_link_added(f, r)
1011
#else
1112
#define CURRENT_FLAW(f)
13+
#define FLAW_COST_CHANGED(f)
1214
#define CURRENT_RESOLVER(r)
1315
#define NEW_CAUSAL_LINK(f, r)
1416
#endif
@@ -25,7 +27,7 @@ namespace ratio
2527

2628
json::json flaw::to_json() const
2729
{
28-
json::json j_flaw;
30+
json::json j_flaw{{"cost", linspire::to_json(est_cost)}};
2931
if (!causes.empty())
3032
{
3133
json::json j_causes(json::json_type::array);
@@ -43,6 +45,24 @@ namespace ratio
4345
throw std::runtime_error("Failed to execute expression in resolver");
4446
}
4547

48+
utils::rational resolver::resolver::get_estimated_cost() const noexcept
49+
{
50+
if (preconditions.empty())
51+
return intrinsic_cost;
52+
#ifdef H_ADD
53+
// we compute the cost of the resolver as the sum of its intrinsic cost and the estimated costs of its preconditions..
54+
return std::accumulate(preconditions.begin(), preconditions.end(), intrinsic_cost, [](const auto &lhs, const auto &prec)
55+
{ return lhs + prec.get().get_estimated_cost(); });
56+
#endif
57+
#ifdef H_MAX
58+
// we compute the cost of the resolver as the sum of its intrinsic cost and the maximum of its preconditions' estimated costs..
59+
return intrinsic_cost + (*std::max_element(preconditions.begin(), preconditions.end(), [](const auto &lhs, const auto &rhs)
60+
{ return lhs.get().get_estimated_cost() < rhs.get().get_estimated_cost(); }))
61+
.get()
62+
.get_estimated_cost();
63+
#endif
64+
}
65+
4666
json::json resolver::to_json() const
4767
{
4868
json::json j_resolver{{"flaw", flw.get_id()}, {"intrinsic_cost", linspire::to_json(intrinsic_cost)}};
@@ -445,6 +465,15 @@ namespace ratio
445465
ac_slv.retract(ac_cnst);
446466
}
447467

468+
void solver_core::set_flaw_cost(flaw &flw, const utils::rational &cost) noexcept
469+
{
470+
if (flw.est_cost != cost)
471+
{
472+
flw.est_cost = cost;
473+
FLAW_COST_CHANGED(flw);
474+
}
475+
}
476+
448477
riddle::atom_state solver_core::get_atom_state(const riddle::atom_term &atm) const noexcept
449478
{
450479
switch (ac_slv.sat_val(static_cast<const riddle::atom &>(atm).get_sigma()))

0 commit comments

Comments
 (0)