From 56ab539fedb10cd381a25f50a002876e1cee0efa Mon Sep 17 00:00:00 2001 From: naymaraq Date: Thu, 2 Apr 2026 19:49:11 +0400 Subject: [PATCH 1/2] minor fixies Signed-off-by: naymaraq --- .../model_wrappers/asr_inference_wrapper.py | 2 +- .../model_wrappers/salm_asr_inference_wrapper.py | 2 +- .../asr/inference/nmt/llm_translator.py | 2 +- .../streaming/buffering/cache_feature_bufferer.py | 4 +++- .../endpointing/greedy/greedy_endpointing.py | 2 +- .../inference/streaming/framing/request_options.py | 14 +++++++------- .../inference/streaming/state/cache_aware_state.py | 5 +++++ .../asr/inference/utils/device_utils.py | 5 +++-- 8 files changed, 22 insertions(+), 14 deletions(-) diff --git a/nemo/collections/asr/inference/model_wrappers/asr_inference_wrapper.py b/nemo/collections/asr/inference/model_wrappers/asr_inference_wrapper.py index 305ec894b6d2..6be50dee5596 100644 --- a/nemo/collections/asr/inference/model_wrappers/asr_inference_wrapper.py +++ b/nemo/collections/asr/inference/model_wrappers/asr_inference_wrapper.py @@ -92,7 +92,7 @@ def load_model(model_name: str, map_location: torch.device) -> ASRModel: asr_model.eval() return asr_model except Exception as e: - raise RuntimeError(f"Failed to load model {model_name}: {str(e)}") + raise RuntimeError(f"Failed to load model {model_name}: {str(e)}") from e @property def word_separator(self) -> str: diff --git a/nemo/collections/asr/inference/model_wrappers/salm_asr_inference_wrapper.py b/nemo/collections/asr/inference/model_wrappers/salm_asr_inference_wrapper.py index 5a27b121d55d..3355899c4b98 100644 --- a/nemo/collections/asr/inference/model_wrappers/salm_asr_inference_wrapper.py +++ b/nemo/collections/asr/inference/model_wrappers/salm_asr_inference_wrapper.py @@ -85,7 +85,7 @@ def load_model(model_name: str, device: torch.device) -> SALM: model.to(device) return model except Exception as e: - raise RuntimeError(f"Failed to load model {model_name}: {str(e)}") + raise RuntimeError(f"Failed to load model {model_name}: {str(e)}") from e def get_window_stride(self) -> float: """Returns the window stride of the model.""" diff --git a/nemo/collections/asr/inference/nmt/llm_translator.py b/nemo/collections/asr/inference/nmt/llm_translator.py index f4fb4e9b96a7..34c50b17eef9 100644 --- a/nemo/collections/asr/inference/nmt/llm_translator.py +++ b/nemo/collections/asr/inference/nmt/llm_translator.py @@ -167,7 +167,7 @@ def load_model(self, llm_params: dict) -> LLM: model = LLM(model=self.model_name, **llm_params) return model except Exception as e: - raise RuntimeError(f"Model loading failed: {str(e)}") + raise RuntimeError(f"Model loading failed: {str(e)}") from e def translate_batch( self, diff --git a/nemo/collections/asr/inference/streaming/buffering/cache_feature_bufferer.py b/nemo/collections/asr/inference/streaming/buffering/cache_feature_bufferer.py index a469ceeeee06..f4a8c273a100 100644 --- a/nemo/collections/asr/inference/streaming/buffering/cache_feature_bufferer.py +++ b/nemo/collections/asr/inference/streaming/buffering/cache_feature_bufferer.py @@ -111,6 +111,8 @@ def free_slots(self, slot_ids: list[int]) -> None: slot_ids (list[int]): list of slot ids """ for slot_id in slot_ids: + if slot_id not in self.slotidx2streamidx: + continue self.available_slots.put(slot_id) stream_id = self.slotidx2streamidx[slot_id] del self.slotidx2streamidx[slot_id], self.streamidx2slotidx[stream_id] @@ -148,7 +150,7 @@ def preprocess( right_padding = torch.floor(right_paddings / self.sample_rate / self.timestep_duration) # B return features, right_padding - def _update_feature_buffer(self, slot_ids: int, feat_chunk: Tensor) -> None: + def _update_feature_buffer(self, slot_ids: list[int], feat_chunk: Tensor) -> None: """ Add an extracted feature to `feature_buffer` Args: diff --git a/nemo/collections/asr/inference/streaming/endpointing/greedy/greedy_endpointing.py b/nemo/collections/asr/inference/streaming/endpointing/greedy/greedy_endpointing.py index 1442043ae711..85a34427f768 100644 --- a/nemo/collections/asr/inference/streaming/endpointing/greedy/greedy_endpointing.py +++ b/nemo/collections/asr/inference/streaming/endpointing/greedy/greedy_endpointing.py @@ -265,7 +265,7 @@ def detect_eou_near_pivot( raise ValueError("Pivot point is out of range") if search_start_point > pivot_point: - raise ValueError("Search start point is greater then pivot_point") + raise ValueError("Search start point is greater than pivot_point") if self.residue_tokens_at_end > 0: sequence_length = max(0, sequence_length - self.residue_tokens_at_end) diff --git a/nemo/collections/asr/inference/streaming/framing/request_options.py b/nemo/collections/asr/inference/streaming/framing/request_options.py index 78b9c1867dcc..ee12141dc610 100644 --- a/nemo/collections/asr/inference/streaming/framing/request_options.py +++ b/nemo/collections/asr/inference/streaming/framing/request_options.py @@ -27,14 +27,14 @@ class ASRRequestOptions: None value means that the option is not set and the default value will be used """ - enable_itn: bool = None - enable_pnc: bool = None - stop_history_eou: int = None - asr_output_granularity: ASROutputGranularity | str = None + enable_itn: bool | None = None + enable_pnc: bool | None = None + stop_history_eou: int | None = None + asr_output_granularity: ASROutputGranularity | str | None = None language_code: str | None = None - enable_nmt: bool = None - source_language: str = None - target_language: str = None + enable_nmt: bool | None = None + source_language: str | None = None + target_language: str | None = None biasing_cfg: BiasingRequestItemConfig | None = None def __post_init__(self) -> None: diff --git a/nemo/collections/asr/inference/streaming/state/cache_aware_state.py b/nemo/collections/asr/inference/streaming/state/cache_aware_state.py index f87b18edc002..941d0c87ed6d 100644 --- a/nemo/collections/asr/inference/streaming/state/cache_aware_state.py +++ b/nemo/collections/asr/inference/streaming/state/cache_aware_state.py @@ -69,6 +69,11 @@ def update_label_buffer(self, labels: list[int]) -> None: labels: (list[int]) list of labels """ shift = len(labels) + if shift == 0: + return + if shift >= len(self.label_buffer): + self.label_buffer[:] = labels[-len(self.label_buffer):] + return self.label_buffer[:-shift] = self.label_buffer[shift:].copy() self.label_buffer[-shift:] = labels.copy() diff --git a/nemo/collections/asr/inference/utils/device_utils.py b/nemo/collections/asr/inference/utils/device_utils.py index 320a6da5c838..1861ed686955 100644 --- a/nemo/collections/asr/inference/utils/device_utils.py +++ b/nemo/collections/asr/inference/utils/device_utils.py @@ -48,10 +48,11 @@ def setup_device(device: str, device_id: int | None, compute_dtype: str) -> tupl logging.warning(f"Device ID {device_id} is not available. Using GPU 0 instead.") device_id = 0 - compute_dtype = COMPUTE_DTYPE_MAP.get(compute_dtype, None) + compute_dtype_str = compute_dtype + compute_dtype = COMPUTE_DTYPE_MAP.get(compute_dtype_str, None) if compute_dtype is None: raise ValueError( - f"Invalid compute dtype: {compute_dtype}. Must be one of {list(COMPUTE_DTYPE_MAP.keys())}" + f"Invalid compute dtype: {compute_dtype_str}. Must be one of {list(COMPUTE_DTYPE_MAP.keys())}" ) device_str = f"cuda:{device_id}" From 1005e4750aa5547c8ded3dd3c47f49648cf0242d Mon Sep 17 00:00:00 2001 From: naymaraq Date: Thu, 2 Apr 2026 15:54:24 +0000 Subject: [PATCH 2/2] Apply isort and black reformatting Signed-off-by: naymaraq --- .../asr/inference/streaming/state/cache_aware_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/asr/inference/streaming/state/cache_aware_state.py b/nemo/collections/asr/inference/streaming/state/cache_aware_state.py index 941d0c87ed6d..db733fed41f8 100644 --- a/nemo/collections/asr/inference/streaming/state/cache_aware_state.py +++ b/nemo/collections/asr/inference/streaming/state/cache_aware_state.py @@ -72,7 +72,7 @@ def update_label_buffer(self, labels: list[int]) -> None: if shift == 0: return if shift >= len(self.label_buffer): - self.label_buffer[:] = labels[-len(self.label_buffer):] + self.label_buffer[:] = labels[-len(self.label_buffer) :] return self.label_buffer[:-shift] = self.label_buffer[shift:].copy() self.label_buffer[-shift:] = labels.copy()