Skip to content

Commit 543248e

Browse files
committed
Adapt pytest test_arm to support UR arm.
Signed-off-by: Jelmer de Wolde <jelmer.de.wolde@alliander.com>
1 parent ecc5209 commit 543248e

2 files changed

Lines changed: 101 additions & 30 deletions

File tree

alliander_tests/src/alliander_tests/tests/test_arm.py

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import numpy as np
56
import pytest
7+
import xmltodict
68
from alliander_utilities.config_objects import Arm
9+
from control_msgs.msg import JointTrajectoryControllerState
710
from rclpy.node import Node
811
from sensor_msgs.msg import JointState
912

@@ -13,18 +16,17 @@
1316
call_trigger_action,
1417
check_joint_positions,
1518
follow_joint_trajectory_goal,
19+
get_message,
20+
get_parameter,
1621
wait_until_reached_joint,
1722
)
1823

19-
arm = Arm("franka", gripper=True, moveit=True)
20-
PLATFORMS = [arm]
21-
2224

2325
class _TestArm:
2426
"""Base class for arm tests.
2527
2628
Attributes:
27-
platforms (dict): A dictionary of the platforms to launch.
29+
platforms (dict): A dictionary of the platforms to launch.
2830
"""
2931

3032
platforms: dict
@@ -56,6 +58,8 @@ def test_gripper_action(
5658
finger_joint_fault_tolerance (float): The tolerance for the finger joint position.
5759
timeout (int): The timeout in seconds before stopping the test.
5860
"""
61+
if self.platforms["arm"].name == "ur":
62+
pytest.skip("Gripper is not yet implemented for UR arm.")
5963
assert (
6064
call_trigger_action(
6165
test_node,
@@ -85,18 +89,27 @@ def test_follow_joint_trajectory_goal(
8589
joint_movement_tolerance (float): The tolerance for joint movement.
8690
timeout (int): The timeout in seconds to wait for the joint trajectory goal to be followed.
8791
"""
88-
expected_positions = [0.15, -0.39, 0.1, -2.06, 0.0, 1.68, 1.01]
92+
# Get joint names and current position and define a goal position:
93+
controller_state = get_message(
94+
JointTrajectoryControllerState,
95+
f"/{self.platforms['arm'].namespace}/joint_trajectory_controller/controller_state",
96+
timeout=timeout,
97+
)
98+
current_positions = list(controller_state.reference.positions)
99+
goal_positions = [position + np.deg2rad(10) for position in current_positions]
100+
101+
# Call the follow_joint_trajectory action and check if the joints reached the expected positions:
89102
follow_joint_trajectory_goal(
90103
test_node,
91-
positions=expected_positions,
92-
controller=f"{self.platforms['arm'].namespace}/fr3_arm_controller",
104+
controller_state.joint_names,
105+
goal_positions,
106+
controller=f"{self.platforms['arm'].namespace}/joint_trajectory_controller",
93107
timeout=timeout,
94108
)
95-
joint_names = [f"fr3_joint{i + 1}" for i in range(7)]
96109
check_joint_positions(
97110
self.platforms["arm"].namespace,
98-
joint_names,
99-
expected_positions,
111+
controller_state.joint_names,
112+
goal_positions,
100113
joint_movement_tolerance,
101114
timeout,
102115
)
@@ -111,21 +124,43 @@ def test_move_to_drop_configuration(
111124
joint_movement_tolerance (float): The tolerance for joint movement.
112125
timeout (int): The timeout in seconds before stopping the test.
113126
"""
127+
# Get the robot_description_semantic and convert to a dictionary:
128+
robot_description_semantic_str = get_parameter(
129+
test_node,
130+
f"/{self.platforms['arm'].namespace}/move_group",
131+
"robot_description_semantic",
132+
timeout=timeout,
133+
).string_value
134+
robot_description_semantic = xmltodict.parse(robot_description_semantic_str)
135+
136+
# Extract the configurations from the robot_description_semantic:
137+
configurations = {}
138+
group_states = robot_description_semantic["robot"]["group_state"]
139+
for group_state in group_states:
140+
configurations[group_state["@name"]] = {"names": [], "positions": []}
141+
for joint in group_state["joint"]:
142+
configurations[group_state["@name"]]["names"].append(joint["@name"])
143+
configurations[group_state["@name"]]["positions"].append(
144+
joint["@value"]
145+
)
146+
147+
# Call the move_to_configuration service and check if the joints reached the expected positions:
148+
configuration = "drop"
149+
names = configurations[configuration]["names"]
150+
positions = [float(pos) for pos in configurations[configuration]["positions"]]
114151
assert call_move_to_configuration_service(
115-
test_node, self.platforms["arm"].namespace, "drop", timeout=timeout
152+
test_node, self.platforms["arm"].namespace, configuration, timeout=timeout
116153
), "Failed to call move_to_configuration service."
117-
joint_names = [f"fr3_joint{i + 1}" for i in range(7)]
118-
expected_positions = [-1.57079632679, -0.65, 0, -2.4, 0, 1.75, 0.78539816339]
119154
check_joint_positions(
120155
self.platforms["arm"].namespace,
121-
joint_names,
122-
expected_positions,
156+
names,
157+
positions,
123158
joint_movement_tolerance,
124159
timeout,
125160
)
126161

127162

128-
for arm in ["franka"]:
163+
for arm in ["franka", "ur"]:
129164
arm_platform = Arm(arm, (0, 0, 0.5), gripper=True, moveit=True)
130165
test_class = type(
131166
f"Test{arm.capitalize()}",

alliander_tests/src/alliander_tests/utils.py

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from builtin_interfaces.msg import Duration
1414
from control_msgs.action import FollowJointTrajectory
1515
from launch_testing_ros.wait_for_topics import WaitForTopics
16+
from rcl_interfaces.msg import ParameterValue
17+
from rcl_interfaces.srv import GetParameters
1618
from rclpy.action import ActionClient
1719
from rclpy.action.client import ClientGoalHandle
1820
from rclpy.client import Client
@@ -163,6 +165,31 @@ def call_trigger_service(node: Node, service_name: str, timeout: int) -> bool:
163165
return future.result() is not None
164166

165167

168+
def get_parameter(
169+
node: Node, parameter_node: str, parameter_name: str, timeout: int
170+
) -> ParameterValue:
171+
"""Get a parameter using a service call.
172+
173+
Args:
174+
node (Node): The rclpy node used to get the parameter.
175+
parameter_node (str): The name of the node that has the parameter.
176+
parameter_name (str): The name of the parameter to get.
177+
timeout (int): Timeout in seconds to wait for the parameter.
178+
179+
Returns:
180+
ParameterValue: The value of the parameter
181+
"""
182+
client = create_ready_service_client(
183+
node, GetParameters, f"{parameter_node}/get_parameters", timeout
184+
)
185+
request = GetParameters.Request()
186+
request.names = [parameter_name]
187+
188+
future = client.call_async(request)
189+
rclpy.spin_until_future_complete(node, future=future, timeout_sec=timeout)
190+
return future.result().values[0]
191+
192+
166193
def call_trigger_action(node: Node, action_name: str, timeout: int) -> bool:
167194
"""Call a trigger action and return True if the action was called successfully.
168195
@@ -223,18 +250,33 @@ def create_ready_action_client(
223250
return client
224251

225252

226-
def assert_for_message(message_type: type, topic: str, timeout: int) -> None:
227-
"""Assert that a message of a specific type is received on a given topic within a timeout period.
253+
def get_message(message_type: type, topic: str, timeout: int) -> Any:
254+
"""Try to receive a message of a specific type on a given topic within a timeout period.
228255
229256
Args:
230257
message_type (type): The type of the message to wait for.
231258
topic (str): The topic to listen to.
232259
timeout (int): The maximum time in seconds to wait for the message.
260+
261+
Returns:
262+
Any: The received message, or None if no message was received within the timeout.
233263
"""
234264
wait_for_topics = WaitForTopics([(topic, message_type)], timeout)
235265
received = wait_for_topics.wait()
236266
wait_for_topics.shutdown()
237-
assert received, (
267+
return wait_for_topics.received_messages(topic)[-1] if received else None
268+
269+
270+
def assert_for_message(message_type: type, topic: str, timeout: int) -> None:
271+
"""Assert that a message of a specific type is received on a given topic within a timeout period.
272+
273+
Args:
274+
message_type (type): The type of the message to wait for.
275+
topic (str): The topic to listen to.
276+
timeout (int): The maximum time in seconds to wait for the message.
277+
"""
278+
message = get_message(message_type, topic, timeout)
279+
assert message is not None, (
238280
f"No message received of type {message_type.__name__} on topic {topic} within {timeout} seconds."
239281
)
240282

@@ -303,20 +345,22 @@ def call_move_to_configuration_service(
303345

304346
def follow_joint_trajectory_goal(
305347
node: Node,
348+
names: list[str],
306349
positions: list[float],
307350
controller: str,
308351
timeout: int,
309-
time_from_start: int = 3,
310352
) -> None:
311353
"""Test sending a joint trajectory goal to the arm controller.
312354
313355
Args:
314356
node (Node): The ROS 2 node to use for the action client.
357+
names (list[str]): The names of the joints to control.
315358
positions (list[float]): The joint positions to move to.
316359
controller (str): The name of the controller to use.
317360
timeout (int): The timeout in seconds for the action client.
318-
time_from_start (int, optional): The time from start in seconds. Defaults to 3.
319361
"""
362+
time_from_start = 3
363+
320364
action_client = create_ready_action_client(
321365
node,
322366
FollowJointTrajectory,
@@ -325,15 +369,7 @@ def follow_joint_trajectory_goal(
325369
)
326370

327371
goal_msg = FollowJointTrajectory.Goal()
328-
goal_msg.trajectory.joint_names = [
329-
"fr3_joint1",
330-
"fr3_joint2",
331-
"fr3_joint3",
332-
"fr3_joint4",
333-
"fr3_joint5",
334-
"fr3_joint6",
335-
"fr3_joint7",
336-
]
372+
goal_msg.trajectory.joint_names = names
337373

338374
point = JointTrajectoryPoint()
339375
point.positions = positions

0 commit comments

Comments
 (0)