@@ -20,6 +20,7 @@ class Span:
2020 class_probs: Optional dict of top-k class probabilities
2121 generated_labels: Optional list of generated labels (for generative decoders)
2222 """
23+
2324 start : int
2425 end : int
2526 entity_type : str
@@ -260,7 +261,7 @@ def _decode_batch_item(
260261 """
261262 # Mask probabilities to only include input spans (for efficiency)
262263 if input_spans_i is not None :
263- L , K_dim , C = probs_i .shape
264+ L , K_dim , _ = probs_i .shape
264265 span_filter = torch .zeros (L , K_dim , dtype = torch .bool , device = probs_i .device )
265266 for word_start , word_end in input_spans_i :
266267 width = word_end - word_start
@@ -358,18 +359,20 @@ class IDs to class names.
358359 if B == 1 :
359360 id_to_class_0 = self ._get_id_to_class_for_sample (id_to_classes , 0 )
360361 input_spans_0 = input_spans [0 ] if input_spans is not None else None
361- return [self ._decode_batch_item (
362- probs_i = probs [0 ],
363- tokens_i = tokens [0 ],
364- id_to_class_i = id_to_class_0 ,
365- K = K ,
366- threshold = threshold ,
367- flat_ner = flat_ner ,
368- multi_label = multi_label ,
369- span_label_map = span_label_maps [0 ],
370- return_class_probs = return_class_probs ,
371- input_spans_i = input_spans_0 ,
372- )]
362+ return [
363+ self ._decode_batch_item (
364+ probs_i = probs [0 ],
365+ tokens_i = tokens [0 ],
366+ id_to_class_i = id_to_class_0 ,
367+ K = K ,
368+ threshold = threshold ,
369+ flat_ner = flat_ner ,
370+ multi_label = multi_label ,
371+ span_label_map = span_label_maps [0 ],
372+ return_class_probs = return_class_probs ,
373+ input_spans_i = input_spans_0 ,
374+ )
375+ ]
373376
374377 # Apply input_spans mask at batch level (one mask, one multiply)
375378 if input_spans is not None :
@@ -392,9 +395,7 @@ class IDs to class names.
392395 return [[] for _ in range (B )]
393396
394397 # ONE vectorized valid-span check across entire batch
395- num_tokens = torch .tensor (
396- [len (t ) for t in tokens ], device = probs .device , dtype = torch .long
397- )
398+ num_tokens = torch .tensor ([len (t ) for t in tokens ], device = probs .device , dtype = torch .long )
398399 valid = (s_idx + k_idx + 1 ) <= num_tokens [b_idx ]
399400 b_idx = b_idx [valid ]
400401 s_idx = s_idx [valid ]
@@ -427,15 +428,11 @@ class IDs to class names.
427428 top_indices_list = all_top_indices .tolist ()
428429
429430 # Pre-resolve id_to_class mappings per batch item
430- id_to_class_per_item = [
431- self ._get_id_to_class_for_sample (id_to_classes , i ) for i in range (B )
432- ]
431+ id_to_class_per_item = [self ._get_id_to_class_for_sample (id_to_classes , i ) for i in range (B )]
433432
434433 # Group by batch item and build Span objects (pure Python)
435434 batch_spans : List [List [Span ]] = [[] for _ in range (B )]
436- for j , (b , s , k , c , flat_idx , score ) in enumerate (
437- zip (b_list , s_list , k_list , c_list , flat_idxs , scores )
438- ):
435+ for j , (b , s , k , c , flat_idx , score ) in enumerate (zip (b_list , s_list , k_list , c_list , flat_idxs , scores )):
439436 id_to_class_i = id_to_class_per_item [b ]
440437
441438 class_probs = None
@@ -445,16 +442,11 @@ class IDs to class names.
445442 class_name = id_to_class_i .get (idx + 1 , f"class_{ idx } " )
446443 class_probs [class_name ] = prob
447444
448- span = self ._build_span_tuple (
449- s , k , c , flat_idx , score , id_to_class_i , span_label_maps [b ], class_probs
450- )
445+ span = self ._build_span_tuple (s , k , c , flat_idx , score , id_to_class_i , span_label_maps [b ], class_probs )
451446 batch_spans [b ].append (span )
452447
453448 # Per-item greedy search (inherently sequential, but cheap pure Python)
454- return [
455- self .greedy_search (spans , flat_ner , multi_label = multi_label )
456- for spans in batch_spans
457- ]
449+ return [self .greedy_search (spans , flat_ner , multi_label = multi_label ) for spans in batch_spans ]
458450
459451 def decode (
460452 self ,
@@ -544,13 +536,7 @@ def _build_span_tuple(
544536 Span: Span object with entity properties.
545537 """
546538 ent_type = id_to_class [class_idx + 1 ] # +1 because 0 is <pad>
547- return Span (
548- start = start ,
549- end = start + width ,
550- entity_type = ent_type ,
551- score = score ,
552- class_probs = class_probs
553- )
539+ return Span (start = start , end = start + width , entity_type = ent_type , score = score , class_probs = class_probs )
554540
555541
556542class SpanGenerativeDecoder (BaseSpanDecoder ):
@@ -679,7 +665,7 @@ def _build_span_tuple(
679665 entity_type = ent_type ,
680666 score = score ,
681667 class_probs = class_probs ,
682- generated_labels = gen_ent_type
668+ generated_labels = gen_ent_type ,
683669 )
684670
685671 def decode_generative (
@@ -864,15 +850,8 @@ def _decode_relations_batch(
864850 # 3. Vectorized index-validity check
865851 head = rel_idx [..., 0 ] # (B, R)
866852 tail = rel_idx [..., 1 ] # (B, R)
867- num_spans = torch .tensor (
868- [len (s ) for s in spans ], device = rel_idx .device , dtype = head .dtype
869- ) # (B,)
870- valid = (
871- (head >= 0 )
872- & (tail >= 0 )
873- & (head < num_spans [:, None ])
874- & (tail < num_spans [:, None ])
875- ) # (B, R)
853+ num_spans = torch .tensor ([len (s ) for s in spans ], device = rel_idx .device , dtype = head .dtype ) # (B,)
854+ valid = (head >= 0 ) & (tail >= 0 ) & (head < num_spans [:, None ]) & (tail < num_spans [:, None ]) # (B, R)
876855 rel_probs = rel_probs * valid .unsqueeze (- 1 )
877856
878857 # 4. Single torch.where on the full (B, R, C) tensor
@@ -898,9 +877,7 @@ def _decode_relations_batch(
898877 mapping = rel_id_to_classes [b ] if is_list else rel_id_to_classes
899878 if c1 not in mapping :
900879 continue
901- relations [b ].append (
902- (int (head_list [k ]), mapping [c1 ], int (tail_list [k ]), scores [k ])
903- )
880+ relations [b ].append ((int (head_list [k ]), mapping [c1 ], int (tail_list [k ]), scores [k ]))
904881
905882 return relations
906883
@@ -955,13 +932,7 @@ def _build_span_tuple(
955932 Span: Span object with entity properties.
956933 """
957934 ent_type = id_to_class [class_idx + 1 ] # +1 because 0 is <pad>
958- return Span (
959- start = start ,
960- end = start + width ,
961- entity_type = ent_type ,
962- score = score ,
963- class_probs = class_probs
964- )
935+ return Span (start = start , end = start + width , entity_type = ent_type , score = score , class_probs = class_probs )
965936
966937 def _build_entity_span_to_decoded_idx (
967938 self ,
@@ -1151,6 +1122,7 @@ def decode(
11511122 rel_idx: Optional tensor of shape (batch_size, num_relations, 2).
11521123 rel_logits: Optional tensor of shape (batch_size, num_relations, num_relation_classes).
11531124 rel_mask: Optional boolean tensor of shape (batch_size, num_relations).
1125+ return_class_probs: Whether to include class probabilities in the decoded spans.
11541126 flat_ner: If True, applies greedy filtering for non-overlapping entities.
11551127 threshold: Minimum confidence score for entity predictions.
11561128 relation_threshold: Minimum confidence score for relation predictions.
@@ -1266,13 +1238,8 @@ def _calculate_span_score(
12661238 start_score = start_cpu [st ][cls_st ]
12671239 end_score = end_cpu [ed ][cls_ed ]
12681240 # The span score is the minimum value among all scores
1269- spn_score = min (min (ins ), start_score , end_score )
1270- span_i .append (Span (
1271- start = st ,
1272- end = ed ,
1273- entity_type = id_to_classes [cls_st + 1 ],
1274- score = spn_score
1275- ))
1241+ spn_score = min (* ins , start_score , end_score )
1242+ span_i .append (Span (start = st , end = ed , entity_type = id_to_classes [cls_st + 1 ], score = spn_score ))
12761243 return span_i
12771244
12781245 def _decode_from_spans (
@@ -1349,12 +1316,7 @@ class IDs to class names.
13491316 class_id = class_idx + 1 # Convert to 1-indexed
13501317 if class_id in id_to_class_i :
13511318 entity_type = id_to_class_i [class_id ]
1352- span_scores .append (Span (
1353- start = span_start ,
1354- end = span_end ,
1355- entity_type = entity_type ,
1356- score = prob
1357- ))
1319+ span_scores .append (Span (start = span_start , end = span_end , entity_type = entity_type , score = prob ))
13581320
13591321 # Apply greedy search to handle overlapping spans if needed
13601322 span_i = self .greedy_search (span_scores , flat_ner , multi_label )
@@ -1664,6 +1626,8 @@ def decode(
16641626 rel_id_to_classes: Optional mapping from relation class IDs to relation names.
16651627 If None, relation decoding is skipped and empty relation lists are returned.
16661628 Can be either a single Dict or List[Dict] for per-sample mappings.
1629+ entity_spans: Optional tensor of pre-computed entity spans to use instead
1630+ of decoding them from model_output.
16671631 Class IDs are 1-indexed.
16681632 **kwargs: Additional keyword arguments passed to the parent class decode method.
16691633
0 commit comments