diff --git a/tableone.py b/tableone.py index ac3f005..61155bb 100644 --- a/tableone.py +++ b/tableone.py @@ -84,8 +84,12 @@ def __init__(self, data, columns=None, categorical=None, groupby=None, nonnormal = [nonnormal] # if columns not specified, use all columns - if not columns: + if type(columns) == type(None): columns = data.columns.get_values() + elif 'pandas.core.indexes' in str(type(columns)): + columns = columns.get_values() + else: + columns = list(columns) # check that the columns exist in the dataframe if not set(columns).issubset(data.columns): @@ -98,13 +102,17 @@ def __init__(self, data, columns=None, categorical=None, groupby=None, raise InputError('Input contains duplicate columns: {}'.format(dups)) # if categorical not specified, try to identify categorical - if not categorical and type(categorical) != list: + if type(columns) == type(None): categorical = self._detect_categorical_columns(data[columns]) + elif 'pandas.core.indexes' in str(type(categorical)): + categorical = categorical.get_values() + else: + categorical = list(columns) if pval and not groupby: raise InputError("If pval=True then the groupby must be specified.") - self._columns = list(columns) + self._columns = columns self._isnull = isnull self._continuous = [c for c in columns if c not in categorical + [groupby]] self._categorical = categorical @@ -267,7 +275,7 @@ def _normaltest(self,x): Compute test for normal distribution. Null hypothesis: x comes from a normal distribution - p < alpha suggests the null hypothesis can be rejected. + p < alpha suggests the null hypothesis can be rejected. """ stat,p = stats.normaltest(x.values, nan_policy='omit') return p