@@ -84,11 +84,12 @@ def change_plot_reshape_data(
8484
8585def change_plot_create (
8686 observations : pd .DataFrame ,
87+ predictions : pd .DataFrame | None = None ,
8788 no_change_col : str = 'no_change' ,
8889 change_col : str = 'change' ,
8990 final_observation_col : str = 'final_obs' ,
90- title : str = None ,
91- subtitle : str = None ,
91+ title : str | None = None ,
92+ subtitle : str | None = None ,
9293 x_label : str = '' ,
9394 y_label : str = '' ,
9495 day_range : int = 365 * 10 ,
@@ -99,6 +100,7 @@ def change_plot_create(
99100 Args:
100101 observations: DataFrame with observations. Each row is an iteration of a
101102 tag, with the three columns described below.
103+ predictions: DataFrame with modeled predictions.
102104 no_change_col: Column name for the days elapsed from when the tag was added to
103105 when it was last confirmed (observed unchanged).
104106 change_col: Column name for the days elapsed from when the tag was added to when
@@ -122,6 +124,11 @@ def change_plot_create(
122124 final_observation_col = final_observation_col ,
123125 day_range = day_range ,
124126 )
127+ if predictions is not None :
128+ if subtitle is not None :
129+ subtitle = f"{ subtitle } \n Modeled predictions in red"
130+ else :
131+ subtitle = "Modeled predictions in red"
125132 fig = (
126133 gg .ggplot (
127134 data = reshaped ,
@@ -150,6 +157,20 @@ def change_plot_create(
150157 ) +
151158 gg .theme_bw ()
152159 )
160+ if predictions is not None :
161+ p_renamed = predictions .assign (
162+ year = pd .col ('t2' ),
163+ y = pd .col ('conf_mean' ),
164+ ymin = pd .col ('conf_lower' ),
165+ ymax = pd .col ('conf_upper' ),
166+ )
167+ fig = fig + gg .geom_ribbon (
168+ data = p_renamed ,
169+ fill = 'red' , alpha = 0.25 , linetype = 'dashed' , color = '#444444'
170+ ) + gg .geom_line (
171+ data = p_renamed , color = '#444444' ,
172+ mapping = gg .aes (x = 'year' , y = 'y' )
173+ )
153174 return fig
154175
155176
0 commit comments