Skip to content

Commit 97bf516

Browse files
awoll-bdaiexploy-bot
authored andcommitted
Add joint position command
### What change is being made Add joint position command, e.g. `cmd.joint_pos.arm` The command provides the list of joints in the metadata. ### Why this change is being made Support policies which command joint positions (e.g. manipulation) ### Tested Covered by unit and integration tests. GitOrigin-RevId: fbd2158274d35d5bcfd930a6ba21c7cfcff521e2
1 parent 3dccc1d commit 97bf516

11 files changed

Lines changed: 270 additions & 0 deletions

control/include/exploy/command_interface.hpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,33 @@ class CommandInterface {
144144
LOG_STREAM(ERROR, "floatValue() not implemented for command: " << command_name);
145145
return std::nullopt;
146146
}
147+
/**
148+
* @brief Initialize data source of a commanded joint position.
149+
*
150+
* Called once per joint during initialization (usually non real-time).
151+
*
152+
* @param command_name The name of the command.
153+
* @param joint_name The name of the joint.
154+
* @return True if initialization succeeded, false otherwise.
155+
*/
156+
virtual bool initJointPosition(const std::string& command_name, const std::string& joint_name) {
157+
LOG_STREAM(ERROR, "initJointPosition() not implemented for command: "
158+
<< command_name << ", joint: " << joint_name);
159+
return false;
160+
}
161+
/**
162+
* @brief Get the commanded position for a single joint.
163+
*
164+
* @param command_name The name of the command.
165+
* @param joint_name The name of the joint.
166+
* @return The commanded joint position, or std::nullopt if unavailable.
167+
*/
168+
virtual std::optional<float> jointPosition(const std::string& command_name,
169+
const std::string& joint_name) const {
170+
LOG_STREAM(ERROR, "jointPosition() not implemented for command: " << command_name
171+
<< ", joint: " << joint_name);
172+
return std::nullopt;
173+
}
147174
};
148175

149176
} // namespace exploy::control

control/include/exploy/components.hpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,34 @@ class CommandBooleanInput : public Input {
485485
std::string command_name_; ///< Boolean command name.
486486
};
487487

488+
/**
489+
* @brief Input component that reads commanded joint positions.
490+
*
491+
* Reads the commanded position for each joint specified in the metadata by calling
492+
* CommandInterface::jointPosition() once per joint, then copies the values to the
493+
* ONNX input buffer in the order defined by the metadata joint names.
494+
*/
495+
class CommandJointPositionInput : public Input {
496+
public:
497+
/**
498+
* @brief Construct a joint position command input component.
499+
*
500+
* @param key ONNX input tensor name (e.g., "cmd.joint_pos.arm").
501+
* @param command_name Name of the joint position command to read.
502+
* @param metadata Metadata specifying the ordered list of joint names.
503+
*/
504+
CommandJointPositionInput(const std::string& key, const std::string& command_name,
505+
const metadata::JointPositionCommandMetadata& metadata);
506+
507+
bool init(RobotStateInterface& state, CommandInterface& command) override;
508+
bool read(OnnxRuntime& runtime, RobotStateInterface& state, CommandInterface& command) override;
509+
510+
private:
511+
std::string key_; ///< ONNX input tensor name.
512+
std::string command_name_; ///< Joint position command name.
513+
metadata::JointPositionCommandMetadata metadata_; ///< Command configuration.
514+
};
515+
488516
/**
489517
* @brief Input component that reads a floating-point command value.
490518
*

control/include/exploy/matcher.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,19 @@ class CommandFloatMatcher : public Matcher {
337337
bool matches(const Match& maybe_match) override;
338338
std::vector<std::unique_ptr<Input>> createInputs() const override;
339339
};
340+
341+
/**
342+
* @brief Matcher for joint position command input tensors.
343+
*
344+
* Matches patterns like "cmd.joint_pos.{name}" and creates CommandJointPositionInput
345+
* components. The joint names are read from the tensor's JSON metadata field
346+
* ("joint_names" array).
347+
*/
348+
class CommandJointPositionMatcher : public Matcher {
349+
public:
350+
bool matches(const Match& maybe_match) override;
351+
std::vector<std::unique_ptr<Input>> createInputs() const override;
352+
};
340353
// ---------------------------------------------------------------
341354

342355
// --------------- Body matchers --------------------------------

control/include/exploy/metadata.hpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,15 @@ struct JointMetadata {
265265
std::vector<std::string> names{}; ///< Joint names.
266266
};
267267

268+
/**
269+
* @brief Metadata for joint position commands.
270+
*
271+
* Specifies the ordered list of joint names whose positions are provided by the command.
272+
*/
273+
struct JointPositionCommandMetadata {
274+
std::vector<std::string> joint_names{}; ///< Ordered joint names.
275+
};
276+
268277
/**
269278
* @brief Parse JointMetadata from JSON.
270279
*
@@ -275,6 +284,16 @@ inline void from_json(const json& j, JointMetadata& jm) {
275284
j.at("joint_names").get_to(jm.names);
276285
}
277286

287+
/**
288+
* @brief Parse JointPositionCommandMetadata from JSON.
289+
*
290+
* @param j JSON object containing a "joint_names" array.
291+
* @param cmd JointPositionCommandMetadata object to populate.
292+
*/
293+
inline void from_json(const json& j, JointPositionCommandMetadata& cmd) {
294+
j.at("joint_names").get_to(cmd.joint_names);
295+
}
296+
278297
/**
279298
* @brief Parsed version (MAJOR.MINOR).
280299
*/

control/src/components.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,43 @@ bool CommandBooleanInput::read(OnnxRuntime& runtime, RobotStateInterface& /*stat
485485
return true;
486486
}
487487

488+
// Implementation of CommandJointPositionInput methods
489+
CommandJointPositionInput::CommandJointPositionInput(
490+
const std::string& key, const std::string& command_name,
491+
const metadata::JointPositionCommandMetadata& metadata)
492+
: key_(key), command_name_(command_name), metadata_(metadata) {}
493+
494+
bool CommandJointPositionInput::init(RobotStateInterface& /*state*/, CommandInterface& command) {
495+
if (metadata_.joint_names.empty()) {
496+
LOG_STREAM(ERROR,
497+
"initJointPosition() called with empty joint_names for command: " << command_name_);
498+
return false;
499+
}
500+
for (const auto& joint_name : metadata_.joint_names) {
501+
if (!command.initJointPosition(command_name_, joint_name)) return false;
502+
}
503+
return true;
504+
}
505+
506+
bool CommandJointPositionInput::read(OnnxRuntime& runtime, RobotStateInterface& /*state*/,
507+
CommandInterface& command) {
508+
auto maybe_buffer = runtime.inputBuffer<float>(key_);
509+
if (!maybe_buffer.has_value()) return false;
510+
auto buffer = maybe_buffer.value();
511+
if (buffer.size() != metadata_.joint_names.size()) {
512+
LOG_STREAM(ERROR, "Buffer size " << buffer.size() << " does not match joint_names size "
513+
<< metadata_.joint_names.size()
514+
<< " for command: " << command_name_);
515+
return false;
516+
}
517+
for (std::size_t i = 0; i < metadata_.joint_names.size(); ++i) {
518+
auto maybe_pos = command.jointPosition(command_name_, metadata_.joint_names[i]);
519+
if (!maybe_pos.has_value()) return false;
520+
buffer[i] = maybe_pos.value();
521+
}
522+
return true;
523+
}
524+
488525
// Implementation of CommandFloatInput methods
489526
CommandFloatInput::CommandFloatInput(const std::string& key, const std::string& command_name,
490527
const metadata::FloatCommandMetadata& metadata)

control/src/controller.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ OnnxRLController::OnnxRLController(RobotStateInterface& state, CommandInterface&
2929
context_.registerMatcher(std::make_unique<CommandSE2VelocityMatcher>());
3030
context_.registerMatcher(std::make_unique<CommandBooleanMatcher>());
3131
context_.registerMatcher(std::make_unique<CommandFloatMatcher>());
32+
context_.registerMatcher(std::make_unique<CommandJointPositionMatcher>());
3233
context_.registerMatcher(std::make_unique<StepCountMatcher>());
3334
// Register all group matchers
3435
context_.registerGroupMatcher(std::make_unique<JointMatcher>());

control/src/matcher.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,30 @@ std::vector<std::unique_ptr<Input>> CommandFloatMatcher::createInputs() const {
434434
return inputs;
435435
}
436436

437+
bool CommandJointPositionMatcher::matches(const Match& maybe_match) {
438+
std::smatch match;
439+
std::regex pattern = std::regex(fmt::format("cmd\\.joint_pos\\.({})", kAlphanumeric));
440+
if (std::regex_match(maybe_match.name, match, pattern)) {
441+
found_matches_[match[1].str()] = maybe_match;
442+
return true;
443+
}
444+
return false;
445+
}
446+
447+
std::vector<std::unique_ptr<Input>> CommandJointPositionMatcher::createInputs() const {
448+
std::vector<std::unique_ptr<Input>> inputs;
449+
for (const auto& [name, match] : found_matches_) {
450+
if (!match.metadata.has_value()) continue;
451+
auto maybe_metadata =
452+
metadata::safe_json_get<metadata::JointPositionCommandMetadata>(match.metadata.value());
453+
if (!maybe_metadata.has_value()) continue;
454+
if (maybe_metadata.value().joint_names.empty()) continue;
455+
inputs.push_back(
456+
std::make_unique<CommandJointPositionInput>(match.name, name, maybe_metadata.value()));
457+
}
458+
return inputs;
459+
}
460+
437461
bool CommandSE2VelocityMatcher::matches(const Match& maybe_match) {
438462
std::smatch match;
439463
std::regex pattern = std::regex(fmt::format("cmd\\.se2_velocity\\.({})", kAlphanumeric));

control/test/components_test.cpp

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,51 @@ TEST_F(OnnxComponentsTest, MemoryOutput_WithRealRuntime) {
212212
EXPECT_TRUE(memory_was_overwritten) << "Memory should be overwritten with new output data";
213213
}
214214

215+
TEST_F(OnnxComponentsTest, CommandJointPositionInput_InitAndRead) {
216+
metadata::JointPositionCommandMetadata meta;
217+
meta.joint_names = {"j1", "j2"};
218+
CommandJointPositionInput input("cmd.joint_pos.arm", "arm", meta);
219+
220+
EXPECT_CALL(command_mock_, initJointPosition("arm", "j1")).WillOnce(Return(true));
221+
EXPECT_CALL(command_mock_, initJointPosition("arm", "j2")).WillOnce(Return(true));
222+
EXPECT_TRUE(input.init(state_mock_, command_mock_));
223+
224+
EXPECT_CALL(command_mock_, jointPosition("arm", "j1")).WillOnce(Return(std::make_optional(0.1f)));
225+
EXPECT_CALL(command_mock_, jointPosition("arm", "j2")).WillOnce(Return(std::make_optional(0.2f)));
226+
EXPECT_TRUE(input.read(runtime, state_mock_, command_mock_));
227+
228+
auto buffer = runtime.inputBuffer<float>("cmd.joint_pos.arm");
229+
ASSERT_TRUE(buffer.has_value());
230+
ASSERT_EQ(buffer->size(), 2u);
231+
EXPECT_FLOAT_EQ((*buffer)[0], 0.1f);
232+
EXPECT_FLOAT_EQ((*buffer)[1], 0.2f);
233+
}
234+
235+
TEST_F(OnnxComponentsTest, CommandJointPositionInput_InitFailsWhenJointNamesEmpty) {
236+
metadata::JointPositionCommandMetadata meta; // joint_names left empty
237+
CommandJointPositionInput input("cmd.joint_pos.arm", "arm", meta);
238+
EXPECT_FALSE(input.init(state_mock_, command_mock_));
239+
}
240+
241+
TEST_F(OnnxComponentsTest, CommandJointPositionInput_InitFailsWhenOneJointFails) {
242+
metadata::JointPositionCommandMetadata meta;
243+
meta.joint_names = {"j1", "j2"};
244+
CommandJointPositionInput input("cmd.joint_pos.arm", "arm", meta);
245+
246+
EXPECT_CALL(command_mock_, initJointPosition("arm", "j1")).WillOnce(Return(false));
247+
EXPECT_FALSE(input.init(state_mock_, command_mock_));
248+
}
249+
250+
TEST_F(OnnxComponentsTest, CommandJointPositionInput_ReadFailsWhenJointUnavailable) {
251+
metadata::JointPositionCommandMetadata meta;
252+
meta.joint_names = {"j1", "j2"};
253+
CommandJointPositionInput input("cmd.joint_pos.arm", "arm", meta);
254+
255+
EXPECT_CALL(command_mock_, jointPosition("arm", "j1")).WillOnce(Return(std::make_optional(0.1f)));
256+
EXPECT_CALL(command_mock_, jointPosition("arm", "j2")).WillOnce(Return(std::nullopt));
257+
EXPECT_FALSE(input.read(runtime, state_mock_, command_mock_));
258+
}
259+
215260
// --------------- Matcher tests for default metadata --------------------------------
216261

217262
TEST(CommandFloatMatcherTest, CreatesInputWithoutMetadata) {
@@ -257,4 +302,61 @@ TEST(CommandSE2VelocityMatcherTest, CreatesInputWithMetadata) {
257302
ASSERT_EQ(inputs.size(), 1u);
258303
}
259304

305+
TEST(CommandJointPositionMatcherTest, DoesNotMatchOtherPatterns) {
306+
CommandJointPositionMatcher matcher;
307+
EXPECT_FALSE(matcher.matches({.name = "cmd.float.gain"}));
308+
EXPECT_FALSE(matcher.matches({.name = "cmd.joint_vel.arm"}));
309+
EXPECT_FALSE(matcher.matches({.name = "cmd.joint_pos"}));
310+
}
311+
312+
TEST(CommandJointPositionMatcherTest, SkipsInputWithoutMetadata) {
313+
CommandJointPositionMatcher matcher;
314+
ASSERT_TRUE(matcher.matches({.name = "cmd.joint_pos.arm"}));
315+
316+
auto inputs = matcher.createInputs();
317+
ASSERT_EQ(inputs.size(), 0u);
318+
}
319+
320+
TEST(CommandJointPositionMatcherTest, SkipsInputWithEmptyJointNames) {
321+
CommandJointPositionMatcher matcher;
322+
ASSERT_TRUE(matcher.matches({.name = "cmd.joint_pos.arm", .metadata = R"({"joint_names": []})"}));
323+
324+
auto inputs = matcher.createInputs();
325+
ASSERT_EQ(inputs.size(), 0u);
326+
}
327+
328+
TEST(CommandJointPositionMatcherTest, CreatesInputWithMetadata) {
329+
CommandJointPositionMatcher matcher;
330+
Match m{
331+
.name = "cmd.joint_pos.arm",
332+
.metadata = R"({"joint_names": ["j1", "j2"]})",
333+
};
334+
ASSERT_TRUE(matcher.matches(m));
335+
336+
auto inputs = matcher.createInputs();
337+
ASSERT_EQ(inputs.size(), 1u);
338+
}
339+
340+
TEST(CommandJointPositionMatcherTest, DoesNotMatchSameNameTwice) {
341+
CommandJointPositionMatcher matcher;
342+
ASSERT_TRUE(matcher.matches({.name = "cmd.joint_pos.arm"}));
343+
// Second match with the same name overwrites, still produces one entry (but 0 inputs without
344+
// metadata)
345+
ASSERT_TRUE(matcher.matches({.name = "cmd.joint_pos.arm"}));
346+
347+
auto inputs = matcher.createInputs();
348+
ASSERT_EQ(inputs.size(), 0u);
349+
}
350+
351+
TEST(CommandJointPositionMatcherTest, MatchesMultipleCommands) {
352+
CommandJointPositionMatcher matcher;
353+
ASSERT_TRUE(matcher.matches(
354+
{.name = "cmd.joint_pos.arm", .metadata = R"({"joint_names": ["j1", "j2"]})"}));
355+
ASSERT_TRUE(matcher.matches(
356+
{.name = "cmd.joint_pos.leg", .metadata = R"({"joint_names": ["j3", "j4"]})"}));
357+
358+
auto inputs = matcher.createInputs();
359+
ASSERT_EQ(inputs.size(), 2u);
360+
}
361+
260362
} // namespace exploy::control

control/test/controller_test.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,8 @@ class OnnxControllerTest : public ::testing::Test {
225225
EXPECT_CALL(command_mock_, initSe3Pose("pose", _)).WillOnce(Return(true));
226226
EXPECT_CALL(command_mock_, initBooleanSelector("selector", _)).WillOnce(Return(true));
227227
EXPECT_CALL(command_mock_, initFloatValue("value", _)).WillOnce(Return(true));
228+
EXPECT_CALL(command_mock_, initJointPosition("arm", "j1")).WillOnce(Return(true));
229+
EXPECT_CALL(command_mock_, initJointPosition("arm", "j2")).WillOnce(Return(true));
228230
}
229231

230232
void ExpectInitCustom() {
@@ -283,6 +285,10 @@ class OnnxControllerTest : public ::testing::Test {
283285
.WillRepeatedly(Return(std::make_optional(true)));
284286
EXPECT_CALL(command_mock_, floatValue("value"))
285287
.WillRepeatedly(Return(std::make_optional(1.23f)));
288+
EXPECT_CALL(command_mock_, jointPosition("arm", "j1"))
289+
.WillRepeatedly(Return(std::make_optional(0.0f)));
290+
EXPECT_CALL(command_mock_, jointPosition("arm", "j2"))
291+
.WillRepeatedly(Return(std::make_optional(0.0f)));
286292
}
287293

288294
void ExpectReadCustom() {

control/test/mock_command_interface.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ class MockCommandInterface : public CommandInterface {
2424
(const std::string& command_name, const FloatScalarConfig& config), (override));
2525
MOCK_METHOD(std::optional<float>, floatValue, (const std::string& command_name),
2626
(const override));
27+
MOCK_METHOD(bool, initJointPosition,
28+
(const std::string& command_name, const std::string& joint_name), (override));
29+
MOCK_METHOD(std::optional<float>, jointPosition,
30+
(const std::string& command_name, const std::string& joint_name), (const, override));
2731
};
2832

2933
} // namespace exploy::control

0 commit comments

Comments
 (0)