diff --git a/simplify/simplifiers/models/ie/__init__.py b/simplify/simplifiers/ie/__init__.py
similarity index 100%
rename from simplify/simplifiers/models/ie/__init__.py
rename to simplify/simplifiers/ie/__init__.py
diff --git a/simplify/simplifiers/ie/data.py b/simplify/simplifiers/ie/data.py
new file mode 100644
index 0000000..b05c9b8
--- /dev/null
+++ b/simplify/simplifiers/ie/data.py
@@ -0,0 +1,367 @@
+import json
+import os
+from collections import Counter
+
+import nltk
+
+import torch
+from torch.utils.data import DataLoader
+from transformers import AutoTokenizer
+
+from utils import (Coordination, coords_to_sentences, ext_to_triplet,
+ get_config, process_extraction)
+
+
+def collate_pad_data(max_depth):
+ max_depth = max_depth
+
+ def wrapper(data):
+ nonlocal max_depth
+ return pad_data(data, max_depth)
+
+ return wrapper
+
+
+def pad_data(data, max_depth):
+ input_ids = [ex["input_ids"] for ex in data]
+ word_begin_index = [d["word_begin_index"] for d in data]
+ sentence_index = [d["sentence_index"] for d in data]
+ labels = [d["labels"] for d in data]
+ max_length = max((len(i) for i in input_ids))
+ input_ids = torch.tensor([pad_list(i, max_length, 0) for i in input_ids])
+ max_length = max((len(i) for i in word_begin_index))
+ word_begin_index = torch.tensor(
+ [pad_list(i, max_length, 0) for i in word_begin_index]
+ )
+ sentence_index = torch.tensor(sentence_index)
+ for i, label in enumerate(labels):
+ to_add = max_depth - len(label)
+ if to_add > 0:
+ labels[i] = labels[i] + [[0] * len(label[0])] * to_add
+ labels = torch.tensor(
+ [[pad_list(l, max_length, -100) for l in label] for label in labels]
+ )
+ # each input_ids has multiple labels(targets) following the depth, here we make sure that all
+ # the input_idss have the number of lables
+ # NOTE: labels need to be padded
+ padded = {
+ "input_ids": input_ids,
+ "labels": labels,
+ "word_begin_index": word_begin_index,
+ "sentence_index": sentence_index,
+ }
+ return dot_dict(padded)
+
+
+def pad_list(list_, size, padding_token=0):
+ to_add = size - len(list_)
+ if to_add > 0:
+ return list_ + [padding_token] * to_add
+ if to_add == 0:
+ return list_
+ return list_[:-(-to_add)] # truncate
+
+
+def truncate(input_ids, word_begin_index, max_length):
+ max_len = max_length - 2
+ input_ids = input_ids[:max_len]
+ unsqueeze_wb = [[i] * j for i, j in enumerate(word_begin_index)]
+ flatten_uwb = [i for l in unsqueeze_wb for i in l]
+ flatten_uwb = flatten_uwb[:max_len]
+ word_begin_index = list(Counter(flatten_uwb).values())
+ return input_ids, word_begin_index
+
+
+class dot_dict(dict):
+ """dot.notation access to dictionary attributes"""
+
+ __getattr__ = dict.get
+ __setattr__ = dict.__setitem__
+ __delattr__ = dict.__delitem__
+
+
+ADDITIONAL_SPECIAL_TOKENS = [f"[unused{i}]" for i in range(1, 4)]
+
+
+class Data:
+ def __init__(self, config, *args, **kwargs):
+ self.config = dot_dict(config) if isinstance(config, dict) else config
+ self.config.additional_special_tokens = (
+ self.config.additional_special_tokens or ADDITIONAL_SPECIAL_TOKENS
+ )
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ self.config.model_path,
+ additional_special_tokens=self.config.additional_special_tokens,
+ )
+
+ self.max_word_length = self.config.max_word_length or 100
+ self.bos_token_id = self.config.bos_token_id or 101
+ self.eos_token_id = self.config.eos_token_id or 102
+
+ def batch_encode_sentence(self, sentences, labels=None, inplace=True):
+ if labels is None:
+ labels = [None] * len(sentences)
+ assert len(labels) == len(
+ sentences
+ ), "make sure that sentences and labels have the same length and that they map to each other"
+ self.orig_sentences = sentences
+ encoded_sentences = [
+ self.encode_sentence(s, l) for s, l in zip(sentences, labels)
+ ]
+ self.input_ids, self.word_begin_indexes, self.sentences, self.labels = zip(
+ *encoded_sentences
+ )
+ if not inplace:
+ return input_ids, word_begin_indexes, sentences, labels
+
+ def encode_sentence(self, *args, **kwargs):
+ NotImplemented
+
+ def to_allennlp_format(extractions, sentence):
+ for extra in extractions:
+ args1 = extra.args[0]
+ pred = extra.pred
+ args2 = " ".join(extra.args[1:])
+ confidence = extra.confidence
+
+ extra_sent = f"{sentence}\t {args1} {pred} {arg2} \t{confidence}"
+
+ def to_dataloader(
+ self,
+ input_ids=None,
+ word_begin_indexes=None,
+ sentences=None,
+ labels=None,
+ batch_size=8,
+ shuffle=False,
+ collate_fn=None,
+ ):
+ input_ids = input_ids or self.input_ids
+ word_begin_indexes = word_begin_indexes or self.word_begin_indexes
+ sentences = sentences or self.sentences
+ labels = labels or self.labels
+
+ collate_fn = collate_fn or collate_pad_data(self.config.max_depth)
+ self.sentences = sentences
+ f_names = self.field_names
+ examples = [
+ dict(zip(f_names, e))
+ for e in zip(input_ids, word_begin_indexes, labels, range(len(sentences)))
+ ]
+ dataloader = DataLoader(
+ examples, batch_size=batch_size, collate_fn=collate_fn, shuffle=shuffle
+ )
+ return dataloader
+
+ def normalize(self, sentence):
+ if isinstance(sentence, str):
+ sentence = sentence.strip(" |\n")
+ sentence = sentence.replace("’", "'")
+ sentence = sentence.replace("”", "''")
+ sentence = sentence.replace("“", "''")
+ return sentence
+
+ @classmethod
+ def load_pretrained(cls, path):
+ path_file = get_config(path)
+ with open(path_file) as f:
+ config = json.load(f)
+ return cls(config)
+
+ def save_pretrained(self, path):
+ config = dict(self.config)
+ config.pop("__class__")
+ with open(path, "w") as f:
+ json.dump(config, f)
+
+
+class DataForTriplet(Data):
+ def __init__(self, config):
+ config["max_depth"] = 5
+ super().__init__(config)
+
+ self.field_names = ("input_ids", "word_begin_index", "labels", "sentence_index")
+ self.label_dict = {
+ "NONE": 0,
+ "ARG1": 1,
+ "REL": 2,
+ "ARG2": 3,
+ "LOC": 4,
+ "TIME": 4,
+ "TYPE": 5,
+ "ARGS": 3,
+ }
+
+ def encode_sentence(self, sentence, labels=None):
+ sentence = self.normalize(sentence)
+ tokens = nltk.word_tokenize(sentence)
+ sentence = " ".join(tokens[: self.max_word_length])
+ u_sentence = (
+ sentence.strip() + " " + " ".join(self.config.additional_special_tokens)
+ )
+ word_tokens = self.tokenizer.batch_encode_plus(
+ u_sentence.split(), add_special_tokens=False
+ )["input_ids"]
+
+ input_ids, word_begin_index = [], []
+ for i in word_tokens:
+ word_begin_index.append(len(input_ids) + 1)
+ input_ids.extend(i or [100])
+
+ input_ids = [self.bos_token_id] + input_ids + [self.eos_token_id]
+ labels = [labels] if isinstance(labels, str) else labels
+ if labels and isinstance(labels, list):
+ labels = [
+ [self.label_dict[i] for i in label_at_depth.split()]
+ for label_at_depth in labels
+ ]
+ labels = [pad_list(l, len(word_begin_index), 0) for l in labels]
+ else:
+ labels = [[0] * len(word_begin_index)]
+ return input_ids, word_begin_index, sentence, labels
+
+ def batch_decode_prediction(self, predictions, scores, sentences=None):
+ sentences = sentences or self.sentences
+ return [
+ DataForTriplet.decode_prediction(pred, score, sent)
+ for pred, score, sent in zip(predictions, scores, sentences)
+ ]
+
+ @staticmethod
+ def decode_prediction(predictions, scores, sentence):
+ # prediction: shape of (depth, max_word_length)
+ # scores: shape of (depth)
+
+ words = sentence.split() + ADDITIONAL_SPECIAL_TOKENS
+ predictions, indices = torch.unique(
+ predictions, dim=0, sorted=False, return_inverse=True
+ )
+ scores = scores[indices.unique()]
+
+ mask_non_null = predictions.sum(dim=-1) != 0
+ predictions = predictions[mask_non_null]
+ scores = scores[mask_non_null]
+
+ extractions = []
+ for prediction, score in zip(predictions, scores):
+ prediction = prediction[: len(words)]
+ pro_extraction = process_extraction(prediction, words, score.item())
+ if pro_extraction.args[0] and pro_extraction.pred:
+ extracted_triplet = ext_to_triplet(pro_extraction)
+ extractions.append(extracted_triplet)
+
+ output = dict()
+ output["triplet"] = extractions
+ output["sentence"] = sentence
+
+ return output
+
+
+class DataForConjunction(Data):
+ def __init__(self, config):
+ config["max_depth"] = 3
+ super().__init__(config)
+
+ self.field_names = ("input_ids", "word_begin_index", "labels", "sentence_index")
+ self.label_dict = {
+ "CP_START": 2,
+ "CP": 1,
+ "CC": 3,
+ "SEP": 4,
+ "OTHERS": 5,
+ "NONE": 0,
+ }
+
+ def encode_sentence(self, sentence, labels=None):
+ sentence = self.normalize(sentence)
+ tokens = nltk.word_tokenize(sentence)
+ sentence = " ".join(tokens[: self.max_word_length])
+ u_sentence = sentence.strip() + " [unused1] [unused2] [unused3]"
+
+ word_tokens = self.tokenizer.batch_encode_plus(
+ u_sentence.split(), add_special_tokens=False
+ )["input_ids"]
+ input_ids, word_begin_index = [], []
+ for i in word_tokens:
+ word_begin_index.append(len(input_ids) + 1)
+ input_ids.extend(i or [100])
+
+ input_ids = [self.bos_token_id] + input_ids + [self.eos_token_id]
+ labels = [labels] if isinstance(labels, str) else labels
+ if labels and isinstance(labels, list):
+
+ labels = [
+ [self.label_dict[i] for i in label_at_depth.split()]
+ for label_at_depth in labels
+ ]
+ labels = [pad_list(l, len(word_begin_index), 0) for l in labels]
+ else:
+ labels = [[0] * len(word_begin_index)]
+ return input_ids, word_begin_index, sentence, labels
+
+ def batch_decode_prediction(self, predictions, sentences=None, **kw):
+ # prediction batch_size, depth, num_words, sentences
+ sentences = sentences or self.sentences
+ output = []
+ for prediction, sentence in zip(predictions, sentences):
+ words = sentence.split()
+ len_words = len(words)
+ prediction = [p[:len_words] for p in prediction.tolist()]
+ coords = get_coords(prediction)
+ output_sentences = coords_to_sentences(coords, words)
+ output.append(
+ {
+ "sentence": sentence,
+ "prediction": output_sentences[0],
+ "conjugation_words": output_sentences[1],
+ }
+ )
+ return output
+
+
+def get_coords(predictions):
+ all_coordination_phrases = dict()
+ for depth, depth_prediction in enumerate(predictions):
+ coordination_phrase, start_index = None, -1
+ coordphrase, conjunction, coordinator, separator = False, False, False, False
+ for i, label in enumerate(depth_prediction):
+ if label != 1: # conjunction can end
+ if conjunction and coordination_phrase != None:
+ conjunction = False
+ coordination_phrase["conjuncts"].append((start_index, i - 1))
+ if label == 0 or label == 2: # coordination phrase can end
+ if (
+ coordination_phrase != None
+ and len(coordination_phrase["conjuncts"]) >= 2
+ and coordination_phrase["cc"]
+ > coordination_phrase["conjuncts"][0][1]
+ and coordination_phrase["cc"]
+ < coordination_phrase["conjuncts"][-1][0]
+ ):
+ coordination = Coordination(
+ coordination_phrase["cc"],
+ coordination_phrase["conjuncts"],
+ label=depth,
+ )
+ all_coordination_phrases[coordination_phrase["cc"]] = coordination
+ coordination_phrase = None
+ if label == 0:
+ continue
+ if label == 1: # can start a conjunction
+ if not conjunction:
+ conjunction = True
+ start_index = i
+ if label == 2: # starts a coordination phrase
+ coordination_phrase = {"cc": -1, "conjuncts": [], "seps": []}
+ conjunction = True
+ start_index = i
+ if label == 3 and coordination_phrase != None:
+ coordination_phrase["cc"] = i
+ if label == 4 and coordination_phrase != None:
+ coordination_phrase["seps"].append(i)
+ if label == 5: # nothing to be done
+ continue
+ if label == 3 and coordination_phrase == None:
+ # coordinating words which do not have associated conjuncts
+ all_coordination_phrases[i] = None
+ return all_coordination_phrases
diff --git a/simplify/simplifiers/ie/model.py b/simplify/simplifiers/ie/model.py
new file mode 100644
index 0000000..a71f7d8
--- /dev/null
+++ b/simplify/simplifiers/ie/model.py
@@ -0,0 +1,173 @@
+import torch
+import torch.nn as nn
+from transformers import AutoConfig, AutoModel
+
+from data import dot_dict
+from utils import get_weight, read_config_file
+
+
+class IEModel(nn.Module):
+ def __init__(self, config):
+ self.config = config
+
+ def forward(self, ):
+ ...
+
+class IEConfig:
+ def __init__(self,
+ backbone_model:'bert-large-cased',
+ max_depth:int=3,
+ num_iterative_layers:int=2,
+ dropout_rate:float=0.0,
+ label_hidden_size:int=300,
+ label_size:int=100,
+ num_labels:int=6,
+ max_word_length:int=100,
+ additional_special_tokens:List[str]=[ "[unused1]", "[unused2]", "[unused3]"],
+ **kwargs):
+
+ self.backbone_model= backbone_model
+ self.max_depth= max_depth
+ self.num_iterative_layers= num_iterative_layers
+ self.dropout_rate= dropout_rate
+ self.label_hidden_size= label_hidden_size
+ self.label_size= label_size
+ self.num_labels= num_labels
+ self.max_word_length= max_word_length
+ self.additional_special_tokens= additional_special_tokens
+ self.kwargs = kwargs
+
+ @classmethod
+ def load_config(cls, config_file):
+ config = read_config_file(config_file)
+ return cls(**config)
+
+ def save_config(self, file_path):
+ with open(file_path, 'w') as f:
+ json.dump(self.config, f)
+
+config = {
+
+ "max_depth": 3,
+ "model_path": "bert-large-cased",
+ "num_iterative_layers": 2,
+ "dropout_rate": 0.0,
+ "label_size":100,
+ "label_hidden_size":300,
+ "num_labels": 6,
+ "additional_special_tokens": [ "[unused1]", "[unused2]", "[unused3]"],
+ "max_word_length": 100,
+}
+
+class Models:
+ def __init__(self, config):
+ self.config = dot_dict(config)
+ self.config.num_labels = self.config.num_labels or 6
+ self.config.label_size = self.config.label_size or 100
+ self.config.label_hidden_size = self.config.label_hidden_size or 300
+
+ self.config.num_iterative_layers = self.config.num_iterative_layers or 2
+ self.config.drop_rate = self.config.drop_rate or 0.0
+
+ def save_pretrained(self, path):
+ torch.save({"config": dict(self.config), "state_dict": self.state_dict()}, path)
+
+ @classmethod
+ def load_pretrained(cls, path, device="cpu"):
+ path_file = get_weight(path)
+ config_and_state_dict = torch.load(path_file, map_location=device)
+ model = cls(config_and_state_dict["config"])
+ model.load_state_dict(config_and_state_dict["state_dict"])
+ return model
+
+ def is_valid_extraction(self, prediction):
+ return any([(1 in p and 2 in p) for p in prediction])
+
+
+class BackboneModel(nn.Module, Models):
+ def __init__(self, config):
+ nn.Module.__init__(self)
+ Models.__init__(self, config)
+
+ model_config = AutoConfig.from_pretrained(self.config.model_path)
+ self.base_model = AutoModel.from_config(model_config)
+ self.hidden_size = self.base_model.config.hidden_size
+
+ if self.config.num_iterative_layers == 0:
+ self.iterative_layers = []
+ else:
+ self.iterative_layers = self.base_model.encoder.layer[
+ -self.config.num_iterative_layers :
+ ]
+ self.base_model.encoder.layer = self.base_model.encoder.layer[
+ : -self.config.num_iterative_layers
+ ]
+ self.dropout = nn.Dropout(self.config.drop_rate)
+
+ self.label_embedding = nn.Embedding(self.config.label_size, self.hidden_size)
+ self.linear1 = nn.Linear(self.hidden_size, self.config.label_hidden_size)
+ self.linear2 = nn.Linear(self.config.label_hidden_size, self.config.num_labels)
+
+ def forward(self, input_ids, word_begin_index, labels=None):
+ hidden_states = self.base_model(input_ids)["last_hidden_state"]
+ prediction_in_depth = []
+ word_scores_in_depth = []
+ for d in range(self.config.depth):
+ for layer in self.iterative_layers:
+ hidden_states = layer(hidden_states)[0]
+ hidden_states = self.dropout(hidden_states)
+ un_word_begin_index = word_begin_index.unsqueeze(2).repeat(
+ 1, 1, self.hidden_size
+ )
+ word_hidden_states = torch.gather(
+ hidden_states, dim=1, index=un_word_begin_index
+ )
+
+ if d != 0:
+ index_words = torch.argmax(word_scores, dim=-1)
+ label_embeddings = self.label_embedding(index_words)
+ word_hidden_states = word_hidden_states + label_embeddings
+
+ word_hidden_states = self.linear1(word_hidden_states)
+ word_scores = self.linear2(word_hidden_states)
+
+ prediction = torch.argmax(word_scores, dim=-1)
+
+ word_scores_in_depth.append(word_scores)
+ prediction_in_depth.append(prediction)
+
+ if not self.is_valid_extraction(prediction):
+ break
+ # word_scores = depth [batch,num_words]
+ word_scores_in_depth = torch.stack(word_scores_in_depth, dim=1)
+
+ scores = [
+ self.calculate_score(word_scores, label)
+ for word_scores, label in zip(word_scores_in_depth, labels)
+ ]
+ output = dict()
+ output["scores"] = torch.stack(scores, dim=0)
+ output["predictions"] = torch.stack(prediction_in_depth, dim=1)
+ return output
+
+ def calculate_score(self, word_scores, labels):
+ max_log_probs, predictions = torch.log_softmax(word_scores, dim=-1).max(dim=-1)
+ mask_labels = (labels[0, :] != -100).float()
+ mask_predictions = (predictions != 0).float() * mask_labels
+ log_prob = (max_log_probs * mask_predictions) / (
+ mask_predictions.sum(dim=1) + 1
+ ).unsqueeze(-1)
+ scores = torch.exp(log_prob.sum(dim=1))
+ return scores
+
+
+class ModelForConjunction(nn.Module, Models):
+ def __init__(self, config):
+ nn.Module.__init__(self)
+ Models.__init__(self, config)
+
+ self.config.depth = 3
+ self.model = BackboneModel(self.config)
+
+ def forward(self, input_ids, word_begin_index, labels=None):
+ return self.model(input_ids, word_begin_index, labels)
diff --git a/simplify/simplifiers/ie/pipeline.py b/simplify/simplifiers/ie/pipeline.py
new file mode 100644
index 0000000..5fd85f2
--- /dev/null
+++ b/simplify/simplifiers/ie/pipeline.py
@@ -0,0 +1,30 @@
+import torch
+
+from data import Data
+from model import Models
+
+
+class Pipeline:
+ def __init__(self, model: Models, data: Data):
+ self.model = model.eval()
+ self.data = data
+
+ def run(self, sentences):
+ self.data.batch_encode_sentence(sentences, inplace=True)
+ dataloader = self.data.to_dataloader()
+ output = []
+ with torch.no_grad():
+ for batch in dataloader:
+ model_output = self.model(
+ batch.input_ids, batch.word_begin_index, batch.labels
+ )
+ batch_sentences = [
+ self.data.sentences[i] for i in batch.sentence_index.tolist()
+ ]
+ # output dict prediction and scores of shape batch_size, depth, num_word)
+ # batch_sentences (batch_size)
+ out = self.data.batch_decode_prediction(
+ **model_output, sentences=batch_sentences
+ )
+ output.extend(out)
+ return output
diff --git a/simplify/simplifiers/ie/utils.py b/simplify/simplifiers/ie/utils.py
new file mode 100644
index 0000000..b3f0df0
--- /dev/null
+++ b/simplify/simplifiers/ie/utils.py
@@ -0,0 +1,812 @@
+import itertools
+import logging
+import os
+import warnings
+from operator import itemgetter
+
+import nltk
+import numpy as np
+from transformers.file_utils import cached_path, hf_bucket_url
+
+SEP = ";;;"
+QUESTION_TRG_INDEX = 3 # index of the predicate within the question
+QUESTION_PP_INDEX = 5
+QUESTION_OBJ2_INDEX = 6
+DATA_CONFIG_NAME = "data-config.json"
+WEIGHT_NAME = "model.pt"
+
+
+def read_config_file(file_path):
+ with open(file_path) as f:
+ return json.loads(f.read())
+
+def get_config(path):
+ return _get_file(path, DATA_CONFIG_NAME)
+
+
+def get_weight(path):
+ return _get_file(path, WEIGHT_NAME)
+
+
+def _get_file(path, file_name):
+ possible_config_file = os.path.join(path, file_name)
+ if os.path.isdir(path) and on.path.isfile(possible_config_file):
+ config_file = possible_config_file
+ elif os.path.isfile(path):
+ config_file = path
+ else:
+ config_file = hf_bucket_url(path, filename=file_name)
+ try:
+ return cached_path(config_file)
+ except Exception as e:
+ raise f"can't load {path}, due to {str(e)}"
+
+
+
+
+def process_extraction(extraction, sentence, score):
+ # rel, arg1, arg2, loc, time = [], [], [], [], []
+ rel, arg1, arg2, loc_time, args = [], [], [], [], []
+ tag_mode = "none"
+ rel_case = 0
+ for i, token in enumerate(sentence):
+ if "[unused" in token:
+ if extraction[i].item() == 2:
+ rel_case = int(re.search("\[unused(.*)\]", token).group(1))
+ continue
+ if extraction[i] == 1:
+ arg1.append(token)
+ if extraction[i] == 2:
+ rel.append(token)
+ if extraction[i] == 3:
+ arg2.append(token)
+ if extraction[i] == 4:
+ loc_time.append(token)
+ rel = " ".join(rel).strip()
+ if rel_case == 1:
+ rel = "is " + rel
+ elif rel_case == 2:
+ rel = "is " + rel + " of"
+ elif rel_case == 3:
+ rel = "is " + rel + " from"
+ arg1 = " ".join(arg1).strip()
+ arg2 = " ".join(arg2).strip()
+ args = " ".join(args).strip()
+ loc_time = " ".join(loc_time).strip()
+ arg2 = (arg2 + " " + loc_time + " " + args).strip()
+ sentence_str = " ".join(sentence).strip()
+ extraction = Extraction(
+ pred=rel, head_pred_index=None, sent=sentence_str, confidence=score, index=0
+ )
+ extraction.addArg(arg1)
+ extraction.addArg(arg2)
+ return extraction
+
+
+def coords_to_sentences(conj_coords, words):
+
+ for k in list(conj_coords):
+ if conj_coords[k] is None:
+ conj_coords.pop(k)
+
+ for k in list(conj_coords):
+ if words[conj_coords[k].cc] in ["nor", "&"]: # , 'or']:
+ conj_coords.pop(k)
+
+ remove_unbreakable_conjuncts(conj_coords, words)
+
+ conj_words = []
+ for k in list(conj_coords):
+ for conjunct in conj_coords[k].conjuncts:
+ conj_words.append(" ".join(words[conjunct[0] : conjunct[1] + 1]))
+
+ sentence_indices = []
+ for i in range(0, len(words)):
+ sentence_indices.append(i)
+
+ roots, parent_mapping, child_mapping = get_tree(conj_coords)
+
+ q = list(roots)
+
+ sentences = []
+ count = len(q)
+ new_count = 0
+
+ conj_same_level = []
+
+ while len(q) > 0:
+
+ conj = q.pop(0)
+ count -= 1
+ conj_same_level.append(conj)
+
+ for child in child_mapping[conj]:
+ q.append(child)
+ new_count += 1
+
+ if count == 0:
+ get_sentences(sentences, conj_same_level, conj_coords, sentence_indices)
+ count = new_count
+ new_count = 0
+ conj_same_level = []
+
+ word_sentences = [
+ " ".join([words[i] for i in sorted(sentence)]) for sentence in sentences
+ ]
+
+ return word_sentences, conj_words, sentences
+
+
+def get_tree(conj):
+ parent_child_list = []
+
+ child_mapping, parent_mapping = {}, {}
+
+ for key in conj:
+ assert conj[key].cc == key
+ parent_child_list.append([])
+ for k in conj:
+ if conj[k] is not None:
+ if is_parent(conj[key], conj[k]):
+ parent_child_list[-1].append(k)
+
+ child_mapping[key] = parent_child_list[-1]
+
+ parent_child_list.sort(key=list.__len__)
+
+ for i in range(0, len(parent_child_list)):
+ for child in parent_child_list[i]:
+ for j in range(i + 1, len(parent_child_list)):
+ if child in parent_child_list[j]:
+ parent_child_list[j].remove(child)
+
+ for key in conj:
+ for child in child_mapping[key]:
+ parent_mapping[child] = key
+
+ roots = []
+ for key in conj:
+ if key not in parent_mapping:
+ roots.append(key)
+
+ return roots, parent_mapping, child_mapping
+
+
+def is_parent(parent, child):
+ min = child.conjuncts[0][0]
+ max = child.conjuncts[-1][-1]
+
+ for conjunct in parent.conjuncts:
+ if conjunct[0] <= min and conjunct[1] >= max:
+ return True
+ return False
+
+
+def get_sentences(sentences, conj_same_level, conj_coords, sentence_indices):
+ for conj in conj_same_level:
+
+ if len(sentences) == 0:
+
+ for conj_structure in conj_coords[conj].conjuncts:
+ sentence = []
+ for i in range(conj_structure[0], conj_structure[1] + 1):
+ sentence.append(i)
+ sentences.append(sentence)
+
+ min = conj_coords[conj].conjuncts[0][0]
+ max = conj_coords[conj].conjuncts[-1][-1]
+
+ for sentence in sentences:
+ for i in sentence_indices:
+ if i < min or i > max:
+ sentence.append(i)
+
+ else:
+ to_add = []
+ to_remove = []
+
+ for sentence in sentences:
+ if conj_coords[conj].conjuncts[0][0] in sentence:
+ sentence.sort()
+
+ min = conj_coords[conj].conjuncts[0][0]
+ max = conj_coords[conj].conjuncts[-1][-1]
+
+ for conj_structure in conj_coords[conj].conjuncts:
+ new_sentence = []
+ for i in sentence:
+ if (
+ i in range(conj_structure[0], conj_structure[1] + 1)
+ or i < min
+ or i > max
+ ):
+ new_sentence.append(i)
+
+ to_add.append(new_sentence)
+
+ to_remove.append(sentence)
+
+ for sent in to_remove:
+ sentences.remove(sent)
+ sentences.extend(to_add)
+
+
+def remove_unbreakable_conjuncts(conj, words):
+
+ unbreakable_indices = []
+
+ unbreakable_words = [
+ "between",
+ "among",
+ "sum",
+ "total",
+ "addition",
+ "amount",
+ "value",
+ "aggregate",
+ "gross",
+ "mean",
+ "median",
+ "average",
+ "center",
+ "equidistant",
+ "middle",
+ ]
+
+ for i, word in enumerate(words):
+ if word.lower() in unbreakable_words:
+ unbreakable_indices.append(i)
+
+ to_remove = []
+ span_start = 0
+
+ for key in conj:
+ span_end = conj[key].conjuncts[0][0] - 1
+ for i in unbreakable_indices:
+ if span_start <= i <= span_end:
+ to_remove.append(key)
+ span_start = conj[key].conjuncts[-1][-1] + 1
+
+ for k in set(to_remove):
+ conj.pop(k)
+
+
+class Coordination(object):
+ __slots__ = ("cc", "conjuncts", "seps", "label")
+
+ def __init__(self, cc, conjuncts, seps=None, label=None):
+ assert isinstance(conjuncts, (list, tuple)) and len(conjuncts) >= 2
+ assert all(isinstance(conj, tuple) for conj in conjuncts)
+ conjuncts = sorted(conjuncts, key=lambda span: span[0])
+ # NOTE(chantera): The form 'A and B, C' is considered to be coordination. # NOQA
+ # assert cc > conjuncts[-2][1] and cc < conjuncts[-1][0]
+ assert cc > conjuncts[0][1] and cc < conjuncts[-1][0]
+ if seps is not None:
+ # if len(seps) == len(conjuncts) - 2:
+ # for i, sep in enumerate(seps):
+ # assert conjuncts[i][1] < sep and conjuncts[i + 1][0] > sep
+ # else:
+ if len(seps) != len(conjuncts) - 2:
+ warnings.warn(
+ "Coordination does not contain enough separators. "
+ "It may be a wrong coordination: "
+ "cc={}, conjuncts={}, separators={}".format(cc, conjuncts, seps)
+ )
+ else:
+ seps = []
+ self.cc = cc
+ self.conjuncts = tuple(conjuncts)
+ self.seps = tuple(seps)
+ self.label = label
+
+ def get_pair(self, index, check=False):
+ pair = None
+ for i in range(1, len(self.conjuncts)):
+ if self.conjuncts[i][0] > index:
+ pair = (self.conjuncts[i - 1], self.conjuncts[i])
+ assert pair[0][1] < index and pair[1][0] > index
+ break
+ if check and pair is None:
+ raise LookupError("Could not find any pair for index={}".format(index))
+ return pair
+
+ def __repr__(self):
+ return "Coordination(cc={}, conjuncts={}, seps={}, label={})".format(
+ self.cc, self.conjuncts, self.seps, self.label
+ )
+
+ def __eq__(self, other):
+ if not isinstance(other, Coordination):
+ return False
+ return (
+ self.cc == other.cc
+ and len(self.conjuncts) == len(other.conjuncts)
+ and all(
+ conj1 == conj2 for conj1, conj2 in zip(self.conjuncts, other.conjuncts)
+ )
+ )
+
+
+class Extraction:
+ """
+ Stores sentence, single predicate and corresponding arguments.
+ """
+
+ def __init__(
+ self, pred, head_pred_index, sent, confidence, question_dist="", index=-1
+ ):
+ self.pred = pred
+ self.head_pred_index = head_pred_index
+ self.sent = sent
+ self.args = []
+ self.confidence = confidence
+ self.matched = []
+ self.questions = {}
+ # self.indsForQuestions = defaultdict(lambda: set())
+ self.is_mwp = False
+ self.question_dist = question_dist
+ self.index = index
+
+ def distArgFromPred(self, arg):
+ assert len(self.pred) == 2
+ dists = []
+ for x in self.pred[1]:
+ for y in arg.indices:
+ dists.append(abs(x - y))
+
+ return min(dists)
+
+ def argsByDistFromPred(self, question):
+ return sorted(
+ self.questions[question], key=lambda arg: self.distArgFromPred(arg)
+ )
+
+ def addArg(self, arg, question=None):
+ self.args.append(arg)
+ if question:
+ self.questions[question] = self.questions.get(question, []) + [
+ Argument(arg)
+ ]
+
+ def noPronounArgs(self):
+ """
+ Returns True iff all of this extraction's arguments are not pronouns.
+ """
+ for (a, _) in self.args:
+ tokenized_arg = nltk.word_tokenize(a)
+ if len(tokenized_arg) == 1:
+ _, pos_tag = nltk.pos_tag(tokenized_arg)[0]
+ if "PRP" in pos_tag:
+ return False
+ return True
+
+ def isContiguous(self):
+ return all([indices for (_, indices) in self.args])
+
+ def toBinary(self):
+ """Try to represent this extraction's arguments as binary
+ If fails, this function will return an empty list."""
+
+ ret = [self.elementToStr(self.pred)]
+
+ if len(self.args) == 2:
+ # we're in luck
+ return ret + [self.elementToStr(arg) for arg in self.args]
+
+ return []
+
+ if not self.isContiguous():
+ # give up on non contiguous arguments (as we need indexes)
+ return []
+
+ # otherwise, try to merge based on indices
+ # TODO: you can explore other methods for doing this
+ binarized = self.binarizeByIndex()
+
+ if binarized:
+ return ret + binarized
+
+ return []
+
+ def elementToStr(self, elem, print_indices=True):
+ """formats an extraction element (pred or arg) as a raw string
+ removes indices and trailing spaces"""
+ if print_indices:
+ return str(elem)
+ if isinstance(elem, str):
+ return elem
+ if isinstance(elem, tuple):
+ ret = elem[0].rstrip().lstrip()
+ else:
+ ret = " ".join(elem.words)
+ assert ret, "empty element? {0}".format(elem)
+ return ret
+
+ def binarizeByIndex(self):
+ extraction = [self.pred] + self.args
+ markPred = [(w, ind, i == 0) for i, (w, ind) in enumerate(extraction)]
+ sortedExtraction = sorted(markPred, key=lambda ws, indices, f: indices[0])
+ s = " ".join(
+ [
+ "{1} {0} {1}".format(self.elementToStr(elem), SEP)
+ if elem[2]
+ else self.elementToStr(elem)
+ for elem in sortedExtraction
+ ]
+ )
+ binArgs = [a for a in s.split(SEP) if a.rstrip().lstrip()]
+
+ if len(binArgs) == 2:
+ return binArgs
+
+ # failure
+ return []
+
+ def bow(self):
+ return " ".join([self.elementToStr(elem) for elem in [self.pred] + self.args])
+
+ def getSortedArgs(self):
+ """
+ Sort the list of arguments.
+ If a question distribution is provided - use it,
+ otherwise, default to the order of appearance in the sentence.
+ """
+ if self.question_dist:
+ # There's a question distribtuion - use it
+ return self.sort_args_by_distribution()
+ ls = []
+ for q, args in self.questions.iteritems():
+ if len(args) != 1:
+ logging.debug("Not one argument: {}".format(args))
+ continue
+ arg = args[0]
+ indices = list(self.indsForQuestions[q].union(arg.indices))
+ if not indices:
+ logging.debug("Empty indexes for arg {} -- backing to zero".format(arg))
+ indices = [0]
+ ls.append(((arg, q), indices))
+ return [a for a, _ in sorted(ls, key=lambda _, indices: min(indices))]
+
+ def question_prob_for_loc(self, question, loc):
+ """
+ Returns the probability of the given question leading to argument
+ appearing in the given location in the output slot.
+ """
+ gen_question = generalize_question(question)
+ q_dist = self.question_dist[gen_question]
+ logging.debug("distribution of {}: {}".format(gen_question, q_dist))
+
+ return float(q_dist.get(loc, 0)) / sum(q_dist.values())
+
+ def sort_args_by_distribution(self):
+ """
+ Use this instance's question distribution (this func assumes it exists)
+ in determining the positioning of the arguments.
+ Greedy algorithm:
+ 0. Decide on which argument will serve as the ``subject'' (first slot) of this extraction
+ 0.1 Based on the most probable one for this spot
+ (special care is given to select the highly-influential subject position)
+ 1. For all other arguments, sort arguments by the prevalance of their questions
+ 2. For each argument:
+ 2.1 Assign to it the most probable slot still available
+ 2.2 If non such exist (fallback) - default to put it in the last location
+ """
+ INF_LOC = 100 # Used as an impractical last argument
+
+ # Store arguments by slot
+ ret = {INF_LOC: []}
+ logging.debug("sorting: {}".format(self.questions))
+
+ # Find the most suitable arguemnt for the subject location
+ logging.debug(
+ "probs for subject: {}".format(
+ [
+ (q, self.question_prob_for_loc(q, 0))
+ for (q, _) in self.questions.iteritems()
+ ]
+ )
+ )
+
+ subj_question, subj_args = max(
+ self.questions.iteritems(),
+ key=lambda q, _: self.question_prob_for_loc(q, 0),
+ )
+
+ ret[0] = [(subj_args[0], subj_question)]
+
+ # Find the rest
+ for (question, args) in sorted(
+ [
+ (q, a)
+ for (q, a) in self.questions.iteritems()
+ if (q not in [subj_question])
+ ],
+ key=lambda q, _: sum(self.question_dist[generalize_question(q)].values()),
+ reverse=True,
+ ):
+ gen_question = generalize_question(question)
+ arg = args[0]
+ assigned_flag = False
+ for (loc, count) in sorted(
+ self.question_dist[gen_question].iteritems(),
+ key=lambda _, c: c,
+ reverse=True,
+ ):
+ if loc not in ret:
+ # Found an empty slot for this item
+ # Place it there and break out
+ ret[loc] = [(arg, question)]
+ assigned_flag = True
+ break
+
+ if not assigned_flag:
+ # Add this argument to the non-assigned (hopefully doesn't happen much)
+ logging.debug(
+ "Couldn't find an open assignment for {}".format(
+ (arg, gen_question)
+ )
+ )
+ ret[INF_LOC].append((arg, question))
+
+ logging.debug("Linearizing arg list: {}".format(ret))
+
+ # Finished iterating - consolidate and return a list of arguments
+ return [
+ arg
+ for (_, arg_ls) in sorted(ret.iteritems(), key=lambda k, v: int(k))
+ for arg in arg_ls
+ ]
+
+ def __str__(self):
+ pred_str = self.elementToStr(self.pred)
+ return "{}\t{}\t{}".format(
+ self.get_base_verb(pred_str),
+ self.compute_global_pred(pred_str, self.questions.keys()),
+ "\t".join(
+ [
+ escape_special_chars(
+ self.augment_arg_with_question(self.elementToStr(arg), question)
+ )
+ for arg, question in self.getSortedArgs()
+ ]
+ ),
+ )
+
+ def get_base_verb(self, surface_pred):
+ """
+ Given the surface pred, return the original annotated verb
+ """
+ # Assumes that at this point the verb is always the last word
+ # in the surface predicate
+ return surface_pred.split(" ")[-1]
+
+ def compute_global_pred(self, surface_pred, questions):
+ """
+ Given the surface pred and all instansiations of questions,
+ make global coherence decisions regarding the final form of the predicate
+ This should hopefully take care of multi word predicates and correct inflections
+ """
+ from operator import itemgetter
+
+ split_surface = surface_pred.split(" ")
+
+ if len(split_surface) > 1:
+ # This predicate has a modal preceding the base verb
+ verb = split_surface[-1]
+ ret = split_surface[:-1] # get all of the elements in the modal
+ else:
+ verb = split_surface[0]
+ ret = []
+
+ split_questions = map(lambda question: question.split(" "), questions)
+
+ preds = map(
+ normalize_element, map(itemgetter(QUESTION_TRG_INDEX), split_questions)
+ )
+ if len(set(preds)) > 1:
+ # This predicate is appears in multiple ways, let's stick to the base form
+ ret.append(verb)
+
+ if len(set(preds)) == 1:
+ # Change the predciate to the inflected form
+ # if there's exactly one way in which the predicate is conveyed
+ ret.append(preds[0])
+
+ pps = map(
+ normalize_element, map(itemgetter(QUESTION_PP_INDEX), split_questions)
+ )
+
+ obj2s = map(
+ normalize_element, map(itemgetter(QUESTION_OBJ2_INDEX), split_questions)
+ )
+
+ if len(set(pps)) == 1:
+ # If all questions for the predicate include the same pp attachemnt -
+ # assume it's a multiword predicate
+ self.is_mwp = (
+ True # Signal to arguments that they shouldn't take the preposition
+ )
+ ret.append(pps[0])
+
+ # Concat all elements in the predicate and return
+ return " ".join(ret).strip()
+
+ def augment_arg_with_question(self, arg, question):
+ """
+ Decide what elements from the question to incorporate in the given
+ corresponding argument
+ """
+ # Parse question
+ wh, aux, sbj, trg, obj1, pp, obj2 = map(
+ normalize_element, question.split(" ")[:-1]
+ ) # Last split is the question mark
+
+ # Place preposition in argument
+ # This is safer when dealing with n-ary arguments, as it's directly attaches to the
+ # appropriate argument
+ if (not self.is_mwp) and pp and (not obj2):
+ if not (arg.startswith("{} ".format(pp))):
+ # Avoid repeating the preporition in cases where both question and answer contain it
+ return " ".join([pp, arg])
+
+ # Normal cases
+ return arg
+
+ def clusterScore(self, cluster):
+ """
+ Calculate cluster density score as the mean distance of the maximum distance of each slot.
+ Lower score represents a denser cluster.
+ """
+ logging.debug("*-*-*- Cluster: {}".format(cluster))
+
+ # Find global centroid
+ arr = np.array([x for ls in cluster for x in ls])
+ centroid = np.sum(arr) / arr.shape[0]
+ logging.debug("Centroid: {}".format(centroid))
+
+ # Calculate mean over all maxmimum points
+ return np.average([max([abs(x - centroid) for x in ls]) for ls in cluster])
+
+ def resolveAmbiguity(self):
+ """
+ Heursitic to map the elments (argument and predicates) of this extraction
+ back to the indices of the sentence.
+ """
+ ## TODO: This removes arguments for which there was no consecutive span found
+ ## Part of these are non-consecutive arguments,
+ ## but other could be a bug in recognizing some punctuation marks
+
+ elements = [self.pred] + [(s, indices) for (s, indices) in self.args if indices]
+ logging.debug("Resolving ambiguity in: {}".format(elements))
+
+ # Collect all possible combinations of arguments and predicate indices
+ # (hopefully it's not too much)
+ all_combinations = list(itertools.product(*map(itemgetter(1), elements)))
+ logging.debug("Number of combinations: {}".format(len(all_combinations)))
+
+ # Choose the ones with best clustering and unfold them
+ resolved_elements = zip(
+ map(itemgetter(0), elements),
+ min(all_combinations, key=lambda cluster: self.clusterScore(cluster)),
+ )
+ logging.debug("Resolved elements = {}".format(resolved_elements))
+
+ self.pred = resolved_elements[0]
+ self.args = resolved_elements[1:]
+
+ def conll(self, external_feats={}):
+ """
+ Return a CoNLL string representation of this extraction
+ """
+ return (
+ "\n".join(
+ [
+ "\t".join(
+ map(
+ str,
+ [i, w]
+ + list(self.pred)
+ + [self.head_pred_index]
+ + external_feats
+ + [self.get_label(i)],
+ )
+ )
+ for (i, w) in enumerate(self.sent.split(" "))
+ ]
+ )
+ + "\n"
+ )
+
+ def get_label(self, index):
+ """
+ Given an index of a word in the sentence -- returns the appropriate BIO conll label
+ Assumes that ambiguation was already resolved.
+ """
+ # Get the element(s) in which this index appears
+ ent = [
+ (elem_ind, elem)
+ for (elem_ind, elem) in enumerate(
+ map(itemgetter(1), [self.pred] + self.args)
+ )
+ if index in elem
+ ]
+
+ if not ent:
+ # index doesnt appear in any element
+ return "O"
+
+ if len(ent) > 1:
+ # The same word appears in two different answers
+ # In this case we choose the first one as label
+ logging.warn(
+ "Index {} appears in one than more element: {}".format(
+ index, "\t".join(map(str, [ent, self.sent, self.pred, self.args]))
+ )
+ )
+
+ ## Some indices appear in more than one argument (ones where the above message appears)
+ ## From empricial observation, these seem to mostly consist of different levels of granularity:
+ ## what had _ been taken _ _ _ ? loan commitments topping $ 3 billion
+ ## how much had _ been taken _ _ _ ? topping $ 3 billion
+ ## In these cases we heuristically choose the shorter answer span, hopefully creating minimal spans
+ ## E.g., in this example two arguemnts are created: (loan commitments, topping $ 3 billion)
+
+ elem_ind, elem = min(ent, key=lambda _, ls: len(ls))
+
+ # Distinguish between predicate and arguments
+ prefix = "P" if elem_ind == 0 else "A{}".format(elem_ind - 1)
+
+ # Distinguish between Beginning and Inside labels
+ suffix = "B" if index == elem[0] else "I"
+
+ return "{}-{}".format(prefix, suffix)
+
+ def __str__(self):
+ return "{0}\t{1}".format(
+ self.elementToStr(self.pred, print_indices=True),
+ "\t".join([self.elementToStr(arg) for arg in self.args]),
+ )
+
+
+class Argument:
+ def __init__(self, arg):
+ self.words = [x for x in arg[0].strip().split(" ") if x]
+ self.posTags = map(itemgetter(1), nltk.pos_tag(self.words))
+ self.indices = arg[1]
+ self.feats = {}
+
+ def __str__(self):
+ return "({})".format(
+ "\t".join(
+ map(
+ str,
+ [(" ".join(self.words)).replace("\t", "\\t"), str(self.indices)],
+ )
+ )
+ )
+
+
+def normalize_element(elem):
+ """
+ Return a surface form of the given question element.
+ the output should be properly able to precede a predicate (or blank otherwise)
+ """
+ return elem.replace("_", " ") if (elem != "_") else ""
+
+
+## Helper functions
+def escape_special_chars(s):
+ return s.replace("\t", "\\t")
+
+
+def generalize_question(question):
+ """
+ Given a question in the context of the sentence and the predicate index within
+ the question - return a generalized version which extracts only order-imposing features
+ """
+ import nltk # Using nltk since couldn't get spaCy to agree on the tokenization
+
+ wh, aux, sbj, trg, obj1, pp, obj2 = question.split(" ")[
+ :-1
+ ] # Last split is the question mark
+ return " ".join([wh, sbj, obj1])
diff --git a/simplify/simplifiers/models/__init__.py b/simplify/simplifiers/models/__init__.py
deleted file mode 100644
index 551b6fc..0000000
--- a/simplify/simplifiers/models/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from .muss import Muss
-from .discourse import Discourse
-
-
-__all_ = ['Discourse', 'Muss']
diff --git a/simplify/simplifiers/models/base.py b/simplify/simplifiers/models/base.py
deleted file mode 100644
index b211cd4..0000000
--- a/simplify/simplifiers/models/base.py
+++ /dev/null
@@ -1,9 +0,0 @@
-from abc import ABCMeta
-
-
-class BaseModel(metaclass=ABCMeta):
- @abstractmethod
- def __call__(
- self,
- ):
- NotImplemented
diff --git a/simplify/simplifiers/models/discourse/__init__.py b/simplify/simplifiers/models/discourse/__init__.py
deleted file mode 100644
index 3bcb386..0000000
--- a/simplify/simplifiers/models/discourse/__init__.py
+++ /dev/null
@@ -1,102 +0,0 @@
-"""
-Lambda-3/DiscourseSimplification is licensed under the
-
-GNU General Public License v3.0
-Permissions of this strong copyleft license are conditioned on making available complete source code of licensed works and modifications, which include larger works using a licensed work, under the same license. Copyright and license notices must be preserved. Contributors provide an express grant of patent rights.
-"""
-
-
-from multiprocessing import Pool
-import json
-
-__all__ = ["Discourse"]
-jarpath = "https://github.com/sadakmed/simplify/raw/master/.jar/discourse.jar"
-
-
-def with_jvm(paths):
- def nested_decorator(func):
- global wrapper
-
- def wrapper(sentences):
- import jpype
- import jpype.imports
-
- jpype.startJVM("-ea", classpath=paths)
- from org.slf4j.Logger import ROOT_LOGGER_NAME
- from org.lambda3.text.simplification.discourse.processing import (
- DiscourseSimplifier,
- ProcessingType,
- )
-
- logging = jpype.java.util.logging
- off = logging.Level.OFF
- logging.Logger.getLogger(ROOT_LOGGER_NAME).setLevel(off)
-
- modules = {
- "jpype": jpype,
- "DiscourseSimplifier": DiscourseSimplifier,
- "ProcessingType": ProcessingType,
- }
- func.__globals__.update(modules)
-
- simple_sentences = func(sentences)
-
- jpype.shutdownJVM()
- return simple_sentences
-
- return wrapper
-
- return nested_decorator
-
-
-@with_jvm(jarpath)
-def discourse(sentences: list):
- jlist_sentences = jpype.java.util.ArrayList(sentences)
- dis = DiscourseSimplifier()
- j_simple_sentences = dis.doDiscourseSimplification(
- jlist_sentences, ProcessingType.SEPARATE
- )
- p_simple_sentences = str(j_simple_sentences.serializeToJSON().toString())
- simple_sentences = json.loads(p_simple_sentences)
- return simple_sentences
-
-
-def discourse_old(sentences: list, paths: list):
- import jpype as jp
- import jpype.imports
-
- jp.startJVM("-ea", classpath=paths)
- from org.slf4j.Logger import ROOT_LOGGER_NAME
-
- print(ROOT_LOGGER_NAME)
- jp.java.util.logging.Logger.getLogger(ROOT_LOGGER_NAME).setLevel(
- jp.java.util.logging.Level.OFF
- )
-
- from org.lambda3.text.simplification.discourse.processing import (
- DiscourseSimplifier,
- ProcessingType,
- )
- from org.lambda3.text.simplification.discourse.model import SimplificationContent
-
- jlist_sentences = jpype.java.util.ArrayList(sentences)
- dis = DiscourseSimplifier()
- j_simple_content = dis.doDiscourseSimplification(
- jlist_sentences, ProcessingType.SEPARATE
- )
- p_simple_content = str(j_simple_content.serializeToJSON().toString())
- jp.shutdownJVM()
-
- out_dict = json.loads(p_simple_content)
- return out_dict
-
-
-class Discourse:
- def __init__(self):
- pass
-
- def __call__(self, sentences):
- with Pool(1) as p:
- # output = p.map(partial(discourse, paths=jarpath), [sentences])
- output = p.map(discourse, [sentences])
- return output
diff --git a/simplify/simplifiers/models/muss/__init__.py b/simplify/simplifiers/models/muss/__init__.py
deleted file mode 100644
index 2583f91..0000000
--- a/simplify/simplifiers/models/muss/__init__.py
+++ /dev/null
@@ -1,65 +0,0 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import shutil
-
-from .preprocessors import get_preprocessors
-from .utils import write_lines, read_lines, get_temp_filepath, download_and_extract
-from .simplifiers import get_fairseq_simplifier, get_preprocessed_simplifier
-from simplify import SIMPLIFY_CACHE
-
-__all__ = ["Muss"]
-
-ALLOWED_MODEL_NAMES = [
- "muss_en_wikilarge_mined",
- "muss_en_mined",
-]
-
-
-preprocessors_kwargs = {
- "LengthRatioPreprocessor": {"target_ratio": 0.9, "use_short_name": False},
- "ReplaceOnlyLevenshteinPreprocessor": {
- "target_ratio": 0.65,
- "use_short_name": False,
- },
- "WordRankRatioPreprocessor": {"target_ratio": 0.75, "use_short_name": False},
- "DependencyTreeDepthRatioPreprocessor": {
- "target_ratio": 0.4,
- "use_short_name": False,
- },
- "GPT2BPEPreprocessor": {},
-}
-
-
-def get_model_path(model_name):
- assert model_name in ALLOWED_MODEL_NAMES
- model_path = SIMPLIFY_CACHE / model_name
- if not model_path.exists():
- url = f"https://dl.fbaipublicfiles.com/muss/{model_name}.tar.gz"
- extracted_path = download_and_extract(url)[0]
- shutil.move(extracted_path, model_path)
- return model_path
-
-
-class Muss:
- def __init__(self, model_name: str = "muss_en_wikilarge_mined"):
- self.model_name = model_name
- self.preprocessors = None
- self.simplifier = None
- self._initialize()
-
- def _initialize(self):
- model_path = get_model_path(self.model_name)
- self.preprocessors = get_preprocessors(preprocessors_kwargs)
- simplifier = get_fairseq_simplifier(model_path)
- self.simplifier = get_preprocessed_simplifier(simplifier, self.preprocessors)
-
- def __call__(self, sentences):
- source_path = get_temp_filepath()
- write_lines(sentences, source_path)
- prediction_path = self.simplifier(source_path)
- prediction_sentences = read_lines(prediction_path)
- return prediction_sentences
diff --git a/simplify/simplifiers/models/muss/fairseq_util.py b/simplify/simplifiers/models/muss/fairseq_util.py
deleted file mode 100644
index a60ad3e..0000000
--- a/simplify/simplifiers/models/muss/fairseq_util.py
+++ /dev/null
@@ -1,143 +0,0 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-from collections import defaultdict
-from pathlib import Path
-
-import re
-import shutil
-import shlex
-
-
-from fairseq_cli import generate
-
-
-from .utils import (
- log_std_streams,
- yield_lines,
- write_lines,
- mock_cli_args,
- create_temp_dir,
- mute,
- args_dict_to_str,
-)
-
-
-def remove_multiple_whitespaces(text):
- return re.sub(r" +", " ", text)
-
-
-def fairseq_parse_all_hypotheses(out_filepath):
- hypotheses_dict = defaultdict(list)
- for line in yield_lines(out_filepath):
- match = re.match(r"^H-(\d+)\t-?\d+\.\d+\t(.*)$", line)
- if match:
- sample_id, hypothesis = match.groups()
- hypotheses_dict[int(sample_id)].append(hypothesis)
- # Sort in original order
- return [hypotheses_dict[i] for i in range(len(hypotheses_dict))]
-
-
-def _fairseq_generate(
- complex_filepath,
- output_pred_filepath,
- checkpoint_paths,
- complex_dictionary_path,
- simple_dictionary_path,
- beam=5,
- hypothesis_num=1,
- lenpen=1.0,
- diverse_beam_groups=None,
- diverse_beam_strength=0.5,
- sampling=False,
- max_tokens=16384,
- source_lang="complex",
- target_lang="simple",
- **kwargs,
-):
- # exp_dir must contain checkpoints/checkpoint_best.pt, and dict.{complex,simple}.txt
- # First copy input complex file to exp_dir and create dummy simple file
-
- with create_temp_dir() as temp_dir:
- new_complex_filepath = (
- temp_dir / f"tmp.{source_lang}-{target_lang}.{source_lang}"
- )
- dummy_simple_filepath = (
- temp_dir / f"tmp.{source_lang}-{target_lang}.{target_lang}"
- )
- shutil.copy(complex_filepath, new_complex_filepath)
- shutil.copy(complex_filepath, dummy_simple_filepath)
- shutil.copy(complex_dictionary_path, temp_dir / f"dict.{source_lang}.txt")
- shutil.copy(simple_dictionary_path, temp_dir / f"dict.{target_lang}.txt")
- args = f"""
- {temp_dir} --dataset-impl raw --gen-subset tmp --path {':'.join([str(path) for path in checkpoint_paths])}
- --beam {beam} --nbest {hypothesis_num} --lenpen {lenpen}
- --diverse-beam-groups {diverse_beam_groups if diverse_beam_groups is not None else -1} --diverse-beam-strength {diverse_beam_strength}
- --max-tokens {max_tokens}
- --model-overrides "{{'encoder_embed_path': None, 'decoder_embed_path': None}}"
- --skip-invalid-size-inputs-valid-test
- """
- if sampling:
- args += f"--sampling --sampling-topk 10"
- # FIXME: if the kwargs are already present in the args string, they will appear twice but fairseq will take only the last one into account
- args += f" {args_dict_to_str(kwargs)}"
- args = remove_multiple_whitespaces(args.replace("\n", " "))
- out_filepath = temp_dir / "generation.out"
- with mute(mute_stderr=False):
- with log_std_streams(out_filepath):
- # evaluate model in batch mode
- args = shlex.split(args)
- with mock_cli_args(args):
- generate.cli_main()
-
- all_hypotheses = fairseq_parse_all_hypotheses(out_filepath)
- predictions = [hypotheses[hypothesis_num - 1] for hypotheses in all_hypotheses]
- write_lines(predictions, output_pred_filepath)
-
-
-def fairseq_generate(
- complex_filepath,
- output_pred_filepath,
- exp_dir,
- beam=5,
- hypothesis_num=1,
- lenpen=1.0,
- diverse_beam_groups=None,
- diverse_beam_strength=0.5,
- sampling=False,
- max_tokens=8000,
- source_lang="complex",
- target_lang="simple",
- **kwargs,
-):
-
- exp_dir = Path(exp_dir)
- possible_checkpoint_paths = [
- exp_dir / "model.pt",
- exp_dir / "checkpoints/checkpoint_best.pt",
- exp_dir / "checkpoints/checkpoint_last.pt",
- ]
- assert any(
- [path for path in possible_checkpoint_paths if path.exists()]
- ), f"Generation failed, no checkpoint found in {possible_checkpoint_paths}" # noqa: E501
- checkpoint_path = [path for path in possible_checkpoint_paths if path.exists()][0]
- complex_dictionary_path = exp_dir / f"dict.{source_lang}.txt"
- simple_dictionary_path = exp_dir / f"dict.{target_lang}.txt"
- _fairseq_generate(
- complex_filepath,
- output_pred_filepath,
- [checkpoint_path],
- complex_dictionary_path=complex_dictionary_path,
- simple_dictionary_path=simple_dictionary_path,
- beam=beam,
- hypothesis_num=hypothesis_num,
- lenpen=lenpen,
- diverse_beam_groups=diverse_beam_groups,
- diverse_beam_strength=diverse_beam_strength,
- sampling=sampling,
- max_tokens=max_tokens,
- **kwargs,
- )
diff --git a/simplify/simplifiers/models/muss/preprocessors.py b/simplify/simplifiers/models/muss/preprocessors.py
deleted file mode 100644
index f0e728f..0000000
--- a/simplify/simplifiers/models/muss/preprocessors.py
+++ /dev/null
@@ -1,601 +0,0 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-from abc import ABC
-from functools import wraps, lru_cache
-import hashlib
-from pathlib import Path
-
-# import dill as pickle
-import shutil
-
-import re
-
-import numpy as np
-
-from fairseq.data.encoders.gpt2_bpe_utils import get_encoder
-
-from .utils import (
- write_lines_in_parallel,
- yield_lines_in_parallel,
- add_dicts,
- get_default_args,
- get_temp_filepath,
- failsafe_division,
- download,
- download_and_extract,
- yield_lines,
-)
-
-from simplify import SIMPLIFY_CACHE
-from simplify.evaluators import lev_ratio
-
-FATTEXT_EMBEDDINGS_DIR = SIMPLIFY_CACHE / "fasttext-vectors/"
-SPECIAL_TOKEN_REGEX = r"<[a-zA-Z\-_\d\.]+>"
-
-
-def get_fasttext_embeddings_path(language="en"):
- fasttext_embeddings_path = FASTTEXT_EMBEDDINGS_DIR / f"cc.{language}.300.vec"
- if not fasttext_embeddings_path.exists():
- url = f"https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.{language}.300.vec.gz"
- fasttext_embeddings_path.parent.mkdir(parents=True, exist_ok=True)
- shutil.move(download_and_extract(url)[0], fasttext_embeddings_path)
- return fasttext_embeddings_path
-
-
-@lru_cache(maxsize=10)
-def get_spacy_model(language="en", size="md"):
- # Inline lazy import because importing spacy is slow
- import spacy
-
- if language == "it" and size == "md":
- print(
- "Model it_core_news_md is not available for italian, falling back to it_core_news_sm"
- )
- size = "sm"
- model_name = {
- "en": f"en_core_web_{size}",
- "fr": f"fr_core_news_{size}",
- "es": f"es_core_news_{size}",
- "it": f"it_core_news_{size}",
- "de": f"de_core_news_{size}",
- }[language]
- return spacy.load(model_name) # python -m spacy download en_core_web_sm
-
-
-@lru_cache(maxsize=10 ** 6)
-def spacy_process(text, language="en", size="md"):
- return get_spacy_model(language=language, size=size)(str(text))
-
-
-@lru_cache(maxsize=1)
-def get_spacy_tokenizer(language="en"):
- return get_spacy_model(language=language).Defaults.create_tokenizer(
- get_spacy_model(language=language)
- )
-
-
-def get_spacy_content_tokens(text, language="en"):
- def is_content_token(token):
- return (
- not token.is_stop and not token.is_punct and token.ent_type_ == ""
- ) # Not named entity
-
- return [
- token
- for token in get_spacy_tokenizer(language=language)(text)
- if is_content_token(token)
- ]
-
-
-def get_content_words(text, language="en"):
- return [token.text for token in get_spacy_content_tokens(text, language=language)]
-
-
-@lru_cache(maxsize=10)
-def get_word2rank(vocab_size=10 ** 5, language="en"):
- word2rank = {}
- line_generator = yield_lines(get_fasttext_embeddings_path(language))
- next(line_generator) # Skip the first line (header)
- for i, line in enumerate(line_generator):
- if (i + 1) > vocab_size:
- break
- word = line.split(" ")[0]
- word2rank[word] = i
- return word2rank
-
-
-def get_rank(word, language="en"):
- return get_word2rank(language=language).get(
- word, len(get_word2rank(language=language))
- )
-
-
-def get_log_rank(word, language="en"):
- return np.log(1 + get_rank(word, language=language))
-
-
-def get_log_ranks(text, language="en"):
- return [
- get_log_rank(word, language=language)
- for word in get_content_words(text, language=language)
- if word in get_word2rank(language=language)
- ]
-
-
-# Single sentence feature extractors with signature function(sentence) -> float
-def get_lexical_complexity_score(sentence, language="en"):
- log_ranks = get_log_ranks(sentence, language=language)
- if len(log_ranks) == 0:
- log_ranks = [
- np.log(1 + len(get_word2rank()))
- ] # TODO: This is completely arbitrary
- return np.quantile(log_ranks, 0.75)
-
-
-def get_levenshtein_similarity(complex_sentence, simple_sentence):
- return lev_ratio(complex_sentence, simple_sentence)
-
-
-def get_levenshtein_distance(complex_sentence, simple_sentence):
- # We should rename this to get_levenshtein_distance_ratio for more clarity
- return 1 - get_levenshtein_similarity(complex_sentence, simple_sentence)
-
-
-def get_replace_only_levenshtein_distance(complex_sentence, simple_sentence):
- return len(
- [
- _
- for operation, _, _ in Levenshtein.editops(
- complex_sentence, simple_sentence
- )
- if operation == "replace"
- ]
- )
-
-
-def get_replace_only_levenshtein_distance_ratio(complex_sentence, simple_sentence):
- max_replace_only_distance = min(len(complex_sentence), len(simple_sentence))
- return failsafe_division(
- get_replace_only_levenshtein_distance(complex_sentence, simple_sentence),
- max_replace_only_distance,
- default=0,
- )
-
-
-def get_replace_only_levenshtein_similarity(complex_sentence, simple_sentence):
- return 1 - get_replace_only_levenshtein_distance_ratio(
- complex_sentence, simple_sentence
- )
-
-
-def get_dependency_tree_depth(sentence, language="en"):
- def get_subtree_depth(node):
- if len(list(node.children)) == 0:
- return 0
- return 1 + max([get_subtree_depth(child) for child in node.children])
-
- tree_depths = [
- get_subtree_depth(spacy_sentence.root)
- for spacy_sentence in spacy_process(sentence, language=language).sents
- ]
- if len(tree_depths) == 0:
- return 0
- return max(tree_depths)
-
-
-PREPROCESSORS_REGISTRY = {}
-
-
-def get_preprocessor_by_name(preprocessor_name):
- return PREPROCESSORS_REGISTRY[preprocessor_name]
-
-
-def get_preprocessors(preprocessors_kwargs):
- preprocessors = []
- for preprocessor_name, kwargs in preprocessors_kwargs.items():
- preprocessors.append(get_preprocessor_by_name(preprocessor_name)(**kwargs))
- return preprocessors
-
-
-def extract_special_tokens(sentence):
- """Remove any number of token at the beginning of the sentence"""
- match = re.match(fr"(^(?:{SPECIAL_TOKEN_REGEX} *)+) *(.*)$", sentence)
- if match is None:
- return "", sentence
- special_tokens, sentence = match.groups()
- return special_tokens.strip(), sentence
-
-
-def remove_special_tokens(sentence):
- return extract_special_tokens(sentence)[1]
-
-
-def store_args(constructor):
- @wraps(constructor)
- def wrapped(self, *args, **kwargs):
- if not hasattr(self, "args") or not hasattr(self, "kwargs"):
- # TODO: Default args are not overwritten if provided as args
- self.args = args
- self.kwargs = add_dicts(get_default_args(constructor), kwargs)
- return constructor(self, *args, **kwargs)
-
- return wrapped
-
-
-# def dump_preprocessors(preprocessors, dir_path):
-# with open(Path(dir_path) / 'preprocessors.pickle', 'wb') as f:
-# pickle.dump(preprocessors, f)
-
-
-# def load_preprocessors(dir_path):
-# path = Path(dir_path) / 'preprocessors.pickle'
-# if not path.exists():
-# return None
-# with open(path, 'rb') as f:
-# return pickle.load(f)
-
-
-class AbstractPreprocessor(ABC):
- def __init_subclass__(cls, **kwargs):
- """Register all children in registry"""
- super().__init_subclass__(**kwargs)
- PREPROCESSORS_REGISTRY[cls.__name__] = cls
-
- def __repr__(self):
- args = getattr(self, "args", ())
- kwargs = getattr(self, "kwargs", {})
- args_repr = [repr(arg) for arg in args]
- kwargs_repr = [
- f"{k}={repr(v)}" for k, v in sorted(kwargs.items(), key=lambda kv: kv[0])
- ]
- args_kwargs_str = ", ".join(args_repr + kwargs_repr)
- return f"{self.__class__.__name__}({args_kwargs_str})"
-
- def get_hash_string(self):
- return self.__class__.__name__
-
- def get_hash(self):
- return hashlib.md5(self.get_hash_string().encode()).hexdigest()
-
- @staticmethod
- def get_nevergrad_variables():
- return {}
-
- @property
- def prefix(self):
- return self.__class__.__name__.replace("Preprocessor", "")
-
- def fit(self, complex_filepath, simple_filepath):
- pass
-
- def encode_sentence(self, sentence, encoder_sentence=None):
- raise NotImplementedError
-
- def decode_sentence(self, sentence, encoder_sentence=None):
- raise NotImplementedError
-
- def encode_sentence_pair(self, complex_sentence, simple_sentence):
- if complex_sentence is not None:
- complex_sentence = self.encode_sentence(complex_sentence)
- if simple_sentence is not None:
- simple_sentence = self.encode_sentence(simple_sentence)
- return complex_sentence, simple_sentence
-
- def encode_file(self, input_filepath, output_filepath, encoder_filepath=None):
- if encoder_filepath is None:
- # We will use an empty temporary file which will yield None for each line
- encoder_filepath = get_temp_filepath(create=True)
- with open(output_filepath, "w", encoding="utf-8") as f:
- for input_line, encoder_line in yield_lines_in_parallel(
- [input_filepath, encoder_filepath], strict=False
- ):
- f.write(self.encode_sentence(input_line, encoder_line) + "\n")
-
- def decode_file(self, input_filepath, output_filepath, encoder_filepath=None):
- if encoder_filepath is None:
- # We will use an empty temporary file which will yield None for each line
- encoder_filepath = get_temp_filepath(create=True)
- with open(output_filepath, "w", encoding="utf-8") as f:
- for encoder_sentence, input_sentence in yield_lines_in_parallel(
- [encoder_filepath, input_filepath], strict=False
- ):
- decoded_sentence = self.decode_sentence(
- input_sentence, encoder_sentence=encoder_sentence
- )
- f.write(decoded_sentence + "\n")
-
- def encode_file_pair(
- self,
- complex_filepath,
- simple_filepath,
- output_complex_filepath,
- output_simple_filepath,
- ):
- """Jointly encode a complex file and a simple file (can be aligned or not)"""
- with write_lines_in_parallel(
- [output_complex_filepath, output_simple_filepath], strict=False
- ) as output_files:
- for complex_line, simple_line in yield_lines_in_parallel(
- [complex_filepath, simple_filepath], strict=False
- ):
- output_files.write(self.encode_sentence_pair(complex_line, simple_line))
-
-
-class ComposedPreprocessor(AbstractPreprocessor):
- @store_args
- def __init__(self, preprocessors, sort=False):
- if preprocessors is None:
- preprocessors = []
- if sort:
- # Make sure preprocessors are always in the same order
- preprocessors = sorted(
- preprocessors, key=lambda preprocessor: preprocessor.__class__.__name__
- )
- self.preprocessors = preprocessors
-
- def get_hash_string(self):
- preprocessors_hash_strings = [
- preprocessor.get_hash_string() for preprocessor in self.preprocessors
- ]
- return f"ComposedPreprocessor(preprocessors={preprocessors_hash_strings})"
-
- def get_suffix(self):
- return "_".join([p.prefix.lower() for p in self.preprocessors])
-
- def fit(self, complex_filepath, simple_filepath):
- for preprocessor in self.preprocessors:
- pass
-
- def encode_sentence(self, sentence, encoder_sentence=None):
- for preprocessor in self.preprocessors:
- sentence = preprocessor.encode_sentence(sentence, encoder_sentence)
- return sentence
-
- def decode_sentence(self, sentence, encoder_sentence=None):
- for preprocessor in self.preprocessors:
- sentence = preprocessor.decode_sentence(sentence, encoder_sentence)
- return sentence
-
- def encode_file(self, input_filepath, output_filepath, encoder_filepath=None):
- for preprocessor in self.preprocessors:
- intermediary_output_filepath = get_temp_filepath()
- preprocessor.encode_file(
- input_filepath, intermediary_output_filepath, encoder_filepath
- )
- input_filepath = intermediary_output_filepath
- shutil.copyfile(input_filepath, output_filepath)
-
- def decode_file(self, input_filepath, output_filepath, encoder_filepath=None):
- for preprocessor in self.preprocessors:
- intermediary_output_filepath = get_temp_filepath()
- preprocessor.decode_file(
- input_filepath, intermediary_output_filepath, encoder_filepath
- )
- input_filepath = intermediary_output_filepath
- shutil.copyfile(input_filepath, output_filepath)
-
- def encode_file_pair(
- self,
- complex_filepath,
- simple_filepath,
- output_complex_filepath,
- output_simple_filepath,
- ):
- for preprocessor in self.preprocessors:
- intermediary_output_complex_filepath = get_temp_filepath()
- intermediary_output_simple_filepath = get_temp_filepath()
- preprocessor.encode_file_pair(
- complex_filepath,
- simple_filepath,
- intermediary_output_complex_filepath,
- intermediary_output_simple_filepath,
- )
- complex_filepath = intermediary_output_complex_filepath
- simple_filepath = intermediary_output_simple_filepath
- shutil.copyfile(complex_filepath, output_complex_filepath)
- shutil.copyfile(simple_filepath, output_simple_filepath)
-
- def encode_sentence_pair(self, complex_sentence, simple_sentence):
- for preprocessor in self.preprocessors:
- complex_sentence, simple_sentence = preprocessor.encode_sentence_pair(
- complex_sentence, simple_sentence
- )
- return complex_sentence, simple_sentence
-
-
-class FeaturePreprocessor(AbstractPreprocessor):
- """Prepend a computed feature at the beginning of the sentence"""
-
- @store_args
- def __init__(
- self,
- feature_name,
- get_feature_value,
- get_target_feature_value,
- bucket_size=0.05,
- noise_std=0,
- prepend_to_target=False,
- use_short_name=False,
- ):
- self.get_feature_value = get_feature_value
- self.get_target_feature_value = get_target_feature_value
- self.bucket_size = bucket_size
- self.noise_std = noise_std
- self.feature_name = feature_name.upper()
- self.use_short_name = use_short_name
- if use_short_name:
- # There might be collisions
- self.feature_name = self.feature_name.lower()[:4]
- self.prepend_to_target = prepend_to_target
-
- def get_hash_string(self):
- return f"{self.__class__.__name__}(feature_name={repr(self.feature_name)}, bucket_size={self.bucket_size}, noise_std={self.noise_std}, prepend_to_target={self.prepend_to_target}, use_short_name={self.use_short_name})" # noqa: E501
-
- def bucketize(self, value):
- """Round value to bucket_size to reduce the number of different values"""
- return round(round(value / self.bucket_size) * self.bucket_size, 10)
-
- def add_noise(self, value):
- return value + np.random.normal(0, self.noise_std)
-
- def get_feature_token(self, feature_value):
- return f"<{self.feature_name}_{feature_value}>"
-
- def encode_sentence(self, sentence, encoder_sentence=None):
- if not self.prepend_to_target:
- desired_feature = self.bucketize(
- self.get_target_feature_value(remove_special_tokens(sentence))
- )
- sentence = f"{self.get_feature_token(desired_feature)} {sentence}"
- return sentence
-
- def decode_sentence(self, sentence, encoder_sentence=None):
- if self.prepend_to_target:
- _, sentence = extract_special_tokens(sentence)
- return sentence
-
- def encode_sentence_pair(self, complex_sentence, simple_sentence):
- feature = self.bucketize(
- self.add_noise(
- self.get_feature_value(
- remove_special_tokens(complex_sentence),
- remove_special_tokens(simple_sentence),
- )
- )
- )
- if self.prepend_to_target:
- simple_sentence = f"{self.get_feature_token(feature)} {simple_sentence}"
- else:
- complex_sentence = f"{self.get_feature_token(feature)} {complex_sentence}"
- return complex_sentence, simple_sentence
-
-
-class LevenshteinPreprocessor(FeaturePreprocessor):
- @store_args
- def __init__(self, target_ratio=0.8, bucket_size=0.05, noise_std=0, **kwargs):
- self.target_ratio = target_ratio
- super().__init__(
- self.prefix.upper(),
- self.get_feature_value,
- self.get_target_feature_value,
- bucket_size,
- noise_std,
- **kwargs,
- )
-
- def get_feature_value(self, complex_sentence, simple_sentence):
- return get_levenshtein_similarity(complex_sentence, simple_sentence)
-
- def get_target_feature_value(self, complex_sentence):
- return self.target_ratio
-
-
-class ReplaceOnlyLevenshteinPreprocessor(LevenshteinPreprocessor):
- def get_feature_value(self, complex_sentence, simple_sentence):
- return get_replace_only_levenshtein_similarity(
- complex_sentence, simple_sentence
- )
-
-
-class RatioPreprocessor(FeaturePreprocessor):
- @store_args
- def __init__(
- self,
- feature_extractor,
- target_ratio=0.8,
- bucket_size=0.05,
- noise_std=0,
- **kwargs,
- ):
- self.feature_extractor = feature_extractor
- self.target_ratio = target_ratio
- super().__init__(
- self.prefix.upper(),
- self.get_feature_value,
- self.get_target_feature_value,
- bucket_size,
- noise_std,
- **kwargs,
- )
-
- def get_feature_value(self, complex_sentence, simple_sentence):
- return min(
- failsafe_division(
- self.feature_extractor(simple_sentence),
- self.feature_extractor(complex_sentence),
- ),
- 2,
- )
-
- def get_target_feature_value(self, complex_sentence):
- return self.target_ratio
-
-
-class LengthRatioPreprocessor(RatioPreprocessor):
- @store_args
- def __init__(self, *args, **kwargs):
- super().__init__(len, *args, **kwargs)
-
-
-class WordRankRatioPreprocessor(RatioPreprocessor):
- @store_args
- def __init__(self, *args, language="en", **kwargs):
- super().__init__(
- lambda sentence: get_lexical_complexity_score(sentence, language=language),
- *args,
- **kwargs,
- )
-
-
-class DependencyTreeDepthRatioPreprocessor(RatioPreprocessor):
- @store_args
- def __init__(self, *args, language="en", **kwargs):
- super().__init__(
- lambda sentence: get_dependency_tree_depth(sentence, language=language),
- *args,
- **kwargs,
- )
-
-
-class GPT2BPEPreprocessor(AbstractPreprocessor):
- def __init__(self):
- self.bpe_dir = SIMPLIFY_CACHE / "bart_bpe"
- self.bpe_dir.mkdir(exist_ok=True, parents=True)
- self.encoder_json_path = self.bpe_dir / "encoder.json"
- self.vocab_bpe_path = self.bpe_dir / "vocab.bpe"
- self.dict_path = self.bpe_dir / "dict.txt"
- download(
- "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json",
- self.encoder_json_path,
- overwrite=False,
- )
- download(
- "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe",
- self.vocab_bpe_path,
- overwrite=False,
- )
- download(
- "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt",
- self.dict_path,
- overwrite=False,
- )
-
- @property
- @lru_cache(maxsize=1)
- def bpe_encoder(self):
- """
- We need to use a property because GPT2BPEPreprocessor() is cannot pickled
- > pickle.dumps(GPT2BPEPreprocessor())
- ----> TypeError: can't pickle module objects
- """
- return get_encoder(self.encoder_json_path, self.vocab_bpe_path)
-
- def encode_sentence(self, sentence, *args, **kwargs):
- return " ".join([str(idx) for idx in self.bpe_encoder.encode(sentence)])
-
- def decode_sentence(self, sentence, *args, **kwargs):
- return self.bpe_encoder.decode([int(idx) for idx in sentence.split(" ")])
diff --git a/simplify/simplifiers/models/muss/simplifiers.py b/simplify/simplifiers/models/muss/simplifiers.py
deleted file mode 100644
index 08deef5..0000000
--- a/simplify/simplifiers/models/muss/simplifiers.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-from functools import wraps
-from pathlib import Path
-import shutil
-
-from imohash import hashfile
-
-from .fairseq_util import fairseq_generate
-from .preprocessors import ComposedPreprocessor
-from .utils import count_lines, get_temp_filepath
-
-
-def memoize_simplifier(simplifier):
- memo = {}
-
- @wraps(simplifier)
- def wrapped(complex_filepath, pred_filepath):
- complex_filehash = hashfile(complex_filepath, hexdigest=True)
- previous_pred_filepath = memo.get(complex_filehash)
- if previous_pred_filepath is not None and Path(previous_pred_filepath).exists():
- assert count_lines(complex_filepath) == count_lines(previous_pred_filepath)
- # Reuse previous prediction
- shutil.copyfile(previous_pred_filepath, pred_filepath)
- else:
- simplifier(complex_filepath, pred_filepath)
- # Save prediction
- memo[complex_filehash] = pred_filepath
-
- return wrapped
-
-
-def make_output_file_optional(simplifier):
- @wraps(simplifier)
- def wrapped(complex_filepath, pred_filepath=None):
- if pred_filepath is None:
- pred_filepath = get_temp_filepath()
- simplifier(complex_filepath, pred_filepath)
- return pred_filepath
-
- return wrapped
-
-
-def get_fairseq_simplifier(exp_dir, **kwargs):
- """Function factory"""
-
- @make_output_file_optional
- @memoize_simplifier
- def fairseq_simplifier(complex_filepath, output_pred_filepath):
- fairseq_generate(complex_filepath, output_pred_filepath, exp_dir, **kwargs)
-
- return fairseq_simplifier
-
-
-def get_preprocessed_simplifier(simplifier, preprocessors):
- composed_preprocessor = ComposedPreprocessor(preprocessors)
-
- @make_output_file_optional
- @memoize_simplifier
- @wraps(simplifier)
- def preprocessed_simplifier(complex_filepath, pred_filepath):
- preprocessed_complex_filepath = get_temp_filepath()
- composed_preprocessor.encode_file(
- complex_filepath, preprocessed_complex_filepath
- )
- preprocessed_pred_filepath = simplifier(preprocessed_complex_filepath)
- composed_preprocessor.decode_file(
- preprocessed_pred_filepath, pred_filepath, encoder_filepath=complex_filepath
- )
-
- preprocessed_simplifier.__name__ = (
- f"{preprocessed_simplifier.__name__}_{composed_preprocessor.get_suffix()}"
- )
- return preprocessed_simplifier
diff --git a/simplify/simplifiers/models/muss/utils.py b/simplify/simplifiers/models/muss/utils.py
deleted file mode 100644
index 55decb5..0000000
--- a/simplify/simplifiers/models/muss/utils.py
+++ /dev/null
@@ -1,385 +0,0 @@
-# Copyright (c) Facebook, Inc. and its affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-from contextlib import contextmanager
-import gzip
-import inspect
-from io import StringIO
-from itertools import zip_longest
-from pathlib import Path
-import shutil
-import sys
-import tempfile
-import time
-from types import MethodType
-
-
-import bz2
-import os
-import tarfile
-from urllib.request import urlretrieve
-import zipfile
-
-from tqdm import tqdm
-
-
-def reporthook(count, block_size, total_size):
- # Download progress bar
- global start_time
- if count == 0:
- start_time = time.time()
- return
- duration = time.time() - start_time
- progress_size_mb = count * block_size / (1024 * 1024)
- speed = progress_size_mb / duration
- percent = int(count * block_size * 100 / total_size)
- msg = f"\r... {percent}% - {int(progress_size_mb)} MB - {speed:.2f} MB/s - {int(duration)}s"
- sys.stdout.write(msg)
-
-
-def download(url, destination_path=None, overwrite=True):
- if destination_path is None:
- destination_path = get_temp_filepath()
- if not overwrite and destination_path.exists():
- return destination_path
- print("Downloading...")
- try:
- urlretrieve(url, destination_path, reporthook)
- sys.stdout.write("\n")
- except (Exception, KeyboardInterrupt, SystemExit):
- print("Rolling back: remove partially downloaded file")
- os.remove(destination_path)
- raise
- return destination_path
-
-
-def download_and_extract(url):
- tmp_dir = Path(tempfile.mkdtemp())
- compressed_filename = url.split("/")[-1]
- compressed_filepath = tmp_dir / compressed_filename
- download(url, compressed_filepath)
- print("Extracting...")
- extracted_paths = extract(compressed_filepath, tmp_dir)
- compressed_filepath.unlink()
- return extracted_paths
-
-
-def extract(filepath, output_dir):
- output_dir = Path(output_dir)
- # Infer extract function based on extension
- extensions_to_functions = {
- ".tar.gz": untar,
- ".tar.bz2": untar,
- ".tgz": untar,
- ".zip": unzip,
- ".gz": ungzip,
- ".bz2": unbz2,
- }
-
- def get_extension(filename, extensions):
- possible_extensions = [ext for ext in extensions if filename.endswith(ext)]
- if len(possible_extensions) == 0:
- raise Exception(f"File {filename} has an unknown extension")
- # Take the longest (.tar.gz should take precedence over .gz)
- return max(possible_extensions, key=lambda ext: len(ext))
-
- filename = os.path.basename(filepath)
- extension = get_extension(filename, list(extensions_to_functions))
- extract_function = extensions_to_functions[extension]
-
- # Extract files in a temporary dir then move the extracted item back to
- # the ouput dir in order to get the details of what was extracted
- tmp_extract_dir = Path(tempfile.mkdtemp())
- # Extract
- extract_function(filepath, output_dir=tmp_extract_dir)
- extracted_items = os.listdir(tmp_extract_dir)
- output_paths = []
- for name in extracted_items:
- extracted_path = tmp_extract_dir / name
- output_path = output_dir / name
- move_with_overwrite(extracted_path, output_path)
- output_paths.append(output_path)
- return output_paths
-
-
-def move_with_overwrite(source_path, target_path):
- if os.path.isfile(target_path):
- os.remove(target_path)
- if os.path.isdir(target_path) and os.path.isdir(source_path):
- shutil.rmtree(target_path)
- shutil.move(source_path, target_path)
-
-
-def untar(compressed_path, output_dir):
- with tarfile.open(compressed_path) as f:
- f.extractall(output_dir)
-
-
-def unzip(compressed_path, output_dir):
- with zipfile.ZipFile(compressed_path, "r") as f:
- f.extractall(output_dir)
-
-
-def ungzip(compressed_path, output_dir):
- filename = os.path.basename(compressed_path)
- assert filename.endswith(".gz")
- if not os.path.exists(output_dir):
- os.makedirs(output_dir)
- output_path = os.path.join(output_dir, filename[:-3])
- with gzip.open(compressed_path, "rb") as f_in:
- with open(output_path, "wb") as f_out:
- shutil.copyfileobj(f_in, f_out)
-
-
-def unbz2(compressed_path, output_dir):
- extract_filename = os.path.basename(compressed_path).replace(".bz2", "")
- extract_path = os.path.join(output_dir, extract_filename)
- with bz2.BZ2File(compressed_path, "rb") as compressed_file, open(
- extract_path, "wb"
- ) as extract_file:
- for data in tqdm(iter(lambda: compressed_file.read(1024 * 1024), b"")):
- extract_file.write(data)
-
-
-@contextmanager
-def open_files(filepaths, mode="r"):
- files = []
- try:
- files = [Path(filepath).open(mode, encoding="utf-8") for filepath in filepaths]
- yield files
- finally:
- [f.close() for f in files]
-
-
-def yield_lines_in_parallel(filepaths, strip=True, strict=True, n_lines=float("inf")):
- assert type(filepaths) == list
- with open_files(filepaths) as files:
- for i, parallel_lines in enumerate(zip_longest(*files)):
- if i >= n_lines:
- break
- if None in parallel_lines:
- assert (
- not strict
- ), f"Files don't have the same number of lines: {filepaths}, use strict=False"
- if strip:
- parallel_lines = [
- l.rstrip("\n") if l is not None else None for l in parallel_lines
- ]
- yield parallel_lines
-
-
-class FilesWrapper:
- """Write to multiple open files at the same time"""
-
- def __init__(self, files, strict=True):
- self.files = files
- self.strict = strict # Whether to raise an exception when a line is None
-
- def write(self, lines):
- assert len(lines) == len(self.files)
- for line, f in zip(lines, self.files):
- if line is None:
- assert not self.strict
- continue
- f.write(line.rstrip("\n") + "\n")
-
-
-@contextmanager
-def write_lines_in_parallel(filepaths, strict=True):
- with open_files(filepaths, "w") as files:
- yield FilesWrapper(files, strict=strict)
-
-
-def write_lines(lines, filepath=None):
- if filepath is None:
- filepath = get_temp_filepath()
- filepath = Path(filepath)
- filepath.parent.mkdir(parents=True, exist_ok=True)
- with filepath.open("w", encoding="utf-8") as f:
- for line in lines:
- f.write(line + "\n")
- return filepath
-
-
-def yield_lines(filepath, gzipped=False, n_lines=None):
- filepath = Path(filepath)
- open_function = open
- if gzipped or filepath.name.endswith(".gz"):
- open_function = gzip.open
- with open_function(filepath, "rt", encoding="utf-8") as f:
- for i, l in enumerate(f):
- if n_lines is not None and i >= n_lines:
- break
- yield l.rstrip("\n")
-
-
-def read_lines(filepath, gzipped=False):
- return list(yield_lines(filepath, gzipped=gzipped))
-
-
-def count_lines(filepath):
- n_lines = 0
- # We iterate over the generator to avoid loading the whole file in memory
- for _ in yield_lines(filepath):
- n_lines += 1
- return n_lines
-
-
-def arg_name_python_to_cli(arg_name, cli_sep="-"):
- arg_name = arg_name.replace("_", cli_sep)
- return f"--{arg_name}"
-
-
-def kwargs_to_cli_args_list(kwargs, cli_sep="-"):
- cli_args_list = []
- for key, val in kwargs.items():
- key = arg_name_python_to_cli(key, cli_sep)
- if isinstance(val, bool):
- cli_args_list.append(str(key))
- else:
- if isinstance(val, str):
- # Add quotes around val
- assert "'" not in val
- val = f"'{val}'"
- cli_args_list.extend([str(key), str(val)])
- return cli_args_list
-
-
-def args_dict_to_str(args_dict, cli_sep="-"):
- return " ".join(kwargs_to_cli_args_list(args_dict, cli_sep=cli_sep))
-
-
-def failsafe_division(a, b, default=0):
- if b == 0:
- return default
- return a / b
-
-
-@contextmanager
-def redirect_streams(source_streams, target_streams):
- # We assign these functions before hand in case a target stream is also a source stream.
- # If it's the case then the write function would be patched leading to infinie recursion
- target_writes = [target_stream.write for target_stream in target_streams]
- target_flushes = [target_stream.flush for target_stream in target_streams]
-
- def patched_write(self, message):
- for target_write in target_writes:
- target_write(message)
-
- def patched_flush(self):
- for target_flush in target_flushes:
- target_flush()
-
- original_source_stream_writes = [
- source_stream.write for source_stream in source_streams
- ]
- original_source_stream_flushes = [
- source_stream.flush for source_stream in source_streams
- ]
- try:
- for source_stream in source_streams:
- source_stream.write = MethodType(patched_write, source_stream)
- source_stream.flush = MethodType(patched_flush, source_stream)
- yield
- finally:
- for (
- source_stream,
- original_source_stream_write,
- original_source_stream_flush,
- ) in zip(
- source_streams,
- original_source_stream_writes,
- original_source_stream_flushes,
- ):
- source_stream.write = original_source_stream_write
- source_stream.flush = original_source_stream_flush
-
-
-@contextmanager
-def mute(mute_stdout=True, mute_stderr=True):
- streams = []
- if mute_stdout:
- streams.append(sys.stdout)
- if mute_stderr:
- streams.append(sys.stderr)
- with redirect_streams(source_streams=streams, target_streams=StringIO()):
- yield
-
-
-@contextmanager
-def log_std_streams(filepath):
- log_file = open(filepath, "w", encoding="utf-8")
- try:
- with redirect_streams(
- source_streams=[sys.stdout], target_streams=[log_file, sys.stdout]
- ):
- with redirect_streams(
- source_streams=[sys.stderr], target_streams=[log_file, sys.stderr]
- ):
- yield
- finally:
- log_file.close()
-
-
-def add_dicts(*dicts):
- return {k: v for dic in dicts for k, v in dic.items()}
-
-
-def get_default_args(func):
- signature = inspect.signature(func)
- return {
- k: v.default
- for k, v in signature.parameters.items()
- if v.default is not inspect.Parameter.empty
- }
-
-
-TEMP_DIR = None
-
-
-def get_temp_filepath(create=False):
- global TEMP_DIR
- temp_filepath = Path(tempfile.mkstemp()[1])
- if TEMP_DIR is not None:
- temp_filepath.unlink()
- temp_filepath = TEMP_DIR / temp_filepath.name
- temp_filepath.touch(exist_ok=False)
- if not create:
- temp_filepath.unlink()
- return temp_filepath
-
-
-def get_temp_dir():
- return Path(tempfile.mkdtemp())
-
-
-@contextmanager
-def create_temp_dir():
- temp_dir = get_temp_dir()
- try:
- yield temp_dir
- finally:
- shutil.rmtree(temp_dir)
-
-
-@contextmanager
-def log_action(action_description):
- start_time = time.time()
- print(f"{action_description}...")
- try:
- yield
- except BaseException as e:
- print(f"{action_description} failed after {time.time() - start_time:.2f}s.")
- raise e
- print(f"{action_description} completed after {time.time() - start_time:.2f}s.")
-
-
-@contextmanager
-def mock_cli_args(args):
- current_args = sys.argv
- sys.argv = sys.argv[:1] + args
- yield
- sys.argv = current_args