diff --git a/test_stage_2.py b/test_stage_2.py index bed3ee6..4bd289d 100644 --- a/test_stage_2.py +++ b/test_stage_2.py @@ -40,7 +40,8 @@ def parse_args(): parser.add_argument("--steps", type=int, default=20, help="DDIM sampling steps") parser.add_argument("--fps", type=int) - parser.add_argument("--skip", type=int, default=1, help="frame sample rate = (skip+1)") + parser.add_argument("--skip", type=int, default=1, help="frame sample rate = (skip+1)") + parser.add_argument("--compare", type=int, default=1, help="render a side by side comparison of reference, pose, and result") args = parser.parse_args() print('Width:', args.W) @@ -63,10 +64,21 @@ def scale_video(video,width,height): return scaled_video -def main(): - args = parse_args() - - config = OmegaConf.load(args.config) +def run_video_generation( + config_path="./configs/test_stage_2.yaml", + width=768, + height=768, + length=300, + slice_num=48, + overlap=4, + cfg=3.5, + seed=99, + steps=20, + fps=None, + skip=1, + compare=1 +): + config = OmegaConf.load(config_path) if config.weight_dtype == "fp16": weight_dtype = torch.float16 @@ -102,9 +114,7 @@ def main(): sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) scheduler = DDIMScheduler(**sched_kwargs) - generator = torch.manual_seed(args.seed) - - width, height = args.W, args.H + generator = torch.manual_seed(seed) # load pretrained weights denoising_unet.load_state_dict( @@ -143,17 +153,17 @@ def handle_single(ref_image_path,pose_video_path): pose_images = read_frames(pose_video_path) src_fps = get_fps(pose_video_path) print(f"pose video has {len(pose_images)} frames, with {src_fps} fps") - L = min(args.L, len(pose_images)) + L = min(length, len(pose_images)) pose_transform = transforms.Compose( [transforms.Resize((height, width)), transforms.ToTensor()] ) original_width,original_height = 0,0 - pose_images = pose_images[::args.skip+1] + pose_images = pose_images[::skip+1] print("processing length:", len(pose_images)) - src_fps = src_fps // (args.skip + 1) + src_fps = src_fps // (skip + 1) print("fps", src_fps) - L = L // ((args.skip + 1)) + L = L // ((skip + 1)) for pose_image_pil in pose_images[: L]: pose_tensor_list.append(pose_transform(pose_image_pil)) @@ -162,8 +172,8 @@ def handle_single(ref_image_path,pose_video_path): pose_image_pil = pose_image_pil.resize((width,height)) # repeart the last segment - last_segment_frame_num = (L - args.S) % (args.S - args.O) - repeart_frame_num = (args.S - args.O - last_segment_frame_num) % (args.S - args.O) + last_segment_frame_num = (L - slice_num) % (slice_num - overlap) + repeart_frame_num = (slice_num - overlap - last_segment_frame_num) % (slice_num - overlap) for i in range(repeart_frame_num): pose_list.append(pose_list[-1]) pose_tensor_list.append(pose_tensor_list[-1]) @@ -183,38 +193,47 @@ def handle_single(ref_image_path,pose_video_path): width, height, len(pose_list), - args.steps, - args.cfg, + steps, + cfg, generator=generator, - context_frames=args.S, + context_frames=slice_num, context_stride=1, - context_overlap=args.O, + context_overlap=overlap, ).videos m1 = config.pose_guider_path.split('.')[0].split('/')[-1] m2 = config.motion_module_path.split('.')[0].split('/')[-1] - save_dir_name = f"{time_str}-{args.cfg}-{m1}-{m2}" + save_dir_name = f"{time_str}-{cfg}-{m1}-{m2}" save_dir = Path(f"./output/video-{date_str}/{save_dir_name}") save_dir.mkdir(exist_ok=True, parents=True) result = scale_video(video[:,:,:L], original_width, original_height) + output_path1 = f"{save_dir}/{ref_name}_{pose_name}_{cfg}_{steps}_{skip}.mp4" save_videos_grid( result, - f"{save_dir}/{ref_name}_{pose_name}_{args.cfg}_{args.steps}_{args.skip}.mp4", + output_path1, n_rows=1, - fps=src_fps if args.fps is None else args.fps, - ) + fps=src_fps if fps is None else fps, + ) + + if not compare: + return [output_path1] video = torch.cat([ref_image_tensor, pose_tensor[:,:,:L], video[:,:,:L]], dim=0) video = scale_video(video, original_width, original_height) + output_path2 = f"{save_dir}/{ref_name}_{pose_name}_{cfg}_{steps}_{skip}_{m1}_{m2}.mp4" save_videos_grid( video, - f"{save_dir}/{ref_name}_{pose_name}_{args.cfg}_{args.steps}_{args.skip}_{m1}_{m2}.mp4", + output_path2, n_rows=3, - fps=src_fps if args.fps is None else args.fps, + fps=src_fps if fps is None else fps, ) + + return [output_path1, output_path2] + + all_video_paths = [] for ref_image_path_dir in config["test_cases"].keys(): if os.path.isdir(ref_image_path_dir): @@ -228,10 +247,15 @@ def handle_single(ref_image_path,pose_video_path): else: pose_video_paths = [pose_video_path_dir] for pose_video_path in pose_video_paths: - handle_single(ref_image_path, pose_video_path) - + video_path = handle_single(ref_image_path, pose_video_path) + all_video_paths.extend(video_path) + return all_video_paths if __name__ == "__main__": - main() + args = parse_args() + video_paths = run_video_generation( + args.config, args.W, args.H, args.L, args.S, args.O, args.cfg, args.seed, args.steps, args.fps, args.skip, args.compare + ) + print(json.dumps(video_paths))