Skip to content

Commit bcc20e4

Browse files
yzhautouskayatharvajoshi10github-actions[bot]yiyixuxu
authored
Add Cosmos3 action generation support (#13823)
* Add Cosmos3 action generation support * Add README action examples * Use do_classifier_free_guidance property * Remove unused method * Add action policy example to pipelines doc * Adding model selection for action example doc. * Remove redundant casts * Rename _pack_action_tokens to _prepare_action_segment * Move validation checks to check_inputs * Add action arguments in the __call__ docstring * Move action mode check to check_inputs * Rename action to action_tokens * Add warning for num_frames ovewrite attempt * Rename action_tokens to raw_actions * Remove scheduler config override * Refactor action to use CosmosActionCondition * Fix examples script to support flow_shift arg * Apply styling fixes * Remove CosmosActionCondition properties, move to pipeline * Replace validate wiht post init * Set height/width/num_frames to None, raise error if set ofr action * Fix action_dim default setting * Remove video argument before v2v is added * Fix None args * Add _EMBODIMENT_TO_RAW_ACTION_DIM mapping * Remove --raw-action-dim from README.md * Added prompt upsampler docs and examples * Apply style fixes * Bugfix for action mrope compression factor * Add action prompt json formatting, update docs * Add hand_pose raw_action_dim value * Add missing args to docstrings --------- Co-authored-by: Atharva Joshi <atjoshi@nvidia.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
1 parent 25b85c1 commit bcc20e4

9 files changed

Lines changed: 1362 additions & 322 deletions

File tree

docs/source/en/api/pipelines/cosmos3.md

Lines changed: 234 additions & 225 deletions
Large diffs are not rendered by default.

examples/cosmos3/README.md

Lines changed: 112 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,123 @@ python examples/cosmos3/inference_cosmos3.py \
4848
--enable-sound
4949
```
5050

51+
Action forward dynamics, robot domain (predict video from an observation video and a provided action chunk):
52+
53+
```bash
54+
python examples/cosmos3/inference_cosmos3.py \
55+
--model nano \
56+
--prompt "Put the pot to the left of the purple item." \
57+
--vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_0.mp4" \
58+
--action-mode forward_dynamics \
59+
--action-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_0.json" \
60+
--action-chunk-size 16 \
61+
--domain-name bridge_orig_lerobot \
62+
--resolution-tier 480 --fps 5 \
63+
--num-inference-steps 30 --guidance-scale 1.0 --flow-shift 10.0 --seed 0 \
64+
--output results/cosmos3_forward_dynamics_robot
65+
```
66+
67+
Action forward dynamics, autonomous-vehicle domain:
68+
69+
```bash
70+
python examples/cosmos3/inference_cosmos3.py \
71+
--model nano \
72+
--prompt "You are an autonomous vehicle planning system." \
73+
--vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/av_vision_25_73d01c91-51f0-46cf-9b76-5682a76fb349.mp4" \
74+
--action-mode forward_dynamics \
75+
--action-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/av_action_25.json" \
76+
--action-chunk-size 60 \
77+
--domain-name av \
78+
--resolution-tier 480 --fps 10 \
79+
--num-inference-steps 30 --guidance-scale 1.0 --flow-shift 10.0 --seed 0 \
80+
--output results/cosmos3_forward_dynamics_av
81+
```
82+
83+
Action inverse dynamics, robot domain (predict actions from an observed video):
84+
85+
```bash
86+
python examples/cosmos3/inference_cosmos3.py \
87+
--model nano \
88+
--prompt "Put the pot to the left of the purple item." \
89+
--vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_0.mp4" \
90+
--action-mode inverse_dynamics \
91+
--action-chunk-size 16 \
92+
--domain-name bridge_orig_lerobot \
93+
--resolution-tier 480 --fps 5 \
94+
--num-inference-steps 30 --guidance-scale 1.0 --flow-shift 10.0 --seed 0 \
95+
--output results/cosmos3_inverse_dynamics_robot
96+
```
97+
98+
Action inverse dynamics, autonomous-vehicle domain:
99+
100+
```bash
101+
python examples/cosmos3/inference_cosmos3.py \
102+
--model nano \
103+
--prompt "You are an autonomous vehicle planning system." \
104+
--vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/av_vision_25_73d01c91-51f0-46cf-9b76-5682a76fb349.mp4" \
105+
--action-mode inverse_dynamics \
106+
--action-chunk-size 60 \
107+
--domain-name av \
108+
--resolution-tier 480 --fps 10 \
109+
--num-inference-steps 30 --guidance-scale 1.0 --flow-shift 10.0 --seed 0 \
110+
--output results/cosmos3_inverse_dynamics_av
111+
```
112+
113+
Action policy, robot domain (predict both future video and actions from the first observation frame):
114+
115+
```bash
116+
python examples/cosmos3/inference_cosmos3.py \
117+
--model nano \
118+
--prompt "Put the pot to the left of the purple item." \
119+
--vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/bridge_0.mp4" \
120+
--action-mode policy \
121+
--action-chunk-size 16 \
122+
--domain-name bridge_orig_lerobot \
123+
--resolution-tier 480 --fps 5 \
124+
--num-inference-steps 30 --guidance-scale 1.0 --flow-shift 10.0 --seed 0 \
125+
--output results/cosmos3_policy_robot
126+
```
127+
128+
Action policy, autonomous-vehicle domain:
129+
130+
```bash
131+
python examples/cosmos3/inference_cosmos3.py \
132+
--model nano \
133+
--prompt "You are an autonomous vehicle planning system. Please go backward." \
134+
--vision-path "https://github.com/nvidia-cosmos/cosmos-dependencies/raw/refs/heads/assets/cosmos3/inputs/action/av_vision_25_73d01c91-51f0-46cf-9b76-5682a76fb349.mp4" \
135+
--action-mode policy \
136+
--action-chunk-size 60 \
137+
--domain-name av \
138+
--resolution-tier 480 --fps 10 \
139+
--num-inference-steps 30 --guidance-scale 1.0 --flow-shift 10.0 --seed 0 \
140+
--output results/cosmos3_policy_av
141+
```
142+
143+
Action modes use `action_chunk_size + 1` conditioning frames. `forward_dynamics` consumes `--action-path`; `inverse_dynamics` and `policy` write predicted actions to `sample_action.json` in model-normalized action space. This script loads `--vision-path` as a video for all action modes; `policy` and `forward_dynamics` condition only on the first frame, while `inverse_dynamics` uses the whole video.
144+
145+
Pass `--prompt` as a plain task description and select the camera perspective with `--view-point` (default `ego_view`); the pipeline builds the structured action caption (task, viewpoint, duration, FPS, resolution) the model was trained on. Do not hand-write the viewpoint sentence into `--prompt`.
146+
147+
`--resolution-tier` is a resolution *tier* (`256`/`480`/`704`/`720`). The tier keys a table of predefined aspect-ratio canvases; the one closest to the input aspect ratio becomes the padded conditioning canvas. It is not the output frame size: the input is downscaled (never upscaled) and padded to fill the canvas, then the padding is cropped from the latents so the decoded output follows the downscaled input content. `--height` / `--width` (and `--num-frames`) are ignored for action modes.
148+
149+
Pick the tier that matches the native resolution of your conditioning input (`480` for ~480p, `720` for ~720p). A tier below your input downscales it and discards detail; a tier above your input gains no resolution (content is never upscaled), wastes compute on padding, and is a train/inference distribution mismatch that can degrade quality.
150+
51151
### Useful flags
52152

53153
| Flag | Default | Description |
54154
|---|---|---|
55155
| `--prompt` | (required) | Text prompt. |
56-
| `--vision-path` | `None` | URL or local path for an image-conditioning frame (image-to-video). |
57-
| `--num-frames` | `189` | `1` = image, otherwise number of video frames (`189` ≈ 7.9 s @ 24 FPS). |
58-
| `--height` / `--width` | `720` / `1280` | Output resolution (must be a multiple of the VAE spatial scale factor). |
156+
| `--vision-path` | `None` | URL or local path for an image-conditioning frame (image-to-video), or the image/video conditioning for action modes. |
157+
| `--num-frames` | `189` | `1` = image, otherwise number of video frames (`189` ≈ 7.9 s @ 24 FPS). Ignored for action modes (derived from `--action-chunk-size`). |
158+
| `--height` / `--width` | `720` / `1280` | Output resolution (must be a multiple of the VAE spatial scale factor). Ignored for action modes; use `--resolution-tier`. |
159+
| `--resolution-tier` | `480` | Action resolution tier (`256`/`480`/`704`/`720`): selects the aspect bin / padded conditioning canvas, not the output size. |
59160
| `--fps` | `24.0` | Frame rate of the generated video. |
161+
| `--flow-shift` | `None` | Override `UniPCMultistepScheduler.flow_shift` (and force `use_karras_sigmas=False`); left at the checkpoint default when unset. Cosmos3 runs use `10.0`. |
60162
| `--enable-sound` | off | Generate a synchronized audio track. |
61-
| `--no-duration-template` | off | Skip the duration metadata sentence appended to the prompt and negative prompt. Ignored for `--num-frames 1`. |
62-
| `--no-resolution-template` | off | Skip the resolution metadata sentence appended to the prompt and negative prompt. |
163+
| `--action-mode` | `None` | Enable action conditioning/generation. One of `forward_dynamics`, `inverse_dynamics`, or `policy`. |
164+
| `--action-path` | `None` | URL or local JSON action path for `forward_dynamics`. |
165+
| `--action-chunk-size` | `None` | Number of action tokens. Action runs generate/use `action_chunk_size + 1` video frames. |
166+
| `--domain-name` | `None` | Action embodiment domain, for example `bridge_orig_lerobot` or `av`. |
167+
| `--view-point` | `ego_view` | Camera perspective for the action caption's framing (`ego_view`, `third_person_view`, `wrist_view`, `concat_view`). Action only. |
168+
| `--no-duration-template` | off | Skip the duration metadata sentence appended to the prompt and negative prompt. Ignored for `--num-frames 1` and for action modes (which build a structured caption instead). |
169+
| `--no-resolution-template` | off | Skip the resolution metadata sentence appended to the prompt and negative prompt. Ignored for action modes. |
63170
| `--output` | `.` | Directory to write `sample.jpg` or `sample.mp4`. |

examples/cosmos3/inference_cosmos3.py

Lines changed: 125 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@
2323
"""
2424

2525
import argparse
26+
import json
2627
import pathlib
28+
import urllib.request
2729

2830
import torch
2931
from huggingface_hub import snapshot_download
3032

31-
from diffusers import Cosmos3OmniPipeline
32-
from diffusers.utils import encode_video, export_to_video, load_image
33+
from diffusers import Cosmos3OmniPipeline, CosmosActionCondition, UniPCMultistepScheduler
34+
from diffusers.utils import encode_video, export_to_video, load_image, load_video
3335

3436

3537
HF_REPOS = {
@@ -38,6 +40,22 @@
3840
}
3941

4042

43+
def _load_action(path: str | None):
44+
if path is None:
45+
raise ValueError("--action-path is required for forward_dynamics mode.")
46+
if path.startswith(("http://", "https://")):
47+
with urllib.request.urlopen(path) as response:
48+
action = json.loads(response.read().decode("utf-8"))
49+
else:
50+
action = json.loads(pathlib.Path(path).read_text())
51+
tensor = torch.as_tensor(action, dtype=torch.float32)
52+
if tensor.ndim == 3 and tensor.shape[0] == 1:
53+
tensor = tensor.squeeze(0)
54+
if tensor.ndim != 2:
55+
raise ValueError(f"Cosmos3 action must have shape [T, D], got {tuple(tensor.shape)}.")
56+
return tensor
57+
58+
4159
def main():
4260
parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
4361
parser.add_argument("--prompt", required=True, help="Text prompt.")
@@ -50,24 +68,68 @@ def main():
5068
parser.add_argument(
5169
"--vision-path",
5270
default=None,
53-
help="Optional URL or local path for an image-conditioning frame (enables image-to-video).",
71+
help="Optional URL or local path for an image-conditioning frame, or an action conditioning video.",
5472
)
5573
parser.add_argument("--output", default=".", help="Directory to save generated video/image/audio files.")
56-
parser.add_argument("--height", type=int, default=720)
57-
parser.add_argument("--width", type=int, default=1280)
74+
parser.add_argument(
75+
"--height",
76+
type=int,
77+
default=None,
78+
help="Output height in pixels (default 720). Ignored for action modes; use --resolution-tier instead.",
79+
)
80+
parser.add_argument(
81+
"--width",
82+
type=int,
83+
default=None,
84+
help="Output width in pixels (default 1280). Ignored for action modes; use --resolution-tier instead.",
85+
)
5886
parser.add_argument(
5987
"--num-frames",
6088
type=int,
6189
default=189,
6290
help="Number of frames to generate. Use 1 for text-to-image; defaults to 189 for video (≈ 7.9s @ 24 FPS).",
6391
)
6492
parser.add_argument("--fps", type=float, default=24.0)
93+
parser.add_argument("--guidance-scale", type=float, default=6.0, help="Classifier-free guidance scale.")
94+
parser.add_argument("--num-inference-steps", type=int, default=35, help="Number of denoising steps.")
95+
parser.add_argument(
96+
"--flow-shift",
97+
type=float,
98+
default=None,
99+
help="Override the scheduler's flow-matching shift (UniPCMultistepScheduler.flow_shift).",
100+
)
101+
parser.add_argument("--seed", type=int, default=None, help="Random seed for latent initialization.")
65102
parser.add_argument(
66103
"--enable-sound",
67104
action="store_true",
68105
default=False,
69106
help="Generate sound alongside video (requires a sound-capable checkpoint).",
70107
)
108+
parser.add_argument(
109+
"--action-mode",
110+
choices=["forward_dynamics", "inverse_dynamics", "policy"],
111+
default=None,
112+
help="Enable Cosmos3 action generation with a loaded conditioning video.",
113+
)
114+
parser.add_argument("--action-path", default=None, help="JSON action path for forward_dynamics mode.")
115+
parser.add_argument("--action-chunk-size", type=int, default=None, help="Number of action tokens to generate/use.")
116+
parser.add_argument("--domain-name", default=None, help="Cosmos3 action embodiment domain name.")
117+
parser.add_argument(
118+
"--view-point",
119+
choices=["ego_view", "third_person_view", "wrist_view", "concat_view"],
120+
default="ego_view",
121+
help="Camera perspective for the action caption's cinematography.framing field (default: ego_view).",
122+
)
123+
parser.add_argument(
124+
"--resolution-tier",
125+
type=int,
126+
default=480,
127+
choices=[256, 480, 704, 720],
128+
help=(
129+
"Action resolution tier (256/480/704/720). Selects the aspect bin / padded conditioning canvas, "
130+
"not the output frame size."
131+
),
132+
)
71133
parser.add_argument(
72134
"--no-duration-template",
73135
dest="add_duration_template",
@@ -108,23 +170,59 @@ def main():
108170
)
109171
print("Pipeline loaded successfully.")
110172

173+
if args.flow_shift is not None:
174+
pipeline.scheduler = UniPCMultistepScheduler.from_config(
175+
pipeline.scheduler.config, flow_shift=args.flow_shift, use_karras_sigmas=False
176+
)
177+
111178
output_dir = pathlib.Path(args.output)
112179
output_dir.mkdir(parents=True, exist_ok=True)
113-
114-
image = load_image(args.vision_path) if args.vision_path is not None else None
115-
116-
result = pipeline(
117-
prompt=args.prompt,
118-
image=image,
119-
num_frames=args.num_frames,
120-
height=args.height,
121-
width=args.width,
122-
fps=args.fps,
123-
enable_sound=args.enable_sound,
124-
add_resolution_template=args.add_resolution_template,
125-
add_duration_template=args.add_duration_template,
126-
enable_safety_check=not args.no_safety_check,
127-
)
180+
generator = torch.Generator().manual_seed(args.seed) if args.seed is not None else None
181+
182+
if args.action_mode is not None:
183+
if args.vision_path is None:
184+
raise ValueError("--vision-path must point to a conditioning video for action modes.")
185+
if args.action_chunk_size is None:
186+
raise ValueError("--action-chunk-size is required for action modes.")
187+
video = load_video(args.vision_path)
188+
raw_actions = _load_action(args.action_path) if args.action_mode == "forward_dynamics" else None
189+
result = pipeline(
190+
prompt=args.prompt,
191+
action=CosmosActionCondition(
192+
mode=args.action_mode,
193+
chunk_size=args.action_chunk_size,
194+
domain_name=args.domain_name,
195+
resolution_tier=args.resolution_tier,
196+
raw_actions=raw_actions,
197+
video=video,
198+
view_point=args.view_point,
199+
),
200+
fps=args.fps,
201+
num_inference_steps=args.num_inference_steps,
202+
guidance_scale=args.guidance_scale,
203+
generator=generator,
204+
use_system_prompt=False,
205+
add_resolution_template=args.add_resolution_template,
206+
add_duration_template=args.add_duration_template,
207+
enable_safety_check=not args.no_safety_check,
208+
)
209+
else:
210+
image = load_image(args.vision_path) if args.vision_path is not None else None
211+
result = pipeline(
212+
prompt=args.prompt,
213+
image=image,
214+
num_frames=args.num_frames,
215+
height=args.height,
216+
width=args.width,
217+
fps=args.fps,
218+
num_inference_steps=args.num_inference_steps,
219+
enable_sound=args.enable_sound,
220+
guidance_scale=args.guidance_scale,
221+
generator=generator,
222+
add_resolution_template=args.add_resolution_template,
223+
add_duration_template=args.add_duration_template,
224+
enable_safety_check=not args.no_safety_check,
225+
)
128226

129227
if args.num_frames == 1:
130228
save_path = output_dir / "sample.jpg"
@@ -145,6 +243,13 @@ def main():
145243
export_to_video(result.video, str(save_path), fps=int(args.fps), quality=10, macro_block_size=1)
146244
print(f"Saved: {save_path}")
147245

246+
if result.action is not None:
247+
for action in result.action:
248+
action_path = output_dir / "sample_action.json"
249+
with open(action_path, "w") as f:
250+
json.dump(action.tolist(), f)
251+
print(f"Saved: {action_path}")
252+
148253

149254
if __name__ == "__main__":
150255
main()

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,7 @@
553553
"Cosmos2TextToImagePipeline",
554554
"Cosmos2VideoToWorldPipeline",
555555
"Cosmos3OmniPipeline",
556+
"CosmosActionCondition",
556557
"CosmosTextToWorldPipeline",
557558
"CosmosVideoToWorldPipeline",
558559
"CycleDiffusionPipeline",
@@ -1373,6 +1374,7 @@
13731374
Cosmos2TextToImagePipeline,
13741375
Cosmos2VideoToWorldPipeline,
13751376
Cosmos3OmniPipeline,
1377+
CosmosActionCondition,
13761378
CosmosTextToWorldPipeline,
13771379
CosmosVideoToWorldPipeline,
13781380
CycleDiffusionPipeline,

0 commit comments

Comments
 (0)