diff --git a/examples/datasets/colmap.py b/examples/datasets/colmap.py index 6f505cd52..fd85750c8 100644 --- a/examples/datasets/colmap.py +++ b/examples/datasets/colmap.py @@ -16,6 +16,7 @@ import json import os +from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path from typing import Any, Dict, List, Optional @@ -46,28 +47,45 @@ def _get_rel_paths(path_dir: str) -> List[str]: return paths -def _resize_image_folder(image_dir: str, resized_dir: str, factor: int) -> str: +def _resize_image(image_path: str, resized_path: str, factor: int) -> None: + """Resize a single image.""" + if os.path.isfile(resized_path): + return + os.makedirs(os.path.dirname(resized_path), exist_ok=True) + image = imageio.imread(image_path)[..., :3] + resized_size = ( + int(round(image.shape[1] / factor)), + int(round(image.shape[0] / factor)), + ) + resized_image = np.array(Image.fromarray(image).resize(resized_size, Image.BICUBIC)) + imageio.imwrite(resized_path, resized_image) + + +def _resize_image_folder( + image_dir: str, resized_dir: str, factor: int, num_workers: int = 0 +) -> str: """Resize image folder.""" - print(f"Downscaling images by {factor}x from {image_dir} to {resized_dir}.") + if num_workers <= 0: + num_workers = os.cpu_count() + print( + f"Downscaling images by {factor}x from {image_dir} to {resized_dir} " + f"({num_workers} threads)." + ) os.makedirs(resized_dir, exist_ok=True) image_files = _get_rel_paths(image_dir) - for image_file in tqdm(image_files): + tasks = [] + for image_file in image_files: image_path = os.path.join(image_dir, image_file) resized_path = os.path.join( resized_dir, os.path.splitext(image_file)[0] + ".png" ) - if os.path.isfile(resized_path): - continue - image = imageio.imread(image_path)[..., :3] - resized_size = ( - int(round(image.shape[1] / factor)), - int(round(image.shape[0] / factor)), - ) - resized_image = np.array( - Image.fromarray(image).resize(resized_size, Image.BICUBIC) - ) - imageio.imwrite(resized_path, resized_image) + tasks.append((image_path, resized_path, factor)) + + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [executor.submit(_resize_image, *t) for t in tasks] + for future in tqdm(as_completed(futures), total=len(futures)): + future.result() return resized_dir @@ -81,6 +99,7 @@ def __init__( normalize: bool = False, test_every: int = 8, load_exposure: bool = False, + num_resize_workers: int = 0, ): self.data_dir = data_dir self.factor = factor @@ -212,7 +231,10 @@ def __init__( image_files = sorted(_get_rel_paths(image_dir)) if factor > 1 and os.path.splitext(image_files[0])[1].lower() == ".jpg": image_dir = _resize_image_folder( - colmap_image_dir, image_dir + "_png", factor=factor + colmap_image_dir, + image_dir + "_png", + factor=factor, + num_workers=num_resize_workers, ) image_files = sorted(_get_rel_paths(image_dir)) colmap_to_image = dict(zip(colmap_files, image_files))