Skip to content

sana model/pipeline review #13614

@hlky

Description

@hlky

sana model/pipeline review

Commit tested: 0f1abc4ae8b0eb2a3b40e82a310507281144c423

Review performed against the repository review rules.

Duplicate search: searched GitHub Issues and PRs for sana, affected class/file names, and failure terms. No likely duplicates found except Issue 2, which is already tracked.

Issue 1: Sana Sprint rejects documented 1/3/4-step inference by default

Affected code:

if intermediate_timesteps is not None and num_inference_steps != 2:
raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.")

num_inference_steps: int = 2,
timesteps: list[int] = None,
max_timesteps: float = 1.57080,
intermediate_timesteps: float = 1.3,

if intermediate_timesteps is not None and num_inference_steps != 2:
raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.")

num_inference_steps: int = 2,
timesteps: list[int] = None,
max_timesteps: float = 1.57080,
intermediate_timesteps: float = 1.3,

Problem:
intermediate_timesteps defaults to 1.3, but check_inputs() rejects any non-None value unless num_inference_steps == 2. As a result, num_inference_steps=1, 3, or 4 fails unless users know to pass intermediate_timesteps=None.

Impact:
SANA-Sprint is documented as a 1-4 step model, but the pipeline blocks the one-step path by default.

Reproduction:

from diffusers import SanaSprintImg2ImgPipeline, SanaSprintPipeline

common = dict(
    prompt="cat",
    height=1024,
    width=1024,
    num_inference_steps=1,
    timesteps=None,
    max_timesteps=1.5708,
    intermediate_timesteps=1.3,
    callback_on_step_end_tensor_inputs=None,
    prompt_embeds=None,
    prompt_attention_mask=None,
)

for cls, extra in [(SanaSprintPipeline, {}), (SanaSprintImg2ImgPipeline, {"strength": 0.5})]:
    try:
        cls.check_inputs(None, **common, **extra)
    except Exception as e:
        print(cls.__name__, type(e).__name__, e)

Relevant precedent:

*This paper presents SANA-Sprint, an efficient diffusion model for ultra-fast text-to-image (T2I) generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. We introduce three key innovations: (1) We propose a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. Our hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity. (2) SANA-Sprint is a unified step-adaptive model that achieves high-quality generation in 1-4 steps, eliminating step-specific training and improving efficiency. (3) We integrate ControlNet with SANA-Sprint for real-time interactive image generation, enabling instant visual feedback for user interaction. SANA-Sprint establishes a new Pareto frontier in speed-quality tradeoffs, achieving state-of-the-art performance with 7.59 FID and 0.74 GenEval in only 1 step — outperforming FLUX-schnell (7.94 FID / 0.71 GenEval) while being 10× faster (0.1s vs 1.1s on H100). It also achieves 0.1s (T2I) and 0.25s (ControlNet) latency for 1024×1024 images on H100, and 0.31s (T2I) on an RTX 4090, showcasing its exceptional efficiency and potential for AI-powered consumer applications (AIPC). Code and pre-trained models will be open-sourced.*

Suggested fix:

# In both Sprint pipeline __call__ signatures:
intermediate_timesteps: float | None = None,

# Before retrieve_timesteps:
if num_inference_steps == 2 and intermediate_timesteps is None:
    intermediate_timesteps = 1.3

Issue 2: Known duplicate: guidance_embeds=True crashes without guidance

Affected code:

if guidance is not None:
timestep, embedded_timestep = self.time_embed(
timestep, guidance=guidance, hidden_dtype=hidden_states.dtype
)
else:
timestep, embedded_timestep = self.time_embed(
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype

Problem:
SanaTransformer2DModel.forward() dispatches the time embedding call based on whether guidance was passed, not on which time embedding module was configured. With guidance_embeds=True and no guidance, it calls SanaCombinedTimestepGuidanceEmbeddings.forward(..., batch_size=...), which is not accepted.

Impact:
A model configured with guidance embeddings cannot be used by non-guidance Sana pipelines; users get a low-level TypeError.

Reproduction:

import torch
from diffusers import SanaTransformer2DModel

model = SanaTransformer2DModel(
    in_channels=4, out_channels=4, num_attention_heads=2, attention_head_dim=4,
    num_layers=1, num_cross_attention_heads=2, cross_attention_head_dim=4,
    cross_attention_dim=8, caption_channels=8, sample_size=4, patch_size=1,
    guidance_embeds=True,
)

model(
    hidden_states=torch.randn(1, 4, 4, 4),
    encoder_hidden_states=torch.randn(1, 3, 8),
    timestep=torch.tensor([1.0]),
)

Relevant precedent:
Duplicate: #12540
Related PRs: #13109 and closed-unmerged #13517

Suggested fix:
Route by embedding type/configuration instead of guidance is not None, and raise a clear ValueError when guidance_embeds=True but guidance is absent.

Issue 3: cross_attention_dim=None constructs a broken block

Affected code:

if cross_attention_dim is not None:
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn2 = Attention(
query_dim=dim,
qk_norm=qk_norm,
kv_heads=num_cross_attention_heads if qk_norm is not None else None,
cross_attention_dim=cross_attention_dim,
heads=num_cross_attention_heads,
dim_head=cross_attention_head_dim,
dropout=dropout,
bias=True,
out_bias=attention_out_bias,
processor=SanaAttnProcessor2_0(),
)
# 3. Feed-forward
self.ff = GLUMBConv(dim, dim, mlp_ratio, norm_type=None, residual_connection=False)
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
encoder_hidden_states: torch.Tensor | None = None,
encoder_attention_mask: torch.Tensor | None = None,
timestep: torch.LongTensor | None = None,
height: int = None,
width: int = None,
) -> torch.Tensor:
batch_size = hidden_states.shape[0]
# 1. Modulation
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
# 2. Self Attention
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
norm_hidden_states = norm_hidden_states.to(hidden_states.dtype)
attn_output = self.attn1(norm_hidden_states)
hidden_states = hidden_states + gate_msa * attn_output
# 3. Cross Attention
if self.attn2 is not None:
attn_output = self.attn2(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
norm_hidden_states = self.norm2(hidden_states)

cross_attention_dim: int | None = 2240,
caption_channels: int = 2304,
mlp_ratio: float = 2.5,
dropout: float = 0.0,
attention_bias: bool = False,
sample_size: int = 32,
patch_size: int = 1,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
interpolation_scale: int | None = None,
) -> None:
super().__init__()
out_channels = out_channels or in_channels
inner_dim = num_attention_heads * attention_head_dim
# 1. Patch Embedding
self.patch_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=interpolation_scale,
pos_embed_type="sincos" if interpolation_scale is not None else None,
)
# 2. Additional condition embeddings
self.time_embed = AdaLayerNormSingle(inner_dim)
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
# 3. Transformer blocks
self.transformer_blocks = nn.ModuleList(
[
SanaTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
num_cross_attention_heads=num_cross_attention_heads,
cross_attention_head_dim=cross_attention_head_dim,
cross_attention_dim=cross_attention_dim,

Problem:
cross_attention_dim is annotated as optional, but SanaTransformerBlock only defines self.attn2 and self.norm2 inside if cross_attention_dim is not None. forward() always reads them.

Impact:
Valid-looking configs fail at runtime, including SanaControlNetModel, which reuses the same block.

Reproduction:

import torch
from diffusers import SanaControlNetModel, SanaTransformer2DModel

for cls in (SanaTransformer2DModel, SanaControlNetModel):
    model = cls(
        in_channels=4, out_channels=4, num_attention_heads=2, attention_head_dim=4,
        num_layers=1, num_cross_attention_heads=2, cross_attention_head_dim=4,
        cross_attention_dim=None, caption_channels=8, sample_size=4, patch_size=1,
    )
    kwargs = dict(
        hidden_states=torch.randn(1, 4, 4, 4),
        encoder_hidden_states=torch.randn(1, 3, 8),
        timestep=torch.tensor([1.0]),
    )
    if cls is SanaControlNetModel:
        kwargs["controlnet_cond"] = torch.randn(1, 4, 4, 4)

    try:
        model(**kwargs)
    except Exception as e:
        print(cls.__name__, type(e).__name__, e)

Relevant precedent:
Standard transformer blocks either define the optional attention attributes unconditionally or reject unsupported configs early.

Suggested fix:

self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn2 = None
if cross_attention_dim is not None:
    self.attn2 = Attention(...)

Issue 4: Sana attention bypasses backend dispatch and silently ignores self-attention masks

Affected code:

class SanaAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
encoder_hidden_states: torch.Tensor | None = None,
encoder_attention_mask: torch.Tensor | None = None,
timestep: torch.LongTensor | None = None,
height: int = None,
width: int = None,
) -> torch.Tensor:
batch_size = hidden_states.shape[0]
# 1. Modulation
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
# 2. Self Attention
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
norm_hidden_states = norm_hidden_states.to(hidden_states.dtype)
attn_output = self.attn1(norm_hidden_states)

encoder_attention_mask: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
attention_kwargs: dict[str, Any] | None = None,
controlnet_block_samples: tuple[torch.Tensor] | None = None,
return_dict: bool = True,
) -> tuple[torch.Tensor, ...] | Transformer2DModelOutput:
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)

Problem:
SanaAttnProcessor2_0 calls F.scaled_dot_product_attention directly and has no _attention_backend / _parallel_config, so set_attention_backend() cannot configure it. Separately, the public attention_mask is normalized and passed into blocks, but self-attention calls self.attn1(norm_hidden_states) without the mask.

Impact:
Backend selection and context-parallel attention support do not behave like newer transformer families. Passing attention_mask gives users a false signal because it is ignored.

Reproduction:

import torch
from diffusers import SanaTransformer2DModel

model = SanaTransformer2DModel(
    in_channels=4, out_channels=4, num_attention_heads=2, attention_head_dim=4,
    num_layers=1, num_cross_attention_heads=2, cross_attention_head_dim=4,
    cross_attention_dim=8, caption_channels=8, sample_size=4, patch_size=1,
).eval()

print({name: hasattr(proc, "_attention_backend") for name, proc in model.attn_processors.items()})
model.set_attention_backend("_native_math")
print({name: getattr(proc, "_attention_backend", None) for name, proc in model.attn_processors.items()})

inputs = dict(
    hidden_states=torch.randn(1, 4, 4, 4),
    encoder_hidden_states=torch.randn(1, 3, 8),
    timestep=torch.tensor([1.0]),
)

with torch.no_grad():
    a = model(**inputs).sample
    b = model(**inputs, attention_mask=torch.zeros(1, 16)).sample

print((a - b).abs().max().item())  # 0.0: mask had no effect

Relevant precedent:

class SanaAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, attn.heads, head_dim)
value = value.view(batch_size, -1, attn.heads, head_dim)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,

Suggested fix:
Port the Sana Video processor pattern: add _attention_backend / _parallel_config and call dispatch_attention_fn() for cross-attention. For self-attention, either implement mask handling in SanaLinearAttnProcessor2_0 or remove/reject the unsupported public attention_mask.

Issue 5: SanaPipelineOutput is not exported from the Sana subpackage

Affected code:

_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_sana"] = ["SanaPipeline"]
_import_structure["pipeline_sana_controlnet"] = ["SanaControlNetPipeline"]
_import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"]
_import_structure["pipeline_sana_sprint_img2img"] = ["SanaSprintImg2ImgPipeline"]

class SanaPipelineOutput(BaseOutput):

Problem:
pipeline_output.py defines SanaPipelineOutput, and docs autodoc it, but src/diffusers/pipelines/sana/__init__.py never adds pipeline_output to _import_structure.

Impact:
The expected subpackage import fails while similar pipeline families expose their output classes through lazy imports.

Reproduction:

from diffusers.pipelines.sana import SanaPipelineOutput

Relevant precedent:

_dummy_objects = {}
_additional_imports = {}
_import_structure = {"pipeline_output": ["FluxPipelineOutput", "FluxPriorReduxPipelineOutput"]}

Suggested fix:

_import_structure = {"pipeline_output": ["SanaPipelineOutput"]}

# in TYPE_CHECKING / slow import branch
from .pipeline_output import SanaPipelineOutput

Issue 6: Test coverage gaps for Sana ControlNet and Sprint variants

Affected code:

@slow
@require_torch_accelerator
class SanaPipelineIntegrationTests(unittest.TestCase):

class SanaControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = SanaControlNetPipeline

class SanaSprintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = SanaSprintPipeline

class SanaSprintImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = SanaSprintImg2ImgPipeline

Problem:
Fast pipeline tests exist for all four pipelines, and slow tests exist for base SanaPipeline only. There are no slow tests for SanaControlNetPipeline, SanaSprintPipeline, or SanaSprintImg2ImgPipeline. There is also no tests/models/controlnets/test_models_controlnet_sana.py, so SanaControlNetModel lacks direct model-mixin coverage.

Impact:
Real checkpoint loading, expected-output slices, ControlNet serialization/model behavior, and Sprint 1/3/4-step behavior are not covered.

Reproduction:

from pathlib import Path

for path in sorted(Path("tests/pipelines/sana").glob("test_*.py")):
    text = path.read_text()
    print(path, "@slow" in text)

print("Sana ControlNet model tests:", list(Path("tests/models/controlnets").glob("*sana*.py")))

Relevant precedent:
Base Sana has slow integration tests:

@slow
@require_torch_accelerator
class SanaPipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def test_sana_1024(self):
generator = torch.Generator("cpu").manual_seed(0)
pipe = SanaPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_diffusers", torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload(device=torch_device)
image = pipe(
prompt=self.prompt,
height=1024,
width=1024,
generator=generator,
num_inference_steps=20,
output_type="np",
).images[0]
image = image.flatten()
output_slice = np.concatenate((image[:16], image[-16:]))
# fmt: off
expected_slice = np.array([0.0427, 0.0789, 0.0662, 0.0464, 0.082, 0.0574, 0.0535, 0.0886, 0.0647, 0.0549, 0.0872, 0.0605, 0.0593, 0.0942, 0.0674, 0.0581, 0.0076, 0.0168, 0.0027, 0.0063, 0.0159, 0.0, 0.0071, 0.0198, 0.0034, 0.0105, 0.0212, 0.0, 0.0, 0.0166, 0.0042, 0.0125])
# fmt: on
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-4))
def test_sana_512(self):
generator = torch.Generator("cpu").manual_seed(0)
pipe = SanaPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_512px_diffusers", torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload(device=torch_device)
image = pipe(
prompt=self.prompt,
height=512,
width=512,
generator=generator,
num_inference_steps=20,
output_type="np",
).images[0]
image = image.flatten()
output_slice = np.concatenate((image[:16], image[-16:]))
# fmt: off
expected_slice = np.array([0.0803, 0.0774, 0.1108, 0.0872, 0.093, 0.1118, 0.0952, 0.0898, 0.1038, 0.0818, 0.0754, 0.0894, 0.074, 0.0691, 0.0906, 0.0671, 0.0154, 0.0254, 0.0203, 0.0178, 0.0283, 0.0193, 0.0215, 0.0273, 0.0188, 0.0212, 0.0273, 0.0151, 0.0061, 0.0244, 0.0212, 0.0259])
# fmt: on
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-4))

Suggested fix:
Add slow tests with small output slices for the public ControlNet and Sprint checkpoints, add num_inference_steps coverage for Sprint 1, 2, 3, and 4, and add a ModelTesterMixin-style SanaControlNetModel test file under tests/models/controlnets/.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions