diff --git a/CHANGELOG.md b/CHANGELOG.md index 4fcb3aee0..779543fbf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ SPDX-License-Identifier: MIT-0 ### Added +- **Classification valid-class enforcement (#356)** — Page-level classification (`multimodalPageLevelClassification`) now validates the model's predicted class against the configured class vocabulary and, on an out-of-vocabulary prediction, **re-prompts the model** with a correction message listing the allowed classes, retrying up to a configurable limit. This is a deterministic guardrail against models (especially smaller/cheaper ones) returning a label outside the defined set. Because classification runs at `temperature=0`, the retry appends the correction to the request content (a single-turn re-prompt) rather than re-sending an identical request. Three new `classification` config keys control it: `enforceValidClasses` (default `true`), `maxValidationRetries` (default `2`), and `invalidClassFallback` (default `unclassified`). When retries are exhausted the page is assigned the fallback class and flagged with a `validation_error` in its classification metadata (the document keeps processing — no hard failure). All three keys are editable in the Configuration UI. **Behavior change on upgrade:** enforcement is **on by default**, so an out-of-vocabulary prediction that previously passed through unchanged is now corrected or coerced to `unclassified`; set `enforceValidClasses: false` under `classification` to restore the prior "warn and use as-is" behavior. Holistic packet classification is not covered by this loop yet. See [docs/classification.md](docs/classification.md) and the demo notebook `notebooks/misc/classification-valid-class-enforcement.ipynb`. + - **Configurable Lambda architecture** — New `LambdaArchitecture` parameter (`arm64` or `x86_64`) for all unified pattern Lambda container images. Defaults to `arm64` (Graviton) for best price-performance. Use `x86_64` when deploying with custom base images that only support AMD64. The parameter flows through to CodeBuild (`--platform` flag) and Dockerfile (`FROM` image suffix). ## [0.5.15] diff --git a/docs/classification.md b/docs/classification.md index 10e2c78e0..1b0ca3d85 100644 --- a/docs/classification.md +++ b/docs/classification.md @@ -166,6 +166,60 @@ The boundary detection is automatically included in the classification results. } } ``` + +##### Enforcing a Valid Class Vocabulary (Validation + Retry) + +With a fixed set of classes, a language model can occasionally return a label +that is **not** in your configured list (for example predicting `receipt` when +only `invoice`, `w2`, and `check` are valid). Smaller / cheaper models are +especially prone to this. `multimodalPageLevelClassification` includes a +deterministic validation + retry guardrail to prevent out-of-vocabulary +classifications: + +1. After the model returns a class, it is validated against the configured + class vocabulary. +2. If the class is **not** valid, the model is **re-prompted** — the original + request content is re-sent with an appended correction message that lists + the allowed classes. (This matters: classification runs at `temperature=0`, + so re-sending an identical request would return the identical invalid + answer. The correction changes the input and steers the model back to the + allowed set.) +3. The retry repeats up to `maxValidationRetries` times. +4. If all retries are exhausted, the page is assigned `invalidClassFallback` + (default `unclassified`) and flagged with a `validation_error` entry in its + classification metadata. The document continues processing — there is no + hard failure. + +```yaml +classification: + classificationMethod: multimodalPageLevelClassification + enforceValidClasses: true # Validate + retry on invalid class (default: true) + maxValidationRetries: 2 # Re-prompt up to N times (default: 2) + invalidClassFallback: unclassified # Class used when retries are exhausted (default: unclassified) +``` + +| Setting | Default | Description | +|---------|---------|-------------| +| `enforceValidClasses` | `true` | When `true`, validate the predicted class and re-prompt on out-of-vocabulary results. When `false`, an invalid class is logged and used as-is (legacy behavior). | +| `maxValidationRetries` | `2` | Maximum number of re-prompts. `0` disables retries (a single invalid prediction goes straight to the fallback). | +| `invalidClassFallback` | `unclassified` | Class assigned when retries are exhausted. Set to one of your defined classes, or the built-in `unclassified`. | + +> **Behavior change on upgrade:** enforcement is **on by default**. An +> out-of-vocabulary prediction that previously passed through unchanged is now +> corrected or coerced to the fallback class. Set `enforceValidClasses: false` +> to restore the prior "warn and use as-is" behavior. +> +> **Catch-all class:** if you want an explicit "other"/"unknown" bucket, define +> it as one of your classes — the model will then be able to select it +> legitimately, and it counts as a valid prediction. +> +> **Scope:** this loop applies to `multimodalPageLevelClassification`. +> Text-based holistic classification has similar needs but is not covered yet. + +A runnable demonstration (forcing an out-of-vocabulary prediction and showing +the retry correct it) is available in +`notebooks/misc/classification-valid-class-enforcement.ipynb`. + #### Text-Based Holistic Classification - Analyzes entire document packets to identify logical boundaries diff --git a/lib/idp_common_pkg/idp_common/classification/README.md b/lib/idp_common_pkg/idp_common/classification/README.md index 801b43d95..c00854890 100644 --- a/lib/idp_common_pkg/idp_common/classification/README.md +++ b/lib/idp_common_pkg/idp_common/classification/README.md @@ -13,6 +13,8 @@ This module provides document classification capabilities for the IDP Accelerato - **Optional regex-based classification for enhanced performance** - Document name regex matching when all pages should be classified as the same class - Page content regex matching for multi-modal page-level classification +- **Valid-class enforcement with re-prompt/retry** (page-level) — rejects + out-of-vocabulary predictions and re-prompts the model with the allowed classes - Direct integration with the Document data model - Support for both text and image content - Concurrent processing of multiple pages @@ -216,6 +218,43 @@ See `notebooks/examples/step2_classification_with_regex.ipynb` for interactive d - Configuration examples and best practices - Error handling scenarios +## Enforcing a Valid Class Vocabulary (Validation + Retry) + +For `multimodalPageLevelClassification`, the service can guarantee the predicted +class is always one of the configured classes. After each LLM call, the +predicted class is validated against `self.valid_doc_types` (built from the +configured `classes`). On an out-of-vocabulary prediction the service +re-prompts the model — re-sending the original request content with an appended +correction message that lists the allowed classes — and retries up to a +configurable limit. Because classification runs at `temperature=0`, this +single-turn re-prompt (rather than an identical re-send) is what lets the model +change its answer. + +Implemented in `ClassificationService.classify_page_bedrock` with the helper +`_build_validation_retry_content`. Metering is aggregated across all attempts. + +```yaml +classification: + enforceValidClasses: true # default: true + maxValidationRetries: 2 # default: 2 + invalidClassFallback: unclassified # default: unclassified +``` + +- `enforceValidClasses` — when `false`, an invalid class is logged and used + as-is (legacy behavior); the loop runs exactly once. +- `maxValidationRetries` — number of re-prompts after the initial attempt + (`0` = no retries). +- `invalidClassFallback` — class assigned when all retries are exhausted; the + resulting `PageClassification.classification.metadata` then carries a + `validation_error` string. The document is **not** failed. + +> Holistic (`textbasedHolisticClassification`) does not use this loop yet; it +> still logs a warning and uses an unknown type as-is. + +See `notebooks/misc/classification-valid-class-enforcement.ipynb` for a +deterministic, mock-driven walkthrough of all three scenarios (retry-then-valid, +retries-exhausted-fallback, and enforcement-disabled). + ## Usage Example ### Using with Bedrock LLMs (Default) diff --git a/lib/idp_common_pkg/idp_common/classification/service.py b/lib/idp_common_pkg/idp_common/classification/service.py index a34935904..b479e225f 100644 --- a/lib/idp_common_pkg/idp_common/classification/service.py +++ b/lib/idp_common_pkg/idp_common/classification/service.py @@ -1557,80 +1557,137 @@ def classify_page_bedrock( # Invoke Bedrock model try: - response_with_metering = self._invoke_bedrock_model( - content=content, config=config - ) - - t1 = time.time() - logger.info( - f"Time taken for classification of page {page_id}: {t1 - t0:.2f} seconds" - ) + # Validation/retry loop: re-prompt the model when it returns a class + # that is not in the configured vocabulary. When enforcement is + # disabled, the loop runs exactly once and preserves legacy + # "warn and use anyway" behavior. + enforce = self.config.classification.enforceValidClasses + max_retries = ( + self.config.classification.maxValidationRetries if enforce else 0 + ) + attempt_content = content + metering: Dict[str, Any] = {} + doc_type = "" + document_boundary = "continue" + validation_error: Optional[str] = None + + for attempt in range(max_retries + 1): + response_with_metering = self._invoke_bedrock_model( + content=attempt_content, config=config + ) - response = response_with_metering["response"] - metering = response_with_metering["metering"] + response = response_with_metering["response"] + # Accumulate metering across all attempts so token usage from + # retries is not lost. Assign the first attempt's metering + # directly (preserving its exact shape) and merge subsequent + # attempts. + attempt_metering = response_with_metering.get("metering", {}) + if not metering: + metering = attempt_metering + else: + metering = utils.merge_metering_data(metering, attempt_metering) - # Extract classification result - # Defensive: Handle case where LLM returns empty content array - content_array = response["output"]["message"].get("content", []) - if not content_array or len(content_array) == 0: - logger.error( - "LLM returned empty content array in classification response", - extra={"page_id": page_id, "response": response}, - ) - raise ValueError( - f"Classification failed for page {page_id}: LLM returned empty response" - ) + # Extract classification result + # Defensive: Handle case where LLM returns empty content array + content_array = response["output"]["message"].get("content", []) + if not content_array or len(content_array) == 0: + logger.error( + "LLM returned empty content array in classification response", + extra={"page_id": page_id, "response": response}, + ) + raise ValueError( + f"Classification failed for page {page_id}: LLM returned empty response" + ) - classification_text = content_array[0].get("text", "") + classification_text = content_array[0].get("text", "") - # Try to extract structured data (JSON or YAML) from the response - try: - classification_data, detected_format = ( - extract_structured_data_from_text(classification_text) - ) - if isinstance(classification_data, dict): - doc_type = classification_data.get("class", "") - document_boundary = classification_data.get( - "document_boundary", "continue" + # Try to extract structured data (JSON or YAML) from the response + try: + classification_data, detected_format = ( + extract_structured_data_from_text(classification_text) ) - logger.info( - f"Parsed classification response as {detected_format}: {classification_data}" + if isinstance(classification_data, dict): + doc_type = classification_data.get("class", "") + document_boundary = classification_data.get( + "document_boundary", "continue" + ) + logger.info( + f"Parsed classification response as {detected_format}: {classification_data}" + ) + else: + # If parsing failed, try to extract classification directly from text + doc_type = self._extract_class_from_text(classification_text) + document_boundary = "continue" + except Exception as e: + logger.warning( + f"Failed to parse structured data from response: {e}" ) - else: - # If parsing failed, try to extract classification directly from text + # Try to extract classification directly from text doc_type = self._extract_class_from_text(classification_text) document_boundary = "continue" - except Exception as e: - logger.warning(f"Failed to parse structured data from response: {e}") - # Try to extract classification directly from text - doc_type = self._extract_class_from_text(classification_text) - document_boundary = "continue" - - # Validate classification against known document types - if not doc_type: - doc_type = "unclassified" - logger.warning( - f"Empty classification for page {page_id}, using 'unclassified'" - ) - elif doc_type not in self.valid_doc_types: - logger.warning( - f"Unknown document type '{doc_type}' for page {page_id}, " - f"valid types are: {', '.join(self.valid_doc_types)}" - ) - # Still use the classification, it might be a new valid type + + # Validate the predicted class against the configured vocabulary + if doc_type and doc_type in self.valid_doc_types: + break # Valid prediction - done + + if not enforce: + # Legacy behavior: warn and use the prediction as-is. + if not doc_type: + doc_type = "unclassified" + logger.warning( + f"Empty classification for page {page_id}, using 'unclassified'" + ) + else: + logger.warning( + f"Unknown document type '{doc_type}' for page {page_id}, " + f"valid types are: {', '.join(self.valid_doc_types)}" + ) + # Still use the classification, it might be a new valid type + break + + # Enforcement is on and the prediction is invalid. + invalid_value = doc_type or "(empty)" + if attempt < max_retries: + logger.warning( + f"Invalid class '{invalid_value}' for page {page_id} " + f"(attempt {attempt + 1}/{max_retries + 1}); re-prompting " + f"with valid classes." + ) + attempt_content = self._build_validation_retry_content( + content, invalid_value + ) + else: + # Retries exhausted - assign configured fallback class. + fallback = self.config.classification.invalidClassFallback + validation_error = ( + f"Model returned invalid class '{invalid_value}' after " + f"{max_retries + 1} attempt(s); assigned fallback " + f"'{fallback}'." + ) + logger.error(f"Page {page_id}: {validation_error}") + doc_type = fallback + + t1 = time.time() + logger.info( + f"Time taken for classification of page {page_id}: {t1 - t0:.2f} seconds" + ) logger.info(f"Page {page_id} classified as {doc_type}") # Create and return classification result + metadata: Dict[str, Any] = { + "metering": metering, + "document_boundary": str(document_boundary).lower(), + } + if validation_error: + metadata["validation_error"] = validation_error + return PageClassification( page_id=page_id, classification=DocumentClassification( doc_type=doc_type, confidence=1.0, # Default confidence - metadata={ - "metering": metering, - "document_boundary": str(document_boundary).lower(), - }, + metadata=metadata, ), image_uri=image_uri, text_uri=text_uri, @@ -1834,6 +1891,40 @@ def classify_page( text_uri=text_uri, ) + def _build_validation_retry_content( + self, original_content: List[Dict[str, Any]], invalid_class: str + ) -> List[Dict[str, Any]]: + """ + Build the content for a validation retry by appending a correction + instruction to the original content. + + Because classification typically runs at temperature 0.0, re-sending + the identical request would return the identical (invalid) answer. The + appended correction message changes the input so the model is steered + back to the allowed vocabulary. This is a single-turn re-prompt: we + re-send the original content plus the correction, rather than threading + a multi-turn conversation history. + + Args: + original_content: The content list from the initial invocation. + invalid_class: The out-of-vocabulary class the model returned. + + Returns: + A new content list (the original is not mutated) with the + correction instruction appended. + """ + valid_classes = ", ".join(sorted(self.valid_doc_types)) + correction = ( + f"\n\nYour previous response classified the document as " + f"'{invalid_class}', which is NOT a valid class. You MUST choose " + f"exactly one class from this list: [{valid_classes}]. " + f"Respond again using the required output format and select only " + f"from the allowed classes." + ) + # Shallow-copy the list and append a new text item. The original + # content dicts are not mutated. + return list(original_content) + [{"text": correction}] + def _invoke_bedrock_model( self, content: List[Dict[str, Any]], config: Dict[str, Any] ) -> Dict[str, Any]: diff --git a/lib/idp_common_pkg/idp_common/config/models.py b/lib/idp_common_pkg/idp_common/config/models.py index 928c1770b..1fb40a4c4 100644 --- a/lib/idp_common_pkg/idp_common/config/models.py +++ b/lib/idp_common_pkg/idp_common/config/models.py @@ -384,6 +384,26 @@ class ClassificationConfig(BaseModel): description="Number of pages before/after target page to include as context for multimodalPageLevelClassification. " "0=no context (default), 1=include 1 page on each side, 2=include 2 pages on each side.", ) + enforceValidClasses: bool = Field( + default=True, + description="When True, validate the predicted class against the configured " + "class vocabulary and retry (re-prompting the model) on out-of-vocabulary " + "predictions. When False, an out-of-vocabulary prediction is logged and used " + "as-is (legacy behavior). Applies to multimodalPageLevelClassification.", + ) + maxValidationRetries: int = Field( + default=2, + ge=0, + description="Maximum number of re-prompt retries when the predicted class is " + "not in the configured class vocabulary. Only used when enforceValidClasses " + "is True.", + ) + invalidClassFallback: str = Field( + default="unclassified", + description="Class label assigned when all validation retries are exhausted. " + "Should be one of the configured classes or the built-in 'unclassified'. " + "Only used when enforceValidClasses is True.", + ) image: ImageConfig = Field(default_factory=ImageConfig) @field_validator("temperature", "top_p", "top_k", mode="before") @@ -456,6 +476,25 @@ def parse_context_pages_count(cls, v: Any) -> int: return 0 return result + @field_validator("maxValidationRetries", mode="before") + @classmethod + def parse_max_validation_retries(cls, v: Any) -> int: + """Parse maxValidationRetries from string or number, ensuring non-negative value""" + if isinstance(v, str): + v = int(v) if v.strip() else 2 + result = int(v) + if result < 0: + return 0 + return result + + @field_validator("enforceValidClasses", mode="before") + @classmethod + def parse_enforce_valid_classes(cls, v: Any) -> bool: + """Parse enforceValidClasses from string or bool (config may store as string)""" + if isinstance(v, str): + return v.strip().lower() in ("true", "1", "yes", "on") + return bool(v) + class GranularAssessmentConfig(BaseModel): """Granular assessment configuration""" diff --git a/lib/idp_common_pkg/idp_common/config/system_defaults/base-classification.yaml b/lib/idp_common_pkg/idp_common/config/system_defaults/base-classification.yaml index b44f86016..ddf9beabb 100644 --- a/lib/idp_common_pkg/idp_common/config/system_defaults/base-classification.yaml +++ b/lib/idp_common_pkg/idp_common/config/system_defaults/base-classification.yaml @@ -10,6 +10,12 @@ classification: maxPagesForClassification: "ALL" contextPagesCount: "0" sectionSplitting: llm_determined + # Validate the predicted class against the configured class vocabulary and + # re-prompt the model on out-of-vocabulary predictions + # (multimodalPageLevelClassification only). + enforceValidClasses: true + maxValidationRetries: "2" + invalidClassFallback: unclassified image: target_height: "" target_width: "" diff --git a/lib/idp_common_pkg/tests/unit/classification/test_classification_service.py b/lib/idp_common_pkg/tests/unit/classification/test_classification_service.py index 4409ca258..73f72867f 100644 --- a/lib/idp_common_pkg/tests/unit/classification/test_classification_service.py +++ b/lib/idp_common_pkg/tests/unit/classification/test_classification_service.py @@ -295,6 +295,158 @@ def test_classify_page_bedrock_success( mock_prepare_bedrock_image.assert_called_once_with(b"image_data") mock_invoke.assert_called_once() + @staticmethod + def _bedrock_response(class_value): + """Helper to build a mocked _invoke_bedrock_model return value.""" + return { + "response": { + "output": { + "message": { + "content": [{"text": json.dumps({"class": class_value})}] + } + } + }, + "metering": {"bedrock": {"inputTokens": 100, "outputTokens": 10}}, + } + + @patch("idp_common.s3.get_text_content") + @patch("idp_common.image.prepare_image") + @patch( + "idp_common.classification.service.ClassificationService._invoke_bedrock_model" + ) + @patch("idp_common.image.prepare_bedrock_image_attachment") + def test_enforce_valid_classes_retry_then_valid( + self, + mock_prepare_bedrock_image, + mock_invoke, + mock_prepare_image, + mock_get_text, + service, + ): + """Out-of-vocabulary prediction is corrected on retry.""" + mock_get_text.return_value = "This is an invoice for $100" + mock_prepare_image.return_value = b"image_data" + mock_prepare_bedrock_image.return_value = {"image": "base64_encoded_image"} + + # First call returns an invalid class, second returns a valid one. + mock_invoke.side_effect = [ + self._bedrock_response("not_a_real_class"), + self._bedrock_response("invoice"), + ] + + result = service.classify_page_bedrock( + page_id="1", + text_uri="s3://bucket/text.txt", + image_uri="s3://bucket/image.jpg", + ) + + assert result.classification.doc_type == "invoice" + assert "validation_error" not in result.classification.metadata + assert mock_invoke.call_count == 2 + # Metering aggregated across both attempts. + assert ( + result.classification.metadata["metering"]["bedrock"]["inputTokens"] == 200 + ) + + @patch("idp_common.s3.get_text_content") + @patch("idp_common.image.prepare_image") + @patch( + "idp_common.classification.service.ClassificationService._invoke_bedrock_model" + ) + @patch("idp_common.image.prepare_bedrock_image_attachment") + def test_enforce_valid_classes_exhausted_uses_fallback( + self, + mock_prepare_bedrock_image, + mock_invoke, + mock_prepare_image, + mock_get_text, + service, + ): + """When all retries fail, the fallback class and error metadata are set.""" + mock_get_text.return_value = "Some ambiguous content" + mock_prepare_image.return_value = b"image_data" + mock_prepare_bedrock_image.return_value = {"image": "base64_encoded_image"} + + # Always returns an invalid class. Default maxValidationRetries=2 -> 3 calls. + mock_invoke.return_value = self._bedrock_response("bogus_class") + + result = service.classify_page_bedrock( + page_id="1", + text_uri="s3://bucket/text.txt", + image_uri="s3://bucket/image.jpg", + ) + + assert result.classification.doc_type == "unclassified" + assert "validation_error" in result.classification.metadata + assert "bogus_class" in result.classification.metadata["validation_error"] + assert mock_invoke.call_count == 3 # initial + 2 retries + + @patch("idp_common.s3.get_text_content") + @patch("idp_common.image.prepare_image") + @patch( + "idp_common.classification.service.ClassificationService._invoke_bedrock_model" + ) + @patch("idp_common.image.prepare_bedrock_image_attachment") + def test_enforce_valid_classes_valid_first_attempt_no_retry( + self, + mock_prepare_bedrock_image, + mock_invoke, + mock_prepare_image, + mock_get_text, + service, + ): + """A valid first prediction triggers no extra invocations.""" + mock_get_text.return_value = "This is an invoice for $100" + mock_prepare_image.return_value = b"image_data" + mock_prepare_bedrock_image.return_value = {"image": "base64_encoded_image"} + mock_invoke.return_value = self._bedrock_response("invoice") + + result = service.classify_page_bedrock( + page_id="1", + text_uri="s3://bucket/text.txt", + image_uri="s3://bucket/image.jpg", + ) + + assert result.classification.doc_type == "invoice" + mock_invoke.assert_called_once() + + @patch("idp_common.s3.get_text_content") + @patch("idp_common.image.prepare_image") + @patch( + "idp_common.classification.service.ClassificationService._invoke_bedrock_model" + ) + @patch("idp_common.image.prepare_bedrock_image_attachment") + def test_enforcement_disabled_uses_invalid_class_as_is( + self, + mock_prepare_bedrock_image, + mock_invoke, + mock_prepare_image, + mock_get_text, + mock_config, + ): + """Legacy behavior: with enforcement off, invalid class is used as-is.""" + mock_config["classification"]["enforceValidClasses"] = False + with patch("boto3.Session"): + service = ClassificationService( + region="us-west-2", config=mock_config, backend="bedrock" + ) + + mock_get_text.return_value = "This is an invoice for $100" + mock_prepare_image.return_value = b"image_data" + mock_prepare_bedrock_image.return_value = {"image": "base64_encoded_image"} + mock_invoke.return_value = self._bedrock_response("not_a_real_class") + + result = service.classify_page_bedrock( + page_id="1", + text_uri="s3://bucket/text.txt", + image_uri="s3://bucket/image.jpg", + ) + + # Invalid class is used as-is, no retry, no validation_error. + assert result.classification.doc_type == "not_a_real_class" + assert "validation_error" not in result.classification.metadata + mock_invoke.assert_called_once() + @patch("idp_common.s3.get_text_content") @patch( "idp_common.classification.service.ClassificationService._invoke_bedrock_model" diff --git a/lib/idp_common_pkg/tests/unit/config/test_config_models.py b/lib/idp_common_pkg/tests/unit/config/test_config_models.py index af6507f39..935d1bc51 100644 --- a/lib/idp_common_pkg/tests/unit/config/test_config_models.py +++ b/lib/idp_common_pkg/tests/unit/config/test_config_models.py @@ -124,6 +124,31 @@ def test_full_config_with_mixed_types(self): assert config.extraction.top_p == 0.1 assert config.extraction.max_tokens == 10000 + def test_classification_valid_class_enforcement_defaults(self): + """New class-enforcement fields default to enabled with sane values.""" + from idp_common.config.models import ClassificationConfig + + cfg = ClassificationConfig() + assert cfg.enforceValidClasses is True + assert cfg.maxValidationRetries == 2 + assert cfg.invalidClassFallback == "unclassified" + + def test_classification_valid_class_enforcement_parsing(self): + """String-typed stored config values parse into the correct types.""" + from idp_common.config.models import ClassificationConfig + + cfg = ClassificationConfig( + enforceValidClasses="false", + maxValidationRetries="3", + invalidClassFallback="other", + ) + assert cfg.enforceValidClasses is False + assert cfg.maxValidationRetries == 3 + assert cfg.invalidClassFallback == "other" + + # Negative retries are clamped to 0. + assert ClassificationConfig(maxValidationRetries="-1").maxValidationRetries == 0 + def test_config_type_hints(self): """Test that config can be used as type hint""" diff --git a/notebooks/misc/classification-valid-class-enforcement.ipynb b/notebooks/misc/classification-valid-class-enforcement.ipynb new file mode 100644 index 000000000..085745ba4 --- /dev/null +++ b/notebooks/misc/classification-valid-class-enforcement.ipynb @@ -0,0 +1,294 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Enforcing a Valid Classification Vocabulary (with Retry)\n", + "\n", + "When you classify pages with a fixed set of classes (e.g. `invoice`, `w2`,\n", + "`check`), a language model can occasionally return a label that is **not** in\n", + "your configured list \u2014 for example predicting `receipt` when only `invoice`,\n", + "`w2`, and `check` are valid. Smaller / cheaper models are especially prone to\n", + "this.\n", + "\n", + "The `multimodalPageLevelClassification` method in `idp_common` supports a\n", + "**deterministic validation + retry loop** that fixes this:\n", + "\n", + "1. After the model returns a class, validate it against the configured vocabulary.\n", + "2. If it is **not** valid, re-prompt the model \u2014 appending a correction message\n", + " that lists the allowed classes (this matters: at `temperature=0`, re-sending\n", + " the *same* request would return the *same* invalid answer).\n", + "3. Retry up to `maxValidationRetries` times.\n", + "4. If retries are exhausted, assign the configured `invalidClassFallback`\n", + " (default `unclassified`) and flag the page with a `validation_error`.\n", + "\n", + "This notebook demonstrates the behavior **deterministically** by mocking the\n", + "Bedrock call so we can force an out-of-vocabulary prediction. In production the\n", + "same loop runs against the real model \u2014 no code changes required, just the\n", + "config flags shown below.\n", + "\n", + "> **Related config keys** (under `classification:`):\n", + "> - `enforceValidClasses` (default `true`)\n", + "> - `maxValidationRetries` (default `2`)\n", + "> - `invalidClassFallback` (default `unclassified`)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import logging\n", + "from unittest.mock import patch\n", + "\n", + "from idp_common.classification.service import ClassificationService\n", + "\n", + "# Surface the retry/validation log messages so we can see the loop work.\n", + "logging.basicConfig(level=logging.WARNING)\n", + "logging.getLogger(\"idp_common.classification\").setLevel(logging.INFO)\n", + "\n", + "print(\"Imports OK\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Define a configuration with a fixed class vocabulary\n", + "\n", + "We define three valid classes and enable enforcement with up to 2 retries.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def make_class(name, description):\n", + " return {\n", + " \"$schema\": \"https://json-schema.org/draft/2020-12/schema\",\n", + " \"$id\": name,\n", + " \"x-aws-idp-document-type\": name,\n", + " \"type\": \"object\",\n", + " \"description\": description,\n", + " \"properties\": {},\n", + " }\n", + "\n", + "config = {\n", + " \"classes\": [\n", + " make_class(\"invoice\", \"An invoice document\"),\n", + " make_class(\"w2\", \"A W-2 tax form\"),\n", + " make_class(\"check\", \"A bank check\"),\n", + " ],\n", + " \"classification\": {\n", + " \"model\": \"us.amazon.nova-lite-v1:0\",\n", + " \"temperature\": 0.0,\n", + " \"top_k\": 5,\n", + " \"system_prompt\": \"You are a document classification assistant.\",\n", + " \"task_prompt\": (\n", + " \"Classify the document into one of:\\n\"\n", + " \"{CLASS_NAMES_AND_DESCRIPTIONS}\\n\\n\"\n", + " \"Document text:\\n{DOCUMENT_TEXT}\\n\\n\"\n", + " \"Image:\\n{DOCUMENT_IMAGE}\\n\\n\"\n", + " 'Respond with JSON: {\"class\": \"\"}'\n", + " ),\n", + " \"classificationMethod\": \"multimodalPageLevelClassification\",\n", + " # --- the feature under test ---\n", + " \"enforceValidClasses\": True,\n", + " \"maxValidationRetries\": 2,\n", + " \"invalidClassFallback\": \"unclassified\",\n", + " },\n", + "}\n", + "\n", + "print(\"Valid classes:\", [c[\"$id\"] for c in config[\"classes\"]])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. A helper to simulate the model\n", + "\n", + "Real models won't reliably emit an invalid class on demand, so we mock\n", + "`_invoke_bedrock_model` to return a scripted sequence of responses. This lets\n", + "us demonstrate each scenario deterministically. Each mocked response uses the\n", + "same shape the real Bedrock client returns.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def bedrock_response(class_value):\n", + " \"\"\"Build a fake Bedrock Converse response for the given class label.\"\"\"\n", + " return {\n", + " \"response\": {\n", + " \"output\": {\n", + " \"message\": {\"content\": [{\"text\": json.dumps({\"class\": class_value})}]}\n", + " }\n", + " },\n", + " \"metering\": {\"bedrock\": {\"inputTokens\": 120, \"outputTokens\": 8}},\n", + " }\n", + "\n", + "\n", + "def run_with_responses(cfg, responses):\n", + " \"\"\"Classify one page, feeding the scripted `responses` to the model.\"\"\"\n", + " with patch(\"boto3.Session\"):\n", + " service = ClassificationService(region=\"us-west-2\", config=cfg, backend=\"bedrock\")\n", + "\n", + " with (\n", + " patch(\"idp_common.s3.get_text_content\", return_value=\"ACME Corp Invoice #42 Total: $100\"),\n", + " patch(\"idp_common.image.prepare_image\", return_value=b\"img\"),\n", + " patch(\"idp_common.image.prepare_bedrock_image_attachment\", return_value={\"image\": \"b64\"}),\n", + " patch.object(ClassificationService, \"_invoke_bedrock_model\", side_effect=responses) as mock_invoke,\n", + " ):\n", + " result = service.classify_page_bedrock(\n", + " page_id=\"1\",\n", + " text_uri=\"s3://bucket/text.txt\",\n", + " image_uri=\"s3://bucket/image.jpg\",\n", + " )\n", + " return result, mock_invoke.call_count" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Scenario A \u2014 model corrects itself on retry\n", + "\n", + "The model first returns `receipt` (not in our vocabulary). The validation loop\n", + "re-prompts, and on the second attempt the model returns the valid `invoice`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result, calls = run_with_responses(\n", + " config,\n", + " [bedrock_response(\"receipt\"), bedrock_response(\"invoice\")],\n", + ")\n", + "\n", + "print(f\"\\nFinal class : {result.classification.doc_type}\")\n", + "print(f\"Model calls : {calls} (1 initial + 1 retry)\")\n", + "print(f\"validation_error present: {'validation_error' in result.classification.metadata}\")\n", + "assert result.classification.doc_type == \"invoice\"\n", + "assert calls == 2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Scenario B \u2014 retries exhausted, fallback applied\n", + "\n", + "The model returns an invalid class on every attempt. After the initial call\n", + "plus 2 retries (3 total), the page is assigned the `invalidClassFallback`\n", + "(`unclassified`) and flagged with a `validation_error`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result, calls = run_with_responses(\n", + " config,\n", + " [bedrock_response(\"receipt\")] * 5, # always invalid; only 3 will be consumed\n", + ")\n", + "\n", + "print(f\"\\nFinal class : {result.classification.doc_type}\")\n", + "print(f\"Model calls : {calls} (1 initial + 2 retries)\")\n", + "print(f\"validation_error: {result.classification.metadata.get('validation_error')}\")\n", + "assert result.classification.doc_type == \"unclassified\"\n", + "assert calls == 3" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Scenario C \u2014 enforcement disabled (legacy behavior)\n", + "\n", + "With `enforceValidClasses: false`, an out-of-vocabulary prediction is logged as\n", + "a warning and used **as-is** \u2014 no retry, no fallback. This is the behavior prior\n", + "to this feature, retained for backward compatibility.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "legacy_config = json.loads(json.dumps(config)) # deep copy\n", + "legacy_config[\"classification\"][\"enforceValidClasses\"] = False\n", + "\n", + "result, calls = run_with_responses(legacy_config, [bedrock_response(\"receipt\")])\n", + "\n", + "print(f\"\\nFinal class : {result.classification.doc_type} (used as-is)\")\n", + "print(f\"Model calls : {calls}\")\n", + "assert result.classification.doc_type == \"receipt\"\n", + "assert calls == 1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Summary\n", + "\n", + "| Scenario | `enforceValidClasses` | Model behavior | Result |\n", + "|----------|----------------------|----------------|--------|\n", + "| A | `true` | invalid \u2192 valid on retry | corrected to valid class |\n", + "| B | `true` | always invalid | `invalidClassFallback` + `validation_error` |\n", + "| C | `false` | invalid | invalid class used as-is (legacy) |\n", + "\n", + "**To enable in your deployment**, set these under `classification:` in your\n", + "config (they are on by default for new deployments):\n", + "\n", + "```yaml\n", + "classification:\n", + " enforceValidClasses: true\n", + " maxValidationRetries: 2\n", + " invalidClassFallback: unclassified\n", + "```\n", + "\n", + "These are also editable in the **Configuration UI** under the Classification\n", + "section.\n", + "\n", + "> **Note:** This applies to `multimodalPageLevelClassification`. Holistic\n", + "> packet classification has similar needs but is not covered by this loop yet.\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file diff --git a/patterns/unified/template.yaml b/patterns/unified/template.yaml index 7678f6c37..da6753b39 100644 --- a/patterns/unified/template.yaml +++ b/patterns/unified/template.yaml @@ -911,6 +911,26 @@ Resources: default: 0 order: 3.6 dependsOn: { field: "classificationMethod", value: "multimodalPageLevelClassification" } + enforceValidClasses: + type: boolean + description: "When enabled, the predicted class is validated against the defined class vocabulary. If the model returns a class that is not in the list, it is re-prompted with the valid choices and retried (see Max Validation Retries). If retries are exhausted, the page is assigned the Invalid Class Fallback and flagged with a validation error. When disabled, an out-of-vocabulary prediction is used as-is (legacy behavior). Only applies to multimodalPageLevelClassification." + default: true + order: 3.7 + dependsOn: { field: "classificationMethod", value: "multimodalPageLevelClassification" } + maxValidationRetries: + type: integer + description: "Maximum number of times to re-prompt the model when it returns a class outside the defined vocabulary. Only used when Enforce Valid Classes is enabled. 0 disables retries (a single invalid prediction goes straight to the fallback)." + minimum: 0 + maximum: 5 + default: 2 + order: 3.8 + dependsOn: { field: "enforceValidClasses", value: true } + invalidClassFallback: + type: string + description: "Class label assigned to a page when all validation retries are exhausted. Should be one of your defined classes, or the built-in 'unclassified'. Only used when Enforce Valid Classes is enabled." + default: "unclassified" + order: 3.9 + dependsOn: { field: "enforceValidClasses", value: true } temperature: type: number minimum: 0