|
6 | 6 | from pathlib import Path |
7 | 7 | from typing import TYPE_CHECKING |
8 | 8 |
|
| 9 | +import numpy as np |
9 | 10 | from rosbags.rosbag2 import Reader |
10 | 11 | from rosbags.typesys import Stores, get_typestore |
11 | 12 |
|
|
26 | 27 |
|
27 | 28 | _POINTCLOUD2_MSGTYPE = "sensor_msgs/msg/PointCloud2" |
28 | 29 | _PANDARSCAN_MSGTYPE = "pandar_msgs/msg/PandarScan" |
| 30 | +_TF_STATIC_TOPIC = "/tf_static" |
29 | 31 |
|
30 | 32 | _SUPPORTED_MSGTYPES = {_POINTCLOUD2_MSGTYPE, _PANDARSCAN_MSGTYPE} |
31 | 33 |
|
32 | 34 | logger = logging.getLogger(__name__) |
33 | 35 |
|
34 | 36 |
|
| 37 | +def _quat_to_matrix(x: float, y: float, z: float, w: float) -> np.ndarray: |
| 38 | + """Convert quaternion (x, y, z, w) to a 3x3 rotation matrix.""" |
| 39 | + return np.array([ |
| 40 | + [1 - 2 * (y * y + z * z), 2 * (x * y - w * z), 2 * (x * z + w * y)], |
| 41 | + [2 * (x * y + w * z), 1 - 2 * (x * x + z * z), 2 * (y * z - w * x)], |
| 42 | + [2 * (x * z - w * y), 2 * (y * z + w * x), 1 - 2 * (x * x + y * y)], |
| 43 | + ]) |
| 44 | + |
| 45 | + |
| 46 | +def _build_tf_to_base(typestore: object, reader: Reader) -> dict[str, tuple[np.ndarray, np.ndarray]]: |
| 47 | + """Read /tf_static and compute transforms from each frame to base_link. |
| 48 | +
|
| 49 | + Returns: |
| 50 | + Dict mapping ``frame_id`` to ``(R, t)`` where ``R`` is a 3x3 rotation |
| 51 | + matrix and ``t`` is a 3-element translation vector, representing the |
| 52 | + transform from that frame to ``base_link``. |
| 53 | + """ |
| 54 | + tf_conns = [c for c in reader.connections if c.topic == _TF_STATIC_TOPIC] |
| 55 | + if not tf_conns: |
| 56 | + return {} |
| 57 | + |
| 58 | + # Parse the TF tree: child_frame -> (parent_frame, R, t) |
| 59 | + tree: dict[str, tuple[str, np.ndarray, np.ndarray]] = {} |
| 60 | + for conn, _ts, rawdata in reader.messages(connections=tf_conns): |
| 61 | + msg = typestore.deserialize_cdr(rawdata, conn.msgtype) |
| 62 | + for tf in msg.transforms: |
| 63 | + parent = tf.header.frame_id |
| 64 | + child = tf.child_frame_id |
| 65 | + tr = tf.transform.translation |
| 66 | + rot = tf.transform.rotation |
| 67 | + R = _quat_to_matrix(rot.x, rot.y, rot.z, rot.w) |
| 68 | + t = np.array([tr.x, tr.y, tr.z]) |
| 69 | + tree[child] = (parent, R, t) |
| 70 | + |
| 71 | + # Compose transforms from each frame to base_link |
| 72 | + result: dict[str, tuple[np.ndarray, np.ndarray]] = {} |
| 73 | + result["base_link"] = (np.eye(3), np.zeros(3)) |
| 74 | + |
| 75 | + def _resolve(frame: str) -> tuple[np.ndarray, np.ndarray] | None: |
| 76 | + if frame in result: |
| 77 | + return result[frame] |
| 78 | + if frame not in tree: |
| 79 | + return None |
| 80 | + parent, R_child, t_child = tree[frame] |
| 81 | + parent_tf = _resolve(parent) |
| 82 | + if parent_tf is None: |
| 83 | + return None |
| 84 | + R_parent, t_parent = parent_tf |
| 85 | + # T_base = T_parent * T_child: p_base = R_parent*(R_child*p + t_child) + t_parent |
| 86 | + R = R_parent @ R_child |
| 87 | + t = R_parent @ t_child + t_parent |
| 88 | + result[frame] = (R, t) |
| 89 | + return result[frame] |
| 90 | + |
| 91 | + for frame in tree: |
| 92 | + _resolve(frame) |
| 93 | + |
| 94 | + return result |
| 95 | + |
| 96 | + |
35 | 97 | class Rosbag2Reader: |
36 | 98 | """Reader for rosbag2 files that provides LiDAR point cloud data. |
37 | 99 |
|
@@ -122,6 +184,25 @@ def __init__( |
122 | 184 | for conn in self._connections: |
123 | 185 | self._topic_connections.setdefault(conn.topic, []).append(conn) |
124 | 186 |
|
| 187 | + # Read /tf_static and build transforms to base_link |
| 188 | + self._tf_to_base = _build_tf_to_base(self._typestore, self._reader) |
| 189 | + |
| 190 | + # For PandarScan topics with frame_id specified, look up TF transform |
| 191 | + self._channel_tf: dict[str, tuple[np.ndarray, np.ndarray]] = {} |
| 192 | + if topic_mapping is not None: |
| 193 | + for m in topic_mapping: |
| 194 | + if m.frame_id is None: |
| 195 | + continue |
| 196 | + if m.frame_id in self._tf_to_base: |
| 197 | + self._channel_tf[m.channel] = self._tf_to_base[m.frame_id] |
| 198 | + else: |
| 199 | + logger.warning( |
| 200 | + "Channel '%s': frame_id='%s' not found in /tf_static. " |
| 201 | + "Points will be in sensor frame.", |
| 202 | + m.channel, |
| 203 | + m.frame_id, |
| 204 | + ) |
| 205 | + |
125 | 206 | # Build timestamp index: channel -> sorted list of timestamp_ns |
126 | 207 | # Also build a cached list of timestamp_us per channel for bisect lookups |
127 | 208 | self._timestamp_ns: dict[str, list[int]] = {} |
@@ -230,7 +311,13 @@ def get_pointcloud( |
230 | 311 | ): |
231 | 312 | msg = self._typestore.deserialize_cdr(rawdata, conn.msgtype) |
232 | 313 | if conn.msgtype == _PANDARSCAN_MSGTYPE: |
233 | | - return pandarscan_to_lidar(msg, self._channel_to_sensor_type[channel]) |
| 314 | + pc = pandarscan_to_lidar(msg, self._channel_to_sensor_type[channel]) |
| 315 | + tf = self._channel_tf.get(channel) |
| 316 | + if tf is not None: |
| 317 | + R, t = tf |
| 318 | + xyz = pc.points[:3, :] |
| 319 | + pc.points[:3, :] = R @ xyz + t[:, np.newaxis] |
| 320 | + return pc |
234 | 321 | return pointcloud2_to_lidar(msg) |
235 | 322 |
|
236 | 323 | raise ValueError( |
|
0 commit comments