Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 129 additions & 1 deletion superagi/tools/base_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,18 @@
from inspect import signature
from typing import List
from typing import Optional, Type, Callable, Any, Union, Dict, Tuple

from enum import Enum
import yaml
from pydantic import BaseModel, create_model, validate_arguments, Extra

from superagi.types.key_type import ToolConfigKeyType
import os
from sqlalchemy.orm import Session
import csv
from superagi.helper.s3_helper import S3Helper
from superagi.lib.logger import logger
from superagi.config.config import get_config
from superagi.types.storage_types import StorageType


class SchemaSettings:
Expand Down Expand Up @@ -258,3 +265,124 @@ def __init__(self, key: str, key_type: str = None, is_required: bool = False, is
self.key_type = key_type
else:
raise ValueError("key_type should be string/file/integer")

def get_resource_path( file_name: str):
"""Get final path of the resource.

Args:
file_name (str): The name of the file.
"""
root_output_dir = get_root_output_dir() + file_name
return root_output_dir


def get_root_output_dir():
"""Get root dir of the resource.
"""
root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR')

if root_dir is not None:
root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir
root_dir = root_dir if root_dir.endswith("/") else root_dir + "/"
else:
root_dir = os.getcwd() + "/"
return root_dir


class FileManager:
def __init__(self, session: Session, agent_id: int = None, agent_execution_id: int = None):
self.session = session
self.agent_id = agent_id
self.agent_execution_id = agent_execution_id

def write_binary_file(self, file_name: str, data):
if self.agent_id is not None:
final_path = get_resource_path(file_name)
else:
final_path = get_resource_path(file_name)
try:
with open(final_path, mode="wb") as img:
img.write(data)
img.close()
with open(final_path, 'rb') as img:
storage_type = StorageType.get_storage_type(get_config("STORAGE_TYPE", StorageType.FILE.value))
if storage_type == StorageType.S3.value:
S3Helper().upload_file(img, path=final_path)
logger.info(f"Binary {file_name} saved successfully")
return f"Binary {file_name} saved successfully"
except Exception as err:
return f"Error write_binary_file: {err}"

def write_file(self, file_name: str, content):
if self.agent_id is not None:
final_path = get_resource_path(file_name)

else:
final_path = get_resource_path(file_name)

try:
with open(final_path, mode="w") as file:
file.write(content)
file.close()

with open(final_path, 'rb') as img:
storage_type = StorageType.get_storage_type(get_config("STORAGE_TYPE", StorageType.FILE.value))
if storage_type == StorageType.S3.value:
S3Helper().upload_file(img, path=final_path)
logger.info(f"{file_name} - File written successfully")
return f"{file_name} - File written successfully"
except Exception as err:
return f"Error write_file: {err}"

def write_csv_file(self, file_name: str, csv_data):
if self.agent_id is not None:
final_path = get_resource_path(file_name)
else:
final_path = get_resource_path(file_name)
try:
with open(final_path, mode="w", newline="") as file:
writer = csv.writer(file, lineterminator="\n")
writer.writerows(csv_data)
with open(final_path, 'rb') as img:
storage_type = StorageType.get_storage_type(get_config("STORAGE_TYPE", StorageType.FILE.value))
if storage_type == StorageType.S3.value:
S3Helper().upload_file(img, path=final_path)
logger.info(f"{file_name} - File written successfully")
return f"{file_name} - File written successfully"
except Exception as err:
return f"Error write_csv_file: {err}"


def read_file(self, file_name: str):
if self.agent_id is not None:
final_path = get_resource_path(file_name)
else:
final_path = get_resource_path(file_name)

try:
with open(final_path, mode="r") as file:
content = file.read()
logger.info(f"{file_name} - File read successfully")
return content
except Exception as err:
return f"Error while reading file {file_name}: {err}"

def get_files(self):
"""
Gets all file names generated by the CodingTool.
Returns:
A list of file names.
"""

if self.agent_id is not None:
final_path = "/assets/output/"
else:
final_path = "/assets/output/"
try:
# List all files in the directory
files = os.listdir(final_path)
except Exception as err:
logger.error(f"Error while accessing files in {final_path}: {err}")
files = []
return files