-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathUtils.py
More file actions
166 lines (137 loc) · 5.1 KB
/
Utils.py
File metadata and controls
166 lines (137 loc) · 5.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
##Few functions are adapted from: https://github.com/memray/OpenNMT-kpg-release/tree/master/onmt/keyphrase
from nltk.stem.porter import *
import numpy as np
import pickle
import hickle
stemmer = PorterStemmer()
DIGIT_token = "<digit>"
def stem_word_list(word_list):
return [stemmer.stem(w.strip()) for w in word_list]
def if_present_duplicate_phrases(src_seq, tgt_seqs, stemming=True, lowercase=True):
"""
Check if each given target sequence verbatim appears in the source sequence
:param src_seq:
:param tgt_seqs:
:param stemming:
:param lowercase:
:param check_duplicate:
:return:
"""
if lowercase:
src_seq = [w.lower() for w in src_seq]
if stemming:
src_seq = stem_word_list(src_seq)
present_indices = []
present_flags = []
duplicate_flags = []
phrase_set = set() # some phrases are duplicate after stemming, like "model" and "models" would be same after stemming, thus we ignore the following ones
for tgt_seq in tgt_seqs:
if lowercase:
tgt_seq = [w.lower() for w in tgt_seq]
if stemming:
tgt_seq = stem_word_list(tgt_seq)
# check if the phrase appears in source text
# iterate each word in source
match_flag, match_pos_idx = if_present_phrase(src_seq, tgt_seq)
# if it reaches the end of source and no match, means it doesn't appear in the source
present_flags.append(match_flag)
present_indices.append(match_pos_idx)
# check if it is duplicate
if '_'.join(tgt_seq) in phrase_set:
duplicate_flags.append(True)
else:
duplicate_flags.append(False)
phrase_set.add('_'.join(tgt_seq))
assert len(present_flags) == len(present_indices)
return np.asarray(present_flags), \
np.asarray(present_indices), \
np.asarray(duplicate_flags)
def if_present_phrase(src_str_tokens, phrase_str_tokens):
"""
:param src_str_tokens: a list of strings (words) of source text
:param phrase_str_tokens: a list of strings (words) of a phrase
:return:
"""
match_flag = False
match_pos_idx = -1
for src_start_idx in range(len(src_str_tokens) - len(phrase_str_tokens) + 1):
match_flag = True
# iterate each word in target, if one word does not match, set match=False and break
for seq_idx, seq_w in enumerate(phrase_str_tokens):
src_w = src_str_tokens[src_start_idx + seq_idx]
if src_w != seq_w:
match_flag = False
break
if match_flag:
match_pos_idx = src_start_idx
break
return match_flag, match_pos_idx
def meng17_tokenize(text):
'''
The tokenizer used in Meng et al. ACL 2017
parse the feed-in text, filtering and tokenization
keep [_<>,\(\)\.\'%], replace digits to <digit>, split by [^a-zA-Z0-9_<>,\(\)\.\'%]
:param text:
:return: a list of tokens
'''
# remove line breakers
text = re.sub(r'[\r\n\t]', ' ', text)
# pad spaces to the left and right of special punctuations
text = re.sub(r'[_<>,\(\)\.\'%]', ' \g<0> ', text)
# tokenize by non-letters (new-added + # & *, but don't pad spaces, to make them as one whole word)
text = text.lower()
tokens = list(filter(lambda w: len(w) > 0, re.split(r'[^a-zA-Z0-9_<>,#&\+\*\(\)\.\']', text)))
return tokens
def replace_numbers_to_DIGIT(tokens, k=2):
# replace big numbers (contain more than k digit) with <digit>
tokens = [w if not re.match('^\d{%d,}$' % k, w) else DIGIT_token for w in tokens]
return tokens
def match_keyphrase(keyphrase_1, keyphrase_2):
keyphrase_1 = stem_word_list(keyphrase_1)
keyphrase_2 = stem_word_list(keyphrase_2)
if keyphrase_1 == keyphrase_2:
return True
else:
return False
def remove_duplicate_keyphrases(keyphrases):
unique_keyphrases = []
changed = 0
for i in range(len(keyphrases)):
match_flag = 0
for j in range(len(unique_keyphrases)):
if match_keyphrase(keyphrases[i], unique_keyphrases[j]):
#print(keyphrases)
#print(keyphrases[i], unique_keyphrases[j])
match_flag = 1
changed = 1
if match_flag == 0:
unique_keyphrases.append(keyphrases[i])
#if changed == 1:
# print(unique_keyphrases)
return unique_keyphrases
def load_data(location):
with open(location, 'rb') as file:
data = pickle.load(file)
return data
def save_data(data, location):
with open(location, 'wb') as file:
pickle.dump(data,file)
def load_data_hkl(location):
with open(location, 'rb') as file:
data = hickle.load(file)
return data
def save_data_hkl(data, location):
with open(location, 'wb') as file:
hickle.dump(data,file)
'''data_hkl = []
for d in data:
data_hkl_inside = []
for d_inside in d:
data_hkl_inside.append(d_inside)
data_hkl.append(data_hkl_inside)
with open("file.txt", "w") as f:
for s in score:
f.write(str(s) +"\n")
with open("file.txt", "r") as f:
for line in f:
score.append(int(line.strip()))'''