Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions python-package/xgboost/testing/ordinal.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,44 @@ def run_invalid(DMatrixT: Type) -> None:
)


def run_cat_oov_in_range(device: Device) -> None:
"""Test OOV categories whose sorted insertion point is inside the training set."""
Df, _ = get_df_impl(device)

def check(train_df: Any, pred_df: Any, y: np.ndarray) -> None:
for DMatrixT in (DMatrix, QuantileDMatrix):
Xy = DMatrixT(train_df, y, enable_categorical=True)
booster = train({"device": device}, Xy, num_boost_round=1)

with pytest.raises(
ValueError, match="Found a category not in the training"
):
booster.inplace_predict(pred_df)

fmat = DMatrixT(pred_df, enable_categorical=True)
with pytest.raises(
ValueError, match="Found a category not in the training"
):
booster.predict(fmat)

# "B2" is absent from training but sorts between "B" and "C".
train_df = Df({"cat": ["A", "B", "C", "A", "B", "B"], "x": [1, 2, 3, 4, 5, 6]})
train_df["cat"] = train_df["cat"].astype("category")
pred_df = Df({"cat": ["B2", "C"], "x": [2, 2]})
pred_df["cat"] = pred_df["cat"].astype("category")
check(train_df, pred_df, np.array([0, 1, 0, 1, 0, 1]))

# Numeric category 4 is absent from training but sorts between 3 and 5.

# XGBoost accepts only dense codes. In this test case, even though the categories
# are not contiguous, the code can still be dense.
train_df = Df({"cat": [1, 3, 5, 1, 3, 3], "x": [1, 2, 3, 4, 5, 6]})
train_df["cat"] = train_df["cat"].astype("category")
pred_df = Df({"cat": [4, 5], "x": [2, 2]})
pred_df["cat"] = pred_df["cat"].astype("category")
check(train_df, pred_df, np.array([0, 1, 0, 1, 0, 1]))


def run_cat_thread_safety(device: Device) -> None:
"""Basic tests for thread safety."""
X, y = make_categorical(2048, 16, 112, onehot=False, cat_ratio=0.5, device=device)
Expand Down
13 changes: 11 additions & 2 deletions src/encoder/ordinal.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,14 @@ struct SegmentedSearchSortedStrOp {
if (ret_it == it + f_sorted_idx.size()) {
return detail::NotFound();
}
return *ret_it;
// Handle OOV category
auto sorted_idx = *ret_it;
auto candidate_idx = f_sorted_idx[sorted_idx];
auto candidate_beg = haystack.offsets[candidate_idx];
auto candidate_end = haystack.offsets[candidate_idx + 1];
auto candidate =
haystack.values.subspan(candidate_beg, candidate_end - candidate_beg);
return candidate == needle ? sorted_idx : detail::NotFound();
}
};

Expand Down Expand Up @@ -105,7 +112,9 @@ struct SegmentedSearchSortedNumOp {
if (ret_it == it + f_sorted_idx.size()) {
return detail::NotFound();
}
return *ret_it;
auto sorted_idx = *ret_it;
auto candidate = haystack[f_sorted_idx[sorted_idx]];
return candidate == needle ? sorted_idx : detail::NotFound();
}
};

Expand Down
16 changes: 12 additions & 4 deletions src/encoder/ordinal.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,13 @@ void ArgSort(InIt in_first, InIt in_last, OutIt out_first, Comp comp = std::less
if (ret_it == it + haystack.size()) {
return detail::NotFound();
}
return *ret_it;
// Handle OOV category
auto sorted_idx = *ret_it;
auto candidate_idx = ref_sorted_idx[sorted_idx];
auto candidate_beg = h_off[candidate_idx];
auto candidate_end = h_off[candidate_idx + 1];
auto candidate = h_data.subspan(candidate_beg, candidate_end - candidate_beg);
return candidate == needle ? sorted_idx : detail::NotFound();
}

template <typename T>
Expand All @@ -265,7 +271,9 @@ SearchSorted(Span<T const> haystack, Span<std::int32_t const> ref_sorted_idx, T
if (ret_it == it + haystack.size()) {
return detail::NotFound();
}
return *ret_it;
auto sorted_idx = *ret_it;
auto candidate = haystack[ref_sorted_idx[sorted_idx]];
return candidate == needle ? sorted_idx : detail::NotFound();
}

template <typename ExecPolicy>
Expand Down Expand Up @@ -344,8 +352,8 @@ void Recode(ExecPolicy const &policy, HostColumnsView orig_enc, Span<std::int32_

std::size_t out_idx = 0;
for (std::size_t f_idx = 0, n_features = orig_enc.Size(); f_idx < n_features; f_idx++) {
auto const& l_f = orig_enc.columns[f_idx];
auto const& r_f = new_enc.columns[f_idx];
auto const &l_f = orig_enc.columns[f_idx];
auto const &r_f = new_enc.columns[f_idx];
auto report = [&] {
std::stringstream ss;
ss << "Invalid new DataFrame input for the: " << f_idx << "th feature (0-based). "
Expand Down
5 changes: 5 additions & 0 deletions tests/python-gpu/test_gpu_ordinal.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
run_cat_container_mixed,
run_cat_invalid,
run_cat_leaf,
run_cat_oov_in_range,
run_cat_predict,
run_cat_shap,
run_cat_thread_safety,
Expand Down Expand Up @@ -48,6 +49,10 @@ def test_cat_invalid() -> None:
run_cat_invalid("cuda")


def test_cat_oov_in_range() -> None:
run_cat_oov_in_range("cuda")


def test_cat_thread_safety() -> None:
run_cat_thread_safety("cuda")

Expand Down
6 changes: 5 additions & 1 deletion tests/python/test_ordinal.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import pytest

from xgboost import testing as tm
from xgboost.testing.ordinal import (
run_cat_container,
run_cat_container_iter,
run_cat_container_mixed,
run_cat_invalid,
run_cat_leaf,
run_cat_oov_in_range,
run_cat_predict,
run_cat_shap,
run_cat_thread_safety,
Expand Down Expand Up @@ -41,6 +41,10 @@ def test_cat_invalid() -> None:
run_cat_invalid("cpu")


def test_cat_oov_in_range() -> None:
run_cat_oov_in_range("cpu")


def test_cat_thread_safety() -> None:
run_cat_thread_safety("cpu")

Expand Down
Loading