11import inspect
22import math
33from dataclasses import dataclass
4- from types import SimpleNamespace
5- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
4+ from typing import Callable , Dict , List , Optional , Tuple , Union
65
76import numpy as np
87import torch
@@ -257,33 +256,6 @@ def _dynamic_resize_from_bucket(image_size: Tuple[int, int], basesize: int = 51
257256
258257
259258
260- def _build_args (
261- args : Any ,
262- text_encoder : Qwen3VLForConditionalGeneration ,
263- ) -> Any :
264- """
265- Return args unchanged if provided, otherwise construct a default namespace.
266-
267- Args:
268- args: Existing args object, or None.
269- text_encoder: Text encoder used to resolve the checkpoint path when args is None.
270-
271- Returns:
272- The original args object, or a SimpleNamespace with sensible defaults.
273- """
274- if args is not None :
275- return args
276-
277- text_encoder_ckpt = _get_text_encoder_ckpt (text_encoder )
278- return SimpleNamespace (
279- enable_multi_task_training = False ,
280- text_token_max_length = 2048 ,
281- dit_precision = "bf16" ,
282- vae_precision = "bf16" ,
283- text_encoder_arch_config = {"params" : {"text_encoder_ckpt" : text_encoder_ckpt }},
284- )
285-
286-
287259def retrieve_timesteps (
288260 scheduler ,
289261 num_inference_steps : Optional [int ] = None ,
@@ -364,7 +336,9 @@ def __init__(
364336 tokenizer : Qwen2Tokenizer ,
365337 transformer : JoyImageEditTransformer3DModel ,
366338 processor : Qwen3VLProcessor ,
367- args : Any = None ,
339+ enable_multi_task_training : bool = False ,
340+ text_token_max_length : int = 2048 ,
341+ text_encoder_ckpt : Optional [str ] = None ,
368342 ):
369343 """
370344 Initialise the pipeline and register all sub-modules.
@@ -376,10 +350,12 @@ def __init__(
376350 tokenizer: Tokenizer paired with the text encoder.
377351 transformer: 3-D transformer denoising network.
378352 processor: Qwen3-VL processor for multi-image prompt preparation.
379- args: Optional configuration namespace. Defaults are inferred when None.
353+ enable_multi_task_training: Whether to enable multi-task training mode.
354+ text_token_max_length: Maximum number of text tokens for the encoder.
355+ text_encoder_ckpt: Path to text encoder checkpoint. Inferred from
356+ ``text_encoder`` when not provided.
380357 """
381358 super ().__init__ ()
382- self .args = _build_args (args = args , text_encoder = text_encoder )
383359 self .register_modules (
384360 vae = vae ,
385361 text_encoder = text_encoder ,
@@ -389,6 +365,9 @@ def __init__(
389365 processor = processor ,
390366 )
391367
368+ self .enable_multi_task_training = enable_multi_task_training
369+ self .text_token_max_length = text_token_max_length
370+
392371 self .vae_scale_factor_temporal = (
393372 self .vae .config .scale_factor_temporal if getattr (self , "vae" , None ) else 4
394373 )
@@ -397,15 +376,12 @@ def __init__(
397376 )
398377 self .image_processor = VaeImageProcessor (vae_scale_factor = self .vae_scale_factor_spatial )
399378
400- text_encoder_ckpt = dict (self .args .text_encoder_arch_config .get ("params" , {})).get (
401- "text_encoder_ckpt" , _get_text_encoder_ckpt (self .text_encoder )
402- )
379+ if text_encoder_ckpt is None :
380+ text_encoder_ckpt = _get_text_encoder_ckpt (self .text_encoder )
403381 self .qwen_processor = (
404382 processor if processor is not None else AutoProcessor .from_pretrained (text_encoder_ckpt )
405383 )
406384
407- self .text_token_max_length = self .args .text_token_max_length
408-
409385 # Prompt templates used when encoding text with / without image tokens.
410386 self .prompt_template_encode = {
411387 "image" : (
0 commit comments