Skip to content

Commit 5079d84

Browse files
committed
allow specifying how to handle outputs
1 parent 677011e commit 5079d84

9 files changed

Lines changed: 124 additions & 12 deletions

File tree

include/scl/simulation/context.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#pragma once
1919

2020
#include <cstddef>
21+
#include <functional>
2122
#include <type_traits>
2223

2324
#include "scl/simulation/event.h"
@@ -40,7 +41,8 @@ class SimulatorContext final {
4041
static SimulatorContext create(
4142
std::size_t number_of_parties,
4243
NetworkParams network_params,
43-
std::vector<Simulator::SimulationHook>&& hooks);
44+
std::vector<Simulator::SimulationHook>&& hooks,
45+
std::optional<Simulator::OutputHandler>&& output_handler);
4446

4547
/**
4648
* @brief Get the context of a particular party in the simulation.
@@ -62,6 +64,15 @@ class SimulatorContext final {
6264
}
6365
}
6466

67+
/**
68+
* @brief Handle a protocol output.
69+
*/
70+
void handleOutput(std::size_t pid, std::any output) {
71+
if (m_output_handler.has_value()) {
72+
m_output_handler.value()(pid, output);
73+
}
74+
}
75+
6576
/**
6677
* @brief Start the clock of party.
6778
*/
@@ -97,6 +108,9 @@ class SimulatorContext final {
97108
return m_number_of_parties;
98109
}
99110

111+
/**
112+
* @brief Extract result of a simulation.
113+
*/
100114
Simulator::Result toResult() {
101115
return Simulator::Result{std::move(m_events)};
102116
}
@@ -107,6 +121,8 @@ class SimulatorContext final {
107121
std::vector<EventList> m_events;
108122
std::vector<Time::TimePoint> m_clocks;
109123
std::vector<Simulator::SimulationHook> m_hooks;
124+
std::optional<Simulator::OutputHandler> m_output_handler;
125+
std::vector<bool> m_cancellation_map;
110126

111127
SimulatorContext(NetworkParams network_params)
112128
: m_network_params(network_params) {}
@@ -148,6 +164,13 @@ class Context {
148164
m_ctx.startClock(m_id);
149165
}
150166

167+
/**
168+
* @brief Handle a protocol output.
169+
*/
170+
void output(std::any output) {
171+
m_ctx.handleOutput(m_id, output);
172+
}
173+
151174
/**
152175
* @brief Append an event to this party's list of events.
153176
*/

include/scl/simulation/event.h

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ namespace scl {
3232
/**
3333
* @brief Event types.
3434
* @ingroup eval-sim
35+
*
36+
* EventType denotes the different types of events that can be spawned by
37+
* simulation. All event types, with the exception of TRANSIENT, arrise due to
38+
* protocol actions of one form of another.
3539
*/
3640
enum class EventType {
3741
/**
@@ -97,7 +101,12 @@ enum class EventType {
97101
/**
98102
* @brief Event emitted when a protocol sleeps.
99103
*/
100-
SLEEP
104+
SLEEP,
105+
106+
/**
107+
* @brief Event emitted when a protocol outputs something.
108+
*/
109+
OUTPUT
101110
};
102111

103112
/**
@@ -419,6 +428,19 @@ class SleepEvent final : public Event {
419428
Time::Duration m_duration;
420429
};
421430

431+
/**
432+
* @brief Event issued when a protocol has an output.
433+
* @ingroup eval-sim
434+
*/
435+
class OutputEvent final : public Event {
436+
public:
437+
using Event::Event;
438+
void write(std::ostream& stream) override;
439+
EventType type() const override {
440+
return EventType::OUTPUT;
441+
}
442+
};
443+
422444
/**
423445
* @brief Tracks events added by a party during simulation.
424446
* @ingroup eval-sim

include/scl/simulation/simulator.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#pragma once
1919

20+
#include <any>
2021
#include <concepts>
2122
#include <functional>
2223
#include <memory>
@@ -102,6 +103,11 @@ class Simulator final {
102103
HookType hook;
103104
};
104105

106+
/**
107+
* @brief Type for functions that can be used to process protocol outputs.
108+
*/
109+
using OutputHandler = std::function<void(std::size_t, std::any)>;
110+
105111
/**
106112
* @brief Run the simulation.
107113
*/
@@ -129,8 +135,18 @@ class Simulator final {
129135
m_hooks.emplace_back(SimulationHook{.trigger = {}, .hook = hook});
130136
}
131137

138+
/**
139+
* @brief Instruct the simulator how to handler protocol outputs.
140+
*/
141+
template <typename HANDLER>
142+
requires(std::convertible_to<HANDLER, OutputHandler>)
143+
void addOutputHandler(HANDLER handler) {
144+
m_output_handler = handler;
145+
}
146+
132147
private:
133148
std::vector<SimulationHook> m_hooks;
149+
std::optional<OutputHandler> m_output_handler;
134150

135151
Result run(std::vector<std::unique_ptr<Protocol>>&& protocols,
136152
NetworkParams network_params);

src/scl/simulation/context.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,16 @@ using namespace scl;
2525
details::SimulatorContext details::SimulatorContext::create(
2626
std::size_t number_of_parties,
2727
NetworkParams network_params,
28-
std::vector<Simulator::SimulationHook>&& hooks) {
28+
std::vector<Simulator::SimulationHook>&& hooks,
29+
std::optional<Simulator::OutputHandler>&& output_handler) {
2930
SimulatorContext ctx{network_params};
3031

3132
ctx.m_number_of_parties = number_of_parties;
3233
ctx.m_events.resize(number_of_parties);
3334
ctx.m_clocks.resize(number_of_parties);
3435
ctx.m_hooks = std::move(hooks);
36+
ctx.m_output_handler = std::move(output_handler);
37+
ctx.m_cancellation_map.resize(number_of_parties, false);
3538

3639
return ctx;
3740
}

src/scl/simulation/event.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,14 @@ void SleepEvent::write(std::ostream& stream) {
176176
JSON_OBJ_END;
177177
}
178178

179+
void OutputEvent::write(std::ostream& stream) {
180+
JSON_OBJ_START;
181+
writeString(stream, "type", "OUTPUT");
182+
JSON_COMMA;
183+
writeTimestamp(stream, time());
184+
JSON_OBJ_END;
185+
}
186+
179187
#undef JSON_OBJ_START
180188
#undef JSON_OBJ_END
181189
#undef JSON_COMMA

src/scl/simulation/simulator.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ Task<void> runProtocol(std::unique_ptr<Protocol> protocol,
111111
ctx.addEvent<EndEvent>(et, protocol->name());
112112

113113
if (next.output.has_value()) {
114-
// TODO: handle output
114+
ctx.output(next.output);
115115
}
116116

117117
// This will suspend this party, allowing someone else to run. It's not
@@ -126,6 +126,10 @@ Task<void> runProtocol(std::unique_ptr<Protocol> protocol,
126126

127127
ctx.addEvent<StopEvent>(ctx.lastEvent()->time());
128128

129+
} catch (CancelledEvent&) {
130+
// this party was stopped by a user supplied hook.
131+
ctx.addEvent<CancelledEvent>(ctx.lastEvent()->time());
132+
129133
} catch (std::exception& e) {
130134
// all exceptions are caught and discarded, but we make sure to record an
131135
// event.
@@ -209,9 +213,11 @@ Simulator::Result Simulator::run(
209213
std::vector<std::unique_ptr<Protocol>>&& protocols,
210214
NetworkParams network_params) {
211215
if (!protocols.empty()) {
212-
auto sim_ctx = details::SimulatorContext::create(protocols.size(),
213-
network_params,
214-
std::move(m_hooks));
216+
auto sim_ctx =
217+
details::SimulatorContext::create(protocols.size(),
218+
network_params,
219+
std::move(m_hooks),
220+
std::move(m_output_handler));
215221
auto runtime = std::make_unique<details::SimulatorRuntime>(sim_ctx);
216222

217223
runtime->run(simulate(std::move(protocols), sim_ctx));

test/scl/simulation/test_channel.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ using namespace std::chrono_literals;
3131

3232
TEST_CASE("SimulatedChannel send", "[sim]") {
3333
auto nd = NetworkParams::create(2);
34-
auto ctx = details::SimulatorContext::create(2, nd, {});
34+
auto ctx = details::SimulatorContext::create(2, nd, {}, {});
3535

3636
auto transport = std::make_shared<details::Transport>(ctx);
3737
ChannelId id{0, 1};
@@ -70,7 +70,7 @@ TEST_CASE("SimulatedChannel send", "[sim]") {
7070

7171
TEST_CASE("SimulatedChannel recv", "[sim]") {
7272
auto nd = NetworkParams::create(2);
73-
auto ctx = details::SimulatorContext::create(2, nd, {});
73+
auto ctx = details::SimulatorContext::create(2, nd, {}, {});
7474
auto tp = std::make_shared<details::Transport>(ctx);
7575
ChannelId id{0, 1};
7676
auto channel = details::SimulatedChannel::create(id, ctx.getContext(0), tp);
@@ -106,7 +106,7 @@ TEST_CASE("SimulatedChannel recv timeout") {
106106
using namespace std::chrono_literals;
107107

108108
auto nd = NetworkParams::create(2);
109-
auto ctx = details::SimulatorContext::create(2, nd, {});
109+
auto ctx = details::SimulatorContext::create(2, nd, {}, {});
110110
auto tp = std::make_shared<details::Transport>(ctx);
111111
ChannelId id{0, 1};
112112

test/scl/simulation/test_simulator.cc

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
* along with this program. If not, see <https://www.gnu.org/licenses/>.
1616
*/
1717

18+
#include <any>
1819
#include <catch2/catch_test_macros.hpp>
1920
#include <initializer_list>
21+
#include <memory>
2022

2123
#include "scl/coro.h"
2224
#include "scl/protocol.h"
@@ -98,3 +100,35 @@ TEST_CASE("Simulator test", "[sim]") {
98100
EventType::END,
99101
EventType::STOP});
100102
}
103+
104+
namespace {
105+
106+
struct OutputProtocol final : public Protocol {
107+
Task<Result> run(Env& /* ignored */) const override {
108+
co_return Result::done(123);
109+
}
110+
};
111+
112+
} // namespace
113+
114+
TEST_CASE("Simulator handle output", "[sim]") {
115+
auto nd = NetworkParams::create(1);
116+
Simulator sim;
117+
118+
bool called = false;
119+
120+
sim.addOutputHandler([&called](std::size_t pid, std::any output) {
121+
called = (pid == 0) && (output.has_value()) &&
122+
(std::any_cast<int>(output) == 123);
123+
});
124+
125+
auto res = sim.run(
126+
[]() {
127+
std::vector<std::unique_ptr<Protocol>> p;
128+
p.emplace_back(std::make_unique<OutputProtocol>());
129+
return p;
130+
},
131+
nd);
132+
133+
REQUIRE(called);
134+
}

test/scl/simulation/test_transport.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ using namespace std::chrono_literals;
2929

3030
TEST_CASE("Transport send", "[sim]") {
3131
auto nd = NetworkParams::create(2);
32-
auto ctx = details::SimulatorContext::create(2, nd, {});
32+
auto ctx = details::SimulatorContext::create(2, nd, {}, {});
3333
ChannelId id{0, 1};
3434

3535
details::Transport transport(ctx);
@@ -53,7 +53,7 @@ TEST_CASE("Transport send", "[sim]") {
5353

5454
TEST_CASE("Transport ready w. limit", "[sim]") {
5555
auto nd = NetworkParams::create(2);
56-
auto ctx = details::SimulatorContext::create(2, nd, {});
56+
auto ctx = details::SimulatorContext::create(2, nd, {}, {});
5757
ChannelId id{0, 1};
5858
auto sid = id.flip();
5959

0 commit comments

Comments
 (0)