-
Notifications
You must be signed in to change notification settings - Fork 7k
Expand file tree
/
Copy pathencoders.py
More file actions
264 lines (230 loc) · 9.6 KB
/
encoders.py
File metadata and controls
264 lines (230 loc) · 9.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
# Copyright 2025 Baidu ERNIE-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import torch
from transformers import AutoTokenizer, Mistral3Model
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...utils import logging
from ...utils.import_utils import is_transformers_version
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import ErnieImageModularPipeline
if is_transformers_version("<", "5.0.0"):
raise ImportError("`ErnieImageModularPipeline` requires `transformers>=5.0.0` for `Ministral3ForCausalLM`.")
from transformers import Ministral3ForCausalLM # noqa: E402
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class ErnieImagePromptEnhancerStep(ModularPipelineBlocks):
model_name = "ernie-image"
@property
def description(self) -> str:
return "Prompt enhancer step that rewrites the input prompt using a causal language model (PE)."
@property
def expected_components(self) -> list[ComponentSpec]:
return [
ComponentSpec("pe", Ministral3ForCausalLM),
ComponentSpec("pe_tokenizer", AutoTokenizer),
]
@property
def inputs(self) -> list[InputParam]:
return [
InputParam(
"prompt",
required=True,
type_hint=str,
description="The prompt or prompts to guide image generation.",
),
InputParam("height", type_hint=int, description="The height in pixels of the generated image."),
InputParam("width", type_hint=int, description="The width in pixels of the generated image."),
InputParam(
"pe_system_prompt",
type_hint=str,
default=None,
description="Optional system prompt passed to the prompt enhancer.",
),
InputParam(
"pe_temperature",
type_hint=float,
default=0.6,
description="Sampling temperature used when generating with the prompt enhancer.",
),
InputParam(
"pe_top_p",
type_hint=float,
default=0.95,
description="Nucleus sampling `top_p` used when generating with the prompt enhancer.",
),
]
@property
def intermediate_outputs(self) -> list[OutputParam]:
return [
OutputParam("prompt", type_hint=list, description="The prompt list after prompt-enhancer rewriting."),
OutputParam("height", type_hint=int, description="The resolved image height in pixels."),
OutputParam("width", type_hint=int, description="The resolved image width in pixels."),
]
@staticmethod
def _enhance_prompt(
pe: Ministral3ForCausalLM,
pe_tokenizer: AutoTokenizer,
prompt: str,
device: torch.device,
width: int,
height: int,
system_prompt: str | None,
temperature: float,
top_p: float,
) -> str:
user_content = json.dumps({"prompt": prompt, "width": width, "height": height}, ensure_ascii=False)
messages = []
if system_prompt is not None:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": user_content})
input_text = pe_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
inputs = pe_tokenizer(input_text, return_tensors="pt").to(device)
output_ids = pe.generate(
**inputs,
max_new_tokens=pe_tokenizer.model_max_length,
do_sample=temperature != 1.0 or top_p != 1.0,
temperature=temperature,
top_p=top_p,
pad_token_id=pe_tokenizer.pad_token_id,
eos_token_id=pe_tokenizer.eos_token_id,
)
generated_ids = output_ids[0][inputs["input_ids"].shape[1] :]
return pe_tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
@torch.no_grad()
def __call__(self, components: ErnieImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
prompt = block_state.prompt
if isinstance(prompt, str):
prompt = [prompt]
height = block_state.height or components.default_height
width = block_state.width or components.default_width
revised = [
self._enhance_prompt(
pe=components.pe,
pe_tokenizer=components.pe_tokenizer,
prompt=p,
device=device,
width=width,
height=height,
system_prompt=block_state.pe_system_prompt,
temperature=block_state.pe_temperature,
top_p=block_state.pe_top_p,
)
for p in prompt
]
block_state.prompt = revised
block_state.height = height
block_state.width = width
self.set_block_state(state, block_state)
return components, state
class ErnieImageTextEncoderStep(ModularPipelineBlocks):
model_name = "ernie-image"
@property
def description(self) -> str:
return (
"Text encoder step that encodes prompts into variable-length hidden states for the ErnieImage transformer."
)
@property
def expected_components(self) -> list[ComponentSpec]:
return [
ComponentSpec("text_encoder", Mistral3Model),
ComponentSpec("tokenizer", AutoTokenizer),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 4.0}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> list[InputParam]:
return [
InputParam("prompt", type_hint=str, description="The prompt or prompts to guide image generation."),
InputParam(
"negative_prompt",
type_hint=str,
description="The prompt or prompts to avoid during image generation.",
),
]
@property
def intermediate_outputs(self) -> list[OutputParam]:
return [
OutputParam(
"prompt_embeds",
type_hint=list,
kwargs_type="denoiser_input_fields",
description="List of per-prompt text embeddings of shape (T, H).",
),
OutputParam(
"negative_prompt_embeds",
type_hint=list,
kwargs_type="denoiser_input_fields",
description="List of per-prompt negative text embeddings for classifier-free guidance.",
),
]
@staticmethod
def _encode(
text_encoder: Mistral3Model,
tokenizer: AutoTokenizer,
prompt: list[str],
device: torch.device,
) -> list[torch.Tensor]:
text_hiddens = []
for p in prompt:
ids = tokenizer(p, add_special_tokens=True, truncation=True, padding=False)["input_ids"]
if len(ids) == 0:
ids = [tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 0]
input_ids = torch.tensor([ids], device=device)
outputs = text_encoder(input_ids=input_ids, output_hidden_states=True)
text_hiddens.append(outputs.hidden_states[-2][0])
return text_hiddens
@torch.no_grad()
def __call__(self, components: ErnieImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
prompt = block_state.prompt
if prompt is None:
prompt = [""]
if isinstance(prompt, str):
prompt = [prompt]
block_state.prompt_embeds = self._encode(
text_encoder=components.text_encoder,
tokenizer=components.tokenizer,
prompt=prompt,
device=device,
)
if components.requires_unconditional_embeds:
negative_prompt = block_state.negative_prompt
if negative_prompt is None:
negative_prompt = ""
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * len(prompt)
if len(negative_prompt) != len(prompt):
raise ValueError(
f"`negative_prompt` must have the same length as `prompt` ({len(prompt)}), "
f"got {len(negative_prompt)}."
)
block_state.negative_prompt_embeds = self._encode(
text_encoder=components.text_encoder,
tokenizer=components.tokenizer,
prompt=negative_prompt,
device=device,
)
else:
block_state.negative_prompt_embeds = None
self.set_block_state(state, block_state)
return components, state