Commit daf2f6b
authored
fix: Reduce K-means memory from O(n·k·d) to O(n·k) and fix empty cluster NaN (#333)
Two changes:
1. calculate_inertia: use ||x-c||^2 = ||x||^2 + ||c||^2 - 2·x·cᵀ to
compute distances via matrix multiply instead of broadcasting a
{runs, k*n, d} tensor. Peak memory drops from O(n*k*d) to O(n*k).
Measured RSS with EXLA (1536-dim embeddings, k=20):
n=100: 1604MB → 146MB (11x)
n=500: 7768MB → 755MB (10x)
n=1000: OOM → 707MB
2. Centroid update: when a cluster has zero members, keep the previous
centroid instead of computing 0/0 = NaN. This is the standard fix
used by scikit-learn. Without it, NaN propagates through all
subsequent iterations, causing K-means to exit after 1 iteration.1 parent 7937730 commit daf2f6b
1 file changed
Lines changed: 28 additions & 14 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
181 | 181 | | |
182 | 182 | | |
183 | 183 | | |
| 184 | + | |
184 | 185 | | |
185 | | - | |
| 186 | + | |
186 | 187 | | |
187 | | - | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
188 | 197 | | |
189 | 198 | | |
190 | 199 | | |
| |||
228 | 237 | | |
229 | 238 | | |
230 | 239 | | |
231 | | - | |
232 | | - | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
233 | 247 | | |
234 | | - | |
235 | | - | |
236 | | - | |
237 | | - | |
238 | | - | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
239 | 252 | | |
| 253 | + | |
| 254 | + | |
240 | 255 | | |
241 | | - | |
242 | | - | |
243 | | - | |
244 | | - | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
245 | 260 | | |
246 | | - | |
247 | 261 | | |
248 | 262 | | |
249 | 263 | | |
| |||
0 commit comments