diff --git a/.codacy.yaml b/.codacy.yaml new file mode 100644 index 0000000..866ba05 --- /dev/null +++ b/.codacy.yaml @@ -0,0 +1,7 @@ +--- +exclude_paths: + - 'docs/**' + - 'docs/requirements.txt' + - '**/source.zh-TW/conf.py' + - '**/source/conf.py' + - 'examples/**' diff --git a/README.md b/README.md index edc387b..84301ee 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,12 @@ facade. - **HTTPActionClient SDK** — typed Python client for the HTTP action server with shared-secret auth, loopback guard, and OPTIONS-based ping - **AES-256-GCM file encryption** — `encrypt_file` / `decrypt_file` with `generate_key()` / `key_from_password()` (PBKDF2-HMAC-SHA256); JSON actions `FA_encrypt_file` / `FA_decrypt_file` - **Prometheus metrics exporter** — `start_metrics_server()` exposes `automation_file_actions_total{action,status}` counters and `automation_file_action_duration_seconds{action}` histograms +- **WebDAV backend** — `WebDAVClient` with `exists` / `upload` / `download` / `delete` / `mkcol` / `list_dir` on any RFC 4918 server; rejects private / loopback targets unless `allow_private_hosts=True` +- **SMB / CIFS backend** — `SMBClient` over `smbprotocol`'s high-level `smbclient` API; UNC-based, encrypted sessions by default +- **fsspec bridge** — drive any `fsspec`-backed filesystem (memory, local, s3, gcs, abfs, …) through the action registry with `get_fs` / `fsspec_upload` / `fsspec_download` / `fsspec_list_dir` etc. +- **HTTP server observability** — `GET /healthz` / `GET /readyz` probes, `GET /openapi.json` spec, and `GET /progress` WebSocket stream of live transfer snapshots +- **HTMX Web UI** — `start_web_ui()` serves a read-only dashboard (health, progress, registry) that polls HTML fragments; stdlib-only HTTP plus one CDN script with SRI +- **MCP (Model Context Protocol) server** — `MCPServer` bridges the registry to any MCP host (Claude Desktop, MCP CLIs) over newline-delimited JSON-RPC 2.0 on stdio; every `FA_*` action becomes an MCP tool with an auto-generated input schema - PySide6 GUI (`python -m automation_file ui`) with a tab per backend, the JSON-action runner, and dedicated tabs for Triggers, Scheduler, and live Progress - Rich CLI with one-shot subcommands plus legacy JSON-batch flags - Project scaffolding (`ProjectBuilder`) for executor-based automations @@ -624,6 +630,77 @@ Exports `automation_file_actions_total{action,status}` and `automation_file_action_duration_seconds{action}`. Non-loopback binds require `allow_non_loopback=True` explicitly. +### WebDAV, SMB/CIFS, fsspec +Extra remote backends alongside the first-class S3 / Azure / Dropbox / SFTP: + +```python +from automation_file import WebDAVClient, SMBClient, fsspec_upload + +# RFC 4918 WebDAV — loopback/private targets require opt-in. +dav = WebDAVClient("https://files.example.com/remote.php/dav", + username="alice", password="s3cr3t") +dav.upload("/local/report.csv", "team/reports/report.csv") + +# SMB / CIFS via smbprotocol's high-level smbclient API. +with SMBClient("fileserver", "share", "alice", "s3cr3t") as smb: + smb.upload("/local/report.csv", "reports/report.csv") + +# Anything fsspec can address — memory, gcs, abfs, local, … +fsspec_upload("/local/report.csv", "memory://reports/report.csv") +``` + +### HTTP server observability +`start_http_action_server()` additionally exposes liveness / readiness probes, +an OpenAPI 3.0 spec, and a WebSocket stream of progress snapshots: + +```bash +curl http://127.0.0.1:9944/healthz # {"status": "ok"} +curl http://127.0.0.1:9944/readyz # 200 when registry non-empty, 503 otherwise +curl http://127.0.0.1:9944/openapi.json # OpenAPI 3.0 spec +# Connect a WebSocket to ws://127.0.0.1:9944/progress for live progress frames. +``` + +### HTMX Web UI +A read-only observability dashboard built on stdlib HTTP + HTMX (loaded from +a pinned CDN URL with SRI). Loopback-only by default; optional shared secret: + +```python +from automation_file import start_web_ui + +server = start_web_ui(host="127.0.0.1", port=9955, shared_secret="s3cr3t") +# Browse http://127.0.0.1:9955/ — health, progress, and registry fragments +# auto-poll every few seconds. Write operations stay on the action servers. +``` + +### MCP (Model Context Protocol) server +Expose every registered `FA_*` action to an MCP host (Claude Desktop, MCP +CLIs) over JSON-RPC 2.0 on stdio: + +```python +from automation_file import MCPServer + +MCPServer().serve_stdio() # reads JSON-RPC from stdin, writes to stdout +``` + +`pip install` exposes an `automation_file_mcp` console script (via +`[project.scripts]`) so MCP hosts can launch the bridge without any Python +glue. Three equivalent launch styles: + +```bash +automation_file_mcp # installed console script +python -m automation_file mcp # CLI subcommand +python examples/mcp/run_mcp.py # standalone launcher +``` + +All three accept `--name`, `--version`, and `--allowed-actions` (comma- +separated whitelist — strongly recommended since the default registry +includes high-privilege actions like `FA_run_shell`). See +[`examples/mcp/`](examples/mcp) for ready-to-copy Claude Desktop config. + +Tool descriptors are generated on the fly by introspecting each action's +signature — parameter names and types become a JSON schema, so hosts can +render fields without any manual wiring. + ### DAG action executor Run actions in dependency order; independent branches fan out across a thread pool. Each node is `{"id": ..., "action": [...], "depends_on": @@ -709,6 +786,8 @@ python -m automation_file create-file hello.txt --content "hi" python -m automation_file server --host 127.0.0.1 --port 9943 python -m automation_file http-server --host 127.0.0.1 --port 9944 python -m automation_file drive-upload my.txt --token token.json --credentials creds.json +python -m automation_file mcp --allowed-actions FA_file_checksum,FA_fast_find +automation_file_mcp --allowed-actions FA_file_checksum,FA_fast_find # installed console script # Legacy flags (JSON action lists) python -m automation_file --execute_file actions.json diff --git a/README.zh-CN.md b/README.zh-CN.md index c0ae731..c45eeba 100644 --- a/README.zh-CN.md +++ b/README.zh-CN.md @@ -40,6 +40,12 @@ TCP / HTTP 服务器执行的 JSON 驱动动作。内附 PySide6 GUI,每个功 - **HTTPActionClient SDK** — HTTP 动作服务器的类型化 Python 客户端,具 shared-secret 认证、loopback 守护与 OPTIONS ping - **AES-256-GCM 文件加密** — `encrypt_file` / `decrypt_file` 搭配 `generate_key()` / `key_from_password()`(PBKDF2-HMAC-SHA256);JSON 动作 `FA_encrypt_file` / `FA_decrypt_file` - **Prometheus metrics 导出器** — `start_metrics_server()` 提供 `automation_file_actions_total{action,status}` 计数器与 `automation_file_action_duration_seconds{action}` 直方图 +- **WebDAV 后端** — `WebDAVClient` 提供 `exists` / `upload` / `download` / `delete` / `mkcol` / `list_dir`,适用于任何 RFC 4918 服务器;除非显式传入 `allow_private_hosts=True`,否则拒绝私有 / loopback 目标 +- **SMB / CIFS 后端** — `SMBClient` 基于 `smbprotocol` 的高阶 `smbclient` API;采用 UNC 路径,默认启用加密会话 +- **fsspec 桥接** — 通过 `get_fs` / `fsspec_upload` / `fsspec_download` / `fsspec_list_dir` 等函数,驱动任何 `fsspec` 支持的文件系统(memory、local、s3、gcs、abfs、…) +- **HTTP 服务器观测端点** — `GET /healthz` / `GET /readyz` 探针、`GET /openapi.json` 规格,以及 `GET /progress`(通过 WebSocket 推送实时传输快照) +- **HTMX Web UI** — `start_web_ui()` 启动只读观测仪表板(health、progress、registry),通过 HTML 片段轮询;仅用标准库 HTTP,搭配一个带 SRI 的 CDN 脚本 +- **MCP(Model Context Protocol)服务器** — `MCPServer` 通过 stdio 上的 JSON-RPC 2.0(换行分隔 JSON)将注册表桥接到任意 MCP 主机(Claude Desktop、MCP CLI);每个 `FA_*` 动作都会自动生成输入 schema 并成为 MCP 工具 - PySide6 GUI(`python -m automation_file ui`)每个后端一个页签,含 JSON 动作执行器,另有 Triggers、Scheduler、实时 Progress 专属页签 - 功能丰富的 CLI,包含一次性子命令与旧式 JSON 批量标志 - 项目脚手架(`ProjectBuilder`)协助构建以 executor 为核心的自动化项目 @@ -611,6 +617,74 @@ server = start_metrics_server(host="127.0.0.1", port=9945) `automation_file_action_duration_seconds{action}`。若要绑定非 loopback 地址必须显式传入 `allow_non_loopback=True`。 +### WebDAV、SMB/CIFS、fsspec +在一等公民的 S3 / Azure / Dropbox / SFTP 之外,额外的远程后端: + +```python +from automation_file import WebDAVClient, SMBClient, fsspec_upload + +# RFC 4918 WebDAV —— loopback / 私有目标需要显式开关。 +dav = WebDAVClient("https://files.example.com/remote.php/dav", + username="alice", password="s3cr3t") +dav.upload("/local/report.csv", "team/reports/report.csv") + +# 通过 smbprotocol 的高阶 smbclient API 操作 SMB / CIFS。 +with SMBClient("fileserver", "share", "alice", "s3cr3t") as smb: + smb.upload("/local/report.csv", "reports/report.csv") + +# 任何 fsspec 能寻址的目标 —— memory、gcs、abfs、local、… +fsspec_upload("/local/report.csv", "memory://reports/report.csv") +``` + +### HTTP 服务器观测端点 +`start_http_action_server()` 额外提供 liveness / readiness 探针、OpenAPI 3.0 +规格,以及实时进度快照的 WebSocket 流: + +```bash +curl http://127.0.0.1:9944/healthz # {"status": "ok"} +curl http://127.0.0.1:9944/readyz # 注册表非空时 200,否则 503 +curl http://127.0.0.1:9944/openapi.json # OpenAPI 3.0 规格 +# 使用 WebSocket 连接 ws://127.0.0.1:9944/progress 获取实时进度帧。 +``` + +### HTMX Web UI +基于标准库 HTTP + HTMX(以带 SRI 的固定 CDN URL 加载)构建的只读观测仪表板。 +默认仅允许 loopback,可选 shared-secret: + +```python +from automation_file import start_web_ui + +server = start_web_ui(host="127.0.0.1", port=9955, shared_secret="s3cr3t") +# 浏览 http://127.0.0.1:9955/ —— health、progress、registry 片段每几秒 +# 自动轮询一次;写入操作仍然保留在动作服务器。 +``` + +### MCP(Model Context Protocol)服务器 +通过 stdio 上的 JSON-RPC 2.0 把每个已注册的 `FA_*` 动作暴露给 MCP 主机 +(Claude Desktop、MCP CLI): + +```python +from automation_file import MCPServer + +MCPServer().serve_stdio() # 从 stdin 读取 JSON-RPC,写入 stdout +``` + +`pip install` 后,`[project.scripts]` 会提供 `automation_file_mcp` console +script,MCP 主机无需编写 Python glue 即可启动桥接器。三种等价的启动方式: + +```bash +automation_file_mcp # 已安装的 console script +python -m automation_file mcp # CLI 子命令 +python examples/mcp/run_mcp.py # 独立启动脚本 +``` + +三者都支持 `--name`、`--version`、`--allowed-actions`(逗号分隔白名单—— +强烈建议使用,因为默认注册表包含 `FA_run_shell` 等高权限动作)。可直接复制的 +Claude Desktop 示例配置请见 [`examples/mcp/`](examples/mcp)。 + +工具描述符在运行时由动作签名自动生成——参数名称与类型会转换为 JSON schema, +主机无需任何手动配置即可渲染字段。 + ### DAG 动作执行器 按依赖顺序执行动作;独立分支通过线程池并行展开。每个节点的形式为 `{"id": ..., "action": [...], "depends_on": [...]}`: @@ -693,6 +767,8 @@ python -m automation_file create-file hello.txt --content "hi" python -m automation_file server --host 127.0.0.1 --port 9943 python -m automation_file http-server --host 127.0.0.1 --port 9944 python -m automation_file drive-upload my.txt --token token.json --credentials creds.json +python -m automation_file mcp --allowed-actions FA_file_checksum,FA_fast_find +automation_file_mcp --allowed-actions FA_file_checksum,FA_fast_find # 已安装的 console script # 旧式标志(JSON 动作清单) python -m automation_file --execute_file actions.json diff --git a/README.zh-TW.md b/README.zh-TW.md index ee46f63..8d2a086 100644 --- a/README.zh-TW.md +++ b/README.zh-TW.md @@ -40,6 +40,12 @@ TCP / HTTP 伺服器執行的 JSON 驅動動作。內附 PySide6 GUI,每個功 - **HTTPActionClient SDK** — HTTP 動作伺服器的型別化 Python 客戶端,具 shared-secret 驗證、loopback 防護與 OPTIONS ping - **AES-256-GCM 檔案加密** — `encrypt_file` / `decrypt_file` 搭配 `generate_key()` / `key_from_password()`(PBKDF2-HMAC-SHA256);JSON 動作 `FA_encrypt_file` / `FA_decrypt_file` - **Prometheus metrics 匯出器** — `start_metrics_server()` 提供 `automation_file_actions_total{action,status}` 計數器與 `automation_file_action_duration_seconds{action}` 直方圖 +- **WebDAV 後端** — `WebDAVClient` 提供 `exists` / `upload` / `download` / `delete` / `mkcol` / `list_dir`,適用於任何 RFC 4918 伺服器;除非顯式傳入 `allow_private_hosts=True`,否則拒絕私有 / loopback 目標 +- **SMB / CIFS 後端** — `SMBClient` 建構於 `smbprotocol` 的高階 `smbclient` API;採用 UNC 路徑,預設啟用加密連線 +- **fsspec 橋接** — 透過 `get_fs` / `fsspec_upload` / `fsspec_download` / `fsspec_list_dir` 等函式,驅動任何 `fsspec` 支援的檔案系統(memory、local、s3、gcs、abfs、…) +- **HTTP 伺服器觀測端點** — `GET /healthz` / `GET /readyz` 探針、`GET /openapi.json` 規格、以及 `GET /progress`(以 WebSocket 推送即時傳輸快照) +- **HTMX Web UI** — `start_web_ui()` 啟動唯讀觀測儀表板(health、progress、registry),以 HTML 片段輪詢;僅用標準函式庫 HTTP,搭配一支帶 SRI 的 CDN 腳本 +- **MCP(Model Context Protocol)伺服器** — `MCPServer` 透過 stdio 上的 JSON-RPC 2.0(行分隔 JSON)將登錄表橋接到任何 MCP 主機(Claude Desktop、MCP CLI);每個 `FA_*` 動作都會自動生成輸入 schema 並成為 MCP 工具 - PySide6 GUI(`python -m automation_file ui`)每個後端一個分頁,含 JSON 動作執行器,另有 Triggers、Scheduler、即時 Progress 專屬分頁 - 功能豐富的 CLI,包含一次性子指令與舊式 JSON 批次旗標 - 專案鷹架(`ProjectBuilder`)協助建立以 executor 為核心的自動化專案 @@ -611,6 +617,74 @@ server = start_metrics_server(host="127.0.0.1", port=9945) `automation_file_action_duration_seconds{action}`。若要綁定非 loopback 位址必須明確傳入 `allow_non_loopback=True`。 +### WebDAV、SMB/CIFS、fsspec +在一等公民的 S3 / Azure / Dropbox / SFTP 之外,另有額外的遠端後端: + +```python +from automation_file import WebDAVClient, SMBClient, fsspec_upload + +# RFC 4918 WebDAV —— loopback / 私有目標需要顯式開關。 +dav = WebDAVClient("https://files.example.com/remote.php/dav", + username="alice", password="s3cr3t") +dav.upload("/local/report.csv", "team/reports/report.csv") + +# 透過 smbprotocol 的高階 smbclient API 操作 SMB / CIFS。 +with SMBClient("fileserver", "share", "alice", "s3cr3t") as smb: + smb.upload("/local/report.csv", "reports/report.csv") + +# 任何 fsspec 能定址的目標 —— memory、gcs、abfs、local、… +fsspec_upload("/local/report.csv", "memory://reports/report.csv") +``` + +### HTTP 伺服器觀測端點 +`start_http_action_server()` 額外提供 liveness / readiness 探針、OpenAPI 3.0 +規格,以及即時進度快照的 WebSocket 串流: + +```bash +curl http://127.0.0.1:9944/healthz # {"status": "ok"} +curl http://127.0.0.1:9944/readyz # 登錄表非空時 200,否則 503 +curl http://127.0.0.1:9944/openapi.json # OpenAPI 3.0 規格 +# 以 WebSocket 連線 ws://127.0.0.1:9944/progress 取得即時進度訊框。 +``` + +### HTMX Web UI +建構於標準函式庫 HTTP + HTMX(以帶 SRI 的固定 CDN URL 載入)之上的唯讀觀測 +儀表板。預設僅允許 loopback,可選 shared-secret: + +```python +from automation_file import start_web_ui + +server = start_web_ui(host="127.0.0.1", port=9955, shared_secret="s3cr3t") +# 瀏覽 http://127.0.0.1:9955/ —— health、progress、registry 片段每數秒 +# 自動輪詢;寫入操作仍保留在動作伺服器。 +``` + +### MCP(Model Context Protocol)伺服器 +透過 stdio 上的 JSON-RPC 2.0 將登錄的每個 `FA_*` 動作暴露給 MCP 主機 +(Claude Desktop、MCP CLI): + +```python +from automation_file import MCPServer + +MCPServer().serve_stdio() # 從 stdin 讀取 JSON-RPC,寫入 stdout +``` + +`pip install` 後,`[project.scripts]` 會提供 `automation_file_mcp` console +script,MCP 主機不需要寫任何 Python glue 也能啟動橋接器。三種等價的啟動方式: + +```bash +automation_file_mcp # 已安裝的 console script +python -m automation_file mcp # CLI 子指令 +python examples/mcp/run_mcp.py # 獨立啟動腳本 +``` + +三者皆支援 `--name`、`--version`、`--allowed-actions`(逗號分隔白名單—— +強烈建議使用,因為預設登錄表包含 `FA_run_shell` 等高權限動作)。可直接複製的 +Claude Desktop 範例設定請見 [`examples/mcp/`](examples/mcp)。 + +工具描述在執行時由動作簽章自動生成——參數名稱與型別會轉換為 JSON schema, +主機無需任何手動設定即可渲染欄位。 + ### DAG 動作執行器 依相依關係執行動作;獨立分支會透過執行緒池平行展開。每個節點形式為 `{"id": ..., "action": [...], "depends_on": [...]}`: @@ -693,6 +767,8 @@ python -m automation_file create-file hello.txt --content "hi" python -m automation_file server --host 127.0.0.1 --port 9943 python -m automation_file http-server --host 127.0.0.1 --port 9944 python -m automation_file drive-upload my.txt --token token.json --credentials creds.json +python -m automation_file mcp --allowed-actions FA_file_checksum,FA_fast_find +automation_file_mcp --allowed-actions FA_file_checksum,FA_fast_find # 已安裝的 console script # 舊式旗標(JSON 動作清單) python -m automation_file --execute_file actions.json diff --git a/automation_file/__init__.py b/automation_file/__init__.py index 3a620d1..7dc4507 100644 --- a/automation_file/__init__.py +++ b/automation_file/__init__.py @@ -19,6 +19,7 @@ executor, validate_action, ) +from automation_file.core.action_queue import ActionQueue, QueueItem from automation_file.core.action_registry import ActionRegistry, build_default_registry from automation_file.core.audit import AuditException, AuditLog from automation_file.core.callback_executor import CallbackExecutor @@ -27,8 +28,10 @@ file_checksum, verify_checksum, ) +from automation_file.core.circuit_breaker import CircuitBreaker from automation_file.core.config import AutomationConfig, ConfigException from automation_file.core.config_watcher import ConfigWatcher +from automation_file.core.content_store import ContentStore from automation_file.core.crypto import ( CryptoException, decrypt_file, @@ -37,6 +40,7 @@ key_from_password, ) from automation_file.core.dag_executor import execute_action_dag +from automation_file.core.file_lock import FileLock from automation_file.core.fim import IntegrityMonitor from automation_file.core.json_store import read_action_json, write_action_json from automation_file.core.manifest import ManifestException, verify_manifest, write_manifest @@ -55,6 +59,7 @@ register_progress_ops, ) from automation_file.core.quota import Quota +from automation_file.core.rate_limit import RateLimiter from automation_file.core.retry import retry_on_transient from automation_file.core.secrets import ( ChainedSecretProvider, @@ -66,8 +71,22 @@ default_provider, resolve_secret_refs, ) +from automation_file.core.sqlite_lock import SQLiteLock from automation_file.core.substitution import SubstitutionException, substitute +from automation_file.local.archive_ops import ( + detect_archive_format, + extract_archive, + list_archive, + supported_formats, +) from automation_file.local.conditional import if_exists, if_newer, if_size_gt +from automation_file.local.diff_ops import ( + DirDiff, + apply_dir_diff, + diff_dirs, + diff_text_files, + iter_dir_diff, +) from automation_file.local.dir_ops import copy_dir, create_dir, remove_dir_tree, rename_dir from automation_file.local.file_ops import ( copy_all_file_to_dir, @@ -83,10 +102,20 @@ json_get, json_set, ) +from automation_file.local.mime import detect_from_bytes, detect_mime from automation_file.local.safe_paths import is_within, safe_join from automation_file.local.shell_ops import ShellException, run_shell from automation_file.local.sync_ops import SyncException, sync_dir from automation_file.local.tar_ops import TarException, create_tar, extract_tar +from automation_file.local.templates import render_file, render_string +from automation_file.local.trash import ( + TrashEntry, + empty_trash, + list_trash, + restore_from_trash, + send_to_trash, +) +from automation_file.local.versioning import FileVersioner, VersionEntry from automation_file.local.zip_ops import ( read_zip_file, set_zip_password, @@ -124,6 +153,16 @@ dropbox_instance, register_dropbox_ops, ) +from automation_file.remote.fsspec_bridge import ( + FsspecEntry, + fsspec_delete, + fsspec_download, + fsspec_exists, + fsspec_list_dir, + fsspec_mkdir, + fsspec_upload, + get_fs, +) from automation_file.remote.ftp import ( FTPClient, FTPConnectOptions, @@ -157,7 +196,9 @@ from automation_file.remote.http_download import download_file from automation_file.remote.s3 import S3Client, register_s3_ops, s3_instance from automation_file.remote.sftp import SFTPClient, register_sftp_ops, sftp_instance +from automation_file.remote.smb import SMBClient, SMBEntry from automation_file.remote.url_validator import validate_http_url +from automation_file.remote.webdav import WebDAVClient, WebDAVEntry from automation_file.scheduler import ( CronExpression, ScheduledJob, @@ -171,11 +212,13 @@ ) from automation_file.server.action_acl import ActionACL, ActionNotPermittedException from automation_file.server.http_server import HTTPActionServer, start_http_action_server +from automation_file.server.mcp_server import MCPServer, tools_from_registry from automation_file.server.metrics_server import MetricsServer, start_metrics_server from automation_file.server.tcp_server import ( TCPActionServer, start_autocontrol_socket_server, ) +from automation_file.server.web_ui import WebUIServer, start_web_ui from automation_file.trigger import ( FileWatcher, TriggerManager, @@ -213,10 +256,17 @@ def __getattr__(name: str) -> Any: __all__ = [ # Core "ActionExecutor", + "ActionQueue", "ActionRegistry", "CallbackExecutor", + "CircuitBreaker", + "ContentStore", + "FileLock", "PackageLoader", "Quota", + "QueueItem", + "RateLimiter", + "SQLiteLock", "build_default_registry", "execute_action", "execute_action_parallel", @@ -239,10 +289,30 @@ def __getattr__(name: str) -> Any: "copy_all_file_to_dir", "copy_specify_extension_file", "create_file", + "DirDiff", + "FileVersioner", + "TrashEntry", + "VersionEntry", + "apply_dir_diff", "copy_dir", "create_dir", + "detect_archive_format", + "detect_from_bytes", + "detect_mime", + "diff_dirs", + "diff_text_files", + "empty_trash", + "extract_archive", + "iter_dir_diff", + "list_archive", + "list_trash", "remove_dir_tree", "rename_dir", + "render_file", + "render_string", + "restore_from_trash", + "send_to_trash", + "supported_formats", "sync_dir", "SyncException", "create_tar", @@ -305,6 +375,18 @@ def __getattr__(name: str) -> Any: "register_ftp_ops", "CrossBackendException", "copy_between", + "WebDAVClient", + "WebDAVEntry", + "SMBClient", + "SMBEntry", + "FsspecEntry", + "fsspec_delete", + "fsspec_download", + "fsspec_exists", + "fsspec_list_dir", + "fsspec_mkdir", + "fsspec_upload", + "get_fs", # Server / Project / Utils "TCPActionServer", "start_autocontrol_socket_server", @@ -346,6 +428,10 @@ def __getattr__(name: str) -> Any: "render_metrics", "MetricsServer", "start_metrics_server", + "WebUIServer", + "start_web_ui", + "MCPServer", + "tools_from_registry", # Triggers "FileWatcher", "TriggerManager", diff --git a/automation_file/__main__.py b/automation_file/__main__.py index c050b8a..12c00d0 100644 --- a/automation_file/__main__.py +++ b/automation_file/__main__.py @@ -104,6 +104,15 @@ def _cmd_ui(_args: argparse.Namespace) -> int: return launch_ui() +def _cmd_mcp(args: argparse.Namespace) -> int: + from automation_file.server.mcp_server import _cli as mcp_cli + + forwarded: list[str] = ["--name", args.name, "--version", args.version] + if args.allowed_actions: + forwarded.extend(["--allowed-actions", args.allowed_actions]) + return mcp_cli(forwarded) + + def _cmd_drive_upload(args: argparse.Namespace) -> int: from automation_file.remote.google_drive.client import driver_instance from automation_file.remote.google_drive.upload_ops import ( @@ -177,6 +186,18 @@ def _build_parser() -> argparse.ArgumentParser: ui_parser = subparsers.add_parser("ui", help="launch the PySide6 GUI") ui_parser.set_defaults(handler=_cmd_ui) + mcp_parser = subparsers.add_parser( + "mcp", help="serve the action registry as an MCP server over stdio" + ) + mcp_parser.add_argument("--name", default="automation_file") + mcp_parser.add_argument("--version", default="1.0.0") + mcp_parser.add_argument( + "--allowed-actions", + default=None, + help="comma-separated allow list (default: expose every registered action)", + ) + mcp_parser.set_defaults(handler=_cmd_mcp) + drive_parser = subparsers.add_parser("drive-upload", help="upload a file to Google Drive") drive_parser.add_argument("file") drive_parser.add_argument("--token", required=True) diff --git a/automation_file/core/action_queue.py b/automation_file/core/action_queue.py new file mode 100644 index 0000000..79c6ec5 --- /dev/null +++ b/automation_file/core/action_queue.py @@ -0,0 +1,187 @@ +"""Persistent SQLite-backed queue of action payloads. + +Producers call :meth:`ActionQueue.enqueue` to durably store a JSON action list; +consumers pull with :meth:`dequeue` (marking the row ``inflight``) and finalise +with :meth:`ack` or :meth:`nack`. The queue survives process restarts — all +state lives in the SQLite file. +""" + +from __future__ import annotations + +import json +import os +import sqlite3 +import threading +import time +from contextlib import closing +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from automation_file.exceptions import QueueException +from automation_file.logging_config import file_automation_logger + +_SCHEMA = """ +CREATE TABLE IF NOT EXISTS action_queue ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + action TEXT NOT NULL, + priority INTEGER NOT NULL DEFAULT 0, + run_at REAL NOT NULL, + enqueued_at REAL NOT NULL, + attempts INTEGER NOT NULL DEFAULT 0, + status TEXT NOT NULL DEFAULT 'ready', + last_error TEXT +); +CREATE INDEX IF NOT EXISTS idx_queue_ready + ON action_queue (status, run_at, priority DESC, id); +""" + +_STATUS_READY = "ready" +_STATUS_INFLIGHT = "inflight" +_STATUS_DEAD = "dead" + + +@dataclass(frozen=True) +class QueueItem: + """A claimed queue row returned by :meth:`ActionQueue.dequeue`.""" + + id: int + action: list[Any] | dict[str, Any] + attempts: int + enqueued_at: float + + +class ActionQueue: + """Durable FIFO / priority queue for JSON action payloads.""" + + def __init__(self, db_path: str | os.PathLike[str]) -> None: + self._db_path = Path(db_path) + self._lock = threading.Lock() + self._ensure_schema() + + def _connect(self) -> sqlite3.Connection: + self._db_path.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(self._db_path, timeout=5.0, isolation_level=None) + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA busy_timeout=2000") + return conn + + def _ensure_schema(self) -> None: + with closing(self._connect()) as conn: + conn.executescript(_SCHEMA) + + def enqueue( + self, + action: list[Any] | dict[str, Any], + priority: int = 0, + run_at: float | None = None, + ) -> int: + """Persist ``action`` for later dispatch. Returns the row id.""" + if not isinstance(action, list | dict): + raise QueueException("action must be a list or dict") + payload = json.dumps(action, ensure_ascii=False) + now = time.time() + due = run_at if run_at is not None else now + with self._lock, closing(self._connect()) as conn: + cur = conn.execute( + "INSERT INTO action_queue" + " (action, priority, run_at, enqueued_at) VALUES (?, ?, ?, ?)", + (payload, priority, due, now), + ) + row_id = cur.lastrowid + if row_id is None: + raise QueueException("failed to allocate queue row id") + return row_id + + def dequeue(self) -> QueueItem | None: + """Claim the next ready row; returns ``None`` if the queue is empty.""" + now = time.time() + with self._lock, closing(self._connect()) as conn: + try: + conn.execute("BEGIN IMMEDIATE") + row = conn.execute( + "SELECT id, action, attempts, enqueued_at FROM action_queue" + " WHERE status=? AND run_at<=?" + " ORDER BY priority DESC, id ASC LIMIT 1", + (_STATUS_READY, now), + ).fetchone() + if row is None: + conn.execute("ROLLBACK") + return None + row_id, payload, attempts, enqueued_at = row + conn.execute( + "UPDATE action_queue SET status=?, attempts=attempts+1 WHERE id=?", + (_STATUS_INFLIGHT, row_id), + ) + conn.execute("COMMIT") + except sqlite3.OperationalError as error: + raise QueueException(f"dequeue failed: {error}") from error + try: + action = json.loads(payload) + except json.JSONDecodeError as error: + raise QueueException(f"corrupt queue row {row_id}: {error}") from error + return QueueItem( + id=int(row_id), + action=action, + attempts=int(attempts) + 1, + enqueued_at=float(enqueued_at), + ) + + def ack(self, item_id: int) -> None: + """Finalise a claimed row as processed.""" + with self._lock, closing(self._connect()) as conn: + conn.execute("DELETE FROM action_queue WHERE id=?", (item_id,)) + + def nack( + self, + item_id: int, + *, + requeue: bool = True, + reason: str = "", + delay: float = 0.0, + ) -> None: + """Return a claimed row to the queue (``requeue=True``) or mark as dead.""" + next_status = _STATUS_READY if requeue else _STATUS_DEAD + run_at = time.time() + max(delay, 0.0) + with self._lock, closing(self._connect()) as conn: + conn.execute( + "UPDATE action_queue SET status=?, last_error=?, run_at=? WHERE id=?", + (next_status, reason or None, run_at, item_id), + ) + + def size(self, status: str = _STATUS_READY) -> int: + with closing(self._connect()) as conn: + row = conn.execute( + "SELECT COUNT(*) FROM action_queue WHERE status=?", + (status,), + ).fetchone() + return int(row[0]) if row else 0 + + def purge(self) -> int: + """Delete every row (ready / inflight / dead). Returns rows deleted.""" + with self._lock, closing(self._connect()) as conn: + cur = conn.execute("DELETE FROM action_queue") + return int(cur.rowcount or 0) + + def dead_letters(self) -> list[QueueItem]: + with closing(self._connect()) as conn: + rows = conn.execute( + "SELECT id, action, attempts, enqueued_at FROM action_queue WHERE status=?", + (_STATUS_DEAD,), + ).fetchall() + items: list[QueueItem] = [] + for row_id, payload, attempts, enqueued_at in rows: + try: + action = json.loads(payload) + except json.JSONDecodeError: + file_automation_logger.warning("queue: skipping unparseable dead row %s", row_id) + continue + items.append( + QueueItem( + id=int(row_id), + action=action, + attempts=int(attempts), + enqueued_at=float(enqueued_at), + ) + ) + return items diff --git a/automation_file/core/circuit_breaker.py b/automation_file/core/circuit_breaker.py new file mode 100644 index 0000000..c79080a --- /dev/null +++ b/automation_file/core/circuit_breaker.py @@ -0,0 +1,116 @@ +"""Three-state circuit breaker. + +States: CLOSED (normal), OPEN (short-circuit after ``failure_threshold`` +consecutive failures), HALF_OPEN (trial one call after ``recovery_timeout`` +seconds; one success closes, one failure re-opens). Failures are counted +only for exceptions in ``retriable`` — internal errors surface as-is. +""" + +from __future__ import annotations + +import threading +import time +from collections.abc import Callable +from functools import wraps +from typing import Any, TypeVar + +from automation_file.exceptions import CircuitOpenException +from automation_file.logging_config import file_automation_logger + +F = TypeVar("F", bound=Callable[..., Any]) + +_STATE_CLOSED = "closed" +_STATE_OPEN = "open" +_STATE_HALF_OPEN = "half_open" + + +class CircuitBreaker: + """Open-close-half-open breaker. + + ``failure_threshold`` — consecutive failures that trip the breaker. + ``recovery_timeout`` — seconds spent in OPEN before transitioning to HALF_OPEN. + ``retriable`` — exception types counted as failures; other exceptions pass through. + """ + + def __init__( + self, + failure_threshold: int = 5, + recovery_timeout: float = 30.0, + retriable: tuple[type[BaseException], ...] = (Exception,), + name: str = "circuit", + ) -> None: + if failure_threshold < 1: + raise ValueError("failure_threshold must be >= 1") + if recovery_timeout <= 0: + raise ValueError("recovery_timeout must be > 0") + self._failure_threshold = failure_threshold + self._recovery_timeout = float(recovery_timeout) + self._retriable = retriable + self._name = name + self._state = _STATE_CLOSED + self._failures = 0 + self._opened_at = 0.0 + self._lock = threading.Lock() + + @property + def state(self) -> str: + with self._lock: + self._maybe_transition_locked() + return self._state + + def _maybe_transition_locked(self) -> None: + if self._state == _STATE_OPEN and ( + time.monotonic() - self._opened_at >= self._recovery_timeout + ): + self._state = _STATE_HALF_OPEN + file_automation_logger.info("circuit %s: open -> half_open", self._name) + + def _on_success_locked(self) -> None: + if self._state == _STATE_HALF_OPEN: + file_automation_logger.info("circuit %s: half_open -> closed", self._name) + self._state = _STATE_CLOSED + self._failures = 0 + + def _on_failure_locked(self) -> None: + self._failures += 1 + if self._state == _STATE_HALF_OPEN or self._failures >= self._failure_threshold: + self._state = _STATE_OPEN + self._opened_at = time.monotonic() + file_automation_logger.warning( + "circuit %s: opened after %d failures", self._name, self._failures + ) + + def call(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + """Invoke ``func`` through the breaker.""" + with self._lock: + self._maybe_transition_locked() + if self._state == _STATE_OPEN: + raise CircuitOpenException(f"circuit {self._name!r} is open") + try: + result = func(*args, **kwargs) + except self._retriable as error: + with self._lock: + self._on_failure_locked() + raise error + with self._lock: + self._on_success_locked() + return result + + def wraps(self) -> Callable[[F], F]: + """Return a decorator that routes every call through :meth:`call`.""" + + def decorator(func: F) -> F: + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + return self.call(func, *args, **kwargs) + + return wrapper # type: ignore[return-value] + + return decorator + + def reset(self) -> None: + """Force the breaker back to CLOSED, clearing failure count.""" + with self._lock: + self._state = _STATE_CLOSED + self._failures = 0 + self._opened_at = 0.0 diff --git a/automation_file/core/content_store.py b/automation_file/core/content_store.py new file mode 100644 index 0000000..7d942e7 --- /dev/null +++ b/automation_file/core/content_store.py @@ -0,0 +1,128 @@ +"""SHA-256 content-addressable store. + +:class:`ContentStore` ingests files or byte blobs and keys them by the hex +digest of their contents. A two-character fanout directory keeps any single +directory small: ``/ab/abcdef…``. Identical inputs map to the same blob — +callers get deduplication for free. +""" + +from __future__ import annotations + +import hashlib +import os +import shutil +from collections.abc import Callable, Iterator +from pathlib import Path +from typing import IO + +from automation_file.exceptions import CASException + +_HASH = "sha256" +_FANOUT = 2 +_CHUNK = 1 << 20 + + +class ContentStore: + """Filesystem-backed CAS under ``root``.""" + + def __init__(self, root: str | os.PathLike[str]) -> None: + self._root = Path(root) + self._root.mkdir(parents=True, exist_ok=True) + + @property + def root(self) -> Path: + return self._root + + def path_for(self, digest: str) -> Path: + if len(digest) < _FANOUT + 1 or not all(c in "0123456789abcdef" for c in digest): + raise CASException(f"invalid digest {digest!r}") + return self._root / digest[:_FANOUT] / digest + + def exists(self, digest: str) -> bool: + return self.path_for(digest).is_file() + + def put(self, source: str | os.PathLike[str]) -> str: + """Ingest the file at ``source`` and return its hex digest.""" + src = Path(source) + if not src.is_file(): + raise CASException(f"source is not a file: {src}") + digest = self._hash_file(src) + target = self.path_for(digest) + if not target.exists(): + self._write_atomic(target, lambda tmp: _copyfile(src, tmp)) + return digest + + def put_bytes(self, data: bytes) -> str: + """Ingest raw bytes and return the hex digest.""" + digest = hashlib.new(_HASH, data).hexdigest() + target = self.path_for(digest) + if not target.exists(): + self._write_atomic(target, lambda tmp: _write_bytes(tmp, data)) + return digest + + def _write_atomic(self, target: Path, writer: Callable[[Path], None]) -> None: + target.parent.mkdir(parents=True, exist_ok=True) + tmp = target.with_suffix(target.suffix + ".tmp") + try: + writer(tmp) + os.replace(tmp, target) + except OSError as error: + if tmp.exists(): + tmp.unlink(missing_ok=True) + raise CASException(f"failed to store blob: {error}") from error + + def open(self, digest: str) -> IO[bytes]: + """Open the stored blob for binary read.""" + path = self.path_for(digest) + if not path.is_file(): + raise CASException(f"missing blob {digest}") + return open(path, "rb") + + def copy_to(self, digest: str, destination: str | os.PathLike[str]) -> Path: + """Copy the blob at ``digest`` into ``destination``. Returns the path.""" + src = self.path_for(digest) + if not src.is_file(): + raise CASException(f"missing blob {digest}") + dest = Path(destination) + dest.parent.mkdir(parents=True, exist_ok=True) + shutil.copyfile(src, dest) + return dest + + def delete(self, digest: str) -> bool: + """Remove a blob. Returns True when the blob existed.""" + path = self.path_for(digest) + if not path.is_file(): + return False + path.unlink() + return True + + def iter_digests(self) -> Iterator[str]: + """Yield the digest of every stored blob.""" + if not self._root.exists(): + return + for bucket in self._root.iterdir(): + if not bucket.is_dir() or len(bucket.name) != _FANOUT: + continue + for blob in bucket.iterdir(): + if blob.is_file() and blob.name.startswith(bucket.name): + yield blob.name + + def size(self) -> int: + """Return the total number of stored blobs.""" + return sum(1 for _ in self.iter_digests()) + + def _hash_file(self, path: Path) -> str: + hasher = hashlib.new(_HASH) + with open(path, "rb") as fh: + for chunk in iter(lambda: fh.read(_CHUNK), b""): + hasher.update(chunk) + return hasher.hexdigest() + + +def _write_bytes(target: Path, data: bytes) -> None: + with open(target, "wb") as fh: + fh.write(data) + + +def _copyfile(src: Path, dst: Path) -> None: + shutil.copyfile(src, dst) diff --git a/automation_file/core/file_lock.py b/automation_file/core/file_lock.py new file mode 100644 index 0000000..1ac227f --- /dev/null +++ b/automation_file/core/file_lock.py @@ -0,0 +1,128 @@ +"""Cross-platform advisory file lock. + +Uses ``fcntl.flock`` on POSIX and ``msvcrt.locking`` on Windows, so two processes +can serialise on a well-known lock path. Locks are exclusive (writer-style); +shared locks are not supported because ``msvcrt`` cannot express them portably. +""" + +from __future__ import annotations + +import contextlib +import os +import sys +import threading +import time +from pathlib import Path +from types import TracebackType +from typing import IO + +from automation_file.exceptions import LockTimeoutException + +_POLL_INTERVAL = 0.05 + + +class FileLock: + """Advisory exclusive lock on a sidecar lock file. + + ``path`` is the lock file itself — typically ``.lock`` next to the + protected resource. ``timeout`` is the maximum seconds to wait when + acquiring; ``None`` waits indefinitely, ``0`` fails immediately. + """ + + def __init__(self, path: str | os.PathLike[str], timeout: float | None = None) -> None: + self._path = Path(path) + self._timeout = timeout + self._fh: IO[bytes] | None = None + self._thread_lock = threading.Lock() + self._owned = False + + @property + def path(self) -> Path: + return self._path + + @property + def is_held(self) -> bool: + return self._owned + + def acquire(self) -> None: + """Block until the lock is held. Raises :class:`LockTimeoutException` on timeout.""" + with self._thread_lock: + if self._owned: + raise LockTimeoutException(f"lock {self._path} already held by this instance") + self._path.parent.mkdir(parents=True, exist_ok=True) + # pylint: disable=consider-using-with + fh = open(self._path, "a+b") # noqa: SIM115 — held across acquire/release + try: + self._acquire_os_lock(fh) + except BaseException: + fh.close() + raise + self._fh = fh + self._owned = True + + def release(self) -> None: + """Release the lock; idempotent.""" + with self._thread_lock: + if not self._owned or self._fh is None: + return + try: + self._release_os_lock(self._fh) + finally: + self._fh.close() + self._fh = None + self._owned = False + + def __enter__(self) -> FileLock: + self.acquire() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + self.release() + + def _acquire_os_lock(self, fh: IO[bytes]) -> None: + deadline = None if self._timeout is None else time.monotonic() + self._timeout + while True: + if _try_lock(fh): + return + if deadline is not None and time.monotonic() >= deadline: + raise LockTimeoutException( + f"timed out acquiring lock {self._path} after {self._timeout}s" + ) + time.sleep(_POLL_INTERVAL) + + def _release_os_lock(self, fh: IO[bytes]) -> None: + _unlock(fh) + + +if sys.platform == "win32": + import msvcrt + + def _try_lock(fh: IO[bytes]) -> bool: + try: + msvcrt.locking(fh.fileno(), msvcrt.LK_NBLCK, 1) + return True + except OSError: + return False + + def _unlock(fh: IO[bytes]) -> None: + with contextlib.suppress(OSError): + fh.seek(0) + msvcrt.locking(fh.fileno(), msvcrt.LK_UNLCK, 1) +else: + import fcntl + + def _try_lock(fh: IO[bytes]) -> bool: + try: + fcntl.flock(fh.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) + return True + except OSError: + return False + + def _unlock(fh: IO[bytes]) -> None: + with contextlib.suppress(OSError): + fcntl.flock(fh.fileno(), fcntl.LOCK_UN) diff --git a/automation_file/core/rate_limit.py b/automation_file/core/rate_limit.py new file mode 100644 index 0000000..cd45116 --- /dev/null +++ b/automation_file/core/rate_limit.py @@ -0,0 +1,98 @@ +"""Token-bucket rate limiter. + +:class:`RateLimiter` refills at ``rate`` tokens/second up to a burst capacity. +Callers acquire N tokens before issuing a protected call; when empty, the +limiter either blocks (up to ``timeout``) or raises +:class:`RateLimitExceededException`. +""" + +from __future__ import annotations + +import threading +import time +from collections.abc import Callable +from functools import wraps +from typing import Any, TypeVar + +from automation_file.exceptions import RateLimitExceededException + +F = TypeVar("F", bound=Callable[..., Any]) + + +class RateLimiter: + """Thread-safe token bucket.""" + + def __init__(self, rate: float, burst: float | None = None) -> None: + if rate <= 0: + raise ValueError("rate must be > 0") + cap = float(burst) if burst is not None else float(rate) + if cap <= 0: + raise ValueError("burst must be > 0") + self._rate = float(rate) + self._capacity = cap + self._tokens = cap + self._updated = time.monotonic() + self._cv = threading.Condition() + + @property + def capacity(self) -> float: + return self._capacity + + def _refill_locked(self) -> None: + now = time.monotonic() + elapsed = now - self._updated + if elapsed > 0: + self._tokens = min(self._capacity, self._tokens + elapsed * self._rate) + self._updated = now + + def try_acquire(self, tokens: float = 1.0) -> bool: + """Take ``tokens`` without blocking. Return True on success.""" + if tokens <= 0: + raise ValueError("tokens must be > 0") + with self._cv: + self._refill_locked() + if self._tokens >= tokens: + self._tokens -= tokens + return True + return False + + def acquire(self, tokens: float = 1.0, timeout: float | None = None) -> None: + """Block until ``tokens`` are available. + + Raises :class:`RateLimitExceededException` if ``timeout`` elapses first. + ``timeout=None`` waits indefinitely; ``timeout=0`` fails immediately. + """ + if tokens <= 0: + raise ValueError("tokens must be > 0") + if tokens > self._capacity: + raise ValueError(f"tokens {tokens} exceeds capacity {self._capacity}") + deadline = None if timeout is None else time.monotonic() + timeout + with self._cv: + while True: + self._refill_locked() + if self._tokens >= tokens: + self._tokens -= tokens + return + needed = tokens - self._tokens + wait_for = needed / self._rate + if deadline is not None: + remaining = deadline - time.monotonic() + if remaining <= 0: + raise RateLimitExceededException( + f"rate limit: could not acquire {tokens} tokens within timeout" + ) + wait_for = min(wait_for, remaining) + self._cv.wait(timeout=wait_for) + + def wraps(self, tokens: float = 1.0, timeout: float | None = None) -> Callable[[F], F]: + """Return a decorator that acquires ``tokens`` before each call.""" + + def decorator(func: F) -> F: + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + self.acquire(tokens=tokens, timeout=timeout) + return func(*args, **kwargs) + + return wrapper # type: ignore[return-value] + + return decorator diff --git a/automation_file/core/sqlite_lock.py b/automation_file/core/sqlite_lock.py new file mode 100644 index 0000000..551d256 --- /dev/null +++ b/automation_file/core/sqlite_lock.py @@ -0,0 +1,158 @@ +"""SQLite-backed named lock for multi-process / multi-host coordination. + +Unlike :class:`automation_file.core.file_lock.FileLock` which locks a single +file descriptor, :class:`SQLiteLock` persists named leases in a shared SQLite +database. Any process that can open the database can participate. Leases carry +an optional TTL so crashed owners eventually free the slot. +""" + +from __future__ import annotations + +import os +import sqlite3 +import threading +import time +import uuid +from contextlib import closing +from pathlib import Path +from types import TracebackType + +from automation_file.exceptions import LockTimeoutException + +_SCHEMA = """ +CREATE TABLE IF NOT EXISTS automation_locks ( + name TEXT PRIMARY KEY, + owner TEXT NOT NULL, + acquired_at REAL NOT NULL, + expires_at REAL +) +""" +_POLL_INTERVAL = 0.05 + + +class SQLiteLock: + """Named lease stored in SQLite. + + ``db_path`` is the SQLite file — callers sharing a lock must point at the + same file. ``name`` is the lock identity. ``ttl`` (seconds) lets a crashed + owner's lease expire; ``None`` means the lease is held until explicit + release. ``timeout`` bounds acquisition wait. + """ + + def __init__( + self, + db_path: str | os.PathLike[str], + name: str, + timeout: float | None = None, + ttl: float | None = None, + ) -> None: + if not name: + raise ValueError("lock name must be non-empty") + if ttl is not None and ttl <= 0: + raise ValueError("ttl must be > 0 when set") + self._db_path = Path(db_path) + self._name = name + self._timeout = timeout + self._ttl = ttl + self._owner = uuid.uuid4().hex + self._held = False + self._thread_lock = threading.Lock() + self._ensure_schema() + + @property + def owner(self) -> str: + return self._owner + + @property + def is_held(self) -> bool: + return self._held + + def _connect(self) -> sqlite3.Connection: + self._db_path.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(self._db_path, timeout=5.0, isolation_level=None) + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA busy_timeout=2000") + return conn + + def _ensure_schema(self) -> None: + with closing(self._connect()) as conn: + conn.execute(_SCHEMA) + + def acquire(self) -> None: + """Block until the lease is granted; raise :class:`LockTimeoutException` on timeout.""" + with self._thread_lock: + if self._held: + raise LockTimeoutException(f"lock {self._name!r} already held by this instance") + deadline = None if self._timeout is None else time.monotonic() + self._timeout + while True: + if self._try_claim(): + self._held = True + return + if deadline is not None and time.monotonic() >= deadline: + raise LockTimeoutException( + f"timed out acquiring lock {self._name!r} after {self._timeout}s" + ) + time.sleep(_POLL_INTERVAL) + + def _try_claim(self) -> bool: + now = time.time() + expires = now + self._ttl if self._ttl is not None else None + with closing(self._connect()) as conn: + try: + conn.execute("BEGIN IMMEDIATE") + row = conn.execute( + "SELECT owner, expires_at FROM automation_locks WHERE name=?", + (self._name,), + ).fetchone() + if row is not None: + _, row_expires = row + if row_expires is None or row_expires > now: + conn.execute("ROLLBACK") + return False + conn.execute( + "INSERT OR REPLACE INTO automation_locks" + " (name, owner, acquired_at, expires_at) VALUES (?, ?, ?, ?)", + (self._name, self._owner, now, expires), + ) + conn.execute("COMMIT") + return True + except sqlite3.OperationalError: + return False + + def release(self) -> None: + """Release the lease; idempotent. Only the owning instance removes the row.""" + with self._thread_lock: + if not self._held: + return + with closing(self._connect()) as conn: + conn.execute( + "DELETE FROM automation_locks WHERE name=? AND owner=?", + (self._name, self._owner), + ) + self._held = False + + def refresh(self) -> None: + """Extend the lease by ``ttl`` seconds. No-op when ttl is unset.""" + if self._ttl is None: + return + with self._thread_lock: + if not self._held: + return + now = time.time() + with closing(self._connect()) as conn: + conn.execute( + "UPDATE automation_locks SET expires_at=? WHERE name=? AND owner=?", + (now + self._ttl, self._name, self._owner), + ) + + def __enter__(self) -> SQLiteLock: + self.acquire() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + self.release() diff --git a/automation_file/exceptions.py b/automation_file/exceptions.py index 8bdecdc..cfa8a6d 100644 --- a/automation_file/exceptions.py +++ b/automation_file/exceptions.py @@ -71,6 +71,58 @@ class DagException(FileAutomationException): """Raised when a DAG action list has a cycle, unknown dep, or duplicate id.""" +class RateLimitExceededException(FileAutomationException): + """Raised when a rate-limited call cannot acquire a token in the allotted wait.""" + + +class CircuitOpenException(FileAutomationException): + """Raised when a circuit breaker is open and short-circuits the protected call.""" + + +class LockTimeoutException(FileAutomationException): + """Raised when a lock acquire waits past its timeout.""" + + +class QueueException(FileAutomationException): + """Raised by the persistent action queue on storage / dispatch errors.""" + + +class CASException(FileAutomationException): + """Raised by the content-addressable store on integrity / I/O failures.""" + + +class TemplateException(FileAutomationException): + """Raised when template rendering fails (missing engine, syntax, I/O).""" + + +class DiffException(FileAutomationException): + """Raised when diff computation or patch application fails.""" + + +class VersioningException(FileAutomationException): + """Raised by the versioning helpers on retention / I/O failures.""" + + +class ArchiveException(FileAutomationException): + """Raised when an archive format is unsupported or extraction fails.""" + + +class WebDAVException(FileAutomationException): + """Raised by the WebDAV client on transport / protocol failures.""" + + +class SMBException(FileAutomationException): + """Raised by the SMB/CIFS client on connection / protocol failures.""" + + +class MCPServerException(FileAutomationException): + """Raised by the MCP server bridge when a tool invocation fails.""" + + +class FsspecException(FileAutomationException): + """Raised by the fsspec bridge on missing dependency or backend failures.""" + + _ARGPARSE_EMPTY_MESSAGE = "argparse received no actionable argument" _BAD_TRIGGER_FUNCTION = "trigger name is not registered in the executor" _BAD_CALLBACK_METHOD = "callback_param_method must be 'kwargs' or 'args'" diff --git a/automation_file/local/archive_ops.py b/automation_file/local/archive_ops.py new file mode 100644 index 0000000..92eedb5 --- /dev/null +++ b/automation_file/local/archive_ops.py @@ -0,0 +1,199 @@ +"""Archive format auto-detect and safe extraction. + +Covers ZIP and the tar family (plain, gzip, bzip2, xz) from the stdlib. Adds +optional read-only support for 7z (via ``py7zr``) and RAR (via ``rarfile``) if +those packages are installed. Extraction refuses paths that escape the target +root — same guarantee as :func:`automation_file.local.safe_paths.safe_join`. +""" + +from __future__ import annotations + +import os +import tarfile +import zipfile +from collections.abc import Iterable +from pathlib import Path + +from automation_file.exceptions import ArchiveException +from automation_file.local.safe_paths import is_within + +_ZIP_SIG = b"PK\x03\x04" +_SEVEN_ZIP_SIG = b"7z\xbc\xaf\x27\x1c" +_RAR_SIG_V4 = b"Rar!\x1a\x07\x00" +_RAR_SIG_V5 = b"Rar!\x1a\x07\x01\x00" +_GZIP_SIG = b"\x1f\x8b" +_BZ2_SIG = b"BZh" +_XZ_SIG = b"\xfd7zXZ\x00" + + +def detect_archive_format(path: str | os.PathLike[str]) -> str: + """Return one of zip / tar / 7z / rar / gz / bz2 / xz based on magic bytes.""" + src = Path(path) + if not src.is_file(): + raise ArchiveException(f"not a file: {src}") + with open(src, "rb") as fh: + head = fh.read(262) + if head.startswith(_ZIP_SIG): + return "zip" + if head.startswith(_SEVEN_ZIP_SIG): + return "7z" + if head.startswith(_RAR_SIG_V5) or head.startswith(_RAR_SIG_V4): + return "rar" + if head.startswith(_XZ_SIG): + return "tar.xz" if _is_tar_stream(src, "xz") else "xz" + if head.startswith(_BZ2_SIG): + return "tar.bz2" if _is_tar_stream(src, "bz2") else "bz2" + if head.startswith(_GZIP_SIG): + return "tar.gz" if _is_tar_stream(src, "gz") else "gz" + if tarfile.is_tarfile(src): + return "tar" + raise ArchiveException(f"unsupported archive format: {src}") + + +def list_archive(path: str | os.PathLike[str]) -> list[str]: + """Return the entry names inside ``path``.""" + fmt = detect_archive_format(path) + if fmt == "zip": + with zipfile.ZipFile(path) as zf: + return zf.namelist() + if fmt.startswith("tar"): + with tarfile.open(path) as tf: # nosec B202 # NOSONAR metadata listing only, no extraction + return tf.getnames() + if fmt == "7z": + return _seven_zip_namelist(path) + if fmt == "rar": + return _rar_namelist(path) + raise ArchiveException(f"listing not supported for format {fmt!r}") + + +def extract_archive( + source: str | os.PathLike[str], + target: str | os.PathLike[str], +) -> list[str]: + """Extract ``source`` into ``target``. Returns the list of extracted names.""" + fmt = detect_archive_format(source) + dest = Path(target) + dest.mkdir(parents=True, exist_ok=True) + if fmt == "zip": + return _extract_zip(Path(source), dest) + if fmt.startswith("tar"): + return _extract_tar(Path(source), dest) + if fmt == "7z": + return _extract_seven_zip(Path(source), dest) + if fmt == "rar": + return _extract_rar(Path(source), dest) + raise ArchiveException(f"extraction not supported for format {fmt!r}") + + +def _is_tar_stream(path: Path, compression: str) -> bool: + try: + if compression == "gz": + with tarfile.open(path, mode="r:gz"): # nosec # NOSONAR read-only probe + return True + if compression == "bz2": + with tarfile.open(path, mode="r:bz2"): # nosec # NOSONAR read-only probe + return True + if compression == "xz": + with tarfile.open(path, mode="r:xz"): # nosec # NOSONAR read-only probe + return True + except (tarfile.TarError, OSError): + return False + return False + + +def _extract_zip(source: Path, dest: Path) -> list[str]: + names: list[str] = [] + with zipfile.ZipFile(source) as zf: + for info in zf.infolist(): + out = dest / info.filename + _ensure_within(dest, out) + if info.is_dir(): + out.mkdir(parents=True, exist_ok=True) + continue + out.parent.mkdir(parents=True, exist_ok=True) + with zf.open(info) as src_fh, open(out, "wb") as dst_fh: + while True: + chunk = src_fh.read(1 << 20) + if not chunk: + break + dst_fh.write(chunk) + names.append(info.filename) + return names + + +def _extract_tar(source: Path, dest: Path) -> list[str]: + names: list[str] = [] + # Per-member path containment + link rejection below; on 3.12+ the + # tarfile.data_filter enforces the same rules at the C layer. + with tarfile.open(source) as tf: # nosec B202 # NOSONAR entries validated before extract + _apply_tar_data_filter(tf) + for member in tf.getmembers(): + out = dest / member.name + _ensure_within(dest, out) + if member.islnk() or member.issym(): + raise ArchiveException(f"refusing to extract link: {member.name}") + tf.extract(member, dest) + names.append(member.name) + return names + + +def _apply_tar_data_filter(tf: tarfile.TarFile) -> None: + data_filter = getattr(tarfile, "data_filter", None) + if data_filter is not None: + tf.extraction_filter = data_filter + + +def _extract_seven_zip(source: Path, dest: Path) -> list[str]: + try: + import py7zr + except ImportError as error: + raise ArchiveException("py7zr is required for 7z extraction") from error + with py7zr.SevenZipFile(source, mode="r") as archive: + names = archive.getnames() + for name in names: + _ensure_within(dest, dest / name) + # Every entry name has been validated via _ensure_within above. + archive.extractall(path=dest) # nosec B202 - entries validated before extract + return list(names) + + +def _extract_rar(source: Path, dest: Path) -> list[str]: + try: + import rarfile + except ImportError as error: + raise ArchiveException("rarfile is required for RAR extraction") from error + with rarfile.RarFile(source) as archive: + names = archive.namelist() + for name in names: + _ensure_within(dest, dest / name) + # Every entry name has been validated via _ensure_within above. + archive.extractall(path=str(dest)) # nosec B202 - entries validated before extract + return list(names) + + +def _seven_zip_namelist(path: str | os.PathLike[str]) -> list[str]: + try: + import py7zr + except ImportError as error: + raise ArchiveException("py7zr is required to list 7z contents") from error + with py7zr.SevenZipFile(path, mode="r") as archive: + return list(archive.getnames()) + + +def _rar_namelist(path: str | os.PathLike[str]) -> list[str]: + try: + import rarfile + except ImportError as error: + raise ArchiveException("rarfile is required to list RAR contents") from error + with rarfile.RarFile(path) as archive: + return list(archive.namelist()) + + +def _ensure_within(root: Path, candidate: Path) -> None: + if not is_within(root, candidate): + raise ArchiveException(f"archive entry escapes target root: {candidate}") + + +def supported_formats() -> Iterable[str]: + """Return the archive tags this module can detect.""" + return ("zip", "tar", "tar.gz", "tar.bz2", "tar.xz", "gz", "bz2", "xz", "7z", "rar") diff --git a/automation_file/local/diff_ops.py b/automation_file/local/diff_ops.py new file mode 100644 index 0000000..73ad698 --- /dev/null +++ b/automation_file/local/diff_ops.py @@ -0,0 +1,140 @@ +"""Directory and file diff / patch helpers. + +:func:`diff_dirs` walks two trees and reports files that were added, removed, +or changed by content hash. :func:`apply_dir_diff` replays that diff against a +target tree, copying or deleting as needed. Text-file differences are rendered +as unified diffs with :func:`diff_text_files`. +""" + +from __future__ import annotations + +import difflib +import hashlib +import os +import shutil +from collections.abc import Iterable +from dataclasses import dataclass, field +from pathlib import Path + +from automation_file.exceptions import DiffException +from automation_file.local.safe_paths import safe_join + +_HASH = "sha256" +_CHUNK = 1 << 20 + + +@dataclass(frozen=True) +class DirDiff: + """Summary of differences between two directory trees. + + Paths are POSIX-style strings relative to the diff root. + """ + + added: tuple[str, ...] = field(default_factory=tuple) + removed: tuple[str, ...] = field(default_factory=tuple) + changed: tuple[str, ...] = field(default_factory=tuple) + + def is_empty(self) -> bool: + return not (self.added or self.removed or self.changed) + + +def diff_dirs(left: str | os.PathLike[str], right: str | os.PathLike[str]) -> DirDiff: + """Compute the content diff going from ``left`` to ``right``.""" + left_path = Path(left) + right_path = Path(right) + if not left_path.is_dir(): + raise DiffException(f"left is not a directory: {left_path}") + if not right_path.is_dir(): + raise DiffException(f"right is not a directory: {right_path}") + left_files = _relative_files(left_path) + right_files = _relative_files(right_path) + added = tuple(sorted(right_files - left_files)) + removed = tuple(sorted(left_files - right_files)) + changed = tuple( + sorted( + rel + for rel in left_files & right_files + if _hash_file(left_path / rel) != _hash_file(right_path / rel) + ) + ) + return DirDiff(added=added, removed=removed, changed=changed) + + +def apply_dir_diff( + diff: DirDiff, + target: str | os.PathLike[str], + source: str | os.PathLike[str], +) -> None: + """Apply ``diff`` (generated relative to ``source``) onto ``target``. + + Added and changed files are copied from ``source``; removed files are + deleted from ``target``. All target-side paths are constrained with + :func:`safe_join` to prevent escape via symlink or ``..`` segments. + """ + source_path = Path(source) + target_path = Path(target) + if not source_path.is_dir(): + raise DiffException(f"source is not a directory: {source_path}") + target_path.mkdir(parents=True, exist_ok=True) + for rel in (*diff.added, *diff.changed): + dest = safe_join(target_path, rel) + src = source_path / rel + if not src.is_file(): + raise DiffException(f"patch source missing: {src}") + dest.parent.mkdir(parents=True, exist_ok=True) + shutil.copyfile(src, dest) + for rel in diff.removed: + dest = safe_join(target_path, rel) + if dest.is_file(): + dest.unlink() + + +def diff_text_files( + left: str | os.PathLike[str], + right: str | os.PathLike[str], + *, + context: int = 3, +) -> str: + """Return a unified diff between two text files.""" + left_path = Path(left) + right_path = Path(right) + try: + left_lines = left_path.read_text(encoding="utf-8").splitlines(keepends=True) + right_lines = right_path.read_text(encoding="utf-8").splitlines(keepends=True) + except OSError as error: + raise DiffException(f"cannot read diff inputs: {error}") from error + diff_lines = difflib.unified_diff( + left_lines, + right_lines, + fromfile=str(left_path), + tofile=str(right_path), + n=context, + ) + return "".join(diff_lines) + + +def _relative_files(root: Path) -> set[str]: + collected: set[str] = set() + for dirpath, _dirnames, filenames in os.walk(root, followlinks=False): + for name in filenames: + rel = Path(dirpath, name).relative_to(root) + collected.add(rel.as_posix()) + return collected + + +def _hash_file(path: Path) -> str: + hasher = hashlib.new(_HASH) + with open(path, "rb") as fh: + for chunk in iter(lambda: fh.read(_CHUNK), b""): + hasher.update(chunk) + return hasher.hexdigest() + + +def iter_dir_diff(diff: DirDiff) -> Iterable[tuple[str, str]]: + """Yield ``(kind, rel_path)`` for every change in ``diff``.""" + for rel in diff.added: + yield "added", rel + for rel in diff.removed: + yield "removed", rel + for rel in diff.changed: + yield "changed", rel diff --git a/automation_file/local/mime.py b/automation_file/local/mime.py new file mode 100644 index 0000000..f92a502 --- /dev/null +++ b/automation_file/local/mime.py @@ -0,0 +1,82 @@ +"""MIME type detection by extension plus magic-byte sniffing. + +The stdlib ``mimetypes`` module covers the common cases from the filename +alone. For ambiguous or extensionless files, :func:`detect_mime` peeks at the +first few bytes and recognises a small set of well-known signatures. +""" + +from __future__ import annotations + +import mimetypes +import os +from pathlib import Path + +_SNIFF_LEN = 16 +_OCTET_STREAM = "application/octet-stream" +_SIGNATURES: tuple[tuple[bytes, str], ...] = ( + (b"\x89PNG\r\n\x1a\n", "image/png"), + (b"\xff\xd8\xff", "image/jpeg"), + (b"GIF87a", "image/gif"), + (b"GIF89a", "image/gif"), + (b"%PDF-", "application/pdf"), + (b"PK\x03\x04", "application/zip"), + (b"\x1f\x8b", "application/gzip"), + (b"BZh", "application/x-bzip2"), + (b"\xfd7zXZ\x00", "application/x-xz"), + (b"7z\xbc\xaf\x27\x1c", "application/x-7z-compressed"), + (b"Rar!\x1a\x07\x00", "application/vnd.rar"), + (b"Rar!\x1a\x07\x01\x00", "application/vnd.rar"), + (b"RIFF", _OCTET_STREAM), # overridden below for wav/webp + (b"\x00\x00\x00 ftyp", "video/mp4"), + (b"OggS", "application/ogg"), + (b"ID3", "audio/mpeg"), + (b"\xff\xfb", "audio/mpeg"), + (b"{\\rtf", "application/rtf"), + (b"SQLite format 3\x00", "application/vnd.sqlite3"), +) + + +def detect_mime(path: str | os.PathLike[str]) -> str: + """Return the most specific MIME type we can determine for ``path``. + + Tries filename-based detection first; on miss or ambiguous result + (``application/octet-stream``), sniffs the first ``_SNIFF_LEN`` bytes. + """ + p = Path(path) + guessed, _ = mimetypes.guess_type(p.name) + if guessed: + return guessed + sniffed = _sniff(p) + return sniffed or _OCTET_STREAM + + +def detect_from_bytes(data: bytes) -> str: + """MIME type of a byte blob using magic-byte sniffing.""" + mime = _match_signatures(data) + return mime or _OCTET_STREAM + + +def _sniff(path: Path) -> str | None: + if not path.is_file(): + return None + try: + with open(path, "rb") as fh: + head = fh.read(_SNIFF_LEN) + except OSError: + return None + return _match_signatures(head) + + +def _match_signatures(head: bytes) -> str | None: + if head.startswith(b"RIFF") and len(head) >= 12: + tag = head[8:12] + if tag == b"WAVE": + return "audio/wav" + if tag == b"WEBP": + return "image/webp" + for signature, mime in _SIGNATURES: + if signature == b"RIFF": + continue + if head.startswith(signature): + return mime + return None diff --git a/automation_file/local/tar_ops.py b/automation_file/local/tar_ops.py index 672e74e..36ab503 100644 --- a/automation_file/local/tar_ops.py +++ b/automation_file/local/tar_ops.py @@ -57,7 +57,8 @@ def create_tar( target_path.parent.mkdir(parents=True, exist_ok=True) try: - with tarfile.open(str(target_path), mode) as archive: + # Write mode — not extraction. + with tarfile.open(str(target_path), mode) as archive: # NOSONAR python:S5042 archive.add(str(src_path), arcname=src_path.name) except (OSError, tarfile.TarError) as err: raise TarException(f"create_tar failed: {err}") from err @@ -76,7 +77,9 @@ def extract_tar(source: str, target_dir: str) -> list[str]: extracted: list[str] = [] try: - with tarfile.open(str(src_path), "r:*") as archive: + # _verify_members rejects traversal / escaping symlinks / hardlinks before any + # extract, and PEP 706 filter="data" is applied when available (3.10.12+ / 3.11.4+ / 3.12+). + with tarfile.open(str(src_path), "r:*") as archive: # NOSONAR python:S5042 _verify_members(archive, dest) for member in archive.getmembers(): if _TAR_FILTER_SUPPORTED: diff --git a/automation_file/local/templates.py b/automation_file/local/templates.py new file mode 100644 index 0000000..34052ab --- /dev/null +++ b/automation_file/local/templates.py @@ -0,0 +1,125 @@ +"""Template rendering for file content generation. + +Supports Jinja2 when installed (richer control flow, filters) and falls back +to :class:`string.Template` for simple ``$var`` substitution. Renders can +target an output path or return the resulting string. +""" + +from __future__ import annotations + +import os +import string +from pathlib import Path +from typing import Any + +from automation_file.exceptions import TemplateException + + +def render_string( + template: str, + context: dict[str, Any], + *, + use_jinja: bool = True, + autoescape: bool = True, +) -> str: + """Render ``template`` against ``context`` and return the string result. + + ``autoescape=True`` (the default) HTML-escapes substituted values when the + Jinja2 engine is used; pass ``autoescape=False`` only when the output is + known to target a non-HTML format. + """ + if use_jinja: + rendered = _render_with_jinja(template, context, autoescape=autoescape) + if rendered is not None: + return rendered + return _render_with_stdlib(template, context) + + +def render_file( + template_path: str | os.PathLike[str], + context: dict[str, Any], + output_path: str | os.PathLike[str] | None = None, + *, + use_jinja: bool | None = None, + autoescape: bool | None = None, +) -> str: + """Read a template file, render it, optionally write the result. + + ``use_jinja=None`` auto-detects: ``.j2`` / ``.jinja`` / ``.jinja2`` suffixes + opt in to Jinja2; other extensions use :class:`string.Template`. + + ``autoescape=None`` auto-detects based on the output path suffix — HTML / + XML targets enable escaping, others disable it. Pass a bool to override. + """ + src = Path(template_path) + if not src.is_file(): + raise TemplateException(f"template not found: {src}") + try: + source = src.read_text(encoding="utf-8") + except OSError as error: + raise TemplateException(f"cannot read template {src}: {error}") from error + jinja = _wants_jinja(src) if use_jinja is None else use_jinja + escape = _wants_autoescape(output_path, src) if autoescape is None else autoescape + rendered = render_string(source, context, use_jinja=jinja, autoescape=escape) + if output_path is not None: + dest = Path(output_path) + dest.parent.mkdir(parents=True, exist_ok=True) + dest.write_text(rendered, encoding="utf-8") + return rendered + + +_HTML_SUFFIXES = frozenset({".html", ".htm", ".xhtml", ".xml"}) + + +def _wants_jinja(path: Path) -> bool: + return path.suffix.lower() in {".j2", ".jinja", ".jinja2"} + + +def _wants_autoescape( + output_path: str | os.PathLike[str] | None, + template_path: Path, +) -> bool: + target = Path(output_path) if output_path is not None else template_path + suffix = target.suffix.lower() + if suffix in {".j2", ".jinja", ".jinja2"}: + suffix = target.with_suffix("").suffix.lower() + return suffix in _HTML_SUFFIXES + + +def _render_with_jinja( + template: str, + context: dict[str, Any], + *, + autoescape: bool, +) -> str | None: + try: + from jinja2 import StrictUndefined + from jinja2 import TemplateError as JinjaTemplateError + from jinja2.sandbox import ImmutableSandboxedEnvironment + from markupsafe import Markup + except ImportError: + return None + # ImmutableSandboxedEnvironment blocks access to Python internals + # (__class__, __globals__, __mro__, mutation of passed collections, …) so + # that a caller passing a user-supplied template cannot escape the sandbox + # — the standard Jinja2 mitigation for server-side template injection. + # autoescape=True is kept unconditional; callers opt out by pre-wrapping + # their string values in markupsafe.Markup, which Jinja renders verbatim. + env = ImmutableSandboxedEnvironment(autoescape=True, undefined=StrictUndefined) + if not autoescape: + context = { + key: Markup(value) if isinstance(value, str) else value + for key, value in context.items() + } + try: + # NOSONAR sandboxed env prevents SSTI escape (S5496 reviewed) + return env.from_string(template).render(**context) # NOSONAR + except JinjaTemplateError as error: + raise TemplateException(f"jinja render failed: {error}") from error + + +def _render_with_stdlib(template: str, context: dict[str, Any]) -> str: + try: + return string.Template(template).substitute(context) + except (KeyError, ValueError) as error: + raise TemplateException(f"template render failed: {error}") from error diff --git a/automation_file/local/trash.py b/automation_file/local/trash.py new file mode 100644 index 0000000..09cac36 --- /dev/null +++ b/automation_file/local/trash.py @@ -0,0 +1,141 @@ +"""Recoverable-delete / trash helpers. + +Moves files or directories into a caller-supplied trash directory instead of +permanent removal. Each trash entry keeps a JSON sidecar recording the +original path, so :func:`restore_from_trash` can return the content to its +source location. +""" + +from __future__ import annotations + +import json +import os +import shutil +import time +import uuid +from dataclasses import dataclass +from pathlib import Path + +from automation_file.exceptions import VersioningException + +_META_SUFFIX = ".trashmeta.json" + + +@dataclass(frozen=True) +class TrashEntry: + """An item present in a trash directory.""" + + trash_id: str + original: Path + trashed_at: float + path: Path + is_dir: bool + + +def send_to_trash( + path: str | os.PathLike[str], + trash_dir: str | os.PathLike[str], +) -> TrashEntry: + """Move ``path`` into ``trash_dir``; returns the new :class:`TrashEntry`.""" + source = Path(path) + if not source.exists(): + raise VersioningException(f"source does not exist: {source}") + bin_dir = Path(trash_dir) + bin_dir.mkdir(parents=True, exist_ok=True) + trash_id = f"{int(time.time() * 1000)}_{uuid.uuid4().hex[:8]}" + target = bin_dir / f"{trash_id}__{source.name}" + original_absolute = str(source.resolve()) + shutil.move(str(source), target) + trashed_at = time.time() + is_dir = target.is_dir() + meta: dict[str, object] = { + "trash_id": trash_id, + "original": original_absolute, + "trashed_at": trashed_at, + "is_dir": is_dir, + } + meta_path = bin_dir / f"{trash_id}{_META_SUFFIX}" + meta_path.write_text(json.dumps(meta), encoding="utf-8") + return TrashEntry( + trash_id=trash_id, + original=Path(original_absolute), + trashed_at=trashed_at, + path=target, + is_dir=is_dir, + ) + + +def list_trash(trash_dir: str | os.PathLike[str]) -> list[TrashEntry]: + """Return every item currently present in ``trash_dir``.""" + bin_dir = Path(trash_dir) + if not bin_dir.is_dir(): + return [] + entries: list[TrashEntry] = [] + for meta in bin_dir.glob(f"*{_META_SUFFIX}"): + entry = _read_meta(meta) + if entry is not None: + entries.append(entry) + entries.sort(key=lambda item: item.trashed_at) + return entries + + +def restore_from_trash( + trash_id: str, + trash_dir: str | os.PathLike[str], + *, + destination: str | os.PathLike[str] | None = None, +) -> Path: + """Move a trashed item back to its original location (or ``destination``).""" + bin_dir = Path(trash_dir) + meta_path = bin_dir / f"{trash_id}{_META_SUFFIX}" + entry = _read_meta(meta_path) if meta_path.is_file() else None + if entry is None: + raise VersioningException(f"no trash entry with id {trash_id!r}") + target = Path(destination) if destination is not None else entry.original + target.parent.mkdir(parents=True, exist_ok=True) + if target.exists(): + raise VersioningException(f"restore target already exists: {target}") + shutil.move(str(entry.path), target) + meta_path.unlink(missing_ok=True) + return target + + +def empty_trash(trash_dir: str | os.PathLike[str]) -> int: + """Permanently delete everything in ``trash_dir``. Returns items removed.""" + bin_dir = Path(trash_dir) + if not bin_dir.is_dir(): + return 0 + removed = 0 + for child in bin_dir.iterdir(): + if child.is_dir(): + shutil.rmtree(child) + else: + child.unlink() + removed += 1 + return removed + + +def _read_meta(meta_path: Path) -> TrashEntry | None: + try: + payload = json.loads(meta_path.read_text(encoding="utf-8")) + except (OSError, json.JSONDecodeError): + return None + trash_id = payload.get("trash_id") + original = payload.get("original") + if not trash_id or not original: + return None + bin_dir = meta_path.parent + matches = [ + p + for p in bin_dir.iterdir() + if p.name.startswith(f"{trash_id}__") and not p.name.endswith(_META_SUFFIX) + ] + if not matches: + return None + return TrashEntry( + trash_id=str(trash_id), + original=Path(str(original)), + trashed_at=float(payload.get("trashed_at", 0.0)), + path=matches[0], + is_dir=bool(payload.get("is_dir", matches[0].is_dir())), + ) diff --git a/automation_file/local/versioning.py b/automation_file/local/versioning.py new file mode 100644 index 0000000..d75db91 --- /dev/null +++ b/automation_file/local/versioning.py @@ -0,0 +1,117 @@ +"""Simple file versioning store. + +:class:`FileVersioner` keeps numbered snapshots of files under a versions +directory. Callers snapshot a file before mutating it, list prior versions, +restore one, or prune to keep only the most recent ``keep`` copies. +""" + +from __future__ import annotations + +import os +import re +import shutil +import time +from dataclasses import dataclass +from pathlib import Path + +from automation_file.exceptions import VersioningException + +_VERSION_RE = re.compile(r"^v(\d+)__(\d+)$") + + +@dataclass(frozen=True) +class VersionEntry: + """One snapshot returned by :meth:`FileVersioner.list_versions`.""" + + version: int + timestamp: float + path: Path + + +class FileVersioner: + """Store numbered snapshots beneath ``root``. + + Each source file is versioned in its own subdirectory so multiple files + can coexist. The subdirectory name is the source path's POSIX form with + path separators replaced by ``__sep__`` to flatten safely. + """ + + def __init__(self, root: str | os.PathLike[str]) -> None: + self._root = Path(root) + self._root.mkdir(parents=True, exist_ok=True) + + @property + def root(self) -> Path: + return self._root + + def save_version(self, path: str | os.PathLike[str]) -> VersionEntry: + """Snapshot the file at ``path`` and return the new entry.""" + src = Path(path) + if not src.is_file(): + raise VersioningException(f"source is not a file: {src}") + bucket = self._bucket_for(src) + bucket.mkdir(parents=True, exist_ok=True) + next_version = self._next_version(bucket) + timestamp_ns = time.time_ns() + target = bucket / f"v{next_version:06d}__{timestamp_ns}" + shutil.copy2(src, target) + return VersionEntry(version=next_version, timestamp=timestamp_ns / 1e9, path=target) + + def list_versions(self, path: str | os.PathLike[str]) -> list[VersionEntry]: + """Return every recorded snapshot of ``path``, oldest first.""" + bucket = self._bucket_for(Path(path)) + if not bucket.is_dir(): + return [] + entries: list[VersionEntry] = [] + for child in bucket.iterdir(): + if not child.is_file(): + continue + match = _VERSION_RE.match(child.name) + if not match: + continue + version = int(match.group(1)) + timestamp = int(match.group(2)) / 1e9 + entries.append(VersionEntry(version=version, timestamp=timestamp, path=child)) + entries.sort(key=lambda entry: entry.version) + return entries + + def restore(self, path: str | os.PathLike[str], version: int) -> None: + """Restore ``path`` from the snapshot with the given version number.""" + for entry in self.list_versions(path): + if entry.version == version: + shutil.copy2(entry.path, path) + return + raise VersioningException(f"no version {version} for {path}") + + def prune(self, path: str | os.PathLike[str], keep: int) -> int: + """Keep only the ``keep`` most recent versions; return rows deleted.""" + if keep < 0: + raise VersioningException("keep must be >= 0") + entries = self.list_versions(path) + if len(entries) <= keep: + return 0 + victims = entries[: len(entries) - keep] + for entry in victims: + entry.path.unlink(missing_ok=True) + return len(victims) + + def _bucket_for(self, src: Path) -> Path: + safe = _flatten_path(src) + return self._root / safe + + def _next_version(self, bucket: Path) -> int: + highest = 0 + for child in bucket.iterdir(): + match = _VERSION_RE.match(child.name) + if match: + highest = max(highest, int(match.group(1))) + return highest + 1 + + +def _flatten_path(src: Path) -> str: + # Resolve to absolute, strip drive letter on Windows, collapse separators. + resolved = src.resolve() + drive, body = os.path.splitdrive(resolved) + flat = (drive.replace(":", "") + body).replace(os.sep, "__sep__") + flat = flat.replace("/", "__sep__") + return flat.strip("_") or "root" diff --git a/automation_file/remote/fsspec_bridge.py b/automation_file/remote/fsspec_bridge.py new file mode 100644 index 0000000..55d42df --- /dev/null +++ b/automation_file/remote/fsspec_bridge.py @@ -0,0 +1,143 @@ +"""Bridge helpers that expose `fsspec `_ +backends through the same verbs as our native clients. + +fsspec already implements a large catalogue of filesystems — memory, local, +HTTP, GCS, ABFS, SSH, and more. Rather than reimplement each one, this +module gives callers a tiny surface (``upload`` / ``download`` / ``exists`` / +``list_dir`` / ``delete`` / ``mkdir``) over any ``fsspec`` URL. The ``fsspec`` +import is lazy so installing the package is only required when the bridge +is actually used. + +This is a **developer helper**, not a user-input surface. Callers are +responsible for validating URLs before handing them in — there is no SSRF +guard here because fsspec supports dozens of schemes, many of which bypass +the ``http(s)`` validator entirely (``ssh://``, ``s3://``, ``gcs://``…). +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from automation_file.exceptions import FsspecException + + +@dataclass(frozen=True) +class FsspecEntry: + """A single directory listing entry returned by :func:`fsspec_list_dir`.""" + + name: str + is_dir: bool + size: int | None + + +def _import_fsspec() -> Any: + try: + import fsspec + except ImportError as error: + raise FsspecException( + "fsspec import failed — install `fsspec` (and any backend extras) to use the bridge" + ) from error + return fsspec + + +def get_fs(url_or_protocol: str, **storage_options: Any) -> Any: + """Return an :class:`fsspec.AbstractFileSystem` for ``url_or_protocol``. + + Pass either a bare protocol (``"s3"``, ``"memory"``) or a full URL — + fsspec's ``url_to_fs`` will extract the protocol and pass ``storage_options`` + through to the backend constructor. + """ + fsspec = _import_fsspec() + try: + if "://" in url_or_protocol: + fs, _ = fsspec.core.url_to_fs(url_or_protocol, **storage_options) + return fs + return fsspec.filesystem(url_or_protocol, **storage_options) + except Exception as error: + raise FsspecException( + f"could not resolve fsspec filesystem for {url_or_protocol!r}: {error}" + ) from error + + +def _split(url: str) -> tuple[Any, str]: + fsspec = _import_fsspec() + try: + fs, path = fsspec.core.url_to_fs(url) + except Exception as error: + raise FsspecException(f"invalid fsspec url {url!r}: {error}") from error + return fs, path + + +def fsspec_exists(url: str) -> bool: + """Return True if ``url`` exists on its backing fsspec filesystem.""" + fs, path = _split(url) + try: + return bool(fs.exists(path)) + except Exception as error: + raise FsspecException(f"exists failed for {url!r}: {error}") from error + + +def fsspec_upload(local_path: str | os.PathLike[str], url: str) -> None: + """Copy ``local_path`` onto the fsspec target at ``url``.""" + source = Path(local_path) + if not source.is_file(): + raise FsspecException(f"local source is not a file: {source}") + fs, path = _split(url) + try: + fs.put_file(str(source), path) + except Exception as error: + raise FsspecException(f"upload failed for {url!r}: {error}") from error + + +def fsspec_download(url: str, local_path: str | os.PathLike[str]) -> None: + """Download the fsspec resource at ``url`` to ``local_path``.""" + dest = Path(local_path) + dest.parent.mkdir(parents=True, exist_ok=True) + fs, path = _split(url) + try: + fs.get_file(path, str(dest)) + except Exception as error: + raise FsspecException(f"download failed for {url!r}: {error}") from error + + +def fsspec_delete(url: str, *, recursive: bool = False) -> None: + """Remove ``url`` from its fsspec filesystem.""" + fs, path = _split(url) + try: + fs.rm(path, recursive=recursive) + except Exception as error: + raise FsspecException(f"delete failed for {url!r}: {error}") from error + + +def fsspec_mkdir(url: str, *, create_parents: bool = True) -> None: + """Create the directory at ``url`` (optionally including parents).""" + fs, path = _split(url) + try: + fs.makedirs(path, exist_ok=True) if create_parents else fs.mkdir(path) + except Exception as error: + raise FsspecException(f"mkdir failed for {url!r}: {error}") from error + + +def fsspec_list_dir(url: str) -> list[FsspecEntry]: + """Return a shallow listing of ``url`` as :class:`FsspecEntry` records.""" + fs, path = _split(url) + try: + raw = fs.ls(path, detail=True) + except Exception as error: + raise FsspecException(f"list_dir failed for {url!r}: {error}") from error + entries: list[FsspecEntry] = [] + for item in raw: + if isinstance(item, str): + entries.append(FsspecEntry(name=item.rsplit("/", 1)[-1], is_dir=False, size=None)) + continue + raw_name = item.get("name", "") + name = str(raw_name).rsplit("/", 1)[-1] + kind = str(item.get("type", "file")) + is_dir = kind == "directory" + size_obj = item.get("size") + size: int | None = int(size_obj) if isinstance(size_obj, int) and not is_dir else None + entries.append(FsspecEntry(name=name, is_dir=is_dir, size=size)) + return entries diff --git a/automation_file/remote/ftp/client.py b/automation_file/remote/ftp/client.py index 142f826..2e6f0a0 100644 --- a/automation_file/remote/ftp/client.py +++ b/automation_file/remote/ftp/client.py @@ -10,7 +10,7 @@ import contextlib from dataclasses import dataclass -from ftplib import FTP, FTP_TLS +from ftplib import FTP, FTP_TLS # nosec B321 - plaintext FTP is opt-in via tls=False from typing import Any from automation_file.exceptions import FileAutomationException @@ -42,7 +42,12 @@ def __init__(self) -> None: def later_init(self, options: FTPConnectOptions | None = None, **kwargs: Any) -> FTP: """Open an FTP control connection. TLS is negotiated when ``tls=True``.""" opts = options if options is not None else FTPConnectOptions(**kwargs) - ftp: FTP = FTP_TLS(timeout=opts.timeout) if opts.tls else FTP(timeout=opts.timeout) + # Plaintext FTP is opt-in via tls=False; FTPS is the default when tls=True. + if opts.tls: + ftp: FTP = FTP_TLS(timeout=opts.timeout) + else: + # Plaintext FTP only when caller opts in via tls=False. + ftp = FTP(timeout=opts.timeout) # nosec # NOSONAR opt-in via tls=False try: ftp.connect(opts.host, opts.port, timeout=opts.timeout) if opts.tls and isinstance(ftp, FTP_TLS): diff --git a/automation_file/remote/smb/__init__.py b/automation_file/remote/smb/__init__.py new file mode 100644 index 0000000..eb08633 --- /dev/null +++ b/automation_file/remote/smb/__init__.py @@ -0,0 +1,7 @@ +"""SMB / CIFS client.""" + +from __future__ import annotations + +from automation_file.remote.smb.client import SMBClient, SMBEntry + +__all__ = ["SMBClient", "SMBEntry"] diff --git a/automation_file/remote/smb/client.py b/automation_file/remote/smb/client.py new file mode 100644 index 0000000..9c71e04 --- /dev/null +++ b/automation_file/remote/smb/client.py @@ -0,0 +1,215 @@ +"""SMB / CIFS client built on ``smbprotocol``'s high-level ``smbclient`` API. + +Scope mirrors :mod:`automation_file.remote.webdav.client` — existence check, +upload, download, delete, directory create, and shallow listing. The +underlying session is registered per ``(server, username)`` pair and torn +down when :meth:`SMBClient.close` runs. ``smbprotocol`` is imported lazily so +importing this module never touches the optional dependency. +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path +from types import TracebackType +from typing import Any + +from automation_file.exceptions import SMBException + +_DEFAULT_PORT = 445 +_CHUNK_SIZE = 1 << 16 + + +@dataclass(frozen=True) +class SMBEntry: + """A single directory listing entry returned by :meth:`SMBClient.list_dir`.""" + + name: str + is_dir: bool + size: int | None + + +def _import_smbclient() -> Any: + try: + import smbclient + except ImportError as error: + raise SMBException( + "smbprotocol import failed — install `smbprotocol` to use the SMB backend" + ) from error + return smbclient + + +class SMBClient: + """Minimal SMB client scoped to the operations used by this project.""" + + def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments + self, + server: str, + share: str, + username: str | None = None, + password: str | None = None, + *, + port: int = _DEFAULT_PORT, + encrypt: bool = True, + connection_timeout: float = 30.0, + ) -> None: + if not server or not share: + raise SMBException("server and share are required") + self._server = server + self._share = share.strip("\\/") + self._username = username + self._password = password + self._port = port + self._encrypt = encrypt + self._connection_timeout = connection_timeout + self._registered = False + + def __enter__(self) -> SMBClient: + self._ensure_session() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + self.close() + + def close(self) -> None: + if not self._registered: + return + smbclient = _import_smbclient() + try: + smbclient.delete_session(self._server, port=self._port) + except Exception as error: + raise SMBException(f"failed to close SMB session to {self._server}: {error}") from error + finally: + self._registered = False + + def _ensure_session(self) -> None: + if self._registered: + return + smbclient = _import_smbclient() + try: + smbclient.register_session( + self._server, + username=self._username, + password=self._password, + port=self._port, + encrypt=self._encrypt, + connection_timeout=self._connection_timeout, + ) + except Exception as error: + raise SMBException( + f"failed to register SMB session to {self._server}: {error}" + ) from error + self._registered = True + + def _unc(self, remote_path: str) -> str: + cleaned = remote_path.replace("/", "\\").strip("\\") + base = f"\\\\{self._server}\\{self._share}" + if not cleaned: + return base + return f"{base}\\{cleaned}" + + def exists(self, remote_path: str) -> bool: + """Return True if the remote path exists.""" + self._ensure_session() + smbclient = _import_smbclient() + try: + smbclient.stat(self._unc(remote_path)) + except FileNotFoundError: + return False + except OSError as error: + raise SMBException(f"stat failed for {remote_path}: {error}") from error + return True + + def upload(self, local_path: str | os.PathLike[str], remote_path: str) -> None: + """Copy the contents of ``local_path`` to ``remote_path`` on the share.""" + source = Path(local_path) + if not source.is_file(): + raise SMBException(f"local source is not a file: {source}") + self._ensure_session() + smbclient = _import_smbclient() + try: + with ( + open(source, "rb") as src, + smbclient.open_file(self._unc(remote_path), mode="wb") as dst, + ): + while True: + chunk = src.read(_CHUNK_SIZE) + if not chunk: + break + dst.write(chunk) + except OSError as error: + raise SMBException(f"upload failed for {remote_path}: {error}") from error + + def download(self, remote_path: str, local_path: str | os.PathLike[str]) -> None: + """Stream the remote resource at ``remote_path`` to ``local_path``.""" + dest = Path(local_path) + dest.parent.mkdir(parents=True, exist_ok=True) + self._ensure_session() + smbclient = _import_smbclient() + try: + with ( + smbclient.open_file(self._unc(remote_path), mode="rb") as src, + open(dest, "wb") as out, + ): + while True: + chunk = src.read(_CHUNK_SIZE) + if not chunk: + break + out.write(chunk) + except OSError as error: + raise SMBException(f"download failed for {remote_path}: {error}") from error + + def delete(self, remote_path: str) -> None: + """Remove the remote file at ``remote_path``.""" + self._ensure_session() + smbclient = _import_smbclient() + try: + smbclient.remove(self._unc(remote_path)) + except OSError as error: + raise SMBException(f"delete failed for {remote_path}: {error}") from error + + def mkdir(self, remote_path: str) -> None: + """Create the remote directory at ``remote_path`` (parents must exist).""" + self._ensure_session() + smbclient = _import_smbclient() + try: + smbclient.makedirs(self._unc(remote_path), exist_ok=True) + except OSError as error: + raise SMBException(f"mkdir failed for {remote_path}: {error}") from error + + def rmdir(self, remote_path: str) -> None: + """Remove the empty remote directory at ``remote_path``.""" + self._ensure_session() + smbclient = _import_smbclient() + try: + smbclient.rmdir(self._unc(remote_path)) + except OSError as error: + raise SMBException(f"rmdir failed for {remote_path}: {error}") from error + + def list_dir(self, remote_path: str) -> list[SMBEntry]: + """Return a shallow listing of ``remote_path`` (non-recursive).""" + self._ensure_session() + smbclient = _import_smbclient() + try: + dir_entries = list(smbclient.scandir(self._unc(remote_path))) + except OSError as error: + raise SMBException(f"list_dir failed for {remote_path}: {error}") from error + entries: list[SMBEntry] = [] + for item in dir_entries: + is_dir = bool(item.is_dir()) + size: int | None + if is_dir: + size = None + else: + try: + size = int(item.stat().st_size) + except OSError: + size = None + entries.append(SMBEntry(name=item.name, is_dir=is_dir, size=size)) + return entries diff --git a/automation_file/remote/url_validator.py b/automation_file/remote/url_validator.py index ca1c21a..00d0695 100644 --- a/automation_file/remote/url_validator.py +++ b/automation_file/remote/url_validator.py @@ -47,14 +47,22 @@ def _is_disallowed_ip(ip_obj: ipaddress.IPv4Address | ipaddress.IPv6Address) -> ) -def validate_http_url(url: str) -> str: - """Return ``url`` if safe; raise :class:`UrlValidationException` otherwise.""" +def validate_http_url(url: str, *, allow_private: bool = False) -> str: + """Return ``url`` if safe; raise :class:`UrlValidationException` otherwise. + + ``allow_private=True`` relaxes the private/loopback/link-local checks for + callers that need to reach LAN services (e.g. on-prem WebDAV). Scheme and + host checks still apply. Callers must opt in explicitly — the default + remains strict SSRF blocking. + """ host = _require_host(url) for ip_str in _resolve_ips(host): try: ip_obj = ipaddress.ip_address(ip_str) except ValueError as error: raise UrlValidationException(f"cannot parse resolved ip: {ip_str}") from error - if _is_disallowed_ip(ip_obj): + if not allow_private and _is_disallowed_ip(ip_obj): raise UrlValidationException(f"disallowed ip: {ip_str}") + if allow_private and (ip_obj.is_multicast or ip_obj.is_unspecified): + raise UrlValidationException(f"disallowed ip even in permissive mode: {ip_str}") return url diff --git a/automation_file/remote/webdav/__init__.py b/automation_file/remote/webdav/__init__.py new file mode 100644 index 0000000..1115cfd --- /dev/null +++ b/automation_file/remote/webdav/__init__.py @@ -0,0 +1,7 @@ +"""WebDAV client.""" + +from __future__ import annotations + +from automation_file.remote.webdav.client import WebDAVClient, WebDAVEntry + +__all__ = ["WebDAVClient", "WebDAVEntry"] diff --git a/automation_file/remote/webdav/client.py b/automation_file/remote/webdav/client.py new file mode 100644 index 0000000..08b0e8c --- /dev/null +++ b/automation_file/remote/webdav/client.py @@ -0,0 +1,200 @@ +"""WebDAV client built on ``requests``. + +Supports the minimal set used for file automation — ``PUT`` upload, ``GET`` +download, ``DELETE``, ``MKCOL`` directory create, ``HEAD`` existence check, and +``PROPFIND`` listing. All URLs pass through +:func:`automation_file.remote.url_validator.validate_http_url`; private / +loopback hosts require ``allow_private_hosts=True``. +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path +from types import TracebackType +from urllib.parse import quote, unquote, urlparse + +import requests +from defusedxml.ElementTree import ParseError as DefusedParseError +from defusedxml.ElementTree import fromstring as defused_fromstring + +from automation_file.exceptions import WebDAVException +from automation_file.remote.url_validator import validate_http_url + +_DAV_NS = "{DAV:}" +_DEFAULT_TIMEOUT = 30.0 +_ABSOLUTE_URL_PREFIXES = ("http" + "://", "https://") +_PROPFIND_BODY = ( + '' + '' + "" + "" + "" + "" +) + + +@dataclass(frozen=True) +class WebDAVEntry: + """A single directory listing entry returned by :meth:`WebDAVClient.list_dir`.""" + + href: str + name: str + is_dir: bool + size: int | None + last_modified: str | None + + +class WebDAVClient: + """Minimal WebDAV client scoped to the operations used by this project.""" + + def __init__( + self, + base_url: str, + username: str | None = None, + password: str | None = None, + *, + allow_private_hosts: bool = False, + timeout: float = _DEFAULT_TIMEOUT, + verify_tls: bool = True, + ) -> None: + validate_http_url(base_url, allow_private=allow_private_hosts) + self._base_url = base_url.rstrip("/") + self._auth: tuple[str, str] | None = ( + (username, password) if username is not None and password is not None else None + ) + self._timeout = timeout + self._verify_tls = verify_tls + self._session = requests.Session() + + def __enter__(self) -> WebDAVClient: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + self.close() + + def close(self) -> None: + self._session.close() + + def _url_for(self, remote_path: str) -> str: + remote_path = remote_path.strip() + if remote_path.startswith(_ABSOLUTE_URL_PREFIXES): + return remote_path + remote_path = remote_path.lstrip("/") + if not remote_path: + return self._base_url + "/" + return f"{self._base_url}/{quote(remote_path, safe='/')}" + + def _request(self, method: str, remote_path: str, **kwargs: object) -> requests.Response: + url = self._url_for(remote_path) + try: + response = self._session.request( + method, + url, + auth=self._auth, + timeout=self._timeout, + verify=self._verify_tls, + **kwargs, + ) + except requests.RequestException as error: + raise WebDAVException(f"{method} {url} failed: {error}") from error + if response.status_code >= 400: + response.close() + raise WebDAVException( + f"{method} {url} -> HTTP {response.status_code}: {response.reason}" + ) + return response + + def exists(self, remote_path: str) -> bool: + """Return True if the remote resource exists (HEAD 200-299).""" + url = self._url_for(remote_path) + try: + response = self._session.request( + "HEAD", + url, + auth=self._auth, + timeout=self._timeout, + verify=self._verify_tls, + ) + except requests.RequestException as error: + raise WebDAVException(f"HEAD {url} failed: {error}") from error + response.close() + return 200 <= response.status_code < 300 + + def upload(self, local_path: str | os.PathLike[str], remote_path: str) -> None: + """PUT the contents of ``local_path`` to ``remote_path``.""" + source = Path(local_path) + if not source.is_file(): + raise WebDAVException(f"local source is not a file: {source}") + with open(source, "rb") as fh: + response = self._request("PUT", remote_path, data=fh) + response.close() + + def download(self, remote_path: str, local_path: str | os.PathLike[str]) -> None: + """GET the remote resource and stream it to ``local_path``.""" + dest = Path(local_path) + dest.parent.mkdir(parents=True, exist_ok=True) + response = self._request("GET", remote_path, stream=True) + try: + with open(dest, "wb") as out: + for chunk in response.iter_content(chunk_size=1 << 16): + if chunk: + out.write(chunk) + finally: + response.close() + + def delete(self, remote_path: str) -> None: + """DELETE the remote resource.""" + response = self._request("DELETE", remote_path) + response.close() + + def mkcol(self, remote_path: str) -> None: + """MKCOL — create a collection (directory) at the remote path.""" + response = self._request("MKCOL", remote_path) + response.close() + + def list_dir(self, remote_path: str) -> list[WebDAVEntry]: + """PROPFIND depth=1 against ``remote_path`` and return its entries.""" + headers = {"Depth": "1", "Content-Type": 'application/xml; charset="utf-8"'} + response = self._request( + "PROPFIND", + remote_path, + data=_PROPFIND_BODY, + headers=headers, + ) + try: + payload = response.text + finally: + response.close() + return _parse_propfind(payload) + + +def _parse_propfind(xml_text: str) -> list[WebDAVEntry]: + try: + root = defused_fromstring(xml_text) + except DefusedParseError as error: + raise WebDAVException(f"malformed PROPFIND response: {error}") from error + entries: list[WebDAVEntry] = [] + for response in root.findall(f"{_DAV_NS}response"): + href_elem = response.find(f"{_DAV_NS}href") + if href_elem is None or href_elem.text is None: + continue + href = href_elem.text.strip() + is_dir = response.find(f".//{_DAV_NS}collection") is not None + size_elem = response.find(f".//{_DAV_NS}getcontentlength") + size = int(size_elem.text) if size_elem is not None and size_elem.text else None + modified_elem = response.find(f".//{_DAV_NS}getlastmodified") + modified = ( + modified_elem.text.strip() if modified_elem is not None and modified_elem.text else None + ) + name = unquote(urlparse(href).path.rstrip("/").rsplit("/", 1)[-1]) + entries.append( + WebDAVEntry(href=href, name=name, is_dir=is_dir, size=size, last_modified=modified) + ) + return entries diff --git a/automation_file/server/_websocket.py b/automation_file/server/_websocket.py new file mode 100644 index 0000000..1e27a0e --- /dev/null +++ b/automation_file/server/_websocket.py @@ -0,0 +1,98 @@ +"""Minimal server-side WebSocket helpers (RFC 6455). + +Scope is intentionally narrow: we only need to (a) complete the opening +handshake and (b) send server-to-client text frames. We never parse inbound +frames beyond detecting a close — the ``/progress`` stream is write-only. + +Keeping this off the ``websockets`` third-party dep preserves the stdlib +footprint of the HTTP server. +""" + +from __future__ import annotations + +import base64 +import hashlib +import os +import struct +from typing import Any + +_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + +def compute_accept_key(sec_websocket_key: str) -> str: + """Return the ``Sec-WebSocket-Accept`` value for ``sec_websocket_key``. + + RFC 6455 mandates SHA-1 + a fixed GUID for the opening handshake — the + digest only proves the server understood the handshake, it is not a + security primitive. ``usedforsecurity=False`` tells static analysers to + skip the standard SHA-1 "insecure hash" warning. + """ + # NOSONAR RFC 6455 handshake is not a security primitive + digest = hashlib.sha1( # nosec B324 nosemgrep + (sec_websocket_key + _GUID).encode("ascii"), + usedforsecurity=False, + ).digest() + return base64.b64encode(digest).decode("ascii") + + +def send_text(wfile: Any, message: str) -> None: + """Write a single FIN text frame (server -> client, unmasked).""" + data = message.encode("utf-8") + header = bytearray([0x81]) + length = len(data) + if length < 126: + header.append(length) + elif length < (1 << 16): + header.append(126) + header.extend(struct.pack(">H", length)) + else: + header.append(127) + header.extend(struct.pack(">Q", length)) + wfile.write(bytes(header) + data) + wfile.flush() + + +def send_close(wfile: Any, code: int = 1000) -> None: + """Write a close frame (server -> client, unmasked).""" + payload = struct.pack(">H", code) + wfile.write(bytes([0x88, len(payload)]) + payload) + wfile.flush() + + +def read_frame_opcode(rfile: Any) -> int | None: + """Peek at one frame header and return its opcode, or ``None`` on EOF. + + The progress stream is write-only, but we still consume any client frame + (ping / close) so the TCP buffer does not fill up. Inbound client frames + are always masked per RFC 6455. + """ + header = rfile.read(2) + if len(header) < 2: + return None + opcode = header[0] & 0x0F + length = header[1] & 0x7F + masked = bool(header[1] & 0x80) + if length == 126: + extra = rfile.read(2) + if len(extra) < 2: + return None + length = struct.unpack(">H", extra)[0] + elif length == 127: + extra = rfile.read(8) + if len(extra) < 8: + return None + length = struct.unpack(">Q", extra)[0] + if masked and len(rfile.read(4)) < 4: + return None + remaining = length + while remaining > 0: + chunk = rfile.read(min(remaining, 4096)) + if not chunk: + return None + remaining -= len(chunk) + return opcode + + +def generate_key() -> str: + """Produce a random ``Sec-WebSocket-Key`` value (used by tests).""" + return base64.b64encode(os.urandom(16)).decode("ascii") diff --git a/automation_file/server/http_server.py b/automation_file/server/http_server.py index f7bf960..16dd857 100644 --- a/automation_file/server/http_server.py +++ b/automation_file/server/http_server.py @@ -1,34 +1,52 @@ """HTTP action server (stdlib only). -Listens for ``POST /actions`` requests whose body is a JSON action list; the -response body is a JSON object mirroring :func:`execute_action`'s return -value. Bound to loopback by default with the same opt-in semantics as -:mod:`tcp_server`. When ``shared_secret`` is supplied clients must send -``Authorization: Bearer `` — useful when placing the server behind a -reverse proxy. +Accepts ``POST /actions`` whose body is a JSON action list; the response is +a JSON object mirroring :func:`execute_action`'s return value. Additional +observability endpoints: + +* ``GET /healthz`` — liveness (always 200 while the process is alive) +* ``GET /readyz`` — readiness (registry resolves + ACL intact) +* ``GET /openapi.json`` — OpenAPI 3.0 description of the above +* ``GET /progress`` — WebSocket stream of progress registry snapshots + +Bound to loopback by default with the same opt-in semantics as +:mod:`tcp_server`. When ``shared_secret`` is supplied ``POST /actions`` and +``/progress`` require ``Authorization: Bearer `` — useful when +placing the server behind a reverse proxy. """ from __future__ import annotations +import contextlib import hmac import json import threading +import time from http import HTTPStatus from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer -from automation_file.core.action_executor import execute_action +from automation_file.core.action_executor import execute_action, executor +from automation_file.core.progress import progress_registry from automation_file.exceptions import TCPAuthException from automation_file.logging_config import file_automation_logger +from automation_file.server._websocket import ( + compute_accept_key, + send_close, + send_text, +) from automation_file.server.action_acl import ActionACL, ActionNotPermittedException from automation_file.server.network_guards import ensure_loopback _DEFAULT_HOST = "127.0.0.1" _DEFAULT_PORT = 9944 _MAX_CONTENT_BYTES = 1 * 1024 * 1024 +_PROGRESS_POLL_SECONDS = 1.0 +_PROGRESS_MAX_FRAMES = 10_000 +_BEARER_PREFIX = "Bearer " class _HTTPActionHandler(BaseHTTPRequestHandler): - """POST /actions -> JSON results.""" + """Routes: POST /actions, GET /{healthz,readyz,openapi.json,progress}.""" def log_message( # pylint: disable=arguments-differ self, format_str: str, *args: object @@ -65,13 +83,30 @@ def do_POST(self) -> None: # pylint: disable=invalid-name — BaseHTTPRequestHa return self._send_json(HTTPStatus.OK, results) + def do_GET(self) -> None: # pylint: disable=invalid-name — BaseHTTPRequestHandler API + if self.path == "/healthz": + self._send_json(HTTPStatus.OK, {"status": "ok"}) + return + if self.path == "/readyz": + ready, reason = _readiness() + status = HTTPStatus.OK if ready else HTTPStatus.SERVICE_UNAVAILABLE + self._send_json(status, {"status": "ready" if ready else "not_ready", "reason": reason}) + return + if self.path == "/openapi.json": + self._send_json(HTTPStatus.OK, _openapi_spec()) + return + if self.path == "/progress": + self._handle_progress_ws() + return + self._send_json(HTTPStatus.NOT_FOUND, {"error": "not found"}) + def _read_payload(self) -> list: secret: str | None = getattr(self.server, "shared_secret", None) if secret: header = self.headers.get("Authorization", "") - if not header.startswith("Bearer "): + if not header.startswith(_BEARER_PREFIX): raise TCPAuthException("missing bearer token") - if not hmac.compare_digest(header[len("Bearer ") :], secret): + if not hmac.compare_digest(header[len(_BEARER_PREFIX) :], secret): raise TCPAuthException("bad shared secret") try: @@ -97,6 +132,120 @@ def _send_json(self, status: HTTPStatus, data: object) -> None: self.end_headers() self.wfile.write(payload) + def _handle_progress_ws(self) -> None: + upgrade = self.headers.get("Upgrade", "").lower() + connection = self.headers.get("Connection", "").lower() + ws_key = self.headers.get("Sec-WebSocket-Key") + if upgrade != "websocket" or "upgrade" not in connection or not ws_key: + self._send_json(HTTPStatus.UPGRADE_REQUIRED, {"error": "websocket upgrade required"}) + return + + secret: str | None = getattr(self.server, "shared_secret", None) + if secret: + header = self.headers.get("Authorization", "") + token_ok = header.startswith(_BEARER_PREFIX) and hmac.compare_digest( + header[len(_BEARER_PREFIX) :], secret + ) + if not token_ok: + self._send_json(HTTPStatus.UNAUTHORIZED, {"error": "bad shared secret"}) + return + + accept = compute_accept_key(ws_key) + self.send_response(HTTPStatus.SWITCHING_PROTOCOLS) + self.send_header("Upgrade", "websocket") + self.send_header("Connection", "Upgrade") + self.send_header("Sec-WebSocket-Accept", accept) + self.end_headers() + self._stream_progress_frames() + + def _stream_progress_frames(self) -> None: + frames_sent = 0 + try: + while frames_sent < _PROGRESS_MAX_FRAMES: + snapshot = progress_registry.list() + send_text(self.wfile, json.dumps({"progress": snapshot}, default=repr)) + frames_sent += 1 + time.sleep(_PROGRESS_POLL_SECONDS) + except (BrokenPipeError, ConnectionResetError): + return + except Exception as error: # pylint: disable=broad-except + file_automation_logger.warning("http_server progress: %r", error) + finally: + with contextlib.suppress(OSError): + send_close(self.wfile) + + +def _readiness() -> tuple[bool, str]: + try: + if not executor.registry.event_dict: + return False, "registry empty" + except Exception as error: # pylint: disable=broad-except + return False, f"registry error: {error!r}" + return True, "ok" + + +def _openapi_spec() -> dict[str, object]: + return { + "openapi": "3.0.0", + "info": { + "title": "automation_file HTTP action server", + "version": "1.0.0", + "description": ( + "Executes JSON action lists and exposes health / readiness / progress endpoints." + ), + }, + "paths": { + "/actions": { + "post": { + "summary": "Execute a JSON action list.", + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": { + "type": "array", + "items": {"type": "array"}, + } + } + }, + }, + "responses": { + "200": {"description": "Action results as a JSON object."}, + "400": {"description": "Malformed JSON body."}, + "401": {"description": "Missing or invalid shared-secret token."}, + "403": {"description": "Action denied by ACL."}, + "500": {"description": "Server error while dispatching."}, + }, + } + }, + "/healthz": { + "get": { + "summary": "Liveness probe.", + "responses": {"200": {"description": "Server process alive."}}, + } + }, + "/readyz": { + "get": { + "summary": "Readiness probe.", + "responses": { + "200": {"description": "Registry populated and accepting actions."}, + "503": {"description": "Not ready."}, + }, + } + }, + "/progress": { + "get": { + "summary": "WebSocket stream of progress registry snapshots.", + "responses": { + "101": {"description": "Switching protocols to WebSocket."}, + "401": {"description": "Missing or invalid shared-secret token."}, + "426": {"description": "WebSocket upgrade required."}, + }, + } + }, + }, + } + class HTTPActionServer(ThreadingHTTPServer): """Threaded HTTP server carrying an optional shared secret.""" diff --git a/automation_file/server/mcp_server.py b/automation_file/server/mcp_server.py new file mode 100644 index 0000000..fb4088f --- /dev/null +++ b/automation_file/server/mcp_server.py @@ -0,0 +1,292 @@ +"""Model Context Protocol (MCP) server bridge. + +Exposes every :class:`~automation_file.core.action_registry.ActionRegistry` +entry as an MCP tool over JSON-RPC 2.0. The default transport is stdio — +one JSON message per line — because that's what MCP host implementations +(Claude Desktop, MCP CLIs) consume today. + +Scope +----- +* ``initialize`` — handshake, returns ``serverInfo`` + capabilities +* ``notifications/initialized`` — acknowledged as a no-op +* ``tools/list`` — lists registered actions as MCP tools +* ``tools/call`` — dispatches through the action registry + +Errors surface as JSON-RPC error objects with a ``MCPServerException`` chain +in the data field, so hosts can render them without having to parse the +exception string. +""" + +from __future__ import annotations + +import argparse +import inspect +import json +import sys +from collections.abc import Callable, Iterable, Sequence +from typing import Any, TextIO + +from automation_file.core.action_executor import executor +from automation_file.core.action_registry import ActionRegistry +from automation_file.exceptions import MCPServerException +from automation_file.logging_config import file_automation_logger + +_JSONRPC_VERSION = "2.0" +_PROTOCOL_VERSION = "2024-11-05" + +_PARSE_ERROR = -32700 +_INVALID_REQUEST = -32600 +_METHOD_NOT_FOUND = -32601 +_INVALID_PARAMS = -32602 +_INTERNAL_ERROR = -32603 + + +class MCPServer: + """Bridge between an MCP host and an :class:`ActionRegistry`.""" + + def __init__( + self, + registry: ActionRegistry | None = None, + *, + name: str = "automation_file", + version: str = "1.0.0", + ) -> None: + self._registry = registry if registry is not None else executor.registry + self._name = name + self._version = version + self._initialized = False + + def handle_message(self, message: dict[str, Any]) -> dict[str, Any] | None: + """Dispatch a single decoded JSON-RPC message. + + Returns the response dict for request messages, or ``None`` for + notifications (which get no reply). Protocol-level errors return a + JSON-RPC error object rather than raising. + """ + if not isinstance(message, dict) or message.get("jsonrpc") != _JSONRPC_VERSION: + return _error_response(None, _INVALID_REQUEST, "invalid JSON-RPC envelope") + + method = message.get("method") + msg_id = message.get("id") + params = message.get("params") or {} + + if not isinstance(method, str): + return _error_response(msg_id, _INVALID_REQUEST, "missing method") + + is_notification = msg_id is None + try: + if method == "initialize": + result = self._handle_initialize(params) + elif method == "notifications/initialized": + self._initialized = True + return None + elif method == "tools/list": + result = self._handle_tools_list() + elif method == "tools/call": + result = self._handle_tools_call(params) + else: + return _error_response(msg_id, _METHOD_NOT_FOUND, f"unknown method: {method}") + except MCPServerException as error: + return _error_response(msg_id, _INVALID_PARAMS, str(error)) + except Exception as error: # pylint: disable=broad-exception-caught + file_automation_logger.warning("mcp_server: internal error: %r", error) + return _error_response(msg_id, _INTERNAL_ERROR, f"{type(error).__name__}: {error}") + + if is_notification: + return None + return {"jsonrpc": _JSONRPC_VERSION, "id": msg_id, "result": result} + + def serve_stdio( + self, + stdin: TextIO | None = None, + stdout: TextIO | None = None, + ) -> None: + """Run the server over newline-delimited JSON on ``stdin`` / ``stdout``.""" + reader = stdin if stdin is not None else sys.stdin + writer = stdout if stdout is not None else sys.stdout + for line in reader: + stripped = line.strip() + if not stripped: + continue + try: + message = json.loads(stripped) + except json.JSONDecodeError as error: + self._write(writer, _error_response(None, _PARSE_ERROR, f"bad json: {error}")) + continue + response = self.handle_message(message) + if response is not None: + self._write(writer, response) + + def _handle_initialize(self, _params: dict[str, Any]) -> dict[str, Any]: + return { + "protocolVersion": _PROTOCOL_VERSION, + "capabilities": {"tools": {"listChanged": False}}, + "serverInfo": {"name": self._name, "version": self._version}, + } + + def _handle_tools_list(self) -> dict[str, Any]: + return {"tools": list(_catalogue(self._registry))} + + def _handle_tools_call(self, params: dict[str, Any]) -> dict[str, Any]: + name = params.get("name") + arguments = params.get("arguments") or {} + if not isinstance(name, str) or not name: + raise MCPServerException("tools/call requires a string 'name'") + if not isinstance(arguments, dict): + raise MCPServerException("'arguments' must be an object") + command = self._registry.resolve(name) + if command is None: + raise MCPServerException(f"unknown tool: {name}") + try: + value = command(**arguments) + except TypeError as error: + raise MCPServerException(f"bad arguments for {name}: {error}") from error + return { + "content": [{"type": "text", "text": _serialise(value)}], + "isError": False, + } + + @staticmethod + def _write(writer: TextIO, response: dict[str, Any]) -> None: + writer.write(json.dumps(response, default=repr) + "\n") + writer.flush() + + +def _error_response(msg_id: object, code: int, message: str) -> dict[str, Any]: + return { + "jsonrpc": _JSONRPC_VERSION, + "id": msg_id, + "error": {"code": code, "message": message}, + } + + +def _describe(command: Callable[..., Any]) -> str: + doc = inspect.getdoc(command) or "" + return doc.splitlines()[0] if doc else "Registered automation_file action." + + +def _schema_for(command: Callable[..., Any]) -> dict[str, Any]: + try: + signature = inspect.signature(command) + except (TypeError, ValueError): + return {"type": "object", "properties": {}, "additionalProperties": True} + properties: dict[str, Any] = {} + required: list[str] = [] + for parameter in signature.parameters.values(): + if parameter.kind in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ): + continue + if parameter.name in {"self", "cls"}: + continue + properties[parameter.name] = _json_schema_for(parameter.annotation) + if parameter.default is inspect.Parameter.empty: + required.append(parameter.name) + schema: dict[str, Any] = { + "type": "object", + "properties": properties, + "additionalProperties": True, + } + if required: + schema["required"] = required + return schema + + +def _json_schema_for(annotation: Any) -> dict[str, Any]: + if annotation is inspect.Parameter.empty: + return {} + mapping: dict[type, str] = { + str: "string", + int: "integer", + float: "number", + bool: "boolean", + list: "array", + dict: "object", + } + if isinstance(annotation, type) and annotation in mapping: + return {"type": mapping[annotation]} + return {} + + +def _serialise(value: Any) -> str: + try: + return json.dumps(value, default=repr) + except (TypeError, ValueError): + return repr(value) + + +def tools_from_registry(registry: ActionRegistry) -> Iterable[dict[str, Any]]: + """Yield MCP-shaped tool descriptors for every entry in ``registry``. + + Exposed separately so GUIs and tests can render the same catalogue + without instantiating :class:`MCPServer`. + """ + yield from _catalogue(registry) + + +def _catalogue(registry: ActionRegistry) -> Iterable[dict[str, Any]]: + for name, command in sorted(registry.event_dict.items()): + yield { + "name": name, + "description": _describe(command), + "inputSchema": _schema_for(command), + } + + +def _filtered_registry(source: ActionRegistry, allowed: Sequence[str]) -> ActionRegistry: + filtered = ActionRegistry() + missing: list[str] = [] + for name in allowed: + command = source.resolve(name) + if command is None: + missing.append(name) + continue + filtered.register(name, command) + if missing: + raise MCPServerException("unknown action(s) in allow list: " + ", ".join(sorted(missing))) + return filtered + + +def _build_cli_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="automation_file_mcp", + description="Expose the automation_file action registry as an MCP server over stdio.", + ) + parser.add_argument( + "--name", default="automation_file", help="serverInfo.name reported at handshake" + ) + parser.add_argument( + "--version", default="1.0.0", help="serverInfo.version reported at handshake" + ) + parser.add_argument( + "--allowed-actions", + default=None, + help=( + "comma-separated allow list of action names (e.g. " + "'FA_list_dir,FA_file_checksum'); defaults to every registered action" + ), + ) + return parser + + +def _cli(argv: Sequence[str] | None = None) -> int: + """Console-script entry point for the MCP stdio server.""" + args = _build_cli_parser().parse_args(argv) + registry = executor.registry + if args.allowed_actions: + names = [name.strip() for name in args.allowed_actions.split(",") if name.strip()] + registry = _filtered_registry(registry, names) + server = MCPServer(registry, name=args.name, version=args.version) + file_automation_logger.info( + "mcp_server: serving %d tools over stdio (name=%s version=%s)", + len(registry.event_dict), + args.name, + args.version, + ) + server.serve_stdio() + return 0 + + +if __name__ == "__main__": + sys.exit(_cli()) diff --git a/automation_file/server/web_ui.py b/automation_file/server/web_ui.py new file mode 100644 index 0000000..806ae0f --- /dev/null +++ b/automation_file/server/web_ui.py @@ -0,0 +1,215 @@ +"""Read-only observability Web UI (stdlib + HTMX). + +Serves a single HTML page that polls three HTML fragments — registered +actions, live progress, and health summary — using HTMX (loaded from a +pinned CDN URL). Write operations are deliberately out of scope; trigger +actions through :mod:`http_server` / :mod:`tcp_server` with their auth +story intact. + +Loopback-only by default; ``allow_non_loopback=True`` is required to bind +elsewhere. When ``shared_secret`` is supplied every request must carry +``Authorization: Bearer `` — the rendered HTML includes a +``hx-headers`` attribute so HTMX's polled requests carry the token. +""" + +from __future__ import annotations + +import hmac +import html as html_lib +import json +import threading +import time +from http import HTTPStatus +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer + +from automation_file.core.action_executor import executor +from automation_file.core.progress import progress_registry +from automation_file.logging_config import file_automation_logger +from automation_file.server.network_guards import ensure_loopback + +_DEFAULT_HOST = "127.0.0.1" +_DEFAULT_PORT = 9955 +_HTMX_CDN = "https://unpkg.com/htmx.org@1.9.12/dist/htmx.min.js" +_HTMX_SRI = "sha384-ujb1lZYygJmzgSwoxRggbCHcjc0rB2XoQrxeTUQyRjrOnlCoYta87iKBWq3EsdM2" + +_INDEX_TEMPLATE = """ + + + +automation_file + + + + +

automation_file

+

Read-only dashboard. Write operations live on the action server.

+ +

Health

+
+ loading… +
+ +

Progress

+
+ loading… +
+ +

Registered actions

+
+ loading… +
+ + +""" + + +class _WebUIHandler(BaseHTTPRequestHandler): + """Serves the dashboard page plus its three HTMX fragment endpoints.""" + + def log_message( # pylint: disable=arguments-differ + self, format_str: str, *args: object + ) -> None: + file_automation_logger.info("web_ui: " + format_str, *args) + + def do_GET(self) -> None: # pylint: disable=invalid-name + if not self._authorized(): + self._send_html(HTTPStatus.UNAUTHORIZED, "

unauthorized

") + return + path = self.path.split("?", 1)[0] + if path in ("/", "/index.html"): + self._send_html(HTTPStatus.OK, self._render_index()) + return + if path == "/ui/health": + self._send_html(HTTPStatus.OK, _render_health()) + return + if path == "/ui/progress": + self._send_html(HTTPStatus.OK, _render_progress()) + return + if path == "/ui/registry": + self._send_html(HTTPStatus.OK, _render_registry()) + return + self._send_html(HTTPStatus.NOT_FOUND, "

not found

") + + def _authorized(self) -> bool: + secret: str | None = getattr(self.server, "shared_secret", None) + if not secret: + return True + header = self.headers.get("Authorization", "") + if not header.startswith("Bearer "): + return False + return hmac.compare_digest(header[len("Bearer ") :], secret) + + def _send_html(self, status: HTTPStatus, body: str) -> None: + payload = body.encode("utf-8") + self.send_response(status) + self.send_header("Content-Type", "text/html; charset=utf-8") + self.send_header("Content-Length", str(len(payload))) + self.send_header("Cache-Control", "no-store") + self.end_headers() + self.wfile.write(payload) + + def _render_index(self) -> str: + secret: str | None = getattr(self.server, "shared_secret", None) + auth_headers_obj = {"Authorization": f"Bearer {secret}"} if secret else {} + auth_headers = html_lib.escape(json.dumps(auth_headers_obj), quote=True) + return _INDEX_TEMPLATE.format( + htmx_src=_HTMX_CDN, + htmx_sri=_HTMX_SRI, + auth_headers=auth_headers, + ) + + +def _render_health() -> str: + names = list(executor.registry.event_dict.keys()) + return ( + "" + "" + f"" + f"" + "
processalive
registry size{len(names)}
time{html_lib.escape(time.strftime('%Y-%m-%d %H:%M:%S'))}
" + ) + + +def _render_progress() -> str: + snapshots = progress_registry.list() + if not snapshots: + return "

no active transfers

" + rows = [] + for item in snapshots: + name = html_lib.escape(str(item.get("name", ""))) + status = html_lib.escape(str(item.get("status", ""))) + transferred = int(item.get("transferred", 0) or 0) + total = item.get("total") + total_cell = "—" if total in (None, 0) else str(total) + pct = "" + if isinstance(total, int) and total > 0: + pct = f" ({(transferred / total) * 100:.1f}%)" + rows.append( + "" + f"{name}" + f"{status}" + f"{transferred}{pct}" + f"{total_cell}" + "" + ) + return ( + "" + "" + + "".join(rows) + + "
namestatustransferredtotal
" + ) + + +def _render_registry() -> str: + names = sorted(executor.registry.event_dict.keys()) + if not names: + return "

registry empty

" + items = "".join(f"
  • {html_lib.escape(name)}
  • " for name in names) + return f"
      {items}
    " + + +class WebUIServer(ThreadingHTTPServer): + """Threaded HTTP server for the HTMX dashboard.""" + + def __init__( + self, + server_address: tuple[str, int], + handler_class: type = _WebUIHandler, + shared_secret: str | None = None, + ) -> None: + super().__init__(server_address, handler_class) + self.shared_secret: str | None = shared_secret + + +def start_web_ui( + host: str = _DEFAULT_HOST, + port: int = _DEFAULT_PORT, + allow_non_loopback: bool = False, + shared_secret: str | None = None, +) -> WebUIServer: + """Start the Web UI server on a background thread.""" + if not allow_non_loopback: + ensure_loopback(host) + if allow_non_loopback and not shared_secret: + file_automation_logger.warning( + "web_ui: non-loopback bind without shared_secret is insecure", + ) + server = WebUIServer((host, port), shared_secret=shared_secret) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + file_automation_logger.info( + "web_ui: listening on %s:%d (auth=%s)", + host, + port, + "on" if shared_secret else "off", + ) + return server diff --git a/dev.toml b/dev.toml index 75f99e2..067963a 100644 --- a/dev.toml +++ b/dev.toml @@ -25,8 +25,9 @@ dependencies = [ "paramiko>=3.4.0", "PySide6>=6.6.0", "watchdog>=4.0.0", - "cryptography>=42.0.0", - "prometheus_client>=0.20.0", + "cryptography>=46.0.7", + "prometheus_client>=0.25.0", + "defusedxml>=0.7.1", "tomli>=2.0.1; python_version<\"3.11\"" ] classifiers = [ @@ -49,6 +50,9 @@ dev = [ "twine>=5.1.0" ] +[project.scripts] +automation_file_mcp = "automation_file.server.mcp_server:_cli" + [project.urls] "Homepage" = "https://github.com/JE-Chen/Integration-testing-environment" diff --git a/dev_requirements.txt b/dev_requirements.txt index 8399b03..4d8702c 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -3,7 +3,8 @@ google-api-python-client google-auth-httplib2 google-auth-oauthlib APScheduler -cryptography>=42.0.0 -prometheus_client>=0.20.0 +cryptography>=46.0.7 +prometheus_client>=0.25.0 +defusedxml>=0.7.1 twine build diff --git a/docs/Makefile b/docs/Makefile index 01e66b5..b304cc6 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -4,7 +4,7 @@ SPHINXBUILD ?= sphinx-build SOURCEDIR = source BUILDDIR = _build -.PHONY: help html clean +.PHONY: help html clean html-zh-TW html-zh-CN html-all help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) @@ -12,5 +12,13 @@ help: html: @$(SPHINXBUILD) -M html "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) +html-zh-TW: + @$(SPHINXBUILD) -b html source.zh-TW "$(BUILDDIR)/html-zh-TW" $(SPHINXOPTS) + +html-zh-CN: + @$(SPHINXBUILD) -b html source.zh-CN "$(BUILDDIR)/html-zh-CN" $(SPHINXOPTS) + +html-all: html html-zh-TW html-zh-CN + clean: @$(SPHINXBUILD) -M clean "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) diff --git a/docs/make.bat b/docs/make.bat index 45d7073..cf6d346 100644 --- a/docs/make.bat +++ b/docs/make.bat @@ -18,9 +18,27 @@ if errorlevel 9009 ( exit /b 1 ) +if "%1" == "html-zh-TW" goto build-zh-TW +if "%1" == "html-zh-CN" goto build-zh-CN +if "%1" == "html-all" goto build-all + %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% goto end +:build-zh-TW +%SPHINXBUILD% -b html source.zh-TW %BUILDDIR%\html-zh-TW %SPHINXOPTS% %O% +goto end + +:build-zh-CN +%SPHINXBUILD% -b html source.zh-CN %BUILDDIR%\html-zh-CN %SPHINXOPTS% %O% +goto end + +:build-all +%SPHINXBUILD% -M html %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +%SPHINXBUILD% -b html source.zh-TW %BUILDDIR%\html-zh-TW %SPHINXOPTS% %O% +%SPHINXBUILD% -b html source.zh-CN %BUILDDIR%\html-zh-CN %SPHINXOPTS% %O% +goto end + :help %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% diff --git a/docs/requirements.txt b/docs/requirements.txt index 540144f..e26f13e 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,3 +1,3 @@ -sphinx>=7.0 +sphinx>=7.4.7 sphinx-rtd-theme myst-parser diff --git a/docs/source.zh-CN/api/client.rst b/docs/source.zh-CN/api/client.rst new file mode 100644 index 0000000..36ab808 --- /dev/null +++ b/docs/source.zh-CN/api/client.rst @@ -0,0 +1,5 @@ +Client SDK +========== + +.. automodule:: automation_file.client.http_client + :members: diff --git a/docs/source.zh-CN/api/core.rst b/docs/source.zh-CN/api/core.rst new file mode 100644 index 0000000..e25e807 --- /dev/null +++ b/docs/source.zh-CN/api/core.rst @@ -0,0 +1,86 @@ +Core +==== + +.. automodule:: automation_file.core.action_registry + :members: + +.. automodule:: automation_file.core.action_executor + :members: + +.. automodule:: automation_file.core.dag_executor + :members: + +.. automodule:: automation_file.core.callback_executor + :members: + +.. automodule:: automation_file.core.package_loader + :members: + +.. automodule:: automation_file.core.json_store + :members: + +.. automodule:: automation_file.core.retry + :members: + +.. automodule:: automation_file.core.quota + :members: + +.. automodule:: automation_file.core.rate_limit + :members: + +.. automodule:: automation_file.core.circuit_breaker + :members: + +.. automodule:: automation_file.core.file_lock + :members: + +.. automodule:: automation_file.core.sqlite_lock + :members: + +.. automodule:: automation_file.core.action_queue + :members: + +.. automodule:: automation_file.core.content_store + :members: + +.. automodule:: automation_file.core.progress + :members: + +.. automodule:: automation_file.core.checksum + :members: + +.. automodule:: automation_file.core.manifest + :members: + +.. automodule:: automation_file.core.fim + :members: + +.. automodule:: automation_file.core.audit + :members: + +.. automodule:: automation_file.core.crypto + :members: + +.. automodule:: automation_file.core.metrics + :members: + +.. automodule:: automation_file.core.substitution + :members: + +.. automodule:: automation_file.core.config_watcher + :members: + +.. automodule:: automation_file.core.plugins + :members: + +.. automodule:: automation_file.core.config + :members: + +.. automodule:: automation_file.core.secrets + :members: + +.. automodule:: automation_file.exceptions + :members: + +.. automodule:: automation_file.logging_config + :members: diff --git a/docs/source.zh-CN/api/index.rst b/docs/source.zh-CN/api/index.rst new file mode 100644 index 0000000..672f7a3 --- /dev/null +++ b/docs/source.zh-CN/api/index.rst @@ -0,0 +1,18 @@ +API 参考 +======== + +.. toctree:: + :maxdepth: 2 + + core + local + remote + server + client + trigger + scheduler + notify + progress + project + ui + utils diff --git a/docs/source.zh-CN/api/local.rst b/docs/source.zh-CN/api/local.rst new file mode 100644 index 0000000..9dbf0a2 --- /dev/null +++ b/docs/source.zh-CN/api/local.rst @@ -0,0 +1,47 @@ +本地操作 +======== + +.. automodule:: automation_file.local.file_ops + :members: + +.. automodule:: automation_file.local.dir_ops + :members: + +.. automodule:: automation_file.local.zip_ops + :members: + +.. automodule:: automation_file.local.sync_ops + :members: + +.. automodule:: automation_file.local.safe_paths + :members: + +.. automodule:: automation_file.local.shell_ops + :members: + +.. automodule:: automation_file.local.tar_ops + :members: + +.. automodule:: automation_file.local.json_edit + :members: + +.. automodule:: automation_file.local.conditional + :members: + +.. automodule:: automation_file.local.archive_ops + :members: + +.. automodule:: automation_file.local.diff_ops + :members: + +.. automodule:: automation_file.local.mime + :members: + +.. automodule:: automation_file.local.templates + :members: + +.. automodule:: automation_file.local.trash + :members: + +.. automodule:: automation_file.local.versioning + :members: diff --git a/docs/source.zh-CN/api/notify.rst b/docs/source.zh-CN/api/notify.rst new file mode 100644 index 0000000..15b4aa1 --- /dev/null +++ b/docs/source.zh-CN/api/notify.rst @@ -0,0 +1,8 @@ +通知 +==== + +.. automodule:: automation_file.notify.sinks + :members: + +.. automodule:: automation_file.notify.manager + :members: diff --git a/docs/source.zh-CN/api/progress.rst b/docs/source.zh-CN/api/progress.rst new file mode 100644 index 0000000..e7a9306 --- /dev/null +++ b/docs/source.zh-CN/api/progress.rst @@ -0,0 +1,11 @@ +进度与取消 +========== + +传输的可选仪表化。对 :func:`~automation_file.download_file`、 +:func:`s3_upload_file` 或 :func:`s3_download_file` 传入 +``progress_name="