@@ -72,6 +72,44 @@ def test_plot_fail_unknown_chart_type(self):
7272 "Chart type other is not supported" ,
7373 )
7474
75+ @patch ("dte_adj.plot.plt" )
76+ def test_plot_weighted (self , mock_plt ):
77+ # Arrange
78+ x_values = np .array ([1 , 2 , 3 , 4 , 5 ])
79+ means = np .array ([0.1 , 0.2 , 0.3 , 0.4 , 0.5 ])
80+ upper_bands = np .array ([0.2 , 0.3 , 0.4 , 0.5 , 0.6 ])
81+ lower_bands = np .array ([0.0 , 0.1 , 0.2 , 0.3 , 0.4 ])
82+ mock_ax = MagicMock ()
83+ mock_plt .subplots .return_value = (MagicMock (), mock_ax )
84+
85+ # Act
86+ result_ax = plot (
87+ x_values ,
88+ means ,
89+ lower_bands ,
90+ upper_bands ,
91+ chart_type = "line" ,
92+ weighted = True ,
93+ )
7594
76- if __name__ == "__main__" :
77- unittest .main ()
95+ # Assert
96+ self .assertEqual (result_ax , mock_ax )
97+ mock_plt .subplots .assert_called_once ()
98+ plot_call = mock_ax .plot .call_args
99+ fill_between_call = mock_ax .fill_between .call_args
100+
101+ # Check that values are weighted (multiplied by x_values)
102+ plot_args , plot_kwargs = plot_call
103+ x_values_arg , y_values_arg = plot_args
104+ expected_weighted_means = means * x_values
105+ self .assertTrue (np .array_equal (x_values_arg , x_values ))
106+ self .assertTrue (np .array_equal (y_values_arg , expected_weighted_means ))
107+
108+ # Check that confidence intervals are also weighted
109+ fill_between_args , fill_between_kwargs = fill_between_call
110+ x_fill , lower_fill , upper_fill = fill_between_args
111+ expected_weighted_lower = lower_bands * x_values
112+ expected_weighted_upper = upper_bands * x_values
113+ self .assertTrue (np .array_equal (x_fill , x_values_arg ))
114+ self .assertTrue (np .array_equal (lower_fill , expected_weighted_lower ))
115+ self .assertTrue (np .array_equal (upper_fill , expected_weighted_upper ))
0 commit comments