Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def load_model(model_name: str, map_location: torch.device) -> ASRModel:
asr_model.eval()
return asr_model
except Exception as e:
raise RuntimeError(f"Failed to load model {model_name}: {str(e)}")
raise RuntimeError(f"Failed to load model {model_name}: {str(e)}") from e

@property
def word_separator(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def load_model(model_name: str, device: torch.device) -> SALM:
model.to(device)
return model
except Exception as e:
raise RuntimeError(f"Failed to load model {model_name}: {str(e)}")
raise RuntimeError(f"Failed to load model {model_name}: {str(e)}") from e

def get_window_stride(self) -> float:
"""Returns the window stride of the model."""
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/inference/nmt/llm_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def load_model(self, llm_params: dict) -> LLM:
model = LLM(model=self.model_name, **llm_params)
return model
except Exception as e:
raise RuntimeError(f"Model loading failed: {str(e)}")
raise RuntimeError(f"Model loading failed: {str(e)}") from e

def translate_batch(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def free_slots(self, slot_ids: list[int]) -> None:
slot_ids (list[int]): list of slot ids
"""
for slot_id in slot_ids:
if slot_id not in self.slotidx2streamidx:
continue
self.available_slots.put(slot_id)
stream_id = self.slotidx2streamidx[slot_id]
del self.slotidx2streamidx[slot_id], self.streamidx2slotidx[stream_id]
Expand Down Expand Up @@ -148,7 +150,7 @@ def preprocess(
right_padding = torch.floor(right_paddings / self.sample_rate / self.timestep_duration) # B
return features, right_padding

def _update_feature_buffer(self, slot_ids: int, feat_chunk: Tensor) -> None:
def _update_feature_buffer(self, slot_ids: list[int], feat_chunk: Tensor) -> None:
"""
Add an extracted feature to `feature_buffer`
Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def detect_eou_near_pivot(
raise ValueError("Pivot point is out of range")

if search_start_point > pivot_point:
raise ValueError("Search start point is greater then pivot_point")
raise ValueError("Search start point is greater than pivot_point")

if self.residue_tokens_at_end > 0:
sequence_length = max(0, sequence_length - self.residue_tokens_at_end)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ class ASRRequestOptions:
None value means that the option is not set and the default value will be used
"""

enable_itn: bool = None
enable_pnc: bool = None
stop_history_eou: int = None
asr_output_granularity: ASROutputGranularity | str = None
enable_itn: bool | None = None
enable_pnc: bool | None = None
stop_history_eou: int | None = None
asr_output_granularity: ASROutputGranularity | str | None = None
language_code: str | None = None
enable_nmt: bool = None
source_language: str = None
target_language: str = None
enable_nmt: bool | None = None
source_language: str | None = None
target_language: str | None = None
biasing_cfg: BiasingRequestItemConfig | None = None

def __post_init__(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ def update_label_buffer(self, labels: list[int]) -> None:
labels: (list[int]) list of labels
"""
shift = len(labels)
if shift == 0:
return
if shift >= len(self.label_buffer):
self.label_buffer[:] = labels[-len(self.label_buffer) :]
return
self.label_buffer[:-shift] = self.label_buffer[shift:].copy()
self.label_buffer[-shift:] = labels.copy()

Expand Down
5 changes: 3 additions & 2 deletions nemo/collections/asr/inference/utils/device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ def setup_device(device: str, device_id: int | None, compute_dtype: str) -> tupl
logging.warning(f"Device ID {device_id} is not available. Using GPU 0 instead.")
device_id = 0

compute_dtype = COMPUTE_DTYPE_MAP.get(compute_dtype, None)
compute_dtype_str = compute_dtype
compute_dtype = COMPUTE_DTYPE_MAP.get(compute_dtype_str, None)
if compute_dtype is None:
raise ValueError(
f"Invalid compute dtype: {compute_dtype}. Must be one of {list(COMPUTE_DTYPE_MAP.keys())}"
f"Invalid compute dtype: {compute_dtype_str}. Must be one of {list(COMPUTE_DTYPE_MAP.keys())}"
)

device_str = f"cuda:{device_id}"
Expand Down
Loading