|
| 1 | +# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai> |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +import mimetypes |
| 6 | +import re |
| 7 | +from collections import defaultdict |
| 8 | +from pathlib import Path |
| 9 | +from typing import Dict, List, Optional |
| 10 | + |
| 11 | +from haystack import component |
| 12 | +from haystack.components.routers.file_type_router import CUSTOM_MIMETYPES |
| 13 | +from haystack.dataclasses import Document |
| 14 | + |
| 15 | + |
| 16 | +@component |
| 17 | +class DocumentTypeRouter: |
| 18 | + """ |
| 19 | + Categorizes documents by MIME types based on their metadata. |
| 20 | +
|
| 21 | + DocumentTypeRouter is used to dynamically route documents within a pipeline based on their MIME types. |
| 22 | + It supports exact MIME type matches and regex patterns. |
| 23 | +
|
| 24 | + MIME types can be extracted directly from document metadata or inferred from file paths using standard or |
| 25 | + user-supplied MIME type mappings. |
| 26 | +
|
| 27 | + ### Usage example |
| 28 | +
|
| 29 | + ```python |
| 30 | + from haystack_experimental.components.routers import DocumentTypeRouter |
| 31 | + from haystack.dataclasses import Document |
| 32 | +
|
| 33 | + docs = [ |
| 34 | + Document(content="Example text", meta={"file_path": "example.txt"}), |
| 35 | + Document(content="Another document", meta={"mime_type": "application/pdf"}), |
| 36 | + Document(content="Unknown type") |
| 37 | + ] |
| 38 | +
|
| 39 | + router = DocumentTypeRouter( |
| 40 | + mime_type_meta_field="mime_type", |
| 41 | + file_path_meta_field="file_path", |
| 42 | + mime_types=["text/plain", "application/pdf"] |
| 43 | + ) |
| 44 | +
|
| 45 | + result = router.run(documents=docs) |
| 46 | + print(result) |
| 47 | + ``` |
| 48 | +
|
| 49 | + Expected output: |
| 50 | + ```python |
| 51 | + { |
| 52 | + "text/plain": [Document(...)], |
| 53 | + "application/pdf": [Document(...)], |
| 54 | + "unclassified": [Document(...)] |
| 55 | + } |
| 56 | + ``` |
| 57 | + """ |
| 58 | + |
| 59 | + def __init__( |
| 60 | + self, |
| 61 | + *, |
| 62 | + mime_type_meta_field: Optional[str] = None, |
| 63 | + file_path_meta_field: Optional[str] = None, |
| 64 | + mime_types: List[str], |
| 65 | + additional_mimetypes: Optional[Dict[str, str]] = None, |
| 66 | + ) -> None: |
| 67 | + """ |
| 68 | + Initialize the DocumentTypeRouter component. |
| 69 | +
|
| 70 | + :param mime_type_meta_field: |
| 71 | + Optional name of the metadata field that holds the MIME type. |
| 72 | +
|
| 73 | + :param file_path_meta_field: |
| 74 | + Optional name of the metadata field that holds the file path. Used to infer the MIME type if |
| 75 | + `mime_type_meta_field` is not provided or missing in a document. |
| 76 | +
|
| 77 | + :param mime_types: |
| 78 | + A list of MIME types or regex patterns to classify the input documents. |
| 79 | + (for example: `["text/plain", "audio/x-wav", "image/jpeg"]`). |
| 80 | +
|
| 81 | + :param additional_mimetypes: |
| 82 | + Optional dictionary mapping MIME types to file extensions to enhance or override the standard |
| 83 | + `mimetypes` module. Useful when working with uncommon or custom file types. |
| 84 | + For example: `{"application/vnd.custom-type": ".custom"}`. |
| 85 | +
|
| 86 | + :raises ValueError: If `mime_types` is empty or if both `mime_type_meta_field` and `file_path_meta_field` are |
| 87 | + not provided. |
| 88 | + """ |
| 89 | + if not mime_types: |
| 90 | + raise ValueError("The list of mime types cannot be empty.") |
| 91 | + |
| 92 | + if mime_type_meta_field is None and file_path_meta_field is None: |
| 93 | + raise ValueError( |
| 94 | + "At least one of 'mime_type_meta_field' or 'file_path_meta_field' must be provided to determine MIME " |
| 95 | + "types." |
| 96 | + ) |
| 97 | + self.mime_type_meta_field = mime_type_meta_field |
| 98 | + self.file_path_meta_field = file_path_meta_field |
| 99 | + |
| 100 | + if additional_mimetypes: |
| 101 | + for mime, ext in additional_mimetypes.items(): |
| 102 | + mimetypes.add_type(mime, ext) |
| 103 | + |
| 104 | + self._mime_type_patterns = [] |
| 105 | + for mime_type in mime_types: |
| 106 | + try: |
| 107 | + pattern = re.compile(mime_type) |
| 108 | + except re.error: |
| 109 | + raise ValueError(f"Invalid regex pattern '{mime_type}'.") |
| 110 | + self._mime_type_patterns.append(pattern) |
| 111 | + |
| 112 | + component.set_output_types( |
| 113 | + self, |
| 114 | + unclassified=List[Document], |
| 115 | + **dict.fromkeys(mime_types, List[Document]), |
| 116 | + ) |
| 117 | + self.mime_types = mime_types |
| 118 | + self.additional_mimetypes = additional_mimetypes |
| 119 | + |
| 120 | + def run(self, documents: List[Document]) -> Dict[str, List[Document]]: |
| 121 | + """ |
| 122 | + Categorize input documents into groups based on their MIME type. |
| 123 | +
|
| 124 | + MIME types can either be directly available in document metadata or derived from file paths using the |
| 125 | + standard Python `mimetypes` module and custom mappings. |
| 126 | +
|
| 127 | + :param documents: |
| 128 | + A list of documents to be categorized. |
| 129 | +
|
| 130 | + :returns: |
| 131 | + A dictionary where the keys are MIME types (or `"unclassified"`) and the values are lists of documents. |
| 132 | + """ |
| 133 | + mime_types = defaultdict(list) |
| 134 | + |
| 135 | + for doc in documents: |
| 136 | + mime_type = doc.meta.get(self.mime_type_meta_field) if self.mime_type_meta_field else None |
| 137 | + file_path = doc.meta.get(self.file_path_meta_field) if self.file_path_meta_field else None |
| 138 | + |
| 139 | + if mime_type is None and file_path: |
| 140 | + # if mime_type is not provided, try to guess it from the file path |
| 141 | + mime_type = self._get_mime_type(Path(file_path)) |
| 142 | + |
| 143 | + matched = False |
| 144 | + if mime_type: |
| 145 | + for pattern in self._mime_type_patterns: |
| 146 | + if pattern.fullmatch(mime_type): |
| 147 | + mime_types[pattern.pattern].append(doc) |
| 148 | + matched = True |
| 149 | + break |
| 150 | + if not matched: |
| 151 | + mime_types["unclassified"].append(doc) |
| 152 | + |
| 153 | + return dict(mime_types) |
| 154 | + |
| 155 | + def _get_mime_type(self, path: Path) -> Optional[str]: |
| 156 | + """ |
| 157 | + Get the MIME type of the provided file path. |
| 158 | +
|
| 159 | + :param path: The file path to get the MIME type for. |
| 160 | +
|
| 161 | + :returns: The MIME type of the provided file path, or `None` if the MIME type cannot be determined. |
| 162 | + """ |
| 163 | + extension = path.suffix.lower() |
| 164 | + mime_type = mimetypes.guess_type(path.as_posix())[0] |
| 165 | + # lookup custom mappings if the mime type is not found |
| 166 | + return CUSTOM_MIMETYPES.get(extension, mime_type) |
0 commit comments