|
| 1 | +import os |
| 2 | +import random |
| 3 | +import torch |
| 4 | +import decord |
| 5 | +import numpy as np |
| 6 | +from torch.utils.data import Dataset, DataLoader, DistributedSampler |
| 7 | +from torchvision import transforms |
| 8 | +from typing import List, Tuple |
| 9 | + |
| 10 | +# Global variables |
| 11 | +rank = int(os.environ.get("RANK", "0")) |
| 12 | +local_rank = int(os.environ.get("LOCAL_RANK", "0")) |
| 13 | +world_size = int(os.environ.get("WORLD_SIZE", "1")) |
| 14 | + |
| 15 | + |
| 16 | +class VideoDataset(Dataset): |
| 17 | + """Dataset for video classification training.""" |
| 18 | + |
| 19 | + def __init__( |
| 20 | + self, |
| 21 | + file_list: List[Tuple[str, int]], |
| 22 | + input_size: int = 224, |
| 23 | + sequence_length: int = 16, |
| 24 | + use_rgb: bool = True, |
| 25 | + use_flip: bool = True, |
| 26 | + reprob: float = 0.0, |
| 27 | + seed: int = 0, |
| 28 | + ): |
| 29 | + self.file_list = file_list |
| 30 | + self.input_size = input_size |
| 31 | + self.sequence_length = sequence_length |
| 32 | + self.use_rgb = use_rgb |
| 33 | + self.use_flip = use_flip |
| 34 | + self.reprob = reprob |
| 35 | + self.seed = seed |
| 36 | + |
| 37 | + # Default mean and std values |
| 38 | + self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]) * 255 |
| 39 | + self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711]) * 255 |
| 40 | + |
| 41 | + # Save a valid item for replacement in case of errors |
| 42 | + self.replace_example_info = self.file_list[0] |
| 43 | + |
| 44 | + # Set up transforms for training |
| 45 | + self.transform = transforms.Compose([ |
| 46 | + transforms.Lambda(lambda x: torch.from_numpy(x.astype(np.float32)).permute(3, 0, 1, 2)), # FHWC -> CFHW |
| 47 | + transforms.RandomResizedCrop( |
| 48 | + size=(self.input_size, self.input_size), |
| 49 | + scale=(0.5, 1.0), |
| 50 | + ratio=(0.75, 1.3333), |
| 51 | + antialias=True |
| 52 | + ), |
| 53 | + transforms.RandomHorizontalFlip(p=0.5) if self.use_flip else transforms.Lambda(lambda x: x), |
| 54 | + transforms.Lambda(lambda x: (x - self.mean.view(-1, 1, 1, 1)) / self.std.view(-1, 1, 1, 1)), |
| 55 | + ]) |
| 56 | + |
| 57 | + def _sample_frames(self, video_path): |
| 58 | + """Sample frames from video for training.""" |
| 59 | + try: |
| 60 | + decord.bridge.set_bridge('torch') |
| 61 | + decord_vr = decord.VideoReader(video_path, num_threads=4) |
| 62 | + duration = len(decord_vr) |
| 63 | + |
| 64 | + average_duration = duration // self.sequence_length |
| 65 | + all_index = [] |
| 66 | + |
| 67 | + if average_duration > 0: |
| 68 | + all_index = list( |
| 69 | + np.multiply(list(range(self.sequence_length)), average_duration) + |
| 70 | + np.random.randint(average_duration, size=self.sequence_length) |
| 71 | + ) |
| 72 | + elif duration > self.sequence_length: |
| 73 | + all_index = list( |
| 74 | + np.sort(np.random.randint(duration, size=self.sequence_length)) |
| 75 | + ) |
| 76 | + else: |
| 77 | + all_index = [0] * (self.sequence_length - duration) + list(range(duration)) |
| 78 | + |
| 79 | + frame_id_list = list(np.array(all_index)) |
| 80 | + decord_vr.seek(0) |
| 81 | + video_data = decord_vr.get_batch(frame_id_list).numpy() |
| 82 | + |
| 83 | + if not self.use_rgb: # Convert BGR to RGB if needed |
| 84 | + video_data = video_data[:, :, :, ::-1] |
| 85 | + |
| 86 | + return video_data |
| 87 | + |
| 88 | + except Exception as e: |
| 89 | + print(f"Error processing {video_path}: {e}") |
| 90 | + return None |
| 91 | + |
| 92 | + def __len__(self) -> int: |
| 93 | + return len(self.file_list) |
| 94 | + |
| 95 | + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: |
| 96 | + video_path, video_label = self.file_list[idx] |
| 97 | + |
| 98 | + try: |
| 99 | + video_data = self._sample_frames(video_path) |
| 100 | + if video_data is None: |
| 101 | + video_path, video_label = self.replace_example_info |
| 102 | + video_data = self._sample_frames(video_path) |
| 103 | + except Exception: |
| 104 | + print(f"Error: {video_path}") |
| 105 | + video_path, video_label = self.replace_example_info |
| 106 | + video_data = self._sample_frames(video_path) |
| 107 | + |
| 108 | + video_tensor = self.transform(video_data) |
| 109 | + |
| 110 | + if isinstance(video_label, int): |
| 111 | + label_tensor = torch.tensor(video_label, dtype=torch.long) |
| 112 | + elif isinstance(video_label, np.ndarray): |
| 113 | + label_tensor = torch.from_numpy(video_label).long() |
| 114 | + else: |
| 115 | + label_tensor = torch.tensor(video_label, dtype=torch.long) |
| 116 | + |
| 117 | + return video_tensor, label_tensor |
| 118 | + |
| 119 | + |
| 120 | +class DALIWarper(object): |
| 121 | + def __init__(self, dali_iter, step_data_num, mode="train", auto_reset=True): |
| 122 | + self.iter = dali_iter |
| 123 | + self.step_data_num = step_data_num |
| 124 | + assert(mode in ["train", "val", "test"]) |
| 125 | + self.mode = mode |
| 126 | + self.auto_reset = auto_reset |
| 127 | + |
| 128 | + def __next__(self): |
| 129 | + try: |
| 130 | + videos, labels = self.iter.__next__() |
| 131 | + videos = videos.cuda() |
| 132 | + labels = labels.cuda() |
| 133 | + return videos, labels |
| 134 | + except StopIteration: |
| 135 | + if self.auto_reset: |
| 136 | + self.iter.reset() |
| 137 | + return self.__next__() |
| 138 | + else: |
| 139 | + raise StopIteration |
| 140 | + |
| 141 | + def __iter__(self): |
| 142 | + return self |
| 143 | + |
| 144 | + def __len__(self): |
| 145 | + return self.step_data_num |
| 146 | + |
| 147 | + def reset(self): |
| 148 | + self.iter.reset() |
| 149 | + |
| 150 | + |
| 151 | +def create_video_dataloader( |
| 152 | + file_list: List[Tuple[str, int]], |
| 153 | + batch_size: int = 32, |
| 154 | + num_workers: int = 4, |
| 155 | + input_size: int = 224, |
| 156 | + sequence_length: int = 16, |
| 157 | + seed: int = 0, |
| 158 | + num_shard: int = 1, |
| 159 | + shard_id: int = 0, |
| 160 | + pin_memory: bool = True, |
| 161 | + reprob: float = 0.0, |
| 162 | + use_dali_warper: bool = False, |
| 163 | + dali_mode: str = "train", |
| 164 | + auto_reset: bool = True, |
| 165 | +): |
| 166 | + """Create a PyTorch dataloader for video training data.""" |
| 167 | + |
| 168 | + dataset = VideoDataset( |
| 169 | + file_list=file_list, |
| 170 | + input_size=input_size, |
| 171 | + sequence_length=sequence_length, |
| 172 | + seed=seed, |
| 173 | + reprob=reprob, |
| 174 | + ) |
| 175 | + |
| 176 | + sampler = DistributedSampler( |
| 177 | + dataset, |
| 178 | + num_replicas=num_shard, |
| 179 | + rank=shard_id, |
| 180 | + shuffle=True, |
| 181 | + seed=seed |
| 182 | + ) |
| 183 | + # shuffle = False # Sampler will handle shuffling |
| 184 | + |
| 185 | + |
| 186 | + dataloader = DataLoader( |
| 187 | + dataset, |
| 188 | + batch_size=batch_size, |
| 189 | + # shuffle=shuffle, |
| 190 | + sampler=sampler, |
| 191 | + num_workers=num_workers, |
| 192 | + pin_memory=pin_memory, |
| 193 | + drop_last=True, |
| 194 | + persistent_workers=(num_workers > 0), |
| 195 | + prefetch_factor=2, |
| 196 | + ) |
| 197 | + |
| 198 | + # 如果需要,使用 DALIWarper 封装 DataLoader |
| 199 | + if use_dali_warper: |
| 200 | + step_data_num = len(dataloader) |
| 201 | + dataloader = DALIWarper( |
| 202 | + dataloader, |
| 203 | + step_data_num=step_data_num, |
| 204 | + mode=dali_mode, |
| 205 | + auto_reset=auto_reset |
| 206 | + ) |
| 207 | + |
| 208 | + return dataloader |
0 commit comments