Skip to content

Commit dcc0a3a

Browse files
Add model loading progress for reloads
1 parent 4ac0c23 commit dcc0a3a

5 files changed

Lines changed: 134 additions & 83 deletions

File tree

Sources/MLXInferenceCore/InferenceEngine.swift

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ private struct TransformersTokenizerBridge: MLXLMCommon.Tokenizer, Sendable {
7272
public enum ModelState: Equatable, Sendable {
7373
case idle
7474
case downloading(progress: Double, speed: String)
75-
case loading
75+
case loading(progress: Double, stage: String)
7676
case ready(modelId: String)
7777
case generating
7878
case error(String)
@@ -319,7 +319,7 @@ public final class InferenceEngine: ObservableObject {
319319
}
320320

321321
private func loadVerifiedModel(modelId: String) async {
322-
state = .loading
322+
setLoadingState(progress: 0.05, stage: "Preparing model configuration")
323323
currentModelId = modelId
324324

325325
do {
@@ -354,36 +354,29 @@ public final class InferenceEngine: ObservableObject {
354354
print("[InferenceEngine] SSD expert streaming: disabled")
355355
}
356356

357+
setLoadingState(progress: 0.15, stage: "Inspecting model architecture")
357358
let downloader = HubDownloader(hub: hub)
358359
let architecture = try await ModelArchitectureProbe.inspect(
359360
configuration: config,
360361
downloader: downloader
361362
)
362363

363-
let speedTracker = DownloadSpeedTracker()
364+
let loadingStage = architecture.supportsVision
365+
? "Loading multimodal model"
366+
: "Loading language model"
367+
368+
setLoadingState(progress: 0.22, stage: loadingStage)
364369

365370
if architecture.supportsVision {
366371
container = try await VLMModelFactory.shared.loadContainer(
367372
from: downloader,
368373
using: TransformersTokenizerLoader(),
369374
configuration: config
370375
) { [weak self] progress in
371-
speedTracker.record(totalBytes: progress.completedUnitCount)
372-
let smoothedSpeed = speedTracker.speedBytesPerSec
373-
374376
Task { @MainActor in
375377
guard let self else { return }
376378
let pct = progress.fractionCompleted
377-
let speedStr = smoothedSpeed
378-
.map { String(format: "%.1f MB/s", $0 / 1_000_000) } ?? ""
379-
self.state = .downloading(progress: pct, speed: speedStr)
380-
381-
self.downloadManager.updateProgress(ModelDownloadProgress(
382-
modelId: modelId,
383-
fractionCompleted: pct,
384-
currentFile: "",
385-
speedMBps: smoothedSpeed.map { $0 / 1_000_000 }
386-
))
379+
self.setLoadingState(progress: 0.22 + (pct * 0.68), stage: loadingStage)
387380
}
388381
}
389382
} else {
@@ -392,22 +385,10 @@ public final class InferenceEngine: ObservableObject {
392385
using: TransformersTokenizerLoader(),
393386
configuration: config
394387
) { [weak self] progress in
395-
speedTracker.record(totalBytes: progress.completedUnitCount)
396-
let smoothedSpeed = speedTracker.speedBytesPerSec
397-
398388
Task { @MainActor in
399389
guard let self else { return }
400390
let pct = progress.fractionCompleted
401-
let speedStr = smoothedSpeed
402-
.map { String(format: "%.1f MB/s", $0 / 1_000_000) } ?? ""
403-
self.state = .downloading(progress: pct, speed: speedStr)
404-
405-
self.downloadManager.updateProgress(ModelDownloadProgress(
406-
modelId: modelId,
407-
fractionCompleted: pct,
408-
currentFile: "",
409-
speedMBps: smoothedSpeed.map { $0 / 1_000_000 }
410-
))
391+
self.setLoadingState(progress: 0.22 + (pct * 0.68), stage: loadingStage)
411392
}
412393
}
413394
}
@@ -417,11 +398,13 @@ public final class InferenceEngine: ObservableObject {
417398
downloadManager.refresh()
418399

419400
// Verify integrity to catch incomplete downloads before marking as ready
401+
setLoadingState(progress: 0.94, stage: "Verifying model files")
420402
guard ModelStorage.verifyModelIntegrity(for: modelId) else {
421403
throw NSError(domain: "InferenceEngine", code: 1, userInfo: [NSLocalizedDescriptionKey: "Model safetensors files are incomplete. Please delete and re-download."])
422404
}
423405

424406
// Read the model's actual max context length from config.json
407+
setLoadingState(progress: 0.98, stage: "Reading model limits")
425408
if let ctxLen = ModelStorage.readMaxContextLength(for: modelId) {
426409
self.maxContextWindow = ctxLen
427410
print("[InferenceEngine] Model context window: \(ctxLen) tokens")
@@ -471,6 +454,10 @@ public final class InferenceEngine: ObservableObject {
471454
MLX.Memory.cacheLimit = 0
472455
}
473456

457+
private func setLoadingState(progress: Double, stage: String) {
458+
state = .loading(progress: min(max(progress, 0), 1), stage: stage)
459+
}
460+
474461
private func markModelCorrupted(modelId: String?, message: String) {
475462
let failedModelId = modelId ?? currentModelId
476463
releaseLoadedModelResources()
@@ -622,7 +609,7 @@ extension InferenceEngine {
622609
// Use the real token count from the prepared LMInput rather than
623610
// a character-length heuristic (which was consistently off by 2–3×
624611
// for CJK and code content).
625-
let baseTokens = lmInput.text.tokens.shape[0]
612+
let baseTokens = lmInput.text.tokens.size
626613
self.activeContextTokens = baseTokens
627614

628615
// maxContextWindow is already set during loadModel() from config.json

SwiftBuddy/SwiftBuddy/Views/ChatView.swift

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -137,19 +137,28 @@ struct ChatView: View {
137137
case .downloading(let progress, let speed):
138138
DownloadAnimationView(progress: progress, speed: speed)
139139

140-
case .loading:
140+
case .loading(let progress, let stage):
141141
VStack(spacing: 16) {
142142
ZStack {
143143
Circle()
144144
.stroke(SwiftBuddyTheme.accent.opacity(0.15), lineWidth: 3)
145145
.frame(width: 64, height: 64)
146-
ProgressView()
146+
ProgressView(value: progress)
147147
.controlSize(.large)
148148
.tint(SwiftBuddyTheme.accent)
149+
.frame(width: 64)
150+
}
151+
VStack(spacing: 4) {
152+
Text("Loading model into Metal GPU…")
153+
.font(.subheadline)
154+
.foregroundStyle(SwiftBuddyTheme.textSecondary)
155+
Text(stage)
156+
.font(.caption)
157+
.foregroundStyle(SwiftBuddyTheme.textTertiary)
158+
Text("\(Int(progress * 100))%")
159+
.font(.caption.monospacedDigit())
160+
.foregroundStyle(SwiftBuddyTheme.textTertiary)
149161
}
150-
Text("Loading model into Metal GPU…")
151-
.font(.subheadline)
152-
.foregroundStyle(SwiftBuddyTheme.textSecondary)
153162
}
154163

155164
case .idle:
@@ -252,13 +261,18 @@ struct ChatView: View {
252261
switch engine.state {
253262
case .idle:
254263
bannerRow(icon: "cpu", text: "No model loaded", color: SwiftBuddyTheme.textTertiary)
255-
case .loading:
256-
HStack(spacing: 8) {
257-
ProgressView().controlSize(.mini).tint(SwiftBuddyTheme.accent)
258-
Text("Loading model…")
259-
.font(.caption)
260-
.foregroundStyle(SwiftBuddyTheme.textSecondary)
261-
Spacer()
264+
case .loading(let progress, let stage):
265+
VStack(alignment: .leading, spacing: 4) {
266+
HStack {
267+
Text(stage)
268+
.font(.caption.weight(.medium))
269+
.foregroundStyle(SwiftBuddyTheme.textSecondary)
270+
Spacer()
271+
Text("\(Int(progress * 100))%")
272+
.font(.caption2.monospacedDigit())
273+
.foregroundStyle(SwiftBuddyTheme.textTertiary)
274+
}
275+
ProgressView(value: progress).tint(SwiftBuddyTheme.accent)
262276
}
263277
.padding(.horizontal, 16)
264278
.padding(.vertical, 8)
@@ -527,7 +541,7 @@ extension ModelState {
527541
var shortLabel: String {
528542
switch self {
529543
case .idle: return "No model"
530-
case .loading: return "Loading…"
544+
case .loading(let progress, _): return "\(Int(progress * 100))% loading"
531545
case .downloading(let p, _): return "\(Int(p * 100))% downloading"
532546
case .ready(let modelId): return modelId.components(separatedBy: "/").last ?? modelId
533547
case .generating: return "Generating"

SwiftBuddy/SwiftBuddy/Views/ModelsView.swift

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,8 @@ private struct ActiveModelCardView: View {
277277
entry: engine.loadedModelId.flatMap { id in ModelCatalog.all.first(where: { $0.id == id }) },
278278
state: engine.state
279279
)
280-
case .loading:
281-
loadingCard
280+
case .loading(let progress, let stage):
281+
loadingCard(progress: progress, stage: stage)
282282
case .downloading(let progress, let speed):
283283
downloadingCard(progress: progress, speed: speed)
284284
case .idle, .error:
@@ -287,18 +287,24 @@ private struct ActiveModelCardView: View {
287287
}
288288
}
289289

290-
private var loadingCard: some View {
291-
HStack(spacing: 12) {
292-
ProgressView().controlSize(.regular).tint(SwiftBuddyTheme.accent)
293-
VStack(alignment: .leading, spacing: 2) {
294-
Text("Loading model…")
295-
.font(.subheadline.weight(.semibold))
296-
.foregroundStyle(SwiftBuddyTheme.textPrimary)
297-
Text("Initializing Metal GPU")
298-
.font(.caption)
290+
private func loadingCard(progress: Double, stage: String) -> some View {
291+
VStack(alignment: .leading, spacing: 10) {
292+
HStack {
293+
ProgressView().controlSize(.regular).tint(SwiftBuddyTheme.accent)
294+
VStack(alignment: .leading, spacing: 2) {
295+
Text("Loading model…")
296+
.font(.subheadline.weight(.semibold))
297+
.foregroundStyle(SwiftBuddyTheme.textPrimary)
298+
Text(stage)
299+
.font(.caption)
300+
.foregroundStyle(SwiftBuddyTheme.textSecondary)
301+
}
302+
Spacer()
303+
Text("\(Int(progress * 100))%")
304+
.font(.caption.monospacedDigit())
299305
.foregroundStyle(SwiftBuddyTheme.textSecondary)
300306
}
301-
Spacer()
307+
ProgressView(value: progress).tint(SwiftBuddyTheme.accent)
302308
}
303309
.padding()
304310
.glassCard(cornerRadius: SwiftBuddyTheme.radiusLarge)

SwiftBuddy/SwiftBuddy/Views/RootView.swift

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ struct RootView: View {
2626
@State private var showTextIngestion = false
2727
@State private var showModelManagement = false
2828
@State private var lastDownloadLogBucket: Int?
29+
@State private var lastLoadingStage: String?
2930
enum Tab { case chat, models, palace, mindPalace, miner, settings }
3031

3132
var body: some View {
@@ -72,11 +73,16 @@ struct RootView: View {
7273
switch newState {
7374
case .idle:
7475
lastDownloadLogBucket = nil
76+
lastLoadingStage = nil
7577
ConsoleLog.shared.info("Engine idle — no model loaded")
76-
case .loading:
78+
case .loading(_, let stage):
7779
lastDownloadLogBucket = nil
78-
ConsoleLog.shared.info("Loading model…")
80+
if lastLoadingStage != stage {
81+
lastLoadingStage = stage
82+
ConsoleLog.shared.info(stage)
83+
}
7984
case .downloading(let p, let speed):
85+
lastLoadingStage = nil
8086
let percent = Int(p * 100)
8187
let bucket = min((percent / 25) * 25, 100)
8288
if bucket != lastDownloadLogBucket, [0, 25, 50, 75, 100].contains(bucket) {
@@ -85,12 +91,15 @@ struct RootView: View {
8591
}
8692
case .ready(let modelId):
8793
lastDownloadLogBucket = nil
94+
lastLoadingStage = nil
8895
ConsoleLog.shared.info("✓ Model ready: \(modelId)")
8996
case .generating:
9097
lastDownloadLogBucket = nil
98+
lastLoadingStage = nil
9199
ConsoleLog.shared.debug("Generating…")
92100
case .error(let msg):
93101
lastDownloadLogBucket = nil
102+
lastLoadingStage = nil
94103
ConsoleLog.shared.error("Engine error: \(msg)")
95104
}
96105
}
@@ -430,12 +439,12 @@ struct RootView: View {
430439
.tint(SwiftBuddyTheme.accent)
431440
.controlSize(.small)
432441

433-
case .loading:
434-
HStack(spacing: 6) {
435-
ProgressView().controlSize(.mini).tint(SwiftBuddyTheme.accent)
436-
Text("Loading…")
437-
.font(.caption)
438-
.foregroundStyle(SwiftBuddyTheme.textSecondary)
442+
case .loading(let progress, let stage):
443+
VStack(alignment: .leading, spacing: 4) {
444+
ProgressView(value: progress).tint(SwiftBuddyTheme.accent)
445+
Text("\(Int(progress * 100))% · \(stage)")
446+
.font(.caption2.monospacedDigit())
447+
.foregroundStyle(SwiftBuddyTheme.textTertiary)
439448
}
440449

441450
case .downloading(let progress, let speed):

SwiftBuddy/SwiftBuddy/Views/SettingsView.swift

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -527,29 +527,64 @@ struct SettingsView: View {
527527
if viewModel.config.streamExperts != (ModelCatalog.all.first(where: {
528528
if case .ready(let id) = engine.state { return $0.id == id } else { return false }
529529
})?.isMoE ?? false) {
530-
HStack(spacing: 6) {
531-
Image(systemName: "arrow.clockwise.circle.fill")
532-
.foregroundStyle(SwiftBuddyTheme.warning)
533-
.font(.caption)
534-
Text("Reload model to apply this change")
535-
.font(.caption2.weight(.medium))
536-
.foregroundStyle(SwiftBuddyTheme.warning)
537-
Spacer()
538-
Button("Reload") {
539-
let currentId: String? = {
540-
if case .ready(let id) = engine.state { return id }
541-
return nil
542-
}()
543-
if let id = currentId {
544-
Task {
545-
engine.unload()
546-
await engine.load(modelId: id)
530+
VStack(alignment: .leading, spacing: 8) {
531+
HStack(spacing: 6) {
532+
Image(systemName: "arrow.clockwise.circle.fill")
533+
.foregroundStyle(SwiftBuddyTheme.warning)
534+
.font(.caption)
535+
Text("Reload model to apply this change")
536+
.font(.caption2.weight(.medium))
537+
.foregroundStyle(SwiftBuddyTheme.warning)
538+
Spacer()
539+
Button("Reload") {
540+
let currentId: String? = {
541+
if case .ready(let id) = engine.state { return id }
542+
return nil
543+
}()
544+
if let id = currentId {
545+
Task {
546+
engine.unload()
547+
await engine.load(modelId: id)
548+
}
547549
}
548550
}
551+
.font(.caption2.weight(.semibold))
552+
.foregroundStyle(SwiftBuddyTheme.accent)
553+
.buttonStyle(.plain)
554+
}
555+
556+
switch engine.state {
557+
case .loading(let progress, let stage):
558+
VStack(alignment: .leading, spacing: 4) {
559+
HStack {
560+
Text(stage)
561+
.font(.caption2.weight(.medium))
562+
.foregroundStyle(SwiftBuddyTheme.textSecondary)
563+
Spacer()
564+
Text("\(Int(progress * 100))%")
565+
.font(.caption2.monospacedDigit())
566+
.foregroundStyle(SwiftBuddyTheme.textTertiary)
567+
}
568+
ProgressView(value: progress)
569+
.tint(SwiftBuddyTheme.accent)
570+
}
571+
case .downloading(let progress, let speed):
572+
VStack(alignment: .leading, spacing: 4) {
573+
HStack {
574+
Text("Downloading model files")
575+
.font(.caption2.weight(.medium))
576+
.foregroundStyle(SwiftBuddyTheme.textSecondary)
577+
Spacer()
578+
Text("\(Int(progress * 100))% · \(speed)")
579+
.font(.caption2.monospacedDigit())
580+
.foregroundStyle(SwiftBuddyTheme.textTertiary)
581+
}
582+
ProgressView(value: progress)
583+
.tint(SwiftBuddyTheme.accent)
584+
}
585+
default:
586+
EmptyView()
549587
}
550-
.font(.caption2.weight(.semibold))
551-
.foregroundStyle(SwiftBuddyTheme.accent)
552-
.buttonStyle(.plain)
553588
}
554589
.padding(.horizontal, 4)
555590
.padding(.vertical, 6)

0 commit comments

Comments
 (0)