Skip to content

Commit e83d58f

Browse files
committed
Hyperedge.distance()
1 parent 4e5d855 commit e83d58f

5 files changed

Lines changed: 286 additions & 0 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
### Added
66

77
- `Hyperedge.transform()`: pattern-based rewrites.
8+
- `Hyperedge.distance()`: distance metric based on tree edit distance.
89
- per-atom `tok_pos`/`text_span` and per-root `Hyperedge.tokens` source-position metadata.
910
- continuity-aware sub-edge text derivation with verbatim character-offset slicing.
1011
- `transforms.tok_pos_tree(edge)` rebuilds the parallel `tok_pos` tree from in-memory atoms.

src/hyperbase/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from hyperbase.builders import hedge
2+
from hyperbase.distance import edge_distance
23
from hyperbase.loaders import load_edges
34
from hyperbase.parsers import get_parser
45

56
__all__ = [
7+
"edge_distance",
68
"get_parser",
79
"hedge",
810
"load_edges",

src/hyperbase/distance.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
"""Distance measure between two hyperedges.
2+
3+
The distance is an ordered *tree edit distance* (Zhang-Shasha): the minimum
4+
cost of a sequence of node relabel / insert / delete operations that turns one
5+
edge-tree into the other. It accounts for both the nesting structure and the
6+
atom content of the edges, and is polynomial in the size of the edges, which
7+
makes it tractable for the (small) edges produced by sentence parsing.
8+
9+
The raw cost is a true metric. By default it is normalised to ``[0, 1]`` by
10+
dividing by the combined node count of both edges, so ``0`` means identical and
11+
larger values mean more different.
12+
"""
13+
14+
from __future__ import annotations
15+
16+
from typing import cast
17+
18+
from hyperbase.hyperedge import Atom, Hyperedge
19+
20+
21+
def _annotate(root: Hyperedge) -> tuple[list[Hyperedge], list[int], list[int]]:
22+
"""Post-order annotation of a tree for the Zhang-Shasha algorithm.
23+
24+
Returns ``(nodes, lmld, keyroots)`` where ``nodes`` lists every node of the
25+
tree in post-order, ``lmld[i]`` is the post-order index of the leftmost-leaf
26+
descendant of node ``i``, and ``keyroots`` lists the post-order indices of
27+
the key roots (the root plus every node that has a left sibling).
28+
"""
29+
nodes: list[Hyperedge] = []
30+
lmld: list[int] = []
31+
32+
def visit(edge: Hyperedge) -> int:
33+
if edge.atom or len(edge) == 0:
34+
idx = len(nodes)
35+
nodes.append(edge)
36+
lmld.append(idx)
37+
return idx
38+
first_leaf = -1
39+
for pos, child in enumerate(edge):
40+
child_idx = visit(child)
41+
if pos == 0:
42+
first_leaf = lmld[child_idx]
43+
idx = len(nodes)
44+
nodes.append(edge)
45+
lmld.append(first_leaf)
46+
return idx
47+
48+
visit(root)
49+
50+
# The key root for a given leftmost leaf is the highest-indexed node that
51+
# shares it; iterating in order and overwriting yields exactly that.
52+
keyroot_for: dict[int, int] = {}
53+
for i, leaf in enumerate(lmld):
54+
keyroot_for[leaf] = i
55+
keyroots = sorted(keyroot_for.values())
56+
return nodes, lmld, keyroots
57+
58+
59+
def _safe_normalise(edge: Hyperedge) -> Hyperedge:
60+
"""Normalise argument roles, falling back to the edge on malformed input."""
61+
from hyperbase.transforms import normalise
62+
63+
try:
64+
return normalise(edge)
65+
except (KeyError, RuntimeError):
66+
return edge
67+
68+
69+
def _internal_label(edge: Hyperedge) -> str:
70+
"""Label used to compare two non-atom nodes (their inferred main type)."""
71+
try:
72+
return edge.mtype()
73+
except RuntimeError:
74+
return "?"
75+
76+
77+
def edge_distance(
78+
edge1: Hyperedge,
79+
edge2: Hyperedge,
80+
*,
81+
normalize: bool = True,
82+
canonical: bool = True,
83+
root_weight: float = 0.5,
84+
type_weight: float = 0.5,
85+
) -> float:
86+
"""Tree edit distance between two hyperedges.
87+
88+
Keyword arguments:
89+
normalize -- divide the raw edit cost by the combined node count, yielding a
90+
value in ``[0, 1]`` (default: True). When False, the raw cost is
91+
returned.
92+
canonical -- normalise the argument-role ordering of both edges first, so
93+
that edges differing only by a benign argument reordering are treated as
94+
equal (default: True).
95+
root_weight -- relative weight of the atom root (the word) when relabelling
96+
two atoms (default: 0.5).
97+
type_weight -- relative weight of the atom type/role when relabelling two
98+
atoms (default: 0.5).
99+
100+
The atom weights are normalised so that any non-negative pair keeps a single
101+
atom relabel cost within ``[0, 1]``; ``type_weight=0`` gives a purely
102+
lexical comparison and ``root_weight=0`` a purely structural one.
103+
"""
104+
if canonical:
105+
edge1 = _safe_normalise(edge1)
106+
edge2 = _safe_normalise(edge2)
107+
108+
total_weight = root_weight + type_weight
109+
if total_weight <= 0:
110+
w_root = w_type = 0.5
111+
else:
112+
w_root = root_weight / total_weight
113+
w_type = type_weight / total_weight
114+
115+
def atom_cost(a: Atom, b: Atom) -> float:
116+
root_part = 0.0 if a.root() == b.root() else 1.0
117+
if a.type() == b.type():
118+
type_part = 0.0
119+
elif a.mtype() == b.mtype():
120+
type_part = 0.5
121+
else:
122+
type_part = 1.0
123+
return w_root * root_part + w_type * type_part
124+
125+
def relabel_cost(a: Hyperedge, b: Hyperedge) -> float:
126+
if a.atom and b.atom:
127+
return atom_cost(cast(Atom, a), cast(Atom, b))
128+
if a.not_atom and b.not_atom:
129+
return 0.0 if _internal_label(a) == _internal_label(b) else 0.5
130+
# one atom and one non-atom: maximal substitution cost
131+
return 1.0
132+
133+
nodes1, lmld1, keyroots1 = _annotate(edge1)
134+
nodes2, lmld2, keyroots2 = _annotate(edge2)
135+
n1 = len(nodes1)
136+
n2 = len(nodes2)
137+
138+
treedist = [[0.0] * n2 for _ in range(n1)]
139+
140+
for i in keyroots1:
141+
li = lmld1[i]
142+
for j in keyroots2:
143+
lj = lmld2[j]
144+
rows = i - li + 2
145+
cols = j - lj + 2
146+
fd = [[0.0] * cols for _ in range(rows)]
147+
for di in range(1, rows):
148+
fd[di][0] = fd[di - 1][0] + 1.0
149+
for dj in range(1, cols):
150+
fd[0][dj] = fd[0][dj - 1] + 1.0
151+
for di in range(1, rows):
152+
ni = li + di - 1
153+
for dj in range(1, cols):
154+
nj = lj + dj - 1
155+
delete = fd[di - 1][dj] + 1.0
156+
insert = fd[di][dj - 1] + 1.0
157+
if lmld1[ni] == li and lmld2[nj] == lj:
158+
relabel = fd[di - 1][dj - 1] + relabel_cost(
159+
nodes1[ni], nodes2[nj]
160+
)
161+
best = min(delete, insert, relabel)
162+
fd[di][dj] = best
163+
treedist[ni][nj] = best
164+
else:
165+
relabel = fd[lmld1[ni] - li][lmld2[nj] - lj] + treedist[ni][nj]
166+
fd[di][dj] = min(delete, insert, relabel)
167+
168+
ted = treedist[n1 - 1][n2 - 1]
169+
if not normalize:
170+
return ted
171+
total = n1 + n2
172+
if total == 0:
173+
return 0.0
174+
return ted / total

src/hyperbase/hyperedge.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,30 @@ def normalise(self) -> Hyperedge:
440440

441441
return _propagate_root_text(self, normalise(self))
442442

443+
def distance(
444+
self,
445+
other: Hyperedge,
446+
*,
447+
normalize: bool = True,
448+
canonical: bool = True,
449+
root_weight: float = 0.5,
450+
type_weight: float = 0.5,
451+
) -> float:
452+
"""Tree edit distance between this edge and ``other``.
453+
454+
See ``hyperbase.distance.edge_distance`` for details.
455+
"""
456+
from hyperbase.distance import edge_distance
457+
458+
return edge_distance(
459+
self,
460+
other,
461+
normalize=normalize,
462+
canonical=canonical,
463+
root_weight=root_weight,
464+
type_weight=type_weight,
465+
)
466+
443467
############
444468
# patterns #
445469
############

tests/test_distance.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import unittest
2+
3+
from hyperbase import edge_distance
4+
from hyperbase.builders import hedge
5+
6+
7+
class TestDistance(unittest.TestCase):
8+
def test_identity_atom(self):
9+
assert edge_distance(hedge("berlin/C"), hedge("berlin/C")) == 0.0
10+
11+
def test_identity_nested(self):
12+
edge = hedge("(is/P.so (the/M sky/C) blue/C)")
13+
assert edge_distance(edge, edge) == 0.0
14+
15+
def test_symmetry(self):
16+
a = hedge("(is/P.so berlin/C nice/C)")
17+
b = hedge("(loves/P.so mary/C art/C)")
18+
assert edge_distance(a, b) == edge_distance(b, a)
19+
20+
def test_bounds(self):
21+
pairs = [
22+
("berlin/C", "paris/C"),
23+
("berlin/C", "loves/P"),
24+
("(is/P.so berlin/C nice/C)", "(loves/P.so mary/C art/C)"),
25+
("berlin/C", "(is/P.so berlin/C nice/C)"),
26+
]
27+
for s1, s2 in pairs:
28+
d = edge_distance(hedge(s1), hedge(s2))
29+
assert 0.0 <= d <= 1.0
30+
31+
def test_graded_atom_cost(self):
32+
base = hedge("(is/P.so berlin/C nice/C)")
33+
subtype = hedge("(is/P.so berlin/C nice/Cp)") # same root, subtype only
34+
root = hedge("(is/P.so berlin/C ugly/C)") # different root, same type
35+
both = hedge("(is/P.so berlin/C loves/P)") # different root and type
36+
d_subtype = edge_distance(base, subtype)
37+
d_root = edge_distance(base, root)
38+
d_both = edge_distance(base, both)
39+
assert 0.0 < d_subtype < d_root < d_both
40+
41+
def test_structural_change_costs_more(self):
42+
base = hedge("(is/P.so berlin/C nice/C)")
43+
relabel = hedge("(is/P.so berlin/C ugly/C)") # one leaf relabel
44+
# extra argument (valid connector with three argroles)
45+
extra_arg = hedge("(is/P.soc berlin/C nice/C today/C)")
46+
assert edge_distance(base, relabel) < edge_distance(base, extra_arg)
47+
48+
def test_canonical_invariance(self):
49+
# Same relation, arguments and roles permuted (s,o vs o,s).
50+
a = hedge("(is/P.so berlin/C nice/C)")
51+
b = hedge("(is/P.os nice/C berlin/C)")
52+
assert edge_distance(a, b, canonical=True) == 0.0
53+
assert edge_distance(a, b, canonical=False) > 0.0
54+
55+
def test_raw_cost(self):
56+
# A single leaf relabel (same type, different root) -> raw cost 0.5.
57+
base = hedge("(is/P.so berlin/C nice/C)")
58+
relabel = hedge("(is/P.so berlin/C ugly/C)")
59+
assert edge_distance(base, relabel, normalize=False) == 0.5
60+
61+
def test_weight_extremes(self):
62+
a = hedge("berlin/C")
63+
b = hedge("berlin/P") # same root, different type
64+
# type_weight=0 -> purely lexical: same word -> identical
65+
assert edge_distance(a, b, type_weight=0.0, root_weight=1.0) == 0.0
66+
# root_weight=0 -> purely structural: different type -> non-zero
67+
assert edge_distance(a, b, root_weight=0.0, type_weight=1.0) > 0.0
68+
69+
def test_method(self):
70+
a = hedge("(is/P.so berlin/C nice/C)")
71+
b = hedge("(is/P.so berlin/C ugly/C)")
72+
assert a.distance(b) == edge_distance(a, b)
73+
74+
def test_triangle_inequality_raw(self):
75+
a = hedge("(is/P.so berlin/C nice/C)")
76+
b = hedge("(is/P.so berlin/C ugly/C)")
77+
c = hedge("(loves/P.so mary/C art/C)")
78+
d_ac = edge_distance(a, c, normalize=False)
79+
d_ab = edge_distance(a, b, normalize=False)
80+
d_bc = edge_distance(b, c, normalize=False)
81+
assert d_ac <= d_ab + d_bc
82+
83+
84+
if __name__ == "__main__":
85+
unittest.main()

0 commit comments

Comments
 (0)