Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 79 additions & 47 deletions chatterbot/corpus.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,70 @@
import os
import io
import glob
from pathlib import Path
from dataclasses import dataclass
from typing import List, Generator

from chatterbot.exceptions import OptionalDependencyImportError

# Try to import ChatterBot corpus data directory
try:
from chatterbot_corpus.corpus import DATA_DIRECTORY
except (ImportError, ModuleNotFoundError):
# Default to the home directory of the current user
DATA_DIRECTORY = os.path.join(
Path.home(),
'chatterbot_corpus',
'data'
)
# Default to home directory if corpus package not installed
DATA_DIRECTORY = Path.home() / 'chatterbot_corpus' / 'data'

# Only support YAML formats for now
CORPUS_EXTENSIONS = ['yml', 'yaml']

# Simple cache for loaded corpus files
_corpus_cache = {}
Comment thread
annuaicoder marked this conversation as resolved.


CORPUS_EXTENSION = 'yml'
@dataclass
class CorpusData:
conversations: List[List[str]]
categories: List[str]
file_path: str


def get_file_path(dotted_path, extension='json') -> str:
def get_file_path(dotted_path: str, extensions: List[str] = CORPUS_EXTENSIONS) -> Path:
"""
Reads a dotted file path and returns the file path.
Convert a dotted path or filesystem path into an actual file path.
Raises FileNotFoundError if the file does not exist.
"""
# If the operating system's file path seperator character is in the string
if os.sep in dotted_path or '/' in dotted_path:
# Assume the path is a valid file path
return dotted_path
path = Path(dotted_path)

# If path already exists, return it
if path.exists():
return path

# Split dotted path
parts = dotted_path.split('.')
if parts[0] == 'chatterbot':
parts.pop(0)
parts[0] = DATA_DIRECTORY
parts[0] = str(DATA_DIRECTORY)

base_path = Path(*parts)

corpus_path = os.path.join(*parts)
# Check for file existence with supported extensions
for ext in extensions:
candidate = base_path.with_suffix(f'.{ext}')
if candidate.exists():
return candidate

path_with_extension = '{}.{}'.format(corpus_path, extension)
if os.path.exists(path_with_extension):
corpus_path = path_with_extension
# If directory exists, return it
if base_path.is_dir():
return base_path

return corpus_path
raise FileNotFoundError(f"Corpus file or directory not found for: {dotted_path}")


def read_corpus(file_name) -> dict:
def read_corpus(file_path: Path) -> dict:
"""
Read and return the data from a corpus json file.
Read a YAML corpus file and return its contents.
Caches results for repeated access.
"""
if file_path in _corpus_cache:
return _corpus_cache[file_path]

try:
import yaml
except ImportError:
Expand All @@ -55,37 +75,49 @@ def read_corpus(file_name) -> dict:
)
raise OptionalDependencyImportError(message)

with io.open(file_name, encoding='utf-8') as data_file:
return yaml.safe_load(data_file)
try:
with io.open(file_path, encoding='utf-8') as f:
data = yaml.safe_load(f)
except Exception as e:
raise RuntimeError(f"Failed to read corpus file {file_path}: {e}") from e

if not isinstance(data, dict):
raise ValueError(f"Corpus file {file_path} did not return a dictionary.")

_corpus_cache[file_path] = data
return data


def list_corpus_files(dotted_path) -> list[str]:
def list_corpus_files(dotted_path: str) -> List[Path]:
"""
Return a list of file paths to each data file in the specified corpus.
Return a sorted list of all corpus files (with supported extensions)
in the given dotted path or directory.
"""
corpus_path = get_file_path(dotted_path, extension=CORPUS_EXTENSION)
paths = []
path = get_file_path(dotted_path)
files: List[Path] = []

if os.path.isdir(corpus_path):
paths = glob.glob(corpus_path + '/**/*.' + CORPUS_EXTENSION, recursive=True)
if path.is_dir():
for ext in CORPUS_EXTENSIONS:
files.extend(path.rglob(f'*.{ext}'))
else:
paths.append(corpus_path)
files.append(path)

paths.sort()
return paths
return sorted(files)


def load_corpus(*data_file_paths):
def load_corpus(*data_file_paths: str) -> Generator[CorpusData, None, None]:
"""
Return the data contained within a specified corpus.
Yield CorpusData objects for each specified corpus file.
"""
for file_path in data_file_paths:
corpus = []
corpus_data = read_corpus(file_path)

conversations = corpus_data.get('conversations', [])
corpus.extend(conversations)

categories = corpus_data.get('categories', [])

yield corpus, categories, file_path
for file_path_str in data_file_paths:
path = get_file_path(file_path_str)
if path.is_dir():
files = list_corpus_files(path)
else:
files = [path]

for file in files:
corpus_data = read_corpus(file)
conversations = corpus_data.get('conversations', [])
categories = corpus_data.get('categories', [])
yield CorpusData(conversations, categories, str(file))
Loading