sana model/pipeline review
Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423
Review performed against the repository review rules.
Duplicate search: searched GitHub Issues and PRs for sana, affected class/file names, and failure terms. No likely duplicates found except Issue 2, which is already tracked.
Issue 1: Sana Sprint rejects documented 1/3/4-step inference by default
Affected code:
|
if intermediate_timesteps is not None and num_inference_steps != 2: |
|
raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.") |
|
num_inference_steps: int = 2, |
|
timesteps: list[int] = None, |
|
max_timesteps: float = 1.57080, |
|
intermediate_timesteps: float = 1.3, |
|
if intermediate_timesteps is not None and num_inference_steps != 2: |
|
raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.") |
|
num_inference_steps: int = 2, |
|
timesteps: list[int] = None, |
|
max_timesteps: float = 1.57080, |
|
intermediate_timesteps: float = 1.3, |
Problem:
intermediate_timesteps defaults to 1.3, but check_inputs() rejects any non-None value unless num_inference_steps == 2. As a result, num_inference_steps=1, 3, or 4 fails unless users know to pass intermediate_timesteps=None.
Impact:
SANA-Sprint is documented as a 1-4 step model, but the pipeline blocks the one-step path by default.
Reproduction:
from diffusers import SanaSprintImg2ImgPipeline, SanaSprintPipeline
common = dict(
prompt="cat",
height=1024,
width=1024,
num_inference_steps=1,
timesteps=None,
max_timesteps=1.5708,
intermediate_timesteps=1.3,
callback_on_step_end_tensor_inputs=None,
prompt_embeds=None,
prompt_attention_mask=None,
)
for cls, extra in [(SanaSprintPipeline, {}), (SanaSprintImg2ImgPipeline, {"strength": 0.5})]:
try:
cls.check_inputs(None, **common, **extra)
except Exception as e:
print(cls.__name__, type(e).__name__, e)
Relevant precedent:
|
*This paper presents SANA-Sprint, an efficient diffusion model for ultra-fast text-to-image (T2I) generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. We introduce three key innovations: (1) We propose a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. Our hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity. (2) SANA-Sprint is a unified step-adaptive model that achieves high-quality generation in 1-4 steps, eliminating step-specific training and improving efficiency. (3) We integrate ControlNet with SANA-Sprint for real-time interactive image generation, enabling instant visual feedback for user interaction. SANA-Sprint establishes a new Pareto frontier in speed-quality tradeoffs, achieving state-of-the-art performance with 7.59 FID and 0.74 GenEval in only 1 step — outperforming FLUX-schnell (7.94 FID / 0.71 GenEval) while being 10× faster (0.1s vs 1.1s on H100). It also achieves 0.1s (T2I) and 0.25s (ControlNet) latency for 1024×1024 images on H100, and 0.31s (T2I) on an RTX 4090, showcasing its exceptional efficiency and potential for AI-powered consumer applications (AIPC). Code and pre-trained models will be open-sourced.* |
Suggested fix:
# In both Sprint pipeline __call__ signatures:
intermediate_timesteps: float | None = None,
# Before retrieve_timesteps:
if num_inference_steps == 2 and intermediate_timesteps is None:
intermediate_timesteps = 1.3
Issue 2: Known duplicate: guidance_embeds=True crashes without guidance
Affected code:
|
if guidance is not None: |
|
timestep, embedded_timestep = self.time_embed( |
|
timestep, guidance=guidance, hidden_dtype=hidden_states.dtype |
|
) |
|
else: |
|
timestep, embedded_timestep = self.time_embed( |
|
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype |
Problem:
SanaTransformer2DModel.forward() dispatches the time embedding call based on whether guidance was passed, not on which time embedding module was configured. With guidance_embeds=True and no guidance, it calls SanaCombinedTimestepGuidanceEmbeddings.forward(..., batch_size=...), which is not accepted.
Impact:
A model configured with guidance embeddings cannot be used by non-guidance Sana pipelines; users get a low-level TypeError.
Reproduction:
import torch
from diffusers import SanaTransformer2DModel
model = SanaTransformer2DModel(
in_channels=4, out_channels=4, num_attention_heads=2, attention_head_dim=4,
num_layers=1, num_cross_attention_heads=2, cross_attention_head_dim=4,
cross_attention_dim=8, caption_channels=8, sample_size=4, patch_size=1,
guidance_embeds=True,
)
model(
hidden_states=torch.randn(1, 4, 4, 4),
encoder_hidden_states=torch.randn(1, 3, 8),
timestep=torch.tensor([1.0]),
)
Relevant precedent:
Duplicate: #12540
Related PRs: #13109 and closed-unmerged #13517
Suggested fix:
Route by embedding type/configuration instead of guidance is not None, and raise a clear ValueError when guidance_embeds=True but guidance is absent.
Issue 3: cross_attention_dim=None constructs a broken block
Affected code:
|
if cross_attention_dim is not None: |
|
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) |
|
self.attn2 = Attention( |
|
query_dim=dim, |
|
qk_norm=qk_norm, |
|
kv_heads=num_cross_attention_heads if qk_norm is not None else None, |
|
cross_attention_dim=cross_attention_dim, |
|
heads=num_cross_attention_heads, |
|
dim_head=cross_attention_head_dim, |
|
dropout=dropout, |
|
bias=True, |
|
out_bias=attention_out_bias, |
|
processor=SanaAttnProcessor2_0(), |
|
) |
|
|
|
# 3. Feed-forward |
|
self.ff = GLUMBConv(dim, dim, mlp_ratio, norm_type=None, residual_connection=False) |
|
|
|
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: torch.Tensor | None = None, |
|
encoder_hidden_states: torch.Tensor | None = None, |
|
encoder_attention_mask: torch.Tensor | None = None, |
|
timestep: torch.LongTensor | None = None, |
|
height: int = None, |
|
width: int = None, |
|
) -> torch.Tensor: |
|
batch_size = hidden_states.shape[0] |
|
|
|
# 1. Modulation |
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( |
|
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) |
|
).chunk(6, dim=1) |
|
|
|
# 2. Self Attention |
|
norm_hidden_states = self.norm1(hidden_states) |
|
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa |
|
norm_hidden_states = norm_hidden_states.to(hidden_states.dtype) |
|
|
|
attn_output = self.attn1(norm_hidden_states) |
|
hidden_states = hidden_states + gate_msa * attn_output |
|
|
|
# 3. Cross Attention |
|
if self.attn2 is not None: |
|
attn_output = self.attn2( |
|
hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
attention_mask=encoder_attention_mask, |
|
) |
|
hidden_states = attn_output + hidden_states |
|
|
|
# 4. Feed-forward |
|
norm_hidden_states = self.norm2(hidden_states) |
|
cross_attention_dim: int | None = 2240, |
|
caption_channels: int = 2304, |
|
mlp_ratio: float = 2.5, |
|
dropout: float = 0.0, |
|
attention_bias: bool = False, |
|
sample_size: int = 32, |
|
patch_size: int = 1, |
|
norm_elementwise_affine: bool = False, |
|
norm_eps: float = 1e-6, |
|
interpolation_scale: int | None = None, |
|
) -> None: |
|
super().__init__() |
|
|
|
out_channels = out_channels or in_channels |
|
inner_dim = num_attention_heads * attention_head_dim |
|
|
|
# 1. Patch Embedding |
|
self.patch_embed = PatchEmbed( |
|
height=sample_size, |
|
width=sample_size, |
|
patch_size=patch_size, |
|
in_channels=in_channels, |
|
embed_dim=inner_dim, |
|
interpolation_scale=interpolation_scale, |
|
pos_embed_type="sincos" if interpolation_scale is not None else None, |
|
) |
|
|
|
# 2. Additional condition embeddings |
|
self.time_embed = AdaLayerNormSingle(inner_dim) |
|
|
|
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) |
|
self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True) |
|
|
|
# 3. Transformer blocks |
|
self.transformer_blocks = nn.ModuleList( |
|
[ |
|
SanaTransformerBlock( |
|
inner_dim, |
|
num_attention_heads, |
|
attention_head_dim, |
|
dropout=dropout, |
|
num_cross_attention_heads=num_cross_attention_heads, |
|
cross_attention_head_dim=cross_attention_head_dim, |
|
cross_attention_dim=cross_attention_dim, |
Problem:
cross_attention_dim is annotated as optional, but SanaTransformerBlock only defines self.attn2 and self.norm2 inside if cross_attention_dim is not None. forward() always reads them.
Impact:
Valid-looking configs fail at runtime, including SanaControlNetModel, which reuses the same block.
Reproduction:
import torch
from diffusers import SanaControlNetModel, SanaTransformer2DModel
for cls in (SanaTransformer2DModel, SanaControlNetModel):
model = cls(
in_channels=4, out_channels=4, num_attention_heads=2, attention_head_dim=4,
num_layers=1, num_cross_attention_heads=2, cross_attention_head_dim=4,
cross_attention_dim=None, caption_channels=8, sample_size=4, patch_size=1,
)
kwargs = dict(
hidden_states=torch.randn(1, 4, 4, 4),
encoder_hidden_states=torch.randn(1, 3, 8),
timestep=torch.tensor([1.0]),
)
if cls is SanaControlNetModel:
kwargs["controlnet_cond"] = torch.randn(1, 4, 4, 4)
try:
model(**kwargs)
except Exception as e:
print(cls.__name__, type(e).__name__, e)
Relevant precedent:
Standard transformer blocks either define the optional attention attributes unconditionally or reject unsupported configs early.
Suggested fix:
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn2 = None
if cross_attention_dim is not None:
self.attn2 = Attention(...)
Issue 4: Sana attention bypasses backend dispatch and silently ignores self-attention masks
Affected code:
|
class SanaAttnProcessor2_0: |
|
r""" |
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). |
|
""" |
|
|
|
def __init__(self): |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor | None = None, |
|
attention_mask: torch.Tensor | None = None, |
|
) -> torch.Tensor: |
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
# scaled_dot_product_attention expects attention_mask shape to be |
|
# (batch, heads, source_length, target_length) |
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
|
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
if attn.norm_q is not None: |
|
query = attn.norm_q(query) |
|
if attn.norm_k is not None: |
|
key = attn.norm_k(key) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
# the output of sdp = (batch, num_heads, seq_len, head_dim) |
|
# TODO: add support for attn.scale when we move to Torch 2.1 |
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: torch.Tensor | None = None, |
|
encoder_hidden_states: torch.Tensor | None = None, |
|
encoder_attention_mask: torch.Tensor | None = None, |
|
timestep: torch.LongTensor | None = None, |
|
height: int = None, |
|
width: int = None, |
|
) -> torch.Tensor: |
|
batch_size = hidden_states.shape[0] |
|
|
|
# 1. Modulation |
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( |
|
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) |
|
).chunk(6, dim=1) |
|
|
|
# 2. Self Attention |
|
norm_hidden_states = self.norm1(hidden_states) |
|
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa |
|
norm_hidden_states = norm_hidden_states.to(hidden_states.dtype) |
|
|
|
attn_output = self.attn1(norm_hidden_states) |
|
encoder_attention_mask: torch.Tensor | None = None, |
|
attention_mask: torch.Tensor | None = None, |
|
attention_kwargs: dict[str, Any] | None = None, |
|
controlnet_block_samples: tuple[torch.Tensor] | None = None, |
|
return_dict: bool = True, |
|
) -> tuple[torch.Tensor, ...] | Transformer2DModelOutput: |
|
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension. |
|
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. |
|
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. |
|
# expects mask of shape: |
|
# [batch, key_tokens] |
|
# adds singleton query_tokens dimension: |
|
# [batch, 1, key_tokens] |
|
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: |
|
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) |
|
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) |
|
if attention_mask is not None and attention_mask.ndim == 2: |
|
# assume that mask is expressed as: |
|
# (1 = keep, 0 = discard) |
|
# convert mask into a bias that can be added to attention scores: |
|
# (keep = +0, discard = -10000.0) |
|
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 |
|
attention_mask = attention_mask.unsqueeze(1) |
|
|
|
# convert encoder_attention_mask to a bias the same way we do for attention_mask |
|
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: |
|
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 |
|
encoder_attention_mask = encoder_attention_mask.unsqueeze(1) |
Problem:
SanaAttnProcessor2_0 calls F.scaled_dot_product_attention directly and has no _attention_backend / _parallel_config, so set_attention_backend() cannot configure it. Separately, the public attention_mask is normalized and passed into blocks, but self-attention calls self.attn1(norm_hidden_states) without the mask.
Impact:
Backend selection and context-parallel attention support do not behave like newer transformer families. Passing attention_mask gives users a false signal because it is ignored.
Reproduction:
import torch
from diffusers import SanaTransformer2DModel
model = SanaTransformer2DModel(
in_channels=4, out_channels=4, num_attention_heads=2, attention_head_dim=4,
num_layers=1, num_cross_attention_heads=2, cross_attention_head_dim=4,
cross_attention_dim=8, caption_channels=8, sample_size=4, patch_size=1,
).eval()
print({name: hasattr(proc, "_attention_backend") for name, proc in model.attn_processors.items()})
model.set_attention_backend("_native_math")
print({name: getattr(proc, "_attention_backend", None) for name, proc in model.attn_processors.items()})
inputs = dict(
hidden_states=torch.randn(1, 4, 4, 4),
encoder_hidden_states=torch.randn(1, 3, 8),
timestep=torch.tensor([1.0]),
)
with torch.no_grad():
a = model(**inputs).sample
b = model(**inputs, attention_mask=torch.zeros(1, 16)).sample
print((a - b).abs().max().item()) # 0.0: mask had no effect
Relevant precedent:
|
class SanaAttnProcessor2_0: |
|
r""" |
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). |
|
""" |
|
|
|
_attention_backend = None |
|
_parallel_config = None |
|
|
|
def __init__(self): |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor | None = None, |
|
attention_mask: torch.Tensor | None = None, |
|
) -> torch.Tensor: |
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
# scaled_dot_product_attention expects attention_mask shape to be |
|
# (batch, heads, source_length, target_length) |
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
|
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
if attn.norm_q is not None: |
|
query = attn.norm_q(query) |
|
if attn.norm_k is not None: |
|
key = attn.norm_k(key) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim) |
|
key = key.view(batch_size, -1, attn.heads, head_dim) |
|
value = value.view(batch_size, -1, attn.heads, head_dim) |
|
|
|
# the output of sdp = (batch, num_heads, seq_len, head_dim) |
|
hidden_states = dispatch_attention_fn( |
|
query, |
|
key, |
|
value, |
|
attn_mask=attention_mask, |
|
dropout_p=0.0, |
|
is_causal=False, |
|
backend=self._attention_backend, |
|
parallel_config=self._parallel_config, |
Suggested fix:
Port the Sana Video processor pattern: add _attention_backend / _parallel_config and call dispatch_attention_fn() for cross-attention. For self-attention, either implement mask handling in SanaLinearAttnProcessor2_0 or remove/reject the unsupported public attention_mask.
Issue 5: SanaPipelineOutput is not exported from the Sana subpackage
Affected code:
|
_import_structure = {} |
|
|
|
|
|
try: |
|
if not (is_transformers_available() and is_torch_available()): |
|
raise OptionalDependencyNotAvailable() |
|
except OptionalDependencyNotAvailable: |
|
from ...utils import dummy_torch_and_transformers_objects # noqa F403 |
|
|
|
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) |
|
else: |
|
_import_structure["pipeline_sana"] = ["SanaPipeline"] |
|
_import_structure["pipeline_sana_controlnet"] = ["SanaControlNetPipeline"] |
|
_import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"] |
|
_import_structure["pipeline_sana_sprint_img2img"] = ["SanaSprintImg2ImgPipeline"] |
|
class SanaPipelineOutput(BaseOutput): |
Problem:
pipeline_output.py defines SanaPipelineOutput, and docs autodoc it, but src/diffusers/pipelines/sana/__init__.py never adds pipeline_output to _import_structure.
Impact:
The expected subpackage import fails while similar pipeline families expose their output classes through lazy imports.
Reproduction:
from diffusers.pipelines.sana import SanaPipelineOutput
Relevant precedent:
|
_dummy_objects = {} |
|
_additional_imports = {} |
|
_import_structure = {"pipeline_output": ["FluxPipelineOutput", "FluxPriorReduxPipelineOutput"]} |
Suggested fix:
_import_structure = {"pipeline_output": ["SanaPipelineOutput"]}
# in TYPE_CHECKING / slow import branch
from .pipeline_output import SanaPipelineOutput
Issue 6: Test coverage gaps for Sana ControlNet and Sprint variants
Affected code:
|
@slow |
|
@require_torch_accelerator |
|
class SanaPipelineIntegrationTests(unittest.TestCase): |
|
class SanaControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase): |
|
pipeline_class = SanaControlNetPipeline |
|
class SanaSprintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): |
|
pipeline_class = SanaSprintPipeline |
|
class SanaSprintImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): |
|
pipeline_class = SanaSprintImg2ImgPipeline |
Problem:
Fast pipeline tests exist for all four pipelines, and slow tests exist for base SanaPipeline only. There are no slow tests for SanaControlNetPipeline, SanaSprintPipeline, or SanaSprintImg2ImgPipeline. There is also no tests/models/controlnets/test_models_controlnet_sana.py, so SanaControlNetModel lacks direct model-mixin coverage.
Impact:
Real checkpoint loading, expected-output slices, ControlNet serialization/model behavior, and Sprint 1/3/4-step behavior are not covered.
Reproduction:
from pathlib import Path
for path in sorted(Path("tests/pipelines/sana").glob("test_*.py")):
text = path.read_text()
print(path, "@slow" in text)
print("Sana ControlNet model tests:", list(Path("tests/models/controlnets").glob("*sana*.py")))
Relevant precedent:
Base Sana has slow integration tests:
|
@slow |
|
@require_torch_accelerator |
|
class SanaPipelineIntegrationTests(unittest.TestCase): |
|
prompt = "A painting of a squirrel eating a burger." |
|
|
|
def setUp(self): |
|
super().setUp() |
|
gc.collect() |
|
backend_empty_cache(torch_device) |
|
|
|
def tearDown(self): |
|
super().tearDown() |
|
gc.collect() |
|
backend_empty_cache(torch_device) |
|
|
|
def test_sana_1024(self): |
|
generator = torch.Generator("cpu").manual_seed(0) |
|
|
|
pipe = SanaPipeline.from_pretrained( |
|
"Efficient-Large-Model/Sana_1600M_1024px_diffusers", torch_dtype=torch.float16 |
|
) |
|
pipe.enable_model_cpu_offload(device=torch_device) |
|
|
|
image = pipe( |
|
prompt=self.prompt, |
|
height=1024, |
|
width=1024, |
|
generator=generator, |
|
num_inference_steps=20, |
|
output_type="np", |
|
).images[0] |
|
|
|
image = image.flatten() |
|
output_slice = np.concatenate((image[:16], image[-16:])) |
|
|
|
# fmt: off |
|
expected_slice = np.array([0.0427, 0.0789, 0.0662, 0.0464, 0.082, 0.0574, 0.0535, 0.0886, 0.0647, 0.0549, 0.0872, 0.0605, 0.0593, 0.0942, 0.0674, 0.0581, 0.0076, 0.0168, 0.0027, 0.0063, 0.0159, 0.0, 0.0071, 0.0198, 0.0034, 0.0105, 0.0212, 0.0, 0.0, 0.0166, 0.0042, 0.0125]) |
|
# fmt: on |
|
|
|
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-4)) |
|
|
|
def test_sana_512(self): |
|
generator = torch.Generator("cpu").manual_seed(0) |
|
|
|
pipe = SanaPipeline.from_pretrained( |
|
"Efficient-Large-Model/Sana_1600M_512px_diffusers", torch_dtype=torch.float16 |
|
) |
|
pipe.enable_model_cpu_offload(device=torch_device) |
|
|
|
image = pipe( |
|
prompt=self.prompt, |
|
height=512, |
|
width=512, |
|
generator=generator, |
|
num_inference_steps=20, |
|
output_type="np", |
|
).images[0] |
|
|
|
image = image.flatten() |
|
output_slice = np.concatenate((image[:16], image[-16:])) |
|
|
|
# fmt: off |
|
expected_slice = np.array([0.0803, 0.0774, 0.1108, 0.0872, 0.093, 0.1118, 0.0952, 0.0898, 0.1038, 0.0818, 0.0754, 0.0894, 0.074, 0.0691, 0.0906, 0.0671, 0.0154, 0.0254, 0.0203, 0.0178, 0.0283, 0.0193, 0.0215, 0.0273, 0.0188, 0.0212, 0.0273, 0.0151, 0.0061, 0.0244, 0.0212, 0.0259]) |
|
# fmt: on |
|
|
|
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-4)) |
Suggested fix:
Add slow tests with small output slices for the public ControlNet and Sprint checkpoints, add num_inference_steps coverage for Sprint 1, 2, 3, and 4, and add a ModelTesterMixin-style SanaControlNetModel test file under tests/models/controlnets/.
sanamodel/pipeline reviewCommit tested:
0f1abc4ae8b0eb2a3b40e82a310507281144c423Review performed against the repository review rules.
Duplicate search: searched GitHub Issues and PRs for
sana, affected class/file names, and failure terms. No likely duplicates found except Issue 2, which is already tracked.Issue 1: Sana Sprint rejects documented 1/3/4-step inference by default
Affected code:
diffusers/src/diffusers/pipelines/sana/pipeline_sana_sprint.py
Lines 436 to 437 in 0f1abc4
diffusers/src/diffusers/pipelines/sana/pipeline_sana_sprint.py
Lines 620 to 623 in 0f1abc4
diffusers/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py
Lines 463 to 464 in 0f1abc4
diffusers/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py
Lines 691 to 694 in 0f1abc4
Problem:
intermediate_timestepsdefaults to1.3, butcheck_inputs()rejects any non-Nonevalue unlessnum_inference_steps == 2. As a result,num_inference_steps=1,3, or4fails unless users know to passintermediate_timesteps=None.Impact:
SANA-Sprint is documented as a 1-4 step model, but the pipeline blocks the one-step path by default.
Reproduction:
Relevant precedent:
diffusers/docs/source/en/api/pipelines/sana_sprint.md
Line 25 in 0f1abc4
Suggested fix:
Issue 2: Known duplicate:
guidance_embeds=Truecrashes withoutguidanceAffected code:
diffusers/src/diffusers/models/transformers/sana_transformer.py
Lines 460 to 466 in 0f1abc4
Problem:
SanaTransformer2DModel.forward()dispatches the time embedding call based on whetherguidancewas passed, not on which time embedding module was configured. Withguidance_embeds=Trueand noguidance, it callsSanaCombinedTimestepGuidanceEmbeddings.forward(..., batch_size=...), which is not accepted.Impact:
A model configured with guidance embeddings cannot be used by non-guidance Sana pipelines; users get a low-level
TypeError.Reproduction:
Relevant precedent:
Duplicate: #12540
Related PRs: #13109 and closed-unmerged #13517
Suggested fix:
Route by embedding type/configuration instead of
guidance is not None, and raise a clearValueErrorwhenguidance_embeds=Truebutguidanceis absent.Issue 3:
cross_attention_dim=Noneconstructs a broken blockAffected code:
diffusers/src/diffusers/models/transformers/sana_transformer.py
Lines 226 to 281 in 0f1abc4
diffusers/src/diffusers/models/controlnets/controlnet_sana.py
Lines 56 to 99 in 0f1abc4
Problem:
cross_attention_dimis annotated as optional, butSanaTransformerBlockonly definesself.attn2andself.norm2insideif cross_attention_dim is not None.forward()always reads them.Impact:
Valid-looking configs fail at runtime, including
SanaControlNetModel, which reuses the same block.Reproduction:
Relevant precedent:
Standard transformer blocks either define the optional attention attributes unconditionally or reject unsupported configs early.
Suggested fix:
Issue 4: Sana attention bypasses backend dispatch and silently ignores self-attention masks
Affected code:
diffusers/src/diffusers/models/transformers/sana_transformer.py
Lines 122 to 172 in 0f1abc4
diffusers/src/diffusers/models/transformers/sana_transformer.py
Lines 246 to 268 in 0f1abc4
diffusers/src/diffusers/models/transformers/sana_transformer.py
Lines 424 to 451 in 0f1abc4
Problem:
SanaAttnProcessor2_0callsF.scaled_dot_product_attentiondirectly and has no_attention_backend/_parallel_config, soset_attention_backend()cannot configure it. Separately, the publicattention_maskis normalized and passed into blocks, but self-attention callsself.attn1(norm_hidden_states)without the mask.Impact:
Backend selection and context-parallel attention support do not behave like newer transformer families. Passing
attention_maskgives users a false signal because it is ignored.Reproduction:
Relevant precedent:
diffusers/src/diffusers/models/transformers/transformer_sana_video.py
Lines 277 to 335 in 0f1abc4
Suggested fix:
Port the Sana Video processor pattern: add
_attention_backend/_parallel_configand calldispatch_attention_fn()for cross-attention. For self-attention, either implement mask handling inSanaLinearAttnProcessor2_0or remove/reject the unsupported publicattention_mask.Issue 5:
SanaPipelineOutputis not exported from the Sana subpackageAffected code:
diffusers/src/diffusers/pipelines/sana/__init__.py
Lines 14 to 28 in 0f1abc4
diffusers/src/diffusers/pipelines/sana/pipeline_output.py
Line 10 in 0f1abc4
Problem:
pipeline_output.pydefinesSanaPipelineOutput, and docs autodoc it, butsrc/diffusers/pipelines/sana/__init__.pynever addspipeline_outputto_import_structure.Impact:
The expected subpackage import fails while similar pipeline families expose their output classes through lazy imports.
Reproduction:
Relevant precedent:
diffusers/src/diffusers/pipelines/flux/__init__.py
Lines 13 to 15 in 0f1abc4
Suggested fix:
Issue 6: Test coverage gaps for Sana ControlNet and Sprint variants
Affected code:
diffusers/tests/pipelines/sana/test_sana.py
Lines 313 to 315 in 0f1abc4
diffusers/tests/pipelines/sana/test_sana_controlnet.py
Lines 39 to 40 in 0f1abc4
diffusers/tests/pipelines/sana/test_sana_sprint.py
Lines 32 to 33 in 0f1abc4
diffusers/tests/pipelines/sana/test_sana_sprint_img2img.py
Lines 37 to 38 in 0f1abc4
Problem:
Fast pipeline tests exist for all four pipelines, and slow tests exist for base
SanaPipelineonly. There are no slow tests forSanaControlNetPipeline,SanaSprintPipeline, orSanaSprintImg2ImgPipeline. There is also notests/models/controlnets/test_models_controlnet_sana.py, soSanaControlNetModellacks direct model-mixin coverage.Impact:
Real checkpoint loading, expected-output slices, ControlNet serialization/model behavior, and Sprint 1/3/4-step behavior are not covered.
Reproduction:
Relevant precedent:
Base Sana has slow integration tests:
diffusers/tests/pipelines/sana/test_sana.py
Lines 313 to 378 in 0f1abc4
Suggested fix:
Add slow tests with small output slices for the public ControlNet and Sprint checkpoints, add
num_inference_stepscoverage for Sprint1,2,3, and4, and add aModelTesterMixin-styleSanaControlNetModeltest file undertests/models/controlnets/.