Skip to content

Commit 23b49d9

Browse files
awoll-bdaiexploy-bot
authored andcommitted
Use base_names in base matchers (#69)
# Pull Request ### What change is being made Use the base_names metadata in base matchers. ### Why this change is being made Allow arbitrary base names. ### Tested Adapted tests. GitOrigin-RevId: dd7d3e298abe386a88eb9824e43e4a9ee54150e6
1 parent f3735a9 commit 23b49d9

7 files changed

Lines changed: 61 additions & 23 deletions

File tree

control/context.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,21 @@ std::optional<int> parseUpdateRate(OnnxRuntime& onnx_model) {
2626
return static_cast<int>(std::stod(maybe_update_rate.value()));
2727
}
2828

29+
std::unordered_map<std::string, std::string> parseBaseNames(const OnnxRuntime& onnx_model) {
30+
std::unordered_map<std::string, std::string> base_names;
31+
const auto maybe_base_names = onnx_model.getCustomMetadata("base_names");
32+
if (!maybe_base_names.has_value()) return base_names;
33+
try {
34+
auto json_base_names = json::parse(maybe_base_names.value());
35+
for (auto it = json_base_names.begin(); it != json_base_names.end(); ++it) {
36+
base_names[it.key()] = it.value().get<std::string>();
37+
}
38+
} catch (const json::exception& e) {
39+
LOG_STREAM(ERROR, "Failed to parse base_names metadata: " << e.what());
40+
}
41+
return base_names;
42+
}
43+
2944
} // namespace
3045

3146
// Registration methods
@@ -56,10 +71,13 @@ bool OnnxContext::createContext(OnnxRuntime& onnx_model, bool strict) {
5671
if (!maybe_update_rate.has_value()) return false;
5772
update_rate_ = maybe_update_rate.value();
5873

74+
base_names_ = parseBaseNames(onnx_model);
75+
5976
for (const auto& input_name : onnx_model.inputNames()) {
6077
Match maybe_match{
6178
.name = input_name,
6279
.metadata = onnx_model.getCustomMetadata(input_name),
80+
.base_names = base_names_,
6381
};
6482
bool found_match = false;
6583
for (auto& group_matchers : group_matchers_) {
@@ -81,6 +99,7 @@ bool OnnxContext::createContext(OnnxRuntime& onnx_model, bool strict) {
8199
Match maybe_match{
82100
.name = output_name,
83101
.metadata = onnx_model.getCustomMetadata(output_name),
102+
.base_names = base_names_,
84103
};
85104
bool found_match = false;
86105
for (auto& group_matchers : group_matchers_) {

control/context.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,9 @@ class OnnxContext {
8181
std::vector<std::unique_ptr<Output>> outputs_; ///< Output components for writing robot commands.
8282
std::vector<std::unique_ptr<Matcher>> matchers_; ///< Registered single-tensor matchers.
8383
std::vector<std::unique_ptr<GroupMatcher>>
84-
group_matchers_; ///< Registered multi-tensor matchers.
85-
int update_rate_{0}; ///< Control loop update rate in Hz.
84+
group_matchers_; ///< Registered multi-tensor matchers.
85+
int update_rate_{0}; ///< Control loop update rate in Hz.
86+
std::unordered_map<std::string, std::string> base_names_; ///< Map of base names.
8687
};
8788

8889
} // namespace exploy::control

control/matcher.cpp

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
// Copyright (c) 2026 Robotics and AI Institute LLC dba RAI Institute. All rights reserved.
22

33
#include <fmt/core.h>
4+
#include <fmt/ranges.h>
45
#include <optional>
6+
#include <ranges>
57
#include <regex>
68
#include <string>
79
#include <string_view>
@@ -13,6 +15,21 @@
1315

1416
namespace exploy::control {
1517

18+
namespace {
19+
20+
// Builds a regex pattern for base tensor matching.
21+
// Returns std::nullopt when base_names is empty, causing matchers to reject the tensor.
22+
std::optional<std::regex> buildBasePattern(
23+
const std::unordered_map<std::string, std::string>& base_names, std::string_view field) {
24+
if (base_names.empty()) return std::nullopt;
25+
auto pairs = base_names | std::views::transform([](const auto& p) {
26+
return fmt::format(R"({}\.{})", p.first, p.second);
27+
});
28+
return std::regex(fmt::format(R"(obj\.({})\.{})", fmt::join(pairs, "|"), field));
29+
}
30+
31+
} // namespace
32+
1633
// --------------- Joint matchers --------------------------------
1734
bool JointMatcher::matches(const Match& maybe_match) {
1835
std::smatch match;
@@ -53,8 +70,8 @@ std::vector<std::unique_ptr<Input>> JointMatcher::createInputs() const {
5370

5471
// --------------- Base matchers --------------------------------
5572
bool BasePositionMatcher::matches(const Match& maybe_match) {
56-
std::regex pattern = std::regex(fmt::format("obj\\.({})\\.base\\.pos_b_rt_w_in_w", alphanumeric));
57-
if (std::regex_match(maybe_match.name, pattern)) {
73+
auto maybe_pattern = buildBasePattern(maybe_match.base_names, "pos_b_rt_w_in_w");
74+
if (maybe_pattern.has_value() && std::regex_match(maybe_match.name, maybe_pattern.value())) {
5875
found_matches_[maybe_match.name] = maybe_match;
5976
return true;
6077
}
@@ -70,8 +87,8 @@ std::vector<std::unique_ptr<Input>> BasePositionMatcher::createInputs() const {
7087
}
7188

7289
bool BaseOrientationMatcher::matches(const Match& maybe_match) {
73-
std::regex pattern = std::regex(fmt::format("obj\\.({})\\.base\\.w_Q_b", alphanumeric));
74-
if (std::regex_match(maybe_match.name, pattern)) {
90+
auto maybe_pattern = buildBasePattern(maybe_match.base_names, "w_Q_b");
91+
if (maybe_pattern.has_value() && std::regex_match(maybe_match.name, maybe_pattern.value())) {
7592
found_matches_[maybe_match.name] = maybe_match;
7693
return true;
7794
}
@@ -87,9 +104,8 @@ std::vector<std::unique_ptr<Input>> BaseOrientationMatcher::createInputs() const
87104
}
88105

89106
bool BaseLinearVelocityMatcher::matches(const Match& maybe_match) {
90-
std::regex pattern =
91-
std::regex(fmt::format("obj\\.({})\\.base\\.lin_vel_b_rt_w_in_b", alphanumeric));
92-
if (std::regex_match(maybe_match.name, pattern)) {
107+
auto maybe_pattern = buildBasePattern(maybe_match.base_names, "lin_vel_b_rt_w_in_b");
108+
if (maybe_pattern.has_value() && std::regex_match(maybe_match.name, maybe_pattern.value())) {
93109
found_matches_[maybe_match.name] = maybe_match;
94110
return true;
95111
}
@@ -105,9 +121,8 @@ std::vector<std::unique_ptr<Input>> BaseLinearVelocityMatcher::createInputs() co
105121
}
106122

107123
bool BaseAngularVelocityMatcher::matches(const Match& maybe_match) {
108-
std::regex pattern =
109-
std::regex(fmt::format("obj\\.({})\\.base\\.ang_vel_b_rt_w_in_b", alphanumeric));
110-
if (std::regex_match(maybe_match.name, pattern)) {
124+
auto maybe_pattern = buildBasePattern(maybe_match.base_names, "ang_vel_b_rt_w_in_b");
125+
if (maybe_pattern.has_value() && std::regex_match(maybe_match.name, maybe_pattern.value())) {
111126
found_matches_[maybe_match.name] = maybe_match;
112127
return true;
113128
}

control/matcher.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ constexpr std::string_view alphanumeric = "[a-zA-Z0-9_]+";
2828
* from the ONNX model.
2929
*/
3030
struct Match {
31-
std::string name{}; ///< ONNX tensor name.
32-
std::optional<std::string> metadata{}; ///< Optional JSON metadata string.
31+
std::string name{}; ///< ONNX tensor name.
32+
std::optional<std::string> metadata{}; ///< Optional JSON metadata string.
33+
std::unordered_map<std::string, std::string> base_names; ///< Map of base names.
3334
};
3435

3536
/**

control/test/components_test.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ TEST_F(OnnxComponentsTest, JointPositionInput_InitAndRead) {
5353
}
5454

5555
TEST_F(OnnxComponentsTest, BasePositionInput_InitAndRead) {
56-
BasePositionInput base_input("obj.robot1.base.pos_b_rt_w_in_w");
56+
BasePositionInput base_input("obj.robot1.base_name.pos_b_rt_w_in_w");
5757

5858
// Test successful initialization
5959
EXPECT_CALL(state_mock_, initBasePosW()).WillOnce(Return(true));
@@ -67,13 +67,13 @@ TEST_F(OnnxComponentsTest, BasePositionInput_InitAndRead) {
6767
}
6868

6969
TEST_F(OnnxComponentsTest, BasePositionInput_InitFailure) {
70-
BasePositionInput base_input("obj.robot1.base.pos_b_rt_w_in_w");
70+
BasePositionInput base_input("obj.robot1.base_name.pos_b_rt_w_in_w");
7171
EXPECT_CALL(state_mock_, initBasePosW()).WillOnce(Return(false));
7272
EXPECT_FALSE(base_input.init(state_mock_, command_mock_));
7373
}
7474

7575
TEST_F(OnnxComponentsTest, BaseOrientationInput_InitAndRead) {
76-
BaseOrientationInput base_input("obj.robot1.base.w_Q_b");
76+
BaseOrientationInput base_input("obj.robot1.base_name.w_Q_b");
7777

7878
// Test initialization
7979
EXPECT_CALL(state_mock_, initBaseQuatW()).WillOnce(Return(true));
@@ -87,7 +87,7 @@ TEST_F(OnnxComponentsTest, BaseOrientationInput_InitAndRead) {
8787
}
8888

8989
TEST_F(OnnxComponentsTest, BaseLinearVelocityInput_InitAndRead) {
90-
BaseLinearVelocityInput base_input("obj.robot1.base.lin_vel_b_rt_w_in_b");
90+
BaseLinearVelocityInput base_input("obj.robot1.base_name.lin_vel_b_rt_w_in_b");
9191

9292
// Test initialization
9393
EXPECT_CALL(state_mock_, initBaseLinVelB()).WillOnce(Return(true));
@@ -101,7 +101,7 @@ TEST_F(OnnxComponentsTest, BaseLinearVelocityInput_InitAndRead) {
101101
}
102102

103103
TEST_F(OnnxComponentsTest, BaseAngularVelocityInput_InitAndRead) {
104-
BaseAngularVelocityInput base_input("obj.robot1.base.ang_vel_b_rt_w_in_b");
104+
BaseAngularVelocityInput base_input("obj.robot1.base_name.ang_vel_b_rt_w_in_b");
105105

106106
// Test initialization
107107
EXPECT_CALL(state_mock_, initBaseAngVelB()).WillOnce(Return(true));

control/test/testdata/test_onnx_generator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
"obj.robot1.joints.pos",
1818
"obj.robot1.joints.vel",
1919
# base
20-
"obj.robot1.base.pos_b_rt_w_in_w",
21-
"obj.robot1.base.w_Q_b",
22-
"obj.robot1.base.lin_vel_b_rt_w_in_b",
23-
"obj.robot1.base.ang_vel_b_rt_w_in_b",
20+
"obj.robot1.base_name.pos_b_rt_w_in_w",
21+
"obj.robot1.base_name.w_Q_b",
22+
"obj.robot1.base_name.lin_vel_b_rt_w_in_b",
23+
"obj.robot1.base_name.ang_vel_b_rt_w_in_b",
2424
# commands
2525
"cmd.se2_velocity.vel",
2626
"cmd.se2_velocity.vel_with_range",
@@ -216,6 +216,7 @@ def get_env_metadata() -> dict:
216216
return {
217217
"exploy_version": "0.1.0",
218218
"update_rate": 10.0,
219+
"base_names": {"robot1": "base_name"},
219220
}
220221

221222

pixi.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ pip = "*"
2525
# FEATURE: Python Development - Python tools, linting, formatting, testing
2626
[feature.python.dependencies]
2727
python = "3.11.*"
28+
onnx = "*"
2829
onnxscript = "*"
2930
pybind11 = ">=2.10"
3031
ruff = "*"

0 commit comments

Comments
 (0)