Skip to content

Commit 994cc4f

Browse files
authored
Fix learning synthesis issues: session_reward context and Gemini structured outputs
* Fix learning synthesis not triggered when session_reward set via orchestrator - Use ExecutionContext.get() after orchestrator.arun() completes to ensure we access the same context instance that the orchestrator used - Ensure ExecutionContext.metadata property always returns a dict with defensive initialization checks - Maintains backward compatibility with all adapters (LangGraph, BYOA, etc.) - Fixes issue #138 * Implement Gemini structured outputs for learning synthesis - Add JSON Schema builder for playbook_entry.v1 structure (atlas/learning/schema.py) - Update LLMClient to detect Gemini models and use structured outputs via extra_body - Update LearningSynthesizer to pass JSON schema for Gemini models - Improve error handling with clearer error messages including model info - Update mock response handler to support playbook_entry.v1 structure - Add comprehensive unit tests for schema generation and structured outputs - Maintain backward compatibility for non-Gemini models (OpenAI, etc.) This ensures schema validation at API level for Gemini models, prevents malformed responses, and reduces silent failures in learning synthesis. Non-Gemini models continue to use OpenAI-style response_format for backward compatibility. Fixes issue #139 * Address Copilot feedback: add defensive type checks - Add isinstance check for response_format before calling .get() - Simplify test assertion for better readability - Improves defensive programming and code clarity
1 parent c96f585 commit 994cc4f

6 files changed

Lines changed: 346 additions & 12 deletions

File tree

atlas/core/__init__.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,17 +216,20 @@ async def arun(
216216
capability_probe=capability_probe_client,
217217
)
218218
result = await orchestrator.arun(task)
219+
# Get current context after orchestrator completes to ensure we access the same
220+
# context instance that the orchestrator used (they share the same ExecutionContextState)
221+
current_context = ExecutionContext.get()
219222
if (
220223
database
221224
and learning_synthesizer
222225
and learning_synthesizer.enabled
223226
and learning_cfg.update_enabled
224-
and execution_context.metadata.get("session_reward") is not None
227+
and current_context.metadata.get("session_reward") is not None
225228
):
226-
reward_payload = execution_context.metadata.get("session_reward")
227-
trajectory_payload = execution_context.metadata.get("session_trajectory")
228-
history_payload = execution_context.metadata.get("learning_history")
229-
current_learning_state = execution_context.metadata.get("learning_state") or {}
229+
reward_payload = current_context.metadata.get("session_reward")
230+
trajectory_payload = current_context.metadata.get("session_trajectory")
231+
history_payload = current_context.metadata.get("learning_history")
232+
current_learning_state = current_context.metadata.get("learning_state") or {}
230233
synthesis = await learning_synthesizer.asynthesize(
231234
learning_key=learning_key,
232235
task=task,

atlas/learning/schema.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
"""JSON Schema builder for playbook entry learning synthesis."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Any, Dict
6+
7+
8+
def build_playbook_entry_schema() -> Dict[str, Any]:
9+
"""Build JSON Schema for playbook_entry.v1 structure.
10+
11+
This schema matches the structure defined in atlas/learning/prompts.py
12+
and is used for Gemini structured outputs to enforce type safety and
13+
schema validation at the API level.
14+
15+
Returns:
16+
JSON Schema dictionary compatible with Gemini's response_json_schema format
17+
"""
18+
return {
19+
"type": "object",
20+
"properties": {
21+
"version": {
22+
"type": "string",
23+
"const": "playbook_entry.v1",
24+
"description": "Schema version identifier"
25+
},
26+
"student_pamphlet": {
27+
"type": ["string", "null"],
28+
"description": "Updated student learning pamphlet text or null if unchanged"
29+
},
30+
"teacher_pamphlet": {
31+
"type": ["string", "null"],
32+
"description": "Updated teacher learning pamphlet text or null if unchanged"
33+
},
34+
"playbook_entries": {
35+
"type": "array",
36+
"description": "List of playbook entries to add or update",
37+
"items": {
38+
"type": "object",
39+
"properties": {
40+
"id": {
41+
"type": ["string", "null"],
42+
"description": "Unique identifier for the entry, or null for new entries"
43+
},
44+
"audience": {
45+
"type": "string",
46+
"enum": ["student", "teacher"],
47+
"description": "Target audience for this playbook entry"
48+
},
49+
"cue": {
50+
"type": "object",
51+
"description": "Machine-detectable trigger pattern",
52+
"properties": {
53+
"type": {
54+
"type": "string",
55+
"enum": ["regex", "keyword", "predicate"],
56+
"description": "Type of cue pattern"
57+
},
58+
"pattern": {
59+
"type": "string",
60+
"description": "Machine-detectable trigger pattern (max 150 chars)"
61+
},
62+
"description": {
63+
"type": ["string", "null"],
64+
"description": "Optional human-readable explanation"
65+
}
66+
},
67+
"required": ["type", "pattern"]
68+
},
69+
"action": {
70+
"type": "object",
71+
"description": "Action to take when cue is detected",
72+
"properties": {
73+
"imperative": {
74+
"type": "string",
75+
"description": "Imperative verb phrasing describing the action (max 120 chars)"
76+
},
77+
"runtime_handle": {
78+
"type": ["string", "null"],
79+
"description": "Runtime handle/tool name from available_runtime_handles, or null if no tools"
80+
},
81+
"tool_name": {
82+
"type": ["string", "null"],
83+
"description": "Optional tool name"
84+
},
85+
"arguments": {
86+
"type": ["object", "null"],
87+
"description": "Optional tool arguments"
88+
}
89+
},
90+
"required": ["imperative"]
91+
},
92+
"expected_effect": {
93+
"type": "string",
94+
"description": "Explanation of why this action improves outcomes (max 200 chars)"
95+
},
96+
"scope": {
97+
"type": "object",
98+
"description": "Scope and constraints for when this entry applies",
99+
"properties": {
100+
"category": {
101+
"type": "string",
102+
"enum": ["reinforcement", "differentiation"],
103+
"description": "Whether this reinforces existing behavior or introduces new strategy"
104+
},
105+
"constraints": {
106+
"type": "string",
107+
"description": "Boundaries and applicability constraints (max 250 chars)"
108+
},
109+
"applies_when": {
110+
"type": ["string", "null"],
111+
"description": "Optional condition for when this entry applies"
112+
}
113+
},
114+
"required": ["category", "constraints"]
115+
},
116+
"metadata": {
117+
"type": ["object", "null"],
118+
"description": "Optional free-form metadata"
119+
}
120+
},
121+
"required": ["audience", "cue", "action", "expected_effect", "scope"]
122+
}
123+
},
124+
"session_student_learning": {
125+
"type": ["string", "null"],
126+
"description": "Brief takeaway from this session for student (optional)"
127+
},
128+
"session_teacher_learning": {
129+
"type": ["string", "null"],
130+
"description": "Teacher intervention note from this session (optional)"
131+
},
132+
"metadata": {
133+
"type": ["object", "null"],
134+
"description": "Optional metadata including synthesis reasoning and validation notes"
135+
}
136+
},
137+
"required": ["version"]
138+
}
139+
140+
141+
__all__ = ["build_playbook_entry_schema"]
142+

atlas/learning/synthesizer.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
stabilise_playbook_entry_id,
2525
)
2626
from atlas.learning.prompts import LEARNING_SYNTHESIS_PROMPT
27+
from atlas.learning.schema import build_playbook_entry_schema
2728
from atlas.runtime.orchestration.execution_context import ExecutionContext
2829
from atlas.utils.llm_client import LLMClient
2930

@@ -105,25 +106,43 @@ async def asynthesize(
105106
if client is None:
106107
logger.debug("Learning synthesizer client unavailable; skipping update for %s", learning_key)
107108
return None
109+
110+
# Build JSON schema for structured outputs (Gemini models)
111+
json_schema = build_playbook_entry_schema()
112+
overrides: Dict[str, Any] = {}
113+
if self._is_gemini_model(client.model):
114+
# Pass JSON schema via extra_body for Gemini structured outputs
115+
overrides["extra_body"] = {
116+
"response_json_schema": json_schema
117+
}
118+
108119
try:
109120
response = await client.acomplete(
110121
messages,
111122
response_format={"type": "json_object"},
123+
overrides=overrides,
112124
)
113125
audit_entry = {
114126
"model": client.model,
115127
"messages": messages,
116128
"response": response.content,
117129
"reasoning": response.reasoning or {},
118130
"raw_response": response.raw,
131+
"structured_output": self._is_gemini_model(client.model),
119132
}
120133
except Exception as exc:
121134
logger.warning("Learning synthesis call failed for %s: %s", learning_key, exc)
122135
return None
123136

124137
parsed = self._try_parse_json(response.content)
125138
if parsed is None:
126-
logger.warning("Learning synthesis returned non-JSON payload for %s", learning_key)
139+
logger.error(
140+
"Learning synthesis returned non-JSON payload for %s (model: %s). "
141+
"Response preview: %s",
142+
learning_key,
143+
client.model,
144+
response.content[:200] if response.content else "empty",
145+
)
127146
return None
128147

129148
result = self._build_result(parsed, learning_state or {})
@@ -679,6 +698,17 @@ def _teacher_guidance_digest(self, context: ExecutionContext) -> str | None:
679698
serialized = "\n".join(sorted(set(notes)))
680699
return hashlib.sha256(serialized.encode("utf-8")).hexdigest()[:16]
681700

701+
def _is_gemini_model(self, model: str) -> bool:
702+
"""Check if model is a Gemini model.
703+
704+
Args:
705+
model: Model identifier string
706+
707+
Returns:
708+
True if model is a Gemini model, False otherwise
709+
"""
710+
return model.startswith("gemini/") or model.startswith("google/")
711+
682712
@staticmethod
683713
def _clean_str(value: Any) -> str | None:
684714
if value is None:

atlas/runtime/orchestration/execution_context.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,15 @@ def __init__(self, state: ExecutionContextState) -> None:
9595

9696
@property
9797
def metadata(self) -> dict[str, typing.Any]:
98-
return self._state.metadata.get()
98+
# Ensure metadata dict is initialized by accessing the property first
99+
# This triggers ExecutionContextState.metadata property which initializes if None
100+
_ = self._state.metadata
101+
result = self._state.metadata.get()
102+
# Defensive check: ensure we always return a dict
103+
if result is None:
104+
self._state.metadata.set({})
105+
return {}
106+
return result
99107

100108
@property
101109
def active_function(self) -> InvocationNode:

atlas/utils/llm_client.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,17 @@ def complete(
6262
content, reasoning = self._extract_content(result)
6363
return LLMResponse(content=content, reasoning=reasoning, raw=result)
6464

65+
def _is_gemini_model(self, model: str) -> bool:
66+
"""Check if model is a Gemini model.
67+
68+
Args:
69+
model: Model identifier string
70+
71+
Returns:
72+
True if model is a Gemini model, False otherwise
73+
"""
74+
return model.startswith("gemini/") or model.startswith("google/")
75+
6576
def _prepare_kwargs(
6677
self,
6778
messages: Sequence[dict[str, Any]],
@@ -87,8 +98,6 @@ def _prepare_kwargs(
8798
if params.max_output_tokens is not None:
8899
kwargs["max_tokens"] = params.max_output_tokens
89100
kwargs["timeout"] = params.timeout_seconds
90-
if response_format:
91-
kwargs["response_format"] = response_format
92101

93102
extra_headers = dict(params.additional_headers)
94103
override_headers = overrides.pop("extra_headers", None)
@@ -105,6 +114,17 @@ def _prepare_kwargs(
105114
if supports_reasoning and params.reasoning_effort:
106115
extra_body.setdefault("reasoning_effort", params.reasoning_effort)
107116

117+
# Handle Gemini structured outputs
118+
if response_format and isinstance(response_format, dict) and response_format.get("type") == "json_object":
119+
if self._is_gemini_model(params.model):
120+
# Use Gemini structured outputs via extra_body
121+
# response_json_schema should be provided via overrides["extra_body"]
122+
extra_body.setdefault("response_mime_type", "application/json")
123+
# Don't set response_format for Gemini - it's not supported
124+
else:
125+
# Use OpenAI-style response_format for non-Gemini models
126+
kwargs["response_format"] = response_format
127+
108128
if extra_headers:
109129
kwargs["extra_headers"] = extra_headers
110130
if extra_body:
@@ -179,7 +199,19 @@ def _mock_response(
179199
if isinstance(last_message, dict):
180200
user_content = str(last_message.get("content", ""))
181201
if response_format and response_format.get("type") == "json_object":
182-
if "plan" in user_content:
202+
# Check if this looks like a learning synthesis request
203+
if "playbook_entry" in user_content or "learning" in user_content.lower():
204+
# Return mock playbook_entry.v1 structure
205+
payload = {
206+
"version": "playbook_entry.v1",
207+
"student_pamphlet": None,
208+
"teacher_pamphlet": None,
209+
"playbook_entries": [],
210+
"session_student_learning": None,
211+
"session_teacher_learning": None,
212+
"metadata": None
213+
}
214+
elif "plan" in user_content:
183215
payload = {
184216
"steps": [
185217
{

0 commit comments

Comments
 (0)