diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml
index 86443f042c..a6aa909f50 100644
--- a/.github/workflows/python.yml
+++ b/.github/workflows/python.yml
@@ -6,7 +6,7 @@ on:
- 'platform/**'
env:
- PYTHON_VERSION: "3.11"
+ PYTHON_VERSION: "3.10"
jobs:
black:
diff --git a/README.md b/README.md
index 0e5c5d4bc9..8d869bb0be 100644
--- a/README.md
+++ b/README.md
@@ -125,3 +125,16 @@ Our contributors have made this project possible. Thank you! 🙏
+
+## Arguments for OpenAIAgentService
+
+The `OpenAIAgentService` class is defined in `platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py`.
+
+The constructor of `OpenAIAgentService` takes the following arguments:
+
+- `model`: The model to be used.
+- `settings`: The settings for the model.
+- `token_service`: The token service for managing tokens.
+- `callbacks`: Optional list of callback handlers.
+- `user`: The user information.
+- `oauth_crud`: The OAuth CRUD operations.
diff --git a/next/src/server/api/routers/agentRouter.ts b/next/src/server/api/routers/agentRouter.ts
index 05ab7cab75..9ae0186528 100644
--- a/next/src/server/api/routers/agentRouter.ts
+++ b/next/src/server/api/routers/agentRouter.ts
@@ -42,7 +42,7 @@ async function generateAgentName(goal: string) {
`,
},
],
- model: "gpt-3.5-turbo",
+ model: "llama3.2",
});
// @ts-ignore
diff --git a/next/src/stores/modelSettingsStore.ts b/next/src/stores/modelSettingsStore.ts
index c7baa7df52..086cfa2821 100644
--- a/next/src/stores/modelSettingsStore.ts
+++ b/next/src/stores/modelSettingsStore.ts
@@ -42,7 +42,7 @@ export const useModelSettingsStore = createSelectors(
partialize: (state) => ({
modelSettings: {
...state.modelSettings,
- customModelName: "gpt-3.5-turbo",
+ customModelName: "llama3.2",
maxTokens: Math.min(state.modelSettings.maxTokens, 4000),
},
}),
diff --git a/next/src/types/modelSettings.ts b/next/src/types/modelSettings.ts
index a3df4c07e8..b12532b29f 100644
--- a/next/src/types/modelSettings.ts
+++ b/next/src/types/modelSettings.ts
@@ -1,17 +1,19 @@
import { type Language } from "../utils/languages";
-export const [GPT_35_TURBO, GPT_35_TURBO_16K, GPT_4] = [
+export const [GPT_35_TURBO, GPT_35_TURBO_16K, GPT_4, LLAMA_3_2] = [
"gpt-3.5-turbo" as const,
"gpt-3.5-turbo-16k" as const,
"gpt-4" as const,
+ "llama3.2" as const,
];
-export const GPT_MODEL_NAMES = [GPT_35_TURBO, GPT_35_TURBO_16K, GPT_4];
-export type GPTModelNames = "gpt-3.5-turbo" | "gpt-3.5-turbo-16k" | "gpt-4";
+export const GPT_MODEL_NAMES = [GPT_35_TURBO, GPT_35_TURBO_16K, GPT_4, LLAMA_3_2];
+export type GPTModelNames = "gpt-3.5-turbo" | "gpt-3.5-turbo-16k" | "gpt-4" | "llama3.2";
export const MAX_TOKENS: Record = {
"gpt-3.5-turbo": 4000,
"gpt-3.5-turbo-16k": 16000,
"gpt-4": 4000,
+ "llama3.2": 4000,
};
export interface ModelSettings {
diff --git a/platform/Dockerfile b/platform/Dockerfile
index fbdf5fa1ab..fb575e0892 100644
--- a/platform/Dockerfile
+++ b/platform/Dockerfile
@@ -1,4 +1,4 @@
-FROM python:3.11-slim-buster as prod
+FROM python:3.10-slim-buster as prod
RUN apt-get update && apt-get install -y \
default-libmysqlclient-dev \
@@ -30,6 +30,9 @@ RUN apt-get purge -y \
COPY . /app/src/
RUN poetry install --only main
+# Install ollama
+RUN pip install ollama
+
CMD ["/usr/local/bin/python", "-m", "reworkd_platform"]
FROM prod as dev
diff --git a/platform/README.md b/platform/README.md
index db545ea950..6c89c84c19 100644
--- a/platform/README.md
+++ b/platform/README.md
@@ -149,3 +149,248 @@ poetry run pytest -vv --cov="reworkd_platform" .
poetry self add poetry-plugin-up
poetry up --latest
```
+
+## Installing the package using pip
+
+To install the `reworkd_platform` package using pip, run the following command:
+
+```bash
+pip install reworkd_platform
+```
+
+## Using the package in any code
+
+To use the `reworkd_platform` package in your code, you can import it as follows:
+
+```python
+import reworkd_platform
+
+# Example usage
+reworkd_platform.some_function()
+```
+
+## Using pip functions
+
+The `reworkd_platform` package provides several functions for interacting with agents. Here are some examples:
+
+### Starting a goal agent
+
+```python
+from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import OpenAIAgentService
+from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI
+from reworkd_platform.schemas.agent import ModelSettings
+from reworkd_platform.services.tokenizer.token_service import TokenService
+from reworkd_platform.db.crud.oauth import OAuthCrud
+from reworkd_platform.schemas.user import UserBase
+
+# Initialize the OpenAIAgentService
+model = WrappedChatOpenAI(model_name="gpt-3.5-turbo")
+settings = ModelSettings(language="en")
+token_service = TokenService.create()
+callbacks = None
+user = UserBase(id=1, name="John Doe")
+oauth_crud = OAuthCrud()
+
+agent_service = OpenAIAgentService(model, settings, token_service, callbacks, user, oauth_crud)
+
+# Start a goal agent
+tasks = agent_service.pip_start_goal_agent(goal="Your goal here")
+print(tasks)
+```
+
+### Analyzing a task agent
+
+```python
+from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import OpenAIAgentService
+from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI
+from reworkd_platform.schemas.agent import ModelSettings
+from reworkd_platform.services.tokenizer.token_service import TokenService
+from reworkd_platform.db.crud.oauth import OAuthCrud
+from reworkd_platform.schemas.user import UserBase
+
+# Initialize the OpenAIAgentService
+model = WrappedChatOpenAI(model_name="gpt-3.5-turbo")
+settings = ModelSettings(language="en")
+token_service = TokenService.create()
+callbacks = None
+user = UserBase(id=1, name="John Doe")
+oauth_crud = OAuthCrud()
+
+agent_service = OpenAIAgentService(model, settings, token_service, callbacks, user, oauth_crud)
+
+# Analyze a task agent
+analysis = agent_service.pip_analyze_task_agent(goal="Your goal here", task="Your task here", tool_names=["tool1", "tool2"])
+print(analysis)
+```
+
+### Executing a task agent
+
+```python
+from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import OpenAIAgentService
+from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI
+from reworkd_platform.schemas.agent import ModelSettings
+from reworkd_platform.services.tokenizer.token_service import TokenService
+from reworkd_platform.db.crud.oauth import OAuthCrud
+from reworkd_platform.schemas.user import UserBase
+
+# Initialize the OpenAIAgentService
+model = WrappedChatOpenAI(model_name="gpt-3.5-turbo")
+settings = ModelSettings(language="en")
+token_service = TokenService.create()
+callbacks = None
+user = UserBase(id=1, name="John Doe")
+oauth_crud = OAuthCrud()
+
+agent_service = OpenAIAgentService(model, settings, token_service, callbacks, user, oauth_crud)
+
+# Execute a task agent
+response = agent_service.pip_execute_task_agent(goal="Your goal here", task="Your task here", analysis=analysis)
+print(response)
+```
+
+### Creating tasks agent
+
+```python
+from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import OpenAIAgentService
+from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI
+from reworkd_platform.schemas.agent import ModelSettings
+from reworkd_platform.services.tokenizer.token_service import TokenService
+from reworkd_platform.db.crud.oauth import OAuthCrud
+from reworkd_platform.schemas.user import UserBase
+
+# Initialize the OpenAIAgentService
+model = WrappedChatOpenAI(model_name="gpt-3.5-turbo")
+settings = ModelSettings(language="en")
+token_service = TokenService.create()
+callbacks = None
+user = UserBase(id=1, name="John Doe")
+oauth_crud = OAuthCrud()
+
+agent_service = OpenAIAgentService(model, settings, token_service, callbacks, user, oauth_crud)
+
+# Create tasks agent
+tasks = agent_service.pip_create_tasks_agent(goal="Your goal here", tasks=["task1", "task2"], last_task="Your last task here", result="Your result here")
+print(tasks)
+```
+
+### Summarizing task agent
+
+```python
+from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import OpenAIAgentService
+from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI
+from reworkd_platform.schemas.agent import ModelSettings
+from reworkd_platform.services.tokenizer.token_service import TokenService
+from reworkd_platform.db.crud.oauth import OAuthCrud
+from reworkd_platform.schemas.user import UserBase
+
+# Initialize the OpenAIAgentService
+model = WrappedChatOpenAI(model_name="gpt-3.5-turbo")
+settings = ModelSettings(language="en")
+token_service = TokenService.create()
+callbacks = None
+user = UserBase(id=1, name="John Doe")
+oauth_crud = OAuthCrud()
+
+agent_service = OpenAIAgentService(model, settings, token_service, callbacks, user, oauth_crud)
+
+# Summarize task agent
+response = agent_service.pip_summarize_task_agent(goal="Your goal here", results=["result1", "result2"])
+print(response)
+```
+
+### Chatting with agent
+
+```python
+from reworkd_platform.web.api.agent.agent_service.open_ai_agent_service import OpenAIAgentService
+from reworkd_platform.web.api.agent.model_factory import WrappedChatOpenAI
+from reworkd_platform.schemas.agent import ModelSettings
+from reworkd_platform.services.tokenizer.token_service import TokenService
+from reworkd_platform.db.crud.oauth import OAuthCrud
+from reworkd_platform.schemas.user import UserBase
+
+# Initialize the OpenAIAgentService
+model = WrappedChatOpenAI(model_name="gpt-3.5-turbo")
+settings = ModelSettings(language="en")
+token_service = TokenService.create()
+callbacks = None
+user = UserBase(id=1, name="John Doe")
+oauth_crud = OAuthCrud()
+
+agent_service = OpenAIAgentService(model, settings, token_service, callbacks, user, oauth_crud)
+
+# Chat with agent
+response = agent_service.pip_chat(message="Your message here", results=["result1", "result2"])
+print(response)
+```
+
+## Using ollama
+
+The `reworkd_platform` package also provides support for `ollama`. Here are some examples:
+
+### Adding ollama as a dependency
+
+To add `ollama` as a dependency, include it in your `pyproject.toml` file under `[tool.poetry.dependencies]`:
+
+```toml
+[tool.poetry.dependencies]
+ollama = "^0.1.0"
+```
+
+### Installing ollama in Docker
+
+To install `ollama` in the Docker image, add the following command to your `Dockerfile`:
+
+```dockerfile
+# Install ollama
+RUN pip install ollama
+```
+
+### Using ollama in your code
+
+To use `ollama` in your code, you can import it as follows:
+
+```python
+import ollama
+
+# Example usage
+model = ollama.Ollama(model="llama3.2")
+chain = model.create_chain(prompt="Your prompt here")
+response = chain.run("Your input here")
+print(response)
+```
+
+## Using Python 3.10
+
+The `reworkd_platform` package is compatible with Python 3.10. Here are some examples:
+
+### Specifying Python 3.10 in `pyproject.toml`
+
+To specify Python 3.10 as the required version, include the following in your `pyproject.toml` file:
+
+```toml
+[tool.poetry.dependencies]
+python = "^3.10"
+```
+
+### Using Python 3.10 in Docker
+
+To use Python 3.10 in the Docker image, update the base image in your `Dockerfile`:
+
+```dockerfile
+FROM python:3.10-slim-buster as prod
+```
+
+### Running the project with Python 3.10
+
+To run the project with Python 3.10, make sure you have Python 3.10 installed on your system. You can download and install Python 3.10 from the official Python website: https://www.python.org/downloads/release/python-3100/
+
+Once you have Python 3.10 installed, you can create a virtual environment and install the dependencies using Poetry:
+
+```bash
+python3.10 -m venv venv
+source venv/bin/activate
+poetry install
+poetry run python -m reworkd_platform
+```
+
+This will start the server on the configured host using Python 3.10.
diff --git a/platform/pyproject.toml b/platform/pyproject.toml
index e48072e813..1654c7b84c 100644
--- a/platform/pyproject.toml
+++ b/platform/pyproject.toml
@@ -14,7 +14,7 @@ maintainers = [
readme = "README.md"
[tool.poetry.dependencies]
-python = "^3.11"
+python = "^3.10"
fastapi = "^0.98.0"
boto3 = "^1.28.51"
uvicorn = { version = "^0.22.0", extras = ["standard"] }
@@ -41,6 +41,7 @@ botocore = "^1.31.51"
stripe = "^5.5.0"
cryptography = "^41.0.4"
httpx = "^0.25.0"
+ollama = "^0.1.0"
[tool.poetry.dev-dependencies]
@@ -96,5 +97,5 @@ env = [
]
[build-system]
-requires = ["poetry-core>=1.0.0"]
-build-backend = "poetry.core.masonry.api"
+requires = ["poetry-core>=1.0.0", "setuptools", "wheel"]
+build-backend = "setuptools.build_meta"
diff --git a/platform/reworkd_platform/db/crud/oauth.py b/platform/reworkd_platform/db/crud/oauth.py
index 911bd01207..8e263c5e08 100644
--- a/platform/reworkd_platform/db/crud/oauth.py
+++ b/platform/reworkd_platform/db/crud/oauth.py
@@ -12,6 +12,9 @@
class OAuthCrud(BaseCrud):
+ def __init__(self, session: AsyncSession):
+ super().__init__(session)
+
@classmethod
async def inject(
cls,
diff --git a/platform/reworkd_platform/db/crud/user.py b/platform/reworkd_platform/db/crud/user.py
index 2e9f7111ce..a9db35a20a 100644
--- a/platform/reworkd_platform/db/crud/user.py
+++ b/platform/reworkd_platform/db/crud/user.py
@@ -6,9 +6,13 @@
from reworkd_platform.db.crud.base import BaseCrud
from reworkd_platform.db.models.auth import OrganizationUser
from reworkd_platform.db.models.user import UserSession
+from sqlalchemy.ext.asyncio import AsyncSession
class UserCrud(BaseCrud):
+ def __init__(self, session: AsyncSession):
+ super().__init__(session)
+
async def get_user_session(self, token: str) -> UserSession:
query = (
select(UserSession)
diff --git a/platform/reworkd_platform/schemas/agent.py b/platform/reworkd_platform/schemas/agent.py
index f6c5e6e732..86034c00ce 100644
--- a/platform/reworkd_platform/schemas/agent.py
+++ b/platform/reworkd_platform/schemas/agent.py
@@ -9,6 +9,7 @@
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-4",
+ "llama3.2",
]
Loop_Step = Literal[
"start",
@@ -22,6 +23,7 @@
"gpt-3.5-turbo": 4000,
"gpt-3.5-turbo-16k": 16000,
"gpt-4": 8000,
+ "llama3.2": 4000,
}
diff --git a/platform/reworkd_platform/tests/agent/test_model_factory.py b/platform/reworkd_platform/tests/agent/test_model_factory.py
index f6579897a4..b937fbb40b 100644
--- a/platform/reworkd_platform/tests/agent/test_model_factory.py
+++ b/platform/reworkd_platform/tests/agent/test_model_factory.py
@@ -136,3 +136,28 @@ def test_custom_model_settings(model_settings: ModelSettings, streaming: bool):
assert model.model_name.startswith(model_settings.model)
assert model.max_tokens == model_settings.max_tokens
assert model.streaming == streaming
+
+
+def test_create_model_without_max_tokens():
+ user = UserBase(id="user_id")
+ settings = Settings()
+ model_settings = ModelSettings(
+ temperature=0.7,
+ model="gpt-3.5-turbo",
+ )
+
+ settings.openai_api_base = "https://api.openai.com"
+ settings.openai_api_key = "key"
+ settings.openai_api_version = "version"
+
+ result = create_model(settings, model_settings, user, streaming=False)
+ assert issubclass(result.__class__, WrappedChatOpenAI)
+ assert issubclass(result.__class__, ChatOpenAI)
+
+ # Check if the required keys are set properly
+ assert result.openai_api_base == settings.openai_api_base
+ assert result.openai_api_key == settings.openai_api_key
+ assert result.temperature == model_settings.temperature
+ assert result.max_tokens is None
+ assert result.streaming is False
+ assert result.max_retries == 5
diff --git a/platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py b/platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py
index 3221a3fbfc..2414aabfe1 100644
--- a/platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py
+++ b/platform/reworkd_platform/web/api/agent/agent_service/open_ai_agent_service.py
@@ -2,11 +2,13 @@
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
from lanarky.responses import StreamingResponse
-from langchain import LLMChain
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.output_parsers import PydanticOutputParser
-from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate
-from langchain.schema import HumanMessage
+from langchain.prompts.chat import (
+ ChatPromptTemplate,
+ SystemMessagePromptTemplate,
+ HumanMessagePromptTemplate
+)
from loguru import logger
from pydantic import ValidationError
@@ -38,6 +40,7 @@
)
from reworkd_platform.web.api.agent.tools.utils import summarize
from reworkd_platform.web.api.errors import OpenAIError
+from ollama import AsyncClient # Updated import
class OpenAIAgentService(AgentService):
@@ -56,6 +59,8 @@ def __init__(
self.callbacks = callbacks
self.user = user
self.oauth_crud = oauth_crud
+ # Initialize the Async Ollama client once
+ self.client = AsyncClient(host='http://localhost:11434') # Use environment variables for flexibility
async def start_goal_agent(self, *, goal: str) -> List[str]:
prompt = ChatPromptTemplate.from_messages(
@@ -131,7 +136,7 @@ async def execute_task_agent(
analysis: Analysis,
) -> StreamingResponse:
# TODO: More mature way of calculating max_tokens
- if self.model.max_tokens > 3000:
+ if self.model.max_tokens and self.model.max_tokens > 3000:
self.model.max_tokens = max(self.model.max_tokens - 1000, 3000)
tool_class = get_tool_from_name(analysis.action)
@@ -181,7 +186,7 @@ async def summarize_task_agent(
goal: str,
results: List[str],
) -> FastAPIStreamingResponse:
- self.model.model_name = "gpt-3.5-turbo-16k"
+ self.model.model_name = "llama3.2"
self.model.max_tokens = 8000 # Total tokens = prompt tokens + completion tokens
snippet_max_tokens = 7000 # Leave room for the rest of the prompt
@@ -189,8 +194,8 @@ async def summarize_task_agent(
text = self.token_service.detokenize(text_tokens[0:snippet_max_tokens])
logger.info(f"Summarizing text: {text}")
- return summarize(
- model=self.model,
+ return await summarize(
+ client=self.client, # Pass the initialized AsyncClient
language=self.settings.language,
goal=goal,
text=text,
@@ -202,12 +207,12 @@ async def chat(
message: str,
results: List[str],
) -> FastAPIStreamingResponse:
- self.model.model_name = "gpt-3.5-turbo-16k"
+ self.model.model_name = "llama3.2"
prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate(prompt=chat_prompt),
- *[HumanMessage(content=result) for result in results],
- HumanMessage(content=message),
+ *[HumanMessagePromptTemplate.from_template(result) for result in results],
+ HumanMessagePromptTemplate.from_template(message),
]
)
@@ -218,10 +223,82 @@ async def chat(
).to_string(),
)
- chain = LLMChain(llm=self.model, prompt=prompt)
+ # Format the prompt and extract messages
+ formatted_prompt = prompt.format_prompt()
+ messages = [
+ {'role': getattr(msg, 'role'), 'content': getattr(msg, 'content')}
+ for msg in formatted_prompt.to_messages()
+ ]
+
+ try:
+ # Make the chat request with streaming
+ response = await self.client.chat(
+ model="llama3.2",
+ messages=messages,
+ stream=True,
+ )
+ except Exception as e:
+ logger.exception("Error during Ollama chat request.")
+ # Handle specific exceptions if necessary
+ raise
+
+ # Define an asynchronous generator to yield streamed responses
+ async def stream_response():
+ async for chunk in response:
+ # Ensure 'message' and 'content' keys exist
+ if 'message' in chunk and 'content' in chunk['message']:
+ yield chunk['message']['content']
+
+ return FastAPIStreamingResponse(stream_response(), media_type="text/event-stream")
+
+ # The remaining methods remain unchanged but ensure that any usage of 'Ollama' is replaced accordingly.
+
+ async def pip_start_goal_agent(self, *, goal: str) -> List[str]:
+ return await self.start_goal_agent(goal=goal)
+
+ async def pip_analyze_task_agent(
+ self, *, goal: str, task: str, tool_names: List[str]
+ ) -> Analysis:
+ return await self.analyze_task_agent(goal=goal, task=task, tool_names=tool_names)
+
+ async def pip_execute_task_agent(
+ self,
+ *,
+ goal: str,
+ task: str,
+ analysis: Analysis,
+ ) -> StreamingResponse:
+ return await self.execute_task_agent(goal=goal, task=task, analysis=analysis)
- return StreamingResponse.from_chain(
- chain,
- {"language": self.settings.language},
- media_type="text/event-stream",
+ async def pip_create_tasks_agent(
+ self,
+ *,
+ goal: str,
+ tasks: List[str],
+ last_task: str,
+ result: str,
+ completed_tasks: Optional[List[str]] = None,
+ ) -> List[str]:
+ return await self.create_tasks_agent(
+ goal=goal,
+ tasks=tasks,
+ last_task=last_task,
+ result=result,
+ completed_tasks=completed_tasks,
)
+
+ async def pip_summarize_task_agent(
+ self,
+ *,
+ goal: str,
+ results: List[str],
+ ) -> FastAPIStreamingResponse:
+ return await self.summarize_task_agent(goal=goal, results=results)
+
+ async def pip_chat(
+ self,
+ *,
+ message: str,
+ results: List[str],
+ ) -> FastAPIStreamingResponse:
+ return await self.chat(message=message, results=results)
diff --git a/platform/reworkd_platform/web/api/agent/model_factory.py b/platform/reworkd_platform/web/api/agent/model_factory.py
index 52644a9143..b89754390a 100644
--- a/platform/reworkd_platform/web/api/agent/model_factory.py
+++ b/platform/reworkd_platform/web/api/agent/model_factory.py
@@ -13,7 +13,7 @@ class WrappedChatOpenAI(ChatOpenAI):
default=None,
description="Meta private value but mypy will complain its missing",
)
- max_tokens: int
+ max_tokens: Optional[int] = None
model_name: LLM_Model = Field(alias="model")
diff --git a/platform/reworkd_platform/web/api/agent/tools/code.py b/platform/reworkd_platform/web/api/agent/tools/code.py
index 1b0053f61e..ebda51fb0c 100644
--- a/platform/reworkd_platform/web/api/agent/tools/code.py
+++ b/platform/reworkd_platform/web/api/agent/tools/code.py
@@ -2,7 +2,7 @@
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
from lanarky.responses import StreamingResponse
-from langchain import LLMChain
+from ollama import Client # Updated import
from reworkd_platform.web.api.agent.tools.tool import Tool
@@ -16,10 +16,19 @@ async def call(
) -> FastAPIStreamingResponse:
from reworkd_platform.web.api.agent.prompts import code_prompt
- chain = LLMChain(llm=self.model, prompt=code_prompt)
-
- return StreamingResponse.from_chain(
- chain,
- {"goal": goal, "language": self.language, "task": task},
- media_type="text/event-stream",
+ client = Client(host='http://localhost:11434') # Specify host if different
+ response = client.chat(
+ model="llama3.2",
+ messages=[
+ {"role": "system", "content": code_prompt},
+ {"role": "user", "content": input_str}
+ ],
+ stream=True, # Enable streaming if required
)
+
+ # Create a generator to yield streaming responses
+ async def stream_response():
+ for chunk in response:
+ yield chunk['message']['content']
+
+ return FastAPIStreamingResponse(stream_response(), media_type="text/event-stream")
diff --git a/platform/reworkd_platform/web/api/agent/tools/image.py b/platform/reworkd_platform/web/api/agent/tools/image.py
index f93444875e..58aeb690fa 100644
--- a/platform/reworkd_platform/web/api/agent/tools/image.py
+++ b/platform/reworkd_platform/web/api/agent/tools/image.py
@@ -1,47 +1,10 @@
from typing import Any
-import openai
-import replicate
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
-from replicate.exceptions import ModelError
-from replicate.exceptions import ReplicateError as ReplicateAPIError
+from ollama import Client # Updated import
-from reworkd_platform.settings import settings
from reworkd_platform.web.api.agent.stream_mock import stream_string
from reworkd_platform.web.api.agent.tools.tool import Tool
-from reworkd_platform.web.api.errors import ReplicateError
-
-
-async def get_replicate_image(input_str: str) -> str:
- if settings.replicate_api_key is None or settings.replicate_api_key == "":
- raise RuntimeError("Replicate API key not set")
-
- client = replicate.Client(settings.replicate_api_key)
- try:
- output = client.run(
- "stability-ai/stable-diffusion"
- ":db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf",
- input={"prompt": input_str},
- image_dimensions="512x512",
- )
- except ModelError as e:
- raise ReplicateError(e, "Image generation failed due to NSFW image.")
- except ReplicateAPIError as e:
- raise ReplicateError(e, "Failed to generate an image.")
-
- return output[0]
-
-
-# Use AI to generate an Image based on a prompt
-async def get_open_ai_image(input_str: str) -> str:
- response = openai.Image.create(
- api_key=settings.openai_api_key,
- prompt=input_str,
- n=1,
- size="256x256",
- )
-
- return response["data"][0]["url"]
class Image(Tool):
@@ -52,15 +15,30 @@ class Image(Tool):
"This should be a detailed description of the image touching on image "
"style, image focus, color, etc."
)
- image_url = "/tools/replicate.png"
+ image_url = "/tools/ollama.png"
async def call(
self, goal: str, task: str, input_str: str, *args: Any, **kwargs: Any
) -> FastAPIStreamingResponse:
- # Use the replicate API if its available, otherwise use DALL-E
- try:
- url = await get_replicate_image(input_str)
- except RuntimeError:
- url = await get_open_ai_image(input_str)
+ client = Client(host='http://localhost:11434') # Specify host if different
+ response = client.chat(
+ model="llama3.2",
+ messages=[
+ {"role": "system", "content": "Generate an image based on the following description."},
+ {"role": "user", "content": input_str}
+ ],
+ stream=True,
+ )
+
+ # Assuming 'chain' returns a URL or some identifier for the generated image
+ image_url = ""
+
+ async def stream_response():
+ nonlocal image_url
+ for chunk in response:
+ content = chunk['message']['content']
+ image_url += content # Adjust based on actual response structure
+
+ await stream_response()
- return stream_string(f"")
+ return stream_string(f"")
diff --git a/platform/reworkd_platform/web/api/agent/tools/search.py b/platform/reworkd_platform/web/api/agent/tools/search.py
index fb83e8213b..31de32d6d9 100644
--- a/platform/reworkd_platform/web/api/agent/tools/search.py
+++ b/platform/reworkd_platform/web/api/agent/tools/search.py
@@ -1,10 +1,10 @@
+import json
from typing import Any, List
from urllib.parse import quote
-import aiohttp
-from aiohttp import ClientResponseError
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
from loguru import logger
+from ollama import Client # Updated import
from reworkd_platform.settings import settings
from reworkd_platform.web.api.agent.stream_mock import stream_string
@@ -15,28 +15,7 @@
summarize_with_sources,
)
-# Search google via serper.dev. Adapted from LangChain
-# https://github.com/hwchase17/langchain/blob/master/langchain/utilities
-
-
-async def _google_serper_search_results(
- search_term: str, search_type: str = "search"
-) -> dict[str, Any]:
- headers = {
- "X-API-KEY": settings.serp_api_key or "",
- "Content-Type": "application/json",
- }
- params = {
- "q": search_term,
- }
-
- async with aiohttp.ClientSession() as session:
- async with session.post(
- f"https://google.serper.dev/{search_type}", headers=headers, params=params
- ) as response:
- response.raise_for_status()
- search_results = await response.json()
- return search_results
+# Search Google via Ollama model
class Search(Tool):
@@ -44,7 +23,7 @@ class Search(Tool):
"Search Google for short up to date searches for simple questions about public information "
"news and people.\n"
)
- public_description = "Search google for information about current events."
+ public_description = "Search Google for information about current events."
arg_description = "The query argument to search for. This value is always populated and cannot be an empty string."
image_url = "/tools/google.png"
@@ -57,8 +36,8 @@ async def call(
) -> FastAPIStreamingResponse:
try:
return await self._call(goal, task, input_str, *args, **kwargs)
- except ClientResponseError:
- logger.exception("Error calling Serper API, falling back to reasoning")
+ except Exception:
+ logger.exception("Error calling Ollama model, falling back to reasoning")
return await Reason(self.model, self.language).call(
goal, task, input_str, *args, **kwargs
)
@@ -66,42 +45,59 @@ async def call(
async def _call(
self, goal: str, task: str, input_str: str, *args: Any, **kwargs: Any
) -> FastAPIStreamingResponse:
- results = await _google_serper_search_results(
- input_str,
+ client = Client(host='http://localhost:11434') # Specify host if different
+ response = client.chat(
+ model="llama3.2",
+ messages=[
+ {"role": "system", "content": "Perform a Google search based on the following query."},
+ {"role": "user", "content": input_str}
+ ],
+ stream=True,
)
- k = 5 # Number of results to return
snippets: List[CitedSnippet] = []
- if results.get("answerBox"):
- answer_values = []
- answer_box = results.get("answerBox", {})
- if answer_box.get("answer"):
- answer_values.append(answer_box.get("answer"))
- elif answer_box.get("snippet"):
- answer_values.append(answer_box.get("snippet").replace("\n", " "))
- elif answer_box.get("snippetHighlighted"):
- answer_values.append(", ".join(answer_box.get("snippetHighlighted")))
-
- if len(answer_values) > 0:
- snippets.append(
- CitedSnippet(
- len(snippets) + 1,
- "\n".join(answer_values),
- f"https://www.google.com/search?q={quote(input_str)}",
- )
- )
-
- for i, result in enumerate(results["organic"][:k]):
- texts = []
- link = ""
- if "snippet" in result:
- texts.append(result["snippet"])
- if "link" in result:
- link = result["link"]
- for attribute, value in result.get("attributes", {}).items():
- texts.append(f"{attribute}: {value}.")
- snippets.append(CitedSnippet(len(snippets) + 1, "\n".join(texts), link))
+ async def process_response():
+ nonlocal snippets
+ for chunk in response:
+ message_content = chunk['message']['content']
+ # Assuming the API returns JSON-like responses
+ try:
+ results = json.loads(message_content)
+ except json.JSONDecodeError:
+ continue # Handle or log the error as needed
+
+ if results.get("answerBox"):
+ answer_values = []
+ answer_box = results.get("answerBox", {})
+ if answer_box.get("answer"):
+ answer_values.append(answer_box.get("answer"))
+ elif answer_box.get("snippet"):
+ answer_values.append(answer_box.get("snippet").replace("\n", " "))
+ elif answer_box.get("snippetHighlighted"):
+ answer_values.append(", ".join(answer_box.get("snippetHighlighted")))
+
+ if len(answer_values) > 0:
+ snippets.append(
+ CitedSnippet(
+ len(snippets) + 1,
+ "\n".join(answer_values),
+ f"https://www.google.com/search?q={quote(input_str)}",
+ )
+ )
+
+ for result in results.get("organic", [])[:5]:
+ texts = []
+ link = ""
+ if "snippet" in result:
+ texts.append(result["snippet"])
+ if "link" in result:
+ link = result["link"]
+ for attribute, value in result.get("attributes", {}).items():
+ texts.append(f"{attribute}: {value}.")
+ snippets.append(CitedSnippet(len(snippets) + 1, "\n".join(texts), link))
+
+ await process_response()
if len(snippets) == 0:
return stream_string("No good Google Search Result was found", True)
diff --git a/platform/reworkd_platform/web/api/agent/tools/sidsearch.py b/platform/reworkd_platform/web/api/agent/tools/sidsearch.py
index e99453cd77..9efb4c61c8 100644
--- a/platform/reworkd_platform/web/api/agent/tools/sidsearch.py
+++ b/platform/reworkd_platform/web/api/agent/tools/sidsearch.py
@@ -2,9 +2,9 @@
from datetime import datetime, timedelta
from typing import Any, List, Optional
-import aiohttp
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
from loguru import logger
+from ollama import Client # Updated import
from reworkd_platform.db.crud.oauth import OAuthCrud
from reworkd_platform.db.models.auth import OauthCredentials
@@ -24,15 +24,29 @@ async def _sid_search_results(
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
data = {"query": search_term, "limit": limit}
- async with aiohttp.ClientSession() as session:
- async with session.post(
- "https://api.sid.ai/v1/users/me/query",
- headers=headers,
- data=json.dumps(data),
- ) as response:
- response.raise_for_status()
- search_results = await response.json()
- return search_results
+ client = Client(host='http://localhost:11434') # Specify host if different
+ response = client.chat(
+ model="llama3.2",
+ messages=[
+ {"role": "system", "content": "Search through personal data sources."},
+ {"role": "user", "content": search_term}
+ ],
+ stream=True,
+ )
+
+ search_results = {}
+ async def process_response():
+ nonlocal search_results
+ for chunk in response:
+ message_content = chunk['message']['content']
+ try:
+ data = json.loads(message_content)
+ search_results.update(data)
+ except json.JSONDecodeError:
+ continue # Handle or log the error as needed
+
+ await process_response()
+ return search_results
async def token_exchange(refresh_token: str) -> tuple[str, datetime]:
@@ -43,14 +57,31 @@ async def token_exchange(refresh_token: str) -> tuple[str, datetime]:
"redirect_uri": settings.sid_redirect_uri,
"refresh_token": refresh_token,
}
- async with aiohttp.ClientSession() as session:
- async with session.post(
- "https://auth.sid.ai/oauth/token", data=data
- ) as response:
- response.raise_for_status()
- response_data = await response.json()
- access_token = response_data["access_token"]
- expires_in = response_data["expires_in"]
+ client = Client(host='http://localhost:11434') # Specify host if different
+ response = client.chat(
+ model="llama3.2",
+ messages=[
+ {"role": "system", "content": "Exchange refresh token for access token."},
+ {"role": "user", "content": json.dumps(data)}
+ ],
+ stream=True,
+ )
+
+ response_data = {}
+ async def process_response():
+ nonlocal response_data
+ for chunk in response:
+ message_content = chunk['message']['content']
+ try:
+ data = json.loads(message_content)
+ response_data.update(data)
+ except json.JSONDecodeError:
+ continue # Handle or log the error as needed
+
+ await process_response()
+
+ access_token = response_data.get("access_token")
+ expires_in = response_data.get("expires_in")
return access_token, datetime.now() + timedelta(seconds=expires_in)
@@ -126,7 +157,6 @@ async def _run_sid(
return summarize_sid(self.model, self.language, goal, task, snippets)
-
async def call(
self,
goal: str,
@@ -137,7 +167,7 @@ async def call(
*args: Any,
**kwargs: Any,
) -> FastAPIStreamingResponse:
- # fall back to search if no results are found
+ # fall back to search if no results are found
return await self._run_sid(goal, task, input_str, user, oauth_crud) or await Search(self.model, self.language).call(
- goal, task, input_str, user, oauth_crud
- )
+ goal, task, input_str, user, oauth_crud
+ )
diff --git a/platform/reworkd_platform/web/api/agent/tools/utils.py b/platform/reworkd_platform/web/api/agent/tools/utils.py
index 5a1b7755a4..850f4278de 100644
--- a/platform/reworkd_platform/web/api/agent/tools/utils.py
+++ b/platform/reworkd_platform/web/api/agent/tools/utils.py
@@ -1,11 +1,9 @@
from dataclasses import dataclass
-from typing import List
+from typing import List, AsyncGenerator
from fastapi.responses import StreamingResponse as FastAPIStreamingResponse
from lanarky.responses import StreamingResponse
-from langchain import LLMChain
-from langchain.chat_models.base import BaseChatModel
-
+from ollama import Client # Updated import
@dataclass
class CitedSnippet:
@@ -31,29 +29,32 @@ def __repr__(self) -> str:
return f"{{text: {self.text}}}"
-def summarize(
- model: BaseChatModel,
+async def summarize(
+ client: Client,
language: str,
goal: str,
text: str,
) -> FastAPIStreamingResponse:
from reworkd_platform.web.api.agent.prompts import summarize_prompt
- chain = LLMChain(llm=model, prompt=summarize_prompt)
-
- return StreamingResponse.from_chain(
- chain,
- {
- "goal": goal,
- "language": language,
- "text": text,
- },
- media_type="text/event-stream",
+ response = client.chat(
+ model="llama3.2",
+ messages=[
+ {"role": "system", "content": summarize_prompt},
+ {"role": "user", "content": text}
+ ],
+ stream=True,
)
+ async def stream_response():
+ for chunk in response:
+ yield chunk['message']['content']
+
+ return FastAPIStreamingResponse(stream_response(), media_type="text/event-stream")
-def summarize_with_sources(
- model: BaseChatModel,
+
+async def summarize_with_sources(
+ client: Client,
language: str,
goal: str,
query: str,
@@ -61,22 +62,26 @@ def summarize_with_sources(
) -> FastAPIStreamingResponse:
from reworkd_platform.web.api.agent.prompts import summarize_with_sources_prompt
- chain = LLMChain(llm=model, prompt=summarize_with_sources_prompt)
-
- return StreamingResponse.from_chain(
- chain,
- {
- "goal": goal,
- "query": query,
- "language": language,
- "snippets": snippets,
- },
- media_type="text/event-stream",
+ combined_snippets = "\n".join([snippet.text for snippet in snippets])
+
+ response = client.chat(
+ model="llama3.2",
+ messages=[
+ {"role": "system", "content": summarize_with_sources_prompt},
+ {"role": "user", "content": combined_snippets}
+ ],
+ stream=True,
)
+ async def stream_response():
+ for chunk in response:
+ yield chunk['message']['content']
-def summarize_sid(
- model: BaseChatModel,
+ return FastAPIStreamingResponse(stream_response(), media_type="text/event-stream")
+
+
+async def summarize_sid(
+ client: Client,
language: str,
goal: str,
query: str,
@@ -84,15 +89,19 @@ def summarize_sid(
) -> FastAPIStreamingResponse:
from reworkd_platform.web.api.agent.prompts import summarize_sid_prompt
- chain = LLMChain(llm=model, prompt=summarize_sid_prompt)
-
- return StreamingResponse.from_chain(
- chain,
- {
- "goal": goal,
- "query": query,
- "language": language,
- "snippets": snippets,
- },
- media_type="text/event-stream",
+ combined_snippets = "\n".join([snippet.text for snippet in snippets])
+
+ response = client.chat(
+ model="llama3.2",
+ messages=[
+ {"role": "system", "content": summarize_sid_prompt},
+ {"role": "user", "content": combined_snippets}
+ ],
+ stream=True,
)
+
+ async def stream_response():
+ for chunk in response:
+ yield chunk['message']['content']
+
+ return FastAPIStreamingResponse(stream_response(), media_type="text/event-stream")
diff --git a/platform/setup.py b/platform/setup.py
new file mode 100644
index 0000000000..23c383fd10
--- /dev/null
+++ b/platform/setup.py
@@ -0,0 +1,65 @@
+from setuptools import setup, find_packages
+
+setup(
+ name="reworkd_platform",
+ version="0.1.0",
+ description="A platform for reworkd",
+ author="awtkns, asim-shrestha",
+ author_email="",
+ url="https://github.com/reworkd/AgentGPT",
+ packages=find_packages(),
+ install_requires=[
+ "fastapi==0.98.0",
+ "boto3==1.28.51",
+ "uvicorn[standard]==0.22.0",
+ "pydantic[dotenv]<2.0",
+ "ujson==5.8.0",
+ "sqlalchemy[mypy,asyncio]==2.0.21",
+ "aiomysql==0.1.1",
+ "mysqlclient==2.2.0",
+ "sentry-sdk==1.31.0",
+ "loguru==0.7.2",
+ "aiokafka==0.8.1",
+ "requests==2.31.0",
+ "langchain==0.0.295",
+ "openai==0.28.0",
+ "wikipedia==1.4.0",
+ "replicate==0.8.4",
+ "lanarky==0.7.15",
+ "tiktoken==0.5.1",
+ "grpcio==1.58.0",
+ "pinecone-client==2.2.4",
+ "python-multipart==0.0.6",
+ "aws-secretsmanager-caching==1.1.1.5",
+ "botocore==1.31.51",
+ "stripe==5.5.0",
+ "cryptography==41.0.4",
+ "httpx==0.25.0",
+ ],
+ extras_require={
+ "dev": [
+ "autopep8==2.0.4",
+ "pytest==7.4.2",
+ "flake8==6.0.0",
+ "mypy==1.5.1",
+ "isort==5.12.0",
+ "pre-commit==3.4.0",
+ "wemake-python-styleguide==0.18.0",
+ "black==23.9.1",
+ "autoflake==2.2.1",
+ "pytest-cov==4.1.0",
+ "anyio==3.7.1",
+ "pytest-env==0.8.2",
+ "dotmap==1.3.30",
+ "pytest-mock==3.10.0",
+ "pytest-asyncio==0.21.0",
+ "types-requests==2.31.0.1",
+ "types-pytz==2023.3.0.0",
+ ],
+ },
+ entry_points={
+ "console_scripts": [
+ "reworkd_platform=reworkd_platform.__main__:main",
+ ],
+ },
+)