|
| 1 | +"""Distance measure between two hyperedges. |
| 2 | +
|
| 3 | +The distance is an ordered *tree edit distance* (Zhang-Shasha): the minimum |
| 4 | +cost of a sequence of node relabel / insert / delete operations that turns one |
| 5 | +edge-tree into the other. It accounts for both the nesting structure and the |
| 6 | +atom content of the edges, and is polynomial in the size of the edges, which |
| 7 | +makes it tractable for the (small) edges produced by sentence parsing. |
| 8 | +
|
| 9 | +The raw cost is a true metric. By default it is normalised to ``[0, 1]`` by |
| 10 | +dividing by the combined node count of both edges, so ``0`` means identical and |
| 11 | +larger values mean more different. |
| 12 | +""" |
| 13 | + |
| 14 | +from __future__ import annotations |
| 15 | + |
| 16 | +from typing import cast |
| 17 | + |
| 18 | +from hyperbase.hyperedge import Atom, Hyperedge |
| 19 | + |
| 20 | + |
| 21 | +def _annotate(root: Hyperedge) -> tuple[list[Hyperedge], list[int], list[int]]: |
| 22 | + """Post-order annotation of a tree for the Zhang-Shasha algorithm. |
| 23 | +
|
| 24 | + Returns ``(nodes, lmld, keyroots)`` where ``nodes`` lists every node of the |
| 25 | + tree in post-order, ``lmld[i]`` is the post-order index of the leftmost-leaf |
| 26 | + descendant of node ``i``, and ``keyroots`` lists the post-order indices of |
| 27 | + the key roots (the root plus every node that has a left sibling). |
| 28 | + """ |
| 29 | + nodes: list[Hyperedge] = [] |
| 30 | + lmld: list[int] = [] |
| 31 | + |
| 32 | + def visit(edge: Hyperedge) -> int: |
| 33 | + if edge.atom or len(edge) == 0: |
| 34 | + idx = len(nodes) |
| 35 | + nodes.append(edge) |
| 36 | + lmld.append(idx) |
| 37 | + return idx |
| 38 | + first_leaf = -1 |
| 39 | + for pos, child in enumerate(edge): |
| 40 | + child_idx = visit(child) |
| 41 | + if pos == 0: |
| 42 | + first_leaf = lmld[child_idx] |
| 43 | + idx = len(nodes) |
| 44 | + nodes.append(edge) |
| 45 | + lmld.append(first_leaf) |
| 46 | + return idx |
| 47 | + |
| 48 | + visit(root) |
| 49 | + |
| 50 | + # The key root for a given leftmost leaf is the highest-indexed node that |
| 51 | + # shares it; iterating in order and overwriting yields exactly that. |
| 52 | + keyroot_for: dict[int, int] = {} |
| 53 | + for i, leaf in enumerate(lmld): |
| 54 | + keyroot_for[leaf] = i |
| 55 | + keyroots = sorted(keyroot_for.values()) |
| 56 | + return nodes, lmld, keyroots |
| 57 | + |
| 58 | + |
| 59 | +def _safe_normalise(edge: Hyperedge) -> Hyperedge: |
| 60 | + """Normalise argument roles, falling back to the edge on malformed input.""" |
| 61 | + from hyperbase.transforms import normalise |
| 62 | + |
| 63 | + try: |
| 64 | + return normalise(edge) |
| 65 | + except (KeyError, RuntimeError): |
| 66 | + return edge |
| 67 | + |
| 68 | + |
| 69 | +def _internal_label(edge: Hyperedge) -> str: |
| 70 | + """Label used to compare two non-atom nodes (their inferred main type).""" |
| 71 | + try: |
| 72 | + return edge.mtype() |
| 73 | + except RuntimeError: |
| 74 | + return "?" |
| 75 | + |
| 76 | + |
| 77 | +def edge_distance( |
| 78 | + edge1: Hyperedge, |
| 79 | + edge2: Hyperedge, |
| 80 | + *, |
| 81 | + normalize: bool = True, |
| 82 | + canonical: bool = True, |
| 83 | + root_weight: float = 0.5, |
| 84 | + type_weight: float = 0.5, |
| 85 | +) -> float: |
| 86 | + """Tree edit distance between two hyperedges. |
| 87 | +
|
| 88 | + Keyword arguments: |
| 89 | + normalize -- divide the raw edit cost by the combined node count, yielding a |
| 90 | + value in ``[0, 1]`` (default: True). When False, the raw cost is |
| 91 | + returned. |
| 92 | + canonical -- normalise the argument-role ordering of both edges first, so |
| 93 | + that edges differing only by a benign argument reordering are treated as |
| 94 | + equal (default: True). |
| 95 | + root_weight -- relative weight of the atom root (the word) when relabelling |
| 96 | + two atoms (default: 0.5). |
| 97 | + type_weight -- relative weight of the atom type/role when relabelling two |
| 98 | + atoms (default: 0.5). |
| 99 | +
|
| 100 | + The atom weights are normalised so that any non-negative pair keeps a single |
| 101 | + atom relabel cost within ``[0, 1]``; ``type_weight=0`` gives a purely |
| 102 | + lexical comparison and ``root_weight=0`` a purely structural one. |
| 103 | + """ |
| 104 | + if canonical: |
| 105 | + edge1 = _safe_normalise(edge1) |
| 106 | + edge2 = _safe_normalise(edge2) |
| 107 | + |
| 108 | + total_weight = root_weight + type_weight |
| 109 | + if total_weight <= 0: |
| 110 | + w_root = w_type = 0.5 |
| 111 | + else: |
| 112 | + w_root = root_weight / total_weight |
| 113 | + w_type = type_weight / total_weight |
| 114 | + |
| 115 | + def atom_cost(a: Atom, b: Atom) -> float: |
| 116 | + root_part = 0.0 if a.root() == b.root() else 1.0 |
| 117 | + if a.type() == b.type(): |
| 118 | + type_part = 0.0 |
| 119 | + elif a.mtype() == b.mtype(): |
| 120 | + type_part = 0.5 |
| 121 | + else: |
| 122 | + type_part = 1.0 |
| 123 | + return w_root * root_part + w_type * type_part |
| 124 | + |
| 125 | + def relabel_cost(a: Hyperedge, b: Hyperedge) -> float: |
| 126 | + if a.atom and b.atom: |
| 127 | + return atom_cost(cast(Atom, a), cast(Atom, b)) |
| 128 | + if a.not_atom and b.not_atom: |
| 129 | + return 0.0 if _internal_label(a) == _internal_label(b) else 0.5 |
| 130 | + # one atom and one non-atom: maximal substitution cost |
| 131 | + return 1.0 |
| 132 | + |
| 133 | + nodes1, lmld1, keyroots1 = _annotate(edge1) |
| 134 | + nodes2, lmld2, keyroots2 = _annotate(edge2) |
| 135 | + n1 = len(nodes1) |
| 136 | + n2 = len(nodes2) |
| 137 | + |
| 138 | + treedist = [[0.0] * n2 for _ in range(n1)] |
| 139 | + |
| 140 | + for i in keyroots1: |
| 141 | + li = lmld1[i] |
| 142 | + for j in keyroots2: |
| 143 | + lj = lmld2[j] |
| 144 | + rows = i - li + 2 |
| 145 | + cols = j - lj + 2 |
| 146 | + fd = [[0.0] * cols for _ in range(rows)] |
| 147 | + for di in range(1, rows): |
| 148 | + fd[di][0] = fd[di - 1][0] + 1.0 |
| 149 | + for dj in range(1, cols): |
| 150 | + fd[0][dj] = fd[0][dj - 1] + 1.0 |
| 151 | + for di in range(1, rows): |
| 152 | + ni = li + di - 1 |
| 153 | + for dj in range(1, cols): |
| 154 | + nj = lj + dj - 1 |
| 155 | + delete = fd[di - 1][dj] + 1.0 |
| 156 | + insert = fd[di][dj - 1] + 1.0 |
| 157 | + if lmld1[ni] == li and lmld2[nj] == lj: |
| 158 | + relabel = fd[di - 1][dj - 1] + relabel_cost( |
| 159 | + nodes1[ni], nodes2[nj] |
| 160 | + ) |
| 161 | + best = min(delete, insert, relabel) |
| 162 | + fd[di][dj] = best |
| 163 | + treedist[ni][nj] = best |
| 164 | + else: |
| 165 | + relabel = fd[lmld1[ni] - li][lmld2[nj] - lj] + treedist[ni][nj] |
| 166 | + fd[di][dj] = min(delete, insert, relabel) |
| 167 | + |
| 168 | + ted = treedist[n1 - 1][n2 - 1] |
| 169 | + if not normalize: |
| 170 | + return ted |
| 171 | + total = n1 + n2 |
| 172 | + if total == 0: |
| 173 | + return 0.0 |
| 174 | + return ted / total |
0 commit comments