Skip to content

Commit 7f50532

Browse files
committed
Refactor masked LM sampling style selection
Handles backward-compatibility, so the rest of the code base does not need to change.
1 parent 13becf1 commit 7f50532

1 file changed

Lines changed: 18 additions & 4 deletions

File tree

megatron/data/dataset_utils.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
1919
# with some modifications.
2020

21+
from enum import Enum
2122
import math
2223
import os
2324
import time
@@ -41,6 +42,11 @@
4142
DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5]
4243

4344

45+
class SamplingStyle(Enum):
46+
POISSON = 'poisson'
47+
GEOMETRIC = 'geometric'
48+
49+
4450
def analyze_data_prefix(data_prefix):
4551

4652
# The data prefix should be in the format of:
@@ -194,9 +200,15 @@ def create_masked_lm_predictions(tokens,
194200
favor_longer_ngram=False,
195201
do_permutation=False,
196202
geometric_dist=False,
197-
masking_style="bert"):
203+
masking_style="bert",
204+
sampling_style=SamplingStyle.POISSON):
198205
"""Creates the predictions for the masked LM objective.
199206
Note: Tokens here are vocab ids and not text tokens."""
207+
if not isinstance(sampling_style, SamplingStyle):
208+
sampling_style = SamplingStyle(sampling_style)
209+
# Backward-compatibility
210+
if geometric_dist:
211+
sampling_style = SamplingStyle.GEOMETRIC
200212

201213
cand_indexes = []
202214
# Note(mingdachen): We create a list for recording if the piece is
@@ -235,7 +247,7 @@ def create_masked_lm_predictions(tokens,
235247
max(1, int(round(len(tokens) * masked_lm_prob))))
236248

237249
ngrams = np.arange(1, max_ngrams + 1, dtype=np.int64)
238-
if not geometric_dist:
250+
if sampling_style is SamplingStyle.POISSON:
239251
# Note(mingdachen):
240252
# By default, we set the probilities to favor shorter ngram sequences.
241253
pvals = 1. / np.arange(1, max_ngrams + 1)
@@ -266,15 +278,17 @@ def create_masked_lm_predictions(tokens,
266278
if index in covered_indexes:
267279
continue
268280

269-
if not geometric_dist:
281+
if sampling_style is SamplingStyle.POISSON:
270282
n = np_rng.choice(ngrams[:len(cand_index_set)],
271283
p=pvals[:len(cand_index_set)] /
272284
pvals[:len(cand_index_set)].sum(keepdims=True))
273-
else:
285+
elif sampling_style is SamplingStyle.GEOMETRIC:
274286
# Sampling "n" from the geometric distribution and clipping it to
275287
# the max_ngrams. Using p=0.2 default from the SpanBERT paper
276288
# https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1)
277289
n = min(np_rng.geometric(0.2), max_ngrams)
290+
else:
291+
raise ValueError('unknown sampling style')
278292

279293
index_set = sum(cand_index_set[n - 1], [])
280294
n -= 1

0 commit comments

Comments
 (0)