Skip to content

Commit 01b98db

Browse files
committed
Release v0.2.2
1 parent fd8ccf4 commit 01b98db

5 files changed

Lines changed: 147 additions & 10 deletions

File tree

autochecklist/generators/instance_level/contrastive.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,8 @@ def _generate_with_candidates(
154154
)
155155
format_kwargs["reference"] = reference
156156

157-
# Load format instructions
158-
format_text = load_format(self._format_name)
157+
# Load format instructions (skip for custom schemas)
158+
format_text = load_format(self._format_name) if self._format_name else ""
159159

160160
# Inject format inline if template has {format_instructions} placeholder,
161161
# otherwise append after the prompt (default).

autochecklist/generators/instance_level/direct.py

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,17 @@ def __init__(
5656
self._method_name = method_name
5757
self.max_items = preset.get("max_items", max_items)
5858
self.min_items = preset.get("min_items", min_items)
59+
60+
is_custom_schema = response_schema is not None
5961
self._response_schema = response_schema or preset.get(
6062
"response_schema", ChecklistResponse
6163
)
62-
self._format_name = format_name or preset.get("format_name", "checklist")
64+
if format_name is not None:
65+
self._format_name = format_name
66+
elif is_custom_schema:
67+
self._format_name = None
68+
else:
69+
self._format_name = preset.get("format_name", "checklist")
6370

6471
# Load template
6572
if custom_prompt is not None:
@@ -115,8 +122,8 @@ def generate(
115122
if "history" in self._template._placeholders:
116123
format_kwargs["history"] = history
117124

118-
# Load format instructions
119-
format_text = load_format(self._format_name)
125+
# Load format instructions (skip for custom schemas)
126+
format_text = load_format(self._format_name) if self._format_name else ""
120127

121128
# Inject format inline if template has {format_instructions} placeholder,
122129
# otherwise append after the prompt (default).
@@ -149,20 +156,67 @@ def _parse_structured(self, raw: str) -> list[ChecklistItem]:
149156
150157
Primary path: json.loads() succeeds (structured output).
151158
Fallback path: extract_json() extracts JSON from raw text.
159+
160+
Auto-detects the list field and item fields from the schema,
161+
supporting both built-in and custom response schemas.
152162
"""
153163
try:
154164
data = json.loads(raw)
155165
except json.JSONDecodeError:
156166
data = extract_json(raw)
157167
validated = self._response_schema.model_validate(data)
158168

169+
# Find the list field (first List[BaseModel] field)
170+
item_list = self._get_item_list(validated)
171+
159172
items = []
160-
for q in validated.questions[: self.max_items]:
173+
for q in item_list[: self.max_items]:
174+
q_data = q.model_dump() if hasattr(q, "model_dump") else {}
175+
# Find question text: use 'question' field, or first str field
176+
question, question_key = self._get_question_text(q, q_data)
177+
weight = getattr(q, "weight", 100.0)
178+
category = getattr(q, "category", None)
179+
# Extra fields → metadata
180+
known = {question_key, "weight", "category"}
181+
extra = {k: v for k, v in q_data.items() if k not in known}
161182
items.append(
162183
ChecklistItem(
163-
question=q.question,
164-
weight=getattr(q, "weight", 100.0),
165-
category=getattr(q, "category", None),
184+
question=question,
185+
weight=weight,
186+
category=category,
187+
metadata=extra if extra else {},
166188
)
167189
)
168190
return items
191+
192+
@staticmethod
193+
def _get_item_list(validated: Any) -> list:
194+
"""Extract the list of items from a validated response model."""
195+
# Try 'questions' first (built-in convention)
196+
if hasattr(validated, "questions"):
197+
return validated.questions
198+
# Auto-detect: first list attribute
199+
for field_name in type(validated).model_fields:
200+
value = getattr(validated, field_name)
201+
if isinstance(value, list):
202+
return value
203+
raise ValueError(
204+
f"Cannot find list field in {type(validated).__name__}. "
205+
"Schema must have a list field (e.g., 'questions', 'items')."
206+
)
207+
208+
@staticmethod
209+
def _get_question_text(item: Any, item_data: dict) -> tuple[str, str]:
210+
"""Extract question text and its field key from an item."""
211+
if isinstance(item, str):
212+
return item, "question"
213+
if hasattr(item, "question"):
214+
return item.question, "question"
215+
# Fall back to first str field
216+
for key, value in item_data.items():
217+
if isinstance(value, str):
218+
return value, key
219+
raise ValueError(
220+
f"Cannot find question text in {type(item).__name__}. "
221+
"Item must have a 'question' field or at least one str field."
222+
)

docs/user-guide/custom-prompts.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,4 +130,38 @@ pipe = pipeline("code_review", scorer="strict")
130130

131131
Custom generators registered via `register_custom_generator()` always use the unweighted `ChecklistResponse` schema. To use weighted output (`WeightedChecklistResponse`), instantiate `DirectGenerator` directly with `response_schema=WeightedChecklistResponse`.
132132

133+
### Custom Response Schemas
134+
135+
You can pass any Pydantic model as `response_schema` to define your own output structure. When a custom schema is provided, format instructions are skipped and the LLM is guided entirely via structured output enforcement.
136+
137+
```python
138+
from pydantic import BaseModel
139+
from autochecklist import DirectGenerator
140+
141+
class ActionItem(BaseModel):
142+
item: str
143+
sources: list[int]
144+
145+
class ActionItemsResponse(BaseModel):
146+
questions: list[ActionItem]
147+
148+
gen = DirectGenerator(
149+
custom_prompt="Generate evaluation criteria with source references for:\n\n{input}",
150+
response_schema=ActionItemsResponse,
151+
model="openai/gpt-5-mini",
152+
)
153+
checklist = gen.generate(input="Write a literature review.")
154+
```
155+
156+
The parser auto-detects the list field and question text:
157+
158+
- **List field**: uses `questions` if present, otherwise the first list field (e.g., `items`, `criteria`)
159+
- **Question text**: uses `question` if present, otherwise the first `str` field (e.g., `item`, `text`)
160+
- **Extra fields**: any fields beyond the question text, `weight`, and `category` are preserved in `ChecklistItem.metadata`
161+
162+
```python
163+
checklist.items[0].question # "Is it cited?"
164+
checklist.items[0].metadata # {"sources": [1, 3]}
165+
```
166+
133167
**Scorers** also use structured JSON output (`BatchScoringResponse`, `ItemScoringResponse`, etc.) with the same provider-level enforcement and fallback. Your custom scorer prompt does not need to dictate the output format.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "autochecklist"
3-
version = "0.2.1"
3+
version = "0.2.2"
44
description = "A library of checklist generation and scoring methods for LLM evaluation"
55
authors = [{name = "ChicagoHAI"}]
66
readme = "README.pypi.md"

tests/test_generators/test_direct.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,55 @@ def test_parse_structured_without_category_unchanged(self):
112112
assert items[0].category is None
113113

114114

115+
class TestCustomResponseSchema:
116+
def test_custom_schema_with_items_list_str(self):
117+
"""Custom schema with 'items: List[str]' parses into ChecklistItems."""
118+
from pydantic import BaseModel
119+
from typing import List
120+
from autochecklist.generators.instance_level.direct import DirectGenerator
121+
122+
class CustomResponse(BaseModel):
123+
items: List[str]
124+
125+
gen = DirectGenerator(
126+
method_name="custom",
127+
custom_prompt="Generate criteria for: {input}",
128+
response_schema=CustomResponse,
129+
)
130+
assert gen._format_name is None
131+
132+
raw = '{"items": ["Does the response address the topic?", "Is the tone appropriate?"]}'
133+
items = gen._parse_structured(raw)
134+
assert len(items) == 2
135+
assert items[0].question == "Does the response address the topic?"
136+
assert items[1].question == "Is the tone appropriate?"
137+
assert items[0].weight == 100.0
138+
assert items[0].category is None
139+
140+
def test_nested_schema_extra_fields_in_metadata(self):
141+
"""Nested item model with non-str fields preserved in metadata."""
142+
from pydantic import BaseModel
143+
from autochecklist.generators.instance_level.direct import DirectGenerator
144+
145+
class ActionItem(BaseModel):
146+
item: str
147+
sources: list[int]
148+
149+
class ActionItemsResponse(BaseModel):
150+
questions: list[ActionItem]
151+
152+
gen = DirectGenerator(
153+
method_name="custom",
154+
custom_prompt="Generate criteria for: {input}",
155+
response_schema=ActionItemsResponse,
156+
)
157+
raw = '{"questions": [{"item": "Is it cited?", "sources": [1, 3]}]}'
158+
items = gen._parse_structured(raw)
159+
assert len(items) == 1
160+
assert items[0].question == "Is it cited?"
161+
assert items[0].metadata == {"sources": [1, 3]}
162+
163+
115164
class TestContrastiveGeneratorConfig:
116165
def test_rlcf_candidate_preset_loads(self):
117166
from autochecklist.generators.instance_level.contrastive import ContrastiveGenerator

0 commit comments

Comments
 (0)