|
| 1 | +# Copyright 2026 Enactic, Inc. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Shared MuJoCo context for OpenArm FK, IK, and controller nodes.""" |
| 16 | + |
| 17 | +from __future__ import annotations |
| 18 | + |
| 19 | +import argparse |
| 20 | + |
| 21 | +import mujoco |
| 22 | +import openarm_mujoco_v2 as openarm_mujoco |
| 23 | +from openarm_mujoco_v2 import JointResolver |
| 24 | + |
| 25 | +_DEFAULT_XML = openarm_mujoco.openarm_cell_xml() |
| 26 | + |
| 27 | +_DEFAULT_FRAME_RIGHT = "right_ee_control_point" |
| 28 | +_DEFAULT_FRAME_TYPE_RIGHT = "site" |
| 29 | +_DEFAULT_FRAME_LEFT = "left_ee_control_point" |
| 30 | +_DEFAULT_FRAME_TYPE_LEFT = "site" |
| 31 | + |
| 32 | +_FRAME_OBJ = { |
| 33 | + "body": mujoco.mjtObj.mjOBJ_BODY, |
| 34 | + "site": mujoco.mjtObj.mjOBJ_SITE, |
| 35 | + "geom": mujoco.mjtObj.mjOBJ_GEOM, |
| 36 | +} |
| 37 | + |
| 38 | + |
| 39 | +class ArmSetup: |
| 40 | + """MuJoCo context shared across FK, IK, and controller nodes. |
| 41 | +
|
| 42 | + Bundles the model, data, joint resolver, active arm sides, and per-arm |
| 43 | + EE frame IDs/types. Instantiate once per process; pass into any solver |
| 44 | + or controller that needs model access. |
| 45 | +
|
| 46 | + Pose convention throughout: float32[7] = [px, py, pz, qw, qx, qy, qz] |
| 47 | + """ |
| 48 | + |
| 49 | + def __init__( |
| 50 | + self, |
| 51 | + model: mujoco.MjModel, |
| 52 | + data: mujoco.MjData, |
| 53 | + joint_resolver: JointResolver, |
| 54 | + sides: list[str], |
| 55 | + frame_ids: dict[str, int], |
| 56 | + frame_types: dict[str, str], |
| 57 | + ) -> None: |
| 58 | + self.model = model |
| 59 | + self.data = data |
| 60 | + self.joint_resolver = joint_resolver |
| 61 | + self.sides = sides |
| 62 | + self.frame_ids = frame_ids # side → MuJoCo object ID |
| 63 | + self.frame_types = frame_types # side → "body" | "site" | "geom" |
| 64 | + |
| 65 | + @classmethod |
| 66 | + def from_args( |
| 67 | + cls, |
| 68 | + xml: str, |
| 69 | + mode: str, |
| 70 | + frame_right: str, |
| 71 | + frame_type_right: str, |
| 72 | + frame_left: str, |
| 73 | + frame_type_left: str, |
| 74 | + keyframe: str | None = "home", |
| 75 | + ) -> ArmSetup: |
| 76 | + model = mujoco.MjModel.from_xml_path(xml) |
| 77 | + data = mujoco.MjData(model) |
| 78 | + |
| 79 | + if keyframe: |
| 80 | + key_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_KEY, keyframe) |
| 81 | + if key_id >= 0: |
| 82 | + mujoco.mj_resetDataKeyframe(model, data, key_id) |
| 83 | + else: |
| 84 | + print(f"Warning: keyframe '{keyframe}' not found, using defaults.") |
| 85 | + |
| 86 | + mujoco.mj_forward(model, data) |
| 87 | + |
| 88 | + sides: list[str] = [] |
| 89 | + if mode in ("right", "bimanual"): |
| 90 | + sides.append("right") |
| 91 | + if mode in ("left", "bimanual"): |
| 92 | + sides.append("left") |
| 93 | + |
| 94 | + frame_ids: dict[str, int] = {} |
| 95 | + frame_types: dict[str, str] = {} |
| 96 | + for side in sides: |
| 97 | + name = frame_right if side == "right" else frame_left |
| 98 | + ftype = frame_type_right if side == "right" else frame_type_left |
| 99 | + frame_ids[side] = _resolve_frame_id(model, name, ftype) |
| 100 | + frame_types[side] = ftype |
| 101 | + |
| 102 | + return cls( |
| 103 | + model=model, |
| 104 | + data=data, |
| 105 | + joint_resolver=JointResolver(model), |
| 106 | + sides=sides, |
| 107 | + frame_ids=frame_ids, |
| 108 | + frame_types=frame_types, |
| 109 | + ) |
| 110 | + |
| 111 | + def read_ee_pose(self, side: str) -> "np.ndarray": |
| 112 | + """Return float32[7] = [px, py, pz, qw, qx, qy, qz] for the given arm.""" |
| 113 | + from openarm_control.poses import read_ee_pose |
| 114 | + return read_ee_pose(self.data, self.frame_ids[side], self.frame_types[side]) |
| 115 | + |
| 116 | + |
| 117 | +def _resolve_frame_id(model: mujoco.MjModel, name: str, ftype: str) -> int: |
| 118 | + obj = _FRAME_OBJ.get(ftype) |
| 119 | + if obj is None: |
| 120 | + raise ValueError(f"Unknown frame_type '{ftype}'. Expected body/site/geom.") |
| 121 | + fid = mujoco.mj_name2id(model, obj, name) |
| 122 | + if fid < 0: |
| 123 | + raise ValueError(f"{ftype.capitalize()} '{name}' not found in model.") |
| 124 | + return fid |
| 125 | + |
| 126 | + |
| 127 | +def register_common_args(parser: argparse.ArgumentParser) -> None: |
| 128 | + """Register shared CLI flags used by all arm nodes: --xml, --keyframe, --mode, --frame-*.""" |
| 129 | + parser.add_argument( |
| 130 | + "--xml", |
| 131 | + default=_DEFAULT_XML, |
| 132 | + help=f"MJCF scene file (default: {_DEFAULT_XML})", |
| 133 | + ) |
| 134 | + parser.add_argument( |
| 135 | + "--keyframe", "-k", |
| 136 | + default="home", |
| 137 | + help="Initial keyframe name (default: home)", |
| 138 | + ) |
| 139 | + parser.add_argument( |
| 140 | + "--mode", |
| 141 | + choices=["right", "left", "bimanual"], |
| 142 | + default="bimanual", |
| 143 | + help="Which arm(s) to compute (default: bimanual)", |
| 144 | + ) |
| 145 | + parser.add_argument( |
| 146 | + "--frame-right", |
| 147 | + default=_DEFAULT_FRAME_RIGHT, |
| 148 | + help=f"EE frame name for right arm (default: {_DEFAULT_FRAME_RIGHT})", |
| 149 | + ) |
| 150 | + parser.add_argument( |
| 151 | + "--frame-type-right", |
| 152 | + choices=["body", "site", "geom"], |
| 153 | + default=_DEFAULT_FRAME_TYPE_RIGHT, |
| 154 | + help="EE frame type for right arm (default: site)", |
| 155 | + ) |
| 156 | + parser.add_argument( |
| 157 | + "--frame-left", |
| 158 | + default=_DEFAULT_FRAME_LEFT, |
| 159 | + help=f"EE frame name for left arm (default: {_DEFAULT_FRAME_LEFT})", |
| 160 | + ) |
| 161 | + parser.add_argument( |
| 162 | + "--frame-type-left", |
| 163 | + choices=["body", "site", "geom"], |
| 164 | + default=_DEFAULT_FRAME_TYPE_LEFT, |
| 165 | + help="EE frame type for left arm (default: site)", |
| 166 | + ) |
| 167 | + |
| 168 | + |
| 169 | +def setup_from_args(args: argparse.Namespace) -> ArmSetup: |
| 170 | + """Build ArmSetup from a namespace that contains the common CLI flags.""" |
| 171 | + return ArmSetup.from_args( |
| 172 | + xml=args.xml, |
| 173 | + mode=args.mode, |
| 174 | + frame_right=args.frame_right, |
| 175 | + frame_type_right=args.frame_type_right, |
| 176 | + frame_left=args.frame_left, |
| 177 | + frame_type_left=args.frame_type_left, |
| 178 | + keyframe=args.keyframe, |
| 179 | + ) |
0 commit comments