Skip to content

Commit 001f7d3

Browse files
author
huangfeice
committed
[refactor] Refactor JoyImageEditPipeline to use explicit arguments instead of namespace and remove _build_arg
1 parent cc9d134 commit 001f7d3

1 file changed

Lines changed: 13 additions & 37 deletions

File tree

src/diffusers/pipelines/joyimage/pipeline_joyimage_edit.py

Lines changed: 13 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import inspect
22
import math
33
from dataclasses import dataclass
4-
from types import SimpleNamespace
5-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
4+
from typing import Callable, Dict, List, Optional, Tuple, Union
65

76
import numpy as np
87
import torch
@@ -257,33 +256,6 @@ def _dynamic_resize_from_bucket(image_size: Tuple[int, int], basesize: int = 51
257256

258257

259258

260-
def _build_args(
261-
args: Any,
262-
text_encoder: Qwen3VLForConditionalGeneration,
263-
) -> Any:
264-
"""
265-
Return args unchanged if provided, otherwise construct a default namespace.
266-
267-
Args:
268-
args: Existing args object, or None.
269-
text_encoder: Text encoder used to resolve the checkpoint path when args is None.
270-
271-
Returns:
272-
The original args object, or a SimpleNamespace with sensible defaults.
273-
"""
274-
if args is not None:
275-
return args
276-
277-
text_encoder_ckpt = _get_text_encoder_ckpt(text_encoder)
278-
return SimpleNamespace(
279-
enable_multi_task_training=False,
280-
text_token_max_length=2048,
281-
dit_precision="bf16",
282-
vae_precision="bf16",
283-
text_encoder_arch_config={"params": {"text_encoder_ckpt": text_encoder_ckpt}},
284-
)
285-
286-
287259
def retrieve_timesteps(
288260
scheduler,
289261
num_inference_steps: Optional[int] = None,
@@ -364,7 +336,9 @@ def __init__(
364336
tokenizer: Qwen2Tokenizer,
365337
transformer: JoyImageEditTransformer3DModel,
366338
processor: Qwen3VLProcessor,
367-
args: Any = None,
339+
enable_multi_task_training: bool = False,
340+
text_token_max_length: int = 2048,
341+
text_encoder_ckpt: Optional[str] = None,
368342
):
369343
"""
370344
Initialise the pipeline and register all sub-modules.
@@ -376,10 +350,12 @@ def __init__(
376350
tokenizer: Tokenizer paired with the text encoder.
377351
transformer: 3-D transformer denoising network.
378352
processor: Qwen3-VL processor for multi-image prompt preparation.
379-
args: Optional configuration namespace. Defaults are inferred when None.
353+
enable_multi_task_training: Whether to enable multi-task training mode.
354+
text_token_max_length: Maximum number of text tokens for the encoder.
355+
text_encoder_ckpt: Path to text encoder checkpoint. Inferred from
356+
``text_encoder`` when not provided.
380357
"""
381358
super().__init__()
382-
self.args = _build_args(args=args, text_encoder=text_encoder)
383359
self.register_modules(
384360
vae=vae,
385361
text_encoder=text_encoder,
@@ -389,6 +365,9 @@ def __init__(
389365
processor=processor,
390366
)
391367

368+
self.enable_multi_task_training = enable_multi_task_training
369+
self.text_token_max_length = text_token_max_length
370+
392371
self.vae_scale_factor_temporal = (
393372
self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
394373
)
@@ -397,15 +376,12 @@ def __init__(
397376
)
398377
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
399378

400-
text_encoder_ckpt = dict(self.args.text_encoder_arch_config.get("params", {})).get(
401-
"text_encoder_ckpt", _get_text_encoder_ckpt(self.text_encoder)
402-
)
379+
if text_encoder_ckpt is None:
380+
text_encoder_ckpt = _get_text_encoder_ckpt(self.text_encoder)
403381
self.qwen_processor = (
404382
processor if processor is not None else AutoProcessor.from_pretrained(text_encoder_ckpt)
405383
)
406384

407-
self.text_token_max_length = self.args.text_token_max_length
408-
409385
# Prompt templates used when encoding text with / without image tokens.
410386
self.prompt_template_encode = {
411387
"image": (

0 commit comments

Comments
 (0)