diff --git a/mpx/config/config_aliengo_trot_two_step.py b/mpx/config/config_aliengo_trot_two_step.py
index 3bb9a37..22456ff 100644
--- a/mpx/config/config_aliengo_trot_two_step.py
+++ b/mpx/config/config_aliengo_trot_two_step.py
@@ -48,6 +48,8 @@
initial_state = base.initial_state
cost = partial(mpc_objectives.quadruped_wb_obj, True, n_joints, n_contact, N)
+cost_smooth = partial(mpc_objectives.quadruped_wb_smooth_cost, True, n_joints, n_contact, N)
+inequalities = partial(mpc_objectives.quadruped_wb_inequalities, n_joints, n_contact, 0.5, 44.0, 10.0)
hessian_approx = base.hessian_approx
dynamics = base.dynamics
@@ -64,3 +66,20 @@
solver_mode = "fddp"
max_torque = base.max_torque
min_torque = base.min_torque
+
+lipa_enforce_inequalities = True
+
+def _lipa_settings():
+ from primal_dual_lipa.types import SolverSettings
+ return SolverSettings(
+ max_iterations=2000,
+ η0=1e9,
+ η_update_factor=1.0,
+ µ_update_factor=0.9,
+ cost_improvement_threshold=1e-3,
+ primal_violation_threshold=1e-5,
+ use_parallel_lqr=False,
+ num_parallel_line_search_steps=1,
+ )
+
+lipa_settings = _lipa_settings()
diff --git a/mpx/config/config_barrel_roll.py b/mpx/config/config_barrel_roll.py
index 05f6b12..8e8bbee 100644
--- a/mpx/config/config_barrel_roll.py
+++ b/mpx/config/config_barrel_roll.py
@@ -88,6 +88,7 @@
)
cost = partial(mpc_objectives.quadruped_wb_obj, False, n_joints, n_contact, N)
+cost_smooth = partial(mpc_objectives.quadruped_wb_smooth_cost, False, n_joints, n_contact, N)
hessian_approx = None
def dynamics(model, mjx_model, contact_id, body_id):
@@ -105,4 +106,24 @@ def dynamics(model, mjx_model, contact_id, body_id):
# dynamics = mpc_dyn_model.quadruped_wb_dynamics_learned_contact_model
# dynamics = mpc_dyn_model.quadruped_wb_dynamics_explicit_contact
max_torque = 40
-min_torque = -40
\ No newline at end of file
+min_torque = -40
+
+inequalities = partial(
+ mpc_objectives.quadruped_wb_inequalities, n_joints, n_contact, 0.5, 50.0, 20.0
+)
+lipa_enforce_inequalities = True
+
+def _lipa_settings():
+ from primal_dual_lipa.types import SolverSettings
+ return SolverSettings(
+ max_iterations=2000,
+ η0=1e9,
+ η_update_factor=1.1,
+ µ_update_factor=0.9,
+ cost_improvement_threshold=1e-3,
+ primal_violation_threshold=1e-5,
+ use_parallel_lqr=False,
+ num_parallel_line_search_steps=1,
+ )
+
+lipa_settings = _lipa_settings()
diff --git a/mpx/config/config_h1_jump_forward.py b/mpx/config/config_h1_jump_forward.py
index a13822b..a9ec2cd 100644
--- a/mpx/config/config_h1_jump_forward.py
+++ b/mpx/config/config_h1_jump_forward.py
@@ -41,6 +41,8 @@
torque_limits = base.torque_limits
cost = partial(mpc_objectives.h1_kinodynamic_obj, n_joints, n_contact, N)
+cost_smooth = partial(mpc_objectives.h1_kinodynamic_smooth_cost, n_joints, n_contact, N)
+inequalities = partial(mpc_objectives.h1_kinodynamic_inequalities, n_joints, n_contact, 0.7)
hessian_approx = base.hessian_approx
dynamics = base.dynamics
MPCWrapper = base.MPCWrapper
@@ -58,3 +60,37 @@
solver_mode = "fddp"
max_torque = base.max_torque
min_torque = base.min_torque
+
+lipa_enforce_inequalities = True
+
+def _lipa_settings():
+ from primal_dual_lipa.types import SolverSettings
+ return SolverSettings(
+ max_iterations=2000,
+ η0=1e9,
+ η_update_factor=1.0,
+ µ_update_factor=0.9,
+ cost_improvement_threshold=1e-3,
+ primal_violation_threshold=1e-5,
+ num_iterative_refinement_steps=2,
+ use_parallel_lqr=False,
+ num_parallel_line_search_steps=1,
+ )
+
+lipa_settings = _lipa_settings()
+
+def _lipa_settings_enforce():
+ from primal_dual_lipa.types import SolverSettings
+ return SolverSettings(
+ max_iterations=500,
+ η0=1e5,
+ η_update_factor=2.0,
+ µ_update_factor=0.9,
+ cost_improvement_threshold=1e-3,
+ primal_violation_threshold=1e-5,
+ num_iterative_refinement_steps=2,
+ use_parallel_lqr=False,
+ num_parallel_line_search_steps=1,
+ )
+
+lipa_settings_enforce = _lipa_settings_enforce()
diff --git a/mpx/data/acrobot/scene.xml b/mpx/data/acrobot/scene.xml
index 2ca267d..1fd4d9f 100644
--- a/mpx/data/acrobot/scene.xml
+++ b/mpx/data/acrobot/scene.xml
@@ -3,7 +3,7 @@
-
+
diff --git a/mpx/data/aliengo/scene_flat.xml b/mpx/data/aliengo/scene_flat.xml
index 89e46fc..8ad354c 100644
--- a/mpx/data/aliengo/scene_flat.xml
+++ b/mpx/data/aliengo/scene_flat.xml
@@ -22,7 +22,7 @@
-
+
diff --git a/mpx/data/unitree_h1/mjx_scene_h1_walk.xml b/mpx/data/unitree_h1/mjx_scene_h1_walk.xml
index 110b364..4c0a500 100644
--- a/mpx/data/unitree_h1/mjx_scene_h1_walk.xml
+++ b/mpx/data/unitree_h1/mjx_scene_h1_walk.xml
@@ -6,7 +6,7 @@
-
+
diff --git a/mpx/examples/acrobot.py b/mpx/examples/acrobot.py
index 7d71ae3..15423a6 100644
--- a/mpx/examples/acrobot.py
+++ b/mpx/examples/acrobot.py
@@ -228,6 +228,6 @@ def step_controller(viewer=None):
parser = argparse.ArgumentParser()
parser.add_argument("--headless", action="store_true")
parser.add_argument("--steps", type=int, default=500)
- parser.add_argument("--solver", choices=("primal_dual", "fddp"), default="primal_dual")
+ parser.add_argument("--solver", choices=("primal_dual", "fddp", "lipa"), default="primal_dual")
args = parser.parse_args()
main(headless=args.headless, steps=args.steps, solver_mode=args.solver)
diff --git a/mpx/examples/mjx_h1.py b/mpx/examples/mjx_h1.py
index 7969fb7..df6944f 100644
--- a/mpx/examples/mjx_h1.py
+++ b/mpx/examples/mjx_h1.py
@@ -7,6 +7,10 @@
sys.path.append(os.path.abspath(os.path.join(dir_path, "..")))
os.environ.setdefault("XLA_FLAGS", "--xla_gpu_enable_command_buffer=")
+if "--video" in sys.argv:
+ os.environ.setdefault("MUJOCO_GL", "egl")
+ os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
+
import jax
import jax.numpy as jnp
import mujoco
@@ -36,7 +40,7 @@ def solve_mpc(mpc_data, qpos, qvel, foot, command, contact):
return solve_mpc
-def main(steps=500):
+def main(steps=500, video=None, vx=0.0, vy=0.0, wz=0.0, fps=30, headless=False):
model = mujoco.MjModel.from_xml_path(
dir_path + "/../data/unitree_h1/mjx_scene_h1_walk.xml"
)
@@ -45,7 +49,7 @@ def main(steps=500):
model.opt.timestep = 1 / sim_frequency
mpc = mpc_wrapper.MPCWrapper(config, limited_memory=True)
- command_handle = sim_utils.KeyboardVelocityCommand()
+ command_handle = sim_utils.KeyboardVelocityCommand(vx=vx, vy=vy, wz=wz)
solve_mpc = _build_solve_fn(mpc)
reset_mpc = jax.jit(mpc.reset)
@@ -102,6 +106,27 @@ def step_controller():
mujoco.mj_step(model, data)
counter += 1
+ if headless or video is not None:
+ recorder = None
+ capture_period = max(1, int(round(sim_frequency / fps)))
+ if video is not None:
+ os.makedirs(os.path.dirname(os.path.abspath(video)) or ".", exist_ok=True)
+ recorder = sim_utils.VideoRecorder(model, video, fps=fps)
+ p_start = np.asarray(data.qpos[:3]).copy()
+ try:
+ for i in range(steps):
+ step_controller()
+ if recorder is not None and i % capture_period == 0:
+ recorder.capture(data)
+ finally:
+ if recorder is not None:
+ recorder.close()
+ print(f"Wrote video: {video}")
+ p_end = np.asarray(data.qpos[:3])
+ delta = p_end - p_start
+ print(f"Base position: start={p_start} end={p_end} delta={delta}")
+ return
+
with mujoco.viewer.launch_passive(
model,
data,
@@ -119,5 +144,20 @@ def step_controller():
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--steps", type=int, default=500)
+ parser.add_argument("--headless", action="store_true")
+ parser.add_argument("--video", type=str, default=None,
+ help="Write an mp4 of the run to this path (forces headless).")
+ parser.add_argument("--vx", type=float, default=0.0)
+ parser.add_argument("--vy", type=float, default=0.0)
+ parser.add_argument("--wz", type=float, default=0.0)
+ parser.add_argument("--fps", type=int, default=30)
args = parser.parse_args()
- main(steps=args.steps)
+ main(
+ steps=args.steps,
+ video=args.video,
+ vx=args.vx,
+ vy=args.vy,
+ wz=args.wz,
+ fps=args.fps,
+ headless=args.headless,
+ )
diff --git a/mpx/examples/mjx_h1_kinodynamic.py b/mpx/examples/mjx_h1_kinodynamic.py
index 043a29e..39d2893 100644
--- a/mpx/examples/mjx_h1_kinodynamic.py
+++ b/mpx/examples/mjx_h1_kinodynamic.py
@@ -7,6 +7,10 @@
sys.path.append(os.path.abspath(os.path.join(dir_path, "..")))
os.environ.setdefault("XLA_FLAGS", "--xla_gpu_enable_command_buffer=")
+if "--video" in sys.argv:
+ os.environ.setdefault("MUJOCO_GL", "egl")
+ os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
+
import jax
import jax.numpy as jnp
import mujoco
@@ -35,7 +39,7 @@ def solve_mpc(mpc_data, qpos, qvel, foot, command, contact):
return solve_mpc
-def main(steps=500):
+def main(steps=500, video=None, vx=0.0, vy=0.0, wz=0.0, fps=30, headless=False):
model = mujoco.MjModel.from_xml_path(
dir_path + "/../data/unitree_h1/mjx_scene_h1_walk.xml"
)
@@ -44,7 +48,7 @@ def main(steps=500):
model.opt.timestep = 1 / sim_frequency
mpc = config.MPCWrapper(config, limited_memory=True)
- command_handle = sim_utils.KeyboardVelocityCommand()
+ command_handle = sim_utils.KeyboardVelocityCommand(vx=vx, vy=vy, wz=wz)
solve_mpc = _build_solve_fn(mpc)
reset_mpc = jax.jit(mpc.reset)
@@ -100,6 +104,27 @@ def step_controller():
mujoco.mj_step(model, data)
counter += 1
+ if headless or video is not None:
+ recorder = None
+ capture_period = max(1, int(round(sim_frequency / fps)))
+ if video is not None:
+ os.makedirs(os.path.dirname(os.path.abspath(video)) or ".", exist_ok=True)
+ recorder = sim_utils.VideoRecorder(model, video, fps=fps)
+ p_start = np.asarray(data.qpos[:3]).copy()
+ try:
+ for i in range(steps):
+ step_controller()
+ if recorder is not None and i % capture_period == 0:
+ recorder.capture(data)
+ finally:
+ if recorder is not None:
+ recorder.close()
+ print(f"Wrote video: {video}")
+ p_end = np.asarray(data.qpos[:3])
+ delta = p_end - p_start
+ print(f"Base position: start={p_start} end={p_end} delta={delta}")
+ return
+
with mujoco.viewer.launch_passive(
model,
data,
@@ -117,5 +142,20 @@ def step_controller():
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--steps", type=int, default=500)
+ parser.add_argument("--headless", action="store_true")
+ parser.add_argument("--video", type=str, default=None,
+ help="Write an mp4 of the run to this path (forces headless).")
+ parser.add_argument("--vx", type=float, default=0.0)
+ parser.add_argument("--vy", type=float, default=0.0)
+ parser.add_argument("--wz", type=float, default=0.0)
+ parser.add_argument("--fps", type=int, default=30)
args = parser.parse_args()
- main(steps=args.steps)
+ main(
+ steps=args.steps,
+ video=args.video,
+ vx=args.vx,
+ vy=args.vy,
+ wz=args.wz,
+ fps=args.fps,
+ headless=args.headless,
+ )
diff --git a/mpx/examples/mjx_quad.py b/mpx/examples/mjx_quad.py
index c223900..880ca75 100644
--- a/mpx/examples/mjx_quad.py
+++ b/mpx/examples/mjx_quad.py
@@ -8,6 +8,12 @@
sys.path.append(os.path.abspath(os.path.join(dir_path, "..")))
os.environ.setdefault("XLA_FLAGS", "--xla_gpu_enable_command_buffer=")
+# Headless video recording uses `mujoco.Renderer`, which requires an OpenGL
+# backend to be configured before the first `import mujoco` in the process.
+if "--video" in sys.argv:
+ os.environ.setdefault("MUJOCO_GL", "egl")
+ os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
+
import jax
import jax.numpy as jnp
import mujoco
@@ -37,7 +43,16 @@ def solve_mpc(mpc_data, qpos, qvel, foot, command, contact):
return solve_mpc
-def main(headless=False, steps=500, scene="flat"):
+def main(
+ headless=False,
+ steps=500,
+ scene="flat",
+ video=None,
+ vx=0.0,
+ vy=0.0,
+ wz=0.0,
+ fps=30,
+):
model = mujoco.MjModel.from_xml_path(
dir_path + f"/../data/aliengo/scene_{scene}.xml"
)
@@ -47,7 +62,9 @@ def main(headless=False, steps=500, scene="flat"):
contact_ids = sim_utils.geom_ids(model, config.contact_frame)
mpc = mpc_wrapper.MPCWrapper(config, limited_memory=True)
- command_handle = sim_utils.KeyboardVelocityCommand()
+ # Headless+video: scripted velocity (no keyboard); viewer mode keeps the
+ # interactive arrow-key handle.
+ command_handle = sim_utils.KeyboardVelocityCommand(vx=vx, vy=vy, wz=wz)
solve_mpc = _build_solve_fn(mpc)
reset_mpc = jax.jit(mpc.reset)
@@ -112,9 +129,25 @@ def step_controller():
mujoco.mj_step(model, data)
counter += 1
- if headless:
- for _ in range(steps):
- step_controller()
+ if headless or video is not None:
+ recorder = None
+ capture_period = max(1, int(round(sim_frequency / fps)))
+ if video is not None:
+ os.makedirs(os.path.dirname(os.path.abspath(video)) or ".", exist_ok=True)
+ recorder = sim_utils.VideoRecorder(model, video, fps=fps)
+ p_start = np.asarray(data.qpos[:3]).copy()
+ try:
+ for i in range(steps):
+ step_controller()
+ if recorder is not None and i % capture_period == 0:
+ recorder.capture(data)
+ finally:
+ if recorder is not None:
+ recorder.close()
+ print(f"Wrote video: {video}")
+ p_end = np.asarray(data.qpos[:3])
+ delta = p_end - p_start
+ print(f"Base position: start={p_start} end={p_end} delta={delta}")
return
with mujoco.viewer.launch_passive(
@@ -141,9 +174,22 @@ def step_controller():
parser.add_argument("--steps", type=int, default=500)
parser.add_argument("--scene", type=str, default="flat")
parser.add_argument("--headless", action="store_true")
+ parser.add_argument("--video", type=str, default=None,
+ help="Write an mp4 of the run to this path (forces headless).")
+ parser.add_argument("--vx", type=float, default=0.0,
+ help="Forward velocity command (m/s) for headless/video runs.")
+ parser.add_argument("--vy", type=float, default=0.0)
+ parser.add_argument("--wz", type=float, default=0.0,
+ help="Yaw-rate command (rad/s).")
+ parser.add_argument("--fps", type=int, default=30)
args = parser.parse_args()
main(
headless=args.headless,
steps=args.steps,
scene=args.scene,
+ video=args.video,
+ vx=args.vx,
+ vy=args.vy,
+ wz=args.wz,
+ fps=args.fps,
)
diff --git a/mpx/examples/mjx_talos.py b/mpx/examples/mjx_talos.py
index d14ea2c..5c95e78 100644
--- a/mpx/examples/mjx_talos.py
+++ b/mpx/examples/mjx_talos.py
@@ -7,6 +7,10 @@
sys.path.append(os.path.abspath(os.path.join(dir_path, "..")))
os.environ.setdefault("XLA_FLAGS", "--xla_gpu_enable_command_buffer=")
+if "--video" in sys.argv:
+ os.environ.setdefault("MUJOCO_GL", "egl")
+ os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
+
import jax
import jax.numpy as jnp
import mujoco
@@ -36,7 +40,7 @@ def solve_mpc(mpc_data, qpos, qvel, foot, command, contact):
return solve_mpc
-def main(steps=500):
+def main(steps=500, video=None, vx=0.0, vy=0.0, wz=0.0, fps=30, headless=False):
model = mujoco.MjModel.from_xml_path(
dir_path + "/../data/pal_talos/talos_motor_rough.xml"
)
@@ -45,7 +49,7 @@ def main(steps=500):
model.opt.timestep = 1 / sim_frequency
mpc = mpc_wrapper.MPCWrapper(config, limited_memory=True)
- command_handle = sim_utils.KeyboardVelocityCommand()
+ command_handle = sim_utils.KeyboardVelocityCommand(vx=vx, vy=vy, wz=wz)
solve_mpc = _build_solve_fn(mpc)
reset_mpc = jax.jit(mpc.reset)
@@ -102,6 +106,27 @@ def step_controller():
mujoco.mj_step(model, data)
counter += 1
+ if headless or video is not None:
+ recorder = None
+ capture_period = max(1, int(round(sim_frequency / fps)))
+ if video is not None:
+ os.makedirs(os.path.dirname(os.path.abspath(video)) or ".", exist_ok=True)
+ recorder = sim_utils.VideoRecorder(model, video, fps=fps)
+ p_start = np.asarray(data.qpos[:3]).copy()
+ try:
+ for i in range(steps):
+ step_controller()
+ if recorder is not None and i % capture_period == 0:
+ recorder.capture(data)
+ finally:
+ if recorder is not None:
+ recorder.close()
+ print(f"Wrote video: {video}")
+ p_end = np.asarray(data.qpos[:3])
+ delta = p_end - p_start
+ print(f"Base position: start={p_start} end={p_end} delta={delta}")
+ return
+
with mujoco.viewer.launch_passive(
model,
data,
@@ -119,5 +144,20 @@ def step_controller():
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--steps", type=int, default=500)
+ parser.add_argument("--headless", action="store_true")
+ parser.add_argument("--video", type=str, default=None,
+ help="Write an mp4 of the run to this path (forces headless).")
+ parser.add_argument("--vx", type=float, default=0.0)
+ parser.add_argument("--vy", type=float, default=0.0)
+ parser.add_argument("--wz", type=float, default=0.0)
+ parser.add_argument("--fps", type=int, default=30)
args = parser.parse_args()
- main(steps=args.steps)
+ main(
+ steps=args.steps,
+ video=args.video,
+ vx=args.vx,
+ vy=args.vy,
+ wz=args.wz,
+ fps=args.fps,
+ headless=args.headless,
+ )
diff --git a/mpx/examples/offline_task.py b/mpx/examples/offline_task.py
index c72fe0c..ca89f9e 100644
--- a/mpx/examples/offline_task.py
+++ b/mpx/examples/offline_task.py
@@ -9,6 +9,12 @@
sys.path.append(os.path.abspath(os.path.join(dir_path, "..")))
os.environ.setdefault("XLA_FLAGS", "--xla_gpu_enable_command_buffer=")
+# Headless video recording uses `mujoco.Renderer`, which requires an OpenGL
+# backend to be configured before the first `import mujoco` in the process.
+if "--video" in sys.argv:
+ os.environ.setdefault("MUJOCO_GL", "egl")
+ os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
+
import jax
import jax.numpy as jnp
import mujoco
@@ -52,15 +58,17 @@
},
}
-SOLVERS = ("primal_dual", "fddp")
+SOLVERS = ("primal_dual", "fddp", "lipa")
-def _clone_config(module_name, solver_mode):
+def _clone_config(module_name, solver_mode, lipa_enforce_inequalities=None):
module = importlib.import_module(module_name)
attrs = {name: getattr(module, name) for name in dir(module) if not name.startswith("__")}
config = SimpleNamespace(**attrs)
if solver_mode is not None:
config.solver_mode = solver_mode
+ if lipa_enforce_inequalities is not None:
+ config.lipa_enforce_inequalities = lipa_enforce_inequalities
return config
@@ -88,29 +96,57 @@ def _solve_wrapper_task(config, max_iter, verbose):
def _solve_direct_task(config, max_iter, verbose):
- _, solve = base_mpc_wrapper.build_solver_step(
- config,
- config.cost,
- config.dynamics,
- config.hessian_approx,
- limited_memory=False,
- )
- solve = jax.jit(solve)
- X, U, V, history, stats = offline_solver.run_offline_solve(
- solve,
- config.cost,
- config.dynamics,
- config.solver_mode,
- config.reference,
- config.parameter,
- config.W,
- config.x0,
- config.initial_X0,
- config.initial_U0,
- config.initial_V0,
- max_iter=max_iter,
- verbose=verbose,
- )
+ if getattr(config, "solver_mode", None) == "lipa":
+ from mpx.utils.lipa_solver import run_lipa_offline
+ from mpx.utils.mpc_wrapper import lipa_pick_cost_and_inequalities
+
+ (
+ lipa_cost,
+ lipa_inequalities,
+ lipa_settings,
+ lipa_warmup_cost,
+ lipa_warmup_settings,
+ ) = lipa_pick_cost_and_inequalities(config, config.cost)
+ X, U, V, history, stats = run_lipa_offline(
+ lipa_cost,
+ config.dynamics,
+ config.reference,
+ config.parameter,
+ config.W,
+ config.x0,
+ config.initial_X0,
+ config.initial_U0,
+ config.initial_V0,
+ settings=lipa_settings,
+ inequalities=lipa_inequalities,
+ warmup_cost=lipa_warmup_cost,
+ warmup_settings=lipa_warmup_settings,
+ verbose=verbose,
+ )
+ else:
+ _, solve = base_mpc_wrapper.build_solver_step(
+ config,
+ config.cost,
+ config.dynamics,
+ config.hessian_approx,
+ limited_memory=False,
+ )
+ solve = jax.jit(solve)
+ X, U, V, history, stats = offline_solver.run_offline_solve(
+ solve,
+ config.cost,
+ config.dynamics,
+ config.solver_mode,
+ config.reference,
+ config.parameter,
+ config.W,
+ config.x0,
+ config.initial_X0,
+ config.initial_U0,
+ config.initial_V0,
+ max_iter=max_iter,
+ verbose=verbose,
+ )
return {
"config": config,
"X": X,
@@ -123,9 +159,11 @@ def _solve_direct_task(config, max_iter, verbose):
}
-def solve_task(task_name, solver_mode=None, max_iter=100, verbose=True):
+def solve_task(
+ task_name, solver_mode=None, max_iter=100, verbose=True, lipa_enforce_inequalities=None
+):
task = TASKS[task_name]
- config = _clone_config(task["config"], solver_mode)
+ config = _clone_config(task["config"], solver_mode, lipa_enforce_inequalities)
benchmark_mode = task["benchmark_mode"]
if benchmark_mode == "direct":
result = _solve_direct_task(config, max_iter=max_iter, verbose=verbose)
@@ -185,7 +223,52 @@ def _predicted_base_positions(config, model, qpos_sequence):
return base_positions
-def _play_mujoco_trajectory(result, headless=False, loop=True, ghost_stride=1):
+def _record_offline_trajectory_video(result, video_path, fps=None, loop_count=1, width=1280, height=720):
+ """Render the optimised offline trajectory `X` to an mp4 via offscreen GL.
+
+ No viewer involved. Uses `MUJOCO_GL=egl` (auto-set when the example sees
+ `--video` in argv) and the shared `sim_utils.VideoRecorder`. One frame per
+ state in `X`, optionally repeated `loop_count` times so short trajectories
+ aren't blink-and-miss-it. Playback is real-time by default (fps = 1/config.dt).
+ """
+ config = result["config"]
+ scene_path = result["scene_path"]
+ X = np.asarray(result["X"])
+ model = mujoco.MjModel.from_xml_path(scene_path)
+ data = mujoco.MjData(model)
+ os.makedirs(os.path.dirname(os.path.abspath(video_path)) or ".", exist_ok=True)
+
+ config_fps = 1.0 / float(config.dt)
+ if fps is None:
+ fps = int(round(config_fps))
+ subsample = 1
+ else:
+ subsample = max(1, int(round(config_fps / fps)))
+ fps = int(round(config_fps / subsample))
+
+ recorder = sim_utils.VideoRecorder(model, video_path, fps=fps, width=width, height=height)
+ try:
+ for _ in range(loop_count):
+ for i, state in enumerate(X):
+ if i % subsample != 0:
+ continue
+ qpos, qvel = _state_to_mujoco(config, state)
+ data.qpos = np.asarray(qpos)
+ data.qvel = np.asarray(qvel)
+ mujoco.mj_forward(model, data)
+ recorder.capture(data)
+ finally:
+ recorder.close()
+ print(f"Wrote video: {video_path} at {fps} fps (subsample {subsample}), {width}x{height}")
+
+
+def _play_mujoco_trajectory(result, headless=False, loop=True, ghost_stride=1, video=None, fps=None, width=None, height=None):
+ if video is not None:
+ # Recording mode: render trajectory frames to mp4 instead of opening
+ # the viewer. Implies headless.
+ _record_offline_trajectory_video(result, video, fps=fps, width=width or 1280, height=height or 720)
+ return
+
config = result["config"]
scene_path = result["scene_path"]
X = np.asarray(result["X"])
@@ -252,20 +335,35 @@ def _play_mujoco_trajectory(result, headless=False, loop=True, ghost_stride=1):
time.sleep(config.dt)
-def run_task(task_name, solver_mode=None, headless=False, max_iter=100, verbose=True, loop=True):
+def run_task(
+ task_name,
+ solver_mode=None,
+ headless=False,
+ max_iter=100,
+ verbose=True,
+ loop=True,
+ lipa_enforce_inequalities=None,
+ video=None,
+ fps=None,
+ width=None,
+ height=None,
+):
result = solve_task(
task_name,
solver_mode=solver_mode,
max_iter=max_iter,
verbose=verbose,
+ lipa_enforce_inequalities=lipa_enforce_inequalities,
)
stats = result["stats"]
+ enforce = getattr(result["config"], "lipa_enforce_inequalities", False)
+ enforce_tag = " | enforce-ineq" if (result["config"].solver_mode == "lipa" and enforce) else ""
print(
- f"{task_name} | {result['config'].solver_mode} | "
+ f"{task_name} | {result['config'].solver_mode}{enforce_tag} | "
f"iterations {stats['n_iterations']} | "
f"avg iter time {stats['average_iteration_time_ms']:.3f} ms"
)
- _play_mujoco_trajectory(result, headless=headless, loop=loop)
+ _play_mujoco_trajectory(result, headless=headless, loop=loop, video=video, fps=fps, width=width, height=height)
def build_parser(default_task=None):
@@ -284,6 +382,27 @@ def build_parser(default_task=None):
parser.add_argument("--max-iter", type=int, default=100)
parser.add_argument("--quiet", action="store_true")
parser.add_argument("--no-loop", action="store_true")
+ parser.add_argument("--video", type=str, default=None,
+ help="Render the optimised trajectory to this mp4 path (forces headless).")
+ parser.add_argument("--fps", type=int, default=None,
+ help="Target FPS for video recording (e.g. 50 or 60). Subsamples the trajectory.")
+ parser.add_argument("--width", type=int, default=1280, help="Video width (default: 1280).")
+ parser.add_argument("--height", type=int, default=720, help="Video height (default: 720).")
+ enforce_group = parser.add_mutually_exclusive_group()
+ enforce_group.add_argument(
+ "--lipa-enforce-inequalities",
+ dest="lipa_enforce_inequalities",
+ action="store_true",
+ default=None,
+ help="(LIPA only) Enforce config inequalities as true constraints; overrides config attr.",
+ )
+ enforce_group.add_argument(
+ "--no-lipa-enforce-inequalities",
+ dest="lipa_enforce_inequalities",
+ action="store_false",
+ default=None,
+ help="(LIPA only) Disable enforcement; revert to soft-penalty cost shared with FDDP/PD.",
+ )
return parser
@@ -299,6 +418,11 @@ def main(default_task=None):
max_iter=args.max_iter,
verbose=not args.quiet,
loop=not args.no_loop,
+ lipa_enforce_inequalities=args.lipa_enforce_inequalities,
+ video=args.video,
+ fps=args.fps,
+ width=args.width,
+ height=args.height,
)
diff --git a/mpx/examples/srbd_quad.py b/mpx/examples/srbd_quad.py
index 7547cd9..fbd1625 100644
--- a/mpx/examples/srbd_quad.py
+++ b/mpx/examples/srbd_quad.py
@@ -8,6 +8,10 @@
sys.path.append(os.path.abspath(os.path.join(dir_path, "..")))
os.environ.setdefault("XLA_FLAGS", "--xla_gpu_enable_command_buffer=")
+if "--video" in sys.argv:
+ os.environ.setdefault("MUJOCO_GL", "egl")
+ os.environ.setdefault("PYOPENGL_PLATFORM", "egl")
+
import jax
import jax.numpy as jnp
import mujoco
@@ -44,7 +48,16 @@ def _srbd_state(qpos, qvel):
)
-def main(headless=False, steps=500, scene="flat"):
+def main(
+ headless=False,
+ steps=500,
+ scene="flat",
+ video=None,
+ vx=0.0,
+ vy=0.0,
+ wz=0.0,
+ fps=30,
+):
model = mujoco.MjModel.from_xml_path(
dir_path + f"/../data/aliengo/scene_{scene}.xml"
)
@@ -53,7 +66,7 @@ def main(headless=False, steps=500, scene="flat"):
model.opt.timestep = 1.0 / sim_frequency
contact_ids = sim_utils.geom_ids(model, config.contact_frame)
- command_handle = sim_utils.KeyboardVelocityCommand(vx=0.0, vy=0.0, wz=0.0)
+ command_handle = sim_utils.KeyboardVelocityCommand(vx=vx, vy=vy, wz=wz)
mpc = mpc_wrapper_srbd.BatchedMPCControllerWrapper(config, n_env=1)
_reset_to_initial_state(model, data)
@@ -104,9 +117,25 @@ def step_controller():
mujoco.mj_step(model, data)
counter += 1
- if headless:
- for _ in range(steps):
- step_controller()
+ if headless or video is not None:
+ recorder = None
+ capture_period = max(1, int(round(sim_frequency / fps)))
+ if video is not None:
+ os.makedirs(os.path.dirname(os.path.abspath(video)) or ".", exist_ok=True)
+ recorder = sim_utils.VideoRecorder(model, video, fps=fps)
+ p_start = np.asarray(data.qpos[:3]).copy()
+ try:
+ for i in range(steps):
+ step_controller()
+ if recorder is not None and i % capture_period == 0:
+ recorder.capture(data)
+ finally:
+ if recorder is not None:
+ recorder.close()
+ print(f"Wrote video: {video}")
+ p_end = np.asarray(data.qpos[:3])
+ delta = p_end - p_start
+ print(f"Base position: start={p_start} end={p_end} delta={delta}")
return
with mujoco.viewer.launch_passive(
@@ -132,9 +161,20 @@ def step_controller():
parser.add_argument("--steps", type=int, default=500)
parser.add_argument("--scene", type=str, default="flat")
parser.add_argument("--headless", action="store_true")
+ parser.add_argument("--video", type=str, default=None,
+ help="Write an mp4 of the run to this path (forces headless).")
+ parser.add_argument("--vx", type=float, default=0.0)
+ parser.add_argument("--vy", type=float, default=0.0)
+ parser.add_argument("--wz", type=float, default=0.0)
+ parser.add_argument("--fps", type=int, default=30)
args = parser.parse_args()
main(
headless=args.headless,
steps=args.steps,
scene=args.scene,
+ video=args.video,
+ vx=args.vx,
+ vy=args.vy,
+ wz=args.wz,
+ fps=args.fps,
)
diff --git a/mpx/utils/lipa_solver.py b/mpx/utils/lipa_solver.py
new file mode 100644
index 0000000..7eafd74
--- /dev/null
+++ b/mpx/utils/lipa_solver.py
@@ -0,0 +1,298 @@
+"""Adapter that exposes the Primal-Dual LIPA solver via the mpx solver API.
+
+mpx solvers all share the signature
+ solve(reference, parameter, W, x0, X0, U0, V0) -> (X, U, V)
+with V having shape (N+1, n). LIPA expects a different problem statement
+(`Variables` pytree, cost/dynamics with a (x, u, theta, t) signature, no
+externalised W/reference/parameter). This module bridges the two.
+
+Note on offline use vs mpx's other solvers: mpx's primal_dual / fddp do
+*one* SQP/iLQR step per call and rely on `run_offline_solve`'s outer loop
+to converge. LIPA is a complete NLP solver — its main loop schedules µ
+(IPM barrier) and η (per-constraint AL penalty) internally. Calling it
+many times restarts those parameters at every call (see
+`primal_dual_lipa.optimizers.solve` lines 78-81), wasting iterations and
+producing misleading benchmark numbers. So for offline mode use
+`run_lipa_offline`, which calls LIPA exactly once and reports its
+internal iteration count and wall time.
+"""
+
+from functools import partial
+from timeit import default_timer as timer
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+
+from primal_dual_lipa.optimizers import solve as lipa_solve
+from primal_dual_lipa.types import SolverSettings, Variables
+
+
+def _wrap_cost(cost):
+ def lipa_cost(W, reference, x, u, theta, t):
+ del theta
+ return cost(W, reference, x, u, t)
+
+ return lipa_cost
+
+
+def _wrap_dynamics(dynamics):
+ def lipa_dynamics(parameter, x, u, theta, t):
+ del theta
+ return dynamics(x, u, t, parameter=parameter)
+
+ return lipa_dynamics
+
+
+def _wrap_inequalities(inequalities):
+ def lipa_inequalities(reference, x, u, theta, t):
+ del theta
+ return inequalities(reference, x, u, t)
+
+ return lipa_inequalities
+
+
+def _empty_inequalities(reference, x, u, t):
+ del reference, x, u, t
+ return jnp.empty(0)
+
+
+@partial(jax.jit, static_argnames=("cost", "dynamics", "inequalities"))
+def _lipa_solve_with_stats(
+ cost, dynamics, inequalities, settings, reference, parameter, W, x0, X_in, U_in, V_in
+):
+ """Single LIPA call that returns the final variables plus solver stats.
+
+ `inequalities=None` keeps the prior behavior (no constraint blocks, ``g_dim=0``).
+ Otherwise the constraint shape is inferred from a trace-time evaluation of the
+ user callable on the warm-start sample.
+ """
+
+ lipa_cost = partial(_wrap_cost(cost), W, reference)
+ lipa_dynamics = partial(_wrap_dynamics(dynamics), parameter)
+
+ ineq_callable = inequalities if inequalities is not None else _empty_inequalities
+ lipa_inequalities = partial(_wrap_inequalities(ineq_callable), reference)
+
+ T = U_in.shape[0]
+ sample_g = lipa_inequalities(X_in[0], U_in[0], jnp.empty(0, dtype=X_in.dtype), 0)
+ g_dim = sample_g.shape[0]
+
+ vars_in = Variables(
+ X=X_in,
+ U=U_in,
+ S=jnp.zeros((T + 1, g_dim), dtype=X_in.dtype),
+ Y_dyn=V_in,
+ Y_eq=jnp.zeros((T + 1, 0), dtype=X_in.dtype),
+ Z=jnp.zeros((T + 1, g_dim), dtype=X_in.dtype),
+ Theta=jnp.empty(0, dtype=X_in.dtype),
+ )
+
+ vars_out, iterations, no_errors = lipa_solve(
+ vars_in=vars_in,
+ x0=x0,
+ cost=lipa_cost,
+ dynamics=lipa_dynamics,
+ inequalities=lipa_inequalities,
+ settings=settings,
+ )
+ return vars_out.X, vars_out.U, vars_out.Y_dyn, iterations, no_errors
+
+
+def _default_settings():
+ """Pick conservative defaults for an unseen problem.
+
+ The goal here is robustness, not peak performance. Aggressive
+ settings belong as per-config `lipa_settings` overrides.
+ """
+
+ on_gpu = any(d.platform == "gpu" for d in jax.devices())
+ common = dict(
+ max_iterations=2000,
+ η0=1e3,
+ η_update_factor=1.0,
+ µ_update_factor=0.9,
+ cost_improvement_threshold=1e-3,
+ primal_violation_threshold=1e-5,
+ )
+ if on_gpu:
+ return SolverSettings(
+ use_parallel_lqr=True,
+ num_parallel_line_search_steps=8,
+ **common,
+ )
+ return SolverSettings(**common)
+
+
+def build_lipa_solve(cost, dynamics, settings=None, *, inequalities=None):
+ """Return a `solve(reference, parameter, W, x0, X0, U0, V0) -> (X, U, V)`.
+
+ Used by online MPC (e.g. `MPCWrapper.run`). For offline benchmarks,
+ prefer `run_lipa_offline`, which is a single-call path that surfaces
+ LIPA's internal iteration count and avoids resetting µ/η repeatedly.
+
+ Defaults differ by backend (parallel LQR + parallel line search on GPU).
+ Override via `config.lipa_settings`. Pass `inequalities=callable(reference,
+ x, u, t) -> g` to enforce ``g <= 0`` constraints; omit to keep the prior
+ inequality-free behavior shared with the FDDP / primal-dual solvers.
+ """
+
+ if settings is None:
+ settings = _default_settings()
+
+ def solve(reference, parameter, W, x0, X0, U0, V0):
+ X, U, V, _iters, _no_errors = _lipa_solve_with_stats(
+ cost, dynamics, inequalities, settings, reference, parameter, W, x0, X0, U0, V0
+ )
+ return X, U, V
+
+ return solve
+
+
+def run_lipa_offline(
+ cost,
+ dynamics,
+ reference,
+ parameter,
+ W,
+ x0,
+ X0,
+ U0,
+ V0,
+ *,
+ settings=None,
+ inequalities=None,
+ warmup_cost=None,
+ warmup_settings=None,
+ warmup=True,
+ verbose=True,
+):
+ """Solve a single OCP with LIPA and return stats matching `run_offline_solve`.
+
+ Unlike `run_offline_solve`, which loops one-step solvers until cost
+ plateaus, this calls LIPA exactly once. Reported `n_iterations` is
+ LIPA's internal IPM iteration count.
+
+ Two-phase warm start: if `warmup_cost` is provided (typically the soft-
+ penalty version of `cost`), an initial LIPA solve is run on that
+ inequality-free formulation, then the main inequality-enforcing solve
+ starts from its result. This sidesteps a class of local-basin pitfalls
+ where the AL term η·Jᵀc dominates and the IPM parks at a degenerate
+ iterate (e.g. on barrel_roll, the multi-shooting quaternion defect at
+ the apex of the maneuver hits a sign-flip singularity that the cold-
+ start solve cannot escape). The warm-start phase uses the same LIPA
+ solver — this is not bootstrapping from a different solver.
+ """
+
+ from mpx.jax_ocp_solvers.jax_ocp_solvers import optimizers as ocp_opt
+
+ if settings is None:
+ settings = _default_settings()
+
+ offline_cost = partial(cost, W, reference)
+ offline_dynamics = partial(dynamics, parameter=parameter)
+ model_evaluator = jax.jit(
+ partial(ocp_opt.model_evaluator_helper, offline_cost, offline_dynamics, x0)
+ )
+
+ g0, c0 = model_evaluator(X0, U0)
+ initial_objective = float(g0)
+ initial_l2_cost = float(np.sqrt(np.sum(np.asarray(g0) * np.asarray(g0))))
+ initial_dynamics_violation = float(np.sum(np.asarray(c0) * np.asarray(c0)))
+
+ if verbose:
+ print("{:<10} {:<20} {:<20} {:<20}".format("Iter", "Cost", "Constraint", "Time [ms]"))
+ print("{:<10d} {:<20.5f} {:<20.5f} {:<20}".format(0, initial_l2_cost, initial_dynamics_violation, "-"))
+
+ do_warmup_phase = warmup_cost is not None and inequalities is not None
+ warmup_phase_settings = warmup_settings if warmup_settings is not None else settings
+ warmup_iters = 0
+ warmup_time_ms = 0.0
+
+ if do_warmup_phase:
+ # Phase 1: solve the inequality-free (soft-penalty) problem once and
+ # use its (X, U, V) as the warm start for phase 2. We deliberately do
+ # NOT call _lipa_solve_with_stats twice (warmup + timed) here — the
+ # parallel-LQR scan reduction is not bit-deterministic across
+ # back-to-back invocations of the same compiled function on the same
+ # inputs (different floating-point summation order can land on
+ # numerically different iterates), and on stiff problems like
+ # h1_jump_forward that's enough drift to make phase 2 sometimes
+ # converge in 100 iters and sometimes hit max_iterations. The trade
+ # here is mildly inaccurate phase-1 wall-time accounting (first call
+ # includes any JIT compile that wasn't already cached) for
+ # reproducible phase-2 starting iterates.
+ start = timer()
+ Xp1, Up1, Vp1, iters_p1, _ = _lipa_solve_with_stats(
+ warmup_cost, dynamics, None, warmup_phase_settings,
+ reference, parameter, W, x0, X0, U0, V0,
+ )
+ Xp1.block_until_ready()
+ warmup_time_ms = 1e3 * (timer() - start)
+ warmup_iters = int(iters_p1)
+ if verbose:
+ print(
+ "{:<10s} {:<20s} {:<20s} {:<20.5f}".format(
+ "ph1", "(warmup)", "(warmup)", warmup_time_ms
+ )
+ )
+ print(f" Phase 1 (soft-penalty warm start): {warmup_iters} iters")
+ # Phase 2 starts from phase 1's iterate.
+ X0, U0, V0 = Xp1, Up1, Vp1
+
+ if warmup and not do_warmup_phase:
+ # Single-phase mode: traditional warmup-then-timed pattern.
+ Xw, _, _, _, _ = _lipa_solve_with_stats(
+ cost, dynamics, inequalities, settings, reference, parameter, W, x0, X0, U0, V0
+ )
+ Xw.block_until_ready()
+
+ start = timer()
+ X, U, V, iterations, no_errors = _lipa_solve_with_stats(
+ cost, dynamics, inequalities, settings, reference, parameter, W, x0, X0, U0, V0
+ )
+ X.block_until_ready()
+ stop = timer()
+ iteration_time_ms = 1e3 * (stop - start)
+
+ g, c = model_evaluator(X, U)
+ final_objective = float(g)
+ final_l2_cost = float(np.sqrt(np.sum(np.asarray(g) * np.asarray(g))))
+ final_dynamics_violation = float(np.sum(np.asarray(c) * np.asarray(c)))
+ n_iters = int(iterations) + warmup_iters
+ converged = bool(no_errors)
+
+ if verbose:
+ print(
+ "{:<10d} {:<20.5f} {:<20.5f} {:<20.5f}".format(
+ 1, final_l2_cost, final_dynamics_violation, iteration_time_ms
+ )
+ )
+ if do_warmup_phase:
+ print(
+ f" Phase 2 (constrained): {int(iterations)} iters, no_errors: {converged}\n"
+ f" Total LIPA internal iterations: {n_iters}"
+ )
+ else:
+ print(f" LIPA internal iterations: {n_iters}, no_errors: {converged}")
+
+ history = [X0, X]
+ stats = {
+ "n_iterations": n_iters,
+ "warmup_iterations": warmup_iters,
+ "converged": converged,
+ "warmup_discarded": warmup,
+ "objective_history": [initial_objective, final_objective],
+ "l2_cost_history": [initial_l2_cost, final_l2_cost],
+ "dynamics_violation_history": [initial_dynamics_violation, final_dynamics_violation],
+ "metric_iteration_history": [0, 1],
+ "iteration_time_ms_history": [iteration_time_ms + warmup_time_ms],
+ "initial_objective": initial_objective,
+ "initial_l2_cost": initial_l2_cost,
+ "initial_dynamics_violation": initial_dynamics_violation,
+ "average_iteration_time_ms": iteration_time_ms + warmup_time_ms,
+ "final_objective": final_objective,
+ "final_l2_cost": final_l2_cost,
+ "final_dynamics_violation": final_dynamics_violation,
+ }
+ return X, U, V, history, stats
diff --git a/mpx/utils/mpc_wrapper.py b/mpx/utils/mpc_wrapper.py
index 4ca4b59..157cfc7 100644
--- a/mpx/utils/mpc_wrapper.py
+++ b/mpx/utils/mpc_wrapper.py
@@ -7,6 +7,7 @@
from mujoco.mjx._src.dataclasses import PyTreeNode
from mpx.jax_ocp_solvers.jax_ocp_solvers import optimizers
+from mpx.utils.lipa_solver import build_lipa_solve, run_lipa_offline
import mpx.utils.offline_solver as offline_solver
import mpx.utils.mpc_utils as mpc_utils
@@ -28,6 +29,41 @@ class MPCData(PyTreeNode):
mpx_data = MPCData
+def lipa_pick_cost_and_inequalities(config, cost):
+ """Pick the LIPA call configuration based on the config.
+
+ Returns ``(main_cost, inequalities, main_settings, warmup_cost,
+ warmup_settings)``:
+
+ * Off path (no enforce): main = ``cost`` (the soft-penalty cost), no
+ inequalities, settings from ``config.lipa_settings``.
+ * Enforce path: main = ``cost_smooth + inequalities`` with
+ ``config.lipa_settings_enforce or config.lipa_settings``. The warm-start
+ pair (``cost``, ``lipa_settings``) is also returned so the offline path
+ can do a two-phase solve — phase 1 on the inequality-free formulation,
+ phase 2 on the constrained one starting from phase 1's iterate. This
+ sidesteps local-basin pitfalls (notably the multi-shooting quaternion
+ singularity at the apex of the barrel-roll maneuver) without
+ bootstrapping from a different solver.
+
+ Configs opt in by setting ``lipa_enforce_inequalities = True`` and
+ providing both ``cost_smooth`` and ``inequalities``.
+ """
+ enforce = getattr(config, "lipa_enforce_inequalities", False)
+ base_settings = getattr(config, "lipa_settings", None)
+ if not enforce:
+ return cost, None, base_settings, None, None
+ cost_smooth = getattr(config, "cost_smooth", None)
+ inequalities = getattr(config, "inequalities", None)
+ if cost_smooth is None or inequalities is None:
+ raise ValueError(
+ "lipa_enforce_inequalities=True requires both `cost_smooth` and "
+ "`inequalities` to be defined on the config."
+ )
+ enforce_settings = getattr(config, "lipa_settings_enforce", None) or base_settings
+ return cost_smooth, inequalities, enforce_settings, cost, base_settings
+
+
def build_solver_step(config, cost, dynamics, hessian_approx, limited_memory):
solver_mode = getattr(config, "solver_mode", "primal_dual")
@@ -54,6 +90,19 @@ def solve(reference, parameter, W, x0, X0, U0, V0):
return solver_mode, solve
+ if solver_mode == "lipa":
+ # Online MPC stays single-phase: per-step warm-start via the data
+ # carry already chains across calls, and a per-step phase-1 would
+ # double the compile + per-step compute. The two-phase flow is
+ # offline-only (see run_lipa_offline / runOffline).
+ lipa_cost, lipa_inequalities, lipa_settings, _, _ = lipa_pick_cost_and_inequalities(
+ config, cost
+ )
+ solve = build_lipa_solve(
+ lipa_cost, dynamics, settings=lipa_settings, inequalities=lipa_inequalities
+ )
+ return solver_mode, solve
+
raise ValueError(f"Unsupported MPC solver_mode: {solver_mode}")
@@ -326,21 +375,49 @@ def runOffline(self, qpos, qvel, *, return_stats=False, verbose=True, max_iter=1
U0 = self.initial_U0
V0 = self.initial_V0
- X0, U0, _, output, stats = offline_solver.run_offline_solve(
- self._solve,
- self.cost,
- self.dynamics,
- self.config.solver_mode,
- reference,
- parameter,
- W,
- x0,
- X0,
- U0,
- V0,
- max_iter=max_iter,
- verbose=verbose,
- )
+ if self.solver_mode == "lipa":
+ # LIPA is a complete NLP solver; one call converges. Looping it
+ # restarts the IPM µ/η each time, which inflates "iterations"
+ # and wall time without improving the solution.
+ (
+ lipa_cost,
+ lipa_inequalities,
+ lipa_settings,
+ lipa_warmup_cost,
+ lipa_warmup_settings,
+ ) = lipa_pick_cost_and_inequalities(self.config, self.cost)
+ X0, U0, _, output, stats = run_lipa_offline(
+ lipa_cost,
+ self.dynamics,
+ reference,
+ parameter,
+ W,
+ x0,
+ X0,
+ U0,
+ V0,
+ settings=lipa_settings,
+ inequalities=lipa_inequalities,
+ warmup_cost=lipa_warmup_cost,
+ warmup_settings=lipa_warmup_settings,
+ verbose=verbose,
+ )
+ else:
+ X0, U0, _, output, stats = offline_solver.run_offline_solve(
+ self._solve,
+ self.cost,
+ self.dynamics,
+ self.config.solver_mode,
+ reference,
+ parameter,
+ W,
+ x0,
+ X0,
+ U0,
+ V0,
+ max_iter=max_iter,
+ verbose=verbose,
+ )
if return_stats:
return X0, U0, reference, output, stats
diff --git a/mpx/utils/objectives.py b/mpx/utils/objectives.py
index 9c14941..b69c3fa 100644
--- a/mpx/utils/objectives.py
+++ b/mpx/utils/objectives.py
@@ -74,8 +74,40 @@ def friction_constraint(u):
H_constraint = J_friction_cone(u).T@H_penalty@J_friction_cone(u)
return J_x(x,u).T@W@J_x(x,u), J_u(x,u).T@W@J_u(x,u) + H_constraint, J_x(x,u).T@W@J_u(x,u)
-def quadruped_wb_obj(swing_tracking,n_joints,n_contact,N,W,reference,x, u, t):
-
+def _quadruped_wb_constraint_slacks(n_joints, n_contact, mu, torque_limit, dq_limit, x, u, friction_eps=1e-2):
+ grf = x[13 + 2 * n_joints + 3 * n_contact:]
+ tau = u[:n_joints]
+ dq = x[13 + n_joints:13 + 2 * n_joints]
+ Fx = grf[0::3]
+ Fy = grf[1::3]
+ Fz = grf[2::3]
+ s_friction = mu * Fz - jnp.sqrt(jnp.square(Fx) + jnp.square(Fy) + jnp.ones(n_contact) * friction_eps)
+ sym = jnp.kron(jnp.eye(n_joints), jnp.array([-1.0, 1.0])).T
+ s_torque = sym @ tau + (torque_limit + 1e-2)
+ s_dq = sym @ dq + (dq_limit + 1e-2)
+ return s_friction, s_torque, s_dq
+
+
+def quadruped_wb_inequalities(
+ n_joints, n_contact, mu, torque_limit, dq_limit, reference, x, u, t, friction_eps=1e-12
+):
+ """LIPA-form inequalities ``g(x,u,t) <= 0`` for the quadruped whole-body problem.
+
+ Friction is gated by the reference contact mask (vacuous in swing); torque and
+ joint-speed limits are always active. At the terminal stage there is no control
+ input, so all entries collapse to zero.
+ """
+ s_friction, s_torque, s_dq = _quadruped_wb_constraint_slacks(
+ n_joints, n_contact, mu, torque_limit, dq_limit, x, u, friction_eps=friction_eps
+ )
+ contact = reference[t, 13 + n_joints + 3 * n_contact:13 + n_joints + 4 * n_contact]
+ g = jnp.concatenate([-contact * s_friction, -s_torque, -s_dq])
+ N = reference.shape[0] - 1
+ return jnp.where(t == N, jnp.zeros_like(g), g)
+
+
+def quadruped_wb_smooth_cost(swing_tracking, n_joints, n_contact, N, W, reference, x, u, t):
+ """Stage cost without any soft-inequality penalties (friction/torque/dq)."""
p = x[:3]
quat = x[3:7]
q = x[7:7+n_joints]
@@ -94,18 +126,6 @@ def quadruped_wb_obj(swing_tracking,n_joints,n_contact,N,W,reference,x, u, t):
p_leg_ref = reference[t,13+n_joints:13+n_joints+3*n_contact]
contact = reference[t,13+n_joints+3*n_contact:13+n_joints+4*n_contact]
grf_ref = reference[t,13+n_joints+4*n_contact:13+n_joints+7*n_contact]
- mu = 0.5
- friction_cone = mu*grf[2::3] - jnp.sqrt(jnp.square(grf[1::3]) + jnp.square(grf[::3]) + jnp.ones(n_contact)*1e-2)
- friction_cone = penalty(friction_cone)
- torque_limits = jnp.array([
- 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44,
- 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44 ])
- #min grf
- # min_force = grf[2::3] - jnp.ones(n_contact)*10
- torque_limits = jnp.kron(jnp.eye(n_joints),(jnp.array([-1,1]))).T@tau+torque_limits + jnp.ones_like(torque_limits)*1e-2
-
- joint_speed_limits = jnp.ones(2*n_joints)*10
- joint_speed_limits = jnp.kron(jnp.eye(n_joints),(jnp.array([-1,1]))).T@dq + joint_speed_limits + jnp.ones_like(joint_speed_limits)*1e-2
if swing_tracking:
contact_map = jnp.ones(3*n_contact)
@@ -116,14 +136,26 @@ def quadruped_wb_obj(swing_tracking,n_joints,n_contact,N,W,reference,x, u, t):
(dp - dp_ref).T @ W[6+n_joints:9+n_joints,6+n_joints:9+n_joints] @ (dp - dp_ref) + (omega - omega_ref).T @ W[9+n_joints:12+n_joints,9+n_joints:12+n_joints] @ (omega - omega_ref) + dq.T @ W[12+n_joints:12+2*n_joints,12+n_joints:12+2*n_joints] @ dq +\
(contact_map*(p_leg - p_leg_ref)).T @W[12+2*n_joints:12+2*n_joints+3*n_contact,12+2*n_joints:12+2*n_joints+3*n_contact]@ (contact_map*(p_leg - p_leg_ref))+ \
tau.T @ W[12+2*n_joints+3*n_contact:12+3*n_joints+3*n_contact,12+2*n_joints+3*n_contact:12+3*n_joints+3*n_contact] @ tau +\
- (grf-grf_ref).T @ W[12+3*n_joints+3*n_contact:12+3*n_joints+6*n_contact,12+3*n_joints+3*n_contact:12+3*n_joints+6*n_contact] @ (grf-grf_ref) +\
- jnp.sum(penalty(torque_limits,1,1)) + jnp.sum(friction_cone*contact) + jnp.sum(penalty(joint_speed_limits,1,1))
+ (grf-grf_ref).T @ W[12+3*n_joints+3*n_contact:12+3*n_joints+6*n_contact,12+3*n_joints+3*n_contact:12+3*n_joints+6*n_contact] @ (grf-grf_ref)
term_cost = (p - p_ref).T @ W[:3,:3] @ (p - p_ref) + math.quat_sub(quat,quat_ref).T@W[3:6,3:6]@math.quat_sub(quat,quat_ref) + (q - q_ref).T @ W[6:6+n_joints,6:6+n_joints] @ (q - q_ref) +\
(dp - dp_ref).T @ W[6+n_joints:9+n_joints,6+n_joints:9+n_joints] @ (dp - dp_ref) + (omega - omega_ref).T @ W[9+n_joints:12+n_joints,9+n_joints:12+n_joints] @ (omega - omega_ref) + dq.T @ W[12+n_joints:12+2*n_joints,12+n_joints:12+2*n_joints] @ dq
-
return jnp.where(t == N, 0.5 * term_cost, 0.5 * stage_cost)
+
+def quadruped_wb_obj(swing_tracking, n_joints, n_contact, N, W, reference, x, u, t):
+ smooth = quadruped_wb_smooth_cost(swing_tracking, n_joints, n_contact, N, W, reference, x, u, t)
+ s_friction, s_torque, s_dq = _quadruped_wb_constraint_slacks(
+ n_joints, n_contact, 0.5, 44.0, 10.0, x, u
+ )
+ contact = reference[t, 13 + n_joints + 3 * n_contact:13 + n_joints + 4 * n_contact]
+ soft = (
+ jnp.sum(penalty(s_friction) * contact)
+ + jnp.sum(penalty(s_torque, 1, 1))
+ + jnp.sum(penalty(s_dq, 1, 1))
+ )
+ return smooth + jnp.where(t == N, 0.0, 0.5 * soft)
+
def quadruped_wb_hessian_gn(swing_tracking,n_joints,n_contact,W,reference,x, u, t):
contact = reference[t,13+n_joints+3*n_contact:13+n_joints+4*n_contact]
@@ -323,7 +355,43 @@ def torque_constraint(u):
return J_x(x,u).T@W@J_x(x,u), J_u(x,u).T@W@J_u(x,u), J_x(x,u).T@W@J_u(x,u)
-def h1_kinodynamic_obj(n_joints, n_contact, N, W, reference, x, u, t):
+def _h1_kinodynamic_friction_slack(n_joints, n_contact, mu, u, friction_eps=1e-1):
+ grf = u[n_joints:]
+ Fx = grf[0::3]
+ Fy = grf[1::3]
+ Fz = grf[2::3]
+ return mu * Fz - jnp.sqrt(jnp.square(Fx) + jnp.square(Fy) + jnp.ones(n_contact) * friction_eps)
+
+
+def h1_kinodynamic_inequalities(n_joints, n_contact, mu, reference, x, u, t, friction_eps=1e-12):
+ """LIPA-form ``g <= 0`` inequalities for the H1 kinodynamic problem.
+
+ Two physical constraints, both gated by the reference contact mask
+ (vacuous during swing):
+
+ * ``Fz >= 0`` — a foot can only push into the ground, not pull. Without
+ this, the soft-penalty optimizer happily uses negative Fz to "anchor"
+ the foot, which produces unphysical jump take-offs and breaks the
+ Coulomb-cone interpretation: with Fz < 0 and `g = mu*Fz - sqrt(Fx²+Fy²)`
+ the cone becomes infeasible by `≈ |mu*Fz|` regardless of (Fx, Fy).
+ * Friction cone: ``sqrt(Fx² + Fy²) <= mu * Fz``.
+
+ Other limits (joint-velocity, torque) live elsewhere and are not enforced
+ as constraints in this solver.
+ """
+ grf = u[n_joints:]
+ Fz = grf[2::3]
+ s_friction = _h1_kinodynamic_friction_slack(n_joints, n_contact, mu, u, friction_eps=friction_eps)
+ contact = reference[t, 13 + n_joints + 3 * n_contact:13 + n_joints + 4 * n_contact]
+ g_friction = -contact * s_friction
+ g_fz = -contact * Fz
+ g = jnp.concatenate([g_friction, g_fz])
+ N = reference.shape[0] - 1
+ return jnp.where(t == N, jnp.zeros_like(g), g)
+
+
+def h1_kinodynamic_smooth_cost(n_joints, n_contact, N, W, reference, x, u, t):
+ """H1 kinodynamic stage cost with the friction soft-penalty stripped out."""
p = x[:3]
quat = x[3:7]
@@ -342,14 +410,8 @@ def h1_kinodynamic_obj(n_joints, n_contact, N, W, reference, x, u, t):
dp_ref = reference[t,7+n_joints:10+n_joints]
omega_ref = reference[t,10+n_joints:13+n_joints]
p_leg_ref = reference[t,13+n_joints:13+n_joints+3*n_contact]
- contact = reference[t,13+n_joints+3*n_contact:13+n_joints+4*n_contact]
grf_ref = reference[t,13+n_joints+4*n_contact:13+n_joints+7*n_contact]
- mu = 0.7
- friction_cone = mu * grf[2::3] - jnp.sqrt(
- jnp.square(grf[1::3]) + jnp.square(grf[::3]) + jnp.ones(n_contact) * 1e-1
- )
-
stage_cost = (
(p - p_ref).T @ W[:3,:3] @ (p - p_ref)
+ math.quat_sub(quat,quat_ref).T @ W[3:6,3:6] @ math.quat_sub(quat,quat_ref)
@@ -366,7 +428,6 @@ def h1_kinodynamic_obj(n_joints, n_contact, N, W, reference, x, u, t):
+ (grf - grf_ref).T
@ W[12+3*n_joints+3*n_contact:12+3*n_joints+6*n_contact,12+3*n_joints+3*n_contact:12+3*n_joints+6*n_contact]
@ (grf - grf_ref)
- + jnp.sum(penalty(friction_cone) * contact)
)
term_cost = (
(p - p_ref).T @ W[:3,:3] @ (p - p_ref)
@@ -379,6 +440,14 @@ def h1_kinodynamic_obj(n_joints, n_contact, N, W, reference, x, u, t):
return jnp.where(t == N, 0.5 * term_cost, 0.5 * stage_cost)
+
+def h1_kinodynamic_obj(n_joints, n_contact, N, W, reference, x, u, t):
+ smooth = h1_kinodynamic_smooth_cost(n_joints, n_contact, N, W, reference, x, u, t)
+ s_friction = _h1_kinodynamic_friction_slack(n_joints, n_contact, 0.7, u)
+ contact = reference[t, 13 + n_joints + 3 * n_contact:13 + n_joints + 4 * n_contact]
+ soft = jnp.sum(penalty(s_friction) * contact)
+ return smooth + jnp.where(t == N, 0.0, 0.5 * soft)
+
def talos_wb_obj(n_joints,n_contact,N,W,reference,x, u, t):
p = x[:3]
diff --git a/mpx/utils/sim.py b/mpx/utils/sim.py
index 55b2aae..02dc0fb 100644
--- a/mpx/utils/sim.py
+++ b/mpx/utils/sim.py
@@ -393,3 +393,84 @@ def render_ghost_trajectory(
)
return ghost_geoms, scratch_data
+
+
+class VideoRecorder:
+ """Offscreen mp4 recorder built around `mujoco.Renderer`.
+
+ Designed for headless execution of the online MPC examples (mjx_quad,
+ mjx_h1, ...). Tracks the robot base by default (lookat = qpos[:3]).
+
+ Requires:
+ * `MUJOCO_GL=egl` (or another working backend) set BEFORE `import mujoco`
+ — see `enable_offscreen_gl_for_video()`.
+ * `imageio[ffmpeg]` for libx264 mp4 output.
+ """
+
+ def __init__(
+ self,
+ model: mujoco.MjModel,
+ path: str,
+ *,
+ fps: int = 30,
+ width: int = 640,
+ height: int = 480,
+ distance: float = 3.0,
+ azimuth: float = 90.0,
+ elevation: float = -20.0,
+ bit_depth: int = 10,
+ ):
+ import imageio # late import: only needed when recording
+
+ self._renderer = mujoco.Renderer(model, height=height, width=width)
+
+ # Pick pixel format based on bit depth
+ pix_fmt = "yuv420p10le" if bit_depth == 10 else "yuv420p"
+
+ self._writer = imageio.get_writer(
+ path,
+ format="FFMPEG",
+ codec="libx264",
+ fps=fps,
+ macro_block_size=1,
+ output_params=["-pix_fmt", pix_fmt],
+ )
+ self._cam = mujoco.MjvCamera()
+ self._cam.distance = float(distance)
+ self._cam.azimuth = float(azimuth)
+ self._cam.elevation = float(elevation)
+ self._cam.lookat[:] = [0.0, 0.0, 0.0]
+
+ def capture(self, data: mujoco.MjData, lookat: np.ndarray | None = None) -> None:
+ """Render and append one frame.
+
+ Default `lookat` is the floating-base world position (`qpos[:3]`) when
+ the model has at least 3 generalised coords — the typical case for
+ legged-robot configs. For low-dimensional models (e.g. acrobot's
+ 2-DOF qpos), falls back to the world body's xpos so the camera stays
+ pointed at something sensible instead of crashing on the reshape.
+ """
+
+ if lookat is None:
+ qpos = np.asarray(data.qpos, dtype=np.float64)
+ if qpos.size >= 3:
+ lookat = qpos[:3]
+ else:
+ # Centre on the first non-world body's world position. Always 3D.
+ lookat = np.asarray(data.xpos[1] if data.xpos.shape[0] > 1 else data.xpos[0],
+ dtype=np.float64)
+ self._cam.lookat[:] = np.asarray(lookat, dtype=np.float64).reshape(3)
+ self._renderer.update_scene(data, self._cam)
+ self._writer.append_data(self._renderer.render())
+
+ def close(self) -> None:
+ try:
+ self._writer.close()
+ finally:
+ self._renderer.close()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *_exc):
+ self.close()
diff --git a/pyproject.toml b/pyproject.toml
index 1bd39d6..b6a1ac3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -16,7 +16,8 @@ dependencies = [
"jax[cuda12]",
"mujoco",
"mujoco-mjx",
- "trajax @ git+https://github.com/google/trajax"
+ "trajax @ git+https://github.com/google/trajax",
+ "primal-dual-lipa @ git+https://github.com/joaospinto/primal-dual-lipa"
]
[project.urls]