@@ -72,7 +72,7 @@ private struct TransformersTokenizerBridge: MLXLMCommon.Tokenizer, Sendable {
7272public enum ModelState : Equatable , Sendable {
7373 case idle
7474 case downloading( progress: Double , speed: String )
75- case loading
75+ case loading( progress : Double , stage : String )
7676 case ready( modelId: String )
7777 case generating
7878 case error( String )
@@ -319,7 +319,7 @@ public final class InferenceEngine: ObservableObject {
319319 }
320320
321321 private func loadVerifiedModel( modelId: String ) async {
322- state = . loading
322+ setLoadingState ( progress : 0.05 , stage : " Preparing model configuration " )
323323 currentModelId = modelId
324324
325325 do {
@@ -354,36 +354,29 @@ public final class InferenceEngine: ObservableObject {
354354 print ( " [InferenceEngine] SSD expert streaming: disabled " )
355355 }
356356
357+ setLoadingState ( progress: 0.15 , stage: " Inspecting model architecture " )
357358 let downloader = HubDownloader ( hub: hub)
358359 let architecture = try await ModelArchitectureProbe . inspect (
359360 configuration: config,
360361 downloader: downloader
361362 )
362363
363- let speedTracker = DownloadSpeedTracker ( )
364+ let loadingStage = architecture. supportsVision
365+ ? " Loading multimodal model "
366+ : " Loading language model "
367+
368+ setLoadingState ( progress: 0.22 , stage: loadingStage)
364369
365370 if architecture. supportsVision {
366371 container = try await VLMModelFactory . shared. loadContainer (
367372 from: downloader,
368373 using: TransformersTokenizerLoader ( ) ,
369374 configuration: config
370375 ) { [ weak self] progress in
371- speedTracker. record ( totalBytes: progress. completedUnitCount)
372- let smoothedSpeed = speedTracker. speedBytesPerSec
373-
374376 Task { @MainActor in
375377 guard let self else { return }
376378 let pct = progress. fractionCompleted
377- let speedStr = smoothedSpeed
378- . map { String ( format: " %.1f MB/s " , $0 / 1_000_000 ) } ?? " "
379- self . state = . downloading( progress: pct, speed: speedStr)
380-
381- self . downloadManager. updateProgress ( ModelDownloadProgress (
382- modelId: modelId,
383- fractionCompleted: pct,
384- currentFile: " " ,
385- speedMBps: smoothedSpeed. map { $0 / 1_000_000 }
386- ) )
379+ self . setLoadingState ( progress: 0.22 + ( pct * 0.68 ) , stage: loadingStage)
387380 }
388381 }
389382 } else {
@@ -392,22 +385,10 @@ public final class InferenceEngine: ObservableObject {
392385 using: TransformersTokenizerLoader ( ) ,
393386 configuration: config
394387 ) { [ weak self] progress in
395- speedTracker. record ( totalBytes: progress. completedUnitCount)
396- let smoothedSpeed = speedTracker. speedBytesPerSec
397-
398388 Task { @MainActor in
399389 guard let self else { return }
400390 let pct = progress. fractionCompleted
401- let speedStr = smoothedSpeed
402- . map { String ( format: " %.1f MB/s " , $0 / 1_000_000 ) } ?? " "
403- self . state = . downloading( progress: pct, speed: speedStr)
404-
405- self . downloadManager. updateProgress ( ModelDownloadProgress (
406- modelId: modelId,
407- fractionCompleted: pct,
408- currentFile: " " ,
409- speedMBps: smoothedSpeed. map { $0 / 1_000_000 }
410- ) )
391+ self . setLoadingState ( progress: 0.22 + ( pct * 0.68 ) , stage: loadingStage)
411392 }
412393 }
413394 }
@@ -417,11 +398,13 @@ public final class InferenceEngine: ObservableObject {
417398 downloadManager. refresh ( )
418399
419400 // Verify integrity to catch incomplete downloads before marking as ready
401+ setLoadingState ( progress: 0.94 , stage: " Verifying model files " )
420402 guard ModelStorage . verifyModelIntegrity ( for: modelId) else {
421403 throw NSError ( domain: " InferenceEngine " , code: 1 , userInfo: [ NSLocalizedDescriptionKey: " Model safetensors files are incomplete. Please delete and re-download. " ] )
422404 }
423405
424406 // Read the model's actual max context length from config.json
407+ setLoadingState ( progress: 0.98 , stage: " Reading model limits " )
425408 if let ctxLen = ModelStorage . readMaxContextLength ( for: modelId) {
426409 self . maxContextWindow = ctxLen
427410 print ( " [InferenceEngine] Model context window: \( ctxLen) tokens " )
@@ -471,6 +454,10 @@ public final class InferenceEngine: ObservableObject {
471454 MLX . Memory. cacheLimit = 0
472455 }
473456
457+ private func setLoadingState( progress: Double , stage: String ) {
458+ state = . loading( progress: min ( max ( progress, 0 ) , 1 ) , stage: stage)
459+ }
460+
474461 private func markModelCorrupted( modelId: String ? , message: String ) {
475462 let failedModelId = modelId ?? currentModelId
476463 releaseLoadedModelResources ( )
@@ -622,7 +609,7 @@ extension InferenceEngine {
622609 // Use the real token count from the prepared LMInput rather than
623610 // a character-length heuristic (which was consistently off by 2–3×
624611 // for CJK and code content).
625- let baseTokens = lmInput. text. tokens. shape [ 0 ]
612+ let baseTokens = lmInput. text. tokens. size
626613 self . activeContextTokens = baseTokens
627614
628615 // maxContextWindow is already set during loadModel() from config.json
0 commit comments