Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ ml-intern
Create a `.env` file in the project root (or export these in your shell):

```bash
ANTHROPIC_API_KEY=<your-anthropic-api-key> # if using anthropic models
ANTHROPIC_API_KEY=<your-anthropic-api-key> # if using anthropic/ models
AWS_REGION_NAME=us-east-1 # if using bedrock/ models
AWS_ACCESS_KEY_ID=<key> # (or AWS_BEARER_TOKEN_BEDROCK for SSO)
AWS_SECRET_ACCESS_KEY=<secret> #
HF_TOKEN=<your-hugging-face-token>
GITHUB_TOKEN=<github-personal-access-token>
```
Expand All @@ -50,6 +53,7 @@ ml-intern "fine-tune llama on my dataset"

```bash
ml-intern --model anthropic/claude-opus-4-6 "your prompt"
ml-intern --model bedrock/us.anthropic.claude-opus-4-6-v1 "your prompt"
ml-intern --max-iterations 100 "your prompt"
ml-intern --no-stream "your prompt"
```
Expand Down
117 changes: 63 additions & 54 deletions agent/core/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ def _friendly_error_message(error: Exception) -> str | None:
"Authentication failed — your API key is missing or invalid.\n\n"
"To fix this, set the API key for your model provider:\n"
" • Anthropic: export ANTHROPIC_API_KEY=sk-...\n"
" • Bedrock: export AWS_ACCESS_KEY_ID=... AWS_SECRET_ACCESS_KEY=... AWS_REGION_NAME=...\n"
" (or AWS_BEARER_TOKEN_BEDROCK for SSO / identity-center auth)\n"
" • OpenAI: export OPENAI_API_KEY=sk-...\n"
" • HF Router: export HF_TOKEN=hf_...\n\n"
"You can also add it to a .env file in the project root.\n"
Expand Down Expand Up @@ -293,10 +295,20 @@ class LLMResult:


async def _call_llm_streaming(session: Session, messages, tools, llm_params) -> LLMResult:
"""Call the LLM with streaming, emitting assistant_chunk events."""
response = None
"""Call the LLM with streaming, emitting assistant_chunk events.

The retry loop wraps both the ``acompletion()`` call and the stream
iteration — providers (especially Bedrock) can throw transient errors
mid-stream. On a mid-stream failure we discard partial content and
retry from scratch (partial tool-call JSON is unusable anyway).
"""
_healed_effort = False # one-shot safety net per call
for _llm_attempt in range(_MAX_LLM_RETRIES):
full_content = ""
tool_calls_acc: dict[int, dict] = {}
token_count = 0
finish_reason = None

try:
response = await acompletion(
messages=messages,
Expand All @@ -307,7 +319,54 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
timeout=600,
**llm_params,
)
break

async for chunk in response:
if session.is_cancelled:
tool_calls_acc.clear()
break

choice = chunk.choices[0] if chunk.choices else None
if not choice:
if hasattr(chunk, "usage") and chunk.usage:
token_count = chunk.usage.total_tokens
continue

delta = choice.delta
if choice.finish_reason:
finish_reason = choice.finish_reason

if delta.content:
full_content += delta.content
await session.send_event(
Event(event_type="assistant_chunk", data={"content": delta.content})
)

if delta.tool_calls:
for tc_delta in delta.tool_calls:
idx = tc_delta.index
if idx not in tool_calls_acc:
tool_calls_acc[idx] = {
"id": "", "type": "function",
"function": {"name": "", "arguments": ""},
}
if tc_delta.id:
tool_calls_acc[idx]["id"] = tc_delta.id
if tc_delta.function:
if tc_delta.function.name:
tool_calls_acc[idx]["function"]["name"] += tc_delta.function.name
if tc_delta.function.arguments:
tool_calls_acc[idx]["function"]["arguments"] += tc_delta.function.arguments

if hasattr(chunk, "usage") and chunk.usage:
token_count = chunk.usage.total_tokens

return LLMResult(
content=full_content or None,
tool_calls_acc=tool_calls_acc,
token_count=token_count,
finish_reason=finish_reason,
)

except ContextWindowExceededError:
raise
except Exception as e:
Expand All @@ -333,57 +392,7 @@ async def _call_llm_streaming(session: Session, messages, tools, llm_params) ->
continue
raise

full_content = ""
tool_calls_acc: dict[int, dict] = {}
token_count = 0
finish_reason = None

async for chunk in response:
if session.is_cancelled:
tool_calls_acc.clear()
break

choice = chunk.choices[0] if chunk.choices else None
if not choice:
if hasattr(chunk, "usage") and chunk.usage:
token_count = chunk.usage.total_tokens
continue

delta = choice.delta
if choice.finish_reason:
finish_reason = choice.finish_reason

if delta.content:
full_content += delta.content
await session.send_event(
Event(event_type="assistant_chunk", data={"content": delta.content})
)

if delta.tool_calls:
for tc_delta in delta.tool_calls:
idx = tc_delta.index
if idx not in tool_calls_acc:
tool_calls_acc[idx] = {
"id": "", "type": "function",
"function": {"name": "", "arguments": ""},
}
if tc_delta.id:
tool_calls_acc[idx]["id"] = tc_delta.id
if tc_delta.function:
if tc_delta.function.name:
tool_calls_acc[idx]["function"]["name"] += tc_delta.function.name
if tc_delta.function.arguments:
tool_calls_acc[idx]["function"]["arguments"] += tc_delta.function.arguments

if hasattr(chunk, "usage") and chunk.usage:
token_count = chunk.usage.total_tokens

return LLMResult(
content=full_content or None,
tool_calls_acc=tool_calls_acc,
token_count=token_count,
finish_reason=finish_reason,
)
raise RuntimeError("Exhausted LLM retries without returning or raising")


async def _call_llm_non_streaming(session: Session, messages, tools, llm_params) -> LLMResult:
Expand Down
53 changes: 32 additions & 21 deletions agent/core/llm_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,28 @@ class UnsupportedEffortError(ValueError):
"""


def _resolve_anthropic_effort(
params: dict, reasoning_effort: str | None, strict: bool,
) -> dict:
"""Apply Anthropic-family thinking config to ``params`` (shared by
``anthropic/`` and ``bedrock/`` paths — same Claude models, same API
shape for thinking/effort).
"""
if reasoning_effort:
level = reasoning_effort
if level == "minimal":
level = "low"
if level not in _ANTHROPIC_EFFORTS:
if strict:
raise UnsupportedEffortError(
f"Anthropic doesn't accept effort={level!r}"
)
else:
params["thinking"] = {"type": "adaptive"}
params["output_config"] = {"effort": level}
return params


def _resolve_llm_params(
model_name: str,
session_hf_token: str | None = None,
Expand All @@ -106,6 +128,12 @@ def _resolve_llm_params(
will reject this; the probe's cascade catches that and falls back
to no thinking.

• ``bedrock/<model>`` — same Claude models via Amazon Bedrock. LiteLLM
handles the AWS auth (via ``AWS_ACCESS_KEY_ID`` /
``AWS_SECRET_ACCESS_KEY`` / ``AWS_REGION_NAME``, or
``AWS_BEARER_TOKEN_BEDROCK`` for SSO / identity-center auth). The
thinking / effort params are identical to the ``anthropic/`` path.

• ``openai/<model>`` — ``reasoning_effort`` forwarded as a top-level
kwarg (GPT-5 / o-series). LiteLLM uses the user's ``OPENAI_API_KEY``.

Expand All @@ -132,27 +160,10 @@ def _resolve_llm_params(
3. HF_TOKEN env — belt-and-suspenders fallback for CLI users.
"""
if model_name.startswith("anthropic/"):
params: dict = {"model": model_name}
if reasoning_effort:
level = reasoning_effort
if level == "minimal":
level = "low"
if level not in _ANTHROPIC_EFFORTS:
if strict:
raise UnsupportedEffortError(
f"Anthropic doesn't accept effort={level!r}"
)
else:
# Adaptive thinking + output_config.effort is the stable
# Anthropic API for Claude 4.6 / 4.7. Both kwargs are
# passed top-level: LiteLLM forwards unknown params into
# the request body for Anthropic, so ``output_config``
# reaches the API. ``extra_body`` does NOT work here —
# Anthropic rejects it as "Extra inputs are not
# permitted".
params["thinking"] = {"type": "adaptive"}
params["output_config"] = {"effort": level}
return params
return _resolve_anthropic_effort({"model": model_name}, reasoning_effort, strict)

if model_name.startswith("bedrock/"):
return _resolve_anthropic_effort({"model": model_name}, reasoning_effort, strict)

if model_name.startswith("openai/"):
params = {"model": model_name}
Expand Down
6 changes: 4 additions & 2 deletions agent/core/model_switcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def is_valid_model_id(model_id: str) -> bool:

Accepts:
• anthropic/<model>
• bedrock/<model>
• openai/<model>
• <org>/<model>[:<tag>] (HF router; tag = provider or policy)
• huggingface/<org>/<model>[:<tag>] (same, accepts legacy prefix)
Expand All @@ -63,7 +64,7 @@ def _print_hf_routing_info(model_id: str, console) -> bool:
Anthropic / OpenAI ids return ``True`` without printing anything —
the probe below covers "does this model exist".
"""
if model_id.startswith(("anthropic/", "openai/")):
if model_id.startswith(("anthropic/", "openai/", "bedrock/")):
return True

from agent.core import hf_router_catalog as cat
Expand Down Expand Up @@ -136,7 +137,7 @@ def print_model_listing(config, console) -> None:
console.print(
"\n[dim]Paste any HF model id (e.g. 'MiniMaxAI/MiniMax-M2.7').\n"
"Add ':fastest', ':cheapest', ':preferred', or ':<provider>' to override routing.\n"
"Use 'anthropic/<model>' or 'openai/<model>' for direct API access.[/dim]"
"Use 'anthropic/<model>', 'bedrock/<model>', or 'openai/<model>' for direct API access.[/dim]"
)


Expand All @@ -146,6 +147,7 @@ def print_invalid_id(arg: str, console) -> None:
"[dim]Expected:\n"
" • <org>/<model>[:tag] (HF router — paste from huggingface.co)\n"
" • anthropic/<model>\n"
" • bedrock/<model>\n"
" • openai/<model>[/dim]"
)

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies = [
"nbconvert>=7.16.6",
"nbformat>=5.10.4",
"whoosh>=2.7.4",
"boto3>=1.35.0",
# Web backend dependencies
"fastapi>=0.115.0",
"uvicorn[standard]>=0.32.0",
Expand Down