From 208d9f0750e5476bb4812c2a4643661f0f0f7b31 Mon Sep 17 00:00:00 2001 From: tomrod-flamelit Date: Mon, 10 Oct 2022 10:24:21 -0500 Subject: [PATCH] enable multi-valued treatedment ID --- SyntheticControlMethods/main.py | 36 ++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/SyntheticControlMethods/main.py b/SyntheticControlMethods/main.py index aa7a306..815099e 100644 --- a/SyntheticControlMethods/main.py +++ b/SyntheticControlMethods/main.py @@ -11,6 +11,9 @@ # limitations under the License. from __future__ import absolute_import, division, print_function +from typing import Union +from collections.abc import Iterable +from numbers import Number import pandas as pd import numpy as np @@ -143,7 +146,7 @@ class DataProcessor(object): def _process_input_data(self, dataset, outcome_var, id_var, time_var, - treatment_period, treated_unit, + treatment_period, treated_unit: Union(Iterable, str, Number), pen, exclude_columns, random_seed, **kwargs): ''' @@ -161,20 +164,30 @@ def _process_input_data(self, dataset, n_controls = dataset[id_var].nunique() - 1 n_covariates = len(covariates) + #Create treated unit list + if isinstance(treated_unit, Iterable): + treated_unit_list = treated_unit + elif isinstance(treated_unit, str) or isinstance(treated_unit, Number): + treated_unit_list = [treated_unit] + elif isinstance(treated_unit, dict): + raise NotImplementedError('Input for treated_unit as a dict is not implemented.') + else: + raise ValueError('Treated Unit must be iterable of values, string, or int') + #All units that are not the treated unit are controls - control_units = dataset.loc[dataset[id_var] != treated_unit][id_var].unique() + control_units = dataset.loc[~dataset[id_var].isin(treated_unit_list)][id_var].unique() ###Get treated unit matrices first### treated_outcome_all, treated_outcome, unscaled_treated_covariates = self._process_treated_data( dataset, outcome_var, id_var, time_var, - treatment_period, treated_unit, periods_all, + treatment_period, treated_unit_list, periods_all, periods_pre_treatment, covariates, n_covariates ) ### Now for control unit matrices ### control_outcome_all, control_outcome, unscaled_control_covariates = self._process_control_data( dataset, outcome_var, id_var, time_var, - treatment_period, treated_unit, n_controls, + treatment_period, treated_unit_list, n_controls, periods_all, periods_pre_treatment, covariates ) @@ -213,13 +226,13 @@ def _process_input_data(self, dataset, 'random_seed':random_seed, } - def _process_treated_data(self, dataset, outcome_var, id_var, time_var, treatment_period, treated_unit, + def _process_treated_data(self, dataset, outcome_var, id_var, time_var, treatment_period, treated_unit_list, periods_all, periods_pre_treatment, covariates, n_covariates): ''' Extracts and formats outcome and covariate matrices for the treated unit ''' - treated_data_all = dataset.loc[dataset[id_var] == treated_unit] + treated_data_all = dataset.loc[dataset[id_var].isin(treated_unit_list)] treated_outcome_all = np.array(treated_data_all[outcome_var]).reshape(periods_all,1) #All outcomes #Only pre-treatment @@ -232,14 +245,14 @@ def _process_treated_data(self, dataset, outcome_var, id_var, time_var, treatmen return treated_outcome_all, treated_outcome, treated_covariates - def _process_control_data(self, dataset, outcome_var, id_var, time_var, treatment_period, treated_unit, n_controls, + def _process_control_data(self, dataset, outcome_var, id_var, time_var, treatment_period, treated_unit_list, n_controls, periods_all, periods_pre_treatment, covariates): ''' Extracts and formats outcome and covariate matrices for the control group ''' #Every unit that is not the treated unit is control - control_data_all = dataset.loc[dataset[id_var] != treated_unit] + control_data_all = dataset.loc[~dataset[id_var].isin(treated_unit_list)] control_outcome_all = np.array(control_data_all[outcome_var]).reshape(n_controls, periods_all).T #All outcomes #Only pre-treatment @@ -312,9 +325,12 @@ def __init__(self, dataset, E.g. 1990 for german reunification. treated_unit: - Type: str. + Type: Iterable, str, or Number Name of the unit that recieved treatment, e.g. "West Germany" - data["id_var"] == treated_unit + if isinstance(treated_unit, str) or isinstance(treated_unit, Number): + data["id_var"].isiin([treated_unit]) + elif isinstance(treated_unit, Iterable): + data["id_var"].isin(treated_unit) n_optim: Type: int. Default: 10.