Skip to content

Commit 7435b33

Browse files
authored
Merge pull request #338 from urchade/relex
Improve GLiNER-relex architecture
2 parents 9bc9978 + 3dc9355 commit 7435b33

7 files changed

Lines changed: 612 additions & 163 deletions

File tree

gliner/config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,20 +197,32 @@ def __init__(
197197
rel_token: str = "<<REL>>",
198198
adjacency_loss_coef=1.0,
199199
relation_loss_coef=1.0,
200+
augment_data_prob=0.5,
201+
augment_ent_drop_prob=(0.0, 1.0),
202+
augment_rel_drop_prob=(0.0, 0.3),
203+
augment_add_other_prob=0.5,
200204
**kwargs,
201205
):
202206
"""Initialize UniEncoderRelexConfig.
203207
204208
Args:
205209
relations_layer (str, optional): Name of relations layer,
206210
see gliner.modeling.multitask.relations_layers.py. Defaults to None.
211+
Use "none" to enable single-step relation extraction that scores all
212+
entity pair combinations directly without adjacency filtering.
207213
triples_layer (str, optional): Name of triples layer,
208214
see gliner.modeling.multitask.triples_layers.py. Defaults to None.
209215
embed_rel_token (bool, optional): Whether to embed relation tokens. Defaults to True.
210216
rel_token_index (int, optional): Index of relation token. Defaults to -1.
211217
rel_token (str, optional): Relation marker token. Defaults to "<<REL>>".
212218
adjacency_loss_coef (float, optional): Adjacency modeling loss coefficient. Defaults to 1.0.
213219
relation_loss_coef (float, optional): Relation representaton loss coefficient. Defaults to 1.0.
220+
augment_data_prob (float, optional): Probability of applying data augmentation
221+
to an example. Defaults to 0.0 (disabled).
222+
augment_ent_drop_prob (tuple, optional): Range (min, max) from which to sample
223+
the per-type entity drop probability. Defaults to (0.0, 0.4).
224+
augment_rel_drop_prob (tuple, optional): Range (min, max) from which to sample
225+
the per-type relation drop probability. Defaults to (0.0, 0.4).
214226
**kwargs: Additional keyword arguments passed to UniEncoderConfig.
215227
216228
Raises:
@@ -225,6 +237,10 @@ def __init__(
225237
self.rel_token = rel_token
226238
self.adjacency_loss_coef = adjacency_loss_coef
227239
self.relation_loss_coef = relation_loss_coef
240+
self.augment_data_prob = augment_data_prob
241+
self.augment_ent_drop_prob = tuple(augment_ent_drop_prob)
242+
self.augment_rel_drop_prob = tuple(augment_rel_drop_prob)
243+
self.augment_add_other_prob = augment_add_other_prob
228244

229245

230246
class UniEncoderSpanRelexConfig(UniEncoderRelexConfig):

gliner/data_processing/processor.py

Lines changed: 199 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,6 +1417,95 @@ def __init__(self, config, tokenizer, words_splitter):
14171417
super().__init__(config, tokenizer, words_splitter)
14181418
self.rel_token = config.rel_token
14191419

1420+
def augment_example(self, example, ner_negatives=None, other_keyword="other"):
1421+
"""Apply data augmentation by randomly dropping entity/relation types.
1422+
1423+
For each example (triggered by augment_data_prob):
1424+
- Sample a per-type drop probability from augment_ent_drop_prob range
1425+
and drop each entity type with that probability.
1426+
- Sample a per-type drop probability from augment_rel_drop_prob range
1427+
and drop each relation type with that probability.
1428+
- With augment_add_other_prob, entities whose type is dropped get their
1429+
type replaced with "other" so they still participate in relation
1430+
extraction. If "other" is not added and the entity has no active
1431+
(non-dropped) relations, it is dropped entirely.
1432+
- Relations with dropped types are removed; relation indices are remapped
1433+
after entity removal.
1434+
1435+
Args:
1436+
example: Dictionary with 'ner' and 'relations' keys.
1437+
ner_negatives: Pool of negative entity types (unused, kept for API compat).
1438+
other_keyword: Replacement type label for dropped entity types.
1439+
1440+
Returns:
1441+
A (possibly modified) copy of the example with '_dropped_ent_types'
1442+
and '_dropped_rel_types' metadata keys.
1443+
"""
1444+
ner = example.get("ner", [])
1445+
relations = example.get("relations", [])
1446+
1447+
if not ner:
1448+
return example
1449+
1450+
ent_drop_prob = random.uniform(*self.config.augment_ent_drop_prob)
1451+
rel_drop_prob = random.uniform(*self.config.augment_rel_drop_prob)
1452+
add_other = random.random() < self.config.augment_add_other_prob
1453+
1454+
all_ent_types = set(e[-1] for e in ner)
1455+
all_rel_types = set(r[-1] for r in relations) if relations else set()
1456+
1457+
# "other" is exempt from dropping since it's our replacement label
1458+
dropped_ent_types = {t for t in all_ent_types if t != other_keyword and random.random() < ent_drop_prob}
1459+
dropped_rel_types = {t for t in all_rel_types if random.random() < rel_drop_prob}
1460+
1461+
if not dropped_ent_types and not dropped_rel_types:
1462+
return example
1463+
1464+
# Determine which entities participate in non-dropped relations
1465+
entity_has_active_rel = set()
1466+
if relations:
1467+
for head_idx, tail_idx, rel_type in relations:
1468+
if rel_type not in dropped_rel_types:
1469+
entity_has_active_rel.add(head_idx)
1470+
entity_has_active_rel.add(tail_idx)
1471+
1472+
# Process entities: replace with "other", keep, or drop
1473+
new_ner = []
1474+
old_to_new_idx = {}
1475+
1476+
for i, ent in enumerate(ner):
1477+
ent_type = ent[-1]
1478+
1479+
if ent_type in dropped_ent_types and i not in entity_has_active_rel:
1480+
# Entity type dropped and no active relations → drop entity
1481+
continue
1482+
if ent_type in dropped_ent_types and not add_other:
1483+
# Entity type dropped and "other" not enabled → drop entity
1484+
continue
1485+
1486+
old_to_new_idx[i] = len(new_ner)
1487+
if ent_type in dropped_ent_types and add_other:
1488+
# Replace dropped type with "other"
1489+
new_ner.append(list(ent[:-1]) + [other_keyword])
1490+
else:
1491+
new_ner.append(ent)
1492+
1493+
# Update relations: drop relations with dropped types, remap entity indices
1494+
new_relations = []
1495+
for head_idx, tail_idx, rel_type in relations:
1496+
if rel_type in dropped_rel_types:
1497+
continue
1498+
if head_idx in old_to_new_idx and tail_idx in old_to_new_idx:
1499+
new_relations.append([old_to_new_idx[head_idx], old_to_new_idx[tail_idx], rel_type])
1500+
1501+
result = dict(example)
1502+
result["ner"] = new_ner
1503+
result["relations"] = new_relations
1504+
result["_dropped_ent_types"] = dropped_ent_types
1505+
result["_dropped_rel_types"] = dropped_rel_types
1506+
1507+
return result
1508+
14201509
def batch_generate_class_mappings(
14211510
self,
14221511
batch_list: List[Dict],
@@ -1456,6 +1545,10 @@ def batch_generate_class_mappings(
14561545
max_neg_type_ratio = int(self.config.max_neg_type_ratio)
14571546
neg_type_ratio = random.randint(0, max_neg_type_ratio) if max_neg_type_ratio else 0
14581547

1548+
# Augmentation metadata (set by augment_example)
1549+
dropped_ent_types = b.get("_dropped_ent_types", set())
1550+
dropped_rel_types = b.get("_dropped_rel_types", set())
1551+
14591552
# Process NER types
14601553
if "ner_negatives" in b:
14611554
negs_i = b["ner_negatives"]
@@ -1465,7 +1558,9 @@ def batch_generate_class_mappings(
14651558
if "ner_labels" in b:
14661559
types = b["ner_labels"]
14671560
else:
1468-
types = list(set([el[-1] for el in b["ner"]] + negs_i))
1561+
# Exclude dropped entity types ("other" replacements are already in ner)
1562+
ent_types = [el[-1] for el in b["ner"] if el[-1] not in dropped_ent_types]
1563+
types = list(set(ent_types + negs_i))
14691564
random.shuffle(types)
14701565
types = types[: int(self.config.max_types)]
14711566

@@ -1483,7 +1578,9 @@ def batch_generate_class_mappings(
14831578
if "rel_labels" in b:
14841579
rel_types = b["rel_labels"]
14851580
else:
1486-
rel_types = list(set([el[-1] for el in b.get("relations", [])] + rel_negs_i))
1581+
# Exclude dropped relation types
1582+
active_rel_types = [el[-1] for el in b.get("relations", []) if el[-1] not in dropped_rel_types]
1583+
rel_types = list(set(active_rel_types + rel_negs_i))
14871584
random.shuffle(rel_types)
14881585
rel_types = rel_types[: int(self.config.max_types)]
14891586

@@ -1525,6 +1622,15 @@ def collate_raw_batch(
15251622
Dictionary containing collated batch data for joint entity and
15261623
relation extraction.
15271624
"""
1625+
# Apply data augmentation if enabled (only during dynamic mapping generation)
1626+
augment_prob = getattr(self.config, 'augment_data_prob', 0.0)
1627+
if augment_prob > 0.0 and class_to_ids is None and entity_types is None:
1628+
if ner_negatives is None:
1629+
ner_negatives = get_negatives(batch_list, sampled_neg=100, key="ner")
1630+
batch_list = [
1631+
self.augment_example(b, ner_negatives) if random.random() < augment_prob else b
1632+
for b in batch_list
1633+
]
15281634
if class_to_ids is None and entity_types is None:
15291635
# Dynamically infer per-example mappings
15301636
class_to_ids, id_to_classes, rel_class_to_ids, rel_id_to_classes = self.batch_generate_class_mappings(
@@ -1616,11 +1722,16 @@ def preprocess_example(self, tokens, ner, classes_to_id, relations, rel_classes_
16161722
span_to_idx = {(s, e): i for i, (s, e) in enumerate(spans_idx.tolist())}
16171723

16181724
# Create entity index mapping (from original entity list to span indices)
1725+
# and compact indices (0, 1, 2, ...) matching target_span_rep ordering
16191726
entity_to_span_idx = {}
1727+
entity_to_compact_idx = {}
1728+
compact_idx = 0
16201729
if ner is not None:
16211730
for ent_idx, (start, end, _) in enumerate(ner): # (start, end, label)
16221731
if (start, end) in span_to_idx and end < num_tokens:
16231732
entity_to_span_idx[ent_idx] = span_to_idx[(start, end)]
1733+
entity_to_compact_idx[ent_idx] = compact_idx
1734+
compact_idx += 1
16241735

16251736
# Process relations
16261737
rel_idx_list = []
@@ -1630,9 +1741,9 @@ def preprocess_example(self, tokens, ner, classes_to_id, relations, rel_classes_
16301741
for rel in relations:
16311742
head_idx, tail_idx, rel_type = rel
16321743

1633-
# Check if both entities are valid and map to span indices
1634-
if head_idx in entity_to_span_idx and tail_idx in entity_to_span_idx and rel_type in rel_classes_to_id:
1635-
rel_idx_list.append([head_idx, tail_idx])
1744+
# Use compact indices so rel_idx aligns with target_span_rep positions
1745+
if head_idx in entity_to_compact_idx and tail_idx in entity_to_compact_idx and rel_type in rel_classes_to_id:
1746+
rel_idx_list.append([entity_to_compact_idx[head_idx], entity_to_compact_idx[tail_idx]])
16361747
rel_label_list.append(rel_classes_to_id[rel_type])
16371748

16381749
# Convert to tensors
@@ -1696,28 +1807,40 @@ def create_batch_dict(self, batch, class_to_ids, id_to_classes, rel_class_to_ids
16961807
"rel_id_to_classes": rel_id_to_classes,
16971808
}
16981809

1699-
def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_negatives=True, negative_ratio=2.0):
1810+
def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_negatives=True, negative_ratio=(1.0, 10.0)):
17001811
"""Create relation labels with negative pair sampling.
17011812
17021813
Overrides the span-based version to work with token-level entity representations.
17031814
Uses entities_id count instead of span_label for entity counting.
17041815
1816+
When relations_layer is "none", generates labels for ALL entity pair
1817+
combinations (no adjacency matrix), matching the order produced by
1818+
build_all_entity_pairs. Otherwise, uses adjacency-based pair sampling.
1819+
17051820
Args:
17061821
batch: Batch dictionary containing entities and relations.
17071822
add_reversed_negatives: If True, add reversed direction pairs as negatives.
17081823
add_random_negatives: If True, add random entity pairs as negatives.
1709-
negative_ratio: Ratio of negative to positive pairs.
1824+
negative_ratio: Ratio of negative to positive pairs. Can be a float for a fixed ratio
1825+
or a (min, max) tuple to sample a random ratio per example.
17101826
17111827
Returns:
17121828
Tuple containing:
1713-
- adj_matrix: Adjacency matrix (shape: [B, max_entities, max_entities])
1829+
- adj_matrix: Adjacency matrix (shape: [B, max_entities, max_entities]).
1830+
None when relations_layer is "none".
17141831
- rel_matrix: Multi-hot relation labels (shape: [B, max_pairs, num_relation_classes])
17151832
"""
17161833
B = len(batch["tokens"])
17171834
span_mask = batch["span_mask"]
17181835

1719-
# Count entities per sample (differs from span-based which uses span_label)
1720-
batch_ents = span_mask.long().squeeze(-1).sum(-1)
1836+
# For span-based models, span_mask covers all candidate spans (L*max_width),
1837+
# but rel_idx uses compact entity indices (0..num_annotated-1), so we must
1838+
# count annotated entities instead. For token-level models, span_mask is already
1839+
# sized by entity count, so the original formula works.
1840+
if "span_label" in batch and batch["span_label"] is not None:
1841+
batch_ents = (batch["span_label"] > 0).sum(-1)
1842+
else:
1843+
batch_ents = span_mask.long().squeeze(-1).sum(-1)
17211844
max_En = max(batch_ents.max().item(), 1)
17221845

17231846
rel_class_to_ids = batch["rel_class_to_ids"]
@@ -1726,9 +1849,16 @@ def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_
17261849
else:
17271850
C = len(rel_class_to_ids) if rel_class_to_ids else 0
17281851

1852+
single_step = getattr(self.config, "relations_layer", None) == "none"
1853+
17291854
if C == 0:
1855+
if single_step:
1856+
return None, torch.zeros(B, 1, 1, dtype=torch.float)
17301857
return torch.zeros(B, max_En, max_En, dtype=torch.float), torch.zeros(B, 1, 1, dtype=torch.float)
17311858

1859+
if single_step:
1860+
return self._create_single_step_relation_labels(batch, batch_ents, C)
1861+
17321862
adj_matrix = torch.zeros(B, max_En, max_En, dtype=torch.float)
17331863

17341864
all_pairs_info = []
@@ -1758,7 +1888,11 @@ def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_
17581888
# Generate negative pairs
17591889
negative_pairs = set()
17601890
num_positives = len(positive_pairs)
1761-
target_negatives = int(num_positives * negative_ratio)
1891+
if isinstance(negative_ratio, (tuple, list)):
1892+
ratio = random.uniform(negative_ratio[0], negative_ratio[1])
1893+
else:
1894+
ratio = negative_ratio
1895+
target_negatives = max(1, int(num_positives * ratio))
17621896

17631897
if add_reversed_negatives:
17641898
for e1, e2 in positive_pairs:
@@ -1811,6 +1945,60 @@ def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_
18111945

18121946
return adj_matrix, rel_matrix
18131947

1948+
def _create_single_step_relation_labels(self, batch, batch_ents, C):
1949+
"""Create relation labels for single-step mode (all entity pair combinations).
1950+
1951+
Generates labels for ALL directed pairs (i, j) where i != j among entities,
1952+
matching the order produced by build_all_entity_pairs.
1953+
1954+
Args:
1955+
batch: Batch dictionary containing entities and relations.
1956+
batch_ents: Tensor of entity counts per example.
1957+
C: Number of relation classes.
1958+
1959+
Returns:
1960+
Tuple of (None, rel_matrix) where rel_matrix has shape
1961+
(B, max_pairs, C) with max_pairs = max(N_i * (N_i - 1)).
1962+
"""
1963+
B = len(batch["tokens"])
1964+
1965+
# Build pair-to-index mapping and collect labels
1966+
max_total_pairs = 0
1967+
all_pair_maps = []
1968+
1969+
for i in range(B):
1970+
N = batch_ents[i].item()
1971+
# All (e1, e2) pairs where e1 != e2, ordered as build_all_entity_pairs produces:
1972+
# (0,1), (0,2), ..., (1,0), (1,2), ..., i.e., sorted by (e1, e2)
1973+
pair_to_idx = {}
1974+
idx = 0
1975+
for e1 in range(N):
1976+
for e2 in range(N):
1977+
if e1 != e2:
1978+
pair_to_idx[(e1, e2)] = idx
1979+
idx += 1
1980+
all_pair_maps.append(pair_to_idx)
1981+
max_total_pairs = max(max_total_pairs, idx)
1982+
1983+
max_total_pairs = max(max_total_pairs, 1)
1984+
rel_matrix = torch.zeros(B, max_total_pairs, C, dtype=torch.float)
1985+
1986+
for i in range(B):
1987+
N = batch_ents[i].item()
1988+
rel_idx_i = batch["rel_idx"][i]
1989+
rel_label_i = batch["rel_label"][i]
1990+
pair_to_idx = all_pair_maps[i]
1991+
1992+
for k in range(rel_label_i.shape[0]):
1993+
if rel_label_i[k] > 0:
1994+
e1 = rel_idx_i[k, 0].item()
1995+
e2 = rel_idx_i[k, 1].item()
1996+
pair_key = (e1, e2)
1997+
if pair_key in pair_to_idx:
1998+
rel_matrix[i, pair_to_idx[pair_key], rel_label_i[k].item() - 1] = 1.0
1999+
2000+
return None, rel_matrix
2001+
18142002
def prepare_inputs(
18152003
self,
18162004
texts: Sequence[Sequence[str]],

0 commit comments

Comments
 (0)