diff --git a/sbi/analysis/plot.py b/sbi/analysis/plot.py index b3d8ce436..731b382f6 100644 --- a/sbi/analysis/plot.py +++ b/sbi/analysis/plot.py @@ -1688,6 +1688,7 @@ def _sbc_rank_plot( num_cols: int = 4, params_in_subplots: bool = False, show_ylabel: bool = False, + ylim: Optional[Tuple[float, float]] = (-0.125, 0.125), sharey: bool = False, fig: Optional[FigureBase] = None, legend_kwargs: Optional[Dict] = None, @@ -1702,7 +1703,8 @@ def _sbc_rank_plot( obtained from different methods. num_bins: number of bins used for binning the ranks, default is num_sbc_runs / 20. - plot_type: type of SBC plot, histograms ("hist") or empirical cdfs ("cdf"). + plot_type: type of SBC plot, histograms ("hist"), empirical cdfs ("cdf") or + empirical cdf minus expected cdf ("cdf-diff") parameter_labels: list of labels for each parameter dimension. ranks_labels: list of labels for each set of ranks. colors: list of colors for each parameter dimension, or each set of ranks. @@ -1718,6 +1720,7 @@ def _sbc_rank_plot( params_in_subplots: whether to show each parameter in a separate subplot, or all in one. show_ylabel: whether to show ylabels and ticks. + ylim: limits on the y-axis sharey: whether to share the y-labels, ticks, and limits across subplots. fig: figure object to plot in. ax: axis object, must contain as many sublpots as parameters or len(ranks). @@ -1738,10 +1741,12 @@ def _sbc_rank_plot( if isinstance(rank, Tensor): ranks_list[idx]: np.ndarray = rank.numpy() # type: ignore - plot_types = ["hist", "cdf"] - assert plot_type in plot_types, ( - "plot type {plot_type} not implemented, use one in {plot_types}." - ) + plot_types = ["hist", "cdf", "cdf-diff"] + + if plot_type not in plot_types: + raise ValueError( + f"plot type {plot_type} not supported, use one in {plot_types}." + ) if legend_kwargs is None: legend_kwargs = dict(loc="best", handlelength=0.8) @@ -1814,6 +1819,27 @@ def _sbc_rank_plot( num_repeats, alpha=uniform_region_alpha, ) + elif plot_type == "cdf-diff": + _plot_ranks_as_ecdf( + ranki[:, jj], # type: ignore + num_bins, + num_repeats, + num_sbc_runs, + ranks_label=ranks_labels[ii], + color=f"C{ii}" if colors is None else colors[ii], + xlabel=f"posterior ranks {parameter_labels[jj]}", + # Show legend and ylabel only in first subplot. + show_ylabel=jj == 0, + alpha=line_alpha, + ylim=ylim, + ) + if ii == 0 and show_uniform_region: + _plot_ecdf_region_expected_under_uniformity( + num_sbc_runs, + num_bins, + num_repeats, + alpha=uniform_region_alpha, + ) elif plot_type == "hist": _plot_ranks_as_hist( ranki[:, jj], # type: ignore @@ -1855,25 +1881,49 @@ def _sbc_rank_plot( plt.sca(ax) ranki = ranks_list[0] - for jj in range(num_parameters): - _plot_ranks_as_cdf( - ranki[:, jj], # type: ignore - num_bins, - num_repeats, - ranks_label=parameter_labels[jj], - color=f"C{jj}" if colors is None else colors[jj], - xlabel="posterior rank", - # Plot ylabel and legend at last. - show_ylabel=jj == (num_parameters - 1), - alpha=line_alpha, - ) - if show_uniform_region: - _plot_cdf_region_expected_under_uniformity( - num_sbc_runs, - num_bins, - num_repeats, - alpha=uniform_region_alpha, - ) + + if plot_type == "cdf": + for jj in range(num_parameters): + _plot_ranks_as_cdf( + ranki[:, jj], # type: ignore + num_bins, + num_repeats, + ranks_label=parameter_labels[jj], + color=f"C{jj}" if colors is None else colors[jj], + xlabel="posterior rank", + # Plot ylabel and legend at last. + show_ylabel=jj == (num_parameters - 1), + alpha=line_alpha, + ) + if show_uniform_region: + _plot_cdf_region_expected_under_uniformity( + num_sbc_runs, + num_bins, + num_repeats, + alpha=uniform_region_alpha, + ) + elif plot_type == "cdf-diff": + for jj in range(num_parameters): + _plot_ranks_as_ecdf( + ranki[:, jj], # type: ignore + num_bins, + num_repeats, + num_sbc_runs, + ranks_label=parameter_labels[jj], + color=f"C{jj}" if colors is None else colors[jj], + xlabel="posterior rank", + # Plot ylabel and legend at last. + show_ylabel=jj == (num_parameters - 1), + alpha=line_alpha, + ylim=ylim, + ) + if show_uniform_region: + _plot_ecdf_region_expected_under_uniformity( + num_sbc_runs, + num_bins, + num_repeats, + alpha=uniform_region_alpha, + ) # show legend on the last subplot. plt.legend(**legend_kwargs) @@ -1982,6 +2032,71 @@ def _plot_ranks_as_cdf( plt.xlabel("posterior rank" if xlabel is None else xlabel) +def _plot_ranks_as_ecdf( + ranks: np.ndarray, + num_bins: int, + num_repeats: int, + num_sbc_runs: int, + ranks_label: Optional[str] = None, + xlabel: Optional[str] = None, + color: Optional[str] = None, + alpha: float = 0.8, + show_ylabel: bool = True, + num_ticks: int = 3, + ylim: Optional[tuple[float, float]] = None, +) -> None: + """Plot ranks as a delta of the empirical CDFs to the expected CDF + + Args: + ranks: SBC ranks in shape (num_sbc_runs, ) + num_bins: number of bins for the histogram, recommendation is num_sbc_runs / 20. + num_repeats: number of repeats of each CDF step, i.e., resolution of the eCDF. + ranks_label: label for the ranks, e.g., when comparing ranks of different + methods. + xlabel: label for the current parameter + color: line color for the cdf. + alpha: line transparency. + show_ylabel: whether to show y-label. + num_ticks: number of ticks on the x-axis. + ylim: limits on the y-axis + + """ + # Construct uniform histogram. + uni_bins = binom(num_sbc_runs, p=1 / (num_bins)).ppf(0.5) * np.ones(num_bins) + uni_bins_cdf = uni_bins.cumsum() / uni_bins.sum() + + # Compute the mean to substract to all cdfs + means = [binom(num_sbc_runs, p=p).ppf(0.5) for p in uni_bins_cdf] + means_norm = means / np.max(means) + + # Generate histogram of ranks. + hist, *_ = np.histogram(ranks, bins=num_bins, density=False) + # Construct empirical CDF, don't include last bin because it is 1 by default + histcs = hist.cumsum() + histcs_norm = histcs / histcs.max() + + # Plot cdf and repeat each stair step + plt.plot( + np.linspace(0, 1, num_repeats * (num_bins - 1)), + np.repeat(histcs_norm[:-1] - means_norm[:-1], num_repeats), + label=ranks_label, + color=color, + alpha=alpha, + ) + + if show_ylabel: + plt.ylabel("empirical CDF - expected CDF") + + if ylim is not None: + plt.ylim(ylim) + else: + plt.ylim(-0.125, 0.125) + + plt.xlim(0, 1) + plt.xticks(np.linspace(0, 1, num_ticks)) + plt.xlabel("posterior rank" if xlabel is None else xlabel) + + def _plot_cdf_region_expected_under_uniformity( num_sbc_runs: int, num_bins: int, @@ -2012,6 +2127,41 @@ def _plot_cdf_region_expected_under_uniformity( ) +def _plot_ecdf_region_expected_under_uniformity( + num_sbc_runs: int, + num_bins: int, + num_repeats: int, + alpha: float = 0.2, + color: str = "gray", +) -> None: + """Plot region of empirical ecdfs expected under uniformity on the current axis.""" + + # Construct uniform histogram. + uni_bins = binom(num_sbc_runs, p=1 / num_bins).ppf(0.5) * np.ones(num_bins) + uni_bins_cdf = uni_bins.cumsum() / uni_bins.sum() + # Decrease value one in last entry by epsilon to find valid + # confidence intervals. + uni_bins_cdf[-1] -= 1e-9 + + # Compute the mean, lower and upper bounds + lower = [binom(num_sbc_runs, p=p).ppf(0.005) for p in uni_bins_cdf] + upper = [binom(num_sbc_runs, p=p).ppf(0.995) for p in uni_bins_cdf] + means = [binom(num_sbc_runs, p=p).ppf(0.5) for p in uni_bins_cdf] + means_norm = means / np.max(means) + lower_norm = lower / np.max(lower) + upper_norm = upper / np.max(upper) + + # Plot grey area with expected ECDF. + plt.fill_between( + x=np.linspace(0, 1, num_repeats * (num_bins - 1)), + y1=np.repeat(lower_norm[:-1] - means_norm[:-1], num_repeats), + y2=np.repeat(upper_norm[:-1] - means_norm[:-1], num_repeats), # pyright: ignore[reportArgumentType] + color=color, + alpha=alpha, + label="expected under uniformity", + ) + + def _plot_hist_region_expected_under_uniformity( num_sbc_runs: int, num_bins: int, diff --git a/tests/sbc_test.py b/tests/sbc_test.py index c1f7cbadf..34446ff92 100644 --- a/tests/sbc_test.py +++ b/tests/sbc_test.py @@ -270,7 +270,7 @@ def test_sbc_checks(): @pytest.mark.parametrize("num_bins", (None, 30)) -@pytest.mark.parametrize("plot_type", ("cdf", "hist")) +@pytest.mark.parametrize("plot_type", ("cdf", "hist", "cdf-diff")) @pytest.mark.parametrize("legend_kwargs", (None, {"loc": "upper left"})) @pytest.mark.parametrize("num_rank_sets", (1, 2)) def test_sbc_plotting(