Skip to content

Commit 4146c2a

Browse files
committed
Fix learning rate scheduling to update only after optimizer step in training loop
1 parent 9f941a9 commit 4146c2a

15 files changed

Lines changed: 395 additions & 2756 deletions

dataloader/ap_dataloader_dali.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(self, mode: str, source_params: Dict[str, Any]):
7373
self.fallback_example = self.file_list[0] if self.file_list else ("", 0)
7474

7575
def _get_frame_indices(self, num_frames: int) -> List[int]:
76-
76+
7777
if num_frames < self.sequence_length:
7878
indices = list(range(num_frames))
7979
indices += [num_frames - 1] * (self.sequence_length - num_frames)
@@ -134,40 +134,40 @@ def preprocess_videos(videos, mode, input_size, mean, std):
134134
output_layout="FHWC",
135135
)
136136

137-
# if mode == "train":
137+
if mode == "train":
138138
# Brightness/contrast
139-
# if fn.random.coin_flip(dtype=types.BOOL, probability=0.2):
140-
# videos = fn.brightness_contrast(
141-
# videos,
142-
# contrast=fn.random.uniform(range=(0.6, 1.4)),
143-
# brightness=fn.random.uniform(range=(-0.125, 0.125)),
144-
# device="gpu",
145-
# )
146-
147-
# # Saturation
148-
# if fn.random.coin_flip(dtype=types.BOOL, probability=0.2):
149-
# videos = fn.saturation(
150-
# videos,
151-
# saturation=fn.random.uniform(range=[0.6, 1.4]),
152-
# device="gpu",
153-
# )
154-
155-
# # Hue
156-
# if fn.random.coin_flip(dtype=types.BOOL, probability=0.2):
157-
# videos = fn.hue(
158-
# videos,
159-
# hue=fn.random.uniform(range=[-0.2, 0.2]),
160-
# device="gpu",
161-
# )
162-
163-
# # Color space conversion
164-
# if fn.random.coin_flip(dtype=types.BOOL, probability=0.1):
165-
# videos = fn.color_space_conversion(
166-
# videos,
167-
# image_type=types.RGB,
168-
# output_type=types.BGR,
169-
# device="gpu",
170-
# )
139+
if fn.random.coin_flip(dtype=types.BOOL, probability=0.8):
140+
videos = fn.brightness_contrast(
141+
videos,
142+
contrast=fn.random.uniform(range=(0.6, 1.4)),
143+
brightness=fn.random.uniform(range=(-0.125, 0.125)),
144+
device="gpu",
145+
)
146+
147+
# Saturation
148+
if fn.random.coin_flip(dtype=types.BOOL, probability=0.8):
149+
videos = fn.saturation(
150+
videos,
151+
saturation=fn.random.uniform(range=[0.6, 1.4]),
152+
device="gpu",
153+
)
154+
155+
# Hue
156+
if fn.random.coin_flip(dtype=types.BOOL, probability=0.8):
157+
videos = fn.hue(
158+
videos,
159+
hue=fn.random.uniform(range=[-0.2, 0.2]),
160+
device="gpu",
161+
)
162+
163+
# Color space conversion
164+
if fn.random.coin_flip(dtype=types.BOOL, probability=0.1):
165+
videos = fn.color_space_conversion(
166+
videos,
167+
image_type=types.RGB,
168+
output_type=types.BGR,
169+
device="gpu",
170+
)
171171

172172
# Unified normalization to FLOAT / CFHW
173173
videos = fn.crop_mirror_normalize(

dataloader/data_decord_torch.py

Lines changed: 0 additions & 208 deletions
This file was deleted.

0 commit comments

Comments
 (0)