-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathload.py
More file actions
47 lines (39 loc) · 1.59 KB
/
load.py
File metadata and controls
47 lines (39 loc) · 1.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 8 20:00:59 2020
@author: cm
"""
from classifier_multi_label_seq2seq_attention.hyperparameters import Hyperparamters as hp
def label2onehot(string):
string = '|' if string == '' else string
string_list = list(str(string).split('/')) + ['E']
return [int(hp.dict_label2id.get(l)) for l in string_list]
def normalization_label(label_ids):
max_length = max([len(l) for l in label_ids])
return [l + [0] * (max_length - len(l)) if len(l) < max_length else l for i, l in enumerate(label_ids)]
if __name__ == '__main__':
# Test
label_ids = [[1, 2, 3], [1, 2], [2, 3, 4, 5, 6]]
print(normalization_label(label_ids))
#
# from classifier_multi_label_seq2seq_attention.utils import load_csv,save_csv
# f = 'F:/github/classifier_multi_label_seq2seq_attention/data/test.csv'
# df = load_csv(f)
# contents = df['content'].tolist()
# labels = df['label'].tolist()
# ls = ['产品整体评价','机身颜色','外观设计','重量尺寸','机身材质','外壳做工']
# labels_new = []
# for l in labels:
# l1 = str(l).split('/')
# l1_new = []
# for li in l1:
# if li in ls:
# l1_new.append(li)
# labels_new.append(l1_new)
# #
# import pandas as pd
# df = pd.DataFrame(columns=['content','label'])
# df['content'] = contents
# df['label'] = ['/'.join(l) for l in labels_new]
# file_csv = f = 'F:/github/classifier_multi_label_seq2seq_attention/data/test0.csv'
# save_csv(df,file_csv,encoding='utf-8-sig')