-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathprocess_chr_embedding.py
More file actions
122 lines (86 loc) · 3.38 KB
/
process_chr_embedding.py
File metadata and controls
122 lines (86 loc) · 3.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# -*- coding: utf-8 -*-
#Author: Jay Yip
#Date 05Mar2017
"""Download and process the Chinese character embedding table"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import urllib.request
import os
import pickle
import numpy as np
import tensorflow as tf
import configuration
FLAGS = tf.app.flags.FLAGS
tf.flags.DEFINE_string("chr_embedding_dir", 'polyglot-zh_char.pkl',
"Path to polyglot embedding file")
tf.flags.DEFINE_string("vocab_dir", "data/vocab.pkl",
"Path of vocabulary file.")
class Vocabulary(object):
"""Simple vocabulary wrapper."""
def __init__(self, vocab, unk_id, unk_word='<UNK>'):
"""Initializes the vocabulary.
Args:
vocab: A dictionary of word to word_id.
unk_id: Id of the special 'unknown' word.
"""
self._vocab = vocab
self._unk_id = unk_id
self._vocab[unk_word] = 0
def word_to_id(self, word):
"""Returns the integer id of a word string."""
if word in self._vocab:
return self._vocab[word]
else:
return self._unk_id
def id_to_word(self, word_id):
"""Returns the word string of an integer word id."""
if word_id >= len(self._vocab):
return self._vocab[self.unk_id]
else:
return self._vocab[word_id]
def download_embedding():
"""
Download files from web
Seems cannot download by pgm
Download from: https://sites.google.com/site/rmyeid/projects/polyglot
Returns:
A tuple (word, embedding). Emebddings shape is (100004, 64).
"""
assert (tf.gfile.Exists(FLAGS.chr_embedding_dir)), (
"Embedding pkl don't found, please \
download the Chinese chr embedding from https://sites.google.com/site/rmyeid/projects/polyglot"
)
with open(FLAGS.chr_embedding_dir, 'rb') as f:
u = pickle._Unpickler(f)
u.encoding = 'latin1'
p = u.load()
return p
def process_embedding(vocab, original_embedding, config):
"""
This function will process the embedding. The embedding table will be organized with
the same order as the word_count. Any unknown features will be abandomed.
Args:
vocab: Vocabulary obj generated by build input
original_embedding: A tuple (word, embedding). Emebddings shape is (100004, 64).
Returns:
embedding_table: A numpy 2d array. Will be feed to embedding_placeholder when graph execution
"""
#Init 2d numpy array
embedding_table = np.zeros((len(vocab._vocab), config.embedding_size))
word, embedding = original_embedding
for i, w in enumerate(word):
embedding_table[vocab.word_to_id(w), :] = embedding[i, :]
#Manually set the last row of embedding(unknown chr)
embedding_table[0, :] = embedding[0, :]
return embedding_table
def main(unused_argv):
#Load configuration
model_config = configuration.ModelConfig()
#Load vocabulary object
vocab = pickle.load(open(FLAGS.vocab_dir, 'rb'))
original_embedding = download_embedding()
chr_embedding = process_embedding(vocab, original_embedding, model_config)
pickle.dump(chr_embedding, open('chr_embedding.pkl', 'wb'))
if __name__ == '__main__':
tf.app.run()