From 3dc3976c21e8244e5bb3071fa45d46cd669080d0 Mon Sep 17 00:00:00 2001 From: Stanislav Date: Wed, 27 Feb 2019 02:23:18 +0300 Subject: [PATCH] minor changes for easier finetuning --- elmoformanylangs/biLM.py | 94 +++++++++++++++++++++++++++------------- 1 file changed, 65 insertions(+), 29 deletions(-) diff --git a/elmoformanylangs/biLM.py b/elmoformanylangs/biLM.py index ec0ed2c..1fec6f2 100644 --- a/elmoformanylangs/biLM.py +++ b/elmoformanylangs/biLM.py @@ -16,6 +16,7 @@ import torch.optim as optim from torch.autograd import Variable from .modules.elmo import ElmobiLm +from .elmo import Embedder from .modules.lstm import LstmbiLm from .modules.token_embedder import ConvTokenEmbedder, LstmTokenEmbedder from .modules.embedding_layer import EmbeddingLayer @@ -280,8 +281,14 @@ def save_model(self, path, save_classify_layer): torch.save(self.classify_layer.state_dict(), os.path.join(path, 'classifier.pkl')) def load_model(self, path): - self.token_embedder.load_state_dict(torch.load(os.path.join(path, 'token_embedder.pkl'))) - self.encoder.load_state_dict(torch.load(os.path.join(path, 'encoder.pkl'))) + + self.token_embedder.load_state_dict(torch.load(os.path.join(path, 'token_embedder.pkl'), + map_location=lambda storage, loc: storage)) + self.encoder.load_state_dict(torch.load(os.path.join(path, 'encoder.pkl'), + map_location=lambda storage, loc: storage)) + + #self.token_embedder.load_state_dict(torch.load(os.path.join(path, 'token_embedder.pkl'))) + #self.encoder.load_state_dict(torch.load(os.path.join(path, 'encoder.pkl'))) self.classify_layer.load_state_dict(torch.load(os.path.join(path, 'classifier.pkl'))) @@ -340,7 +347,7 @@ def train_model(epoch, opt, model, optimizer, loss_forward, loss_backward = model.forward(w, c, masks) loss = (loss_forward + loss_backward) / 2.0 - total_loss += loss_forward.data[0] + total_loss += loss_forward.item() n_tags = sum(lens) total_tag += n_tags loss.backward() @@ -440,6 +447,12 @@ def train(): cmd.add_argument('--valid_size', type=int, default=0, help="size of validation dataset when there's no valid.") cmd.add_argument('--eval_steps', required=False, type=int, help='report every xx batches.') + + + cmd.add_argument('--fine_tune', required=False, action="store_true", help='finetune base model') + cmd.add_argument('--old_model_folder', required=False, type=str, help='path to base model for finetuning') + + opt = cmd.parse_args(sys.argv[2:]) @@ -501,29 +514,40 @@ def train(): len(test_data), sum([len(s) - 1 for s in test_data]))) else: test_data = None - - if opt.word_embedding is not None: - embs = load_embedding(opt.word_embedding) - word_lexicon = {word: i for i, word in enumerate(embs[0])} - else: - embs = None - word_lexicon = {} - + + if opt.fine_tune: + embedder = Embedder(opt.old_model_folder) + word_lexicon = embedder.word_lexicon + char_lexicon = embedder.char_lexicon + label_to_ix = word_lexicon + embs = None + + # Maintain the vocabulary. vocabulary is used in either WordEmbeddingInput or softmax classification vocab = get_truncated_vocab(train_data, opt.min_count) + if opt.fine_tune: + if opt.word_embedding is not None: + embs = load_embedding(opt.word_embedding) + word_lexicon = {word: i for i, word in enumerate(embs[0])} + else: + embs = None + word_lexicon = {} - # Ensure index of '' is 0 - for special_word in ['', '', '', '']: - if special_word not in word_lexicon: - word_lexicon[special_word] = len(word_lexicon) - for word, _ in vocab: - if word not in word_lexicon: - word_lexicon[word] = len(word_lexicon) + # Ensure index of '' is 0 + for special_word in ['', '', '', '']: + if special_word not in word_lexicon: + word_lexicon[special_word] = len(word_lexicon) + + for word, _ in vocab: + if word not in word_lexicon: + word_lexicon[word] = len(word_lexicon) + # Word Embedding if config['token_embedder']['word_dim'] > 0: word_emb_layer = EmbeddingLayer(config['token_embedder']['word_dim'], word_lexicon, fix_emb=False, embs=embs) + #print(word_emb_layer) logging.info('Word embedding size: {0}'.format(len(word_emb_layer.word2id))) else: word_emb_layer = None @@ -531,16 +555,19 @@ def train(): # Character Lexicon if config['token_embedder']['char_dim'] > 0: - char_lexicon = {} - for sentence in train_data: - for word in sentence: - for ch in word: - if ch not in char_lexicon: - char_lexicon[ch] = len(char_lexicon) - - for special_char in ['', '', '', '', '', '']: - if special_char not in char_lexicon: - char_lexicon[special_char] = len(char_lexicon) + + if opt.fine_tune: + + char_lexicon = {} + for sentence in train_data: + for word in sentence: + for ch in word: + if ch not in char_lexicon: + char_lexicon[ch] = len(char_lexicon) + + for special_char in ['', '', '', '', '', '']: + if special_char not in char_lexicon: + char_lexicon[special_char] = len(char_lexicon) char_emb_layer = EmbeddingLayer(config['token_embedder']['char_dim'], char_lexicon, fix_emb=False) logging.info('Char embedding size: {0}'.format(len(char_emb_layer.word2id))) @@ -572,7 +599,15 @@ def train(): nclasses = len(label_to_ix) - model = Model(config, word_emb_layer, char_emb_layer, nclasses, use_cuda) + if opt.fine_tunes: + model = Model(config, word_emb_layer, char_emb_layer, nclasses, use_cuda) + else: + model = Model(embedder.config, word_emb_layer, char_emb_layer, nclasses, use_cuda) + model.token_embedder = embedder.model.token_embedder + model.encoder = embedder.model.encoder + + + logging.info(str(model)) if use_cuda: model = model.cuda() @@ -671,6 +706,7 @@ def test(): model = Model(config, word_emb_layer, char_emb_layer, len(word_lexicon), use_cuda) + if use_cuda: model.cuda()