Skip to content

Commit acf5bff

Browse files
committed
Implement text encoder outputs extra conditioning data
1 parent a9a022b commit acf5bff

2 files changed

Lines changed: 14 additions & 6 deletions

File tree

batch_encoding.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,11 @@ def batch_encode_with_cache(clip_model, prompts, cond_cache, prompt_type="positi
7676
original_layer = clip_model.cond_stage_model.clip_layer
7777
clip_model.cond_stage_model.set_clip_options({"layer": clip_skip})
7878

79-
cond, pooled = clip_model.encode_from_tokens(tokens, return_pooled=True)
80-
conditioning = [[cond, {"pooled_output": pooled}]]
79+
# Use return_dict=True to preserve all extra conditioning keys
80+
# (e.g. t5xxl_ids/t5xxl_weights for Anima, attention_mask for Lumina, etc.)
81+
pooled_dict = clip_model.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
82+
cond = pooled_dict.pop("cond")
83+
conditioning = [[cond, pooled_dict]]
8184
results[prompt] = conditioning
8285
cond_cache.set(prompt, conditioning, prompt_type)
8386

generation_orchestrator.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -616,8 +616,11 @@ def run_generation_loop(
616616
patched_clip.cond_stage_model.set_clip_options({"layer": clip_skip})
617617

618618
tokens = patched_clip.tokenize(prompt)
619-
cond, pooled = patched_clip.encode_from_tokens(tokens, return_pooled=True)
620-
conditioning_cache["positive"][prompt] = [[cond, {"pooled_output": pooled}]]
619+
# Use return_dict=True to preserve all extra conditioning keys
620+
# (e.g. t5xxl_ids/t5xxl_weights for Anima, attention_mask for Lumina, etc.)
621+
pooled_dict = patched_clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
622+
cond = pooled_dict.pop("cond")
623+
conditioning_cache["positive"][prompt] = [[cond, pooled_dict]]
621624

622625
if original_layer is not None:
623626
patched_clip.cond_stage_model.set_clip_options({"layer": original_layer})
@@ -635,8 +638,10 @@ def run_generation_loop(
635638
patched_clip.cond_stage_model.set_clip_options({"layer": clip_skip})
636639

637640
tokens = patched_clip.tokenize(prompt)
638-
cond, pooled = patched_clip.encode_from_tokens(tokens, return_pooled=True)
639-
conditioning_cache["negative"][prompt] = [[cond, {"pooled_output": pooled}]]
641+
# Use return_dict=True to preserve all extra conditioning keys
642+
pooled_dict = patched_clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
643+
cond = pooled_dict.pop("cond")
644+
conditioning_cache["negative"][prompt] = [[cond, pooled_dict]]
640645

641646
if original_layer is not None:
642647
patched_clip.cond_stage_model.set_clip_options({"layer": original_layer})

0 commit comments

Comments
 (0)