Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ cff-version: 1.2.0
message: "If you use this software, please cite it as below."
type: software
title: "dte_adj: A Python Package for Estimating Distribution Treatment Effects"
version: 0.1.7
version: 0.1.8
date-released: 2024-12-01
url: "https://github.com/CyberAgentAILab/python-dte-adjustment"
repository-code: "https://github.com/CyberAgentAILab/python-dte-adjustment"
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
project = "dte_adj"
copyright = "2024, CyberAgent, Inc."
author = "CyberAgent, Inc"
release = "0.1.7"
release = "0.1.8"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
9 changes: 8 additions & 1 deletion dte_adj/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def plot(
title: Optional[str] = None,
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
weighted: bool = False,
):
"""Visualize distributional parameters and their confidence intervals.

Expand All @@ -29,12 +30,18 @@ def plot(
title (str, optional): Axes title.
xlabel (str, optional): X-axis title label.
ylabel (str, optional): Y-axis title label.
weighted (bool, optional): If True, multiply treatment effects by X values to show value-weighted effects. Defaults to False.

Returns:
matplotlib.axes.Axes: The axes with the plot.
"""
if ax is None:
fig, ax = plt.subplots()
_, ax = plt.subplots()

if weighted:
means = means * X
lower_bounds = lower_bounds * X
upper_bounds = upper_bounds * X

if chart_type == "line":
ax.plot(X, means, label="Values", color=color)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "dte_adj"
version = "0.1.7"
version = "0.1.8"
description = "This is a Python library for estimating distributional treatment effects"
readme = "README.md"
requires-python = ">=3.10"
Expand Down
42 changes: 40 additions & 2 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,44 @@ def test_plot_fail_unknown_chart_type(self):
"Chart type other is not supported",
)

@patch("dte_adj.plot.plt")
def test_plot_weighted(self, mock_plt):
# Arrange
x_values = np.array([1, 2, 3, 4, 5])
means = np.array([0.1, 0.2, 0.3, 0.4, 0.5])
upper_bands = np.array([0.2, 0.3, 0.4, 0.5, 0.6])
lower_bands = np.array([0.0, 0.1, 0.2, 0.3, 0.4])
mock_ax = MagicMock()
mock_plt.subplots.return_value = (MagicMock(), mock_ax)

# Act
result_ax = plot(
x_values,
means,
lower_bands,
upper_bands,
chart_type="line",
weighted=True,
)

if __name__ == "__main__":
unittest.main()
# Assert
self.assertEqual(result_ax, mock_ax)
mock_plt.subplots.assert_called_once()
plot_call = mock_ax.plot.call_args
fill_between_call = mock_ax.fill_between.call_args

# Check that values are weighted (multiplied by x_values)
plot_args, plot_kwargs = plot_call
x_values_arg, y_values_arg = plot_args
expected_weighted_means = means * x_values
self.assertTrue(np.array_equal(x_values_arg, x_values))
self.assertTrue(np.array_equal(y_values_arg, expected_weighted_means))

# Check that confidence intervals are also weighted
fill_between_args, fill_between_kwargs = fill_between_call
x_fill, lower_fill, upper_fill = fill_between_args
expected_weighted_lower = lower_bands * x_values
expected_weighted_upper = upper_bands * x_values
self.assertTrue(np.array_equal(x_fill, x_values_arg))
self.assertTrue(np.array_equal(lower_fill, expected_weighted_lower))
self.assertTrue(np.array_equal(upper_fill, expected_weighted_upper))
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading