Skip to content

Commit 9fe710c

Browse files
rstrahanclaude
andauthored
feat(classification): enforce valid class vocabulary with re-prompt/retry (#356) (#371)
Page-level classification (multimodalPageLevelClassification) now validates the model's predicted class against the configured class vocabulary and re-prompts the model on out-of-vocabulary predictions, retrying up to a configurable limit. Because classification runs at temperature=0, the retry appends a correction message listing the allowed classes (a single-turn re-prompt) rather than re-sending an identical request. When retries are exhausted the page is assigned a configurable fallback class and flagged with a validation_error in metadata; the document keeps processing (no hard failure). Three new classification config keys control it (on by default): - enforceValidClasses (default true) - maxValidationRetries (default 2) - invalidClassFallback (default unclassified) Behavior change on upgrade: enforcement is on by default, so out-of-vocabulary predictions that previously passed through unchanged are now corrected or coerced to the fallback. Set enforceValidClasses: false to restore legacy "warn and use as-is" behavior. - Add fields + validators to ClassificationConfig - Add defaults to base-classification.yaml (inherited by config_library samples) - Add ConfigSchema entries to patterns/unified/template.yaml for the Config UI - Add validation/retry loop and _build_validation_retry_content helper - Add unit tests (retry-then-valid, exhausted-fallback, valid-first, legacy, metering aggregation) and config-model default/parsing tests - Add demo notebook notebooks/misc/classification-valid-class-enforcement.ipynb - Update user docs, developer README, and CHANGELOG Holistic packet classification is not covered by this loop yet. Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent edfd9bf commit 9fe710c

10 files changed

Lines changed: 779 additions & 57 deletions

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ SPDX-License-Identifier: MIT-0
77

88
### Added
99

10+
- **Classification valid-class enforcement (#356)** — Page-level classification (`multimodalPageLevelClassification`) now validates the model's predicted class against the configured class vocabulary and, on an out-of-vocabulary prediction, **re-prompts the model** with a correction message listing the allowed classes, retrying up to a configurable limit. This is a deterministic guardrail against models (especially smaller/cheaper ones) returning a label outside the defined set. Because classification runs at `temperature=0`, the retry appends the correction to the request content (a single-turn re-prompt) rather than re-sending an identical request. Three new `classification` config keys control it: `enforceValidClasses` (default `true`), `maxValidationRetries` (default `2`), and `invalidClassFallback` (default `unclassified`). When retries are exhausted the page is assigned the fallback class and flagged with a `validation_error` in its classification metadata (the document keeps processing — no hard failure). All three keys are editable in the Configuration UI. **Behavior change on upgrade:** enforcement is **on by default**, so an out-of-vocabulary prediction that previously passed through unchanged is now corrected or coerced to `unclassified`; set `enforceValidClasses: false` under `classification` to restore the prior "warn and use as-is" behavior. Holistic packet classification is not covered by this loop yet. See [docs/classification.md](docs/classification.md) and the demo notebook `notebooks/misc/classification-valid-class-enforcement.ipynb`.
11+
1012
- **Configurable Lambda architecture** — New `LambdaArchitecture` parameter (`arm64` or `x86_64`) for all unified pattern Lambda container images. Defaults to `arm64` (Graviton) for best price-performance. Use `x86_64` when deploying with custom base images that only support AMD64. The parameter flows through to CodeBuild (`--platform` flag) and Dockerfile (`FROM` image suffix).
1113

1214
### Removed

docs/classification.md

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,60 @@ The boundary detection is automatically included in the classification results.
166166
}
167167
}
168168
```
169+
170+
##### Enforcing a Valid Class Vocabulary (Validation + Retry)
171+
172+
With a fixed set of classes, a language model can occasionally return a label
173+
that is **not** in your configured list (for example predicting `receipt` when
174+
only `invoice`, `w2`, and `check` are valid). Smaller / cheaper models are
175+
especially prone to this. `multimodalPageLevelClassification` includes a
176+
deterministic validation + retry guardrail to prevent out-of-vocabulary
177+
classifications:
178+
179+
1. After the model returns a class, it is validated against the configured
180+
class vocabulary.
181+
2. If the class is **not** valid, the model is **re-prompted** — the original
182+
request content is re-sent with an appended correction message that lists
183+
the allowed classes. (This matters: classification runs at `temperature=0`,
184+
so re-sending an identical request would return the identical invalid
185+
answer. The correction changes the input and steers the model back to the
186+
allowed set.)
187+
3. The retry repeats up to `maxValidationRetries` times.
188+
4. If all retries are exhausted, the page is assigned `invalidClassFallback`
189+
(default `unclassified`) and flagged with a `validation_error` entry in its
190+
classification metadata. The document continues processing — there is no
191+
hard failure.
192+
193+
```yaml
194+
classification:
195+
classificationMethod: multimodalPageLevelClassification
196+
enforceValidClasses: true # Validate + retry on invalid class (default: true)
197+
maxValidationRetries: 2 # Re-prompt up to N times (default: 2)
198+
invalidClassFallback: unclassified # Class used when retries are exhausted (default: unclassified)
199+
```
200+
201+
| Setting | Default | Description |
202+
|---------|---------|-------------|
203+
| `enforceValidClasses` | `true` | When `true`, validate the predicted class and re-prompt on out-of-vocabulary results. When `false`, an invalid class is logged and used as-is (legacy behavior). |
204+
| `maxValidationRetries` | `2` | Maximum number of re-prompts. `0` disables retries (a single invalid prediction goes straight to the fallback). |
205+
| `invalidClassFallback` | `unclassified` | Class assigned when retries are exhausted. Set to one of your defined classes, or the built-in `unclassified`. |
206+
207+
> **Behavior change on upgrade:** enforcement is **on by default**. An
208+
> out-of-vocabulary prediction that previously passed through unchanged is now
209+
> corrected or coerced to the fallback class. Set `enforceValidClasses: false`
210+
> to restore the prior "warn and use as-is" behavior.
211+
>
212+
> **Catch-all class:** if you want an explicit "other"/"unknown" bucket, define
213+
> it as one of your classes — the model will then be able to select it
214+
> legitimately, and it counts as a valid prediction.
215+
>
216+
> **Scope:** this loop applies to `multimodalPageLevelClassification`.
217+
> Text-based holistic classification has similar needs but is not covered yet.
218+
219+
A runnable demonstration (forcing an out-of-vocabulary prediction and showing
220+
the retry correct it) is available in
221+
`notebooks/misc/classification-valid-class-enforcement.ipynb`.
222+
169223
#### Text-Based Holistic Classification
170224

171225
- Analyzes entire document packets to identify logical boundaries

lib/idp_common_pkg/idp_common/classification/README.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ This module provides document classification capabilities for the IDP Accelerato
1313
- **Optional regex-based classification for enhanced performance**
1414
- Document name regex matching when all pages should be classified as the same class
1515
- Page content regex matching for multi-modal page-level classification
16+
- **Valid-class enforcement with re-prompt/retry** (page-level) — rejects
17+
out-of-vocabulary predictions and re-prompts the model with the allowed classes
1618
- Direct integration with the Document data model
1719
- Support for both text and image content
1820
- Concurrent processing of multiple pages
@@ -216,6 +218,43 @@ See `notebooks/examples/step2_classification_with_regex.ipynb` for interactive d
216218
- Configuration examples and best practices
217219
- Error handling scenarios
218220

221+
## Enforcing a Valid Class Vocabulary (Validation + Retry)
222+
223+
For `multimodalPageLevelClassification`, the service can guarantee the predicted
224+
class is always one of the configured classes. After each LLM call, the
225+
predicted class is validated against `self.valid_doc_types` (built from the
226+
configured `classes`). On an out-of-vocabulary prediction the service
227+
re-prompts the model — re-sending the original request content with an appended
228+
correction message that lists the allowed classes — and retries up to a
229+
configurable limit. Because classification runs at `temperature=0`, this
230+
single-turn re-prompt (rather than an identical re-send) is what lets the model
231+
change its answer.
232+
233+
Implemented in `ClassificationService.classify_page_bedrock` with the helper
234+
`_build_validation_retry_content`. Metering is aggregated across all attempts.
235+
236+
```yaml
237+
classification:
238+
enforceValidClasses: true # default: true
239+
maxValidationRetries: 2 # default: 2
240+
invalidClassFallback: unclassified # default: unclassified
241+
```
242+
243+
- `enforceValidClasses` — when `false`, an invalid class is logged and used
244+
as-is (legacy behavior); the loop runs exactly once.
245+
- `maxValidationRetries` — number of re-prompts after the initial attempt
246+
(`0` = no retries).
247+
- `invalidClassFallback` — class assigned when all retries are exhausted; the
248+
resulting `PageClassification.classification.metadata` then carries a
249+
`validation_error` string. The document is **not** failed.
250+
251+
> Holistic (`textbasedHolisticClassification`) does not use this loop yet; it
252+
> still logs a warning and uses an unknown type as-is.
253+
254+
See `notebooks/misc/classification-valid-class-enforcement.ipynb` for a
255+
deterministic, mock-driven walkthrough of all three scenarios (retry-then-valid,
256+
retries-exhausted-fallback, and enforcement-disabled).
257+
219258
## Usage Example
220259

221260
### Using with Bedrock LLMs (Default)

lib/idp_common_pkg/idp_common/classification/service.py

Lines changed: 148 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1557,80 +1557,137 @@ def classify_page_bedrock(
15571557

15581558
# Invoke Bedrock model
15591559
try:
1560-
response_with_metering = self._invoke_bedrock_model(
1561-
content=content, config=config
1562-
)
1563-
1564-
t1 = time.time()
1565-
logger.info(
1566-
f"Time taken for classification of page {page_id}: {t1 - t0:.2f} seconds"
1567-
)
1560+
# Validation/retry loop: re-prompt the model when it returns a class
1561+
# that is not in the configured vocabulary. When enforcement is
1562+
# disabled, the loop runs exactly once and preserves legacy
1563+
# "warn and use anyway" behavior.
1564+
enforce = self.config.classification.enforceValidClasses
1565+
max_retries = (
1566+
self.config.classification.maxValidationRetries if enforce else 0
1567+
)
1568+
attempt_content = content
1569+
metering: Dict[str, Any] = {}
1570+
doc_type = ""
1571+
document_boundary = "continue"
1572+
validation_error: Optional[str] = None
1573+
1574+
for attempt in range(max_retries + 1):
1575+
response_with_metering = self._invoke_bedrock_model(
1576+
content=attempt_content, config=config
1577+
)
15681578

1569-
response = response_with_metering["response"]
1570-
metering = response_with_metering["metering"]
1579+
response = response_with_metering["response"]
1580+
# Accumulate metering across all attempts so token usage from
1581+
# retries is not lost. Assign the first attempt's metering
1582+
# directly (preserving its exact shape) and merge subsequent
1583+
# attempts.
1584+
attempt_metering = response_with_metering.get("metering", {})
1585+
if not metering:
1586+
metering = attempt_metering
1587+
else:
1588+
metering = utils.merge_metering_data(metering, attempt_metering)
15711589

1572-
# Extract classification result
1573-
# Defensive: Handle case where LLM returns empty content array
1574-
content_array = response["output"]["message"].get("content", [])
1575-
if not content_array or len(content_array) == 0:
1576-
logger.error(
1577-
"LLM returned empty content array in classification response",
1578-
extra={"page_id": page_id, "response": response},
1579-
)
1580-
raise ValueError(
1581-
f"Classification failed for page {page_id}: LLM returned empty response"
1582-
)
1590+
# Extract classification result
1591+
# Defensive: Handle case where LLM returns empty content array
1592+
content_array = response["output"]["message"].get("content", [])
1593+
if not content_array or len(content_array) == 0:
1594+
logger.error(
1595+
"LLM returned empty content array in classification response",
1596+
extra={"page_id": page_id, "response": response},
1597+
)
1598+
raise ValueError(
1599+
f"Classification failed for page {page_id}: LLM returned empty response"
1600+
)
15831601

1584-
classification_text = content_array[0].get("text", "")
1602+
classification_text = content_array[0].get("text", "")
15851603

1586-
# Try to extract structured data (JSON or YAML) from the response
1587-
try:
1588-
classification_data, detected_format = (
1589-
extract_structured_data_from_text(classification_text)
1590-
)
1591-
if isinstance(classification_data, dict):
1592-
doc_type = classification_data.get("class", "")
1593-
document_boundary = classification_data.get(
1594-
"document_boundary", "continue"
1604+
# Try to extract structured data (JSON or YAML) from the response
1605+
try:
1606+
classification_data, detected_format = (
1607+
extract_structured_data_from_text(classification_text)
15951608
)
1596-
logger.info(
1597-
f"Parsed classification response as {detected_format}: {classification_data}"
1609+
if isinstance(classification_data, dict):
1610+
doc_type = classification_data.get("class", "")
1611+
document_boundary = classification_data.get(
1612+
"document_boundary", "continue"
1613+
)
1614+
logger.info(
1615+
f"Parsed classification response as {detected_format}: {classification_data}"
1616+
)
1617+
else:
1618+
# If parsing failed, try to extract classification directly from text
1619+
doc_type = self._extract_class_from_text(classification_text)
1620+
document_boundary = "continue"
1621+
except Exception as e:
1622+
logger.warning(
1623+
f"Failed to parse structured data from response: {e}"
15981624
)
1599-
else:
1600-
# If parsing failed, try to extract classification directly from text
1625+
# Try to extract classification directly from text
16011626
doc_type = self._extract_class_from_text(classification_text)
16021627
document_boundary = "continue"
1603-
except Exception as e:
1604-
logger.warning(f"Failed to parse structured data from response: {e}")
1605-
# Try to extract classification directly from text
1606-
doc_type = self._extract_class_from_text(classification_text)
1607-
document_boundary = "continue"
1608-
1609-
# Validate classification against known document types
1610-
if not doc_type:
1611-
doc_type = "unclassified"
1612-
logger.warning(
1613-
f"Empty classification for page {page_id}, using 'unclassified'"
1614-
)
1615-
elif doc_type not in self.valid_doc_types:
1616-
logger.warning(
1617-
f"Unknown document type '{doc_type}' for page {page_id}, "
1618-
f"valid types are: {', '.join(self.valid_doc_types)}"
1619-
)
1620-
# Still use the classification, it might be a new valid type
1628+
1629+
# Validate the predicted class against the configured vocabulary
1630+
if doc_type and doc_type in self.valid_doc_types:
1631+
break # Valid prediction - done
1632+
1633+
if not enforce:
1634+
# Legacy behavior: warn and use the prediction as-is.
1635+
if not doc_type:
1636+
doc_type = "unclassified"
1637+
logger.warning(
1638+
f"Empty classification for page {page_id}, using 'unclassified'"
1639+
)
1640+
else:
1641+
logger.warning(
1642+
f"Unknown document type '{doc_type}' for page {page_id}, "
1643+
f"valid types are: {', '.join(self.valid_doc_types)}"
1644+
)
1645+
# Still use the classification, it might be a new valid type
1646+
break
1647+
1648+
# Enforcement is on and the prediction is invalid.
1649+
invalid_value = doc_type or "(empty)"
1650+
if attempt < max_retries:
1651+
logger.warning(
1652+
f"Invalid class '{invalid_value}' for page {page_id} "
1653+
f"(attempt {attempt + 1}/{max_retries + 1}); re-prompting "
1654+
f"with valid classes."
1655+
)
1656+
attempt_content = self._build_validation_retry_content(
1657+
content, invalid_value
1658+
)
1659+
else:
1660+
# Retries exhausted - assign configured fallback class.
1661+
fallback = self.config.classification.invalidClassFallback
1662+
validation_error = (
1663+
f"Model returned invalid class '{invalid_value}' after "
1664+
f"{max_retries + 1} attempt(s); assigned fallback "
1665+
f"'{fallback}'."
1666+
)
1667+
logger.error(f"Page {page_id}: {validation_error}")
1668+
doc_type = fallback
1669+
1670+
t1 = time.time()
1671+
logger.info(
1672+
f"Time taken for classification of page {page_id}: {t1 - t0:.2f} seconds"
1673+
)
16211674

16221675
logger.info(f"Page {page_id} classified as {doc_type}")
16231676

16241677
# Create and return classification result
1678+
metadata: Dict[str, Any] = {
1679+
"metering": metering,
1680+
"document_boundary": str(document_boundary).lower(),
1681+
}
1682+
if validation_error:
1683+
metadata["validation_error"] = validation_error
1684+
16251685
return PageClassification(
16261686
page_id=page_id,
16271687
classification=DocumentClassification(
16281688
doc_type=doc_type,
16291689
confidence=1.0, # Default confidence
1630-
metadata={
1631-
"metering": metering,
1632-
"document_boundary": str(document_boundary).lower(),
1633-
},
1690+
metadata=metadata,
16341691
),
16351692
image_uri=image_uri,
16361693
text_uri=text_uri,
@@ -1834,6 +1891,40 @@ def classify_page(
18341891
text_uri=text_uri,
18351892
)
18361893

1894+
def _build_validation_retry_content(
1895+
self, original_content: List[Dict[str, Any]], invalid_class: str
1896+
) -> List[Dict[str, Any]]:
1897+
"""
1898+
Build the content for a validation retry by appending a correction
1899+
instruction to the original content.
1900+
1901+
Because classification typically runs at temperature 0.0, re-sending
1902+
the identical request would return the identical (invalid) answer. The
1903+
appended correction message changes the input so the model is steered
1904+
back to the allowed vocabulary. This is a single-turn re-prompt: we
1905+
re-send the original content plus the correction, rather than threading
1906+
a multi-turn conversation history.
1907+
1908+
Args:
1909+
original_content: The content list from the initial invocation.
1910+
invalid_class: The out-of-vocabulary class the model returned.
1911+
1912+
Returns:
1913+
A new content list (the original is not mutated) with the
1914+
correction instruction appended.
1915+
"""
1916+
valid_classes = ", ".join(sorted(self.valid_doc_types))
1917+
correction = (
1918+
f"\n\nYour previous response classified the document as "
1919+
f"'{invalid_class}', which is NOT a valid class. You MUST choose "
1920+
f"exactly one class from this list: [{valid_classes}]. "
1921+
f"Respond again using the required output format and select only "
1922+
f"from the allowed classes."
1923+
)
1924+
# Shallow-copy the list and append a new text item. The original
1925+
# content dicts are not mutated.
1926+
return list(original_content) + [{"text": correction}]
1927+
18371928
def _invoke_bedrock_model(
18381929
self, content: List[Dict[str, Any]], config: Dict[str, Any]
18391930
) -> Dict[str, Any]:

0 commit comments

Comments
 (0)