@@ -381,3 +381,47 @@ def test_pixart_512_without_resolution_binning(self):
381381 no_res_bin_image_slice = no_res_bin_image [0 , - 3 :, - 3 :, - 1 ]
382382
383383 assert not np .allclose (image_slice , no_res_bin_image_slice , atol = 1e-4 , rtol = 1e-4 )
384+
385+ @unittest .skipUnless (is_torch_neuronx_available (), "torch_neuronx not available" )
386+ def test_pixart_512_neuron_compile (self ):
387+ """
388+ Smoke-test PixArtAlphaPipeline under torch.compile(backend="neuron") at 512×512.
389+ """
390+ import torch_neuronx # noqa: F401 — registers torch.neuron
391+ from torch_neuronx .neuron_dynamo_backend import set_model_name
392+
393+ device = torch .neuron .current_device ()
394+ generator = torch .Generator ("cpu" ).manual_seed (0 )
395+
396+ pipe = PixArtAlphaPipeline .from_pretrained (self .ckpt_id_512 , torch_dtype = torch .bfloat16 )
397+ pipe = pipe .to (device )
398+ # Flush pending lazy-XLA parameter-copy ops before compiling.
399+ torch .neuron .synchronize ()
400+
401+ pipe .transformer .eval ()
402+ pipe .vae .eval ()
403+ pipe .text_encoder .eval ()
404+
405+ set_model_name ("pixart_text_encoder" )
406+ pipe .text_encoder = torch .compile (pipe .text_encoder , backend = "neuron" , fullgraph = True )
407+ set_model_name ("pixart_transformer" )
408+ pipe .transformer = torch .compile (pipe .transformer , backend = "neuron" , fullgraph = True )
409+ # VAE must be compiled after pipeline __init__ (which reads vae.config.block_out_channels).
410+ set_model_name ("pixart_vae" )
411+ pipe .vae = torch .compile (pipe .vae , backend = "neuron" , fullgraph = True )
412+
413+ image = pipe (
414+ self .prompt ,
415+ generator = generator ,
416+ height = 512 ,
417+ width = 512 ,
418+ num_inference_steps = 2 ,
419+ output_type = "np" ,
420+ ).images
421+
422+ self .assertEqual (image .shape , (1 , 512 , 512 , 3 ))
423+ self .assertFalse (np .isnan (image ).any (), "Output contains NaN values" )
424+ self .assertTrue (
425+ (image >= 0.0 ).all () and (image <= 1.0 ).all (),
426+ "Output pixel values outside [0, 1]" ,
427+ )
0 commit comments