|
18 | 18 | # https://github.com/google-research/albert/blob/master/create_pretraining_data.py |
19 | 19 | # with some modifications. |
20 | 20 |
|
| 21 | +from enum import Enum |
21 | 22 | import math |
22 | 23 | import os |
23 | 24 | import time |
|
41 | 42 | DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5] |
42 | 43 |
|
43 | 44 |
|
| 45 | +class SamplingStyle(Enum): |
| 46 | + POISSON = 'poisson' |
| 47 | + GEOMETRIC = 'geometric' |
| 48 | + |
| 49 | + |
44 | 50 | def analyze_data_prefix(data_prefix): |
45 | 51 |
|
46 | 52 | # The data prefix should be in the format of: |
@@ -194,9 +200,15 @@ def create_masked_lm_predictions(tokens, |
194 | 200 | favor_longer_ngram=False, |
195 | 201 | do_permutation=False, |
196 | 202 | geometric_dist=False, |
197 | | - masking_style="bert"): |
| 203 | + masking_style="bert", |
| 204 | + sampling_style=SamplingStyle.POISSON): |
198 | 205 | """Creates the predictions for the masked LM objective. |
199 | 206 | Note: Tokens here are vocab ids and not text tokens.""" |
| 207 | + if not isinstance(sampling_style, SamplingStyle): |
| 208 | + sampling_style = SamplingStyle(sampling_style) |
| 209 | + # Backward-compatibility |
| 210 | + if geometric_dist: |
| 211 | + sampling_style = SamplingStyle.GEOMETRIC |
200 | 212 |
|
201 | 213 | cand_indexes = [] |
202 | 214 | # Note(mingdachen): We create a list for recording if the piece is |
@@ -235,7 +247,7 @@ def create_masked_lm_predictions(tokens, |
235 | 247 | max(1, int(round(len(tokens) * masked_lm_prob)))) |
236 | 248 |
|
237 | 249 | ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64) |
238 | | - if not geometric_dist: |
| 250 | + if sampling_style is SamplingStyle.POISSON: |
239 | 251 | # Note(mingdachen): |
240 | 252 | # By default, we set the probilities to favor shorter ngram sequences. |
241 | 253 | pvals = 1. / np.arange(1, max_ngrams + 1) |
@@ -266,15 +278,17 @@ def create_masked_lm_predictions(tokens, |
266 | 278 | if index in covered_indexes: |
267 | 279 | continue |
268 | 280 |
|
269 | | - if not geometric_dist: |
| 281 | + if sampling_style is SamplingStyle.POISSON: |
270 | 282 | n = np_rng.choice(ngrams[:len(cand_index_set)], |
271 | 283 | p=pvals[:len(cand_index_set)] / |
272 | 284 | pvals[:len(cand_index_set)].sum(keepdims=True)) |
273 | | - else: |
| 285 | + elif sampling_style is SamplingStyle.GEOMETRIC: |
274 | 286 | # Sampling "n" from the geometric distribution and clipping it to |
275 | 287 | # the max_ngrams. Using p=0.2 default from the SpanBERT paper |
276 | 288 | # https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1) |
277 | 289 | n = min(np_rng.geometric(0.2), max_ngrams) |
| 290 | + else: |
| 291 | + raise ValueError('unknown sampling style') |
278 | 292 |
|
279 | 293 | index_set = sum(cand_index_set[n - 1], []) |
280 | 294 | n -= 1 |
|
0 commit comments