Skip to content

Commit 78694d5

Browse files
authored
Add Arabic char tokenizer and Japanese-English katakana support (#15614)
Apply isort and black reformatting Fix Hindi chartokenizer, use 'case=upper' and prevent duplicate spaces in the Japanese G2P fallback paths. Apply isort and black reformatting Fix HindiCharsTokenizer backward compat and add Arabic dialect tests Apply isort and black reformatting Add Arabic tokenizer test coverage: diacritics, dialects, punctuation, unknown chars Expand Arabic tokenizer tests: parametrize diacritics, dialects Apply isort and black reformatting added comprehensive test coverage. fix: add back-compatibility, case=mixed, ascii_letters. fix: add charset_version to Hindi/Arabic tokenizers for backward compatibility Introduce a parameter in HindiCharsTokenizer and ArabicCharsTokenizer so old models (v1: case='mixed') keep working while new models train with the corrected charset (v2: case='upper'). - Define CASELESS_SCRIPT_TOKENIZER_TARGETS and DEFAULT_CHARSET_VERSION constants in tts_tokenizers.py - Persist charset_version into the OmegaConf config during training (setup_tokenizers) so .nemo archives record which version was used - Add _migrate_charset_version() helper in magpietts inference utils to pin charset_version=1 for old checkpoints that lack the field, preventing a silent vocabulary mismatch at inference time bugfix: L2_TTS_Fast_dev_runs_Magpietts_OnlineCFGDistillation.sh Signed-off-by: quapham <quapham@users.noreply.github.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
1 parent 2223816 commit 78694d5

7 files changed

Lines changed: 347 additions & 21 deletions

File tree

nemo/collections/common/tokenizers/text_to_speech/ipa_lexicon.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# These functions are used by locale-specific tokenizers (e.g., HindiCharsTokenizer uses
2222
# get_grapheme_character_set("hi-IN")). If someone later creates PortugueseCharsTokenizer or
2323
# KoreanCharsTokenizer, they'd hit this.
24-
SUPPORTED_LOCALES = ["en-US", "de-DE", "es-ES", "it-IT", "fr-FR", "vi-VN", "ja-JP", "hi-IN", "pt-BR", "ko-KR"]
24+
SUPPORTED_LOCALES = ["en-US", "de-DE", "es-ES", "it-IT", "fr-FR", "vi-VN", "ja-JP", "hi-IN", "ar-MSA", "pt-BR", "ko-KR"]
2525

2626
# Derived from LJSpeech and "/" additionally
2727
DEFAULT_PUNCTUATION = (
@@ -114,6 +114,15 @@
114114
# Danda (period)
115115
'।',
116116
),
117+
# ref: https://en.wikipedia.org/wiki/Arabic_alphabet
118+
"ar-MSA": (
119+
'ء', 'آ', 'أ', 'إ', 'ؤ', 'ئ', 'ا', 'ب', 'ة', 'ت',
120+
'ث', 'ج', 'ح', 'خ', 'د', 'ذ', 'ر', 'ز', 'س', 'ش',
121+
'ص', 'ض', 'ط', 'ظ', 'ع', 'غ', 'ف', 'ق', 'ك', 'ل',
122+
'م', 'ن', 'ه', 'و', 'ى', 'ي',
123+
# Diacritics
124+
'ً', 'ٌ', 'ٍ', 'َ', 'ُ', 'ِ', 'ّ', 'ٰ', 'ْ',
125+
),
117126
}
118127

119128
IPA_CHARACTER_SETS = {
@@ -354,6 +363,14 @@ def get_ipa_punctuation_list(locale):
354363
'・',
355364
]
356365
)
366+
elif locale == "ar-MSA":
367+
punct_set.update(
368+
[
369+
'،',
370+
'؛',
371+
'؟',
372+
]
373+
)
357374
elif locale == "hi-IN":
358375
punct_set.update(
359376
[

nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py

Lines changed: 133 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@
4444
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
4545
from nemo.utils import logging
4646

47+
CASELESS_SCRIPT_TOKENIZER_TARGETS = frozenset(
48+
{
49+
'nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.HindiCharsTokenizer',
50+
'nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.ArabicCharsTokenizer',
51+
}
52+
)
53+
54+
DEFAULT_CHARSET_VERSION = 2
55+
4756

4857
class BaseTokenizer(ABC):
4958
"""Abstract class for creating an arbitrary tokenizer to convert string to list of int tokens.
@@ -164,9 +173,8 @@ def encode(self, text):
164173
logging.warning(f"Text: [{text}] contains unknown char: [{c}]. Symbol will be skipped.")
165174

166175
# Remove trailing spaces
167-
if cs:
168-
while cs[-1] == space:
169-
cs.pop()
176+
while cs and cs[-1] == space:
177+
cs.pop()
170178

171179
if self.pad_with_space:
172180
cs = [space] + cs + [space]
@@ -382,6 +390,14 @@ def __init__(
382390
class HindiCharsTokenizer(BaseCharsTokenizer):
383391
"""Hindi grapheme tokenizer (character-based, no phonemes).
384392
Args:
393+
chars: Explicit character set string. When provided, ``charset_version`` is ignored.
394+
charset_version: Controls which default character set to use (only when ``chars`` is None).
395+
``2`` (default) — ``case="upper"`` Devanagari + ``ascii_letters``.
396+
Hindi/Devanagari has no case distinction, so ``case="upper"`` avoids duplicating
397+
every code-point. ``ascii_letters`` covers both upper- and lower-case English for
398+
mixed-language text.
399+
``1`` — legacy ``case="mixed"`` Devanagari + ``ascii_lowercase``. Use this value to
400+
restore models that were trained before the charset fix.
385401
punct: Whether to reserve grapheme for basic punctuation or not.
386402
apostrophe: Whether to use apostrophe or not.
387403
add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None),
@@ -406,13 +422,14 @@ class HindiCharsTokenizer(BaseCharsTokenizer):
406422

407423
_LOCALE = "hi-IN"
408424
_PUNCT_LIST = get_ipa_punctuation_list(_LOCALE)
425+
_CHARSET_STR = get_grapheme_character_set(locale=_LOCALE, case="upper") + string.ascii_letters
409426
_PUNCT_LIST_V1 = sorted(list(DEFAULT_PUNCTUATION))
410-
_CHARSET_STR = get_grapheme_character_set(locale=_LOCALE, case="mixed")
411-
_CHARSET_STR += string.ascii_lowercase
427+
_CHARSET_STR_V1 = get_grapheme_character_set(locale=_LOCALE, case="mixed") + string.ascii_lowercase
412428

413429
def __init__(
414430
self,
415-
chars=_CHARSET_STR,
431+
chars=None,
432+
charset_version=2,
416433
punct=True,
417434
apostrophe=True,
418435
add_blank_at=None,
@@ -421,6 +438,22 @@ def __init__(
421438
punct_version=2,
422439
text_preprocessing_func=any_locale_text_preprocessing,
423440
):
441+
if chars is None:
442+
if charset_version == 1:
443+
warnings.warn(
444+
"HindiCharsTokenizer charset_version=1 (case='mixed' + ascii_lowercase) is deprecated "
445+
"and will be removed in a future release. "
446+
"Migrate to charset_version=2 (case='upper' + ascii_letters) and retrain.",
447+
DeprecationWarning,
448+
stacklevel=2,
449+
)
450+
chars = self._CHARSET_STR_V1
451+
elif charset_version == 2:
452+
chars = self._CHARSET_STR
453+
else:
454+
raise ValueError(
455+
f"HindiCharsTokenizer: unsupported charset_version={charset_version!r}. Use 1 (legacy) or 2."
456+
)
424457
if non_default_punct_list is None:
425458
if punct_version == 1:
426459
warnings.warn(
@@ -467,9 +500,100 @@ def encode(self, text):
467500
logging.warning(f"Text: [{text}] contains unknown char: [{c}]. Symbol will be skipped.")
468501

469502
# Remove trailing spaces
470-
if cs:
471-
while cs[-1] == space:
472-
cs.pop()
503+
while cs and cs[-1] == space:
504+
cs.pop()
505+
506+
if self.pad_with_space:
507+
cs = [space] + cs + [space]
508+
509+
return [self._token2id[p] for p in cs]
510+
511+
512+
class ArabicCharsTokenizer(BaseCharsTokenizer):
513+
"""Arabic grapheme tokenizer (character-based, no phonemes).
514+
Args:
515+
chars: Explicit character set string. When provided, ``charset_version`` is ignored.
516+
charset_version: Controls which default character set to use (only when ``chars`` is None).
517+
``2`` (default) — ``case="upper"`` Arabic + ``ascii_letters``.
518+
Arabic script has no case distinction, so ``case="upper"`` avoids duplicating
519+
every code-point. ``ascii_letters`` covers both upper- and lower-case English for
520+
mixed-language text.
521+
``1`` — legacy ``case="mixed"`` Arabic + ``ascii_letters``. Use this value to
522+
restore models that were trained before the charset fix.
523+
punct: Whether to reserve grapheme for basic punctuation or not.
524+
apostrophe: Whether to use apostrophe or not.
525+
add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None),
526+
if None then no blank in labels.
527+
pad_with_space: Whether to pad text with spaces at the beginning and at the end or not.
528+
non_default_punct_list: List of punctuation marks which will be used instead default.
529+
text_preprocessing_func: Text preprocessing function. Keeps Arabic unchanged.
530+
531+
Each Unicode code point becomes 1 token (letters, diacritics, and Arabic punct from ipa_lexicon).
532+
Supports both upper and lower English letters (e.g. mixed-language text).
533+
534+
Input Text: مرحبا Hello
535+
Chars: ['م', 'ر', 'ح', 'ب', 'ا', ' ', 'H', 'e', 'l', 'l', 'o']
536+
"""
537+
538+
_LOCALE = "ar-MSA"
539+
_PUNCT_LIST = get_ipa_punctuation_list(_LOCALE)
540+
_CHARSET_STR = get_grapheme_character_set(locale=_LOCALE, case="upper") + string.ascii_letters
541+
_CHARSET_STR_V1 = get_grapheme_character_set(locale=_LOCALE, case="mixed") + string.ascii_letters
542+
543+
def __init__(
544+
self,
545+
chars=None,
546+
charset_version=2,
547+
punct=True,
548+
apostrophe=True,
549+
add_blank_at=None,
550+
pad_with_space=False,
551+
non_default_punct_list=_PUNCT_LIST,
552+
text_preprocessing_func=any_locale_text_preprocessing,
553+
):
554+
if chars is None:
555+
if charset_version == 1:
556+
warnings.warn(
557+
"ArabicCharsTokenizer charset_version=1 (case='mixed' + ascii_letters) is deprecated "
558+
"and will be removed in a future release. "
559+
"Migrate to charset_version=2 (case='upper' + ascii_letters) and retrain.",
560+
DeprecationWarning,
561+
stacklevel=2,
562+
)
563+
chars = self._CHARSET_STR_V1
564+
elif charset_version == 2:
565+
chars = self._CHARSET_STR
566+
else:
567+
raise ValueError(
568+
f"ArabicCharsTokenizer: unsupported charset_version={charset_version!r}. Use 1 (legacy) or 2."
569+
)
570+
super().__init__(
571+
chars=chars,
572+
punct=punct,
573+
apostrophe=apostrophe,
574+
add_blank_at=add_blank_at,
575+
pad_with_space=pad_with_space,
576+
non_default_punct_list=non_default_punct_list,
577+
text_preprocessing_func=text_preprocessing_func,
578+
)
579+
580+
def encode(self, text):
581+
"""Encode Arabic text, handling diacritics and English (upper/lower) correctly."""
582+
cs, space, tokens = [], self.tokens[self.space], set(self.tokens)
583+
584+
text = self.text_preprocessing_func(text)
585+
for c in text:
586+
if c == space and len(cs) > 0 and cs[-1] != space:
587+
cs.append(c)
588+
elif c in tokens and c != space:
589+
cs.append(c)
590+
elif (c in self.PUNCT_LIST) and self.punct:
591+
cs.append(c)
592+
elif c != space:
593+
logging.warning(f"Text: [{text}] contains unknown char: [{c}]. Symbol will be skipped.")
594+
595+
while cs and cs[-1] == space:
596+
cs.pop()
473597

474598
if self.pad_with_space:
475599
cs = [space] + cs + [space]

nemo/collections/tts/data/text_to_speech_dataset_lhotse.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@
2424
from omegaconf import DictConfig, open_dict
2525
from transformers import AutoTokenizer, T5Tokenizer
2626

27-
from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import AggregatedTTSTokenizer, IPABPETokenizer
27+
from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import (
28+
CASELESS_SCRIPT_TOKENIZER_TARGETS,
29+
DEFAULT_CHARSET_VERSION,
30+
AggregatedTTSTokenizer,
31+
IPABPETokenizer,
32+
)
2833
from nemo.collections.tts.parts.utils.tts_dataset_utils import (
2934
beta_binomial_prior_distribution,
3035
normalize_volume,
@@ -74,6 +79,18 @@ def setup_tokenizers(all_tokenizers_config, mode='train'):
7479
# TODO @xueyang: is it really necessary to set phone probability to 1.0 for test mode?
7580
if mode == 'test' and hasattr(tokenizer, "set_phone_prob"):
7681
tokenizer.set_phone_prob(1.0)
82+
83+
# Persist charset_version so it's saved in .nemo archives and
84+
# update_config_for_inference can distinguish old checkpoints
85+
# (missing charset_version → v1) from new ones.
86+
if (
87+
hasattr(tokenizer_config, '_target_')
88+
and tokenizer_config._target_ in CASELESS_SCRIPT_TOKENIZER_TARGETS
89+
and not hasattr(tokenizer_config, 'charset_version')
90+
):
91+
with open_dict(all_tokenizers_config):
92+
tokenizer_config.charset_version = DEFAULT_CHARSET_VERSION
93+
7794
tokenizers.append(tokenizer)
7895
tokenizer_names.append(tokenizer_name)
7996

nemo/collections/tts/g2p/models/ja_jp_ipa.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -298,28 +298,44 @@ def __call__(self, text: str) -> List[str]:
298298
acc = word.get('acc', 0)
299299

300300
string = unicodedata.normalize('NFKC', string)
301+
pos_group1 = word.get('pos_group1', '')
301302

302-
# Handle English letters
303+
# If string is pure ASCII letters after NFKC normalization, decide based on pos:
304+
# - pos='フィラー' means OpenJTalk didn't recognize the word (just spelled it out) → keep Latin
305+
# - any other pos (e.g. '名詞') means OpenJTalk recognized it as a loanword → use katakana pron
303306
if string and all(c in self.ascii_letter_dict for c in string):
304-
if current_chain:
305-
self._process_chain(current_chain, result)
306-
current_chain = []
307-
308-
result.extend(list(string))
309-
continue
307+
if pos == 'フィラー':
308+
if current_chain:
309+
self._process_chain(current_chain, result)
310+
current_chain = []
311+
result.extend(list(string))
312+
continue
310313

311-
# Handle punctuation
312-
if pos in ('記号', '補助記号'):
314+
# Handle punctuation (記号), but keep alphabet symbols (アルファベット) as regular words
315+
if pos in ('記号', '補助記号') and pos_group1 != 'アルファベット':
313316
if current_chain:
314317
self._process_chain(current_chain, result)
315318
current_chain = []
316319
if string.isspace():
317-
result.append(' ')
320+
if not result or result[-1] != ' ':
321+
result.append(' ')
318322
elif string in punctuation:
319323
result.append(string)
324+
else:
325+
logging.warning(
326+
f"Unknown symbol '{string}' (pos={pos}) not in punctuation list, replacing with space. original text: {text}"
327+
)
328+
if not result or result[-1] != ' ':
329+
result.append(' ')
320330
continue
321331

322332
if not pron or mora_size == 0:
333+
if string and not string.isspace():
334+
logging.warning(
335+
f"Unknown symbol '{string}' (pos={pos}) not in punctuation list, replacing with space. original text: {text}"
336+
)
337+
if not result or result[-1] != ' ':
338+
result.append(' ')
323339
continue
324340

325341
# Add word to current chain

nemo/collections/tts/modules/magpietts_inference/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import torch
2929
from omegaconf import DictConfig, OmegaConf, open_dict
3030

31+
from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import CASELESS_SCRIPT_TOKENIZER_TARGETS
3132
from nemo.collections.tts.models import EasyMagpieTTSInferenceModel, MagpieTTSModel
3233
from nemo.utils import logging
3334

@@ -149,6 +150,24 @@ def validate(self) -> None:
149150
)
150151

151152

153+
def _migrate_charset_version(model_cfg: DictConfig) -> None:
154+
"""Pin charset_version=1 for Hindi/Arabic tokenizers in old checkpoints.
155+
156+
New models have ``charset_version`` persisted by ``setup_tokenizers()``.
157+
Old checkpoints lack it, so without this migration the new default (v2)
158+
would silently change the token-to-ID mapping and break the model.
159+
160+
Must be called inside ``open_dict(model_cfg)``.
161+
"""
162+
if not hasattr(model_cfg, 'text_tokenizers'):
163+
return
164+
for tok_name in model_cfg.text_tokenizers:
165+
tok_cfg = model_cfg.text_tokenizers[tok_name]
166+
if hasattr(tok_cfg, '_target_') and tok_cfg._target_ in CASELESS_SCRIPT_TOKENIZER_TARGETS:
167+
if not hasattr(tok_cfg, 'charset_version'):
168+
tok_cfg.charset_version = 1
169+
170+
152171
def _migrate_tokenizer_punctuation(model_cfg: DictConfig) -> None:
153172
"""Backfill punctuation fields for tokenizers that predate them.
154173
@@ -203,6 +222,7 @@ def update_config_for_inference(
203222
model_cfg.codecmodel_path = codecmodel_path
204223

205224
_migrate_tokenizer_punctuation(model_cfg)
225+
_migrate_charset_version(model_cfg)
206226

207227
# Update text tokenizer paths for backward compatibility
208228
if hasattr(model_cfg, 'text_tokenizer'):

0 commit comments

Comments
 (0)