Skip to content

Commit daf2f6b

Browse files
fix: Reduce K-means memory from O(n·k·d) to O(n·k) and fix empty cluster NaN (#333)
Two changes: 1. calculate_inertia: use ||x-c||^2 = ||x||^2 + ||c||^2 - 2·x·cᵀ to compute distances via matrix multiply instead of broadcasting a {runs, k*n, d} tensor. Peak memory drops from O(n*k*d) to O(n*k). Measured RSS with EXLA (1536-dim embeddings, k=20): n=100: 1604MB → 146MB (11x) n=500: 7768MB → 755MB (10x) n=1000: OOM → 707MB 2. Centroid update: when a cluster has zero members, keep the previous centroid instead of computing 0/0 = NaN. This is the standard fix used by scikit-learn. Without it, NaN propagates through all subsequent iterations, causing K-means to exit after 1 iteration.
1 parent 7937730 commit daf2f6b

1 file changed

Lines changed: 28 additions & 14 deletions

File tree

lib/scholar/cluster/k_means.ex

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,19 @@ defmodule Scholar.Cluster.KMeans do
181181
broadcast_weights
182182

183183
group_sizes = Nx.sum(group_masks, axes: [2], keep_axes: true)
184+
empty_clusters = group_sizes == 0
184185

185-
centroids =
186+
new_centroids =
186187
((Nx.new_axis(group_masks, -1) * Nx.new_axis(broadcast_x, 1)) |> Nx.sum(axes: [2])) /
187-
group_sizes
188+
Nx.max(group_sizes, 1)
189+
190+
# Keep previous centroid for empty clusters instead of NaN from 0/0
191+
centroids =
192+
Nx.select(
193+
Nx.broadcast(empty_clusters, Nx.shape(new_centroids)),
194+
previous_iteration_centroids,
195+
new_centroids
196+
)
188197

189198
distance =
190199
Scholar.Metrics.Distance.squared_euclidean(centroids, previous_iteration_centroids,
@@ -228,22 +237,27 @@ defmodule Scholar.Cluster.KMeans do
228237
end
229238
end
230239

231-
defnp calculate_inertia(x, centroids, num_clusters, num_runs) do
232-
{num_samples, num_features} = Nx.shape(x)
240+
defnp calculate_inertia(x, centroids, _num_clusters, _num_runs) do
241+
# Use the identity ||x - c||^2 = ||x||^2 + ||c||^2 - 2·x·cᵀ
242+
# to compute distances via matrix multiply instead of broadcasting.
243+
# Peak memory is O(runs*k*n) instead of O(runs*k*n*d).
244+
x_sq = Nx.sum(x * x, axes: [1])
245+
c_sq = Nx.sum(centroids * centroids, axes: [2])
246+
dot = Nx.dot(centroids, [2], x, [1])
233247

234-
modified_centroids =
235-
centroids
236-
|> Nx.new_axis(2)
237-
|> Nx.broadcast({num_runs, num_clusters, num_samples, num_features})
238-
|> Nx.reshape({num_runs, num_clusters * num_samples, num_features})
248+
inertia_for_centroids =
249+
Nx.new_axis(Nx.new_axis(x_sq, 0), 0) +
250+
Nx.new_axis(c_sq, 2) -
251+
2 * dot
239252

253+
# k-means++ pads unused centroid slots with infinity. The expansion
254+
# produces inf - inf = NaN there; restore inf so weighted sampling works.
240255
inertia_for_centroids =
241-
Scholar.Metrics.Distance.squared_euclidean(
242-
Nx.tile(x, [num_runs, num_clusters, 1]),
243-
modified_centroids,
244-
axes: [2]
256+
Nx.select(
257+
Nx.is_nan(inertia_for_centroids),
258+
Nx.Constants.infinity(Nx.type(inertia_for_centroids)),
259+
inertia_for_centroids
245260
)
246-
|> Nx.reshape({num_runs, num_clusters, num_samples})
247261

248262
{inertia_for_centroids, Nx.reduce_min(inertia_for_centroids, axes: [1])}
249263
end

0 commit comments

Comments
 (0)