Skip to content

Commit d47a3dd

Browse files
authored
Fix batching in fairseq SentencepieceTokenizer (#640)
* Move fairseq fix out of the loop. * Tidying up the patch.
1 parent 44e494b commit d47a3dd

1 file changed

Lines changed: 23 additions & 22 deletions

File tree

operators/tokenizer/sentencepiece_tokenizer.cc

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -83,32 +83,33 @@ OrtStatusPtr KernelSentencepieceTokenizer::Compute(const ortc::Tensor<std::strin
8383
content.push_back(tokenizer_.eos_id());
8484
token_indices.push_back(ort_extensions::narrow<int32_t>(str_input[i].length()));
8585
}
86-
87-
if (fairseq.has_value() && (*fairseq)) {
88-
// HF Fairseq Example (XLMRobertaTokenizer) : https://huggingface.co/transformers/v4.6.0/_modules/transformers/models/xlm_roberta/tokenization_xlm_roberta.html#XLMRobertaTokenizer
89-
//
90-
// Original fairseq vocab and spm vocab must be "aligned":
91-
// Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
92-
// -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
93-
// fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's' | '▁de' | '-'
94-
// spm | '<unk>' | '<s>' | '</s>' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a'
95-
//
96-
// As per HF, the first "real" token "," has position 4 in the XLMRobertaTokenizer vocab and position
97-
// 3 in the SPM vocab, so we add a padding value of 1 to IDs, and fix exceptions for '<unk>' and '<s>'.
98-
std::for_each(content.begin(), content.end(), [](int& n) {
99-
if (n == 0) { // '<unk>': 0 -> 3
100-
n = 3;
101-
} else if (n == 1) { // '<s>': 1 -> 0
102-
n = 0;
103-
} else if (n != 2) { // '</s>': 2 -> 2, '<*>': x -> x + 1
104-
n++;
105-
}
106-
});
107-
}
10886
}
10987
}
11088
instance_indices.push_back(content.size());
11189

90+
// Patch fairseq indices
91+
if (fairseq.has_value() && (*fairseq) && !add_rev) {
92+
// HF Fairseq Example (XLMRobertaTokenizer) : https://huggingface.co/transformers/v4.6.0/_modules/transformers/models/xlm_roberta/tokenization_xlm_roberta.html#XLMRobertaTokenizer
93+
//
94+
// Original fairseq vocab and spm vocab must be "aligned":
95+
// Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
96+
// -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
97+
// fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's' | '▁de' | '-'
98+
// spm | '<unk>' | '<s>' | '</s>' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a'
99+
//
100+
// As per HF, the first "real" token "," has position 4 in the XLMRobertaTokenizer vocab and position
101+
// 3 in the SPM vocab, so we add a padding value of 1 to IDs, and fix exceptions for '<unk>' and '<s>'.
102+
std::for_each(content.begin(), content.end(), [](int& n) {
103+
if (n == 0) { // '<unk>': 0 -> 3
104+
n = 3;
105+
} else if (n == 1) { // '<s>': 1 -> 0
106+
n = 0;
107+
} else if (n != 2) { // '</s>': 2 -> 2, '<*>': x -> x + 1
108+
n++;
109+
}
110+
});
111+
}
112+
112113
// Setup output
113114
std::vector<int64_t> size_content(1);
114115
size_content[0] = content.size();

0 commit comments

Comments
 (0)