@@ -1417,6 +1417,95 @@ def __init__(self, config, tokenizer, words_splitter):
14171417 super ().__init__ (config , tokenizer , words_splitter )
14181418 self .rel_token = config .rel_token
14191419
1420+ def augment_example (self , example , ner_negatives = None , other_keyword = "other" ):
1421+ """Apply data augmentation by randomly dropping entity/relation types.
1422+
1423+ For each example (triggered by augment_data_prob):
1424+ - Sample a per-type drop probability from augment_ent_drop_prob range
1425+ and drop each entity type with that probability.
1426+ - Sample a per-type drop probability from augment_rel_drop_prob range
1427+ and drop each relation type with that probability.
1428+ - With augment_add_other_prob, entities whose type is dropped get their
1429+ type replaced with "other" so they still participate in relation
1430+ extraction. If "other" is not added and the entity has no active
1431+ (non-dropped) relations, it is dropped entirely.
1432+ - Relations with dropped types are removed; relation indices are remapped
1433+ after entity removal.
1434+
1435+ Args:
1436+ example: Dictionary with 'ner' and 'relations' keys.
1437+ ner_negatives: Pool of negative entity types (unused, kept for API compat).
1438+ other_keyword: Replacement type label for dropped entity types.
1439+
1440+ Returns:
1441+ A (possibly modified) copy of the example with '_dropped_ent_types'
1442+ and '_dropped_rel_types' metadata keys.
1443+ """
1444+ ner = example .get ("ner" , [])
1445+ relations = example .get ("relations" , [])
1446+
1447+ if not ner :
1448+ return example
1449+
1450+ ent_drop_prob = random .uniform (* self .config .augment_ent_drop_prob )
1451+ rel_drop_prob = random .uniform (* self .config .augment_rel_drop_prob )
1452+ add_other = random .random () < self .config .augment_add_other_prob
1453+
1454+ all_ent_types = set (e [- 1 ] for e in ner )
1455+ all_rel_types = set (r [- 1 ] for r in relations ) if relations else set ()
1456+
1457+ # "other" is exempt from dropping since it's our replacement label
1458+ dropped_ent_types = {t for t in all_ent_types if t != other_keyword and random .random () < ent_drop_prob }
1459+ dropped_rel_types = {t for t in all_rel_types if random .random () < rel_drop_prob }
1460+
1461+ if not dropped_ent_types and not dropped_rel_types :
1462+ return example
1463+
1464+ # Determine which entities participate in non-dropped relations
1465+ entity_has_active_rel = set ()
1466+ if relations :
1467+ for head_idx , tail_idx , rel_type in relations :
1468+ if rel_type not in dropped_rel_types :
1469+ entity_has_active_rel .add (head_idx )
1470+ entity_has_active_rel .add (tail_idx )
1471+
1472+ # Process entities: replace with "other", keep, or drop
1473+ new_ner = []
1474+ old_to_new_idx = {}
1475+
1476+ for i , ent in enumerate (ner ):
1477+ ent_type = ent [- 1 ]
1478+
1479+ if ent_type in dropped_ent_types and i not in entity_has_active_rel :
1480+ # Entity type dropped and no active relations → drop entity
1481+ continue
1482+ if ent_type in dropped_ent_types and not add_other :
1483+ # Entity type dropped and "other" not enabled → drop entity
1484+ continue
1485+
1486+ old_to_new_idx [i ] = len (new_ner )
1487+ if ent_type in dropped_ent_types and add_other :
1488+ # Replace dropped type with "other"
1489+ new_ner .append (list (ent [:- 1 ]) + [other_keyword ])
1490+ else :
1491+ new_ner .append (ent )
1492+
1493+ # Update relations: drop relations with dropped types, remap entity indices
1494+ new_relations = []
1495+ for head_idx , tail_idx , rel_type in relations :
1496+ if rel_type in dropped_rel_types :
1497+ continue
1498+ if head_idx in old_to_new_idx and tail_idx in old_to_new_idx :
1499+ new_relations .append ([old_to_new_idx [head_idx ], old_to_new_idx [tail_idx ], rel_type ])
1500+
1501+ result = dict (example )
1502+ result ["ner" ] = new_ner
1503+ result ["relations" ] = new_relations
1504+ result ["_dropped_ent_types" ] = dropped_ent_types
1505+ result ["_dropped_rel_types" ] = dropped_rel_types
1506+
1507+ return result
1508+
14201509 def batch_generate_class_mappings (
14211510 self ,
14221511 batch_list : List [Dict ],
@@ -1456,6 +1545,10 @@ def batch_generate_class_mappings(
14561545 max_neg_type_ratio = int (self .config .max_neg_type_ratio )
14571546 neg_type_ratio = random .randint (0 , max_neg_type_ratio ) if max_neg_type_ratio else 0
14581547
1548+ # Augmentation metadata (set by augment_example)
1549+ dropped_ent_types = b .get ("_dropped_ent_types" , set ())
1550+ dropped_rel_types = b .get ("_dropped_rel_types" , set ())
1551+
14591552 # Process NER types
14601553 if "ner_negatives" in b :
14611554 negs_i = b ["ner_negatives" ]
@@ -1465,7 +1558,9 @@ def batch_generate_class_mappings(
14651558 if "ner_labels" in b :
14661559 types = b ["ner_labels" ]
14671560 else :
1468- types = list (set ([el [- 1 ] for el in b ["ner" ]] + negs_i ))
1561+ # Exclude dropped entity types ("other" replacements are already in ner)
1562+ ent_types = [el [- 1 ] for el in b ["ner" ] if el [- 1 ] not in dropped_ent_types ]
1563+ types = list (set (ent_types + negs_i ))
14691564 random .shuffle (types )
14701565 types = types [: int (self .config .max_types )]
14711566
@@ -1483,7 +1578,9 @@ def batch_generate_class_mappings(
14831578 if "rel_labels" in b :
14841579 rel_types = b ["rel_labels" ]
14851580 else :
1486- rel_types = list (set ([el [- 1 ] for el in b .get ("relations" , [])] + rel_negs_i ))
1581+ # Exclude dropped relation types
1582+ active_rel_types = [el [- 1 ] for el in b .get ("relations" , []) if el [- 1 ] not in dropped_rel_types ]
1583+ rel_types = list (set (active_rel_types + rel_negs_i ))
14871584 random .shuffle (rel_types )
14881585 rel_types = rel_types [: int (self .config .max_types )]
14891586
@@ -1525,6 +1622,15 @@ def collate_raw_batch(
15251622 Dictionary containing collated batch data for joint entity and
15261623 relation extraction.
15271624 """
1625+ # Apply data augmentation if enabled (only during dynamic mapping generation)
1626+ augment_prob = getattr (self .config , 'augment_data_prob' , 0.0 )
1627+ if augment_prob > 0.0 and class_to_ids is None and entity_types is None :
1628+ if ner_negatives is None :
1629+ ner_negatives = get_negatives (batch_list , sampled_neg = 100 , key = "ner" )
1630+ batch_list = [
1631+ self .augment_example (b , ner_negatives ) if random .random () < augment_prob else b
1632+ for b in batch_list
1633+ ]
15281634 if class_to_ids is None and entity_types is None :
15291635 # Dynamically infer per-example mappings
15301636 class_to_ids , id_to_classes , rel_class_to_ids , rel_id_to_classes = self .batch_generate_class_mappings (
@@ -1616,11 +1722,16 @@ def preprocess_example(self, tokens, ner, classes_to_id, relations, rel_classes_
16161722 span_to_idx = {(s , e ): i for i , (s , e ) in enumerate (spans_idx .tolist ())}
16171723
16181724 # Create entity index mapping (from original entity list to span indices)
1725+ # and compact indices (0, 1, 2, ...) matching target_span_rep ordering
16191726 entity_to_span_idx = {}
1727+ entity_to_compact_idx = {}
1728+ compact_idx = 0
16201729 if ner is not None :
16211730 for ent_idx , (start , end , _ ) in enumerate (ner ): # (start, end, label)
16221731 if (start , end ) in span_to_idx and end < num_tokens :
16231732 entity_to_span_idx [ent_idx ] = span_to_idx [(start , end )]
1733+ entity_to_compact_idx [ent_idx ] = compact_idx
1734+ compact_idx += 1
16241735
16251736 # Process relations
16261737 rel_idx_list = []
@@ -1630,9 +1741,9 @@ def preprocess_example(self, tokens, ner, classes_to_id, relations, rel_classes_
16301741 for rel in relations :
16311742 head_idx , tail_idx , rel_type = rel
16321743
1633- # Check if both entities are valid and map to span indices
1634- if head_idx in entity_to_span_idx and tail_idx in entity_to_span_idx and rel_type in rel_classes_to_id :
1635- rel_idx_list .append ([head_idx , tail_idx ])
1744+ # Use compact indices so rel_idx aligns with target_span_rep positions
1745+ if head_idx in entity_to_compact_idx and tail_idx in entity_to_compact_idx and rel_type in rel_classes_to_id :
1746+ rel_idx_list .append ([entity_to_compact_idx [ head_idx ], entity_to_compact_idx [ tail_idx ] ])
16361747 rel_label_list .append (rel_classes_to_id [rel_type ])
16371748
16381749 # Convert to tensors
@@ -1696,28 +1807,40 @@ def create_batch_dict(self, batch, class_to_ids, id_to_classes, rel_class_to_ids
16961807 "rel_id_to_classes" : rel_id_to_classes ,
16971808 }
16981809
1699- def create_relation_labels (self , batch , add_reversed_negatives = True , add_random_negatives = True , negative_ratio = 2.0 ):
1810+ def create_relation_labels (self , batch , add_reversed_negatives = True , add_random_negatives = True , negative_ratio = ( 1.0 , 10.0 ) ):
17001811 """Create relation labels with negative pair sampling.
17011812
17021813 Overrides the span-based version to work with token-level entity representations.
17031814 Uses entities_id count instead of span_label for entity counting.
17041815
1816+ When relations_layer is "none", generates labels for ALL entity pair
1817+ combinations (no adjacency matrix), matching the order produced by
1818+ build_all_entity_pairs. Otherwise, uses adjacency-based pair sampling.
1819+
17051820 Args:
17061821 batch: Batch dictionary containing entities and relations.
17071822 add_reversed_negatives: If True, add reversed direction pairs as negatives.
17081823 add_random_negatives: If True, add random entity pairs as negatives.
1709- negative_ratio: Ratio of negative to positive pairs.
1824+ negative_ratio: Ratio of negative to positive pairs. Can be a float for a fixed ratio
1825+ or a (min, max) tuple to sample a random ratio per example.
17101826
17111827 Returns:
17121828 Tuple containing:
1713- - adj_matrix: Adjacency matrix (shape: [B, max_entities, max_entities])
1829+ - adj_matrix: Adjacency matrix (shape: [B, max_entities, max_entities]).
1830+ None when relations_layer is "none".
17141831 - rel_matrix: Multi-hot relation labels (shape: [B, max_pairs, num_relation_classes])
17151832 """
17161833 B = len (batch ["tokens" ])
17171834 span_mask = batch ["span_mask" ]
17181835
1719- # Count entities per sample (differs from span-based which uses span_label)
1720- batch_ents = span_mask .long ().squeeze (- 1 ).sum (- 1 )
1836+ # For span-based models, span_mask covers all candidate spans (L*max_width),
1837+ # but rel_idx uses compact entity indices (0..num_annotated-1), so we must
1838+ # count annotated entities instead. For token-level models, span_mask is already
1839+ # sized by entity count, so the original formula works.
1840+ if "span_label" in batch and batch ["span_label" ] is not None :
1841+ batch_ents = (batch ["span_label" ] > 0 ).sum (- 1 )
1842+ else :
1843+ batch_ents = span_mask .long ().squeeze (- 1 ).sum (- 1 )
17211844 max_En = max (batch_ents .max ().item (), 1 )
17221845
17231846 rel_class_to_ids = batch ["rel_class_to_ids" ]
@@ -1726,9 +1849,16 @@ def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_
17261849 else :
17271850 C = len (rel_class_to_ids ) if rel_class_to_ids else 0
17281851
1852+ single_step = getattr (self .config , "relations_layer" , None ) == "none"
1853+
17291854 if C == 0 :
1855+ if single_step :
1856+ return None , torch .zeros (B , 1 , 1 , dtype = torch .float )
17301857 return torch .zeros (B , max_En , max_En , dtype = torch .float ), torch .zeros (B , 1 , 1 , dtype = torch .float )
17311858
1859+ if single_step :
1860+ return self ._create_single_step_relation_labels (batch , batch_ents , C )
1861+
17321862 adj_matrix = torch .zeros (B , max_En , max_En , dtype = torch .float )
17331863
17341864 all_pairs_info = []
@@ -1758,7 +1888,11 @@ def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_
17581888 # Generate negative pairs
17591889 negative_pairs = set ()
17601890 num_positives = len (positive_pairs )
1761- target_negatives = int (num_positives * negative_ratio )
1891+ if isinstance (negative_ratio , (tuple , list )):
1892+ ratio = random .uniform (negative_ratio [0 ], negative_ratio [1 ])
1893+ else :
1894+ ratio = negative_ratio
1895+ target_negatives = max (1 , int (num_positives * ratio ))
17621896
17631897 if add_reversed_negatives :
17641898 for e1 , e2 in positive_pairs :
@@ -1811,6 +1945,60 @@ def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_
18111945
18121946 return adj_matrix , rel_matrix
18131947
1948+ def _create_single_step_relation_labels (self , batch , batch_ents , C ):
1949+ """Create relation labels for single-step mode (all entity pair combinations).
1950+
1951+ Generates labels for ALL directed pairs (i, j) where i != j among entities,
1952+ matching the order produced by build_all_entity_pairs.
1953+
1954+ Args:
1955+ batch: Batch dictionary containing entities and relations.
1956+ batch_ents: Tensor of entity counts per example.
1957+ C: Number of relation classes.
1958+
1959+ Returns:
1960+ Tuple of (None, rel_matrix) where rel_matrix has shape
1961+ (B, max_pairs, C) with max_pairs = max(N_i * (N_i - 1)).
1962+ """
1963+ B = len (batch ["tokens" ])
1964+
1965+ # Build pair-to-index mapping and collect labels
1966+ max_total_pairs = 0
1967+ all_pair_maps = []
1968+
1969+ for i in range (B ):
1970+ N = batch_ents [i ].item ()
1971+ # All (e1, e2) pairs where e1 != e2, ordered as build_all_entity_pairs produces:
1972+ # (0,1), (0,2), ..., (1,0), (1,2), ..., i.e., sorted by (e1, e2)
1973+ pair_to_idx = {}
1974+ idx = 0
1975+ for e1 in range (N ):
1976+ for e2 in range (N ):
1977+ if e1 != e2 :
1978+ pair_to_idx [(e1 , e2 )] = idx
1979+ idx += 1
1980+ all_pair_maps .append (pair_to_idx )
1981+ max_total_pairs = max (max_total_pairs , idx )
1982+
1983+ max_total_pairs = max (max_total_pairs , 1 )
1984+ rel_matrix = torch .zeros (B , max_total_pairs , C , dtype = torch .float )
1985+
1986+ for i in range (B ):
1987+ N = batch_ents [i ].item ()
1988+ rel_idx_i = batch ["rel_idx" ][i ]
1989+ rel_label_i = batch ["rel_label" ][i ]
1990+ pair_to_idx = all_pair_maps [i ]
1991+
1992+ for k in range (rel_label_i .shape [0 ]):
1993+ if rel_label_i [k ] > 0 :
1994+ e1 = rel_idx_i [k , 0 ].item ()
1995+ e2 = rel_idx_i [k , 1 ].item ()
1996+ pair_key = (e1 , e2 )
1997+ if pair_key in pair_to_idx :
1998+ rel_matrix [i , pair_to_idx [pair_key ], rel_label_i [k ].item () - 1 ] = 1.0
1999+
2000+ return None , rel_matrix
2001+
18142002 def prepare_inputs (
18152003 self ,
18162004 texts : Sequence [Sequence [str ]],
0 commit comments