-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Expand file tree
/
Copy pathrunner.py
More file actions
120 lines (102 loc) · 4.27 KB
/
runner.py
File metadata and controls
120 lines (102 loc) · 4.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os, torch
from tqdm import tqdm
from accelerate import Accelerator
from .training_module import DiffusionTrainingModule
from .logger import ModelLogger
import time
from ..utils.profiling.flops_profiler import (
print_model_profile,
get_flops,
profile_entire_model,
unprofile_entire_model,
)
def launch_training_task(
accelerator: Accelerator,
dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule,
model_logger: ModelLogger,
learning_rate: float = 1e-5,
weight_decay: float = 1e-2,
num_workers: int = 1,
save_steps: int = None,
num_epochs: int = 1,
args = None,
):
if args is not None:
learning_rate = args.learning_rate
weight_decay = args.weight_decay
num_workers = args.dataset_num_workers
save_steps = args.save_steps
num_epochs = args.num_epochs
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
model.to(device=accelerator.device)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
train_step = 0
profile_entire_model(model)
for epoch_id in range(num_epochs):
progress = tqdm(
dataloader,
disable=not accelerator.is_main_process,
desc=f"Epoch {epoch_id + 1}/{num_epochs}",
)
for data in progress:
iter_start = time.time()
if data is None:
continue
with accelerator.accumulate(model):
optimizer.zero_grad()
if dataset.load_from_cache:
loss = model({}, inputs=data)
else:
loss = model(data)
t5_Tflops, wan_Tflops, vae_Tflops = get_flops(model)
accelerator.backward(loss)
optimizer.step()
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
scheduler.step()
torch.cuda.synchronize()
time_step = time.time() - iter_start
train_step += 1
total_flops = t5_Tflops + wan_Tflops + vae_Tflops
TFLOPS = total_flops * 3 / time_step
if accelerator.is_main_process:
postfix_dict = {
"Rank": f"{accelerator.process_index}",
"loss": f"{loss.item():.5f}",
"lr": f"{optimizer.param_groups[0]['lr']:.5e}",
"step/t": f"{time_step:.3f}",
"[t5] Tflops": f"{t5_Tflops:.3f}",
"[dit] Tflops": f"{wan_Tflops:.3f}",
"[vae] Tflops": f"{vae_Tflops:.3f}",
"TFLOPS": f"{TFLOPS:.3f}",
}
progress.set_postfix(postfix_dict)
log_msg = f"[Step {train_step:6d}] | " + " | ".join(f"{k}: {v}" for k, v in postfix_dict.items())
progress.write(log_msg)
if save_steps is None:
model_logger.on_epoch_end(accelerator, model, epoch_id)
unprofile_entire_model(model)
model_logger.on_training_end(accelerator, model, save_steps)
def launch_data_process_task(
accelerator: Accelerator,
dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule,
model_logger: ModelLogger,
num_workers: int = 8,
args = None,
):
if args is not None:
num_workers = args.dataset_num_workers
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
model.to(device=accelerator.device)
model, dataloader = accelerator.prepare(model, dataloader)
for data_id, data in enumerate(tqdm(dataloader)):
with accelerator.accumulate(model):
with torch.no_grad():
folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
os.makedirs(folder, exist_ok=True)
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
data = model(data)
torch.save(data, save_path)