diff --git a/pyproject.toml b/pyproject.toml index b7501fda..02f493db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,7 +94,7 @@ combine-as-imports = true convention = "google" [tool.ruff.lint.pylint] -max-args = 7 +max-args = 8 [tool.ruff.format] quote-style = "double" diff --git a/src/nhp/model/aae.py b/src/nhp/model/aae.py index 9dba1f41..68e0656c 100644 --- a/src/nhp/model/aae.py +++ b/src/nhp/model/aae.py @@ -26,6 +26,7 @@ class AaEModel(Model): run_params: The parameters to use for each model run. Generated automatically if left as None. Defaults to None. save_full_model_results: Whether to save the full model results or not. Defaults to False. + aggregation_columns: The columns to use for aggregation. Defaults to None. """ def __init__( @@ -35,6 +36,7 @@ def __init__( hsa: Any = None, run_params: dict | None = None, save_full_model_results: bool = False, + aggregation_columns: list[str] | None = None, ) -> None: """Initialise the A&E Model. @@ -44,6 +46,7 @@ def __init__( hsa: Health Status Adjustment object. Defaults to None. run_params: The run parameters to use. Defaults to None. save_full_model_results: Whether to save full model results. Defaults to False. + aggregation_columns: The columns to use for aggregation. Defaults to None. """ # call the parent init function super().__init__( @@ -54,6 +57,7 @@ def __init__( hsa, run_params, save_full_model_results, + aggregation_columns, ) def _get_data(self, data_loader: Data) -> pd.DataFrame: diff --git a/src/nhp/model/inpatients.py b/src/nhp/model/inpatients.py index 1cf04e15..4df3d937 100644 --- a/src/nhp/model/inpatients.py +++ b/src/nhp/model/inpatients.py @@ -27,6 +27,7 @@ class InpatientsModel(Model): run_params: The parameters to use for each model run. Generated automatically if left as None. Defaults to None. save_full_model_results: Whether to save the full model results or not. Defaults to False. + aggregation_columns: The columns to use for aggregation. Defaults to None. """ def __init__( @@ -36,6 +37,7 @@ def __init__( hsa: Any = None, run_params: dict | None = None, save_full_model_results: bool = False, + aggregation_columns: list[str] | None = None, ) -> None: """Initialise the Inpatients Model. @@ -45,6 +47,7 @@ def __init__( hsa: Health Status Adjustment object. Defaults to None. run_params: The run parameters to use. Defaults to None. save_full_model_results: Whether to save full model results. Defaults to False. + aggregation_columns: The columns to use for aggregation. Defaults to None. """ # call the parent init function super().__init__( @@ -55,6 +58,7 @@ def __init__( hsa, run_params, save_full_model_results, + aggregation_columns, ) def _get_data(self, data_loader: Data) -> pd.DataFrame: diff --git a/src/nhp/model/model.py b/src/nhp/model/model.py index 204f158b..56070b17 100644 --- a/src/nhp/model/model.py +++ b/src/nhp/model/model.py @@ -43,6 +43,7 @@ class Model: run_params: The parameters to use for each model run. Generated automatically if left as None. Defaults to None. save_full_model_results: Whether to save the full model results or not. Defaults to False. + aggregation_columns: The columns to use for aggregation. Defaults to None. Attributes: model_type: A string describing the type of model, must be one of "aae", "ip", or "op". @@ -67,6 +68,7 @@ def __init__( hsa: Any | None = None, run_params: dict | None = None, save_full_model_results: bool = False, + aggregation_columns: List[str] | None = None, ) -> None: """Initialise the Model. @@ -78,6 +80,8 @@ def __init__( hsa: Health Status Adjustment object. Defaults to None. run_params: The run parameters to use. Defaults to None. save_full_model_results: Whether to save full model results. Defaults to False. + aggregation_columns: The columns to use for aggregation. If left None (Default), then + ["pod", "sitetret"] is used. """ valid_model_types = ["aae", "ip", "op"] assert model_type in valid_model_types, "Model type must be one of 'aae', 'ip', or 'op'" @@ -107,6 +111,8 @@ def __init__( self.run_params = run_params or self.generate_run_params(params) self.save_full_model_results = save_full_model_results + self._aggregation_columns = aggregation_columns or ["pod", "sitetret"] + @property def measures(self) -> List[str]: """The names of the measure columns. @@ -408,8 +414,7 @@ def path_fn(f): return mr.get_aggregate_results() - @staticmethod - def get_agg(results: pd.DataFrame, *args: str) -> pd.Series: + def get_agg(self, results: pd.DataFrame, *args: str) -> pd.Series: """Get aggregation from model results. Args: @@ -419,7 +424,7 @@ def get_agg(results: pd.DataFrame, *args: str) -> pd.Series: Returns: Aggregated results. """ - return results.groupby(["pod", "sitetret", *args, "measure"])["value"].sum() + return results.groupby([*self._aggregation_columns, *args, "measure"])["value"].sum() def save_results(self, model_iteration: ModelIteration, path_fn: Callable[[str], str]) -> None: """Save the results of running the model. diff --git a/src/nhp/model/outpatients.py b/src/nhp/model/outpatients.py index 38bfa357..1916bdfb 100644 --- a/src/nhp/model/outpatients.py +++ b/src/nhp/model/outpatients.py @@ -26,6 +26,7 @@ class OutpatientsModel(Model): run_params: The parameters to use for each model run. Generated automatically if left as None. Defaults to None. save_full_model_results: Whether to save the full model results or not. Defaults to False. + aggregation_columns: The columns to use for aggregation. Defaults to None. """ def __init__( @@ -35,6 +36,7 @@ def __init__( hsa: Any = None, run_params: dict | None = None, save_full_model_results: bool = False, + aggregation_columns: list[str] | None = None, ) -> None: """Initialise the Outpatients Model. @@ -44,6 +46,7 @@ def __init__( hsa: Health Status Adjustment object. Defaults to None. run_params: The run parameters to use. Defaults to None. save_full_model_results: Whether to save full model results. Defaults to False. + aggregation_columns: The columns to use for aggregation. Defaults to None. """ # call the parent init function super().__init__( @@ -54,6 +57,7 @@ def __init__( hsa, run_params, save_full_model_results, + aggregation_columns, ) def _get_data(self, data_loader: Data) -> pd.DataFrame: diff --git a/src/nhp/model/run.py b/src/nhp/model/run.py index e8690649..3b804ca2 100644 --- a/src/nhp/model/run.py +++ b/src/nhp/model/run.py @@ -34,10 +34,10 @@ def update(self, n=1): tqdm.progress_callback(self.n) -def timeit(func: Callable, *args) -> Any: +def timeit(func: Callable, *args, **kwargs) -> Any: """Time how long it takes to evaluate function `f` with arguments `*args`.""" start = time.time() - results = func(*args) + results = func(*args, **kwargs) print(f"elapsed: {time.time() - start:.3f}s") return results @@ -50,6 +50,7 @@ def _run_model( run_params: dict, progress_callback: Callable[[Any], None], save_full_model_results: bool, + aggregation_columns: list[str] | None = None, ) -> list[ModelRunResult]: """Run the model iterations. @@ -63,6 +64,7 @@ def _run_model( run_params: The generated run parameters for the model run. progress_callback: A callback function for progress updates. save_full_model_results: Whether to save full model results. + aggregation_columns: The columns to use for aggregation. Defaults to None. Returns: A list containing the aggregated results for all model runs. @@ -70,8 +72,10 @@ def _run_model( model_class = model_type.__name__[:-5] logging.info("%s", model_class) logging.info(" * instantiating") - # ignore type issues here: Model has different arguments to Inpatients/Outpatients/A&E - model = model_type(params, data, hsa, run_params, save_full_model_results) # type: ignore + # ignore type issues here: model_type is Type[Model] so ty checks against Model.__init__, + # which has extra leading positional args (model_type, measures) that the concrete subclasses + # don't expose — the positional mapping is correct at runtime. + model = model_type(params, data, hsa, run_params, save_full_model_results, aggregation_columns) # ty: ignore[invalid-argument-type] logging.info(" * running") # set the progress callback for this run @@ -115,6 +119,7 @@ def run_all( nhp_data: Callable[[int, str], Data], progress_callback: Callable[[Any], Callable[[Any], None]] = noop_progress_callback, save_full_model_results: bool = False, + aggregation_columns: list[str] | None = None, ) -> tuple[dict[str, pd.DataFrame], list[str]]: """Run the model. @@ -126,6 +131,7 @@ def run_all( progress_callback: A callback function for updating progress. Defaults to noop_progress_callback. save_full_model_results: Whether to save full model results. Defaults to False. + aggregation_columns: The columns to use for aggregation. Defaults to None. Returns: A dictionary containing the results dataframes, and a list of the variants that were run. @@ -152,6 +158,7 @@ def run_all( run_params, progress_callback(m.__name__[:-5]), save_full_model_results, + aggregation_columns, ) for m in model_types ] @@ -161,13 +168,17 @@ def run_all( def run_single_model_run( - params: dict, data_path: str, model_type: Type[Model], model_run: int + params: dict, + data_path: str, + model_type: Type[Model], + model_run: int, + aggregation_columns: list[str] | None = None, ) -> None: """Runs a single model iteration for easier debugging in vscode.""" data = Local.create(data_path) print("initialising model... ", end="") - model = timeit(model_type, params, data) + model = timeit(model_type, params, data, aggregation_columns=aggregation_columns) print("running model... ", end="") m_run = timeit(ModelIteration, model, model_run) print("aggregating results... ", end="") diff --git a/tests/unit/nhp/model/test_model.py b/tests/unit/nhp/model/test_model.py index f884db99..f1867464 100644 --- a/tests/unit/nhp/model/test_model.py +++ b/tests/unit/nhp/model/test_model.py @@ -14,9 +14,10 @@ @pytest.fixture def mock_model(): """Create a mock Model instance.""" - with patch.object(Model, "__init__", lambda s, m, p, d, c: None): + with patch.object(Model, "__init__", lambda *args, **kwargs: None): mdl = Model(None, None, None, None) # type: ignore mdl.model_type = "aae" + mdl._aggregation_columns = ["pod", "sitetret"] mdl.params = { "input_data": "synthetic", "model_runs": 3, diff --git a/tests/unit/nhp/model/test_run.py b/tests/unit/nhp/model/test_run.py index 06a56f35..a276be2a 100644 --- a/tests/unit/nhp/model/test_run.py +++ b/tests/unit/nhp/model/test_run.py @@ -139,6 +139,7 @@ def test_run_all(mocker): {"variant": "variants"}, pc_m(), False, + None, ) for m in [InpatientsModel, OutpatientsModel, AaEModel] ] @@ -184,7 +185,9 @@ def test_run_single_model_run(mocker, capsys): ndl_mock.create.assert_called_once_with("data") assert timeit_mock.call_count == 3 - assert timeit_mock.call_args_list[0] == call("model_type", params, "nhp_data") + assert timeit_mock.call_args_list[0] == call( + "model_type", params, "nhp_data", aggregation_columns=None + ) assert timeit_mock.call_args_list[2] == call(mr_mock.get_aggregate_results) assert capsys.readouterr().out == "\n".join(