@@ -83,32 +83,33 @@ OrtStatusPtr KernelSentencepieceTokenizer::Compute(const ortc::Tensor<std::strin
8383 content.push_back (tokenizer_.eos_id ());
8484 token_indices.push_back (ort_extensions::narrow<int32_t >(str_input[i].length ()));
8585 }
86-
87- if (fairseq.has_value () && (*fairseq)) {
88- // HF Fairseq Example (XLMRobertaTokenizer) : https://huggingface.co/transformers/v4.6.0/_modules/transformers/models/xlm_roberta/tokenization_xlm_roberta.html#XLMRobertaTokenizer
89- //
90- // Original fairseq vocab and spm vocab must be "aligned":
91- // Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
92- // -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
93- // fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's' | '▁de' | '-'
94- // spm | '<unk>' | '<s>' | '</s>' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a'
95- //
96- // As per HF, the first "real" token "," has position 4 in the XLMRobertaTokenizer vocab and position
97- // 3 in the SPM vocab, so we add a padding value of 1 to IDs, and fix exceptions for '<unk>' and '<s>'.
98- std::for_each (content.begin (), content.end (), [](int & n) {
99- if (n == 0 ) { // '<unk>': 0 -> 3
100- n = 3 ;
101- } else if (n == 1 ) { // '<s>': 1 -> 0
102- n = 0 ;
103- } else if (n != 2 ) { // '</s>': 2 -> 2, '<*>': x -> x + 1
104- n++;
105- }
106- });
107- }
10886 }
10987 }
11088 instance_indices.push_back (content.size ());
11189
90+ // Patch fairseq indices
91+ if (fairseq.has_value () && (*fairseq) && !add_rev) {
92+ // HF Fairseq Example (XLMRobertaTokenizer) : https://huggingface.co/transformers/v4.6.0/_modules/transformers/models/xlm_roberta/tokenization_xlm_roberta.html#XLMRobertaTokenizer
93+ //
94+ // Original fairseq vocab and spm vocab must be "aligned":
95+ // Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
96+ // -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
97+ // fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's' | '▁de' | '-'
98+ // spm | '<unk>' | '<s>' | '</s>' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a'
99+ //
100+ // As per HF, the first "real" token "," has position 4 in the XLMRobertaTokenizer vocab and position
101+ // 3 in the SPM vocab, so we add a padding value of 1 to IDs, and fix exceptions for '<unk>' and '<s>'.
102+ std::for_each (content.begin (), content.end (), [](int & n) {
103+ if (n == 0 ) { // '<unk>': 0 -> 3
104+ n = 3 ;
105+ } else if (n == 1 ) { // '<s>': 1 -> 0
106+ n = 0 ;
107+ } else if (n != 2 ) { // '</s>': 2 -> 2, '<*>': x -> x + 1
108+ n++;
109+ }
110+ });
111+ }
112+
112113 // Setup output
113114 std::vector<int64_t > size_content (1 );
114115 size_content[0 ] = content.size ();
0 commit comments