33import json
44import math
55import os
6+ import random
67import shutil
78import traceback
89from collections .abc import Callable
2223from modules .util .commands .TrainCommands import TrainCommands
2324from modules .util .config .SampleConfig import SampleConfig
2425from modules .util .config .TrainConfig import TrainConfig
26+ from modules .util .dataset_fingerprint import compute_concept_fingerprint
2527from modules .util .dtype_util import create_grad_scaler , enable_grad_scaling
2628from modules .util .enum .ConceptType import ConceptType
2729from modules .util .enum .EMAMode import EMAMode
4244from torchvision .transforms .functional import pil_to_tensor
4345
4446import huggingface_hub
47+ import numpy as np
4548from requests .exceptions import ConnectionError
4649from tqdm import tqdm
4750
@@ -78,6 +81,11 @@ def __init__(self, config: TrainConfig, callbacks: TrainCallbacks, commands: Tra
7881 self .one_step_trained = False
7982 self .grad_hook_handles = []
8083
84+ # Loop locals mirrored so __backup/__save can read them without threading.
85+ self ._loop_accumulated_loss : float = 0.0
86+ self ._loop_accumulated_loss_tensor : torch .Tensor | None = None
87+ self ._loop_scaler = None
88+
8189 def start (self ):
8290 if multi .is_master ():
8391 self .__save_config_to_workspace ()
@@ -445,6 +453,7 @@ def __backup(self, train_progress: TrainProgress, print_msg: bool = True, print_
445453 if print_msg :
446454 print_cb ("Creating Backup " + backup_path )
447455
456+ self ._stage_accumulator_state_for_save ()
448457 self .model_saver .save (
449458 self .model ,
450459 self .config .model_type ,
@@ -464,6 +473,7 @@ def __backup(self, train_progress: TrainProgress, print_msg: bool = True, print_
464473 traceback .print_exc ()
465474 print ("Could not delete partial backup" )
466475 finally :
476+ self ._clear_staged_accumulator_state ()
467477 if self .config .rolling_backup :
468478 self .__prune_backups (self .config .rolling_backup_count )
469479
@@ -496,17 +506,20 @@ def __save(self, train_progress: TrainProgress, print_msg: bool = True, print_cb
496506 if self .config .optimizer .optimizer .is_schedule_free :
497507 torch .clear_autocast_cache ()
498508 self .model .optimizer .eval ()
509+ self ._stage_accumulator_state_for_save ()
499510 self .model_saver .save (
500511 model = self .model ,
501512 model_type = self .config .model_type ,
502513 output_model_format = self .config .output_model_format ,
503514 output_model_destination = save_path ,
504515 dtype = self .config .output_dtype .torch_dtype ()
505516 )
517+ self ._clear_staged_accumulator_state ()
506518 if self .config .optimizer .optimizer .is_schedule_free :
507519 torch .clear_autocast_cache ()
508520 self .model .optimizer .train ()
509521 except Exception :
522+ self ._clear_staged_accumulator_state ()
510523 traceback .print_exc ()
511524 print ("Could not save model. Check your disk space!" )
512525 try :
@@ -553,6 +566,142 @@ def __is_update_step(self, train_progress: TrainProgress) -> bool:
553566 "update_step" , self .config .gradient_accumulation_steps , TimeUnit .STEP , train_progress , start_at_zero = False
554567 )
555568
569+ def _stage_accumulator_state_for_save (self ):
570+ # Build the in-flight grad-accum snapshot for InternalModelSaverMixin.
571+ if not multi .is_master ():
572+ self .model .accumulator_state = None
573+ return
574+
575+ if self ._loop_accumulated_loss_tensor is not None and \
576+ isinstance (self ._loop_accumulated_loss_tensor , torch .Tensor ):
577+ try :
578+ acc_loss_f = float (self ._loop_accumulated_loss_tensor .item ())
579+ except Exception :
580+ acc_loss_f = float (self ._loop_accumulated_loss )
581+ else :
582+ acc_loss_f = float (self ._loop_accumulated_loss )
583+
584+ param_grads : dict [str , torch .Tensor ] = {}
585+ if self .model is not None and self .model .parameters is not None :
586+ for key , p in self .model .parameters .iter_named_parameters ():
587+ if not p .requires_grad or p .grad is None :
588+ continue
589+ param_grads [key ] = p .grad .detach ().to (device = "cpu" , copy = True )
590+
591+ scaler_state = None
592+ if self ._loop_scaler is not None :
593+ try :
594+ scaler_state = self ._loop_scaler .state_dict ()
595+ except Exception :
596+ scaler_state = None
597+
598+ rng : dict = {
599+ "torch_cpu" : torch .get_rng_state (),
600+ "torch_cuda" : torch .cuda .get_rng_state_all () if torch .cuda .is_available () else None ,
601+ "python" : random .getstate (),
602+ # Snapshots the GLOBAL numpy RNG; Generator-based snapshots don't round-trip with set_state.
603+ "numpy" : np .random .get_state (legacy = True ), # noqa: NPY002
604+ }
605+
606+ fp_hash , fp_count = compute_concept_fingerprint (
607+ getattr (self .config , "concepts" , None ),
608+ getattr (self .config , "concept_file_name" , None ),
609+ )
610+ self .model .accumulator_state = {
611+ "accumulated_loss" : acc_loss_f ,
612+ "param_grads" : param_grads ,
613+ "scaler" : scaler_state ,
614+ "rng" : rng ,
615+ "fingerprint" : {
616+ "gradient_accumulation_steps" : int (self .config .gradient_accumulation_steps ),
617+ "dataset_hash" : fp_hash ,
618+ "concept_count" : fp_count ,
619+ },
620+ }
621+
622+ def _clear_staged_accumulator_state (self ):
623+ if self .model is not None :
624+ self .model .accumulator_state = None
625+
626+ def _restore_accumulator_state (
627+ self ,
628+ accumulated_loss : torch .Tensor ,
629+ train_device : torch .device ,
630+ scaler ,
631+ ) -> tuple [torch .Tensor , bool ]:
632+ # Returns (accumulated_loss, has_gradient). Warn-only on mismatch; never discards state.
633+ if not multi .is_master ():
634+ return accumulated_loss , False
635+ state = getattr (self .model , "accumulator_state" , None )
636+ if state is None :
637+ return accumulated_loss , False
638+
639+ fp = state .get ("fingerprint" , {})
640+ saved_acc = fp .get ("gradient_accumulation_steps" )
641+ if saved_acc is not None and saved_acc != self .config .gradient_accumulation_steps :
642+ print (
643+ f"Warning: gradient_accumulation_steps mismatch on resume: "
644+ f"saved={ saved_acc } current={ self .config .gradient_accumulation_steps } ; "
645+ f"restoring partial accumulator state anyway."
646+ )
647+ current_hash , current_count = compute_concept_fingerprint (
648+ getattr (self .config , "concepts" , None ),
649+ getattr (self .config , "concept_file_name" , None ),
650+ )
651+ if fp .get ("dataset_hash" ) and fp .get ("dataset_hash" ) != current_hash :
652+ delta = current_count - int (fp .get ("concept_count" , current_count ))
653+ print (
654+ f"Warning: dataset fingerprint mismatch on resume: "
655+ f"saved_concepts={ fp .get ('concept_count' )} current_concepts={ current_count } "
656+ f"(delta={ delta } ); restoring partial accumulator state anyway."
657+ )
658+
659+ acc_loss_f = float (state .get ("accumulated_loss" , 0.0 ) or 0.0 )
660+ accumulated_loss = torch .tensor (acc_loss_f , device = train_device )
661+
662+ saved_grads : dict = state .get ("param_grads" , {}) or {}
663+ if self .model is not None and self .model .parameters is not None :
664+ current_keys = {k for k , _ in self .model .parameters .iter_named_parameters ()}
665+ missing = [k for k in saved_grads if k not in current_keys ]
666+ if saved_grads and len (missing ) / len (saved_grads ) > 0.10 :
667+ print (
668+ f"Warning: { len (missing )} of { len (saved_grads )} saved grad keys are "
669+ f"absent in the current model; skipping those grads."
670+ )
671+ applied = 0
672+ for key , p in self .model .parameters .iter_named_parameters ():
673+ if not p .requires_grad :
674+ continue
675+ if key in saved_grads :
676+ p .grad = saved_grads [key ].to (device = p .device , dtype = p .dtype , non_blocking = True )
677+ applied += 1
678+ else :
679+ p .grad = None
680+ has_gradient = applied > 0
681+ else :
682+ has_gradient = False
683+
684+ if scaler is not None and state .get ("scaler" ) is not None :
685+ try :
686+ scaler .load_state_dict (state ["scaler" ])
687+ except Exception :
688+ print ("Warning: could not restore GradScaler state; continuing with a fresh scaler." )
689+
690+ rng = state .get ("rng" , {}) or {}
691+ if "torch_cpu" in rng and rng ["torch_cpu" ] is not None :
692+ torch .set_rng_state (rng ["torch_cpu" ])
693+ if rng .get ("torch_cuda" ) is not None and torch .cuda .is_available ():
694+ with contextlib .suppress (Exception ):
695+ torch .cuda .set_rng_state_all (rng ["torch_cuda" ])
696+ if "python" in rng and rng ["python" ] is not None :
697+ random .setstate (rng ["python" ])
698+ if rng .get ("numpy" ) is not None :
699+ with contextlib .suppress (Exception ):
700+ np .random .set_state (rng ["numpy" ]) # noqa: NPY002
701+
702+ self .model .accumulator_state = None
703+ return accumulated_loss , has_gradient
704+
556705 def __apply_fused_back_pass (self , scaler ):
557706 fused_optimizer_step = self .config .optimizer .optimizer .supports_fused_back_pass () and self .config .optimizer .fused_back_pass
558707 fused_reduce = self .config .multi_gpu and self .config .fused_gradient_reduce
@@ -621,6 +770,7 @@ def train(self):
621770 return
622771
623772 scaler = create_grad_scaler () if enable_grad_scaling (self .config .train_dtype , self .parameters ) else None
773+ self ._loop_scaler = scaler # mirror so save-side staging can capture state_dict
624774
625775 self .__apply_fused_back_pass (scaler )
626776
@@ -634,6 +784,15 @@ def train(self):
634784 ema_loss_steps = 0
635785 epochs = range (train_progress .epoch , self .config .epochs , 1 )
636786
787+ # If resuming from a mid-window save, restore in-flight accumulator + grads + RNG.
788+ accumulated_loss , restored_has_grad = self ._restore_accumulator_state (
789+ accumulated_loss , train_device , scaler ,
790+ )
791+ if restored_has_grad :
792+ has_gradient = True
793+ self ._loop_accumulated_loss_tensor = accumulated_loss
794+ self ._loop_accumulated_loss = float (accumulated_loss .item ()) if accumulated_loss is not None else 0.0
795+
637796 for _epoch in tqdm (epochs , desc = "epoch" ) if multi .is_master () else epochs :
638797 multi .sync_commands (self .commands )
639798 if self .commands .get_stop_command ():
@@ -761,6 +920,7 @@ def sample_commands_fun():
761920 detached_loss = loss .detach ()
762921 multi .reduce_tensor_mean (detached_loss )
763922 accumulated_loss += detached_loss
923+ self ._loop_accumulated_loss_tensor = accumulated_loss # save-side stage mirror
764924
765925 if self .__is_update_step (train_progress ):
766926 if self .config .fused_gradient_reduce :
@@ -807,6 +967,8 @@ def sample_commands_fun():
807967 self .tensorboard .add_scalar ("smooth_loss/train_step" , ema_loss , train_progress .global_step )
808968
809969 accumulated_loss = 0.0
970+ self ._loop_accumulated_loss = 0.0 # clear save-side mirror at boundary
971+ self ._loop_accumulated_loss_tensor = None
810972 self .model_setup .after_optimizer_step (self .model , self .config , train_progress )
811973
812974 if self .model .ema :
0 commit comments