@@ -1803,6 +1803,25 @@ def test_qnn_backend_prelu(self):
18031803 self .lower_module_and_test_output (module , sample_input )
18041804
18051805 def test_qnn_backend_rand (self ):
1806+ module = Rand () # noqa: F405
1807+ sample_inputs = [
1808+ (torch .randn (3 , 4 , 5 ),),
1809+ (torch .randn (2 , 8 ),),
1810+ (
1811+ torch .randn (
1812+ 10 ,
1813+ ),
1814+ ),
1815+ (torch .randn (1 , 3 , 32 , 32 ),),
1816+ ]
1817+ for i , sample_input in enumerate (sample_inputs ):
1818+ with self .subTest (i = i ):
1819+ self .lower_module_and_test_output (
1820+ module , sample_input , assert_output_equal = False
1821+ )
1822+
1823+ def test_qnn_backend_randn (self ):
1824+ module = Randn () # noqa: F405
18061825 sample_inputs = [
18071826 (torch .randn (3 , 4 , 5 ),),
18081827 (torch .randn (2 , 8 ),),
@@ -1815,7 +1834,6 @@ def test_qnn_backend_rand(self):
18151834 ]
18161835 for i , sample_input in enumerate (sample_inputs ):
18171836 with self .subTest (i = i ):
1818- module = Rand () # noqa: F405
18191837 self .lower_module_and_test_output (
18201838 module , sample_input , assert_output_equal = False
18211839 )
@@ -4380,6 +4398,7 @@ def test_qnn_backend_prelu(self):
43804398 self .lower_module_and_test_output (module , sample_input )
43814399
43824400 def test_qnn_backend_rand (self ):
4401+ module = Rand () # noqa: F405
43834402 sample_inputs = [
43844403 (torch .randn (3 , 4 , 5 ),),
43854404 (torch .randn (2 , 8 ),),
@@ -4392,10 +4411,28 @@ def test_qnn_backend_rand(self):
43924411 ]
43934412 for i , sample_input in enumerate (sample_inputs ):
43944413 with self .subTest (i = i ):
4395- module = Rand () # noqa: F405
4396- module = self .get_qdq_module (module , sample_input )
4414+ qdq_module = self .get_qdq_module (module , sample_input )
43974415 self .lower_module_and_test_output (
4398- module , sample_input , assert_output_equal = False
4416+ qdq_module , sample_input , assert_output_equal = False
4417+ )
4418+
4419+ def test_qnn_backend_randn (self ):
4420+ module = Randn () # noqa: F405
4421+ sample_inputs = [
4422+ (torch .randn (3 , 4 , 5 ),),
4423+ (torch .randn (2 , 8 ),),
4424+ (
4425+ torch .randn (
4426+ 10 ,
4427+ ),
4428+ ),
4429+ (torch .randn (1 , 3 , 32 , 32 ),),
4430+ ]
4431+ for i , sample_input in enumerate (sample_inputs ):
4432+ with self .subTest (i = i ):
4433+ qdq_module = self .get_qdq_module (module , sample_input )
4434+ self .lower_module_and_test_output (
4435+ qdq_module , sample_input , assert_output_equal = False
43994436 )
44004437
44014438 def test_qnn_backend_reciprocal (self ):
0 commit comments