diff --git a/pliers/extractors/__init__.py b/pliers/extractors/__init__.py index cc1f983e..653000de 100644 --- a/pliers/extractors/__init__.py +++ b/pliers/extractors/__init__.py @@ -72,7 +72,7 @@ VADERSentimentExtractor, SpaCyExtractor, WordCounterExtractor, BertExtractor, BertSequenceEncodingExtractor, BertLMExtractor, - BertSentimentExtractor) + BertSentimentExtractor, GPTForwardLMExtractor) from .video import (FarnebackOpticalFlowExtractor) __all__ = [ @@ -154,6 +154,7 @@ 'BertSequenceEncodingExtractor', 'BertLMExtractor', 'BertSentimentExtractor', + 'GPTForwardLMExtractor', 'AudiosetLabelExtractor', 'WordCounterExtractor', 'MetricExtractor', diff --git a/pliers/extractors/text.py b/pliers/extractors/text.py index 8735ae23..121b7b0f 100644 --- a/pliers/extractors/text.py +++ b/pliers/extractors/text.py @@ -855,3 +855,156 @@ def _extract(self, stims): return ExtractorResult(word_counter, stims, self, features=self.features, onsets=onsets, durations=durations) + + +class GPTForwardLMExtractor(ComplexTextExtractor): + ''' Returns next word predictions for GPT models . + Args: + pretrained_model (str): A string specifying which transformer + model to use. + tokenizer (str): Type of tokenization used in the tokenization step. + If different from model, out-of-vocabulary tokens may be treated as + unknown tokens. + framework (str): name deep learning framework to use. Must be 'pt' + (PyTorch) or 'tf' (tensorflow). Defaults to 'pt'. + top_n (int): Specifies how many of the highest-probability tokens are + to be returned. Mutually exclusive with target and threshold. + target (str or list): Vocabulary token(s) for which probability is to + be returned. Tokens defined in the vocabulary change across + tokenizers. Mutually exclusive with top_n and threshold. + threshold (float): If defined, only values above this threshold will + be returned. Mutually exclusive with top_n and target. + return_softmax (bool): if True, returns probability scores instead of + raw predictions. + return_true (bool): if True, returns true_token and its probability. + return_input (bool): whether to return input sequence + onset (str): whether the onset in the result is the one from + the target word ('target') or from the last word in the + context ('last_context') + model_kwargs (dict): Named arguments for pretrained model. + tokenizer_kwargs (dict): Named arguments for tokenizer. + ''' + + _log_attributes = ('pretrained_model', 'framework', 'top_n', 'target', + 'threshold', 'tokenizer_type', 'return_softmax', 'return_true_word', + 'return_true_token', 'return_input', 'return_context', 'onset') + _model_attributes = ('pretrained_model', 'framework', 'top_n', + 'target', 'threshold', 'tokenizer_type') + + def __init__(self, + pretrained_model='gpt2', + tokenizer='gpt2', + model_class='GPT2LMHeadModel', + tokenizer_class='GPT2TokenizerFast', + framework='pt', + top_n=None, + threshold=None, + target=None, + return_true_token=True, + return_true_word=False, + return_softmax=None, + return_input=True, + return_context=True, + onset='target', + model_kwargs=None, + tokenizer_kwargs=None): + verify_dependencies(['transformers']) + if framework not in ['pt', 'tf']: + raise(ValueError('''Invalid framework; + must be one of 'pt' (pytorch) or 'tf' (tensorflow)''')) + if onset not in ['target', 'last_context']: + raise(ValueError('''Onset must be one of + 'target' or 'last_context'.''')) + self.pretrained_model = pretrained_model + self.tokenizer_type = tokenizer + self.model_class = model_class + self.framework = framework + self.model_kwargs = model_kwargs if model_kwargs else {} + self.tokenizer_kwargs = tokenizer_kwargs if tokenizer_kwargs else {} + model = model_class if self.framework == 'pt' else 'TF' + model_class + self.model = getattr(transformers, model).from_pretrained( + pretrained_model, **self.model_kwargs) + self.tokenizer = getattr(transformers, tokenizer_class).from_pretrained( + tokenizer, **self.tokenizer_kwargs) + self.target = listify(target) + if self.target: + missing = set(self.target) - set(self.tokenizer.vocab.keys()) + if missing: + logging.warning(f'{missing} not in vocabulary. Dropping.') + present = set(self.target) & set(self.tokenizer.vocab.keys()) + self.target = list(present) + if self.target == []: + raise ValueError('No valid target token. Import transformers' + ' and run transformers.GPT2Tokenizer.from_pretrained' + f'(\'{tokenizer}\').vocab.keys() to see available tokens') + self.top_n = top_n + self.threshold = threshold + self.return_softmax = return_softmax + self.return_context = return_context + self.return_true_word = return_true_word + self.return_true_token = return_true_token + self.return_input = return_input + self.onset = onset + super().__init__() + + def _preprocess(self, stims): + ''' Tokenizes input and returns context and target info ''' + els = [(e.text, e.onset, e.duration) for e in stims.elements] + wds, ons, dur = map(list, zip(*els)) + c_wds, c_ons, c_dur = (l[:-1] for l in [wds,ons,dur]) # second last + c_tok = self.tokenizer.encode(' '.join(c_wds), return_tensors=self.framework) + stims.name = ' '.join(wds) if stims.name == '' else stims.name + t_wds = ' ' + wds[-1] + t_id = self.tokenizer.encode(t_wds, return_tensors=self.framework)[0,0] + t_tok = self.tokenizer.decode(t_id) + return ((c_ons, c_dur, c_tok, c_wds), + (t_id, t_tok, t_wds, ons[-1], dur[-1])) + + def _extract(self, stims): + c_outs, t_outs = self._preprocess(stims) + c_ons, c_dur, c_tok, c_wds = c_outs + t_id, t_tok, t_wds, t_ons, t_dur = t_outs + outputs = self.model(c_tok) + if self.framework == 'pt': + preds = outputs.logits[0,-1,:].detach().numpy() + else: + preds = outputs.logits[0,-1,:].numpy() + if self.return_softmax: + preds = scipy.special.softmax(preds, axis=-1) + out_idx = preds.argsort()[::-1] + if self.top_n: + sub_idx = out_idx[:self.top_n] + elif self.target: + sub_idx = self.tokenizer.convert_tokens_to_ids(self.target) + elif self.threshold: + sub_idx = np.where(preds >= self.threshold)[0] + else: + sub_idx = out_idx + out_idx = [idx for idx in out_idx if idx in sub_idx] + feat = [self.tokenizer.decode(o) for o in out_idx] + data = [listify(float(p)) for p in preds[out_idx]] + if self.return_true_token: + feat += ['true_token', 'true_token_score'] + data += [t_tok, float(preds[t_id])] + if self.return_true_word: + feat += ['true_word'] + data += [t_wds] + if self.return_context: + feat += ['lm_context'] + data += [' '.join(c_wds)] + if self.return_input: + feat += ['lm_sequence'] + data += [stims.name] + if self.onset == 'target': + ons = listify(t_ons) + dur = listify(t_dur) + else: + ons = listify(c_ons[-1]) + dur = listify(c_dur[-1]) + return ExtractorResult(data, stims, self, + features=feat, onsets=ons, durations=dur) + + def _to_df(self, result): + res_df = pd.DataFrame(dict(zip(result.features, result._data))) + res_df['object_id'] = range(res_df.shape[0]) + return res_df