-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathvisualization.py
More file actions
117 lines (90 loc) · 3.54 KB
/
visualization.py
File metadata and controls
117 lines (90 loc) · 3.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import cv2
import hydra
import imageio
import os
import numpy as np
import torch as t
from omegaconf import DictConfig, OmegaConf
from torchrl.envs import ParallelEnv
from torchrl.envs.libs.meltingpot import MeltingpotEnv
from utils import (
instantiate_agent
)
num_frames = 1000
model_folder = '/home/mila/j/juan.duque/projects/advantage-alignment/experiments/3klboisl'
model_name = 'agent.pt'
output_folder = 'videos'
def create_video_with_imageio(frames, output_folder, video_name='output.mp4', frame_rate=3):
# Ensure the output folder exists
if not os.path.exists(output_folder):
os.makedirs(output_folder)
video_path = os.path.join(output_folder, video_name)
# Normalize frames if needed and ensure they are in uint8 format
if frames.dtype != np.uint8:
frames = (frames).astype(np.uint8)
# Write video using imageio
imageio.mimsave(video_path, frames, fps=frame_rate, macro_block_size=None)
print(f"Video saved to {video_path}")
def create_video_from_frames(frames, output_folder, video_name='output.mp4', frame_rate=3):
"""
Create an MP4 video from a numpy array of frames.
Args:
- frames (np.array): 4D numpy array with shape (time, height, width, channels).
- output_folder (str): Folder to save the output video.
- video_name (str): Name of the output video file.
- frame_rate (int): Frame rate of the output video.
Returns:
- None
"""
# Ensure the output folder exists
if not os.path.exists(output_folder):
os.makedirs(output_folder)
if frames.dtype != np.uint8:
frames = (frames).astype(np.uint8)
# Get the shape of the first frame to set the video properties
height, width, channels = frames[0].shape
# Define the codec and create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec for MP4 (H264 is widely supported)
video_path = os.path.join(output_folder, video_name)
video_writer = cv2.VideoWriter(video_path, fourcc, frame_rate, (width, height))
# Write each frame to the video file
for i in range(frames.shape[0]):
video_writer.write(frames[i])
# Release the VideoWriter object
video_writer.release()
print(f"Video saved to {video_path}")
@hydra.main(version_base="1.3", config_path="configs", config_name="meltingpot.yaml")
def main(cfg: DictConfig) -> None:
from train import (
_gen_sim,
ReplayBuffer,
)
cxt_len = cfg['max_cxt_len']
scenario = cfg['env']['scenario']
env = ParallelEnv(1, lambda: MeltingpotEnv(scenario))
agents = []
images = []
rewards = []
agent = instantiate_agent(cfg)
model_path = os.path.join(model_folder, model_name)
model_dict = t.load(model_path)
agent.actor.load_state_dict(model_dict['actor_state_dict'])
for i in range(cfg['env']['num_agents']):
agents.append(agent)
state = env.reset()
replay_buffer = ReplayBuffer(
env=env,
replay_buffer_size=cfg['rb_size'],
n_agents=len(agents),
cxt_len=cxt_len,
device=cfg['device']
)
for i in range(num_frames // cxt_len):
trajectory, state = _gen_sim(state, env, 1, cxt_len, agents, replay_buffer, cfg)
images.append(trajectory.data['full_maps'][0].cpu().numpy())
rewards.append(trajectory.data['rewards'].cpu().numpy())
frames = np.concatenate(images)
create_video_with_imageio(frames, output_folder=output_folder, video_name='adalign_5.mp4', frame_rate=3)
if __name__ == "__main__":
OmegaConf.register_new_resolver("eval", eval)
main()