Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 51 additions & 27 deletions test_stage_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def parse_args():
parser.add_argument("--steps", type=int, default=20, help="DDIM sampling steps")
parser.add_argument("--fps", type=int)

parser.add_argument("--skip", type=int, default=1, help="frame sample rate = (skip+1)")
parser.add_argument("--skip", type=int, default=1, help="frame sample rate = (skip+1)")
parser.add_argument("--compare", type=int, default=1, help="render a side by side comparison of reference, pose, and result")
args = parser.parse_args()

print('Width:', args.W)
Expand All @@ -63,10 +64,21 @@ def scale_video(video,width,height):
return scaled_video


def main():
args = parse_args()

config = OmegaConf.load(args.config)
def run_video_generation(
config_path="./configs/test_stage_2.yaml",
width=768,
height=768,
length=300,
slice_num=48,
overlap=4,
cfg=3.5,
seed=99,
steps=20,
fps=None,
skip=1,
compare=1
):
config = OmegaConf.load(config_path)

if config.weight_dtype == "fp16":
weight_dtype = torch.float16
Expand Down Expand Up @@ -102,9 +114,7 @@ def main():
sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
scheduler = DDIMScheduler(**sched_kwargs)

generator = torch.manual_seed(args.seed)

width, height = args.W, args.H
generator = torch.manual_seed(seed)

# load pretrained weights
denoising_unet.load_state_dict(
Expand Down Expand Up @@ -143,17 +153,17 @@ def handle_single(ref_image_path,pose_video_path):
pose_images = read_frames(pose_video_path)
src_fps = get_fps(pose_video_path)
print(f"pose video has {len(pose_images)} frames, with {src_fps} fps")
L = min(args.L, len(pose_images))
L = min(length, len(pose_images))
pose_transform = transforms.Compose(
[transforms.Resize((height, width)), transforms.ToTensor()]
)
original_width,original_height = 0,0

pose_images = pose_images[::args.skip+1]
pose_images = pose_images[::skip+1]
print("processing length:", len(pose_images))
src_fps = src_fps // (args.skip + 1)
src_fps = src_fps // (skip + 1)
print("fps", src_fps)
L = L // ((args.skip + 1))
L = L // ((skip + 1))

for pose_image_pil in pose_images[: L]:
pose_tensor_list.append(pose_transform(pose_image_pil))
Expand All @@ -162,8 +172,8 @@ def handle_single(ref_image_path,pose_video_path):
pose_image_pil = pose_image_pil.resize((width,height))

# repeart the last segment
last_segment_frame_num = (L - args.S) % (args.S - args.O)
repeart_frame_num = (args.S - args.O - last_segment_frame_num) % (args.S - args.O)
last_segment_frame_num = (L - slice_num) % (slice_num - overlap)
repeart_frame_num = (slice_num - overlap - last_segment_frame_num) % (slice_num - overlap)
for i in range(repeart_frame_num):
pose_list.append(pose_list[-1])
pose_tensor_list.append(pose_tensor_list[-1])
Expand All @@ -183,38 +193,47 @@ def handle_single(ref_image_path,pose_video_path):
width,
height,
len(pose_list),
args.steps,
args.cfg,
steps,
cfg,
generator=generator,
context_frames=args.S,
context_frames=slice_num,
context_stride=1,
context_overlap=args.O,
context_overlap=overlap,
).videos


m1 = config.pose_guider_path.split('.')[0].split('/')[-1]
m2 = config.motion_module_path.split('.')[0].split('/')[-1]

save_dir_name = f"{time_str}-{args.cfg}-{m1}-{m2}"
save_dir_name = f"{time_str}-{cfg}-{m1}-{m2}"
save_dir = Path(f"./output/video-{date_str}/{save_dir_name}")
save_dir.mkdir(exist_ok=True, parents=True)

result = scale_video(video[:,:,:L], original_width, original_height)
output_path1 = f"{save_dir}/{ref_name}_{pose_name}_{cfg}_{steps}_{skip}.mp4"
save_videos_grid(
result,
f"{save_dir}/{ref_name}_{pose_name}_{args.cfg}_{args.steps}_{args.skip}.mp4",
output_path1,
n_rows=1,
fps=src_fps if args.fps is None else args.fps,
)
fps=src_fps if fps is None else fps,
)

if not compare:
return [output_path1]

video = torch.cat([ref_image_tensor, pose_tensor[:,:,:L], video[:,:,:L]], dim=0)
video = scale_video(video, original_width, original_height)
output_path2 = f"{save_dir}/{ref_name}_{pose_name}_{cfg}_{steps}_{skip}_{m1}_{m2}.mp4"
save_videos_grid(
video,
f"{save_dir}/{ref_name}_{pose_name}_{args.cfg}_{args.steps}_{args.skip}_{m1}_{m2}.mp4",
output_path2,
n_rows=3,
fps=src_fps if args.fps is None else args.fps,
fps=src_fps if fps is None else fps,
)

return [output_path1, output_path2]

all_video_paths = []

for ref_image_path_dir in config["test_cases"].keys():
if os.path.isdir(ref_image_path_dir):
Expand All @@ -228,10 +247,15 @@ def handle_single(ref_image_path,pose_video_path):
else:
pose_video_paths = [pose_video_path_dir]
for pose_video_path in pose_video_paths:
handle_single(ref_image_path, pose_video_path)

video_path = handle_single(ref_image_path, pose_video_path)
all_video_paths.extend(video_path)

return all_video_paths


if __name__ == "__main__":
main()
args = parse_args()
video_paths = run_video_generation(
args.config, args.W, args.H, args.L, args.S, args.O, args.cfg, args.seed, args.steps, args.fps, args.skip, args.compare
)
print(json.dumps(video_paths))