Skip to content

Commit a04b81e

Browse files
authored
Merge pull request #104 from roydsouza/fix/moe-memory-and-multimodal-tokens-rebased
Fix: Resolve multimodal BOA/EOA tokens dynamically from config.json
2 parents f1dddb8 + 5cfc277 commit a04b81e

2 files changed

Lines changed: 97 additions & 11 deletions

File tree

Sources/SwiftLM/Server.swift

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3013,15 +3013,15 @@ public final class ALMModelFactory: ModelFactory, @unchecked Sendable {
30133013
) async throws -> ModelContext {
30143014
let context = try await LLMModelFactory.shared._load(configuration: configuration, tokenizerLoader: tokenizerLoader)
30153015

3016-
let numAudioEmbeddings = OmniModelFactory.extractNumAudioEmbeddings(configuration: configuration)
3016+
let tokens = OmniModelFactory.extractMultimodalTokens(configuration: configuration)
30173017
let messageGenerator = DefaultMessageGenerator()
30183018
let processor = ALMUserInputProcessor(
30193019
tokenizer: context.tokenizer,
30203020
configuration: context.configuration,
30213021
messageGenerator: messageGenerator,
3022-
boaToken: 255010,
3023-
eoaToken: 255011,
3024-
numAudioEmbeddings: numAudioEmbeddings
3022+
boaToken: tokens.boa,
3023+
eoaToken: tokens.eoa,
3024+
numAudioEmbeddings: tokens.numAudio
30253025
)
30263026

30273027
return .init(
@@ -3081,10 +3081,12 @@ public final class OmniModelFactory: ModelFactory, @unchecked Sendable {
30813081
tokenizerLoader: any TokenizerLoader
30823082
) async throws -> ModelContext {
30833083
let vlmContext = try await VLMModelFactory.shared._load(configuration: configuration, tokenizerLoader: tokenizerLoader)
3084-
let numAudioEmbeddings = OmniModelFactory.extractNumAudioEmbeddings(configuration: configuration)
3084+
let tokens = OmniModelFactory.extractMultimodalTokens(configuration: configuration)
30853085
let omniProcessor = OmniUserInputProcessor(
30863086
vlmProcessor: vlmContext.processor,
3087-
numAudioEmbeddings: numAudioEmbeddings
3087+
boaToken: tokens.boa,
3088+
eoaToken: tokens.eoa,
3089+
numAudioEmbeddings: tokens.numAudio
30883090
)
30893091

30903092
return .init(
@@ -3095,19 +3097,35 @@ public final class OmniModelFactory: ModelFactory, @unchecked Sendable {
30953097
)
30963098
}
30973099

3100+
@available(*, deprecated, message: "Use extractMultimodalTokens(configuration:).numAudio instead")
30983101
public static func extractNumAudioEmbeddings(configuration: ResolvedModelConfiguration) -> Int {
3102+
extractMultimodalTokens(configuration: configuration).numAudio
3103+
}
3104+
3105+
public static func extractMultimodalTokens(configuration: ResolvedModelConfiguration) -> (numAudio: Int, boa: Int, eoa: Int) {
30993106
let configurationURL = configuration.modelDirectory.appending(component: "config.json")
3107+
var numAudio = 128
3108+
var boa = 255010
3109+
var eoa = 255011
3110+
31003111
if let data = try? Data(contentsOf: configurationURL),
31013112
let dict = try? JSONSerialization.jsonObject(with: data) as? [String: Any] {
31023113

3114+
// Extract num_audio_embeddings
31033115
if let subsampling = dict["subsampling_conv_channels"] as? [Int] {
3104-
return subsampling.first ?? 128
3105-
}
3106-
if let audioConfig = dict["audio_config"] as? [String: Any],
3116+
numAudio = subsampling.first ?? 128
3117+
} else if let audioConfig = dict["audio_config"] as? [String: Any],
31073118
let embeddings = audioConfig["num_audio_embeddings"] as? Int {
3108-
return embeddings
3119+
numAudio = embeddings
31093120
}
3121+
3122+
// Extract BOA/EOA tokens
3123+
if let b = dict["boa_token_id"] as? Int { boa = b }
3124+
else if let b = (dict["audio_config"] as? [String: Any])?["boa_token_id"] as? Int { boa = b }
3125+
3126+
if let e = dict["eoa_token_id"] as? Int { eoa = e }
3127+
else if let e = (dict["audio_config"] as? [String: Any])?["eoa_token_id"] as? Int { eoa = e }
31103128
}
3111-
return 128
3129+
return (numAudio, boa, eoa)
31123130
}
31133131
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import XCTest
2+
import Foundation
3+
@testable import SwiftLM
4+
import MLXLMCommon
5+
6+
final class MultimodalTokenExtractionTests: XCTestCase {
7+
8+
func testExtractMultimodalTokens_Defaults() throws {
9+
let tempDir = FileManager.default.temporaryDirectory.appendingPathComponent(UUID().uuidString)
10+
try FileManager.default.createDirectory(at: tempDir, withIntermediateDirectories: true)
11+
defer { try? FileManager.default.removeItem(at: tempDir) }
12+
13+
let config = ModelConfiguration(directory: tempDir).resolved(modelDirectory: tempDir, tokenizerDirectory: tempDir)
14+
15+
let tokens = OmniModelFactory.extractMultimodalTokens(configuration: config)
16+
XCTAssertEqual(tokens.numAudio, 128)
17+
XCTAssertEqual(tokens.boa, 255010)
18+
XCTAssertEqual(tokens.eoa, 255011)
19+
}
20+
21+
func testExtractMultimodalTokens_FromConfig() throws {
22+
let tempDir = FileManager.default.temporaryDirectory.appendingPathComponent(UUID().uuidString)
23+
try FileManager.default.createDirectory(at: tempDir, withIntermediateDirectories: true)
24+
defer { try? FileManager.default.removeItem(at: tempDir) }
25+
26+
let jsonDict: [String: Any] = [
27+
"subsampling_conv_channels": [256],
28+
"boa_token_id": 999990,
29+
"eoa_token_id": 999991
30+
]
31+
32+
let jsonData = try JSONSerialization.data(withJSONObject: jsonDict)
33+
let configURL = tempDir.appendingPathComponent("config.json")
34+
try jsonData.write(to: configURL)
35+
36+
let config = ModelConfiguration(directory: tempDir).resolved(modelDirectory: tempDir, tokenizerDirectory: tempDir)
37+
let tokens = OmniModelFactory.extractMultimodalTokens(configuration: config)
38+
39+
XCTAssertEqual(tokens.numAudio, 256)
40+
XCTAssertEqual(tokens.boa, 999990)
41+
XCTAssertEqual(tokens.eoa, 999991)
42+
}
43+
44+
func testExtractMultimodalTokens_FromAudioConfigFallback() throws {
45+
let tempDir = FileManager.default.temporaryDirectory.appendingPathComponent(UUID().uuidString)
46+
try FileManager.default.createDirectory(at: tempDir, withIntermediateDirectories: true)
47+
defer { try? FileManager.default.removeItem(at: tempDir) }
48+
49+
let jsonDict: [String: Any] = [
50+
"audio_config": [
51+
"num_audio_embeddings": 512,
52+
"boa_token_id": 888880,
53+
"eoa_token_id": 888881
54+
]
55+
]
56+
57+
let jsonData = try JSONSerialization.data(withJSONObject: jsonDict)
58+
let configURL = tempDir.appendingPathComponent("config.json")
59+
try jsonData.write(to: configURL)
60+
61+
let config = ModelConfiguration(directory: tempDir).resolved(modelDirectory: tempDir, tokenizerDirectory: tempDir)
62+
let tokens = OmniModelFactory.extractMultimodalTokens(configuration: config)
63+
64+
XCTAssertEqual(tokens.numAudio, 512)
65+
XCTAssertEqual(tokens.boa, 888880)
66+
XCTAssertEqual(tokens.eoa, 888881)
67+
}
68+
}

0 commit comments

Comments
 (0)