Skip to content

Commit de230d9

Browse files
Preserve BOS/EOS as literal strings in decoded text output
Signed-off-by: Elena Rastorgueva <erastorgueva@nvidia.com>
1 parent eebea30 commit de230d9

1 file changed

Lines changed: 11 additions & 6 deletions

File tree

nemo/collections/speechlm2/parts/text_utils.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ def _decode_tokens_with_specials(
2727
Groups consecutive non-special tokens and decodes each group via
2828
``tokenizer.tokens_to_text()`` (HF ``convert_tokens_to_string``), which
2929
properly reverses byte-level BPE encoding (e.g. ``âĢĻ`` -> ``'``).
30-
Special tokens are never passed to ``convert_tokens_to_string`` — they
31-
are either inserted as literal strings or dropped entirely.
30+
Special tokens (BOS, EOS, PAD) are never passed to
31+
``convert_tokens_to_string``. BOS/EOS are always kept as literal
32+
strings so that turn boundaries are visible. PAD tokens are kept
33+
only when *keep_pad* is True.
3234
3335
Args:
3436
token_strings: Raw token strings from ``tokenizer.ids_to_tokens()``.
@@ -38,11 +40,11 @@ def _decode_tokens_with_specials(
3840
keep_pad: If True, preserve all special tokens as literal strings
3941
in the output. If False, strip them.
4042
"""
41-
# Build special-token set from explicit bos/eos/pad — same approach as
42-
# filter_special_tokens() and model_factory._extract_special_token_ids_from_nemo().
43-
special_tokens = {pad_token_str}
4443
bos = getattr(tokenizer, 'bos_token', None)
4544
eos = getattr(tokenizer, 'eos_token', None)
45+
46+
# All tokens that must not go through convert_tokens_to_string.
47+
special_tokens = {pad_token_str}
4648
if bos:
4749
special_tokens.add(bos)
4850
if eos:
@@ -56,7 +58,10 @@ def _decode_tokens_with_specials(
5658
if segment:
5759
result_parts.append(tokenizer.tokens_to_text(segment))
5860
segment = []
59-
if keep_pad:
61+
if tok == pad_token_str:
62+
if keep_pad:
63+
result_parts.append(tok)
64+
else:
6065
result_parts.append(tok)
6166
else:
6267
segment.append(tok)

0 commit comments

Comments
 (0)