Skip to content

Commit 5b6db9f

Browse files
committed
fix: Update Package.resolved for mlx-swift-lm Omni Audio logic and resolve @sendable concurrency error in ModelDownloadManager
1 parent 3a6562a commit 5b6db9f

3 files changed

Lines changed: 36 additions & 9 deletions

File tree

Package.resolved

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Sources/MLXInferenceCore/ModelDownloadManager.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ public final class ModelDownloadManager: ObservableObject {
179179
_ = try await hub.snapshot(
180180
from: modelId,
181181
matching: ["*.safetensors", "*.json", "*.model", "*.txt", "*.tiktoken"],
182-
progressHandler: { [weak self] progress in
182+
progressHandler: { @Sendable [weak self] progress in
183183
Task { @MainActor [weak self] in
184184
let pct = progress.fractionCompleted
185185
let speedBytesPerSec = progress.userInfo[ProgressUserInfoKey("throughputKey")] as? Double

Sources/SwiftLM/Server.swift

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2342,16 +2342,19 @@ public struct ALMUserInputProcessor: UserInputProcessor, @unchecked Sendable {
23422342
let configuration: ModelConfiguration
23432343
let messageGenerator: MessageGenerator
23442344
let fusionProcessor: MultimodalFusionProcessor
2345+
let numAudioEmbeddings: Int
23452346

23462347
public init(
23472348
tokenizer: any MLXLMCommon.Tokenizer, configuration: ModelConfiguration,
23482349
messageGenerator: MessageGenerator,
2349-
boaToken: Int = 255010, eoaToken: Int = 255011
2350+
boaToken: Int = 255010, eoaToken: Int = 255011,
2351+
numAudioEmbeddings: Int = 128
23502352
) {
23512353
self.tokenizer = tokenizer
23522354
self.configuration = configuration
23532355
self.messageGenerator = messageGenerator
23542356
self.fusionProcessor = MultimodalFusionProcessor(boaToken: boaToken, eoaToken: eoaToken)
2357+
self.numAudioEmbeddings = numAudioEmbeddings
23552358
}
23562359

23572360
public func prepare(input: UserInput) throws -> LMInput {
@@ -2366,7 +2369,7 @@ public struct ALMUserInputProcessor: UserInputProcessor, @unchecked Sendable {
23662369
// Mock num audio embeddings for now - typically derived from the model or audio lengths
23672370
let rawSequence = fusionProcessor.interleave(
23682371
textTokens: promptTokensInt,
2369-
numAudioEmbeddings: 128, // Placeholder
2372+
numAudioEmbeddings: numAudioEmbeddings,
23702373
audioFirst: true
23712374
)
23722375
return LMInput(tokens: MLXArray(rawSequence))
@@ -2394,13 +2397,15 @@ public final class ALMModelFactory: ModelFactory, @unchecked Sendable {
23942397
) async throws -> ModelContext {
23952398
let context = try await LLMModelFactory.shared._load(configuration: configuration, tokenizerLoader: tokenizerLoader)
23962399

2400+
let numAudioEmbeddings = OmniModelFactory.extractNumAudioEmbeddings(configuration: configuration)
23972401
let messageGenerator = DefaultMessageGenerator()
23982402
let processor = ALMUserInputProcessor(
23992403
tokenizer: context.tokenizer,
24002404
configuration: context.configuration,
24012405
messageGenerator: messageGenerator,
24022406
boaToken: 255010,
2403-
eoaToken: 255011
2407+
eoaToken: 255011,
2408+
numAudioEmbeddings: numAudioEmbeddings
24042409
)
24052410

24062411
return .init(
@@ -2415,10 +2420,12 @@ public final class ALMModelFactory: ModelFactory, @unchecked Sendable {
24152420
public struct OmniUserInputProcessor: UserInputProcessor, @unchecked Sendable {
24162421
let vlmProcessor: any UserInputProcessor
24172422
let fusionProcessor: MultimodalFusionProcessor
2423+
let numAudioEmbeddings: Int
24182424

2419-
public init(vlmProcessor: any UserInputProcessor, boaToken: Int = 255010, eoaToken: Int = 255011) {
2425+
public init(vlmProcessor: any UserInputProcessor, boaToken: Int = 255010, eoaToken: Int = 255011, numAudioEmbeddings: Int = 128) {
24202426
self.vlmProcessor = vlmProcessor
24212427
self.fusionProcessor = MultimodalFusionProcessor(boaToken: boaToken, eoaToken: eoaToken)
2428+
self.numAudioEmbeddings = numAudioEmbeddings
24222429
}
24232430

24242431
public func prepare(input: UserInput) async throws -> LMInput {
@@ -2431,7 +2438,7 @@ public struct OmniUserInputProcessor: UserInputProcessor, @unchecked Sendable {
24312438
print("[Omni] Interleaving Audio Tokens into VLM prompt structure.")
24322439
let rawSequence = fusionProcessor.interleave(
24332440
textTokens: tokens,
2434-
numAudioEmbeddings: 128, // Placeholder until audio config extraction is available globally
2441+
numAudioEmbeddings: numAudioEmbeddings,
24352442
audioFirst: false // Append audio after vision context typically
24362443
)
24372444
return LMInput(text: .init(tokens: MLXArray(rawSequence)), image: vlmInput.image)
@@ -2453,7 +2460,11 @@ public final class OmniModelFactory: ModelFactory, @unchecked Sendable {
24532460
tokenizerLoader: any TokenizerLoader
24542461
) async throws -> ModelContext {
24552462
let vlmContext = try await VLMModelFactory.shared._load(configuration: configuration, tokenizerLoader: tokenizerLoader)
2456-
let omniProcessor = OmniUserInputProcessor(vlmProcessor: vlmContext.processor)
2463+
let numAudioEmbeddings = OmniModelFactory.extractNumAudioEmbeddings(configuration: configuration)
2464+
let omniProcessor = OmniUserInputProcessor(
2465+
vlmProcessor: vlmContext.processor,
2466+
numAudioEmbeddings: numAudioEmbeddings
2467+
)
24572468

24582469
return .init(
24592470
configuration: vlmContext.configuration,
@@ -2462,4 +2473,20 @@ public final class OmniModelFactory: ModelFactory, @unchecked Sendable {
24622473
tokenizer: vlmContext.tokenizer
24632474
)
24642475
}
2476+
2477+
public static func extractNumAudioEmbeddings(configuration: ResolvedModelConfiguration) -> Int {
2478+
let configurationURL = configuration.modelDirectory.appending(component: "config.json")
2479+
if let data = try? Data(contentsOf: configurationURL),
2480+
let dict = try? JSONSerialization.jsonObject(with: data) as? [String: Any] {
2481+
2482+
if let subsampling = dict["subsampling_conv_channels"] as? [Int] {
2483+
return subsampling.first ?? 128
2484+
}
2485+
if let audioConfig = dict["audio_config"] as? [String: Any],
2486+
let embeddings = audioConfig["num_audio_embeddings"] as? Int {
2487+
return embeddings
2488+
}
2489+
}
2490+
return 128
2491+
}
24652492
}

0 commit comments

Comments
 (0)