@@ -2342,16 +2342,19 @@ public struct ALMUserInputProcessor: UserInputProcessor, @unchecked Sendable {
23422342 let configuration : ModelConfiguration
23432343 let messageGenerator : MessageGenerator
23442344 let fusionProcessor : MultimodalFusionProcessor
2345+ let numAudioEmbeddings : Int
23452346
23462347 public init (
23472348 tokenizer: any MLXLMCommon . Tokenizer , configuration: ModelConfiguration ,
23482349 messageGenerator: MessageGenerator ,
2349- boaToken: Int = 255010 , eoaToken: Int = 255011
2350+ boaToken: Int = 255010 , eoaToken: Int = 255011 ,
2351+ numAudioEmbeddings: Int = 128
23502352 ) {
23512353 self . tokenizer = tokenizer
23522354 self . configuration = configuration
23532355 self . messageGenerator = messageGenerator
23542356 self . fusionProcessor = MultimodalFusionProcessor ( boaToken: boaToken, eoaToken: eoaToken)
2357+ self . numAudioEmbeddings = numAudioEmbeddings
23552358 }
23562359
23572360 public func prepare( input: UserInput ) throws -> LMInput {
@@ -2366,7 +2369,7 @@ public struct ALMUserInputProcessor: UserInputProcessor, @unchecked Sendable {
23662369 // Mock num audio embeddings for now - typically derived from the model or audio lengths
23672370 let rawSequence = fusionProcessor. interleave (
23682371 textTokens: promptTokensInt,
2369- numAudioEmbeddings: 128 , // Placeholder
2372+ numAudioEmbeddings: numAudioEmbeddings ,
23702373 audioFirst: true
23712374 )
23722375 return LMInput ( tokens: MLXArray ( rawSequence) )
@@ -2394,13 +2397,15 @@ public final class ALMModelFactory: ModelFactory, @unchecked Sendable {
23942397 ) async throws -> ModelContext {
23952398 let context = try await LLMModelFactory . shared. _load ( configuration: configuration, tokenizerLoader: tokenizerLoader)
23962399
2400+ let numAudioEmbeddings = OmniModelFactory . extractNumAudioEmbeddings ( configuration: configuration)
23972401 let messageGenerator = DefaultMessageGenerator ( )
23982402 let processor = ALMUserInputProcessor (
23992403 tokenizer: context. tokenizer,
24002404 configuration: context. configuration,
24012405 messageGenerator: messageGenerator,
24022406 boaToken: 255010 ,
2403- eoaToken: 255011
2407+ eoaToken: 255011 ,
2408+ numAudioEmbeddings: numAudioEmbeddings
24042409 )
24052410
24062411 return . init(
@@ -2415,10 +2420,12 @@ public final class ALMModelFactory: ModelFactory, @unchecked Sendable {
24152420public struct OmniUserInputProcessor : UserInputProcessor , @unchecked Sendable {
24162421 let vlmProcessor : any UserInputProcessor
24172422 let fusionProcessor : MultimodalFusionProcessor
2423+ let numAudioEmbeddings : Int
24182424
2419- public init ( vlmProcessor: any UserInputProcessor , boaToken: Int = 255010 , eoaToken: Int = 255011 ) {
2425+ public init ( vlmProcessor: any UserInputProcessor , boaToken: Int = 255010 , eoaToken: Int = 255011 , numAudioEmbeddings : Int = 128 ) {
24202426 self . vlmProcessor = vlmProcessor
24212427 self . fusionProcessor = MultimodalFusionProcessor ( boaToken: boaToken, eoaToken: eoaToken)
2428+ self . numAudioEmbeddings = numAudioEmbeddings
24222429 }
24232430
24242431 public func prepare( input: UserInput ) async throws -> LMInput {
@@ -2431,7 +2438,7 @@ public struct OmniUserInputProcessor: UserInputProcessor, @unchecked Sendable {
24312438 print ( " [Omni] Interleaving Audio Tokens into VLM prompt structure. " )
24322439 let rawSequence = fusionProcessor. interleave (
24332440 textTokens: tokens,
2434- numAudioEmbeddings: 128 , // Placeholder until audio config extraction is available globally
2441+ numAudioEmbeddings: numAudioEmbeddings ,
24352442 audioFirst: false // Append audio after vision context typically
24362443 )
24372444 return LMInput ( text: . init( tokens: MLXArray ( rawSequence) ) , image: vlmInput. image)
@@ -2453,7 +2460,11 @@ public final class OmniModelFactory: ModelFactory, @unchecked Sendable {
24532460 tokenizerLoader: any TokenizerLoader
24542461 ) async throws -> ModelContext {
24552462 let vlmContext = try await VLMModelFactory . shared. _load ( configuration: configuration, tokenizerLoader: tokenizerLoader)
2456- let omniProcessor = OmniUserInputProcessor ( vlmProcessor: vlmContext. processor)
2463+ let numAudioEmbeddings = OmniModelFactory . extractNumAudioEmbeddings ( configuration: configuration)
2464+ let omniProcessor = OmniUserInputProcessor (
2465+ vlmProcessor: vlmContext. processor,
2466+ numAudioEmbeddings: numAudioEmbeddings
2467+ )
24572468
24582469 return . init(
24592470 configuration: vlmContext. configuration,
@@ -2462,4 +2473,20 @@ public final class OmniModelFactory: ModelFactory, @unchecked Sendable {
24622473 tokenizer: vlmContext. tokenizer
24632474 )
24642475 }
2476+
2477+ public static func extractNumAudioEmbeddings( configuration: ResolvedModelConfiguration ) -> Int {
2478+ let configurationURL = configuration. modelDirectory. appending ( component: " config.json " )
2479+ if let data = try ? Data ( contentsOf: configurationURL) ,
2480+ let dict = try ? JSONSerialization . jsonObject ( with: data) as? [ String : Any ] {
2481+
2482+ if let subsampling = dict [ " subsampling_conv_channels " ] as? [ Int ] {
2483+ return subsampling. first ?? 128
2484+ }
2485+ if let audioConfig = dict [ " audio_config " ] as? [ String : Any ] ,
2486+ let embeddings = audioConfig [ " num_audio_embeddings " ] as? Int {
2487+ return embeddings
2488+ }
2489+ }
2490+ return 128
2491+ }
24652492}
0 commit comments