@@ -436,6 +436,73 @@ def test_nest_answers(mock_reader: ExtractiveReader):
436436 assert no_answer .score == pytest .approx (expected_no_answer )
437437
438438
439+ def test_postprocess_filters_probs_when_answers_per_seq_exceeds_valid_spans (mock_reader : ExtractiveReader ):
440+ # Setup: seq_length=4, attention-masked positions 0-1, so only context positions 2-3
441+ # remain. The number of valid (start <= end) spans is therefore 3: (2,2), (2,3), (3,3).
442+ # We ask for 5 answers per sequence — the 2 extra topk picks fall in the masked region
443+ # and become probability 0 after sigmoid.
444+ start = torch .zeros ((2 , 4 ))
445+ end = torch .zeros ((2 , 4 ))
446+ sequence_ids = torch .ones ((2 , 4 ))
447+ attention_mask = torch .ones ((2 , 4 ))
448+ attention_mask [:, :2 ] = 0
449+ encoding = Mock ()
450+ encoding .token_to_chars = lambda i : (int (i ), int (i ) + 1 )
451+
452+ start_candidates , end_candidates , probs = mock_reader ._postprocess (
453+ start = start ,
454+ end = end ,
455+ sequence_ids = sequence_ids ,
456+ attention_mask = attention_mask ,
457+ answers_per_seq = 5 ,
458+ encodings = [encoding , encoding ],
459+ )
460+
461+ # Per-sequence lengths must all agree so the downstream strict=True zips don't blow up.
462+ assert len (start_candidates ) == len (end_candidates ) == len (probs ) == 2
463+ for i in range (2 ):
464+ assert len (start_candidates [i ]) == len (end_candidates [i ]) == len (probs [i ]) == 3
465+ # All retained probabilities are valid (> 0); masked candidates were dropped.
466+ assert torch .all (probs [i ] > 0 )
467+
468+
469+ def test_nest_answers_accepts_variable_length_probability_rows (mock_reader : ExtractiveReader ):
470+ # Mirrors the new contract from _postprocess: `probabilities` is a list of 1D tensors
471+ # whose lengths match the per-sequence start/end lists.
472+ start = [[0 , 1 ], [0 ]]
473+ end = [[5 , 6 ], [5 ]]
474+ probabilities = [torch .tensor ([0.8 , 0.6 ]), torch .tensor ([0.7 ])]
475+ query_ids = [0 , 0 ]
476+ document_ids = [0 , 1 ]
477+
478+ nested_answers = mock_reader ._nest_answers (
479+ start = start ,
480+ end = end ,
481+ probabilities = probabilities ,
482+ flattened_documents = example_documents [0 ][:2 ],
483+ queries = example_queries [:1 ],
484+ answers_per_seq = 2 ,
485+ top_k = 10 ,
486+ score_threshold = None ,
487+ query_ids = query_ids ,
488+ document_ids = document_ids ,
489+ no_answer = False ,
490+ overlap_threshold = None ,
491+ )
492+
493+ assert len (nested_answers ) == 1
494+ sorted_answers = sorted (nested_answers [0 ], key = lambda a : a .score , reverse = True )
495+
496+ docs = example_documents [0 ][:2 ]
497+ expected = [
498+ (pytest .approx (0.8 ), docs [0 ].id , "Angel" ),
499+ (pytest .approx (0.7 ), docs [1 ].id , "Olaf " ),
500+ (pytest .approx (0.6 ), docs [0 ].id , "ngela" ),
501+ ]
502+ actual = [(a .score , a .document .id , a .data ) for a in sorted_answers ]
503+ assert actual == expected
504+
505+
439506def test_add_answer_page_number_returns_same_answer (mock_reader : ExtractiveReader , caplog ):
440507 # answer.document_offset is None
441508 document = Document (content = "I thought a lot about this. The answer is 42." , meta = {"page_number" : 5 })
0 commit comments