diff --git a/simplify/simplifiers/models/ie/__init__.py b/simplify/simplifiers/ie/__init__.py similarity index 100% rename from simplify/simplifiers/models/ie/__init__.py rename to simplify/simplifiers/ie/__init__.py diff --git a/simplify/simplifiers/ie/data.py b/simplify/simplifiers/ie/data.py new file mode 100644 index 0000000..b05c9b8 --- /dev/null +++ b/simplify/simplifiers/ie/data.py @@ -0,0 +1,367 @@ +import json +import os +from collections import Counter + +import nltk + +import torch +from torch.utils.data import DataLoader +from transformers import AutoTokenizer + +from utils import (Coordination, coords_to_sentences, ext_to_triplet, + get_config, process_extraction) + + +def collate_pad_data(max_depth): + max_depth = max_depth + + def wrapper(data): + nonlocal max_depth + return pad_data(data, max_depth) + + return wrapper + + +def pad_data(data, max_depth): + input_ids = [ex["input_ids"] for ex in data] + word_begin_index = [d["word_begin_index"] for d in data] + sentence_index = [d["sentence_index"] for d in data] + labels = [d["labels"] for d in data] + max_length = max((len(i) for i in input_ids)) + input_ids = torch.tensor([pad_list(i, max_length, 0) for i in input_ids]) + max_length = max((len(i) for i in word_begin_index)) + word_begin_index = torch.tensor( + [pad_list(i, max_length, 0) for i in word_begin_index] + ) + sentence_index = torch.tensor(sentence_index) + for i, label in enumerate(labels): + to_add = max_depth - len(label) + if to_add > 0: + labels[i] = labels[i] + [[0] * len(label[0])] * to_add + labels = torch.tensor( + [[pad_list(l, max_length, -100) for l in label] for label in labels] + ) + # each input_ids has multiple labels(targets) following the depth, here we make sure that all + # the input_idss have the number of lables + # NOTE: labels need to be padded + padded = { + "input_ids": input_ids, + "labels": labels, + "word_begin_index": word_begin_index, + "sentence_index": sentence_index, + } + return dot_dict(padded) + + +def pad_list(list_, size, padding_token=0): + to_add = size - len(list_) + if to_add > 0: + return list_ + [padding_token] * to_add + if to_add == 0: + return list_ + return list_[:-(-to_add)] # truncate + + +def truncate(input_ids, word_begin_index, max_length): + max_len = max_length - 2 + input_ids = input_ids[:max_len] + unsqueeze_wb = [[i] * j for i, j in enumerate(word_begin_index)] + flatten_uwb = [i for l in unsqueeze_wb for i in l] + flatten_uwb = flatten_uwb[:max_len] + word_begin_index = list(Counter(flatten_uwb).values()) + return input_ids, word_begin_index + + +class dot_dict(dict): + """dot.notation access to dictionary attributes""" + + __getattr__ = dict.get + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + +ADDITIONAL_SPECIAL_TOKENS = [f"[unused{i}]" for i in range(1, 4)] + + +class Data: + def __init__(self, config, *args, **kwargs): + self.config = dot_dict(config) if isinstance(config, dict) else config + self.config.additional_special_tokens = ( + self.config.additional_special_tokens or ADDITIONAL_SPECIAL_TOKENS + ) + self.tokenizer = AutoTokenizer.from_pretrained( + self.config.model_path, + additional_special_tokens=self.config.additional_special_tokens, + ) + + self.max_word_length = self.config.max_word_length or 100 + self.bos_token_id = self.config.bos_token_id or 101 + self.eos_token_id = self.config.eos_token_id or 102 + + def batch_encode_sentence(self, sentences, labels=None, inplace=True): + if labels is None: + labels = [None] * len(sentences) + assert len(labels) == len( + sentences + ), "make sure that sentences and labels have the same length and that they map to each other" + self.orig_sentences = sentences + encoded_sentences = [ + self.encode_sentence(s, l) for s, l in zip(sentences, labels) + ] + self.input_ids, self.word_begin_indexes, self.sentences, self.labels = zip( + *encoded_sentences + ) + if not inplace: + return input_ids, word_begin_indexes, sentences, labels + + def encode_sentence(self, *args, **kwargs): + NotImplemented + + def to_allennlp_format(extractions, sentence): + for extra in extractions: + args1 = extra.args[0] + pred = extra.pred + args2 = " ".join(extra.args[1:]) + confidence = extra.confidence + + extra_sent = f"{sentence}\t {args1} {pred} {arg2} \t{confidence}" + + def to_dataloader( + self, + input_ids=None, + word_begin_indexes=None, + sentences=None, + labels=None, + batch_size=8, + shuffle=False, + collate_fn=None, + ): + input_ids = input_ids or self.input_ids + word_begin_indexes = word_begin_indexes or self.word_begin_indexes + sentences = sentences or self.sentences + labels = labels or self.labels + + collate_fn = collate_fn or collate_pad_data(self.config.max_depth) + self.sentences = sentences + f_names = self.field_names + examples = [ + dict(zip(f_names, e)) + for e in zip(input_ids, word_begin_indexes, labels, range(len(sentences))) + ] + dataloader = DataLoader( + examples, batch_size=batch_size, collate_fn=collate_fn, shuffle=shuffle + ) + return dataloader + + def normalize(self, sentence): + if isinstance(sentence, str): + sentence = sentence.strip(" |\n") + sentence = sentence.replace("’", "'") + sentence = sentence.replace("”", "''") + sentence = sentence.replace("“", "''") + return sentence + + @classmethod + def load_pretrained(cls, path): + path_file = get_config(path) + with open(path_file) as f: + config = json.load(f) + return cls(config) + + def save_pretrained(self, path): + config = dict(self.config) + config.pop("__class__") + with open(path, "w") as f: + json.dump(config, f) + + +class DataForTriplet(Data): + def __init__(self, config): + config["max_depth"] = 5 + super().__init__(config) + + self.field_names = ("input_ids", "word_begin_index", "labels", "sentence_index") + self.label_dict = { + "NONE": 0, + "ARG1": 1, + "REL": 2, + "ARG2": 3, + "LOC": 4, + "TIME": 4, + "TYPE": 5, + "ARGS": 3, + } + + def encode_sentence(self, sentence, labels=None): + sentence = self.normalize(sentence) + tokens = nltk.word_tokenize(sentence) + sentence = " ".join(tokens[: self.max_word_length]) + u_sentence = ( + sentence.strip() + " " + " ".join(self.config.additional_special_tokens) + ) + word_tokens = self.tokenizer.batch_encode_plus( + u_sentence.split(), add_special_tokens=False + )["input_ids"] + + input_ids, word_begin_index = [], [] + for i in word_tokens: + word_begin_index.append(len(input_ids) + 1) + input_ids.extend(i or [100]) + + input_ids = [self.bos_token_id] + input_ids + [self.eos_token_id] + labels = [labels] if isinstance(labels, str) else labels + if labels and isinstance(labels, list): + labels = [ + [self.label_dict[i] for i in label_at_depth.split()] + for label_at_depth in labels + ] + labels = [pad_list(l, len(word_begin_index), 0) for l in labels] + else: + labels = [[0] * len(word_begin_index)] + return input_ids, word_begin_index, sentence, labels + + def batch_decode_prediction(self, predictions, scores, sentences=None): + sentences = sentences or self.sentences + return [ + DataForTriplet.decode_prediction(pred, score, sent) + for pred, score, sent in zip(predictions, scores, sentences) + ] + + @staticmethod + def decode_prediction(predictions, scores, sentence): + # prediction: shape of (depth, max_word_length) + # scores: shape of (depth) + + words = sentence.split() + ADDITIONAL_SPECIAL_TOKENS + predictions, indices = torch.unique( + predictions, dim=0, sorted=False, return_inverse=True + ) + scores = scores[indices.unique()] + + mask_non_null = predictions.sum(dim=-1) != 0 + predictions = predictions[mask_non_null] + scores = scores[mask_non_null] + + extractions = [] + for prediction, score in zip(predictions, scores): + prediction = prediction[: len(words)] + pro_extraction = process_extraction(prediction, words, score.item()) + if pro_extraction.args[0] and pro_extraction.pred: + extracted_triplet = ext_to_triplet(pro_extraction) + extractions.append(extracted_triplet) + + output = dict() + output["triplet"] = extractions + output["sentence"] = sentence + + return output + + +class DataForConjunction(Data): + def __init__(self, config): + config["max_depth"] = 3 + super().__init__(config) + + self.field_names = ("input_ids", "word_begin_index", "labels", "sentence_index") + self.label_dict = { + "CP_START": 2, + "CP": 1, + "CC": 3, + "SEP": 4, + "OTHERS": 5, + "NONE": 0, + } + + def encode_sentence(self, sentence, labels=None): + sentence = self.normalize(sentence) + tokens = nltk.word_tokenize(sentence) + sentence = " ".join(tokens[: self.max_word_length]) + u_sentence = sentence.strip() + " [unused1] [unused2] [unused3]" + + word_tokens = self.tokenizer.batch_encode_plus( + u_sentence.split(), add_special_tokens=False + )["input_ids"] + input_ids, word_begin_index = [], [] + for i in word_tokens: + word_begin_index.append(len(input_ids) + 1) + input_ids.extend(i or [100]) + + input_ids = [self.bos_token_id] + input_ids + [self.eos_token_id] + labels = [labels] if isinstance(labels, str) else labels + if labels and isinstance(labels, list): + + labels = [ + [self.label_dict[i] for i in label_at_depth.split()] + for label_at_depth in labels + ] + labels = [pad_list(l, len(word_begin_index), 0) for l in labels] + else: + labels = [[0] * len(word_begin_index)] + return input_ids, word_begin_index, sentence, labels + + def batch_decode_prediction(self, predictions, sentences=None, **kw): + # prediction batch_size, depth, num_words, sentences + sentences = sentences or self.sentences + output = [] + for prediction, sentence in zip(predictions, sentences): + words = sentence.split() + len_words = len(words) + prediction = [p[:len_words] for p in prediction.tolist()] + coords = get_coords(prediction) + output_sentences = coords_to_sentences(coords, words) + output.append( + { + "sentence": sentence, + "prediction": output_sentences[0], + "conjugation_words": output_sentences[1], + } + ) + return output + + +def get_coords(predictions): + all_coordination_phrases = dict() + for depth, depth_prediction in enumerate(predictions): + coordination_phrase, start_index = None, -1 + coordphrase, conjunction, coordinator, separator = False, False, False, False + for i, label in enumerate(depth_prediction): + if label != 1: # conjunction can end + if conjunction and coordination_phrase != None: + conjunction = False + coordination_phrase["conjuncts"].append((start_index, i - 1)) + if label == 0 or label == 2: # coordination phrase can end + if ( + coordination_phrase != None + and len(coordination_phrase["conjuncts"]) >= 2 + and coordination_phrase["cc"] + > coordination_phrase["conjuncts"][0][1] + and coordination_phrase["cc"] + < coordination_phrase["conjuncts"][-1][0] + ): + coordination = Coordination( + coordination_phrase["cc"], + coordination_phrase["conjuncts"], + label=depth, + ) + all_coordination_phrases[coordination_phrase["cc"]] = coordination + coordination_phrase = None + if label == 0: + continue + if label == 1: # can start a conjunction + if not conjunction: + conjunction = True + start_index = i + if label == 2: # starts a coordination phrase + coordination_phrase = {"cc": -1, "conjuncts": [], "seps": []} + conjunction = True + start_index = i + if label == 3 and coordination_phrase != None: + coordination_phrase["cc"] = i + if label == 4 and coordination_phrase != None: + coordination_phrase["seps"].append(i) + if label == 5: # nothing to be done + continue + if label == 3 and coordination_phrase == None: + # coordinating words which do not have associated conjuncts + all_coordination_phrases[i] = None + return all_coordination_phrases diff --git a/simplify/simplifiers/ie/model.py b/simplify/simplifiers/ie/model.py new file mode 100644 index 0000000..a71f7d8 --- /dev/null +++ b/simplify/simplifiers/ie/model.py @@ -0,0 +1,173 @@ +import torch +import torch.nn as nn +from transformers import AutoConfig, AutoModel + +from data import dot_dict +from utils import get_weight, read_config_file + + +class IEModel(nn.Module): + def __init__(self, config): + self.config = config + + def forward(self, ): + ... + +class IEConfig: + def __init__(self, + backbone_model:'bert-large-cased', + max_depth:int=3, + num_iterative_layers:int=2, + dropout_rate:float=0.0, + label_hidden_size:int=300, + label_size:int=100, + num_labels:int=6, + max_word_length:int=100, + additional_special_tokens:List[str]=[ "[unused1]", "[unused2]", "[unused3]"], + **kwargs): + + self.backbone_model= backbone_model + self.max_depth= max_depth + self.num_iterative_layers= num_iterative_layers + self.dropout_rate= dropout_rate + self.label_hidden_size= label_hidden_size + self.label_size= label_size + self.num_labels= num_labels + self.max_word_length= max_word_length + self.additional_special_tokens= additional_special_tokens + self.kwargs = kwargs + + @classmethod + def load_config(cls, config_file): + config = read_config_file(config_file) + return cls(**config) + + def save_config(self, file_path): + with open(file_path, 'w') as f: + json.dump(self.config, f) + +config = { + + "max_depth": 3, + "model_path": "bert-large-cased", + "num_iterative_layers": 2, + "dropout_rate": 0.0, + "label_size":100, + "label_hidden_size":300, + "num_labels": 6, + "additional_special_tokens": [ "[unused1]", "[unused2]", "[unused3]"], + "max_word_length": 100, +} + +class Models: + def __init__(self, config): + self.config = dot_dict(config) + self.config.num_labels = self.config.num_labels or 6 + self.config.label_size = self.config.label_size or 100 + self.config.label_hidden_size = self.config.label_hidden_size or 300 + + self.config.num_iterative_layers = self.config.num_iterative_layers or 2 + self.config.drop_rate = self.config.drop_rate or 0.0 + + def save_pretrained(self, path): + torch.save({"config": dict(self.config), "state_dict": self.state_dict()}, path) + + @classmethod + def load_pretrained(cls, path, device="cpu"): + path_file = get_weight(path) + config_and_state_dict = torch.load(path_file, map_location=device) + model = cls(config_and_state_dict["config"]) + model.load_state_dict(config_and_state_dict["state_dict"]) + return model + + def is_valid_extraction(self, prediction): + return any([(1 in p and 2 in p) for p in prediction]) + + +class BackboneModel(nn.Module, Models): + def __init__(self, config): + nn.Module.__init__(self) + Models.__init__(self, config) + + model_config = AutoConfig.from_pretrained(self.config.model_path) + self.base_model = AutoModel.from_config(model_config) + self.hidden_size = self.base_model.config.hidden_size + + if self.config.num_iterative_layers == 0: + self.iterative_layers = [] + else: + self.iterative_layers = self.base_model.encoder.layer[ + -self.config.num_iterative_layers : + ] + self.base_model.encoder.layer = self.base_model.encoder.layer[ + : -self.config.num_iterative_layers + ] + self.dropout = nn.Dropout(self.config.drop_rate) + + self.label_embedding = nn.Embedding(self.config.label_size, self.hidden_size) + self.linear1 = nn.Linear(self.hidden_size, self.config.label_hidden_size) + self.linear2 = nn.Linear(self.config.label_hidden_size, self.config.num_labels) + + def forward(self, input_ids, word_begin_index, labels=None): + hidden_states = self.base_model(input_ids)["last_hidden_state"] + prediction_in_depth = [] + word_scores_in_depth = [] + for d in range(self.config.depth): + for layer in self.iterative_layers: + hidden_states = layer(hidden_states)[0] + hidden_states = self.dropout(hidden_states) + un_word_begin_index = word_begin_index.unsqueeze(2).repeat( + 1, 1, self.hidden_size + ) + word_hidden_states = torch.gather( + hidden_states, dim=1, index=un_word_begin_index + ) + + if d != 0: + index_words = torch.argmax(word_scores, dim=-1) + label_embeddings = self.label_embedding(index_words) + word_hidden_states = word_hidden_states + label_embeddings + + word_hidden_states = self.linear1(word_hidden_states) + word_scores = self.linear2(word_hidden_states) + + prediction = torch.argmax(word_scores, dim=-1) + + word_scores_in_depth.append(word_scores) + prediction_in_depth.append(prediction) + + if not self.is_valid_extraction(prediction): + break + # word_scores = depth [batch,num_words] + word_scores_in_depth = torch.stack(word_scores_in_depth, dim=1) + + scores = [ + self.calculate_score(word_scores, label) + for word_scores, label in zip(word_scores_in_depth, labels) + ] + output = dict() + output["scores"] = torch.stack(scores, dim=0) + output["predictions"] = torch.stack(prediction_in_depth, dim=1) + return output + + def calculate_score(self, word_scores, labels): + max_log_probs, predictions = torch.log_softmax(word_scores, dim=-1).max(dim=-1) + mask_labels = (labels[0, :] != -100).float() + mask_predictions = (predictions != 0).float() * mask_labels + log_prob = (max_log_probs * mask_predictions) / ( + mask_predictions.sum(dim=1) + 1 + ).unsqueeze(-1) + scores = torch.exp(log_prob.sum(dim=1)) + return scores + + +class ModelForConjunction(nn.Module, Models): + def __init__(self, config): + nn.Module.__init__(self) + Models.__init__(self, config) + + self.config.depth = 3 + self.model = BackboneModel(self.config) + + def forward(self, input_ids, word_begin_index, labels=None): + return self.model(input_ids, word_begin_index, labels) diff --git a/simplify/simplifiers/ie/pipeline.py b/simplify/simplifiers/ie/pipeline.py new file mode 100644 index 0000000..5fd85f2 --- /dev/null +++ b/simplify/simplifiers/ie/pipeline.py @@ -0,0 +1,30 @@ +import torch + +from data import Data +from model import Models + + +class Pipeline: + def __init__(self, model: Models, data: Data): + self.model = model.eval() + self.data = data + + def run(self, sentences): + self.data.batch_encode_sentence(sentences, inplace=True) + dataloader = self.data.to_dataloader() + output = [] + with torch.no_grad(): + for batch in dataloader: + model_output = self.model( + batch.input_ids, batch.word_begin_index, batch.labels + ) + batch_sentences = [ + self.data.sentences[i] for i in batch.sentence_index.tolist() + ] + # output dict prediction and scores of shape batch_size, depth, num_word) + # batch_sentences (batch_size) + out = self.data.batch_decode_prediction( + **model_output, sentences=batch_sentences + ) + output.extend(out) + return output diff --git a/simplify/simplifiers/ie/utils.py b/simplify/simplifiers/ie/utils.py new file mode 100644 index 0000000..b3f0df0 --- /dev/null +++ b/simplify/simplifiers/ie/utils.py @@ -0,0 +1,812 @@ +import itertools +import logging +import os +import warnings +from operator import itemgetter + +import nltk +import numpy as np +from transformers.file_utils import cached_path, hf_bucket_url + +SEP = ";;;" +QUESTION_TRG_INDEX = 3 # index of the predicate within the question +QUESTION_PP_INDEX = 5 +QUESTION_OBJ2_INDEX = 6 +DATA_CONFIG_NAME = "data-config.json" +WEIGHT_NAME = "model.pt" + + +def read_config_file(file_path): + with open(file_path) as f: + return json.loads(f.read()) + +def get_config(path): + return _get_file(path, DATA_CONFIG_NAME) + + +def get_weight(path): + return _get_file(path, WEIGHT_NAME) + + +def _get_file(path, file_name): + possible_config_file = os.path.join(path, file_name) + if os.path.isdir(path) and on.path.isfile(possible_config_file): + config_file = possible_config_file + elif os.path.isfile(path): + config_file = path + else: + config_file = hf_bucket_url(path, filename=file_name) + try: + return cached_path(config_file) + except Exception as e: + raise f"can't load {path}, due to {str(e)}" + + + + +def process_extraction(extraction, sentence, score): + # rel, arg1, arg2, loc, time = [], [], [], [], [] + rel, arg1, arg2, loc_time, args = [], [], [], [], [] + tag_mode = "none" + rel_case = 0 + for i, token in enumerate(sentence): + if "[unused" in token: + if extraction[i].item() == 2: + rel_case = int(re.search("\[unused(.*)\]", token).group(1)) + continue + if extraction[i] == 1: + arg1.append(token) + if extraction[i] == 2: + rel.append(token) + if extraction[i] == 3: + arg2.append(token) + if extraction[i] == 4: + loc_time.append(token) + rel = " ".join(rel).strip() + if rel_case == 1: + rel = "is " + rel + elif rel_case == 2: + rel = "is " + rel + " of" + elif rel_case == 3: + rel = "is " + rel + " from" + arg1 = " ".join(arg1).strip() + arg2 = " ".join(arg2).strip() + args = " ".join(args).strip() + loc_time = " ".join(loc_time).strip() + arg2 = (arg2 + " " + loc_time + " " + args).strip() + sentence_str = " ".join(sentence).strip() + extraction = Extraction( + pred=rel, head_pred_index=None, sent=sentence_str, confidence=score, index=0 + ) + extraction.addArg(arg1) + extraction.addArg(arg2) + return extraction + + +def coords_to_sentences(conj_coords, words): + + for k in list(conj_coords): + if conj_coords[k] is None: + conj_coords.pop(k) + + for k in list(conj_coords): + if words[conj_coords[k].cc] in ["nor", "&"]: # , 'or']: + conj_coords.pop(k) + + remove_unbreakable_conjuncts(conj_coords, words) + + conj_words = [] + for k in list(conj_coords): + for conjunct in conj_coords[k].conjuncts: + conj_words.append(" ".join(words[conjunct[0] : conjunct[1] + 1])) + + sentence_indices = [] + for i in range(0, len(words)): + sentence_indices.append(i) + + roots, parent_mapping, child_mapping = get_tree(conj_coords) + + q = list(roots) + + sentences = [] + count = len(q) + new_count = 0 + + conj_same_level = [] + + while len(q) > 0: + + conj = q.pop(0) + count -= 1 + conj_same_level.append(conj) + + for child in child_mapping[conj]: + q.append(child) + new_count += 1 + + if count == 0: + get_sentences(sentences, conj_same_level, conj_coords, sentence_indices) + count = new_count + new_count = 0 + conj_same_level = [] + + word_sentences = [ + " ".join([words[i] for i in sorted(sentence)]) for sentence in sentences + ] + + return word_sentences, conj_words, sentences + + +def get_tree(conj): + parent_child_list = [] + + child_mapping, parent_mapping = {}, {} + + for key in conj: + assert conj[key].cc == key + parent_child_list.append([]) + for k in conj: + if conj[k] is not None: + if is_parent(conj[key], conj[k]): + parent_child_list[-1].append(k) + + child_mapping[key] = parent_child_list[-1] + + parent_child_list.sort(key=list.__len__) + + for i in range(0, len(parent_child_list)): + for child in parent_child_list[i]: + for j in range(i + 1, len(parent_child_list)): + if child in parent_child_list[j]: + parent_child_list[j].remove(child) + + for key in conj: + for child in child_mapping[key]: + parent_mapping[child] = key + + roots = [] + for key in conj: + if key not in parent_mapping: + roots.append(key) + + return roots, parent_mapping, child_mapping + + +def is_parent(parent, child): + min = child.conjuncts[0][0] + max = child.conjuncts[-1][-1] + + for conjunct in parent.conjuncts: + if conjunct[0] <= min and conjunct[1] >= max: + return True + return False + + +def get_sentences(sentences, conj_same_level, conj_coords, sentence_indices): + for conj in conj_same_level: + + if len(sentences) == 0: + + for conj_structure in conj_coords[conj].conjuncts: + sentence = [] + for i in range(conj_structure[0], conj_structure[1] + 1): + sentence.append(i) + sentences.append(sentence) + + min = conj_coords[conj].conjuncts[0][0] + max = conj_coords[conj].conjuncts[-1][-1] + + for sentence in sentences: + for i in sentence_indices: + if i < min or i > max: + sentence.append(i) + + else: + to_add = [] + to_remove = [] + + for sentence in sentences: + if conj_coords[conj].conjuncts[0][0] in sentence: + sentence.sort() + + min = conj_coords[conj].conjuncts[0][0] + max = conj_coords[conj].conjuncts[-1][-1] + + for conj_structure in conj_coords[conj].conjuncts: + new_sentence = [] + for i in sentence: + if ( + i in range(conj_structure[0], conj_structure[1] + 1) + or i < min + or i > max + ): + new_sentence.append(i) + + to_add.append(new_sentence) + + to_remove.append(sentence) + + for sent in to_remove: + sentences.remove(sent) + sentences.extend(to_add) + + +def remove_unbreakable_conjuncts(conj, words): + + unbreakable_indices = [] + + unbreakable_words = [ + "between", + "among", + "sum", + "total", + "addition", + "amount", + "value", + "aggregate", + "gross", + "mean", + "median", + "average", + "center", + "equidistant", + "middle", + ] + + for i, word in enumerate(words): + if word.lower() in unbreakable_words: + unbreakable_indices.append(i) + + to_remove = [] + span_start = 0 + + for key in conj: + span_end = conj[key].conjuncts[0][0] - 1 + for i in unbreakable_indices: + if span_start <= i <= span_end: + to_remove.append(key) + span_start = conj[key].conjuncts[-1][-1] + 1 + + for k in set(to_remove): + conj.pop(k) + + +class Coordination(object): + __slots__ = ("cc", "conjuncts", "seps", "label") + + def __init__(self, cc, conjuncts, seps=None, label=None): + assert isinstance(conjuncts, (list, tuple)) and len(conjuncts) >= 2 + assert all(isinstance(conj, tuple) for conj in conjuncts) + conjuncts = sorted(conjuncts, key=lambda span: span[0]) + # NOTE(chantera): The form 'A and B, C' is considered to be coordination. # NOQA + # assert cc > conjuncts[-2][1] and cc < conjuncts[-1][0] + assert cc > conjuncts[0][1] and cc < conjuncts[-1][0] + if seps is not None: + # if len(seps) == len(conjuncts) - 2: + # for i, sep in enumerate(seps): + # assert conjuncts[i][1] < sep and conjuncts[i + 1][0] > sep + # else: + if len(seps) != len(conjuncts) - 2: + warnings.warn( + "Coordination does not contain enough separators. " + "It may be a wrong coordination: " + "cc={}, conjuncts={}, separators={}".format(cc, conjuncts, seps) + ) + else: + seps = [] + self.cc = cc + self.conjuncts = tuple(conjuncts) + self.seps = tuple(seps) + self.label = label + + def get_pair(self, index, check=False): + pair = None + for i in range(1, len(self.conjuncts)): + if self.conjuncts[i][0] > index: + pair = (self.conjuncts[i - 1], self.conjuncts[i]) + assert pair[0][1] < index and pair[1][0] > index + break + if check and pair is None: + raise LookupError("Could not find any pair for index={}".format(index)) + return pair + + def __repr__(self): + return "Coordination(cc={}, conjuncts={}, seps={}, label={})".format( + self.cc, self.conjuncts, self.seps, self.label + ) + + def __eq__(self, other): + if not isinstance(other, Coordination): + return False + return ( + self.cc == other.cc + and len(self.conjuncts) == len(other.conjuncts) + and all( + conj1 == conj2 for conj1, conj2 in zip(self.conjuncts, other.conjuncts) + ) + ) + + +class Extraction: + """ + Stores sentence, single predicate and corresponding arguments. + """ + + def __init__( + self, pred, head_pred_index, sent, confidence, question_dist="", index=-1 + ): + self.pred = pred + self.head_pred_index = head_pred_index + self.sent = sent + self.args = [] + self.confidence = confidence + self.matched = [] + self.questions = {} + # self.indsForQuestions = defaultdict(lambda: set()) + self.is_mwp = False + self.question_dist = question_dist + self.index = index + + def distArgFromPred(self, arg): + assert len(self.pred) == 2 + dists = [] + for x in self.pred[1]: + for y in arg.indices: + dists.append(abs(x - y)) + + return min(dists) + + def argsByDistFromPred(self, question): + return sorted( + self.questions[question], key=lambda arg: self.distArgFromPred(arg) + ) + + def addArg(self, arg, question=None): + self.args.append(arg) + if question: + self.questions[question] = self.questions.get(question, []) + [ + Argument(arg) + ] + + def noPronounArgs(self): + """ + Returns True iff all of this extraction's arguments are not pronouns. + """ + for (a, _) in self.args: + tokenized_arg = nltk.word_tokenize(a) + if len(tokenized_arg) == 1: + _, pos_tag = nltk.pos_tag(tokenized_arg)[0] + if "PRP" in pos_tag: + return False + return True + + def isContiguous(self): + return all([indices for (_, indices) in self.args]) + + def toBinary(self): + """Try to represent this extraction's arguments as binary + If fails, this function will return an empty list.""" + + ret = [self.elementToStr(self.pred)] + + if len(self.args) == 2: + # we're in luck + return ret + [self.elementToStr(arg) for arg in self.args] + + return [] + + if not self.isContiguous(): + # give up on non contiguous arguments (as we need indexes) + return [] + + # otherwise, try to merge based on indices + # TODO: you can explore other methods for doing this + binarized = self.binarizeByIndex() + + if binarized: + return ret + binarized + + return [] + + def elementToStr(self, elem, print_indices=True): + """formats an extraction element (pred or arg) as a raw string + removes indices and trailing spaces""" + if print_indices: + return str(elem) + if isinstance(elem, str): + return elem + if isinstance(elem, tuple): + ret = elem[0].rstrip().lstrip() + else: + ret = " ".join(elem.words) + assert ret, "empty element? {0}".format(elem) + return ret + + def binarizeByIndex(self): + extraction = [self.pred] + self.args + markPred = [(w, ind, i == 0) for i, (w, ind) in enumerate(extraction)] + sortedExtraction = sorted(markPred, key=lambda ws, indices, f: indices[0]) + s = " ".join( + [ + "{1} {0} {1}".format(self.elementToStr(elem), SEP) + if elem[2] + else self.elementToStr(elem) + for elem in sortedExtraction + ] + ) + binArgs = [a for a in s.split(SEP) if a.rstrip().lstrip()] + + if len(binArgs) == 2: + return binArgs + + # failure + return [] + + def bow(self): + return " ".join([self.elementToStr(elem) for elem in [self.pred] + self.args]) + + def getSortedArgs(self): + """ + Sort the list of arguments. + If a question distribution is provided - use it, + otherwise, default to the order of appearance in the sentence. + """ + if self.question_dist: + # There's a question distribtuion - use it + return self.sort_args_by_distribution() + ls = [] + for q, args in self.questions.iteritems(): + if len(args) != 1: + logging.debug("Not one argument: {}".format(args)) + continue + arg = args[0] + indices = list(self.indsForQuestions[q].union(arg.indices)) + if not indices: + logging.debug("Empty indexes for arg {} -- backing to zero".format(arg)) + indices = [0] + ls.append(((arg, q), indices)) + return [a for a, _ in sorted(ls, key=lambda _, indices: min(indices))] + + def question_prob_for_loc(self, question, loc): + """ + Returns the probability of the given question leading to argument + appearing in the given location in the output slot. + """ + gen_question = generalize_question(question) + q_dist = self.question_dist[gen_question] + logging.debug("distribution of {}: {}".format(gen_question, q_dist)) + + return float(q_dist.get(loc, 0)) / sum(q_dist.values()) + + def sort_args_by_distribution(self): + """ + Use this instance's question distribution (this func assumes it exists) + in determining the positioning of the arguments. + Greedy algorithm: + 0. Decide on which argument will serve as the ``subject'' (first slot) of this extraction + 0.1 Based on the most probable one for this spot + (special care is given to select the highly-influential subject position) + 1. For all other arguments, sort arguments by the prevalance of their questions + 2. For each argument: + 2.1 Assign to it the most probable slot still available + 2.2 If non such exist (fallback) - default to put it in the last location + """ + INF_LOC = 100 # Used as an impractical last argument + + # Store arguments by slot + ret = {INF_LOC: []} + logging.debug("sorting: {}".format(self.questions)) + + # Find the most suitable arguemnt for the subject location + logging.debug( + "probs for subject: {}".format( + [ + (q, self.question_prob_for_loc(q, 0)) + for (q, _) in self.questions.iteritems() + ] + ) + ) + + subj_question, subj_args = max( + self.questions.iteritems(), + key=lambda q, _: self.question_prob_for_loc(q, 0), + ) + + ret[0] = [(subj_args[0], subj_question)] + + # Find the rest + for (question, args) in sorted( + [ + (q, a) + for (q, a) in self.questions.iteritems() + if (q not in [subj_question]) + ], + key=lambda q, _: sum(self.question_dist[generalize_question(q)].values()), + reverse=True, + ): + gen_question = generalize_question(question) + arg = args[0] + assigned_flag = False + for (loc, count) in sorted( + self.question_dist[gen_question].iteritems(), + key=lambda _, c: c, + reverse=True, + ): + if loc not in ret: + # Found an empty slot for this item + # Place it there and break out + ret[loc] = [(arg, question)] + assigned_flag = True + break + + if not assigned_flag: + # Add this argument to the non-assigned (hopefully doesn't happen much) + logging.debug( + "Couldn't find an open assignment for {}".format( + (arg, gen_question) + ) + ) + ret[INF_LOC].append((arg, question)) + + logging.debug("Linearizing arg list: {}".format(ret)) + + # Finished iterating - consolidate and return a list of arguments + return [ + arg + for (_, arg_ls) in sorted(ret.iteritems(), key=lambda k, v: int(k)) + for arg in arg_ls + ] + + def __str__(self): + pred_str = self.elementToStr(self.pred) + return "{}\t{}\t{}".format( + self.get_base_verb(pred_str), + self.compute_global_pred(pred_str, self.questions.keys()), + "\t".join( + [ + escape_special_chars( + self.augment_arg_with_question(self.elementToStr(arg), question) + ) + for arg, question in self.getSortedArgs() + ] + ), + ) + + def get_base_verb(self, surface_pred): + """ + Given the surface pred, return the original annotated verb + """ + # Assumes that at this point the verb is always the last word + # in the surface predicate + return surface_pred.split(" ")[-1] + + def compute_global_pred(self, surface_pred, questions): + """ + Given the surface pred and all instansiations of questions, + make global coherence decisions regarding the final form of the predicate + This should hopefully take care of multi word predicates and correct inflections + """ + from operator import itemgetter + + split_surface = surface_pred.split(" ") + + if len(split_surface) > 1: + # This predicate has a modal preceding the base verb + verb = split_surface[-1] + ret = split_surface[:-1] # get all of the elements in the modal + else: + verb = split_surface[0] + ret = [] + + split_questions = map(lambda question: question.split(" "), questions) + + preds = map( + normalize_element, map(itemgetter(QUESTION_TRG_INDEX), split_questions) + ) + if len(set(preds)) > 1: + # This predicate is appears in multiple ways, let's stick to the base form + ret.append(verb) + + if len(set(preds)) == 1: + # Change the predciate to the inflected form + # if there's exactly one way in which the predicate is conveyed + ret.append(preds[0]) + + pps = map( + normalize_element, map(itemgetter(QUESTION_PP_INDEX), split_questions) + ) + + obj2s = map( + normalize_element, map(itemgetter(QUESTION_OBJ2_INDEX), split_questions) + ) + + if len(set(pps)) == 1: + # If all questions for the predicate include the same pp attachemnt - + # assume it's a multiword predicate + self.is_mwp = ( + True # Signal to arguments that they shouldn't take the preposition + ) + ret.append(pps[0]) + + # Concat all elements in the predicate and return + return " ".join(ret).strip() + + def augment_arg_with_question(self, arg, question): + """ + Decide what elements from the question to incorporate in the given + corresponding argument + """ + # Parse question + wh, aux, sbj, trg, obj1, pp, obj2 = map( + normalize_element, question.split(" ")[:-1] + ) # Last split is the question mark + + # Place preposition in argument + # This is safer when dealing with n-ary arguments, as it's directly attaches to the + # appropriate argument + if (not self.is_mwp) and pp and (not obj2): + if not (arg.startswith("{} ".format(pp))): + # Avoid repeating the preporition in cases where both question and answer contain it + return " ".join([pp, arg]) + + # Normal cases + return arg + + def clusterScore(self, cluster): + """ + Calculate cluster density score as the mean distance of the maximum distance of each slot. + Lower score represents a denser cluster. + """ + logging.debug("*-*-*- Cluster: {}".format(cluster)) + + # Find global centroid + arr = np.array([x for ls in cluster for x in ls]) + centroid = np.sum(arr) / arr.shape[0] + logging.debug("Centroid: {}".format(centroid)) + + # Calculate mean over all maxmimum points + return np.average([max([abs(x - centroid) for x in ls]) for ls in cluster]) + + def resolveAmbiguity(self): + """ + Heursitic to map the elments (argument and predicates) of this extraction + back to the indices of the sentence. + """ + ## TODO: This removes arguments for which there was no consecutive span found + ## Part of these are non-consecutive arguments, + ## but other could be a bug in recognizing some punctuation marks + + elements = [self.pred] + [(s, indices) for (s, indices) in self.args if indices] + logging.debug("Resolving ambiguity in: {}".format(elements)) + + # Collect all possible combinations of arguments and predicate indices + # (hopefully it's not too much) + all_combinations = list(itertools.product(*map(itemgetter(1), elements))) + logging.debug("Number of combinations: {}".format(len(all_combinations))) + + # Choose the ones with best clustering and unfold them + resolved_elements = zip( + map(itemgetter(0), elements), + min(all_combinations, key=lambda cluster: self.clusterScore(cluster)), + ) + logging.debug("Resolved elements = {}".format(resolved_elements)) + + self.pred = resolved_elements[0] + self.args = resolved_elements[1:] + + def conll(self, external_feats={}): + """ + Return a CoNLL string representation of this extraction + """ + return ( + "\n".join( + [ + "\t".join( + map( + str, + [i, w] + + list(self.pred) + + [self.head_pred_index] + + external_feats + + [self.get_label(i)], + ) + ) + for (i, w) in enumerate(self.sent.split(" ")) + ] + ) + + "\n" + ) + + def get_label(self, index): + """ + Given an index of a word in the sentence -- returns the appropriate BIO conll label + Assumes that ambiguation was already resolved. + """ + # Get the element(s) in which this index appears + ent = [ + (elem_ind, elem) + for (elem_ind, elem) in enumerate( + map(itemgetter(1), [self.pred] + self.args) + ) + if index in elem + ] + + if not ent: + # index doesnt appear in any element + return "O" + + if len(ent) > 1: + # The same word appears in two different answers + # In this case we choose the first one as label + logging.warn( + "Index {} appears in one than more element: {}".format( + index, "\t".join(map(str, [ent, self.sent, self.pred, self.args])) + ) + ) + + ## Some indices appear in more than one argument (ones where the above message appears) + ## From empricial observation, these seem to mostly consist of different levels of granularity: + ## what had _ been taken _ _ _ ? loan commitments topping $ 3 billion + ## how much had _ been taken _ _ _ ? topping $ 3 billion + ## In these cases we heuristically choose the shorter answer span, hopefully creating minimal spans + ## E.g., in this example two arguemnts are created: (loan commitments, topping $ 3 billion) + + elem_ind, elem = min(ent, key=lambda _, ls: len(ls)) + + # Distinguish between predicate and arguments + prefix = "P" if elem_ind == 0 else "A{}".format(elem_ind - 1) + + # Distinguish between Beginning and Inside labels + suffix = "B" if index == elem[0] else "I" + + return "{}-{}".format(prefix, suffix) + + def __str__(self): + return "{0}\t{1}".format( + self.elementToStr(self.pred, print_indices=True), + "\t".join([self.elementToStr(arg) for arg in self.args]), + ) + + +class Argument: + def __init__(self, arg): + self.words = [x for x in arg[0].strip().split(" ") if x] + self.posTags = map(itemgetter(1), nltk.pos_tag(self.words)) + self.indices = arg[1] + self.feats = {} + + def __str__(self): + return "({})".format( + "\t".join( + map( + str, + [(" ".join(self.words)).replace("\t", "\\t"), str(self.indices)], + ) + ) + ) + + +def normalize_element(elem): + """ + Return a surface form of the given question element. + the output should be properly able to precede a predicate (or blank otherwise) + """ + return elem.replace("_", " ") if (elem != "_") else "" + + +## Helper functions +def escape_special_chars(s): + return s.replace("\t", "\\t") + + +def generalize_question(question): + """ + Given a question in the context of the sentence and the predicate index within + the question - return a generalized version which extracts only order-imposing features + """ + import nltk # Using nltk since couldn't get spaCy to agree on the tokenization + + wh, aux, sbj, trg, obj1, pp, obj2 = question.split(" ")[ + :-1 + ] # Last split is the question mark + return " ".join([wh, sbj, obj1]) diff --git a/simplify/simplifiers/models/__init__.py b/simplify/simplifiers/models/__init__.py deleted file mode 100644 index 551b6fc..0000000 --- a/simplify/simplifiers/models/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .muss import Muss -from .discourse import Discourse - - -__all_ = ['Discourse', 'Muss'] diff --git a/simplify/simplifiers/models/base.py b/simplify/simplifiers/models/base.py deleted file mode 100644 index b211cd4..0000000 --- a/simplify/simplifiers/models/base.py +++ /dev/null @@ -1,9 +0,0 @@ -from abc import ABCMeta - - -class BaseModel(metaclass=ABCMeta): - @abstractmethod - def __call__( - self, - ): - NotImplemented diff --git a/simplify/simplifiers/models/discourse/__init__.py b/simplify/simplifiers/models/discourse/__init__.py deleted file mode 100644 index 3bcb386..0000000 --- a/simplify/simplifiers/models/discourse/__init__.py +++ /dev/null @@ -1,102 +0,0 @@ -""" -Lambda-3/DiscourseSimplification is licensed under the - -GNU General Public License v3.0 -Permissions of this strong copyleft license are conditioned on making available complete source code of licensed works and modifications, which include larger works using a licensed work, under the same license. Copyright and license notices must be preserved. Contributors provide an express grant of patent rights. -""" - - -from multiprocessing import Pool -import json - -__all__ = ["Discourse"] -jarpath = "https://github.com/sadakmed/simplify/raw/master/.jar/discourse.jar" - - -def with_jvm(paths): - def nested_decorator(func): - global wrapper - - def wrapper(sentences): - import jpype - import jpype.imports - - jpype.startJVM("-ea", classpath=paths) - from org.slf4j.Logger import ROOT_LOGGER_NAME - from org.lambda3.text.simplification.discourse.processing import ( - DiscourseSimplifier, - ProcessingType, - ) - - logging = jpype.java.util.logging - off = logging.Level.OFF - logging.Logger.getLogger(ROOT_LOGGER_NAME).setLevel(off) - - modules = { - "jpype": jpype, - "DiscourseSimplifier": DiscourseSimplifier, - "ProcessingType": ProcessingType, - } - func.__globals__.update(modules) - - simple_sentences = func(sentences) - - jpype.shutdownJVM() - return simple_sentences - - return wrapper - - return nested_decorator - - -@with_jvm(jarpath) -def discourse(sentences: list): - jlist_sentences = jpype.java.util.ArrayList(sentences) - dis = DiscourseSimplifier() - j_simple_sentences = dis.doDiscourseSimplification( - jlist_sentences, ProcessingType.SEPARATE - ) - p_simple_sentences = str(j_simple_sentences.serializeToJSON().toString()) - simple_sentences = json.loads(p_simple_sentences) - return simple_sentences - - -def discourse_old(sentences: list, paths: list): - import jpype as jp - import jpype.imports - - jp.startJVM("-ea", classpath=paths) - from org.slf4j.Logger import ROOT_LOGGER_NAME - - print(ROOT_LOGGER_NAME) - jp.java.util.logging.Logger.getLogger(ROOT_LOGGER_NAME).setLevel( - jp.java.util.logging.Level.OFF - ) - - from org.lambda3.text.simplification.discourse.processing import ( - DiscourseSimplifier, - ProcessingType, - ) - from org.lambda3.text.simplification.discourse.model import SimplificationContent - - jlist_sentences = jpype.java.util.ArrayList(sentences) - dis = DiscourseSimplifier() - j_simple_content = dis.doDiscourseSimplification( - jlist_sentences, ProcessingType.SEPARATE - ) - p_simple_content = str(j_simple_content.serializeToJSON().toString()) - jp.shutdownJVM() - - out_dict = json.loads(p_simple_content) - return out_dict - - -class Discourse: - def __init__(self): - pass - - def __call__(self, sentences): - with Pool(1) as p: - # output = p.map(partial(discourse, paths=jarpath), [sentences]) - output = p.map(discourse, [sentences]) - return output diff --git a/simplify/simplifiers/models/muss/__init__.py b/simplify/simplifiers/models/muss/__init__.py deleted file mode 100644 index 2583f91..0000000 --- a/simplify/simplifiers/models/muss/__init__.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import shutil - -from .preprocessors import get_preprocessors -from .utils import write_lines, read_lines, get_temp_filepath, download_and_extract -from .simplifiers import get_fairseq_simplifier, get_preprocessed_simplifier -from simplify import SIMPLIFY_CACHE - -__all__ = ["Muss"] - -ALLOWED_MODEL_NAMES = [ - "muss_en_wikilarge_mined", - "muss_en_mined", -] - - -preprocessors_kwargs = { - "LengthRatioPreprocessor": {"target_ratio": 0.9, "use_short_name": False}, - "ReplaceOnlyLevenshteinPreprocessor": { - "target_ratio": 0.65, - "use_short_name": False, - }, - "WordRankRatioPreprocessor": {"target_ratio": 0.75, "use_short_name": False}, - "DependencyTreeDepthRatioPreprocessor": { - "target_ratio": 0.4, - "use_short_name": False, - }, - "GPT2BPEPreprocessor": {}, -} - - -def get_model_path(model_name): - assert model_name in ALLOWED_MODEL_NAMES - model_path = SIMPLIFY_CACHE / model_name - if not model_path.exists(): - url = f"https://dl.fbaipublicfiles.com/muss/{model_name}.tar.gz" - extracted_path = download_and_extract(url)[0] - shutil.move(extracted_path, model_path) - return model_path - - -class Muss: - def __init__(self, model_name: str = "muss_en_wikilarge_mined"): - self.model_name = model_name - self.preprocessors = None - self.simplifier = None - self._initialize() - - def _initialize(self): - model_path = get_model_path(self.model_name) - self.preprocessors = get_preprocessors(preprocessors_kwargs) - simplifier = get_fairseq_simplifier(model_path) - self.simplifier = get_preprocessed_simplifier(simplifier, self.preprocessors) - - def __call__(self, sentences): - source_path = get_temp_filepath() - write_lines(sentences, source_path) - prediction_path = self.simplifier(source_path) - prediction_sentences = read_lines(prediction_path) - return prediction_sentences diff --git a/simplify/simplifiers/models/muss/fairseq_util.py b/simplify/simplifiers/models/muss/fairseq_util.py deleted file mode 100644 index a60ad3e..0000000 --- a/simplify/simplifiers/models/muss/fairseq_util.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from collections import defaultdict -from pathlib import Path - -import re -import shutil -import shlex - - -from fairseq_cli import generate - - -from .utils import ( - log_std_streams, - yield_lines, - write_lines, - mock_cli_args, - create_temp_dir, - mute, - args_dict_to_str, -) - - -def remove_multiple_whitespaces(text): - return re.sub(r" +", " ", text) - - -def fairseq_parse_all_hypotheses(out_filepath): - hypotheses_dict = defaultdict(list) - for line in yield_lines(out_filepath): - match = re.match(r"^H-(\d+)\t-?\d+\.\d+\t(.*)$", line) - if match: - sample_id, hypothesis = match.groups() - hypotheses_dict[int(sample_id)].append(hypothesis) - # Sort in original order - return [hypotheses_dict[i] for i in range(len(hypotheses_dict))] - - -def _fairseq_generate( - complex_filepath, - output_pred_filepath, - checkpoint_paths, - complex_dictionary_path, - simple_dictionary_path, - beam=5, - hypothesis_num=1, - lenpen=1.0, - diverse_beam_groups=None, - diverse_beam_strength=0.5, - sampling=False, - max_tokens=16384, - source_lang="complex", - target_lang="simple", - **kwargs, -): - # exp_dir must contain checkpoints/checkpoint_best.pt, and dict.{complex,simple}.txt - # First copy input complex file to exp_dir and create dummy simple file - - with create_temp_dir() as temp_dir: - new_complex_filepath = ( - temp_dir / f"tmp.{source_lang}-{target_lang}.{source_lang}" - ) - dummy_simple_filepath = ( - temp_dir / f"tmp.{source_lang}-{target_lang}.{target_lang}" - ) - shutil.copy(complex_filepath, new_complex_filepath) - shutil.copy(complex_filepath, dummy_simple_filepath) - shutil.copy(complex_dictionary_path, temp_dir / f"dict.{source_lang}.txt") - shutil.copy(simple_dictionary_path, temp_dir / f"dict.{target_lang}.txt") - args = f""" - {temp_dir} --dataset-impl raw --gen-subset tmp --path {':'.join([str(path) for path in checkpoint_paths])} - --beam {beam} --nbest {hypothesis_num} --lenpen {lenpen} - --diverse-beam-groups {diverse_beam_groups if diverse_beam_groups is not None else -1} --diverse-beam-strength {diverse_beam_strength} - --max-tokens {max_tokens} - --model-overrides "{{'encoder_embed_path': None, 'decoder_embed_path': None}}" - --skip-invalid-size-inputs-valid-test - """ - if sampling: - args += f"--sampling --sampling-topk 10" - # FIXME: if the kwargs are already present in the args string, they will appear twice but fairseq will take only the last one into account - args += f" {args_dict_to_str(kwargs)}" - args = remove_multiple_whitespaces(args.replace("\n", " ")) - out_filepath = temp_dir / "generation.out" - with mute(mute_stderr=False): - with log_std_streams(out_filepath): - # evaluate model in batch mode - args = shlex.split(args) - with mock_cli_args(args): - generate.cli_main() - - all_hypotheses = fairseq_parse_all_hypotheses(out_filepath) - predictions = [hypotheses[hypothesis_num - 1] for hypotheses in all_hypotheses] - write_lines(predictions, output_pred_filepath) - - -def fairseq_generate( - complex_filepath, - output_pred_filepath, - exp_dir, - beam=5, - hypothesis_num=1, - lenpen=1.0, - diverse_beam_groups=None, - diverse_beam_strength=0.5, - sampling=False, - max_tokens=8000, - source_lang="complex", - target_lang="simple", - **kwargs, -): - - exp_dir = Path(exp_dir) - possible_checkpoint_paths = [ - exp_dir / "model.pt", - exp_dir / "checkpoints/checkpoint_best.pt", - exp_dir / "checkpoints/checkpoint_last.pt", - ] - assert any( - [path for path in possible_checkpoint_paths if path.exists()] - ), f"Generation failed, no checkpoint found in {possible_checkpoint_paths}" # noqa: E501 - checkpoint_path = [path for path in possible_checkpoint_paths if path.exists()][0] - complex_dictionary_path = exp_dir / f"dict.{source_lang}.txt" - simple_dictionary_path = exp_dir / f"dict.{target_lang}.txt" - _fairseq_generate( - complex_filepath, - output_pred_filepath, - [checkpoint_path], - complex_dictionary_path=complex_dictionary_path, - simple_dictionary_path=simple_dictionary_path, - beam=beam, - hypothesis_num=hypothesis_num, - lenpen=lenpen, - diverse_beam_groups=diverse_beam_groups, - diverse_beam_strength=diverse_beam_strength, - sampling=sampling, - max_tokens=max_tokens, - **kwargs, - ) diff --git a/simplify/simplifiers/models/muss/preprocessors.py b/simplify/simplifiers/models/muss/preprocessors.py deleted file mode 100644 index f0e728f..0000000 --- a/simplify/simplifiers/models/muss/preprocessors.py +++ /dev/null @@ -1,601 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from abc import ABC -from functools import wraps, lru_cache -import hashlib -from pathlib import Path - -# import dill as pickle -import shutil - -import re - -import numpy as np - -from fairseq.data.encoders.gpt2_bpe_utils import get_encoder - -from .utils import ( - write_lines_in_parallel, - yield_lines_in_parallel, - add_dicts, - get_default_args, - get_temp_filepath, - failsafe_division, - download, - download_and_extract, - yield_lines, -) - -from simplify import SIMPLIFY_CACHE -from simplify.evaluators import lev_ratio - -FATTEXT_EMBEDDINGS_DIR = SIMPLIFY_CACHE / "fasttext-vectors/" -SPECIAL_TOKEN_REGEX = r"<[a-zA-Z\-_\d\.]+>" - - -def get_fasttext_embeddings_path(language="en"): - fasttext_embeddings_path = FASTTEXT_EMBEDDINGS_DIR / f"cc.{language}.300.vec" - if not fasttext_embeddings_path.exists(): - url = f"https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.{language}.300.vec.gz" - fasttext_embeddings_path.parent.mkdir(parents=True, exist_ok=True) - shutil.move(download_and_extract(url)[0], fasttext_embeddings_path) - return fasttext_embeddings_path - - -@lru_cache(maxsize=10) -def get_spacy_model(language="en", size="md"): - # Inline lazy import because importing spacy is slow - import spacy - - if language == "it" and size == "md": - print( - "Model it_core_news_md is not available for italian, falling back to it_core_news_sm" - ) - size = "sm" - model_name = { - "en": f"en_core_web_{size}", - "fr": f"fr_core_news_{size}", - "es": f"es_core_news_{size}", - "it": f"it_core_news_{size}", - "de": f"de_core_news_{size}", - }[language] - return spacy.load(model_name) # python -m spacy download en_core_web_sm - - -@lru_cache(maxsize=10 ** 6) -def spacy_process(text, language="en", size="md"): - return get_spacy_model(language=language, size=size)(str(text)) - - -@lru_cache(maxsize=1) -def get_spacy_tokenizer(language="en"): - return get_spacy_model(language=language).Defaults.create_tokenizer( - get_spacy_model(language=language) - ) - - -def get_spacy_content_tokens(text, language="en"): - def is_content_token(token): - return ( - not token.is_stop and not token.is_punct and token.ent_type_ == "" - ) # Not named entity - - return [ - token - for token in get_spacy_tokenizer(language=language)(text) - if is_content_token(token) - ] - - -def get_content_words(text, language="en"): - return [token.text for token in get_spacy_content_tokens(text, language=language)] - - -@lru_cache(maxsize=10) -def get_word2rank(vocab_size=10 ** 5, language="en"): - word2rank = {} - line_generator = yield_lines(get_fasttext_embeddings_path(language)) - next(line_generator) # Skip the first line (header) - for i, line in enumerate(line_generator): - if (i + 1) > vocab_size: - break - word = line.split(" ")[0] - word2rank[word] = i - return word2rank - - -def get_rank(word, language="en"): - return get_word2rank(language=language).get( - word, len(get_word2rank(language=language)) - ) - - -def get_log_rank(word, language="en"): - return np.log(1 + get_rank(word, language=language)) - - -def get_log_ranks(text, language="en"): - return [ - get_log_rank(word, language=language) - for word in get_content_words(text, language=language) - if word in get_word2rank(language=language) - ] - - -# Single sentence feature extractors with signature function(sentence) -> float -def get_lexical_complexity_score(sentence, language="en"): - log_ranks = get_log_ranks(sentence, language=language) - if len(log_ranks) == 0: - log_ranks = [ - np.log(1 + len(get_word2rank())) - ] # TODO: This is completely arbitrary - return np.quantile(log_ranks, 0.75) - - -def get_levenshtein_similarity(complex_sentence, simple_sentence): - return lev_ratio(complex_sentence, simple_sentence) - - -def get_levenshtein_distance(complex_sentence, simple_sentence): - # We should rename this to get_levenshtein_distance_ratio for more clarity - return 1 - get_levenshtein_similarity(complex_sentence, simple_sentence) - - -def get_replace_only_levenshtein_distance(complex_sentence, simple_sentence): - return len( - [ - _ - for operation, _, _ in Levenshtein.editops( - complex_sentence, simple_sentence - ) - if operation == "replace" - ] - ) - - -def get_replace_only_levenshtein_distance_ratio(complex_sentence, simple_sentence): - max_replace_only_distance = min(len(complex_sentence), len(simple_sentence)) - return failsafe_division( - get_replace_only_levenshtein_distance(complex_sentence, simple_sentence), - max_replace_only_distance, - default=0, - ) - - -def get_replace_only_levenshtein_similarity(complex_sentence, simple_sentence): - return 1 - get_replace_only_levenshtein_distance_ratio( - complex_sentence, simple_sentence - ) - - -def get_dependency_tree_depth(sentence, language="en"): - def get_subtree_depth(node): - if len(list(node.children)) == 0: - return 0 - return 1 + max([get_subtree_depth(child) for child in node.children]) - - tree_depths = [ - get_subtree_depth(spacy_sentence.root) - for spacy_sentence in spacy_process(sentence, language=language).sents - ] - if len(tree_depths) == 0: - return 0 - return max(tree_depths) - - -PREPROCESSORS_REGISTRY = {} - - -def get_preprocessor_by_name(preprocessor_name): - return PREPROCESSORS_REGISTRY[preprocessor_name] - - -def get_preprocessors(preprocessors_kwargs): - preprocessors = [] - for preprocessor_name, kwargs in preprocessors_kwargs.items(): - preprocessors.append(get_preprocessor_by_name(preprocessor_name)(**kwargs)) - return preprocessors - - -def extract_special_tokens(sentence): - """Remove any number of token at the beginning of the sentence""" - match = re.match(fr"(^(?:{SPECIAL_TOKEN_REGEX} *)+) *(.*)$", sentence) - if match is None: - return "", sentence - special_tokens, sentence = match.groups() - return special_tokens.strip(), sentence - - -def remove_special_tokens(sentence): - return extract_special_tokens(sentence)[1] - - -def store_args(constructor): - @wraps(constructor) - def wrapped(self, *args, **kwargs): - if not hasattr(self, "args") or not hasattr(self, "kwargs"): - # TODO: Default args are not overwritten if provided as args - self.args = args - self.kwargs = add_dicts(get_default_args(constructor), kwargs) - return constructor(self, *args, **kwargs) - - return wrapped - - -# def dump_preprocessors(preprocessors, dir_path): -# with open(Path(dir_path) / 'preprocessors.pickle', 'wb') as f: -# pickle.dump(preprocessors, f) - - -# def load_preprocessors(dir_path): -# path = Path(dir_path) / 'preprocessors.pickle' -# if not path.exists(): -# return None -# with open(path, 'rb') as f: -# return pickle.load(f) - - -class AbstractPreprocessor(ABC): - def __init_subclass__(cls, **kwargs): - """Register all children in registry""" - super().__init_subclass__(**kwargs) - PREPROCESSORS_REGISTRY[cls.__name__] = cls - - def __repr__(self): - args = getattr(self, "args", ()) - kwargs = getattr(self, "kwargs", {}) - args_repr = [repr(arg) for arg in args] - kwargs_repr = [ - f"{k}={repr(v)}" for k, v in sorted(kwargs.items(), key=lambda kv: kv[0]) - ] - args_kwargs_str = ", ".join(args_repr + kwargs_repr) - return f"{self.__class__.__name__}({args_kwargs_str})" - - def get_hash_string(self): - return self.__class__.__name__ - - def get_hash(self): - return hashlib.md5(self.get_hash_string().encode()).hexdigest() - - @staticmethod - def get_nevergrad_variables(): - return {} - - @property - def prefix(self): - return self.__class__.__name__.replace("Preprocessor", "") - - def fit(self, complex_filepath, simple_filepath): - pass - - def encode_sentence(self, sentence, encoder_sentence=None): - raise NotImplementedError - - def decode_sentence(self, sentence, encoder_sentence=None): - raise NotImplementedError - - def encode_sentence_pair(self, complex_sentence, simple_sentence): - if complex_sentence is not None: - complex_sentence = self.encode_sentence(complex_sentence) - if simple_sentence is not None: - simple_sentence = self.encode_sentence(simple_sentence) - return complex_sentence, simple_sentence - - def encode_file(self, input_filepath, output_filepath, encoder_filepath=None): - if encoder_filepath is None: - # We will use an empty temporary file which will yield None for each line - encoder_filepath = get_temp_filepath(create=True) - with open(output_filepath, "w", encoding="utf-8") as f: - for input_line, encoder_line in yield_lines_in_parallel( - [input_filepath, encoder_filepath], strict=False - ): - f.write(self.encode_sentence(input_line, encoder_line) + "\n") - - def decode_file(self, input_filepath, output_filepath, encoder_filepath=None): - if encoder_filepath is None: - # We will use an empty temporary file which will yield None for each line - encoder_filepath = get_temp_filepath(create=True) - with open(output_filepath, "w", encoding="utf-8") as f: - for encoder_sentence, input_sentence in yield_lines_in_parallel( - [encoder_filepath, input_filepath], strict=False - ): - decoded_sentence = self.decode_sentence( - input_sentence, encoder_sentence=encoder_sentence - ) - f.write(decoded_sentence + "\n") - - def encode_file_pair( - self, - complex_filepath, - simple_filepath, - output_complex_filepath, - output_simple_filepath, - ): - """Jointly encode a complex file and a simple file (can be aligned or not)""" - with write_lines_in_parallel( - [output_complex_filepath, output_simple_filepath], strict=False - ) as output_files: - for complex_line, simple_line in yield_lines_in_parallel( - [complex_filepath, simple_filepath], strict=False - ): - output_files.write(self.encode_sentence_pair(complex_line, simple_line)) - - -class ComposedPreprocessor(AbstractPreprocessor): - @store_args - def __init__(self, preprocessors, sort=False): - if preprocessors is None: - preprocessors = [] - if sort: - # Make sure preprocessors are always in the same order - preprocessors = sorted( - preprocessors, key=lambda preprocessor: preprocessor.__class__.__name__ - ) - self.preprocessors = preprocessors - - def get_hash_string(self): - preprocessors_hash_strings = [ - preprocessor.get_hash_string() for preprocessor in self.preprocessors - ] - return f"ComposedPreprocessor(preprocessors={preprocessors_hash_strings})" - - def get_suffix(self): - return "_".join([p.prefix.lower() for p in self.preprocessors]) - - def fit(self, complex_filepath, simple_filepath): - for preprocessor in self.preprocessors: - pass - - def encode_sentence(self, sentence, encoder_sentence=None): - for preprocessor in self.preprocessors: - sentence = preprocessor.encode_sentence(sentence, encoder_sentence) - return sentence - - def decode_sentence(self, sentence, encoder_sentence=None): - for preprocessor in self.preprocessors: - sentence = preprocessor.decode_sentence(sentence, encoder_sentence) - return sentence - - def encode_file(self, input_filepath, output_filepath, encoder_filepath=None): - for preprocessor in self.preprocessors: - intermediary_output_filepath = get_temp_filepath() - preprocessor.encode_file( - input_filepath, intermediary_output_filepath, encoder_filepath - ) - input_filepath = intermediary_output_filepath - shutil.copyfile(input_filepath, output_filepath) - - def decode_file(self, input_filepath, output_filepath, encoder_filepath=None): - for preprocessor in self.preprocessors: - intermediary_output_filepath = get_temp_filepath() - preprocessor.decode_file( - input_filepath, intermediary_output_filepath, encoder_filepath - ) - input_filepath = intermediary_output_filepath - shutil.copyfile(input_filepath, output_filepath) - - def encode_file_pair( - self, - complex_filepath, - simple_filepath, - output_complex_filepath, - output_simple_filepath, - ): - for preprocessor in self.preprocessors: - intermediary_output_complex_filepath = get_temp_filepath() - intermediary_output_simple_filepath = get_temp_filepath() - preprocessor.encode_file_pair( - complex_filepath, - simple_filepath, - intermediary_output_complex_filepath, - intermediary_output_simple_filepath, - ) - complex_filepath = intermediary_output_complex_filepath - simple_filepath = intermediary_output_simple_filepath - shutil.copyfile(complex_filepath, output_complex_filepath) - shutil.copyfile(simple_filepath, output_simple_filepath) - - def encode_sentence_pair(self, complex_sentence, simple_sentence): - for preprocessor in self.preprocessors: - complex_sentence, simple_sentence = preprocessor.encode_sentence_pair( - complex_sentence, simple_sentence - ) - return complex_sentence, simple_sentence - - -class FeaturePreprocessor(AbstractPreprocessor): - """Prepend a computed feature at the beginning of the sentence""" - - @store_args - def __init__( - self, - feature_name, - get_feature_value, - get_target_feature_value, - bucket_size=0.05, - noise_std=0, - prepend_to_target=False, - use_short_name=False, - ): - self.get_feature_value = get_feature_value - self.get_target_feature_value = get_target_feature_value - self.bucket_size = bucket_size - self.noise_std = noise_std - self.feature_name = feature_name.upper() - self.use_short_name = use_short_name - if use_short_name: - # There might be collisions - self.feature_name = self.feature_name.lower()[:4] - self.prepend_to_target = prepend_to_target - - def get_hash_string(self): - return f"{self.__class__.__name__}(feature_name={repr(self.feature_name)}, bucket_size={self.bucket_size}, noise_std={self.noise_std}, prepend_to_target={self.prepend_to_target}, use_short_name={self.use_short_name})" # noqa: E501 - - def bucketize(self, value): - """Round value to bucket_size to reduce the number of different values""" - return round(round(value / self.bucket_size) * self.bucket_size, 10) - - def add_noise(self, value): - return value + np.random.normal(0, self.noise_std) - - def get_feature_token(self, feature_value): - return f"<{self.feature_name}_{feature_value}>" - - def encode_sentence(self, sentence, encoder_sentence=None): - if not self.prepend_to_target: - desired_feature = self.bucketize( - self.get_target_feature_value(remove_special_tokens(sentence)) - ) - sentence = f"{self.get_feature_token(desired_feature)} {sentence}" - return sentence - - def decode_sentence(self, sentence, encoder_sentence=None): - if self.prepend_to_target: - _, sentence = extract_special_tokens(sentence) - return sentence - - def encode_sentence_pair(self, complex_sentence, simple_sentence): - feature = self.bucketize( - self.add_noise( - self.get_feature_value( - remove_special_tokens(complex_sentence), - remove_special_tokens(simple_sentence), - ) - ) - ) - if self.prepend_to_target: - simple_sentence = f"{self.get_feature_token(feature)} {simple_sentence}" - else: - complex_sentence = f"{self.get_feature_token(feature)} {complex_sentence}" - return complex_sentence, simple_sentence - - -class LevenshteinPreprocessor(FeaturePreprocessor): - @store_args - def __init__(self, target_ratio=0.8, bucket_size=0.05, noise_std=0, **kwargs): - self.target_ratio = target_ratio - super().__init__( - self.prefix.upper(), - self.get_feature_value, - self.get_target_feature_value, - bucket_size, - noise_std, - **kwargs, - ) - - def get_feature_value(self, complex_sentence, simple_sentence): - return get_levenshtein_similarity(complex_sentence, simple_sentence) - - def get_target_feature_value(self, complex_sentence): - return self.target_ratio - - -class ReplaceOnlyLevenshteinPreprocessor(LevenshteinPreprocessor): - def get_feature_value(self, complex_sentence, simple_sentence): - return get_replace_only_levenshtein_similarity( - complex_sentence, simple_sentence - ) - - -class RatioPreprocessor(FeaturePreprocessor): - @store_args - def __init__( - self, - feature_extractor, - target_ratio=0.8, - bucket_size=0.05, - noise_std=0, - **kwargs, - ): - self.feature_extractor = feature_extractor - self.target_ratio = target_ratio - super().__init__( - self.prefix.upper(), - self.get_feature_value, - self.get_target_feature_value, - bucket_size, - noise_std, - **kwargs, - ) - - def get_feature_value(self, complex_sentence, simple_sentence): - return min( - failsafe_division( - self.feature_extractor(simple_sentence), - self.feature_extractor(complex_sentence), - ), - 2, - ) - - def get_target_feature_value(self, complex_sentence): - return self.target_ratio - - -class LengthRatioPreprocessor(RatioPreprocessor): - @store_args - def __init__(self, *args, **kwargs): - super().__init__(len, *args, **kwargs) - - -class WordRankRatioPreprocessor(RatioPreprocessor): - @store_args - def __init__(self, *args, language="en", **kwargs): - super().__init__( - lambda sentence: get_lexical_complexity_score(sentence, language=language), - *args, - **kwargs, - ) - - -class DependencyTreeDepthRatioPreprocessor(RatioPreprocessor): - @store_args - def __init__(self, *args, language="en", **kwargs): - super().__init__( - lambda sentence: get_dependency_tree_depth(sentence, language=language), - *args, - **kwargs, - ) - - -class GPT2BPEPreprocessor(AbstractPreprocessor): - def __init__(self): - self.bpe_dir = SIMPLIFY_CACHE / "bart_bpe" - self.bpe_dir.mkdir(exist_ok=True, parents=True) - self.encoder_json_path = self.bpe_dir / "encoder.json" - self.vocab_bpe_path = self.bpe_dir / "vocab.bpe" - self.dict_path = self.bpe_dir / "dict.txt" - download( - "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json", - self.encoder_json_path, - overwrite=False, - ) - download( - "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe", - self.vocab_bpe_path, - overwrite=False, - ) - download( - "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt", - self.dict_path, - overwrite=False, - ) - - @property - @lru_cache(maxsize=1) - def bpe_encoder(self): - """ - We need to use a property because GPT2BPEPreprocessor() is cannot pickled - > pickle.dumps(GPT2BPEPreprocessor()) - ----> TypeError: can't pickle module objects - """ - return get_encoder(self.encoder_json_path, self.vocab_bpe_path) - - def encode_sentence(self, sentence, *args, **kwargs): - return " ".join([str(idx) for idx in self.bpe_encoder.encode(sentence)]) - - def decode_sentence(self, sentence, *args, **kwargs): - return self.bpe_encoder.decode([int(idx) for idx in sentence.split(" ")]) diff --git a/simplify/simplifiers/models/muss/simplifiers.py b/simplify/simplifiers/models/muss/simplifiers.py deleted file mode 100644 index 08deef5..0000000 --- a/simplify/simplifiers/models/muss/simplifiers.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from functools import wraps -from pathlib import Path -import shutil - -from imohash import hashfile - -from .fairseq_util import fairseq_generate -from .preprocessors import ComposedPreprocessor -from .utils import count_lines, get_temp_filepath - - -def memoize_simplifier(simplifier): - memo = {} - - @wraps(simplifier) - def wrapped(complex_filepath, pred_filepath): - complex_filehash = hashfile(complex_filepath, hexdigest=True) - previous_pred_filepath = memo.get(complex_filehash) - if previous_pred_filepath is not None and Path(previous_pred_filepath).exists(): - assert count_lines(complex_filepath) == count_lines(previous_pred_filepath) - # Reuse previous prediction - shutil.copyfile(previous_pred_filepath, pred_filepath) - else: - simplifier(complex_filepath, pred_filepath) - # Save prediction - memo[complex_filehash] = pred_filepath - - return wrapped - - -def make_output_file_optional(simplifier): - @wraps(simplifier) - def wrapped(complex_filepath, pred_filepath=None): - if pred_filepath is None: - pred_filepath = get_temp_filepath() - simplifier(complex_filepath, pred_filepath) - return pred_filepath - - return wrapped - - -def get_fairseq_simplifier(exp_dir, **kwargs): - """Function factory""" - - @make_output_file_optional - @memoize_simplifier - def fairseq_simplifier(complex_filepath, output_pred_filepath): - fairseq_generate(complex_filepath, output_pred_filepath, exp_dir, **kwargs) - - return fairseq_simplifier - - -def get_preprocessed_simplifier(simplifier, preprocessors): - composed_preprocessor = ComposedPreprocessor(preprocessors) - - @make_output_file_optional - @memoize_simplifier - @wraps(simplifier) - def preprocessed_simplifier(complex_filepath, pred_filepath): - preprocessed_complex_filepath = get_temp_filepath() - composed_preprocessor.encode_file( - complex_filepath, preprocessed_complex_filepath - ) - preprocessed_pred_filepath = simplifier(preprocessed_complex_filepath) - composed_preprocessor.decode_file( - preprocessed_pred_filepath, pred_filepath, encoder_filepath=complex_filepath - ) - - preprocessed_simplifier.__name__ = ( - f"{preprocessed_simplifier.__name__}_{composed_preprocessor.get_suffix()}" - ) - return preprocessed_simplifier diff --git a/simplify/simplifiers/models/muss/utils.py b/simplify/simplifiers/models/muss/utils.py deleted file mode 100644 index 55decb5..0000000 --- a/simplify/simplifiers/models/muss/utils.py +++ /dev/null @@ -1,385 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -from contextlib import contextmanager -import gzip -import inspect -from io import StringIO -from itertools import zip_longest -from pathlib import Path -import shutil -import sys -import tempfile -import time -from types import MethodType - - -import bz2 -import os -import tarfile -from urllib.request import urlretrieve -import zipfile - -from tqdm import tqdm - - -def reporthook(count, block_size, total_size): - # Download progress bar - global start_time - if count == 0: - start_time = time.time() - return - duration = time.time() - start_time - progress_size_mb = count * block_size / (1024 * 1024) - speed = progress_size_mb / duration - percent = int(count * block_size * 100 / total_size) - msg = f"\r... {percent}% - {int(progress_size_mb)} MB - {speed:.2f} MB/s - {int(duration)}s" - sys.stdout.write(msg) - - -def download(url, destination_path=None, overwrite=True): - if destination_path is None: - destination_path = get_temp_filepath() - if not overwrite and destination_path.exists(): - return destination_path - print("Downloading...") - try: - urlretrieve(url, destination_path, reporthook) - sys.stdout.write("\n") - except (Exception, KeyboardInterrupt, SystemExit): - print("Rolling back: remove partially downloaded file") - os.remove(destination_path) - raise - return destination_path - - -def download_and_extract(url): - tmp_dir = Path(tempfile.mkdtemp()) - compressed_filename = url.split("/")[-1] - compressed_filepath = tmp_dir / compressed_filename - download(url, compressed_filepath) - print("Extracting...") - extracted_paths = extract(compressed_filepath, tmp_dir) - compressed_filepath.unlink() - return extracted_paths - - -def extract(filepath, output_dir): - output_dir = Path(output_dir) - # Infer extract function based on extension - extensions_to_functions = { - ".tar.gz": untar, - ".tar.bz2": untar, - ".tgz": untar, - ".zip": unzip, - ".gz": ungzip, - ".bz2": unbz2, - } - - def get_extension(filename, extensions): - possible_extensions = [ext for ext in extensions if filename.endswith(ext)] - if len(possible_extensions) == 0: - raise Exception(f"File {filename} has an unknown extension") - # Take the longest (.tar.gz should take precedence over .gz) - return max(possible_extensions, key=lambda ext: len(ext)) - - filename = os.path.basename(filepath) - extension = get_extension(filename, list(extensions_to_functions)) - extract_function = extensions_to_functions[extension] - - # Extract files in a temporary dir then move the extracted item back to - # the ouput dir in order to get the details of what was extracted - tmp_extract_dir = Path(tempfile.mkdtemp()) - # Extract - extract_function(filepath, output_dir=tmp_extract_dir) - extracted_items = os.listdir(tmp_extract_dir) - output_paths = [] - for name in extracted_items: - extracted_path = tmp_extract_dir / name - output_path = output_dir / name - move_with_overwrite(extracted_path, output_path) - output_paths.append(output_path) - return output_paths - - -def move_with_overwrite(source_path, target_path): - if os.path.isfile(target_path): - os.remove(target_path) - if os.path.isdir(target_path) and os.path.isdir(source_path): - shutil.rmtree(target_path) - shutil.move(source_path, target_path) - - -def untar(compressed_path, output_dir): - with tarfile.open(compressed_path) as f: - f.extractall(output_dir) - - -def unzip(compressed_path, output_dir): - with zipfile.ZipFile(compressed_path, "r") as f: - f.extractall(output_dir) - - -def ungzip(compressed_path, output_dir): - filename = os.path.basename(compressed_path) - assert filename.endswith(".gz") - if not os.path.exists(output_dir): - os.makedirs(output_dir) - output_path = os.path.join(output_dir, filename[:-3]) - with gzip.open(compressed_path, "rb") as f_in: - with open(output_path, "wb") as f_out: - shutil.copyfileobj(f_in, f_out) - - -def unbz2(compressed_path, output_dir): - extract_filename = os.path.basename(compressed_path).replace(".bz2", "") - extract_path = os.path.join(output_dir, extract_filename) - with bz2.BZ2File(compressed_path, "rb") as compressed_file, open( - extract_path, "wb" - ) as extract_file: - for data in tqdm(iter(lambda: compressed_file.read(1024 * 1024), b"")): - extract_file.write(data) - - -@contextmanager -def open_files(filepaths, mode="r"): - files = [] - try: - files = [Path(filepath).open(mode, encoding="utf-8") for filepath in filepaths] - yield files - finally: - [f.close() for f in files] - - -def yield_lines_in_parallel(filepaths, strip=True, strict=True, n_lines=float("inf")): - assert type(filepaths) == list - with open_files(filepaths) as files: - for i, parallel_lines in enumerate(zip_longest(*files)): - if i >= n_lines: - break - if None in parallel_lines: - assert ( - not strict - ), f"Files don't have the same number of lines: {filepaths}, use strict=False" - if strip: - parallel_lines = [ - l.rstrip("\n") if l is not None else None for l in parallel_lines - ] - yield parallel_lines - - -class FilesWrapper: - """Write to multiple open files at the same time""" - - def __init__(self, files, strict=True): - self.files = files - self.strict = strict # Whether to raise an exception when a line is None - - def write(self, lines): - assert len(lines) == len(self.files) - for line, f in zip(lines, self.files): - if line is None: - assert not self.strict - continue - f.write(line.rstrip("\n") + "\n") - - -@contextmanager -def write_lines_in_parallel(filepaths, strict=True): - with open_files(filepaths, "w") as files: - yield FilesWrapper(files, strict=strict) - - -def write_lines(lines, filepath=None): - if filepath is None: - filepath = get_temp_filepath() - filepath = Path(filepath) - filepath.parent.mkdir(parents=True, exist_ok=True) - with filepath.open("w", encoding="utf-8") as f: - for line in lines: - f.write(line + "\n") - return filepath - - -def yield_lines(filepath, gzipped=False, n_lines=None): - filepath = Path(filepath) - open_function = open - if gzipped or filepath.name.endswith(".gz"): - open_function = gzip.open - with open_function(filepath, "rt", encoding="utf-8") as f: - for i, l in enumerate(f): - if n_lines is not None and i >= n_lines: - break - yield l.rstrip("\n") - - -def read_lines(filepath, gzipped=False): - return list(yield_lines(filepath, gzipped=gzipped)) - - -def count_lines(filepath): - n_lines = 0 - # We iterate over the generator to avoid loading the whole file in memory - for _ in yield_lines(filepath): - n_lines += 1 - return n_lines - - -def arg_name_python_to_cli(arg_name, cli_sep="-"): - arg_name = arg_name.replace("_", cli_sep) - return f"--{arg_name}" - - -def kwargs_to_cli_args_list(kwargs, cli_sep="-"): - cli_args_list = [] - for key, val in kwargs.items(): - key = arg_name_python_to_cli(key, cli_sep) - if isinstance(val, bool): - cli_args_list.append(str(key)) - else: - if isinstance(val, str): - # Add quotes around val - assert "'" not in val - val = f"'{val}'" - cli_args_list.extend([str(key), str(val)]) - return cli_args_list - - -def args_dict_to_str(args_dict, cli_sep="-"): - return " ".join(kwargs_to_cli_args_list(args_dict, cli_sep=cli_sep)) - - -def failsafe_division(a, b, default=0): - if b == 0: - return default - return a / b - - -@contextmanager -def redirect_streams(source_streams, target_streams): - # We assign these functions before hand in case a target stream is also a source stream. - # If it's the case then the write function would be patched leading to infinie recursion - target_writes = [target_stream.write for target_stream in target_streams] - target_flushes = [target_stream.flush for target_stream in target_streams] - - def patched_write(self, message): - for target_write in target_writes: - target_write(message) - - def patched_flush(self): - for target_flush in target_flushes: - target_flush() - - original_source_stream_writes = [ - source_stream.write for source_stream in source_streams - ] - original_source_stream_flushes = [ - source_stream.flush for source_stream in source_streams - ] - try: - for source_stream in source_streams: - source_stream.write = MethodType(patched_write, source_stream) - source_stream.flush = MethodType(patched_flush, source_stream) - yield - finally: - for ( - source_stream, - original_source_stream_write, - original_source_stream_flush, - ) in zip( - source_streams, - original_source_stream_writes, - original_source_stream_flushes, - ): - source_stream.write = original_source_stream_write - source_stream.flush = original_source_stream_flush - - -@contextmanager -def mute(mute_stdout=True, mute_stderr=True): - streams = [] - if mute_stdout: - streams.append(sys.stdout) - if mute_stderr: - streams.append(sys.stderr) - with redirect_streams(source_streams=streams, target_streams=StringIO()): - yield - - -@contextmanager -def log_std_streams(filepath): - log_file = open(filepath, "w", encoding="utf-8") - try: - with redirect_streams( - source_streams=[sys.stdout], target_streams=[log_file, sys.stdout] - ): - with redirect_streams( - source_streams=[sys.stderr], target_streams=[log_file, sys.stderr] - ): - yield - finally: - log_file.close() - - -def add_dicts(*dicts): - return {k: v for dic in dicts for k, v in dic.items()} - - -def get_default_args(func): - signature = inspect.signature(func) - return { - k: v.default - for k, v in signature.parameters.items() - if v.default is not inspect.Parameter.empty - } - - -TEMP_DIR = None - - -def get_temp_filepath(create=False): - global TEMP_DIR - temp_filepath = Path(tempfile.mkstemp()[1]) - if TEMP_DIR is not None: - temp_filepath.unlink() - temp_filepath = TEMP_DIR / temp_filepath.name - temp_filepath.touch(exist_ok=False) - if not create: - temp_filepath.unlink() - return temp_filepath - - -def get_temp_dir(): - return Path(tempfile.mkdtemp()) - - -@contextmanager -def create_temp_dir(): - temp_dir = get_temp_dir() - try: - yield temp_dir - finally: - shutil.rmtree(temp_dir) - - -@contextmanager -def log_action(action_description): - start_time = time.time() - print(f"{action_description}...") - try: - yield - except BaseException as e: - print(f"{action_description} failed after {time.time() - start_time:.2f}s.") - raise e - print(f"{action_description} completed after {time.time() - start_time:.2f}s.") - - -@contextmanager -def mock_cli_args(args): - current_args = sys.argv - sys.argv = sys.argv[:1] + args - yield - sys.argv = current_args