@@ -24,54 +24,78 @@ def foo():
2424 pass
2525
2626
27- def test_constructor ():
27+ @pytest .mark .parametrize ("reduction" , [None , "mean" , "sum" ])
28+ def test_constructor (reduction ):
2829
29- SystemEquation ([eq1 , eq2 ])
30- SystemEquation ([eq1 , eq2 ], reduction = "sum" )
30+ # Constructor with callable functions
31+ SystemEquation ([eq1 , eq2 ], reduction = reduction )
32+
33+ # Constructor with Equation instances
3134 SystemEquation (
3235 [
3336 FixedValue (value = 0.0 , components = ["u1" ]),
3437 FixedGradient (value = 0.0 , components = ["u2" ]),
3538 ],
36- reduction = "mean" ,
39+ reduction = reduction ,
3740 )
3841
42+ # Constructor with mixed types
43+ SystemEquation (
44+ [
45+ FixedValue (value = 0.0 , components = ["u1" ]),
46+ eq1 ,
47+ ],
48+ reduction = reduction ,
49+ )
50+
51+ # Non-standard reduction not implemented
3952 with pytest .raises (NotImplementedError ):
4053 SystemEquation ([eq1 , eq2 ], reduction = "foo" )
4154
55+ # Invalid input type
4256 with pytest .raises (ValueError ):
4357 SystemEquation (foo )
4458
4559
46- def test_residual ():
60+ @pytest .mark .parametrize ("reduction" , [None , "mean" , "sum" ])
61+ def test_residual (reduction ):
4762
63+ # Generate random points and output
4864 pts = LabelTensor (torch .rand (10 , 2 ), labels = ["x" , "y" ])
4965 pts .requires_grad = True
5066 u = torch .pow (pts , 2 )
5167 u .labels = ["u1" , "u2" ]
5268
53- eq_1 = SystemEquation ([ eq1 , eq2 ], reduction = "mean" )
54- res = eq_1 . residual ( pts , u )
55- assert res . shape == torch . Size ([ 10 ] )
69+ # System with callable functions
70+ system_eq = SystemEquation ([ eq1 , eq2 ], reduction = reduction )
71+ res = system_eq . residual ( pts , u )
5672
57- eq_1 = SystemEquation ([ eq1 , eq2 ], reduction = "sum" )
58- res = eq_1 . residual ( pts , u )
59- assert res .shape == torch . Size ([ 10 ])
73+ # Checks on the shape of the residual
74+ shape = torch . Size ([ 10 , 3 ]) if reduction is None else torch . Size ([ 10 ] )
75+ assert res .shape == shape
6076
61- eq_1 = SystemEquation ([eq1 , eq2 ], reduction = None )
62- res = eq_1 .residual (pts , u )
63- assert res .shape == torch .Size ([10 , 3 ])
77+ # System with Equation instances
78+ system_eq = SystemEquation (
79+ [
80+ FixedValue (value = 0.0 , components = ["u1" ]),
81+ FixedGradient (value = 0.0 , components = ["u2" ]),
82+ ],
83+ reduction = reduction ,
84+ )
6485
65- eq_1 = SystemEquation ([ eq1 , eq2 ])
66- res = eq_1 . residual ( pts , u )
67- assert res .shape == torch . Size ([ 10 , 3 ])
86+ # Checks on the shape of the residual
87+ shape = torch . Size ([ 10 , 3 ]) if reduction is None else torch . Size ([ 10 ] )
88+ assert res .shape == shape
6889
90+ # System with mixed types
6991 system_eq = SystemEquation (
7092 [
7193 FixedValue (value = 0.0 , components = ["u1" ]),
72- FixedGradient ( value = 0.0 , components = [ "u2" ]) ,
94+ eq1 ,
7395 ],
74- reduction = "mean" ,
96+ reduction = reduction ,
7597 )
76- res = system_eq .residual (pts , u )
77- assert res .shape == torch .Size ([10 ])
98+
99+ # Checks on the shape of the residual
100+ shape = torch .Size ([10 , 3 ]) if reduction is None else torch .Size ([10 ])
101+ assert res .shape == shape
0 commit comments