Skip to content

Find similar observations using leaf node matching #11919

@mfdel

Description

@mfdel

I'm proposing a feature to find which training samples are "similar" to a prediction sample, based on whether they end up in the same leaf nodes across trees.

The idea is simple: if two observations consistently land in the same leaves across many trees, the model "sees" them as similar.

I already built a function to do this for my use case. I was doing a prediction to assess sales potential for some geographies and the model was predicting way too high for a zip code. I couldn't figure out why from the features or feature importances alone.

So I wrote a wrapper that checks: for this zip code, which training observations land in the same leaf nodes most often? Turns out a tiny zip code on the other side of the world matched 83% of the time. Looking at the features, they were strangely similar. It was an outlier and later I excluded it form the training.

I would never have found this by looking at Euclidean distance or raw feature values. The model knew these were similar. I just needed a way to ask it.

This is how it works

  1. Use pred_leaf=True to get leaf indices for each tree
  2. For each tree, check if query and reference land in same leaf (boolean)
  3. Average across trees → similarity score between 0 and 1 (I also weight trees by the variance of their leaf predictions, meaning that trees where all leaves predict ~same value aren't very discriminative.)

What the API could look like

# High-level
similar_idx, scores = model.find_similar(
    query=X_query,
    reference=X_train,
    k=5,
)

# Or lower-level on booster
query_leaves = booster.predict(DMatrix(X_query), pred_leaf=True)
ref_leaves = booster.predict(DMatrix(X_train), pred_leaf=True)
similarity = booster.compute_leaf_similarity(query_leaves, ref_leaves)

Why this is useful

  • Debugging predictions: "why is this prediction so high?" → find similar training samples and inspect them
  • Finding bad training data: outliers in training can affect predictions in unexpected places
  • Explaining to stakeholders: "this prediction is similar to these 5 historical cases" is easier to trust than a black box number

I haven't used it but I found out Random Forests have something similar to this, proximity matrices. Would be nice to have this in XGBoost.

Questions

  • Is this something that fits in XGBoost's scope, or better as a separate utility?
  • Any concerns about scaling to large datasets? (the leaf prediction is fast, similarity is just broadcasting)
  • Happy to put together a PR if there's interest

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions