|
| 1 | +// Calibrator.swift — Auto-tuning "Wisdom" system for optimal inference config |
| 2 | +// |
| 3 | +// FFTW-style approach: profile once per (model, hardware) pair, store optimal |
| 4 | +// config, apply instantly on subsequent runs. |
| 5 | +// |
| 6 | +// On first run with a new model, the calibrator runs a short benchmark to find |
| 7 | +// the optimal cache limit that maximizes tok/s. The result is stored in |
| 8 | +// ~/.mlx-server/wisdom/<hash>.json and loaded directly on future runs. |
| 9 | +// |
| 10 | +// Usage: |
| 11 | +// let wisdom = try await Calibrator.calibrate(container: container, plan: plan, profile: profile) |
| 12 | +// // Apply: Memory.cacheLimit = wisdom.cacheLimit |
| 13 | + |
| 14 | +import Foundation |
| 15 | +import MLX |
| 16 | +import MLXLMCommon |
| 17 | +import Tokenizers |
| 18 | + |
| 19 | +// MARK: - Wisdom Entry |
| 20 | + |
| 21 | +/// Persisted calibration result for a specific (model, hardware) combination. |
| 22 | +struct WisdomEntry: Codable { |
| 23 | + let modelId: String |
| 24 | + let hardwareFingerprint: String |
| 25 | + let cacheLimit: Int // bytes |
| 26 | + let gpuLayers: Int? |
| 27 | + let tokPerSec: Double |
| 28 | + let prefillTokPerSec: Double |
| 29 | + let ttftMs: Double |
| 30 | + let memoryPeakMB: Int |
| 31 | + let calibratedAt: Date |
| 32 | + let calibrationSeconds: Double |
| 33 | +} |
| 34 | + |
| 35 | +// MARK: - Calibration Config |
| 36 | + |
| 37 | +/// A single calibration trial configuration |
| 38 | +private struct CalibrationTrial { |
| 39 | + let cacheLimitBytes: Int |
| 40 | + let label: String |
| 41 | +} |
| 42 | + |
| 43 | +// MARK: - Calibrator |
| 44 | + |
| 45 | +enum Calibrator { |
| 46 | + |
| 47 | + /// Directory for wisdom files |
| 48 | + private static var wisdomDirectory: URL { |
| 49 | + let home = FileManager.default.homeDirectoryForCurrentUser |
| 50 | + return home.appendingPathComponent(".mlx-server/wisdom") |
| 51 | + } |
| 52 | + |
| 53 | + /// Hardware fingerprint: chip + memory + OS |
| 54 | + static func hardwareFingerprint() -> String { |
| 55 | + var sysinfo = utsname() |
| 56 | + uname(&sysinfo) |
| 57 | + let machine = withUnsafePointer(to: &sysinfo.machine) { |
| 58 | + $0.withMemoryRebound(to: CChar.self, capacity: Int(_SYS_NAMELEN)) { |
| 59 | + String(cString: $0) |
| 60 | + } |
| 61 | + } |
| 62 | + let memGB = Int(ProcessInfo.processInfo.physicalMemory / (1024 * 1024 * 1024)) |
| 63 | + let os = ProcessInfo.processInfo.operatingSystemVersionString |
| 64 | + return "\(machine)_\(memGB)GB_\(os)" |
| 65 | + } |
| 66 | + |
| 67 | + /// Unique key for a (model, hardware) pair |
| 68 | + private static func wisdomKey(modelId: String) -> String { |
| 69 | + let hw = hardwareFingerprint() |
| 70 | + let combined = "\(modelId)_\(hw)" |
| 71 | + // Simple hash: use the string itself, sanitized for filename |
| 72 | + let sanitized = combined |
| 73 | + .replacingOccurrences(of: "/", with: "_") |
| 74 | + .replacingOccurrences(of: " ", with: "_") |
| 75 | + .replacingOccurrences(of: "(", with: "") |
| 76 | + .replacingOccurrences(of: ")", with: "") |
| 77 | + return sanitized |
| 78 | + } |
| 79 | + |
| 80 | + /// Load existing wisdom for a model, if available |
| 81 | + static func loadWisdom(modelId: String) -> WisdomEntry? { |
| 82 | + let key = wisdomKey(modelId: modelId) |
| 83 | + let path = wisdomDirectory.appendingPathComponent("\(key).json") |
| 84 | + |
| 85 | + guard FileManager.default.fileExists(atPath: path.path) else { return nil } |
| 86 | + |
| 87 | + do { |
| 88 | + let data = try Data(contentsOf: path) |
| 89 | + let decoder = JSONDecoder() |
| 90 | + decoder.dateDecodingStrategy = .iso8601 |
| 91 | + return try decoder.decode(WisdomEntry.self, from: data) |
| 92 | + } catch { |
| 93 | + print("[mlx-server] ⚠️ Failed to load wisdom: \(error.localizedDescription)") |
| 94 | + return nil |
| 95 | + } |
| 96 | + } |
| 97 | + |
| 98 | + /// Save wisdom entry to disk |
| 99 | + private static func saveWisdom(_ entry: WisdomEntry) throws { |
| 100 | + let key = wisdomKey(modelId: entry.modelId) |
| 101 | + let dir = wisdomDirectory |
| 102 | + |
| 103 | + try FileManager.default.createDirectory(at: dir, withIntermediateDirectories: true) |
| 104 | + |
| 105 | + let path = dir.appendingPathComponent("\(key).json") |
| 106 | + let encoder = JSONEncoder() |
| 107 | + encoder.outputFormatting = [.prettyPrinted, .sortedKeys] |
| 108 | + encoder.dateEncodingStrategy = .iso8601 |
| 109 | + let data = try encoder.encode(entry) |
| 110 | + try data.write(to: path) |
| 111 | + } |
| 112 | + |
| 113 | + /// Run calibration: benchmark different cache limits, pick the best |
| 114 | + /// |
| 115 | + /// Runs 3-5 short inference bursts at different cache limits and |
| 116 | + /// measures tok/s for each. Returns the optimal configuration. |
| 117 | + static func calibrate( |
| 118 | + container: ModelContainer, |
| 119 | + plan: PartitionPlan, |
| 120 | + modelId: String, |
| 121 | + contextSize: Int = 4096 |
| 122 | + ) async throws -> WisdomEntry { |
| 123 | + let startTime = Date() |
| 124 | + print("[mlx-server] 📊 Calibrating... (this only happens once per model × hardware)") |
| 125 | + |
| 126 | + // Determine trial cache limits based on available memory |
| 127 | + let systemRAMBytes = Int(ProcessInfo.processInfo.physicalMemory) |
| 128 | + let modelWeightBytes = Int(plan.weightMemoryGB * 1e9) |
| 129 | + |
| 130 | + // Trial cache limits: from tight (just weights + 20%) to generous (50% of free RAM) |
| 131 | + let freeRAMBytes = systemRAMBytes - modelWeightBytes |
| 132 | + let trials: [CalibrationTrial] = [ |
| 133 | + CalibrationTrial( |
| 134 | + cacheLimitBytes: modelWeightBytes + modelWeightBytes / 5, |
| 135 | + label: "tight (weights + 20%)" |
| 136 | + ), |
| 137 | + CalibrationTrial( |
| 138 | + cacheLimitBytes: modelWeightBytes + freeRAMBytes / 4, |
| 139 | + label: "moderate (weights + 25% free)" |
| 140 | + ), |
| 141 | + CalibrationTrial( |
| 142 | + cacheLimitBytes: modelWeightBytes + freeRAMBytes / 2, |
| 143 | + label: "generous (weights + 50% free)" |
| 144 | + ), |
| 145 | + CalibrationTrial( |
| 146 | + cacheLimitBytes: 0, // system default (no limit) |
| 147 | + label: "unlimited (system default)" |
| 148 | + ), |
| 149 | + ] |
| 150 | + |
| 151 | + var bestTrial: (trial: CalibrationTrial, tokPerSec: Double, prefillTokPerSec: Double, ttft: Double)? |
| 152 | + |
| 153 | + // Calibration prompt — short enough for fast iteration, long enough to measure |
| 154 | + let calibrationPrompt = "Explain the concept of machine learning in three sentences." |
| 155 | + let maxTokens = 30 // Just enough to measure steady-state decode speed |
| 156 | + |
| 157 | + for (idx, trial) in trials.enumerated() { |
| 158 | + print("[mlx-server] Trial \(idx + 1)/\(trials.count): \(trial.label) (\(trial.cacheLimitBytes / (1024*1024))MB)") |
| 159 | + |
| 160 | + // Set cache limit for this trial |
| 161 | + if trial.cacheLimitBytes > 0 { |
| 162 | + Memory.cacheLimit = trial.cacheLimitBytes |
| 163 | + } else { |
| 164 | + // Reset to system default |
| 165 | + Memory.cacheLimit = 0 |
| 166 | + } |
| 167 | + |
| 168 | + // Run inference and measure |
| 169 | + let result = await measureInference( |
| 170 | + container: container, |
| 171 | + prompt: calibrationPrompt, |
| 172 | + maxTokens: maxTokens |
| 173 | + ) |
| 174 | + |
| 175 | + if let result = result { |
| 176 | + print("[mlx-server] → \(String(format: "%.1f", result.tokPerSec)) tok/s decode, \(String(format: "%.0f", result.ttftMs))ms TTFT") |
| 177 | + |
| 178 | + if bestTrial == nil || result.tokPerSec > bestTrial!.tokPerSec { |
| 179 | + bestTrial = (trial, result.tokPerSec, result.prefillTokPerSec, result.ttftMs) |
| 180 | + } |
| 181 | + } else { |
| 182 | + print("[mlx-server] → failed, skipping") |
| 183 | + } |
| 184 | + } |
| 185 | + |
| 186 | + guard let best = bestTrial else { |
| 187 | + throw CalibratorError.allTrialsFailed |
| 188 | + } |
| 189 | + |
| 190 | + let elapsed = Date().timeIntervalSince(startTime) |
| 191 | + |
| 192 | + // Apply the winner |
| 193 | + if best.trial.cacheLimitBytes > 0 { |
| 194 | + Memory.cacheLimit = best.trial.cacheLimitBytes |
| 195 | + } |
| 196 | + |
| 197 | + let entry = WisdomEntry( |
| 198 | + modelId: modelId, |
| 199 | + hardwareFingerprint: hardwareFingerprint(), |
| 200 | + cacheLimit: best.trial.cacheLimitBytes, |
| 201 | + gpuLayers: plan.gpuLayers, |
| 202 | + tokPerSec: best.tokPerSec, |
| 203 | + prefillTokPerSec: best.prefillTokPerSec, |
| 204 | + ttftMs: best.ttft, |
| 205 | + memoryPeakMB: Int(Double(GPU.activeMemory) / 1e6), |
| 206 | + calibratedAt: Date(), |
| 207 | + calibrationSeconds: elapsed |
| 208 | + ) |
| 209 | + |
| 210 | + try saveWisdom(entry) |
| 211 | + |
| 212 | + print("[mlx-server] 📊 Calibration complete in \(String(format: "%.1f", elapsed))s") |
| 213 | + print("[mlx-server] Winner: \(best.trial.label) → \(String(format: "%.1f", best.tokPerSec)) tok/s") |
| 214 | + print("[mlx-server] Saved to ~/.mlx-server/wisdom/") |
| 215 | + |
| 216 | + return entry |
| 217 | + } |
| 218 | + |
| 219 | + /// Measure a single inference run |
| 220 | + private static func measureInference( |
| 221 | + container: ModelContainer, |
| 222 | + prompt: String, |
| 223 | + maxTokens: Int |
| 224 | + ) async -> InferenceResult? { |
| 225 | + do { |
| 226 | + // Prepare input using the same pattern as Server.swift |
| 227 | + let chatMessages: [Chat.Message] = [.user(prompt)] |
| 228 | + let userInput = UserInput(chat: chatMessages) |
| 229 | + let lmInput = try await container.prepare(input: userInput) |
| 230 | + let inputTokenCount = lmInput.text.tokens.size |
| 231 | + |
| 232 | + let result: InferenceResult = try await container.perform { context in |
| 233 | + let generateParams = GenerateParameters(temperature: 0.6) |
| 234 | + |
| 235 | + let ttftStart = Date() |
| 236 | + var firstTokenTime: Date? |
| 237 | + var tokenCount = 0 |
| 238 | + |
| 239 | + for try await result in try MLXLMCommon.generate( |
| 240 | + input: lmInput, |
| 241 | + parameters: generateParams, |
| 242 | + context: context |
| 243 | + ) { |
| 244 | + switch result { |
| 245 | + case .chunk(_, tokenId: _): |
| 246 | + if firstTokenTime == nil { |
| 247 | + firstTokenTime = Date() |
| 248 | + } |
| 249 | + tokenCount += 1 |
| 250 | + if tokenCount >= maxTokens { |
| 251 | + break |
| 252 | + } |
| 253 | + default: |
| 254 | + break |
| 255 | + } |
| 256 | + if tokenCount >= maxTokens { break } |
| 257 | + } |
| 258 | + |
| 259 | + let ttft = firstTokenTime?.timeIntervalSince(ttftStart) ?? 0 |
| 260 | + let decodeTime = Date().timeIntervalSince(firstTokenTime ?? ttftStart) |
| 261 | + let tokPerSec = decodeTime > 0 && tokenCount > 1 ? Double(tokenCount - 1) / decodeTime : 0 |
| 262 | + let prefillTokPerSec = ttft > 0 ? Double(inputTokenCount) / ttft : 0 |
| 263 | + |
| 264 | + return InferenceResult( |
| 265 | + tokPerSec: tokPerSec, |
| 266 | + prefillTokPerSec: prefillTokPerSec, |
| 267 | + ttftMs: ttft * 1000, |
| 268 | + tokenCount: tokenCount |
| 269 | + ) |
| 270 | + } |
| 271 | + |
| 272 | + return result |
| 273 | + } catch { |
| 274 | + return nil |
| 275 | + } |
| 276 | + } |
| 277 | +} |
| 278 | + |
| 279 | +// MARK: - Supporting Types |
| 280 | + |
| 281 | +private struct InferenceResult { |
| 282 | + let tokPerSec: Double |
| 283 | + let prefillTokPerSec: Double |
| 284 | + let ttftMs: Double |
| 285 | + let tokenCount: Int |
| 286 | +} |
| 287 | + |
| 288 | +enum CalibratorError: Error { |
| 289 | + case allTrialsFailed |
| 290 | +} |
0 commit comments