-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Expand file tree
/
Copy pathrunner.py
More file actions
250 lines (231 loc) · 9.69 KB
/
runner.py
File metadata and controls
250 lines (231 loc) · 9.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import os, torch
import numpy as np
from tqdm import tqdm
from accelerate import Accelerator
from .training_module import DiffusionTrainingModule
from .logger import ModelLogger
def _pad_frames(frames, target_frames):
if target_frames is None:
return frames
if len(frames) >= target_frames:
return frames[:target_frames]
if len(frames) == 0:
raise ValueError("Cannot pad empty frame list.")
pad_frame = frames[-1]
return frames + [pad_frame] * (target_frames - len(frames))
def _frame_to_tensor(frame, min_value=-1.0, max_value=1.0):
if isinstance(frame, torch.Tensor):
tensor = frame
if tensor.dim() == 3 and tensor.shape[0] not in (1, 3):
tensor = tensor.permute(2, 0, 1)
return tensor
array = np.array(frame, dtype=np.float32)
tensor = torch.from_numpy(array).permute(2, 0, 1)
tensor = tensor * ((max_value - min_value) / 255.0) + min_value
return tensor
def _frames_to_tensor(frames, min_value=-1.0, max_value=1.0):
frame_tensors = [_frame_to_tensor(frame, min_value=min_value, max_value=max_value) for frame in frames]
return torch.stack(frame_tensors, dim=1)
def _collate_batch(batch, data_file_keys, num_frames):
if len(batch) == 1:
return batch[0]
single_frame_keys = {"reference_image", "vace_reference_image"}
output = {}
keys = batch[0].keys()
for key in keys:
values = [sample.get(key) for sample in batch]
if key in data_file_keys:
is_mask = "mask" in key
min_value = 0.0 if is_mask else -1.0
max_value = 1.0 if is_mask else 1.0
if any(value is None for value in values):
raise ValueError(f"Missing key '{key}' in one or more batch samples.")
if key in single_frame_keys:
frames = []
for value in values:
if isinstance(value, list):
if len(value) == 0:
raise ValueError(f"Key '{key}' has empty frame list.")
frames.append(value[0])
else:
frames.append(value)
tensors = [_frame_to_tensor(frame, min_value=min_value, max_value=max_value) for frame in frames]
output[key] = torch.stack(tensors, dim=0)
else:
tensors = []
for value in values:
if isinstance(value, list):
padded = _pad_frames(value, num_frames)
tensors.append(_frames_to_tensor(padded, min_value=min_value, max_value=max_value))
elif isinstance(value, torch.Tensor):
tensors.append(value)
else:
raise ValueError(f"Unsupported value type for key '{key}': {type(value)}")
output[key] = torch.stack(tensors, dim=0)
else:
output[key] = values
return output
def run_validation(
accelerator: Accelerator,
dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule,
num_workers: int,
batch_size: int,
data_file_keys: list[str],
num_frames: int,
max_batches: int = None,
):
if dataset is None:
return None
if batch_size > 1:
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=lambda batch: _collate_batch(batch, data_file_keys, num_frames),
num_workers=num_workers,
)
else:
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
dataloader = accelerator.prepare(dataloader)
was_training = model.training
model.eval()
losses = []
with torch.no_grad():
for step, data in enumerate(tqdm(dataloader, desc="Eval")):
if max_batches is not None and step >= max_batches:
break
if dataset.load_from_cache:
loss = model({}, inputs=data)
else:
loss = model(data)
loss = loss.detach().float()
loss = accelerator.gather(loss)
losses.append(loss.flatten())
if was_training:
model.train()
if not losses:
return None
mean_loss = torch.cat(losses).mean().item()
if accelerator.is_main_process:
print(f"Eval loss: {mean_loss:.6f}")
return mean_loss
def launch_training_task(
accelerator: Accelerator,
dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule,
model_logger: ModelLogger,
learning_rate: float = 1e-5,
weight_decay: float = 1e-2,
num_workers: int = 1,
save_steps: int = None,
num_epochs: int = 1,
val_dataset: torch.utils.data.Dataset = None,
args = None,
):
if args is not None:
learning_rate = args.learning_rate
weight_decay = args.weight_decay
num_workers = args.dataset_num_workers
save_steps = args.save_steps
num_epochs = args.num_epochs
batch_size = args.batch_size
data_file_keys = args.data_file_keys.split(",")
num_frames = getattr(args, "num_frames", None)
val_num_workers = args.val_dataset_num_workers
val_batch_size = args.val_batch_size or batch_size
val_data_file_keys = (args.val_data_file_keys or args.data_file_keys).split(",")
eval_every_n_epochs = args.eval_every_n_epochs
eval_max_batches = args.eval_max_batches
else:
batch_size = 1
data_file_keys = []
num_frames = None
val_num_workers = 0
val_batch_size = 1
val_data_file_keys = []
eval_every_n_epochs = 0
eval_max_batches = None
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
if batch_size > 1:
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=lambda batch: _collate_batch(batch, data_file_keys, num_frames),
num_workers=num_workers,
)
else:
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
best_val_loss = None
for epoch_id in range(num_epochs):
epoch_loss_sum = None
epoch_steps = 0
for data in tqdm(dataloader):
with accelerator.accumulate(model):
optimizer.zero_grad()
if dataset.load_from_cache:
loss = model({}, inputs=data)
else:
loss = model(data)
loss_value = loss.detach().float()
if epoch_loss_sum is None:
epoch_loss_sum = loss_value
else:
epoch_loss_sum = epoch_loss_sum + loss_value
epoch_steps += 1
accelerator.backward(loss)
optimizer.step()
model_logger.on_step_end(accelerator, model, save_steps)
scheduler.step()
if epoch_loss_sum is None:
epoch_loss_sum = torch.tensor(0.0, device=accelerator.device)
steps_tensor = torch.tensor(float(epoch_steps), device=epoch_loss_sum.device)
loss_stats = torch.stack([epoch_loss_sum, steps_tensor]).unsqueeze(0)
gathered_stats = accelerator.gather(loss_stats)
if accelerator.is_main_process:
total_loss = gathered_stats[:, 0].sum().item()
total_steps = gathered_stats[:, 1].sum().item()
avg_loss = total_loss / total_steps if total_steps > 0 else float("nan")
print(f"Train loss (epoch {epoch_id}): {avg_loss:.6f}")
if save_steps is None:
model_logger.on_epoch_end(accelerator, model, epoch_id)
if val_dataset is not None and eval_every_n_epochs > 0 and (epoch_id + 1) % eval_every_n_epochs == 0:
val_loss = run_validation(
accelerator,
val_dataset,
model,
val_num_workers,
val_batch_size,
val_data_file_keys,
num_frames,
max_batches=eval_max_batches,
)
if val_loss is not None and (best_val_loss is None or val_loss < best_val_loss):
best_val_loss = val_loss
if accelerator.is_main_process:
print(f"New best eval loss: {best_val_loss:.6f}. Saving best checkpoint.")
model_logger.save_model(accelerator, model, "best.safetensors")
model_logger.on_training_end(accelerator, model, save_steps)
def launch_data_process_task(
accelerator: Accelerator,
dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule,
model_logger: ModelLogger,
num_workers: int = 8,
args = None,
):
if args is not None:
num_workers = args.dataset_num_workers
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
model, dataloader = accelerator.prepare(model, dataloader)
for data_id, data in enumerate(tqdm(dataloader)):
with accelerator.accumulate(model):
with torch.no_grad():
folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
os.makedirs(folder, exist_ok=True)
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
data = model(data)
torch.save(data, save_path)