@@ -145,3 +145,138 @@ def test_enable_disable_slicing(self):
145145 output_without_slicing .detach ().cpu ().numpy ().all (),
146146 output_without_slicing_2 .detach ().cpu ().numpy ().all (),
147147 ), "Without slicing outputs should match with the outputs when slicing is manually disabled."
148+
149+
150+ class NewAutoencoderTesterMixin :
151+ @staticmethod
152+ def _accepts_generator (model ):
153+ model_sig = inspect .signature (model .forward )
154+ accepts_generator = "generator" in model_sig .parameters
155+ return accepts_generator
156+
157+ @staticmethod
158+ def _accepts_norm_num_groups (model_class ):
159+ model_sig = inspect .signature (model_class .__init__ )
160+ accepts_norm_groups = "norm_num_groups" in model_sig .parameters
161+ return accepts_norm_groups
162+
163+ def test_forward_with_norm_groups (self ):
164+ if not self ._accepts_norm_num_groups (self .model_class ):
165+ pytest .skip (f"Test not supported for { self .model_class .__name__ } " )
166+ init_dict = self .get_init_dict ()
167+ inputs_dict = self .get_dummy_inputs ()
168+
169+ init_dict ["norm_num_groups" ] = 16
170+ init_dict ["block_out_channels" ] = (16 , 32 )
171+
172+ model = self .model_class (** init_dict )
173+ model .to (torch_device )
174+ model .eval ()
175+
176+ with torch .no_grad ():
177+ output = model (** inputs_dict )
178+
179+ if isinstance (output , dict ):
180+ output = output .to_tuple ()[0 ]
181+
182+ assert output is not None
183+ expected_shape = inputs_dict ["sample" ].shape
184+ assert output .shape == expected_shape , "Input and output shapes do not match"
185+
186+ def test_enable_disable_tiling (self ):
187+ if not hasattr (self .model_class , "enable_tiling" ):
188+ pytest .skip (f"Skipping test as { self .model_class .__name__ } doesn't support tiling." )
189+
190+ init_dict = self .get_init_dict ()
191+ inputs_dict = self .get_dummy_inputs ()
192+
193+ torch .manual_seed (0 )
194+ model = self .model_class (** init_dict ).to (torch_device )
195+
196+ if not hasattr (model , "use_tiling" ):
197+ pytest .skip (f"Skipping test as { self .model_class .__name__ } doesn't support tiling." )
198+
199+ inputs_dict .update ({"return_dict" : False })
200+ _ = inputs_dict .pop ("generator" , None )
201+ accepts_generator = self ._accepts_generator (model )
202+
203+ with torch .no_grad ():
204+ torch .manual_seed (0 )
205+ if accepts_generator :
206+ inputs_dict ["generator" ] = torch .manual_seed (0 )
207+ output_without_tiling = model (** inputs_dict )[0 ]
208+ if isinstance (output_without_tiling , DecoderOutput ):
209+ output_without_tiling = output_without_tiling .sample
210+
211+ torch .manual_seed (0 )
212+ model .enable_tiling ()
213+ if accepts_generator :
214+ inputs_dict ["generator" ] = torch .manual_seed (0 )
215+ output_with_tiling = model (** inputs_dict )[0 ]
216+ if isinstance (output_with_tiling , DecoderOutput ):
217+ output_with_tiling = output_with_tiling .sample
218+
219+ assert (output_without_tiling .cpu () - output_with_tiling .cpu ()).max () < 0.5 , (
220+ "VAE tiling should not affect the inference results"
221+ )
222+
223+ torch .manual_seed (0 )
224+ model .disable_tiling ()
225+ if accepts_generator :
226+ inputs_dict ["generator" ] = torch .manual_seed (0 )
227+ output_without_tiling_2 = model (** inputs_dict )[0 ]
228+ if isinstance (output_without_tiling_2 , DecoderOutput ):
229+ output_without_tiling_2 = output_without_tiling_2 .sample
230+
231+ assert torch .allclose (output_without_tiling .cpu (), output_without_tiling_2 .cpu ()), (
232+ "Without tiling outputs should match with the outputs when tiling is manually disabled."
233+ )
234+
235+ def test_enable_disable_slicing (self ):
236+ if not hasattr (self .model_class , "enable_slicing" ):
237+ pytest .skip (f"Skipping test as { self .model_class .__name__ } doesn't support slicing." )
238+
239+ init_dict = self .get_init_dict ()
240+ inputs_dict = self .get_dummy_inputs ()
241+
242+ torch .manual_seed (0 )
243+ model = self .model_class (** init_dict ).to (torch_device )
244+ if not hasattr (model , "use_slicing" ):
245+ pytest .skip (f"Skipping test as { self .model_class .__name__ } doesn't support tiling." )
246+
247+ inputs_dict .update ({"return_dict" : False })
248+ _ = inputs_dict .pop ("generator" , None )
249+ accepts_generator = self ._accepts_generator (model )
250+
251+ with torch .no_grad ():
252+ if accepts_generator :
253+ inputs_dict ["generator" ] = torch .manual_seed (0 )
254+
255+ torch .manual_seed (0 )
256+ output_without_slicing = model (** inputs_dict )[0 ]
257+ if isinstance (output_without_slicing , DecoderOutput ):
258+ output_without_slicing = output_without_slicing .sample
259+
260+ torch .manual_seed (0 )
261+ model .enable_slicing ()
262+ if accepts_generator :
263+ inputs_dict ["generator" ] = torch .manual_seed (0 )
264+ output_with_slicing = model (** inputs_dict )[0 ]
265+ if isinstance (output_with_slicing , DecoderOutput ):
266+ output_with_slicing = output_with_slicing .sample
267+
268+ assert (output_without_slicing .cpu () - output_with_slicing .cpu ()).max () < 0.5 , (
269+ "VAE slicing should not affect the inference results"
270+ )
271+
272+ torch .manual_seed (0 )
273+ model .disable_slicing ()
274+ if accepts_generator :
275+ inputs_dict ["generator" ] = torch .manual_seed (0 )
276+ output_without_slicing_2 = model (** inputs_dict )[0 ]
277+ if isinstance (output_without_slicing_2 , DecoderOutput ):
278+ output_without_slicing_2 = output_without_slicing_2 .sample
279+
280+ assert torch .allclose (output_without_slicing .cpu (), output_without_slicing_2 .cpu ()), (
281+ "Without slicing outputs should match with the outputs when slicing is manually disabled."
282+ )
0 commit comments