Skip to content

Commit b1a32af

Browse files
awoll-bdaiexploy-bot
authored andcommitted
Allow to run inference asynchronously
### What change is being made Add `SyncWorker` and `AsyncWorker`, which allow to run policy inference on a separate thread. The two strategies can be selected via `WorkerMode` at `init()` time. #### `SyncWorker` All three callbacks execute on the calling thread. The cycle blocks until inference completes. ``` main ╔══════╦═══════════════╦═══════╗ ╔══════╦═══════════════╦═══════╗ ║ read ║ work ║ write ║ (idle) ║ read ║ work ║ write ║ ··· ╚══════╩═══════════════╩═══════╝ ╚══════╩═══════════════╩═══════╝ ◄────────── cycle N ─────────► ◄───────── cycle N+1 ─────────► ``` #### `AsyncWorker` `read` and `write` run on the calling thread; `work` is offloaded to a dedicated background thread. The result of cycle N is written at the start of cycle N+1. ``` main ╔══════╗ ╔═══════╦══════╗ ╔═══════╗ ║ read ║ · · · · · · · · · · · ║ write ║ read ║ · · · · · · · · · ║ write ║ ··· ╚══╤═══╝ ╚══╤════╩══╤═══╝ ╚═══╤═══╝ │ │ │ │ worker ╚════════ work ══════════════════╝ ╚═══════ work ═══════════════╝ ◄──────────── cycle N ───────►◄───────────── cycle N+1 ─────────► ``` #### `AsyncWorker` — overrun If inference is still running when the next cycle boundary arrives, the cycle is skipped and `update()` returns `true`. Write and the following read resume on the next call once work has finished. ``` main ╔══════╗ ╎ (skip) ╎ ╔═══════╦══════╗ ║ read ║ · · · · · · · ╎ ╎ ║ write ║ read ║ ··· ╚══╤═══╝ ╎ ╎ ╚═══════╩══╤═══╝ │ boundary │ │ worker ╚══════════════════════ work ═══════╝ ╚═══ work ═══ ``` Note: The `Write()` will happen earliest in the second update call. Short running policies might therefore have an increased `Read()->Write()` period with the `AsyncWorker`. ### Why this change is being made Enable running policies without blocking the main thread. ### Tested Added extensive tests. Run in simulation and on hardware. Tracy. Navigation runs with sync worker and is blocking the main thread for a long time (red "work" block). <img width="1072" height="216" alt="Screenshot from 2026-05-19 14-13-01" src="https://github.com/user-attachments/assets/56eae2c3-ec21-4d48-8109-f9810214144d" /> Navigation runs with async worker and does not block main thread (purple "work" block). <img width="1071" height="300" alt="Screenshot from 2026-05-19 14-18-58" src="https://github.com/user-attachments/assets/1fba8603-ddbf-4f32-a383-b09b7a00e590" /> GitOrigin-RevId: e49e2e989fc1c1756e2c5998173f07a7440e82b6
1 parent e6dcd16 commit b1a32af

8 files changed

Lines changed: 836 additions & 15 deletions

File tree

control/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ add_library(exploy SHARED
4242
src/onnx_runtime.cpp
4343
src/components.cpp
4444
src/matcher.cpp
45+
src/worker.cpp
4546
)
4647

4748
# Create an alias target for symmetry between the build tree and install tree.
@@ -175,6 +176,7 @@ if(BUILD_TESTING)
175176
test/logging_test.cpp
176177
test/metadata_test.cpp
177178
test/onnx_runtime_test.cpp
179+
test/worker_test.cpp
178180
"${GEN_OUTPUT}"
179181
)
180182
target_compile_definitions(${PROJECT_NAME}_test PRIVATE

control/include/exploy/controller.hpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,25 @@
77
#include "exploy/data_collection_interface.hpp"
88
#include "exploy/onnx_runtime.hpp"
99
#include "exploy/state_interface.hpp"
10+
#include "exploy/worker.hpp"
1011

12+
#include <memory>
1113
#include <string>
1214

1315
namespace exploy::control {
1416

17+
/**
18+
* @brief Selects the ONNX inference execution strategy.
19+
*
20+
* Pass this enum to `OnnxRLController::init()` to choose between:
21+
* - `SYNC` — inference blocks the calling thread
22+
* - `ASYNC` — inference runs on a background thread
23+
*/
24+
enum class WorkerMode {
25+
SYNC, ///< Run inference synchronously on the calling thread.
26+
ASYNC, ///< Run inference on a background thread (see AsyncWorker).
27+
};
28+
1529
/**
1630
* @class OnnxRLController
1731
*
@@ -53,9 +67,10 @@ class OnnxRLController {
5367
* @brief Initialize the controller.
5468
*
5569
* @param enable_data_collection Whether to enable data collection.
70+
* @param mode Whether to run the ONNX inference synchronously or asynchronously.
5671
* @return True if initialization succeeds, false otherwise.
5772
*/
58-
bool init(bool enable_data_collection);
73+
bool init(bool enable_data_collection, WorkerMode mode = WorkerMode::SYNC);
5974
/**
6075
* @brief Reset the controller.
6176
*/
@@ -75,6 +90,9 @@ class OnnxRLController {
7590
bool initCommands();
7691
bool initSensors();
7792

93+
bool readInputs();
94+
bool writeOutputs();
95+
7896
OnnxContext context_{};
7997
OnnxRuntime onnx_model_{};
8098
RobotStateInterface& state_;
@@ -84,7 +102,12 @@ class OnnxRLController {
84102

85103
// Data collection.
86104
DataCollectionInterface& data_collection_;
105+
// Written by the work_fn (possibly on a background thread) before work_finished_
106+
// is set under the worker's mutex. Read on the main thread only after observing
107+
// work_finished_ under that same mutex — no atomic needed.
87108
double inference_duration_s_{};
109+
110+
std::unique_ptr<Worker> worker_{nullptr};
88111
};
89112

90113
} // namespace exploy::control

control/include/exploy/worker.hpp

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
// Copyright (c) 2026 Robotics and AI Institute LLC dba RAI Institute. All rights reserved.
2+
#pragma once
3+
4+
#include <condition_variable>
5+
#include <cstdint>
6+
#include <functional>
7+
#include <mutex>
8+
#include <thread>
9+
10+
namespace exploy::control {
11+
12+
/**
13+
* @brief Abstract base class for controller execution strategies.
14+
*
15+
* A Worker owns the read → work → write pipeline. Concrete subclasses decide
16+
* *when* and *how* the three callbacks are invoked:
17+
*
18+
* - `read_fn` — called on the main thread to snapshot robot state.
19+
* - `work_fn` — runs the ONNX inference.
20+
* - `write_fn` — called on the main thread to dispatch joint targets.
21+
*
22+
* Use `setCallbacks()` to register all three before calling `update()`.
23+
*/
24+
class Worker {
25+
public:
26+
virtual ~Worker() = default;
27+
28+
/** @brief Reset the worker to its initial state. */
29+
virtual void reset() {}
30+
31+
/**
32+
* @brief Advance the control pipeline by one tick.
33+
*
34+
* @param time_us Current timestamp in microseconds.
35+
* @return `true` if the update succeeded or was safely skipped, `false` on error.
36+
*/
37+
virtual bool update(uint64_t time_us) = 0;
38+
39+
/**
40+
* @brief Register the read / work / write callbacks.
41+
*
42+
* Must be called exactly once before the first `update()`. All three
43+
* arguments must be non-null; the function returns `false` otherwise.
44+
*
45+
* @param read_fn Reads observations from the robot state (main thread).
46+
* @param work_fn Runs ONNX inference (may execute on a background thread).
47+
* @param write_fn Writes joint targets back to the robot (main thread).
48+
* @return `true` on success.
49+
*/
50+
bool setCallbacks(std::function<bool()> read_fn, std::function<bool()> work_fn,
51+
std::function<bool()> write_fn);
52+
53+
protected:
54+
std::function<bool()> read_fn_;
55+
std::function<bool()> work_fn_;
56+
std::function<bool()> write_fn_;
57+
};
58+
59+
/**
60+
* @brief Synchronous worker — runs read → work → write on the calling thread.
61+
*
62+
* All three callbacks execute inline inside `update()`. The call blocks until
63+
* inference is complete, so the caller's thread must afford the full inference
64+
* latency every control cycle.
65+
*
66+
* Phase is maintained across updates: the first call initialises the phase
67+
* reference and subsequent calls fire whenever the elapsed time exceeds the
68+
* configured period.
69+
*/
70+
class SyncWorker : public Worker {
71+
public:
72+
/**
73+
* @param update_rate_hz Desired control frequency in Hz.
74+
*/
75+
explicit SyncWorker(double update_rate_hz);
76+
77+
bool update(uint64_t time_us) override;
78+
void reset() override;
79+
80+
private:
81+
uint64_t period_ms_;
82+
uint64_t last_scheduled_update_us_ = 0;
83+
bool first_run_ = true;
84+
};
85+
86+
/**
87+
* @brief Asynchronous worker — offloads ONNX inference to a background thread.
88+
*
89+
* The pipeline is split across two consecutive `update()` calls:
90+
*
91+
* 1. **First call at cycle boundary** — `read_fn` executes on the main thread,
92+
* then `work_fn` is dispatched to a dedicated background thread.
93+
* 2. **Subsequent calls** — if inference has finished, `write_fn` runs on the
94+
* main thread to commit joint targets. If the worker is still busy
95+
* (overrun), the cycle is skipped and `update()` returns `true`.
96+
*
97+
* This decouples the main control loop from the inference latency, allowing
98+
* the robot to keep receiving state updates while the GPU or CPU is busy.
99+
*
100+
* Thread safety: the internal mutex guards all shared state between the main
101+
* thread and the worker thread.
102+
*
103+
* `reset()` stops the background thread and clears all state. The thread is
104+
* re-started on the next `update()` call.
105+
*
106+
* **Error handling**: when `work_fn` returns false the worker latches into a
107+
* faulted state. All subsequent `update()` calls return `false` immediately
108+
* without dispatching new work. Call `reset()` to clear the fault and resume.
109+
*/
110+
class AsyncWorker : public Worker {
111+
public:
112+
/**
113+
* @param update_rate_hz Desired control frequency in Hz.
114+
*/
115+
explicit AsyncWorker(double update_rate_hz);
116+
~AsyncWorker() override;
117+
118+
void reset() override;
119+
bool update(uint64_t time_us) override;
120+
121+
private:
122+
void startWorker();
123+
void stopWorker();
124+
void threadLoop();
125+
126+
// Main-thread-only — no synchronization needed.
127+
uint64_t period_ms_;
128+
uint64_t last_scheduled_update_us_ = 0;
129+
bool first_run_ = true;
130+
uint64_t consecutive_overruns_ = 0;
131+
std::thread thread_;
132+
133+
// Shared between the main thread and the worker thread — all guarded by mutex_.
134+
std::mutex mutex_;
135+
std::condition_variable cv_;
136+
bool stop_ = false;
137+
bool working_ = false;
138+
bool work_requested_ = false;
139+
bool work_finished_ = false;
140+
bool work_successful_ = true;
141+
bool faulted_ = false;
142+
};
143+
144+
} // namespace exploy::control

control/src/controller.cpp

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ bool OnnxRLController::create(const std::string& onnx_model_path, bool register_
5454
return true;
5555
}
5656

57-
bool OnnxRLController::init(bool enable_data_collection) {
57+
bool OnnxRLController::init(bool enable_data_collection, WorkerMode mode) {
5858
if (!onnx_model_.isInitialized()) {
5959
LOG_STREAM(ERROR, "ONNX model is not initialized.");
6060
return false;
@@ -73,6 +73,35 @@ bool OnnxRLController::init(bool enable_data_collection) {
7373
}
7474
}
7575

76+
auto rate = static_cast<double>(context_.updateRate());
77+
if (rate <= 0) {
78+
LOG_STREAM(ERROR, "Invalid update rate: " << rate << " Hz. Must be > 0.");
79+
return false;
80+
}
81+
if (mode == WorkerMode::ASYNC) {
82+
worker_ = std::make_unique<AsyncWorker>(rate);
83+
} else {
84+
worker_ = std::make_unique<SyncWorker>(rate);
85+
}
86+
auto success = worker_->setCallbacks(
87+
[this]() {
88+
return readInputs();
89+
},
90+
[this]() {
91+
auto start = std::chrono::high_resolution_clock::now();
92+
bool success = onnx_model_.evaluate();
93+
auto end = std::chrono::high_resolution_clock::now();
94+
inference_duration_s_ = std::chrono::duration<double>(end - start).count();
95+
return success;
96+
},
97+
[this]() {
98+
return writeOutputs();
99+
});
100+
if (!success) {
101+
LOG(ERROR, "Failed to set worker callbacks.");
102+
return false;
103+
}
104+
76105
if (enable_data_collection) {
77106
for (const auto& name : onnx_model_.inputNames()) {
78107
auto maybe_buffer = onnx_model_.inputBuffer<float>(name);
@@ -103,30 +132,32 @@ bool OnnxRLController::init(bool enable_data_collection) {
103132

104133
void OnnxRLController::reset() {
105134
onnx_model_.resetBuffers();
135+
if (worker_) worker_->reset();
106136
}
107137

108-
bool OnnxRLController::update(uint64_t time_us) {
138+
bool OnnxRLController::readInputs() {
109139
for (const auto& input : context_.getInputs()) {
110140
if (!input->read(onnx_model_, state_, command_)) {
111141
LOG_STREAM(ERROR, "Failed to read input");
112142
return false;
113143
}
114144
}
145+
return true;
146+
}
115147

116-
auto start_time = std::chrono::high_resolution_clock::now();
117-
if (!onnx_model_.evaluate()) {
118-
LOG_STREAM(ERROR, "Policy evaluation failed.");
119-
return false;
120-
}
121-
auto end_time = std::chrono::high_resolution_clock::now();
122-
inference_duration_s_ = std::chrono::duration<double>(end_time - start_time).count();
123-
148+
bool OnnxRLController::writeOutputs() {
124149
for (const auto& output : context_.getOutputs()) {
125150
if (!output->write(onnx_model_, state_, command_)) {
126151
LOG_STREAM(ERROR, "Failed to write output");
127152
return false;
128153
}
129154
}
155+
return true;
156+
}
157+
158+
bool OnnxRLController::update(uint64_t time_us) {
159+
if (!worker_) return false;
160+
if (!worker_->update(time_us)) return false;
130161

131162
if (!data_collection_.collectData(time_us)) {
132163
LOG(WARN, "Data collection failed.");

0 commit comments

Comments
 (0)