fix: Reduce K-means memory from O(n·k·d) to O(n·k) and fix empty cluster NaN#333
Merged
josevalim merged 3 commits intoelixir-nx:mainfrom Apr 27, 2026
Merged
Conversation
…ter NaN
Two changes:
1. calculate_inertia: use ||x-c||^2 = ||x||^2 + ||c||^2 - 2·x·cᵀ to
compute distances via matrix multiply instead of broadcasting a
{runs, k*n, d} tensor. Peak memory drops from O(n*k*d) to O(n*k).
Measured RSS with EXLA (1536-dim embeddings, k=20):
n=100: 1604MB → 146MB (11x)
n=500: 7768MB → 755MB (10x)
n=1000: OOM → 707MB
2. Centroid update: when a cluster has zero members, keep the previous
centroid instead of computing 0/0 = NaN. This is the standard fix
used by scikit-learn. Without it, NaN propagates through all
subsequent iterations, causing K-means to exit after 1 iteration.
Contributor
Author
|
I'll review the failures in the CI here |
The empty-cluster fix changes which run wins the lowest-inertia tie-break across num_runs initializations, swapping the label permutation. Pin tests and doctests to the new ordering.
The new ||x||² + ||c||² - 2·x·cᵀ expansion produces inf - inf = NaN when k-means++ pads unused centroid slots with infinity, breaking the weighted sampling that picks the next initial centroid. The original direct ||x - c||² formula returned inf naturally. Mapping NaN back to inf restores k-means++ behavior. With this fix, clustering matches main exactly, so revert the test pinning from the previous commit.
Contributor
Author
|
CI green :) |
josevalim
approved these changes
Apr 27, 2026
Contributor
|
💚 💙 💜 💛 ❤️ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
calculate_inertiamaterializes two{runs, k*n, d}tensors viaNx.broadcast+Nx.tileto compute pairwise distances. For moderate inputs (n=1000, k=20, d=1536), each call allocates ~245MB — and it's called every iteration. With EXLA on CPU, XLA's BFC allocator holds onto freed buffers between iterations, causing RSS to spike well beyond the live data size. We hit repeated OOM kills in production on a 32GB container with ~7000 conversation embeddings.The centroid update step also divides by
group_sizeswithout guarding against empty clusters, producing NaN that propagates through all subsequent iterations and causes K-means to exit after 1 iteration.Changes
1. Memory-efficient distance computation
Replace the broadcast+tile approach in
calculate_inertiawith the algebraic identity:The dot product
x·cᵀis computed viaNx.dot(a single GEMM call), producing a{runs, k, n}tensor instead of{runs, k*n, d}. Peak memory drops from O(runs·k·n·d) to O(runs·k·n).One subtlety: k-means++ pads unused centroid slots with infinity during initialization. The original
||x - c||²formula returns inf for those slots naturally, but the expanded form producesinf - inf = NaN, which breaks the distance-weighted sampling that picks the next initial centroid. The new code maps NaN back to inf to preserve k-means++ behavior.2. Empty cluster handling
When a cluster has zero members, keep the previous centroid instead of computing 0/0 = NaN. This is the standard approach used by scikit-learn. Without this fix, NaN propagates through all subsequent iterations and the convergence check
distance > tolevaluates to false (NaN comparisons), causing the while loop to exit after 1 iteration.Measurements
RSS delta measured with EXLA backend on 1536-dimensional OpenAI embeddings, k=20:
Numerical accuracy
The expansion
||x||² + ||c||² - 2·x·cᵀhas slightly more floating-point cancellation than direct subtraction for nearby points. Measured max absolute difference against the original implementation:These differences are negligible for argmin-based cluster assignment — the cluster each point is assigned to is unaffected.