|
17 | 17 | from transformers import Qwen2Tokenizer, Qwen3VLModel |
18 | 18 | from transformers.masking_utils import create_causal_mask |
19 | 19 |
|
20 | | -from ...utils import logging |
| 20 | +from ...pipelines.ideogram4.prompt_enhancer import ( |
| 21 | + PROMPT_UPSAMPLE_TEMPERATURE, |
| 22 | + Ideogram4PromptEnhancerHead, |
| 23 | + build_caption_logits_processor, |
| 24 | + build_prompt_enhancer, |
| 25 | + generate_captions, |
| 26 | +) |
| 27 | +from ...utils import is_outlines_available, logging |
21 | 28 | from ..modular_pipeline import ModularPipelineBlocks, PipelineState |
22 | 29 | from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam |
23 | 30 | from .modular_pipeline import Ideogram4ModularPipeline |
|
31 | 38 | QWEN3_VL_ACTIVATION_LAYERS = (0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 35) |
32 | 39 |
|
33 | 40 |
|
| 41 | +# auto_docstring |
| 42 | +class Ideogram4PromptUpsampleStep(ModularPipelineBlocks): |
| 43 | + """ |
| 44 | + Optional step that rewrites the prompt(s) into Ideogram4's native structured JSON caption (the format the model is |
| 45 | + trained on) when ``prompt_upsampling=True``. Requires the optional ``prompt_enhancer_head`` component, which is |
| 46 | + grafted onto the shared ``text_encoder`` body to make it generative; install ``outlines`` for schema-constrained |
| 47 | + captions. |
| 48 | +
|
| 49 | + Components: |
| 50 | + text_encoder (`Qwen3VLModel`): The Qwen3-VL text encoder. tokenizer (`Qwen2Tokenizer`): The tokenizer paired |
| 51 | + with the text encoder. prompt_enhancer_head (`Ideogram4PromptEnhancerHead`): The LM head grafted onto the |
| 52 | + text encoder for upsampling. |
| 53 | +
|
| 54 | + Inputs: |
| 55 | + prompt (`str`): |
| 56 | + The prompt or prompts to guide image generation. |
| 57 | + prompt_upsampling (`bool`, *optional*, defaults to False): |
| 58 | + If True, rewrite the prompt into the native JSON caption before encoding. |
| 59 | + prompt_upsampling_temperature (`float`, *optional*, defaults to 1.0): |
| 60 | + Sampling temperature for prompt upsampling. |
| 61 | + height (`int`, *optional*): |
| 62 | + Together with width, sets the caption's target aspect ratio. |
| 63 | + width (`int`, *optional*): |
| 64 | + Together with height, sets the caption's target aspect ratio. |
| 65 | + generator (`Generator`, *optional*): |
| 66 | + Reused to make the upsampling reproducible. |
| 67 | +
|
| 68 | + Outputs: |
| 69 | + prompt (`str`): |
| 70 | + The (possibly upsampled) prompt forwarded to the text encoder. |
| 71 | + """ |
| 72 | + |
| 73 | + model_name = "ideogram4" |
| 74 | + |
| 75 | + def __init__(self): |
| 76 | + # Built lazily on first upsample: the head-less encoder body + `prompt_enhancer_head`, combined. |
| 77 | + self._prompt_enhancer = None |
| 78 | + # Outlines logits processor for schema-constrained captions; built lazily on first upsample. |
| 79 | + self._caption_logits_processor = None |
| 80 | + super().__init__() |
| 81 | + |
| 82 | + @property |
| 83 | + def description(self) -> str: |
| 84 | + return ( |
| 85 | + "Optional step that rewrites the prompt(s) into Ideogram4's native structured JSON caption when " |
| 86 | + "`prompt_upsampling=True` (the format the model is trained on). Requires a generative `text_encoder` " |
| 87 | + "(a `Qwen3VLForConditionalGeneration`); install `outlines` for schema-constrained captions." |
| 88 | + ) |
| 89 | + |
| 90 | + @property |
| 91 | + def expected_components(self) -> list[ComponentSpec]: |
| 92 | + return [ |
| 93 | + ComponentSpec("text_encoder", Qwen3VLModel, description="The Qwen3-VL text encoder."), |
| 94 | + ComponentSpec("tokenizer", Qwen2Tokenizer, description="The tokenizer paired with the text encoder."), |
| 95 | + ComponentSpec( |
| 96 | + "prompt_enhancer_head", |
| 97 | + Ideogram4PromptEnhancerHead, |
| 98 | + description="LM head grafted onto the text encoder for prompt upsampling.", |
| 99 | + ), |
| 100 | + ] |
| 101 | + |
| 102 | + @property |
| 103 | + def inputs(self) -> list[InputParam]: |
| 104 | + return [ |
| 105 | + InputParam.template("prompt", required=True), |
| 106 | + InputParam( |
| 107 | + name="prompt_upsampling", |
| 108 | + type_hint=bool, |
| 109 | + default=False, |
| 110 | + description="If True, rewrite the prompt into Ideogram4's native JSON caption before encoding.", |
| 111 | + ), |
| 112 | + InputParam( |
| 113 | + name="prompt_upsampling_temperature", |
| 114 | + type_hint=float, |
| 115 | + default=PROMPT_UPSAMPLE_TEMPERATURE, |
| 116 | + description="Sampling temperature for prompt upsampling.", |
| 117 | + ), |
| 118 | + InputParam.template("height"), |
| 119 | + InputParam.template("width"), |
| 120 | + InputParam.template("max_sequence_length", default=2048), |
| 121 | + InputParam.template("generator"), |
| 122 | + ] |
| 123 | + |
| 124 | + @property |
| 125 | + def intermediate_outputs(self) -> list[OutputParam]: |
| 126 | + return [ |
| 127 | + OutputParam( |
| 128 | + name="prompt", |
| 129 | + type_hint=list, |
| 130 | + description="The (possibly upsampled) prompt forwarded to the text encoder.", |
| 131 | + ), |
| 132 | + ] |
| 133 | + |
| 134 | + @torch.no_grad() |
| 135 | + def __call__(self, components: Ideogram4ModularPipeline, state: PipelineState) -> PipelineState: |
| 136 | + block_state = self.get_block_state(state) |
| 137 | + |
| 138 | + if block_state.prompt_upsampling: |
| 139 | + if components.prompt_enhancer_head is None: |
| 140 | + raise ValueError( |
| 141 | + "Prompt upsampling requires the `prompt_enhancer_head` component, which is not loaded. Load an " |
| 142 | + "`Ideogram4PromptEnhancerHead` and add it to the pipeline." |
| 143 | + ) |
| 144 | + if self._prompt_enhancer is None: |
| 145 | + self._prompt_enhancer = build_prompt_enhancer(components.text_encoder, components.prompt_enhancer_head) |
| 146 | + if self._caption_logits_processor is None and is_outlines_available(): |
| 147 | + self._caption_logits_processor = build_caption_logits_processor( |
| 148 | + self._prompt_enhancer, components.tokenizer |
| 149 | + ) |
| 150 | + if self._caption_logits_processor is None: |
| 151 | + logger.warning_once( |
| 152 | + "`outlines` is not installed; prompt upsampling runs unconstrained and may not return " |
| 153 | + "schema-valid JSON. Install with `pip install outlines` for structured captions." |
| 154 | + ) |
| 155 | + height = block_state.height or components.default_height |
| 156 | + width = block_state.width or components.default_width |
| 157 | + block_state.prompt = generate_captions( |
| 158 | + self._prompt_enhancer, |
| 159 | + components.tokenizer, |
| 160 | + self._caption_logits_processor, |
| 161 | + block_state.prompt, |
| 162 | + height, |
| 163 | + width, |
| 164 | + temperature=block_state.prompt_upsampling_temperature, |
| 165 | + max_new_tokens=block_state.max_sequence_length, |
| 166 | + generator=block_state.generator, |
| 167 | + device=components._execution_device, |
| 168 | + ) |
| 169 | + |
| 170 | + self.set_block_state(state, block_state) |
| 171 | + return components, state |
| 172 | + |
| 173 | + |
34 | 174 | # auto_docstring |
35 | 175 | class Ideogram4TextEncoderStep(ModularPipelineBlocks): |
36 | 176 | """ |
|
0 commit comments