66from typing import Any
77from typing import Dict
88
9- try :
10- import numpy as np
11- has_numpy = True
12- except ImportError :
13- has_numpy = False
14-
159
1610if sys .version_info >= (3 , 10 ):
1711 _UNION_TYPES = {typing .Union , types .UnionType }
@@ -36,29 +30,51 @@ def get_annotations(obj: Any) -> Dict[str, Any]:
3630 return getattr (obj , '__annotations__' , {})
3731
3832
33+ def get_module (obj : Any ) -> str :
34+ """Get the module of an object."""
35+ module = getattr (obj , '__module__' , '' ).split ('.' )
36+ if module :
37+ return module [0 ]
38+ return ''
39+
40+
41+ def get_type_name (obj : Any ) -> str :
42+ """Get the type name of an object."""
43+ if hasattr (obj , '__name__' ):
44+ return obj .__name__
45+ if hasattr (obj , '__class__' ):
46+ return obj .__class__ .__name__
47+ return ''
48+
49+
3950def is_numpy (obj : Any ) -> bool :
4051 """Check if an object is a numpy array."""
41- if is_union (obj ):
42- obj = typing .get_args (obj )[0 ]
43- if not has_numpy :
44- return False
4552 if inspect .isclass (obj ):
46- return obj is np .ndarray
47- if typing .get_origin (obj ) is np .ndarray :
48- return True
49- return isinstance (obj , np .ndarray )
53+ if get_module (obj ) == 'numpy' :
54+ return get_type_name (obj ) == 'ndarray'
55+
56+ origin = typing .get_origin (obj )
57+ if get_module (origin ) == 'numpy' :
58+ if get_type_name (origin ) == 'ndarray' :
59+ return True
60+
61+ dtype = type (obj )
62+ if get_module (dtype ) == 'numpy' :
63+ return get_type_name (dtype ) == 'ndarray'
64+
65+ return False
5066
5167
5268def is_dataframe (obj : Any ) -> bool :
5369 """Check if an object is a DataFrame."""
5470 # Cheating here a bit so we don't have to import pandas / polars / pyarrow:
5571 # unless we absolutely need to
56- if getattr (obj , '__module__' , '' ). startswith ( ' pandas.' ) :
57- return getattr (obj , '__name__' , '' ) == 'DataFrame'
58- if getattr (obj , '__module__' , '' ). startswith ( ' polars.' ) :
59- return getattr (obj , '__name__' , '' ) == 'DataFrame'
60- if getattr (obj , '__module__' , '' ). startswith ( ' pyarrow.' ) :
61- return getattr (obj , '__name__' , '' ) == 'Table'
72+ if get_module (obj ) == ' pandas' :
73+ return get_type_name (obj ) == 'DataFrame'
74+ if get_module (obj ) == ' polars' :
75+ return get_type_name (obj ) == 'DataFrame'
76+ if get_module (obj ) == ' pyarrow' :
77+ return get_type_name (obj ) == 'Table'
6278 return False
6379
6480
@@ -74,13 +90,13 @@ def get_data_format(obj: Any) -> str:
7490 """Return the data format of the DataFrame / Table / vector."""
7591 # Cheating here a bit so we don't have to import pandas / polars / pyarrow
7692 # unless we absolutely need to
77- if getattr (obj , '__module__' , '' ). startswith ( ' pandas.' ) :
93+ if get_module (obj ) == ' pandas' :
7894 return 'pandas'
79- if getattr (obj , '__module__' , '' ). startswith ( ' polars.' ) :
95+ if get_module (obj ) == ' polars' :
8096 return 'polars'
81- if getattr (obj , '__module__' , '' ). startswith ( ' pyarrow.' ) :
97+ if get_module (obj ) == ' pyarrow' :
8298 return 'arrow'
83- if getattr (obj , '__module__' , '' ). startswith ( ' numpy.' ) :
99+ if get_module (obj ) == ' numpy' :
84100 return 'numpy'
85101 if isinstance (obj , list ):
86102 return 'list'
@@ -92,8 +108,8 @@ def is_pandas_series(obj: Any) -> bool:
92108 if is_union (obj ):
93109 obj = typing .get_args (obj )[0 ]
94110 return (
95- getattr (obj , '__module__' , '' ). startswith ( ' pandas.' ) and
96- getattr (obj , '__name__' , '' ) == 'Series'
111+ get_module (obj ) == ' pandas' and
112+ get_type_name (obj ) == 'Series'
97113 )
98114
99115
@@ -102,8 +118,8 @@ def is_polars_series(obj: Any) -> bool:
102118 if is_union (obj ):
103119 obj = typing .get_args (obj )[0 ]
104120 return (
105- getattr (obj , '__module__' , '' ). startswith ( ' polars.' ) and
106- getattr (obj , '__name__' , '' ) == 'Series'
121+ get_module (obj ) == ' polars' and
122+ get_type_name (obj ) == 'Series'
107123 )
108124
109125
@@ -112,8 +128,8 @@ def is_pyarrow_array(obj: Any) -> bool:
112128 if is_union (obj ):
113129 obj = typing .get_args (obj )[0 ]
114130 return (
115- getattr (obj , '__module__' , '' ). startswith ( ' pyarrow.' ) and
116- getattr (obj , '__name__' , '' ) == 'Array'
131+ get_module (obj ) == ' pyarrow' and
132+ get_type_name (obj ) == 'Array'
117133 )
118134
119135
@@ -147,6 +163,6 @@ def is_pydantic(obj: Any) -> bool:
147163 # the class is a subclass
148164 return bool ([
149165 x for x in inspect .getmro (obj )
150- if x . __module__ . startswith ( 'pydantic.' )
151- and x . __name__ == 'BaseModel'
166+ if get_module ( x ) == 'pydantic'
167+ and get_type_name ( x ) == 'BaseModel'
152168 ])
0 commit comments