Skip to content

Commit 7f13f68

Browse files
committed
test: pixart compile mode on neuron
1 parent 16b9606 commit 7f13f68

1 file changed

Lines changed: 44 additions & 0 deletions

File tree

tests/pipelines/pixart_alpha/test_pixart.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,3 +381,47 @@ def test_pixart_512_without_resolution_binning(self):
381381
no_res_bin_image_slice = no_res_bin_image[0, -3:, -3:, -1]
382382

383383
assert not np.allclose(image_slice, no_res_bin_image_slice, atol=1e-4, rtol=1e-4)
384+
385+
@unittest.skipUnless(is_torch_neuronx_available(), "torch_neuronx not available")
386+
def test_pixart_512_neuron_compile(self):
387+
"""
388+
Smoke-test PixArtAlphaPipeline under torch.compile(backend="neuron") at 512×512.
389+
"""
390+
import torch_neuronx # noqa: F401 — registers torch.neuron
391+
from torch_neuronx.neuron_dynamo_backend import set_model_name
392+
393+
device = torch.neuron.current_device()
394+
generator = torch.Generator("cpu").manual_seed(0)
395+
396+
pipe = PixArtAlphaPipeline.from_pretrained(self.ckpt_id_512, torch_dtype=torch.bfloat16)
397+
pipe = pipe.to(device)
398+
# Flush pending lazy-XLA parameter-copy ops before compiling.
399+
torch.neuron.synchronize()
400+
401+
pipe.transformer.eval()
402+
pipe.vae.eval()
403+
pipe.text_encoder.eval()
404+
405+
set_model_name("pixart_text_encoder")
406+
pipe.text_encoder = torch.compile(pipe.text_encoder, backend="neuron", fullgraph=True)
407+
set_model_name("pixart_transformer")
408+
pipe.transformer = torch.compile(pipe.transformer, backend="neuron", fullgraph=True)
409+
# VAE must be compiled after pipeline __init__ (which reads vae.config.block_out_channels).
410+
set_model_name("pixart_vae")
411+
pipe.vae = torch.compile(pipe.vae, backend="neuron", fullgraph=True)
412+
413+
image = pipe(
414+
self.prompt,
415+
generator=generator,
416+
height=512,
417+
width=512,
418+
num_inference_steps=2,
419+
output_type="np",
420+
).images
421+
422+
self.assertEqual(image.shape, (1, 512, 512, 3))
423+
self.assertFalse(np.isnan(image).any(), "Output contains NaN values")
424+
self.assertTrue(
425+
(image >= 0.0).all() and (image <= 1.0).all(),
426+
"Output pixel values outside [0, 1]",
427+
)

0 commit comments

Comments
 (0)