Skip to content

Commit f1dddb8

Browse files
authored
Merge pull request #101 from SharpAI/fix/qwen3-jinja-template-issue-97
fix: address post-merge PR 99 feedback and tests
2 parents ba50749 + 7870b2f commit f1dddb8

9 files changed

Lines changed: 941 additions & 155 deletions

File tree

Sources/MLXInferenceCore/InferenceEngine.swift

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -613,24 +613,27 @@ extension InferenceEngine {
613613

614614
// maxContextWindow is already set during loadModel() from config.json
615615

616-
// TurboKV: enable 3-bit PolarQuant+QJL on every KVCacheSimple layer
617-
// before generation. Must be set on the model (not the cache) so the
618-
// cache inherits the flag when newCache() is called inside generate().
616+
// TurboKV: enable 3-bit PolarQuant+QJL on every KVCacheSimple cache layer.
617+
// KVCacheSimple is a cache object (not a neural-network Module), so we
618+
// iterate the cache array — mirroring the pattern in Server.swift.
619+
let cache = await container.perform { ctx in ctx.model.newCache(parameters: params) }
619620
if config.turboKV {
620-
await container.perform { ctx in
621-
for module in ctx.model.modules() {
622-
if let simple = module as? KVCacheSimple {
623-
simple.turboQuantEnabled = true
624-
}
621+
for layer in cache {
622+
if let simple = layer as? KVCacheSimple {
623+
simple.turboQuantEnabled = true
625624
}
626625
}
627626
print("[InferenceEngine] TurboKV enabled for this request")
628627
}
629628

630-
let stream: AsyncStream<Generation> = try await container.generate(
631-
input: lmInput,
632-
parameters: params
633-
)
629+
let stream: AsyncStream<Generation> = try await container.perform { ctx in
630+
try MLXLMCommon.generate(
631+
input: lmInput,
632+
cache: cache,
633+
parameters: params,
634+
context: ctx
635+
)
636+
}
634637

635638
for await generation in stream {
636639
guard !Task.isCancelled else { break }

SwiftBuddy/SwiftBuddy/ViewModels/ServerManager.swift

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ final class ServerManager: ObservableObject {
130130
guard !isOnline else { return }
131131
let configuration = startupConfiguration.normalized
132132

133-
task = Task {
133+
task = Task.detached { [weak self] in
134+
guard let self = self else { return }
134135
do {
135136
let router = Router()
136137

@@ -259,18 +260,24 @@ final class ServerManager: ObservableObject {
259260
configuration: .init(address: .hostname(configuration.host, port: configuration.port))
260261
)
261262

262-
self.isOnline = true
263-
self.host = configuration.host
264-
self.port = configuration.port
265-
self.runningConfiguration = configuration
266-
self.restartRequired = false
263+
await MainActor.run {
264+
self.isOnline = true
265+
self.host = configuration.host
266+
self.port = configuration.port
267+
self.runningConfiguration = configuration
268+
self.restartRequired = false
269+
}
267270
ConsoleLog.shared.info("Server online at http://\(configuration.host):\(configuration.port)")
268271

269272
try await app.runService()
270273
} catch {
271274
print("Server failed: \(error)")
272275
ConsoleLog.shared.error("Server failed: \(error.localizedDescription)")
273-
self.isOnline = false
276+
await MainActor.run {
277+
self.isOnline = false
278+
self.runningConfiguration = nil
279+
self.restartRequired = false
280+
}
274281
}
275282
}
276283
}

SwiftBuddy/SwiftBuddy/Views/SettingsView.swift

Lines changed: 86 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,24 @@ struct SettingsView: View {
3737
return ModelCatalog.all.first(where: { $0.id == modelId })?.isMoE ?? false
3838
}
3939

40+
private var currentModelId: String? {
41+
guard case .ready(let modelId) = engine.state else { return nil }
42+
return modelId
43+
}
44+
4045
private var effectiveStreamExpertsSetting: Bool {
4146
viewModel.config.effectiveStreamExperts(defaultingTo: currentModelIsMoE)
4247
}
4348

49+
// Tracks the stream-experts value that was in effect when the current model was loaded.
50+
// A mismatch with `effectiveStreamExpertsSetting` means a reload is required.
51+
@State private var appliedStreamExperts: Bool? = nil
52+
53+
private var needsModelReloadForStreamingChange: Bool {
54+
guard let applied = appliedStreamExperts else { return false }
55+
return effectiveStreamExpertsSetting != applied
56+
}
57+
4458
private var ssdStreamingBinding: Binding<Bool> {
4559
Binding(
4660
get: { effectiveStreamExpertsSetting },
@@ -117,6 +131,16 @@ struct SettingsView: View {
117131
}
118132
.onAppear {
119133
draftServerConfiguration = server.startupConfiguration
134+
// Seed the applied value from the current engine state so the reload
135+
// prompt doesn't fire spuriously on first open.
136+
if case .ready = engine.state {
137+
appliedStreamExperts = effectiveStreamExpertsSetting
138+
}
139+
}
140+
.onChange(of: engine.state) { _, newState in
141+
if case .ready = newState {
142+
appliedStreamExperts = effectiveStreamExpertsSetting
143+
}
120144
}
121145
#if os(macOS)
122146
.frame(minWidth: 520, minHeight: 580)
@@ -295,6 +319,9 @@ struct SettingsView: View {
295319
tint: SwiftBuddyTheme.warning,
296320
hint: "Stream MoE expert weights from NVMe (requires model reload)"
297321
)
322+
if needsModelReloadForStreamingChange {
323+
modelReloadPrompt
324+
}
298325
toggleRow(
299326
label: "TurboQuant KV", icon: "bolt.badge.clock",
300327
isOn: $viewModel.config.turboKV,
@@ -555,70 +582,8 @@ struct SettingsView: View {
555582
tint: SwiftBuddyTheme.accentSecondary,
556583
hint: "mmap expert weights from NVMe — only active expert pages stay in RAM. Auto-enabled for MoE catalog models."
557584
)
558-
if effectiveStreamExpertsSetting != currentModelIsMoE {
559-
VStack(alignment: .leading, spacing: 8) {
560-
HStack(spacing: 6) {
561-
Image(systemName: "arrow.clockwise.circle.fill")
562-
.foregroundStyle(SwiftBuddyTheme.warning)
563-
.font(.caption)
564-
Text("Reload model to apply this change")
565-
.font(.caption2.weight(.medium))
566-
.foregroundStyle(SwiftBuddyTheme.warning)
567-
Spacer()
568-
Button("Reload") {
569-
let currentId: String? = {
570-
if case .ready(let id) = engine.state { return id }
571-
return nil
572-
}()
573-
if let id = currentId {
574-
Task {
575-
engine.unload()
576-
await engine.load(modelId: id)
577-
}
578-
}
579-
}
580-
.font(.caption2.weight(.semibold))
581-
.foregroundStyle(SwiftBuddyTheme.accent)
582-
.buttonStyle(.plain)
583-
}
584-
585-
switch engine.state {
586-
case .loading(let progress, let stage):
587-
VStack(alignment: .leading, spacing: 4) {
588-
HStack {
589-
Text(stage)
590-
.font(.caption2.weight(.medium))
591-
.foregroundStyle(SwiftBuddyTheme.textSecondary)
592-
Spacer()
593-
Text("\(Int(progress * 100))%")
594-
.font(.caption2.monospacedDigit())
595-
.foregroundStyle(SwiftBuddyTheme.textTertiary)
596-
}
597-
ProgressView(value: progress)
598-
.tint(SwiftBuddyTheme.accent)
599-
}
600-
case .downloading(let progress, let speed):
601-
VStack(alignment: .leading, spacing: 4) {
602-
HStack {
603-
Text("Downloading model files")
604-
.font(.caption2.weight(.medium))
605-
.foregroundStyle(SwiftBuddyTheme.textSecondary)
606-
Spacer()
607-
Text("\(Int(progress * 100))% · \(speed)")
608-
.font(.caption2.monospacedDigit())
609-
.foregroundStyle(SwiftBuddyTheme.textTertiary)
610-
}
611-
ProgressView(value: progress)
612-
.tint(SwiftBuddyTheme.accent)
613-
}
614-
default:
615-
EmptyView()
616-
}
617-
}
618-
.padding(.horizontal, 4)
619-
.padding(.vertical, 6)
620-
.background(SwiftBuddyTheme.warning.opacity(0.08))
621-
.clipShape(RoundedRectangle(cornerRadius: 8))
585+
if needsModelReloadForStreamingChange {
586+
modelReloadPrompt
622587
}
623588
}
624589
}
@@ -981,6 +946,63 @@ struct SettingsView: View {
981946
}
982947
}
983948

949+
@ViewBuilder
950+
private var modelReloadPrompt: some View {
951+
VStack(alignment: .leading, spacing: 8) {
952+
HStack(spacing: 6) {
953+
Image(systemName: "arrow.clockwise.circle.fill")
954+
.foregroundStyle(SwiftBuddyTheme.warning)
955+
.font(.caption)
956+
Text("Reload model to apply this change")
957+
.font(.caption2.weight(.medium))
958+
.foregroundStyle(SwiftBuddyTheme.warning)
959+
Spacer()
960+
Button("Reload") {
961+
reloadCurrentModel()
962+
}
963+
.font(.caption2.weight(.semibold))
964+
.foregroundStyle(SwiftBuddyTheme.accent)
965+
.buttonStyle(.plain)
966+
.disabled(currentModelId == nil)
967+
}
968+
969+
switch engine.state {
970+
case .loading(let progress, let stage):
971+
progressRow(label: stage, progress: progress)
972+
case .downloading(let progress, let speed):
973+
progressRow(label: "Downloading · \(speed)", progress: progress)
974+
default:
975+
EmptyView()
976+
}
977+
}
978+
}
979+
980+
@ViewBuilder
981+
private func progressRow(label: String, progress: Double) -> some View {
982+
VStack(alignment: .leading, spacing: 4) {
983+
HStack {
984+
Text(label)
985+
.font(.caption2.weight(.medium))
986+
.foregroundStyle(SwiftBuddyTheme.textSecondary)
987+
Spacer()
988+
Text("\(Int(progress * 100))%")
989+
.font(.caption2.monospacedDigit())
990+
.foregroundStyle(SwiftBuddyTheme.textTertiary)
991+
}
992+
ProgressView(value: progress)
993+
.tint(SwiftBuddyTheme.accent)
994+
.controlSize(.small)
995+
}
996+
}
997+
998+
private func reloadCurrentModel() {
999+
guard let currentModelId else { return }
1000+
Task {
1001+
engine.unload()
1002+
await engine.load(modelId: currentModelId)
1003+
}
1004+
}
1005+
9841006
@ViewBuilder
9851007
private func parameterCard<Content: View>(_ title: String, @ViewBuilder content: () -> Content) -> some View {
9861008
VStack(alignment: .leading, spacing: 10) {

SwiftBuddy/generate_xcodeproj.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def uid():
7070
# ── MLXInferenceCore sources (path relative to SwiftBuddy/)
7171
core_sources = [
7272
("../Sources/MLXInferenceCore/ChatMessage.swift", uid(), uid()),
73+
("../Sources/MLXInferenceCore/CLICommandBuilder.swift", uid(), uid()),
7374
("../Sources/MLXInferenceCore/GenerationConfig.swift", uid(), uid()),
7475
("../Sources/MLXInferenceCore/ModelCatalog.swift", uid(), uid()),
7576
("../Sources/MLXInferenceCore/ModelStorage.swift", uid(), uid()),
@@ -512,7 +513,7 @@ def main():
512513
print(" • ../mlx-swift-lm → MLXLLM, MLXLMCommon")
513514
print()
514515
print("📂 MLXInferenceCore sources included directly:")
515-
for p, _, _ in [("ChatMessage", None, None), ("GenerationConfig", None, None),
516+
for p, _, _ in [("ChatMessage", None, None), ("CLICommandBuilder", None, None), ("GenerationConfig", None, None),
516517
("ModelCatalog", None, None), ("ModelDownloadManager", None, None),
517518
("ModelArchitectureProbe", None, None), ("InferenceEngine", None, None)]:
518519
print(f" • {p}.swift")
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import XCTest
2+
import MLX
3+
import MLXLMCommon
4+
@testable import MLXInferenceCore
5+
6+
final class ContextWindowCalculationTests: XCTestCase {
7+
8+
@MainActor
9+
func testContextTokensCalculation() async throws {
10+
// Feature: Verify that tokens calculation accurately reflects the prompt cache window
11+
// by evaluating the full size of the prepared tokens array, not just the batch shape.
12+
13+
let engine = InferenceEngine()
14+
15+
// Mock a scenario where userInput prepares a chat template with large history.
16+
// We will directly instantiate LMInput and assert on its size.
17+
18+
let mockTokens = MLXArray(Array(0..<512))
19+
// If tokenizer batches it, shape could be [1, 512].
20+
let reshapedTokens = mockTokens.reshaped([1, 512])
21+
22+
// MLXLMCommon's LMInput struct
23+
let lmInput = LMInput(tokens: reshapedTokens)
24+
25+
// Validate that using .size accurately captures the token count (512)
26+
// rather than falling victim to the batch dimension .shape[0] which would be 1.
27+
XCTAssertEqual(lmInput.text.tokens.shape[0], 1, "shape[0] captures the batch dimension, returning 1")
28+
XCTAssertEqual(lmInput.text.tokens.size, 512, "size captures the total token count, resolving the context window bug")
29+
}
30+
}

0 commit comments

Comments
 (0)