|
| 1 | +"""Parser-agnostic parse-correctness checking for parsed edges. |
| 2 | +
|
| 3 | +Where :func:`hyperbase.correctness.check_correctness` validates a hyperedge in |
| 4 | +isolation, this module checks a whole *parse*: the edge plus how its atoms map |
| 5 | +onto the original tokens. :func:`check_parse_correctness` combines the hard |
| 6 | +grammar errors, the soft structural-quality errors, and token-matching |
| 7 | +validation so any parser plugin can score the output of a parse against the |
| 8 | +original tokens. The third value in each error tuple is a severity (lower is |
| 9 | +worse): ``0`` for hard correctness failures, ``1`` for token-mismatch issues, |
| 10 | +``2`` for argrole problems, ``3`` for junction issues. |
| 11 | +""" |
| 12 | + |
| 13 | +from hyperbase.correctness import check_structural_quality |
| 14 | +from hyperbase.hyperedge import Hyperedge |
| 15 | +from hyperbase.parsers.utils import filter_alphanumeric_strings |
| 16 | + |
| 17 | + |
| 18 | +def check_parse_correctness( |
| 19 | + edge: Hyperedge, |
| 20 | + tokens: list[str], |
| 21 | +) -> dict[str | Hyperedge, list[tuple[str, str, int]]]: |
| 22 | + |
| 23 | + # Hard grammar failures (severity 0), keyed by subedge. |
| 24 | + errors: dict[str | Hyperedge, list[tuple[str, str, int]]] = { |
| 25 | + k: list(v) for k, v in edge.check_correctness().items() |
| 26 | + } |
| 27 | + |
| 28 | + structural_errors = check_structural_quality(edge) |
| 29 | + for k, v in structural_errors.items(): |
| 30 | + if k in errors: |
| 31 | + errors[k].extend(v) |
| 32 | + else: |
| 33 | + errors[k] = v |
| 34 | + |
| 35 | + # Only check token matching if we have a valid edge |
| 36 | + if edge: |
| 37 | + try: |
| 38 | + tokens = filter_alphanumeric_strings(tokens) |
| 39 | + roots: list[str] = filter_alphanumeric_strings( |
| 40 | + [atom.label() for atom in edge.all_atoms()] |
| 41 | + ) |
| 42 | + |
| 43 | + # Track which tokens and roots have been matched |
| 44 | + matched_tokens: set[int] = set() |
| 45 | + matched_roots: set[int] = set() |
| 46 | + |
| 47 | + # Count remaining unmatched instances of each root |
| 48 | + def count_unmatched_roots(root_value: str) -> int: |
| 49 | + count = 0 |
| 50 | + for root_idx, root in enumerate(roots): |
| 51 | + if root == root_value and root_idx not in matched_roots: |
| 52 | + count += 1 |
| 53 | + return count |
| 54 | + |
| 55 | + # Go through each token and try to find matching roots |
| 56 | + for token_idx, token in enumerate(tokens): |
| 57 | + if token_idx in matched_tokens: |
| 58 | + continue # Already matched this token |
| 59 | + |
| 60 | + # Try exact match first |
| 61 | + unmatched_root_count = count_unmatched_roots(token) |
| 62 | + if unmatched_root_count > 0: |
| 63 | + matched_tokens.add(token_idx) |
| 64 | + # Find an unmatched instance of this root |
| 65 | + for root_idx, root in enumerate(roots): |
| 66 | + if root == token and root_idx not in matched_roots: |
| 67 | + matched_roots.add(root_idx) |
| 68 | + break |
| 69 | + |
| 70 | + else: |
| 71 | + # Try to find a root that matches this token exactly (case (a)) |
| 72 | + for root_idx, root in enumerate(roots): |
| 73 | + if root_idx in matched_roots: |
| 74 | + continue # Already matched this root |
| 75 | + |
| 76 | + if root == token: |
| 77 | + matched_tokens.add(token_idx) |
| 78 | + matched_roots.add(root_idx) |
| 79 | + break |
| 80 | + |
| 81 | + # If no exact match, try to find combination of roots |
| 82 | + # that form this token (case (b)) |
| 83 | + if token_idx not in matched_tokens: |
| 84 | + # Look for sequence of consecutive roots that concatenate |
| 85 | + # to form the token |
| 86 | + for root_start_idx in range(len(roots)): |
| 87 | + if root_start_idx in matched_roots: |
| 88 | + continue # This root is already matched |
| 89 | + |
| 90 | + concatenated = "" |
| 91 | + root_sequence: list[int] = [] |
| 92 | + |
| 93 | + for root_idx in range(root_start_idx, len(roots)): |
| 94 | + if root_idx in matched_roots: |
| 95 | + # Can't use matched roots in sequence |
| 96 | + break |
| 97 | + |
| 98 | + root = roots[root_idx] |
| 99 | + concatenated += root |
| 100 | + root_sequence.append(root_idx) |
| 101 | + |
| 102 | + if concatenated == token: |
| 103 | + # Found a matching sequence |
| 104 | + matched_tokens.add(token_idx) |
| 105 | + for idx in root_sequence: |
| 106 | + matched_roots.add(idx) |
| 107 | + break |
| 108 | + |
| 109 | + if len(concatenated) >= len(token): |
| 110 | + # Gone too far or exact match found |
| 111 | + break |
| 112 | + |
| 113 | + if token_idx in matched_tokens: |
| 114 | + break # Found a match, no need to try other |
| 115 | + # starting positions |
| 116 | + |
| 117 | + # If still no match, try case (c): root that matches this token |
| 118 | + # and subsequent tokens |
| 119 | + if token_idx not in matched_tokens: |
| 120 | + # Look for a root that can match this token plus some |
| 121 | + # following tokens |
| 122 | + for root_idx, root in enumerate(roots): |
| 123 | + if root_idx in matched_roots: |
| 124 | + continue # Already matched |
| 125 | + |
| 126 | + concatenated = "" |
| 127 | + token_sequence: list[int] = [] |
| 128 | + |
| 129 | + for next_token_idx in range(token_idx, len(tokens)): |
| 130 | + if next_token_idx in matched_tokens: |
| 131 | + continue # Already matched |
| 132 | + |
| 133 | + concatenated += tokens[next_token_idx] |
| 134 | + token_sequence.append(next_token_idx) |
| 135 | + |
| 136 | + if concatenated == root: |
| 137 | + # Found a root that matches multiple tokens |
| 138 | + matched_roots.add(root_idx) |
| 139 | + for idx in token_sequence: |
| 140 | + matched_tokens.add(idx) |
| 141 | + break |
| 142 | + |
| 143 | + if len(concatenated) >= len(root): |
| 144 | + break |
| 145 | + |
| 146 | + # If still no match, try case (d): multi-token to multi-root |
| 147 | + # concatenation matching |
| 148 | + if token_idx not in matched_tokens: |
| 149 | + # First, try positional matching (existing logic) |
| 150 | + for root_start_idx in range(len(roots)): |
| 151 | + if root_start_idx in matched_roots: |
| 152 | + continue # This root is already matched |
| 153 | + |
| 154 | + tokens_concatenated = "" |
| 155 | + roots_concatenated = "" |
| 156 | + token_sequence_d: list[int] = [] |
| 157 | + root_sequence_d: list[int] = [] |
| 158 | + |
| 159 | + max_tokens = min( |
| 160 | + len(tokens) - token_idx, len(roots) - root_start_idx |
| 161 | + ) |
| 162 | + |
| 163 | + for i in range(max_tokens): |
| 164 | + current_token_idx = token_idx + i |
| 165 | + current_root_idx = root_start_idx + i |
| 166 | + |
| 167 | + if ( |
| 168 | + current_token_idx in matched_tokens |
| 169 | + or current_root_idx in matched_roots |
| 170 | + ): |
| 171 | + break # Can't use already matched items |
| 172 | + |
| 173 | + tokens_concatenated += tokens[current_token_idx] |
| 174 | + roots_concatenated += roots[current_root_idx] |
| 175 | + token_sequence_d.append(current_token_idx) |
| 176 | + root_sequence_d.append(current_root_idx) |
| 177 | + |
| 178 | + # Check if concatenations match |
| 179 | + if ( |
| 180 | + tokens_concatenated == roots_concatenated |
| 181 | + and tokens_concatenated |
| 182 | + ): |
| 183 | + # Found a match - mark all as matched |
| 184 | + for idx in token_sequence_d: |
| 185 | + matched_tokens.add(idx) |
| 186 | + for idx in root_sequence_d: |
| 187 | + matched_roots.add(idx) |
| 188 | + break |
| 189 | + |
| 190 | + # Stop if we've gone too far |
| 191 | + # (tokens longer than reasonable) |
| 192 | + if ( |
| 193 | + len(tokens_concatenated) > 10 |
| 194 | + or len(roots_concatenated) > 10 |
| 195 | + ): |
| 196 | + break |
| 197 | + |
| 198 | + if token_idx in matched_tokens: |
| 199 | + break # Found a match, no need to try |
| 200 | + # other root positions |
| 201 | + |
| 202 | + # If still no match, try non-positional contraction matching |
| 203 | + if ( |
| 204 | + token_idx not in matched_tokens |
| 205 | + # Look for contractions by trying to combine this token |
| 206 | + # with the next one and matching against any two available |
| 207 | + # roots in the roots list (not necessarily consecutive) |
| 208 | + and ( |
| 209 | + token_idx + 1 < len(tokens) |
| 210 | + and token_idx + 1 not in matched_tokens |
| 211 | + ) |
| 212 | + ): |
| 213 | + token_concat = tokens[token_idx] + tokens[token_idx + 1] |
| 214 | + |
| 215 | + # Try to find any two available roots |
| 216 | + # (not necessarily consecutive) that concatenate |
| 217 | + # to the same value |
| 218 | + for root_idx1 in range(len(roots)): |
| 219 | + if root_idx1 in matched_roots: |
| 220 | + continue # Can't use already matched roots |
| 221 | + |
| 222 | + for root_idx2 in range(len(roots)): |
| 223 | + if ( |
| 224 | + root_idx2 in matched_roots |
| 225 | + or root_idx2 == root_idx1 |
| 226 | + ): |
| 227 | + continue # Can't use already matched roots |
| 228 | + # or same root |
| 229 | + |
| 230 | + root_concat = roots[root_idx1] + roots[root_idx2] |
| 231 | + |
| 232 | + if token_concat == root_concat: |
| 233 | + # Found a contraction match! |
| 234 | + matched_tokens.add(token_idx) |
| 235 | + matched_tokens.add(token_idx + 1) |
| 236 | + matched_roots.add(root_idx1) |
| 237 | + matched_roots.add(root_idx2) |
| 238 | + break |
| 239 | + |
| 240 | + if token_idx in matched_tokens: |
| 241 | + break # Found a match, no need to try |
| 242 | + # other combinations |
| 243 | + |
| 244 | + token_matching_errors: list[tuple[str, str, int]] = [] |
| 245 | + # Report unmatched roots |
| 246 | + for root_idx, root in enumerate(roots): |
| 247 | + if root_idx not in matched_roots: |
| 248 | + token_matching_errors.append( |
| 249 | + ( |
| 250 | + "root-without-token", |
| 251 | + f"Atom root '{root}' in the parse is used more times than " |
| 252 | + "it appears in the source sentence.", |
| 253 | + 1, |
| 254 | + ) |
| 255 | + ) |
| 256 | + |
| 257 | + # Report unmatched tokens |
| 258 | + for token_idx, token in enumerate(tokens): |
| 259 | + if token_idx not in matched_tokens: |
| 260 | + token_matching_errors.append( |
| 261 | + ( |
| 262 | + "token-unused", |
| 263 | + f"Token '{token}' from the source sentence is not used by " |
| 264 | + "any atom in the parse.", |
| 265 | + 1, |
| 266 | + ) |
| 267 | + ) |
| 268 | + |
| 269 | + if len(token_matching_errors) > 0: |
| 270 | + errors["token-matching"] = token_matching_errors |
| 271 | + |
| 272 | + except (AttributeError, Exception): |
| 273 | + # If token counting fails (e.g., edge is invalid), skip it |
| 274 | + pass |
| 275 | + |
| 276 | + return errors |
0 commit comments