diff --git a/backend/app/component/environment.py b/backend/app/component/environment.py index 15ee6a5c1..589290f36 100644 --- a/backend/app/component/environment.py +++ b/backend/app/component/environment.py @@ -7,37 +7,67 @@ import importlib from typing import Any, overload import threading +import re traceroot_logger = traceroot.get_logger("env") # Thread-local storage for user-specific environment _thread_local = threading.local() +BASE_ENV_DIR = Path.home() / ".eigent" + # Default global environment path -default_env_path = os.path.join(os.path.expanduser("~"), ".eigent", ".env") +default_env_path = str(BASE_ENV_DIR / ".env") load_dotenv(dotenv_path=default_env_path) -def set_user_env_path(env_path: str | None = None): +def _sanitize_env_path(env_path: str | None) -> str | None: + """ + Limit env files to .eigent directory and simple .env-like filenames. + Prevents directory traversal and arbitrary file reads. + """ + if not env_path: + return None + + # Only allow a filename (no path separators) + filename = os.path.basename(env_path) + if filename != env_path: + raise ValueError("env_path must not contain directories") + + if not re.fullmatch(r"\.env(\.[A-Za-z0-9_-]+)?", filename): + raise ValueError("env_path filename is not allowed") + + resolved = (BASE_ENV_DIR / filename).resolve() + return str(resolved) + + +def set_user_env_path(env_path: str | None = None) -> str: """ Set user-specific environment path for current thread. If env_path is None, uses default global environment. """ - traceroot_logger.info("Setting user environment path", extra={"env_path": env_path, "exists": env_path and os.path.exists(env_path) if env_path else None}) + try: + sanitized = _sanitize_env_path(env_path) + except ValueError as e: + traceroot_logger.warning("Rejecting unsafe env_path", extra={"env_path": env_path, "error": str(e)}) + sanitized = None + + traceroot_logger.info("Setting user environment path", extra={"env_path": sanitized, "exists": sanitized and os.path.exists(sanitized) if sanitized else None}) - if env_path and os.path.exists(env_path): - _thread_local.env_path = env_path + if sanitized and os.path.exists(sanitized): + _thread_local.env_path = sanitized # Load user-specific environment variables - load_dotenv(dotenv_path=env_path, override=True) - traceroot_logger.info("User-specific environment loaded", extra={"env_path": env_path}) + load_dotenv(dotenv_path=sanitized, override=True) + traceroot_logger.info("User-specific environment loaded", extra={"env_path": sanitized}) else: # Clear thread-local env_path to fall back to global if hasattr(_thread_local, 'env_path'): delattr(_thread_local, 'env_path') traceroot_logger.info("Reset to default global environment") - if env_path and not os.path.exists(env_path): - traceroot_logger.warning("User environment path does not exist, falling back to global", extra={"env_path": env_path}) + if sanitized and not os.path.exists(sanitized): + traceroot_logger.warning("User environment path does not exist, falling back to global", extra={"env_path": sanitized}) + return get_current_env_path() def get_current_env_path() -> str: diff --git a/backend/app/controller/chat_controller.py b/backend/app/controller/chat_controller.py index b7fc78262..586567ffd 100644 --- a/backend/app/controller/chat_controller.py +++ b/backend/app/controller/chat_controller.py @@ -81,8 +81,8 @@ async def post(data: Chat, request: Request): task_lock = get_or_create_task_lock(data.project_id) # Set user-specific environment path for this thread - set_user_env_path(data.env_path) - load_dotenv(dotenv_path=data.env_path) + env_path = set_user_env_path(data.env_path) + load_dotenv(dotenv_path=env_path) os.environ["file_save_path"] = data.file_save_path() os.environ["browser_port"] = str(data.browser_port) diff --git a/backend/app/controller/task_controller.py b/backend/app/controller/task_controller.py index 2bf3fc7f6..e2fec8cff 100644 --- a/backend/app/controller/task_controller.py +++ b/backend/app/controller/task_controller.py @@ -64,8 +64,8 @@ def add_agent(id: str, data: NewAgent): logger.info("Adding new agent to task", extra={"task_id": id, "agent_name": data.name}) logger.debug("New agent data", extra={"task_id": id, "agent_data": data.model_dump_json()}) # Set user-specific environment path for this thread - set_user_env_path(data.env_path) - load_dotenv(dotenv_path=data.env_path) + env_path = set_user_env_path(data.env_path) + load_dotenv(dotenv_path=env_path) asyncio.run(get_task_lock(id).put_queue(ActionNewAgent(**data.model_dump()))) logger.info("Agent added to task", extra={"task_id": id, "agent_name": data.name}) return Response(status_code=204) diff --git a/backend/app/service/chat_service.py b/backend/app/service/chat_service.py index 0ac7969d2..aca27b893 100644 --- a/backend/app/service/chat_service.py +++ b/backend/app/service/chat_service.py @@ -48,6 +48,40 @@ logger = traceroot.get_logger("chat_service") +def _safe_working_directory(path_value: str | None) -> str | None: + if not path_value: + return None + + base_dir = os.path.join(os.path.expanduser("~"), "eigent") + normalized = path_value.replace("\\", "/") + + if os.path.isabs(normalized): + try: + rel_path = os.path.relpath(normalized, base_dir) + except ValueError: + logger.warning(f"Rejected working directory outside base: {path_value}") + return None + else: + rel_path = normalized + + if rel_path.startswith("..") or rel_path.startswith("../"): + logger.warning(f"Rejected working directory outside base: {path_value}") + return None + + segments = [segment for segment in rel_path.split("/") if segment and segment != "."] + if not segments: + logger.warning(f"Invalid working directory path: {path_value}") + return None + + for segment in segments: + if not segment.isascii() or not segment.replace("-", "").replace("_", "").replace(".", "").isalnum(): + logger.warning(f"Invalid working directory segment: {segment}") + return None + + safe_path = os.path.join(base_dir, *segments) + return safe_path + + def format_task_context(task_data: dict, seen_files: set | None = None, skip_files: bool = False) -> str: """Format structured task data into a readable context string. @@ -66,7 +100,7 @@ def format_task_context(task_data: dict, seen_files: set | None = None, skip_fil # Skip file listing if requested if not skip_files: - working_directory = task_data.get('working_directory') + working_directory = _safe_working_directory(task_data.get('working_directory')) if working_directory: try: if os.path.exists(working_directory): @@ -126,10 +160,11 @@ def collect_previous_task_context(working_directory: str, previous_task_content: context_parts.append(f"Previous Task Result:\n{previous_task_result}\n") # Collect generated files from working directory + safe_working_directory = _safe_working_directory(working_directory) try: - if os.path.exists(working_directory): + if safe_working_directory and os.path.exists(safe_working_directory): generated_files = [] - for root, dirs, files in os.walk(working_directory): + for root, dirs, files in os.walk(safe_working_directory): dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ['node_modules', '__pycache__', 'venv']] for file in files: if not file.startswith('.') and not file.endswith(('.pyc', '.tmp')): @@ -203,6 +238,9 @@ def build_conversation_context(task_lock: TaskLock, header: str = "=== CONVERSAT if working_directories: all_generated_files = set() # Use set to avoid duplicates for working_directory in working_directories: + working_directory = _safe_working_directory(working_directory) + if not working_directory: + continue try: if os.path.exists(working_directory): for root, dirs, files in os.walk(working_directory): diff --git a/electron/main/index.ts b/electron/main/index.ts index fa62d7fa3..971232d62 100644 --- a/electron/main/index.ts +++ b/electron/main/index.ts @@ -1148,7 +1148,7 @@ async function createWindow() { // Use a dedicated partition for main window to isolate from webviews // This ensures main window's auth data (localStorage) is stored separately and persists across restarts partition: 'persist:main_window', - webSecurity: false, + webSecurity: true, preload, nodeIntegration: true, contextIsolation: true, diff --git a/resources/scripts/install-bun.js b/resources/scripts/install-bun.js index 735d69bf0..888cc4464 100644 --- a/resources/scripts/install-bun.js +++ b/resources/scripts/install-bun.js @@ -148,7 +148,7 @@ function detectPlatformAndArch() { function detectIsMusl() { try { // Simple check for Alpine Linux which uses MUSL - const output = execSync('cat /etc/os-release').toString() + const output = fs.readFileSync('/etc/os-release', 'utf8') return output.toLowerCase().includes('alpine') } catch (error) { return false diff --git a/resources/scripts/install-uv.js b/resources/scripts/install-uv.js index b35889085..882724b2a 100644 --- a/resources/scripts/install-uv.js +++ b/resources/scripts/install-uv.js @@ -160,7 +160,7 @@ function detectPlatformAndArch() { function detectIsMusl() { try { // Simple check for Alpine Linux which uses MUSL - const output = execSync("cat /etc/os-release").toString(); + const output = fs.readFileSync("/etc/os-release", "utf8"); return output.toLowerCase().includes("alpine"); } catch (error) { return false; diff --git a/server/app/controller/mcp/proxy_controller.py b/server/app/controller/mcp/proxy_controller.py index 0ec1a0cfd..b21c69344 100644 --- a/server/app/controller/mcp/proxy_controller.py +++ b/server/app/controller/mcp/proxy_controller.py @@ -186,7 +186,7 @@ def google_search(query: str, search_type: str = "web", key: Key = Depends(key_m logger.info("Google search completed", extra={"query": query, "search_type": search_type, "result_count": len(responses)}) else: error_info = data.get("error", {}) - logger.error("Google search API error", extra={"query": query, "api_error": error_info}) + logger.error("Google search API error", extra={"query": query}) raise HTTPException(status_code=500, detail="Internal server error") except Exception as e: diff --git a/server/app/controller/oauth/oauth_controller.py b/server/app/controller/oauth/oauth_controller.py index c43e50973..62fffad6b 100644 --- a/server/app/controller/oauth/oauth_controller.py +++ b/server/app/controller/oauth/oauth_controller.py @@ -35,32 +35,36 @@ def oauth_login(app: str, request: Request, state: Optional[str] = None): raise HTTPException(status_code=400, detail="OAuth login failed") +ALLOWED_OAUTH_PROVIDERS = {"slack", "notion", "x", "googlesuite"} + @router.get("/{app}/callback", name="OAuth Callback") @traceroot.trace() def oauth_callback(app: str, request: Request, code: Optional[str] = None, state: Optional[str] = None): """Handle OAuth provider callback and redirect to client app.""" - if not code: - logger.warning("OAuth callback missing code", extra={"provider": app}) - raise HTTPException(status_code=400, detail="Missing code parameter") - + import re + CODE_STATE_REGEX = re.compile(r'^[A-Za-z0-9_\-]+$') + from starlette.datastructures import URL + + if app not in ALLOWED_OAUTH_PROVIDERS: + logger.warning("Invalid OAuth provider", extra={"provider": app, "code": code}) + raise HTTPException(status_code=400, detail="Invalid OAuth provider") + if not code or not CODE_STATE_REGEX.match(code): + logger.warning("OAuth callback missing or invalid code", extra={"provider": app, "code": code}) + raise HTTPException(status_code=400, detail="Missing or invalid code parameter") + if state and not CODE_STATE_REGEX.match(state): + logger.warning("OAuth callback invalid state", extra={"provider": app, "state": state}) + raise HTTPException(status_code=400, detail="Invalid state parameter") + logger.info("OAuth callback received", extra={"provider": app, "has_state": state is not None}) - - redirect_url = f"eigent://callback/oauth?provider={app}&code={code}&state={state}" - html_content = f""" - - - OAuth Callback - - - -

Redirecting, please wait...

- - - - """ - return HTMLResponse(content=html_content) + + base_url = URL("eigent://callback/oauth") + redirect_url = base_url.include_query_params( + provider=app, + code=code, + state=state or "", + ) + + return RedirectResponse(str(redirect_url)) @router.post("/{app}/token", name="OAuth Fetch Token") diff --git a/server/app/controller/redirect_controller.py b/server/app/controller/redirect_controller.py index 3695a8fb4..53056fcae 100644 --- a/server/app/controller/redirect_controller.py +++ b/server/app/controller/redirect_controller.py @@ -1,72 +1,21 @@ -import json -from fastapi import APIRouter, Depends, Request -from fastapi_babel import _ -from fastapi.responses import HTMLResponse - +import re +from fastapi import APIRouter, Request,HTTPException +from fastapi.responses import RedirectResponse +from utils import traceroot_wrapper as traceroot +logger = traceroot.get_logger("server_redirect_controller") router = APIRouter(tags=["Redirect"]) @router.get("/redirect/callback") def redirect_callback(code: str, request: Request): - cookies = request.cookies - cookies_json = json.dumps(cookies) + from starlette.datastructures import URL + + if not re.match(r'^[A-Za-z0-9_-]+$', code): + logger.warning("redirect callback invalid code", extra={"code": code}) + raise HTTPException(status_code=400, detail="Invalid state parameter") + + base_url = URL("eigent://callback") + redirect_url = base_url.include_query_params(code=code) + return RedirectResponse(str(redirect_url)) - html_content = f""" - - - - - - Authorization successful - - - -
-

Authorization Successful

-

Redirecting to application...

-
Please wait...
-
- - - - """ - return HTMLResponse(content=html_content) diff --git a/server/app/model/chat/chat_snpshot.py b/server/app/model/chat/chat_snpshot.py index a1cb3a98a..a429130c4 100644 --- a/server/app/model/chat/chat_snpshot.py +++ b/server/app/model/chat/chat_snpshot.py @@ -1,56 +1,74 @@ -from typing import Optional -from sqlalchemy import Column, Integer, text -from sqlmodel import Field -from app.model.abstract.model import AbstractModel, DefaultTimes -from pydantic import BaseModel -import os -import base64 -import time - -from app.component.sqids import encode_user_id - - -class ChatSnapshot(AbstractModel, DefaultTimes, table=True): - id: int = Field(default=None, primary_key=True) - user_id: int = Field(sa_column=(Column(Integer, server_default=text("0")))) - api_task_id: str = Field(index=True) - camel_task_id: str = Field(index=True) - browser_url: str - image_path: str - - @classmethod - def get_user_dir(cls, user_id: int) -> str: - return os.path.join("app", "public", "upload", encode_user_id(user_id)) - - @classmethod - def caclDir(cls, path: str) -> float: - """Return disk usage of path directory (in MB, rounded to 2 decimal places)""" - total_size = 0 - for dirpath, dirnames, filenames in os.walk(path): - for f in filenames: - fp = os.path.join(dirpath, f) - if os.path.isfile(fp): - total_size += os.path.getsize(fp) - size_mb = total_size / (1024 * 1024) - return round(size_mb, 2) - - -class ChatSnapshotIn(BaseModel): - api_task_id: str - user_id: Optional[int] = None - camel_task_id: str - browser_url: str - image_base64: str - - @staticmethod - def save_image(user_id: int, api_task_id: str, image_base64: str) -> str: - if "," in image_base64: - image_base64 = image_base64.split(",", 1)[1] - user_dir = encode_user_id(user_id) - folder = os.path.join("app", "public", "upload", user_dir, api_task_id) - os.makedirs(folder, exist_ok=True) - filename = f"{int(time.time() * 1000)}.jpg" - file_path = os.path.join(folder, filename) - with open(file_path, "wb") as f: - f.write(base64.b64decode(image_base64)) - return f"/public/upload/{user_dir}/{api_task_id}/{filename}" +from typing import Optional +from sqlalchemy import Column, Integer, text +from sqlmodel import Field +from app.model.abstract.model import AbstractModel, DefaultTimes +from pydantic import BaseModel +import os +import base64 +import time +import re +from pathlib import Path +from uuid import uuid4 + +from app.component.sqids import encode_user_id + + +class ChatSnapshot(AbstractModel, DefaultTimes, table=True): + id: int = Field(default=None, primary_key=True) + user_id: int = Field(sa_column=(Column(Integer, server_default=text("0")))) + api_task_id: str = Field(index=True) + camel_task_id: str = Field(index=True) + browser_url: str + image_path: str + + @classmethod + def get_user_dir(cls, user_id: int) -> str: + return os.path.join("app", "public", "upload", encode_user_id(user_id)) + + @classmethod + def caclDir(cls, path: str) -> float: + """Return disk usage of path directory (in MB, rounded to 2 decimal places)""" + total_size = 0 + for dirpath, dirnames, filenames in os.walk(path): + for f in filenames: + fp = os.path.join(dirpath, f) + if os.path.isfile(fp): + total_size += os.path.getsize(fp) + size_mb = total_size / (1024 * 1024) + return round(size_mb, 2) + + +class ChatSnapshotIn(BaseModel): + api_task_id: str + user_id: Optional[int] = None + camel_task_id: str + browser_url: str + image_base64: str + + @staticmethod + def save_image(user_id: int, api_task_id: str, image_base64: str) -> str: + if "," in image_base64: + image_base64 = image_base64.split(",", 1)[1] + + user_dir = encode_user_id(user_id) + if os.path.basename(user_dir) != user_dir or not re.fullmatch(r"[A-Za-z0-9._-]{1,128}", user_dir or ""): + raise ValueError("Invalid user_id") + + base_dir = os.path.abspath(os.path.join("app", "public", "upload")) + + # Keep api_task_id as part of the path but ensure it cannot traverse + safe_api_task_id = os.path.basename(api_task_id) + if safe_api_task_id != api_task_id or not re.fullmatch(r"[A-Za-z0-9._-]{1,128}", safe_api_task_id or ""): + raise ValueError("Invalid api_task_id") + + folder = os.path.normpath(os.path.join(base_dir, user_dir, safe_api_task_id)) + # Directory traversal guard: ensure final path stays under base_dir + if not folder.startswith(base_dir + os.sep): + raise ValueError("Unsafe upload path detected") + + Path(folder).mkdir(parents=True, exist_ok=True) + filename = f"{int(time.time() * 1000)}_{uuid4().hex}.jpg" + file_path = os.path.join(folder, filename) + with open(file_path, "wb") as f: + f.write(base64.b64decode(image_base64)) + return f"/public/upload/{user_dir}/{safe_api_task_id}/{filename}" diff --git a/src/lib/oauth.ts b/src/lib/oauth.ts index 78f61d1bb..e5de6d87a 100644 --- a/src/lib/oauth.ts +++ b/src/lib/oauth.ts @@ -211,8 +211,17 @@ export class OAuth { async random(size: number) { const mask = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-._~"; - const randomUints = crypto.getRandomValues(new Uint8Array(size)); - return Array.from(randomUints).map(i => mask[i % mask.length]).join(''); + const maskLength = mask.length; + const out: string[] = []; + const maxUnbiased = 256 - (256 % maskLength); // rejection sampling to avoid modulo bias + + while (out.length < size) { + const byte = crypto.getRandomValues(new Uint8Array(1))[0]; + if (byte >= maxUnbiased) continue; + out.push(mask[byte % maskLength]); + } + + return out.join(''); } }