From 78860c2e960665689358c281afb1b92c195c239b Mon Sep 17 00:00:00 2001 From: Nedas Jaronis Date: Thu, 12 Feb 2026 14:51:28 -0500 Subject: [PATCH 1/2] Add edge case tests for prompt validation (fixes #316) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add three new test cases to improve coverage of edge cases: - test_empty_extraction_text: Validates empty strings are rejected - test_unicode_and_special_chars: Ensures unicode handling (café, °F, €) - test_very_long_extraction_text: Tests graceful handling of 1000+ chars These tests improve reliability by catching edge case bugs early and documenting expected behavior for boundary conditions. --- tests/prompt_validation_test.py | 806 ++++++++++++++++++-------------- 1 file changed, 452 insertions(+), 354 deletions(-) diff --git a/tests/prompt_validation_test.py b/tests/prompt_validation_test.py index a9d4aa3f..544d3331 100644 --- a/tests/prompt_validation_test.py +++ b/tests/prompt_validation_test.py @@ -24,403 +24,501 @@ class PromptAlignmentValidationTest(parameterized.TestCase): - @parameterized.named_parameters( - dict( - testcase_name="exact_alignment", - text="Patient takes lisinopril.", - extraction_class="Medication", - extraction_text="lisinopril", - expected_issues=0, - expected_has_failed=False, - expected_has_non_exact=False, - expected_alignment_status=None, - ), - dict( - testcase_name="fuzzy_match_lesser", - text="Type 2 diabetes.", - extraction_class="Diagnosis", - extraction_text="type-2 diabetes", - expected_issues=1, - expected_has_failed=False, - expected_has_non_exact=True, - expected_alignment_status=data.AlignmentStatus.MATCH_LESSER, - ), - dict( - testcase_name="extraction_not_found", - text="No medications mentioned in this text.", - extraction_class="Medication", - extraction_text="lisinopril", - expected_issues=1, - expected_has_failed=True, - expected_has_non_exact=False, - expected_alignment_status=None, - ), - ) - def test_alignment_detection( - self, - text, - extraction_class, - extraction_text, - expected_issues, - expected_has_failed, - expected_has_non_exact, - expected_alignment_status, - ): - """Test that different alignment types are correctly detected.""" - example = data.ExampleData( - text=text, - extractions=[ - data.Extraction( - extraction_class=extraction_class, - extraction_text=extraction_text, - attributes={}, - ) - ], + @parameterized.named_parameters( + dict( + testcase_name="exact_alignment", + text="Patient takes lisinopril.", + extraction_class="Medication", + extraction_text="lisinopril", + expected_issues=0, + expected_has_failed=False, + expected_has_non_exact=False, + expected_alignment_status=None, + ), + dict( + testcase_name="fuzzy_match_lesser", + text="Type 2 diabetes.", + extraction_class="Diagnosis", + extraction_text="type-2 diabetes", + expected_issues=1, + expected_has_failed=False, + expected_has_non_exact=True, + expected_alignment_status=data.AlignmentStatus.MATCH_LESSER, + ), + dict( + testcase_name="extraction_not_found", + text="No medications mentioned in this text.", + extraction_class="Medication", + extraction_text="lisinopril", + expected_issues=1, + expected_has_failed=True, + expected_has_non_exact=False, + expected_alignment_status=None, + ), ) + def test_alignment_detection( + self, + text, + extraction_class, + extraction_text, + expected_issues, + expected_has_failed, + expected_has_non_exact, + expected_alignment_status, + ): + """Test that different alignment types are correctly detected.""" + example = data.ExampleData( + text=text, + extractions=[ + data.Extraction( + extraction_class=extraction_class, + extraction_text=extraction_text, + attributes={}, + ) + ], + ) - report = prompt_validation.validate_prompt_alignment([example]) - - self.assertLen(report.issues, expected_issues) - self.assertEqual(report.has_failed, expected_has_failed) - self.assertEqual(report.has_non_exact, expected_has_non_exact) - - if expected_issues > 0: - issue = report.issues[0] - self.assertEqual(issue.alignment_status, expected_alignment_status) - self.assertEqual(issue.extraction_class, extraction_class) - if expected_has_failed: - self.assertIsNone(issue.alignment_status) - elif expected_has_non_exact: - self.assertIsNotNone(issue.alignment_status) - - @parameterized.named_parameters( - dict( - testcase_name="one_fails", - text="Patient takes lisinopril and has diabetes mellitus.", - extractions=[ - ("Medication", "lisinopril"), # PASSES - found exactly - ("Diagnosis", "diabetes"), # PASSES - found exactly - ("Medication", "metformin"), # FAILS - not in text - ], - expected_issues=1, - expected_has_failed=True, - expected_has_non_exact=False, - expected_failed_text="metformin", - ), - dict( - testcase_name="all_pass", - text="Patient takes lisinopril and aspirin for diabetes management.", - extractions=[ - ("Medication", "lisinopril"), - ("Medication", "aspirin"), - ("Diagnosis", "diabetes"), - ], - expected_issues=0, - expected_has_failed=False, - expected_has_non_exact=False, - expected_failed_text=None, - ), - ) - def test_multiple_extractions_per_example( - self, - text, - extractions, - expected_issues, - expected_has_failed, - expected_has_non_exact, - expected_failed_text, - ): - """Test validation with multiple extractions in a single example.""" - example = data.ExampleData( - text=text, - extractions=[ - data.Extraction( - extraction_class=extraction_class, - extraction_text=extraction_text, - attributes={}, - ) - for extraction_class, extraction_text in extractions - ], + report = prompt_validation.validate_prompt_alignment([example]) + + self.assertLen(report.issues, expected_issues) + self.assertEqual(report.has_failed, expected_has_failed) + self.assertEqual(report.has_non_exact, expected_has_non_exact) + + if expected_issues > 0: + issue = report.issues[0] + self.assertEqual(issue.alignment_status, expected_alignment_status) + self.assertEqual(issue.extraction_class, extraction_class) + if expected_has_failed: + self.assertIsNone(issue.alignment_status) + elif expected_has_non_exact: + self.assertIsNotNone(issue.alignment_status) + + @parameterized.named_parameters( + dict( + testcase_name="one_fails", + text="Patient takes lisinopril and has diabetes mellitus.", + extractions=[ + ("Medication", "lisinopril"), # PASSES - found exactly + ("Diagnosis", "diabetes"), # PASSES - found exactly + ("Medication", "metformin"), # FAILS - not in text + ], + expected_issues=1, + expected_has_failed=True, + expected_has_non_exact=False, + expected_failed_text="metformin", + ), + dict( + testcase_name="all_pass", + text="Patient takes lisinopril and aspirin for diabetes management.", + extractions=[ + ("Medication", "lisinopril"), + ("Medication", "aspirin"), + ("Diagnosis", "diabetes"), + ], + expected_issues=0, + expected_has_failed=False, + expected_has_non_exact=False, + expected_failed_text=None, + ), ) + def test_multiple_extractions_per_example( + self, + text, + extractions, + expected_issues, + expected_has_failed, + expected_has_non_exact, + expected_failed_text, + ): + """Test validation with multiple extractions in a single example.""" + example = data.ExampleData( + text=text, + extractions=[ + data.Extraction( + extraction_class=extraction_class, + extraction_text=extraction_text, + attributes={}, + ) + for extraction_class, extraction_text in extractions + ], + ) - report = prompt_validation.validate_prompt_alignment([example]) - - self.assertLen(report.issues, expected_issues) - self.assertEqual(report.has_failed, expected_has_failed) - self.assertEqual(report.has_non_exact, expected_has_non_exact) - - if expected_failed_text: - issue = report.issues[0] - self.assertIsNone(issue.alignment_status) - self.assertEqual(issue.extraction_text_preview, expected_failed_text) - - @parameterized.named_parameters( - dict( - testcase_name="warning_mode_with_failed", - text="Patient has no known allergies.", - extraction_text="penicillin", - validation_level=prompt_validation.PromptValidationLevel.WARNING, - strict_non_exact=False, - ), - dict( - testcase_name="off_mode_with_failed", - text="Patient history incomplete.", - extraction_text="aspirin", - validation_level=prompt_validation.PromptValidationLevel.OFF, - strict_non_exact=False, - ), - ) - def test_validation_levels_that_dont_raise( - self, text, extraction_text, validation_level, strict_non_exact - ): - """Test that WARNING and OFF modes don't raise exceptions.""" - example = data.ExampleData( - text=text, - extractions=[ - data.Extraction( - extraction_class="Medication", - extraction_text=extraction_text, - attributes={}, - ) - ], + report = prompt_validation.validate_prompt_alignment([example]) + + self.assertLen(report.issues, expected_issues) + self.assertEqual(report.has_failed, expected_has_failed) + self.assertEqual(report.has_non_exact, expected_has_non_exact) + + if expected_failed_text: + issue = report.issues[0] + self.assertIsNone(issue.alignment_status) + self.assertEqual(issue.extraction_text_preview, + expected_failed_text) + + @parameterized.named_parameters( + dict( + testcase_name="warning_mode_with_failed", + text="Patient has no known allergies.", + extraction_text="penicillin", + validation_level=prompt_validation.PromptValidationLevel.WARNING, + strict_non_exact=False, + ), + dict( + testcase_name="off_mode_with_failed", + text="Patient history incomplete.", + extraction_text="aspirin", + validation_level=prompt_validation.PromptValidationLevel.OFF, + strict_non_exact=False, + ), ) + def test_validation_levels_that_dont_raise( + self, text, extraction_text, validation_level, strict_non_exact + ): + """Test that WARNING and OFF modes don't raise exceptions.""" + example = data.ExampleData( + text=text, + extractions=[ + data.Extraction( + extraction_class="Medication", + extraction_text=extraction_text, + attributes={}, + ) + ], + ) - report = prompt_validation.validate_prompt_alignment([example]) + report = prompt_validation.validate_prompt_alignment([example]) - # This should not raise an exception in WARNING or OFF modes - prompt_validation.handle_alignment_report( - report, validation_level, strict_non_exact=strict_non_exact - ) + # This should not raise an exception in WARNING or OFF modes + prompt_validation.handle_alignment_report( + report, validation_level, strict_non_exact=strict_non_exact + ) - @parameterized.named_parameters( - dict( - testcase_name="error_mode_failed_alignment", - text="Patient has no known allergies.", - extraction_class="Medication", - extraction_text="penicillin", - strict_non_exact=False, - error_pattern=r"1 extraction\(s\).*could not be aligned", - ), - dict( - testcase_name="error_mode_strict_fuzzy_match", - text="Type 2 diabetes.", - extraction_class="Diagnosis", - extraction_text="type-2 diabetes", - strict_non_exact=True, - error_pattern=r"strict mode.*1 non-exact", - ), - ) - def test_error_mode_raises_appropriately( - self, - text, - extraction_class, - extraction_text, - strict_non_exact, - error_pattern, - ): - """Test that ERROR mode raises with appropriate messages.""" - example = data.ExampleData( - text=text, - extractions=[ - data.Extraction( - extraction_class=extraction_class, - extraction_text=extraction_text, - attributes={}, - ) - ], + @parameterized.named_parameters( + dict( + testcase_name="error_mode_failed_alignment", + text="Patient has no known allergies.", + extraction_class="Medication", + extraction_text="penicillin", + strict_non_exact=False, + error_pattern=r"1 extraction\(s\).*could not be aligned", + ), + dict( + testcase_name="error_mode_strict_fuzzy_match", + text="Type 2 diabetes.", + extraction_class="Diagnosis", + extraction_text="type-2 diabetes", + strict_non_exact=True, + error_pattern=r"strict mode.*1 non-exact", + ), ) - - report = prompt_validation.validate_prompt_alignment([example]) - - with self.assertRaisesRegex( - prompt_validation.PromptAlignmentError, error_pattern + def test_error_mode_raises_appropriately( + self, + text, + extraction_class, + extraction_text, + strict_non_exact, + error_pattern, ): - prompt_validation.handle_alignment_report( - report, - prompt_validation.PromptValidationLevel.ERROR, - strict_non_exact=strict_non_exact, - ) - - def test_empty_examples_produces_empty_report(self): - report = prompt_validation.validate_prompt_alignment([]) - - self.assertEmpty(report.issues) - self.assertFalse(report.has_failed) - self.assertFalse(report.has_non_exact) - - def test_multiple_examples_preserve_indices(self): - examples = [ - data.ExampleData( # Example 0: FAILS - "metformin" not in text - text="First patient record.", + """Test that ERROR mode raises with appropriate messages.""" + example = data.ExampleData( + text=text, extractions=[ data.Extraction( - extraction_class="Medication", - extraction_text="metformin", + extraction_class=extraction_class, + extraction_text=extraction_text, attributes={}, ) ], - ), - data.ExampleData( # Example 1: PASSES - "aspirin" found exactly - text="Patient takes aspirin daily.", + ) + + report = prompt_validation.validate_prompt_alignment([example]) + + with self.assertRaisesRegex( + prompt_validation.PromptAlignmentError, error_pattern + ): + prompt_validation.handle_alignment_report( + report, + prompt_validation.PromptValidationLevel.ERROR, + strict_non_exact=strict_non_exact, + ) + + def test_empty_examples_produces_empty_report(self): + report = prompt_validation.validate_prompt_alignment([]) + + self.assertEmpty(report.issues) + self.assertFalse(report.has_failed) + self.assertFalse(report.has_non_exact) + + def test_multiple_examples_preserve_indices(self): + examples = [ + data.ExampleData( # Example 0: FAILS - "metformin" not in text + text="First patient record.", + extractions=[ + data.Extraction( + extraction_class="Medication", + extraction_text="metformin", + attributes={}, + ) + ], + ), + data.ExampleData( # Example 1: PASSES - "aspirin" found exactly + text="Patient takes aspirin daily.", + extractions=[ + data.Extraction( + extraction_class="Medication", + extraction_text="aspirin", + attributes={}, + ) + ], + ), + data.ExampleData( # Example 2: NON-EXACT - "type-2" fuzzy matches "Type 2" + text="Type 2 diabetes mellitus.", + extractions=[ + data.Extraction( + extraction_class="Diagnosis", + extraction_text="type-2 diabetes", + attributes={}, + ) + ], + ), + ] + + report = prompt_validation.validate_prompt_alignment(examples) + + # Expect 2 issues: example 0 (failed) and example 2 (non-exact) + self.assertLen(report.issues, 2) + self.assertTrue(report.has_failed) + self.assertTrue(report.has_non_exact) + + issue_by_index = { + issue.example_index: issue for issue in report.issues} + + # Example 0: Failed alignment (metformin not found) + self.assertIn(0, issue_by_index) + self.assertIsNone(issue_by_index[0].alignment_status) + + # Example 1: No issue (aspirin found exactly) + self.assertNotIn(1, issue_by_index) + + # Example 2: Non-exact match (type-2 vs Type 2) + self.assertIn(2, issue_by_index) + self.assertIsNotNone(issue_by_index[2].alignment_status) + + def test_validation_does_not_mutate_input(self): + example = data.ExampleData( + text="Patient takes lisinopril 10mg daily.", extractions=[ data.Extraction( extraction_class="Medication", - extraction_text="aspirin", + extraction_text="lisinopril", attributes={}, ) ], + ) + + original_extraction = example.extractions[0] + + self.assertIsNone(getattr(original_extraction, "token_interval", None)) + self.assertIsNone(getattr(original_extraction, "char_interval", None)) + self.assertIsNone( + getattr(original_extraction, "alignment_status", None)) + + _ = prompt_validation.validate_prompt_alignment([example]) + + self.assertIsNone(getattr(original_extraction, "token_interval", None)) + self.assertIsNone(getattr(original_extraction, "char_interval", None)) + self.assertIsNone( + getattr(original_extraction, "alignment_status", None)) + + @parameterized.named_parameters( + dict( + testcase_name="fuzzy_disabled_rejects_non_exact", + text="Patient has type 2 diabetes.", + extraction_class="Diagnosis", + extraction_text="Type-2 Diabetes", + enable_fuzzy=False, + accept_lesser=False, + fuzzy_threshold=0.75, + expected_has_failed=True, + expected_has_non_exact=False, + ), + dict( + testcase_name="fuzzy_enabled_accepts_close_match", + text="Patient has type 2 diabetes.", + extraction_class="Diagnosis", + extraction_text="Type-2 Diabetes", + enable_fuzzy=True, + accept_lesser=False, + fuzzy_threshold=0.75, + expected_has_failed=False, + expected_has_non_exact=True, ), - data.ExampleData( # Example 2: NON-EXACT - "type-2" fuzzy matches "Type 2" - text="Type 2 diabetes mellitus.", + ) + def test_alignment_policies( + self, + text, + extraction_class, + extraction_text, + enable_fuzzy, + accept_lesser, + fuzzy_threshold, + expected_has_failed, + expected_has_non_exact, + ): + """Test different alignment policy configurations.""" + example = data.ExampleData( + text=text, extractions=[ data.Extraction( - extraction_class="Diagnosis", - extraction_text="type-2 diabetes", + extraction_class=extraction_class, + extraction_text=extraction_text, attributes={}, ) ], - ), - ] - - report = prompt_validation.validate_prompt_alignment(examples) - - # Expect 2 issues: example 0 (failed) and example 2 (non-exact) - self.assertLen(report.issues, 2) - self.assertTrue(report.has_failed) - self.assertTrue(report.has_non_exact) - - issue_by_index = {issue.example_index: issue for issue in report.issues} + ) - # Example 0: Failed alignment (metformin not found) - self.assertIn(0, issue_by_index) - self.assertIsNone(issue_by_index[0].alignment_status) + if not enable_fuzzy: + default_report = prompt_validation.validate_prompt_alignment([ + example]) + self.assertFalse(default_report.has_failed) + self.assertTrue(default_report.has_non_exact) - # Example 1: No issue (aspirin found exactly) - self.assertNotIn(1, issue_by_index) + policy = prompt_validation.AlignmentPolicy( + enable_fuzzy_alignment=enable_fuzzy, + accept_match_lesser=accept_lesser, + fuzzy_alignment_threshold=fuzzy_threshold, + ) + report = prompt_validation.validate_prompt_alignment( + [example], policy=policy + ) - # Example 2: Non-exact match (type-2 vs Type 2) - self.assertIn(2, issue_by_index) - self.assertIsNotNone(issue_by_index[2].alignment_status) + self.assertEqual(report.has_failed, expected_has_failed) + self.assertEqual(report.has_non_exact, expected_has_non_exact) - def test_validation_does_not_mutate_input(self): - example = data.ExampleData( - text="Patient takes lisinopril 10mg daily.", - extractions=[ - data.Extraction( - extraction_class="Medication", - extraction_text="lisinopril", - attributes={}, - ) - ], - ) + def test_empty_extraction_text(self): + """Verify that empty strings are properly rejected iwth ValueError.""" + example = data.ExampleData( + text="Patient takes lisinopril 10mg daily.", + extractions=[ + data.Extraction( + extraction_class="Medication", + extraction_text="", + attributes={}, + ) + ], + ) - original_extraction = example.extractions[0] - - self.assertIsNone(getattr(original_extraction, "token_interval", None)) - self.assertIsNone(getattr(original_extraction, "char_interval", None)) - self.assertIsNone(getattr(original_extraction, "alignment_status", None)) - - _ = prompt_validation.validate_prompt_alignment([example]) - - self.assertIsNone(getattr(original_extraction, "token_interval", None)) - self.assertIsNone(getattr(original_extraction, "char_interval", None)) - self.assertIsNone(getattr(original_extraction, "alignment_status", None)) - - @parameterized.named_parameters( - dict( - testcase_name="fuzzy_disabled_rejects_non_exact", - text="Patient has type 2 diabetes.", - extraction_class="Diagnosis", - extraction_text="Type-2 Diabetes", - enable_fuzzy=False, - accept_lesser=False, - fuzzy_threshold=0.75, - expected_has_failed=True, - expected_has_non_exact=False, - ), - dict( - testcase_name="fuzzy_enabled_accepts_close_match", - text="Patient has type 2 diabetes.", - extraction_class="Diagnosis", - extraction_text="Type-2 Diabetes", - enable_fuzzy=True, - accept_lesser=False, - fuzzy_threshold=0.75, - expected_has_failed=False, - expected_has_non_exact=True, - ), - ) - def test_alignment_policies( - self, - text, - extraction_class, - extraction_text, - enable_fuzzy, - accept_lesser, - fuzzy_threshold, - expected_has_failed, - expected_has_non_exact, - ): - """Test different alignment policy configurations.""" - example = data.ExampleData( - text=text, - extractions=[ - data.Extraction( - extraction_class=extraction_class, - extraction_text=extraction_text, - attributes={}, - ) - ], - ) + with self.assertRaises(ValueError) as context: + prompt_validation.validate_prompt_alignment([example]) - if not enable_fuzzy: - default_report = prompt_validation.validate_prompt_alignment([example]) - self.assertFalse(default_report.has_failed) - self.assertTrue(default_report.has_non_exact) + self.assertIn("empty", str(context.exception).lower()) - policy = prompt_validation.AlignmentPolicy( - enable_fuzzy_alignment=enable_fuzzy, - accept_match_lesser=accept_lesser, - fuzzy_alignment_threshold=fuzzy_threshold, - ) - report = prompt_validation.validate_prompt_alignment( - [example], policy=policy - ) + def test_unicode_and_special_chars(self): + """Ensure proper handling of unicode characters and special symbols.""" + example = data.ExampleData( + text="Patient visits café daily. Temperature: 98.6°F. Cost: €50, £40, ¥500.", + extractions=[ + data.Extraction( + extraction_class="Location", + extraction_text="café", + attributes={"type": "establishment"}, + ), + data.Extraction( + extraction_class="Temperature", + extraction_text="98.6°F", # Degree symbol + attributes={"unit": "fahrenheit"}, + ), + data.Extraction( + extraction_class="Cost", + extraction_text="€50", # Euro symbol + attributes={"currency": "euro"}, + ), + data.Extraction( + extraction_class="Cost", + extraction_text="£40", # Pound symbol + attributes={"currency": "pound"}, + ), + data.Extraction( + extraction_class="Cost", + extraction_text="¥500", # Yen symbol + attributes={"currency": "yen"}, + ), + ], + ) - self.assertEqual(report.has_failed, expected_has_failed) - self.assertEqual(report.has_non_exact, expected_has_non_exact) + # Unicode characters should be handled without errors + report = prompt_validation.validate_prompt_alignment([example]) + # All extractions should align successfully (exact matches) + self.assertEmpty(report.issues) + self.assertFalse(report.has_failed) + self.assertFalse(report.has_non_exact) -class ExtractIntegrationTest(absltest.TestCase): - """Minimal integration test for extract() entry point validation.""" + def test_very_long_extraction_text(self): + """Test graceful handling of very long extraction text (1000+ characters).""" + # Create a long text passage (approximately 1140 characters) + long_passage = ( + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " * 20 + ) - def test_extract_validates_in_error_mode(self): - """Verify extract() runs validation when configured.""" - examples = [ - data.ExampleData( - text="Patient takes aspirin.", + example = data.ExampleData( + text=f"Medical history: {long_passage} End of history.", extractions=[ data.Extraction( - extraction_class="Medication", - extraction_text="ibuprofen", - attributes={}, - ) + extraction_class="History", + # Exactly 1000 characters + extraction_text=long_passage[:1000], + attributes={"length": "detailed"}, + ), ], ) - ] - with self.assertRaisesRegex( - prompt_validation.PromptAlignmentError, - r"1 extraction\(s\).*could not be aligned", - ): - extraction.extract( - text_or_documents="Test document", - prompt_description="Extract medications", - examples=examples, - prompt_validation_level=prompt_validation.PromptValidationLevel.ERROR, - model_id="fake-model", - ) + # Long extraction text should be handled gracefully + report = prompt_validation.validate_prompt_alignment([example]) + + # Should either succeed (if substring match works) or fail gracefully + # The important thing is it doesn't crash or hang + self.assertIsNotNone(report) + + # If it finds a match, verify it's tracked correctly + if not report.has_failed: + # Successfully found the long extraction + self.assertIsNotNone(report) + else: + # If it fails, that's also acceptable - just verify it fails cleanly + self.assertTrue(report.has_failed) + + +class ExtractIntegrationTest(absltest.TestCase): + """Minimal integration test for extract() entry point validation.""" + + def test_extract_validates_in_error_mode(self): + """Verify extract() runs validation when configured.""" + examples = [ + data.ExampleData( + text="Patient takes aspirin.", + extractions=[ + data.Extraction( + extraction_class="Medication", + extraction_text="ibuprofen", + attributes={}, + ) + ], + ) + ] + + with self.assertRaisesRegex( + prompt_validation.PromptAlignmentError, + r"1 extraction\(s\).*could not be aligned", + ): + extraction.extract( + text_or_documents="Test document", + prompt_description="Extract medications", + examples=examples, + prompt_validation_level=prompt_validation.PromptValidationLevel.ERROR, + model_id="fake-model", + ) if __name__ == "__main__": - absltest.main() + absltest.main() From efee645049b794e130872d1b3829bf69be9df68e Mon Sep 17 00:00:00 2001 From: Nedas Jaronis Date: Thu, 12 Feb 2026 15:14:11 -0500 Subject: [PATCH 2/2] Fix code formatting --- tests/prompt_validation_test.py | 888 ++++++++++++++++---------------- 1 file changed, 443 insertions(+), 445 deletions(-) diff --git a/tests/prompt_validation_test.py b/tests/prompt_validation_test.py index 544d3331..1225ef62 100644 --- a/tests/prompt_validation_test.py +++ b/tests/prompt_validation_test.py @@ -24,501 +24,499 @@ class PromptAlignmentValidationTest(parameterized.TestCase): - @parameterized.named_parameters( - dict( - testcase_name="exact_alignment", - text="Patient takes lisinopril.", - extraction_class="Medication", - extraction_text="lisinopril", - expected_issues=0, - expected_has_failed=False, - expected_has_non_exact=False, - expected_alignment_status=None, - ), - dict( - testcase_name="fuzzy_match_lesser", - text="Type 2 diabetes.", - extraction_class="Diagnosis", - extraction_text="type-2 diabetes", - expected_issues=1, - expected_has_failed=False, - expected_has_non_exact=True, - expected_alignment_status=data.AlignmentStatus.MATCH_LESSER, - ), - dict( - testcase_name="extraction_not_found", - text="No medications mentioned in this text.", - extraction_class="Medication", - extraction_text="lisinopril", - expected_issues=1, - expected_has_failed=True, - expected_has_non_exact=False, - expected_alignment_status=None, - ), + @parameterized.named_parameters( + dict( + testcase_name="exact_alignment", + text="Patient takes lisinopril.", + extraction_class="Medication", + extraction_text="lisinopril", + expected_issues=0, + expected_has_failed=False, + expected_has_non_exact=False, + expected_alignment_status=None, + ), + dict( + testcase_name="fuzzy_match_lesser", + text="Type 2 diabetes.", + extraction_class="Diagnosis", + extraction_text="type-2 diabetes", + expected_issues=1, + expected_has_failed=False, + expected_has_non_exact=True, + expected_alignment_status=data.AlignmentStatus.MATCH_LESSER, + ), + dict( + testcase_name="extraction_not_found", + text="No medications mentioned in this text.", + extraction_class="Medication", + extraction_text="lisinopril", + expected_issues=1, + expected_has_failed=True, + expected_has_non_exact=False, + expected_alignment_status=None, + ), + ) + def test_alignment_detection( + self, + text, + extraction_class, + extraction_text, + expected_issues, + expected_has_failed, + expected_has_non_exact, + expected_alignment_status, + ): + """Test that different alignment types are correctly detected.""" + example = data.ExampleData( + text=text, + extractions=[ + data.Extraction( + extraction_class=extraction_class, + extraction_text=extraction_text, + attributes={}, + ) + ], ) - def test_alignment_detection( - self, - text, - extraction_class, - extraction_text, - expected_issues, - expected_has_failed, - expected_has_non_exact, - expected_alignment_status, - ): - """Test that different alignment types are correctly detected.""" - example = data.ExampleData( - text=text, - extractions=[ - data.Extraction( - extraction_class=extraction_class, - extraction_text=extraction_text, - attributes={}, - ) - ], - ) - report = prompt_validation.validate_prompt_alignment([example]) - - self.assertLen(report.issues, expected_issues) - self.assertEqual(report.has_failed, expected_has_failed) - self.assertEqual(report.has_non_exact, expected_has_non_exact) - - if expected_issues > 0: - issue = report.issues[0] - self.assertEqual(issue.alignment_status, expected_alignment_status) - self.assertEqual(issue.extraction_class, extraction_class) - if expected_has_failed: - self.assertIsNone(issue.alignment_status) - elif expected_has_non_exact: - self.assertIsNotNone(issue.alignment_status) - - @parameterized.named_parameters( - dict( - testcase_name="one_fails", - text="Patient takes lisinopril and has diabetes mellitus.", - extractions=[ - ("Medication", "lisinopril"), # PASSES - found exactly - ("Diagnosis", "diabetes"), # PASSES - found exactly - ("Medication", "metformin"), # FAILS - not in text - ], - expected_issues=1, - expected_has_failed=True, - expected_has_non_exact=False, - expected_failed_text="metformin", - ), - dict( - testcase_name="all_pass", - text="Patient takes lisinopril and aspirin for diabetes management.", - extractions=[ - ("Medication", "lisinopril"), - ("Medication", "aspirin"), - ("Diagnosis", "diabetes"), - ], - expected_issues=0, - expected_has_failed=False, - expected_has_non_exact=False, - expected_failed_text=None, - ), + report = prompt_validation.validate_prompt_alignment([example]) + + self.assertLen(report.issues, expected_issues) + self.assertEqual(report.has_failed, expected_has_failed) + self.assertEqual(report.has_non_exact, expected_has_non_exact) + + if expected_issues > 0: + issue = report.issues[0] + self.assertEqual(issue.alignment_status, expected_alignment_status) + self.assertEqual(issue.extraction_class, extraction_class) + if expected_has_failed: + self.assertIsNone(issue.alignment_status) + elif expected_has_non_exact: + self.assertIsNotNone(issue.alignment_status) + + @parameterized.named_parameters( + dict( + testcase_name="one_fails", + text="Patient takes lisinopril and has diabetes mellitus.", + extractions=[ + ("Medication", "lisinopril"), # PASSES - found exactly + ("Diagnosis", "diabetes"), # PASSES - found exactly + ("Medication", "metformin"), # FAILS - not in text + ], + expected_issues=1, + expected_has_failed=True, + expected_has_non_exact=False, + expected_failed_text="metformin", + ), + dict( + testcase_name="all_pass", + text="Patient takes lisinopril and aspirin for diabetes management.", + extractions=[ + ("Medication", "lisinopril"), + ("Medication", "aspirin"), + ("Diagnosis", "diabetes"), + ], + expected_issues=0, + expected_has_failed=False, + expected_has_non_exact=False, + expected_failed_text=None, + ), + ) + def test_multiple_extractions_per_example( + self, + text, + extractions, + expected_issues, + expected_has_failed, + expected_has_non_exact, + expected_failed_text, + ): + """Test validation with multiple extractions in a single example.""" + example = data.ExampleData( + text=text, + extractions=[ + data.Extraction( + extraction_class=extraction_class, + extraction_text=extraction_text, + attributes={}, + ) + for extraction_class, extraction_text in extractions + ], + ) + + report = prompt_validation.validate_prompt_alignment([example]) + + self.assertLen(report.issues, expected_issues) + self.assertEqual(report.has_failed, expected_has_failed) + self.assertEqual(report.has_non_exact, expected_has_non_exact) + + if expected_failed_text: + issue = report.issues[0] + self.assertIsNone(issue.alignment_status) + self.assertEqual(issue.extraction_text_preview, expected_failed_text) + + @parameterized.named_parameters( + dict( + testcase_name="warning_mode_with_failed", + text="Patient has no known allergies.", + extraction_text="penicillin", + validation_level=prompt_validation.PromptValidationLevel.WARNING, + strict_non_exact=False, + ), + dict( + testcase_name="off_mode_with_failed", + text="Patient history incomplete.", + extraction_text="aspirin", + validation_level=prompt_validation.PromptValidationLevel.OFF, + strict_non_exact=False, + ), + ) + def test_validation_levels_that_dont_raise( + self, text, extraction_text, validation_level, strict_non_exact + ): + """Test that WARNING and OFF modes don't raise exceptions.""" + example = data.ExampleData( + text=text, + extractions=[ + data.Extraction( + extraction_class="Medication", + extraction_text=extraction_text, + attributes={}, + ) + ], + ) + + report = prompt_validation.validate_prompt_alignment([example]) + + # This should not raise an exception in WARNING or OFF modes + prompt_validation.handle_alignment_report( + report, validation_level, strict_non_exact=strict_non_exact ) - def test_multiple_extractions_per_example( - self, - text, - extractions, - expected_issues, - expected_has_failed, - expected_has_non_exact, - expected_failed_text, + + @parameterized.named_parameters( + dict( + testcase_name="error_mode_failed_alignment", + text="Patient has no known allergies.", + extraction_class="Medication", + extraction_text="penicillin", + strict_non_exact=False, + error_pattern=r"1 extraction\(s\).*could not be aligned", + ), + dict( + testcase_name="error_mode_strict_fuzzy_match", + text="Type 2 diabetes.", + extraction_class="Diagnosis", + extraction_text="type-2 diabetes", + strict_non_exact=True, + error_pattern=r"strict mode.*1 non-exact", + ), + ) + def test_error_mode_raises_appropriately( + self, + text, + extraction_class, + extraction_text, + strict_non_exact, + error_pattern, + ): + """Test that ERROR mode raises with appropriate messages.""" + example = data.ExampleData( + text=text, + extractions=[ + data.Extraction( + extraction_class=extraction_class, + extraction_text=extraction_text, + attributes={}, + ) + ], + ) + + report = prompt_validation.validate_prompt_alignment([example]) + + with self.assertRaisesRegex( + prompt_validation.PromptAlignmentError, error_pattern ): - """Test validation with multiple extractions in a single example.""" - example = data.ExampleData( - text=text, + prompt_validation.handle_alignment_report( + report, + prompt_validation.PromptValidationLevel.ERROR, + strict_non_exact=strict_non_exact, + ) + + def test_empty_examples_produces_empty_report(self): + report = prompt_validation.validate_prompt_alignment([]) + + self.assertEmpty(report.issues) + self.assertFalse(report.has_failed) + self.assertFalse(report.has_non_exact) + + def test_multiple_examples_preserve_indices(self): + examples = [ + data.ExampleData( # Example 0: FAILS - "metformin" not in text + text="First patient record.", extractions=[ data.Extraction( - extraction_class=extraction_class, - extraction_text=extraction_text, + extraction_class="Medication", + extraction_text="metformin", attributes={}, ) - for extraction_class, extraction_text in extractions ], - ) - - report = prompt_validation.validate_prompt_alignment([example]) - - self.assertLen(report.issues, expected_issues) - self.assertEqual(report.has_failed, expected_has_failed) - self.assertEqual(report.has_non_exact, expected_has_non_exact) - - if expected_failed_text: - issue = report.issues[0] - self.assertIsNone(issue.alignment_status) - self.assertEqual(issue.extraction_text_preview, - expected_failed_text) - - @parameterized.named_parameters( - dict( - testcase_name="warning_mode_with_failed", - text="Patient has no known allergies.", - extraction_text="penicillin", - validation_level=prompt_validation.PromptValidationLevel.WARNING, - strict_non_exact=False, - ), - dict( - testcase_name="off_mode_with_failed", - text="Patient history incomplete.", - extraction_text="aspirin", - validation_level=prompt_validation.PromptValidationLevel.OFF, - strict_non_exact=False, ), - ) - def test_validation_levels_that_dont_raise( - self, text, extraction_text, validation_level, strict_non_exact - ): - """Test that WARNING and OFF modes don't raise exceptions.""" - example = data.ExampleData( - text=text, + data.ExampleData( # Example 1: PASSES - "aspirin" found exactly + text="Patient takes aspirin daily.", extractions=[ data.Extraction( extraction_class="Medication", - extraction_text=extraction_text, + extraction_text="aspirin", attributes={}, ) ], - ) - - report = prompt_validation.validate_prompt_alignment([example]) - - # This should not raise an exception in WARNING or OFF modes - prompt_validation.handle_alignment_report( - report, validation_level, strict_non_exact=strict_non_exact - ) - - @parameterized.named_parameters( - dict( - testcase_name="error_mode_failed_alignment", - text="Patient has no known allergies.", - extraction_class="Medication", - extraction_text="penicillin", - strict_non_exact=False, - error_pattern=r"1 extraction\(s\).*could not be aligned", ), - dict( - testcase_name="error_mode_strict_fuzzy_match", - text="Type 2 diabetes.", - extraction_class="Diagnosis", - extraction_text="type-2 diabetes", - strict_non_exact=True, - error_pattern=r"strict mode.*1 non-exact", - ), - ) - def test_error_mode_raises_appropriately( - self, - text, - extraction_class, - extraction_text, - strict_non_exact, - error_pattern, - ): - """Test that ERROR mode raises with appropriate messages.""" - example = data.ExampleData( - text=text, + data.ExampleData( # Example 2: NON-EXACT - "type-2" fuzzy matches "Type 2" + text="Type 2 diabetes mellitus.", extractions=[ data.Extraction( - extraction_class=extraction_class, - extraction_text=extraction_text, + extraction_class="Diagnosis", + extraction_text="type-2 diabetes", attributes={}, ) ], - ) + ), + ] - report = prompt_validation.validate_prompt_alignment([example]) + report = prompt_validation.validate_prompt_alignment(examples) - with self.assertRaisesRegex( - prompt_validation.PromptAlignmentError, error_pattern - ): - prompt_validation.handle_alignment_report( - report, - prompt_validation.PromptValidationLevel.ERROR, - strict_non_exact=strict_non_exact, - ) + # Expect 2 issues: example 0 (failed) and example 2 (non-exact) + self.assertLen(report.issues, 2) + self.assertTrue(report.has_failed) + self.assertTrue(report.has_non_exact) - def test_empty_examples_produces_empty_report(self): - report = prompt_validation.validate_prompt_alignment([]) - - self.assertEmpty(report.issues) - self.assertFalse(report.has_failed) - self.assertFalse(report.has_non_exact) - - def test_multiple_examples_preserve_indices(self): - examples = [ - data.ExampleData( # Example 0: FAILS - "metformin" not in text - text="First patient record.", - extractions=[ - data.Extraction( - extraction_class="Medication", - extraction_text="metformin", - attributes={}, - ) - ], - ), - data.ExampleData( # Example 1: PASSES - "aspirin" found exactly - text="Patient takes aspirin daily.", - extractions=[ - data.Extraction( - extraction_class="Medication", - extraction_text="aspirin", - attributes={}, - ) - ], - ), - data.ExampleData( # Example 2: NON-EXACT - "type-2" fuzzy matches "Type 2" - text="Type 2 diabetes mellitus.", - extractions=[ - data.Extraction( - extraction_class="Diagnosis", - extraction_text="type-2 diabetes", - attributes={}, - ) - ], - ), - ] + issue_by_index = {issue.example_index: issue for issue in report.issues} - report = prompt_validation.validate_prompt_alignment(examples) + # Example 0: Failed alignment (metformin not found) + self.assertIn(0, issue_by_index) + self.assertIsNone(issue_by_index[0].alignment_status) - # Expect 2 issues: example 0 (failed) and example 2 (non-exact) - self.assertLen(report.issues, 2) - self.assertTrue(report.has_failed) - self.assertTrue(report.has_non_exact) + # Example 1: No issue (aspirin found exactly) + self.assertNotIn(1, issue_by_index) - issue_by_index = { - issue.example_index: issue for issue in report.issues} + # Example 2: Non-exact match (type-2 vs Type 2) + self.assertIn(2, issue_by_index) + self.assertIsNotNone(issue_by_index[2].alignment_status) - # Example 0: Failed alignment (metformin not found) - self.assertIn(0, issue_by_index) - self.assertIsNone(issue_by_index[0].alignment_status) + def test_validation_does_not_mutate_input(self): + example = data.ExampleData( + text="Patient takes lisinopril 10mg daily.", + extractions=[ + data.Extraction( + extraction_class="Medication", + extraction_text="lisinopril", + attributes={}, + ) + ], + ) - # Example 1: No issue (aspirin found exactly) - self.assertNotIn(1, issue_by_index) + original_extraction = example.extractions[0] + + self.assertIsNone(getattr(original_extraction, "token_interval", None)) + self.assertIsNone(getattr(original_extraction, "char_interval", None)) + self.assertIsNone(getattr(original_extraction, "alignment_status", None)) + + _ = prompt_validation.validate_prompt_alignment([example]) + + self.assertIsNone(getattr(original_extraction, "token_interval", None)) + self.assertIsNone(getattr(original_extraction, "char_interval", None)) + self.assertIsNone(getattr(original_extraction, "alignment_status", None)) + + @parameterized.named_parameters( + dict( + testcase_name="fuzzy_disabled_rejects_non_exact", + text="Patient has type 2 diabetes.", + extraction_class="Diagnosis", + extraction_text="Type-2 Diabetes", + enable_fuzzy=False, + accept_lesser=False, + fuzzy_threshold=0.75, + expected_has_failed=True, + expected_has_non_exact=False, + ), + dict( + testcase_name="fuzzy_enabled_accepts_close_match", + text="Patient has type 2 diabetes.", + extraction_class="Diagnosis", + extraction_text="Type-2 Diabetes", + enable_fuzzy=True, + accept_lesser=False, + fuzzy_threshold=0.75, + expected_has_failed=False, + expected_has_non_exact=True, + ), + ) + def test_alignment_policies( + self, + text, + extraction_class, + extraction_text, + enable_fuzzy, + accept_lesser, + fuzzy_threshold, + expected_has_failed, + expected_has_non_exact, + ): + """Test different alignment policy configurations.""" + example = data.ExampleData( + text=text, + extractions=[ + data.Extraction( + extraction_class=extraction_class, + extraction_text=extraction_text, + attributes={}, + ) + ], + ) - # Example 2: Non-exact match (type-2 vs Type 2) - self.assertIn(2, issue_by_index) - self.assertIsNotNone(issue_by_index[2].alignment_status) + if not enable_fuzzy: + default_report = prompt_validation.validate_prompt_alignment([example]) + self.assertFalse(default_report.has_failed) + self.assertTrue(default_report.has_non_exact) - def test_validation_does_not_mutate_input(self): - example = data.ExampleData( - text="Patient takes lisinopril 10mg daily.", - extractions=[ - data.Extraction( - extraction_class="Medication", - extraction_text="lisinopril", - attributes={}, - ) - ], - ) + policy = prompt_validation.AlignmentPolicy( + enable_fuzzy_alignment=enable_fuzzy, + accept_match_lesser=accept_lesser, + fuzzy_alignment_threshold=fuzzy_threshold, + ) + report = prompt_validation.validate_prompt_alignment( + [example], policy=policy + ) - original_extraction = example.extractions[0] - - self.assertIsNone(getattr(original_extraction, "token_interval", None)) - self.assertIsNone(getattr(original_extraction, "char_interval", None)) - self.assertIsNone( - getattr(original_extraction, "alignment_status", None)) - - _ = prompt_validation.validate_prompt_alignment([example]) - - self.assertIsNone(getattr(original_extraction, "token_interval", None)) - self.assertIsNone(getattr(original_extraction, "char_interval", None)) - self.assertIsNone( - getattr(original_extraction, "alignment_status", None)) - - @parameterized.named_parameters( - dict( - testcase_name="fuzzy_disabled_rejects_non_exact", - text="Patient has type 2 diabetes.", - extraction_class="Diagnosis", - extraction_text="Type-2 Diabetes", - enable_fuzzy=False, - accept_lesser=False, - fuzzy_threshold=0.75, - expected_has_failed=True, - expected_has_non_exact=False, - ), - dict( - testcase_name="fuzzy_enabled_accepts_close_match", - text="Patient has type 2 diabetes.", - extraction_class="Diagnosis", - extraction_text="Type-2 Diabetes", - enable_fuzzy=True, - accept_lesser=False, - fuzzy_threshold=0.75, - expected_has_failed=False, - expected_has_non_exact=True, - ), + self.assertEqual(report.has_failed, expected_has_failed) + self.assertEqual(report.has_non_exact, expected_has_non_exact) + + def test_empty_extraction_text(self): + """Verify that empty strings are properly rejected iwth ValueError.""" + example = data.ExampleData( + text="Patient takes lisinopril 10mg daily.", + extractions=[ + data.Extraction( + extraction_class="Medication", + extraction_text="", + attributes={}, + ) + ], ) - def test_alignment_policies( - self, - text, - extraction_class, - extraction_text, - enable_fuzzy, - accept_lesser, - fuzzy_threshold, - expected_has_failed, - expected_has_non_exact, - ): - """Test different alignment policy configurations.""" - example = data.ExampleData( - text=text, - extractions=[ - data.Extraction( - extraction_class=extraction_class, - extraction_text=extraction_text, - attributes={}, - ) - ], - ) - if not enable_fuzzy: - default_report = prompt_validation.validate_prompt_alignment([ - example]) - self.assertFalse(default_report.has_failed) - self.assertTrue(default_report.has_non_exact) + with self.assertRaises(ValueError) as context: + prompt_validation.validate_prompt_alignment([example]) - policy = prompt_validation.AlignmentPolicy( - enable_fuzzy_alignment=enable_fuzzy, - accept_match_lesser=accept_lesser, - fuzzy_alignment_threshold=fuzzy_threshold, - ) - report = prompt_validation.validate_prompt_alignment( - [example], policy=policy - ) + self.assertIn("empty", str(context.exception).lower()) - self.assertEqual(report.has_failed, expected_has_failed) - self.assertEqual(report.has_non_exact, expected_has_non_exact) + def test_unicode_and_special_chars(self): + """Ensure proper handling of unicode characters and special symbols.""" + example = data.ExampleData( + text=( + "Patient visits café daily. Temperature: 98.6°F. Cost: €50, £40," + " ¥500." + ), + extractions=[ + data.Extraction( + extraction_class="Location", + extraction_text="café", + attributes={"type": "establishment"}, + ), + data.Extraction( + extraction_class="Temperature", + extraction_text="98.6°F", # Degree symbol + attributes={"unit": "fahrenheit"}, + ), + data.Extraction( + extraction_class="Cost", + extraction_text="€50", # Euro symbol + attributes={"currency": "euro"}, + ), + data.Extraction( + extraction_class="Cost", + extraction_text="£40", # Pound symbol + attributes={"currency": "pound"}, + ), + data.Extraction( + extraction_class="Cost", + extraction_text="¥500", # Yen symbol + attributes={"currency": "yen"}, + ), + ], + ) - def test_empty_extraction_text(self): - """Verify that empty strings are properly rejected iwth ValueError.""" - example = data.ExampleData( - text="Patient takes lisinopril 10mg daily.", - extractions=[ - data.Extraction( - extraction_class="Medication", - extraction_text="", - attributes={}, - ) - ], - ) + # Unicode characters should be handled without errors + report = prompt_validation.validate_prompt_alignment([example]) - with self.assertRaises(ValueError) as context: - prompt_validation.validate_prompt_alignment([example]) + # All extractions should align successfully (exact matches) + self.assertEmpty(report.issues) + self.assertFalse(report.has_failed) + self.assertFalse(report.has_non_exact) - self.assertIn("empty", str(context.exception).lower()) + def test_very_long_extraction_text(self): + """Test graceful handling of very long extraction text (1000+ characters).""" + # Create a long text passage (approximately 1140 characters) + long_passage = ( + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " * 20 + ) - def test_unicode_and_special_chars(self): - """Ensure proper handling of unicode characters and special symbols.""" - example = data.ExampleData( - text="Patient visits café daily. Temperature: 98.6°F. Cost: €50, £40, ¥500.", - extractions=[ - data.Extraction( - extraction_class="Location", - extraction_text="café", - attributes={"type": "establishment"}, - ), - data.Extraction( - extraction_class="Temperature", - extraction_text="98.6°F", # Degree symbol - attributes={"unit": "fahrenheit"}, - ), - data.Extraction( - extraction_class="Cost", - extraction_text="€50", # Euro symbol - attributes={"currency": "euro"}, - ), - data.Extraction( - extraction_class="Cost", - extraction_text="£40", # Pound symbol - attributes={"currency": "pound"}, - ), - data.Extraction( - extraction_class="Cost", - extraction_text="¥500", # Yen symbol - attributes={"currency": "yen"}, - ), - ], - ) + example = data.ExampleData( + text=f"Medical history: {long_passage} End of history.", + extractions=[ + data.Extraction( + extraction_class="History", + # Exactly 1000 characters + extraction_text=long_passage[:1000], + attributes={"length": "detailed"}, + ), + ], + ) - # Unicode characters should be handled without errors - report = prompt_validation.validate_prompt_alignment([example]) + # Long extraction text should be handled gracefully + report = prompt_validation.validate_prompt_alignment([example]) - # All extractions should align successfully (exact matches) - self.assertEmpty(report.issues) - self.assertFalse(report.has_failed) - self.assertFalse(report.has_non_exact) + # Should either succeed (if substring match works) or fail gracefully + # The important thing is it doesn't crash or hang + self.assertIsNotNone(report) - def test_very_long_extraction_text(self): - """Test graceful handling of very long extraction text (1000+ characters).""" - # Create a long text passage (approximately 1140 characters) - long_passage = ( - "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " * 20 - ) + # If it finds a match, verify it's tracked correctly + if not report.has_failed: + # Successfully found the long extraction + self.assertIsNotNone(report) + else: + # If it fails, that's also acceptable - just verify it fails cleanly + self.assertTrue(report.has_failed) - example = data.ExampleData( - text=f"Medical history: {long_passage} End of history.", + +class ExtractIntegrationTest(absltest.TestCase): + """Minimal integration test for extract() entry point validation.""" + + def test_extract_validates_in_error_mode(self): + """Verify extract() runs validation when configured.""" + examples = [ + data.ExampleData( + text="Patient takes aspirin.", extractions=[ data.Extraction( - extraction_class="History", - # Exactly 1000 characters - extraction_text=long_passage[:1000], - attributes={"length": "detailed"}, - ), + extraction_class="Medication", + extraction_text="ibuprofen", + attributes={}, + ) ], ) + ] - # Long extraction text should be handled gracefully - report = prompt_validation.validate_prompt_alignment([example]) - - # Should either succeed (if substring match works) or fail gracefully - # The important thing is it doesn't crash or hang - self.assertIsNotNone(report) - - # If it finds a match, verify it's tracked correctly - if not report.has_failed: - # Successfully found the long extraction - self.assertIsNotNone(report) - else: - # If it fails, that's also acceptable - just verify it fails cleanly - self.assertTrue(report.has_failed) - - -class ExtractIntegrationTest(absltest.TestCase): - """Minimal integration test for extract() entry point validation.""" - - def test_extract_validates_in_error_mode(self): - """Verify extract() runs validation when configured.""" - examples = [ - data.ExampleData( - text="Patient takes aspirin.", - extractions=[ - data.Extraction( - extraction_class="Medication", - extraction_text="ibuprofen", - attributes={}, - ) - ], - ) - ] - - with self.assertRaisesRegex( - prompt_validation.PromptAlignmentError, - r"1 extraction\(s\).*could not be aligned", - ): - extraction.extract( - text_or_documents="Test document", - prompt_description="Extract medications", - examples=examples, - prompt_validation_level=prompt_validation.PromptValidationLevel.ERROR, - model_id="fake-model", - ) + with self.assertRaisesRegex( + prompt_validation.PromptAlignmentError, + r"1 extraction\(s\).*could not be aligned", + ): + extraction.extract( + text_or_documents="Test document", + prompt_description="Extract medications", + examples=examples, + prompt_validation_level=prompt_validation.PromptValidationLevel.ERROR, + model_id="fake-model", + ) if __name__ == "__main__": - absltest.main() + absltest.main()