diff --git a/CITATION.cff b/CITATION.cff index 02d6a46..fb74da6 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -2,7 +2,7 @@ cff-version: 1.2.0 message: "If you use this software, please cite it as below." type: software title: "dte_adj: A Python Package for Estimating Distribution Treatment Effects" -version: 0.1.7 +version: 0.1.8 date-released: 2024-12-01 url: "https://github.com/CyberAgentAILab/python-dte-adjustment" repository-code: "https://github.com/CyberAgentAILab/python-dte-adjustment" diff --git a/docs/source/conf.py b/docs/source/conf.py index 1e4d356..444e311 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -13,7 +13,7 @@ project = "dte_adj" copyright = "2024, CyberAgent, Inc." author = "CyberAgent, Inc" -release = "0.1.7" +release = "0.1.8" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/dte_adj/plot/__init__.py b/dte_adj/plot/__init__.py index 25c380c..b9ee640 100644 --- a/dte_adj/plot/__init__.py +++ b/dte_adj/plot/__init__.py @@ -15,6 +15,7 @@ def plot( title: Optional[str] = None, xlabel: Optional[str] = None, ylabel: Optional[str] = None, + weighted: bool = False, ): """Visualize distributional parameters and their confidence intervals. @@ -29,12 +30,18 @@ def plot( title (str, optional): Axes title. xlabel (str, optional): X-axis title label. ylabel (str, optional): Y-axis title label. + weighted (bool, optional): If True, multiply treatment effects by X values to show value-weighted effects. Defaults to False. Returns: matplotlib.axes.Axes: The axes with the plot. """ if ax is None: - fig, ax = plt.subplots() + _, ax = plt.subplots() + + if weighted: + means = means * X + lower_bounds = lower_bounds * X + upper_bounds = upper_bounds * X if chart_type == "line": ax.plot(X, means, label="Values", color=color) diff --git a/pyproject.toml b/pyproject.toml index b0085d1..3d20e34 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dte_adj" -version = "0.1.7" +version = "0.1.8" description = "This is a Python library for estimating distributional treatment effects" readme = "README.md" requires-python = ">=3.10" diff --git a/tests/test_plot.py b/tests/test_plot.py index 6afcc98..4e574bf 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -72,6 +72,44 @@ def test_plot_fail_unknown_chart_type(self): "Chart type other is not supported", ) + @patch("dte_adj.plot.plt") + def test_plot_weighted(self, mock_plt): + # Arrange + x_values = np.array([1, 2, 3, 4, 5]) + means = np.array([0.1, 0.2, 0.3, 0.4, 0.5]) + upper_bands = np.array([0.2, 0.3, 0.4, 0.5, 0.6]) + lower_bands = np.array([0.0, 0.1, 0.2, 0.3, 0.4]) + mock_ax = MagicMock() + mock_plt.subplots.return_value = (MagicMock(), mock_ax) + + # Act + result_ax = plot( + x_values, + means, + lower_bands, + upper_bands, + chart_type="line", + weighted=True, + ) -if __name__ == "__main__": - unittest.main() + # Assert + self.assertEqual(result_ax, mock_ax) + mock_plt.subplots.assert_called_once() + plot_call = mock_ax.plot.call_args + fill_between_call = mock_ax.fill_between.call_args + + # Check that values are weighted (multiplied by x_values) + plot_args, plot_kwargs = plot_call + x_values_arg, y_values_arg = plot_args + expected_weighted_means = means * x_values + self.assertTrue(np.array_equal(x_values_arg, x_values)) + self.assertTrue(np.array_equal(y_values_arg, expected_weighted_means)) + + # Check that confidence intervals are also weighted + fill_between_args, fill_between_kwargs = fill_between_call + x_fill, lower_fill, upper_fill = fill_between_args + expected_weighted_lower = lower_bands * x_values + expected_weighted_upper = upper_bands * x_values + self.assertTrue(np.array_equal(x_fill, x_values_arg)) + self.assertTrue(np.array_equal(lower_fill, expected_weighted_lower)) + self.assertTrue(np.array_equal(upper_fill, expected_weighted_upper)) diff --git a/uv.lock b/uv.lock index b467304..05e5f37 100644 --- a/uv.lock +++ b/uv.lock @@ -312,7 +312,7 @@ wheels = [ [[package]] name = "dte-adj" -version = "0.1.7" +version = "0.1.8" source = { editable = "." } dependencies = [ { name = "matplotlib" },