Skip to content

fix: Reduce K-means memory from O(n·k·d) to O(n·k) and fix empty cluster NaN#333

Merged
josevalim merged 3 commits intoelixir-nx:mainfrom
georgeguimaraes:fix/kmeans-memory-optimization
Apr 27, 2026
Merged

fix: Reduce K-means memory from O(n·k·d) to O(n·k) and fix empty cluster NaN#333
josevalim merged 3 commits intoelixir-nx:mainfrom
georgeguimaraes:fix/kmeans-memory-optimization

Conversation

@georgeguimaraes
Copy link
Copy Markdown
Contributor

@georgeguimaraes georgeguimaraes commented Apr 14, 2026

Problem

calculate_inertia materializes two {runs, k*n, d} tensors via Nx.broadcast + Nx.tile to compute pairwise distances. For moderate inputs (n=1000, k=20, d=1536), each call allocates ~245MB — and it's called every iteration. With EXLA on CPU, XLA's BFC allocator holds onto freed buffers between iterations, causing RSS to spike well beyond the live data size. We hit repeated OOM kills in production on a 32GB container with ~7000 conversation embeddings.

The centroid update step also divides by group_sizes without guarding against empty clusters, producing NaN that propagates through all subsequent iterations and causes K-means to exit after 1 iteration.

Changes

1. Memory-efficient distance computation

Replace the broadcast+tile approach in calculate_inertia with the algebraic identity:

||x - c||² = ||x||² + ||c||² - 2·x·cᵀ

The dot product x·cᵀ is computed via Nx.dot (a single GEMM call), producing a {runs, k, n} tensor instead of {runs, k*n, d}. Peak memory drops from O(runs·k·n·d) to O(runs·k·n).

One subtlety: k-means++ pads unused centroid slots with infinity during initialization. The original ||x - c||² formula returns inf for those slots naturally, but the expanded form produces inf - inf = NaN, which breaks the distance-weighted sampling that picks the next initial centroid. The new code maps NaN back to inf to preserve k-means++ behavior.

2. Empty cluster handling

When a cluster has zero members, keep the previous centroid instead of computing 0/0 = NaN. This is the standard approach used by scikit-learn. Without this fix, NaN propagates through all subsequent iterations and the convergence check distance > tol evaluates to false (NaN comparisons), causing the while loop to exit after 1 iteration.

Measurements

RSS delta measured with EXLA backend on 1536-dimensional OpenAI embeddings, k=20:

n Before After Reduction
100 1,604 MB 146 MB 11x
500 7,768 MB 755 MB 10x
1,000 OOM (>32 GB) 707 MB

Numerical accuracy

The expansion ||x||² + ||c||² - 2·x·cᵀ has slightly more floating-point cancellation than direct subtraction for nearby points. Measured max absolute difference against the original implementation:

Dimensions Max diff
d=8 8.9e-7
d=16 2.9e-6
d=1,536 1.6e-3

These differences are negligible for argmin-based cluster assignment — the cluster each point is assigned to is unaffected.

…ter NaN

Two changes:

1. calculate_inertia: use ||x-c||^2 = ||x||^2 + ||c||^2 - 2·x·cᵀ to
   compute distances via matrix multiply instead of broadcasting a
   {runs, k*n, d} tensor. Peak memory drops from O(n*k*d) to O(n*k).

   Measured RSS with EXLA (1536-dim embeddings, k=20):
     n=100:  1604MB → 146MB  (11x)
     n=500:  7768MB → 755MB  (10x)
     n=1000: OOM → 707MB

2. Centroid update: when a cluster has zero members, keep the previous
   centroid instead of computing 0/0 = NaN. This is the standard fix
   used by scikit-learn. Without it, NaN propagates through all
   subsequent iterations, causing K-means to exit after 1 iteration.
@georgeguimaraes
Copy link
Copy Markdown
Contributor Author

I'll review the failures in the CI here

The empty-cluster fix changes which run wins the lowest-inertia
tie-break across num_runs initializations, swapping the label
permutation. Pin tests and doctests to the new ordering.
The new ||x||² + ||c||² - 2·x·cᵀ expansion produces inf - inf = NaN
when k-means++ pads unused centroid slots with infinity, breaking the
weighted sampling that picks the next initial centroid. The original
direct ||x - c||² formula returned inf naturally.

Mapping NaN back to inf restores k-means++ behavior. With this fix,
clustering matches main exactly, so revert the test pinning from the
previous commit.
@georgeguimaraes
Copy link
Copy Markdown
Contributor Author

CI green :)

@josevalim josevalim merged commit daf2f6b into elixir-nx:main Apr 27, 2026
2 checks passed
@josevalim
Copy link
Copy Markdown
Contributor

💚 💙 💜 💛 ❤️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants