Skip to content

Commit 4762ab3

Browse files
committed
fix: add back-compatibility, case=mixed, ascii_letters.
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
1 parent dd39423 commit 4762ab3

2 files changed

Lines changed: 123 additions & 11 deletions

File tree

nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import itertools
1717
import os
1818
import string
19+
import warnings
1920
from abc import ABC, abstractmethod
2021
from contextlib import contextmanager
2122
from typing import List, Optional, Union
@@ -386,6 +387,14 @@ def __init__(
386387
class HindiCharsTokenizer(BaseCharsTokenizer):
387388
"""Hindi grapheme tokenizer (character-based, no phonemes).
388389
Args:
390+
chars: Explicit character set string. When provided, ``charset_version`` is ignored.
391+
charset_version: Controls which default character set to use (only when ``chars`` is None).
392+
``2`` (default) — ``case="upper"`` Devanagari + ``ascii_letters``.
393+
Hindi/Devanagari has no case distinction, so ``case="upper"`` avoids duplicating
394+
every code-point. ``ascii_letters`` covers both upper- and lower-case English for
395+
mixed-language text.
396+
``1`` — legacy ``case="mixed"`` Devanagari + ``ascii_lowercase``. Use this value to
397+
restore models that were trained before the charset fix.
389398
punct: Whether to reserve grapheme for basic punctuation or not.
390399
apostrophe: Whether to use apostrophe or not.
391400
add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None),
@@ -404,24 +413,36 @@ class HindiCharsTokenizer(BaseCharsTokenizer):
404413

405414
_LOCALE = "hi-IN"
406415
_PUNCT_LIST = get_ipa_punctuation_list(_LOCALE)
407-
_CHARSET_STR = get_grapheme_character_set(locale=_LOCALE, case="mixed")
408-
_CHARSET_STR += string.ascii_lowercase
416+
_CHARSET_STR = get_grapheme_character_set(locale=_LOCALE, case="upper") + string.ascii_letters
417+
_CHARSET_STR_V1 = get_grapheme_character_set(locale=_LOCALE, case="mixed") + string.ascii_lowercase
409418

410419
def __init__(
411420
self,
412-
chars=_CHARSET_STR,
421+
chars=None,
422+
charset_version=2,
413423
punct=True,
414424
apostrophe=True,
415425
add_blank_at=None,
416426
pad_with_space=False,
417427
non_default_punct_list=_PUNCT_LIST,
418428
text_preprocessing_func=any_locale_text_preprocessing,
419429
):
420-
logging.warning(
421-
"HindiCharsTokenizer: The default character set (case='mixed' + ascii_lowercase) "
422-
"is deprecated and will change to (case='upper' + ascii_letters) in the next release. "
423-
"Please pass 'chars' explicitly to avoid unexpected behavior."
424-
)
430+
if chars is None:
431+
if charset_version == 1:
432+
warnings.warn(
433+
"HindiCharsTokenizer charset_version=1 (case='mixed' + ascii_lowercase) is deprecated "
434+
"and will be removed in a future release. "
435+
"Migrate to charset_version=2 (case='upper' + ascii_letters) and retrain.",
436+
DeprecationWarning,
437+
stacklevel=2,
438+
)
439+
chars = self._CHARSET_STR_V1
440+
elif charset_version == 2:
441+
chars = self._CHARSET_STR
442+
else:
443+
raise ValueError(
444+
f"HindiCharsTokenizer: unsupported charset_version={charset_version!r}. Use 1 (legacy) or 2."
445+
)
425446
super().__init__(
426447
chars=chars,
427448
punct=punct,
@@ -466,6 +487,14 @@ def encode(self, text):
466487
class ArabicCharsTokenizer(BaseCharsTokenizer):
467488
"""Arabic grapheme tokenizer (character-based, no phonemes).
468489
Args:
490+
chars: Explicit character set string. When provided, ``charset_version`` is ignored.
491+
charset_version: Controls which default character set to use (only when ``chars`` is None).
492+
``2`` (default) — ``case="upper"`` Arabic + ``ascii_letters``.
493+
Arabic script has no case distinction, so ``case="upper"`` avoids duplicating
494+
every code-point. ``ascii_letters`` covers both upper- and lower-case English for
495+
mixed-language text.
496+
``1`` — legacy ``case="mixed"`` Arabic + ``ascii_letters``. Use this value to
497+
restore models that were trained before the charset fix.
469498
punct: Whether to reserve grapheme for basic punctuation or not.
470499
apostrophe: Whether to use apostrophe or not.
471500
add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None),
@@ -483,19 +512,36 @@ class ArabicCharsTokenizer(BaseCharsTokenizer):
483512

484513
_LOCALE = "ar-MSA"
485514
_PUNCT_LIST = get_ipa_punctuation_list(_LOCALE)
486-
_CHARSET_STR = get_grapheme_character_set(locale=_LOCALE, case="mixed")
487-
_CHARSET_STR += string.ascii_letters
515+
_CHARSET_STR = get_grapheme_character_set(locale=_LOCALE, case="upper") + string.ascii_letters
516+
_CHARSET_STR_V1 = get_grapheme_character_set(locale=_LOCALE, case="mixed") + string.ascii_letters
488517

489518
def __init__(
490519
self,
491-
chars=_CHARSET_STR,
520+
chars=None,
521+
charset_version=2,
492522
punct=True,
493523
apostrophe=True,
494524
add_blank_at=None,
495525
pad_with_space=False,
496526
non_default_punct_list=_PUNCT_LIST,
497527
text_preprocessing_func=any_locale_text_preprocessing,
498528
):
529+
if chars is None:
530+
if charset_version == 1:
531+
warnings.warn(
532+
"ArabicCharsTokenizer charset_version=1 (case='mixed' + ascii_letters) is deprecated "
533+
"and will be removed in a future release. "
534+
"Migrate to charset_version=2 (case='upper' + ascii_letters) and retrain.",
535+
DeprecationWarning,
536+
stacklevel=2,
537+
)
538+
chars = self._CHARSET_STR_V1
539+
elif charset_version == 2:
540+
chars = self._CHARSET_STR
541+
else:
542+
raise ValueError(
543+
f"ArabicCharsTokenizer: unsupported charset_version={charset_version!r}. Use 1 (legacy) or 2."
544+
)
499545
super().__init__(
500546
chars=chars,
501547
punct=punct,

tests/collections/common/tokenizers/text_to_speech/test_tts_tokenizers.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,51 @@ def test_hindi_chars_tokenizer(self):
329329

330330
assert chars == expected_output
331331

332+
@pytest.mark.run_only_on('CPU')
333+
@pytest.mark.unit
334+
def test_hindi_chars_tokenizer_legacy(self):
335+
"""charset_version=1 reproduces the old (case='mixed' + ascii_lowercase) behaviour."""
336+
input_text = "नमस्ते दुनिया!"
337+
expected_output = "नमस्ते दुनिया!"
338+
339+
tokenizer = HindiCharsTokenizer(charset_version=1)
340+
chars, tokens = self._parse_text(tokenizer, input_text)
341+
342+
assert chars == expected_output
343+
344+
@pytest.mark.run_only_on('CPU')
345+
@pytest.mark.unit
346+
def test_hindi_chars_tokenizer_mixed_english(self):
347+
"""Default v2 charset supports both upper and lower English."""
348+
input_text = "नमस्ते Hello World"
349+
expected_output = "नमस्ते Hello World"
350+
351+
tokenizer = HindiCharsTokenizer()
352+
chars, tokens = self._parse_text(tokenizer, input_text)
353+
354+
assert chars == expected_output
355+
356+
@pytest.mark.run_only_on('CPU')
357+
@pytest.mark.unit
358+
def test_hindi_chars_tokenizer_v1_no_upper_english(self):
359+
"""Legacy v1 charset only has ascii_lowercase, so uppercase English is skipped."""
360+
input_text = "नमस्ते Hello"
361+
expected_output = "नमस्ते ello"
362+
363+
tokenizer = HindiCharsTokenizer(charset_version=1)
364+
chars, tokens = self._parse_text(tokenizer, input_text)
365+
366+
assert chars == expected_output
367+
368+
@pytest.mark.run_only_on('CPU')
369+
@pytest.mark.unit
370+
def test_hindi_chars_tokenizer_v1_v2_different_vocab(self):
371+
"""v1 and v2 must produce different token vocabularies."""
372+
tok_v1 = HindiCharsTokenizer(charset_version=1)
373+
tok_v2 = HindiCharsTokenizer(charset_version=2)
374+
375+
assert tok_v1.tokens != tok_v2.tokens
376+
332377
@pytest.mark.run_only_on('CPU')
333378
@pytest.mark.unit
334379
def test_arabic_chars_tokenizer_mixed_english(self):
@@ -379,3 +424,24 @@ def test_arabic_chars_tokenizer_unknown_token(self):
379424
chars, tokens = self._parse_text(tokenizer, input_text)
380425

381426
assert chars == expected_output
427+
428+
@pytest.mark.run_only_on('CPU')
429+
@pytest.mark.unit
430+
def test_arabic_chars_tokenizer_legacy(self):
431+
"""charset_version=1 reproduces the old (case='mixed' + ascii_letters) behaviour."""
432+
input_text = "مرحبا Hello"
433+
expected_output = "مرحبا Hello"
434+
435+
tokenizer = ArabicCharsTokenizer(charset_version=1)
436+
chars, tokens = self._parse_text(tokenizer, input_text)
437+
438+
assert chars == expected_output
439+
440+
@pytest.mark.run_only_on('CPU')
441+
@pytest.mark.unit
442+
def test_arabic_chars_tokenizer_v1_v2_different_vocab(self):
443+
"""v1 and v2 must produce different token vocabularies."""
444+
tok_v1 = ArabicCharsTokenizer(charset_version=1)
445+
tok_v2 = ArabicCharsTokenizer(charset_version=2)
446+
447+
assert tok_v1.tokens != tok_v2.tokens

0 commit comments

Comments
 (0)