Skip to content

Classification should enforce valid class outputs (reject/retry on out-of-vocabulary predictions) #356

Description

@kaleko

Is your feature request related to a problem? Please describe.
When using page-level classification with a defined set of classes (e.g., "invoice", "w2", "check"), the LLM can return class labels that are not in the allowed set (e.g., predicting "receipt" when only "invoice", "w2", and "check" are valid). A classification system with a defined vocabulary should never return a class outside that vocabulary — if classify(document, classes={'A','B','C'}) is called, "D" should not be a possible output. This is especially problematic with smaller/cheaper models that are more prone to deviating from instructions.

Describe the solution you'd like
Add a deterministic validation + retry loop to the page-level classification module:

  1. After the LLM returns a classification prediction, validate that the predicted class is in the allowed set of classes defined in the config.
  2. If the prediction is invalid, retry with a continuation prompt telling the model its response was invalid and listing the allowed classes again (e.g., "Your response is invalid. You must choose from: [list]. Try again.").
  3. Repeat up to a configurable max retry count.
  4. If all retries fail, either flag an error or assign a default/fallback class.

This is essentially Pydantic-style structured output validation applied to classification — ensuring the output always conforms to the expected enum of class labels. If this is implemented with pydantic, that logic could be re-used in other areas of the accelerator. For example if the LLM returns a json which should match a specific schema (including complex, nested ones) which is modeled by a pydantic.BaseModel called ComplicatedSchema, if the json is invalid then ComplicatedSchema.model_validate_json(LLM_output_json) will fail with a very descriptive error message that can be directly fed back into the LLM and will guide it on which fields need to be fixed and how.

Describe alternatives you've considered

  • Prompt-only approach (current): Instructing the model via prompt to only use defined classes. This is unreliable, especially with weaker models or when prompt caching causes instructions to be less
    prominent.
  • Embedding-based fallback: Match the predicted class embedding to the nearest allowed class embedding. This would work as a soft fallback but adds complexity and may silently map a genuinely uncertain prediction to the wrong class.
  • Bedrock structured output / constrained decoding: Some models support constraining output tokens to a specific set. This is model-dependent and not universally available.

The retry/validation loop is preferred because it is model-agnostic, straightforward to implement, and provides clear error signaling when the model cannot comply.

Additional context

  • This applies to page-level classification specifically (holistic classification may have similar needs).
  • The same validation pattern could potentially be reused for Discovery output compliance (ensuring generated schemas match expected structure), though that is a separate and more complex problem — not in scope here.
  • If the user wants an "unknown/other" catch-all class, they can define it explicitly in their class list — the system should still only output from the defined set.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions