diff --git a/superagi/tools/base_tool.py b/superagi/tools/base_tool.py index 720b6ed..b5bfc2a 100644 --- a/superagi/tools/base_tool.py +++ b/superagi/tools/base_tool.py @@ -5,11 +5,18 @@ from inspect import signature from typing import List from typing import Optional, Type, Callable, Any, Union, Dict, Tuple - +from enum import Enum import yaml from pydantic import BaseModel, create_model, validate_arguments, Extra from superagi.types.key_type import ToolConfigKeyType +import os +from sqlalchemy.orm import Session +import csv +from superagi.helper.s3_helper import S3Helper +from superagi.lib.logger import logger +from superagi.config.config import get_config +from superagi.types.storage_types import StorageType class SchemaSettings: @@ -258,3 +265,124 @@ def __init__(self, key: str, key_type: str = None, is_required: bool = False, is self.key_type = key_type else: raise ValueError("key_type should be string/file/integer") + +def get_resource_path( file_name: str): + """Get final path of the resource. + + Args: + file_name (str): The name of the file. + """ + root_output_dir = get_root_output_dir() + file_name + return root_output_dir + + +def get_root_output_dir(): + """Get root dir of the resource. + """ + root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR') + + if root_dir is not None: + root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir + root_dir = root_dir if root_dir.endswith("/") else root_dir + "/" + else: + root_dir = os.getcwd() + "/" + return root_dir + + +class FileManager: + def __init__(self, session: Session, agent_id: int = None, agent_execution_id: int = None): + self.session = session + self.agent_id = agent_id + self.agent_execution_id = agent_execution_id + + def write_binary_file(self, file_name: str, data): + if self.agent_id is not None: + final_path = get_resource_path(file_name) + else: + final_path = get_resource_path(file_name) + try: + with open(final_path, mode="wb") as img: + img.write(data) + img.close() + with open(final_path, 'rb') as img: + storage_type = StorageType.get_storage_type(get_config("STORAGE_TYPE", StorageType.FILE.value)) + if storage_type == StorageType.S3.value: + S3Helper().upload_file(img, path=final_path) + logger.info(f"Binary {file_name} saved successfully") + return f"Binary {file_name} saved successfully" + except Exception as err: + return f"Error write_binary_file: {err}" + + def write_file(self, file_name: str, content): + if self.agent_id is not None: + final_path = get_resource_path(file_name) + + else: + final_path = get_resource_path(file_name) + + try: + with open(final_path, mode="w") as file: + file.write(content) + file.close() + + with open(final_path, 'rb') as img: + storage_type = StorageType.get_storage_type(get_config("STORAGE_TYPE", StorageType.FILE.value)) + if storage_type == StorageType.S3.value: + S3Helper().upload_file(img, path=final_path) + logger.info(f"{file_name} - File written successfully") + return f"{file_name} - File written successfully" + except Exception as err: + return f"Error write_file: {err}" + + def write_csv_file(self, file_name: str, csv_data): + if self.agent_id is not None: + final_path = get_resource_path(file_name) + else: + final_path = get_resource_path(file_name) + try: + with open(final_path, mode="w", newline="") as file: + writer = csv.writer(file, lineterminator="\n") + writer.writerows(csv_data) + with open(final_path, 'rb') as img: + storage_type = StorageType.get_storage_type(get_config("STORAGE_TYPE", StorageType.FILE.value)) + if storage_type == StorageType.S3.value: + S3Helper().upload_file(img, path=final_path) + logger.info(f"{file_name} - File written successfully") + return f"{file_name} - File written successfully" + except Exception as err: + return f"Error write_csv_file: {err}" + + + def read_file(self, file_name: str): + if self.agent_id is not None: + final_path = get_resource_path(file_name) + else: + final_path = get_resource_path(file_name) + + try: + with open(final_path, mode="r") as file: + content = file.read() + logger.info(f"{file_name} - File read successfully") + return content + except Exception as err: + return f"Error while reading file {file_name}: {err}" + + def get_files(self): + """ + Gets all file names generated by the CodingTool. + Returns: + A list of file names. + """ + + if self.agent_id is not None: + final_path = "/assets/output/" + else: + final_path = "/assets/output/" + try: + # List all files in the directory + files = os.listdir(final_path) + except Exception as err: + logger.error(f"Error while accessing files in {final_path}: {err}") + files = [] + return files + \ No newline at end of file