Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 43 additions & 114 deletions src/chattr/app/builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Main orchestration graph for the Chattr application."""

from collections.abc import AsyncGenerator
from json import dumps, loads
from json import dumps
from pathlib import Path

from agno.agent import (
Expand All @@ -28,12 +28,10 @@
Video,
)
from gradio.components.chatbot import MetadataDict
from m3u8 import M3U8, load
from poml import poml
from pydantic import HttpUrl, ValidationError
from requests import Session
from rich.pretty import pprint

from chattr.app.scheme import MCPScheme, URLConnection
from chattr.app.settings import Settings, logger


Expand All @@ -48,13 +46,7 @@ async def _setup_agent(self) -> Agent:
model=self._setup_model(),
tools=await self._setup_tools(),
description="You are a helpful assistant who can act and mimic Napoleon's character and answer questions about the era.",
instructions=[
"Understand the user's question and context.",
"Gather relevant information and resources.",
"Formulate a clear and concise response in Napoleon's voice.",
"ALWAYS generate audio from the formulated response using the appropriate Tool.",
"Generate video from the resulted audio using the appropriate Tool.",
],
instructions=self._setup_prompt(),
db=self._setup_database(),
knowledge=self._setup_knowledge(
self._setup_vector_database(),
Expand All @@ -71,14 +63,11 @@ async def _setup_agent(self) -> Agent:
)

async def _setup_tools(self) -> list[Toolkit]:
mcp_servers: list[dict] = loads(self.settings.mcp.path.read_text()).get(
"mcp_servers",
[],
)
url_servers = [m for m in mcp_servers if m.get("type") == "url"]
scheme = MCPScheme.model_validate_json(self.settings.mcp.path.read_text())
url_servers = [s for s in scheme.mcp_servers if isinstance(s, URLConnection)]
self.mcp_tools = MultiMCPTools(
urls=[m.get("url") for m in url_servers],
urls_transports=[m.get("transport") for m in url_servers],
urls=[str(s.url) for s in url_servers],
urls_transports=[s.transport for s in url_servers],
)
await self.mcp_tools.connect()
return [self.mcp_tools]
Expand All @@ -91,7 +80,7 @@ def _setup_prompt(self) -> str:
format="dict",
)
if not isinstance(prompt_template, dict):
_msg = "Prompt template must be a string."
_msg = "Prompt template must be a dictionary."
raise TypeError(_msg)
return prompt_template["messages"]

Expand Down Expand Up @@ -150,21 +139,19 @@ async def generate_response(
self,
message: str,
history: list[ChatMessage],
) -> AsyncGenerator[tuple[str, list[ChatMessage], Path | None, Path | None]]:
) -> AsyncGenerator[list[ChatMessage]]:
"""
Generate a response to a user message and update the conversation history.

This asynchronous method streams responses from the state graph and
yields updated history and audio file paths as needed.
yields updated history.

Args:
message: The user's input message as a string.
history: The conversation history as a list of ChatMessage objects.

Returns:
AsyncGenerator: Yields a tuple containing an
empty string, the updated history, and
a Path to an audio file if generated.
AsyncGenerator: Yields the updated history.
"""
try:
agent: Agent = await self._setup_agent()
Expand All @@ -175,70 +162,59 @@ async def generate_response(
pprint(response)
if isinstance(response, RunContentEvent):
history.append(
ChatMessage(
role="assistant",
content=response.content,
),
ChatMessage(role="assistant", content=response.content),
)
elif isinstance(
response, ToolCallStartedEvent | ToolCallCompletedEvent
):
tool = response.tool
metadata = MetadataDict(
title=tool.tool_name,
id=tool.tool_call_id,
)
elif isinstance(response, ToolCallStartedEvent):

if isinstance(response, ToolCallStartedEvent):
metadata["duration"] = tool.created_at
else:
metadata["log"] = (
"Tool Call Failed"
if tool.tool_call_error
else "Tool Call Succeeded"
)
metadata["duration"] = tool.metrics.duration

history.append(
ChatMessage(
role="assistant",
content=dumps(response.tool.tool_args, indent=4),
metadata=MetadataDict(
title=response.tool.tool_name,
id=response.tool.tool_call_id,
duration=response.tool.created_at,
),
content=dumps(tool.tool_args, indent=4),
metadata=metadata,
),
)
elif isinstance(response, ToolCallCompletedEvent):
if response.tool.tool_call_error:
history.append(
ChatMessage(
role="assistant",
content=dumps(response.tool.tool_args, indent=4),
metadata=MetadataDict(
title=response.tool.tool_name,
id=response.tool.tool_call_id,
log="Tool Call Failed",
duration=response.tool.metrics.duration,
),
),
)
else:
history.append(
ChatMessage(
role="assistant",
content=dumps(response.tool.tool_args, indent=4),
metadata=MetadataDict(
title=response.tool.tool_name,
id=response.tool.tool_call_id,
log="Tool Call Succeeded",
duration=response.tool.metrics.duration,
),
),
)
if response.tool.tool_name == "generate_audio_for_text":

if (
isinstance(response, ToolCallCompletedEvent)
and not tool.tool_call_error
):
if tool.tool_name == "generate_audio_for_text":
history.append(
Audio(
response.tool.result,
tool.result,
autoplay=True,
show_download_button=True,
show_share_button=True,
),
)
elif response.tool.tool_name == "generate_video_mcp":
elif tool.tool_name == "generate_video_mcp":
history.append(
Video(
response.tool.result,
tool.result,
autoplay=True,
show_download_button=True,
show_share_button=True,
),
)
else:
msg = f"Unknown tool name: {response.tool.tool_name}"
msg = f"Unknown tool name: {tool.tool_name}"
raise Error(msg)
yield history
except Exception as e:
Expand All @@ -248,53 +224,6 @@ async def generate_response(
finally:
await self._close()

def _is_url(self, value: str | None) -> bool:
"""
Check if a string is a valid URL.

Args:
value: The string to check. Can be None.

Returns:
bool: True if the string is a valid URL, False otherwise.
"""
if value is None:
return False

try:
_ = HttpUrl(value)
except ValidationError:
return False
return True

def _download_file(self, url: HttpUrl, path: Path) -> None:
"""
Download a file from a URL and save it to a local path.

Args:
url: The URL to download the file from.
path: The local file path where the downloaded file will be saved.

Returns:
None

Raises:
requests.RequestException: If the HTTP request fails.
IOError: If file writing fails.
"""
if str(url).endswith(".m3u8"):
_playlist: M3U8 = load(url)
url: str = str(url).replace("playlist.m3u8", _playlist.segments[0].uri)
logger.info(f"Downloading {url} to {path}")
session = Session()
response = session.get(url, stream=True, timeout=30)
response.raise_for_status()
with path.open("wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
logger.info(f"File downloaded to {path}")

async def _close(self) -> None:
try:
logger.info("Closing MCP tools...")
Expand Down
38 changes: 10 additions & 28 deletions src/chattr/app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,17 @@ class MCPSettings(BaseModel):
path: FilePath = Field(default_factory=lambda: Path.cwd() / "mcp.json")

@model_validator(mode="after")
def is_exists(self) -> Self:
"""Check if the MCP config file exists."""
def validate_mcp_config(self) -> Self:
"""Validate the MCP config file."""
if not self.path.exists():
logger.warning("`mcp.json` not found.")
return self
return self

@model_validator(mode="after")
def is_valid(self) -> Self:
"""Validate that the MCP config file is a JSON file."""
if self.path and self.path.suffix != ".json":
if self.path.suffix != ".json":
msg = "MCP config file must be a JSON file"
raise ValueError(msg)
return self

@model_validator(mode="after")
def is_valid_scheme(self) -> Self:
"""Validate that the MCP config file has a valid scheme."""
if self.path and self.path.exists():
_ = MCPScheme.model_validate_json(self.path.read_text())
MCPScheme.model_validate_json(self.path.read_text())
return self


Expand Down Expand Up @@ -95,22 +87,12 @@ def prompts(self) -> DirectoryPath:

@model_validator(mode="after")
def create_missing_dirs(self) -> Self:
"""
Ensure that all specified directories exist, creating them if necessary.

Checks and creates any missing directories defined in the `DirectorySettings`.

Returns:
Self: The validated DirectorySettings instance.
"""
for directory in [self.base, self.assets, self.audio, self.video, self.prompts]:
"""Ensure that all specified directories exist, creating them if necessary."""
dirs = [self.base, self.assets, self.audio, self.video, self.prompts]
for directory in dirs:
if not directory.exists():
try:
directory.mkdir(parents=True, exist_ok=True)
logger.info("Created directory %s.", directory)
except OSError as e:
logger.error("Error creating directory %s: %s", directory, e)
raise
directory.mkdir(parents=True, exist_ok=True)
logger.info("Created directory %s.", directory)
return self


Expand Down
Loading