Skip to content

Commit e71078b

Browse files
committed
init commit
0 parents  commit e71078b

53 files changed

Lines changed: 17870 additions & 0 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitignore

Lines changed: 491 additions & 0 deletions
Large diffs are not rendered by default.

README.md

Whitespace-only changes.

dataloader/data_decord_torch.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
import os
2+
import random
3+
import torch
4+
import decord
5+
import numpy as np
6+
from torch.utils.data import Dataset, DataLoader, DistributedSampler
7+
from torchvision import transforms
8+
from typing import List, Tuple
9+
10+
# Global variables
11+
rank = int(os.environ.get("RANK", "0"))
12+
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
13+
world_size = int(os.environ.get("WORLD_SIZE", "1"))
14+
15+
16+
class VideoDataset(Dataset):
17+
"""Dataset for video classification training."""
18+
19+
def __init__(
20+
self,
21+
file_list: List[Tuple[str, int]],
22+
input_size: int = 224,
23+
sequence_length: int = 16,
24+
use_rgb: bool = True,
25+
use_flip: bool = True,
26+
reprob: float = 0.0,
27+
seed: int = 0,
28+
):
29+
self.file_list = file_list
30+
self.input_size = input_size
31+
self.sequence_length = sequence_length
32+
self.use_rgb = use_rgb
33+
self.use_flip = use_flip
34+
self.reprob = reprob
35+
self.seed = seed
36+
37+
# Default mean and std values
38+
self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]) * 255
39+
self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711]) * 255
40+
41+
# Save a valid item for replacement in case of errors
42+
self.replace_example_info = self.file_list[0]
43+
44+
# Set up transforms for training
45+
self.transform = transforms.Compose([
46+
transforms.Lambda(lambda x: torch.from_numpy(x.astype(np.float32)).permute(3, 0, 1, 2)), # FHWC -> CFHW
47+
transforms.RandomResizedCrop(
48+
size=(self.input_size, self.input_size),
49+
scale=(0.5, 1.0),
50+
ratio=(0.75, 1.3333),
51+
antialias=True
52+
),
53+
transforms.RandomHorizontalFlip(p=0.5) if self.use_flip else transforms.Lambda(lambda x: x),
54+
transforms.Lambda(lambda x: (x - self.mean.view(-1, 1, 1, 1)) / self.std.view(-1, 1, 1, 1)),
55+
])
56+
57+
def _sample_frames(self, video_path):
58+
"""Sample frames from video for training."""
59+
try:
60+
decord.bridge.set_bridge('torch')
61+
decord_vr = decord.VideoReader(video_path, num_threads=4)
62+
duration = len(decord_vr)
63+
64+
average_duration = duration // self.sequence_length
65+
all_index = []
66+
67+
if average_duration > 0:
68+
all_index = list(
69+
np.multiply(list(range(self.sequence_length)), average_duration) +
70+
np.random.randint(average_duration, size=self.sequence_length)
71+
)
72+
elif duration > self.sequence_length:
73+
all_index = list(
74+
np.sort(np.random.randint(duration, size=self.sequence_length))
75+
)
76+
else:
77+
all_index = [0] * (self.sequence_length - duration) + list(range(duration))
78+
79+
frame_id_list = list(np.array(all_index))
80+
decord_vr.seek(0)
81+
video_data = decord_vr.get_batch(frame_id_list).numpy()
82+
83+
if not self.use_rgb: # Convert BGR to RGB if needed
84+
video_data = video_data[:, :, :, ::-1]
85+
86+
return video_data
87+
88+
except Exception as e:
89+
print(f"Error processing {video_path}: {e}")
90+
return None
91+
92+
def __len__(self) -> int:
93+
return len(self.file_list)
94+
95+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
96+
video_path, video_label = self.file_list[idx]
97+
98+
try:
99+
video_data = self._sample_frames(video_path)
100+
if video_data is None:
101+
video_path, video_label = self.replace_example_info
102+
video_data = self._sample_frames(video_path)
103+
except Exception:
104+
print(f"Error: {video_path}")
105+
video_path, video_label = self.replace_example_info
106+
video_data = self._sample_frames(video_path)
107+
108+
video_tensor = self.transform(video_data)
109+
110+
if isinstance(video_label, int):
111+
label_tensor = torch.tensor(video_label, dtype=torch.long)
112+
elif isinstance(video_label, np.ndarray):
113+
label_tensor = torch.from_numpy(video_label).long()
114+
else:
115+
label_tensor = torch.tensor(video_label, dtype=torch.long)
116+
117+
return video_tensor, label_tensor
118+
119+
120+
class DALIWarper(object):
121+
def __init__(self, dali_iter, step_data_num, mode="train", auto_reset=True):
122+
self.iter = dali_iter
123+
self.step_data_num = step_data_num
124+
assert(mode in ["train", "val", "test"])
125+
self.mode = mode
126+
self.auto_reset = auto_reset
127+
128+
def __next__(self):
129+
try:
130+
videos, labels = self.iter.__next__()
131+
videos = videos.cuda()
132+
labels = labels.cuda()
133+
return videos, labels
134+
except StopIteration:
135+
if self.auto_reset:
136+
self.iter.reset()
137+
return self.__next__()
138+
else:
139+
raise StopIteration
140+
141+
def __iter__(self):
142+
return self
143+
144+
def __len__(self):
145+
return self.step_data_num
146+
147+
def reset(self):
148+
self.iter.reset()
149+
150+
151+
def create_video_dataloader(
152+
file_list: List[Tuple[str, int]],
153+
batch_size: int = 32,
154+
num_workers: int = 4,
155+
input_size: int = 224,
156+
sequence_length: int = 16,
157+
seed: int = 0,
158+
num_shard: int = 1,
159+
shard_id: int = 0,
160+
pin_memory: bool = True,
161+
reprob: float = 0.0,
162+
use_dali_warper: bool = False,
163+
dali_mode: str = "train",
164+
auto_reset: bool = True,
165+
):
166+
"""Create a PyTorch dataloader for video training data."""
167+
168+
dataset = VideoDataset(
169+
file_list=file_list,
170+
input_size=input_size,
171+
sequence_length=sequence_length,
172+
seed=seed,
173+
reprob=reprob,
174+
)
175+
176+
sampler = DistributedSampler(
177+
dataset,
178+
num_replicas=num_shard,
179+
rank=shard_id,
180+
shuffle=True,
181+
seed=seed
182+
)
183+
# shuffle = False # Sampler will handle shuffling
184+
185+
186+
dataloader = DataLoader(
187+
dataset,
188+
batch_size=batch_size,
189+
# shuffle=shuffle,
190+
sampler=sampler,
191+
num_workers=num_workers,
192+
pin_memory=pin_memory,
193+
drop_last=True,
194+
persistent_workers=(num_workers > 0),
195+
prefetch_factor=2,
196+
)
197+
198+
# 如果需要,使用 DALIWarper 封装 DataLoader
199+
if use_dali_warper:
200+
step_data_num = len(dataloader)
201+
dataloader = DALIWarper(
202+
dataloader,
203+
step_data_num=step_data_num,
204+
mode=dali_mode,
205+
auto_reset=auto_reset
206+
)
207+
208+
return dataloader

0 commit comments

Comments
 (0)