@@ -161,49 +161,66 @@ def _resample(waveform: np.ndarray, orig_sr: int, target_sr: int = _QWEN_SAMPLE_
161161
162162 return librosa .resample (waveform , orig_sr = orig_sr , target_sr = target_sr )
163163
164- def _build_messages (self , waveform : np .ndarray ) -> list [dict [str , Any ]]:
165- """Build Turn 1 chat messages with an in-memory waveform (numpy array at 16 kHz)."""
164+ def _resolve_prompt (self , template : str , language : str | None ) -> str :
165+ """Replace ``{language}`` placeholder if *language* is provided."""
166+ if language and template and "{language}" in template :
167+ return template .replace ("{language}" , language )
168+ return template
169+
170+ def _build_messages (self , waveform : np .ndarray , language : str | None = None ) -> list [dict [str , Any ]]:
171+ """Build Turn 1 chat messages with an in-memory waveform (numpy array at 16 kHz).
172+
173+ Prompts may contain a ``{language}`` placeholder which is replaced
174+ with *language* (e.g., ``"French"``) when provided.
175+ """
176+ prompt = self ._resolve_prompt (self .prompt_text , language )
166177 messages : list [dict [str , Any ]] = []
167178 if self .system_prompt :
168- messages .append ({"role" : "system" , "content" : [{"type" : "text" , "text" : self .system_prompt }]})
179+ sys_prompt = self ._resolve_prompt (self .system_prompt , language )
180+ messages .append ({"role" : "system" , "content" : [{"type" : "text" , "text" : sys_prompt }]})
169181 messages .append ({
170182 "role" : "user" ,
171183 "content" : [
172- {"type" : "text" , "text" : self . prompt_text },
184+ {"type" : "text" , "text" : prompt },
173185 {"type" : "audio" , "audio" : waveform },
174186 ],
175187 })
176188 return messages
177189
178- def _build_turn2_messages (self , waveform : np .ndarray , pred_text : str ) -> list [dict [str , Any ]]:
179- """Build Turn 2 messages: full Turn 1 conversation history + follow-up promt."""
190+ def _build_turn2_messages (
191+ self , waveform : np .ndarray , pred_text : str , language : str | None = None ,
192+ ) -> list [dict [str , Any ]]:
193+ """Build Turn 2 messages: full Turn 1 conversation history + follow-up prompt."""
194+ prompt = self ._resolve_prompt (self .prompt_text , language )
195+ followup = self ._resolve_prompt (self .followup_prompt , language )
180196 messages : list [dict [str , Any ]] = []
181197 if self .system_prompt :
182- messages .append ({"role" : "system" , "content" : [{"type" : "text" , "text" : self .system_prompt }]})
198+ sys_prompt = self ._resolve_prompt (self .system_prompt , language )
199+ messages .append ({"role" : "system" , "content" : [{"type" : "text" , "text" : sys_prompt }]})
183200 messages .append ({
184201 "role" : "user" ,
185202 "content" : [
186- {"type" : "text" , "text" : self . prompt_text },
203+ {"type" : "text" , "text" : prompt },
187204 {"type" : "audio" , "audio" : waveform },
188205 ],
189206 })
190207 messages .append ({"role" : "assistant" , "content" : [{"type" : "text" , "text" : pred_text }]})
191208 messages .append ({
192209 "role" : "user" ,
193210 "content" : [
194- {"type" : "text" , "text" : self . followup_prompt },
211+ {"type" : "text" , "text" : followup },
195212 ],
196213 })
197214 return messages
198215
199216 def _prepare_single (
200- self , waveform : np .ndarray , sample_rate : int ,
217+ self , waveform : np .ndarray , sample_rate : int , language : str | None = None ,
201218 ) -> tuple [dict [str , Any ], np .ndarray ] | None :
202219 from qwen_omni_utils import process_mm_info
203220
204221 try :
205222 waveform_16k = self ._resample (waveform , sample_rate )
206- messages = self ._build_messages (waveform_16k )
223+ messages = self ._build_messages (waveform_16k , language )
207224 text = self ._processor .apply_chat_template (messages , tokenize = False , add_generation_prompt = True )
208225 audios , images , videos = process_mm_info (messages , use_audio_in_video = False )
209226 except Exception : # noqa: BLE001
@@ -227,18 +244,23 @@ def _prepare_batch(
227244 self ,
228245 waveforms : list [np .ndarray ],
229246 sample_rates : list [int ],
247+ languages : list [str | None ] | None = None ,
230248 ) -> list [tuple [dict [str , Any ], np .ndarray ] | None ]:
249+ langs = languages if languages is not None else [None ] * len (waveforms )
231250 if self ._prep_pool is None :
232- return [self ._prepare_single (w , sr ) for w , sr in zip (waveforms , sample_rates , strict = False )]
233- return list (self ._prep_pool .map (self ._prepare_single , waveforms , sample_rates ))
251+ return [
252+ self ._prepare_single (w , sr , lang )
253+ for w , sr , lang in zip (waveforms , sample_rates , langs , strict = False )
254+ ]
255+ return list (self ._prep_pool .map (self ._prepare_single , waveforms , sample_rates , langs ))
234256
235257 def _prepare_turn2_single (
236- self , waveform_16k : np .ndarray , pred_text : str ,
258+ self , waveform_16k : np .ndarray , pred_text : str , language : str | None = None ,
237259 ) -> dict [str , Any ] | None :
238260 from qwen_omni_utils import process_mm_info
239261
240262 try :
241- messages = self ._build_turn2_messages (waveform_16k , pred_text )
263+ messages = self ._build_turn2_messages (waveform_16k , pred_text , language )
242264 text = self ._processor .apply_chat_template (messages , tokenize = False , add_generation_prompt = True )
243265 audios , images , videos = process_mm_info (messages , use_audio_in_video = False )
244266 except Exception : # noqa: BLE001
@@ -262,13 +284,15 @@ def _prepare_turn2_batch(
262284 self ,
263285 waveforms_16k : list [np .ndarray ],
264286 pred_texts : list [str ],
287+ languages : list [str | None ] | None = None ,
265288 ) -> list [dict [str , Any ] | None ]:
289+ langs = languages if languages is not None else [None ] * len (waveforms_16k )
266290 if self ._prep_pool is None :
267291 return [
268- self ._prepare_turn2_single (w , pt )
269- for w , pt in zip (waveforms_16k , pred_texts , strict = False )
292+ self ._prepare_turn2_single (w , pt , lang )
293+ for w , pt , lang in zip (waveforms_16k , pred_texts , langs , strict = False )
270294 ]
271- return list (self ._prep_pool .map (self ._prepare_turn2_single , waveforms_16k , pred_texts ))
295+ return list (self ._prep_pool .map (self ._prepare_turn2_single , waveforms_16k , pred_texts , langs ))
272296
273297 # ------------------------------------------------------------------
274298 # Generation
@@ -278,6 +302,7 @@ def generate(
278302 self ,
279303 waveforms : list [np .ndarray ],
280304 sample_rates : list [int ],
305+ languages : list [str | None ] | None = None ,
281306 ) -> tuple [list [str ], list [str ]]:
282307 """Run batched two-turn inference on in-memory audio waveforms.
283308
@@ -288,6 +313,9 @@ def generate(
288313 Args:
289314 waveforms: List of 1-D mono numpy float32 arrays.
290315 sample_rates: Corresponding sample rates for each waveform.
316+ languages: Optional per-sample language strings for ``{language}``
317+ placeholder substitution in prompts. Length must match
318+ ``waveforms``. Pass ``None`` (default) to skip substitution.
291319
292320 Returns:
293321 ``(pred_texts, disfluency_texts)`` — one string per input for
@@ -301,7 +329,7 @@ def generate(
301329 n = len (waveforms )
302330
303331 # -- Turn 1 ----------------------------------------------------------
304- prepared = self ._prepare_batch (waveforms , sample_rates )
332+ prepared = self ._prepare_batch (waveforms , sample_rates , languages )
305333 valid_indices = [i for i , p in enumerate (prepared ) if p is not None ]
306334 valid_inputs = [prepared [i ][0 ] for i in valid_indices ]
307335 waveforms_16k : dict [int , np .ndarray ] = {i : prepared [i ][1 ] for i in valid_indices }
@@ -327,9 +355,13 @@ def generate(
327355 if not t2_indices :
328356 return pred_texts , ["" ] * n
329357
358+ t2_languages = (
359+ [languages [i ] for i in t2_indices ] if languages is not None else None
360+ )
330361 t2_prepared = self ._prepare_turn2_batch (
331362 [waveforms_16k [i ] for i in t2_indices ],
332363 [pred_texts [i ] for i in t2_indices ],
364+ t2_languages ,
333365 )
334366
335367 t2_valid = [(i , p ) for i , p in zip (t2_indices , t2_prepared , strict = False ) if p is not None ]
0 commit comments