@@ -566,3 +566,127 @@ def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=Non
566566 "layers_per_block" : 1 ,
567567 }
568568 return init_dict
569+
570+
571+ # Model with conditionally-executed modules, simulating Helios patch_short/patch_mid/patch_long behavior.
572+ # These modules are only called when optional inputs are provided, which means the lazy prefetch
573+ # execution order tracer may not see them on the first forward pass. This can cause a device mismatch
574+ # on subsequent calls when the modules ARE invoked but their weights were never onloaded.
575+ # See: https://github.com/huggingface/diffusers/pull/13211
576+ class DummyModelWithConditionalModules (ModelMixin ):
577+ def __init__ (self , in_features : int , hidden_features : int , out_features : int , num_layers : int ) -> None :
578+ super ().__init__ ()
579+
580+ self .linear_1 = torch .nn .Linear (in_features , hidden_features )
581+ self .activation = torch .nn .ReLU ()
582+ self .blocks = torch .nn .ModuleList (
583+ [DummyBlock (hidden_features , hidden_features , hidden_features ) for _ in range (num_layers )]
584+ )
585+ self .linear_2 = torch .nn .Linear (hidden_features , out_features )
586+
587+ # These modules are only invoked when optional_input is not None.
588+ # Output dimension matches hidden_features so they can be added after linear_1.
589+ self .optional_proj_1 = torch .nn .Linear (in_features , hidden_features )
590+ self .optional_proj_2 = torch .nn .Linear (in_features , hidden_features )
591+
592+ def forward (self , x : torch .Tensor , optional_input : torch .Tensor | None = None ) -> torch .Tensor :
593+ x = self .linear_1 (x )
594+ x = self .activation (x )
595+ if optional_input is not None :
596+ # Add optional projections after linear_1 so dimensions match (both hidden_features)
597+ x = x + self .optional_proj_1 (optional_input )
598+ x = x + self .optional_proj_2 (optional_input )
599+ for block in self .blocks :
600+ x = block (x )
601+ x = self .linear_2 (x )
602+ return x
603+
604+
605+ class ConditionalModuleGroupOffloadTests (GroupOffloadTests ):
606+ """Tests for conditionally-executed modules under group offloading with streams.
607+
608+ Regression tests for the case where a module is not executed during the first forward pass
609+ (when the lazy prefetch execution order is traced), but IS executed on subsequent passes.
610+ Without the fix, the weights of such modules remain on CPU while the input is on GPU,
611+ causing a RuntimeError about tensor device mismatch.
612+ """
613+
614+ def get_model (self ):
615+ torch .manual_seed (0 )
616+ return DummyModelWithConditionalModules (
617+ in_features = self .in_features ,
618+ hidden_features = self .hidden_features ,
619+ out_features = self .out_features ,
620+ num_layers = self .num_layers ,
621+ )
622+
623+ @parameterized .expand ([("leaf_level" ,), ("block_level" ,)])
624+ @unittest .skipIf (
625+ torch .device (torch_device ).type not in ["cuda" , "xpu" ],
626+ "Test requires a CUDA or XPU device." ,
627+ )
628+ def test_conditional_modules_with_stream (self , offload_type : str ):
629+ """Regression test: conditionally-executed modules must not cause device mismatch when using streams.
630+
631+ The model contains two optional Linear layers (optional_proj_1, optional_proj_2) that are only
632+ executed when `optional_input` is provided. This simulates modules like patch_short/patch_mid/
633+ patch_long in HeliosTransformer3DModel, which are only called when history latents are present.
634+
635+ When using streams, `LazyPrefetchGroupOffloadingHook` traces the execution order on the first
636+ forward pass and sets up a prefetch chain so each module pre-loads the next one's weights.
637+ Modules not executed during this tracing pass are excluded from the prefetch chain.
638+
639+ The bug: if a module was absent from the first (tracing) pass, its `onload_self` flag gets set
640+ to False (meaning "someone else will onload me"). But since it's not in the prefetch chain,
641+ nobody ever does — so its weights remain on CPU. When the module is eventually called in a
642+ subsequent pass, the input is on GPU but the weights are on CPU, causing a RuntimeError.
643+
644+ We therefore must invoke the model multiple times:
645+ 1. First pass WITHOUT optional_input: triggers the lazy prefetch tracing. optional_proj_1/2
646+ are absent, so they are excluded from the prefetch chain.
647+ 2. Second pass WITH optional_input: the regression case. Without the fix, this raises a
648+ RuntimeError because optional_proj_1/2 weights are still on CPU.
649+ 3. Third pass WITHOUT optional_input: verifies the model remains stable after having seen
650+ both code paths.
651+ """
652+
653+ model = self .get_model ()
654+ model_ref = self .get_model ()
655+ model_ref .load_state_dict (model .state_dict (), strict = True )
656+ model_ref .to (torch_device )
657+
658+ model .enable_group_offload (
659+ torch_device ,
660+ offload_type = offload_type ,
661+ num_blocks_per_group = 1 ,
662+ use_stream = True ,
663+ )
664+
665+ x = torch .randn (4 , self .in_features ).to (torch_device )
666+ optional_input = torch .randn (4 , self .in_features ).to (torch_device )
667+
668+ with torch .no_grad ():
669+ # First forward pass WITHOUT optional_input — this is when the lazy prefetch
670+ # execution order is traced. optional_proj_1/2 are NOT in the traced order.
671+ out_ref_no_opt = model_ref (x , optional_input = None )
672+ out_no_opt = model (x , optional_input = None )
673+ self .assertTrue (
674+ torch .allclose (out_ref_no_opt , out_no_opt , atol = 1e-5 ),
675+ f"[{ offload_type } ] Outputs do not match on first pass (no optional_input)." ,
676+ )
677+
678+ # Second forward pass WITH optional_input — optional_proj_1/2 ARE now called.
679+ out_ref_with_opt = model_ref (x , optional_input = optional_input )
680+ out_with_opt = model (x , optional_input = optional_input )
681+ self .assertTrue (
682+ torch .allclose (out_ref_with_opt , out_with_opt , atol = 1e-5 ),
683+ f"[{ offload_type } ] Outputs do not match on second pass (with optional_input)." ,
684+ )
685+
686+ # Third pass again without optional_input — verify stable behavior.
687+ out_ref_no_opt2 = model_ref (x , optional_input = None )
688+ out_no_opt2 = model (x , optional_input = None )
689+ self .assertTrue (
690+ torch .allclose (out_ref_no_opt2 , out_no_opt2 , atol = 1e-5 ),
691+ f"[{ offload_type } ] Outputs do not match on third pass (back to no optional_input)." ,
692+ )
0 commit comments