@@ -166,6 +166,9 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
166166 def prepare_inputs (self , data : "BlockState" ) -> List ["BlockState" ]:
167167 raise NotImplementedError ("BaseGuidance::prepare_inputs must be implemented in subclasses." )
168168
169+ def prepare_inputs_from_block_state (self , data : "BlockState" , input_fields : Dict [str , Union [str , Tuple [str , str ]]]) -> List ["BlockState" ]:
170+ raise NotImplementedError ("BaseGuidance::prepare_inputs_from_block_state must be implemented in subclasses." )
171+
169172 def __call__ (self , data : List ["BlockState" ]) -> Any :
170173 if not all (hasattr (d , "noise_pred" ) for d in data ):
171174 raise ValueError ("Expected all data to have `noise_pred` attribute." )
@@ -234,6 +237,53 @@ def _prepare_batch(
234237 data_batch [cls ._identifier_key ] = identifier
235238 return BlockState (** data_batch )
236239
240+
241+ @classmethod
242+ def _prepare_batch_from_block_state (
243+ cls ,
244+ input_fields : Dict [str , Union [str , Tuple [str , str ]]],
245+ data : "BlockState" ,
246+ tuple_index : int ,
247+ identifier : str ,
248+ ) -> "BlockState" :
249+ """
250+ Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
251+ `BaseGuidance` class. It prepares the batch based on the provided tuple index.
252+
253+ Args:
254+ input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
255+ A dictionary where the keys are the names of the fields that will be used to store the data once it is
256+ prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
257+ to look up the required data provided for preparation. If a string is provided, it will be used as the
258+ conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
259+ length 2 is provided, the first element must be the conditional data identifier and the second element
260+ must be the unconditional data identifier or None.
261+ data (`BlockState`):
262+ The input data to be prepared.
263+ tuple_index (`int`):
264+ The index to use when accessing input fields that are tuples.
265+
266+ Returns:
267+ `BlockState`: The prepared batch of data.
268+ """
269+ from ..modular_pipelines .modular_pipeline import BlockState
270+
271+
272+ data_batch = {}
273+ for key , value in input_fields .items ():
274+ try :
275+ if isinstance (value , str ):
276+ data_batch [key ] = getattr (data , value )
277+ elif isinstance (value , tuple ):
278+ data_batch [key ] = getattr (data , value [tuple_index ])
279+ else :
280+ # We've already checked that value is a string or a tuple of strings with length 2
281+ pass
282+ except AttributeError :
283+ logger .debug (f"`data` does not have attribute(s) { value } , skipping." )
284+ data_batch [cls ._identifier_key ] = identifier
285+ return BlockState (** data_batch )
286+
237287 @classmethod
238288 @validate_hf_hub_args
239289 def from_pretrained (
0 commit comments