Skip to content

Commit 9f1a6f1

Browse files
committed
Add eCDF plot to SBC plots
1 parent ea71482 commit 9f1a6f1

1 file changed

Lines changed: 158 additions & 20 deletions

File tree

sbi/analysis/plot.py

Lines changed: 158 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1738,7 +1738,7 @@ def _sbc_rank_plot(
17381738
if isinstance(rank, Tensor):
17391739
ranks_list[idx]: np.ndarray = rank.numpy() # type: ignore
17401740

1741-
plot_types = ["hist", "cdf"]
1741+
plot_types = ["hist", "cdf", "ecdf"]
17421742
assert plot_type in plot_types, (
17431743
"plot type {plot_type} not implemented, use one in {plot_types}."
17441744
)
@@ -1814,6 +1814,26 @@ def _sbc_rank_plot(
18141814
num_repeats,
18151815
alpha=uniform_region_alpha,
18161816
)
1817+
elif plot_type == "ecdf":
1818+
_plot_ranks_as_ecdf(
1819+
ranki[:, jj], # type: ignore
1820+
num_bins,
1821+
num_repeats,
1822+
num_sbc_runs,
1823+
ranks_label=ranks_labels[ii],
1824+
color=f"C{ii}" if colors is None else colors[ii],
1825+
xlabel=f"posterior ranks {parameter_labels[jj]}",
1826+
# Show legend and ylabel only in first subplot.
1827+
show_ylabel=jj == 0,
1828+
alpha=line_alpha,
1829+
)
1830+
if ii == 0 and show_uniform_region:
1831+
_plot_ecdf_region_expected_under_uniformity(
1832+
num_sbc_runs,
1833+
num_bins,
1834+
num_repeats,
1835+
alpha=uniform_region_alpha,
1836+
)
18171837
elif plot_type == "hist":
18181838
_plot_ranks_as_hist(
18191839
ranki[:, jj], # type: ignore
@@ -1855,25 +1875,48 @@ def _sbc_rank_plot(
18551875

18561876
plt.sca(ax)
18571877
ranki = ranks_list[0]
1858-
for jj in range(num_parameters):
1859-
_plot_ranks_as_cdf(
1860-
ranki[:, jj], # type: ignore
1861-
num_bins,
1862-
num_repeats,
1863-
ranks_label=parameter_labels[jj],
1864-
color=f"C{jj}" if colors is None else colors[jj],
1865-
xlabel="posterior rank",
1866-
# Plot ylabel and legend at last.
1867-
show_ylabel=jj == (num_parameters - 1),
1868-
alpha=line_alpha,
1869-
)
1870-
if show_uniform_region:
1871-
_plot_cdf_region_expected_under_uniformity(
1872-
num_sbc_runs,
1873-
num_bins,
1874-
num_repeats,
1875-
alpha=uniform_region_alpha,
1876-
)
1878+
1879+
if plot_type == "cdf":
1880+
for jj in range(num_parameters):
1881+
_plot_ranks_as_cdf(
1882+
ranki[:, jj], # type: ignore
1883+
num_bins,
1884+
num_repeats,
1885+
ranks_label=parameter_labels[jj],
1886+
color=f"C{jj}" if colors is None else colors[jj],
1887+
xlabel="posterior rank",
1888+
# Plot ylabel and legend at last.
1889+
show_ylabel=jj == (num_parameters - 1),
1890+
alpha=line_alpha,
1891+
)
1892+
if show_uniform_region:
1893+
_plot_cdf_region_expected_under_uniformity(
1894+
num_sbc_runs,
1895+
num_bins,
1896+
num_repeats,
1897+
alpha=uniform_region_alpha,
1898+
)
1899+
elif plot_type == "ecdf":
1900+
for jj in range(num_parameters):
1901+
_plot_ranks_as_ecdf(
1902+
ranki[:, jj], # type: ignore
1903+
num_bins,
1904+
num_repeats,
1905+
num_sbc_runs,
1906+
ranks_label=parameter_labels[jj],
1907+
color=f"C{jj}" if colors is None else colors[jj],
1908+
xlabel="posterior rank",
1909+
# Plot ylabel and legend at last.
1910+
show_ylabel=jj == (num_parameters - 1),
1911+
alpha=line_alpha,
1912+
)
1913+
if show_uniform_region:
1914+
_plot_ecdf_region_expected_under_uniformity(
1915+
num_sbc_runs,
1916+
num_bins,
1917+
num_repeats,
1918+
alpha=uniform_region_alpha,
1919+
)
18771920
# show legend on the last subplot.
18781921
plt.legend(**legend_kwargs)
18791922

@@ -1982,6 +2025,66 @@ def _plot_ranks_as_cdf(
19822025
plt.xlabel("posterior rank" if xlabel is None else xlabel)
19832026

19842027

2028+
def _plot_ranks_as_ecdf(
2029+
ranks: np.ndarray,
2030+
num_bins: int,
2031+
num_repeats: int,
2032+
num_sbc_runs: int,
2033+
ranks_label: Optional[str] = None,
2034+
xlabel: Optional[str] = None,
2035+
color: Optional[str] = None,
2036+
alpha: float = 0.8,
2037+
show_ylabel: bool = True,
2038+
num_ticks: int = 3,
2039+
) -> None:
2040+
"""Plot ranks as a delta of the empirical CDFs to the expected CDF
2041+
2042+
Args:
2043+
ranks: SBC ranks in shape (num_sbc_runs, )
2044+
num_bins: number of bins for the histogram, recommendation is num_sbc_runs / 20.
2045+
num_repeats: number of repeats of each CDF step, i.e., resolution of the eCDF.
2046+
ranks_label: label for the ranks, e.g., when comparing ranks of different
2047+
methods.
2048+
xlabel: label for the current parameter
2049+
color: line color for the cdf.
2050+
alpha: line transparency.
2051+
show_ylabel: whether to show y-label "counts".
2052+
show_legend: whether to show the legend, e.g., when comparing multiple ranks.
2053+
num_ticks: number of ticks on the x-axis.
2054+
legend_kwargs: kwargs for the legend.
2055+
2056+
"""
2057+
# Construct uniform histogram.
2058+
uni_bins = binom(num_sbc_runs, p=1 / (num_bins)).ppf(0.5) * np.ones(num_bins)
2059+
uni_bins_cdf = uni_bins.cumsum() / uni_bins.sum()
2060+
2061+
# Compute the mean to substract to all cdfs
2062+
means = [binom(num_sbc_runs, p=p).ppf(0.5) for p in uni_bins_cdf]
2063+
means_norm = means / np.max(means)
2064+
2065+
# Generate histogram of ranks.
2066+
hist, *_ = np.histogram(ranks, bins=num_bins, density=False)
2067+
# Construct empirical CDF, don't include last bin because it is 1 by default
2068+
histcs = hist.cumsum()
2069+
histcs_norm = histcs / histcs.max()
2070+
2071+
# Plot cdf and repeat each stair step
2072+
plt.plot(
2073+
np.linspace(0, 1, num_repeats * (num_bins - 1)),
2074+
np.repeat(histcs_norm[:-1] - means_norm[:-1], num_repeats),
2075+
label=ranks_label,
2076+
color=color,
2077+
alpha=alpha,
2078+
)
2079+
2080+
if show_ylabel:
2081+
plt.ylabel("empirical CDF - expected CDF")
2082+
2083+
plt.xlim(0, 1)
2084+
plt.xticks(np.linspace(0, 1, num_ticks))
2085+
plt.xlabel("posterior rank" if xlabel is None else xlabel)
2086+
2087+
19852088
def _plot_cdf_region_expected_under_uniformity(
19862089
num_sbc_runs: int,
19872090
num_bins: int,
@@ -2012,6 +2115,41 @@ def _plot_cdf_region_expected_under_uniformity(
20122115
)
20132116

20142117

2118+
def _plot_ecdf_region_expected_under_uniformity(
2119+
num_sbc_runs: int,
2120+
num_bins: int,
2121+
num_repeats: int,
2122+
alpha: float = 0.2,
2123+
color: str = "gray",
2124+
) -> None:
2125+
"""Plot region of empirical ecdfs expected under uniformity on the current axis."""
2126+
2127+
# Construct uniform histogram.
2128+
uni_bins = binom(num_sbc_runs, p=1 / num_bins).ppf(0.5) * np.ones(num_bins)
2129+
uni_bins_cdf = uni_bins.cumsum() / uni_bins.sum()
2130+
# Decrease value one in last entry by epsilon to find valid
2131+
# confidence intervals.
2132+
uni_bins_cdf[-1] -= 1e-9
2133+
2134+
# Compute the mean, lower and upper bounds
2135+
lower = [binom(num_sbc_runs, p=p).ppf(0.005) for p in uni_bins_cdf]
2136+
upper = [binom(num_sbc_runs, p=p).ppf(0.995) for p in uni_bins_cdf]
2137+
means = [binom(num_sbc_runs, p=p).ppf(0.5) for p in uni_bins_cdf]
2138+
means_norm = means / np.max(means)
2139+
lower_norm = lower / np.max(lower)
2140+
upper_norm = upper / np.max(upper)
2141+
2142+
# Plot grey area with expected ECDF.
2143+
plt.fill_between(
2144+
x=np.linspace(0, 1, num_repeats * (num_bins - 1)),
2145+
y1=np.repeat(lower_norm[:-1] - means_norm[:-1], num_repeats),
2146+
y2=np.repeat(upper_norm[:-1] - means_norm[:-1], num_repeats), # pyright: ignore[reportArgumentType]
2147+
color=color,
2148+
alpha=alpha,
2149+
label="expected under uniformity",
2150+
)
2151+
2152+
20152153
def _plot_hist_region_expected_under_uniformity(
20162154
num_sbc_runs: int,
20172155
num_bins: int,

0 commit comments

Comments
 (0)