Skip to content

Commit 921185c

Browse files
committed
support image2video
1 parent 84e030c commit 921185c

17 files changed

Lines changed: 973 additions & 150 deletions

src/diffusers/guiders/adaptive_projected_guidance.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
16+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
1717

1818
import torch
1919

@@ -88,6 +88,17 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
8888
data_batches.append(data_batch)
8989
return data_batches
9090

91+
def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]:
92+
if self._step == 0:
93+
if self.adaptive_projected_guidance_momentum is not None:
94+
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
95+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
96+
data_batches = []
97+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
98+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
99+
data_batches.append(data_batch)
100+
return data_batches
101+
91102
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
92103
pred = None
93104

src/diffusers/guiders/adaptive_projected_guidance_mix.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
16+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
1717

1818
import torch
1919

@@ -99,6 +99,21 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
9999
data_batches.append(data_batch)
100100
return data_batches
101101

102+
103+
def prepare_inputs_from_block_state(
104+
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
105+
) -> List["BlockState"]:
106+
107+
if self._step == 0:
108+
if self.adaptive_projected_guidance_momentum is not None:
109+
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
110+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
111+
data_batches = []
112+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
113+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
114+
data_batches.append(data_batch)
115+
return data_batches
116+
102117
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
103118
pred = None
104119

src/diffusers/guiders/auto_guidance.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,15 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
141141
data_batches.append(data_batch)
142142
return data_batches
143143

144+
def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]:
145+
146+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
147+
data_batches = []
148+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
149+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
150+
data_batches.append(data_batch)
151+
return data_batches
152+
144153
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
145154
pred = None
146155

src/diffusers/guiders/classifier_free_guidance.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
16+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
1717

1818
import torch
1919

@@ -99,6 +99,14 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
9999
data_batches.append(data_batch)
100100
return data_batches
101101

102+
def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]:
103+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
104+
data_batches = []
105+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
106+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
107+
data_batches.append(data_batch)
108+
return data_batches
109+
102110
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
103111
pred = None
104112

src/diffusers/guiders/classifier_free_zero_star_guidance.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
16+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
1717

1818
import torch
1919

@@ -85,6 +85,14 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
8585
data_batches.append(data_batch)
8686
return data_batches
8787

88+
def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]:
89+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
90+
data_batches = []
91+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
92+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
93+
data_batches.append(data_batch)
94+
return data_batches
95+
8896
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
8997
pred = None
9098

src/diffusers/guiders/frequency_decoupled_guidance.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,14 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
225225
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
226226
data_batches.append(data_batch)
227227
return data_batches
228+
229+
def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]:
230+
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
231+
data_batches = []
232+
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
233+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
234+
data_batches.append(data_batch)
235+
return data_batches
228236

229237
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
230238
pred = None

src/diffusers/guiders/guider_utils.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
166166
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
167167
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
168168

169+
def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]:
170+
raise NotImplementedError("BaseGuidance::prepare_inputs_from_block_state must be implemented in subclasses.")
171+
169172
def __call__(self, data: List["BlockState"]) -> Any:
170173
if not all(hasattr(d, "noise_pred") for d in data):
171174
raise ValueError("Expected all data to have `noise_pred` attribute.")
@@ -234,6 +237,53 @@ def _prepare_batch(
234237
data_batch[cls._identifier_key] = identifier
235238
return BlockState(**data_batch)
236239

240+
241+
@classmethod
242+
def _prepare_batch_from_block_state(
243+
cls,
244+
input_fields: Dict[str, Union[str, Tuple[str, str]]],
245+
data: "BlockState",
246+
tuple_index: int,
247+
identifier: str,
248+
) -> "BlockState":
249+
"""
250+
Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
251+
`BaseGuidance` class. It prepares the batch based on the provided tuple index.
252+
253+
Args:
254+
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
255+
A dictionary where the keys are the names of the fields that will be used to store the data once it is
256+
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
257+
to look up the required data provided for preparation. If a string is provided, it will be used as the
258+
conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
259+
length 2 is provided, the first element must be the conditional data identifier and the second element
260+
must be the unconditional data identifier or None.
261+
data (`BlockState`):
262+
The input data to be prepared.
263+
tuple_index (`int`):
264+
The index to use when accessing input fields that are tuples.
265+
266+
Returns:
267+
`BlockState`: The prepared batch of data.
268+
"""
269+
from ..modular_pipelines.modular_pipeline import BlockState
270+
271+
272+
data_batch = {}
273+
for key, value in input_fields.items():
274+
try:
275+
if isinstance(value, str):
276+
data_batch[key] = getattr(data, value)
277+
elif isinstance(value, tuple):
278+
data_batch[key] = getattr(data, value[tuple_index])
279+
else:
280+
# We've already checked that value is a string or a tuple of strings with length 2
281+
pass
282+
except AttributeError:
283+
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
284+
data_batch[cls._identifier_key] = identifier
285+
return BlockState(**data_batch)
286+
237287
@classmethod
238288
@validate_hf_hub_args
239289
def from_pretrained(

src/diffusers/guiders/perturbed_attention_guidance.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,24 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
186186
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
187187
data_batches.append(data_batch)
188188
return data_batches
189+
190+
def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]:
191+
if self.num_conditions == 1:
192+
tuple_indices = [0]
193+
input_predictions = ["pred_cond"]
194+
elif self.num_conditions == 2:
195+
tuple_indices = [0, 1]
196+
input_predictions = (
197+
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
198+
)
199+
else:
200+
tuple_indices = [0, 1, 0]
201+
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
202+
data_batches = []
203+
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
204+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
205+
data_batches.append(data_batch)
206+
return data_batches
189207

190208
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
191209
def forward(

src/diffusers/guiders/skip_layer_guidance.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,24 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
183183
data_batches.append(data_batch)
184184
return data_batches
185185

186+
def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]:
187+
if self.num_conditions == 1:
188+
tuple_indices = [0]
189+
input_predictions = ["pred_cond"]
190+
elif self.num_conditions == 2:
191+
tuple_indices = [0, 1]
192+
input_predictions = (
193+
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
194+
)
195+
else:
196+
tuple_indices = [0, 1, 0]
197+
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
198+
data_batches = []
199+
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
200+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
201+
data_batches.append(data_batch)
202+
return data_batches
203+
186204
def forward(
187205
self,
188206
pred_cond: torch.Tensor,

src/diffusers/guiders/smoothed_energy_guidance.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,24 @@ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) ->
172172
data_batches.append(data_batch)
173173
return data_batches
174174

175+
def prepare_inputs_from_block_state(self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]) -> List["BlockState"]:
176+
if self.num_conditions == 1:
177+
tuple_indices = [0]
178+
input_predictions = ["pred_cond"]
179+
elif self.num_conditions == 2:
180+
tuple_indices = [0, 1]
181+
input_predictions = (
182+
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
183+
)
184+
else:
185+
tuple_indices = [0, 1, 0]
186+
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
187+
data_batches = []
188+
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
189+
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
190+
data_batches.append(data_batch)
191+
return data_batches
192+
175193
def forward(
176194
self,
177195
pred_cond: torch.Tensor,

0 commit comments

Comments
 (0)