From 3d5614af4b8b8468aa7436de0a9a6ac76fa09d66 Mon Sep 17 00:00:00 2001 From: Aditya Sharma Date: Mon, 25 Sep 2023 15:53:01 +0530 Subject: [PATCH 1/2] exposing file_manager class --- superagi/tools/base_tool.py | 116 +++++++++++++++++++++++++++++++++++- 1 file changed, 115 insertions(+), 1 deletion(-) diff --git a/superagi/tools/base_tool.py b/superagi/tools/base_tool.py index 720b6ed..3e36d2a 100644 --- a/superagi/tools/base_tool.py +++ b/superagi/tools/base_tool.py @@ -5,11 +5,18 @@ from inspect import signature from typing import List from typing import Optional, Type, Callable, Any, Union, Dict, Tuple - +from enum import Enum import yaml from pydantic import BaseModel, create_model, validate_arguments, Extra from superagi.types.key_type import ToolConfigKeyType +import os +from sqlalchemy.orm import Session +import csv + +from superagi.helper.s3_helper import S3Helper +from superagi.lib.logger import logger +from superagi.config.config import get_config class SchemaSettings: @@ -258,3 +265,110 @@ def __init__(self, key: str, key_type: str = None, is_required: bool = False, is self.key_type = key_type else: raise ValueError("key_type should be string/file/integer") + +class StorageType(Enum): + FILE = 'FILE' + S3 = 'S3' + + @classmethod + def get_storage_type(cls, store): + if store is None: + raise ValueError("Storage type cannot be None.") + store = store.upper() + if store in cls.__members__: + return cls[store] + raise ValueError(f"{store} is not a valid storage name.") + +class FileManager: + def __init__(self, session: Session, agent_id: int = None, agent_execution_id: int = None): + self.session = session + self.agent_id = agent_id + self.agent_execution_id = agent_execution_id + + def write_binary_file(self, file_name: str, data): + if self.agent_id is not None: + final_path = f"/assets/output/{file_name}" + else: + final_path = f"/assets/output/{file_name}" + try: + with open(final_path, mode="wb") as img: + img.write(data) + img.close() + self.write_to_s3(file_name, final_path) + logger.info(f"Binary {file_name} saved successfully") + return f"Binary {file_name} saved successfully" + except Exception as err: + return f"Error write_binary_file: {err}" + + def write_to_s3(self, file_name, final_path): + with open(f"/assets/output/{file_name}", 'rb') as img: + + storage_type = StorageType.get_storage_type(get_config("STORAGE_TYPE", StorageType.FILE.value)) + + if storage_type == StorageType.S3.value: + s3_helper = S3Helper() + s3_helper.upload_file(img, path=f"/assets/output/{file_name}") + + def write_file(self, file_name: str, content): + if self.agent_id is not None: + final_path = f"/assets/output/{file_name}" + else: + final_path = f"/assets/output/{file_name}" + try: + with open(final_path, mode="w") as file: + file.write(content) + file.close() + self.write_to_s3(file_name, final_path) + logger.info(f"{file_name} - File written successfully") + return f"{file_name} - File written successfully" + except Exception as err: + return f"Error write_file: {err}" + + def write_csv_file(self, file_name: str, csv_data): + if self.agent_id is not None: + final_path = f"/assets/output/{file_name}" + else: + final_path = f"/assets/output/{file_name}" + try: + with open(final_path, mode="w", newline="") as file: + writer = csv.writer(file, lineterminator="\n") + writer.writerows(csv_data) + self.write_to_s3(file_name, final_path) + logger.info(f"{file_name} - File written successfully") + return f"{file_name} - File written successfully" + except Exception as err: + return f"Error write_csv_file: {err}" + + + def read_file(self, file_name: str): + if self.agent_id is not None: + final_path = f"/assets/output/{file_name}" + else: + final_path = f"/assets/output/{file_name}" + + try: + with open(final_path, mode="r") as file: + content = file.read() + logger.info(f"{file_name} - File read successfully") + return content + except Exception as err: + return f"Error while reading file {file_name}: {err}" + + def get_files(self): + """ + Gets all file names generated by the CodingTool. + Returns: + A list of file names. + """ + + if self.agent_id is not None: + final_path = "/assets/output/" + else: + final_path = "/assets/output/" + try: + # List all files in the directory + files = os.listdir(final_path) + except Exception as err: + logger.error(f"Error while accessing files in {final_path}: {err}") + files = [] + return files \ No newline at end of file From 0e00c927dfa8c3ec5231bd29228a0ab5e54188da Mon Sep 17 00:00:00 2001 From: Aditya Sharma Date: Thu, 28 Sep 2023 11:23:17 +0530 Subject: [PATCH 2/2] adding fileManager --- superagi/tools/base_tool.py | 82 ++++++++++++++++++++++--------------- 1 file changed, 48 insertions(+), 34 deletions(-) diff --git a/superagi/tools/base_tool.py b/superagi/tools/base_tool.py index 3e36d2a..b5bfc2a 100644 --- a/superagi/tools/base_tool.py +++ b/superagi/tools/base_tool.py @@ -13,10 +13,10 @@ import os from sqlalchemy.orm import Session import csv - from superagi.helper.s3_helper import S3Helper from superagi.lib.logger import logger from superagi.config.config import get_config +from superagi.types.storage_types import StorageType class SchemaSettings: @@ -265,19 +265,29 @@ def __init__(self, key: str, key_type: str = None, is_required: bool = False, is self.key_type = key_type else: raise ValueError("key_type should be string/file/integer") - -class StorageType(Enum): - FILE = 'FILE' - S3 = 'S3' + +def get_resource_path( file_name: str): + """Get final path of the resource. + + Args: + file_name (str): The name of the file. + """ + root_output_dir = get_root_output_dir() + file_name + return root_output_dir + + +def get_root_output_dir(): + """Get root dir of the resource. + """ + root_dir = get_config('RESOURCES_OUTPUT_ROOT_DIR') + + if root_dir is not None: + root_dir = root_dir if root_dir.startswith("/") else os.getcwd() + "/" + root_dir + root_dir = root_dir if root_dir.endswith("/") else root_dir + "/" + else: + root_dir = os.getcwd() + "/" + return root_dir - @classmethod - def get_storage_type(cls, store): - if store is None: - raise ValueError("Storage type cannot be None.") - store = store.upper() - if store in cls.__members__: - return cls[store] - raise ValueError(f"{store} is not a valid storage name.") class FileManager: def __init__(self, session: Session, agent_id: int = None, agent_execution_id: int = None): @@ -287,38 +297,38 @@ def __init__(self, session: Session, agent_id: int = None, agent_execution_id: i def write_binary_file(self, file_name: str, data): if self.agent_id is not None: - final_path = f"/assets/output/{file_name}" + final_path = get_resource_path(file_name) else: - final_path = f"/assets/output/{file_name}" + final_path = get_resource_path(file_name) try: with open(final_path, mode="wb") as img: img.write(data) img.close() - self.write_to_s3(file_name, final_path) + with open(final_path, 'rb') as img: + storage_type = StorageType.get_storage_type(get_config("STORAGE_TYPE", StorageType.FILE.value)) + if storage_type == StorageType.S3.value: + S3Helper().upload_file(img, path=final_path) logger.info(f"Binary {file_name} saved successfully") return f"Binary {file_name} saved successfully" except Exception as err: return f"Error write_binary_file: {err}" - - def write_to_s3(self, file_name, final_path): - with open(f"/assets/output/{file_name}", 'rb') as img: - - storage_type = StorageType.get_storage_type(get_config("STORAGE_TYPE", StorageType.FILE.value)) - - if storage_type == StorageType.S3.value: - s3_helper = S3Helper() - s3_helper.upload_file(img, path=f"/assets/output/{file_name}") def write_file(self, file_name: str, content): if self.agent_id is not None: - final_path = f"/assets/output/{file_name}" + final_path = get_resource_path(file_name) + else: - final_path = f"/assets/output/{file_name}" + final_path = get_resource_path(file_name) + try: with open(final_path, mode="w") as file: file.write(content) file.close() - self.write_to_s3(file_name, final_path) + + with open(final_path, 'rb') as img: + storage_type = StorageType.get_storage_type(get_config("STORAGE_TYPE", StorageType.FILE.value)) + if storage_type == StorageType.S3.value: + S3Helper().upload_file(img, path=final_path) logger.info(f"{file_name} - File written successfully") return f"{file_name} - File written successfully" except Exception as err: @@ -326,14 +336,17 @@ def write_file(self, file_name: str, content): def write_csv_file(self, file_name: str, csv_data): if self.agent_id is not None: - final_path = f"/assets/output/{file_name}" + final_path = get_resource_path(file_name) else: - final_path = f"/assets/output/{file_name}" + final_path = get_resource_path(file_name) try: with open(final_path, mode="w", newline="") as file: writer = csv.writer(file, lineterminator="\n") writer.writerows(csv_data) - self.write_to_s3(file_name, final_path) + with open(final_path, 'rb') as img: + storage_type = StorageType.get_storage_type(get_config("STORAGE_TYPE", StorageType.FILE.value)) + if storage_type == StorageType.S3.value: + S3Helper().upload_file(img, path=final_path) logger.info(f"{file_name} - File written successfully") return f"{file_name} - File written successfully" except Exception as err: @@ -342,9 +355,9 @@ def write_csv_file(self, file_name: str, csv_data): def read_file(self, file_name: str): if self.agent_id is not None: - final_path = f"/assets/output/{file_name}" + final_path = get_resource_path(file_name) else: - final_path = f"/assets/output/{file_name}" + final_path = get_resource_path(file_name) try: with open(final_path, mode="r") as file: @@ -371,4 +384,5 @@ def get_files(self): except Exception as err: logger.error(f"Error while accessing files in {final_path}: {err}") files = [] - return files \ No newline at end of file + return files + \ No newline at end of file