Skip to content

Commit 85823fa

Browse files
LCOA R script + python hook
1 parent 08a8fbb commit 85823fa

5 files changed

Lines changed: 883 additions & 14 deletions

File tree

src/scripts/lcoa_inputs_from_tlo_analyses/analysis_effect_of_treatment_ids.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,12 @@
1414
import pickle
1515
from pathlib import Path
1616
import pandas as pd
17-
import numpy as np
18-
from matplotlib import pyplot as plt
17+
18+
1919
from tlo import Date
2020
from tlo.util import create_age_range_lookup
2121

22-
from scripts.lcoa_inputs_from_tlo_analyses.fig_utils import (
23-
do_bar_plot_with_ci,
24-
plot_multiindex_dot_with_interval,
25-
)
22+
2623
from scripts.lcoa_inputs_from_tlo_analyses.results_processing_utils import (
2724
get_counts_of_appts,
2825
get_counts_of_hsi_by_short_treatment_id,
@@ -63,7 +60,6 @@
6360
extract_results,
6461
get_color_short_treatment_id,
6562
make_age_grp_lookup,
66-
squarify_neat,
6763
summarize,
6864
)
6965
# python src/scripts/lcoa_inputs_from_tlo_analyses/analysis_effect_of_treatment_ids.py outputs/s.bhatia@imperial.ac.uk/effect_of_each_treatment_id-2026-02-12T120859Z figs/ --target-start=2010-01-01 --target-end=2025-12-31
@@ -339,6 +335,10 @@ def apply(
339335
pipe(set_param_names_as_column_index_level_0, param_names=param_names)
340336
)
341337

338+
capacity_used_by_cadre = (
339+
compute_summary_statistics(capacity_used_by_cadre, central_measure='median')
340+
)
341+
342342
results['dalys'] = dalys
343343
results['dalys_averted'] = dalys_averted if do_comparison else None
344344
results['pc_dalys_averted'] = pc_dalys_averted if do_comparison else None

src/scripts/lcoa_inputs_from_tlo_analyses/fig_utils.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,65 @@ def do_label_barh_plot(_df: pd.DataFrame, _ax):
437437
size=7,
438438
)
439439

440+
def plot_cadre_time_by_draw_stacked(
441+
_df: pd.DataFrame,
442+
stat: str = "central",
443+
figsize: tuple[float, float] | None = None,
444+
):
445+
"""Plot horizontal stacked bars of cadre time use by draw for one summary stat."""
446+
if not isinstance(_df.columns, pd.MultiIndex) or _df.columns.nlevels != 2:
447+
raise ValueError("_df columns must be a 2-level MultiIndex with levels for draw and stat.")
448+
449+
stat_level_name = "stat" if "stat" in _df.columns.names else _df.columns.names[1]
450+
available_stats = pd.Index(_df.columns.get_level_values(stat_level_name).unique())
451+
if stat not in available_stats:
452+
raise ValueError(f"Statistic '{stat}' not found. Available stats: {available_stats.tolist()}")
453+
454+
_plot = _df.xs(stat, axis=1, level=stat_level_name).T.fillna(0.0)
455+
if _plot.empty:
456+
raise ValueError(f"No plottable data remain for stat '{stat}'.")
457+
458+
_plot = _plot.loc[_plot.sum(axis=1).sort_values(ascending=True).index]
459+
460+
if figsize is None:
461+
fig_height = max(6, min(0.35 * len(_plot.index) + 3, 20))
462+
figsize = (12, fig_height)
463+
fig, ax = plt.subplots(figsize=figsize)
464+
465+
cadre_colors = list(plt.get_cmap("tab10").colors)
466+
left = np.zeros(len(_plot.index), dtype=float)
467+
y = np.arange(len(_plot.index))
468+
469+
for i, cadre in enumerate(_plot.columns):
470+
values = _plot[cadre].to_numpy(dtype=float)
471+
ax.barh(
472+
y,
473+
values,
474+
left=left,
475+
color=cadre_colors[i % len(cadre_colors)],
476+
label=str(cadre),
477+
)
478+
left += values
479+
480+
ax.set_yticks(y)
481+
ax.set_yticklabels([str(draw) for draw in _plot.index])
482+
ax.set_xlabel("Time used")
483+
ax.set_ylabel("Draw")
484+
ax.grid(axis="x")
485+
ax.spines["top"].set_visible(False)
486+
ax.spines["right"].set_visible(False)
487+
ax.legend(
488+
loc="lower right",
489+
fontsize=12,
490+
handlelength=2.4,
491+
handleheight=1.6,
492+
borderpad=1.0,
493+
labelspacing=0.8,
494+
frameon=True,
495+
)
496+
fig.tight_layout()
497+
return fig, ax
498+
440499
def plot_hsi_counts_stacked_bar(_df: pd.DataFrame, plot_stat: str = "central"):
441500
"""Plot horizontal stacked bars of HSI counts by draw for a selected summary statistic."""
442501
if not isinstance(_df.columns, pd.MultiIndex) or _df.columns.nlevels != 2:

src/scripts/lcoa_inputs_from_tlo_analyses/figures_effect_of_treatment_ids.py

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
make_graph_file_name,
1818
do_barh_plot_with_ci,
1919
do_bar_plot_with_ci,
20+
plot_cadre_time_by_draw_stacked,
2021
plot_deaths_by_period_for_cause,
2122
plot_deaths_by_period_for_draw,
2223
plot_hsi_counts_by_period_for_draw,
@@ -33,37 +34,45 @@
3334
def load_results_files(results_files: list[Path]) -> dict[Path, dict]:
3435
loaded = {}
3536
for results_file in results_files:
37+
print(f"Loading results file: {results_file}")
3638
with open(results_file, "rb") as f:
3739
loaded[results_file] = pickle.load(f)
3840
return loaded
3941

4042

4143
def apply(results_files: list[Path], output_folder: Path, resourcefilepath: Path = None):
4244
"""Produce standard plots describing effect of each TREATMENT_ID."""
45+
print("Starting figure generation for treatment-ID effects.")
46+
print(f"Output folder: {output_folder}")
4347

4448
param_names = get_parameter_names_from_scenario_file()
49+
print(f"Loaded parameter names: {len(param_names)}")
4550

4651
all_results = load_results_files(results_files)
4752
primary_results = all_results[results_files[0]]
53+
print(f"Using primary results from: {results_files[0]}")
4854

4955
num_deaths_averted = primary_results.get('num_deaths_averted')
5056
pc_deaths_averted = primary_results.get('pc_deaths_averted')
51-
num_dalys_averted = primary_results.get('num_dalys_averted')
57+
dalys_averted = primary_results.get('dalys_averted')
5258
pc_dalys_averted = primary_results.get('pc_dalys_averted')
5359
icers = primary_results.get('icers_summarized')
5460
comparison_metrics_available = all(
5561
metric is not None
5662
for metric in (
5763
num_deaths_averted,
5864
pc_deaths_averted,
59-
num_dalys_averted,
65+
dalys_averted,
6066
pc_dalys_averted,
6167
icers,
6268
)
6369
)
70+
print(f"Comparison metrics available: {comparison_metrics_available}")
6471

6572
counts_of_hsi_in_implementation_period = primary_results['counts_of_hsi_by_period']
6673
counts_of_hsi_in_implementation_period = counts_of_hsi_in_implementation_period.drop(['2010-2041'], level=1)
74+
capacity_used_by_cadre = primary_results.get("capacity_used_by_cadre")
75+
6776

6877
result_df_by_period = pd.DataFrame([
6978
{'treatment_id_included': draw, 'nonzero_hsis': treatment_id, 'period': period}
@@ -82,18 +91,48 @@ def apply(results_files: list[Path], output_folder: Path, resourcefilepath: Path
8291
if param == "Nothing":
8392
continue
8493
draw = format_scenario_name(param)
94+
print(f"Plotting yearly HSI counts for draw: {draw}")
8595
name_of_plot = f"Yearly HSI counts for {draw}"
96+
# Since all HSIs will be delivered before the service availability switch
97+
# retain only the treatment id of interest in this period to avoid plot
98+
# clutter.
99+
pre_switch_periods = (
100+
['2010-2010', '2011-2011', '2012-2012', '2013-2013',
101+
'2014-2014', '2015-2015', '2016-2016', '2017-2017',
102+
'2018-2018', '2019-2019', '2020-2020', '2021-2021',
103+
'2022-2022', '2023-2023', '2024-2024', '2025-2025']
104+
)
105+
mask_other_periods = (
106+
~counts_of_hsi_in_implementation_period.
107+
index.
108+
get_level_values("period").
109+
isin(pre_switch_periods)
110+
)
111+
mask_early_periods = (
112+
counts_of_hsi_in_implementation_period.index.get_level_values("period").isin(pre_switch_periods) &
113+
(counts_of_hsi_in_implementation_period.index.get_level_values("appt_type") == draw.replace("_*", ""))
114+
)
115+
plot_this = counts_of_hsi_in_implementation_period[mask_other_periods | mask_early_periods]
86116
fig, ax = plot_hsi_counts_by_period_for_draw(
87-
counts_of_hsi_in_implementation_period,
117+
plot_this,
88118
draw,
89119
)
90120
ax.set_title(name_of_plot)
91121
outfile = os.path.join(output_folder, make_graph_file_name(name_of_plot))
92122
fig.savefig(outfile)
93123
plt.close(fig)
94124

125+
print("Plotting capacity used by cadres across draws.")
126+
fig, ax = plot_cadre_time_by_draw_stacked(capacity_used_by_cadre, stat="central")
127+
name_of_plot = "Capacity Used by Cadres (2026-2040)"
128+
ax.set_title(name_of_plot)
129+
outfile = os.path.join(output_folder, make_graph_file_name(name_of_plot))
130+
fig.savefig(outfile)
131+
plt.close(fig)
132+
95133
# Plot population growth
96134
total_population_in_implementation = primary_results['total_population_by_year']
135+
print("Plotting population size by year.")
97136
fig, ax = plot_population_by_year(total_population_in_implementation / 1e6)
98137
name_of_plot = "Population size by year"
99138
ax.set_title(name_of_plot)
@@ -102,14 +141,14 @@ def apply(results_files: list[Path], output_folder: Path, resourcefilepath: Path
102141
plt.close(fig)
103142

104143
# Plot number of deaths and DALYS by cause for each parameter, with confidence intervals, for the target period
105-
106-
107-
num_dalys_by_cause_label_implementation = primary_results['num_dalys'].drop(['2010-2041'], level=1)
144+
num_dalys_by_cause_label_implementation = primary_results['dalys'].drop(['2010-2041'], level=1)
108145

109146
num_deaths_by_cause_label_implementation = primary_results['num_deaths'].drop(['2010-2041'], level=1)
147+
print("Prepared deaths and DALYs by cause for plotting.")
110148

111149
for param in param_names:
112150
draw = format_scenario_name(param)
151+
print(f"Plotting deaths over time by cause for draw: {draw}")
113152
fig, ax = plot_deaths_by_period_for_draw(
114153
num_deaths_by_cause_label_implementation / 1e3,
115154
draw,
@@ -123,6 +162,7 @@ def apply(results_files: list[Path], output_folder: Path, resourcefilepath: Path
123162

124163
cause_labels = num_deaths_by_cause_label_implementation.index.get_level_values("label").unique()
125164
for cause_label in cause_labels:
165+
print(f"Plotting cause-specific time series for: {cause_label}")
126166
fig, ax = plot_deaths_by_period_for_cause(
127167
num_deaths_by_cause_label_implementation / 1e3,
128168
cause_label=cause_label,
@@ -146,6 +186,7 @@ def apply(results_files: list[Path], output_folder: Path, resourcefilepath: Path
146186
plt.close(fig)
147187

148188
if comparison_metrics_available:
189+
print("Plotting comparison metrics: deaths/DALYs averted, percentages, and ICERs.")
149190
deaths_averted_sorted = (num_deaths_averted.sort_values(by="central", ascending=True) / 1e3)
150191
fig_height = max(6, min(0.28 * len(deaths_averted_sorted.index) + 4, 18))
151192
fig, ax = plt.subplots(figsize=(10, fig_height))
@@ -160,8 +201,9 @@ def apply(results_files: list[Path], output_folder: Path, resourcefilepath: Path
160201
fig.tight_layout()
161202
fig.savefig(outfile)
162203
plt.close(fig)
204+
print("Saved: Deaths Averted by Each Treatment ID")
163205

164-
dalys_averted_sorted = (num_dalys_averted.sort_values(by="central", ascending=True) / 1e3)
206+
dalys_averted_sorted = (dalys_averted.sort_values(by="central", ascending=True) / 1e3)
165207
fig_height = max(6, min(0.28 * len(dalys_averted_sorted.index) + 4, 18))
166208
fig, ax = plt.subplots(figsize=(10, fig_height))
167209
name_of_plot = "DALYS Averted by Each Treatment ID"
@@ -175,6 +217,7 @@ def apply(results_files: list[Path], output_folder: Path, resourcefilepath: Path
175217
fig.tight_layout()
176218
fig.savefig(outfile)
177219
plt.close(fig)
220+
print("Saved: DALYS Averted by Each Treatment ID")
178221

179222
pc_deaths_averted_sorted = (pc_deaths_averted.sort_values(by="central", ascending=True))
180223
fig_height = max(6, min(0.28 * len(pc_deaths_averted_sorted.index) + 4, 18))
@@ -190,6 +233,7 @@ def apply(results_files: list[Path], output_folder: Path, resourcefilepath: Path
190233
fig.tight_layout()
191234
fig.savefig(outfile)
192235
plt.close(fig)
236+
print("Saved: Percentage Deaths Averted by Each Treatment ID")
193237

194238
pc_dalys_averted_sorted = (pc_dalys_averted.sort_values(by="central", ascending=True))
195239
fig_height = max(6, min(0.28 * len(pc_dalys_averted_sorted.index) + 4, 18))
@@ -205,6 +249,7 @@ def apply(results_files: list[Path], output_folder: Path, resourcefilepath: Path
205249
fig.tight_layout()
206250
fig.savefig(outfile)
207251
plt.close(fig)
252+
print("Saved: Percentage DALYs Averted by Each Treatment ID")
208253

209254
icers_sorted = icers.sort_values(by="central", ascending=True)
210255
# Do not plot treatment ids with very wide uncertainty
@@ -227,6 +272,9 @@ def apply(results_files: list[Path], output_folder: Path, resourcefilepath: Path
227272
fig.tight_layout()
228273
fig.savefig(outfile)
229274
plt.close(fig)
275+
print("Saved: ICERs for Each Treatment ID")
276+
277+
print("Finished generating figures.")
230278

231279
if __name__ == "__main__":
232280
parser = argparse.ArgumentParser()

0 commit comments

Comments
 (0)