diff --git a/resources/healthsystem/ResourceFile_HealthSystem_parameters.csv b/resources/healthsystem/ResourceFile_HealthSystem_parameters.csv index c6bd6414e7..44a0f60bc3 100644 --- a/resources/healthsystem/ResourceFile_HealthSystem_parameters.csv +++ b/resources/healthsystem/ResourceFile_HealthSystem_parameters.csv @@ -3,6 +3,8 @@ policy_name,Naive year_mode_switch,2100 scale_to_effective_capabilities,FALSE Service_Availability,"[""*""]" +year_service_availability_switch,2100 +service_availability_postSwitch,"[""*""]" use_funded_or_actual_staffing,funded_plus mode_appt_constraints,1 mode_appt_constraints_postSwitch,1 diff --git a/src/scripts/costing/cost_estimation.py b/src/scripts/costing/cost_estimation.py index 02d0971955..594c2d1d01 100644 --- a/src/scripts/costing/cost_estimation.py +++ b/src/scripts/costing/cost_estimation.py @@ -156,7 +156,8 @@ def get_discount_factor(year): # Compute the cumulative discount factor as the product of (1 + discount_rate) for all previous years discount_factor = 1 for y in range(_initial_year + 1, - year + 1): # only starting from initial year + 1 as the discount factor for initial year should be 1 + year + 1): # only starting from initial year + 1 as the discount factor for initial year + # should be 1 discount_factor *= (1 + _discount_rate.get(y, 0)) # Default to 0 if year not in dictionary return discount_factor else: @@ -282,6 +283,7 @@ def clean_equipment_name(name: str, equipment_drop_list = None) -> str: def estimate_input_cost_of_scenarios(results_folder: Path, resourcefilepath: Path, + suspended_results_folder: Path = None, _draws: Optional[list[int]] = None, _runs: Optional[list[int]] = None, summarize: bool = False, @@ -298,6 +300,10 @@ def estimate_input_cost_of_scenarios(results_folder: Path, Path to the directory containing simulation output files. resourcefilepath : Path, optional Path to the resource files + suspended_results_folder: Path, optional + Path to the directory containing suspended simulation output files (using the suspend and resume functionality), + This is used to extract the scaling_factor to scale result to actual population size. If None, then the + 'scaling_factor' is obtained from the results_folder. _draws : list, optional Specific draws to include in the cost estimation. Defaults to all available draws. _runs : list, optional @@ -316,8 +322,10 @@ def estimate_input_cost_of_scenarios(results_folder: Path, Returns: ------- pd.DataFrame - A dataframe containing discounted costs disaggregated by category, sub-category, category-specific subgroup, year, draw, and run. - Note that if a discount rate is used, the dataframe will provide cost as the NPV during the first year of the dataframe + A dataframe containing discounted costs disaggregated by category, sub-category, category-specific subgroup, + year, draw, and run. + Note that if a discount rate is used, the dataframe will provide cost as the NPV during the first year of the + dataframe """ # Useful common functions @@ -343,6 +351,8 @@ def melt_model_output_draws_and_runs(_df, id_vars): _draws = range(0, info['number_of_draws']) if _runs is None: _runs = range(0, info['runs_per_draw']) + if suspended_results_folder is None: + suspended_results_folder = results_folder # Load cost input files # ------------------------ @@ -354,7 +364,8 @@ def melt_model_output_draws_and_runs(_df, id_vars): facility_id_levels_dict = dict(zip(mfl['Facility_ID'], mfl['Facility_Level'])) fac_levels = set(mfl.Facility_Level) - # If variable discount rate is provided, use the average across the relevant years for the purpose of annuitization of HR and equipment costs + # If variable discount rate is provided, use the average across the relevant years for the purpose of annuitization + # of HR and equipment costs def calculate_annuitization_rate(_discount_rate, _years): if isinstance(_discount_rate, (int, float)): # Single discount rate, return as is @@ -477,25 +488,68 @@ def merge_cost_and_model_data(cost_df, model_df, varnames): return merged_df # Get available staff count for each year and draw - def get_staff_count_by_facid_and_officer_type(_df: pd.Series) -> pd.Series: - """Summarise the parsed logged-key results for one draw (as dataframe) into a pd.Series.""" - _df = _df.set_axis(_df['date'].dt.year).drop(columns=['date']) - _df.index.name = 'year' + def get_staff_count_by_facid_and_officer_type(_df: pd.DataFrame) -> pd.Series: + """ + Convert logged staff dictionary output into tidy format, + summing staff counts across all clinic columns. + + Returns pd.Series indexed by: + (year, FacilityID, Officer) + """ - def change_to_standard_flattened_index_format(col): - parts = col.split("_", 3) # Split by "_" only up to 3 parts - if len(parts) > 2: - return parts[0] + "=" + parts[1] + "|" + parts[2] + "=" + parts[ - 3] # Rejoin with "I" at the second occurrence - return col # If there's no second underscore, return the string as it is + df = _df.copy() + df["year"] = df["date"].dt.year + df = df.drop(columns=["date"]) - _df.columns = [change_to_standard_flattened_index_format(col) for col in _df.columns] + clinic_cols = df.columns.difference(["year"]) - return unflatten_flattened_multi_index_in_logging(_df).stack(level=[0, 1]) # expanded flattened axis + long_frames = [] + + for clinic in clinic_cols: + expanded = df[[clinic, "year"]].copy() + expanded = expanded[expanded[clinic].notna()] + + expanded_dict = expanded[clinic].apply(pd.Series) + expanded_dict["year"] = expanded["year"].values + + long_frames.append(expanded_dict) + + # Combine all clinics + combined = pd.concat(long_frames, ignore_index=True) + + # Melt to long format + long_df = ( + combined + .melt(id_vars=["year"], + var_name="facility_officer", + value_name="count") + .dropna(subset=["count"]) + ) + + # Split FacilityID and Officer + parts = long_df["facility_officer"].str.split("_Officer_", expand=True) + + long_df["FacilityID"] = ( + parts[0] + .str.replace("FacilityID_", "", regex=False) + .astype(int) + ) + long_df["Officer"] = parts[1] + + # SUM ACROSS CLINICS HERE + result = ( + long_df + .groupby(["year", "FacilityID", "Officer"])["count"] + .sum() + .sort_index() + ) + + return result # Staff count by Facility ID available_staff_count_by_facid_and_officertype = extract_results( Path(results_folder), + suspended_results_folder=suspended_results_folder, module='tlo.methods.healthsystem.summary', key='number_of_hcw_staff', custom_generate_series=get_staff_count_by_facid_and_officer_type, @@ -519,22 +573,71 @@ def change_to_standard_flattened_index_format(col): 'Facility_Level'].astype(str) # make sure facility level is stored as string available_staff_count_by_level_and_officer_type = available_staff_count_by_level_and_officer_type.drop( available_staff_count_by_level_and_officer_type[available_staff_count_by_level_and_officer_type[ - 'Facility_Level'] == '5'].index) # drop headquarters because we're only concerned with staff engaged in service delivery + 'Facility_Level'] == '5'].index) # drop headquarters + # because we're only concerned with staff engaged in service delivery available_staff_count_by_level_and_officer_type.rename(columns={'value': 'staff_count'}, inplace=True) # Get list of cadres which were utilised in each run to get the count of staff used in the simulation - # Note that we still cost the full staff count for any cadre-Facility_Level combination that was ever used in a run, and - # not the amount of time which was used - def get_capacity_used_by_officer_type_and_facility_level(_df: pd.Series) -> pd.Series: - """Summarise the parsed logged-key results for one draw (as dataframe) into a pd.Series.""" - _df = _df.set_axis(_df['date'].dt.year).drop(columns=['date']) - _df.index.name = 'year' - return unflatten_flattened_multi_index_in_logging(_df).stack(level=[0, 1]) # expanded flattened axis + # Note that we still cost the full staff count for any cadre-Facility_Level combination that was ever used in a run, + # and not the amount of time which was used + def get_capacity_used_by_officer_type_and_facility_level( + _df: pd.DataFrame + ) -> pd.Series: + """ + Parse logging output and return a Series indexed by: + (year, OfficerType, FacilityLevel) + + Collapses (sums) across clinics. + Uses facility_id_levels_dict to map FacilityID → FacilityLevel. + """ + + # ---- 1. Set year index ---- + _df = _df.set_axis(_df["date"].dt.year).drop(columns=["date"]) + _df.index.name = "year" + + # ---- 2. Unflatten logging columns ---- + _df = unflatten_flattened_multi_index_in_logging(_df) + + # Expect columns like: + # ('Clinic', 'facID_and_officer') + + col_df = _df.columns.to_frame(index=False) + + # ---- 3. Extract OfficerType ---- + col_df["OfficerType"] = ( + col_df["facID_and_officer"] + .str.split("_Officer_") + .str[-1] + ) + + # ---- 4. Extract FacilityID ---- + col_df["FacilityID"] = ( + col_df["facID_and_officer"] + .str.split("_Officer_") + .str[0] + .str.replace("FacilityID_", "", regex=False) + .astype(int) + ) + + # ---- 5. Map to FacilityLevel ---- + col_df["FacilityLevel"] = col_df["FacilityID"].map(facility_id_levels_dict) + + # ---- 6. Rebuild MultiIndex (drop clinic level) ---- + _df.columns = pd.MultiIndex.from_frame( + col_df[["OfficerType", "FacilityLevel"]] + ) + + # ---- 7. Collapse across clinics ---- + _df = _df.groupby(level=["OfficerType", "FacilityLevel"], axis=1).sum() + + # ---- 8. Return stacked format ---- + return _df.stack(["OfficerType", "FacilityLevel"]) annual_capacity_used_by_cadre_and_level = extract_results( Path(results_folder), + suspended_results_folder=suspended_results_folder, module='tlo.methods.healthsystem.summary', - key='Capacity_By_OfficerType_And_FacilityLevel', + key='Capacity_By_FacID_and_Officer', custom_generate_series=get_capacity_used_by_officer_type_and_facility_level, do_scaling=False, ) @@ -552,7 +655,10 @@ def get_capacity_used_by_officer_type_and_facility_level(_df: pd.Series) -> pd.S average_capacity_used_by_cadre_and_level[average_capacity_used_by_cadre_and_level['capacity_used'] != 0][ ['OfficerType', 'FacilityLevel', 'draw', 'run']] print( - f"Out of {average_capacity_used_by_cadre_and_level.groupby(['OfficerType', 'FacilityLevel']).size().count()} cadre and level combinations available, {list_of_cadre_and_level_combinations_used.groupby(['OfficerType', 'FacilityLevel']).size().count()} are used across the simulations") + f"Out of {average_capacity_used_by_cadre_and_level.groupby(['OfficerType', 'FacilityLevel']).size().count()} " + f"cadre and level combinations available, " + f"{list_of_cadre_and_level_combinations_used.groupby(['OfficerType', 'FacilityLevel']).size().count()} " + f"are used across the simulations") list_of_cadre_and_level_combinations_used = list_of_cadre_and_level_combinations_used.rename( columns={'FacilityLevel': 'Facility_Level'}) @@ -564,11 +670,13 @@ def get_capacity_used_by_officer_type_and_facility_level(_df: pd.Series) -> pd.S if (cost_only_used_staff): print( - "The input for 'cost_only_used_staff' implies that only cadre-level combinations which have been used in the run are costed") + "The input for 'cost_only_used_staff' implies that only cadre-level combinations which have been used in " + "the run are costed") staff_size_chosen_for_costing = used_staff_count_by_level_and_officer_type else: print( - "The input for 'cost_only_used_staff' implies that all staff are costed regardless of the cadre-level combinations which have been used in the run are costed") + "The input for 'cost_only_used_staff' implies that all staff are costed regardless of the cadre-level " + "combinations which have been used in the run are costed") staff_size_chosen_for_costing = available_staff_count_by_level_and_officer_type # Calculate various components of HR cost @@ -607,7 +715,8 @@ def calculate_npv_past_training_expenses_by_row(row, r=_discount_rate): if partial_year > 0: npv += annual_cost * partial_year * (1 + r) ** (1 + r) - # Add recruitment cost assuming this happens during the partial year or the year after graduation if partial year == 0 + # Add recruitment cost assuming this happens during the partial year or the year after graduation if + # partial year == 0 npv += row['recruitment_cost_per_person_recruited_usd'] * (1 + r) return npv @@ -619,34 +728,33 @@ def calculate_npv_past_training_expenses_by_row(row, r=_discount_rate): npv_values.append(npv) preservice_training_cost['npv_of_training_and_recruitment_cost'] = npv_values - preservice_training_cost['npv_of_training_and_recruitment_cost_per_recruit'] = preservice_training_cost[ - 'npv_of_training_and_recruitment_cost'] * \ - (1 / (preservice_training_cost[ - 'absorption_rate_of_students_into_public_workforce'] + - preservice_training_cost[ - 'proportion_of_workforce_recruited_from_abroad'])) * \ - (1 / preservice_training_cost[ - 'graduation_rate']) * (1 / - preservice_training_cost[ - 'licensure_exam_passing_rate']) - if _discount_rate == 0: # if the discount rate is 0, then the pre-service + recruitment cost simply needs to be divided by the number of years in tenure + preservice_training_cost['npv_of_training_and_recruitment_cost_per_recruit'] \ + = (preservice_training_cost['npv_of_training_and_recruitment_cost'] * + (1 / (preservice_training_cost['absorption_rate_of_students_into_public_workforce'] + + preservice_training_cost['proportion_of_workforce_recruited_from_abroad'])) * + (1 / preservice_training_cost['graduation_rate']) * + (1 /preservice_training_cost['licensure_exam_passing_rate'])) + if _discount_rate == 0: # if the discount rate is 0, then the pre-service + recruitment cost simply + # needs to be divided by the number of years in tenure preservice_training_cost['annuitisation_rate'] = preservice_training_cost[ 'average_length_of_tenure_in_the_public_sector'] else: preservice_training_cost['annuitisation_rate'] = 1 + (1 - (1 + annuitization_rate) ** ( -preservice_training_cost[ 'average_length_of_tenure_in_the_public_sector'] + 1)) / annuitization_rate - preservice_training_cost['annuitised_training_and_recruitment_cost_per_recruit'] = preservice_training_cost[ - 'npv_of_training_and_recruitment_cost_per_recruit'] / \ - preservice_training_cost[ - 'annuitisation_rate'] - - # Cost per student trained * 1/Rate of absorption from the local and foreign graduates * 1/Graduation rate * attrition rate - # the inverse of attrition rate is the average expected tenure; and the preservice training cost needs to be divided by the average tenure + preservice_training_cost['annuitised_training_and_recruitment_cost_per_recruit'] = \ + (preservice_training_cost['npv_of_training_and_recruitment_cost_per_recruit'] / + preservice_training_cost['annuitisation_rate']) + + # Cost per student trained * 1/Rate of absorption from the local and foreign graduates + # * 1/Graduation rate * attrition rate + # the inverse of attrition rate is the average expected tenure; and the preservice training cost needs to + # be divided by the average tenure preservice_training_cost['cost'] = preservice_training_cost[ 'annuitised_training_and_recruitment_cost_per_recruit'] * \ preservice_training_cost['staff_count'] * preservice_training_cost[ - 'annual_attrition_rate'] # not multiplied with attrition rate again because this is already factored into 'Annual_cost_per_staff_recruited' + 'annual_attrition_rate'] # not multiplied with attrition rate again + # because this is already factored into 'Annual_cost_per_staff_recruited' preservice_training_cost = preservice_training_cost[ ['draw', 'run', 'year', 'OfficerType', 'Facility_Level', 'cost']] @@ -680,7 +788,8 @@ def label_rows_of_cost_dataframe(_df, label_var, label): # Initialize HR with the salary data if (cost_only_used_staff): human_resource_costs = retain_relevant_column_subset( - label_rows_of_cost_dataframe(salary_for_staff, 'cost_subcategory', 'salary_for_cadres_used'), 'OfficerType') + label_rows_of_cost_dataframe(salary_for_staff, 'cost_subcategory', 'salary_for_cadres_used'), + 'OfficerType') # Concatenate additional cost categories additional_costs = [ (preservice_training_cost, 'preservice_training_and_recruitment_cost_for_attrited_workers'), @@ -689,7 +798,8 @@ def label_rows_of_cost_dataframe(_df, label_var, label): ] else: human_resource_costs = retain_relevant_column_subset( - label_rows_of_cost_dataframe(salary_for_staff, 'cost_subcategory', 'salary_for_all_staff'), 'OfficerType') + label_rows_of_cost_dataframe(salary_for_staff, 'cost_subcategory', 'salary_for_all_staff'), + 'OfficerType') # Concatenate additional cost categories additional_costs = [ (preservice_training_cost, 'preservice_training_and_recruitment_cost_for_attrited_workers'), @@ -741,6 +851,7 @@ def get_counts_of_items_requested(_df): cons_req = extract_results( results_folder, + suspended_results_folder=suspended_results_folder, module='tlo.methods.healthsystem.summary', key='Consumables', custom_generate_series=get_counts_of_items_requested, @@ -762,7 +873,8 @@ def get_counts_of_items_requested(_df): # 2.1 Cost of consumables dispensed # --------------------------------------------------------------------------------------------------------------- # Multiply number of items needed by cost of consumable - # consumables_dispensed.columns = consumables_dispensed.columns.get_level_values(0).str() + "_" + consumables_dispensed.columns.get_level_values(1) # Flatten multi-level columns for pandas merge + # consumables_dispensed.columns = consumables_dispensed.columns.get_level_values(0).str() + "_" + + # consumables_dispensed.columns.get_level_values(1) # Flatten multi-level columns for pandas merge unit_costs['consumables'].columns = pd.MultiIndex.from_arrays( [unit_costs['consumables'].columns, [''] * len(unit_costs['consumables'].columns)]) cost_of_consumables_dispensed = consumables_dispensed.merge(unit_costs['consumables'], on=idx['Item_Code'], @@ -796,8 +908,20 @@ def get_counts_of_items_requested(_df): left_on='Item_Code', right_on='item_code', validate='m:1', how='left') + # Identify rows where excess_stock_proportion_of_dispensed is NaN + missing_excess_stock = ( + cost_of_excess_consumables_stocked + ['excess_stock_proportion_of_dispensed'] + .isna() + ) + + # Fill missing values with the average inflow-to-outflow ratio minus 1 + fill_value = average_inflow_to_outflow_ratio_ratio - 1 + cost_of_excess_consumables_stocked.loc[ - cost_of_excess_consumables_stocked.excess_stock_proportion_of_dispensed.isna(), 'excess_stock_proportion_of_dispensed'] = average_inflow_to_outflow_ratio_ratio - 1 # TODO disaggregate the average by program + missing_excess_stock, + 'excess_stock_proportion_of_dispensed' + ] = fill_value # TODO: disaggregate the average by program cost_of_excess_consumables_stocked[quantity_columns] = cost_of_excess_consumables_stocked[ quantity_columns].multiply(cost_of_excess_consumables_stocked[idx[price_column]], axis=0) cost_of_excess_consumables_stocked[quantity_columns] = cost_of_excess_consumables_stocked[ @@ -815,7 +939,8 @@ def get_counts_of_items_requested(_df): def melt_and_label_consumables_cost(_df, label): multi_index = pd.MultiIndex.from_tuples(_df.columns) _df.columns = multi_index - # Select 'Item_Code', 'year', and all columns where both levels of the MultiIndex are numeric (these are the (draw,run) columns with cost values) + # Select 'Item_Code', 'year', and all columns where both levels of the MultiIndex are numeric + # (these are the (draw,run) columns with cost values) selected_columns = [col for col in _df.columns if (col[0] in ['Item_Code', 'year']) or (isinstance(col[0], int) and isinstance(col[1], int))] _df = _df[selected_columns] # Subset the dataframe with the selected columns @@ -827,13 +952,14 @@ def melt_and_label_consumables_cost(_df, label): melted_df['consumable'] = melted_df['Item_Code'].map(consumables_dict) melted_df['cost_subcategory'] = label melted_df[ - 'Facility_Level'] = 'all' # TODO this is temporary until 'tlo.methods.healthsystem.summary' only logs consumable at the aggregate level + 'Facility_Level'] = 'all' + # TODO this is temporary until 'tlo.methods.healthsystem.summary' only logs consumable at the aggregate level melted_df = melted_df.rename(columns={'value': 'cost'}) return melted_df def disaggregate_separately_managed_medical_supplies_from_consumable_costs(_df, _consumables_dict, - # This is a dictionary mapping codes to names + # This is a dictionary mapping codes to names list_of_unique_medical_products): reversed_consumables_dict = {value: key for key, value in _consumables_dict.items()} # reverse dictionary to map names to codes @@ -849,24 +975,26 @@ def disaggregate_separately_managed_medical_supplies_from_consumable_costs(_df, columns='item_code') separately_managed_medical_supplies = [127, 141, 161] # Oxygen, Blood, IRS - cost_of_consumables_dispensed, cost_of_separately_managed_medical_supplies_dispensed = disaggregate_separately_managed_medical_supplies_from_consumable_costs( + cost_of_consumables_dispensed, cost_of_separately_managed_medical_supplies_dispensed = ( + disaggregate_separately_managed_medical_supplies_from_consumable_costs( _df=retain_relevant_column_subset( melt_and_label_consumables_cost(cost_of_consumables_dispensed, 'cost_of_consumables_dispensed'), 'consumable'), _consumables_dict=consumables_dict, - list_of_unique_medical_products=separately_managed_medical_supplies) - cost_of_excess_consumables_stocked, cost_of_separately_managed_medical_supplies_excess_stock = disaggregate_separately_managed_medical_supplies_from_consumable_costs( + list_of_unique_medical_products=separately_managed_medical_supplies)) + cost_of_excess_consumables_stocked, cost_of_separately_managed_medical_supplies_excess_stock = ( + disaggregate_separately_managed_medical_supplies_from_consumable_costs( _df=retain_relevant_column_subset( melt_and_label_consumables_cost(cost_of_excess_consumables_stocked, 'cost_of_excess_consumables_stocked'), 'consumable'), _consumables_dict=consumables_dict, - list_of_unique_medical_products=separately_managed_medical_supplies) + list_of_unique_medical_products=separately_managed_medical_supplies)) consumable_costs = pd.concat([cost_of_consumables_dispensed, cost_of_excess_consumables_stocked]) # 2.4 Supply chain costs # --------------------------------------------------------------------------------------------------------------- - # Assume that the cost of procurement, warehousing and distribution is a fixed proportion of consumable purchase costs + # Assume that the cost of procurement,warehousing and distribution is a fixed proportion of consumable purchase # The fixed proportion is based on Resource Mapping Expenditure data from 2018 resource_mapping_data = unit_costs['actual_expenditure_data'] # Make sure values are numeric @@ -922,7 +1050,8 @@ def disaggregate_separately_managed_medical_supplies_from_consumable_costs(_df, # -------------------------------------------- print("Now estimating Medical equipment costs...") - # Total cost of equipment required as per SEL (HSSP-III) only at facility IDs where it has been used in the simulation + # Total cost of equipment required as per SEL (HSSP-III) only at facility IDs where it has been used in the + # simulation # Get list of equipment used in the simulation by district and level def get_equipment_used_by_district_and_facility(_df: pd.Series) -> pd.Series: """Summarise the parsed logged-key results for one draw (as dataframe) into a pd.Series.""" @@ -987,7 +1116,8 @@ def get_equipment_used_by_district_and_facility(_df: pd.Series) -> pd.Series: on=['District', 'Facility_Level'], how='left') equipment_df.loc[equipment_df.Facility_Count.isna(), 'Facility_Count'] = 0 - # Because levels 1b and 2 are collapsed together, we assume that the same equipment is used by level 1b as that recorded for level 2 + # Because levels 1b and 2 are collapsed together, we assume that the same equipment is used by level 1b as + # that recorded for level 2 def update_itemuse_for_level1b_using_level2_data(_df): # Create a list of District and Item_code combinations for which use == True list_of_equipment_used_at_level2 = \ @@ -1043,7 +1173,8 @@ def update_itemuse_for_level1b_using_level2_data(_df): # Assume that the annual costs are constant each year of the simulation equipment_costs = pd.concat([equipment_costs.assign(year=year) for year in years]) - # TODO If the logger is updated to include year, we may wish to calculate equipment costs by year - currently we assume the same annuitised equipment cost each year + # TODO If the logger is updated to include year, we may wish to calculate equipment costs by year + # (currently we assume the same annuitised equipment cost each year) equipment_costs = equipment_costs.reset_index(drop=True) equipment_costs = equipment_costs.rename(columns={'Equipment_tlo': 'Equipment'}) equipment_costs = prepare_cost_dataframe(equipment_costs, _category_specific_group='Equipment', @@ -1137,8 +1268,8 @@ def update_itemuse_for_level1b_using_level2_data(_df): # Define a function to summarize cost data from -# Note that the dataframe needs to have draw as index and run as columns. if the dataframe is long with draw and run as index, then -# first unstack the dataframe and subsequently apply the summarize function +# Note that the dataframe needs to have draw as index and run as columns. if the dataframe is long with draw and run as +# index, then first unstack the dataframe and subsequently apply the summarize function def summarize_cost_data(_df, _metric: Literal['mean', 'median'] = 'mean') -> pd.DataFrame: """ @@ -1194,8 +1325,8 @@ def estimate_projected_health_spending(resourcefilepath: Path, """ Estimate total projected health spending for a simulation period. - Combines health spending per capita projections (Dieleman et al, 2019) with simulated population estimates to calculate - total health expenditure, optionally applying a discount rate and summarizing across runs. + Combines health spending per capita projections (Dieleman et al, 2019) with simulated population estimates to + calculate total health expenditure, optionally applying a discount rate and summarizing across runs. Parameters: ---------- @@ -1424,7 +1555,8 @@ def do_stacked_bar_plot_of_cost_by_category(_df: pd.DataFrame, if (_disaggregate_by_subgroup is True): for name, df in dfs.items(): dfs[name] = df.copy() # Choose the dataframe to modify - # If sub-groups are more than 10 in number, then disaggregate the top 10 and group the rest into an 'other' category + # If sub-groups are more than 10 in number, then disaggregate the top 10 and group the rest into an + # 'other' category if (len(dfs[name]['cost_subgroup'].unique()) > 10): # Calculate total cost per subgroup subgroup_totals = dfs[name].groupby('cost_subgroup')['cost'].sum() @@ -1870,7 +2002,8 @@ def wrap_text(text, width=15): if (len(_df['cost_subgroup'].unique()) > 10): # Step 2: Group all other consumables into "Other" other_cost = _df.iloc[10:]["cost"].sum() - top_10 = pd.concat([top_10, pd.DataFrame([{"cost_subgroup": "Other", "cost": other_cost}])], ignore_index=True) + top_10 = pd.concat([top_10, pd.DataFrame([{"cost_subgroup": "Other", "cost": other_cost}])], + ignore_index=True) # Prepare data for the treemap total_cost = top_10["cost"].sum() @@ -2015,7 +2148,8 @@ def generate_multiple_scenarios_roi_plot(_monetary_value_of_incremental_health: # Initialize an empty DataFrame to store values for each 'run' all_run_values = pd.DataFrame() - # Create an array of implementation costs ranging from 0 to the max value of max ability to pay for the current draw + # Create an array of implementation costs ranging from 0 to the max value of max ability to pay for the current + # draw implementation_costs = np.linspace(0, max_ability_to_pay_for_implementation.loc[draw_index].max(), 50) # Add fixed values for ROI ratio calculation additional_costs = np.array([1_000_000_000, 3_000_000_000]) @@ -2121,7 +2255,8 @@ def generate_multiple_scenarios_roi_plot(_monetary_value_of_incremental_health: # Replace specific x-ticks with % of health spending values if _projected_health_spending: xtick_labels[ - 1] = f'{xticks[1]:,.0f}\n({xticks[1] / (_projected_health_spending / 1e6) :.2%} of \n projected total \n health spend)' + 1] = (f'{xticks[1]:,.0f}\n({xticks[1] / (_projected_health_spending / 1e6) :.2%} of \n projected total ' + f'\n health spend)') for i, tick in enumerate(xticks): if (i != 0) & (i != 1): # Replace for 4000 xtick_labels[i] = f'{tick:,.0f}\n({tick / (_projected_health_spending / 1e6) :.2%})' diff --git a/src/scripts/lcoa_inputs_from_tlo_analyses/analysis_effect_of_treatment_ids.py b/src/scripts/lcoa_inputs_from_tlo_analyses/analysis_effect_of_treatment_ids.py new file mode 100644 index 0000000000..cd5ee3a071 --- /dev/null +++ b/src/scripts/lcoa_inputs_from_tlo_analyses/analysis_effect_of_treatment_ids.py @@ -0,0 +1,413 @@ +"""Produce plots to show the impact each set of treatments.""" + +import warnings +from time import perf_counter +from pandas.errors import ( + PerformanceWarning, + SettingWithCopyWarning +) +import argparse +from datetime import date +import pickle +from pathlib import Path +import pandas as pd + + +from tlo import Date + +from scripts.lcoa_inputs_from_tlo_analyses.results_processing_utils import ( + get_counts_of_appts, + get_counts_of_hsi_by_short_treatment_id, + get_num_dalys_by_cause_label, + get_num_deaths_by_cause_label, + get_parameter_names_from_scenario_file, + get_periods_within_target_period, + get_total_num_dalys_by_agegrp_and_label, + get_total_num_death_by_agegrp_and_label, + get_total_population_by_year, + make_get_num_dalys_by_cause_label_and_period, + make_get_num_deaths_by_cause_label_and_period, + make_get_counts_of_appts_by_period, + make_get_counts_of_hsis_by_period, + set_param_names_as_column_index_level_0, + target_period, + find_difference_extra_relative_to_comparison, + find_difference_relative_to_comparison, + get_staff_count_by_facid_and_officer_type, + get_capacity_used_by_officer_type_and_facility_level, + melt_model_output_draws_and_runs +) + +from scripts.costing.cost_estimation import ( + apply_discounting_to_cost_data, + do_line_plot_of_cost, + do_stacked_bar_plot_of_cost_by_category, + estimate_input_cost_of_scenarios, + estimate_projected_health_spending, + extract_roi_at_specific_implementation_costs, + generate_multiple_scenarios_roi_plot, + load_unit_cost_assumptions, + summarize_cost_data, + tabulate_roi_estimates, +) +from tlo.analysis.utils import ( + compute_summary_statistics, + extract_results, + get_color_short_treatment_id, + make_age_grp_lookup, + summarize, +) +# python src/scripts/lcoa_inputs_from_tlo_analyses/analysis_effect_of_treatment_ids.py outputs/s.bhatia@imperial.ac.uk/effect_of_each_treatment_id-2026-02-12T120859Z figs/ --target-start=2010-01-01 --target-end=2025-12-31 +# python src/scripts/lcoa_inputs_from_tlo_analyses/analysis_effect_of_treatment_ids.py outputs/s.bhatia@imperial.ac.uk/effect_of_each_treatment_id-2026-02-16T154500Z figs/ --target-start=2025-01-01 --target-end=2041-01-01 +# python src/scripts/lcoa_inputs_from_tlo_analyses/analysis_effect_of_treatment_ids.py outputs/s.bhatia@imperial.ac.uk/effect_of_each_treatment_id-combined --target-start=2010-01-01 --target-end=2041-01-01 +# python src/scripts/lcoa_inputs_from_tlo_analyses/analysis_effect_of_treatment_ids.py outputs/s.bhatia@imperial.ac.uk/effect_of_each_treatment_id-2026-04-01T130709Z --target-start=2010-01-01 --target-end=2041-01-01 --do-comparison=False +# python src/scripts/lcoa_inputs_from_tlo_analyses/analysis_effect_of_treatment_ids.py outputs/s.bhatia@imperial.ac.uk/effect_of_each_treatment_id-combined outputs/generated_outputs --target-start=2010-01-01 --target-end=2041-01-01 --cost-checkpoint-profile=baseline --load-input-costs-from-checkpoint=True +PERIOD_LENGTH_YEARS_FOR_BAR_PLOTS = 1 + +EXCLUDED_HSIs = [ + "FirstAttendance_Emergency", + "FirstAttendance_NonEmergency", + "FirstAttendance_SpuriousEmergencyCare", + "Inpatient_Care" +] + +def parse_iso_date(value: str) -> Date: + parsed = date.fromisoformat(value) + return Date(parsed.year, parsed.month, parsed.day) + + +def parse_bool(value: str) -> bool: + normalized = value.strip().lower() + if normalized in {"true", "t", "1", "yes", "y"}: + return True + if normalized in {"false", "f", "0", "no", "n"}: + return False + raise argparse.ArgumentTypeError( + f"Invalid boolean value '{value}'. Use True or False." + ) + + +def apply( + results_folder: Path, + output_folder: Path, + resourcefilepath: Path, + target_period_tuple: tuple[Date, Date], + do_comparison: bool = True, + cost_checkpoint_profile: str | None = None, + load_input_costs_from_checkpoint: bool | None = None, +): + """Process results to produce objects needed for LCOA analysis.""" + _, age_grp_lookup = make_age_grp_lookup() + + # Extract districts and facility levels from the Master Facility List + mfl = pd.read_csv(resourcefilepath / "healthsystem" / "organisation" / "ResourceFile_Master_Facilities_List.csv") + facility_id_levels_dict = dict(zip(mfl['Facility_ID'], mfl['Facility_Level'])) + + param_names = get_parameter_names_from_scenario_file() + get_num_deaths_by_cause_label_and_period = make_get_num_deaths_by_cause_label_and_period( + PERIOD_LENGTH_YEARS_FOR_BAR_PLOTS, + target_period_tuple, + ) + get_num_dalys_by_cause_label_and_period = make_get_num_dalys_by_cause_label_and_period( + PERIOD_LENGTH_YEARS_FOR_BAR_PLOTS, + target_period_tuple, + ) + get_num_hsi_by_period = make_get_counts_of_hsis_by_period( + PERIOD_LENGTH_YEARS_FOR_BAR_PLOTS, + target_period_tuple=target_period_tuple, + ) + results = {} + # Costs calculation + print("Calculating costs...") + discount_rate_cost = 0.03 + # Period relevant for costing + TARGET_PERIOD = (Date(2026, 1, 1), Date(2040, 12, 31)) # This is the period that is costed + relevant_period_for_costing = [i.year for i in TARGET_PERIOD] + list_of_relevant_years_for_costing = list(range(relevant_period_for_costing[0], relevant_period_for_costing[1] + 1)) + print("List of relevant years for costing:", list_of_relevant_years_for_costing) + checkpoint_path = None + if cost_checkpoint_profile is not None: + checkpoint_path = output_folder / "checkpoints" / f"input_costs_{cost_checkpoint_profile}.pkl" + + if checkpoint_path is not None and load_input_costs_from_checkpoint is True: + print(f"Loading input costs from checkpoint: {checkpoint_path}") + if not checkpoint_path.exists(): + raise FileNotFoundError( + f"Input-cost checkpoint not found at {checkpoint_path}. " + "Run once with --cost-checkpoint-profile and without " + "--load-input-costs-from-checkpoint to create it." + ) + with open(checkpoint_path, "rb") as f: + input_costs = pickle.load(f) + else: + if checkpoint_path is None: + print("No cost checkpoint profile provided. Recomputing input costs.") + else: + print("Recomputing input costs") + start = perf_counter() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=PerformanceWarning) + warnings.filterwarnings("ignore", category=UserWarning) + warnings.filterwarnings("ignore", category=SettingWithCopyWarning) + input_costs = estimate_input_cost_of_scenarios( + results_folder, + resourcefilepath, + _years=list_of_relevant_years_for_costing, + cost_only_used_staff=True, + _discount_rate=discount_rate_cost, + _metric="median",) + + elapsed = perf_counter() - start + print(f"\n=== TIMING: estimate_input_cost_of_scenarios took {elapsed:.3f}s ===\n", flush=True) + if checkpoint_path is not None: + checkpoint_path.parent.mkdir(parents=True, exist_ok=True) + with open(checkpoint_path, "wb") as f: + pickle.dump(input_costs, f) + print(f"Saved input costs checkpoint to: {checkpoint_path}") + results['input_costs'] = input_costs + + # Computing incremental costs + # TODO Check with Sakshi if these are annual costs; as everything else is annual. + if do_comparison: + print("Computing incremental_scenario_cost...") + start = perf_counter() + total_input_cost = input_costs.groupby(['draw', 'run'])['cost'].sum() + incremental_scenario_cost = (pd.DataFrame( + find_difference_relative_to_comparison( + total_input_cost, + comparison=0,) + )) + + elapsed = perf_counter() - start + print(f"\n=== TIMING: computing incremental_scenario_cost took {elapsed:.3f}s ===\n", flush=True) + + incremental_scenario_cost = ( + incremental_scenario_cost.T.reorder_levels(["draw", "run"], axis=1).sort_index(axis=1) + ).pipe(set_param_names_as_column_index_level_0, param_names) + + incremental_scenario_cost_summarized = compute_summary_statistics(incremental_scenario_cost, 'median').iloc[0].unstack() + + # Get total population by year + print("Extracting population data...") + total_population_by_year = ( + extract_results( + results_folder, + module='tlo.methods.demography', + key='population', + custom_generate_series=lambda _df: get_total_population_by_year(_df, target_period_tuple), + do_scaling=True, + autodiscover=True + ).pipe(set_param_names_as_column_index_level_0, param_names=param_names) + ) + + total_population_by_year = compute_summary_statistics(total_population_by_year, central_measure='median') + results['total_population_by_year'] = total_population_by_year + + counts_of_hsi_by_short_treatment_id = ( + extract_results( + results_folder, + module="tlo.methods.healthsystem.summary", + key="HSI_Event", + custom_generate_series=lambda _df: get_counts_of_hsi_by_short_treatment_id(_df, target_period_tuple), + do_scaling=True, + autodiscover=True, + ) + .pipe(set_param_names_as_column_index_level_0, param_names=param_names) + .fillna(0.0) + .sort_index() + ).drop(EXCLUDED_HSIs, errors='ignore') + + counts_of_hsi_by_short_treatment_id = ( + compute_summary_statistics(counts_of_hsi_by_short_treatment_id, 'median') + ) + + results['counts_of_hsi_by_short_treatment_id'] = counts_of_hsi_by_short_treatment_id + + counts_of_hsi_by_period = ( + extract_results( + results_folder, + module="tlo.methods.healthsystem.summary", + key="HSI_Event", + custom_generate_series=lambda _df: get_num_hsi_by_period(_df), + do_scaling=True, + autodiscover=True, + ) + .pipe(set_param_names_as_column_index_level_0, param_names=param_names) + .fillna(0.0) + .sort_index() + ).drop(EXCLUDED_HSIs, level=0, errors='ignore') + + counts_of_hsi_by_period = ( + compute_summary_statistics(counts_of_hsi_by_period, 'median') + ) + results['counts_of_hsi_by_period'] = counts_of_hsi_by_period + + print("Extracting total deaths and DALYs by label...") + num_deaths = ( + extract_results( + results_folder, + module="tlo.methods.demography", + key="death", + custom_generate_series=get_num_deaths_by_cause_label_and_period, + do_scaling=True, + autodiscover=True, + ).pipe(set_param_names_as_column_index_level_0, param_names=param_names) + ) + + if do_comparison: + num_deaths_averted = compute_summary_statistics( + -1.0 * pd.DataFrame( + find_difference_extra_relative_to_comparison(num_deaths.sum(), comparison='Nothing')).T, + central_measure='median' + ).iloc[0].unstack() + + pc_deaths_averted = 100.0 * compute_summary_statistics( + -1.0 * pd.DataFrame( + find_difference_extra_relative_to_comparison(num_deaths.sum(), comparison='Nothing', scaled=True)).T, + central_measure='median' + ).iloc[0].unstack() + else: + num_deaths_averted = None + pc_deaths_averted = None + + num_deaths = compute_summary_statistics(num_deaths, central_measure='median') + + results['num_deaths'] = num_deaths + results['num_deaths_averted'] = num_deaths_averted + results['pc_deaths_averted'] = pc_deaths_averted + + dalys = ( + extract_results( + results_folder, + module="tlo.methods.healthburden", + key="dalys_stacked_by_age_and_time", + custom_generate_series=get_num_dalys_by_cause_label_and_period, + do_scaling=True, + autodiscover=True, + ).pipe(set_param_names_as_column_index_level_0, param_names=param_names) + ) + + if do_comparison: + dalys_averted = ( + -1.0 * pd.DataFrame( + find_difference_extra_relative_to_comparison(dalys.sum(), comparison='Nothing')) + + ) + + pc_dalys_averted = 100.0 * compute_summary_statistics( + -1.0 * pd.DataFrame( + find_difference_extra_relative_to_comparison(dalys.sum(), comparison='Nothing', scaled=True)).T, + central_measure='median' + ).iloc[0].unstack() + # Run-by-run incremental cost-effectiveness ratio calculation + icers = incremental_scenario_cost.T / dalys_averted + icers_summarized = compute_summary_statistics(icers.T, central_measure='median').iloc[0].unstack() + dalys_averted = compute_summary_statistics(dalys_averted.T, central_measure='median').iloc[0].unstack() + + dalys = compute_summary_statistics(dalys, central_measure='median') + + + # This gives us the capacity used for each cadre and level, for each draw and run + # From this we will extract the run-wise delta in capacity used relative to the Nothing scenario, for each cadre + # and summarise. However since no HSIs are delivered in the Nothing scenario, the capacity used in that scenario is zero, + # so the delta relative to Nothing is just the capacity used in each scenario. + # TODO: Check if this should be scaled with population or used as is. + annual_capacity_used_by_cadre_and_level = extract_results( + results_folder, + module='tlo.methods.healthsystem.summary', + key='Capacity_By_FacID_and_Officer', + custom_generate_series=lambda df: get_capacity_used_by_officer_type_and_facility_level(df, facility_id_levels_dict), + do_scaling=True, + autodiscover=True, + ) + # Sum across all facility levels and average across years; so we get the *average* annual capacity used over the whole period + # TODO: Check with Sakshi if this is what we want. + mask = annual_capacity_used_by_cadre_and_level.index.get_level_values(0).isin(range(2026, 2040)) + capacity_used_by_cadre = ( + annual_capacity_used_by_cadre_and_level[mask].groupby(['OfficerType', 'year']). + sum(). + groupby(['OfficerType']). + mean(). + pipe(set_param_names_as_column_index_level_0, param_names=param_names) + ) + + capacity_used_by_cadre = ( + compute_summary_statistics(capacity_used_by_cadre, central_measure='median') + ) + + # Get the total available caapacity by cadre needed for LCOA + # resources/healthsystem/human_resources/actual/ResourceFile_Daily_Capabilities.csv + daily_capacity_by_cadre_and_level = ( + pd.read_csv(resourcefilepath / "healthsystem" / "human_resources" / "actual" / "ResourceFile_Daily_Capabilities.csv") + ) + # This gives the total minutes available per day by cadre and facility level. + # Sum across levels to get cadre specific constraints, and multiply by 365 to get annual capacity + annual_capacity_by_cadre = ( + daily_capacity_by_cadre_and_level.groupby('Officer_Category')['Total_Mins_Per_Day'].sum() * 365 + ) + + staff_count_by_cadre = ( + daily_capacity_by_cadre_and_level.groupby('Officer_Category')['Staff_Count'].sum() + ) + + # Add consumables budget to this dictionary so that we have everything in one place + # USD 225,602,946 (203136642 from donors + 22466304 from the government) + # Revision of Malawi’s Health Benefits Package: A Critical Analysis of Policy Formulation and Implementation + # https://doi.org/10.1016/j.vhri.2023.10.007 + results['annual_consumables_budget'] = 225602946 + + results['dalys'] = dalys + results['dalys_averted'] = dalys_averted if do_comparison else None + results['pc_dalys_averted'] = pc_dalys_averted if do_comparison else None + results['icers_summarized'] = icers_summarized if do_comparison else None + results['incremental_scenario_cost'] = incremental_scenario_cost_summarized if do_comparison else None + results['capacity_used_by_cadre'] = capacity_used_by_cadre + results['annual_capacity_by_cadre'] = annual_capacity_by_cadre + results['staff_count_by_cadre'] = staff_count_by_cadre + + return results + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("results_folder", type=Path) + parser.add_argument("output_folder", type=Path, nargs="?", default=None) + parser.add_argument("--target-start", type=str, default=None) + parser.add_argument("--target-end", type=str, default=None) + parser.add_argument("--do-comparison", type=parse_bool, default=True) + parser.add_argument("--cost-checkpoint-profile", type=str, default=None) + parser.add_argument("--load-input-costs-from-checkpoint", type=parse_bool, default=None) + args = parser.parse_args() + + if (args.target_start is None) != (args.target_end is None): + parser.error("Provide both --target-start and --target-end, or neither.") + + target_period_tuple = ( + parse_iso_date(args.target_start), + parse_iso_date(args.target_end), + ) + if not target_period_tuple[0] < target_period_tuple[1]: + parser.error("--target-start must be earlier than --target-end.") + if args.load_input_costs_from_checkpoint is not None and args.cost_checkpoint_profile is None: + parser.error( + "Provide --cost-checkpoint-profile when using --load-input-costs-from-checkpoint." + ) + + out = args.output_folder if args.output_folder is not None else args.results_folder + results = apply( + results_folder=args.results_folder, + output_folder=out, + resourcefilepath=Path("./resources"), + target_period_tuple=target_period_tuple, + do_comparison=args.do_comparison, + cost_checkpoint_profile=args.cost_checkpoint_profile, + load_input_costs_from_checkpoint=args.load_input_costs_from_checkpoint, + ) + outfile = ( + f"{target_period_tuple[1].year:04d}-{target_period_tuple[1].month:02d}-{target_period_tuple[1].day:02d}" + "_fullresults.pkl" + ) + with open(out / outfile, 'wb') as f: + pickle.dump(results, f) + + print(f"Analysis complete! Results saved to {out / outfile}") diff --git a/src/scripts/lcoa_inputs_from_tlo_analyses/combine_suspended_and_resumed_pickles.py b/src/scripts/lcoa_inputs_from_tlo_analyses/combine_suspended_and_resumed_pickles.py new file mode 100644 index 0000000000..530f449e64 --- /dev/null +++ b/src/scripts/lcoa_inputs_from_tlo_analyses/combine_suspended_and_resumed_pickles.py @@ -0,0 +1,132 @@ +"""CLI helper to combine suspended and resumed pickle outputs.""" + +# python src/scripts/lcoa_inputs_from_tlo_analyses/combine_suspended_and_resumed_pickles.py --suspended_results_folder outputs/s.bhatia@imperial.ac.uk/effect_of_each_treatment_id-2026-02-12T120859Z --resumed_results_folder outputs/s.bhatia@imperial.ac.uk/effect_of_each_treatment_id-2026-02-16T154500Z_folder --output_folder outputs/s.bhatia@imperial.ac.uk/effect_of_each_treatment_id-combined + + +import argparse +import pickle +import warnings +from pathlib import Path +from typing import Any + +import pandas as pd + +def _validate_input_output_paths( + suspended_results_folder: Path, + resumed_results_folder: Path, + output_folder: Path, +) -> None: + """Validate input/output path constraints for pickle combination helper.""" + suspended_resolved = suspended_results_folder.resolve() + resumed_resolved = resumed_results_folder.resolve() + output_resolved = output_folder.resolve() + + if output_resolved == suspended_resolved or output_resolved == resumed_resolved: + raise ValueError( + "output_folder must be different from both suspended_results_folder and resumed_results_folder." + ) + +def _combine_pickled_objects(suspended_obj: Any, resumed_obj: Any, context: str = "root") -> Any: + """Combine suspended and resumed objects with suspended object first.""" + if suspended_obj is None and resumed_obj is None: + return None + if isinstance(suspended_obj, dict) and isinstance(resumed_obj, dict): + combined = {} + for key, suspended_value in suspended_obj.items(): + if key in resumed_obj: + combined[key] = _combine_pickled_objects( + suspended_value, resumed_obj[key], context=f"{context}.{key}" + ) + else: + combined[key] = suspended_value + for key, resumed_value in resumed_obj.items(): + if key not in combined: + combined[key] = resumed_value + return combined + if isinstance(suspended_obj, pd.DataFrame) and isinstance(resumed_obj, pd.DataFrame): + return pd.concat([suspended_obj, resumed_obj], axis=0) + if isinstance(suspended_obj, pd.Series) and isinstance(resumed_obj, pd.Series): + return pd.concat([suspended_obj, resumed_obj], axis=0) + if isinstance(suspended_obj, list) and isinstance(resumed_obj, list): + return suspended_obj + resumed_obj + if isinstance(suspended_obj, tuple) and isinstance(resumed_obj, tuple): + return suspended_obj + resumed_obj + try: + return suspended_obj + resumed_obj + except TypeError as exc: + raise TypeError( + f"Unsupported combine operation at {context}: " + f"{type(suspended_obj).__name__} and {type(resumed_obj).__name__}." + ) from exc + + +def combine_suspended_and_resumed_pickles( + suspended_results_folder: Path, + resumed_results_folder: Path, + output_folder: Path, +) -> None: + """Combine corresponding suspended and resumed pickles into output folder.""" + _validate_input_output_paths(suspended_results_folder, resumed_results_folder, output_folder) + + draw_dirs = sorted([p for p in resumed_results_folder.iterdir() if p.is_dir()], key=lambda p: p.name) + for draw_dir in draw_dirs: + print(f"Processing draw directory: {draw_dir}...") + run_dirs = sorted([p for p in draw_dir.iterdir() if p.is_dir()], key=lambda p: p.name) + for run_dir in run_dirs: + print(f" Processing run directory: {run_dir}...") + pickles = sorted(run_dir.glob("*.pickle"), key=lambda p: p.name) + for resumed_pickle_path in pickles: + print(f" Processing pickle file: {resumed_pickle_path}...") + with resumed_pickle_path.open("rb") as resumed_file: + resumed_obj = pickle.load(resumed_file) + + suspended_pickle_path = ( + suspended_results_folder / "0" / run_dir.name / resumed_pickle_path.name + ) + if suspended_pickle_path.exists(): + with suspended_pickle_path.open("rb") as suspended_file: + suspended_obj = pickle.load(suspended_file) + try: + combined_obj = _combine_pickled_objects(suspended_obj, resumed_obj) + except TypeError as exc: + raise TypeError( + "Could not combine pickled objects for " + f"{resumed_pickle_path} with types " + f"{type(suspended_obj).__name__} and {type(resumed_obj).__name__}." + ) from exc + else: + warnings.warn( + "No suspended counterpart found for " + f"{resumed_pickle_path} (expected at {suspended_pickle_path}); " + "copying resumed object to output unchanged.", + stacklevel=2, + ) + combined_obj = resumed_obj + + output_pickle_path = output_folder / draw_dir.name / run_dir.name / resumed_pickle_path.name + output_pickle_path.parent.mkdir(parents=True, exist_ok=True) + with output_pickle_path.open("wb") as output_file: + pickle.dump(combined_obj, output_file) + + +def main() -> None: + parser = argparse.ArgumentParser( + description=( + "Combine suspended and resumed pickle outputs into a new output folder, " + "with suspended content prepended where counterparts exist." + ) + ) + parser.add_argument("suspended_results_folder", type=Path) + parser.add_argument("resumed_results_folder", type=Path) + parser.add_argument("output_folder", type=Path) + args = parser.parse_args() + + combine_suspended_and_resumed_pickles( + suspended_results_folder=args.suspended_results_folder, + resumed_results_folder=args.resumed_results_folder, + output_folder=args.output_folder, + ) + + +if __name__ == "__main__": + main() diff --git a/src/scripts/lcoa_inputs_from_tlo_analyses/fig_utils.py b/src/scripts/lcoa_inputs_from_tlo_analyses/fig_utils.py new file mode 100644 index 0000000000..46f550e904 --- /dev/null +++ b/src/scripts/lcoa_inputs_from_tlo_analyses/fig_utils.py @@ -0,0 +1,712 @@ +"""Plotting utilities for treatment-id analysis scripts.""" + +import textwrap +import warnings + +import numpy as np +import pandas as pd +from matplotlib import pyplot as plt +from matplotlib.lines import Line2D +from matplotlib.patches import Patch + +from tlo.analysis.utils import ( + CAUSE_OF_DEATH_OR_DALY_LABEL_TO_COLOR_MAP, + get_color_cause_of_death_or_daly_label, + get_color_short_treatment_id, + make_calendar_period_type, + order_of_cause_of_death_or_daly_label, + order_of_short_treatment_ids, +) + + +APPOINTMENT_TYPE_PALETTE = list(plt.get_cmap("tab20").colors) + list(plt.get_cmap("Set2").colors) +APPOINTMENT_TYPE_FIXED_COLORS = {"AccidentsandEmerg": "black"} + +def make_graph_file_name(stub): + filename = stub.replace('*', '_star_').replace(' ', '_').replace('/', '').lower() + return f"{filename}.png" + + +def get_color_by_appointment_type(appointment_types) -> dict: + """Return a deterministic color map for appointment types.""" + non_fixed_appointment_types = sorted( + appt for appt in appointment_types if appt not in APPOINTMENT_TYPE_FIXED_COLORS + ) + color_by_appointment_type = { + appt: APPOINTMENT_TYPE_PALETTE[i % len(APPOINTMENT_TYPE_PALETTE)] + for i, appt in enumerate(non_fixed_appointment_types) + } + color_by_appointment_type.update( + {appt: color for appt, color in APPOINTMENT_TYPE_FIXED_COLORS.items() if appt in appointment_types} + ) + return color_by_appointment_type + + +def _get_short_treatment_id_and_color(treatment_id: str) -> tuple[str, str]: + """Return short treatment id prefix and plotting color for a treatment id.""" + short_treatment_id = str(treatment_id).split("_")[0] + color = get_color_short_treatment_id(short_treatment_id) + return short_treatment_id, ("grey" if pd.isna(color) else color) + + +def _get_ordered_short_treatment_ids(treatment_ids: pd.Index) -> list[str]: + """Return treatment ids with recognized short ids first in standard order.""" + treatment_ids = pd.Index(treatment_ids).unique() + recognized = [treatment_id for treatment_id in treatment_ids if not pd.isna(get_color_short_treatment_id(treatment_id))] + unrecognized = sorted(str(treatment_id) for treatment_id in treatment_ids if pd.isna(get_color_short_treatment_id(treatment_id))) + recognized = sorted(recognized, key=order_of_short_treatment_ids) + return recognized + unrecognized + + +def _parse_period_label(period_label: str) -> tuple[int, int]: + """Parse a period label of the form YYYY-YYYY into start/end years.""" + start_year_text, end_year_text = str(period_label).split("-", maxsplit=1) + return int(start_year_text), int(end_year_text) + + +def _get_sorted_period_labels_and_display_labels(period_labels: list[str]) -> tuple[list[str], list[str]]: + """Return chronological labels plus display labels, falling back to input order if parsing fails.""" + try: + parsed_periods = [(label, _parse_period_label(label)) for label in period_labels] + except (TypeError, ValueError): + return period_labels, period_labels + + ordered_period_labels = [ + label for label, _ in sorted(parsed_periods, key=lambda item: (item[1][0], item[1][1])) + ] + display_labels = [ + str(start_year) if start_year == end_year else label + for label, (start_year, end_year) in sorted(parsed_periods, key=lambda item: (item[1][0], item[1][1])) + ] + return ordered_period_labels, display_labels + + +def _compute_sanitized_asymmetric_errors( + _df: pd.DataFrame, + central_col: str = "central", + lower_col: str = "lower", + upper_col: str = "upper", +) -> tuple[np.ndarray, list]: + """Return non-negative asymmetric errors and labels whose CI bounds were auto-corrected.""" + required_columns = {central_col, lower_col, upper_col} + missing_columns = required_columns.difference(set(_df.columns)) + if missing_columns: + raise ValueError(f"Missing required CI column(s): {sorted(missing_columns)}") + + ci = _df.loc[:, [central_col, lower_col, upper_col]].copy() + ci.columns = ["central", "lower", "upper"] + + swapped_bounds = ci["lower"] > ci["upper"] + if swapped_bounds.any(): + swapped = ci.loc[swapped_bounds, ["lower", "upper"]].copy() + ci.loc[swapped_bounds, "lower"] = swapped["upper"].to_numpy() + ci.loc[swapped_bounds, "upper"] = swapped["lower"].to_numpy() + + central_below_lower = ci["central"] < ci["lower"] + central_above_upper = ci["central"] > ci["upper"] + + lower_error = ci["central"] - ci["lower"] + upper_error = ci["upper"] - ci["central"] + lower_error = lower_error.where(~central_below_lower, 0.0).clip(lower=0.0) + upper_error = upper_error.where(~central_above_upper, 0.0).clip(lower=0.0) + + corrected_rows = swapped_bounds | central_below_lower | central_above_upper + errors = np.vstack([lower_error.to_numpy(dtype=float), upper_error.to_numpy(dtype=float)]) + return errors, list(ci.index[corrected_rows]) + + +def _warn_if_ci_corrected(plot_function_name: str, corrected_labels: list, max_examples: int = 5) -> None: + """Emit one warning with sample labels when CI bounds required correction.""" + unique_labels = list(dict.fromkeys(corrected_labels)) + if not unique_labels: + return + + sample = ", ".join(str(label) for label in unique_labels[:max_examples]) + sample_suffix = "..." if len(unique_labels) > max_examples else "" + warnings.warn( + f"{plot_function_name}: auto-corrected inconsistent CI values for {len(unique_labels)} row(s). " + f"Sample labels: {sample}{sample_suffix}", + stacklevel=2, + ) + + +def plot_deaths_by_period_for_cause( + _df: pd.DataFrame, + cause_label: str, + plot_stat: str = "central", +): + """Plot deaths over time for a single cause, with one line per short treatment id.""" + if not isinstance(_df.index, pd.MultiIndex) or _df.index.nlevels != 2: + raise ValueError("_df index must be a 2-level MultiIndex with levels for label and period.") + if not isinstance(_df.columns, pd.MultiIndex) or _df.columns.nlevels != 2: + raise ValueError("_df columns must be a 2-level MultiIndex with levels for treatment id and stat.") + + label_level_name = "label" if "label" in _df.index.names else _df.index.names[0] + period_level_name = "period" if "period" in _df.index.names else _df.index.names[1] + stat_level_name = "stat" if "stat" in _df.columns.names else _df.columns.names[1] + + available_causes = pd.Index(_df.index.get_level_values(label_level_name).unique()) + if cause_label not in available_causes: + raise ValueError(f"Cause label '{cause_label}' not found. Available causes: {available_causes.tolist()}") + + available_stats = pd.Index(_df.columns.get_level_values(stat_level_name).unique()) + if plot_stat not in available_stats: + raise ValueError(f"Statistic '{plot_stat}' not found. Available stats: {available_stats.tolist()}") + + _plot = _df.xs(cause_label, level=label_level_name).xs(plot_stat, axis=1, level=stat_level_name) + if _plot.empty: + raise ValueError(f"No plottable data remain for cause '{cause_label}' using stat '{plot_stat}'.") + + _plot.index.name = period_level_name + try: + ordered_period_labels, display_period_labels = _get_sorted_period_labels_and_display_labels(_plot.index) + _plot = _plot.reindex(ordered_period_labels) + except (TypeError, ValueError): + _plot = _plot.loc[pd.Index(_plot.index).drop_duplicates()] + + ordered_treatment_ids = _get_ordered_short_treatment_ids(_plot.columns) + _plot = _plot.loc[:, ordered_treatment_ids] + + fig_width = max(10, min(1.4 * len(_plot.index) + 4, 18)) + fig, ax = plt.subplots(figsize=(fig_width, 6)) + x = np.arange(len(_plot.index)) + + for treatment_id in _plot.columns: + _, color = _get_short_treatment_id_and_color(treatment_id) + ax.plot( + x, + _plot[treatment_id].to_numpy(), + marker="o", + linewidth=1.8, + markersize=4, + color=color, + label=str(treatment_id), + ) + + ax.set_xticks(x) + ax.set_xticklabels(display_period_labels, rotation=45, ha="right") + ax.set_xlabel("Period") + ax.set_ylabel("Number of deaths") + ax.set_title(str(cause_label)) + ax.grid(axis="y") + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.legend( + title="Treatment ID", + loc="center left", + bbox_to_anchor=(1.02, 0.5), + fontsize=8, + title_fontsize=9, + frameon=True, + ) + return fig, ax + + +def plot_deaths_by_period_for_draw( + _df: pd.DataFrame, + draw: str, + plot_stat: str = "central", +): + """Plot deaths over time for a single draw, with one line per cause label.""" + if not isinstance(_df.index, pd.MultiIndex) or _df.index.nlevels != 2: + raise ValueError("_df index must be a 2-level MultiIndex with levels for label and period.") + if not isinstance(_df.columns, pd.MultiIndex) or _df.columns.nlevels != 2: + raise ValueError("_df columns must be a 2-level MultiIndex with levels for draw and stat.") + + label_level_name = "label" if "label" in _df.index.names else _df.index.names[0] + period_level_name = "period" if "period" in _df.index.names else _df.index.names[1] + draw_level_name = "draw" if "draw" in _df.columns.names else _df.columns.names[0] + stat_level_name = "stat" if "stat" in _df.columns.names else _df.columns.names[1] + + available_draws = pd.Index(_df.columns.get_level_values(draw_level_name).unique()) + if draw not in available_draws: + raise ValueError(f"Draw '{draw}' not found. Available draws: {available_draws.tolist()}") + available_stats = pd.Index(_df.columns.get_level_values(stat_level_name).unique()) + if plot_stat not in available_stats: + raise ValueError(f"Statistic '{plot_stat}' not found. Available stats: {available_stats.tolist()}") + + _plot = _df[draw].loc[:, [plot_stat]] + if _plot.empty: + raise ValueError(f"No plottable data remain for draw '{draw}' using stat '{plot_stat}'.") + + _plot = _plot[plot_stat].unstack(label_level_name) + ordered_causes = [ + cause_label for cause_label in CAUSE_OF_DEATH_OR_DALY_LABEL_TO_COLOR_MAP.keys() + if cause_label in _plot.columns + ] + unordered_causes = sorted( + cause_label for cause_label in _plot.columns if cause_label not in CAUSE_OF_DEATH_OR_DALY_LABEL_TO_COLOR_MAP + ) + _plot = _plot.loc[:, ordered_causes + unordered_causes] + + ordered_period_labels, display_period_labels = _get_sorted_period_labels_and_display_labels(_plot.index.tolist()) + _plot = _plot.reindex(ordered_period_labels) + if _plot.empty: + raise ValueError(f"No plottable data remain for draw '{draw}' after reshaping by cause.") + + fig_width = max(10, min(1.4 * len(ordered_period_labels) + 4, 18)) + fig, ax = plt.subplots(figsize=(fig_width, 6)) + + for cause_label in _plot.columns: + cause_values = _plot[cause_label] + if cause_values.notna().sum() == 0: + continue + ax.plot( + ordered_period_labels, + cause_values.to_numpy(), + marker="o", + linewidth=1.8, + markersize=4, + color=get_color_cause_of_death_or_daly_label(cause_label), + label=str(cause_label), + ) + + ax.set_xticks(ordered_period_labels) + ax.set_xticklabels(display_period_labels, rotation=45, ha="right") + ax.set_xlabel("Period") + ax.set_ylabel("Number of deaths") + ax.set_title(str(draw)) + ax.grid(axis="y") + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.legend( + title="", + loc="center left", + bbox_to_anchor=(1.02, 0.5), + fontsize=8, + title_fontsize=9, + frameon=True, + ) + return fig, ax + + +def do_bar_plot_with_ci( + _df: pd.DataFrame, + _param, + _ax, + period_labels_for_bar_plots: list[str], + target_period_label: str, +): + """Make vertical bars by cause, decomposed into period chunks, with overall-period CI.""" + available_params = _df.columns.get_level_values(0) if isinstance(_df.columns, pd.MultiIndex) else _df.columns + if _param not in available_params: + warnings.warn(f"Parameter '{_param}' not found in dataframe columns. Skipping plot.", stacklevel=2) + return + + _df_nothing = _df[_param] + _df_nothing = _df_nothing.reindex( + pd.MultiIndex.from_product( + [CAUSE_OF_DEATH_OR_DALY_LABEL_TO_COLOR_MAP.keys(), period_labels_for_bar_plots + [target_period_label]], + names=["label", "period"], + ), + fill_value=0.0, + ) + _df_nothing = _df_nothing.sort_index(axis=0, level=0, key=order_of_cause_of_death_or_daly_label) + + cause_labels = list(_df_nothing.index.get_level_values("label").unique()) + + corrected_labels = [] + for i, cause_label in enumerate(cause_labels): + color = get_color_cause_of_death_or_daly_label(cause_label) + one_cause = _df_nothing.xs(cause_label, level="label") + + bottom = 0.0 + for j, period_label in enumerate(period_labels_for_bar_plots): + chunk_height = one_cause.loc[period_label, "central"] if period_label in one_cause.index else 0.0 + _ax.bar(i, chunk_height, bottom=bottom, color=color, alpha=0.9 if j % 2 == 0 else 0.35) + bottom += chunk_height + + mean_value = one_cause.loc[target_period_label, "central"] + ci_row = pd.DataFrame( + { + "central": [mean_value], + "lower": [one_cause.loc[target_period_label, "lower"]], + "upper": [one_cause.loc[target_period_label, "upper"]], + }, + index=pd.Index([cause_label], name="label"), + ) + overall_yerr, corrected_row_labels = _compute_sanitized_asymmetric_errors(ci_row) + corrected_labels.extend(corrected_row_labels) + _ax.errorbar(i, mean_value, yerr=overall_yerr, fmt="none", ecolor="black", capsize=2, linewidth=1.2) + + _warn_if_ci_corrected("do_bar_plot_with_ci", corrected_labels) + + _ax.set_xticks(range(len(cause_labels))) + _ax.set_xticklabels(cause_labels, rotation=90) + chunk_legend_handles = [ + Patch(facecolor="grey", alpha=0.9 if i % 2 == 0 else 0.35, label=period_label) + for i, period_label in enumerate(period_labels_for_bar_plots) + ] + ci_legend_handle = Line2D([0], [0], color="black", marker="|", markersize=8, linewidth=1.2, label="95% CI") + _ax.legend(handles=chunk_legend_handles + [ci_legend_handle], loc="upper right") + + +def plot_multiindex_dot_with_interval( + _df: pd.DataFrame, + year: int, + _ax, + central_measure: str = "central", + value_col: str = "population", + sort: bool = True, + x_label_rotation: int = 90, + x_tick_fontsize: int = 8, + label_wrap_width: int = 18, + max_xticks: int = 30, +): + """Plot central-value dots and lower/upper intervals by category for one year.""" + if not isinstance(_df.index, pd.MultiIndex) or _df.index.nlevels < 3: + raise ValueError("_df index must be a MultiIndex with at least 3 levels: category, stat, year.") + if value_col not in _df.columns: + raise ValueError(f"Column '{value_col}' not found in dataframe.") + + year_level_values = _df.index.get_level_values(2) + available_years = pd.Index(year_level_values.unique()).sort_values() + if year not in available_years: + raise ValueError(f"Year '{year}' not found in index level 2. Available years: {available_years.tolist()}") + + stat_level_values = _df.index.get_level_values(1) + required_stats = {central_measure, "lower", "upper"} + missing_stats = required_stats.difference(set(stat_level_values)) + if missing_stats: + raise ValueError( + f"Missing required stat(s) in index level 1: {sorted(missing_stats)}. " + f"Available stats: {sorted(set(stat_level_values))}" + ) + + _plot = _df.xs(year, level=2)[value_col].unstack(level=1) + _plot = _plot.loc[:, [central_measure, "lower", "upper"]] + _plot = _plot.dropna(subset=[central_measure, "lower", "upper"]) + if _plot.empty: + raise ValueError(f"No plottable rows remain for year '{year}' after selecting required stats.") + + if sort: + _plot = _plot.sort_values(by=central_measure, ascending=True) + + x = np.arange(len(_plot.index)) + _ax.vlines(x, _plot["lower"], _plot["upper"], color="black", linewidth=1.2) + _ax.scatter(x, _plot[central_measure], color="black", s=20, zorder=3) + + _ax.figure.set_size_inches(max(12, min(0.25 * len(_plot.index), 36)), 7) + wrapped_labels = [textwrap.fill(str(label), width=label_wrap_width) for label in _plot.index] + if max_xticks is not None and len(x) > max_xticks: + step = int(np.ceil(len(x) / max_xticks)) + shown_positions = x[::step] + shown_labels = [wrapped_labels[i] for i in shown_positions] + _ax.set_xticks(shown_positions) + _ax.set_xticklabels(shown_labels, rotation=x_label_rotation, ha="right", fontsize=x_tick_fontsize) + else: + _ax.set_xticks(x) + _ax.set_xticklabels(wrapped_labels, rotation=x_label_rotation, ha="right", fontsize=x_tick_fontsize) + _ax.set_xlabel(_df.index.names[0] if _df.index.names[0] is not None else "category") + _ax.set_ylabel(value_col) + _ax.set_title(f"{value_col}: {central_measure} with lower/upper ({year})") + _ax.grid(axis="y") + _ax.spines["top"].set_visible(False) + _ax.spines["right"].set_visible(False) + + return _ax + + +def do_barh_plot_with_ci(_df: pd.DataFrame, _ax): + """Make horizontal bar plot for each treatment id.""" + errors, corrected_labels = _compute_sanitized_asymmetric_errors(_df) + _df.plot.barh( + ax=_ax, + y="central", + xerr=errors, + legend=False, + color=[_get_short_treatment_id_and_color(_id)[1] for _id in _df.index], + ) + _warn_if_ci_corrected("do_barh_plot_with_ci", corrected_labels) + + +def do_label_barh_plot(_df: pd.DataFrame, _ax): + """Add text annotation from values in dataframe onto axis.""" + y_cords = {ylabel.get_text(): ytick for ytick, ylabel in zip(_ax.get_yticks(), _ax.get_yticklabels())} + pos_on_rhs = _ax.get_xticks()[-1] + + for label, row in _df.iterrows(): + if row["central"] > 0: + annotation = f"{round(row['central'], 1)} ({round(row['lower'])}-{round(row['upper'])}) %" + _ax.annotate( + annotation, + xy=(pos_on_rhs, y_cords.get(label)), + xycoords="data", + horizontalalignment="left", + verticalalignment="center", + size=7, + ) + +def plot_cadre_time_by_draw_stacked( + _df: pd.DataFrame, + stat: str = "central", + figsize: tuple[float, float] | None = None, +): + """Plot horizontal stacked bars of cadre time use by draw for one summary stat.""" + if not isinstance(_df.columns, pd.MultiIndex) or _df.columns.nlevels != 2: + raise ValueError("_df columns must be a 2-level MultiIndex with levels for draw and stat.") + + stat_level_name = "stat" if "stat" in _df.columns.names else _df.columns.names[1] + available_stats = pd.Index(_df.columns.get_level_values(stat_level_name).unique()) + if stat not in available_stats: + raise ValueError(f"Statistic '{stat}' not found. Available stats: {available_stats.tolist()}") + + _plot = _df.xs(stat, axis=1, level=stat_level_name).T.fillna(0.0) + if _plot.empty: + raise ValueError(f"No plottable data remain for stat '{stat}'.") + + _plot = _plot.loc[_plot.sum(axis=1).sort_values(ascending=True).index] + + if figsize is None: + fig_height = max(6, min(0.35 * len(_plot.index) + 3, 20)) + figsize = (12, fig_height) + fig, ax = plt.subplots(figsize=figsize) + + cadre_colors = list(plt.get_cmap("tab10").colors) + left = np.zeros(len(_plot.index), dtype=float) + y = np.arange(len(_plot.index)) + + for i, cadre in enumerate(_plot.columns): + values = _plot[cadre].to_numpy(dtype=float) + ax.barh( + y, + values, + left=left, + color=cadre_colors[i % len(cadre_colors)], + label=str(cadre), + ) + left += values + + ax.set_yticks(y) + ax.set_yticklabels([str(draw) for draw in _plot.index]) + ax.set_xlabel("Time used") + ax.set_ylabel("Draw") + ax.grid(axis="x") + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.legend( + loc="lower right", + fontsize=12, + handlelength=2.4, + handleheight=1.6, + borderpad=1.0, + labelspacing=0.8, + frameon=True, + ) + fig.tight_layout() + return fig, ax + +def plot_hsi_counts_stacked_bar(_df: pd.DataFrame, plot_stat: str = "central"): + """Plot horizontal stacked bars of HSI counts by draw for a selected summary statistic.""" + if not isinstance(_df.columns, pd.MultiIndex) or _df.columns.nlevels != 2: + raise ValueError("_df columns must be a 2-level MultiIndex with levels for draw and stat.") + + stat_level_name = "stat" if "stat" in _df.columns.names else _df.columns.names[1] + stat_level_values = _df.columns.get_level_values(stat_level_name) + if plot_stat not in stat_level_values: + raise ValueError(f"The column MultiIndex does not contain '{plot_stat}' in the stat level.") + + _plot = _df.xs(plot_stat, axis=1, level=stat_level_name).T + if _plot.empty: + raise ValueError(f"No plottable data remain after selecting the '{plot_stat}' columns.") + + if _plot.isna().any().any(): + warnings.warn( + f"Missing values detected after selecting '{plot_stat}'. Bars will omit missing segments.", + stacklevel=2, + ) + + totals = _plot.sum(axis=1, skipna=True) + _plot = _plot.loc[totals.sort_values(ascending=False).index] + if not (_plot.gt(0).any(axis=1)).any(): + raise ValueError(f"No positive values remain after selecting the '{plot_stat}' columns.") + + fig_width = max(12, min(0.22 * len(_plot.columns) + 12, 30)) + fig_height = max(6, min(0.35 * len(_plot.index), 24)) + fig, ax = plt.subplots(figsize=(fig_width, fig_height)) + + left = np.zeros(len(_plot.index), dtype=float) + y = np.arange(len(_plot.index)) + + for treatment_id in _plot.columns: + values = _plot[treatment_id] + mask = values.gt(0) & values.notna() + if not mask.any(): + continue + ax.barh( + y[mask.to_numpy()], + values.loc[mask].to_numpy(), + left=left[mask.to_numpy()], + color=_get_color_for_treatment_id_prefix(treatment_id), + label=str(treatment_id), + ) + left[mask.to_numpy()] += values.loc[mask].to_numpy() + + ax.set_yticks(y) + ax.set_yticklabels([str(label) for label in _plot.index], fontsize=12) + ax.invert_yaxis() + fig.tight_layout() + return fig, ax + + +def plot_hsi_counts_by_period_for_draw( + _df: pd.DataFrame, + draw: str, +): + """Plot central values with lower/upper intervals across period chunks for one draw.""" + if not isinstance(_df.index, pd.MultiIndex) or _df.index.nlevels != 2: + raise ValueError("_df index must be a 2-level MultiIndex with levels for short_treatment_id and period.") + if not isinstance(_df.columns, pd.MultiIndex) or _df.columns.nlevels != 2: + raise ValueError("_df columns must be a 2-level MultiIndex with levels for draw and stat.") + if draw not in _df.columns.get_level_values(0): + available_draws = sorted(set(_df.columns.get_level_values(0))) + raise ValueError(f"Draw '{draw}' not found. Available draws: {available_draws}") + + _df = _df[draw] + _plot = _df.reindex( + pd.MultiIndex.from_product( + [ + _df.index.get_level_values(0).unique(), + _df.index.get_level_values(1).unique(), + ], + names=["treatment_id", "period"], + ), + fill_value=0.0, + ) + period_labels = _plot.index.get_level_values(1).unique() + _plot = _plot.loc[:, ["lower", "central", "upper"]] + _plot = _plot.unstack("period") + + central = _plot["central"] + lower = _plot["lower"] + upper = _plot["upper"] + periods_for_filtering = central.columns.difference(["2025-2025"], sort=False) + non_zero_mask = central.loc[:, periods_for_filtering].gt(0).any(axis=1) + + ordered_period_labels, display_period_labels = _get_sorted_period_labels_and_display_labels(period_labels) + central = central.loc[non_zero_mask, ordered_period_labels] + lower = lower.loc[non_zero_mask, ordered_period_labels] + upper = upper.loc[non_zero_mask, ordered_period_labels] + + if central.empty: + print(f"No non-zero treatment ids remain for draw '{draw}'.") + + x = np.arange(len(ordered_period_labels)) + fig_width = max(10, min(1.2 * len(ordered_period_labels) + 4, 20)) + fig_height = max(6, min(0.28 * len(central.index) + 6, 18)) + fig, ax = plt.subplots(figsize=(fig_width, fig_height)) + + corrected_labels = [] + for treatment_id in central.index: + central_values = central.loc[treatment_id].to_numpy() + ci_rows = pd.DataFrame( + { + "central": central.loc[treatment_id], + "lower": lower.loc[treatment_id], + "upper": upper.loc[treatment_id], + } + ) + yerr, corrected_periods = _compute_sanitized_asymmetric_errors(ci_rows) + corrected_labels.extend([f"{treatment_id}:{period}" for period in corrected_periods]) + _, color = _get_short_treatment_id_and_color(treatment_id) + ax.errorbar( + x, + central_values, + yerr=yerr, + fmt="o", + color=color, + ecolor=color, + elinewidth=1.2, + capsize=2, + markersize=4, + label=str(treatment_id), + ) + + _warn_if_ci_corrected("plot_hsi_counts_by_period_for_draw", corrected_labels) + + ax.set_xticks(x) + ax.set_xticklabels(display_period_labels, rotation=45, ha="right") + ax.set_xlabel("period") + ax.set_ylabel("HSI count") + ax.set_title(f"HSI counts by period: {draw}") + ax.grid(axis="y") + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.legend( + title="Treatment ID", + loc="center left", + bbox_to_anchor=(1.02, 0.5), + fontsize=8, + title_fontsize=9, + frameon=True, + ) + + fig.tight_layout() + return fig, ax + + +def plot_population_by_year(_df: pd.DataFrame): + """Plot yearly central population values for all draws.""" + if not isinstance(_df.columns, pd.MultiIndex) or _df.columns.nlevels != 2: + raise ValueError("_df columns must be a 2-level MultiIndex with levels for draw and stat.") + + stat_level_name = "stat" if "stat" in _df.columns.names else _df.columns.names[1] + + available_stats = pd.Index(_df.columns.get_level_values(stat_level_name).unique()) + if "central" not in available_stats: + raise ValueError(f"Statistic 'central' not found. Available stats: {available_stats.tolist()}") + + implementation_central = _df.xs("central", axis=1, level=stat_level_name).copy() + implementation_central.columns = implementation_central.columns.to_series().str.replace(r"_\*$", "", regex=True) + _plot = implementation_central + + _plot = _plot.loc[:, ~_plot.columns.duplicated()] + _plot = _plot.sort_index() + + ordered_treatment_ids = _get_ordered_short_treatment_ids(_plot.columns) + _plot = _plot.loc[:, ordered_treatment_ids] + + if _plot.empty: + raise ValueError("No plottable population data remain after selecting central values.") + + years = pd.Index(_plot.index) + x = np.arange(len(years)) + fig_width = max(10, min(1.0 * len(years) + 4, 20)) + fig_height = 6 + fig, ax = plt.subplots(figsize=(fig_width, fig_height)) + + for treatment_id in _plot.columns: + short_treatment_id, color = _get_short_treatment_id_and_color(treatment_id) + ax.plot( + x, + _plot[treatment_id].to_numpy(), + marker="o", + linewidth=1.8, + markersize=4, + color=color, + label=short_treatment_id, + ) + + ax.set_xticks(x) + ax.set_xticklabels([str(year) for year in years], rotation=45, ha="right") + ax.set_xlabel("Year") + ax.set_ylabel("Population size") + ax.set_title("Population size by year") + ax.grid(axis="y") + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + handles, labels = ax.get_legend_handles_labels() + deduplicated_handles_by_label = dict(zip(labels, handles)) + ax.legend( + handles=list(deduplicated_handles_by_label.values()), + labels=list(deduplicated_handles_by_label.keys()), + title="Treatment ID", + loc="center left", + bbox_to_anchor=(1.02, 0.5), + fontsize=8, + title_fontsize=9, + frameon=True, + ) + + fig.tight_layout() + return fig, ax diff --git a/src/scripts/lcoa_inputs_from_tlo_analyses/figures_effect_of_treatment_ids.py b/src/scripts/lcoa_inputs_from_tlo_analyses/figures_effect_of_treatment_ids.py new file mode 100644 index 0000000000..9d87ea1357 --- /dev/null +++ b/src/scripts/lcoa_inputs_from_tlo_analyses/figures_effect_of_treatment_ids.py @@ -0,0 +1,281 @@ +import argparse +import glob +import os +import zipfile +from pathlib import Path +import pickle +import pandas as pd +import matplotlib.pyplot as plt + +from scripts.lcoa_inputs_from_tlo_analyses.results_processing_utils import ( + get_parameter_names_from_scenario_file, + format_scenario_name, +) +from scripts.lcoa_inputs_from_tlo_analyses.fig_utils import ( + make_graph_file_name, + do_barh_plot_with_ci, + plot_cadre_time_by_draw_stacked, + plot_deaths_by_period_for_cause, + plot_deaths_by_period_for_draw, + plot_hsi_counts_by_period_for_draw, + plot_population_by_year, +) + + +# python src/scripts/lcoa_inputs_from_tlo_analyses/figures_effect_of_treatment_ids.py outputs/generated_outputs/2041-01-01_fullresults.pkl --output_folder=figs2 + + +PERIOD_LENGTH_YEARS_FOR_BAR_PLOTS = 1 + + +def load_results_files(results_files: list[Path]) -> dict[Path, dict]: + loaded = {} + for results_file in results_files: + print(f"Loading results file: {results_file}") + with open(results_file, "rb") as f: + loaded[results_file] = pickle.load(f) + return loaded + + +def apply(results_files: list[Path], output_folder: Path, resourcefilepath: Path = None): + """Produce standard plots describing effect of each TREATMENT_ID.""" + print("Starting figure generation for treatment-ID effects.") + print(f"Output folder: {output_folder}") + + param_names = get_parameter_names_from_scenario_file() + print(f"Loaded parameter names: {len(param_names)}") + + all_results = load_results_files(results_files) + primary_results = all_results[results_files[0]] + print(f"Using primary results from: {results_files[0]}") + + num_deaths_averted = primary_results.get('num_deaths_averted') + pc_deaths_averted = primary_results.get('pc_deaths_averted') + dalys_averted = primary_results.get('dalys_averted') + pc_dalys_averted = primary_results.get('pc_dalys_averted') + icers = primary_results.get('icers_summarized') + comparison_metrics_available = all( + metric is not None + for metric in ( + num_deaths_averted, + pc_deaths_averted, + dalys_averted, + pc_dalys_averted, + icers, + ) + ) + print(f"Comparison metrics available: {comparison_metrics_available}") + + counts_of_hsi_in_implementation_period = primary_results['counts_of_hsi_by_period'] + counts_of_hsi_in_implementation_period = counts_of_hsi_in_implementation_period.drop(['2010-2041'], level=1) + capacity_used_by_cadre = primary_results.get("capacity_used_by_cadre") + + result_df_by_period = pd.DataFrame([ + {'treatment_id_included': draw, 'nonzero_hsis': treatment_id, 'period': period} + for draw in counts_of_hsi_in_implementation_period.columns.get_level_values(0).unique() + for treatment_id, period in ( + ((counts_of_hsi_in_implementation_period[draw] != 0).any(axis=1))[ + (counts_of_hsi_in_implementation_period[draw] != 0).any(axis=1) + ].index + ) + ]) + result_df_by_period['treatment_id_included'] = result_df_by_period['treatment_id_included'].str.replace( + '_\\*$', '', regex=True + ) + + for param in param_names: + if param == "Nothing": + continue + draw = format_scenario_name(param) + print(f"Plotting yearly HSI counts for draw: {draw}") + name_of_plot = f"Yearly HSI counts for {draw}" + # Since all HSIs will be delivered before the service availability switch + # retain only the treatment id of interest in this period to avoid plot + # clutter. + pre_switch_periods = ( + ['2010-2010', '2011-2011', '2012-2012', '2013-2013', + '2014-2014', '2015-2015', '2016-2016', '2017-2017', + '2018-2018', '2019-2019', '2020-2020', '2021-2021', + '2022-2022', '2023-2023', '2024-2024', '2025-2025'] + ) + mask_other_periods = ( + ~counts_of_hsi_in_implementation_period. + index. + get_level_values("period"). + isin(pre_switch_periods) + ) + mask_early_periods = ( + counts_of_hsi_in_implementation_period.index.get_level_values("period").isin(pre_switch_periods) & + (counts_of_hsi_in_implementation_period.index.get_level_values("appt_type") == draw.replace("_*", "")) + ) + plot_this = counts_of_hsi_in_implementation_period[mask_other_periods | mask_early_periods] + fig, ax = plot_hsi_counts_by_period_for_draw( + plot_this, + draw, + ) + ax.set_title(name_of_plot) + outfile = os.path.join(output_folder, make_graph_file_name(name_of_plot)) + fig.savefig(outfile) + plt.close(fig) + + print("Plotting capacity used by cadres across draws.") + fig, ax = plot_cadre_time_by_draw_stacked(capacity_used_by_cadre, stat="central") + name_of_plot = "Capacity Used by Cadres (2026-2040)" + ax.set_title(name_of_plot) + outfile = os.path.join(output_folder, make_graph_file_name(name_of_plot)) + fig.savefig(outfile) + plt.close(fig) + + # Plot population growth + total_population_in_implementation = primary_results['total_population_by_year'] + print("Plotting population size by year.") + fig, ax = plot_population_by_year(total_population_in_implementation / 1e6) + name_of_plot = "Population size by year" + ax.set_title(name_of_plot) + ax.set_ylabel("Population size (millions)") + fig.savefig(make_graph_file_name(name_of_plot.replace(" ", "_"))) + plt.close(fig) + + # Plot number of deaths and DALYS by cause for each parameter, with confidence intervals, for the target period + num_dalys_by_cause_label_implementation = primary_results['dalys'].drop(['2010-2041'], level=1) + + num_deaths_by_cause_label_implementation = primary_results['num_deaths'].drop(['2010-2041'], level=1) + print("Prepared deaths and DALYs by cause for plotting.") + + for param in param_names: + draw = format_scenario_name(param) + print(f"Plotting deaths over time by cause for draw: {draw}") + fig, ax = plot_deaths_by_period_for_draw( + num_deaths_by_cause_label_implementation / 1e3, + draw, + ) + name_of_plot = f"Deaths Over Time by Cause for {draw}" + ax.set_title(name_of_plot) + ax.set_ylabel("Number of deaths (/1000)") + outfile = os.path.join(output_folder, make_graph_file_name(name_of_plot)) + fig.savefig(outfile) + plt.close(fig) + + cause_labels = num_deaths_by_cause_label_implementation.index.get_level_values("label").unique() + for cause_label in cause_labels: + print(f"Plotting cause-specific time series for: {cause_label}") + fig, ax = plot_deaths_by_period_for_cause( + num_deaths_by_cause_label_implementation / 1e3, + cause_label=cause_label, + ) + name_of_plot = f"Deaths Over Time for {cause_label}" + ax.set_title(name_of_plot) + ax.set_ylabel("Number of deaths (/1000)") + outfile = os.path.join(output_folder, make_graph_file_name(name_of_plot)) + fig.savefig(outfile) + plt.close(fig) + + fig, ax = plot_deaths_by_period_for_cause( + num_dalys_by_cause_label_implementation / 1e3, + cause_label=cause_label, + ) + name_of_plot = f"DALYs Over Time for {cause_label}" + ax.set_title(name_of_plot) + ax.set_ylabel("Number of DALYs (/1000)") + outfile = os.path.join(output_folder, make_graph_file_name(name_of_plot)) + fig.savefig(outfile) + plt.close(fig) + + if comparison_metrics_available: + print("Plotting comparison metrics: deaths/DALYs averted, percentages, and ICERs.") + deaths_averted_sorted = (num_deaths_averted.sort_values(by="central", ascending=True) / 1e3) + fig_height = max(6, min(0.28 * len(deaths_averted_sorted.index) + 4, 18)) + fig, ax = plt.subplots(figsize=(10, fig_height)) + name_of_plot = "Deaths Averted by Each Treatment ID" + do_barh_plot_with_ci(deaths_averted_sorted, ax) + ax.set_title(name_of_plot) + ax.set_xlabel("Number of deaths averted (/1000)") + ax.grid(axis="x") + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + outfile = os.path.join(output_folder, make_graph_file_name(name_of_plot)) + fig.tight_layout() + fig.savefig(outfile) + plt.close(fig) + print("Saved: Deaths Averted by Each Treatment ID") + + dalys_averted_sorted = (dalys_averted.sort_values(by="central", ascending=True) / 1e3) + fig_height = max(6, min(0.28 * len(dalys_averted_sorted.index) + 4, 18)) + fig, ax = plt.subplots(figsize=(10, fig_height)) + name_of_plot = "DALYS Averted by Each Treatment ID" + do_barh_plot_with_ci(dalys_averted_sorted, ax) + ax.set_title(name_of_plot) + ax.set_xlabel("DALYs averted (/1000)") + ax.grid(axis="x") + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + outfile = os.path.join(output_folder, make_graph_file_name(name_of_plot)) + fig.tight_layout() + fig.savefig(outfile) + plt.close(fig) + print("Saved: DALYS Averted by Each Treatment ID") + + pc_deaths_averted_sorted = (pc_deaths_averted.sort_values(by="central", ascending=True)) + fig_height = max(6, min(0.28 * len(pc_deaths_averted_sorted.index) + 4, 18)) + fig, ax = plt.subplots(figsize=(10, fig_height)) + name_of_plot = "Percentage Deaths Averted by Each Treatment ID" + do_barh_plot_with_ci(pc_deaths_averted_sorted, ax) + ax.set_title(name_of_plot) + ax.set_xlabel("Percentage of deaths averted") + ax.grid(axis="x") + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + outfile = os.path.join(output_folder, make_graph_file_name(name_of_plot)) + fig.tight_layout() + fig.savefig(outfile) + plt.close(fig) + print("Saved: Percentage Deaths Averted by Each Treatment ID") + + pc_dalys_averted_sorted = (pc_dalys_averted.sort_values(by="central", ascending=True)) + fig_height = max(6, min(0.28 * len(pc_dalys_averted_sorted.index) + 4, 18)) + fig, ax = plt.subplots(figsize=(10, fig_height)) + name_of_plot = "Percentage DALYs Averted by Each Treatment ID" + do_barh_plot_with_ci(pc_dalys_averted_sorted, ax) + ax.set_title(name_of_plot) + ax.set_xlabel("Percentage of DALYs averted") + ax.grid(axis="x") + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + outfile = os.path.join(output_folder, make_graph_file_name(name_of_plot)) + fig.tight_layout() + fig.savefig(outfile) + plt.close(fig) + print("Saved: Percentage DALYs Averted by Each Treatment ID") + + icers_sorted = icers.sort_values(by="central", ascending=True) + # Do not plot treatment ids with very wide uncertainty + # CervicalCancer_Screening_Xpert_* -110.336087 -6.192826 5064.399284 + # BreastCancer_PalliativeCare_* -25.104866 -5.740423 2611.046029 + # Hiv_Test_* -7335.183554 248.738016 856.794914 + + mask = ~icers_sorted.index.get_level_values("draw").isin(["Hiv_Test_*", "CervicalCancer_Screening_Xpert_*", "BreastCancer_PalliativeCare_*"]) + icers_sorted = icers_sorted[mask] + fig_height = max(6, min(0.28 * len(icers_sorted.index) + 4, 18)) + fig, ax = plt.subplots(figsize=(10, fig_height)) + name_of_plot = "ICERs for Each Treatment ID" + do_barh_plot_with_ci(icers_sorted, ax) + ax.set_title(name_of_plot) + ax.set_xlabel("ICER (USD per DALY averted)") + ax.grid(axis="x") + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + outfile = os.path.join(output_folder, make_graph_file_name(name_of_plot)) + fig.tight_layout() + fig.savefig(outfile) + plt.close(fig) + print("Saved: ICERs for Each Treatment ID") + + print("Finished generating figures.") + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("results_files", type=Path, nargs="+") + parser.add_argument("--output_folder", type=Path, required=True) + args = parser.parse_args() + + apply(results_files=args.results_files, output_folder=args.output_folder, resourcefilepath=Path("./resources")) diff --git a/src/scripts/lcoa_inputs_from_tlo_analyses/optimizer_preaggregated.R b/src/scripts/lcoa_inputs_from_tlo_analyses/optimizer_preaggregated.R new file mode 100644 index 0000000000..fa43c848d2 --- /dev/null +++ b/src/scripts/lcoa_inputs_from_tlo_analyses/optimizer_preaggregated.R @@ -0,0 +1,371 @@ +# Standalone preaggregated optimizer for Python integration. +# - ce_dalys, conscost, and hr_* are preaggregated totals at full implementation. +# - Decision variables represent fractions of each intervention implemented. +# - feascov and substitute/compulsory constraints still bound implementation shares. + +library(lpSolve) + +find_optimal_package <- function(inputs, objective_input, cet_input, + drug_budget_input, drug_budget.scale, + hr.time.constraint, hr.size, hr.scale, + use_feasiblecov_constraint, feascov_scale, compcov_scale, + compulsory_interventions, substitutes, task_shifting_pharm) { # % complements % + + ## Total DALYs averted based on CE evidence; this was per person in the original + ## script but is cumulative in this version + dalys <- as.numeric(as.character(inputs$ce_dalys)) + ## Cumulative cost of drugs and commodities + drugcost <- as.numeric(as.character(inputs$conscost)) + maxcoverage <- as.numeric(as.character(inputs$feascov)) # Maximum possible coverage (demand constraint) + ## Preaggregated mode: unit case scaling + cases <- rep(1, length(dalys)) + ## Full cost + fullcost <- as.numeric(as.character(inputs$ce_cost)) + ## Number of minutes of health worker time required per intervention + hrneed <- + inputs[c("hr_clin", "hr_nur", "hr_pharm", "hr_lab", "hr_ment", "hr_nutri")] + hrneed <- as.data.frame(apply(hrneed, 2, as.numeric)) + + n <- length(dalys) # number of interventions included in the analysis + + ################################### + # 3.1 Set up LPP + ################################### + + # Objective - maximize DALYs + #**************************************************** + # Define net health + cet <- cet_input + nethealth <- dalys - fullcost / cet + + # Define objective + if (objective_input == "nethealth") { + objective <- nethealth * cases + } else if (objective_input == "dalys") { + objective <- dalys * cases + } else { + print("ERROR: objective_input can take values dalys or nethealth") + } + + # Constraints - 1. Drug Budget, 2. HR Requirements + #**************************************************** + # 1. Drug Budget + #---------------- + cons_drug <- drugcost * cases # Cost of drugs for the number of cases covered + cons_drug.limit <- drug_budget_input * drug_budget.scale + cons_drug.limit_base <- drug_budget_input # unscaled drug budget + + # 2. HR Constraints + #--------------------- + ## HR minutes required to deliver intervention to all cases in need + hr_minutes_need <- hrneed * cases[row(hrneed)] + + ## Update HR constraints so that nurses, pharmacists, medical officers, etc. represent joint constraints + ## Medical officer + Clinical officer + Medical Assistant + clinicalstaff.need <- hr_minutes_need[c("hr_clin")] + ## Nurse officer + Nurse midwife + nursingstaff.need <- hr_minutes_need[c("hr_nur")] + ## Pharmacist + Pharmacist Technician + Pharmacist Assistant + pharmstaff.need <- hr_minutes_need[c("hr_pharm")] + ## Lab officer + Lab technician + Lab assistant + labstaff.need <- hr_minutes_need[c("hr_lab")] + # remove CHW + mentalstaff.need <- hr_minutes_need[c("hr_ment")] # Mental health staff + nutristaff.need <- hr_minutes_need[c("hr_nutri")] # Nutrition staff + + # Clean total minutes available per cadre + cons_hr.limit <- hr.time.constraint + clinicalstaffmins.limit <- cons_hr.limit[1] + nursingstaffmins.limit <- cons_hr.limit[2] + pharmstaffmins.limit <- cons_hr.limit[3] + labstaffmins.limit <- cons_hr.limit[4] + mentalstaffmins.limit <- cons_hr.limit[5] + nutristaffmins.limit <- cons_hr.limit[6] + + reps <- 4 # set the number of times that the matrix of interventions is duplicated + + # Define a function which duplicates a matrix horizontally + duplicate_matrix_horizontally <- function(reps, matrix) { + matrix <- do.call(rbind, replicate(reps, matrix, simplify = FALSE)) + } + + if (task_shifting_pharm == 0) { + print("") + } else if (task_shifting_pharm == 1) { + clinicalstaff.need <- + duplicate_matrix_horizontally(reps, as.matrix(clinicalstaff.need)) + nursingstaff.need <- + rbind(as.matrix(nursingstaff.need), + as.matrix(nursingstaff.need + pharmstaff.need), + as.matrix(nursingstaff.need + nutristaff.need), + as.matrix(nursingstaff.need + nutristaff.need + pharmstaff.need)) + pharmstaff.need <- + rbind(as.matrix(pharmstaff.need), as.matrix(rep(0, n)), + as.matrix(pharmstaff.need), as.matrix(rep(0, n))) + labstaff.need <- + duplicate_matrix_horizontally(reps, as.matrix(labstaff.need)) + mentalstaff.need <- + duplicate_matrix_horizontally(reps, as.matrix(mentalstaff.need)) + nutristaff.need <- + rbind(as.matrix(nutristaff.need), as.matrix(nutristaff.need), + as.matrix(rep(0, n)), as.matrix(rep(0, n))) + } else { + print("ERROR: tash_shifting_pharm can take values 0 or 1") + } + + # Clean total workforce size per cadre + hr_size.limit <- hr.size + clinicalstaff.limit <- hr_size.limit[1] + nursingstaff.limit <- hr_size.limit[2] + pharmstaff.limit <- hr_size.limit[3] + labstaff.limit <- hr_size.limit[4] + mentalstaff.limit <- hr_size.limit[5] + nutristaff.limit <- hr_size.limit[6] + + clinicalstaff.scale <- hr.scale[1] + nursestaff.scale <- hr.scale[2] + pharmstaff.scale <- hr.scale[3] + labstaff.scale <- hr.scale[4] + mentalstaff.scale <- hr.scale[5] + nutristaff.scale <- hr.scale[6] + + ## Each list here represents the number of staff (of each cadre) needed to deliver each intervention to all cases in need. + ## Eg. for each cesarean section, 45 minutes of medical staff's time is needed (or 104,200 minutes for 2316 cases). On average 39,900 minutes are available per medical staff each year (257.3 million minutes in total divided by 6,400 medical staff). This means that for 2136 cases, 2.16 medical staff are needed (2316*45/(257.3m/6400)) + + cons_hr <- + cbind(clinicalstaff.need / (clinicalstaffmins.limit / clinicalstaff.limit), + nursingstaff.need / (nursingstaffmins.limit / nursingstaff.limit), + pharmstaff.need / (pharmstaffmins.limit / pharmstaff.limit), + labstaff.need / (labstaffmins.limit / labstaff.limit), + mentalstaff.need / (mentalstaffmins.limit / mentalstaff.limit), + nutristaff.need / (nutristaffmins.limit / nutristaff.limit)) + cons_hr.saved <- cons_hr + + cons_hr.limit_base <- + cbind(clinicalstaff.limit, nursingstaff.limit, pharmstaff.limit, + labstaff.limit, mentalstaff.limit, nutristaff.limit) + cons_hr.limit <- + cbind(clinicalstaff.limit * clinicalstaff.scale, + nursingstaff.limit * nursestaff.scale, + pharmstaff.limit * pharmstaff.scale, + labstaff.limit * labstaff.scale, + mentalstaff.limit * mentalstaff.scale, + nutristaff.limit * nutristaff.scale) + + colnames(cons_hr.limit) <- colnames(cons_hr) + cons_hr.limit.saved <- cons_hr.limit + + # Combine the constraints into one matrix + #**************************************************** + # 1. HR + #-------------------------------------- + cons_hr <- as.matrix(cons_hr) + cons_hr.limit <- as.matrix(cons_hr.limit) + + # 2. Drug + #-------------------------------------- + cons_drug <- as.matrix(cons_drug) + cons_drug.limit <- as.matrix(cons_drug.limit) + + # 3. Max coverage + #-------------------------------------- + cons.feascov <- diag(x = cases, n, n) + if (use_feasiblecov_constraint == 1) { + cons.feascov.limit <- as.matrix(maxcoverage * feascov_scale * cases) + } else if (use_feasiblecov_constraint == 0) { + cons.feascov.limit <- as.matrix(cases) # changed the constraint on 12May (multiplied by cases) + } else { + print("ERROR: use_feasiblecov_constraint can take values 0 or 1") + } + + nonneg.lim <- as.matrix(rep(0, n)) + + # 4. Compulsory interventions + #-------------------------------------- + if (length(compulsory_interventions) > 0) { + comp.count <- length(compulsory_interventions) + cons_compulsory <- matrix(0L, length(compulsory_interventions), ncol = n) + cons_compulsory.limit <- matrix(0L, length(compulsory_interventions), ncol = 1) + for (i in 1:length(compulsory_interventions)) { + a <- which(inputs$intcode == compulsory_interventions[i]) + b <- inputs$intervention[a] + # print(paste("Compulsory intervention: ",b, "; Code: ", compulsory_interventions[i], "; Number ",a )) + cons_compulsory[i, a] <- cases[a] + # CHECK THIS CHANGE MADE on 26Aug21 + cons_compulsory.limit[i] <- cases[a] * maxcoverage[a] * feascov_scale * compcov_scale # changed on 12May to maxcoverage because cons.feascov.limit is now maximum number of cases rather than maximum % coverage + } + dim(cons_compulsory) + } else if (length(compulsory_interventions) == 0) { + comp.count <- 1 + cons_compulsory <- matrix(0L, 1, ncol = n) + cons_compulsory.limit <- matrix(0L, 1, ncol = 1) + } + cons_compulsory <- t(cons_compulsory) + + # placeholder# + ###### % Complementary interventions code left out for now % + + # 5. Substitute interventions + #-------------------------------------- + substitutes <- substitutes + subs.count <- length(substitutes) + cons_substitutes.limit <- matrix(0L, length(substitutes), ncol = 1) + cons_substitutes <- matrix(0L, length(substitutes), ncol = n) + + # First find the maximum number of feasible cases among the substitute interventions + subsgrp_casesmax <- matrix(0L, length(substitutes), ncol = 1) + for (i in 1:subs.count) { + for (j in substitutes[i]) { + subsgrp_cases <- 0 + for (k in j) { + a <- which(inputs$intcode == k) + if (use_feasiblecov_constraint == 1) { + cases_max <- cases[a] * maxcoverage[a] * feascov_scale + } else if (use_feasiblecov_constraint == 0) { + cases_max <- cases[a] + } + subsgrp_cases <- cbind(subsgrp_cases, cases_max) + } + subsgrp_casesmax[i] <- max(subsgrp_cases) + # print(paste("Group", i, "Cases max", subsgrp_casesmax[i])) + } + } + + # Next define the constraint such that the sum of the cases for each substitute interventions is less than or equal to the maxumum feasible cases derived above + # print("Substitutes") + for (i in 1:subs.count) { + # print(paste("Substitute group", i)) + # print("------------------------------------------------------------") + for (j in substitutes[i]) { + for (k in j) { + a <- which(inputs$intcode == k) + b <- inputs$intervention[a] + # print(paste("Intervention: ",b, "; Code: ", k, "; Maximum cases for intervention:", cons.feascov.limit[a],"; Number: ",a)) + cons_substitutes[i, a] <- cases[a] # changed on 12May from 1 to cases + cons_substitutes.limit[i] <- subsgrp_casesmax[i] # changed on 12May to maxcoverage because cons.feascov.limit is now maximum number of cases rather than maximum % coverage + } + } + # cons_substitutes.limit[i] <- cons_substitutes.limit[i]/lengths(substitutes)[i] # removed on 12May + # print(paste("Maximum combined cases for group ",i, "= ", subsgrp_casesmax[i])) # print suppressed + } + cons_substitutes <- t(cons_substitutes) + + # Changes to constraints if task-shifting of pharmacist responsibility is allowed + #-------------------------------------------------------------------------------- + # Update the constraint matrices if task shifting is allowed + if (task_shifting_pharm == 0) { + print("No task shifting of pharmaceutical tasks") + } else if (task_shifting_pharm == 1) { + # 1. Objective + objective <- duplicate_matrix_horizontally(reps, as.matrix(objective)) + # 2. Drug budget constraint (cons_drug.limit does not need to be changed) + cons_drug <- duplicate_matrix_horizontally(reps, as.matrix(cons_drug)) + # 3. Feasible coverage constraint + cons.feascov <- duplicate_matrix_horizontally(reps, as.matrix(cons.feascov)) + # 4. Compulsory interventions + cons_compulsory <- duplicate_matrix_horizontally(reps, as.matrix(cons_compulsory)) + # 6. Substitutes + cons_substitutes <- duplicate_matrix_horizontally(reps, as.matrix(cons_substitutes)) + } else { + print("ERROR: task_shifting_pharm can take values 0 or 1") + } + + # Combine constraints 1-5 + print(dim(t(cons_drug))) + print(dim(t(cons_hr))) + print(dim(t(cons.feascov))) + print(dim(t(cons_compulsory))) + print(dim(t(cons_substitutes))) + cons.mat <- rbind(t(cons_drug), t(cons_hr), t(cons.feascov), t(cons.feascov), t(cons_compulsory), t(cons_substitutes)) # % cons_complements % + dim(cons.mat) + cons.mat.limit <- rbind(cons_drug.limit, t(cons_hr.limit), cons.feascov.limit, nonneg.lim, cons_compulsory.limit, cons_substitutes.limit) # cons_complements.limit, + dim(cons.mat.limit) + print(dim(cons.mat)) + print(dim(cons.mat.limit)) + + # Direction of relationship + cons.dir <- rep("<=", 1 + 8 + n) + cons.dir <- c(cons.dir, rep(">=", n), rep(">=", comp.count)) + cons.dir <- c(cons.dir, rep("<=", length(substitutes))) + # % cons.dir <- c(cons.dir,rep("<=",length(complements))) % + length(cons.dir) + length(cons.dir) <- dim(cons.mat.limit)[1] # Assert that the length of the directions list is the same as that of the constraints matrix + + ################################### + # 3.2 - Run LPP + ################################### + solution.class <- lp("max", objective, cons.mat, cons.dir, cons.mat.limit, compute.sens = TRUE) + + ################################### + # 3.3 - Outputs + ################################### + # Export solution to a .csv file + #------------------------------------ + solution <- as.data.frame(solution.class$solution) + solution_hr <- as.data.frame(solution.class$solution) # use this uncollapsed version of the dataframe for HR use calculations below + # Collapse solution by intervention + if (task_shifting_pharm == 1) { + for (i in 1:length(dalys)) { + for (j in 1:(reps - 1)) { + solution[i, 1] <- solution[i, 1] + solution[i + length(dalys) * j, 1] + } + } + solution <- as.data.frame(solution[1:length(dalys), 1]) + } + + # Number of interventions with a positive net health impact + pos_nethealth.count <- sum(nethealth > 0) # this seems to be one less than the figure in the excel + + # Number of interventions in the optimal package + intervention.count <- sum(solution != 0) + + # DALY burden averted as a % of avertible DALY burden + solution_dalysaverted <- solution * cases * dalys # Dalys averted per intervention + dalysavertible <- cases * dalys # Total DALYs that can be averted at maximum coverage + dalys_averted <- round(sum(unlist(lapply(solution_dalysaverted, sum))), 2) + dalys_averted.prop <- sum(unlist(lapply(solution_dalysaverted, sum))) / sum(unlist(lapply(dalysavertible, sum))) + + # Drugs and Commodities cost (% of budget available) + solution_drugexp <- solution * cons_drug[1:length(dalys), ] # Total drug budget required per intervention for the the optimal solution + total_drug_exp <- round(sum(unlist(lapply(solution_drugexp, sum))), 2) # Total drug budget required for the the optimal solution + drug_exp.prop <- total_drug_exp / cons_drug.limit_base + + # Total HR use (% of capacity) + hr_cadres <- c("Clinical staff", "Nurse", "Pharmacist", "Lab", "Mental", "Nutrition") + solution_hruse <- unlist(solution_hr) * cons_hr # Number of minutes per health worker cadre and intervention utlitised by the optimal solution + if (task_shifting_pharm == 1) { + for (i in 1:length(dalys)) { + for (j in 1:(reps - 1)) { + solution_hruse[i, ] <- solution_hruse[i, ] + solution_hruse[i + length(dalys) * j, ] + } + } + solution_hruse <- solution_hruse[1:length(dalys), ] + } + total_hruse <- colSums(solution_hruse, na.rm = FALSE, dims = 1) # Number of minutes per health worker cadre utlitised by the optimal solution + hruse.prop <- round(total_hruse / cons_hr.limit_base, 2) + colnames(hruse.prop) <- hr_cadres + + # Cost-effectiveness Threshold + icer <- fullcost / dalys + temp <- cbind.data.frame(icer, solution, inputs$intervention) + temp["solution.class$solution"] <- as.numeric(temp[[2]]) + temp["icer"] <- as.numeric(temp[[1]]) + cet_soln <- round(max(temp["icer"][temp["solution.class$solution"] > 0]), 2) # previoiusly temp$icer[temp$solution > 0] + a <- which(icer == max(temp["icer"][temp["solution.class$solution"] > 0])) # to check which included intervention has the highest ICER + least.ce.intervention <- inputs$intervention[a] + + # Collapse above outputs so that each intervention appears once in the list irrespective of task-shifting + # pos_nethealth.count, intervention.count, dalys_averted, cet_soln, drug_exp.prop, t(hruse.prop[,visible_cadres]) + + list( + "Total number of interventions in consideration" = length(dalys), + "Number of interventions with positive net health impact" = pos_nethealth.count, + "Number of interventions in the optimal package" = intervention.count, + "Net DALYs averted" = solution.class$objval, + "Total DALYs averted" = sum(unlist(lapply(solution_dalysaverted, sum))), + "Proportion of DALY burden averted" = dalys_averted.prop, + "Proportion of drug budget used" = drug_exp.prop, + "Proportion of HR capacity used by cadre" = hruse.prop, + "CET based on solution" = cet_soln + ) +} diff --git a/src/scripts/lcoa_inputs_from_tlo_analyses/results_processing_utils.py b/src/scripts/lcoa_inputs_from_tlo_analyses/results_processing_utils.py new file mode 100644 index 0000000000..e62f3db4ac --- /dev/null +++ b/src/scripts/lcoa_inputs_from_tlo_analyses/results_processing_utils.py @@ -0,0 +1,482 @@ +"""Utilities for extracting and processing results for treatment-id analyses.""" + + +import numpy as np +import pandas as pd + +from scripts.lcoa_inputs_from_tlo_analyses.scenario_effect_of_treatment_ids import ( + EffectOfEachTreatment, +) +from tlo import Date +from tlo.analysis.utils import ( + make_age_grp_types, + summarize, + to_age_group, + unflatten_flattened_multi_index_in_logging, +) + +def find_difference_relative_to_comparison(_ser: pd.Series, + comparison: str, + scaled: bool = False, + drop_comparison: bool = True, + ): + """Find the difference in the values in a pd.Series with a multi-index, between the draws (level 0) + within the runs (level 1), relative to where draw = `comparison`. + The comparison is `X - COMPARISON`.""" + return _ser \ + .unstack(level=0) \ + .apply(lambda x: (x - x[comparison]) / (x[comparison] if scaled else 1.0), axis=1) \ + .drop(columns=([comparison] if drop_comparison else [])) \ + .stack() + +def get_total_population_by_year( + _df: pd.DataFrame, + target_period_tuple: tuple[Date, Date], +) -> pd.Series: + years_needed = [i.year for i in target_period_tuple] + _df["year"] = pd.to_datetime(_df["date"]).dt.year + return _df.loc[_df["year"].between(min(years_needed), max(years_needed)), ["year", "total"]].set_index("year")[ + "total" + ] + + + +def extract_deaths_total(df: pd.DataFrame) -> pd.Series: + return pd.Series({"Total": len(df)}) + + +def target_period(target_period_tuple: tuple[Date, Date]) -> str: + """Returns the target period as a string of the form YYYY-YYYY.""" + return "-".join(str(t.year) for t in target_period_tuple) + + +def get_periods_within_target_period( + period_length_years: int, + target_period_tuple: tuple[Date, Date], +) -> list[tuple[str, tuple[int, int]]]: + """Return chunks within target period as [(label, (start_year, end_year)), ...].""" + if period_length_years <= 0: + raise ValueError("period_length_years must be a positive integer.") + start_year, end_year = target_period_tuple[0].year, target_period_tuple[1].year + periods = [] + for chunk_start in range(start_year, end_year + 1, period_length_years): + chunk_end = min(chunk_start + period_length_years - 1, end_year) + periods.append((f"{chunk_start}-{chunk_end}", (chunk_start, chunk_end))) + return periods + + +def get_parameter_names_from_scenario_file() -> tuple[str]: + """Get tuple of scenario names from Scenario class used to create results.""" + e = EffectOfEachTreatment() + excluded = {"Only Hiv_Test_Selftest_*"} + # I think Hiv_test_Selftest has been added after I had submitted the draws, hence filtering it out. + return tuple(name for name in e._scenarios.keys() if name not in excluded) + + +def format_scenario_name(_sn: str) -> str: + """Return reformatted scenario name ready for plotting.""" + if _sn == "Nothing": + return "Nothing" + else: + return _sn.removeprefix("Only ") + + +def set_param_names_as_column_index_level_0(_df: pd.DataFrame, param_names: tuple[str, ...]) -> pd.DataFrame: + """Set columns index level 0 as scenario param names.""" + + ordered_param_names_no_prefix = {i: x for i, x in enumerate(param_names)} + names_of_cols_level0 = [ordered_param_names_no_prefix.get(col) for col in _df.columns.levels[0]] + assert len(names_of_cols_level0) == len(_df.columns.levels[0]) + + reformatted_names = map(format_scenario_name, names_of_cols_level0) + _df.columns = _df.columns.set_levels(reformatted_names, level=0) + return _df + + +def find_difference_extra_relative_to_comparison( + _ser: pd.Series, + comparison: str, + scaled: bool = False, + drop_comparison: bool = True, +): + """Find run-wise differences relative to comparison in a series with multi-index.""" + return ( + _ser.unstack() + .apply(lambda x: (x - x[comparison]) / (x[comparison] if scaled else 1.0), axis=0) + .drop(index=([comparison] if drop_comparison else [])) + .stack() + + ) + + +def find_mean_difference_in_appts_relative_to_comparison( + _df: pd.DataFrame, + comparison: str, + drop_comparison: bool = True, +): + """Find mean fewer appointments when treatment does not happen relative to comparison.""" + return -summarize( + pd.concat( + { + _idx: find_difference_extra_relative_to_comparison( + row, comparison=comparison, drop_comparison=drop_comparison + ) + for _idx, row in _df.iterrows() + }, + axis=1, + ).T, + only_mean=True, + ) + + +def find_mean_difference_extra_relative_to_comparison_dataframe( + _df: pd.DataFrame, + comparison: str, + drop_comparison: bool = True, +): + """Same as find_difference_extra_relative_to_comparison but for dataframe.""" + return summarize( + pd.concat( + { + _idx: find_difference_extra_relative_to_comparison( + row, comparison=comparison, drop_comparison=drop_comparison + ) + for _idx, row in _df.iterrows() + }, + axis=1, + ).T, + only_mean=True, + ) + + +def get_num_deaths_by_cause_label(_df: pd.DataFrame, target_period_tuple: tuple[Date, Date]) -> pd.Series: + """Return total deaths by label within target period.""" + return _df.loc[pd.to_datetime(_df.date).between(*target_period_tuple)].groupby(_df["label"]).size() + + +def get_num_dalys_by_cause_label(_df: pd.DataFrame, target_period_tuple: tuple[Date, Date]) -> pd.Series: + """Return total DALYS by label within target period.""" + return ( + _df.loc[_df.year.between(*[i.year for i in target_period_tuple])] + .drop(columns=["date", "sex", "age_range", "year"]) + .sum() + ) + + +def make_get_num_deaths_by_cause_label_and_period( + period_length_years: int, + target_period_tuple: tuple[Date, Date] +): + """Create helper that summarizes deaths by cause and period chunks + overall.""" + periods = get_periods_within_target_period( + period_length_years=period_length_years, + target_period_tuple=target_period_tuple, + ) + period_lookup = { + year: period_label + for period_label, (start_year, end_year) in periods + for year in range(start_year, end_year + 1) + } + target_period_label = target_period(target_period_tuple) + + def _get_num_deaths_by_cause_label_and_period(_df: pd.DataFrame) -> pd.Series: + _df_in_target = _df.loc[pd.to_datetime(_df.date).between(*target_period_tuple)].copy() + _df_in_target["year"] = pd.to_datetime(_df_in_target["date"]).dt.year + _df_in_target["period"] = _df_in_target["year"].map(period_lookup) + + chunked = _df_in_target.groupby(["label", "period"]).size() + overall = _df_in_target.groupby("label").size() + overall.index = pd.MultiIndex.from_arrays( + [overall.index, np.repeat(target_period_label, len(overall.index))], names=["label", "period"] + ) + return pd.concat([chunked, overall]).sort_index() + + return _get_num_deaths_by_cause_label_and_period + + +def make_get_num_dalys_by_cause_label_and_period( + period_length_years: int, + target_period_tuple: tuple[Date, Date] +): + """Create helper that summarizes DALYS by cause and period chunks + overall.""" + periods = get_periods_within_target_period( + period_length_years=period_length_years, + target_period_tuple=target_period_tuple, + ) + period_lookup = { + year: period_label + for period_label, (period_start, period_end) in periods + for year in range(period_start, period_end + 1) + } + start_year, end_year = target_period_tuple[0].year, target_period_tuple[1].year + target_period_label = target_period(target_period_tuple) + + def _get_num_dalys_by_cause_label_and_period(_df: pd.DataFrame) -> pd.Series: + _df_in_target = _df.loc[_df.year.between(start_year, end_year)].copy() + _df_in_target["period"] = _df_in_target["year"].map(period_lookup) + + melted = ( + _df_in_target.drop(columns=["date", "sex", "age_range"]) + .melt(id_vars=["year", "period"], var_name="label", value_name="dalys") + ) + chunked = melted.groupby(["label", "period"])["dalys"].sum() + overall = melted.groupby("label")["dalys"].sum() + overall.index = pd.MultiIndex.from_arrays( + [overall.index, np.repeat(target_period_label, len(overall.index))], names=["label", "period"] + ) + return pd.concat([chunked, overall]).sort_index() + + return _get_num_dalys_by_cause_label_and_period + + +def get_num_deaths_by_age_group( + _df: pd.DataFrame, + age_grp_lookup: dict, + target_period_tuple: tuple[Date, Date], +): + """Return total deaths by age-group in target period.""" + return ( + _df.loc[pd.to_datetime(_df.date).between(*target_period_tuple)] + .groupby(_df["age"].map(age_grp_lookup).astype(make_age_grp_types())) + .size() + ) + + +def get_total_num_death_by_agegrp_and_label( + _df: pd.DataFrame, + target_period_tuple: tuple[Date, Date], +) -> pd.Series: + """Return deaths in target period by age-group and cause label.""" + _df_limited_to_dates = _df.loc[_df["date"].between(*target_period_tuple)] + age_group = to_age_group(_df_limited_to_dates["age"]) + return _df_limited_to_dates.groupby([age_group, "label"])["person_id"].size() + + +def get_total_num_dalys_by_agegrp_and_label( + _df: pd.DataFrame, + target_period_tuple: tuple[Date, Date], +) -> pd.Series: + """Return DALYS in target period by age-group and cause label.""" + return ( + _df.loc[_df.year.between(*[i.year for i in target_period_tuple])] + .assign(age_group=_df["age_range"]) + .drop(columns=["date", "year", "sex", "age_range"]) + .melt(id_vars=["age_group"], var_name="label", value_name="dalys") + .groupby(by=["age_group", "label"])["dalys"] + .sum() + ) + + +def get_counts_of_hsi_by_short_treatment_id( + _df: pd.DataFrame, + target_period_tuple: tuple[Date, Date], +) -> pd.Series: + """Get counts of short treatment ids occurring in target period.""" + mask = pd.to_datetime(_df["date"]).between(*target_period_tuple) + _counts_by_treatment_id = _df.loc[mask, "TREATMENT_ID"].apply(pd.Series).sum().astype(int) + return _counts_by_treatment_id + + +def get_counts_of_appts(_df: pd.DataFrame, target_period_tuple: tuple[Date, Date]) -> pd.Series: + """Get counts of appointments of each type being used in target period.""" + return ( + _df.loc[pd.to_datetime(_df["date"]).between(*target_period_tuple), "Number_By_Appt_Type_Code"] + .apply(pd.Series) + .sum() + .astype(int) + ) + + +def make_get_counts_of_appts_by_period( + period_length_years: int, + target_period_tuple: tuple[Date, Date], +): + """Create helper that summarizes appointment counts by period chunks + overall.""" + periods = get_periods_within_target_period( + period_length_years=period_length_years, + target_period_tuple=target_period_tuple, + ) + period_lookup = { + year: period_label + for period_label, (start_year, end_year) in periods + for year in range(start_year, end_year + 1) + } + target_period_label = target_period(target_period_tuple) + + def _get_counts_of_appts_by_period(_df: pd.DataFrame) -> pd.Series: + _df_in_target = _df.loc[pd.to_datetime(_df["date"]).between(*target_period_tuple)].copy() + _df_in_target["year"] = pd.to_datetime(_df_in_target["date"]).dt.year + _df_in_target["period"] = _df_in_target["year"].map(period_lookup) + + appts = _df_in_target["Number_By_Appt_Type_Code"].apply(pd.Series) + chunked = appts.groupby(_df_in_target["period"]).sum().T.stack() + chunked.index = chunked.index.set_names(["appt_type", "period"]) + + overall = appts.sum() + overall.index = pd.MultiIndex.from_arrays( + [overall.index, np.repeat(target_period_label, len(overall.index))], + names=["appt_type", "period"], + ) + return pd.concat([chunked, overall]).astype(int).sort_index() + + return _get_counts_of_appts_by_period + + +def make_get_counts_of_hsis_by_period( + period_length_years: int, + target_period_tuple: tuple[Date, Date], +): + """Create helper that summarizes appointment counts by period chunks + overall.""" + periods = get_periods_within_target_period( + period_length_years=period_length_years, + target_period_tuple=target_period_tuple, + ) + period_lookup = { + year: period_label + for period_label, (start_year, end_year) in periods + for year in range(start_year, end_year + 1) + } + target_period_label = target_period(target_period_tuple) + + def _get_counts_of_hsis_by_period(_df: pd.DataFrame) -> pd.Series: + _df_in_target = _df.loc[pd.to_datetime(_df["date"]).between(*target_period_tuple)].copy() + _df_in_target["year"] = pd.to_datetime(_df_in_target["date"]).dt.year + _df_in_target["period"] = _df_in_target["year"].map(period_lookup) + + hsis = _df_in_target["TREATMENT_ID"].apply(pd.Series) + chunked = hsis.groupby(_df_in_target["period"]).sum().T.stack() + chunked.index = chunked.index.set_names(["appt_type", "period"]) + + overall = hsis.sum() + overall.index = pd.MultiIndex.from_arrays( + [overall.index, np.repeat(target_period_label, len(overall.index))], + names=["appt_type", "period"], + ) + return pd.concat([chunked, overall]).astype(int).sort_index() + + return _get_counts_of_hsis_by_period + + + + +# Get available staff count for each year and draw +def get_staff_count_by_facid_and_officer_type(_df: pd.DataFrame) -> pd.Series: + """ + Convert logged staff dictionary output into tidy format, + summing staff counts across all clinic columns. + + Returns pd.Series indexed by: + (year, FacilityID, Officer) + """ + + df = _df.copy() + df["year"] = df["date"].dt.year + df = df.drop(columns=["date"]) + + clinic_cols = df.columns.difference(["year"]) + + long_frames = [] + + for clinic in clinic_cols: + expanded = df[[clinic, "year"]].copy() + expanded = expanded[expanded[clinic].notna()] + + expanded_dict = expanded[clinic].apply(pd.Series) + expanded_dict["year"] = expanded["year"].values + + long_frames.append(expanded_dict) + + # Combine all clinics + combined = pd.concat(long_frames, ignore_index=True) + + # Melt to long format + long_df = ( + combined + .melt(id_vars=["year"], + var_name="facility_officer", + value_name="count") + .dropna(subset=["count"]) + ) + + # Split FacilityID and Officer + parts = long_df["facility_officer"].str.split("_Officer_", expand=True) + + long_df["FacilityID"] = ( + parts[0] + .str.replace("FacilityID_", "", regex=False) + .astype(int) + ) + long_df["Officer"] = parts[1] + + # SUM ACROSS CLINICS HERE + result = ( + long_df + .groupby(["year", "FacilityID", "Officer"])["count"] + .sum() + .sort_index() + ) + + return result + +# Get list of cadres which were utilised in each run to get the count of staff used in the simulation +# Note that we still cost the full staff count for any cadre-Facility_Level combination that was ever used in a run, +# and not the amount of time which was used +def get_capacity_used_by_officer_type_and_facility_level( + _df: pd.DataFrame, + facility_id_levels_dict +) -> pd.Series: + """ + Parse logging output and return a Series indexed by: + (year, OfficerType, FacilityLevel) + + Collapses (sums) across clinics. + Uses facility_id_levels_dict to map FacilityID → FacilityLevel. + """ + + # ---- 1. Set year index ---- + _df = _df.set_axis(_df["date"].dt.year).drop(columns=["date"]) + _df.index.name = "year" + + # ---- 2. Unflatten logging columns ---- + _df = unflatten_flattened_multi_index_in_logging(_df) + + # Expect columns like: + # ('Clinic', 'facID_and_officer') + + col_df = _df.columns.to_frame(index=False) + + # ---- 3. Extract OfficerType ---- + col_df["OfficerType"] = ( + col_df["facID_and_officer"] + .str.split("_Officer_") + .str[-1] + ) + + # ---- 4. Extract FacilityID ---- + col_df["FacilityID"] = ( + col_df["facID_and_officer"] + .str.split("_Officer_") + .str[0] + .str.replace("FacilityID_", "", regex=False) + .astype(int) + ) + + # ---- 5. Map to FacilityLevel ---- + col_df["FacilityLevel"] = col_df["FacilityID"].map(facility_id_levels_dict) + + # ---- 6. Rebuild MultiIndex (drop clinic level) ---- + _df.columns = pd.MultiIndex.from_frame( + col_df[["OfficerType", "FacilityLevel"]] + ) + + # ---- 7. Collapse across clinics ---- + _df = _df.groupby(level=["OfficerType", "FacilityLevel"], axis=1).sum() + + # ---- 8. Return stacked format ---- + return _df.stack(["OfficerType", "FacilityLevel"]) + +def melt_model_output_draws_and_runs(_df, id_vars): + multi_index = pd.MultiIndex.from_tuples(_df.columns) + _df.columns = multi_index + melted_df = pd.melt(_df, id_vars=id_vars).rename(columns={'variable_0': 'draw', 'variable_1': 'run'}) + return melted_df diff --git a/src/scripts/lcoa_inputs_from_tlo_analyses/run_preaggregated_optimizer.py b/src/scripts/lcoa_inputs_from_tlo_analyses/run_preaggregated_optimizer.py new file mode 100644 index 0000000000..09b36f252d --- /dev/null +++ b/src/scripts/lcoa_inputs_from_tlo_analyses/run_preaggregated_optimizer.py @@ -0,0 +1,248 @@ +"""Run the preaggregated R optimizer from Python inputs. + +This script: +1. Loads analysis outputs produced by analysis_effect_of_treatment_ids.py. +2. Builds and writes the optimizer intervention input CSV. +3. Loads optimizer constraints from a separate CSV. +4. Invokes optimizer_preaggregated.R::find_optimal_package via rpy2. +5. Writes optimizer outputs to JSON and optional CSV. +""" + +from __future__ import annotations + +from scripts.lcoa_inputs_from_tlo_analyses.results_processing_utils import format_scenario_name + +import argparse +import json +import pickle +from pathlib import Path +from typing import Any + +import numpy as np +import pandas as pd + + +OPTIMIZER_HR_COLS = ["hr_clin", "hr_nur", "hr_pharm", "hr_lab", "hr_ment", "hr_nutri"] +REQUIRED_OPT_INPUT_COLS = [ + "intcode", + "intervention", + "ce_dalys", + "conscost", + "feascov", + "ce_cost", + *OPTIMIZER_HR_COLS, +] + +GLOBAL_REQUIRED_KEYS = { + "objective_input", + "cet_input", + "drug_budget_input", + "drug_budget.scale", + "use_feasiblecov_constraint", + "feascov_scale", + "compcov_scale", + "task_shifting_pharm", +} + + +def _require_columns(df: pd.DataFrame, required: list[str], df_name: str) -> None: + missing = [c for c in required if c not in df.columns] + if missing: + raise ValueError(f"{df_name} is missing required columns: {missing}") + +# TODO: Check with Sakshi if we only use the central value. +def _coerce_central_series(df: pd.DataFrame) -> pd.Series: + out = df["central"].copy() + out.index = out.index.map(format_scenario_name) + return out.astype(float) + + +def _rename_hrh_map(_df): + """Map officer type labels from model output to optimizer cadre buckets. + + The mapping is deterministic and keyword-based. Unknown officer types are ignored. + """ + # TODO check with Sakshi + # This mapping silently ignores 'DCSA', 'Dental' and 'Radiography' cadres + mapping = ( + { + 'Clinical': 'hr_clin', + 'Laboratory': 'hr_lab', + 'Mental': 'hr_ment', + 'Nursing_and_Midwifery': 'hr_nur', + 'Nutrition': 'hr_nutri', + 'Pharmacy': 'hr_pharm', + } + ) + # Rename dataframes indexed by officer type to what they are called in the + # optimizer + renamed = _df.rename(index=mapping) + return renamed + +def _build_optimizer_inputs(results: dict[str, Any]) -> pd.DataFrame: + + dalys_averted = results.get("dalys_averted") + incremental_cost = results.get("incremental_scenario_cost") + capacity_used = _rename_hrh_map(results.get("capacity_used_by_cadre")) + + ce_dalys = dalys_averted['central'] + ce_cost = incremental_cost['central'] + hr_needs = capacity_used.xs("central", level="stat", axis=1).T + + interventions = sorted(set(ce_dalys.index).intersection(set(ce_cost.index))) + if not interventions: + raise ValueError("No overlapping interventions found between DALYs and costs.") + + opt_df = pd.DataFrame( + { + "intcode": range(1, len(interventions) + 1), + "intervention": interventions, + "ce_dalys": [float(ce_dalys.loc[i]) for i in interventions], + "ce_cost": [float(ce_cost.loc[i]) for i in interventions], + "conscost": [float(ce_cost.loc[i]) for i in interventions], + "hr_clin": [float(hr_needs.loc[i, "hr_clin"]) for i in interventions], + "hr_nur": [float(hr_needs.loc[i, "hr_nur"]) for i in interventions], + "hr_pharm": [float(hr_needs.loc[i, "hr_pharm"]) for i in interventions], + "hr_lab": [float(hr_needs.loc[i, "hr_lab"]) for i in interventions], + "hr_ment": [float(hr_needs.loc[i, "hr_ment"]) for i in interventions], + "hr_nutri": [float(hr_needs.loc[i, "hr_nutri"]) for i in interventions], + } + ) + + return opt_df + + +def _jsonify(value: Any) -> Any: + if isinstance(value, dict): + return {str(k): _jsonify(v) for k, v in value.items()} + if isinstance(value, (list, tuple)): + return [_jsonify(v) for v in value] + if isinstance(value, pd.Series): + return {str(k): _jsonify(v) for k, v in value.to_dict().items()} + if isinstance(value, pd.DataFrame): + return value.to_dict(orient="records") + if isinstance(value, np.ndarray): + return [_jsonify(v) for v in value.tolist()] + if isinstance(value, (np.integer,)): + return int(value) + if isinstance(value, (np.floating, float)): + return float(value) + if pd.isna(value): + return None + return value + + +def _run_optimizer_via_rpy2( + optimizer_inputs: pd.DataFrame, + constraints: dict[str, Any], + r_script_path: Path, +) -> dict[str, Any]: + try: + import rpy2.robjects as ro + from rpy2.robjects import pandas2ri + from rpy2.robjects.conversion import localconverter + from rpy2.robjects.vectors import FloatVector, ListVector, StrVector + except ImportError as exc: + raise RuntimeError( + "rpy2 is required but not available. Install rpy2 in your Python environment." + ) from exc + + if not r_script_path.exists(): + raise FileNotFoundError(f"R script not found: {r_script_path}") + + ro.r["source"](str(r_script_path)) + r_func = ro.globalenv.find("find_optimal_package") + + with localconverter(ro.default_converter + pandas2ri.converter): + r_inputs = ro.conversion.py2rpy(optimizer_inputs) + + r_compulsory = StrVector([]) + r_subs = ListVector({}) + + result_r = r_func( + r_inputs, + # whether we are maximizing DALYs or net health + "dalys", + # CET; I believe not relevant here but give a value anyway + 600, + # Drug budget input + constraints['annual_consumables_budget'], + # Drug budget scale set to 1 + 1, + # HR constraints; need to be clinical staff, nursing, pharmacy, lab, + # mental health, nutrition in that order + FloatVector(constraints["hr_time_constraint"]), + # HR size; same order as above + FloatVector(constraints["hr_size"]), + 1, + # use_feasiblecov_constraint; set to 0 to not use, 1 to use + 0, + # Feasible coverage scale; set to 1 + 1, + # Compulsory coverage scale; set to 1 + 1, + # Compulsory interventions; pass empty list, + r_compulsory, + # substitutes; pass empty list + r_subs, + # task_shifting_pharm; set to 0 to not allow, 1 to allow + 0, + ) + + with localconverter(ro.default_converter + pandas2ri.converter): + result_py = ro.conversion.rpy2py(result_r) + + # rpy2 can return named list-like objects; normalize to dict. + if isinstance(result_py, dict): + return {str(k): _jsonify(v) for k, v in result_py.items()} + + if hasattr(result_r, "names"): + out: dict[str, Any] = {} + names = list(result_r.names) + for i, name in enumerate(names): + out[str(name)] = _jsonify(ro.conversion.rpy2py(result_r[i])) + return out + + raise RuntimeError("Unexpected optimizer result type from R.") + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--analysis-results-pkl", type=Path, required=True) + parser.add_argument("--optimizer-output-json", type=Path, required=True) + parser.add_argument( + "--r-script-path", + type=Path, + default=Path("src/scripts/lcoa_inputs_from_tlo_analyses/optimizer_preaggregated.R"), + ) + return parser.parse_args() + + +def main() -> None: + args = _parse_args() + + if not args.analysis_results_pkl.exists(): + raise FileNotFoundError(f"Analysis results pickle not found: {args.analysis_results_pkl}") + + with open(args.analysis_results_pkl, "rb") as f: + results = pickle.load(f) + + constraints = ({ + 'annual_consumables_budget': results.get("annual_consumables_budget"), + 'hr_time_constraint': _rename_hrh_map(results.get("annual_capacity_by_cadre")), + 'hr_size': _rename_hrh_map(results.get("staff_count_by_cadre")) + }) + optimizer_inputs = _build_optimizer_inputs(results) + + optimizer_output = _run_optimizer_via_rpy2( + optimizer_inputs=optimizer_inputs, + constraints=constraints, + r_script_path=args.r_script_path, + ) + + args.optimizer_output_json.parent.mkdir(parents=True, exist_ok=True) + with open(args.optimizer_output_json, "w", encoding="utf-8") as f: + json.dump(_jsonify(optimizer_output), f, indent=2, sort_keys=True) + +if __name__ == "__main__": + main() diff --git a/src/scripts/lcoa_inputs_from_tlo_analyses/scenario_effect_of_treatment_ids.py b/src/scripts/lcoa_inputs_from_tlo_analyses/scenario_effect_of_treatment_ids.py new file mode 100644 index 0000000000..64e5c79639 --- /dev/null +++ b/src/scripts/lcoa_inputs_from_tlo_analyses/scenario_effect_of_treatment_ids.py @@ -0,0 +1,127 @@ +""" +This file contains all the definitions of scenarios for the TLO-LCOA project. + +It runs the full model under a set of scenario in which only a single TREATMENT_ID is included. + +To check scenarios are generated correctly: +``` +tlo scenario-run --draw-only src/scripts/lcoa_inputs_from_tlo_analyses/scenario_effect_of_treatment_ids.py +``` + +Run on the batch system using: + +``` +tlo batch-submit src/scripts/lcoa_inputs_from_tlo_analyses/scenario_effect_of_treatment_ids.py +``` + +or locally using: +``` +tlo scenario-run src/scripts/lcoa_inputs_from_tlo_analyses/scenario_effect_of_treatment_ids.py +``` + +""" + +from pathlib import Path +from typing import Dict, List +from tlo import Date, logging +from tlo.analysis.utils import get_filtered_treatment_ids, mix_scenarios, get_parameters_for_status_quo +from tlo.methods.fullmodel import fullmodel +from tlo.methods.scenario_switcher import ImprovedHealthSystemAndCareSeekingScenarioSwitcher +from tlo.scenario import BaseScenario +from tlo.methods.individual_history_tracker import IndividualHistoryTracker + + +class ScenarioDefinitions: + @property + def YEAR_OF_SERVICE_AVAILABILITY_SWITCH(self) -> int: + return 2011 + + def baseline(self) -> Dict: + """Return the Dict with values for the parameter changes that define the baseline scenario.""" + return mix_scenarios( + get_parameters_for_status_quo(), # <-- Parameters that have been the calibration targets + { + "HealthSystem": { + "cons_availability": "default", + "year_cons_availability_switch": self.YEAR_OF_SERVICE_AVAILABILITY_SWITCH, + "cons_availability_postSwitch": "all", + "mode_appt_constraints": 1, + "year_service_availability_switch": self.YEAR_OF_SERVICE_AVAILABILITY_SWITCH, + # allow historical HRH scaling to occur 2018-2024 + # 'year_HR_scaling_by_level_and_officer_type': self.YEAR_OF_SERVICE_AVAILABILITY_SWITCH, + "yearly_HR_scaling_mode": "historical_scaling", + }, + "ImprovedHealthSystemAndCareSeekingScenarioSwitcher": { + "max_healthsystem_function": [False, True], # <-- switch from False to True mid-way + "max_healthcare_seeking": [False, True], # <-- switch from False to True mid-way + "year_of_switch": self.YEAR_OF_SERVICE_AVAILABILITY_SWITCH, + }, + }, + ) + + +class EffectOfEachTreatment(BaseScenario): + def __init__(self): + super().__init__() + self.seed = 0 + self.start_date = Date(2010, 1, 1) + self.end_date = Date(2031, 1, 1) + self.pop_size = 1000 + self._scenarios = self._get_scenarios() + self.number_of_draws = len(self._scenarios) + self.runs_per_draw = 5 + + def log_configuration(self): + return { + "filename": "effect_of_each_treatment_id", + "directory": Path("./outputs"), + "custom_levels": { + "*": logging.WARNING, + "tlo.methods.demography": logging.INFO, + "tlo.methods.demography.detail": logging.WARNING, + "tlo.methods.healthburden": logging.INFO, + "tlo.methods.healthsystem.summary": logging.INFO, + "tlo.methods.individual_history_tracker": logging.INFO, + }, + } + + def modules(self): + return fullmodel() + [ImprovedHealthSystemAndCareSeekingScenarioSwitcher()] + + def draw_parameters(self, draw_number, rng): + if draw_number < len(self._scenarios): + return list(self._scenarios.values())[draw_number] + + def _get_scenarios(self) -> Dict[str, Dict]: + """Return the Dict with values for the parameter `Service_Availability` keyed by a name for the scenario. + The sequences of scenarios systematically omits all but one TREATMENT_ID that is defined in the model.""" + + # Generate list of TREATMENT_IDs and filter to the resolution needed + treatments = get_filtered_treatment_ids(depth=None) + + # Return 'Service_Availability' values, with scenarios for nothing, and ones for which all but one + # treatment is omitted + service_availability = dict({"Nothing": []}) + # For each treatment group, create scenarios keeping only one treatment from that group + # Commenting to allow draw 0 to be run and suspended. + service_availability.update( + {f"Only {treatment}": [treatment] for treatment in treatments} + ) + ##service_availability = {"Only Rti_TetanusVaccine": ["Rti_TetanusVaccine"]} + + scenario_definitions = ScenarioDefinitions() + + scenarios = { + key: mix_scenarios( + scenario_definitions.baseline(), {"HealthSystem": {"service_availability_postSwitch": value}} + ) + for key, value in service_availability.items() + } + + return scenarios + + +if __name__ == "__main__": + from tlo.cli import scenario_run + + scenario_run([__file__]) diff --git a/src/scripts/lcoa_inputs_from_tlo_analyses/scenario_effect_of_treatment_ids_no_suspend.py b/src/scripts/lcoa_inputs_from_tlo_analyses/scenario_effect_of_treatment_ids_no_suspend.py new file mode 100644 index 0000000000..9c0c3b7c63 --- /dev/null +++ b/src/scripts/lcoa_inputs_from_tlo_analyses/scenario_effect_of_treatment_ids_no_suspend.py @@ -0,0 +1,116 @@ +""" +This file contains all the definitions of scenarios for the TLO-LCOA project. + +It runs the full model under a set of scenario in which only a single TREATMENT_ID is included. + + +To check scenarios are generated correctly: +``` +tlo scenario-run --draw-only src/scripts/lcoa_inputs_from_tlo_analyses/scenario_effect_of_treatment_ids_no_suspend.py +``` + +Run on the batch system using: + +``` +tlo batch-submit src/scripts/lcoa_inputs_from_tlo_analyses/scenario_effect_of_treatment_ids_no_suspend.py +``` + +or locally using: +``` +tlo scenario-run src/scripts/lcoa_inputs_from_tlo_analyses/scenario_effect_of_treatment_ids_no_suspend.py +``` + +""" + +from pathlib import Path +from typing import Dict, List +from tlo import Date, logging +from tlo.analysis.utils import mix_scenarios, get_parameters_for_status_quo +from tlo.methods.fullmodel import fullmodel +from tlo.scenario import BaseScenario + + + +class ScenarioDefinitions: + @property + def YEAR_OF_SERVICE_AVAILABILITY_SWITCH(self) -> int: + return 2026 + + def baseline(self) -> Dict: + """Return the Dict with values for the parameter changes that define the baseline scenario.""" + return mix_scenarios( + get_parameters_for_status_quo(), # <-- Parameters that have been the calibration targets + { + "HealthSystem": { + "cons_availability": "default", + "mode_appt_constraints": 1, + "year_service_availability_switch": self.YEAR_OF_SERVICE_AVAILABILITY_SWITCH, + # allow historical HRH scaling to occur 2018-2024 + # 'year_HR_scaling_by_level_and_officer_type': self.YEAR_OF_SERVICE_AVAILABILITY_SWITCH, + "yearly_HR_scaling_mode": "historical_scaling", + } + }, + ) + + +class EffectOfEachTreatment(BaseScenario): + def __init__(self): + super().__init__() + self.seed = 0 + self.start_date = Date(2010, 1, 1) + self.end_date = Date(2041, 1, 1) + self.pop_size = 50_000 + self._scenarios = self._get_scenarios() + self.number_of_draws = len(self._scenarios) + self.runs_per_draw = 5 + + def log_configuration(self): + return { + "filename": "effect_of_each_treatment_id", + "directory": Path("./outputs"), + "custom_levels": { + "*": logging.WARNING, + "tlo.methods.demography": logging.INFO, + "tlo.methods.demography.detail": logging.WARNING, + "tlo.methods.healthburden": logging.INFO, + "tlo.methods.healthsystem.summary": logging.INFO, + }, + } + + def modules(self): + return fullmodel() + + def draw_parameters(self, draw_number, rng): + if draw_number < len(self._scenarios): + return list(self._scenarios.values())[draw_number] + + def _get_scenarios(self) -> Dict[str, Dict]: + """Return the Dict with values for the parameter `Service_Availability` keyed by a name for the scenario. + The sequences of scenarios systematically omits all but one TREATMENT_ID that is defined in the model.""" + + # Generate list of TREATMENT_IDs and filter to the resolution needed + treatments = ["Epilepsy_Treatment_Start_*"] + # Return 'Service_Availability' values, with scenarios for nothing, and ones for which all but one + # treatment is omitted + service_availability = dict() + # For each treatment group, create scenarios keeping only one treatment from that group + service_availability.update( + {f"Only {treatment}": [treatment] for treatment in treatments} + ) + + scenario_definitions = ScenarioDefinitions() + + scenarios = { + key: mix_scenarios( + scenario_definitions.baseline(), {"HealthSystem": {"service_availability_postSwitch": value}} + ) + for key, value in service_availability.items() + } + + return scenarios + + +if __name__ == "__main__": + from tlo.cli import scenario_run + + scenario_run([__file__]) diff --git a/src/tlo/analysis/utils.py b/src/tlo/analysis/utils.py index 6b4d2cbf9b..49bf511cec 100644 --- a/src/tlo/analysis/utils.py +++ b/src/tlo/analysis/utils.py @@ -203,7 +203,7 @@ def get_scenario_outputs(scenario_filename: str, outputs_dir: Path) -> list: return folders -def get_scenario_info(scenario_output_dir: Path) -> dict: +def get_scenario_info(scenario_output_dir: Path, autodiscover: bool = False) -> dict: """Utility function to get the the number draws and the number of runs in a batch set. TODO: read the JSON file to get further information @@ -211,6 +211,22 @@ def get_scenario_info(scenario_output_dir: Path) -> dict: info = dict() draw_folders = [f for f in os.scandir(scenario_output_dir) if f.is_dir()] + if autodiscover: + draw_ids = sorted(int(f.name) for f in draw_folders) + runs_by_draw = { + draw: sorted( + int(f.name) + for f in os.scandir(scenario_output_dir / str(draw)) + if f.is_dir() + ) + for draw in draw_ids + } + info['draws'] = draw_ids + info['runs_by_draw'] = runs_by_draw + info['number_of_draws'] = len(draw_ids) + info['runs_per_draw'] = len(runs_by_draw[draw_ids[0]]) if draw_ids else 0 + return info + info['number_of_draws'] = len(draw_folders) run_folders = [f for f in os.scandir(draw_folders[0]) if f.is_dir()] @@ -295,6 +311,9 @@ def extract_results(results_folder: Path, index: str = None, custom_generate_series=None, do_scaling: bool = False, + suspended_results_folder: Path = None, + draw_runs: Optional[List[Tuple[int, int]]] = None, + autodiscover: bool = False, ) -> pd.DataFrame: """Utility function to unpack results. @@ -307,16 +326,19 @@ def extract_results(results_folder: Path, `custom_generate_series`. Optionally, with `do_scaling=True`, each element is multiplied by the scaling_factor recorded in the simulation. + If the suspend-and-resume functionality is used, scaling factor may be avaialble in the folder where the log + of the suspended run are stored. Note that if runs in the batch have failed (such that logs have not been generated), these are dropped silently. """ - def get_multiplier(_draw, _run): + + def get_multiplier(results_folder, _draw, _run): """Helper function to get the multiplier from the simulation. Note that if the scaling factor cannot be found a `KeyError` is thrown.""" return load_pickled_dataframes( - results_folder, _draw, _run, 'tlo.methods.population' - )['tlo.methods.population']['scaling_factor']['scaling_factor'].values[0] + results_folder, _draw, _run, 'tlo.methods.demography' + )['tlo.methods.demography']['scaling_factor']['scaling_factor'].values[0] if custom_generate_series is None: # If there is no `custom_generate_series` provided, it implies that function required selects the specified @@ -335,30 +357,46 @@ def generate_series(dataframe: pd.DataFrame) -> pd.Series: else: return custom_generate_series(dataframe) - # get number of draws and numbers of runs - info = get_scenario_info(results_folder) + if draw_runs is not None: + selected_draw_runs = draw_runs + elif autodiscover: + # get number of draws and numbers of runs + info = get_scenario_info(results_folder, autodiscover) + selected_draw_runs = [ + (draw, run) + for draw in info['draws'] + for run in info['runs_by_draw'][draw] + ] + else: + # Legacy default behaviour: infer ranges from scenario info. + info = get_scenario_info(results_folder) + selected_draw_runs = [ + (draw, run) + for draw in range(info['number_of_draws']) + for run in range(info['runs_per_draw']) + ] # Collect results from each draw/run res = dict() - for draw in range(info['number_of_draws']): - for run in range(info['runs_per_draw']): - - draw_run = (draw, run) - - try: - df: pd.DataFrame = load_pickled_dataframes(results_folder, draw, run, module)[module][key] - output_from_eval: pd.Series = generate_series(df) - assert isinstance(output_from_eval, pd.Series), ( - 'Custom command does not generate a pd.Series' - ) - if do_scaling: - res[draw_run] = output_from_eval * get_multiplier(draw, run) + for draw, run in selected_draw_runs: + draw_run = (draw, run) + try: + df: pd.DataFrame = load_pickled_dataframes(results_folder, draw, run, module)[module][key] + output_from_eval: pd.Series = generate_series(df) + assert isinstance(output_from_eval, pd.Series), ( + 'Custom command does not generate a pd.Series' + ) + if do_scaling: + if suspended_results_folder is not None: + res[draw_run] = output_from_eval * get_multiplier(suspended_results_folder, 0, 0) else: - res[draw_run] = output_from_eval + res[draw_run] = output_from_eval * get_multiplier(results_folder, draw, run) + else: + res[draw_run] = output_from_eval - except KeyError: - # Some logs could not be found - probably because this run failed. - res[draw_run] = None + except KeyError: + # Some logs could not be found - probably because this run failed. + res[draw_run] = None # Use pd.concat to compile results (skips dict items where the values is None) _concat = pd.concat(res, axis=1) @@ -386,7 +424,7 @@ def check_info_value_changes(df): prev_info = row["Info"] return problems - + def remove_events_for_individual_after_death(df): rows_to_drop = [] @@ -430,8 +468,8 @@ def reconstruct_individual_histories(df): if len(problems)>0: print("Values didn't change but were still detected") print(problems) - - + + return df_final diff --git a/src/tlo/methods/healthsystem.py b/src/tlo/methods/healthsystem.py index f8b0f55a03..d9d6f189c6 100644 --- a/src/tlo/methods/healthsystem.py +++ b/src/tlo/methods/healthsystem.py @@ -255,12 +255,19 @@ class HealthSystem(Module): "Year in which the assumption for `equip_availability` changes (The change happens on 1st January of that " "year.)", ), + # Service Availability "Service_Availability": Parameter( Types.LIST, "List of services to be available. NB. This parameter is over-ridden if an argument is provided" " to the module initialiser.", ), + "year_service_availability_switch": Parameter(Types.INT, "Year in which service availability changes."), + "service_availability_postSwitch": Parameter( + Types.LIST, + "List of services to be available after the switch in `year_service_availability_switch`.", + ), + "policy_name": Parameter(Types.STRING, "Name of priority policy adopted"), "year_mode_switch": Parameter(Types.INT, "Year in which mode switch is enforced"), "scale_to_effective_capabilities": Parameter( @@ -897,6 +904,12 @@ def initialise_simulation(self, sim): Date(self.parameters["year_use_funded_or_actual_staffing_switch"], 1, 1), ) + # Schedule service availability switch + sim.schedule_event( + HealthSystemChangeParameters(self,parameters_to_change=["service_availability"]), + Date(self.parameters["year_service_availability_switch"], 1, 1), + ) + # Schedule a one-off rescaling of _daily_capabilities broken down by officer type and level. # This occurs on 1st January of the year specified in the parameters. sim.schedule_event( @@ -1250,17 +1263,28 @@ def format_clinic_capabilities(self) -> pd.DataFrame: return capabilities_ex + def _compute_factors_for_effective_capabilities(self): + """Compute factor to rescale capabilities to capture effective capability. + Computation of these factors is split from the actual rescaling to facilitate + capturing them even when running the model in mode 1.""" + self._rescaling_factors = defaultdict(dict) + for clinic, clinic_cl in self._daily_capabilities.items(): + for facID_and_officer in clinic_cl.keys(): + self._rescaling_factors[clinic][facID_and_officer] = self._summary_counter.frac_time_used_by_facID_and_officer( + facID_and_officer=facID_and_officer, clinic=clinic + ) + self._summary_counter._rescaling_factors = self._rescaling_factors + def _rescale_capabilities_to_capture_effective_capability(self): # Notice that capabilities will only be expanded through this process # (i.e. won't reduce available capabilities if these were under-used in the last year). # Note: Currently relying on module variable rather than parameter for # scale_to_effective_capabilities, in order to facilitate testing. However # this may eventually come into conflict with the Switcher functions. + self._compute_factors_for_effective_capabilities() for clinic, clinic_cl in self._daily_capabilities.items(): for facID_and_officer in clinic_cl.keys(): - rescaling_factor = self._summary_counter.frac_time_used_by_facID_and_officer( - facID_and_officer=facID_and_officer, clinic=clinic - ) + rescaling_factor = self._rescaling_factors[clinic][facID_and_officer] if rescaling_factor > 1 and rescaling_factor != float("inf"): self._daily_capabilities[clinic][facID_and_officer] *= rescaling_factor @@ -1414,6 +1438,10 @@ def get_equip_availability(self) -> str: def schedule_to_call_never_ran_on_date(self, hsi_event: "HSI_Event", tdate: datetime.datetime): """Function to schedule never_ran being called on a given date""" + if self.sim.date > tdate: + print(f"Warning: trying to schedule never_ran for date {tdate} in the past (current simulation date is {self.sim.date}). This event will not be scheduled.") + print(f"treament id is {hsi_event.TREATMENT_ID}") + self.sim.schedule_event(HSIEventWrapper(hsi_event=hsi_event, run_hsi=False), tdate) def get_mode_appt_constraints(self) -> int: @@ -1573,6 +1601,13 @@ def _add_hsi_event_queue_item_to_hsi_event_queue( # Create HSIEventQueue Item, including a counter for the number of HSI_Events, to assist with sorting in the # queue (NB. the sorting is done ascending and by the order of the items in the tuple). + + # First check that the service the HSI needs is available. If not, don't add to queue. + # Don't increment the counter; log and return. + if not self.is_treatment_id_allowed(hsi_event.TREATMENT_ID, self.service_availability): + self.schedule_to_call_never_ran_on_date(hsi_event=hsi_event, tdate=topen) + + self.hsi_event_queue_counter += 1 if self.randomise_queue: @@ -2101,6 +2136,12 @@ def on_end_of_month(self) -> None: def on_end_of_year(self) -> None: """Write to log the current states of the summary counters and reset them.""" + + # If we are at the end of the year preceeding the service availability switch, + # compute rescaling factors. + if (self.sim.date.year == self.parameters['year_service_availability_switch'] - 1): + self._compute_factors_for_effective_capabilities() + # If we are at the end of the year preceeding the mode switch, and if wanted # to rescale capabilities to capture effective availability as was recorded, on # average, in the past year, do so here. @@ -2147,6 +2188,13 @@ def run_individual_level_events_in_mode_1( if event.expected_time_requests: ok_to_run = self.do_all_required_officers_have_nonzero_capabilities( event.expected_time_requests, clinic=clinic) + + # Check here that the treatment id is allowed at this point as service availability might have changed + # since the event was scheduled + if not self.is_treatment_id_allowed(event.TREATMENT_ID, self.service_availability): + call_and_record_never_ran_hsi_event(hsi_event=event, priority=_priority) + continue + if ok_to_run: # Compute the bed days that are allocated to this HSI and provide this information to the HSI @@ -2410,6 +2458,11 @@ def process_events_mode_2(self, hold_over: List[HSIEventQueueItem]) -> None: event_clinic = next_event_tuple.clinic_eligibility capabilities_still_available = set_capabilities_still_available[event_clinic] + # Check here that the treatment id is allowed as service availability might have changed + # since the event was scheduled + if not self.module.is_treatment_id_allowed(event.TREATMENT_ID, self.module.service_availability): + self.module.call_and_record_never_ran_hsi_event(hsi_event=event, priority=next_event_tuple.priority) + if self.sim.date > next_event_tuple.tclose: # The event has expired (after tclose) having never been run. Call the 'never_ran' function self.module.call_and_record_never_ran_hsi_event(hsi_event=event, priority=next_event_tuple.priority) @@ -2730,6 +2783,7 @@ def _reset_internal_stores(self) -> None: self._never_ran_appts = defaultdict(int) # As above, but for `HSI_Event`s that have never ran self._never_ran_appts_by_level = {_level: defaultdict(int) for _level in ("0", "1a", "1b", "2", "3", "4")} + self._rescaling_factors = defaultdict(dict) self._frac_time_used_overall = defaultdict(list) # Running record of the usage of the healthcare system self._sum_of_daily_frac_time_used_by_facID_and_officer = defaultdict(Counter) @@ -2823,6 +2877,7 @@ def write_to_log_and_reset_counters(self): "average_Frac_Time_Used_Overall": { clinic: np.mean(values) for clinic, values in self._frac_time_used_overall.items() }, + "rescaling_factor_for_clinics": self._rescaling_factors, # <-- leaving space here for additional summary measures that may be needed in the future. }, ) @@ -2885,7 +2940,7 @@ def __init__(self, module: HealthSystem, parameters_to_change: List): super().__init__(module) assert isinstance(module, HealthSystem) - self.supported_parameters = ["cons_availability", "equip_availability", "use_funded_or_actual_staffing"] + self.supported_parameters = ["cons_availability", "equip_availability", "use_funded_or_actual_staffing", "service_availability"] if not all(param in self.supported_parameters for param in parameters_to_change): raise ValueError( f"parameters_to_change can only contain the following values: {self.supported_parameters}. " @@ -2906,6 +2961,21 @@ def apply(self, population): if "use_funded_or_actual_staffing" in self.parameters_to_change: self.module.use_funded_or_actual_staffing = p["use_funded_or_actual_staffing_postSwitch"] + if "service_availability" in self.parameters_to_change: + self.module.service_availability = p["service_availability_postSwitch"] + ## As part of the switching, clear the queue of any events currently scheduled + ## that might require one of the omitted services when they actually run. + retained_events = [] + while len(self.module.HSI_EVENT_QUEUE) > 0: + next_event_tuple = hp.heappop(self.module.HSI_EVENT_QUEUE) + if self.module.is_treatment_id_allowed(next_event_tuple.hsi_event.TREATMENT_ID, self.module.service_availability): + retained_events.append(next_event_tuple) + else: + self.module.schedule_to_call_never_ran_on_date(hsi_event=next_event_tuple.hsi_event, tdate=next_event_tuple.topen) + + self.module.HSI_EVENT_QUEUE = retained_events + hp.heapify(self.module.HSI_EVENT_QUEUE) + class DynamicRescalingHRCapabilities(RegularEvent, PopulationScopeEventMixin): """This event exists to scale the daily capabilities assumed at fixed time intervals""" diff --git a/tests/test_healthsystem_general.py b/tests/test_healthsystem_general.py index 17f236e869..2f47eb4678 100644 --- a/tests/test_healthsystem_general.py +++ b/tests/test_healthsystem_general.py @@ -515,6 +515,26 @@ def test_is_treatment_id_allowed(): assert hs.is_treatment_id_allowed("Epi", ["Epi", "Epilepsy_*"]) assert hs.is_treatment_id_allowed("Epilepsy", ["Epi", "Epilepsy_*"]) + ## Service availability switch debugging + excluded_hsis = [ + "FirstAttendance_Emergency_*", + "FirstAttendance_NonEmergency_*", + "FirstAttendance_SpuriousEmergencyCare_*", + ] + treatments = get_filtered_treatment_ids(depth=None) + for treatment_allowed in treatments: + print(f"Allowed {treatment_allowed}") + for treatment_requested in treatments: + # If the only treatment allowed is treatment_allowed then all other treatments should return false + if not treatment_requested == treatment_allowed: + print(f"Requested {treatment_requested}") + if treatment_requested in excluded_hsis: + assert hs.is_treatment_id_allowed(treatment_requested, [treatment_allowed]) + elif treatment_requested.startswith(treatment_allowed.replace("_*", "")): + assert hs.is_treatment_id_allowed(treatment_requested, [treatment_allowed]) + else: + assert not hs.is_treatment_id_allowed(treatment_requested, [treatment_allowed]) + def test_manipulation_of_service_availability(seed, tmpdir): """Check that the parameter `service_availability` can be used to allow/disallow certain `TREATMENT_ID`s. @@ -1529,3 +1549,210 @@ def schedule_hsi_events(ngenericclinic, nclinic1, sim): clinic1_capabilities_before * 2, clinic1_capabilities_after, ), "Expected Clinic1 capabilities to be rescaled by factor of 2" + + +def test_service_availability_switch(tmpdir, seed): + """Test that the service availability is updated in the year specified. + Simultaneously check that the switch triggers related behaviors: + 1) compute and write to logs rescaling factors + 2) clear hsi event queue of any events scheduled to run after the switch + that need one of the unavailable services. + """ + + class DummyModuleGenericClinic(Module): + METADATA = {Metadata.DISEASE_MODULE, Metadata.USES_HEALTHSYSTEM} + + def read_parameters(self, data_folder): + pass + + def initialise_population(self, population): + pass + + def initialise_simulation(self, sim): + pass + + # Create a dummy HSI event class + class DummyHSIEvent(HSI_Event, IndividualScopeEventMixin): + def __init__(self, module, person_id, appt_type, level, treatment_id): + super().__init__(module, person_id=person_id) + self.TREATMENT_ID = treatment_id + self.EXPECTED_APPT_FOOTPRINT = self.make_appt_footprint({appt_type: 1}) + self.ACCEPTED_FACILITY_LEVEL = level + + def apply(self, person_id, squeeze_factor): + self.this_hsi_event_ran = True + + log_config = { + "filename": "log", + "directory": tmpdir, + "custom_levels": {"tlo.methods.healthsystem": logging.DEBUG}, + } + start_date = Date(2010, 1, 1) + + sim = Simulation(start_date=start_date, seed=0, log_config=log_config, resourcefilepath=resourcefilepath) + + sim.register( + demography.Demography(), + healthsystem.HealthSystem( + capabilities_coefficient=1.0, + mode_appt_constraints=1, + ignore_priority=False, + randomise_queue=True, + policy_name="", + use_funded_or_actual_staffing="funded_plus", + ), + DummyModuleGenericClinic(), + ) + + hs_params = sim.modules["HealthSystem"].parameters + hs_params["Service_Availability"] = ["ThisEventShouldRun", "ThisEventShouldNotRunPostSwitch"] + year_service_availability_switch = 2011 + hs_params["year_service_availability_switch"] = year_service_availability_switch + hs_params["service_availability_postSwitch"] = ["ThisEventShouldRun"] + + sim.make_initial_population(n=popsize) + # Schedule 10 events that should run; 10 events that have a treatment id + # that is not available after service availability switch. + nevents_with_available_ids = 60 + nevents_with_withdrawn_ids = 40 + for i in range(0, nevents_with_available_ids): + hsi = DummyHSIEvent( + module=sim.modules["DummyModuleGenericClinic"], + person_id=i, + appt_type="ConWithDCSA", + level="0", + treatment_id="ThisEventShouldRun", + ) + sim.modules["HealthSystem"].schedule_hsi_event( + hsi, topen=sim.date, tclose=sim.date + pd.DateOffset(days=1), priority=1 + ) + + for i in range(nevents_with_available_ids, nevents_with_available_ids + nevents_with_withdrawn_ids): + hsi = DummyHSIEvent( + module=sim.modules["DummyModuleGenericClinic"], + person_id=i, + appt_type="ConWithDCSA", + level="0", + treatment_id="ThisEventShouldNotRunPostSwitch", + ) + # These events open after service availability switch + topen = pd.Timestamp(year_service_availability_switch, 1, 1) + sim.modules["HealthSystem"].schedule_hsi_event( + hsi, topen=topen, tclose=topen + pd.DateOffset(days=1), priority=1 + ) + + sim.simulate(end_date=end_date) + output = parse_log_file(sim.log_filepath, level=logging.DEBUG) + hsi_events = output["tlo.methods.healthsystem"]["HSI_Event"] + ## Expect nevents_with_available_ids rows in hsi_events['HSI_Event'] with did_run True and TREATMENT_ID ThisEventShouldRun + nevents_ran = hsi_events.groupby("TREATMENT_ID")["did_run"].value_counts() + assert nevents_ran.loc[("ThisEventShouldRun", True)] == nevents_with_available_ids + ## Expect nevents_with_withdrawn_ids rows in hsi_events['Never_ran_HSI_Event'] with TREATMENT_ID ThisEventShouldNotRunPostSwitch + never_ran_events = output["tlo.methods.healthsystem"]["Never_ran_HSI_Event"] + nevents_did_not_run = never_ran_events[never_ran_events["TREATMENT_ID"] == "ThisEventShouldNotRunPostSwitch"].shape[ + 0 + ] + assert nevents_did_not_run == nevents_with_withdrawn_ids + + + + +def test_service_availability_with_rescheduling_hsi(tmpdir, seed): + """Test that an HSI that attempts to reschedule itself cannot go ahead + if service availability update has made its treatment id unavailable. + """ + + class DummyModuleGenericClinic(Module): + METADATA = {Metadata.DISEASE_MODULE, Metadata.USES_HEALTHSYSTEM} + + def read_parameters(self, data_folder): + pass + + def initialise_population(self, population): + pass + + def initialise_simulation(self, sim): + pass + + # Create a dummy HSI event class + class DummyHSIEvent(HSI_Event, IndividualScopeEventMixin): + def __init__(self, module, person_id, appt_type, level, treatment_id): + super().__init__(module, person_id=person_id) + self.TREATMENT_ID = treatment_id + self.EXPECTED_APPT_FOOTPRINT = self.make_appt_footprint({appt_type: 1}) + self.ACCEPTED_FACILITY_LEVEL = level + + def apply(self, person_id, squeeze_factor): + self.this_hsi_event_ran = True + sim.modules["HealthSystem"].schedule_hsi_event( + self, topen=self.sim.date + pd.DateOffset(years=1), tclose=None, priority=1 + ) + sim.modules["HealthSystem"].schedule_hsi_event( + self, topen=self.sim.date + pd.DateOffset(years=2), tclose=None, priority=1 + ) + sim.modules["HealthSystem"].schedule_hsi_event( + self, topen=self.sim.date + pd.DateOffset(years=3), tclose=None, priority=1 + ) + sim.modules["HealthSystem"].schedule_hsi_event( + self, topen=self.sim.date + pd.DateOffset(years=4), tclose=None, priority=1 + ) + + + + log_config = { + "filename": "log", + "directory": tmpdir, + "custom_levels": {"tlo.methods.healthsystem": logging.DEBUG}, + } + start_date = Date(2010, 1, 1) + end_date = Date(2015, 1, 1) + sim = Simulation(start_date=start_date, seed=0, log_config=log_config, resourcefilepath=resourcefilepath) + + sim.register( + demography.Demography(), + healthsystem.HealthSystem( + capabilities_coefficient=1.0, + mode_appt_constraints=1, + ignore_priority=False, + randomise_queue=True, + policy_name="", + use_funded_or_actual_staffing="funded_plus", + ), + DummyModuleGenericClinic(), + ) + + hs_params = sim.modules["HealthSystem"].parameters + # First allow everything + hs_params["Service_Availability"] = ['*'] + year_service_availability_switch = 2011 + hs_params["year_service_availability_switch"] = year_service_availability_switch + # Post switch treatment id ThisEventShouldNotRunPostSwitch is unavailable + hs_params["service_availability_postSwitch"] = ["ThisEventShouldRunPostSwitch"] + + sim.make_initial_population(n=popsize) + # Schedule event with treatment id ThisEventShouldNotRunPostSwitch + # so that it runs successfully the first time, and reschedules itself. + hsi = DummyHSIEvent( + module=sim.modules["DummyModuleGenericClinic"], + person_id=1, + appt_type="ConWithDCSA", + level="0", + treatment_id="ThisEventShouldNotRunPostSwitch", + ) + sim.modules["HealthSystem"].schedule_hsi_event( + hsi, topen=start_date, tclose=end_date, priority=1 + ) + sim.simulate(end_date=end_date) + output = parse_log_file(sim.log_filepath, level=logging.DEBUG) + hsi_events = output["tlo.methods.healthsystem"]["HSI_Event"] + # Expect the first instance of this HSI to have run, since we scheduled it + # to run before service availability switch + nevents_ran = hsi_events.groupby("TREATMENT_ID")["did_run"].value_counts() + assert nevents_ran.loc[("ThisEventShouldNotRunPostSwitch", True)] == 1 + # and all subsequent instances to have not run. + never_ran_events = output["tlo.methods.healthsystem"]["Never_ran_HSI_Event"] + nevents_did_not_run = never_ran_events[never_ran_events["TREATMENT_ID"] == "ThisEventShouldNotRunPostSwitch"].shape[ + 0 + ] + # Since we scheduled it in 4 years after the first successful run + assert nevents_did_not_run == 4