-
Notifications
You must be signed in to change notification settings - Fork 16
Implement dropout for encode_minimal
#98
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
113a2fe
db7dd30
aa49a37
a336729
23c8ec9
08d4200
a10cce2
c215210
8ad286a
5493b12
3d63504
bba0765
65eb519
3f0d4fe
83dd2f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -526,9 +526,9 @@ impl BytePairEncoding { | |
| /// tokenization produced by the original BPE algorithm. | ||
| pub fn encode_minimal(&self, text: &[u8]) -> Vec<u32> { | ||
| let mut last_token: Vec<(u32, u32)> = Vec::with_capacity(text.len()); | ||
| let mut state = self.overlapping_searcher.start_state(); | ||
| for (pos, c) in text.iter().enumerate() { | ||
| let (s, iter) = self.overlapping_searcher.consume(state, pos + 1, *c); | ||
| let mut state = self.overlapping_searcher_rev.start_state(); | ||
| for (pos, c) in text.iter().rev().enumerate() { | ||
| let (s, iter) = self.overlapping_searcher_rev.consume(state, pos + 1, *c); | ||
| state = s; | ||
| let mut best = (0, u32::MAX); | ||
| for m in iter { | ||
|
|
@@ -548,7 +548,62 @@ impl BytePairEncoding { | |
| encoded.push(token); | ||
| pos -= self.token_len(token); | ||
| } | ||
| encoded.reverse(); | ||
| encoded | ||
| } | ||
|
|
||
| /// This function computes the encoding while randomly rejecting some merges. | ||
| /// Result of the encoding will be non-deterministic unless `seed` is provided. | ||
| /// Implementation loosely follows original BPE dropout paper: https://arxiv.org/abs/1910.13267 | ||
| /// | ||
| /// In more detail: the tokenization uses dynamic programming, i.e. it models the tokenization as a graph, | ||
| /// where every position between text bytes is a node and two nodes are connected when the text slice between those two nodes matches a token. | ||
| // It then tries to find the shortest possible path from the beginning of the text till the end, i.e. it finds the shortest possible encoding. | ||
| // For this is processes the nodes from left to right and visits all edges to the left. Then, it picks the edge which results in the shortest path. | ||
|
marinegor marked this conversation as resolved.
Outdated
|
||
| // The length of the shortest path is stored as second value, the edge (or rather token) is stored as first value. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since you described all this, you should also add the last step where we walk in reverse direction through the table along the shortest path.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added, thanks! |
||
| // | ||
| // For the dropout (when dropout > 0.0), we uniformly drop edges from the graph, but always keep the one-byte tokens such that the graph stays connected. | ||
| // Note: this is very different from how BPE works and cannot produce the same output as the algorithm | ||
| // in the [paper's repository](https://github.com/VProv/BPE-Dropout/blob/master/bpe.py#L98), for two main reasons: | ||
| // - `encode_minimal` already doesn't follow the original heap-based BPE procedure | ||
| // - randomness source in dropout works differently in rust and python | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this one you can drop IMO, since it shouldn't matter if a reasonable random number generator was chosen
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah makes sense, I think previous two reasons are enough to not claim complete reproducibility. |
||
| // - BPE-dropout authors discard all multi-byte tokens for each word separately, while this implementation does not split the "sentence" into words first | ||
| // and hence may include previously discarded token later down the byte stream. At the sentence level though we don't expect it to make much difference. | ||
|
Comment on lines
+572
to
+573
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are at least two/three distinct points in here
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think 1 and 3 were covered, I added 2. |
||
| #[cfg(feature = "rand")] | ||
| pub fn encode_minimal_dropout<R: rand::Rng>( | ||
| &self, | ||
| text: &[u8], | ||
| dropout: f32, | ||
| mut rng: R, | ||
| ) -> Vec<u32> { | ||
| assert!(0.0 <= dropout); | ||
| assert!(dropout <= 1.0); | ||
|
|
||
| let mut last_token: Vec<(u32, u32)> = Vec::with_capacity(text.len()); | ||
| let mut state = self.overlapping_searcher_rev.start_state(); | ||
| for (pos, c) in text.iter().rev().enumerate() { | ||
| let (s, iter) = self.overlapping_searcher_rev.consume(state, pos + 1, *c); | ||
| state = s; | ||
| let mut best = (0, u32::MAX); | ||
| for m in iter { | ||
| if m.end() > m.start() + 1 && dropout >= rng.random() { | ||
| continue; | ||
| } | ||
| if m.start() == 0 { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there some paper explaining in more detail how the randomization is supposed to work? Also, some documentation would be nice (as part of some readme and/or doc comment). If this is a one-to-one implementation of some paper, then we can probably just link to that paper.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The paper: https://arxiv.org/abs/1910.13267 We're interested in Algorithm 1 (page 3). Improvements rationale can be seen on Figure 6. I don't think it's an one-to-one implementation, since
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ...although I admit I don't really understand Intuition is that dropout roughly equals number of rejected merges in the final encoding, e.g. dropout ~=1 would result in almost single-byte encoding. However, I don't see that with where dictionary is So I'd appreciate any directions if you have any :)
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Note: this is very different from how BPE works and cannot produce the same output as the algorithm in the paper. The only implementation in this crate which follows the "standard" BPE algorithm is The problem with the algorithm in the paper is that it is VERY slow. So, maybe it is good enough to pick a different randomization process which follows the idea of the paper in spirit?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @aneubeck thanks for the explanation, that's actually very helpful. I guess that the only thing that matters is just being able to drop some merges before actually building tokenization. Could you have a look at the updated approach? I've changed the approach that I had before (which I think was very wrong), and instead now consider "best" tokens if they are not in "forbidden_tokens", which have been constructed prior to tokenization. My only worry is the single-byte tokens -- I'm not sure how they're handled, and I wouldn't like to discard them from the allowed tokens, but I'm not sure how to handle that properly. I'm talking about this line: ...
& (!(forbidden_tokens_set.contains(&m.value())) | ((m.end() - m.start()) == 1))
...I'm not sure if the second condition should be present or not, basically.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the changes! There was a little bug with how you treated tokens which started at the beginning of the text (you didn't filter larger tokens out there...). I also got rid of the pretty expensive lookup tables which you were computing. Those would slow down the processing drastically! It would be nice if you could extend the comment of this function describing in more detail what it does (i.e. we uniformly drop edges from the graph I described above, but always keep the one-byte tokens such that the graph stays connected). On my Macbook I measured about 30million input characters/sec with dropout and 40 million/sec with the "standard" minimal_encoding impelmentation.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for the changes as well!
Will do!
I'll spend some time playing around with a toy example (with
that's pretty cool :)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @aneubeck I've added some explanation and updated README slightly I'm running benchmarks now -- I guess it's simply Also, I'm running them on m4 -- should I update the description in README accordingly, or would you prefer to run it on your machine?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will try to review the changes tomorrow. |
||
| best = (m.value(), 1); | ||
| break; | ||
| } else if last_token[m.start() - 1].1 + 1 < best.1 { | ||
| best = (m.value(), last_token[m.start() - 1].1 + 1); | ||
| } | ||
| } | ||
| last_token.push(best); | ||
| } | ||
| let mut encoded = Vec::with_capacity(last_token.last().map(|l| l.1 as usize).unwrap_or(0)); | ||
| let mut pos = text.len(); | ||
| while pos > 0 { | ||
| let token = last_token[pos - 1].0; | ||
| encoded.push(token); | ||
| pos -= self.token_len(token); | ||
| } | ||
| encoded | ||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,3 +8,6 @@ bpe-openai = { path = "../../bpe-openai" } | |
| itertools = "0.14" | ||
| rand = "0.9" | ||
| tiktoken-rs = "0.9" | ||
|
|
||
| [dev-dependencies] | ||
| rand_chacha = { version = "0.9" } | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't need this dependency anymore I think
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed, thanks! |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,7 @@ | ||
| #[cfg(test)] | ||
| mod tests { | ||
| use std::time; | ||
|
|
||
|
marinegor marked this conversation as resolved.
Outdated
|
||
| use itertools::Itertools; | ||
| use rand::{rng, Rng}; | ||
| use tiktoken_rs::cl100k_base_singleton; | ||
|
|
@@ -141,4 +143,45 @@ mod tests { | |
| assert_eq!(enc.token_count(), bpe.count(&input[i..])); | ||
| } | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_bpe_dropout() { | ||
| use rand::rngs::StdRng; | ||
| use rand::SeedableRng; | ||
|
|
||
| fn get_rng(seed: u64) -> StdRng { | ||
| // Expand the u64 seed to 32 bytes | ||
| let mut seed_bytes = [0u8; 32]; | ||
| seed_bytes[..8].copy_from_slice(&seed.to_le_bytes()); | ||
| StdRng::from_seed(seed_bytes) | ||
| } | ||
|
marinegor marked this conversation as resolved.
|
||
|
|
||
| let bpe = &cl100k_base().bpe; | ||
| for bytes in [10000, 20000] { | ||
| for _ in 0..8 { | ||
|
marinegor marked this conversation as resolved.
Outdated
|
||
| let input = create_test_bytes(bpe, bytes); | ||
| let encoded = bpe.encode_minimal(&input); | ||
| let encoded_d_min = bpe.encode_minimal_dropout(&input, 0.2, get_rng(0)); | ||
| let encoded_d_max = bpe.encode_minimal_dropout(&input, 0.9, get_rng(1)); | ||
|
marinegor marked this conversation as resolved.
Outdated
|
||
| let encoded_d_1_0 = bpe.encode_minimal_dropout(&input, 1.0, get_rng(2)); | ||
| let decoded = bpe.decode_tokens(&encoded); | ||
| let decoded_min = bpe.decode_tokens(&encoded_d_min); | ||
| let decoded_max = bpe.decode_tokens(&encoded_d_max); | ||
| let decoded_max_again = bpe.decode_tokens(&encoded_d_1_0); | ||
| println!("Input length: {}, Encoded length: {}, Encoded with dropout length: {}-{}, max {}", | ||
| input.len(), encoded.len(), encoded_d_min.len(), encoded_d_max.len(), encoded_d_1_0.len()); | ||
| assert_eq!(input, decoded); | ||
| assert_eq!(input, decoded_min); | ||
| assert_eq!(input, decoded_max); | ||
| assert_eq!(input, decoded_max_again); | ||
| assert_eq!(input.len(), encoded_d_1_0.len()); | ||
| assert!(encoded_d_min.len() >= encoded.len()); | ||
| assert!(encoded_d_max.len() > encoded.len()); | ||
|
|
||
| assert_ne!(encoded, encoded_d_min); | ||
| assert_ne!(encoded, encoded_d_max); | ||
| assert_ne!(encoded_d_max, encoded_d_1_0); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. replace these assertions with explicit numbers, i.e. something like |
||
| } | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should leave a note here that this is using a different drop-out strategy than proposed in the paper and it was NOT tested with actual training sessions!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed, thanks.