Skip to content

Commit dc4069f

Browse files
committed
feat(swiftlmchat): platform-aware model management for iOS + macOS
ModelStorage (new): - macOS: Library/Caches/huggingface/hub/ (matches defaultHubApi) - iOS: Library/Application Support/SwiftLMChat/Models/ + isExcludedFromBackup - Platform-agnostic scan, sizeOnDisk, delete primitives ModelDownloader (new, iOS only): - URLSession background session (survives app suspension) - HuggingFace API file enumeration (GET /api/models/{id}) - Per-file download with progress streaming - macOS: LLMModelFactory handles download directly (no change) ModelDownloadManager refactor: - Built on ModelStorage + ModelDownloader layers - NWPathMonitor for WiFi/cellular/offline detection - iOS RAM budget: 40% (vs 75% macOS) via modelsForDevice() - Cellular threshold: warn before >200MB downloads on cellular - updateProgress() / clearProgress() for InferenceEngine bridge InferenceEngine: - UIApplication.didReceiveMemoryWarningNotification → auto-unload (iOS) - ProcessInfo.thermalStateDidChangeNotification → ThermalLevel @published - Critical thermal → stop generation immediately - HubApi.downloadBase redirected to ModelStorage.cacheRoot ModelPickerView: - Network status banner (offline / cellular warning) - Thermal warning banner - Cellular confirmation dialog before large downloads - handleModelTap() blocks download when offline SwiftLMChat.entitlements (new): - com.apple.developer.kernel.increased-memory-limit - UIBackgroundModes: fetch, processing Package.swift: add Hub product to MLXInferenceCore dependencies
1 parent 32b7483 commit dc4069f

8 files changed

Lines changed: 681 additions & 159 deletions

File tree

Package.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ let package = Package(
4141
.product(name: "MLX", package: "mlx-swift"),
4242
.product(name: "MLXLLM", package: "mlx-swift-lm"),
4343
.product(name: "MLXLMCommon", package: "mlx-swift-lm"),
44+
.product(name: "Hub", package: "swift-transformers"),
4445
],
4546
path: "Sources/MLXInferenceCore",
4647
swiftSettings: [
Lines changed: 125 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1-
// InferenceEngine.swift — Core MLX inference actor for SwiftLM Chat
2-
// Extracted from Server.swift — no HTTP, no CLI, pure Swift concurrency.
1+
// InferenceEngine.swift — Core MLX inference engine for SwiftLM Chat
2+
// Handles: model load/unload, token streaming, memory/thermal pressure response.
33

44
import Foundation
55
import MLX
66
import MLXLLM
77
import MLXLMCommon
8+
import Hub
9+
#if canImport(UIKit)
10+
import UIKit
11+
#endif
12+
13+
// MARK: — Model State
814

9-
/// The state of the inference engine.
1015
public enum ModelState: Equatable, Sendable {
1116
case idle
1217
case downloading(progress: Double, speed: String)
@@ -16,104 +21,194 @@ public enum ModelState: Equatable, Sendable {
1621
case error(String)
1722
}
1823

19-
/// Token-level output from the generation stream.
24+
// MARK: — Thermal State
25+
26+
public enum ThermalLevel: Sendable {
27+
case nominal, fair, serious, critical
28+
public var displayString: String {
29+
switch self {
30+
case .nominal: return "Normal"
31+
case .fair: return "Warm"
32+
case .serious: return "Hot — generation may be slow"
33+
case .critical: return "Critical — generation paused"
34+
}
35+
}
36+
public var isThrottled: Bool { self == .serious || self == .critical }
37+
}
38+
39+
// MARK: — Generation Token
40+
2041
public struct GenerationToken: Sendable {
2142
public let text: String
22-
public let isThinking: Bool // true when inside <think>...</think>
43+
public let isThinking: Bool
2344

2445
public init(text: String, isThinking: Bool = false) {
2546
self.text = text
2647
self.isThinking = isThinking
2748
}
2849
}
2950

30-
/// Thread-safe MLX inference engine. One instance per app.
31-
/// Uses Swift actor isolation so MLX calls never race.
51+
// MARK: — InferenceEngine
52+
3253
@MainActor
3354
public final class InferenceEngine: ObservableObject {
3455
@Published public private(set) var state: ModelState = .idle
56+
@Published public private(set) var thermalLevel: ThermalLevel = .nominal
3557

36-
/// Shared download manager — exposes download progress and local cache state.
58+
/// Shared download + storage manager.
3759
public let downloadManager = ModelDownloadManager()
3860

3961
private var container: ModelContainer?
4062
private var currentModelId: String?
4163
private var generationTask: Task<Void, Never>?
64+
private var pressureObserver: NSObjectProtocol?
65+
private var thermalObserver: NSObjectProtocol?
66+
67+
public init() {
68+
setupPressureHandlers()
69+
}
4270

43-
public init() {}
71+
deinit {
72+
if let o = pressureObserver { NotificationCenter.default.removeObserver(o) }
73+
if let o = thermalObserver { NotificationCenter.default.removeObserver(o) }
74+
}
75+
76+
// MARK: — Pressure Handlers
77+
78+
private func setupPressureHandlers() {
79+
// iOS memory pressure → unload model weights immediately
80+
#if canImport(UIKit)
81+
pressureObserver = NotificationCenter.default.addObserver(
82+
forName: UIApplication.didReceiveMemoryWarningNotification,
83+
object: nil,
84+
queue: .main
85+
) { [weak self] _ in
86+
Task { @MainActor [weak self] in
87+
guard let self else { return }
88+
// Only unload if not actively generating
89+
if case .generating = self.state { return }
90+
self.unload()
91+
self.state = .error("Unloaded due to memory pressure. Tap to reload.")
92+
}
93+
}
94+
#endif
95+
96+
// Thermal state monitoring (all platforms)
97+
thermalObserver = NotificationCenter.default.addObserver(
98+
forName: ProcessInfo.thermalStateDidChangeNotification,
99+
object: nil,
100+
queue: .main
101+
) { [weak self] _ in
102+
Task { @MainActor [weak self] in
103+
self?.updateThermalLevel()
104+
}
105+
}
106+
updateThermalLevel()
107+
}
108+
109+
private func updateThermalLevel() {
110+
switch ProcessInfo.processInfo.thermalState {
111+
case .nominal: thermalLevel = .nominal
112+
case .fair: thermalLevel = .fair
113+
case .serious: thermalLevel = .serious
114+
case .critical:
115+
thermalLevel = .critical
116+
// Critical: stop any generation immediately
117+
stopGeneration()
118+
@unknown default: thermalLevel = .nominal
119+
}
120+
}
44121

45122
// MARK: — Model Loading
46123

47124
/// Load a model by HuggingFace ID. Downloads if not cached.
125+
/// Uses ModelStorage.cacheRoot as the HubApi download base.
48126
public func load(modelId: String) async {
49127
guard state != .ready(modelId: modelId) else { return }
128+
guard !thermalLevel.isThrottled else {
129+
state = .error("Device is too hot. Let it cool before loading a model.")
130+
return
131+
}
50132

51133
state = .loading
52134
currentModelId = modelId
53135

54136
do {
137+
// Point HubApi at ModelStorage.cacheRoot so downloads land in the right
138+
// place on both platforms (macOS: ~/.cache/HF, iOS: Application Support)
139+
let hub = HubApi(downloadBase: ModelStorage.cacheRoot)
55140
let config = ModelConfiguration(id: modelId)
141+
56142
container = try await LLMModelFactory.shared.loadContainer(
143+
hub: hub,
57144
configuration: config
58145
) { [weak self] progress in
59146
Task { @MainActor in
60147
guard let self else { return }
61148
let pct = progress.fractionCompleted
62-
let speedMBps = progress.throughput.map { $0 / 1_000_000 }
63-
let speedStr = speedMBps.map { String(format: "%.1f MB/s", $0) } ?? ""
149+
let speedBytesPerSec = progress.userInfo[.throughputKey] as? Double
150+
let speedStr = speedBytesPerSec
151+
.map { String(format: "%.1f MB/s", $0 / 1_000_000) } ?? ""
64152
self.state = .downloading(progress: pct, speed: speedStr)
153+
65154
self.downloadManager.updateProgress(ModelDownloadProgress(
66155
modelId: modelId,
67156
fractionCompleted: pct,
68-
speedMBps: speedMBps
157+
currentFile: "",
158+
speedMBps: speedBytesPerSec.map { $0 / 1_000_000 }
69159
))
70160
}
71161
}
72-
downloadManager.completeDownload(modelId: modelId)
162+
163+
downloadManager.clearProgress(modelId: modelId)
164+
downloadManager.lastLoadedModelId = modelId
165+
downloadManager.refresh()
73166
state = .ready(modelId: modelId)
167+
74168
} catch {
75-
downloadManager.cancelDownload(modelId: modelId)
169+
downloadManager.clearProgress(modelId: modelId)
76170
state = .error("Failed to load \(modelId): \(error.localizedDescription)")
77171
container = nil
78172
}
79173
}
80174

81-
/// Unload the current model and free memory.
175+
/// Unload the current model and free all GPU memory.
82176
public func unload() {
83177
generationTask?.cancel()
84178
container = nil
85179
currentModelId = nil
86180
state = .idle
87-
MLX.Memory.clearCache()
181+
MLX.GPU.set(cacheLimit: 0)
88182
}
89183

90184
// MARK: — Generation
91185

92-
/// Generate a response as an AsyncStream of tokens.
93-
/// Each yielded value is a `GenerationToken` (text + thinking flag).
94186
public nonisolated func generate(
95187
messages: [ChatMessage],
96188
config: GenerationConfig = .default
97189
) -> AsyncStream<GenerationToken> {
98190
AsyncStream { continuation in
99191
Task { @MainActor in
100192
guard let container = self.container else {
101-
continuation.finish()
102-
return
193+
continuation.finish(); return
194+
}
195+
196+
// Don't generate when throttled
197+
if self.thermalLevel == .critical {
198+
continuation.yield(GenerationToken(text: "\n\n[Generation paused: device temperature critical]"))
199+
continuation.finish(); return
103200
}
104201

105202
self.state = .generating
106203

107204
do {
108-
let mlxMessages = messages.map { msg -> [String: String] in
109-
["role": msg.role.rawValue, "content": msg.content]
110-
}
111-
112-
// Build MLXLMCommon GenerateParameters
205+
let mlxMessages = messages.map { ["role": $0.role.rawValue, "content": $0.content] }
113206
var params = GenerateParameters(temperature: config.temperature)
114207
params.topP = config.topP
115208

116209
var thinkingActive = false
210+
var outputText = ""
211+
var tokenCount = 0
117212

118213
let userInput = UserInput(messages: mlxMessages)
119214
let lmInput = try await container.prepare(input: userInput)
@@ -122,21 +217,15 @@ public final class InferenceEngine: ObservableObject {
122217
parameters: params
123218
)
124219

125-
var outputText = ""
126-
var tokenCount = 0
127-
128220
for await generation in stream {
129-
switch generation {
130-
case .chunk(let text, tokenId: _):
221+
guard !Task.isCancelled else { break }
222+
223+
if case .chunk(let text, tokenId: _) = generation {
131224
outputText += text
132225
tokenCount += 1
133226

134-
if tokenCount >= config.maxTokens {
135-
continuation.finish()
136-
break
137-
}
227+
if tokenCount >= config.maxTokens { break }
138228

139-
// Thinking state tracking (<think> tags)
140229
if config.enableThinking {
141230
if outputText.contains("<think>") && !outputText.contains("</think>") {
142231
thinkingActive = true
@@ -146,13 +235,9 @@ public final class InferenceEngine: ObservableObject {
146235
}
147236

148237
continuation.yield(GenerationToken(text: text, isThinking: thinkingActive))
149-
150-
default:
151-
break
152238
}
153239
}
154240
} catch {
155-
// Yield error as a token so the UI can display it
156241
continuation.yield(GenerationToken(text: "\n\n[Error: \(error.localizedDescription)]"))
157242
}
158243

@@ -162,12 +247,9 @@ public final class InferenceEngine: ObservableObject {
162247
}
163248
}
164249

165-
/// Cancel any in-progress generation.
166250
public func stopGeneration() {
167251
generationTask?.cancel()
168252
generationTask = nil
169-
if let id = currentModelId {
170-
state = .ready(modelId: id)
171-
}
253+
if let id = currentModelId { state = .ready(modelId: id) }
172254
}
173255
}

0 commit comments

Comments
 (0)