From a5a4353cfc9cfa6871cca87f06c7782907de7fb2 Mon Sep 17 00:00:00 2001 From: Stefano Chiodino Date: Mon, 28 Dec 2020 20:43:06 +0000 Subject: [PATCH] Update trainers.py Hi, I wrote a trainer for the NPS corpus which I thought was pretty good. This can probably done better if you know what you are doing but I was wondering if it was worth having in your codebase. Cheers --- chatterbot/trainers.py | 62 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/chatterbot/trainers.py b/chatterbot/trainers.py index 7411cc838..8ec8d76dd 100644 --- a/chatterbot/trainers.py +++ b/chatterbot/trainers.py @@ -2,10 +2,13 @@ import sys import csv import time +import nltk +import os.path from dateutil import parser as date_parser from chatterbot.conversation import Statement from chatterbot.tagging import PosLemmaTagger from chatterbot import utils +from nltk.corpus import XMLCorpusReader class Trainer(object): @@ -348,3 +351,62 @@ def chunks(items, items_per_chunk): self.chatbot.storage.create_many(statements_from_file) print('Training took', time.time() - start_time, 'seconds.') + + +class ChatterBotNpsCorpusTrainer(Trainer): + """ + Allows the chat bot to be trained using data from the + NPS chat dialog corpus. + """ + + @staticmethod + def get_nps_chat_path() -> str: + for candidate_root_path in nltk.data.path: + candidate_path = os.path.join(candidate_root_path, "corpora", "nps_chat") + if os.path.exists(candidate_path): + return candidate_path + raise Exception("Can't find NPS chat path!") + + def train(self, *corpus_paths): + nltk.download("nps_chat") + nps_chat_path = ChatterBotNpsCorpusTrainer.get_nps_chat_path() + + xml_corpus_reader = XMLCorpusReader(nps_chat_path, r'.*\.xml') + + statements_to_create = [] + conversation_count = 0 + for file in xml_corpus_reader.fileids(): + conversation_count += 1 + xml = xml_corpus_reader.xml(file) + posts = xml.findall(".//Post") + if self.show_training_progress: + utils.print_progress_bar( + 'Training ' + str(os.path.basename(file)), + conversation_count, + len(xml_corpus_reader.fileids()) + ) + + previous_statement_text = None + previous_statement_search_text = '' + + for message, user in [(post.text, post.attrib["user"]) for post in posts if post.text not in ("PART", "JOIN")]: + statement_search_text = self.chatbot.storage.tagger.get_bigram_pair_string(message) + + statement = Statement( + text=message, + search_text=statement_search_text, + in_response_to=previous_statement_text, + search_in_response_to=previous_statement_search_text, + conversation='training' + ) + + # statement.add_tags(*categories) + + statement = self.get_preprocessed_statement(statement) + + previous_statement_text = statement.text + previous_statement_search_text = statement_search_text + + statements_to_create.append(statement) + + self.chatbot.storage.create_many(statements_to_create)