Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ SPDX-License-Identifier: MIT-0

### Added

- **Classification valid-class enforcement (#356)** — Page-level classification (`multimodalPageLevelClassification`) now validates the model's predicted class against the configured class vocabulary and, on an out-of-vocabulary prediction, **re-prompts the model** with a correction message listing the allowed classes, retrying up to a configurable limit. This is a deterministic guardrail against models (especially smaller/cheaper ones) returning a label outside the defined set. Because classification runs at `temperature=0`, the retry appends the correction to the request content (a single-turn re-prompt) rather than re-sending an identical request. Three new `classification` config keys control it: `enforceValidClasses` (default `true`), `maxValidationRetries` (default `2`), and `invalidClassFallback` (default `unclassified`). When retries are exhausted the page is assigned the fallback class and flagged with a `validation_error` in its classification metadata (the document keeps processing — no hard failure). All three keys are editable in the Configuration UI. **Behavior change on upgrade:** enforcement is **on by default**, so an out-of-vocabulary prediction that previously passed through unchanged is now corrected or coerced to `unclassified`; set `enforceValidClasses: false` under `classification` to restore the prior "warn and use as-is" behavior. Holistic packet classification is not covered by this loop yet. See [docs/classification.md](docs/classification.md) and the demo notebook `notebooks/misc/classification-valid-class-enforcement.ipynb`.

- **Configurable Lambda architecture** — New `LambdaArchitecture` parameter (`arm64` or `x86_64`) for all unified pattern Lambda container images. Defaults to `arm64` (Graviton) for best price-performance. Use `x86_64` when deploying with custom base images that only support AMD64. The parameter flows through to CodeBuild (`--platform` flag) and Dockerfile (`FROM` image suffix).

## [0.5.15]
Expand Down
54 changes: 54 additions & 0 deletions docs/classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,60 @@ The boundary detection is automatically included in the classification results.
}
}
```

##### Enforcing a Valid Class Vocabulary (Validation + Retry)

With a fixed set of classes, a language model can occasionally return a label
that is **not** in your configured list (for example predicting `receipt` when
only `invoice`, `w2`, and `check` are valid). Smaller / cheaper models are
especially prone to this. `multimodalPageLevelClassification` includes a
deterministic validation + retry guardrail to prevent out-of-vocabulary
classifications:

1. After the model returns a class, it is validated against the configured
class vocabulary.
2. If the class is **not** valid, the model is **re-prompted** — the original
request content is re-sent with an appended correction message that lists
the allowed classes. (This matters: classification runs at `temperature=0`,
so re-sending an identical request would return the identical invalid
answer. The correction changes the input and steers the model back to the
allowed set.)
3. The retry repeats up to `maxValidationRetries` times.
4. If all retries are exhausted, the page is assigned `invalidClassFallback`
(default `unclassified`) and flagged with a `validation_error` entry in its
classification metadata. The document continues processing — there is no
hard failure.

```yaml
classification:
classificationMethod: multimodalPageLevelClassification
enforceValidClasses: true # Validate + retry on invalid class (default: true)
maxValidationRetries: 2 # Re-prompt up to N times (default: 2)
invalidClassFallback: unclassified # Class used when retries are exhausted (default: unclassified)
```

| Setting | Default | Description |
|---------|---------|-------------|
| `enforceValidClasses` | `true` | When `true`, validate the predicted class and re-prompt on out-of-vocabulary results. When `false`, an invalid class is logged and used as-is (legacy behavior). |
| `maxValidationRetries` | `2` | Maximum number of re-prompts. `0` disables retries (a single invalid prediction goes straight to the fallback). |
| `invalidClassFallback` | `unclassified` | Class assigned when retries are exhausted. Set to one of your defined classes, or the built-in `unclassified`. |

> **Behavior change on upgrade:** enforcement is **on by default**. An
> out-of-vocabulary prediction that previously passed through unchanged is now
> corrected or coerced to the fallback class. Set `enforceValidClasses: false`
> to restore the prior "warn and use as-is" behavior.
>
> **Catch-all class:** if you want an explicit "other"/"unknown" bucket, define
> it as one of your classes — the model will then be able to select it
> legitimately, and it counts as a valid prediction.
>
> **Scope:** this loop applies to `multimodalPageLevelClassification`.
> Text-based holistic classification has similar needs but is not covered yet.

A runnable demonstration (forcing an out-of-vocabulary prediction and showing
the retry correct it) is available in
`notebooks/misc/classification-valid-class-enforcement.ipynb`.

#### Text-Based Holistic Classification

- Analyzes entire document packets to identify logical boundaries
Expand Down
39 changes: 39 additions & 0 deletions lib/idp_common_pkg/idp_common/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ This module provides document classification capabilities for the IDP Accelerato
- **Optional regex-based classification for enhanced performance**
- Document name regex matching when all pages should be classified as the same class
- Page content regex matching for multi-modal page-level classification
- **Valid-class enforcement with re-prompt/retry** (page-level) — rejects
out-of-vocabulary predictions and re-prompts the model with the allowed classes
- Direct integration with the Document data model
- Support for both text and image content
- Concurrent processing of multiple pages
Expand Down Expand Up @@ -216,6 +218,43 @@ See `notebooks/examples/step2_classification_with_regex.ipynb` for interactive d
- Configuration examples and best practices
- Error handling scenarios

## Enforcing a Valid Class Vocabulary (Validation + Retry)

For `multimodalPageLevelClassification`, the service can guarantee the predicted
class is always one of the configured classes. After each LLM call, the
predicted class is validated against `self.valid_doc_types` (built from the
configured `classes`). On an out-of-vocabulary prediction the service
re-prompts the model — re-sending the original request content with an appended
correction message that lists the allowed classes — and retries up to a
configurable limit. Because classification runs at `temperature=0`, this
single-turn re-prompt (rather than an identical re-send) is what lets the model
change its answer.

Implemented in `ClassificationService.classify_page_bedrock` with the helper
`_build_validation_retry_content`. Metering is aggregated across all attempts.

```yaml
classification:
enforceValidClasses: true # default: true
maxValidationRetries: 2 # default: 2
invalidClassFallback: unclassified # default: unclassified
```

- `enforceValidClasses` — when `false`, an invalid class is logged and used
as-is (legacy behavior); the loop runs exactly once.
- `maxValidationRetries` — number of re-prompts after the initial attempt
(`0` = no retries).
- `invalidClassFallback` — class assigned when all retries are exhausted; the
resulting `PageClassification.classification.metadata` then carries a
`validation_error` string. The document is **not** failed.

> Holistic (`textbasedHolisticClassification`) does not use this loop yet; it
> still logs a warning and uses an unknown type as-is.

See `notebooks/misc/classification-valid-class-enforcement.ipynb` for a
deterministic, mock-driven walkthrough of all three scenarios (retry-then-valid,
retries-exhausted-fallback, and enforcement-disabled).

## Usage Example

### Using with Bedrock LLMs (Default)
Expand Down
205 changes: 148 additions & 57 deletions lib/idp_common_pkg/idp_common/classification/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1557,80 +1557,137 @@ def classify_page_bedrock(

# Invoke Bedrock model
try:
response_with_metering = self._invoke_bedrock_model(
content=content, config=config
)

t1 = time.time()
logger.info(
f"Time taken for classification of page {page_id}: {t1 - t0:.2f} seconds"
)
# Validation/retry loop: re-prompt the model when it returns a class
# that is not in the configured vocabulary. When enforcement is
# disabled, the loop runs exactly once and preserves legacy
# "warn and use anyway" behavior.
enforce = self.config.classification.enforceValidClasses
max_retries = (
self.config.classification.maxValidationRetries if enforce else 0
)
attempt_content = content
metering: Dict[str, Any] = {}
doc_type = ""
document_boundary = "continue"
validation_error: Optional[str] = None

for attempt in range(max_retries + 1):
response_with_metering = self._invoke_bedrock_model(
content=attempt_content, config=config
)

response = response_with_metering["response"]
metering = response_with_metering["metering"]
response = response_with_metering["response"]
# Accumulate metering across all attempts so token usage from
# retries is not lost. Assign the first attempt's metering
# directly (preserving its exact shape) and merge subsequent
# attempts.
attempt_metering = response_with_metering.get("metering", {})
if not metering:
metering = attempt_metering
else:
metering = utils.merge_metering_data(metering, attempt_metering)

# Extract classification result
# Defensive: Handle case where LLM returns empty content array
content_array = response["output"]["message"].get("content", [])
if not content_array or len(content_array) == 0:
logger.error(
"LLM returned empty content array in classification response",
extra={"page_id": page_id, "response": response},
)
raise ValueError(
f"Classification failed for page {page_id}: LLM returned empty response"
)
# Extract classification result
# Defensive: Handle case where LLM returns empty content array
content_array = response["output"]["message"].get("content", [])
if not content_array or len(content_array) == 0:
logger.error(
"LLM returned empty content array in classification response",
extra={"page_id": page_id, "response": response},
)
raise ValueError(
f"Classification failed for page {page_id}: LLM returned empty response"
)

classification_text = content_array[0].get("text", "")
classification_text = content_array[0].get("text", "")

# Try to extract structured data (JSON or YAML) from the response
try:
classification_data, detected_format = (
extract_structured_data_from_text(classification_text)
)
if isinstance(classification_data, dict):
doc_type = classification_data.get("class", "")
document_boundary = classification_data.get(
"document_boundary", "continue"
# Try to extract structured data (JSON or YAML) from the response
try:
classification_data, detected_format = (
extract_structured_data_from_text(classification_text)
)
logger.info(
f"Parsed classification response as {detected_format}: {classification_data}"
if isinstance(classification_data, dict):
doc_type = classification_data.get("class", "")
document_boundary = classification_data.get(
"document_boundary", "continue"
)
logger.info(
f"Parsed classification response as {detected_format}: {classification_data}"
)
else:
# If parsing failed, try to extract classification directly from text
doc_type = self._extract_class_from_text(classification_text)
document_boundary = "continue"
except Exception as e:
logger.warning(
f"Failed to parse structured data from response: {e}"
)
else:
# If parsing failed, try to extract classification directly from text
# Try to extract classification directly from text
doc_type = self._extract_class_from_text(classification_text)
document_boundary = "continue"
except Exception as e:
logger.warning(f"Failed to parse structured data from response: {e}")
# Try to extract classification directly from text
doc_type = self._extract_class_from_text(classification_text)
document_boundary = "continue"

# Validate classification against known document types
if not doc_type:
doc_type = "unclassified"
logger.warning(
f"Empty classification for page {page_id}, using 'unclassified'"
)
elif doc_type not in self.valid_doc_types:
logger.warning(
f"Unknown document type '{doc_type}' for page {page_id}, "
f"valid types are: {', '.join(self.valid_doc_types)}"
)
# Still use the classification, it might be a new valid type

# Validate the predicted class against the configured vocabulary
if doc_type and doc_type in self.valid_doc_types:
break # Valid prediction - done

if not enforce:
# Legacy behavior: warn and use the prediction as-is.
if not doc_type:
doc_type = "unclassified"
logger.warning(
f"Empty classification for page {page_id}, using 'unclassified'"
)
else:
logger.warning(
f"Unknown document type '{doc_type}' for page {page_id}, "
f"valid types are: {', '.join(self.valid_doc_types)}"
)
# Still use the classification, it might be a new valid type
break

# Enforcement is on and the prediction is invalid.
invalid_value = doc_type or "(empty)"
if attempt < max_retries:
logger.warning(
f"Invalid class '{invalid_value}' for page {page_id} "
f"(attempt {attempt + 1}/{max_retries + 1}); re-prompting "
f"with valid classes."
)
attempt_content = self._build_validation_retry_content(
content, invalid_value
)
else:
# Retries exhausted - assign configured fallback class.
fallback = self.config.classification.invalidClassFallback
validation_error = (
f"Model returned invalid class '{invalid_value}' after "
f"{max_retries + 1} attempt(s); assigned fallback "
f"'{fallback}'."
)
logger.error(f"Page {page_id}: {validation_error}")
doc_type = fallback

t1 = time.time()
logger.info(
f"Time taken for classification of page {page_id}: {t1 - t0:.2f} seconds"
)

logger.info(f"Page {page_id} classified as {doc_type}")

# Create and return classification result
metadata: Dict[str, Any] = {
"metering": metering,
"document_boundary": str(document_boundary).lower(),
}
if validation_error:
metadata["validation_error"] = validation_error

return PageClassification(
page_id=page_id,
classification=DocumentClassification(
doc_type=doc_type,
confidence=1.0, # Default confidence
metadata={
"metering": metering,
"document_boundary": str(document_boundary).lower(),
},
metadata=metadata,
),
image_uri=image_uri,
text_uri=text_uri,
Expand Down Expand Up @@ -1834,6 +1891,40 @@ def classify_page(
text_uri=text_uri,
)

def _build_validation_retry_content(
self, original_content: List[Dict[str, Any]], invalid_class: str
) -> List[Dict[str, Any]]:
"""
Build the content for a validation retry by appending a correction
instruction to the original content.

Because classification typically runs at temperature 0.0, re-sending
the identical request would return the identical (invalid) answer. The
appended correction message changes the input so the model is steered
back to the allowed vocabulary. This is a single-turn re-prompt: we
re-send the original content plus the correction, rather than threading
a multi-turn conversation history.

Args:
original_content: The content list from the initial invocation.
invalid_class: The out-of-vocabulary class the model returned.

Returns:
A new content list (the original is not mutated) with the
correction instruction appended.
"""
valid_classes = ", ".join(sorted(self.valid_doc_types))
correction = (
f"\n\nYour previous response classified the document as "
f"'{invalid_class}', which is NOT a valid class. You MUST choose "
f"exactly one class from this list: [{valid_classes}]. "
f"Respond again using the required output format and select only "
f"from the allowed classes."
)
# Shallow-copy the list and append a new text item. The original
# content dicts are not mutated.
return list(original_content) + [{"text": correction}]

def _invoke_bedrock_model(
self, content: List[Dict[str, Any]], config: Dict[str, Any]
) -> Dict[str, Any]:
Expand Down
Loading
Loading