Skip to content

Commit 7b6c823

Browse files
committed
feat: Enhance Cloudflare integration with fallback handling and improve review planning logic
1 parent 8f9d79e commit 7b6c823

4 files changed

Lines changed: 256 additions & 10 deletions

File tree

python-ecosystem/inference-orchestrator/src/llm/llm_factory.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import json
44
from typing import Any, Optional
5-
from urllib.parse import urlparse
5+
from urllib.parse import urlparse, urlunparse
66
from pydantic import SecretStr
77
from langchain_openai import ChatOpenAI
88
from langchain_anthropic import ChatAnthropic
@@ -108,6 +108,9 @@ def _normalize_openai_compatible_base_url(ai_base_url: str) -> str:
108108

109109
if _is_cloudflare_base_url(base_url):
110110
parsed = urlparse(base_url)
111+
if parsed.hostname == "api.cloudflare.com" and "/ai/run/" in parsed.path:
112+
ai_prefix = parsed.path.split("/ai/run/", 1)[0]
113+
return urlunparse(parsed._replace(path=f"{ai_prefix}/ai/v1", params="", query="", fragment=""))
111114
if parsed.hostname == "api.cloudflare.com" and parsed.path.endswith("/ai"):
112115
return f"{base_url}/v1"
113116
return base_url
@@ -155,6 +158,66 @@ def _coerce_openai_compatible_text_content(content: Any) -> str:
155158
return str(content)
156159

157160

161+
_CLOUDFLARE_ROLE_BY_MESSAGE_TYPE = {
162+
"human": "user",
163+
"ai": "assistant",
164+
"system": "system",
165+
"tool": "tool",
166+
"function": "function",
167+
}
168+
169+
_CLOUDFLARE_MESSAGE_KEYS = {
170+
"role",
171+
"content",
172+
"name",
173+
"tool_calls",
174+
"tool_call_id",
175+
"function_call",
176+
}
177+
178+
179+
def _cloudflare_message_to_dict(message: Any) -> Any:
180+
"""Convert dict-like or LangChain message objects into chat message dicts."""
181+
if isinstance(message, dict):
182+
data = dict(message)
183+
else:
184+
data = None
185+
if hasattr(message, "model_dump"):
186+
try:
187+
data = message.model_dump(mode="json", exclude_none=True)
188+
except TypeError:
189+
data = message.model_dump()
190+
except Exception:
191+
data = None
192+
if not isinstance(data, dict) and hasattr(message, "dict"):
193+
try:
194+
data = message.dict()
195+
except Exception:
196+
data = None
197+
if not isinstance(data, dict):
198+
role = getattr(message, "role", None)
199+
message_type = getattr(message, "type", None)
200+
role = role or _CLOUDFLARE_ROLE_BY_MESSAGE_TYPE.get(str(message_type))
201+
content = getattr(message, "content", None)
202+
if not role and content is None:
203+
return message
204+
data = {"role": role, "content": content}
205+
for key in ("name", "tool_calls", "tool_call_id", "function_call"):
206+
value = getattr(message, key, None)
207+
if value:
208+
data[key] = value
209+
210+
message_type = data.get("type")
211+
if not data.get("role") and message_type:
212+
data["role"] = _CLOUDFLARE_ROLE_BY_MESSAGE_TYPE.get(str(message_type), str(message_type))
213+
214+
return {
215+
key: value
216+
for key, value in data.items()
217+
if key in _CLOUDFLARE_MESSAGE_KEYS and value is not None
218+
}
219+
220+
158221
def _normalize_cloudflare_chat_payload(payload: dict[str, Any]) -> dict[str, Any]:
159222
"""
160223
Adapt LangChain's OpenAI chat payload to Cloudflare Workers AI's stricter schema.
@@ -167,11 +230,12 @@ def _normalize_cloudflare_chat_payload(payload: dict[str, Any]) -> dict[str, Any
167230
payload.pop("parallel_tool_calls", None)
168231

169232
messages = payload.get("messages")
170-
if not isinstance(messages, list):
233+
if not isinstance(messages, (list, tuple)):
171234
return payload
172235

173236
normalized_messages = []
174237
for message in messages:
238+
message = _cloudflare_message_to_dict(message)
175239
if not isinstance(message, dict):
176240
normalized_messages.append(message)
177241
continue

python-ecosystem/inference-orchestrator/src/service/review/orchestrator/stage_0_planning.py

Lines changed: 116 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any, Dict, Optional
77

88
from model.dtos import ReviewRequestDto
9-
from model.multi_stage import ReviewPlan
9+
from model.multi_stage import ReviewPlan, FileGroup, ReviewFile
1010
from utils.prompts.prompt_builder import PromptBuilder
1111
from utils.diff_processor import ProcessedDiff
1212

@@ -75,5 +75,118 @@ async def execute_stage_0_planning(
7575
content = extract_llm_response_text(response)
7676
return await parse_llm_response(content, ReviewPlan, llm)
7777
except Exception as e:
78-
logger.error(f"Stage 0 planning failed: {e}")
79-
raise ValueError(f"Stage 0 planning failed: {e}")
78+
logger.error(f"Stage 0 planning failed, using local fallback plan: {e}")
79+
return _build_fallback_review_plan(request, processed_diff)
80+
81+
82+
def _build_fallback_review_plan(
83+
request: ReviewRequestDto,
84+
processed_diff: Optional[ProcessedDiff] = None,
85+
) -> ReviewPlan:
86+
"""
87+
Build a conservative review plan without another LLM call.
88+
89+
Stage 0 is an optimization step. If a provider returns empty or malformed
90+
planning JSON, the review should still continue with all changed files.
91+
"""
92+
paths = list(dict.fromkeys(request.changedFiles or []))
93+
diff_by_path = {df.path: df for df in processed_diff.files} if processed_diff else {}
94+
95+
if not paths and processed_diff:
96+
paths = [df.path for df in processed_diff.files if not df.is_skipped]
97+
98+
groups: Dict[str, list[ReviewFile]] = {
99+
"HIGH": [],
100+
"MEDIUM": [],
101+
"LOW": [],
102+
}
103+
104+
for path in paths:
105+
diff_file = diff_by_path.get(path)
106+
priority = _infer_file_priority(path, diff_file)
107+
groups[priority].append(
108+
ReviewFile(
109+
path=path,
110+
focus_areas=_infer_focus_areas(path),
111+
risk_level=priority,
112+
estimated_issues=0,
113+
)
114+
)
115+
116+
file_groups = []
117+
for priority in ("HIGH", "MEDIUM", "LOW"):
118+
files = groups[priority]
119+
if not files:
120+
continue
121+
file_groups.append(
122+
FileGroup(
123+
group_id=f"FALLBACK_{priority}",
124+
priority=priority,
125+
rationale="Local fallback plan generated because AI planning output was unavailable",
126+
files=files,
127+
)
128+
)
129+
130+
return ReviewPlan(
131+
analysis_summary=(
132+
"Fallback review plan generated locally after AI planning returned "
133+
"empty or invalid output."
134+
),
135+
file_groups=file_groups,
136+
cross_file_concerns=_infer_cross_file_concerns(paths),
137+
)
138+
139+
140+
def _infer_file_priority(path: str, diff_file: Any = None) -> str:
141+
lower = path.lower()
142+
if any(marker in lower for marker in (
143+
"auth",
144+
"security",
145+
"permission",
146+
"billing",
147+
"payment",
148+
"migration",
149+
"schema",
150+
"controller",
151+
"handler",
152+
"service",
153+
"repository",
154+
)):
155+
return "HIGH"
156+
if diff_file and getattr(diff_file, "additions", 0) + getattr(diff_file, "deletions", 0) > 200:
157+
return "HIGH"
158+
if any(lower.endswith(ext) for ext in (
159+
".md",
160+
".txt",
161+
".json",
162+
".yaml",
163+
".yml",
164+
".toml",
165+
".lock",
166+
)):
167+
return "LOW"
168+
if any(marker in lower for marker in ("/test/", "/tests/", ".test.", ".spec.", "test_")):
169+
return "LOW"
170+
return "MEDIUM"
171+
172+
173+
def _infer_focus_areas(path: str) -> list[str]:
174+
lower = path.lower()
175+
focus = []
176+
if any(marker in lower for marker in ("auth", "security", "permission")):
177+
focus.append("SECURITY")
178+
if any(marker in lower for marker in ("migration", "schema", "repository", "entity", "model")):
179+
focus.append("DATA_ACCESS")
180+
if any(marker in lower for marker in ("controller", "handler", "api")):
181+
focus.append("API_CONTRACT")
182+
if any(marker in lower for marker in ("/test/", "/tests/", ".test.", ".spec.", "test_")):
183+
focus.append("TESTING")
184+
return focus or ["GENERAL"]
185+
186+
187+
def _infer_cross_file_concerns(paths: list[str]) -> list[str]:
188+
if len(paths) < 2:
189+
return []
190+
return [
191+
"Check interactions between changed files because AI planning was unavailable."
192+
]

python-ecosystem/inference-orchestrator/tests/test_llm_factory.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,14 @@ def test_normalize_cloudflare_ai_gateway_does_not_append_v1(self):
221221
base = "https://gateway.ai.cloudflare.com/v1/account-id/default/compat"
222222
assert _normalize_openai_compatible_base_url(base) == base
223223

224+
def test_normalize_cloudflare_workers_ai_run_endpoint_to_openai_base(self):
225+
assert (
226+
_normalize_openai_compatible_base_url(
227+
"https://api.cloudflare.com/client/v4/accounts/account-id/ai/run/@cf/moonshotai/kimi-k2-instruct"
228+
)
229+
== "https://api.cloudflare.com/client/v4/accounts/account-id/ai/v1"
230+
)
231+
224232
def test_detect_cloudflare_base_url(self):
225233
assert _is_cloudflare_base_url(
226234
"https://api.cloudflare.com/client/v4/accounts/id/ai/v1"
@@ -263,6 +271,36 @@ def test_normalize_cloudflare_payload_content_blocks_and_tool_calls(self):
263271
assert normalized["messages"][3]["content"] == "result"
264272
assert "parallel_tool_calls" not in normalized
265273

274+
def test_normalize_cloudflare_payload_langchain_message_objects(self):
275+
class MessageObject:
276+
type = "human"
277+
content = [{"type": "text", "text": "question"}]
278+
279+
payload = {"messages": (MessageObject(),)}
280+
281+
normalized = _normalize_cloudflare_chat_payload(payload)
282+
283+
assert normalized["messages"] == [
284+
{"role": "user", "content": "question"}
285+
]
286+
287+
def test_normalize_cloudflare_payload_model_dump_message(self):
288+
class DumpMessage:
289+
def model_dump(self, **_kwargs):
290+
return {
291+
"type": "system",
292+
"content": [{"type": "text", "text": "sys"}],
293+
"additional_kwargs": {"ignored": True},
294+
}
295+
296+
normalized = _normalize_cloudflare_chat_payload(
297+
{"messages": [DumpMessage()]}
298+
)
299+
300+
assert normalized["messages"] == [
301+
{"role": "system", "content": "sys"}
302+
]
303+
266304

267305
# ── Constants ────────────────────────────────────────────────────
268306

python-ecosystem/inference-orchestrator/tests/test_stage_0_branch.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ async def test_returns_review_plan_from_structured_output(self):
5454

5555
@pytest.mark.asyncio(loop_scope="function")
5656
async def test_fallback_on_structured_failure(self):
57-
"""Stage 0 raises ValueError when both structured and raw LLM fail."""
57+
"""Stage 0 returns a local plan when AI planning fails."""
5858
mock_llm = MagicMock()
5959
structured = MagicMock()
6060
structured.ainvoke = AsyncMock(side_effect=Exception("API error"))
@@ -70,10 +70,41 @@ async def test_fallback_on_structured_failure(self):
7070
request.enrichmentData = None
7171
request.projectRules = None
7272

73-
with pytest.raises(ValueError, match="Stage 0 planning failed"):
74-
await execute_stage_0_planning(
75-
mock_llm, request, is_incremental=False
76-
)
73+
result = await execute_stage_0_planning(
74+
mock_llm, request, is_incremental=False
75+
)
76+
77+
assert isinstance(result, ReviewPlan)
78+
assert result.analysis_summary.startswith("Fallback review plan")
79+
assert [f.path for g in result.file_groups for f in g.files] == ["a.py", "b.py"]
80+
81+
@pytest.mark.asyncio(loop_scope="function")
82+
async def test_fallback_on_empty_raw_response(self):
83+
"""Stage 0 does not fail the review when raw AI output is empty."""
84+
mock_llm = MagicMock()
85+
structured = MagicMock()
86+
structured.ainvoke = AsyncMock(side_effect=Exception("API error"))
87+
mock_llm.with_structured_output.return_value = structured
88+
raw_response = MagicMock()
89+
raw_response.content = ""
90+
mock_llm.ainvoke = AsyncMock(return_value=raw_response)
91+
92+
request = MagicMock()
93+
request.changedFiles = ["src/auth/service.py", "README.md"]
94+
request.deletedFiles = []
95+
request.rawDiff = "diff"
96+
request.prTitle = "PR"
97+
request.prDescription = "desc"
98+
request.enrichmentData = None
99+
request.projectRules = None
100+
101+
result = await execute_stage_0_planning(
102+
mock_llm, request, is_incremental=False
103+
)
104+
105+
assert isinstance(result, ReviewPlan)
106+
assert result.file_groups[0].priority == "HIGH"
107+
assert result.file_groups[-1].priority == "LOW"
77108

78109

79110
# ── execute_branch_analysis ──────────────────────────────────────

0 commit comments

Comments
 (0)