diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 86443f042c..a6aa909f50 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -6,7 +6,7 @@ on: - 'platform/**' env: - PYTHON_VERSION: "3.11" + PYTHON_VERSION: "3.10" jobs: black: diff --git a/README.md b/README.md index 0e5c5d4bc9..8d869bb0be 100644 --- a/README.md +++ b/README.md @@ -125,3 +125,16 @@ Our contributors have made this project possible. Thank you! 🙏
Made with contrib.rocks.
+ +## Arguments for OpenAIAgentService + +The `OpenAIAgentService` class is defined in `platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py`. + +The constructor of `OpenAIAgentService` takes the following arguments: + +- `model`: The model to be used. +- `settings`: The settings for the model. +- `token_service`: The token service for managing tokens. +- `callbacks`: Optional list of callback handlers. +- `user`: The user information. +- `oauth_crud`: The OAuth CRUD operations. diff --git a/next/src/server/api/routers/agentRouter.ts b/next/src/server/api/routers/agentRouter.ts index 05ab7cab75..9ae0186528 100644 --- a/next/src/server/api/routers/agentRouter.ts +++ b/next/src/server/api/routers/agentRouter.ts @@ -42,7 +42,7 @@ async function generateAgentName(goal: string) { `, }, ], - model: "gpt-3.5-turbo", + model: "llama3.2", }); // @ts-ignore diff --git a/next/src/stores/modelSettingsStore.ts b/next/src/stores/modelSettingsStore.ts index c7baa7df52..086cfa2821 100644 --- a/next/src/stores/modelSettingsStore.ts +++ b/next/src/stores/modelSettingsStore.ts @@ -42,7 +42,7 @@ export const useModelSettingsStore = createSelectors( partialize: (state) => ({ modelSettings: { ...state.modelSettings, - customModelName: "gpt-3.5-turbo", + customModelName: "llama3.2", maxTokens: Math.min(state.modelSettings.maxTokens, 4000), }, }), diff --git a/next/src/types/modelSettings.ts b/next/src/types/modelSettings.ts index a3df4c07e8..b12532b29f 100644 --- a/next/src/types/modelSettings.ts +++ b/next/src/types/modelSettings.ts @@ -1,17 +1,19 @@ import { type Language } from "../utils/languages"; -export const [GPT_35_TURBO, GPT_35_TURBO_16K, GPT_4] = [ +export const [GPT_35_TURBO, GPT_35_TURBO_16K, GPT_4, LLAMA_3_2] = [ "gpt-3.5-turbo" as const, "gpt-3.5-turbo-16k" as const, "gpt-4" as const, + "llama3.2" as const, ]; -export const GPT_MODEL_NAMES = [GPT_35_TURBO, GPT_35_TURBO_16K, GPT_4]; -export type GPTModelNames = "gpt-3.5-turbo" | "gpt-3.5-turbo-16k" | "gpt-4"; +export const GPT_MODEL_NAMES = [GPT_35_TURBO, GPT_35_TURBO_16K, GPT_4, LLAMA_3_2]; +export type GPTModelNames = "gpt-3.5-turbo" | "gpt-3.5-turbo-16k" | "gpt-4" | "llama3.2"; export const MAX_TOKENS: Record = { "gpt-3.5-turbo": 4000, "gpt-3.5-turbo-16k": 16000, "gpt-4": 4000, + "llama3.2": 4000, }; export interface ModelSettings { diff --git a/platform/Dockerfile b/platform/Dockerfile index fbdf5fa1ab..fb575e0892 100644 --- a/platform/Dockerfile +++ b/platform/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.11-slim-buster as prod +FROM python:3.10-slim-buster as prod RUN apt-get update && apt-get install -y \ default-libmysqlclient-dev \ @@ -30,6 +30,9 @@ RUN apt-get purge -y \ COPY . /app/src/ RUN poetry install --only main +# Install ollama +RUN pip install ollama + CMD ["/usr/local/bin/python", "-m", "reworkd_platform"] FROM prod as dev diff --git a/platform/README.md b/platform/README.md index db545ea950..6c89c84c19 100644 --- a/platform/README.md +++ b/platform/README.md @@ -149,3 +149,248 @@ poetry run pytest -vv --cov="reworkd_platform" . poetry self add poetry-plugin-up poetry up --latest ``` + +## Installing the package using pip + +To install the `reworkd_platform` package using pip, run the following command: + +```bash +pip install reworkd_platform +``` + +## Using the package in any code + +To use the `reworkd_platform` package in your code, you can import it as follows: + +```python +import reworkd_platform + +# Example usage +reworkd_platform.some_function() +``` + +## Using pip functions + +The `reworkd_platform` package provides several functions for interacting with agents. Here are some examples: + +### Starting a goal agent + +```python +from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import OpenAIAgentService +from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI +from reworkd_platform.schemas.agent import ModelSettings +from reworkd_platform.services.tokenizer.token_service import TokenService +from reworkd_platform.db.crud.oauth import OAuthCrud +from reworkd_platform.schemas.user import UserBase + +# Initialize the OpenAIAgentService +model = WrappedChatOpenAI(model_name="gpt-3.5-turbo") +settings = ModelSettings(language="en") +token_service = TokenService.create() +callbacks = None +user = UserBase(id=1, name="John Doe") +oauth_crud = OAuthCrud() + +agent_service = OpenAIAgentService(model, settings, token_service, callbacks, user, oauth_crud) + +# Start a goal agent +tasks = agent_service.pip_start_goal_agent(goal="Your goal here") +print(tasks) +``` + +### Analyzing a task agent + +```python +from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import OpenAIAgentService +from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI +from reworkd_platform.schemas.agent import ModelSettings +from reworkd_platform.services.tokenizer.token_service import TokenService +from reworkd_platform.db.crud.oauth import OAuthCrud +from reworkd_platform.schemas.user import UserBase + +# Initialize the OpenAIAgentService +model = WrappedChatOpenAI(model_name="gpt-3.5-turbo") +settings = ModelSettings(language="en") +token_service = TokenService.create() +callbacks = None +user = UserBase(id=1, name="John Doe") +oauth_crud = OAuthCrud() + +agent_service = OpenAIAgentService(model, settings, token_service, callbacks, user, oauth_crud) + +# Analyze a task agent +analysis = agent_service.pip_analyze_task_agent(goal="Your goal here", task="Your task here", tool_names=["tool1", "tool2"]) +print(analysis) +``` + +### Executing a task agent + +```python +from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import OpenAIAgentService +from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI +from reworkd_platform.schemas.agent import ModelSettings +from reworkd_platform.services.tokenizer.token_service import TokenService +from reworkd_platform.db.crud.oauth import OAuthCrud +from reworkd_platform.schemas.user import UserBase + +# Initialize the OpenAIAgentService +model = WrappedChatOpenAI(model_name="gpt-3.5-turbo") +settings = ModelSettings(language="en") +token_service = TokenService.create() +callbacks = None +user = UserBase(id=1, name="John Doe") +oauth_crud = OAuthCrud() + +agent_service = OpenAIAgentService(model, settings, token_service, callbacks, user, oauth_crud) + +# Execute a task agent +response = agent_service.pip_execute_task_agent(goal="Your goal here", task="Your task here", analysis=analysis) +print(response) +``` + +### Creating tasks agent + +```python +from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import OpenAIAgentService +from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI +from reworkd_platform.schemas.agent import ModelSettings +from reworkd_platform.services.tokenizer.token_service import TokenService +from reworkd_platform.db.crud.oauth import OAuthCrud +from reworkd_platform.schemas.user import UserBase + +# Initialize the OpenAIAgentService +model = WrappedChatOpenAI(model_name="gpt-3.5-turbo") +settings = ModelSettings(language="en") +token_service = TokenService.create() +callbacks = None +user = UserBase(id=1, name="John Doe") +oauth_crud = OAuthCrud() + +agent_service = OpenAIAgentService(model, settings, token_service, callbacks, user, oauth_crud) + +# Create tasks agent +tasks = agent_service.pip_create_tasks_agent(goal="Your goal here", tasks=["task1", "task2"], last_task="Your last task here", result="Your result here") +print(tasks) +``` + +### Summarizing task agent + +```python +from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import OpenAIAgentService +from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI +from reworkd_platform.schemas.agent import ModelSettings +from reworkd_platform.services.tokenizer.token_service import TokenService +from reworkd_platform.db.crud.oauth import OAuthCrud +from reworkd_platform.schemas.user import UserBase + +# Initialize the OpenAIAgentService +model = WrappedChatOpenAI(model_name="gpt-3.5-turbo") +settings = ModelSettings(language="en") +token_service = TokenService.create() +callbacks = None +user = UserBase(id=1, name="John Doe") +oauth_crud = OAuthCrud() + +agent_service = OpenAIAgentService(model, settings, token_service, callbacks, user, oauth_crud) + +# Summarize task agent +response = agent_service.pip_summarize_task_agent(goal="Your goal here", results=["result1", "result2"]) +print(response) +``` + +### Chatting with agent + +```python +from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import OpenAIAgentService +from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI +from reworkd_platform.schemas.agent import ModelSettings +from reworkd_platform.services.tokenizer.token_service import TokenService +from reworkd_platform.db.crud.oauth import OAuthCrud +from reworkd_platform.schemas.user import UserBase + +# Initialize the OpenAIAgentService +model = WrappedChatOpenAI(model_name="gpt-3.5-turbo") +settings = ModelSettings(language="en") +token_service = TokenService.create() +callbacks = None +user = UserBase(id=1, name="John Doe") +oauth_crud = OAuthCrud() + +agent_service = OpenAIAgentService(model, settings, token_service, callbacks, user, oauth_crud) + +# Chat with agent +response = agent_service.pip_chat(message="Your message here", results=["result1", "result2"]) +print(response) +``` + +## Using ollama + +The `reworkd_platform` package also provides support for `ollama`. Here are some examples: + +### Adding ollama as a dependency + +To add `ollama` as a dependency, include it in your `pyproject.toml` file under `[tool.poetry.dependencies]`: + +```toml +[tool.poetry.dependencies] +ollama = "^0.1.0" +``` + +### Installing ollama in Docker + +To install `ollama` in the Docker image, add the following command to your `Dockerfile`: + +```dockerfile +# Install ollama +RUN pip install ollama +``` + +### Using ollama in your code + +To use `ollama` in your code, you can import it as follows: + +```python +import ollama + +# Example usage +model = ollama.Ollama(model="llama3.2") +chain = model.create_chain(prompt="Your prompt here") +response = chain.run("Your input here") +print(response) +``` + +## Using Python 3.10 + +The `reworkd_platform` package is compatible with Python 3.10. Here are some examples: + +### Specifying Python 3.10 in `pyproject.toml` + +To specify Python 3.10 as the required version, include the following in your `pyproject.toml` file: + +```toml +[tool.poetry.dependencies] +python = "^3.10" +``` + +### Using Python 3.10 in Docker + +To use Python 3.10 in the Docker image, update the base image in your `Dockerfile`: + +```dockerfile +FROM python:3.10-slim-buster as prod +``` + +### Running the project with Python 3.10 + +To run the project with Python 3.10, make sure you have Python 3.10 installed on your system. You can download and install Python 3.10 from the official Python website: https://www.python.org/downloads/release/python-3100/ + +Once you have Python 3.10 installed, you can create a virtual environment and install the dependencies using Poetry: + +```bash +python3.10 -m venv venv +source venv/bin/activate +poetry install +poetry run python -m reworkd_platform +``` + +This will start the server on the configured host using Python 3.10. diff --git a/platform/pyproject.toml b/platform/pyproject.toml index e48072e813..1654c7b84c 100644 --- a/platform/pyproject.toml +++ b/platform/pyproject.toml @@ -14,7 +14,7 @@ maintainers = [ readme = "README.md" [tool.poetry.dependencies] -python = "^3.11" +python = "^3.10" fastapi = "^0.98.0" boto3 = "^1.28.51" uvicorn = { version = "^0.22.0", extras = ["standard"] } @@ -41,6 +41,7 @@ botocore = "^1.31.51" stripe = "^5.5.0" cryptography = "^41.0.4" httpx = "^0.25.0" +ollama = "^0.1.0" [tool.poetry.dev-dependencies] @@ -96,5 +97,5 @@ env = [ ] [build-system] -requires = ["poetry-core>=1.0.0"] -build-backend = "poetry.core.masonry.api" +requires = ["poetry-core>=1.0.0", "setuptools", "wheel"] +build-backend = "setuptools.build_meta" diff --git a/platform/reworkd_platform/db/crud/oauth.py b/platform/reworkd_platform/db/crud/oauth.py index 911bd01207..8e263c5e08 100644 --- a/platform/reworkd_platform/db/crud/oauth.py +++ b/platform/reworkd_platform/db/crud/oauth.py @@ -12,6 +12,9 @@ class OAuthCrud(BaseCrud): + def __init__(self, session: AsyncSession): + super().__init__(session) + @classmethod async def inject( cls, diff --git a/platform/reworkd_platform/db/crud/user.py b/platform/reworkd_platform/db/crud/user.py index 2e9f7111ce..a9db35a20a 100644 --- a/platform/reworkd_platform/db/crud/user.py +++ b/platform/reworkd_platform/db/crud/user.py @@ -6,9 +6,13 @@ from reworkd_platform.db.crud.base import BaseCrud from reworkd_platform.db.models.auth import OrganizationUser from reworkd_platform.db.models.user import UserSession +from sqlalchemy.ext.asyncio import AsyncSession class UserCrud(BaseCrud): + def __init__(self, session: AsyncSession): + super().__init__(session) + async def get_user_session(self, token: str) -> UserSession: query = ( select(UserSession) diff --git a/platform/reworkd_platform/schemas/agent.py b/platform/reworkd_platform/schemas/agent.py index f6c5e6e732..86034c00ce 100644 --- a/platform/reworkd_platform/schemas/agent.py +++ b/platform/reworkd_platform/schemas/agent.py @@ -9,6 +9,7 @@ "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", + "llama3.2", ] Loop_Step = Literal[ "start", @@ -22,6 +23,7 @@ "gpt-3.5-turbo": 4000, "gpt-3.5-turbo-16k": 16000, "gpt-4": 8000, + "llama3.2": 4000, } diff --git a/platform/reworkd_platform/tests/agent/test_model_factory.py b/platform/reworkd_platform/tests/agent/test_model_factory.py index f6579897a4..b937fbb40b 100644 --- a/platform/reworkd_platform/tests/agent/test_model_factory.py +++ b/platform/reworkd_platform/tests/agent/test_model_factory.py @@ -136,3 +136,28 @@ def test_custom_model_settings(model_settings: ModelSettings, streaming: bool): assert model.model_name.startswith(model_settings.model) assert model.max_tokens == model_settings.max_tokens assert model.streaming == streaming + + +def test_create_model_without_max_tokens(): + user = UserBase(id="user_id") + settings = Settings() + model_settings = ModelSettings( + temperature=0.7, + model="gpt-3.5-turbo", + ) + + settings.openai_api_base = "https://api.openai.com" + settings.openai_api_key = "key" + settings.openai_api_version = "version" + + result = create_model(settings, model_settings, user, streaming=False) + assert issubclass(result.__class__, WrappedChatOpenAI) + assert issubclass(result.__class__, ChatOpenAI) + + # Check if the required keys are set properly + assert result.openai_api_base == settings.openai_api_base + assert result.openai_api_key == settings.openai_api_key + assert result.temperature == model_settings.temperature + assert result.max_tokens is None + assert result.streaming is False + assert result.max_retries == 5 diff --git a/platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py b/platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py index 3221a3fbfc..2414aabfe1 100644 --- a/platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py +++ b/platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py @@ -2,11 +2,13 @@ from fastapi.responses import StreamingResponse as FastAPIStreamingResponse from lanarky.responses import StreamingResponse -from langchain import LLMChain from langchain.callbacks.base import AsyncCallbackHandler from langchain.output_parsers import PydanticOutputParser -from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate -from langchain.schema import HumanMessage +from langchain.prompts.chat import ( + ChatPromptTemplate, + SystemMessagePromptTemplate, + HumanMessagePromptTemplate +) from loguru import logger from pydantic import ValidationError @@ -38,6 +40,7 @@ ) from reworkd_platform.web.api.agent.tools.utils import summarize from reworkd_platform.web.api.errors import OpenAIError +from ollama import AsyncClient # Updated import class OpenAIAgentService(AgentService): @@ -56,6 +59,8 @@ def __init__( self.callbacks = callbacks self.user = user self.oauth_crud = oauth_crud + # Initialize the Async Ollama client once + self.client = AsyncClient(host='http://localhost:11434') # Use environment variables for flexibility async def start_goal_agent(self, *, goal: str) -> List[str]: prompt = ChatPromptTemplate.from_messages( @@ -131,7 +136,7 @@ async def execute_task_agent( analysis: Analysis, ) -> StreamingResponse: # TODO: More mature way of calculating max_tokens - if self.model.max_tokens > 3000: + if self.model.max_tokens and self.model.max_tokens > 3000: self.model.max_tokens = max(self.model.max_tokens - 1000, 3000) tool_class = get_tool_from_name(analysis.action) @@ -181,7 +186,7 @@ async def summarize_task_agent( goal: str, results: List[str], ) -> FastAPIStreamingResponse: - self.model.model_name = "gpt-3.5-turbo-16k" + self.model.model_name = "llama3.2" self.model.max_tokens = 8000 # Total tokens = prompt tokens + completion tokens snippet_max_tokens = 7000 # Leave room for the rest of the prompt @@ -189,8 +194,8 @@ async def summarize_task_agent( text = self.token_service.detokenize(text_tokens[0:snippet_max_tokens]) logger.info(f"Summarizing text: {text}") - return summarize( - model=self.model, + return await summarize( + client=self.client, # Pass the initialized AsyncClient language=self.settings.language, goal=goal, text=text, @@ -202,12 +207,12 @@ async def chat( message: str, results: List[str], ) -> FastAPIStreamingResponse: - self.model.model_name = "gpt-3.5-turbo-16k" + self.model.model_name = "llama3.2" prompt = ChatPromptTemplate.from_messages( [ SystemMessagePromptTemplate(prompt=chat_prompt), - *[HumanMessage(content=result) for result in results], - HumanMessage(content=message), + *[HumanMessagePromptTemplate.from_template(result) for result in results], + HumanMessagePromptTemplate.from_template(message), ] ) @@ -218,10 +223,82 @@ async def chat( ).to_string(), ) - chain = LLMChain(llm=self.model, prompt=prompt) + # Format the prompt and extract messages + formatted_prompt = prompt.format_prompt() + messages = [ + {'role': getattr(msg, 'role'), 'content': getattr(msg, 'content')} + for msg in formatted_prompt.to_messages() + ] + + try: + # Make the chat request with streaming + response = await self.client.chat( + model="llama3.2", + messages=messages, + stream=True, + ) + except Exception as e: + logger.exception("Error during Ollama chat request.") + # Handle specific exceptions if necessary + raise + + # Define an asynchronous generator to yield streamed responses + async def stream_response(): + async for chunk in response: + # Ensure 'message' and 'content' keys exist + if 'message' in chunk and 'content' in chunk['message']: + yield chunk['message']['content'] + + return FastAPIStreamingResponse(stream_response(), media_type="text/event-stream") + + # The remaining methods remain unchanged but ensure that any usage of 'Ollama' is replaced accordingly. + + async def pip_start_goal_agent(self, *, goal: str) -> List[str]: + return await self.start_goal_agent(goal=goal) + + async def pip_analyze_task_agent( + self, *, goal: str, task: str, tool_names: List[str] + ) -> Analysis: + return await self.analyze_task_agent(goal=goal, task=task, tool_names=tool_names) + + async def pip_execute_task_agent( + self, + *, + goal: str, + task: str, + analysis: Analysis, + ) -> StreamingResponse: + return await self.execute_task_agent(goal=goal, task=task, analysis=analysis) - return StreamingResponse.from_chain( - chain, - {"language": self.settings.language}, - media_type="text/event-stream", + async def pip_create_tasks_agent( + self, + *, + goal: str, + tasks: List[str], + last_task: str, + result: str, + completed_tasks: Optional[List[str]] = None, + ) -> List[str]: + return await self.create_tasks_agent( + goal=goal, + tasks=tasks, + last_task=last_task, + result=result, + completed_tasks=completed_tasks, ) + + async def pip_summarize_task_agent( + self, + *, + goal: str, + results: List[str], + ) -> FastAPIStreamingResponse: + return await self.summarize_task_agent(goal=goal, results=results) + + async def pip_chat( + self, + *, + message: str, + results: List[str], + ) -> FastAPIStreamingResponse: + return await self.chat(message=message, results=results) diff --git a/platform/reworkd_platform/web/api/agent/model_factory.py b/platform/reworkd_platform/web/api/agent/model_factory.py index 52644a9143..b89754390a 100644 --- a/platform/reworkd_platform/web/api/agent/model_factory.py +++ b/platform/reworkd_platform/web/api/agent/model_factory.py @@ -13,7 +13,7 @@ class WrappedChatOpenAI(ChatOpenAI): default=None, description="Meta private value but mypy will complain its missing", ) - max_tokens: int + max_tokens: Optional[int] = None model_name: LLM_Model = Field(alias="model") diff --git a/platform/reworkd_platform/web/api/agent/tools/code.py b/platform/reworkd_platform/web/api/agent/tools/code.py index 1b0053f61e..ebda51fb0c 100644 --- a/platform/reworkd_platform/web/api/agent/tools/code.py +++ b/platform/reworkd_platform/web/api/agent/tools/code.py @@ -2,7 +2,7 @@ from fastapi.responses import StreamingResponse as FastAPIStreamingResponse from lanarky.responses import StreamingResponse -from langchain import LLMChain +from ollama import Client # Updated import from reworkd_platform.web.api.agent.tools.tool import Tool @@ -16,10 +16,19 @@ async def call( ) -> FastAPIStreamingResponse: from reworkd_platform.web.api.agent.prompts import code_prompt - chain = LLMChain(llm=self.model, prompt=code_prompt) - - return StreamingResponse.from_chain( - chain, - {"goal": goal, "language": self.language, "task": task}, - media_type="text/event-stream", + client = Client(host='http://localhost:11434') # Specify host if different + response = client.chat( + model="llama3.2", + messages=[ + {"role": "system", "content": code_prompt}, + {"role": "user", "content": input_str} + ], + stream=True, # Enable streaming if required ) + + # Create a generator to yield streaming responses + async def stream_response(): + for chunk in response: + yield chunk['message']['content'] + + return FastAPIStreamingResponse(stream_response(), media_type="text/event-stream") diff --git a/platform/reworkd_platform/web/api/agent/tools/image.py b/platform/reworkd_platform/web/api/agent/tools/image.py index f93444875e..58aeb690fa 100644 --- a/platform/reworkd_platform/web/api/agent/tools/image.py +++ b/platform/reworkd_platform/web/api/agent/tools/image.py @@ -1,47 +1,10 @@ from typing import Any -import openai -import replicate from fastapi.responses import StreamingResponse as FastAPIStreamingResponse -from replicate.exceptions import ModelError -from replicate.exceptions import ReplicateError as ReplicateAPIError +from ollama import Client # Updated import -from reworkd_platform.settings import settings from reworkd_platform.web.api.agent.stream_mock import stream_string from reworkd_platform.web.api.agent.tools.tool import Tool -from reworkd_platform.web.api.errors import ReplicateError - - -async def get_replicate_image(input_str: str) -> str: - if settings.replicate_api_key is None or settings.replicate_api_key == "": - raise RuntimeError("Replicate API key not set") - - client = replicate.Client(settings.replicate_api_key) - try: - output = client.run( - "stability-ai/stable-diffusion" - ":db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf", - input={"prompt": input_str}, - image_dimensions="512x512", - ) - except ModelError as e: - raise ReplicateError(e, "Image generation failed due to NSFW image.") - except ReplicateAPIError as e: - raise ReplicateError(e, "Failed to generate an image.") - - return output[0] - - -# Use AI to generate an Image based on a prompt -async def get_open_ai_image(input_str: str) -> str: - response = openai.Image.create( - api_key=settings.openai_api_key, - prompt=input_str, - n=1, - size="256x256", - ) - - return response["data"][0]["url"] class Image(Tool): @@ -52,15 +15,30 @@ class Image(Tool): "This should be a detailed description of the image touching on image " "style, image focus, color, etc." ) - image_url = "/tools/replicate.png" + image_url = "/tools/ollama.png" async def call( self, goal: str, task: str, input_str: str, *args: Any, **kwargs: Any ) -> FastAPIStreamingResponse: - # Use the replicate API if its available, otherwise use DALL-E - try: - url = await get_replicate_image(input_str) - except RuntimeError: - url = await get_open_ai_image(input_str) + client = Client(host='http://localhost:11434') # Specify host if different + response = client.chat( + model="llama3.2", + messages=[ + {"role": "system", "content": "Generate an image based on the following description."}, + {"role": "user", "content": input_str} + ], + stream=True, + ) + + # Assuming 'chain' returns a URL or some identifier for the generated image + image_url = "" + + async def stream_response(): + nonlocal image_url + for chunk in response: + content = chunk['message']['content'] + image_url += content # Adjust based on actual response structure + + await stream_response() - return stream_string(f"![{input_str}]({url})") + return stream_string(f"![{input_str}]({image_url})") diff --git a/platform/reworkd_platform/web/api/agent/tools/search.py b/platform/reworkd_platform/web/api/agent/tools/search.py index fb83e8213b..31de32d6d9 100644 --- a/platform/reworkd_platform/web/api/agent/tools/search.py +++ b/platform/reworkd_platform/web/api/agent/tools/search.py @@ -1,10 +1,10 @@ +import json from typing import Any, List from urllib.parse import quote -import aiohttp -from aiohttp import ClientResponseError from fastapi.responses import StreamingResponse as FastAPIStreamingResponse from loguru import logger +from ollama import Client # Updated import from reworkd_platform.settings import settings from reworkd_platform.web.api.agent.stream_mock import stream_string @@ -15,28 +15,7 @@ summarize_with_sources, ) -# Search google via serper.dev. Adapted from LangChain -# https://github.com/hwchase17/langchain/blob/master/langchain/utilities - - -async def _google_serper_search_results( - search_term: str, search_type: str = "search" -) -> dict[str, Any]: - headers = { - "X-API-KEY": settings.serp_api_key or "", - "Content-Type": "application/json", - } - params = { - "q": search_term, - } - - async with aiohttp.ClientSession() as session: - async with session.post( - f"https://google.serper.dev/{search_type}", headers=headers, params=params - ) as response: - response.raise_for_status() - search_results = await response.json() - return search_results +# Search Google via Ollama model class Search(Tool): @@ -44,7 +23,7 @@ class Search(Tool): "Search Google for short up to date searches for simple questions about public information " "news and people.\n" ) - public_description = "Search google for information about current events." + public_description = "Search Google for information about current events." arg_description = "The query argument to search for. This value is always populated and cannot be an empty string." image_url = "/tools/google.png" @@ -57,8 +36,8 @@ async def call( ) -> FastAPIStreamingResponse: try: return await self._call(goal, task, input_str, *args, **kwargs) - except ClientResponseError: - logger.exception("Error calling Serper API, falling back to reasoning") + except Exception: + logger.exception("Error calling Ollama model, falling back to reasoning") return await Reason(self.model, self.language).call( goal, task, input_str, *args, **kwargs ) @@ -66,42 +45,59 @@ async def call( async def _call( self, goal: str, task: str, input_str: str, *args: Any, **kwargs: Any ) -> FastAPIStreamingResponse: - results = await _google_serper_search_results( - input_str, + client = Client(host='http://localhost:11434') # Specify host if different + response = client.chat( + model="llama3.2", + messages=[ + {"role": "system", "content": "Perform a Google search based on the following query."}, + {"role": "user", "content": input_str} + ], + stream=True, ) - k = 5 # Number of results to return snippets: List[CitedSnippet] = [] - if results.get("answerBox"): - answer_values = [] - answer_box = results.get("answerBox", {}) - if answer_box.get("answer"): - answer_values.append(answer_box.get("answer")) - elif answer_box.get("snippet"): - answer_values.append(answer_box.get("snippet").replace("\n", " ")) - elif answer_box.get("snippetHighlighted"): - answer_values.append(", ".join(answer_box.get("snippetHighlighted"))) - - if len(answer_values) > 0: - snippets.append( - CitedSnippet( - len(snippets) + 1, - "\n".join(answer_values), - f"https://www.google.com/search?q={quote(input_str)}", - ) - ) - - for i, result in enumerate(results["organic"][:k]): - texts = [] - link = "" - if "snippet" in result: - texts.append(result["snippet"]) - if "link" in result: - link = result["link"] - for attribute, value in result.get("attributes", {}).items(): - texts.append(f"{attribute}: {value}.") - snippets.append(CitedSnippet(len(snippets) + 1, "\n".join(texts), link)) + async def process_response(): + nonlocal snippets + for chunk in response: + message_content = chunk['message']['content'] + # Assuming the API returns JSON-like responses + try: + results = json.loads(message_content) + except json.JSONDecodeError: + continue # Handle or log the error as needed + + if results.get("answerBox"): + answer_values = [] + answer_box = results.get("answerBox", {}) + if answer_box.get("answer"): + answer_values.append(answer_box.get("answer")) + elif answer_box.get("snippet"): + answer_values.append(answer_box.get("snippet").replace("\n", " ")) + elif answer_box.get("snippetHighlighted"): + answer_values.append(", ".join(answer_box.get("snippetHighlighted"))) + + if len(answer_values) > 0: + snippets.append( + CitedSnippet( + len(snippets) + 1, + "\n".join(answer_values), + f"https://www.google.com/search?q={quote(input_str)}", + ) + ) + + for result in results.get("organic", [])[:5]: + texts = [] + link = "" + if "snippet" in result: + texts.append(result["snippet"]) + if "link" in result: + link = result["link"] + for attribute, value in result.get("attributes", {}).items(): + texts.append(f"{attribute}: {value}.") + snippets.append(CitedSnippet(len(snippets) + 1, "\n".join(texts), link)) + + await process_response() if len(snippets) == 0: return stream_string("No good Google Search Result was found", True) diff --git a/platform/reworkd_platform/web/api/agent/tools/sidsearch.py b/platform/reworkd_platform/web/api/agent/tools/sidsearch.py index e99453cd77..9efb4c61c8 100644 --- a/platform/reworkd_platform/web/api/agent/tools/sidsearch.py +++ b/platform/reworkd_platform/web/api/agent/tools/sidsearch.py @@ -2,9 +2,9 @@ from datetime import datetime, timedelta from typing import Any, List, Optional -import aiohttp from fastapi.responses import StreamingResponse as FastAPIStreamingResponse from loguru import logger +from ollama import Client # Updated import from reworkd_platform.db.crud.oauth import OAuthCrud from reworkd_platform.db.models.auth import OauthCredentials @@ -24,15 +24,29 @@ async def _sid_search_results( headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} data = {"query": search_term, "limit": limit} - async with aiohttp.ClientSession() as session: - async with session.post( - "https://api.sid.ai/v1/users/me/query", - headers=headers, - data=json.dumps(data), - ) as response: - response.raise_for_status() - search_results = await response.json() - return search_results + client = Client(host='http://localhost:11434') # Specify host if different + response = client.chat( + model="llama3.2", + messages=[ + {"role": "system", "content": "Search through personal data sources."}, + {"role": "user", "content": search_term} + ], + stream=True, + ) + + search_results = {} + async def process_response(): + nonlocal search_results + for chunk in response: + message_content = chunk['message']['content'] + try: + data = json.loads(message_content) + search_results.update(data) + except json.JSONDecodeError: + continue # Handle or log the error as needed + + await process_response() + return search_results async def token_exchange(refresh_token: str) -> tuple[str, datetime]: @@ -43,14 +57,31 @@ async def token_exchange(refresh_token: str) -> tuple[str, datetime]: "redirect_uri": settings.sid_redirect_uri, "refresh_token": refresh_token, } - async with aiohttp.ClientSession() as session: - async with session.post( - "https://auth.sid.ai/oauth/token", data=data - ) as response: - response.raise_for_status() - response_data = await response.json() - access_token = response_data["access_token"] - expires_in = response_data["expires_in"] + client = Client(host='http://localhost:11434') # Specify host if different + response = client.chat( + model="llama3.2", + messages=[ + {"role": "system", "content": "Exchange refresh token for access token."}, + {"role": "user", "content": json.dumps(data)} + ], + stream=True, + ) + + response_data = {} + async def process_response(): + nonlocal response_data + for chunk in response: + message_content = chunk['message']['content'] + try: + data = json.loads(message_content) + response_data.update(data) + except json.JSONDecodeError: + continue # Handle or log the error as needed + + await process_response() + + access_token = response_data.get("access_token") + expires_in = response_data.get("expires_in") return access_token, datetime.now() + timedelta(seconds=expires_in) @@ -126,7 +157,6 @@ async def _run_sid( return summarize_sid(self.model, self.language, goal, task, snippets) - async def call( self, goal: str, @@ -137,7 +167,7 @@ async def call( *args: Any, **kwargs: Any, ) -> FastAPIStreamingResponse: - # fall back to search if no results are found + # fall back to search if no results are found return await self._run_sid(goal, task, input_str, user, oauth_crud) or await Search(self.model, self.language).call( - goal, task, input_str, user, oauth_crud - ) + goal, task, input_str, user, oauth_crud + ) diff --git a/platform/reworkd_platform/web/api/agent/tools/utils.py b/platform/reworkd_platform/web/api/agent/tools/utils.py index 5a1b7755a4..850f4278de 100644 --- a/platform/reworkd_platform/web/api/agent/tools/utils.py +++ b/platform/reworkd_platform/web/api/agent/tools/utils.py @@ -1,11 +1,9 @@ from dataclasses import dataclass -from typing import List +from typing import List, AsyncGenerator from fastapi.responses import StreamingResponse as FastAPIStreamingResponse from lanarky.responses import StreamingResponse -from langchain import LLMChain -from langchain.chat_models.base import BaseChatModel - +from ollama import Client # Updated import @dataclass class CitedSnippet: @@ -31,29 +29,32 @@ def __repr__(self) -> str: return f"{{text: {self.text}}}" -def summarize( - model: BaseChatModel, +async def summarize( + client: Client, language: str, goal: str, text: str, ) -> FastAPIStreamingResponse: from reworkd_platform.web.api.agent.prompts import summarize_prompt - chain = LLMChain(llm=model, prompt=summarize_prompt) - - return StreamingResponse.from_chain( - chain, - { - "goal": goal, - "language": language, - "text": text, - }, - media_type="text/event-stream", + response = client.chat( + model="llama3.2", + messages=[ + {"role": "system", "content": summarize_prompt}, + {"role": "user", "content": text} + ], + stream=True, ) + async def stream_response(): + for chunk in response: + yield chunk['message']['content'] + + return FastAPIStreamingResponse(stream_response(), media_type="text/event-stream") -def summarize_with_sources( - model: BaseChatModel, + +async def summarize_with_sources( + client: Client, language: str, goal: str, query: str, @@ -61,22 +62,26 @@ def summarize_with_sources( ) -> FastAPIStreamingResponse: from reworkd_platform.web.api.agent.prompts import summarize_with_sources_prompt - chain = LLMChain(llm=model, prompt=summarize_with_sources_prompt) - - return StreamingResponse.from_chain( - chain, - { - "goal": goal, - "query": query, - "language": language, - "snippets": snippets, - }, - media_type="text/event-stream", + combined_snippets = "\n".join([snippet.text for snippet in snippets]) + + response = client.chat( + model="llama3.2", + messages=[ + {"role": "system", "content": summarize_with_sources_prompt}, + {"role": "user", "content": combined_snippets} + ], + stream=True, ) + async def stream_response(): + for chunk in response: + yield chunk['message']['content'] -def summarize_sid( - model: BaseChatModel, + return FastAPIStreamingResponse(stream_response(), media_type="text/event-stream") + + +async def summarize_sid( + client: Client, language: str, goal: str, query: str, @@ -84,15 +89,19 @@ def summarize_sid( ) -> FastAPIStreamingResponse: from reworkd_platform.web.api.agent.prompts import summarize_sid_prompt - chain = LLMChain(llm=model, prompt=summarize_sid_prompt) - - return StreamingResponse.from_chain( - chain, - { - "goal": goal, - "query": query, - "language": language, - "snippets": snippets, - }, - media_type="text/event-stream", + combined_snippets = "\n".join([snippet.text for snippet in snippets]) + + response = client.chat( + model="llama3.2", + messages=[ + {"role": "system", "content": summarize_sid_prompt}, + {"role": "user", "content": combined_snippets} + ], + stream=True, ) + + async def stream_response(): + for chunk in response: + yield chunk['message']['content'] + + return FastAPIStreamingResponse(stream_response(), media_type="text/event-stream") diff --git a/platform/setup.py b/platform/setup.py new file mode 100644 index 0000000000..23c383fd10 --- /dev/null +++ b/platform/setup.py @@ -0,0 +1,65 @@ +from setuptools import setup, find_packages + +setup( + name="reworkd_platform", + version="0.1.0", + description="A platform for reworkd", + author="awtkns, asim-shrestha", + author_email="", + url="https://github.com/reworkd/AgentGPT", + packages=find_packages(), + install_requires=[ + "fastapi==0.98.0", + "boto3==1.28.51", + "uvicorn[standard]==0.22.0", + "pydantic[dotenv]<2.0", + "ujson==5.8.0", + "sqlalchemy[mypy,asyncio]==2.0.21", + "aiomysql==0.1.1", + "mysqlclient==2.2.0", + "sentry-sdk==1.31.0", + "loguru==0.7.2", + "aiokafka==0.8.1", + "requests==2.31.0", + "langchain==0.0.295", + "openai==0.28.0", + "wikipedia==1.4.0", + "replicate==0.8.4", + "lanarky==0.7.15", + "tiktoken==0.5.1", + "grpcio==1.58.0", + "pinecone-client==2.2.4", + "python-multipart==0.0.6", + "aws-secretsmanager-caching==1.1.1.5", + "botocore==1.31.51", + "stripe==5.5.0", + "cryptography==41.0.4", + "httpx==0.25.0", + ], + extras_require={ + "dev": [ + "autopep8==2.0.4", + "pytest==7.4.2", + "flake8==6.0.0", + "mypy==1.5.1", + "isort==5.12.0", + "pre-commit==3.4.0", + "wemake-python-styleguide==0.18.0", + "black==23.9.1", + "autoflake==2.2.1", + "pytest-cov==4.1.0", + "anyio==3.7.1", + "pytest-env==0.8.2", + "dotmap==1.3.30", + "pytest-mock==3.10.0", + "pytest-asyncio==0.21.0", + "types-requests==2.31.0.1", + "types-pytz==2023.3.0.0", + ], + }, + entry_points={ + "console_scripts": [ + "reworkd_platform=reworkd_platform.__main__:main", + ], + }, +)