Skip to content

Commit bdb1f35

Browse files
trainer: persist accumulator + grads across stop/restart
When training is stopped mid-accumulation-window and resumed from a backup (end()/backup_before_save path), accumulated_loss was a function-local in GenericTrainer.train() so it reset to 0.0 on resume while train_progress.global_step was restored verbatim. The first post-resume update step then logged only the trailing micro-batches of the affected window (a sharp downward spike of factor (acc - k_lost)/acc) and the optimizer.step() at that boundary applied under-accumulated grads. This change persists the in-flight gradient-accumulation state: accumulated_loss, per-trainable-parameter .grad tensors keyed by NamedParameterGroup unique_name, GradScaler state, RNG snapshots (torch_cpu/cuda + python + numpy), and a cheap dataset fingerprint. Stored as accumulator/accumulator.pt alongside optimizer.pt and meta.json. On resume the trainer reattaches grads, restores the accumulator and RNG, and the next optimizer step proceeds as if the stop never happened. Mismatched fingerprints (changed concept set or different gradient_accumulation_steps) warn-only and restore anyway; losing accumulated gradient state is worse than an off-spec effective batch. Legacy backups without the new file load silently with the prior behavior. Stop semantics are deliberately untouched. Files modified: - modules/model/BaseModel.py: new accumulator_state attribute - modules/modelSaver/mixin/InternalModelSaverMixin.py: write accumulator/accumulator.pt when staged - modules/modelLoader/mixin/InternalModelLoaderMixin.py: read it under contextlib.suppress(FileNotFoundError) - modules/trainer/GenericTrainer.py: stage on backup/save, restore on train() entry, mirror loop-locals each iteration - modules/util/NamedParameterGroup.py: iter_named_parameters() yields stable (group_unique_name.idx, param) pairs - modules/util/dataset_fingerprint.py: SHA-256 over concept identifiers with concept_file_name fallback
1 parent 2698734 commit bdb1f35

6 files changed

Lines changed: 237 additions & 0 deletions

File tree

modules/model/BaseModel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class BaseModel(metaclass=ABCMeta):
7676
embedding_state_dicts: dict[str, dict[str, Tensor]] | None
7777
autocast_context: torch.autocast | nullcontext
7878
train_dtype: DataType
79+
accumulator_state: dict | None
7980

8081
def __init__(
8182
self,
@@ -93,6 +94,7 @@ def __init__(
9394
self.embedding_state_dicts = {}
9495
self.autocast_context = nullcontext()
9596
self.train_dtype = DataType.FLOAT_32
97+
self.accumulator_state = None
9698

9799
@abstractmethod
98100
def to(self, device: torch.device):

modules/modelLoader/mixin/InternalModelLoaderMixin.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,13 @@ def _load_internal_data(
3838
with contextlib.suppress(FileNotFoundError):
3939
model.ema_state_dict = torch.load(os.path.join(model_name, "ema", "ema.pt"), weights_only=True)
4040

41+
# Optional grad-accum snapshot; legacy backups without it fall through to defaults.
42+
# weights_only=False: payload mixes tensors with python dicts/tuples (RNG state).
43+
with contextlib.suppress(FileNotFoundError):
44+
model.accumulator_state = torch.load(
45+
os.path.join(model_name, "accumulator", "accumulator.pt"),
46+
weights_only=False,
47+
)
48+
4149
# meta
4250
model.train_progress = train_progress

modules/modelSaver/mixin/InternalModelSaverMixin.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,12 @@ def _save_internal_data(
4040
'global_step': model.train_progress.global_step,
4141
},
4242
}, meta_file)
43+
44+
# In-flight grad-accum snapshot; staged by the trainer, skipped on non-training paths.
45+
accumulator_state = getattr(model, "accumulator_state", None)
46+
if accumulator_state is not None:
47+
os.makedirs(os.path.join(destination, "accumulator"), exist_ok=True)
48+
torch.save(
49+
accumulator_state,
50+
os.path.join(destination, "accumulator", "accumulator.pt"),
51+
)

modules/trainer/GenericTrainer.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import math
55
import os
6+
import random
67
import shutil
78
import traceback
89
from collections.abc import Callable
@@ -22,6 +23,7 @@
2223
from modules.util.commands.TrainCommands import TrainCommands
2324
from modules.util.config.SampleConfig import SampleConfig
2425
from modules.util.config.TrainConfig import TrainConfig
26+
from modules.util.dataset_fingerprint import compute_concept_fingerprint
2527
from modules.util.dtype_util import create_grad_scaler, enable_grad_scaling
2628
from modules.util.enum.ConceptType import ConceptType
2729
from modules.util.enum.EMAMode import EMAMode
@@ -42,6 +44,7 @@
4244
from torchvision.transforms.functional import pil_to_tensor
4345

4446
import huggingface_hub
47+
import numpy as np
4548
from requests.exceptions import ConnectionError
4649
from tqdm import tqdm
4750

@@ -78,6 +81,11 @@ def __init__(self, config: TrainConfig, callbacks: TrainCallbacks, commands: Tra
7881
self.one_step_trained = False
7982
self.grad_hook_handles = []
8083

84+
# Loop locals mirrored so __backup/__save can read them without threading.
85+
self._loop_accumulated_loss: float = 0.0
86+
self._loop_accumulated_loss_tensor: torch.Tensor | None = None
87+
self._loop_scaler = None
88+
8189
def start(self):
8290
if multi.is_master():
8391
self.__save_config_to_workspace()
@@ -445,6 +453,7 @@ def __backup(self, train_progress: TrainProgress, print_msg: bool = True, print_
445453
if print_msg:
446454
print_cb("Creating Backup " + backup_path)
447455

456+
self._stage_accumulator_state_for_save()
448457
self.model_saver.save(
449458
self.model,
450459
self.config.model_type,
@@ -464,6 +473,7 @@ def __backup(self, train_progress: TrainProgress, print_msg: bool = True, print_
464473
traceback.print_exc()
465474
print("Could not delete partial backup")
466475
finally:
476+
self._clear_staged_accumulator_state()
467477
if self.config.rolling_backup:
468478
self.__prune_backups(self.config.rolling_backup_count)
469479

@@ -496,17 +506,20 @@ def __save(self, train_progress: TrainProgress, print_msg: bool = True, print_cb
496506
if self.config.optimizer.optimizer.is_schedule_free:
497507
torch.clear_autocast_cache()
498508
self.model.optimizer.eval()
509+
self._stage_accumulator_state_for_save()
499510
self.model_saver.save(
500511
model=self.model,
501512
model_type=self.config.model_type,
502513
output_model_format=self.config.output_model_format,
503514
output_model_destination=save_path,
504515
dtype=self.config.output_dtype.torch_dtype()
505516
)
517+
self._clear_staged_accumulator_state()
506518
if self.config.optimizer.optimizer.is_schedule_free:
507519
torch.clear_autocast_cache()
508520
self.model.optimizer.train()
509521
except Exception:
522+
self._clear_staged_accumulator_state()
510523
traceback.print_exc()
511524
print("Could not save model. Check your disk space!")
512525
try:
@@ -553,6 +566,142 @@ def __is_update_step(self, train_progress: TrainProgress) -> bool:
553566
"update_step", self.config.gradient_accumulation_steps, TimeUnit.STEP, train_progress, start_at_zero=False
554567
)
555568

569+
def _stage_accumulator_state_for_save(self):
570+
# Build the in-flight grad-accum snapshot for InternalModelSaverMixin.
571+
if not multi.is_master():
572+
self.model.accumulator_state = None
573+
return
574+
575+
if self._loop_accumulated_loss_tensor is not None and \
576+
isinstance(self._loop_accumulated_loss_tensor, torch.Tensor):
577+
try:
578+
acc_loss_f = float(self._loop_accumulated_loss_tensor.item())
579+
except Exception:
580+
acc_loss_f = float(self._loop_accumulated_loss)
581+
else:
582+
acc_loss_f = float(self._loop_accumulated_loss)
583+
584+
param_grads: dict[str, torch.Tensor] = {}
585+
if self.model is not None and self.model.parameters is not None:
586+
for key, p in self.model.parameters.iter_named_parameters():
587+
if not p.requires_grad or p.grad is None:
588+
continue
589+
param_grads[key] = p.grad.detach().to(device="cpu", copy=True)
590+
591+
scaler_state = None
592+
if self._loop_scaler is not None:
593+
try:
594+
scaler_state = self._loop_scaler.state_dict()
595+
except Exception:
596+
scaler_state = None
597+
598+
rng: dict = {
599+
"torch_cpu": torch.get_rng_state(),
600+
"torch_cuda": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
601+
"python": random.getstate(),
602+
# Snapshots the GLOBAL numpy RNG; Generator-based snapshots don't round-trip with set_state.
603+
"numpy": np.random.get_state(legacy=True), # noqa: NPY002
604+
}
605+
606+
fp_hash, fp_count = compute_concept_fingerprint(
607+
getattr(self.config, "concepts", None),
608+
getattr(self.config, "concept_file_name", None),
609+
)
610+
self.model.accumulator_state = {
611+
"accumulated_loss": acc_loss_f,
612+
"param_grads": param_grads,
613+
"scaler": scaler_state,
614+
"rng": rng,
615+
"fingerprint": {
616+
"gradient_accumulation_steps": int(self.config.gradient_accumulation_steps),
617+
"dataset_hash": fp_hash,
618+
"concept_count": fp_count,
619+
},
620+
}
621+
622+
def _clear_staged_accumulator_state(self):
623+
if self.model is not None:
624+
self.model.accumulator_state = None
625+
626+
def _restore_accumulator_state(
627+
self,
628+
accumulated_loss: torch.Tensor,
629+
train_device: torch.device,
630+
scaler,
631+
) -> tuple[torch.Tensor, bool]:
632+
# Returns (accumulated_loss, has_gradient). Warn-only on mismatch; never discards state.
633+
if not multi.is_master():
634+
return accumulated_loss, False
635+
state = getattr(self.model, "accumulator_state", None)
636+
if state is None:
637+
return accumulated_loss, False
638+
639+
fp = state.get("fingerprint", {})
640+
saved_acc = fp.get("gradient_accumulation_steps")
641+
if saved_acc is not None and saved_acc != self.config.gradient_accumulation_steps:
642+
print(
643+
f"Warning: gradient_accumulation_steps mismatch on resume: "
644+
f"saved={saved_acc} current={self.config.gradient_accumulation_steps}; "
645+
f"restoring partial accumulator state anyway."
646+
)
647+
current_hash, current_count = compute_concept_fingerprint(
648+
getattr(self.config, "concepts", None),
649+
getattr(self.config, "concept_file_name", None),
650+
)
651+
if fp.get("dataset_hash") and fp.get("dataset_hash") != current_hash:
652+
delta = current_count - int(fp.get("concept_count", current_count))
653+
print(
654+
f"Warning: dataset fingerprint mismatch on resume: "
655+
f"saved_concepts={fp.get('concept_count')} current_concepts={current_count} "
656+
f"(delta={delta}); restoring partial accumulator state anyway."
657+
)
658+
659+
acc_loss_f = float(state.get("accumulated_loss", 0.0) or 0.0)
660+
accumulated_loss = torch.tensor(acc_loss_f, device=train_device)
661+
662+
saved_grads: dict = state.get("param_grads", {}) or {}
663+
if self.model is not None and self.model.parameters is not None:
664+
current_keys = {k for k, _ in self.model.parameters.iter_named_parameters()}
665+
missing = [k for k in saved_grads if k not in current_keys]
666+
if saved_grads and len(missing) / len(saved_grads) > 0.10:
667+
print(
668+
f"Warning: {len(missing)} of {len(saved_grads)} saved grad keys are "
669+
f"absent in the current model; skipping those grads."
670+
)
671+
applied = 0
672+
for key, p in self.model.parameters.iter_named_parameters():
673+
if not p.requires_grad:
674+
continue
675+
if key in saved_grads:
676+
p.grad = saved_grads[key].to(device=p.device, dtype=p.dtype, non_blocking=True)
677+
applied += 1
678+
else:
679+
p.grad = None
680+
has_gradient = applied > 0
681+
else:
682+
has_gradient = False
683+
684+
if scaler is not None and state.get("scaler") is not None:
685+
try:
686+
scaler.load_state_dict(state["scaler"])
687+
except Exception:
688+
print("Warning: could not restore GradScaler state; continuing with a fresh scaler.")
689+
690+
rng = state.get("rng", {}) or {}
691+
if "torch_cpu" in rng and rng["torch_cpu"] is not None:
692+
torch.set_rng_state(rng["torch_cpu"])
693+
if rng.get("torch_cuda") is not None and torch.cuda.is_available():
694+
with contextlib.suppress(Exception):
695+
torch.cuda.set_rng_state_all(rng["torch_cuda"])
696+
if "python" in rng and rng["python"] is not None:
697+
random.setstate(rng["python"])
698+
if rng.get("numpy") is not None:
699+
with contextlib.suppress(Exception):
700+
np.random.set_state(rng["numpy"]) # noqa: NPY002
701+
702+
self.model.accumulator_state = None
703+
return accumulated_loss, has_gradient
704+
556705
def __apply_fused_back_pass(self, scaler):
557706
fused_optimizer_step = self.config.optimizer.optimizer.supports_fused_back_pass() and self.config.optimizer.fused_back_pass
558707
fused_reduce = self.config.multi_gpu and self.config.fused_gradient_reduce
@@ -621,6 +770,7 @@ def train(self):
621770
return
622771

623772
scaler = create_grad_scaler() if enable_grad_scaling(self.config.train_dtype, self.parameters) else None
773+
self._loop_scaler = scaler # mirror so save-side staging can capture state_dict
624774

625775
self.__apply_fused_back_pass(scaler)
626776

@@ -634,6 +784,15 @@ def train(self):
634784
ema_loss_steps = 0
635785
epochs = range(train_progress.epoch, self.config.epochs, 1)
636786

787+
# If resuming from a mid-window save, restore in-flight accumulator + grads + RNG.
788+
accumulated_loss, restored_has_grad = self._restore_accumulator_state(
789+
accumulated_loss, train_device, scaler,
790+
)
791+
if restored_has_grad:
792+
has_gradient = True
793+
self._loop_accumulated_loss_tensor = accumulated_loss
794+
self._loop_accumulated_loss = float(accumulated_loss.item()) if accumulated_loss is not None else 0.0
795+
637796
for _epoch in tqdm(epochs, desc="epoch") if multi.is_master() else epochs:
638797
multi.sync_commands(self.commands)
639798
if self.commands.get_stop_command():
@@ -761,6 +920,7 @@ def sample_commands_fun():
761920
detached_loss = loss.detach()
762921
multi.reduce_tensor_mean(detached_loss)
763922
accumulated_loss += detached_loss
923+
self._loop_accumulated_loss_tensor = accumulated_loss # save-side stage mirror
764924

765925
if self.__is_update_step(train_progress):
766926
if self.config.fused_gradient_reduce:
@@ -807,6 +967,8 @@ def sample_commands_fun():
807967
self.tensorboard.add_scalar("smooth_loss/train_step", ema_loss, train_progress.global_step)
808968

809969
accumulated_loss = 0.0
970+
self._loop_accumulated_loss = 0.0 # clear save-side mirror at boundary
971+
self._loop_accumulated_loss_tensor = None
810972
self.model_setup.after_optimizer_step(self.model, self.config, train_progress)
811973

812974
if self.model.ema:

modules/util/NamedParameterGroup.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ def add_group(self, group: NamedParameterGroup):
3232
def parameters(self) -> list[Parameter]:
3333
return [p for group in self.__groups for p in group.parameters]
3434

35+
def iter_named_parameters(self) -> Iterable[tuple[str, Parameter]]:
36+
# Stable per-parameter keys for accumulator-state save/load.
37+
for group in self.__groups:
38+
for i, p in enumerate(group.parameters):
39+
yield f"{group.unique_name}.{i}", p
40+
3541
def parameters_for_optimizer(self, config: TrainConfig) -> list[dict]:
3642
parameters = []
3743

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""SHA-256 fingerprint of the configured dataset, warn-only on resume mismatch."""
2+
from __future__ import annotations
3+
4+
import hashlib
5+
import json
6+
import os
7+
from collections.abc import Iterable
8+
9+
from modules.util.config.ConceptConfig import ConceptConfig
10+
11+
12+
def _identifier_tuple(c) -> tuple:
13+
def g(name, default):
14+
if hasattr(c, name):
15+
return getattr(c, name)
16+
if isinstance(c, dict):
17+
return c.get(name, default)
18+
return default
19+
20+
raw_type = g('type', '')
21+
type_str = getattr(raw_type, 'value', raw_type)
22+
return (
23+
str(g('name', '') or ''),
24+
str(g('path', '') or ''),
25+
int(g('seed', 0) or 0),
26+
str(type_str or ''),
27+
bool(g('include_subdirectories', False)),
28+
bool(g('enabled', True)),
29+
)
30+
31+
32+
def compute_concept_fingerprint(
33+
concepts: Iterable[ConceptConfig] | Iterable[dict] | None,
34+
concept_file_name: str | None = None,
35+
) -> tuple[str, int]:
36+
items: list = []
37+
if concepts:
38+
items = list(concepts)
39+
elif concept_file_name and os.path.exists(concept_file_name):
40+
# Mirrors TrainConfig.to_pack_dict: under the GUI, concepts live in a file.
41+
try:
42+
with open(concept_file_name, 'r') as f:
43+
items = json.load(f) or []
44+
except (OSError, ValueError):
45+
items = []
46+
47+
payload = [_identifier_tuple(c) for c in items]
48+
payload.sort(key=lambda t: t[1])
49+
blob = json.dumps(payload, separators=(',', ':'), sort_keys=False).encode()
50+
return hashlib.sha256(blob).hexdigest(), len(payload)

0 commit comments

Comments
 (0)