1- // InferenceEngine.swift — Core MLX inference actor for SwiftLM Chat
2- // Extracted from Server.swift — no HTTP, no CLI, pure Swift concurrency .
1+ // InferenceEngine.swift — Core MLX inference engine for SwiftLM Chat
2+ // Handles: model load/unload, token streaming, memory/thermal pressure response .
33
44import Foundation
55import MLX
66import MLXLLM
77import MLXLMCommon
8+ import Hub
9+ #if canImport(UIKit)
10+ import UIKit
11+ #endif
12+
13+ // MARK: — Model State
814
9- /// The state of the inference engine.
1015public enum ModelState : Equatable , Sendable {
1116 case idle
1217 case downloading( progress: Double , speed: String )
@@ -16,104 +21,194 @@ public enum ModelState: Equatable, Sendable {
1621 case error( String )
1722}
1823
19- /// Token-level output from the generation stream.
24+ // MARK: — Thermal State
25+
26+ public enum ThermalLevel : Sendable {
27+ case nominal, fair, serious, critical
28+ public var displayString : String {
29+ switch self {
30+ case . nominal: return " Normal "
31+ case . fair: return " Warm "
32+ case . serious: return " Hot — generation may be slow "
33+ case . critical: return " Critical — generation paused "
34+ }
35+ }
36+ public var isThrottled : Bool { self == . serious || self == . critical }
37+ }
38+
39+ // MARK: — Generation Token
40+
2041public struct GenerationToken : Sendable {
2142 public let text : String
22- public let isThinking : Bool // true when inside <think>...</think>
43+ public let isThinking : Bool
2344
2445 public init ( text: String , isThinking: Bool = false ) {
2546 self . text = text
2647 self . isThinking = isThinking
2748 }
2849}
2950
30- /// Thread-safe MLX inference engine. One instance per app.
31- /// Uses Swift actor isolation so MLX calls never race.
51+ // MARK: — InferenceEngine
52+
3253@MainActor
3354public final class InferenceEngine : ObservableObject {
3455 @Published public private( set) var state : ModelState = . idle
56+ @Published public private( set) var thermalLevel : ThermalLevel = . nominal
3557
36- /// Shared download manager — exposes download progress and local cache state .
58+ /// Shared download + storage manager .
3759 public let downloadManager = ModelDownloadManager ( )
3860
3961 private var container : ModelContainer ?
4062 private var currentModelId : String ?
4163 private var generationTask : Task < Void , Never > ?
64+ private var pressureObserver : NSObjectProtocol ?
65+ private var thermalObserver : NSObjectProtocol ?
66+
67+ public init ( ) {
68+ setupPressureHandlers ( )
69+ }
4270
43- public init ( ) { }
71+ deinit {
72+ if let o = pressureObserver { NotificationCenter . default. removeObserver ( o) }
73+ if let o = thermalObserver { NotificationCenter . default. removeObserver ( o) }
74+ }
75+
76+ // MARK: — Pressure Handlers
77+
78+ private func setupPressureHandlers( ) {
79+ // iOS memory pressure → unload model weights immediately
80+ #if canImport(UIKit)
81+ pressureObserver = NotificationCenter . default. addObserver (
82+ forName: UIApplication . didReceiveMemoryWarningNotification,
83+ object: nil ,
84+ queue: . main
85+ ) { [ weak self] _ in
86+ Task { @MainActor [ weak self] in
87+ guard let self else { return }
88+ // Only unload if not actively generating
89+ if case . generating = self . state { return }
90+ self . unload ( )
91+ self . state = . error( " Unloaded due to memory pressure. Tap to reload. " )
92+ }
93+ }
94+ #endif
95+
96+ // Thermal state monitoring (all platforms)
97+ thermalObserver = NotificationCenter . default. addObserver (
98+ forName: ProcessInfo . thermalStateDidChangeNotification,
99+ object: nil ,
100+ queue: . main
101+ ) { [ weak self] _ in
102+ Task { @MainActor [ weak self] in
103+ self ? . updateThermalLevel ( )
104+ }
105+ }
106+ updateThermalLevel ( )
107+ }
108+
109+ private func updateThermalLevel( ) {
110+ switch ProcessInfo . processInfo. thermalState {
111+ case . nominal: thermalLevel = . nominal
112+ case . fair: thermalLevel = . fair
113+ case . serious: thermalLevel = . serious
114+ case . critical:
115+ thermalLevel = . critical
116+ // Critical: stop any generation immediately
117+ stopGeneration ( )
118+ @unknown default : thermalLevel = . nominal
119+ }
120+ }
44121
45122 // MARK: — Model Loading
46123
47124 /// Load a model by HuggingFace ID. Downloads if not cached.
125+ /// Uses ModelStorage.cacheRoot as the HubApi download base.
48126 public func load( modelId: String ) async {
49127 guard state != . ready( modelId: modelId) else { return }
128+ guard !thermalLevel. isThrottled else {
129+ state = . error( " Device is too hot. Let it cool before loading a model. " )
130+ return
131+ }
50132
51133 state = . loading
52134 currentModelId = modelId
53135
54136 do {
137+ // Point HubApi at ModelStorage.cacheRoot so downloads land in the right
138+ // place on both platforms (macOS: ~/.cache/HF, iOS: Application Support)
139+ let hub = HubApi ( downloadBase: ModelStorage . cacheRoot)
55140 let config = ModelConfiguration ( id: modelId)
141+
56142 container = try await LLMModelFactory . shared. loadContainer (
143+ hub: hub,
57144 configuration: config
58145 ) { [ weak self] progress in
59146 Task { @MainActor in
60147 guard let self else { return }
61148 let pct = progress. fractionCompleted
62- let speedMBps = progress. throughput. map { $0 / 1_000_000 }
63- let speedStr = speedMBps. map { String ( format: " %.1f MB/s " , $0) } ?? " "
149+ let speedBytesPerSec = progress. userInfo [ . throughputKey] as? Double
150+ let speedStr = speedBytesPerSec
151+ . map { String ( format: " %.1f MB/s " , $0 / 1_000_000 ) } ?? " "
64152 self . state = . downloading( progress: pct, speed: speedStr)
153+
65154 self . downloadManager. updateProgress ( ModelDownloadProgress (
66155 modelId: modelId,
67156 fractionCompleted: pct,
68- speedMBps: speedMBps
157+ currentFile: " " ,
158+ speedMBps: speedBytesPerSec. map { $0 / 1_000_000 }
69159 ) )
70160 }
71161 }
72- downloadManager. completeDownload ( modelId: modelId)
162+
163+ downloadManager. clearProgress ( modelId: modelId)
164+ downloadManager. lastLoadedModelId = modelId
165+ downloadManager. refresh ( )
73166 state = . ready( modelId: modelId)
167+
74168 } catch {
75- downloadManager. cancelDownload ( modelId: modelId)
169+ downloadManager. clearProgress ( modelId: modelId)
76170 state = . error( " Failed to load \( modelId) : \( error. localizedDescription) " )
77171 container = nil
78172 }
79173 }
80174
81- /// Unload the current model and free memory.
175+ /// Unload the current model and free all GPU memory.
82176 public func unload( ) {
83177 generationTask? . cancel ( )
84178 container = nil
85179 currentModelId = nil
86180 state = . idle
87- MLX . Memory . clearCache ( )
181+ MLX . GPU . set ( cacheLimit : 0 )
88182 }
89183
90184 // MARK: — Generation
91185
92- /// Generate a response as an AsyncStream of tokens.
93- /// Each yielded value is a `GenerationToken` (text + thinking flag).
94186 public nonisolated func generate(
95187 messages: [ ChatMessage ] ,
96188 config: GenerationConfig = . default
97189 ) -> AsyncStream < GenerationToken > {
98190 AsyncStream { continuation in
99191 Task { @MainActor in
100192 guard let container = self . container else {
101- continuation. finish ( )
102- return
193+ continuation. finish ( ) ; return
194+ }
195+
196+ // Don't generate when throttled
197+ if self . thermalLevel == . critical {
198+ continuation. yield ( GenerationToken ( text: " \n \n [Generation paused: device temperature critical] " ) )
199+ continuation. finish ( ) ; return
103200 }
104201
105202 self . state = . generating
106203
107204 do {
108- let mlxMessages = messages. map { msg -> [ String : String ] in
109- [ " role " : msg. role. rawValue, " content " : msg. content]
110- }
111-
112- // Build MLXLMCommon GenerateParameters
205+ let mlxMessages = messages. map { [ " role " : $0. role. rawValue, " content " : $0. content] }
113206 var params = GenerateParameters ( temperature: config. temperature)
114207 params. topP = config. topP
115208
116209 var thinkingActive = false
210+ var outputText = " "
211+ var tokenCount = 0
117212
118213 let userInput = UserInput ( messages: mlxMessages)
119214 let lmInput = try await container. prepare ( input: userInput)
@@ -122,21 +217,15 @@ public final class InferenceEngine: ObservableObject {
122217 parameters: params
123218 )
124219
125- var outputText = " "
126- var tokenCount = 0
127-
128220 for await generation in stream {
129- switch generation {
130- case . chunk( let text, tokenId: _) :
221+ guard !Task. isCancelled else { break }
222+
223+ if case . chunk( let text, tokenId: _) = generation {
131224 outputText += text
132225 tokenCount += 1
133226
134- if tokenCount >= config. maxTokens {
135- continuation. finish ( )
136- break
137- }
227+ if tokenCount >= config. maxTokens { break }
138228
139- // Thinking state tracking (<think> tags)
140229 if config. enableThinking {
141230 if outputText. contains ( " <think> " ) && !outputText. contains ( " </think> " ) {
142231 thinkingActive = true
@@ -146,13 +235,9 @@ public final class InferenceEngine: ObservableObject {
146235 }
147236
148237 continuation. yield ( GenerationToken ( text: text, isThinking: thinkingActive) )
149-
150- default :
151- break
152238 }
153239 }
154240 } catch {
155- // Yield error as a token so the UI can display it
156241 continuation. yield ( GenerationToken ( text: " \n \n [Error: \( error. localizedDescription) ] " ) )
157242 }
158243
@@ -162,12 +247,9 @@ public final class InferenceEngine: ObservableObject {
162247 }
163248 }
164249
165- /// Cancel any in-progress generation.
166250 public func stopGeneration( ) {
167251 generationTask? . cancel ( )
168252 generationTask = nil
169- if let id = currentModelId {
170- state = . ready( modelId: id)
171- }
253+ if let id = currentModelId { state = . ready( modelId: id) }
172254 }
173255}
0 commit comments