Skip to content

Commit dfe5bd5

Browse files
simbasimba
authored andcommitted
feat: add auto-calibration 'Wisdom' system
FFTW-style auto-tuning that profiles optimal cache limits per model × hardware combination. First run benchmarks 4 configurations (tight → unlimited), measures tok/s, and persists the winner to ~/.mlx-server/wisdom/<key>.json. Subsequent runs load instantly. New CLI flags: --calibrate Force re-calibration even if wisdom exists Integration with startup flow: 1. Check for existing wisdom → apply instantly 2. Or run calibration trials → store + apply 3. --mem-limit always overrides wisdom Calibrator.swift: ~290 lines, zero new dependencies.
1 parent bb3946f commit dfe5bd5

2 files changed

Lines changed: 314 additions & 2 deletions

File tree

Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
// Calibrator.swift — Auto-tuning "Wisdom" system for optimal inference config
2+
//
3+
// FFTW-style approach: profile once per (model, hardware) pair, store optimal
4+
// config, apply instantly on subsequent runs.
5+
//
6+
// On first run with a new model, the calibrator runs a short benchmark to find
7+
// the optimal cache limit that maximizes tok/s. The result is stored in
8+
// ~/.mlx-server/wisdom/<hash>.json and loaded directly on future runs.
9+
//
10+
// Usage:
11+
// let wisdom = try await Calibrator.calibrate(container: container, plan: plan, profile: profile)
12+
// // Apply: Memory.cacheLimit = wisdom.cacheLimit
13+
14+
import Foundation
15+
import MLX
16+
import MLXLMCommon
17+
import Tokenizers
18+
19+
// MARK: - Wisdom Entry
20+
21+
/// Persisted calibration result for a specific (model, hardware) combination.
22+
struct WisdomEntry: Codable {
23+
let modelId: String
24+
let hardwareFingerprint: String
25+
let cacheLimit: Int // bytes
26+
let gpuLayers: Int?
27+
let tokPerSec: Double
28+
let prefillTokPerSec: Double
29+
let ttftMs: Double
30+
let memoryPeakMB: Int
31+
let calibratedAt: Date
32+
let calibrationSeconds: Double
33+
}
34+
35+
// MARK: - Calibration Config
36+
37+
/// A single calibration trial configuration
38+
private struct CalibrationTrial {
39+
let cacheLimitBytes: Int
40+
let label: String
41+
}
42+
43+
// MARK: - Calibrator
44+
45+
enum Calibrator {
46+
47+
/// Directory for wisdom files
48+
private static var wisdomDirectory: URL {
49+
let home = FileManager.default.homeDirectoryForCurrentUser
50+
return home.appendingPathComponent(".mlx-server/wisdom")
51+
}
52+
53+
/// Hardware fingerprint: chip + memory + OS
54+
static func hardwareFingerprint() -> String {
55+
var sysinfo = utsname()
56+
uname(&sysinfo)
57+
let machine = withUnsafePointer(to: &sysinfo.machine) {
58+
$0.withMemoryRebound(to: CChar.self, capacity: Int(_SYS_NAMELEN)) {
59+
String(cString: $0)
60+
}
61+
}
62+
let memGB = Int(ProcessInfo.processInfo.physicalMemory / (1024 * 1024 * 1024))
63+
let os = ProcessInfo.processInfo.operatingSystemVersionString
64+
return "\(machine)_\(memGB)GB_\(os)"
65+
}
66+
67+
/// Unique key for a (model, hardware) pair
68+
private static func wisdomKey(modelId: String) -> String {
69+
let hw = hardwareFingerprint()
70+
let combined = "\(modelId)_\(hw)"
71+
// Simple hash: use the string itself, sanitized for filename
72+
let sanitized = combined
73+
.replacingOccurrences(of: "/", with: "_")
74+
.replacingOccurrences(of: " ", with: "_")
75+
.replacingOccurrences(of: "(", with: "")
76+
.replacingOccurrences(of: ")", with: "")
77+
return sanitized
78+
}
79+
80+
/// Load existing wisdom for a model, if available
81+
static func loadWisdom(modelId: String) -> WisdomEntry? {
82+
let key = wisdomKey(modelId: modelId)
83+
let path = wisdomDirectory.appendingPathComponent("\(key).json")
84+
85+
guard FileManager.default.fileExists(atPath: path.path) else { return nil }
86+
87+
do {
88+
let data = try Data(contentsOf: path)
89+
let decoder = JSONDecoder()
90+
decoder.dateDecodingStrategy = .iso8601
91+
return try decoder.decode(WisdomEntry.self, from: data)
92+
} catch {
93+
print("[mlx-server] ⚠️ Failed to load wisdom: \(error.localizedDescription)")
94+
return nil
95+
}
96+
}
97+
98+
/// Save wisdom entry to disk
99+
private static func saveWisdom(_ entry: WisdomEntry) throws {
100+
let key = wisdomKey(modelId: entry.modelId)
101+
let dir = wisdomDirectory
102+
103+
try FileManager.default.createDirectory(at: dir, withIntermediateDirectories: true)
104+
105+
let path = dir.appendingPathComponent("\(key).json")
106+
let encoder = JSONEncoder()
107+
encoder.outputFormatting = [.prettyPrinted, .sortedKeys]
108+
encoder.dateEncodingStrategy = .iso8601
109+
let data = try encoder.encode(entry)
110+
try data.write(to: path)
111+
}
112+
113+
/// Run calibration: benchmark different cache limits, pick the best
114+
///
115+
/// Runs 3-5 short inference bursts at different cache limits and
116+
/// measures tok/s for each. Returns the optimal configuration.
117+
static func calibrate(
118+
container: ModelContainer,
119+
plan: PartitionPlan,
120+
modelId: String,
121+
contextSize: Int = 4096
122+
) async throws -> WisdomEntry {
123+
let startTime = Date()
124+
print("[mlx-server] 📊 Calibrating... (this only happens once per model × hardware)")
125+
126+
// Determine trial cache limits based on available memory
127+
let systemRAMBytes = Int(ProcessInfo.processInfo.physicalMemory)
128+
let modelWeightBytes = Int(plan.weightMemoryGB * 1e9)
129+
130+
// Trial cache limits: from tight (just weights + 20%) to generous (50% of free RAM)
131+
let freeRAMBytes = systemRAMBytes - modelWeightBytes
132+
let trials: [CalibrationTrial] = [
133+
CalibrationTrial(
134+
cacheLimitBytes: modelWeightBytes + modelWeightBytes / 5,
135+
label: "tight (weights + 20%)"
136+
),
137+
CalibrationTrial(
138+
cacheLimitBytes: modelWeightBytes + freeRAMBytes / 4,
139+
label: "moderate (weights + 25% free)"
140+
),
141+
CalibrationTrial(
142+
cacheLimitBytes: modelWeightBytes + freeRAMBytes / 2,
143+
label: "generous (weights + 50% free)"
144+
),
145+
CalibrationTrial(
146+
cacheLimitBytes: 0, // system default (no limit)
147+
label: "unlimited (system default)"
148+
),
149+
]
150+
151+
var bestTrial: (trial: CalibrationTrial, tokPerSec: Double, prefillTokPerSec: Double, ttft: Double)?
152+
153+
// Calibration prompt — short enough for fast iteration, long enough to measure
154+
let calibrationPrompt = "Explain the concept of machine learning in three sentences."
155+
let maxTokens = 30 // Just enough to measure steady-state decode speed
156+
157+
for (idx, trial) in trials.enumerated() {
158+
print("[mlx-server] Trial \(idx + 1)/\(trials.count): \(trial.label) (\(trial.cacheLimitBytes / (1024*1024))MB)")
159+
160+
// Set cache limit for this trial
161+
if trial.cacheLimitBytes > 0 {
162+
Memory.cacheLimit = trial.cacheLimitBytes
163+
} else {
164+
// Reset to system default
165+
Memory.cacheLimit = 0
166+
}
167+
168+
// Run inference and measure
169+
let result = await measureInference(
170+
container: container,
171+
prompt: calibrationPrompt,
172+
maxTokens: maxTokens
173+
)
174+
175+
if let result = result {
176+
print("[mlx-server] → \(String(format: "%.1f", result.tokPerSec)) tok/s decode, \(String(format: "%.0f", result.ttftMs))ms TTFT")
177+
178+
if bestTrial == nil || result.tokPerSec > bestTrial!.tokPerSec {
179+
bestTrial = (trial, result.tokPerSec, result.prefillTokPerSec, result.ttftMs)
180+
}
181+
} else {
182+
print("[mlx-server] → failed, skipping")
183+
}
184+
}
185+
186+
guard let best = bestTrial else {
187+
throw CalibratorError.allTrialsFailed
188+
}
189+
190+
let elapsed = Date().timeIntervalSince(startTime)
191+
192+
// Apply the winner
193+
if best.trial.cacheLimitBytes > 0 {
194+
Memory.cacheLimit = best.trial.cacheLimitBytes
195+
}
196+
197+
let entry = WisdomEntry(
198+
modelId: modelId,
199+
hardwareFingerprint: hardwareFingerprint(),
200+
cacheLimit: best.trial.cacheLimitBytes,
201+
gpuLayers: plan.gpuLayers,
202+
tokPerSec: best.tokPerSec,
203+
prefillTokPerSec: best.prefillTokPerSec,
204+
ttftMs: best.ttft,
205+
memoryPeakMB: Int(Double(GPU.activeMemory) / 1e6),
206+
calibratedAt: Date(),
207+
calibrationSeconds: elapsed
208+
)
209+
210+
try saveWisdom(entry)
211+
212+
print("[mlx-server] 📊 Calibration complete in \(String(format: "%.1f", elapsed))s")
213+
print("[mlx-server] Winner: \(best.trial.label)\(String(format: "%.1f", best.tokPerSec)) tok/s")
214+
print("[mlx-server] Saved to ~/.mlx-server/wisdom/")
215+
216+
return entry
217+
}
218+
219+
/// Measure a single inference run
220+
private static func measureInference(
221+
container: ModelContainer,
222+
prompt: String,
223+
maxTokens: Int
224+
) async -> InferenceResult? {
225+
do {
226+
// Prepare input using the same pattern as Server.swift
227+
let chatMessages: [Chat.Message] = [.user(prompt)]
228+
let userInput = UserInput(chat: chatMessages)
229+
let lmInput = try await container.prepare(input: userInput)
230+
let inputTokenCount = lmInput.text.tokens.size
231+
232+
let result: InferenceResult = try await container.perform { context in
233+
let generateParams = GenerateParameters(temperature: 0.6)
234+
235+
let ttftStart = Date()
236+
var firstTokenTime: Date?
237+
var tokenCount = 0
238+
239+
for try await result in try MLXLMCommon.generate(
240+
input: lmInput,
241+
parameters: generateParams,
242+
context: context
243+
) {
244+
switch result {
245+
case .chunk(_, tokenId: _):
246+
if firstTokenTime == nil {
247+
firstTokenTime = Date()
248+
}
249+
tokenCount += 1
250+
if tokenCount >= maxTokens {
251+
break
252+
}
253+
default:
254+
break
255+
}
256+
if tokenCount >= maxTokens { break }
257+
}
258+
259+
let ttft = firstTokenTime?.timeIntervalSince(ttftStart) ?? 0
260+
let decodeTime = Date().timeIntervalSince(firstTokenTime ?? ttftStart)
261+
let tokPerSec = decodeTime > 0 && tokenCount > 1 ? Double(tokenCount - 1) / decodeTime : 0
262+
let prefillTokPerSec = ttft > 0 ? Double(inputTokenCount) / ttft : 0
263+
264+
return InferenceResult(
265+
tokPerSec: tokPerSec,
266+
prefillTokPerSec: prefillTokPerSec,
267+
ttftMs: ttft * 1000,
268+
tokenCount: tokenCount
269+
)
270+
}
271+
272+
return result
273+
} catch {
274+
return nil
275+
}
276+
}
277+
}
278+
279+
// MARK: - Supporting Types
280+
281+
private struct InferenceResult {
282+
let tokPerSec: Double
283+
let prefillTokPerSec: Double
284+
let ttftMs: Double
285+
let tokenCount: Int
286+
}
287+
288+
enum CalibratorError: Error {
289+
case allTrialsFailed
290+
}

Sources/mlx-server/Server.swift

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ struct MLXServer: AsyncParsableCommand {
7676
@Option(name: .long, help: "Allowed CORS origin (* for all, or a specific origin URL)")
7777
var cors: String?
7878

79+
@Flag(name: .long, help: "Force re-calibration of optimal memory settings (normally auto-cached)")
80+
var calibrate: Bool = false
81+
7982
mutating func run() async throws {
8083
print("[mlx-server] Loading model: \(model)")
8184
let modelId = model
@@ -190,6 +193,25 @@ struct MLXServer: AsyncParsableCommand {
190193
}
191194
}
192195

196+
// ── Auto-calibration (Wisdom system) ──
197+
if let plan = partitionPlan {
198+
if self.calibrate {
199+
// Force re-calibration
200+
if let wisdom = try? await Calibrator.calibrate(
201+
container: container, plan: plan, modelId: modelId,
202+
contextSize: self.ctxSize ?? 4096
203+
) {
204+
Memory.cacheLimit = wisdom.cacheLimit
205+
}
206+
} else if let wisdom = Calibrator.loadWisdom(modelId: modelId) {
207+
// Load cached wisdom
208+
if wisdom.cacheLimit > 0 {
209+
Memory.cacheLimit = wisdom.cacheLimit
210+
}
211+
print("[mlx-server] 📊 Loaded wisdom: \(String(format: "%.1f", wisdom.tokPerSec)) tok/s, cache=\(wisdom.cacheLimit / (1024*1024))MB (calibrated \(wisdom.calibratedAt.formatted(.relative(presentation: .named))))")
212+
}
213+
}
214+
193215
print("[mlx-server] Model loaded. Starting HTTP server on \(host):\(port)")
194216

195217
// ── Capture CLI defaults into a shared config ──
@@ -208,12 +230,12 @@ struct MLXServer: AsyncParsableCommand {
208230
let corsOrigin = self.cors
209231
let apiKeyValue = self.apiKey
210232

211-
// ── Memory limit enforcement ──
233+
// ── Memory limit enforcement (overrides wisdom) ──
212234
if let memLimitMB = self.memLimit {
213235
let bytes = memLimitMB * 1024 * 1024
214236
Memory.memoryLimit = bytes
215237
Memory.cacheLimit = bytes
216-
print("[mlx-server] Memory limit set to \(memLimitMB)MB")
238+
print("[mlx-server] Memory limit set to \(memLimitMB)MB (overrides wisdom)")
217239
}
218240

219241
// ── Concurrency limiter ──

0 commit comments

Comments
 (0)