-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy path__init__.py
More file actions
78 lines (70 loc) · 2.4 KB
/
__init__.py
File metadata and controls
78 lines (70 loc) · 2.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.axis as axis
from typing import Optional
def plot(
X: np.ndarray,
means: np.ndarray,
lower_bounds: np.ndarray,
upper_bounds: np.ndarray,
chart_type: str = "line",
color: str = "green",
ax: Optional[axis.Axis] = None,
title: Optional[str] = None,
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
weighted: bool = False,
):
"""Visualize distributional parameters and their confidence intervals.
Args:
X (np.Array): values to be used for x axis.
means (np.Array): Expected distributional parameters.
lower_bounds (np.Array): Lower bound for the distributional parameters.
upper_bounds (np.Array): Upper bound for the distributional parameters.
chart_type (str): Chart type of the plotting. Available values are line or bar.
color (str): The color of lines or bars.
ax (matplotlib.axes.Axes, optional): Target axes instance. If None, a new figure and axes will be created.
title (str, optional): Axes title.
xlabel (str, optional): X-axis title label.
ylabel (str, optional): Y-axis title label.
weighted (bool, optional): If True, multiply treatment effects by X values to show value-weighted effects. Defaults to False.
Returns:
matplotlib.axes.Axes: The axes with the plot.
"""
if ax is None:
_, ax = plt.subplots()
if weighted:
means = means * X
lower_bounds = lower_bounds * X
upper_bounds = upper_bounds * X
if chart_type == "line":
ax.plot(X, means, label="Values", color=color)
ax.fill_between(
X,
lower_bounds,
upper_bounds,
color=color,
alpha=0.3,
label="Confidence Interval",
)
elif chart_type == "bar":
ax.bar(
X,
means,
yerr=[
np.maximum(means - lower_bounds, 0),
np.maximum(upper_bounds - means, 0),
],
capsize=5,
color=color,
width=(X.max() - X.min()) / len(X),
)
else:
raise ValueError(f"Chart type {chart_type} is not supported")
if title is not None:
ax.set_title(title)
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
return ax