diff --git a/src/chattr/app/builder.py b/src/chattr/app/builder.py index 7a47f6206..061e82246 100644 --- a/src/chattr/app/builder.py +++ b/src/chattr/app/builder.py @@ -1,7 +1,7 @@ """Main orchestration graph for the Chattr application.""" from collections.abc import AsyncGenerator -from json import dumps, loads +from json import dumps from pathlib import Path from agno.agent import ( @@ -28,12 +28,10 @@ Video, ) from gradio.components.chatbot import MetadataDict -from m3u8 import M3U8, load from poml import poml -from pydantic import HttpUrl, ValidationError -from requests import Session from rich.pretty import pprint +from chattr.app.scheme import MCPScheme, URLConnection from chattr.app.settings import Settings, logger @@ -48,13 +46,7 @@ async def _setup_agent(self) -> Agent: model=self._setup_model(), tools=await self._setup_tools(), description="You are a helpful assistant who can act and mimic Napoleon's character and answer questions about the era.", - instructions=[ - "Understand the user's question and context.", - "Gather relevant information and resources.", - "Formulate a clear and concise response in Napoleon's voice.", - "ALWAYS generate audio from the formulated response using the appropriate Tool.", - "Generate video from the resulted audio using the appropriate Tool.", - ], + instructions=self._setup_prompt(), db=self._setup_database(), knowledge=self._setup_knowledge( self._setup_vector_database(), @@ -71,14 +63,11 @@ async def _setup_agent(self) -> Agent: ) async def _setup_tools(self) -> list[Toolkit]: - mcp_servers: list[dict] = loads(self.settings.mcp.path.read_text()).get( - "mcp_servers", - [], - ) - url_servers = [m for m in mcp_servers if m.get("type") == "url"] + scheme = MCPScheme.model_validate_json(self.settings.mcp.path.read_text()) + url_servers = [s for s in scheme.mcp_servers if isinstance(s, URLConnection)] self.mcp_tools = MultiMCPTools( - urls=[m.get("url") for m in url_servers], - urls_transports=[m.get("transport") for m in url_servers], + urls=[str(s.url) for s in url_servers], + urls_transports=[s.transport for s in url_servers], ) await self.mcp_tools.connect() return [self.mcp_tools] @@ -91,7 +80,7 @@ def _setup_prompt(self) -> str: format="dict", ) if not isinstance(prompt_template, dict): - _msg = "Prompt template must be a string." + _msg = "Prompt template must be a dictionary." raise TypeError(_msg) return prompt_template["messages"] @@ -150,21 +139,19 @@ async def generate_response( self, message: str, history: list[ChatMessage], - ) -> AsyncGenerator[tuple[str, list[ChatMessage], Path | None, Path | None]]: + ) -> AsyncGenerator[list[ChatMessage]]: """ Generate a response to a user message and update the conversation history. This asynchronous method streams responses from the state graph and - yields updated history and audio file paths as needed. + yields updated history. Args: message: The user's input message as a string. history: The conversation history as a list of ChatMessage objects. Returns: - AsyncGenerator: Yields a tuple containing an - empty string, the updated history, and - a Path to an audio file if generated. + AsyncGenerator: Yields the updated history. """ try: agent: Agent = await self._setup_agent() @@ -175,70 +162,59 @@ async def generate_response( pprint(response) if isinstance(response, RunContentEvent): history.append( - ChatMessage( - role="assistant", - content=response.content, - ), + ChatMessage(role="assistant", content=response.content), + ) + elif isinstance( + response, ToolCallStartedEvent | ToolCallCompletedEvent + ): + tool = response.tool + metadata = MetadataDict( + title=tool.tool_name, + id=tool.tool_call_id, ) - elif isinstance(response, ToolCallStartedEvent): + + if isinstance(response, ToolCallStartedEvent): + metadata["duration"] = tool.created_at + else: + metadata["log"] = ( + "Tool Call Failed" + if tool.tool_call_error + else "Tool Call Succeeded" + ) + metadata["duration"] = tool.metrics.duration + history.append( ChatMessage( role="assistant", - content=dumps(response.tool.tool_args, indent=4), - metadata=MetadataDict( - title=response.tool.tool_name, - id=response.tool.tool_call_id, - duration=response.tool.created_at, - ), + content=dumps(tool.tool_args, indent=4), + metadata=metadata, ), ) - elif isinstance(response, ToolCallCompletedEvent): - if response.tool.tool_call_error: - history.append( - ChatMessage( - role="assistant", - content=dumps(response.tool.tool_args, indent=4), - metadata=MetadataDict( - title=response.tool.tool_name, - id=response.tool.tool_call_id, - log="Tool Call Failed", - duration=response.tool.metrics.duration, - ), - ), - ) - else: - history.append( - ChatMessage( - role="assistant", - content=dumps(response.tool.tool_args, indent=4), - metadata=MetadataDict( - title=response.tool.tool_name, - id=response.tool.tool_call_id, - log="Tool Call Succeeded", - duration=response.tool.metrics.duration, - ), - ), - ) - if response.tool.tool_name == "generate_audio_for_text": + + if ( + isinstance(response, ToolCallCompletedEvent) + and not tool.tool_call_error + ): + if tool.tool_name == "generate_audio_for_text": history.append( Audio( - response.tool.result, + tool.result, autoplay=True, show_download_button=True, show_share_button=True, ), ) - elif response.tool.tool_name == "generate_video_mcp": + elif tool.tool_name == "generate_video_mcp": history.append( Video( - response.tool.result, + tool.result, autoplay=True, show_download_button=True, show_share_button=True, ), ) else: - msg = f"Unknown tool name: {response.tool.tool_name}" + msg = f"Unknown tool name: {tool.tool_name}" raise Error(msg) yield history except Exception as e: @@ -248,53 +224,6 @@ async def generate_response( finally: await self._close() - def _is_url(self, value: str | None) -> bool: - """ - Check if a string is a valid URL. - - Args: - value: The string to check. Can be None. - - Returns: - bool: True if the string is a valid URL, False otherwise. - """ - if value is None: - return False - - try: - _ = HttpUrl(value) - except ValidationError: - return False - return True - - def _download_file(self, url: HttpUrl, path: Path) -> None: - """ - Download a file from a URL and save it to a local path. - - Args: - url: The URL to download the file from. - path: The local file path where the downloaded file will be saved. - - Returns: - None - - Raises: - requests.RequestException: If the HTTP request fails. - IOError: If file writing fails. - """ - if str(url).endswith(".m3u8"): - _playlist: M3U8 = load(url) - url: str = str(url).replace("playlist.m3u8", _playlist.segments[0].uri) - logger.info(f"Downloading {url} to {path}") - session = Session() - response = session.get(url, stream=True, timeout=30) - response.raise_for_status() - with path.open("wb") as f: - for chunk in response.iter_content(chunk_size=8192): - if chunk: - f.write(chunk) - logger.info(f"File downloaded to {path}") - async def _close(self) -> None: try: logger.info("Closing MCP tools...") diff --git a/src/chattr/app/settings.py b/src/chattr/app/settings.py index 04989ba69..5e1178bcf 100644 --- a/src/chattr/app/settings.py +++ b/src/chattr/app/settings.py @@ -42,25 +42,17 @@ class MCPSettings(BaseModel): path: FilePath = Field(default_factory=lambda: Path.cwd() / "mcp.json") @model_validator(mode="after") - def is_exists(self) -> Self: - """Check if the MCP config file exists.""" + def validate_mcp_config(self) -> Self: + """Validate the MCP config file.""" if not self.path.exists(): logger.warning("`mcp.json` not found.") - return self + return self - @model_validator(mode="after") - def is_valid(self) -> Self: - """Validate that the MCP config file is a JSON file.""" - if self.path and self.path.suffix != ".json": + if self.path.suffix != ".json": msg = "MCP config file must be a JSON file" raise ValueError(msg) - return self - @model_validator(mode="after") - def is_valid_scheme(self) -> Self: - """Validate that the MCP config file has a valid scheme.""" - if self.path and self.path.exists(): - _ = MCPScheme.model_validate_json(self.path.read_text()) + MCPScheme.model_validate_json(self.path.read_text()) return self @@ -95,22 +87,12 @@ def prompts(self) -> DirectoryPath: @model_validator(mode="after") def create_missing_dirs(self) -> Self: - """ - Ensure that all specified directories exist, creating them if necessary. - - Checks and creates any missing directories defined in the `DirectorySettings`. - - Returns: - Self: The validated DirectorySettings instance. - """ - for directory in [self.base, self.assets, self.audio, self.video, self.prompts]: + """Ensure that all specified directories exist, creating them if necessary.""" + dirs = [self.base, self.assets, self.audio, self.video, self.prompts] + for directory in dirs: if not directory.exists(): - try: - directory.mkdir(parents=True, exist_ok=True) - logger.info("Created directory %s.", directory) - except OSError as e: - logger.error("Error creating directory %s: %s", directory, e) - raise + directory.mkdir(parents=True, exist_ok=True) + logger.info("Created directory %s.", directory) return self