@@ -779,6 +779,64 @@ def inner_f(*args: Any, **kwargs: Any) -> _T:
779779_deprecate_positional_args = require_keyword_args (False )
780780
781781
782+ def _get_categories (
783+ cfn : Callable [[ctypes .c_char_p ], int ],
784+ feature_names : Optional [FeatureNames ],
785+ n_features : int ,
786+ ) -> Optional [Dict [str , "pa.DictionaryArray" ]]:
787+ if not is_pyarrow_available ():
788+ raise ImportError ("`pyarrow` is required for exporting categories." )
789+
790+ if TYPE_CHECKING :
791+ import pyarrow as pa
792+ else :
793+ pa = import_pyarrow ()
794+
795+ fnames = feature_names
796+ if fnames is None :
797+ fnames = [str (i ) for i in range (n_features )]
798+
799+ results : Dict [str , "pa.DictionaryArray" ] = {}
800+
801+ ret = ctypes .c_char_p ()
802+ _check_call (cfn (ret ))
803+ if ret .value is None :
804+ return None
805+
806+ retstr = ret .value .decode () # pylint: disable=no-member
807+ jcats = json .loads (retstr )
808+ assert isinstance (jcats , list ) and len (jcats ) == n_features
809+
810+ for fidx in range (n_features ):
811+ f_jcats = jcats [fidx ]
812+ if f_jcats is None :
813+ # Numeric data
814+ results [fnames [fidx ]] = None
815+ continue
816+
817+ if "offsets" not in f_jcats :
818+ values = from_array_interface (f_jcats )
819+ pa_values = pa .Array .from_pandas (values )
820+ results [fnames [fidx ]] = pa_values
821+ continue
822+
823+ joffsets = f_jcats ["offsets" ]
824+ jvalues = f_jcats ["values" ]
825+ offsets = from_array_interface (joffsets , True )
826+ values = from_array_interface (jvalues , True )
827+ pa_offsets = pa .array (offsets ).buffers ()
828+ pa_values = pa .array (values ).buffers ()
829+ assert (
830+ pa_offsets [0 ] is None and pa_values [0 ] is None
831+ ), "Should not have null mask."
832+ pa_dict = pa .StringArray .from_buffers (
833+ len (offsets ) - 1 , pa_offsets [1 ], pa_values [1 ]
834+ )
835+ results [fnames [fidx ]] = pa_dict
836+
837+ return results
838+
839+
782840@unique
783841class DataSplitMode (IntEnum ):
784842 """Supported data split mode for DMatrix."""
@@ -1299,58 +1357,11 @@ def get_categories(self) -> Optional[Dict[str, "pa.DictionaryArray"]]:
12991357 .. versionadded:: 3.1.0
13001358
13011359 """
1302- if not is_pyarrow_available ():
1303- raise ImportError ("`pyarrow` is required for exporting categories." )
1304-
1305- if TYPE_CHECKING :
1306- import pyarrow as pa
1307- else :
1308- pa = import_pyarrow ()
1309-
1310- n_features = self .num_col ()
1311- fnames = self .feature_names
1312- if fnames is None :
1313- fnames = [str (i ) for i in range (n_features )]
1314-
1315- results : Dict [str , "pa.DictionaryArray" ] = {}
1316-
1317- ret = ctypes .c_char_p ()
1318- _check_call (_LIB .XGBDMatrixGetCategories (self .handle , ctypes .byref (ret )))
1319- if ret .value is None :
1320- return None
1321-
1322- retstr = ret .value .decode () # pylint: disable=no-member
1323- jcats = json .loads (retstr )
1324- assert isinstance (jcats , list ) and len (jcats ) == n_features
1325-
1326- for fidx in range (n_features ):
1327- f_jcats = jcats [fidx ]
1328- if f_jcats is None :
1329- # Numeric data
1330- results [fnames [fidx ]] = None
1331- continue
1332-
1333- if "offsets" not in f_jcats :
1334- values = from_array_interface (f_jcats )
1335- pa_values = pa .Array .from_pandas (values )
1336- results [fnames [fidx ]] = pa_values
1337- continue
1338-
1339- joffsets = f_jcats ["offsets" ]
1340- jvalues = f_jcats ["values" ]
1341- offsets = from_array_interface (joffsets , True )
1342- values = from_array_interface (jvalues , True )
1343- pa_offsets = pa .array (offsets ).buffers ()
1344- pa_values = pa .array (values ).buffers ()
1345- assert (
1346- pa_offsets [0 ] is None and pa_values [0 ] is None
1347- ), "Should not have null mask."
1348- pa_dict = pa .StringArray .from_buffers (
1349- len (offsets ) - 1 , pa_offsets [1 ], pa_values [1 ]
1350- )
1351- results [fnames [fidx ]] = pa_dict
1352-
1353- return results
1360+ return _get_categories (
1361+ lambda ret : _LIB .XGBDMatrixGetCategories (self .handle , ctypes .byref (ret )),
1362+ self .feature_names ,
1363+ self .num_col (),
1364+ )
13541365
13551366 def num_row (self ) -> int :
13561367 """Get the number of rows in the DMatrix."""
@@ -2312,6 +2323,23 @@ def feature_names(self) -> Optional[FeatureNames]:
23122323 def feature_names (self , features : Optional [FeatureNames ]) -> None :
23132324 self ._set_feature_info (features , "feature_name" )
23142325
2326+ def get_categories (self ) -> Optional [Dict [str , "pa.DictionaryArray" ]]:
2327+ """Get the categories in the dataset using `pyarrow`. Returns `None` if there's
2328+ no categorical features.
2329+
2330+ .. warning::
2331+
2332+ This function is still working in progress.
2333+
2334+ .. versionadded:: 3.1.0
2335+
2336+ """
2337+ return _get_categories (
2338+ lambda ret : _LIB .XGBoosterGetCategories (self .handle , ctypes .byref (ret )),
2339+ self .feature_names ,
2340+ self .num_features (),
2341+ )
2342+
23152343 def set_param (
23162344 self ,
23172345 params : Union [Dict , Iterable [Tuple [str , Any ]], str ],
0 commit comments