Skip to content

Commit ffc66be

Browse files
authored
Merge pull request #130 from BuildingEnergySimulationTools/add-kwargs-to-as-plots
Add kwargs to as plots
2 parents 741ac97 + b594e55 commit ffc66be

3 files changed

Lines changed: 156 additions & 9 deletions

File tree

corrai/sampling.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,15 +111,15 @@ class Sample:
111111
112112
Parameters
113113
----------
114-
parameters : list of Parameter
114+
parameters : list of Parameters
115115
List of model parameters used to generate the samples.
116116
117117
Attributes
118118
----------
119119
parameters : list of Parameter
120120
Parameters associated with this sample.
121121
is_dynamic : Bool default True
122-
Specify if stored results are timeeries in a DataFrame for dynamic models
122+
Specify if stored results are timeseries in a DataFrame for dynamic models
123123
or a Series of float for static models
124124
values : ndarray of shape (n_samples, n_parameters)
125125
Numerical values of the sampled parameters.
@@ -696,6 +696,7 @@ def plot_sample(
696696
round_ndigits: int = 2,
697697
quantile_band: float = 0.75,
698698
type_graph: str = "area",
699+
plot_kwargs: dict = None,
699700
) -> go.Figure:
700701
"""
701702
Plot simulation results with different visualization modes.
@@ -736,6 +737,9 @@ def plot_sample(
736737
- ``"scatter"`` : plot all samples individually as scatter markers.
737738
- ``"area"`` : plot aggregated area with min–max envelope,
738739
median line, and quantile bands.
740+
plot_kwargs : dict, optional
741+
Extra keyword arguments passed to ``fig.update_layout()``.
742+
Use to override any layout property (font, colors, axis sizes, etc.).
739743
740744
Examples
741745
--------
@@ -889,6 +893,8 @@ def _legend_for(i: int) -> str:
889893
showlegend=True,
890894
legend_traceorder="normal",
891895
)
896+
if plot_kwargs:
897+
fig.update_layout(**plot_kwargs)
892898
return fig
893899

894900
def plot_pcp(
@@ -1020,6 +1026,7 @@ def plot_sample(
10201026
round_ndigits: int = 2,
10211027
quantile_band: float = 0.75,
10221028
type_graph: str = "area",
1029+
plot_kwargs: dict = None,
10231030
) -> go.Figure:
10241031
return self.sample.plot_sample(
10251032
indicator=indicator,
@@ -1032,6 +1039,7 @@ def plot_sample(
10321039
round_ndigits=round_ndigits,
10331040
quantile_band=quantile_band,
10341041
type_graph=type_graph,
1042+
plot_kwargs=plot_kwargs,
10351043
)
10361044

10371045
@wraps(Sample.plot_pcp)

corrai/sensitivity.py

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def salib_plot_bar(
213213
reference_time_series: pd.Series = None,
214214
agg_method_kwarg: dict = None,
215215
title: str = None,
216+
plot_kwargs: dict = None,
216217
**analyse_kwarg,
217218
):
218219
"""
@@ -256,13 +257,16 @@ def salib_plot_bar(
256257
**analyse_kwarg,
257258
)[f"{method}_{indicator}"]
258259

260+
_kwargs = dict(plot_kwargs or {})
261+
effective_title = _kwargs.pop("title", title)
259262
return plot_bars(
260263
pd.Series(
261264
data=res[sensitivity_metric],
262265
index=[par.name for par in self.sampler.sample.parameters],
263266
name=f"{sensitivity_metric} {unit}",
264267
).sort_values(),
265-
title=title,
268+
title=effective_title,
269+
**_kwargs,
266270
)
267271

268272
def salib_plot_dynamic_metric(
@@ -277,6 +281,7 @@ def salib_plot_dynamic_metric(
277281
reference_time_series: pd.Series = None,
278282
title: str = None,
279283
stacked: bool = False,
284+
plot_kwargs: dict = None,
280285
**analyse_kwarg,
281286
):
282287
"""
@@ -323,7 +328,11 @@ def salib_plot_dynamic_metric(
323328
index=res.index,
324329
)
325330

326-
return plot_dynamic_metric(metrics, sensitivity_metric, unit, title, stacked)
331+
_kwargs = dict(plot_kwargs or {})
332+
effective_title = _kwargs.pop("title", title)
333+
return plot_dynamic_metric(
334+
metrics, sensitivity_metric, unit, effective_title, stacked, **_kwargs
335+
)
327336

328337
def salib_plot_matrix(
329338
self,
@@ -333,6 +342,8 @@ def salib_plot_matrix(
333342
reference_time_series: pd.Series = None,
334343
agg_method_kwarg: dict = None,
335344
title: str = None,
345+
fixed_range: bool = True,
346+
plot_kwargs: dict = None,
336347
**analyse_kwarg,
337348
):
338349
"""
@@ -376,7 +387,15 @@ def salib_plot_matrix(
376387
)[f"{method}_{indicator}"]
377388

378389
parameter_names = [p.name for p in self.sampler.sample.parameters]
379-
return plot_s2_matrix(result, parameter_names, title=title)
390+
_kwargs = dict(plot_kwargs or {})
391+
effective_title = _kwargs.pop("title", title)
392+
return plot_s2_matrix(
393+
result,
394+
parameter_names,
395+
title=effective_title,
396+
fixed_range=fixed_range,
397+
**_kwargs,
398+
)
380399

381400

382401
class SobolSanalysis(Sanalysis):
@@ -453,6 +472,7 @@ def plot_bar(
453472
unit: str = "",
454473
agg_method_kwarg: dict = None,
455474
title: str = None,
475+
plot_kwargs: dict = None,
456476
**analyse_kwargs,
457477
):
458478
return super().salib_plot_bar(
@@ -464,6 +484,7 @@ def plot_bar(
464484
reference_time_series=reference_time_series,
465485
agg_method_kwarg=agg_method_kwarg,
466486
title=title,
487+
plot_kwargs=plot_kwargs,
467488
calc_second_order=calc_second_order,
468489
**analyse_kwargs,
469490
)
@@ -479,6 +500,7 @@ def plot_dynamic_metric(
479500
agg_method_kwarg: dict = None,
480501
calc_second_order: bool = True,
481502
title: str = None,
503+
plot_kwargs: dict = None,
482504
):
483505
return super().salib_plot_dynamic_metric(
484506
indicator=indicator,
@@ -492,6 +514,7 @@ def plot_dynamic_metric(
492514
calc_second_order=calc_second_order,
493515
stacked=True,
494516
title=title,
517+
plot_kwargs=plot_kwargs,
495518
)
496519

497520
def plot_s2_matrix(
@@ -501,6 +524,8 @@ def plot_s2_matrix(
501524
reference_time_series: pd.Series = None,
502525
agg_method_kwarg: dict = None,
503526
title: str = None,
527+
fixed_range: bool = True,
528+
plot_kwargs: dict = None,
504529
**analyse_kwargs,
505530
):
506531
return super().salib_plot_matrix(
@@ -510,6 +535,8 @@ def plot_s2_matrix(
510535
reference_time_series=reference_time_series,
511536
agg_method_kwarg=agg_method_kwarg,
512537
title=title,
538+
fixed_range=fixed_range,
539+
plot_kwargs=plot_kwargs,
513540
**analyse_kwargs,
514541
)
515542

@@ -581,6 +608,7 @@ def plot_scatter(
581608
unit: str = "",
582609
scaler: float = 100,
583610
autosize: bool = True,
611+
plot_kwargs: dict = None,
584612
**analyse_kwargs,
585613
):
586614
cache_key = (indicator, method, "None")
@@ -596,12 +624,15 @@ def plot_scatter(
596624
)[f"{method}_{indicator}"]
597625
self._analysis_cache[cache_key] = {f"{method}_{indicator}": result}
598626

627+
_kwargs = dict(plot_kwargs or {})
628+
effective_title = _kwargs.pop("title", title)
599629
return plot_morris_scatter(
600630
result,
601-
title=title,
631+
title=effective_title,
602632
unit=unit,
603633
scaler=scaler,
604634
autosize=autosize,
635+
**_kwargs,
605636
)
606637

607638
def plot_bar(
@@ -613,6 +644,7 @@ def plot_bar(
613644
unit: str = "",
614645
agg_method_kwarg: dict = None,
615646
title: str = None,
647+
plot_kwargs: dict = None,
616648
**analyse_kwargs,
617649
):
618650
return super().salib_plot_bar(
@@ -624,6 +656,7 @@ def plot_bar(
624656
reference_time_series=reference_time_series,
625657
agg_method_kwarg=agg_method_kwarg,
626658
title=title,
659+
plot_kwargs=plot_kwargs,
627660
**analyse_kwargs,
628661
)
629662

@@ -637,6 +670,7 @@ def plot_dynamic_metric(
637670
unit: str = "",
638671
agg_method_kwarg: dict = None,
639672
title: str = None,
673+
plot_kwargs: dict = None,
640674
):
641675
return super().salib_plot_dynamic_metric(
642676
indicator,
@@ -648,6 +682,7 @@ def plot_dynamic_metric(
648682
agg_method_kwarg,
649683
reference_time_series,
650684
title,
685+
plot_kwargs=plot_kwargs,
651686
)
652687

653688

@@ -709,6 +744,7 @@ def plot_bar(
709744
unit: str = "",
710745
agg_method_kwarg: dict = None,
711746
title: str = None,
747+
plot_kwargs: dict = None,
712748
**analyse_kwargs,
713749
):
714750
return super().salib_plot_bar(
@@ -720,6 +756,7 @@ def plot_bar(
720756
reference_time_series=reference_time_series,
721757
agg_method_kwarg=agg_method_kwarg,
722758
title=title,
759+
plot_kwargs=plot_kwargs,
723760
**analyse_kwargs,
724761
)
725762

@@ -733,6 +770,7 @@ def plot_dynamic_metric(
733770
unit: str = "",
734771
agg_method_kwarg: dict = None,
735772
title: str = None,
773+
plot_kwargs: dict = None,
736774
):
737775
return super().salib_plot_dynamic_metric(
738776
indicator=indicator,
@@ -745,6 +783,7 @@ def plot_dynamic_metric(
745783
reference_time_series=reference_time_series,
746784
stacked=True,
747785
title=title,
786+
plot_kwargs=plot_kwargs,
748787
)
749788

750789

@@ -803,6 +842,7 @@ def plot_bar(
803842
unit: str = "",
804843
agg_method_kwarg: dict = None,
805844
title: str = None,
845+
plot_kwargs: dict = None,
806846
**analyse_kwargs,
807847
):
808848
return super().salib_plot_bar(
@@ -814,6 +854,7 @@ def plot_bar(
814854
reference_time_series=reference_time_series,
815855
agg_method_kwarg=agg_method_kwarg,
816856
title=title,
857+
plot_kwargs=plot_kwargs,
817858
**analyse_kwargs,
818859
)
819860

@@ -827,6 +868,7 @@ def plot_dynamic_metric(
827868
unit: str = "",
828869
agg_method_kwarg: dict = None,
829870
title: str = None,
871+
plot_kwargs: dict = None,
830872
):
831873
return super().salib_plot_dynamic_metric(
832874
indicator=indicator,
@@ -839,15 +881,25 @@ def plot_dynamic_metric(
839881
reference_time_series=reference_time_series,
840882
stacked=True,
841883
title=title,
884+
plot_kwargs=plot_kwargs,
842885
)
843886

844887

888+
def _apply_figure_kwargs(fig, **kwargs):
889+
for key, val in kwargs.items():
890+
try:
891+
fig.update_layout(**{key: val})
892+
except ValueError:
893+
fig.update_traces(**{key: val})
894+
895+
845896
def plot_dynamic_metric(
846897
metrics: pd.DataFrame,
847898
metric_name: str = "",
848899
unit: str = "",
849900
title: str = None,
850901
stacked: bool = False,
902+
**plot_kwargs,
851903
):
852904
fig = go.Figure()
853905
for param in metrics.columns:
@@ -866,12 +918,16 @@ def plot_dynamic_metric(
866918
xaxis_title="Time",
867919
yaxis_title=f"{metric_name} {unit}",
868920
)
921+
_apply_figure_kwargs(fig, **plot_kwargs)
869922

870923
return fig
871924

872925

873926
def plot_bars(
874-
sensitivity_results: pd.Series, title: str = None, error: pd.Series = None
927+
sensitivity_results: pd.Series,
928+
title: str = None,
929+
error: pd.Series = None,
930+
**plot_kwargs,
875931
):
876932
error = {} if error is None else dict(type="data", array=error.values)
877933
fig = go.Figure()
@@ -890,6 +946,7 @@ def plot_bars(
890946
xaxis_title="Parameters",
891947
yaxis_title=f"{sensitivity_results.name}",
892948
)
949+
_apply_figure_kwargs(fig, **plot_kwargs)
893950

894951
return fig
895952

@@ -900,6 +957,7 @@ def plot_morris_scatter(
900957
unit: str = "",
901958
scaler: float = 100,
902959
autosize: bool = True,
960+
**plot_kwargs,
903961
) -> go.Figure:
904962
"""
905963
Plot a Morris sensitivity analysis scatter plot using μ* and σ.
@@ -986,6 +1044,7 @@ def plot_morris_scatter(
9861044
yaxis_title=f"Standard deviation of elementary effects σ [{unit}]",
9871045
yaxis_range=[-0.1 * y_max, y_max],
9881046
)
1047+
_apply_figure_kwargs(fig, **plot_kwargs)
9891048

9901049
return fig
9911050

@@ -995,17 +1054,22 @@ def plot_s2_matrix(
9951054
param_names: list[str],
9961055
title: str = "Sobol 2nd-order interactions (S2)",
9971056
colorscale: str = "Reds",
1057+
fixed_range: bool = True,
1058+
**plot_kwargs,
9981059
):
9991060
df_S2 = pd.DataFrame(result["S2"], index=param_names, columns=param_names)
10001061

1062+
zmin = -1 if fixed_range else df_S2.values.min()
1063+
zmax = 1 if fixed_range else df_S2.values.max()
1064+
10011065
fig = go.Figure(
10021066
data=go.Heatmap(
10031067
z=df_S2.values,
10041068
x=df_S2.columns,
10051069
y=df_S2.index,
10061070
colorscale=colorscale,
1007-
zmin=0,
1008-
zmax=df_S2.values.max(),
1071+
zmin=zmin,
1072+
zmax=zmax,
10091073
colorbar=dict(title="S2"),
10101074
text=df_S2.round(3).astype(str),
10111075
texttemplate="%{text}",
@@ -1017,5 +1081,6 @@ def plot_s2_matrix(
10171081
xaxis_title="Parameter",
10181082
yaxis_title="Parameter",
10191083
)
1084+
_apply_figure_kwargs(fig, **plot_kwargs)
10201085

10211086
return fig

0 commit comments

Comments
 (0)