Skip to content

Commit dcdc6b3

Browse files
authored
Merge pull request #537 from DrJKL/drjkl/feat/extract_condition_step
[Feat] Add node to pull a single condition out of a scheduled prompt condition
2 parents f02b113 + 10ba1fa commit dcdc6b3

3 files changed

Lines changed: 61 additions & 3 deletions

File tree

animatediff/nodes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from .nodes_ad_settings import (AnimateDiffSettingsNode, ManualAdjustPENode, SweetspotStretchPENode, FullStretchPENode,
4040
WeightAdjustAllAddNode, WeightAdjustAllMultNode, WeightAdjustIndivAddNode, WeightAdjustIndivMultNode,
4141
WeightAdjustIndivAttnAddNode, WeightAdjustIndivAttnMultNode)
42-
from .nodes_scheduling import (PromptSchedulingNode, PromptSchedulingLatentsNode, ValueSchedulingNode, ValueSchedulingLatentsNode,
42+
from .nodes_scheduling import (ConditionExtractionNode, PromptSchedulingNode, PromptSchedulingLatentsNode, ValueSchedulingNode, ValueSchedulingLatentsNode,
4343
AddValuesReplaceNode, FloatToFloatsNode)
4444
from .nodes_per_block import (ADBlockComboNode, ADBlockIndivNode, PerBlockHighLevelNode,
4545
PerBlock_SD15_LowLevelNode, PerBlock_SD15_MidLevelNode, PerBlock_SD15_FromFloatsNode,
@@ -165,6 +165,7 @@
165165
PromptSchedulingLatentsNode.NodeID: PromptSchedulingLatentsNode,
166166
ValueSchedulingNode.NodeID: ValueSchedulingNode,
167167
ValueSchedulingLatentsNode.NodeID: ValueSchedulingLatentsNode,
168+
ConditionExtractionNode.NodeID: ConditionExtractionNode,
168169
AddValuesReplaceNode.NodeID: AddValuesReplaceNode,
169170
FloatToFloatsNode.NodeID: FloatToFloatsNode,
170171
# Per-Block
@@ -345,6 +346,7 @@
345346
PromptSchedulingLatentsNode.NodeID: PromptSchedulingLatentsNode.NodeName,
346347
ValueSchedulingNode.NodeID: ValueSchedulingNode.NodeName,
347348
ValueSchedulingLatentsNode.NodeID: ValueSchedulingLatentsNode.NodeName,
349+
ConditionExtractionNode.NodeID: ConditionExtractionNode.NodeName,
348350
AddValuesReplaceNode.NodeID: AddValuesReplaceNode.NodeName,
349351
FloatToFloatsNode.NodeID:FloatToFloatsNode.NodeName,
350352
# Per-Block

animatediff/nodes_scheduling.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Union
22

33
from .documentation import register_description, short_desc, coll, DocHelper
4-
from .scheduling import (evaluate_prompt_schedule, evaluate_value_schedule, TensorInterp, PromptOptions,
4+
from .scheduling import (evaluate_prompt_schedule, evaluate_value_schedule, extract_cond_from_schedule, TensorInterp, PromptOptions,
55
verify_key_value)
66
from .utils_model import BIGMAX
77
from .logger import logger
@@ -24,6 +24,10 @@
2424
desc_value_key = {'value_key': 'Key to use for value schedule in Prompt Scheduling node. Can only contain a-z, A-Z, 0-9, and _ characters. In Prompt Scheduling, keys can be referred to as `some_key`, where the key is surrounded by ` characters.'}
2525
desc_prev_replace = {'prev_replace': 'OPTIONAL, other values_replace can be chained.'}
2626

27+
desc_input_conditioning = {'conditioning': 'Encoded prompts. The output of a Prompt Scheduling node.'}
28+
desc_index = {'index': 'The index to extract. Must be within the range [0,N] where N is the length of scheduled prompts.'}
29+
desc_output_conditioning_single = {'CONDITIONING': 'The single step conditioning from the schedule.'}
30+
2731
desc_output_conditioning = {'CONDITIONING': 'Encoded prompts.'}
2832
desc_output_latent = {'LATENT': 'Unmodified input latents; can be used as pipe, or can be ignored.'}
2933

@@ -280,4 +284,31 @@ def convert_to_floats(self, FLOAT: Union[float, list[float]]):
280284
floats = [float(FLOAT)]
281285
else:
282286
floats = list(FLOAT)
283-
return (floats,)
287+
return (floats,)
288+
289+
class ConditionExtractionNode:
290+
NodeID = 'ADE_ConditionExtraction'
291+
NodeName = 'Condition Step Extraction 🎭🅐🅓'
292+
@classmethod
293+
def INPUT_TYPES(s):
294+
return {
295+
"required": {
296+
"conditioning": ("CONDITIONING",),
297+
"index": ("INT", {"default": 0, "min": 0, "step": 1})
298+
},
299+
}
300+
301+
RETURN_TYPES = ("CONDITIONING",)
302+
CATEGORY = "Animate Diff 🎭🅐🅓/scheduling"
303+
FUNCTION = "extract_conditioning"
304+
305+
Desc = [
306+
short_desc('Extract a single conditioning step from a schedule of prompts.'),
307+
{coll('Inputs'): DocHelper.combine(desc_input_conditioning, desc_index)},
308+
{coll('Outputs'): DocHelper.combine(desc_output_conditioning)}
309+
]
310+
register_description(NodeID, Desc)
311+
312+
def extract_conditioning(self, conditioning, index: int=0):
313+
conditioning_step = extract_cond_from_schedule(conditioning, index)
314+
return (conditioning_step,)

animatediff/scheduling.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ class PromptOptions:
112112
print_schedule: bool = False
113113
add_dict: dict[str] = None
114114

115+
IndividualConditioning = tuple[torch.Tensor, dict[str, torch.Tensor]]
116+
Conditioning = list[IndividualConditioning]
115117

116118
def evaluate_prompt_schedule(text: str, length: int, clip: CLIP, options: PromptOptions):
117119
text = strip_input(text)
@@ -445,6 +447,29 @@ def _handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP
445447
clip.add_hooks_to_dict(final_pooled_dict)
446448
return [[final_cond, final_pooled_dict]]
447449

450+
def extract_cond_from_schedule(conditioning: Conditioning, index: int) -> Conditioning:
451+
return [_extract_single_cond(t, index) for t in conditioning]
452+
453+
def _extract_single_cond(single_cond: IndividualConditioning, index:int) -> IndividualConditioning:
454+
if index < 0:
455+
return single_cond
456+
457+
cond, kwargs = single_cond[0], single_cond[1].copy()
458+
original_pooled = kwargs["pooled_output"]
459+
460+
cond_schedules = cond.shape[0]
461+
pooled_schedules = original_pooled.shape[0]
462+
463+
if cond_schedules <= index or pooled_schedules <= index:
464+
logger.warning(f"Trying to get index {index}, only have {cond_schedules} items")
465+
return single_cond
466+
467+
cond_chunks = cond.chunk(cond_schedules)
468+
chosen_cond = cond_chunks[index]
469+
470+
pool_chunks = original_pooled.chunk(pooled_schedules)
471+
kwargs["pooled_output"] = pool_chunks[index]
472+
return [chosen_cond, kwargs]
448473

449474
def pad_cond(cond: Tensor, target_length: int):
450475
# FizzNodes-style cond padding

0 commit comments

Comments
 (0)