diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..3e54203 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,20 @@ +**/node_modules +**/.venv +**/dist + +.github +.git +.ipynb_checkpoints +.ipython +.jupyter +# Logs +*.log +npm-debug.log* +pnpm-debug.log* +yarn-debug.log* +yarn-error.log* +lerna-debug.log* +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..4d9ba0c --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,18 @@ +version: 2 +updates: + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "weekly" + open-pull-requests-limit: 10 + + - package-ecosystem: "cargo" + directory: "/" + schedule: + interval: "weekly" + open-pull-requests-limit: 10 + + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" diff --git a/.github/scripts/plugins/plugin_manager.py b/.github/scripts/plugins/plugin_manager.py new file mode 100644 index 0000000..772df41 --- /dev/null +++ b/.github/scripts/plugins/plugin_manager.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 +# ruff: noqa: T201 + +import os +import sys +import json +import shutil +import argparse +import subprocess +import importlib.util +from typing import Any, TypedDict +from pathlib import Path + +import toml + + +class PluginInfo(TypedDict): + package_name: str + path: str + registration_info: dict[str, Any] + + +class PluginManager: + def __init__(self) -> None: + self.config = self.load_config() + self.registry = self.load_registry() + + def load_config(self) -> dict[str, Any]: + """Load configuration from ezpz.toml or pyproject.toml.""" + for config_path in [Path("ezpz.toml"), Path("pyproject.toml")]: + if config_path.exists(): + try: + with config_path.open("r") as f: + config = toml.load(f) + return config.get("ezpz_pluginz", config.get("tool", {}).get("ezpz", {})) + except Exception as e: + print(f"❌ Error loading {config_path}: {e}") + sys.exit(1) + print("⚠️ No valid configuration found, using empty config") + return {} + + def load_registry(self) -> dict[str, Any]: + """Load local plugin registry.""" + registry_path = Path.home() / ".ezpz" / "registry" / "plugins.json" + if registry_path.exists(): + with registry_path.open("r") as f: + return json.load(f) + print("⚠️ Local registry not found, assuming empty") + return {"plugins": []} + + def extract_project_plugins(self) -> list[PluginInfo]: + """Extract plugins from configuration.""" + include_paths = self.config.get("include", []) + project_plugins: list[PluginInfo] = [] + for path in include_paths: + path_obj = Path(path) + if path_obj.exists(): + project_plugins.append({"package_name": path_obj.name, "path": str(path_obj), "registration_info": {}}) + else: + print(f"⚠️ Path not found: {path}") + return project_plugins + + def get_plugin_registration_info(self, plugin_path: str) -> dict[str, Any] | None: + """Get registration info from plugin's register_plugin function.""" + plugin_path_obj = Path(plugin_path) + entry_points = [ + plugin_path_obj / "python" / plugin_path_obj.name.replace("-", "_") / "__init__.py", + plugin_path_obj / "src" / plugin_path_obj.name.replace("-", "_") / "__init__.py", + plugin_path_obj / plugin_path_obj.name.replace("-", "_") / "__init__.py", + plugin_path_obj / "__init__.py", + ] + + for entry_point in entry_points: + if entry_point.exists(): + try: + spec = importlib.util.spec_from_file_location(f"plugin_{entry_point.stem}", entry_point) + if spec and spec.loader: + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + if hasattr(module, "register_plugin"): + return module.register_plugin() + except Exception as e: + print(f"⚠️ Error loading plugin from {entry_point}: {e}") + + for init_file in plugin_path_obj.rglob("__init__.py"): + try: + with init_file.open("r") as f: + if "def register_plugin" in f.read(): + spec = importlib.util.spec_from_file_location(f"plugin_{init_file.stem}", init_file) + if spec and spec.loader: + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + if hasattr(module, "register_plugin"): + return module.register_plugin() + except Exception as e: + print(f"⚠️ Error reading {init_file}: {e}") + return None + + def compare_plugins(self, project_plugin: dict[str, Any], registry_plugin: dict[str, Any]) -> bool: + """Compare plugin metadata to detect changes.""" + fields = ["version", "description", "author", "category", "homepage", "aliases", "metadata_"] + return any(project_plugin.get(f) != registry_plugin.get(f) for f in fields) + + def write_outputs(self, outputs: dict[str, Any]) -> None: + """Write outputs to GITHUB_OUTPUT with unique keys.""" + with Path(os.environ["GITHUB_OUTPUT"]).open("a") as f: + f.writelines(f"{key}={json.dumps(value)}\n" for key, value in outputs.items()) + + def analyze(self) -> None: + """Analyze plugins and generate lists for registration/updates.""" + project_plugins = self.extract_project_plugins() + registry_plugins = {p["package_name"]: p for p in self.registry.get("plugins", [])} + plugins_to_register: list[PluginInfo] = [] + plugins_to_update: list[PluginInfo] = [] + + for plugin in project_plugins: + package_name = plugin["package_name"] + plugin_path = plugin["path"] + registration_info = self.get_plugin_registration_info(plugin_path) + if not registration_info: + print(f"⚠️ Skipping {package_name} - no registration info") + continue + plugin["registration_info"] = registration_info + + if package_name not in registry_plugins: + plugins_to_register.append(plugin) + elif self.compare_plugins(registration_info, registry_plugins[package_name]): + plugins_to_update.append(plugin) + + self.write_outputs( + { + "project-plugins": project_plugins, + "plugins-to-register": plugins_to_register, + "plugins-to-update": plugins_to_update, + "has-changes": len(plugins_to_register) > 0 or len(plugins_to_update) > 0, + } + ) + + def resolve_executable(self, cmd: str) -> str: + """Resolve the full path to an executable.""" + full_path = shutil.which(cmd) + if not full_path: + print(f"❌ Executable '{cmd}' not found in PATH") + sys.exit(1) + return full_path + + def safe_subprocess_run(self, args: list[str], **kwargs: Any) -> subprocess.CompletedProcess[Any]: # noqa: ANN401 + """Run a subprocess with validated executable path.""" + validated_args = [self.resolve_executable(args[0]), *args[1:]] + return subprocess.run(validated_args, **kwargs, check=True) # type: ignore # noqa: S603 + + def register(self, plugins_json: str, *, dry_run: bool) -> None: + """Register new plugins.""" + plugins: list[PluginInfo] = json.loads(plugins_json) + failed_plugins = list[str]() + for plugin in plugins: + package_name = plugin["package_name"] + plugin_path = plugin["path"] + try: + if dry_run: + print(f"🏃 DRY RUN: Would register {package_name}") + else: + self.safe_subprocess_run(["rye", "run", "ezpz", "register", plugin_path], check=True, text=True) + print(f"✅ Registered {package_name}") + except subprocess.CalledProcessError as e: + print(f"❌ Failed to register {package_name}: {e}") + failed_plugins.append(package_name) + if failed_plugins: + print(f"❌ Failed to register {len(failed_plugins)} plugins: {', '.join(failed_plugins)}") + sys.exit(1) + + def update(self, plugins_json: str, *, dry_run: bool) -> None: + """Update existing plugins.""" + plugins: list[PluginInfo] = json.loads(plugins_json) + failed_plugins = list[str]() + for plugin in plugins: + package_name = plugin["package_name"] + plugin_path = plugin["path"] + plugin_name = plugin["registration_info"].get("name", package_name) + try: + if dry_run: + print(f"🏃 DRY RUN: Would update {plugin_name}") + else: + self.safe_subprocess_run(["rye", "run", "ezpz", "update", plugin_name, plugin_path], check=True, text=True) + print(f"✅ Updated {package_name}") + except subprocess.CalledProcessError as e: + print(f"❌ Failed to update {package_name}: {e}") + failed_plugins.append(package_name) + if failed_plugins: + print(f"❌ Failed to update {len(failed_plugins)} plugins: {', '.join(failed_plugins)}") + sys.exit(1) + + def check_publish(self, package_name: str, plugins_to_register: str, plugins_to_update: str) -> None: + """Check if a plugin needs publishing.""" + plugins_to_register_list: list[PluginInfo] = json.loads(plugins_to_register) + plugins_to_update_list: list[PluginInfo] = json.loads(plugins_to_update) + needs_publishing = False + publish_type = "none" + + for plugin in plugins_to_register_list: + if plugin["package_name"] == package_name: + needs_publishing = True + publish_type = "new" + break + + if not needs_publishing: + for plugin in plugins_to_update_list: + if plugin["package_name"] == package_name: + needs_publishing = True + publish_type = "update" + break + + self.write_outputs({"needs-publishing": needs_publishing, "publish-type": publish_type}) + + +def main() -> None: + parser = argparse.ArgumentParser(description="EZPZ Plugin Manager") + subparsers = parser.add_subparsers(dest="command", required=True) + + subparsers.add_parser("analyze", help="Analyze plugins") + register_parser = subparsers.add_parser("register", help="Register new plugins") + register_parser.add_argument("--dry-run", action="store_true") + update_parser = subparsers.add_parser("update", help="Update existing plugins") + update_parser.add_argument("--dry-run", action="store_true") + check_publish = subparsers.add_parser("check-publish", help="Check if plugin needs publishing") + check_publish.add_argument("--package-name", required=True) + + args = parser.parse_args() + manager = PluginManager() + + if args.command == "analyze": + manager.analyze() + elif args.command == "register": + manager.register(os.environ.get("PLUGINS_TO_REGISTER", "[]"), args.dry_run) # type: ignore + elif args.command == "update": + manager.update(os.environ.get("PLUGINS_TO_UPDATE", "[]"), args.dry_run) # type: ignore + elif args.command == "check-publish": + manager.check_publish(args.package_name, os.environ.get("PLUGINS_TO_REGISTER", "[]"), os.environ.get("PLUGINS_TO_UPDATE", "[]")) + + +if __name__ == "__main__": + main() diff --git a/.github/scripts/plugins/plugin_ops.nu b/.github/scripts/plugins/plugin_ops.nu new file mode 100644 index 0000000..5097bb5 --- /dev/null +++ b/.github/scripts/plugins/plugin_ops.nu @@ -0,0 +1,257 @@ +#!/usr/bin/env nu + +def main [command: string, package_name: string, plugin_path: string, --dry-run: string = "false"] { + cd $plugin_path + + let dry_run = ($dry_run == "true") + let actual_path = if ($plugin_path | str ends-with $package_name) { "." } else { $plugin_path } + + match $command { + "validate" => { validate_plugin $package_name $actual_path } + "build" => { build_plugin $package_name $actual_path } + "test" => { test_plugin $package_name $actual_path } + "test-pipeline" => { test_pipeline $package_name $actual_path } + "publish" => { publish_plugin $package_name $actual_path $dry_run } + _ => { error make { msg: $"Unknown command: ($command)" } } + } +} + +def test_pipeline [package_name: string, plugin_path: string] { + validate_plugin $package_name $plugin_path + build_plugin $package_name $plugin_path + test_plugin $package_name $plugin_path + print $"✅ Plugin ($package_name) passed all tests" +} + +def test_plugin [package_name: string, plugin_path: string] { + cd $plugin_path + + let test_dir = ($plugin_path | path join "tests") + let test_dir_alt = ($plugin_path | path join "test") + if ((get_path_type $test_dir) == "dir" or (get_path_type $test_dir_alt) == "dir") { + try { + ^python3 -m pytest -v + print "✅ Python tests passed" + } catch { + print "❌ Python tests failed" + exit 1 + } + } + + let cargo_toml = ($plugin_path | path join "Cargo.toml") + if (get_path_type $cargo_toml) == "file" { + try { + ^cargo test + print "✅ Rust tests passed" + } catch { + print "❌ Rust tests failed" + exit 1 + } + } +} + +def validate_plugin [package_name: string, plugin_path: string] { + let pyproject_type = (get_path_type ($plugin_path | path join "pyproject.toml")) + let cargo_type = (get_path_type ($plugin_path | path join "Cargo.toml")) + + let has_pyproject = ($pyproject_type == "file") + let has_cargo = ($cargo_type == "file") + + if not ($has_pyproject or $has_cargo) { + print $"❌ Missing both pyproject.toml and Cargo.toml in ($plugin_path)" + exit 1 + } + + if $has_pyproject { + print "✅ Found pyproject.toml" + let py_typed_type = (get_path_type ($plugin_path | path join "python" $package_name "py.typed")) + if ($py_typed_type == "file") { + print "✅ Found py.typed for type hints" + } + } + + if $has_cargo { + print "✅ Found Cargo.toml" + let lib_rs_type = (get_path_type ($plugin_path | path join "src" "lib.rs")) + let main_rs_type = (get_path_type ($plugin_path | path join "src" "main.rs")) + if not (($lib_rs_type == "file") or ($main_rs_type == "file")) { + print "❌ Rust project missing src/lib.rs or src/main.rs" + exit 1 + } + print "✅ Found Rust source file (lib.rs or main.rs)" + } + + let init_found = check_init_py $package_name $plugin_path + if not $init_found { + print "❌ Could not find __init__.py with register_plugin function" + exit 1 + } + + let dist_dir = ($plugin_path | path join "dist") + let dist_type = (get_path_type $dist_dir) + if ($dist_type == "dir") { + let dist_files = (glob ($dist_dir | path join "*")) + if ($dist_files | length) > 0 { + try { + ^twine check ...$dist_files + print "✅ Package validation passed" + } catch { + print "❌ Package validation failed" + exit 1 + } + } + } + + print "✅ Plugin structure validation passed" +} + +def check_init_py [package_name: string, plugin_path: string] { + let patterns = [ + ($plugin_path | path join "python" $package_name "__init__.py"), + ($plugin_path | path join "src" $package_name "__init__.py"), + ($plugin_path | path join $package_name "__init__.py"), + ($plugin_path | path join "__init__.py") + ] + + for pattern in $patterns { + if ($pattern | path exists) { + let content = (open $pattern) + if ($content | str contains "def register_plugin") { + print $"✅ Found register_plugin function in ($pattern)" + return true + } + } + } + + let found_files = (glob ($plugin_path | path join "**" "__init__.py") | each { |file| + let content = (open $file) + if ($content | str contains "def register_plugin") { + $file + } + } | compact) + + if ($found_files | length) > 0 { + print $"✅ Found register_plugin function in ($found_files | first)" + return true + } + return false +} + +def build_plugin [package_name: string, plugin_path: string] { + let cleanup_patterns = ["dist" "build" "*.egg-info"] + for pattern in $cleanup_patterns { + try { + let items = (glob $pattern) + for item in $items { + rm -rf $item + } + } + } + + let cargo_toml = ($plugin_path | path join "Cargo.toml") + let pyproject_toml = ($plugin_path | path join "pyproject.toml") + + # For mixed projects, build Rust first if both exist + if (($cargo_toml | path exists) and ($pyproject_toml | path exists)) { + try { + ^cargo fetch + ^cargo build --release + print "✅ Rust build successful" + } catch { + print "❌ Rust build failed" + exit 1 + } + + try { + # Use maturin directly for mixed projects instead of rye + ^maturin build --release + print "✅ Python/Rust mixed build successful" + } catch { + print "❌ Python/Rust mixed build failed" + exit 1 + } + } else { + # Handle pure Python projects + if ($pyproject_toml | path exists) { + try { + ^rye build + print "✅ Python build successful" + } catch { + print "❌ Python build failed" + exit 1 + } + } + + # Handle pure Rust projects + if ($cargo_toml | path exists) { + try { + ^cargo fetch + ^cargo build --release + print "✅ Rust build successful" + } catch { + print "❌ Rust build failed" + exit 1 + } + } + } +} + +def publish_plugin [package_name: string, plugin_path: string, dry_run: bool] { + if $dry_run { + print $"🏃 DRY RUN: Would publish ($package_name)" + return + } + + let max_attempts = 3 + let pyproject_toml = ($plugin_path | path join "pyproject.toml") + if ($pyproject_toml | path exists) { + let dist_dir = ($plugin_path | path join "dist") + if ($dist_dir | path exists) { + let dist_files = (glob ($dist_dir | path join "*")) + if ($dist_files | length) > 0 { + for attempt in 1..$max_attempts { + try { + ^twine upload ...$dist_files + print $"✅ Successfully published ($package_name) to PyPI" + break + } catch { + print $"⚠️ Attempt ($attempt) failed for ($package_name)" + if $attempt == $max_attempts { + print $"❌ Failed to publish ($package_name) to PyPI after ($max_attempts) attempts" + exit 1 + } + sleep 5sec + } + } + } else { + print $"⚠️ No distribution files found for ($package_name)" + } + } + } + + let cargo_toml = ($plugin_path | path join "Cargo.toml") + if ($cargo_toml | path exists) { + for attempt in 1..$max_attempts { + try { + ^cargo publish + print $"✅ Successfully published ($package_name) to crates.io" + break + } catch { + print $"⚠️ Attempt ($attempt) failed for ($package_name)" + if $attempt == $max_attempts { + print $"❌ Failed to publish ($package_name) to crates.io after ($max_attempts) attempts" + exit 1 + } + sleep 5sec + } + } + } +} + +def get_path_type [path: string] { + try { + ($path | path type) + } catch { + null + } +} \ No newline at end of file diff --git a/.github/workflows/audit.yml b/.github/workflows/audit.yml new file mode 100644 index 0000000..d53ca78 --- /dev/null +++ b/.github/workflows/audit.yml @@ -0,0 +1,67 @@ +name: Security Audit +on: + workflow_dispatch: + pull_request: + paths: + - "**/*.rs" + - "**/*.py" + - "**/Cargo.toml" + - "**/Cargo.lock" + - "**/pyproject.toml" + - "**/requirements*.txt" + - ".github/workflows/audit.yml" + +jobs: + audit: + runs-on: ubuntu-latest + permissions: + contents: read + security-events: write + steps: + - name: Checkout + uses: actions/checkout@v4.2.2 + + - name: Install Rust & Tools + uses: dtolnay/rust-toolchain@v1 + with: + toolchain: stable + + - name: Install cargo-audit + run: cargo install cargo-audit --locked + + - name: Run cargo audit + run: | + echo "🔍 Running cargo-audit..." + cargo audit || echo "cargo audit failed, continuing..." + + - name: Install Python + uses: actions/setup-python@v5.6.0 + with: + python-version: "3.13" + + - name: Install Rye + uses: eifinger/setup-rye@v4.2.9 + with: + version: 'latest' + enable-cache: true + + - name: Sync Python env (Rye) + run: rye sync + + - name: Install Python audit tools + run: | + echo "🔧 Installing bandit and pip-audit..." + pip install bandit pip-audit + + - name: Run Bandit (Python) + run: | + echo "🔍 Running bandit..." + bandit -r . --skip B101 -ll || echo "bandit failed, continuing..." + + - name: Run pip-audit (Python) + run: | + echo "🔍 Running pip-audit..." + pip-audit || echo "pip-audit failed, continuing..." + + - name: ✅ Audit complete + run: echo "✅ All security checks complete!" \ No newline at end of file diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 0000000..e8b43f4 --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,38 @@ +name: "CodeQL" + +on: + push: + branches: [main] + pull_request: + branches: [main] + schedule: + - cron: "0 9 * * 1" + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: ["python", "rust"] + + steps: + - name: Checkout repository + uses: actions/checkout@v4.2.2 + + - name: Initialize CodeQL + uses: github/codeql-action/init@v2 + with: + languages: ${{ matrix.language }} + + - name: Autobuild + uses: github/codeql-action/autobuild@v2 + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v2 diff --git a/.github/workflows/core.yml b/.github/workflows/core.yml new file mode 100644 index 0000000..c352781 --- /dev/null +++ b/.github/workflows/core.yml @@ -0,0 +1,240 @@ +name: Core Components CI/CD + +on: + push: + branches: [main, dev] + paths: + - "core/pluginz/**" + - "core/macroz/**" + - "pyproject.toml" + - "requirements*.lock" + - ".github/workflows/core.yml" + pull_request: + branches: [main] + paths: + - "core/pluginz/**" + - "core/macroz/**" + - "pyproject.toml" + - "requirements*.lock" + - ".github/workflows/core.yml" + workflow_dispatch: + inputs: + deploy_env: + description: "Deployment environment" + required: true + default: "staging" + type: choice + options: + - staging + - production + publish_pypi: + description: "Publish to PyPI" + required: true + default: false + type: boolean + run_build: + description: "Run Build Packages job" + required: false + default: false + type: boolean + run_publish: + description: "Run Publish to PyPI job" + required: false + default: false + type: boolean + run_deploy: + description: "Run Deploy Registry job" + required: false + default: false + type: boolean + +env: + PYTHON_VERSION: "3.13" + +jobs: + test-core: + runs-on: ubuntu-latest + if: | + github.event_name == 'push' || + github.event_name == 'pull_request' || + github.event_name == 'workflow_dispatch' + steps: + - uses: actions/checkout@v4 + + - name: Install Rye + uses: eifinger/setup-rye@v4 + with: + version: "latest" + enable-cache: true + + - name: Pin Python version + run: rye pin ${{ env.PYTHON_VERSION }} + + - name: Cache Rye dependencies + uses: actions/cache@v4.2.3 + with: + path: | + ~/.cache/uv + .venv + key: ${{ runner.os }}-rye-${{ env.PYTHON_VERSION }}-${{ hashFiles('**/requirements*.lock', '**/pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-rye-${{ env.PYTHON_VERSION }}- + + - name: Install workspace dependencies + run: | + rye sync + + - name: Run linting and formatting checks + run: | + rye run ruff check . + rye run ruff format --check . + + - name: Test pluginz + run: | + cd core/pluginz + rye test + + - name: Test macroz + run: | + cd core/macroz + rye test + + - name: Test CLI functionality + run: | + cd core/pluginz + echo "--- Testing 'ezpz help' command ---" + rye run ezpz --help + + echo "--- Testing 'ezpz help ' ---" + rye run ezpz registry --help + rye run ezpz mount --help + + echo "--- Testing 'ezpz list' command ---" + rye run ezpz list + + echo "--- Testing 'ezpz status' command ---" + rye run ezpz regisry status + + echo "--- Testing 'ezpz mount' ---" + rye run ezpz mount || true + + echo "--- Testing 'ezpz unmount' ---" + rye run ezpz unmount || true + + rye run ezpz find database --field category + rye run ezpz find "my-test-plugin" --exact + rye run ezpz find rust + + build-packages: + needs: test-core + runs-on: ubuntu-latest + if: | + github.event_name == 'workflow_dispatch' && github.event.inputs.run_build == 'true' + outputs: + pluginz-version: ${{ steps.build-info.outputs.pluginz-version }} + macroz-version: ${{ steps.build-info.outputs.macroz-version }} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Install Rye + uses: eifinger/setup-rye@v4 + with: + version: "latest" + + - name: Pin Python version + run: rye pin ${{ env.PYTHON_VERSION }} + + - name: Install dependencies + run: rye sync --no-dev + + - name: Build pluginz package + run: | + cd core/pluginz + rye build + echo "Built pluginz package" + + - name: Build macroz package + run: | + cd core/macroz + rye build + echo "Built macroz package" + + - name: Extract version information + id: build-info + run: | + PLUGINZ_VERSION=$(cd core/pluginz && rye run python -c "import tomllib; print(tomllib.load(open('pyproject.toml', 'rb'))['project']['version'])") + MACROZ_VERSION=$(cd core/macroz && rye run python -c "import tomllib; print(tomllib.load(open('pyproject.toml', 'rb'))['project']['version'])") + echo "pluginz-version=$PLUGINZ_VERSION" >> $GITHUB_OUTPUT + echo "macroz-version=$MACROZ_VERSION" >> $GITHUB_OUTPUT + + - name: Check package integrity + run: | + rye add twine + rye run twine check core/pluginz/dist/* + rye run twine check core/macroz/dist/* + + - name: Upload build artifacts + uses: actions/upload-artifact@v4 + with: + name: python-packages + path: | + core/pluginz/dist/* + core/macroz/dist/* + retention-days: 30 + + publish-pypi: + needs: build-packages + runs-on: ubuntu-latest + if: | + github.event_name == 'workflow_dispatch' && github.event.inputs.run_publish == 'true' + + steps: + - name: Download build artifacts + uses: actions/download-artifact@v4 + with: + name: python-packages + path: dist/ + + - name: Install Rye + uses: eifinger/setup-rye@v4 + with: + version: "latest" + + - name: Publish to PyPI + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} + run: | + rye add twine + rye run twine upload dist/core/pluginz/dist/* + rye run twine upload dist/core/macroz/dist/* + echo "Successfully published packages to PyPI" + + notify-completion: + needs: + - test-core + - build-packages + - publish-pypi + if: | + always() && + ( + github.event_name == 'workflow_dispatch' || + github.event_name == 'push' || + github.event_name == 'pull_request' + ) + runs-on: ubuntu-latest + steps: + - name: Summarize results + run: | + echo "### ✅ Summary of Job Results" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "| Job | Result |" >> $GITHUB_STEP_SUMMARY + echo "|-----|--------|" >> $GITHUB_STEP_SUMMARY + echo "| Test Core | ${{ needs.test-core.result }} |" >> $GITHUB_STEP_SUMMARY + + if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then + echo "| Build Packages | ${{ needs.build-packages.result }} |" >> $GITHUB_STEP_SUMMARY + echo "| Publish PyPI | ${{ needs.publish-pypi.result }} |" >> $GITHUB_STEP_SUMMARY + fi diff --git a/.github/workflows/plugins.yml b/.github/workflows/plugins.yml new file mode 100644 index 0000000..0273457 --- /dev/null +++ b/.github/workflows/plugins.yml @@ -0,0 +1,172 @@ +name: EZPZ Plugin Management + +on: + push: + branches: [main, develop] + paths: ["plugins/**", "ezpz.toml", "pyproject.toml"] + pull_request: + branches: [main] + paths: ["plugins/**", "ezpz.toml", "pyproject.toml"] + workflow_dispatch: + inputs: + operation: + description: "Operation to perform" + required: true + default: "test" + type: choice + options: ["test", "register-and-update", "publish"] + dry_run: + description: "Dry run (no actual registry changes)" + required: false + default: false + type: boolean + +env: + PYTHON_VERSION: "3.13" + RUST_VERSION: "1.87" + +jobs: + discover-plugins: + runs-on: ubuntu-latest + outputs: + project-plugins: ${{ steps.analyze.outputs.project-plugins }} + plugins-to-register: ${{ steps.analyze.outputs.plugins-to-register }} + plugins-to-update: ${{ steps.analyze.outputs.plugins-to-update }} + has-changes: ${{ steps.analyze.outputs.has-changes }} + steps: + - uses: actions/checkout@v4.2.2 + - uses: actions/setup-python@v5.6.0 + with: + python-version: ${{ env.PYTHON_VERSION }} + - uses: eifinger/setup-rye@v4.2.9 + with: + enable-cache: true + - uses: hustcer/setup-nu@v3.20 + with: + version: "0.105.1" + - uses: extractions/setup-just@v2 + with: + just-version: "1.40.0" + - run: rye sync + - run: rye run ezpz registry refresh + - id: analyze + run: just actions::analyze-plugins + + test-plugins: + runs-on: ubuntu-latest + needs: discover-plugins + if: | + (github.event_name == 'push' || github.event_name == 'pull_request' || github.event.inputs.operation == 'test') && + needs.discover-plugins.outputs.has-changes == 'true' + strategy: + matrix: + plugin: ${{ fromJson(needs.discover-plugins.outputs.project-plugins) }} + fail-fast: false + steps: + - uses: actions/checkout@v4.2.2 + - uses: actions/setup-python@v5.6.0 + with: + python-version: ${{ env.PYTHON_VERSION }} + - uses: actions-rs/toolchain@v1.0.6 + if: hashFiles(format('{0}/Cargo.toml', matrix.plugin.path)) != '' + with: + toolchain: ${{ env.RUST_VERSION }} + default: true + - uses: eifinger/setup-rye@v4.2.9 + with: + enable-cache: true + - uses: hustcer/setup-nu@v3.20 + with: + version: "0.105.1" + - uses: extractions/setup-just@v2 + with: + just-version: "1.40.0" + - uses: actions/cache@v3 + with: + path: | + ~/.cache/uv + ~/.cargo/registry + ~/.cargo/git + target/ + key: ${{ runner.os }}-${{ matrix.plugin.package_name }}-${{ hashFiles(format('{0}/**/pyproject.toml', matrix.plugin.path), format('{0}/**/Cargo.toml', matrix.plugin.path)) }} + - run: rye sync + - name: Test Plugin Pipeline + run: | + just actions::test-plugin-pipeline \ + "${{ matrix.plugin.package_name }}" \ + "${{ matrix.plugin.path }}" + + register-update-plugins: + runs-on: ubuntu-latest + needs: [discover-plugins, test-plugins] + if: | + github.event_name == 'workflow_dispatch' && + needs.discover-plugins.outputs.has-changes == 'true' && + needs.test-plugins.result == 'success' && + github.event.inputs.operation == 'register-and-update' + steps: + - uses: actions/checkout@v4.2.2 + - uses: actions/setup-python@v5.6.0 + with: + python-version: ${{ env.PYTHON_VERSION }} + - uses: eifinger/setup-rye@v4.2.9 + with: + enable-cache: true + - uses: extractions/setup-just@v2 + with: + just-version: "1.40.0" + - run: rye sync + - run: rye run ezpz registry refresh + - name: Register and Update Plugins + env: + EZPZ_SERVER_SECRET: ${{ secrets.EZPZ_SERVER_SECRET }} + run: | + just actions::register-update-plugins \ + '${{ needs.discover-plugins.outputs.plugins-to-register }}' \ + '${{ needs.discover-plugins.outputs.plugins-to-update }}' \ + "${{ github.event.inputs.dry_run }}" + + publish-plugins: + runs-on: ubuntu-latest + needs: [discover-plugins, test-plugins] + if: | + github.event_name == 'workflow_dispatch' && + github.event.inputs.operation == 'publish' + strategy: + matrix: + plugin: ${{ fromJson(needs.discover-plugins.outputs.project-plugins) }} + fail-fast: false + steps: + - uses: actions/checkout@v4.2.2 + - uses: actions/setup-python@v5.6.0 + with: + python-version: ${{ env.PYTHON_VERSION }} + - uses: actions-rs/toolchain@v1.0.6 + if: hashFiles(format('{0}/Cargo.toml', matrix.plugin.path)) != '' + with: + toolchain: ${{ env.RUST_VERSION }} + default: true + - uses: eifinger/setup-rye@v4.2.9 + with: + enable-cache: true + - uses: hustcer/setup-nu@v3.20 + with: + version: "0.105.1" + - uses: extractions/setup-just@v2 + with: + just-version: "1.40.0" + - run: | + rye sync + rye add twine + - name: Publish Plugin + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} + CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} + run: | + just actions::publish-plugin \ + "${{ matrix.plugin.package_name }}" \ + "${{ matrix.plugin.path }}" \ + "${{ github.event.inputs.dry_run }}" \ + '${{ needs.discover-plugins.outputs.plugins-to-register }}' \ + '${{ needs.discover-plugins.outputs.plugins-to-update }}' diff --git a/.gitignore b/.gitignore index bc90fd8..3119856 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,10 @@ -*.lock +.patched +.hypothesis +.pytest_cache +Cargo.lock *-lock.yaml +**/*.egg-info +**/*.lock # Mac stuff: .DS_Store diff --git a/.rustfmt.toml b/.rustfmt.toml index 8e381c4..643c521 100755 --- a/.rustfmt.toml +++ b/.rustfmt.toml @@ -11,7 +11,7 @@ max_width = 160 tab_spaces = 2 # Imports -imports_granularity = "One" +imports_granularity = "Crate" reorder_imports = true # Format comments diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..0967ef4 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1 @@ +{} diff --git a/Cargo.toml b/Cargo.toml index 8f8853a..ba6fc37 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ repository = "" [workspace] -members = ["api", "app", "guiz", "stubz"] +members = ["plugins/*", "stubz"] resolver = "2" [profile.dev.package."*"] @@ -27,38 +27,11 @@ opt-level = "z" panic = "abort" strip = true -[profile.wasm-dev] -inherits = "dev" -opt-level = 1 - -[profile.server-dev] -inherits = "dev" - -[profile.android-dev] -inherits = "dev" - [workspace.dependencies] ezpz-stubz = { path = "stubz", package = "ezpz-stubz" } - -pyproject-toml = { version = "0.13.4" } -serde-toml-merge = "0.3.9" -serde_merge = "0.1.3" -serde_yml = "0.0.12" -toml = { version = "0.8.23" } - -bigdecimal = { version = "0.4.8", features = ["serde"] } - - -clap = { version = "4.5.40", features = ["derive"] } - -lru = "0.14.0" - - -# polars -connectorx = "0.4.3" hashbrown = { version = "0.15.4" } -polars = { version = "0.48.1", features = [ +polars = { version = "0.49.1", features = [ "dataframe_arithmetic", "describe", "dtype-full", @@ -73,112 +46,20 @@ polars = { version = "0.48.1", features = [ "strings", ] } # DataFrame library based on Apache Arrow -# PyO3 pyo3 = { version = "*" } -pyo3-polars = { version = "0.21.0", features = ["derive", "dtype-full", "lazy"] } -pyo3-stub-gen = { version = "0.9.1", default-features = false } +pyo3-polars = { version = "0.22.0", features = ["derive", "dtype-full", "lazy"] } +pyo3-stub-gen = { version = "0.11.1", default-features = false } - -api = { path = "api" } - - -maestro-anthropic = { path = "../dioxus-maestro/clients/maestro-anthropic" } -maestro-apalis = { path = "../dioxus-maestro/clients/maestro-apalis" } -maestro-diesel = { path = "../dioxus-maestro/clients/maestro-diesel" } - -maestro-forms = { path = "../dioxus-maestro/frontend/maestro-forms" } -maestro-hooks = { path = "../dioxus-maestro/frontend/maestro-hooks", features = ["web"] } -maestro-toast = { path = "../dioxus-maestro/frontend/maestro-toast", features = ["web"] } -maestro-ui = { path = "../dioxus-maestro/frontend/maestro-ui" } - -anyhow = "1.0.98" chrono = { version = "0.4.41", features = ["serde"] } -dashmap = { version = "6.1.0", features = ["rayon", "serde"] } -derive-new = { version = "0.7.0" } -derive_more = { version = "2.0.1" } -enum-map = { version = "2.7.3" } -futures = "0.3.31" -futures-util = "0.3.31" -itertools = "0.14.0" -num-traits = { version = "0.2.19" } -parking_lot = { version = "0.12.4" } -rand = { version = "0.9.1", features = ["small_rng"] } serde = { version = "1.0.219", features = ["derive"] } -serde_json = "1.0.140" -shrinkwraprs = { version = "0.3.0" } -strum = { version = "0.27.1", features = ["derive"] } -tap = { version = "1.0.1" } -uuid = { version = "1.17.0", features = ["serde", "v4"] } - -leafwing-input-manager = "0.17.0" -leafwing_abilities = "0.11.0" - -bon = { version = "3.6.4" } -lowdash = "0.5.3" -schemars = { git = "https://github.com/GREsau/schemars.git" } -stilts = { version = "0.3.3" } -url = { version = "2.5.4", features = ["serde"] } -validator = { version = "0.20.0", features = ["derive"] } - -markdown-to-html = "0.1.3" -plotters = { version = "0.3.7", default-features = false, features = [ - "bitmap_backend", - "bitmap_encoder", - "bitmap_gif", - "chrono", - "svg_backend", - # "ttf", - "all_elements", - "all_series", - "colormaps", - "deprecated_items", - "full_palette", - "image", -] } -plotters-canvas = { version = "0.3.1" } -tailwind_fuse = "0.3.2" - -dioxus = { version = "0.7.0-alpha.1", default-features = false } -dioxus-free-icons = { git = "https://github.com/dioxus-community/dioxus-free-icons.git", features = [ - "bootstrap", - "feather", - "font-awesome-brands", - "font-awesome-regular", - "font-awesome-solid", - "hero-icons-outline", - "hero-icons-solid", - "ionicons", - "lucide", - "material-design-icons-action", - "material-design-icons-alert", - "material-design-icons-av", - "material-design-icons-communication", - "material-design-icons-content", - "material-design-icons-device", - "material-design-icons-editor", - "material-design-icons-file", - "material-design-icons-hardware", - "material-design-icons-home", - "material-design-icons-image", - "material-design-icons-maps", - "material-design-icons-navigation", - "material-design-icons-notification", - "material-design-icons-places", - "material-design-icons-social", - "material-design-icons-toggle", - "octicons", -] } -dioxus-sdk = { git = "https://github.com/DioxusLabs/sdk.git", features = ["time"] } +serde_json = "1.0.141" -tokio = { version = "1.45.1", default-features = false } -tokio-tungstenite = { version = "0.26.2", default-features = false } [workspace.lints.rust] unsafe_code = "deny" elided_lifetimes_in_paths = "warn" -rust_2021_idioms = "warn" -rust_2021_prelude_collisions = "warn" +rust_2024_prelude_collisions = "warn" semicolon_in_expressions_from_macros = "warn" trivial_numeric_casts = "warn" unsafe_op_in_unsafe_fn = "warn" # `unsafe_op_in_unsafe_fn` may become the default in future Rust versions: https://github.com/rust-lang/rust/issues/71668 diff --git a/README.md b/README.md index 4ac2b7e..62d960f 100644 --- a/README.md +++ b/README.md @@ -1,30 +1,380 @@ # EZPZ -A collection of FOSS packages to make dev life more, well, EZPZ. +A toolkit for extending Polars with custom plugins and type safety. EZPZ is also tailored to bridge the gap between Rust performance and Python developer experience in the Polars Ecosystem. -## Grouping Folders +## 📦 Core Components -- EazyPolarz ([readme](ezpz/README.md)) -- Juzt ([readme](juzt/README.md)) -- Painlezz ([readme](painlezz/README.md)) +### 🔌 [EZPZ-Pluginz](./pluginz/) -### EazyPolarz +_The foundation of the EZPZ ecosystem_ -- guiz - GUI toolkit ([readme](ezpz/README.md)) -- pluginz - Plugin system with proper type checking ([readme](ezpz/README.md)) -- stubz - pyo3-polars integration with pyo3-stub-gen ([readme](ezpz/README.md)) +A powerful tool that provides comprehensive type hinting and IDE support for Polars plugins, dramatically enhancing the development experience for custom Polars extensions. -### Juzt +**Key Features:** -A collection of utilities to juzt get it done. +- Full type safety for Polars plugins +- Hot reloading with automatic type hint updates pointing directly to plugin implementations +- **Plugin registry**: Discover and install ecosystem plugins with ease +- **Site-packages integration**: Seamlessly load and manage plugins from installed packages +- **IDE support**: Autocompletion, inline documentation and error detection +- **Multiple syntax support**: Decorator and function call patterns for plugin discovery +- Support for DataFrame, LazyFrame, Series, and Expression plugins +- Reversible modifications with safe backups -- core ([readme](ezpz/README.md)) -- gui ([readme](ezpz/README.md)) -- infra ([readme](ezpz/README.md)) +```bash +pip install ezpz_pluginz +ezpz mount # Enable plugin support +``` -### Painlezz +### 🦀 [EZPZ Stubz](./stubz/) -- basez ([readme](ezpz/README.md)) -- formatterz - dead simple api to apply code formaters from various languages ([readme](ezpz/README.md)) -- macroz - marcos for python with AST validation inspired by rust ([readme](ezpz/README.md)) -- projectz - utilities for easier monorepo management ([readme](ezpz/README.md)) +_Type-safe PyO3-Polars wrappers_ + +Provides wrapper types that enable PyO3 extensions to work seamlessly with Polars objects while maintaining proper type information. + +**Key Features:** + +- Transparent wrappers for Polars types +- Automatic stub generation with `pyo3_stub_gen` +- Zero-runtime cost abstractions +- Full IDE support + +```toml +[dependencies] +ezpz-stubz = "*" +``` + +### 📈 [EZPZ Rust Technical Analysis](./ezpz-rust-ti/) + +_Production-ready technical analysis plugin_ + +A comprehensive technical analysis library showcasing the EZPZ plugin system with 70+ indicators powered by Rust. + +**Key Features:** + +- 70+ technical indicators +- Polars native integration +- Rust-powered performance +- Full type safety + +```bash +pip install ezpz-rust-ti +# or use the registry +ezpz add rust-ti +``` + +## 📦 Supporting Libraries + +### 🔧 [Macroz](./macroz/) + +_Lightweight Python macro system powering plugin discovery_ + +A lightweight Python macro system for code transformation and metadata collection, built on LibCST for static analysis and code generation. + +**Note**: This component is experimental and may evolve significantly as the Python static analysis ecosystem develops, particularly with upcoming tools like Astral. + +**Key Features:** + +- No-op macros that preserve runtime behavior +- LibCST integration for AST analysis +- Type-safe metadata collection +- Flexible callback system + +```bash +pip install macroz +``` + +## 🏗️ Architecture Overview + +EZPZ follows a modular architecture designed around the Polars ecosystem: + +```table +┌──────────────────────────────────────────────┐ +│ EZPZ Ecosystem │ +├──────────────────────────────────────────────┤ +│ Plugin Development Layer │ +│ ┌─────────────────┐ ┌─────────────────┐ │ +│ │ EZPZ-Pluginz │ │ Painlezz-Macroz │ │ +│ │ (Type System) │ │ (Macro System) │ │ +│ └─────────────────┘ └─────────────────┘ │ +├──────────────────────────────────────────────┤ +│ Runtime Integration Layer │ +│ ┌─────────────────┐ ┌─────────────────┐ │ +│ │ EZPZ-Stubz │ │ Plugin Runtime │ │ +│ │ (PyO3 Wrappers) │ │ Integration │ │ +│ └─────────────────┘ └─────────────────┘ │ +├──────────────────────────────────────────────┤ +│ Application Layer │ +│ ┌─────────────────┐ ┌─────────────────┐ │ +│ │ EZPZ-Rust-TI │ │ Custom Plugins │ │ +│ │(Tech Analysis) │ │ (User-defined) │ │ +│ └─────────────────┘ └─────────────────┘ │ +├──────────────────────────────────────────────┤ +│ Polars Core │ +└──────────────────────────────────────────────┘ +``` + +## 🚀 Quick Start + +### 1. Install the Plugin System + +```bash +pip install ezpz_pluginz +``` + +### 2. Create Your First Plugin + +```python +# my_plugin.py +from ezpz_pluginz.register_plugin_macro import ezpz_plugin_collect + +@ezpz_plugin_collect( + polars_ns="DataFrame", + attr_name="my_operations", + import_="from my_plugin import MyDataFramePlugin", + type_hint="MyDataFramePlugin" +) +class MyDataFramePlugin: + def custom_transform(self, multiplier: float): + """Custom transformation with full type safety""" + return self._df.with_columns( + [pl.col(col) * multiplier for col in self._df.columns] + ) +``` + +### 3. Configure Plugin Discovery + +To configure plugin discovery, you can use either a dedicated `ezpz.toml` file or add a `[tool.ezpz_pluginz]` section to your `pyproject.toml` file. + +#### Option 1: Using `ezpz.toml` + +```toml +# ezpz.toml +[ezpz_pluginz] +name = "my-polars-project" +include = [ + "src/plugins/", + "my_plugin.py" +] +site_customize = true +``` + +#### Option 1: Using `pyproject.toml` + +```toml +# pyproject.toml +[tool.ezpz_pluginz] +name = "my-polars-project" +include = [ + "src/plugins/", + "my_plugin.py" +] +site_customize = true +``` + +### 4. Mount and Use + +```bash +ezpz mount # Enable the plugin system +``` + +```python +import polars as pl + +lf = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).lazy() +result = lf.my_operations.custom_transform(2.0) # Full IDE support! +``` + +### 5. Discover and Install Ecosystem Plugins(TBA) + +```bash +# List all available plugins in the EZPZ ecosystem +ezpz list + +# Search for specific plugins +ezpz find technical +``` + +## 🔍 Plugin Discovery(TBA) + +The EZPZ ecosystem includes a plugin registry that makes it easy to discover and install plugins. + +### For Users + +```bash +# List all available plugins +ezpz list + +# Search for plugins by keyword +ezpz find analysis +ezpz find rust +``` + +### For Plugin Devs + +To register your plugin in the EZPZ ecosystem: + +1. **Add the registration function** to your plugin's `__init__.py`: + +```python +from typing import TYPE_CHECKING, cast + +from ezpz_pluginz.registry.models import PluginMetadata, PluginMetadataInner + +if TYPE_CHECKING: + from pydantic import HttpUrl + +def register_plugin() -> PluginMetadata: + return PluginMetadata( + name="my-plugin", + package_name="plugin-package-name", + description="Plugin description", + aliases=["alias1", "alias2", "alias3"], + version="0.1.0", + author="author", + category="category", + homepage=cast("HttpUrl", "https://home-page"), + metadata_=PluginMetadataInner( + tags=["tag1", "tag2", "tag3"], + license="MIT", + python_version=">=3.13", + dependencies=["ezpz-pluginz", "polars==1.31.0", "pyarrow==20.0.0"], + documentation=cast("HttpUrl", "https://doc-url"), + support_email="your email", + ), + ) +``` + +2. **Add entry point** in your `pyproject.toml`: + +```toml +[project.entry-points."ezpz.plugins"] +my-plugin = "my_plugin:register_plugin" +``` + +3. **Add ezpz-pluginz as dependency**: + +```toml +dependencies = [ + "ezpz-pluginz>=0.1.0", + # ... other deps +] +``` + +That's it! Your plugin will automatically appear when users run `ezpz list`. + +## 🖥️ CLI Commands + +| Command | Purpose | Example | +| --------------------- | -------------------------------- | ---------------- | +| `ezpz mount` | Enable plugin type hints | `ezpz mount` | +| `ezpz unmount` | Disable plugin type hints | `ezpz unmount` | +| `ezpz list` | List available ecosystem plugins | `ezpz list` | +| `ezpz find ` | Search plugins by keyword | `ezpz find rust` | + +## 🎯 Use Cases + +### For Plugin Developers + +- **Type-Safe Development**: Build Polars plugins with type checking +- **Amazing IDE Experience**: Enjoy autocompletion and error detection +- **Easy Distribution**: Publish plugins that integrate seamlessly with the ecosystem +- **Plugin Registry**: Register your plugins for easy discovery by users + +### For Data Scientists + +- **Extended Functionality**: Access powerful extensions like technical analysis +- **Plugin Discovery**: Easily find and install community plugins +- **Familiar Interface**: Work with enhanced Polars using the same API patterns +- **Performance**: Benefit from Rust-powered implementations + +### For Library Authors + +- **Integration Framework**: Build upon EZPZ's plugin architecture +- **Type Safety**: Leverage PyO3 wrappers for robust Rust-Python integration +- **Ecosystem Compatibility**: Ensure your extensions work with existing tools + +## 📋 Installation Matrix + +| Component | Purpose | Installation | Discovery | +| ---------------- | ------------------ | -------------------------- | ----------- | +| **EZPZ-Pluginz** | Core plugin system | `pip install ezpz_pluginz` | N/A | +| **EZPZ-Rust-TI** | Technical analysis | `pip install rust-ti` | `ezpz list` | +| **EZPZ-Stubz** | PyO3 type wrappers | `cargo add ezpz-stubz` | N/A | +| **EZPZ-Macroz** | Macro system | `pip install macroz` | N/A | + +## 🔧 Development Setup + +```bash +# Clone the repository +git clone https://github.com/Summit-Sailors/EZPZ.git +cd EZPZ + +# Install development dependencies +pip install -e ./pluginz[dev] +pip install -e ./macroz[dev] + +# Install Rust components +cargo build --workspace + +# Run tests +pytest pluginz/tests/ +cargo test --workspace +``` + +## 🎯 Roadmap + +- Official Polars team blessing ([tracking issue](https://github.com/pola-rs/polars/issues/14475)) +- Plugin marketplace and discovery ✅ +- More showcase plugins +- Advanced debugging tools + +### Component-Specific Guidelines + +- **Pluginz**: Focus on type safety and IDE integration +- **Rust-TI**: Maintain performance while expanding indicator coverage +- **Stubz**: Ensure zero-cost abstractions and complete type coverage +- **Macroz**: Consider future static analysis tool compatibility + +## 🤝 Contributing + +We welcome contributions to any part of the EZPZ ecosystem! Each component has its own contribution guidelines: + +- **Plugin System**: Focus on type safety and developer experience +- **Macro System**: Maintain lightweight, LibCST-based approach +- **Stubz**: Ensure zero-cost abstractions and proper stub generation +- **Showcase Plugins**: Demonstrate best practices and real-world usage + +## 📚 Documentation + +- [EZPZ-Pluginz Documentation](./core/pluginz/README.md) +- [Painlezz Macroz Documentation](./core/macroz/README.md) +- [EZPZ Stubz Documentation](./stubz/README.md) +- [Technical Analysis Plugin](./plugins/ezpz-rust-ti/README.md) +- [Examples and Tutorials](./examples/README.md) + +## 🙏 Acknowledgments + +- **[Polars](https://pola.rs/)** - The amazing DataFrame library that makes this all possible +- **[PyO3](https://pyo3.rs/)** - Rust bindings for Python enabling seamless integration +- **[LibCST](https://libcst.readthedocs.io/)** - Concrete syntax trees for Python code transformation +- **[rust_ti](https://crates.io/crates/rust_ti)** - Technical analysis algorithms powering our indicators + +## 💖 Support + +For support and sponsorship opportunities, visit our Polar page: + + + + +Subscription Tiers on Polar + + + +## 📄 License + +This project is licensed under the MIT License. See LICENSE file for details. + +--- + +**EZPZ** - Making Polars plugin development EZPZ! 🚀 diff --git a/actions.just b/actions.just new file mode 100644 index 0000000..e839393 --- /dev/null +++ b/actions.just @@ -0,0 +1,121 @@ +set shell := ["bash", "-uc"] +set export +set dotenv-load := true + +analyze-plugins: + #!/usr/bin/env bash + set -euo pipefail + python3 .github/scripts/plugins/plugin_manager.py analyze + +test-plugin-pipeline package_name plugin_path: + #!/usr/bin/env bash + set -euo pipefail + echo "🔍 Testing plugin: {{package_name}} at {{plugin_path}}" + + nu .github/scripts/plugins/plugin_ops.nu test-pipeline "{{package_name}}" "{{plugin_path}}" + +register-update-plugins plugins_to_register plugins_to_update dry_run: + #!/usr/bin/env bash + set -euo pipefail + + if [[ "{{plugins_to_register}}" != "[]" ]]; then + echo "📝 Registering new plugins..." + python3 .github/scripts/plugins/plugin_manager.py register \ + --plugins='{{plugins_to_register}}' \ + --dry-run='{{dry_run}}' + fi + + if [[ "{{plugins_to_update}}" != "[]" ]]; then + echo "🔄 Updating existing plugins..." + python3 .github/scripts/plugins/plugin_manager.py update \ + --plugins='{{plugins_to_update}}' \ + --dry-run='{{dry_run}}' + fi + +publish-plugin package_name plugin_path dry_run plugins_to_register plugins_to_update: + #!/usr/bin/env bash + set -euo pipefail + + if ! python3 .github/scripts/plugins/plugin_manager.py check-publish \ + --package-name="{{package_name}}" \ + --plugins-to-register='{{plugins_to_register}}' \ + --plugins-to-update='{{plugins_to_update}}'; then + echo "ℹ️ Plugin {{package_name}} does not need publishing" + exit 0 + fi + + echo "📦 Publishing plugin: {{package_name}}" + + nu .github/scripts/plugins/plugin_ops.nu publish "{{package_name}}" "{{plugin_path}}" --dry-run='{{dry_run}}' + +validate-plugin package_name plugin_path: + #!/usr/bin/env bash + set -euo pipefail + nu .github/scripts/plugins/plugin_ops.nu validate "{{package_name}}" "{{plugin_path}}" + +build-plugin package_name plugin_path: + #!/usr/bin/env bash + set -euo pipefail + nu .github/scripts/plugins/plugin_ops.nu build "{{package_name}}" "{{plugin_path}}" + +test-plugin package_name plugin_path: + #!/usr/bin/env bash + set -euo pipefail + nu .github/scripts/plugins/plugin_ops.nu test "{{package_name}}" "{{plugin_path}}" + +validate-all: + #!/usr/bin/env bash + set -euo pipefail + + for plugin_dir in plugins/*/; do + if [[ -d "$plugin_dir" ]]; then + plugin_name=$(basename "$plugin_dir") + echo "🔍 Validating $plugin_name..." + just validate-plugin "$plugin_name" "$plugin_dir" + fi + done + +test-all: + #!/usr/bin/env bash + set -euo pipefail + + for plugin_dir in plugins/*/; do + if [[ -d "$plugin_dir" ]]; then + plugin_name=$(basename "$plugin_dir") + echo "🧪 Testing $plugin_name..." + just test-plugin-pipeline "$plugin_name" "$plugin_dir" + fi + done + +clean: + #!/usr/bin/env bash + set -euo pipefail + echo "🧹 Cleaning build artifacts..." + + find . -name "*.pyc" -delete + find . -name "__pycache__" -type d -exec rm -rf {} + 2>/dev/null || true + find . -name "dist" -type d -exec rm -rf {} + 2>/dev/null || true + find . -name "build" -type d -exec rm -rf {} + 2>/dev/null || true + find . -name "*.egg-info" -type d -exec rm -rf {} + 2>/dev/null || true + find . -name "target" -type d -exec rm -rf {} + 2>/dev/null || true + + echo "✅ Clean completed" + +install-tools: + #!/usr/bin/env bash + set -euo pipefail + echo "Installing security and maintenance tools..." + rye install bandit + rye install semgrep + rye install pip-audit + cargo install cargo-audit cargo-outdated + +uninstall-tools: + #!/usr/bin/env bash + set -euo pipefail + echo "Uninstalling security and maintenance tools..." + rye uninstall bandit + rye uninstall semgrep + rye uninstall pip-audit + cargo uninstall cargo-audit cargo-outdated + \ No newline at end of file diff --git a/api/Cargo.toml b/api/Cargo.toml deleted file mode 100644 index 0803839..0000000 --- a/api/Cargo.toml +++ /dev/null @@ -1,41 +0,0 @@ -[package] -authors = { workspace = true } -description = { workspace = true } -edition = { workspace = true } -license = { workspace = true } -name = "api" -repository = { workspace = true } - -[package.metadata.stilts] -template_dir = "$CARGO_MANIFEST_DIR/src" -trim = false - -[dependencies] -maestro-anthropic = { workspace = true } - - -bon = { workspace = true } -chrono = { workspace = true } -futures = { workspace = true } -schemars = { workspace = true } -serde = { workspace = true } -serde_json = { workspace = true } -stilts = { workspace = true } -strum = { workspace = true } -url = { workspace = true } -uuid = { workspace = true } -validator = { workspace = true } - -dioxus = { workspace = true, features = [], default-features = false } - -maestro-diesel = { workspace = true, features = ["async"], optional = true } - -diesel = { version = "2.2.10", features = ["chrono", "postgres", "serde_json", "uuid"], optional = true } -diesel-async = { version = "0.5.2", features = ["postgres"], optional = true } -diesel-derive-enum = { version = "2.1.0", features = ["postgres"], optional = true } - - -[features] -dioxus = ["maestro-anthropic/dioxus"] -dioxus-server = ["dioxus", "maestro-diesel/dioxus", "server"] -server = ["dep:diesel", "dep:diesel-async", "dep:diesel-derive-enum", "dep:maestro-diesel", "maestro-anthropic/server"] diff --git a/api/src/lib.rs b/api/src/lib.rs deleted file mode 100644 index c8394a7..0000000 --- a/api/src/lib.rs +++ /dev/null @@ -1,4 +0,0 @@ -// pub mod postings; - -#[cfg(feature = "server")] -pub mod schema; diff --git a/api/src/schema.rs b/api/src/schema.rs deleted file mode 100644 index d9a52af..0000000 --- a/api/src/schema.rs +++ /dev/null @@ -1 +0,0 @@ -// @generated automatically by Diesel CLI. diff --git a/app/.env.example b/app/.env.example deleted file mode 100644 index 29e7135..0000000 --- a/app/.env.example +++ /dev/null @@ -1,4 +0,0 @@ -DATABASE_URL=postgresql://postgres:postgres@localhost:5432/postgres -ANTHROPIC_API_KEY= - -USERS_COUNT=100 diff --git a/app/Cargo.toml b/app/Cargo.toml deleted file mode 100644 index 21f1e03..0000000 --- a/app/Cargo.toml +++ /dev/null @@ -1,59 +0,0 @@ -[package] -authors = { workspace = true } -description = { workspace = true } -edition = { workspace = true } -license = { workspace = true } -name = "app" -repository = { workspace = true } - - -[dependencies] - -maestro-hooks = { workspace = true } -maestro-toast = { workspace = true } -maestro-ui = { workspace = true } - -api = { workspace = true, features = ["dioxus"] } - -bon = { workspace = true } -chrono = { workspace = true } -dioxus = { workspace = true, features = ["fullstack", "router"] } -dioxus-free-icons = { workspace = true } -dioxus-sdk = { workspace = true, features = ["time"] } -futures = { workspace = true } -markdown-to-html = { workspace = true } -plotters = { workspace = true } -plotters-canvas = { workspace = true } -tailwind_fuse = { workspace = true } - -anyhow = { workspace = true } - -serde = { workspace = true } -serde_json = { workspace = true } -strum = { workspace = true } -uuid = { workspace = true } - -diesel = { version = "2.2.10", features = ["chrono", "postgres", "serde_json", "uuid"], optional = true } -diesel-async = { version = "0.5.2", features = ["postgres"], optional = true } - -maestro-anthropic = { workspace = true, features = ["dioxus"] } -maestro-apalis = { workspace = true, features = ["create"], optional = true } -maestro-diesel = { workspace = true, features = ["async", "dioxus"], optional = true } - -[build-dependencies] -dotenvy = { git = "https://github.com/allan2/dotenvy.git", features = ["macros"] } - -[features] -desktop = ["dioxus/desktop"] -web = ["chrono/wasmbind", "dioxus/web", "uuid/js"] - -server = [ - "api/dioxus-server", - "api/server", - "dep:diesel", - "dep:diesel-async", - "dep:maestro-apalis", - "dep:maestro-diesel", - "dioxus/server", - "maestro-anthropic/server", -] diff --git a/app/Dioxus.toml b/app/Dioxus.toml deleted file mode 100644 index d62df35..0000000 --- a/app/Dioxus.toml +++ /dev/null @@ -1,36 +0,0 @@ -#:schema https://raw.githubusercontent.com/umnovI/dioxus-config-schema/main/dioxus.schema.json - -[application] -asset_dir = "./assets" -default_platform = "desktop" -name = "prompt-rs" -out_dir = "dist" - -[web.app] -title = "Upwork Jobs Navigator" - -[web.watcher] -index_on_404 = true -reload_html = true -watch_path = ["."] - -[web.resource] -script = [] -style = [] - -[web.resource.dev] -script = [] - -# FIXME: Need to `cd assets` before running `dx bundle` due to https://github.com/DioxusLabs/dioxus/issues/1283 -[bundle] -category = "" -copyright = "" -icon = [] -identifier = "" -long_description = """ -""" -name = "dioxus-desktop-template" -osx_frameworks = [] -resources = ["public"] -short_description = "" -version = "0.0.1" diff --git a/app/assets/android-chrome-192x192.png b/app/assets/android-chrome-192x192.png deleted file mode 100644 index 64644b7..0000000 Binary files a/app/assets/android-chrome-192x192.png and /dev/null differ diff --git a/app/assets/android-chrome-512x512.png b/app/assets/android-chrome-512x512.png deleted file mode 100644 index 0aeb4c9..0000000 Binary files a/app/assets/android-chrome-512x512.png and /dev/null differ diff --git a/app/assets/apple-touch-icon.png b/app/assets/apple-touch-icon.png deleted file mode 100644 index 549f88f..0000000 Binary files a/app/assets/apple-touch-icon.png and /dev/null differ diff --git a/app/assets/eyes.svg b/app/assets/eyes.svg deleted file mode 100644 index e809b68..0000000 --- a/app/assets/eyes.svg +++ /dev/null @@ -1,17 +0,0 @@ - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/app/assets/favicon-16x16.png b/app/assets/favicon-16x16.png deleted file mode 100644 index 7e998e0..0000000 Binary files a/app/assets/favicon-16x16.png and /dev/null differ diff --git a/app/assets/favicon-32x32.png b/app/assets/favicon-32x32.png deleted file mode 100644 index db8a1ac..0000000 Binary files a/app/assets/favicon-32x32.png and /dev/null differ diff --git a/app/assets/favicon.ico b/app/assets/favicon.ico deleted file mode 100644 index f451442..0000000 Binary files a/app/assets/favicon.ico and /dev/null differ diff --git a/app/assets/site.webmanifest b/app/assets/site.webmanifest deleted file mode 100644 index 45dc8a2..0000000 --- a/app/assets/site.webmanifest +++ /dev/null @@ -1 +0,0 @@ -{"name":"","short_name":"","icons":[{"src":"/android-chrome-192x192.png","sizes":"192x192","type":"image/png"},{"src":"/android-chrome-512x512.png","sizes":"512x512","type":"image/png"}],"theme_color":"#ffffff","background_color":"#ffffff","display":"standalone"} \ No newline at end of file diff --git a/app/assets/sw.js b/app/assets/sw.js deleted file mode 100644 index 9bc1b6d..0000000 --- a/app/assets/sw.js +++ /dev/null @@ -1,20 +0,0 @@ -var cacheName = "egui-template-pwa"; -var filesToCache = ["./", "./index.html", "./juzt.js", "./juzt_bg.wasm"]; - -/* Start the service worker and cache all of the app's content */ -self.addEventListener("install", function (e) { - e.waitUntil( - caches.open(cacheName).then(function (cache) { - return cache.addAll(filesToCache); - }), - ); -}); - -/* Serve cached content when offline */ -self.addEventListener("fetch", function (e) { - e.respondWith( - caches.match(e.request).then(function (response) { - return response || fetch(e.request); - }), - ); -}); diff --git a/app/assets/tailwind.css b/app/assets/tailwind.css deleted file mode 100644 index c7f5b61..0000000 --- a/app/assets/tailwind.css +++ /dev/null @@ -1,2685 +0,0 @@ -*, ::before, ::after { - --tw-border-spacing-x: 0; - --tw-border-spacing-y: 0; - --tw-translate-x: 0; - --tw-translate-y: 0; - --tw-rotate: 0; - --tw-skew-x: 0; - --tw-skew-y: 0; - --tw-scale-x: 1; - --tw-scale-y: 1; - --tw-pan-x: ; - --tw-pan-y: ; - --tw-pinch-zoom: ; - --tw-scroll-snap-strictness: proximity; - --tw-gradient-from-position: ; - --tw-gradient-via-position: ; - --tw-gradient-to-position: ; - --tw-ordinal: ; - --tw-slashed-zero: ; - --tw-numeric-figure: ; - --tw-numeric-spacing: ; - --tw-numeric-fraction: ; - --tw-ring-inset: ; - --tw-ring-offset-width: 0px; - --tw-ring-offset-color: #fff; - --tw-ring-color: rgb(59 130 246 / 0.5); - --tw-ring-offset-shadow: 0 0 #0000; - --tw-ring-shadow: 0 0 #0000; - --tw-shadow: 0 0 #0000; - --tw-shadow-colored: 0 0 #0000; - --tw-blur: ; - --tw-brightness: ; - --tw-contrast: ; - --tw-grayscale: ; - --tw-hue-rotate: ; - --tw-invert: ; - --tw-saturate: ; - --tw-sepia: ; - --tw-drop-shadow: ; - --tw-backdrop-blur: ; - --tw-backdrop-brightness: ; - --tw-backdrop-contrast: ; - --tw-backdrop-grayscale: ; - --tw-backdrop-hue-rotate: ; - --tw-backdrop-invert: ; - --tw-backdrop-opacity: ; - --tw-backdrop-saturate: ; - --tw-backdrop-sepia: ; - --tw-contain-size: ; - --tw-contain-layout: ; - --tw-contain-paint: ; - --tw-contain-style: ; -} - -::backdrop { - --tw-border-spacing-x: 0; - --tw-border-spacing-y: 0; - --tw-translate-x: 0; - --tw-translate-y: 0; - --tw-rotate: 0; - --tw-skew-x: 0; - --tw-skew-y: 0; - --tw-scale-x: 1; - --tw-scale-y: 1; - --tw-pan-x: ; - --tw-pan-y: ; - --tw-pinch-zoom: ; - --tw-scroll-snap-strictness: proximity; - --tw-gradient-from-position: ; - --tw-gradient-via-position: ; - --tw-gradient-to-position: ; - --tw-ordinal: ; - --tw-slashed-zero: ; - --tw-numeric-figure: ; - --tw-numeric-spacing: ; - --tw-numeric-fraction: ; - --tw-ring-inset: ; - --tw-ring-offset-width: 0px; - --tw-ring-offset-color: #fff; - --tw-ring-color: rgb(59 130 246 / 0.5); - --tw-ring-offset-shadow: 0 0 #0000; - --tw-ring-shadow: 0 0 #0000; - --tw-shadow: 0 0 #0000; - --tw-shadow-colored: 0 0 #0000; - --tw-blur: ; - --tw-brightness: ; - --tw-contrast: ; - --tw-grayscale: ; - --tw-hue-rotate: ; - --tw-invert: ; - --tw-saturate: ; - --tw-sepia: ; - --tw-drop-shadow: ; - --tw-backdrop-blur: ; - --tw-backdrop-brightness: ; - --tw-backdrop-contrast: ; - --tw-backdrop-grayscale: ; - --tw-backdrop-hue-rotate: ; - --tw-backdrop-invert: ; - --tw-backdrop-opacity: ; - --tw-backdrop-saturate: ; - --tw-backdrop-sepia: ; - --tw-contain-size: ; - --tw-contain-layout: ; - --tw-contain-paint: ; - --tw-contain-style: ; -} - -/* -! tailwindcss v3.4.17 | MIT License | https://tailwindcss.com -*/ - -/* -1. Prevent padding and border from affecting element width. (https://github.com/mozdevs/cssremedy/issues/4) -2. Allow adding a border to an element by just adding a border-width. (https://github.com/tailwindcss/tailwindcss/pull/116) -*/ - -*, -::before, -::after { - box-sizing: border-box; - /* 1 */ - border-width: 0; - /* 2 */ - border-style: solid; - /* 2 */ - border-color: #e5e7eb; - /* 2 */ -} - -::before, -::after { - --tw-content: ''; -} - -/* -1. Use a consistent sensible line-height in all browsers. -2. Prevent adjustments of font size after orientation changes in iOS. -3. Use a more readable tab size. -4. Use the user's configured `sans` font-family by default. -5. Use the user's configured `sans` font-feature-settings by default. -6. Use the user's configured `sans` font-variation-settings by default. -7. Disable tap highlights on iOS -*/ - -html, -:host { - line-height: 1.5; - /* 1 */ - -webkit-text-size-adjust: 100%; - /* 2 */ - -moz-tab-size: 4; - /* 3 */ - -o-tab-size: 4; - tab-size: 4; - /* 3 */ - font-family: ui-sans-serif, system-ui, sans-serif, "Apple Color Emoji", "Segoe UI Emoji", "Segoe UI Symbol", "Noto Color Emoji"; - /* 4 */ - font-feature-settings: normal; - /* 5 */ - font-variation-settings: normal; - /* 6 */ - -webkit-tap-highlight-color: transparent; - /* 7 */ -} - -/* -1. Remove the margin in all browsers. -2. Inherit line-height from `html` so users can set them as a class directly on the `html` element. -*/ - -body { - margin: 0; - /* 1 */ - line-height: inherit; - /* 2 */ -} - -/* -1. Add the correct height in Firefox. -2. Correct the inheritance of border color in Firefox. (https://bugzilla.mozilla.org/show_bug.cgi?id=190655) -3. Ensure horizontal rules are visible by default. -*/ - -hr { - height: 0; - /* 1 */ - color: inherit; - /* 2 */ - border-top-width: 1px; - /* 3 */ -} - -/* -Add the correct text decoration in Chrome, Edge, and Safari. -*/ - -abbr:where([title]) { - -webkit-text-decoration: underline dotted; - text-decoration: underline dotted; -} - -/* -Remove the default font size and weight for headings. -*/ - -h1, -h2, -h3, -h4, -h5, -h6 { - font-size: inherit; - font-weight: inherit; -} - -/* -Reset links to optimize for opt-in styling instead of opt-out. -*/ - -a { - color: inherit; - text-decoration: inherit; -} - -/* -Add the correct font weight in Edge and Safari. -*/ - -b, -strong { - font-weight: bolder; -} - -/* -1. Use the user's configured `mono` font-family by default. -2. Use the user's configured `mono` font-feature-settings by default. -3. Use the user's configured `mono` font-variation-settings by default. -4. Correct the odd `em` font sizing in all browsers. -*/ - -code, -kbd, -samp, -pre { - font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; - /* 1 */ - font-feature-settings: normal; - /* 2 */ - font-variation-settings: normal; - /* 3 */ - font-size: 1em; - /* 4 */ -} - -/* -Add the correct font size in all browsers. -*/ - -small { - font-size: 80%; -} - -/* -Prevent `sub` and `sup` elements from affecting the line height in all browsers. -*/ - -sub, -sup { - font-size: 75%; - line-height: 0; - position: relative; - vertical-align: baseline; -} - -sub { - bottom: -0.25em; -} - -sup { - top: -0.5em; -} - -/* -1. Remove text indentation from table contents in Chrome and Safari. (https://bugs.chromium.org/p/chromium/issues/detail?id=999088, https://bugs.webkit.org/show_bug.cgi?id=201297) -2. Correct table border color inheritance in all Chrome and Safari. (https://bugs.chromium.org/p/chromium/issues/detail?id=935729, https://bugs.webkit.org/show_bug.cgi?id=195016) -3. Remove gaps between table borders by default. -*/ - -table { - text-indent: 0; - /* 1 */ - border-color: inherit; - /* 2 */ - border-collapse: collapse; - /* 3 */ -} - -/* -1. Change the font styles in all browsers. -2. Remove the margin in Firefox and Safari. -3. Remove default padding in all browsers. -*/ - -button, -input, -optgroup, -select, -textarea { - font-family: inherit; - /* 1 */ - font-feature-settings: inherit; - /* 1 */ - font-variation-settings: inherit; - /* 1 */ - font-size: 100%; - /* 1 */ - font-weight: inherit; - /* 1 */ - line-height: inherit; - /* 1 */ - letter-spacing: inherit; - /* 1 */ - color: inherit; - /* 1 */ - margin: 0; - /* 2 */ - padding: 0; - /* 3 */ -} - -/* -Remove the inheritance of text transform in Edge and Firefox. -*/ - -button, -select { - text-transform: none; -} - -/* -1. Correct the inability to style clickable types in iOS and Safari. -2. Remove default button styles. -*/ - -button, -input:where([type='button']), -input:where([type='reset']), -input:where([type='submit']) { - -webkit-appearance: button; - /* 1 */ - background-color: transparent; - /* 2 */ - background-image: none; - /* 2 */ -} - -/* -Use the modern Firefox focus style for all focusable elements. -*/ - -:-moz-focusring { - outline: auto; -} - -/* -Remove the additional `:invalid` styles in Firefox. (https://github.com/mozilla/gecko-dev/blob/2f9eacd9d3d995c937b4251a5557d95d494c9be1/layout/style/res/forms.css#L728-L737) -*/ - -:-moz-ui-invalid { - box-shadow: none; -} - -/* -Add the correct vertical alignment in Chrome and Firefox. -*/ - -progress { - vertical-align: baseline; -} - -/* -Correct the cursor style of increment and decrement buttons in Safari. -*/ - -::-webkit-inner-spin-button, -::-webkit-outer-spin-button { - height: auto; -} - -/* -1. Correct the odd appearance in Chrome and Safari. -2. Correct the outline style in Safari. -*/ - -[type='search'] { - -webkit-appearance: textfield; - /* 1 */ - outline-offset: -2px; - /* 2 */ -} - -/* -Remove the inner padding in Chrome and Safari on macOS. -*/ - -::-webkit-search-decoration { - -webkit-appearance: none; -} - -/* -1. Correct the inability to style clickable types in iOS and Safari. -2. Change font properties to `inherit` in Safari. -*/ - -::-webkit-file-upload-button { - -webkit-appearance: button; - /* 1 */ - font: inherit; - /* 2 */ -} - -/* -Add the correct display in Chrome and Safari. -*/ - -summary { - display: list-item; -} - -/* -Removes the default spacing and border for appropriate elements. -*/ - -blockquote, -dl, -dd, -h1, -h2, -h3, -h4, -h5, -h6, -hr, -figure, -p, -pre { - margin: 0; -} - -fieldset { - margin: 0; - padding: 0; -} - -legend { - padding: 0; -} - -ol, -ul, -menu { - list-style: none; - margin: 0; - padding: 0; -} - -/* -Reset default styling for dialogs. -*/ - -dialog { - padding: 0; -} - -/* -Prevent resizing textareas horizontally by default. -*/ - -textarea { - resize: vertical; -} - -/* -1. Reset the default placeholder opacity in Firefox. (https://github.com/tailwindlabs/tailwindcss/issues/3300) -2. Set the default placeholder color to the user's configured gray 400 color. -*/ - -input::-moz-placeholder, textarea::-moz-placeholder { - opacity: 1; - /* 1 */ - color: #9ca3af; - /* 2 */ -} - -input::placeholder, -textarea::placeholder { - opacity: 1; - /* 1 */ - color: #9ca3af; - /* 2 */ -} - -/* -Set the default cursor for buttons. -*/ - -button, -[role="button"] { - cursor: pointer; -} - -/* -Make sure disabled buttons don't get the pointer cursor. -*/ - -:disabled { - cursor: default; -} - -/* -1. Make replaced elements `display: block` by default. (https://github.com/mozdevs/cssremedy/issues/14) -2. Add `vertical-align: middle` to align replaced elements more sensibly by default. (https://github.com/jensimmons/cssremedy/issues/14#issuecomment-634934210) - This can trigger a poorly considered lint error in some tools but is included by design. -*/ - -img, -svg, -video, -canvas, -audio, -iframe, -embed, -object { - display: block; - /* 1 */ - vertical-align: middle; - /* 2 */ -} - -/* -Constrain images and videos to the parent width and preserve their intrinsic aspect ratio. (https://github.com/mozdevs/cssremedy/issues/14) -*/ - -img, -video { - max-width: 100%; - height: auto; -} - -/* Make elements with the HTML hidden attribute stay hidden by default */ - -[hidden]:where(:not([hidden="until-found"])) { - display: none; -} - -.pointer-events-none { - pointer-events: none; -} - -.pointer-events-auto { - pointer-events: auto; -} - -.static { - position: static; -} - -.fixed { - position: fixed; -} - -.absolute { - position: absolute; -} - -.relative { - position: relative; -} - -.inset-0 { - inset: 0px; -} - -.-bottom-0\.5 { - bottom: -0.125rem; -} - -.-right-\[360px\] { - right: -360px; -} - -.bottom-0 { - bottom: 0px; -} - -.bottom-5 { - bottom: 1.25rem; -} - -.left-0 { - left: 0px; -} - -.left-3 { - left: 0.75rem; -} - -.left-5 { - left: 1.25rem; -} - -.right-0 { - right: 0px; -} - -.right-2 { - right: 0.5rem; -} - -.right-3 { - right: 0.75rem; -} - -.right-5 { - right: 1.25rem; -} - -.top-0 { - top: 0px; -} - -.top-2 { - top: 0.5rem; -} - -.top-2\.5 { - top: 0.625rem; -} - -.top-20 { - top: 5rem; -} - -.top-5 { - top: 1.25rem; -} - -.top-\[100\%\] { - top: 100%; -} - -.-z-20 { - z-index: -20; -} - -.-z-40 { - z-index: -40; -} - -.-z-\[9999\] { - z-index: -9999; -} - -.z-40 { - z-index: 40; -} - -.z-50 { - z-index: 50; -} - -.z-\[60\] { - z-index: 60; -} - -.z-\[9000\] { - z-index: 9000; -} - -.col-span-6 { - grid-column: span 6 / span 6; -} - -.m-0 { - margin: 0px; -} - -.mx-auto { - margin-left: auto; - margin-right: auto; -} - -.my-auto { - margin-top: auto; - margin-bottom: auto; -} - -.mb-1 { - margin-bottom: 0.25rem; -} - -.mb-2 { - margin-bottom: 0.5rem; -} - -.mb-4 { - margin-bottom: 1rem; -} - -.mb-6 { - margin-bottom: 1.5rem; -} - -.ml-auto { - margin-left: auto; -} - -.mr-1 { - margin-right: 0.25rem; -} - -.mt-3 { - margin-top: 0.75rem; -} - -.mt-4 { - margin-top: 1rem; -} - -.mt-6 { - margin-top: 1.5rem; -} - -.mt-auto { - margin-top: auto; -} - -.line-clamp-1 { - overflow: hidden; - display: -webkit-box; - -webkit-box-orient: vertical; - -webkit-line-clamp: 1; -} - -.line-clamp-2 { - overflow: hidden; - display: -webkit-box; - -webkit-box-orient: vertical; - -webkit-line-clamp: 2; -} - -.line-clamp-3 { - overflow: hidden; - display: -webkit-box; - -webkit-box-orient: vertical; - -webkit-line-clamp: 3; -} - -.block { - display: block; -} - -.inline-block { - display: inline-block; -} - -.flex { - display: flex; -} - -.inline-flex { - display: inline-flex; -} - -.grid { - display: grid; -} - -.contents { - display: contents; -} - -.hidden { - display: none; -} - -.h-0 { - height: 0px; -} - -.h-0\.5 { - height: 0.125rem; -} - -.h-1 { - height: 0.25rem; -} - -.h-10 { - height: 2.5rem; -} - -.h-11 { - height: 2.75rem; -} - -.h-12 { - height: 3rem; -} - -.h-2 { - height: 0.5rem; -} - -.h-4 { - height: 1rem; -} - -.h-5 { - height: 1.25rem; -} - -.h-6 { - height: 1.5rem; -} - -.h-8 { - height: 2rem; -} - -.h-9 { - height: 2.25rem; -} - -.h-96 { - height: 24rem; -} - -.h-auto { - height: auto; -} - -.h-fit { - height: -moz-fit-content; - height: fit-content; -} - -.h-full { - height: 100%; -} - -.max-h-48 { - max-height: 12rem; -} - -.max-h-72 { - max-height: 18rem; -} - -.max-h-\[80vh\] { - max-height: 80vh; -} - -.max-h-screen { - max-height: 100vh; -} - -.min-h-11 { - min-height: 2.75rem; -} - -.min-h-12 { - min-height: 3rem; -} - -.min-h-screen { - min-height: 100vh; -} - -.w-0 { - width: 0px; -} - -.w-1 { - width: 0.25rem; -} - -.w-1\/3 { - width: 33.333333%; -} - -.w-10 { - width: 2.5rem; -} - -.w-11 { - width: 2.75rem; -} - -.w-12 { - width: 3rem; -} - -.w-2 { - width: 0.5rem; -} - -.w-2\/3 { - width: 66.666667%; -} - -.w-4 { - width: 1rem; -} - -.w-5 { - width: 1.25rem; -} - -.w-6 { - width: 1.5rem; -} - -.w-8 { - width: 2rem; -} - -.w-9 { - width: 2.25rem; -} - -.w-\[310px\] { - width: 310px; -} - -.w-fit { - width: -moz-fit-content; - width: fit-content; -} - -.w-full { - width: 100%; -} - -.min-w-\[200px\] { - min-width: 200px; -} - -.min-w-\[448px\] { - min-width: 448px; -} - -.max-w-2xl { - max-width: 42rem; -} - -.max-w-32 { - max-width: 8rem; -} - -.max-w-3xl { - max-width: 48rem; -} - -.max-w-40 { - max-width: 10rem; -} - -.max-w-72 { - max-width: 18rem; -} - -.max-w-md { - max-width: 28rem; -} - -.max-w-sm { - max-width: 24rem; -} - -.flex-1 { - flex: 1 1 0%; -} - -.flex-shrink-0 { - flex-shrink: 0; -} - -.shrink-0 { - flex-shrink: 0; -} - -.flex-grow { - flex-grow: 1; -} - -.translate-x-1 { - --tw-translate-x: 0.25rem; - transform: translate(var(--tw-translate-x), var(--tw-translate-y)) rotate(var(--tw-rotate)) skewX(var(--tw-skew-x)) skewY(var(--tw-skew-y)) scaleX(var(--tw-scale-x)) scaleY(var(--tw-scale-y)); -} - -.translate-x-5 { - --tw-translate-x: 1.25rem; - transform: translate(var(--tw-translate-x), var(--tw-translate-y)) rotate(var(--tw-rotate)) skewX(var(--tw-skew-x)) skewY(var(--tw-skew-y)) scaleX(var(--tw-scale-x)) scaleY(var(--tw-scale-y)); -} - -.rotate-180 { - --tw-rotate: 180deg; - transform: translate(var(--tw-translate-x), var(--tw-translate-y)) rotate(var(--tw-rotate)) skewX(var(--tw-skew-x)) skewY(var(--tw-skew-y)) scaleX(var(--tw-scale-x)) scaleY(var(--tw-scale-y)); -} - -.transform { - transform: translate(var(--tw-translate-x), var(--tw-translate-y)) rotate(var(--tw-rotate)) skewX(var(--tw-skew-x)) skewY(var(--tw-skew-y)) scaleX(var(--tw-scale-x)) scaleY(var(--tw-scale-y)); -} - -@keyframes spin { - to { - transform: rotate(360deg); - } -} - -.animate-spin { - animation: spin 1s linear infinite; -} - -.cursor-not-allowed { - cursor: not-allowed; -} - -.cursor-pointer { - cursor: pointer; -} - -.resize-none { - resize: none; -} - -.appearance-none { - -webkit-appearance: none; - -moz-appearance: none; - appearance: none; -} - -.grid-cols-1 { - grid-template-columns: repeat(1, minmax(0, 1fr)); -} - -.grid-cols-2 { - grid-template-columns: repeat(2, minmax(0, 1fr)); -} - -.grid-cols-7 { - grid-template-columns: repeat(7, minmax(0, 1fr)); -} - -.flex-row-reverse { - flex-direction: row-reverse; -} - -.flex-col { - flex-direction: column; -} - -.flex-wrap { - flex-wrap: wrap; -} - -.items-start { - align-items: flex-start; -} - -.items-end { - align-items: flex-end; -} - -.items-center { - align-items: center; -} - -.justify-end { - justify-content: flex-end; -} - -.justify-center { - justify-content: center; -} - -.justify-between { - justify-content: space-between; -} - -.justify-around { - justify-content: space-around; -} - -.gap-0\.5 { - gap: 0.125rem; -} - -.gap-1 { - gap: 0.25rem; -} - -.gap-1\.5 { - gap: 0.375rem; -} - -.gap-2 { - gap: 0.5rem; -} - -.gap-3 { - gap: 0.75rem; -} - -.gap-4 { - gap: 1rem; -} - -.gap-6 { - gap: 1.5rem; -} - -.gap-x-2 { - -moz-column-gap: 0.5rem; - column-gap: 0.5rem; -} - -.gap-y-4 { - row-gap: 1rem; -} - -.gap-y-5 { - row-gap: 1.25rem; -} - -.gap-y-6 { - row-gap: 1.5rem; -} - -.gap-y-8 { - row-gap: 2rem; -} - -.space-y-2 > :not([hidden]) ~ :not([hidden]) { - --tw-space-y-reverse: 0; - margin-top: calc(0.5rem * calc(1 - var(--tw-space-y-reverse))); - margin-bottom: calc(0.5rem * var(--tw-space-y-reverse)); -} - -.space-y-4 > :not([hidden]) ~ :not([hidden]) { - --tw-space-y-reverse: 0; - margin-top: calc(1rem * calc(1 - var(--tw-space-y-reverse))); - margin-bottom: calc(1rem * var(--tw-space-y-reverse)); -} - -.space-y-6 > :not([hidden]) ~ :not([hidden]) { - --tw-space-y-reverse: 0; - margin-top: calc(1.5rem * calc(1 - var(--tw-space-y-reverse))); - margin-bottom: calc(1.5rem * var(--tw-space-y-reverse)); -} - -.overflow-hidden { - overflow: hidden; -} - -.overflow-y-auto { - overflow-y: auto; -} - -.overflow-y-hidden { - overflow-y: hidden; -} - -.whitespace-normal { - white-space: normal; -} - -.whitespace-nowrap { - white-space: nowrap; -} - -.whitespace-pre-wrap { - white-space: pre-wrap; -} - -.break-words { - overflow-wrap: break-word; -} - -.rounded { - border-radius: 0.25rem; -} - -.rounded-full { - border-radius: 9999px; -} - -.rounded-lg { - border-radius: 0.5rem; -} - -.rounded-md { - border-radius: 0.375rem; -} - -.border { - border-width: 1px; -} - -.border-b { - border-bottom-width: 1px; -} - -.border-r { - border-right-width: 1px; -} - -.border-t { - border-top-width: 1px; -} - -.\!border-slate-300 { - --tw-border-opacity: 1 !important; - border-color: rgb(203 213 225 / var(--tw-border-opacity, 1)) !important; -} - -.border-\[\#bce8f1\] { - --tw-border-opacity: 1; - border-color: rgb(188 232 241 / var(--tw-border-opacity, 1)); -} - -.border-\[\#d6e9c6\] { - --tw-border-opacity: 1; - border-color: rgb(214 233 198 / var(--tw-border-opacity, 1)); -} - -.border-\[\#ebccd1\] { - --tw-border-opacity: 1; - border-color: rgb(235 204 209 / var(--tw-border-opacity, 1)); -} - -.border-\[\#faebcc\] { - --tw-border-opacity: 1; - border-color: rgb(250 235 204 / var(--tw-border-opacity, 1)); -} - -.border-gray-500 { - --tw-border-opacity: 1; - border-color: rgb(107 114 128 / var(--tw-border-opacity, 1)); -} - -.border-gray-600 { - --tw-border-opacity: 1; - border-color: rgb(75 85 99 / var(--tw-border-opacity, 1)); -} - -.border-gray-700 { - --tw-border-opacity: 1; - border-color: rgb(55 65 81 / var(--tw-border-opacity, 1)); -} - -.border-slate-700 { - --tw-border-opacity: 1; - border-color: rgb(51 65 85 / var(--tw-border-opacity, 1)); -} - -.\!bg-slate-300 { - --tw-bg-opacity: 1 !important; - background-color: rgb(203 213 225 / var(--tw-bg-opacity, 1)) !important; -} - -.bg-\[\#31708f\] { - --tw-bg-opacity: 1; - background-color: rgb(49 112 143 / var(--tw-bg-opacity, 1)); -} - -.bg-\[\#3c763d\] { - --tw-bg-opacity: 1; - background-color: rgb(60 118 61 / var(--tw-bg-opacity, 1)); -} - -.bg-\[\#8a6d3b\] { - --tw-bg-opacity: 1; - background-color: rgb(138 109 59 / var(--tw-bg-opacity, 1)); -} - -.bg-\[\#a94442\] { - --tw-bg-opacity: 1; - background-color: rgb(169 68 66 / var(--tw-bg-opacity, 1)); -} - -.bg-\[\#bce8f1\] { - --tw-bg-opacity: 1; - background-color: rgb(188 232 241 / var(--tw-bg-opacity, 1)); -} - -.bg-\[\#d6e9c6\] { - --tw-bg-opacity: 1; - background-color: rgb(214 233 198 / var(--tw-bg-opacity, 1)); -} - -.bg-\[\#ebccd1\] { - --tw-bg-opacity: 1; - background-color: rgb(235 204 209 / var(--tw-bg-opacity, 1)); -} - -.bg-\[\#faebcc\] { - --tw-bg-opacity: 1; - background-color: rgb(250 235 204 / var(--tw-bg-opacity, 1)); -} - -.bg-blue-200 { - --tw-bg-opacity: 1; - background-color: rgb(191 219 254 / var(--tw-bg-opacity, 1)); -} - -.bg-blue-500 { - --tw-bg-opacity: 1; - background-color: rgb(59 130 246 / var(--tw-bg-opacity, 1)); -} - -.bg-blue-600 { - --tw-bg-opacity: 1; - background-color: rgb(37 99 235 / var(--tw-bg-opacity, 1)); -} - -.bg-gray-100 { - --tw-bg-opacity: 1; - background-color: rgb(243 244 246 / var(--tw-bg-opacity, 1)); -} - -.bg-gray-50 { - --tw-bg-opacity: 1; - background-color: rgb(249 250 251 / var(--tw-bg-opacity, 1)); -} - -.bg-gray-500 { - --tw-bg-opacity: 1; - background-color: rgb(107 114 128 / var(--tw-bg-opacity, 1)); -} - -.bg-gray-600 { - --tw-bg-opacity: 1; - background-color: rgb(75 85 99 / var(--tw-bg-opacity, 1)); -} - -.bg-gray-600\/80 { - background-color: rgb(75 85 99 / 0.8); -} - -.bg-gray-700 { - --tw-bg-opacity: 1; - background-color: rgb(55 65 81 / var(--tw-bg-opacity, 1)); -} - -.bg-gray-800 { - --tw-bg-opacity: 1; - background-color: rgb(31 41 55 / var(--tw-bg-opacity, 1)); -} - -.bg-gray-900\/40 { - background-color: rgb(17 24 39 / 0.4); -} - -.bg-indigo-600 { - --tw-bg-opacity: 1; - background-color: rgb(79 70 229 / var(--tw-bg-opacity, 1)); -} - -.bg-transparent { - background-color: transparent; -} - -.bg-white { - --tw-bg-opacity: 1; - background-color: rgb(255 255 255 / var(--tw-bg-opacity, 1)); -} - -.bg-none { - background-image: none; -} - -.fill-none { - fill: none; -} - -.object-contain { - -o-object-fit: contain; - object-fit: contain; -} - -.\!p-0 { - padding: 0px !important; -} - -.p-0 { - padding: 0px; -} - -.p-1 { - padding: 0.25rem; -} - -.p-2 { - padding: 0.5rem; -} - -.p-4 { - padding: 1rem; -} - -.p-6 { - padding: 1.5rem; -} - -.px-1 { - padding-left: 0.25rem; - padding-right: 0.25rem; -} - -.px-2 { - padding-left: 0.5rem; - padding-right: 0.5rem; -} - -.px-3 { - padding-left: 0.75rem; - padding-right: 0.75rem; -} - -.px-4 { - padding-left: 1rem; - padding-right: 1rem; -} - -.px-6 { - padding-left: 1.5rem; - padding-right: 1.5rem; -} - -.px-7 { - padding-left: 1.75rem; - padding-right: 1.75rem; -} - -.px-9 { - padding-left: 2.25rem; - padding-right: 2.25rem; -} - -.py-1 { - padding-top: 0.25rem; - padding-bottom: 0.25rem; -} - -.py-2 { - padding-top: 0.5rem; - padding-bottom: 0.5rem; -} - -.py-4 { - padding-top: 1rem; - padding-bottom: 1rem; -} - -.py-6 { - padding-top: 1.5rem; - padding-bottom: 1.5rem; -} - -.py-8 { - padding-top: 2rem; - padding-bottom: 2rem; -} - -.pb-1 { - padding-bottom: 0.25rem; -} - -.pr-0\.5 { - padding-right: 0.125rem; -} - -.pr-2 { - padding-right: 0.5rem; -} - -.pr-6 { - padding-right: 1.5rem; -} - -.pt-2 { - padding-top: 0.5rem; -} - -.text-left { - text-align: left; -} - -.text-center { - text-align: center; -} - -.font-dm-mono { - font-family: DM Mono, mono; -} - -.font-sans { - font-family: ui-sans-serif, system-ui, sans-serif, "Apple Color Emoji", "Segoe UI Emoji", "Segoe UI Symbol", "Noto Color Emoji"; -} - -.text-2xl { - font-size: 1.5rem; - line-height: 2rem; -} - -.text-3xl { - font-size: 1.875rem; - line-height: 2.25rem; -} - -.text-lg { - font-size: 1.125rem; - line-height: 1.75rem; -} - -.text-sm { - font-size: 0.875rem; - line-height: 1.25rem; -} - -.text-xl { - font-size: 1.25rem; - line-height: 1.75rem; -} - -.text-xs { - font-size: 0.75rem; - line-height: 1rem; -} - -.font-bold { - font-weight: 700; -} - -.font-medium { - font-weight: 500; -} - -.font-semibold { - font-weight: 600; -} - -.leading-\[17px\] { - line-height: 17px; -} - -.leading-tight { - line-height: 1.25; -} - -.\!text-slate-300 { - --tw-text-opacity: 1 !important; - color: rgb(203 213 225 / var(--tw-text-opacity, 1)) !important; -} - -.text-\[\#d9edf7\] { - --tw-text-opacity: 1; - color: rgb(217 237 247 / var(--tw-text-opacity, 1)); -} - -.text-\[\#dff0d8\] { - --tw-text-opacity: 1; - color: rgb(223 240 216 / var(--tw-text-opacity, 1)); -} - -.text-\[\#f2dede\] { - --tw-text-opacity: 1; - color: rgb(242 222 222 / var(--tw-text-opacity, 1)); -} - -.text-\[\#fcf8e3\] { - --tw-text-opacity: 1; - color: rgb(252 248 227 / var(--tw-text-opacity, 1)); -} - -.text-blue-600 { - --tw-text-opacity: 1; - color: rgb(37 99 235 / var(--tw-text-opacity, 1)); -} - -.text-gray-100 { - --tw-text-opacity: 1; - color: rgb(243 244 246 / var(--tw-text-opacity, 1)); -} - -.text-gray-200 { - --tw-text-opacity: 1; - color: rgb(229 231 235 / var(--tw-text-opacity, 1)); -} - -.text-gray-300 { - --tw-text-opacity: 1; - color: rgb(209 213 219 / var(--tw-text-opacity, 1)); -} - -.text-gray-400 { - --tw-text-opacity: 1; - color: rgb(156 163 175 / var(--tw-text-opacity, 1)); -} - -.text-gray-500 { - --tw-text-opacity: 1; - color: rgb(107 114 128 / var(--tw-text-opacity, 1)); -} - -.text-gray-600 { - --tw-text-opacity: 1; - color: rgb(75 85 99 / var(--tw-text-opacity, 1)); -} - -.text-gray-700 { - --tw-text-opacity: 1; - color: rgb(55 65 81 / var(--tw-text-opacity, 1)); -} - -.text-gray-800 { - --tw-text-opacity: 1; - color: rgb(31 41 55 / var(--tw-text-opacity, 1)); -} - -.text-gray-900 { - --tw-text-opacity: 1; - color: rgb(17 24 39 / var(--tw-text-opacity, 1)); -} - -.text-green-500 { - --tw-text-opacity: 1; - color: rgb(34 197 94 / var(--tw-text-opacity, 1)); -} - -.text-indigo-100 { - --tw-text-opacity: 1; - color: rgb(224 231 255 / var(--tw-text-opacity, 1)); -} - -.text-indigo-200 { - --tw-text-opacity: 1; - color: rgb(199 210 254 / var(--tw-text-opacity, 1)); -} - -.text-indigo-300 { - --tw-text-opacity: 1; - color: rgb(165 180 252 / var(--tw-text-opacity, 1)); -} - -.text-indigo-500 { - --tw-text-opacity: 1; - color: rgb(99 102 241 / var(--tw-text-opacity, 1)); -} - -.text-red-500 { - --tw-text-opacity: 1; - color: rgb(239 68 68 / var(--tw-text-opacity, 1)); -} - -.text-sky-500 { - --tw-text-opacity: 1; - color: rgb(14 165 233 / var(--tw-text-opacity, 1)); -} - -.text-slate-300 { - --tw-text-opacity: 1; - color: rgb(203 213 225 / var(--tw-text-opacity, 1)); -} - -.text-slate-400 { - --tw-text-opacity: 1; - color: rgb(148 163 184 / var(--tw-text-opacity, 1)); -} - -.text-white { - --tw-text-opacity: 1; - color: rgb(255 255 255 / var(--tw-text-opacity, 1)); -} - -.underline-offset-4 { - text-underline-offset: 4px; -} - -.opacity-0 { - opacity: 0; -} - -.opacity-100 { - opacity: 1; -} - -.opacity-50 { - opacity: 0.5; -} - -.shadow-\[0px_1px_10px\] { - --tw-shadow: 0px 1px 10px; - --tw-shadow-colored: 0px 1px 10px var(--tw-shadow-color); - box-shadow: var(--tw-ring-offset-shadow, 0 0 #0000), var(--tw-ring-shadow, 0 0 #0000), var(--tw-shadow); -} - -.shadow-lg { - --tw-shadow: 0 10px 15px -3px rgb(0 0 0 / 0.1), 0 4px 6px -4px rgb(0 0 0 / 0.1); - --tw-shadow-colored: 0 10px 15px -3px var(--tw-shadow-color), 0 4px 6px -4px var(--tw-shadow-color); - box-shadow: var(--tw-ring-offset-shadow, 0 0 #0000), var(--tw-ring-shadow, 0 0 #0000), var(--tw-shadow); -} - -.shadow-gray-400 { - --tw-shadow-color: #9ca3af; - --tw-shadow: var(--tw-shadow-colored); -} - -.ring-blue-600 { - --tw-ring-opacity: 1; - --tw-ring-color: rgb(37 99 235 / var(--tw-ring-opacity, 1)); -} - -.ring-gray-700 { - --tw-ring-opacity: 1; - --tw-ring-color: rgb(55 65 81 / var(--tw-ring-opacity, 1)); -} - -.ring-offset-white { - --tw-ring-offset-color: #fff; -} - -.filter { - filter: var(--tw-blur) var(--tw-brightness) var(--tw-contrast) var(--tw-grayscale) var(--tw-hue-rotate) var(--tw-invert) var(--tw-saturate) var(--tw-sepia) var(--tw-drop-shadow); -} - -.backdrop-blur { - --tw-backdrop-blur: blur(8px); - -webkit-backdrop-filter: var(--tw-backdrop-blur) var(--tw-backdrop-brightness) var(--tw-backdrop-contrast) var(--tw-backdrop-grayscale) var(--tw-backdrop-hue-rotate) var(--tw-backdrop-invert) var(--tw-backdrop-opacity) var(--tw-backdrop-saturate) var(--tw-backdrop-sepia); - backdrop-filter: var(--tw-backdrop-blur) var(--tw-backdrop-brightness) var(--tw-backdrop-contrast) var(--tw-backdrop-grayscale) var(--tw-backdrop-hue-rotate) var(--tw-backdrop-invert) var(--tw-backdrop-opacity) var(--tw-backdrop-saturate) var(--tw-backdrop-sepia); -} - -.transition-all { - transition-property: all; - transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1); - transition-duration: 150ms; -} - -.transition-colors { - transition-property: color, background-color, border-color, text-decoration-color, fill, stroke; - transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1); - transition-duration: 150ms; -} - -.transition-transform { - transition-property: transform; - transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1); - transition-duration: 150ms; -} - -.duration-100 { - transition-duration: 100ms; -} - -.ease-linear { - transition-timing-function: linear; -} - -body, -html { - display: flex; - height: 100%; - flex-direction: column; - --tw-bg-opacity: 1; - background-color: rgb(17 24 39 / var(--tw-bg-opacity, 1)); - font-family: Poppins, serif; - letter-spacing: 0.025em; - --tw-text-opacity: 1; - color: rgb(255 255 255 / var(--tw-text-opacity, 1)); -} - -/* Override specific classes */ - -.bg-gray-100 { - --tw-bg-opacity: 1; - background-color: rgb(31 41 55 / var(--tw-bg-opacity, 1)); -} - -.border { - --tw-border-opacity: 1; - border-color: rgb(55 65 81 / var(--tw-border-opacity, 1)); -} - -textarea, -input { - --tw-border-opacity: 1; - border-color: rgb(55 65 81 / var(--tw-border-opacity, 1)); - --tw-bg-opacity: 1; - background-color: rgb(31 41 55 / var(--tw-bg-opacity, 1)); - --tw-text-opacity: 1; - color: rgb(255 255 255 / var(--tw-text-opacity, 1)); -} - -select { - -webkit-appearance: none; - /* For Safari */ - -moz-appearance: none; - /* For Firefox */ - appearance: none; - /* For modern browsers */ -} - -input[type="range"] { - -webkit-appearance: none; - /* Hides the slider so that custom slider can be made */ - width: 100%; - /* Specific width is required for Firefox. */ - background: transparent; - /* Otherwise white in Chrome */ -} - -input[type="range"]:focus { - outline: none; - /* Removes the blue border. You should probably do some kind of focus styling for accessibility reasons though. */ -} - -input[type="range"]::-ms-track { - width: 100%; - cursor: pointer; - /* Hides the slider so custom styles can be added */ - background: transparent; - border-color: transparent; - color: transparent; -} - -/* Special styling for WebKit/Blink */ - -input[type="range"]::-webkit-slider-thumb { - z-index: 10; - margin-top: -0.5rem; - height: 1rem; - width: 1rem; - cursor: pointer; - border-radius: 9999px; - --tw-bg-opacity: 1; - background-color: rgb(31 41 55 / var(--tw-bg-opacity, 1)); - --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color); - --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(2px + var(--tw-ring-offset-width)) var(--tw-ring-color); - box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); - --tw-ring-opacity: 1; - --tw-ring-color: rgb(99 102 241 / var(--tw-ring-opacity, 1)); - --tw-ring-offset-width: 1px; - --tw-ring-offset-color: #1f2937; - -webkit-transition-property: all; - transition-property: all; - transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1); - transition-duration: 150ms; - transition-timing-function: linear; -} - -input[type="range"]::-webkit-slider-thumb:hover { - --tw-ring-opacity: 1; - --tw-ring-color: rgb(79 70 229 / var(--tw-ring-opacity, 1)); -} - -input[type="range"]::-webkit-slider-thumb:focus-visible { - --tw-ring-opacity: 1; - --tw-ring-color: rgb(79 70 229 / var(--tw-ring-opacity, 1)); -} - -input[type="range"]::-webkit-slider-thumb:disabled { - pointer-events: none; - opacity: 0.5; -} - -/* All the same stuff for Firefox */ - -input[type="range"]::-moz-range-thumb { - z-index: 10; - margin-top: -0.5rem; - height: 1rem; - width: 1rem; - cursor: pointer; - border-radius: 9999px; - --tw-bg-opacity: 1; - background-color: rgb(31 41 55 / var(--tw-bg-opacity, 1)); - --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color); - --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(2px + var(--tw-ring-offset-width)) var(--tw-ring-color); - box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); - --tw-ring-opacity: 1; - --tw-ring-color: rgb(99 102 241 / var(--tw-ring-opacity, 1)); - --tw-ring-offset-width: 1px; - --tw-ring-offset-color: #1f2937; - -moz-transition-property: all; - transition-property: all; - transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1); - transition-duration: 150ms; - transition-timing-function: linear; -} - -input[type="range"]::-moz-range-thumb:hover { - --tw-ring-opacity: 1; - --tw-ring-color: rgb(79 70 229 / var(--tw-ring-opacity, 1)); -} - -input[type="range"]::-moz-range-thumb:focus-visible { - --tw-ring-opacity: 1; - --tw-ring-color: rgb(79 70 229 / var(--tw-ring-opacity, 1)); -} - -input[type="range"]::-moz-range-thumb:disabled { - pointer-events: none; - opacity: 0.5; -} - -/* All the same stuff for IE */ - -input[type="range"]::-ms-thumb { - z-index: 10; - margin-top: -0.5rem; - height: 1rem; - width: 1rem; - cursor: pointer; - border-radius: 9999px; - --tw-bg-opacity: 1; - background-color: rgb(31 41 55 / var(--tw-bg-opacity, 1)); - --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color); - --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(2px + var(--tw-ring-offset-width)) var(--tw-ring-color); - box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); - --tw-ring-opacity: 1; - --tw-ring-color: rgb(99 102 241 / var(--tw-ring-opacity, 1)); - --tw-ring-offset-width: 1px; - --tw-ring-offset-color: #1f2937; - -ms-transition-property: all; - transition-property: all; - transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1); - transition-duration: 150ms; - transition-timing-function: linear; -} - -input[type="range"]::-ms-thumb:hover { - --tw-ring-opacity: 1; - --tw-ring-color: rgb(79 70 229 / var(--tw-ring-opacity, 1)); -} - -input[type="range"]::-ms-thumb:focus-visible { - --tw-ring-opacity: 1; - --tw-ring-color: rgb(79 70 229 / var(--tw-ring-opacity, 1)); -} - -input[type="range"]::-ms-thumb:disabled { - pointer-events: none; - opacity: 0.5; -} - -input[type="range"]::-webkit-slider-runnable-track { - height: 0.125rem; - width: 100%; - cursor: pointer; - border-radius: 0.25rem; - --tw-bg-opacity: 1; - background-color: rgb(129 140 248 / var(--tw-bg-opacity, 1)); -} - -input[type="range"]::-webkit-slider-runnable-track:hover { - --tw-bg-opacity: 1; - background-color: rgb(99 102 241 / var(--tw-bg-opacity, 1)); -} - -input[type="range"]::-webkit-slider-runnable-track:focus-visible { - --tw-bg-opacity: 1; - background-color: rgb(99 102 241 / var(--tw-bg-opacity, 1)); -} - -input[type="range"]::-webkit-slider-runnable-track:disabled { - pointer-events: none; - opacity: 0.5; -} - -input[type="range"]::-moz-range-track { - height: 0.125rem; - width: 100%; - cursor: pointer; - border-radius: 0.25rem; - --tw-bg-opacity: 1; - background-color: rgb(129 140 248 / var(--tw-bg-opacity, 1)); -} - -input[type="range"]::-moz-range-track:hover { - --tw-bg-opacity: 1; - background-color: rgb(99 102 241 / var(--tw-bg-opacity, 1)); -} - -input[type="range"]::-moz-range-track:focus-visible { - --tw-bg-opacity: 1; - background-color: rgb(99 102 241 / var(--tw-bg-opacity, 1)); -} - -input[type="range"]::-moz-range-track:disabled { - pointer-events: none; - opacity: 0.5; -} - -input[type="range"]::-ms-track { - height: 0.125rem; - width: 100%; - cursor: pointer; - border-radius: 0.25rem; - --tw-bg-opacity: 1; - background-color: rgb(129 140 248 / var(--tw-bg-opacity, 1)); -} - -input[type="range"]::-ms-track:hover { - --tw-bg-opacity: 1; - background-color: rgb(99 102 241 / var(--tw-bg-opacity, 1)); -} - -input[type="range"]::-ms-track:focus-visible { - --tw-bg-opacity: 1; - background-color: rgb(99 102 241 / var(--tw-bg-opacity, 1)); -} - -input[type="range"]::-ms-track:disabled { - pointer-events: none; - opacity: 0.5; -} - -input[type="range"]::-webkit-slider-thumb { - -webkit-appearance: none; -} - -.file\:border-0::file-selector-button { - border-width: 0px; -} - -.file\:bg-transparent::file-selector-button { - background-color: transparent; -} - -.file\:text-sm::file-selector-button { - font-size: 0.875rem; - line-height: 1.25rem; -} - -.file\:font-medium::file-selector-button { - font-weight: 500; -} - -.placeholder\:text-gray-500::-moz-placeholder { - --tw-text-opacity: 1; - color: rgb(107 114 128 / var(--tw-text-opacity, 1)); -} - -.placeholder\:text-gray-500::placeholder { - --tw-text-opacity: 1; - color: rgb(107 114 128 / var(--tw-text-opacity, 1)); -} - -.after\:h-0\.5::after { - content: var(--tw-content); - height: 0.125rem; -} - -.after\:w-2::after { - content: var(--tw-content); - width: 0.5rem; -} - -.after\:bg-indigo-500::after { - content: var(--tw-content); - --tw-bg-opacity: 1; - background-color: rgb(99 102 241 / var(--tw-bg-opacity, 1)); -} - -.after\:transition-all::after { - content: var(--tw-content); - transition-property: all; - transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1); - transition-duration: 150ms; -} - -.after\:ease-linear::after { - content: var(--tw-content); - transition-timing-function: linear; -} - -.hover\:\!border-slate-100:hover { - --tw-border-opacity: 1 !important; - border-color: rgb(241 245 249 / var(--tw-border-opacity, 1)) !important; -} - -.hover\:border-slate-600:hover { - --tw-border-opacity: 1; - border-color: rgb(71 85 105 / var(--tw-border-opacity, 1)); -} - -.hover\:bg-blue-700:hover { - --tw-bg-opacity: 1; - background-color: rgb(29 78 216 / var(--tw-bg-opacity, 1)); -} - -.hover\:bg-gray-100:hover { - --tw-bg-opacity: 1; - background-color: rgb(243 244 246 / var(--tw-bg-opacity, 1)); -} - -.hover\:bg-gray-200:hover { - --tw-bg-opacity: 1; - background-color: rgb(229 231 235 / var(--tw-bg-opacity, 1)); -} - -.hover\:bg-gray-300:hover { - --tw-bg-opacity: 1; - background-color: rgb(209 213 219 / var(--tw-bg-opacity, 1)); -} - -.hover\:bg-gray-600:hover { - --tw-bg-opacity: 1; - background-color: rgb(75 85 99 / var(--tw-bg-opacity, 1)); -} - -.hover\:bg-gray-700:hover { - --tw-bg-opacity: 1; - background-color: rgb(55 65 81 / var(--tw-bg-opacity, 1)); -} - -.hover\:bg-gray-800:hover { - --tw-bg-opacity: 1; - background-color: rgb(31 41 55 / var(--tw-bg-opacity, 1)); -} - -.hover\:bg-indigo-500\/20:hover { - background-color: rgb(99 102 241 / 0.2); -} - -.hover\:bg-indigo-700:hover { - --tw-bg-opacity: 1; - background-color: rgb(67 56 202 / var(--tw-bg-opacity, 1)); -} - -.hover\:bg-slate-700:hover { - --tw-bg-opacity: 1; - background-color: rgb(51 65 85 / var(--tw-bg-opacity, 1)); -} - -.hover\:text-gray-100:hover { - --tw-text-opacity: 1; - color: rgb(243 244 246 / var(--tw-text-opacity, 1)); -} - -.hover\:text-gray-200:hover { - --tw-text-opacity: 1; - color: rgb(229 231 235 / var(--tw-text-opacity, 1)); -} - -.hover\:text-gray-700:hover { - --tw-text-opacity: 1; - color: rgb(55 65 81 / var(--tw-text-opacity, 1)); -} - -.hover\:text-indigo-300:hover { - --tw-text-opacity: 1; - color: rgb(165 180 252 / var(--tw-text-opacity, 1)); -} - -.hover\:underline:hover { - text-decoration-line: underline; -} - -.hover\:opacity-80:hover { - opacity: 0.8; -} - -.hover\:after\:w-full:hover::after { - content: var(--tw-content); - width: 100%; -} - -.focus\:outline-none:focus { - outline: 2px solid transparent; - outline-offset: 2px; -} - -.focus-visible\:text-gray-100:focus-visible { - --tw-text-opacity: 1; - color: rgb(243 244 246 / var(--tw-text-opacity, 1)); -} - -.focus-visible\:text-gray-700:focus-visible { - --tw-text-opacity: 1; - color: rgb(55 65 81 / var(--tw-text-opacity, 1)); -} - -.focus-visible\:outline-none:focus-visible { - outline: 2px solid transparent; - outline-offset: 2px; -} - -.focus-visible\:ring-0:focus-visible { - --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color); - --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(0px + var(--tw-ring-offset-width)) var(--tw-ring-color); - box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); -} - -.focus-visible\:ring-1:focus-visible { - --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color); - --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(1px + var(--tw-ring-offset-width)) var(--tw-ring-color); - box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); -} - -.focus-visible\:ring-2:focus-visible { - --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color); - --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(2px + var(--tw-ring-offset-width)) var(--tw-ring-color); - box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); -} - -.focus-visible\:ring-gray-700:focus-visible { - --tw-ring-opacity: 1; - --tw-ring-color: rgb(55 65 81 / var(--tw-ring-opacity, 1)); -} - -.focus-visible\:ring-indigo-500:focus-visible { - --tw-ring-opacity: 1; - --tw-ring-color: rgb(99 102 241 / var(--tw-ring-opacity, 1)); -} - -.focus-visible\:ring-transparent:focus-visible { - --tw-ring-color: transparent; -} - -.focus-visible\:ring-offset-0:focus-visible { - --tw-ring-offset-width: 0px; -} - -.focus-visible\:ring-offset-2:focus-visible { - --tw-ring-offset-width: 2px; -} - -.focus-visible\:ring-offset-white:focus-visible { - --tw-ring-offset-color: #fff; -} - -.disabled\:pointer-events-none:disabled { - pointer-events: none; -} - -.disabled\:cursor-not-allowed:disabled { - cursor: not-allowed; -} - -.disabled\:opacity-50:disabled { - opacity: 0.5; -} - -.group:hover .group-hover\:border-gray-900 { - --tw-border-opacity: 1; - border-color: rgb(17 24 39 / var(--tw-border-opacity, 1)); -} - -.group:hover .group-hover\:\!bg-slate-100 { - --tw-bg-opacity: 1 !important; - background-color: rgb(241 245 249 / var(--tw-bg-opacity, 1)) !important; -} - -.group:hover .group-hover\:bg-gray-900 { - --tw-bg-opacity: 1; - background-color: rgb(17 24 39 / var(--tw-bg-opacity, 1)); -} - -.group:hover .group-hover\:text-gray-900 { - --tw-text-opacity: 1; - color: rgb(17 24 39 / var(--tw-text-opacity, 1)); -} - -@media (min-width: 768px) { - .md\:grid-cols-2 { - grid-template-columns: repeat(2, minmax(0, 1fr)); - } - - .md\:grid-cols-3 { - grid-template-columns: repeat(3, minmax(0, 1fr)); - } -} - -@media (min-width: 1024px) { - .lg\:-right-\[450px\] { - right: -450px; - } - - .lg\:\!w-\[400px\] { - width: 400px !important; - } -} - -.\[\&\:\:-moz-range-thumb\]\:z-10::-moz-range-thumb { - z-index: 10; -} - -.\[\&\:\:-moz-range-thumb\]\:-mt-2::-moz-range-thumb { - margin-top: -0.5rem; -} - -.\[\&\:\:-moz-range-thumb\]\:h-4::-moz-range-thumb { - height: 1rem; -} - -.\[\&\:\:-moz-range-thumb\]\:w-4::-moz-range-thumb { - width: 1rem; -} - -.\[\&\:\:-moz-range-thumb\]\:cursor-pointer::-moz-range-thumb { - cursor: pointer; -} - -.\[\&\:\:-moz-range-thumb\]\:rounded-full::-moz-range-thumb { - border-radius: 9999px; -} - -.\[\&\:\:-moz-range-thumb\]\:bg-gray-800::-moz-range-thumb { - --tw-bg-opacity: 1; - background-color: rgb(31 41 55 / var(--tw-bg-opacity, 1)); -} - -.\[\&\:\:-moz-range-thumb\]\:ring-2::-moz-range-thumb { - --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color); - --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(2px + var(--tw-ring-offset-width)) var(--tw-ring-color); - box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); -} - -.\[\&\:\:-moz-range-thumb\]\:ring-gray-900::-moz-range-thumb { - --tw-ring-opacity: 1; - --tw-ring-color: rgb(17 24 39 / var(--tw-ring-opacity, 1)); -} - -.\[\&\:\:-moz-range-thumb\]\:ring-offset-1::-moz-range-thumb { - --tw-ring-offset-width: 1px; -} - -.\[\&\:\:-moz-range-thumb\]\:transition-all::-moz-range-thumb { - -moz-transition-property: all; - transition-property: all; - transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1); - transition-duration: 150ms; -} - -.\[\&\:\:-moz-range-thumb\]\:ease-linear::-moz-range-thumb { - transition-timing-function: linear; -} - -.\[\&\:\:-moz-range-thumb\]\:hover\:bg-gray-900:hover::-moz-range-thumb { - --tw-bg-opacity: 1; - background-color: rgb(17 24 39 / var(--tw-bg-opacity, 1)); -} - -.\[\&\:\:-moz-range-thumb\]\:hover\:ring-gray-900:hover::-moz-range-thumb { - --tw-ring-opacity: 1; - --tw-ring-color: rgb(17 24 39 / var(--tw-ring-opacity, 1)); -} - -.\[\&\:\:-moz-range-thumb\]\:focus-visible\:ring-gray-600:focus-visible::-moz-range-thumb { - --tw-ring-opacity: 1; - --tw-ring-color: rgb(75 85 99 / var(--tw-ring-opacity, 1)); -} - -.\[\&\:\:-moz-range-thumb\]\:disabled\:pointer-events-none:disabled::-moz-range-thumb { - pointer-events: none; -} - -.\[\&\:\:-moz-range-thumb\]\:disabled\:opacity-50:disabled::-moz-range-thumb { - opacity: 0.5; -} - -.\[\&\:\:-moz-range-track\]\:h-0\.5::-moz-range-track { - height: 0.125rem; -} - -.\[\&\:\:-moz-range-track\]\:w-full::-moz-range-track { - width: 100%; -} - -.\[\&\:\:-moz-range-track\]\:cursor-pointer::-moz-range-track { - cursor: pointer; -} - -.\[\&\:\:-moz-range-track\]\:rounded::-moz-range-track { - border-radius: 0.25rem; -} - -.\[\&\:\:-moz-range-track\]\:bg-gray-500::-moz-range-track { - --tw-bg-opacity: 1; - background-color: rgb(107 114 128 / var(--tw-bg-opacity, 1)); -} - -.\[\&\:\:-moz-range-track\]\:hover\:bg-gray-700:hover::-moz-range-track { - --tw-bg-opacity: 1; - background-color: rgb(55 65 81 / var(--tw-bg-opacity, 1)); -} - -.\[\&\:\:-moz-range-track\]\:focus-visible\:bg-gray-700:focus-visible::-moz-range-track { - --tw-bg-opacity: 1; - background-color: rgb(55 65 81 / var(--tw-bg-opacity, 1)); -} - -.\[\&\:\:-moz-range-track\]\:disabled\:pointer-events-none:disabled::-moz-range-track { - pointer-events: none; -} - -.\[\&\:\:-moz-range-track\]\:disabled\:opacity-50:disabled::-moz-range-track { - opacity: 0.5; -} - -.\[\&\:\:-ms-thumb\]\:z-10::-ms-thumb { - z-index: 10; -} - -.\[\&\:\:-ms-thumb\]\:-mt-2::-ms-thumb { - margin-top: -0.5rem; -} - -.\[\&\:\:-ms-thumb\]\:h-4::-ms-thumb { - height: 1rem; -} - -.\[\&\:\:-ms-thumb\]\:w-4::-ms-thumb { - width: 1rem; -} - -.\[\&\:\:-ms-thumb\]\:cursor-pointer::-ms-thumb { - cursor: pointer; -} - -.\[\&\:\:-ms-thumb\]\:rounded-full::-ms-thumb { - border-radius: 9999px; -} - -.\[\&\:\:-ms-thumb\]\:bg-gray-800::-ms-thumb { - --tw-bg-opacity: 1; - background-color: rgb(31 41 55 / var(--tw-bg-opacity, 1)); -} - -.\[\&\:\:-ms-thumb\]\:ring-2::-ms-thumb { - --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color); - --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(2px + var(--tw-ring-offset-width)) var(--tw-ring-color); - box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); -} - -.\[\&\:\:-ms-thumb\]\:ring-gray-900::-ms-thumb { - --tw-ring-opacity: 1; - --tw-ring-color: rgb(17 24 39 / var(--tw-ring-opacity, 1)); -} - -.\[\&\:\:-ms-thumb\]\:ring-offset-1::-ms-thumb { - --tw-ring-offset-width: 1px; -} - -.\[\&\:\:-ms-thumb\]\:transition-all::-ms-thumb { - -ms-transition-property: all; - transition-property: all; - transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1); - transition-duration: 150ms; -} - -.\[\&\:\:-ms-thumb\]\:ease-linear::-ms-thumb { - transition-timing-function: linear; -} - -.\[\&\:\:-ms-thumb\]\:hover\:bg-gray-900:hover::-ms-thumb { - --tw-bg-opacity: 1; - background-color: rgb(17 24 39 / var(--tw-bg-opacity, 1)); -} - -.\[\&\:\:-ms-thumb\]\:hover\:ring-gray-900:hover::-ms-thumb { - --tw-ring-opacity: 1; - --tw-ring-color: rgb(17 24 39 / var(--tw-ring-opacity, 1)); -} - -.\[\&\:\:-ms-thumb\]\:focus-visible\:ring-gray-600:focus-visible::-ms-thumb { - --tw-ring-opacity: 1; - --tw-ring-color: rgb(75 85 99 / var(--tw-ring-opacity, 1)); -} - -.\[\&\:\:-ms-thumb\]\:disabled\:pointer-events-none:disabled::-ms-thumb { - pointer-events: none; -} - -.\[\&\:\:-ms-thumb\]\:disabled\:opacity-50:disabled::-ms-thumb { - opacity: 0.5; -} - -.\[\&\:\:-ms-track\]\:h-0\.5::-ms-track { - height: 0.125rem; -} - -.\[\&\:\:-ms-track\]\:w-full::-ms-track { - width: 100%; -} - -.\[\&\:\:-ms-track\]\:cursor-pointer::-ms-track { - cursor: pointer; -} - -.\[\&\:\:-ms-track\]\:rounded::-ms-track { - border-radius: 0.25rem; -} - -.\[\&\:\:-ms-track\]\:border-transparent::-ms-track { - border-color: transparent; -} - -.\[\&\:\:-ms-track\]\:bg-gray-500::-ms-track { - --tw-bg-opacity: 1; - background-color: rgb(107 114 128 / var(--tw-bg-opacity, 1)); -} - -.\[\&\:\:-ms-track\]\:bg-transparent::-ms-track { - background-color: transparent; -} - -.\[\&\:\:-ms-track\]\:text-transparent::-ms-track { - color: transparent; -} - -.\[\&\:\:-ms-track\]\:hover\:bg-gray-700:hover::-ms-track { - --tw-bg-opacity: 1; - background-color: rgb(55 65 81 / var(--tw-bg-opacity, 1)); -} - -.\[\&\:\:-ms-track\]\:focus-visible\:bg-gray-700:focus-visible::-ms-track { - --tw-bg-opacity: 1; - background-color: rgb(55 65 81 / var(--tw-bg-opacity, 1)); -} - -.\[\&\:\:-ms-track\]\:disabled\:pointer-events-none:disabled::-ms-track { - pointer-events: none; -} - -.\[\&\:\:-ms-track\]\:disabled\:opacity-50:disabled::-ms-track { - opacity: 0.5; -} - -.\[\&\:\:-webkit-slider-runnable-track\]\:h-0\.5::-webkit-slider-runnable-track { - height: 0.125rem; -} - -.\[\&\:\:-webkit-slider-runnable-track\]\:w-full::-webkit-slider-runnable-track { - width: 100%; -} - -.\[\&\:\:-webkit-slider-runnable-track\]\:cursor-pointer::-webkit-slider-runnable-track { - cursor: pointer; -} - -.\[\&\:\:-webkit-slider-runnable-track\]\:rounded::-webkit-slider-runnable-track { - border-radius: 0.25rem; -} - -.\[\&\:\:-webkit-slider-runnable-track\]\:bg-gray-500::-webkit-slider-runnable-track { - --tw-bg-opacity: 1; - background-color: rgb(107 114 128 / var(--tw-bg-opacity, 1)); -} - -.\[\&\:\:-webkit-slider-runnable-track\]\:hover\:bg-gray-700:hover::-webkit-slider-runnable-track { - --tw-bg-opacity: 1; - background-color: rgb(55 65 81 / var(--tw-bg-opacity, 1)); -} - -.\[\&\:\:-webkit-slider-runnable-track\]\:focus-visible\:bg-gray-700:focus-visible::-webkit-slider-runnable-track { - --tw-bg-opacity: 1; - background-color: rgb(55 65 81 / var(--tw-bg-opacity, 1)); -} - -.\[\&\:\:-webkit-slider-runnable-track\]\:disabled\:pointer-events-none:disabled::-webkit-slider-runnable-track { - pointer-events: none; -} - -.\[\&\:\:-webkit-slider-runnable-track\]\:disabled\:opacity-50:disabled::-webkit-slider-runnable-track { - opacity: 0.5; -} - -.\[\&\:\:-webkit-slider-thumb\]\:z-10::-webkit-slider-thumb { - z-index: 10; -} - -.\[\&\:\:-webkit-slider-thumb\]\:-mt-2::-webkit-slider-thumb { - margin-top: -0.5rem; -} - -.\[\&\:\:-webkit-slider-thumb\]\:h-4::-webkit-slider-thumb { - height: 1rem; -} - -.\[\&\:\:-webkit-slider-thumb\]\:w-4::-webkit-slider-thumb { - width: 1rem; -} - -.\[\&\:\:-webkit-slider-thumb\]\:cursor-pointer::-webkit-slider-thumb { - cursor: pointer; -} - -.\[\&\:\:-webkit-slider-thumb\]\:appearance-none::-webkit-slider-thumb { - -webkit-appearance: none; - appearance: none; -} - -.\[\&\:\:-webkit-slider-thumb\]\:rounded-full::-webkit-slider-thumb { - border-radius: 9999px; -} - -.\[\&\:\:-webkit-slider-thumb\]\:bg-gray-800::-webkit-slider-thumb { - --tw-bg-opacity: 1; - background-color: rgb(31 41 55 / var(--tw-bg-opacity, 1)); -} - -.\[\&\:\:-webkit-slider-thumb\]\:ring-2::-webkit-slider-thumb { - --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color); - --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(2px + var(--tw-ring-offset-width)) var(--tw-ring-color); - box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); -} - -.\[\&\:\:-webkit-slider-thumb\]\:ring-gray-900::-webkit-slider-thumb { - --tw-ring-opacity: 1; - --tw-ring-color: rgb(17 24 39 / var(--tw-ring-opacity, 1)); -} - -.\[\&\:\:-webkit-slider-thumb\]\:ring-offset-1::-webkit-slider-thumb { - --tw-ring-offset-width: 1px; -} - -.\[\&\:\:-webkit-slider-thumb\]\:transition-all::-webkit-slider-thumb { - -webkit-transition-property: all; - transition-property: all; - transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1); - transition-duration: 150ms; -} - -.\[\&\:\:-webkit-slider-thumb\]\:ease-linear::-webkit-slider-thumb { - transition-timing-function: linear; -} - -.\[\&\:\:-webkit-slider-thumb\]\:hover\:bg-gray-900:hover::-webkit-slider-thumb { - --tw-bg-opacity: 1; - background-color: rgb(17 24 39 / var(--tw-bg-opacity, 1)); -} - -.\[\&\:\:-webkit-slider-thumb\]\:hover\:ring-gray-900:hover::-webkit-slider-thumb { - --tw-ring-opacity: 1; - --tw-ring-color: rgb(17 24 39 / var(--tw-ring-opacity, 1)); -} - -.\[\&\:\:-webkit-slider-thumb\]\:focus-visible\:ring-gray-600:focus-visible::-webkit-slider-thumb { - --tw-ring-opacity: 1; - --tw-ring-color: rgb(75 85 99 / var(--tw-ring-opacity, 1)); -} - -.\[\&\:\:-webkit-slider-thumb\]\:disabled\:pointer-events-none:disabled::-webkit-slider-thumb { - pointer-events: none; -} - -.\[\&\:\:-webkit-slider-thumb\]\:disabled\:opacity-50:disabled::-webkit-slider-thumb { - opacity: 0.5; -} - -.\[\&\>path\:last-child\]\:fill-gray-400>path:last-child { - fill: #9ca3af; -} - -.\[\&\>svg\]\:shrink-0>svg { - flex-shrink: 0; -} - -.\[\&_svg\]\:pointer-events-none svg { - pointer-events: none; -} - -.\[\&_svg\]\:size-4 svg { - width: 1rem; - height: 1rem; -} - -.\[\&_svg\]\:shrink-0 svg { - flex-shrink: 0; -} diff --git a/app/build.rs b/app/build.rs deleted file mode 100644 index a61e9f4..0000000 --- a/app/build.rs +++ /dev/null @@ -1,19 +0,0 @@ -#[dotenvy::load] -fn main() { - let profile = std::env::var("PROFILE").unwrap_or_else(|_| "debug".to_string()); - if profile != "release" { - println!("cargo:rustc-env=RUST_BACKTRACE=1"); - println!("cargo:rustc-env=CARGO_PROFILE_DEV_BUILD_OVERRIDE_DEBUG=true"); - println!("cargo:rerun-if-changed=../.env"); - } - - #[cfg(feature = "server")] - { - for key in ["DATABASE_URL", "ANTHROPIC_API_KEY", "SERPAPI_API_KEY", "APALIS_DATABASE_URL"] { - println!("cargo:rustc-env={}={}", key, std::env::var(key).unwrap()); - } - } - for key in ["SERVER_URL", "ENV"] { - println!("cargo:rustc-env={}={}", key, std::env::var(key).unwrap()); - } -} diff --git a/app/input.css b/app/input.css deleted file mode 100644 index 07d3d33..0000000 --- a/app/input.css +++ /dev/null @@ -1,79 +0,0 @@ -@tailwind base; -@tailwind components; -@tailwind utilities; - -body, -html { - @apply bg-gray-900 text-white flex flex-col h-full font-poppins tracking-wide; -} - -/* Override specific classes */ -.bg-gray-100 { - @apply bg-gray-800; -} - -.border { - @apply border-gray-700; -} - -textarea, -input { - @apply bg-gray-800 text-white border-gray-700; -} - -select { - -webkit-appearance: none; /* For Safari */ - -moz-appearance: none; /* For Firefox */ - appearance: none; /* For modern browsers */ -} - -input[type="range"] { - -webkit-appearance: none; /* Hides the slider so that custom slider can be made */ - width: 100%; /* Specific width is required for Firefox. */ - background: transparent; /* Otherwise white in Chrome */ -} - -input[type="range"]:focus { - outline: none; /* Removes the blue border. You should probably do some kind of focus styling for accessibility reasons though. */ -} - -input[type="range"]::-ms-track { - width: 100%; - cursor: pointer; - - /* Hides the slider so custom styles can be added */ - background: transparent; - border-color: transparent; - color: transparent; -} - -/* Special styling for WebKit/Blink */ -input[type="range"]::-webkit-slider-thumb { - @apply cursor-pointer ring-2 ring-offset-1 ring-indigo-500 ring-offset-gray-800 h-4 w-4 rounded-full bg-gray-800 -mt-2 z-10 transition-all ease-linear hover:ring-indigo-600 focus-visible:ring-indigo-600 disabled:pointer-events-none disabled:opacity-50; -} - -/* All the same stuff for Firefox */ -input[type="range"]::-moz-range-thumb { - @apply cursor-pointer ring-2 ring-offset-1 ring-indigo-500 ring-offset-gray-800 h-4 w-4 rounded-full bg-gray-800 -mt-2 z-10 transition-all ease-linear hover:ring-indigo-600 focus-visible:ring-indigo-600 disabled:pointer-events-none disabled:opacity-50; -} - -/* All the same stuff for IE */ -input[type="range"]::-ms-thumb { - @apply cursor-pointer ring-2 ring-offset-1 ring-indigo-500 ring-offset-gray-800 h-4 w-4 rounded-full bg-gray-800 -mt-2 z-10 transition-all ease-linear hover:ring-indigo-600 focus-visible:ring-indigo-600 disabled:pointer-events-none disabled:opacity-50; -} - -input[type="range"]::-webkit-slider-runnable-track { - @apply w-full h-0.5 cursor-pointer bg-indigo-400 rounded hover:bg-indigo-500 focus-visible:bg-indigo-500 disabled:opacity-50 disabled:pointer-events-none; -} - -input[type="range"]::-moz-range-track { - @apply w-full h-0.5 cursor-pointer bg-indigo-400 rounded hover:bg-indigo-500 focus-visible:bg-indigo-500 disabled:opacity-50 disabled:pointer-events-none; -} - -input[type="range"]::-ms-track { - @apply w-full h-0.5 cursor-pointer bg-indigo-400 rounded hover:bg-indigo-500 focus-visible:bg-indigo-500 disabled:opacity-50 disabled:pointer-events-none; -} - -input[type="range"]::-webkit-slider-thumb { - -webkit-appearance: none; -} diff --git a/app/src/layout.rs b/app/src/layout.rs deleted file mode 100644 index 4f094f8..0000000 --- a/app/src/layout.rs +++ /dev/null @@ -1,45 +0,0 @@ -use {crate::router::Route, dioxus::prelude::*, strum::IntoEnumIterator}; - -#[component] -pub fn Layout(children: Element) -> Element { - rsx! { - document::Link { rel: "stylesheet", href: asset!("/assets/tailwind.css") } - document::Link { rel: "icon", href: asset!("/assets/favicon.ico") } - document::Link { rel: "preconnect", href: "https://fonts.googleapis.com" } - document::Link { - rel: "stylesheet", - href: "https://fonts.googleapis.com/css2?family=DM+Mono:wght@400;500&family=Poppins:ital,wght@0,400;0,500;0,600;0,700;1,400;1,500;1,600;1,700&display=swap", - } - div { class: "grid grid-cols-7 gap-4 max-h-screen h-full", - Sidebar {} - div { class: "p-4 rounded col-span-6 h-full min-h-screen", Outlet:: {} } - } - } -} - -#[component] -pub fn Sidebar() -> Element { - let navigator = use_navigator(); - let router = router(); - let current_route = router.current::(); - - rsx! { - div { class: "h-full overflow-hidden py-8", - nav { class: "gap-y-5 flex flex-col overflow-y-auto px-4 h-full", - Fragment { - for route in Route::iter() { - div { - key: "{route.to_string()}", - class: "px-3 py-2 rounded-md hover:bg-gray-800 transition-colors cursor-pointer ease-linear", - class: if route.to_string() == current_route.to_string() { "bg-gray-800 text-indigo-200" }, - onclick: move |_| { - navigator.push(route.clone()); - }, - span { {route.name()} } - } - } - } - } - } - } -} diff --git a/app/src/lib.rs b/app/src/lib.rs deleted file mode 100644 index 63c3c45..0000000 --- a/app/src/lib.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod layout; -pub mod pages; -pub mod router; diff --git a/app/src/main.rs b/app/src/main.rs deleted file mode 100644 index 8351a0e..0000000 --- a/app/src/main.rs +++ /dev/null @@ -1,25 +0,0 @@ -#![allow(non_snake_case)] -use { - app::router::Route, - dioxus::prelude::*, - dioxus_logger::tracing, - maestro_toast::{init::use_init_toast_ctx, toast_frame_component::ToastFrame}, -}; - -fn App() -> Element { - let toast = use_init_toast_ctx(); - rsx! { - ToastFrame { manager: toast } - Router:: {} - } -} - -fn main() { - // #[cfg(not(feature = "server"))] - // dioxus::fullstack::prelude::server_fn::client::set_server_url(&SERVER_URL); - dioxus_logger::init(tracing::Level::INFO).expect("failed to init logger"); - dioxus::LaunchBuilder::new() - .with_context(server_only!(maestro_diesel::async_client::client::acreate_diesel_pool(env!("DATABASE_URL")))) - .with_context(server_only!(maestro_anthropic::AnthropicClient::new(env!("ANTHROPIC_API_KEY")))) - .launch(App); -} diff --git a/app/src/pages/home.rs b/app/src/pages/home.rs deleted file mode 100644 index 83daca3..0000000 --- a/app/src/pages/home.rs +++ /dev/null @@ -1,8 +0,0 @@ -use dioxus::prelude::*; - -#[component] -pub fn Home() -> Element { - rsx! { - div {} - } -} diff --git a/app/src/pages/mod.rs b/app/src/pages/mod.rs deleted file mode 100644 index 9b86bcf..0000000 --- a/app/src/pages/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod home; diff --git a/app/src/router.rs b/app/src/router.rs deleted file mode 100644 index 3fe16b4..0000000 --- a/app/src/router.rs +++ /dev/null @@ -1,20 +0,0 @@ -use { - crate::{layout::Layout, pages::home::Home}, - dioxus::prelude::*, - strum::EnumIter, -}; - -#[derive(Clone, PartialEq, EnumIter, Routable)] -pub enum Route { - #[layout(Layout)] - #[route("/")] - Home {}, -} - -impl Route { - pub fn name(&self) -> &'static str { - match self { - Route::Home {} => "Home", - } - } -} diff --git a/app/tailwind.config.js b/app/tailwind.config.js deleted file mode 100644 index ddfaad0..0000000 --- a/app/tailwind.config.js +++ /dev/null @@ -1,27 +0,0 @@ -/** @type {import('tailwindcss').Config} */ -module.exports = { - mode: "all", - content: [ - // include all rust, html and css files in the src directory - "./src/**/*.{rs,html,css}", - // include all html files in the output (dist) directory - "./dist/**/*.html", - "../../dioxus-maestro/frontend/**/*.{rs,html,css}", - ], - theme: { - extend: { - keyframes: { - highlight: { - "0%": { background: "#8f8" }, - "100%": { background: "auto" }, - }, - }, - animation: { highlight: "highlight 1s" }, - fontFamily: { - "dm-mono": ["DM Mono", "mono"], - poppins: ["Poppins", "serif"], - }, - }, - }, - plugins: [], -}; diff --git a/clippy.toml b/clippy.toml index 8bb65b6..2a20051 100644 --- a/clippy.toml +++ b/clippy.toml @@ -3,7 +3,7 @@ # ----------------------------------------------------------------------------- # Section identical to scripts/clippy_wasm/clippy.toml: -msrv = "1.89.0" +msrv = "1.88.0" allow-unwrap-in-tests = true @@ -18,49 +18,3 @@ max-include-file-size = 1000000 # https://rust-lang.github.io/rust-clippy/master/index.html#/type_complexity type-complexity-threshold = 350 - -# ----------------------------------------------------------------------------- - -# https://rust-lang.github.io/rust-clippy/master/index.html#disallowed_macros -disallowed-macros = [ - 'dbg', - 'std::unimplemented', - 'todo', - - # TODO: consider forbidding these to encourage the use of proper log stream, and then explicitly allow legitimate uses - # 'std::eprint', - # 'std::eprintln', - # 'std::print', - 'std::println', -] - -# https://rust-lang.github.io/rust-clippy/master/index.html#disallowed_methods -disallowed-methods = [ - "std::env::temp_dir", # Use the tempdir crate instead - - # There are many things that aren't allowed on wasm, - # but we cannot disable them all here (because of e.g. https://github.com/rust-lang/rust-clippy/issues/10406) - # so we do that in `clipppy_wasm.toml` instead. - - "std::thread::spawn", # Use `std::thread::Builder` and name the thread - - "sha1::Digest::new", # SHA1 is cryptographically broken - - "std::panic::catch_unwind", # We compile with `panic = "abort"` -] - -# https://rust-lang.github.io/rust-clippy/master/index.html#disallowed_names -disallowed-names = [] - -# https://rust-lang.github.io/rust-clippy/master/index.html#disallowed_types -disallowed-types = [ - # Use the faster & simpler non-poisonable primitives in `parking_lot` instead - "std::sync::Condvar", - "std::sync::Mutex", - "std::sync::Once", - "std::sync::RwLock", - - "ring::digest::SHA1_FOR_LEGACY_USE_ONLY", # SHA1 is cryptographically broken - - -] diff --git a/core/macroz/README.md b/core/macroz/README.md new file mode 100644 index 0000000..d14e31a --- /dev/null +++ b/core/macroz/README.md @@ -0,0 +1,272 @@ +# Painlezz Macroz + +A lightweight Python macro system for code transformation and metadata collection, designed to work seamlessly with LibCST for static analysis and code generation. This system powers the plugin discovery mechanism in EZPZ-Pluginz. + +## Overview + +Painlezz Macroz provides a foundation for creating decorator-based macros that can collect metadata during static analysis without affecting runtime behavior. It's particularly useful for plugin systems, code generators, and tools that need to extract information from decorated classes and functions. + +## Features + +- **No-op Macros**: Decorators that preserve original functionality while enabling metadata collection +- **LibCST Integration**: Built-in visitor patterns for AST traversal and metadata extraction +- **Type-Safe**: Full type hints and generic support for robust macro definitions +- **Minimal Runtime Impact**: Macros are designed to be lightweight and non-intrusive +- **Flexible Callback System**: Support for custom metadata extraction logic + +## Installation + +```bash +pip install macroz +``` + +## Core Components + +### 1. No-op Macros (`macroz/noop.py`) + +The foundation of the macro system - decorators that don't change behavior but enable metadata collection: + +```python +from painlezz_macroz.macroz.noop import class_macro, func_macro + +# Class macro - preserves the class unchanged (identity function) +@class_macro +class MyClass: + pass + +# Function macro - preserves function behavior with proper wrapping +@func_macro +def my_function(): + return "unchanged" +``` + +#### Available Macros + +- **`class_macro[T](cls: T) -> T`**: Identity decorator for classes - returns the class unchanged +- **`func_macro[**P, R](func: Callable[P, R]) -> Callable[P, R]`**: Wrapper decorator for functions that preserves signature and behavior using `@wraps` + +**Important**: The `class_macro` is a true identity function that returns the class unchanged, while `func_macro` creates a wrapper using `functools.wraps` to preserve metadata. + +### 2. Metadata Collection (`visitorz/macro_metadata_collector.py`) + +A powerful LibCST visitor that extracts metadata from macro-decorated code using pattern matching: + +```python +from painlezz_macroz.visitorz.macro_metadata_collector import MacroMetadataCollector +from pydantic import BaseModel +import libcst as cst + +# Define your metadata model +class MyMacroData(BaseModel): + name: str + value: int + +# Create a collector +collector = MacroMetadataCollector[MyMacroData, dict]( + macro_name="my_macro", + callback=lambda args, kwargs: MyMacroData( + name=kwargs["name"], + value=kwargs["value"] + ) +) + +# Parse and visit code +module = cst.parse_module(source_code) +module.visit(collector) + +# Access collected metadata +for data in collector.macro_data: + print(f"Found: {data.name} = {data.value}") +``` + +## How It Works with EZPZ-Pluginz + +The macro system integrates seamlessly with EZPZ-Pluginz to enable plugin discovery: + +### 1. Plugin Definition + +```python +from ezpz_pluginz.register_plugin_macro import ezpz_plugin_collect + +# The ezpz_plugin_collect function returns class_macro internally +@ezpz_plugin_collect( + polars_ns="LazyFrame", + attr_name="my_plugin", + import_="from my_package import MyPlugin", + type_hint="MyPlugin" +) +class MyPlugin: + def custom_method(self): + pass +``` + +### 2. Metadata Extraction + +The `PolarsPluginCollector` extends `MacroMetadataCollector` to extract plugin information: + +```python +class PolarsPluginCollector(MacroMetadataCollector[PolarsPluginMacroMetadataPD, PolarsPluginMacroKwargs]): + def __init__(self) -> None: + super().__init__( + ezpz_plugin_collect.__name__, # "ezpz_plugin_collect" + lambda _args, kwargs: PolarsPluginMacroMetadataPD( + import_=kwargs["import_"], + type_hint=kwargs["type_hint"], + attr_name=kwargs["attr_name"], + polars_ns=EPolarsNS(kwargs["polars_ns"]), + ), + ) +``` + +### 3. Function Call Support + +The collector also handles function call syntax: + +```python +# This syntax is also supported +ezpz_plugin_collect( + polars_ns="DataFrame", + attr_name="my_plugin", + import_="from my_package import MyPlugin", + type_hint="MyPlugin" +)(MyPluginClass) +``` + +## Technical Implementation Details + +### Metadata Collection Process + +The `MacroMetadataCollector` uses LibCST's matcher system to identify decorators: + +```python +@m.leave(m.Decorator(decorator=m.Call(func=m.Name()))) +def collect_macro_metadata(self, node: cst.Decorator) -> None: + match node.decorator: + case cst.Call(func=cst.Name(decorator_name), args=decorator_args) if decorator_name == self.macro_name: + args: list[JSONSerializable] = [] + kwargs = cast("TMacroKwargs", {}) + + for arg in decorator_args: + # Extract literal values using ast.literal_eval + evaled = ast.literal_eval(arg.value.value) if isinstance(arg.value, cst.SimpleString) else ast.literal_eval(dump(arg.value)) + + if arg.keyword is None: + args.append(evaled) + else: + kwargs[arg.keyword.value] = evaled + + # Create metadata instance via callback + self.macro_data.append(self.callback(args, kwargs)) +``` + +### Type System + +The library provides comprehensive generic type support: + +```python +# JSON-serializable types for macro arguments +type JSONSerializable = str | int | float | bool | None | list[JSONSerializable] | dict[str, JSONSerializable] + +# Generic callback type +type TMetadataCallback[T: BaseModel, TMacroKwargs: dict[str, JSONSerializable]] = + Callable[[Iterable[JSONSerializable], TMacroKwargs], T] + +# Generic collector class (TMacroKwargs bound to Any to allow TypedDict) +class MacroMetadataCollector[T: BaseModel, TMacroKwargs: Any](m.MatcherDecoratableVisitor): +``` + +## Usage Patterns + +### Plugin Registration (EZPZ-Pluginz Pattern) + +```python +from ezpz_pluginz.register_plugin_macro import ezpz_plugin_collect + +# Decorator syntax +@ezpz_plugin_collect( + polars_ns="DataFrame", + attr_name="advanced_ops", + import_="from my_plugins import DataFrameAdvanced", + type_hint="DataFrameAdvanced" +) +class DataFrameAdvanced: + def complex_operation(self): + pass + +# Function call syntax +class SeriesUtils: + def utility_method(self): + pass + +ezpz_plugin_collect( + polars_ns="Series", + attr_name="utils", + import_="from my_plugins import SeriesUtils", + type_hint="SeriesUtils" +)(SeriesUtils) +``` + +### Custom Macro System + +```python +from painlezz_macroz.macroz.noop import class_macro +from painlezz_macroz.visitorz.macro_metadata_collector import MacroMetadataCollector +from pydantic import BaseModel + +# Define custom metadata +class ConfigMetadata(BaseModel): + section: str + priority: int = 0 + +# Create custom macro +def config_section(**kwargs): + return class_macro + +# Create collector +collector = MacroMetadataCollector[ConfigMetadata, dict]( + "config_section", + lambda args, kwargs: ConfigMetadata(**kwargs) +) +``` + +## Integration Flow in EZPZ-Pluginz + +1. **Plugin Definition**: Developers use `@ezpz_plugin_collect` to mark plugin classes +2. **Code Scanning**: EZPZ-Pluginz scans configured paths for Python files +3. **AST Parsing**: LibCST parses each file into a concrete syntax tree +4. **Metadata Collection**: `PolarsPluginCollector` visits the AST and extracts plugin metadata +5. **Lockfile Generation**: Collected metadata is serialized into a YAML lockfile +6. **Type Enhancement**: LibCST transformers inject type hints into Polars classes +7. **Import Management**: Required imports are added to `TYPE_CHECKING` blocks + +## Error Handling + +The collector includes robust error handling: + +- Graceful handling of malformed decorator syntax +- Safe literal evaluation using `ast.literal_eval` +- Fallback to LibCST's `dump` function for complex expressions +- Optional callback validation + +## Supported Argument Types + +The macro system supports these JSON-serializable types: + +- `str`, `int`, `float`, `bool`, `None` +- `list[JSONSerializable]` (nested lists) +- `dict[str, JSONSerializable]` (nested dictionaries) + +## Advanced Features + +- **Pattern Matching**: Uses LibCST matchers for precise AST node identification +- **Flexible Callbacks**: Support for custom metadata extraction logic +- **Type Safety**: Full generic support with proper type bounds +- **Multiple Syntax Support**: Handles both decorator and function call patterns + +## Contributing + +Painlezz Macroz is part of the EZPZ ecosystem. Contributions should maintain the lightweight, type-safe approach while expanding functionality for static analysis and code generation use cases. + +## License + +Part of the EZPZ project - see main repository for licensing information. diff --git a/guiz/python/ezpz_guiz/__init__.py b/core/macroz/macroz/__init__.py similarity index 100% rename from guiz/python/ezpz_guiz/__init__.py rename to core/macroz/macroz/__init__.py diff --git a/macroz/painlezz_macroz/macroz/noop.py b/core/macroz/macroz/decoratorz/noop.py similarity index 100% rename from macroz/painlezz_macroz/macroz/noop.py rename to core/macroz/macroz/decoratorz/noop.py diff --git a/macroz/painlezz_macroz/visitorz/macro_metadata_collector.py b/core/macroz/macroz/visitorz/macro_metadata_collector.py similarity index 75% rename from macroz/painlezz_macroz/visitorz/macro_metadata_collector.py rename to core/macroz/macroz/visitorz/macro_metadata_collector.py index 1c851b6..d8a95e3 100644 --- a/macroz/painlezz_macroz/visitorz/macro_metadata_collector.py +++ b/core/macroz/macroz/visitorz/macro_metadata_collector.py @@ -7,9 +7,12 @@ from libcst.display import dump type JSONSerializable = str | int | float | bool | None | list[JSONSerializable] | dict[str, JSONSerializable] +type TMetadataCallback[T: BaseModel, TMacroKwargs: dict[str, JSONSerializable]] = Callable[[Iterable[JSONSerializable], TMacroKwargs], T] -type TMetadataCallback[T: BaseModel, TMacroKwargs: dict[str, JSONSerializable]] = Callable[[Iterable[JSONSerializable], TMacroKwargs], T] +class NoCallbackMethodError(AttributeError): + def __init__(self) -> None: + super().__init__("no callback method available") class MacroMetadataCollector[T: BaseModel, TMacroKwargs: Any](m.MatcherDecoratableVisitor): # we bound TMacroKwargs to Any to allow TypedDict @@ -20,7 +23,7 @@ class MacroMetadataCollector[T: BaseModel, TMacroKwargs: Any](m.MatcherDecoratab def __init__(self, macro_name: str, callback: TMetadataCallback[T, TMacroKwargs] | None = None) -> None: super().__init__() if callback is None and not hasattr(self, "callback"): - raise AttributeError("no callback method available") + raise NoCallbackMethodError() if callback is not None: self.callback = callback self.macro_name = macro_name @@ -33,11 +36,13 @@ def collect_macro_metadata(self, node: cst.Decorator) -> None: args: list[JSONSerializable] = [] kwargs = cast("TMacroKwargs", {}) for arg in decorator_args: - evaled = ast.literal_eval(dump(node)) + # Extract the value from the argument, not the entire node + evaled = ast.literal_eval(arg.value.value) if isinstance(arg.value, cst.SimpleString) else ast.literal_eval(dump(arg.value)) if arg.keyword is None: args.append(evaled) else: kwargs[arg.keyword.value] = evaled - self.macro_data.append(self.callback(args, kwargs)) + # we want one callback per decorator, not per argument + self.macro_data.append(self.callback(args, kwargs)) case _: pass diff --git a/macroz/pyproject.toml b/core/macroz/pyproject.toml similarity index 61% rename from macroz/pyproject.toml rename to core/macroz/pyproject.toml index e215b36..3f2d8ff 100644 --- a/macroz/pyproject.toml +++ b/core/macroz/pyproject.toml @@ -1,8 +1,8 @@ [project] authors = [{ "name" = "Jeremy Meek" }] -dependencies = ["libcst==1.8.0", "pydantic==2.11.5"] +dependencies = ["libcst==1.8.0", "pydantic==2.11.7"] description = "A tool that brings type safety and type checking enhancements to the Polars library." -name = "painlezz-macroz" +name = "macroz" readme = "README.md" requires-python = ">=3.13,<3.14" version = "0.0.1" @@ -13,5 +13,8 @@ build-backend = "hatchling.build" requires = ["hatchling"] [tool.hatch.build] -exclude = ["painlezz_macroz/**/test_*.py"] -include = ["painlezz_macroz/**/*.j2", "painlezz_macroz/**/*.py"] +exclude = ["macroz/**/test_*.py"] +include = ["macroz/**/*.j2", "macroz/**/*.py"] + +[tool.rye] +dev-dependencies = ["pip-audit>=2.9.0"] diff --git a/core/macroz/tests/test_macroz.py b/core/macroz/tests/test_macroz.py new file mode 100644 index 0000000..62aa122 --- /dev/null +++ b/core/macroz/tests/test_macroz.py @@ -0,0 +1,226 @@ +# ruff: noqa: S101 + +from typing import TYPE_CHECKING, Iterable, cast +from inspect import signature + +import libcst as cst +import pytest +import libcst.matchers as m +from pydantic import BaseModel +from macroz.decoratorz.noop import func_macro, class_macro +from macroz.visitorz.macro_metadata_collector import ( + NoCallbackMethodError, + MacroMetadataCollector, +) + +if TYPE_CHECKING: + from macroz.visitorz.macro_metadata_collector import JSONSerializable + +DEFAULT_NAME: str = "" +DEFAULT_VALUE: int = 0 +TEST_MACRO_NAME: str = "test_macro" +CUSTOM_MACRO_NAME: str = "custom_macro" +TEST_CLASS_NAME: str = "TestClass" +TEST_NAME: str = "test" +TEST_VALUE: int = 42 +FUNC_TEST_NAME: str = "func_test" +FUNC_TEST_VALUE: int = 100 +TEST_FUNC_INPUT_X: int = 42 +TEST_FUNC_INPUT_Y: str = "hello" +TEST_FUNC_OUTPUT: str = "42 hello" + + +# Test data model for metadata collection +class TestMetadataModel(BaseModel): + name: str + value: int + + +# Sample callback for metadata collection +def sample_callback(args: Iterable["JSONSerializable"], kwargs: dict[str, "JSONSerializable"]) -> TestMetadataModel: + name: str = cast("str", kwargs.get("name", DEFAULT_NAME)) + value: int = cast("int", kwargs.get("value", DEFAULT_VALUE)) + return TestMetadataModel(name=name, value=value) + + +# convert evaluated_value to string +def safe_string_conversion(value: str | bytes) -> str: + """Convert LibCST evaluated_value to string, handling both str and bytes.""" + if isinstance(value, bytes): + return value.decode("utf-8") + return value + + +# Test no-op macros +def test_class_macro_identity() -> None: + """Test that class_macro returns the class unchanged.""" + + @class_macro + class TestClass: + pass + + assert TestClass.__name__ == TEST_CLASS_NAME + assert isinstance(TestClass(), TestClass) + + +def test_func_macro_preservation() -> None: + """Test that func_macro preserves function behavior and signature.""" + + @func_macro + def test_func(x: int, y: str = "default") -> str: + return f"{x} {y}" + + assert test_func(TEST_FUNC_INPUT_X, y=TEST_FUNC_INPUT_Y) == TEST_FUNC_OUTPUT + assert test_func.__name__ == "test_func" + # Check signature instead of co_varnames to verify parameter preservation + sig = signature(test_func) + assert list(sig.parameters.keys()) == ["x", "y"] + + +# Test MacroMetadataCollector +def test_macro_metadata_collector_initialization() -> None: + """Test MacroMetadataCollector initialization with callback.""" + collector = MacroMetadataCollector[TestMetadataModel, dict[str, "JSONSerializable"]](macro_name=TEST_MACRO_NAME, callback=sample_callback) + assert collector.macro_name == TEST_MACRO_NAME + assert collector.macro_data == [] + assert callable(collector.callback) + + +def test_macro_metadata_collector_no_callback_error() -> None: + """Test that NoCallbackMethodError is raised when no callback is provided.""" + + class InvalidCollector(MacroMetadataCollector[TestMetadataModel, dict[str, "JSONSerializable"]]): + pass + + with pytest.raises(NoCallbackMethodError): + InvalidCollector(macro_name=TEST_MACRO_NAME) + + +def test_macro_metadata_collection() -> None: + """Test metadata collection from a decorated class.""" + source_code = f""" +@{CUSTOM_MACRO_NAME}(name="{TEST_NAME}", value={TEST_VALUE}) +class {TEST_CLASS_NAME}: + pass +""" + + # Mock the collector to handle decorator syntax + class MockCollector(MacroMetadataCollector[TestMetadataModel, dict[str, "JSONSerializable"]]): + @m.leave(m.Decorator()) + def collect_macro_metadata(self, node: cst.Decorator) -> None: + matcher = m.Decorator(decorator=m.Call(func=m.Name(value=self.macro_name))) + if m.matches(node, matcher): + match node.decorator: + case cst.Call(args=decorator_args): + kwargs: dict[str, "JSONSerializable"] = {} + for arg in decorator_args: + if arg.keyword is not None: + if arg.keyword.value == "name" and isinstance(arg.value, cst.SimpleString): + kwargs["name"] = safe_string_conversion(arg.value.evaluated_value) + elif arg.keyword.value == "value" and isinstance(arg.value, cst.Integer): + kwargs["value"] = int(arg.value.evaluated_value) + self.macro_data.append(self.callback([], kwargs)) + case _: + pass + + collector = MockCollector(macro_name=CUSTOM_MACRO_NAME, callback=sample_callback) + module = cst.parse_module(source_code) + module.visit(collector) + + assert len(collector.macro_data) == 1 + metadata = collector.macro_data[0] + assert isinstance(metadata, TestMetadataModel) + assert metadata.name == TEST_NAME + assert metadata.value == TEST_VALUE + + +def test_macro_metadata_collection_empty() -> None: + """Test metadata collection when no matching decorators are found.""" + source_code = f""" +class {TEST_CLASS_NAME}: + pass +""" + collector = MacroMetadataCollector[TestMetadataModel, dict[str, "JSONSerializable"]](macro_name=CUSTOM_MACRO_NAME, callback=sample_callback) + + module = cst.parse_module(source_code) + module.visit(collector) + + assert len(collector.macro_data) == 0 + + +def test_macro_metadata_collection_function_call_syntax() -> None: + """Test metadata collection with function call syntax.""" + # This creates a proper decorator pattern that matches the AST structure + source_code = f""" +@{CUSTOM_MACRO_NAME}(name="{FUNC_TEST_NAME}", value={FUNC_TEST_VALUE}) +class {TEST_CLASS_NAME}: + pass +""" + + # Mock the collector to handle function call syntax + class MockCollector(MacroMetadataCollector[TestMetadataModel, dict[str, "JSONSerializable"]]): + @m.leave(m.Decorator()) + def collect_macro_metadata(self, node: cst.Decorator) -> None: + matcher = m.Decorator(decorator=m.Call(func=m.Name(value=self.macro_name))) + if m.matches(node, matcher): + match node.decorator: + case cst.Call(args=decorator_args): + kwargs: dict[str, "JSONSerializable"] = {} + for arg in decorator_args: + if arg.keyword is not None: + if arg.keyword.value == "name" and isinstance(arg.value, cst.SimpleString): + kwargs["name"] = safe_string_conversion(arg.value.evaluated_value) + elif arg.keyword.value == "value" and isinstance(arg.value, cst.Integer): + kwargs["value"] = int(arg.value.evaluated_value) + self.macro_data.append(self.callback([], kwargs)) + case _: + pass + + collector = MockCollector(macro_name=CUSTOM_MACRO_NAME, callback=sample_callback) + module = cst.parse_module(source_code) + module.visit(collector) + + assert len(collector.macro_data) == 1 + metadata = collector.macro_data[0] + assert isinstance(metadata, TestMetadataModel) + assert metadata.name == FUNC_TEST_NAME + assert metadata.value == FUNC_TEST_VALUE + + +def test_macro_metadata_collection_actual_function_call_syntax() -> None: + """Test metadata collection with actual function call syntax that returns a decorator.""" + # This tests the pattern: some_macro(args)(class) but applied as a decorator + source_code = f""" +class {TEST_CLASS_NAME}: + pass + +# Apply the decorator using function call syntax +{TEST_CLASS_NAME} = {CUSTOM_MACRO_NAME}(name="{FUNC_TEST_NAME}", value={FUNC_TEST_VALUE})({TEST_CLASS_NAME}) +""" + + # For this pattern, we need to look for function calls, not decorators + class FunctionCallCollector(MacroMetadataCollector[TestMetadataModel, dict[str, "JSONSerializable"]]): + @m.leave(m.Assign()) + def collect_macro_metadata(self, node: cst.Assign) -> None: + # Look for assignments like: ClassName = macro_name(args)(ClassName) + if len(node.targets) == 1 and isinstance(node.targets[0].target, cst.Name) and isinstance(node.value, cst.Call) and isinstance(node.value.func, cst.Call): + inner_call = node.value.func + if isinstance(inner_call.func, cst.Name) and inner_call.func.value == self.macro_name: + kwargs: dict[str, "JSONSerializable"] = {} + for arg in inner_call.args: + if arg.keyword is not None: + if arg.keyword.value == "name" and isinstance(arg.value, cst.SimpleString): + kwargs["name"] = safe_string_conversion(arg.value.evaluated_value) + elif arg.keyword.value == "value" and isinstance(arg.value, cst.Integer): + kwargs["value"] = int(arg.value.evaluated_value) + self.macro_data.append(self.callback([], kwargs)) + + collector = FunctionCallCollector(macro_name=CUSTOM_MACRO_NAME, callback=sample_callback) + module = cst.parse_module(source_code) + module.visit(collector) + + assert len(collector.macro_data) == 1 + metadata = collector.macro_data[0] + assert isinstance(metadata, TestMetadataModel) + assert metadata.name == FUNC_TEST_NAME + assert metadata.value == FUNC_TEST_VALUE diff --git a/core/pluginz/README.md b/core/pluginz/README.md new file mode 100644 index 0000000..853847c --- /dev/null +++ b/core/pluginz/README.md @@ -0,0 +1,331 @@ +# EZPZ-Pluginz + +A powerful tool that provides comprehensive type hinting and IDE support for Polars plugins, dramatically enhancing the development experience for custom Polars extensions. + +## Installation + +```bash +pip install ezpz_pluginz +``` + +## Problem It Solves + +Polars is an incredibly fast DataFrame library for Python, but it lacks native support for type hints and IDE integration with custom plugins. The Polars maintainers have indicated they have no immediate plans to address this gap from within Polars itself. Summit Sailors steps in to bridge this crucial developer experience gap. + +## Key Benefits + +With EZPZ-Pluginz, developers can: + +- **Enhanced Type Safety**: Write more robust and maintainable Polars plugins with full type checking support +- **Superior IDE Experience**: Leverage advanced IDE features including autocompletion, inline documentation, and error detection +- **Ecosystem Growth**: Contribute to the Polars ecosystem with greater confidence and tooling support +- **Hot Reloading**: Automatic type hint updates that point directly to your plugin implementations +- **Site-packages Integration**: Seamlessly load and manage plugins from installed packages +- **Registry Management**: Discover, install, and share plugins through a centralized registry system + +## How It Works + +EZPZ-Pluginz uses a sophisticated multi-step process to enhance your Polars development environment: + +1. **Configuration Parsing**: Reads your `ezpz.toml` configuration file +2. **Code Scanning**: Intelligently scans specified files and directories for plugin definitions +3. **AST Analysis**: Uses [libCST](https://libcst.readthedocs.io/en/latest/) for precise code analysis and metadata extraction +4. **Lockfile Generation**: Creates a comprehensive lockfile containing all discovered plugin metadata +5. **Safe Backup**: Creates backup copies of Polars files before any modifications +6. **Type Enhancement**: Applies libCST transformers to inject type hints into appropriate Polars classes +7. **Import Management**: Adds necessary imports within `TYPE_CHECKING` blocks for optimal performance + +![Lockfile Example](images/lockfile.png) +![Import Addition](images/attr_type_hint_import.png) +![Attribute Enhancement](images/attr_type_hint_added.png) + +## Plugin Definition Syntax + +EZPZ-Pluginz supports multiple syntax patterns for maximum flexibility: + +### Decorator Syntax + +```python +from ezpz_pluginz.register_plugin_macro import ezpz_plugin_collect + +@ezpz_plugin_collect( + polars_ns="LazyFrame", + attr_name="my_plugin", + import_="from my_package.plugins import MyLazyFramePlugin", + type_hint="MyLazyFramePlugin" +) +class MyLazyFramePlugin: + def custom_operation(self): + # Your plugin implementation + pass +``` + +### Function Call Syntax + +```python +from ezpz_pluginz.register_plugin_macro import ezpz_plugin_collect + +class MyDataFramePlugin: + def advanced_operation(self): + # Your plugin implementation + pass + +# Register the plugin using function call syntax +ezpz_plugin_collect( + polars_ns="DataFrame", + attr_name="advanced_plugin", + import_="from my_package.plugins import MyDataFramePlugin", + type_hint="MyDataFramePlugin" +)(MyDataFramePlugin) +``` + +### Supported Polars Namespaces + +- `DataFrame` - For DataFrame-specific plugins +- `LazyFrame` - For LazyFrame-specific plugins +- `Series` - For Series-specific plugins +- `Expr` - For Expression-specific plugins + +## Configuration + +Create an `ezpz.toml` file in your project root to specify plugin locations: + +```toml +[ezpz_pluginz] +name = "my-polars-project" +include = [ + "src/plugins/", + "plugins/dataframe_extensions.py", + "external/custom_ops/" +] +site_customize = true # Enable automatic plugin registration +``` + +Or using the config file (`pyproject.toml`): + +```toml +[tool.ezpz_pluginz] +name = "my-polars-project" +include = [ + "src/plugins/", + "plugins/dataframe_extensions.py", + "external/custom_ops/" +] +site_customize = true +``` + +### Configuration Options + +- `name`: Project identifier for your plugin collection +- `include`: List of files and directories to scan for plugins +- `site_customize`: Optional boolean to enable automatic plugin registration via sitecustomize.py + +## CLI Usage + +### Basic Plugin Management + +#### Mount Plugins + +Apply type hints and enable plugin support: + +```bash +ezpz mount +``` + +Loads plugins specified in your ezpz.toml configuration, makes plugin functions available for use, and should be run after installing new plugins or changing configuration. + +#### Unmount Plugins + +Restore original Polars files and remove modifications: + +```bash +ezpz unmount +``` + +Removes mounted plugins from your environment, useful for troubleshooting or cleaning up. + +### Plugin Discovery and Installation + +#### List Available Plugins + +```bash +ezpz list +``` + +Shows all plugins with installation status (✓ = installed, ○ = not installed), displays plugin descriptions, authors, and versions, and sets up local registry if not present. + +#### Advanced Plugin Search + +```bash +ezpz find [options] +``` + +Powerful search capabilities with flexible filtering options: + +- `--field name|description|author|package|category|aliases|all` - Search in specific fields +- `--remote` - Search remote registry +- `--both` - Search both local and remote +- `--case-sensitive` - Case-sensitive search +- `--exact` - Exact match +- `--limit N` - Limit results +- `--details` - Show detailed info + +Examples: + +```bash +# Search for Rust-based plugins +ezpz find rust --field category + +# Search for technical analysis plugins with details +ezpz find 'technical analysis' --remote --details + +# Exact search for polars-related plugins +ezpz find polars --both --exact +``` + +### Registry Management + +All registry management commands are under the `registry` subcommand: + +#### Check Registry Health + +```bash +ezpz registry health +``` + +Verifies connectivity and status of the central plugin registry server. + +#### Register a New Plugin + +```bash +ezpz registry register +``` + +Register a new plugin to the remote registry. Requires `AUTH_SECRET` environment variable. Plugin must have a `register_plugin()` function. Path should point to your plugin directory or file. + +#### Update an Existing Plugin + +```bash +ezpz registry push +``` + +Update an existing plugin in the registry. Requires `AUTH_SECRET` environment variable. Updates the plugin version in the remote registry. + +#### Refresh Registry + +```bash +ezpz registry refresh +``` + +Downloads latest plugin information from registry, run this to see newly published plugins, and is automatically done when installing plugins. + +#### Check Registry Status + +```bash +ezpz registry status +``` + +Shows registry URL and local cache information, displays number of available and verified plugins, and is useful for troubleshooting registry issues. + +#### Clear Registry Cache + +```bash +ezpz registry delete +``` + +Removes the local registry cache (`~/.ezpz`), useful for troubleshooting registry corruption or clearing cache, and registry can be automatically recreated using refresh. + +#### Delete a Plugin + +```bash +ezpz registry delete-plugin +``` + +Mark a plugin as deleted in the remote registry. Requires `AUTH_SECRET` environment variable. Removes the plugin from the local cache after successful remote deletion. + +### Getting Help + +```bash +ezpz --help +ezpz registry --help +``` + +Shows general help or help for registry commands. + +## Plugin Registry System + +EZPZ-Pluginz includes a comprehensive plugin registry system that enables: + +- **Plugin Discovery**: Browse and search through available plugins from the community +- **Easy Installation**: One-command installation of plugins with automatic dependency management +- **Automatic Updates**: Stay up-to-date with the latest plugin releases +- **Health Monitoring**: Check registry connectivity and status +- **Plugin Management**: Register, update, and delete plugins with proper authentication + +The registry system maintains a local cache for fast access and can synchronize with remote repositories to discover new plugins and updates. + +## Important Notes + +- **Minimally Invasive**: While this approach modifies the executing interpreter's Polars package, it uses libCST's concrete syntax trees to preserve file structure and formatting +- **Safe Backups**: Original files are always backed up before modification +- **Type Checking Only**: Imports are added within `TYPE_CHECKING` blocks to avoid runtime overhead +- **Reversible**: All changes can be completely undone using the unmount command +- **Authentication Required**: Registry operations that modify plugins require the `AUTH_SECRET` environment variable + +## Development Status + +### Beta Features ✅ + +- ~~Callable form of `pl.api`~~ +- ~~Install plugins from site-packages~~ +- ~~Basic logging system~~ +- Enhanced function call syntax support +- Robust string value extraction +- Improved error handling and validation +- Plugin registry system with discovery and installation +- Registry health monitoring +- Plugin update and deletion capabilities + +### Current Development Focus + +- Comprehensive functional testing suite +- Advanced exception handling and recovery +- ~~Python version compatibility (unpinned from 3.12.4 to ^3.12)~~ + +### Stability Roadmap + +- Extensive real-world testing and maturity +- Official blessing from the Polars team ([tracking issue](https://github.com/pola-rs/polars/issues/14475)) +- Community feedback integration +- Performance optimization + +## Advanced Features + +- **Automatic Hot Reloading**: Type hints point directly to implementations for immediate updates +- **Site-packages Integration**: Automatically discovers and loads plugins from installed packages +- **Lockfile Management**: Maintains state consistency across development sessions +- **Multi-syntax Support**: Flexible plugin definition patterns for different coding styles +- **Robust Error Handling**: Graceful handling of malformed plugin definitions +- **Registry Integration**: Seamless plugin discovery, installation, and management through centralized registry +- **Advanced Search**: Powerful search capabilities with field-specific filtering and remote/local search options +- **Registry Health Monitoring**: Built-in health checks for registry connectivity +- **Authenticated Operations**: Secure plugin registration and management with authentication + +## Contributing + +We welcome contributions! Please see our contributing guidelines for details on how to submit improvements, bug reports, and feature requests. + +## Support + +For support and sponsorship opportunities, visit our Polar page: + + + + +Subscription Tiers on Polar + + + +## License + +This project is licensed under the MIT License. See LICENSE file for details. diff --git a/pluginz/ezpz.toml b/core/pluginz/ezpz.toml similarity index 100% rename from pluginz/ezpz.toml rename to core/pluginz/ezpz.toml diff --git a/core/pluginz/ezpz_pluginz/__cli__.py b/core/pluginz/ezpz_pluginz/__cli__.py new file mode 100644 index 0000000..deebf49 --- /dev/null +++ b/core/pluginz/ezpz_pluginz/__cli__.py @@ -0,0 +1,561 @@ +# type: ignore[B008] + +import os +import time +import shutil +from typing import Any +from pathlib import Path + +import typer + +from ezpz_pluginz import mount_plugins, unmount_plugins +from ezpz_pluginz.logger import setup_logger +from ezpz_pluginz.registry import ( + REGISTRY_URL, + LOCAL_REGISTRY_DIR, + LOCAL_REGISTRY_FILE, + PluginRegistryAPI, + LocalPluginRegistry, + find_plugin_in_path, + is_package_installed, + setup_local_registry, +) +from ezpz_pluginz.toml_schema import load_config +from ezpz_pluginz.registry.models import PluginUpdate + +app = typer.Typer(name="ezpz", pretty_exceptions_show_locals=False, pretty_exceptions_short=True) +registry_app = typer.Typer(name="registry", help="Registry management commands") +app.add_typer(typer_instance=registry_app, name="registry") + +logger = setup_logger("CLI") + + +def get_auth_secret() -> str: + pat = os.getenv("AUTH_SECRET") + if not pat: + logger.error("Auth Secret required. Set AUTH_SECRET environment variable") + raise typer.Exit(1) + return pat + + +def return_bool(*, val: bool) -> bool: + return val + + +# Core plugin management commands +@app.command(name="mount") +def mount() -> None: + """ + Mount all configured plugins to make them available in your environment. + + Loads plugins specified in your ezpz.toml configuration. + Makes plugin functions available for use. + Run this after installing new plugins or changing configuration. + """ + mount_plugins() + + +@app.command(name="unmount") +def unmount() -> None: + """ + Unmount all plugins from your environment. + + Removes mounted plugins from your environment. + Useful for troubleshooting or cleaning up. + """ + unmount_plugins() + + +@app.command(name="list") +def list_plugins() -> None: + """ + List all available plugins in the registry. + + Shows all plugins with installation status (✓ = installed, ○ = not installed). + Displays plugin descriptions, authors, and versions. + Sets up local registry if not present. + """ + registry = LocalPluginRegistry() + plugins = registry.list_plugins() + + if not plugins: + logger.info("Local registry appears to be empty or not set up.") + if not LOCAL_REGISTRY_FILE.exists(): + logger.info("Setting up local plugin registry for the first time...") + setup_local_registry() + registry = LocalPluginRegistry() + plugins = registry.list_plugins() + else: + logger.info("Local registry exists but appears empty. Refreshing from remote...") + if registry.fetch_and_update_registry(): + plugins = registry.list_plugins() + else: + logger.error("Failed to refresh local registry from remote.") + + if not plugins: + logger.info("No plugins found in local registry after setup.") + logger.info("This could indicate:") + logger.info(" - Network connectivity issues") + logger.info(" - Remote registry is empty") + logger.info(" - Registry URL is incorrect") + logger.info(f" - Current registry URL: {REGISTRY_URL}") + logger.info("Try running 'ezpz registry refresh' manually to update from remote registry.") + return + + logger.info("Available EZPZ Plugins:") + logger.info("-" * 50) + for plugin in plugins: + installed = "✓" if is_package_installed(plugin.package_name) else "○" + logger.info(f"{installed} {plugin.name}") + logger.info(f" Package: {plugin.package_name}") + logger.info(f" Description: {plugin.description}") + if plugin.aliases: + logger.info(f" Aliases: {', '.join(plugin.aliases)}") + if plugin.author: + logger.info(f" Author: {plugin.author}") + if plugin.version: + logger.info(f" Version: {plugin.version}") + logger.info("") + logger.info("") + + +@app.command(name="find") +def find( + keyword: str = typer.Argument(help="Keyword to search for in plugins"), + *, + field: str = typer.Option(None, "--field", "-f", help="Search in specific field: name, description, author, package, category, aliases, all"), + remote: bool = typer.Option(return_bool(val=False), "--remote", "-r", help="Search in remote registry instead of local"), + both: bool = typer.Option(return_bool(val=False), "--both", "-b", help="Search in both local and remote registries"), + case_sensitive: bool = typer.Option(return_bool(val=False), "--case-sensitive", "-c", help="Perform case-sensitive search"), + exact: bool = typer.Option(return_bool(val=False), "--exact", "-e", help="Exact match instead of partial match"), + limit: int = typer.Option(50, "--limit", "-l", help="Maximum number of results to show"), + show_details: bool = typer.Option(return_bool(val=False), "--details", "-d", help="Show detailed plugin information"), +) -> None: + """ + Advanced search for plugins with flexible filtering. + + Search in specific fields: --field name|description|author|package|category|aliases|all. + Search remote registry: --remote. + Search both local and remote: --both. + Case-sensitive search: --case-sensitive. + Exact match: --exact. + Limit results: --limit N. + Show detailed info: --details. + + Examples: + ezpz find rust --field category + ezpz find 'technical analysis' --remote --details + ezpz find polars --both --exact + """ + valid_fields = {"name", "description", "author", "package", "category", "aliases", "all", None} + if field and field not in valid_fields: + logger.error(f"Invalid field '{field}'. Valid options: {', '.join(f for f in valid_fields if f)}") + raise typer.Exit(1) + + search_field = field or "all" + search_local = not remote or both + search_remote = remote or both + + if not keyword.strip(): + logger.error("Search keyword cannot be empty") + raise typer.Exit(1) + + local_results = [] + remote_results = [] + + if search_local: + try: + registry = LocalPluginRegistry() + local_results = advanced_search_local(registry, keyword, search_field, case_sensitive=case_sensitive, exact=exact) + except Exception as e: + logger.warning(f"Local search failed: {e}") + + if search_remote: + try: + api = PluginRegistryAPI() + remote_results = api.search_plugins(keyword) + if search_field != "all": + remote_results = filter_remote_results(remote_results, keyword, search_field, case_sensitive=case_sensitive, exact=exact) + except Exception as e: + logger.warning(f"Remote search failed: {e}") + + all_results = combine_results(local_results, remote_results) + if limit > 0: + all_results = all_results[:limit] + display_search_results(all_results, keyword, search_field, searched_local=search_local, searched_remote=search_remote, show_details=show_details) + + +# Registry subcommands +@registry_app.command(name="health") +def health() -> None: + """ + Check the health of the remote plugin registry. + + Verifies connectivity and status of the central plugin registry server. + """ + remote_reg = PluginRegistryAPI() + try: + response = remote_reg.check_health() + logger.info(response) + except Exception as e: + logger.exception("Health check failed") + raise typer.Exit(1) from e + + +@registry_app.command(name="register") +def register( + plugin_path: str = typer.Argument(..., help="Path to the plugin to register"), +) -> None: + """ + Register a new plugin to the remote registry. + + Requires AUTH_SECRET environment variable. + Plugin must have a register_plugin() function. + Path should point to your plugin directory or file. + Plugin will be made available to other users. + """ + config = load_config() + if not config: + logger.error("Could not load ezpz.toml configuration") + raise typer.Exit(1) + + local_registry = LocalPluginRegistry() + if not local_registry.fetch_and_update_registry(): + logger.warning("Failed to refresh local plugin registry, continuing with cached data") + + plugin_info = find_plugin_in_path(plugin_path, config.include_str_paths) + if plugin_info is None: + logger.error(f"No plugin found at path: {plugin_path}") + logger.info("Make sure the path contains a plugin with a register_plugin() function in the module entry i.e '__init__.py'") + logger.info(f"Searched in configured include paths: {config.include_str_paths}") + raise typer.Exit(1) + + if local_registry.is_plugin_registered(plugin_info.name): + logger.info(f"Plugin '{plugin_info.name}' is already registered") + logger.info("Skipping registration") + return + + auth_secret = get_auth_secret() + api = PluginRegistryAPI() + success = api.register_plugin(plugin_info, auth_secret) + + if success: + logger.info(f"Successfully registered '{plugin_info.name}'") + local_registry.fetch_and_update_registry() + else: + logger.error(f"Failed to register '{plugin_info.name}'") + raise typer.Exit(1) + + +@registry_app.command(name="push") +def update_plugin( + plugin_name: str = typer.Argument(help="Name of the plugin to update"), + plugin_path: str = typer.Argument(default=..., help="Path to the updated plugin"), +) -> None: + """ + Update an existing plugin in the registry. + + Requires AUTH_SECRET environment variable. + Updates the plugin version in the remote registry. + Plugin must already exist in the registry. + """ + auth_secret = get_auth_secret() + refresh() + config = load_config() + if not config: + logger.error("Could not load ezpz.toml configuration") + raise typer.Exit(1) + + plugin_info = find_plugin_in_path(plugin_path, config.include_str_paths) + if not plugin_info: + logger.error(f"No plugin found at path: {plugin_path}") + raise typer.Exit(1) + + local_registry = LocalPluginRegistry() + existing_plugin = local_registry.get_plugin(plugin_name) + if not existing_plugin: + logger.error(f"Plugin '{plugin_name}' not found in local registry") + logger.info("Try running 'ezpz registry refresh' to update the local registry") + raise typer.Exit(1) + + api = PluginRegistryAPI() + logger.info(f"Updating plugin: {plugin_info.name}") + plugin_update = PluginUpdate(**plugin_info.model_dump()) + success = api.update_plugin(existing_plugin.id, plugin_update, auth_secret) + + if success: + logger.info(f"Successfully updated '{plugin_info.name}'") + local_registry.fetch_and_update_registry() + else: + logger.error(f"Failed to update '{plugin_info.name}'") + raise typer.Exit(1) + + +@registry_app.command(name="refresh") +def refresh() -> None: + """ + Refresh the local plugin registry from remote. + + Downloads latest plugin information from registry. + Run this to see newly published plugins. + Automatically done when installing plugins. + """ + logger.info("Refreshing local plugin registry...") + registry = LocalPluginRegistry() + if registry.fetch_and_update_registry(): + logger.info("Local plugin registry refreshed successfully") + else: + raise typer.Exit(1) + + +@registry_app.command(name="status") +def status() -> None: + """ + Show current status of the plugin system. + + Shows registry URL and local cache information. + Displays number of available and verified plugins. + Useful for troubleshooting registry issues. + """ + registry = LocalPluginRegistry() + logger.info("EZPZ Plugin Registry Status:") + logger.info("-" * 40) + logger.info(f"Registry URL: {REGISTRY_URL}") + logger.info(f"Local registry directory: {LOCAL_REGISTRY_DIR}") + if LOCAL_REGISTRY_FILE.exists(): + registry_age = time.time() - LOCAL_REGISTRY_FILE.stat().st_mtime + hours_old = registry_age / 3600 + logger.info(f"Local registry file: {LOCAL_REGISTRY_FILE}") + logger.info(f"Registry age: {hours_old:.1f} hours") + else: + logger.info("Local registry file: Not found") + plugins = registry.list_plugins() + logger.info(f"Total plugins available: {len(plugins)}") + verified_count = sum(1 for p in plugins if p.verified) + logger.info(f"Verified plugins: {verified_count}") + + +@registry_app.command(name="delete") +def delete_registry() -> None: + """ + Delete the local plugin registry cache. + + Removes the local registry (~/.ezpz). + Useful for troubleshooting registry corruption or clearing cache. + Registry can be automatically recreated using refresh. + """ + LOCAL_REGISTRY = Path.home() / ".ezpz" + if not LOCAL_REGISTRY.exists(): + logger.info("Local registry file does not exist - nothing to delete") + return + try: + shutil.rmtree(LOCAL_REGISTRY) + logger.info(f"Successfully deleted local registry: {LOCAL_REGISTRY}") + except Exception as e: + logger.exception("Failed to delete local registry") + raise typer.Exit(1) from e + + +@registry_app.command(name="delete-plugin") +def delete_plugin(_id: str = typer.Argument(help="The ID of the plugin to be deleted")) -> None: + """ + Mark a plugin as deleted in the remote registry. + + Requires AUTH_SECRET environment variable. + Removes the plugin from the local cache after successful remote deletion. + """ + local_registry = LocalPluginRegistry() + remote_registry = PluginRegistryAPI() + pat = get_auth_secret() + try: + try: + plugin = remote_registry.get_plugin(_id) + if plugin.is_deleted: + logger.warning(f"Plugin {_id} is already deleted") + local_registry.remove_plugin_from_local_registry(plugin=plugin) + return + except Exception as e: + logger.warning(f"Failed to check plugin status: {e}") + remote_registry.delete_plugin(_id, pat) + logger.info(f"Successfully deleted plugin: {_id}") + if "plugin" in locals(): + local_registry.remove_plugin_from_local_registry(plugin=plugin) + else: + try: + plugin = remote_registry.get_plugin(_id) + local_registry.remove_plugin_from_local_registry(plugin=plugin) + except Exception: + local_registry.fetch_and_update_registry() + except Exception as e: + logger.exception("Failed to delete plugin") + raise typer.Exit(1) from e + + +# Helper functions +def advanced_search_local(registry: LocalPluginRegistry, keyword: str, field: str, *, case_sensitive: bool, exact: bool) -> list: + plugins = registry.list_plugins() + search_keyword = keyword if case_sensitive else keyword.lower() + + return [("local", plugin) for plugin in plugins if should_include_plugin(plugin, search_keyword, field, case_sensitive=case_sensitive, exact=exact)] + + +def filter_remote_results(plugins: list[dict[str, Any]], keyword: str, field: str, *, case_sensitive: bool, exact: bool) -> list: + if field == "all": + return [("remote", plugin) for plugin in plugins] + + search_keyword = keyword if case_sensitive else keyword.lower() + + return [("remote", plugin) for plugin in plugins if should_include_plugin(plugin, search_keyword, field, case_sensitive=case_sensitive, exact=exact)] + + +def should_include_plugin(plugin: dict[str, Any], search_keyword: str, field: str, *, case_sensitive: bool, exact: bool) -> bool: # noqa: PLR0911 + def get_field_value(plugin: dict[str, Any], field_name: str) -> str: + if field_name == "package": + field_name = "package_name" + + value = getattr(plugin, field_name, "") or "" + if not case_sensitive: + value = value.lower() + return value + + def get_aliases_text(plugin: dict[str, Any]) -> str: + aliases = getattr(plugin, "aliases", []) or [] + text = " ".join(aliases) + if not case_sensitive: + text = text.lower() + return text + + def matches_text(text: str, keyword: str, *, exact: bool) -> bool: + if exact: + return text == keyword + return keyword in text + + # Field-specific search + if field == "name": + return matches_text(get_field_value(plugin, "name"), search_keyword, exact=exact) + if field == "description": + return matches_text(get_field_value(plugin, "description"), search_keyword, exact=exact) + if field == "author": + return matches_text(get_field_value(plugin, "author"), search_keyword, exact=exact) + if field == "package": + return matches_text(get_field_value(plugin, "package_name"), search_keyword, exact=exact) + if field == "category": + return matches_text(get_field_value(plugin, "category"), search_keyword, exact=exact) + if field == "aliases": + return matches_text(get_aliases_text(plugin), search_keyword, exact=exact) + # field == "all" + search_fields = [ + get_field_value(plugin, "name"), + get_field_value(plugin, "description"), + get_field_value(plugin, "author"), + get_field_value(plugin, "package_name"), + get_field_value(plugin, "category"), + get_aliases_text(plugin), + ] + return any(matches_text(field_text, search_keyword, exact=exact) for field_text in search_fields) + + +def combine_results(local_results: list, remote_results: list) -> list: + """Combine and deduplicate local and remote results.""" + seen_plugins = set() + combined = [] + + # local results first (they take precedence) + for source, plugin in local_results: + plugin_key = (plugin.name, plugin.package_name) + if plugin_key not in seen_plugins: + combined.append((source, plugin)) + seen_plugins.add(plugin_key) + + # remote results that aren't already in local + for source, plugin in remote_results: + plugin_key = (plugin.name, plugin.package_name) + if plugin_key not in seen_plugins: + combined.append((source, plugin)) + seen_plugins.add(plugin_key) + + return combined + + +def display_search_results(results: list, keyword: str, field: str, *, searched_local: bool, searched_remote: bool, show_details: bool) -> None: + if not results: + search_scope = [] + if searched_local: + search_scope.append("local") + if searched_remote: + search_scope.append("remote") + scope_text = " and ".join(search_scope) + + logger.info(f"No plugins found matching '{keyword}' in {scope_text} registry") + if field != "all": + logger.info(f"Searched in field: {field}") + return + + # header + search_info = f"Found {len(results)} plugin(s) matching '{keyword}'" + if field != "all": + search_info += f" in field '{field}'" + + logger.info(search_info) + logger.info("-" * 60) + + local_results = [plugin for source, plugin in results if source == "local"] + remote_results = [plugin for source, plugin in results if source == "remote"] + + if local_results: + logger.info(f"LOCAL REGISTRY ({len(local_results)} results):") + logger.info("") + for plugin in local_results: + display_plugin_result(plugin, show_details=show_details, is_local=True) + + if remote_results: + if local_results: # separator if we have both + logger.info("") + logger.info("=" * 60) + logger.info("") + + logger.info(f"REMOTE REGISTRY ({len(remote_results)} results):") + logger.info("") + for plugin in remote_results: + display_plugin_result(plugin, show_details=show_details, is_local=False) + + +def display_plugin_result(plugin: dict[str, Any], *, show_details: bool, is_local: bool) -> None: + if is_local: + installed = "✓" if is_package_installed(plugin.package_name) else "○" + status_prefix = f"{installed} " + else: + installed = "✓" if is_package_installed(plugin.package_name) else "◯" + status_prefix = f"{installed} " + + logger.info(f"{status_prefix}{plugin.name}") + + if show_details: + logger.info(f" Package: {plugin.package_name}") + logger.info(f" Description: {plugin.description}") + + if hasattr(plugin, "aliases") and plugin.aliases: + logger.info(f" Aliases: {', '.join(plugin.aliases)}") + + if hasattr(plugin, "author") and plugin.author: + logger.info(f" Author: {plugin.author}") + + if hasattr(plugin, "version") and plugin.version: + logger.info(f" Version: {plugin.version}") + + if hasattr(plugin, "category") and plugin.category: + logger.info(f" Category: {plugin.category}") + + if hasattr(plugin, "verified") and plugin.verified: + logger.info(" Status: ✅ Verified") + + if not is_local: + logger.info(" Source: Remote Registry") + else: + logger.info(f" {plugin.description}") + + logger.info("") + + +if __name__ == "__main__": + app() diff --git a/pluginz/ezpz_pluginz/__init__.py b/core/pluginz/ezpz_pluginz/__init__.py similarity index 60% rename from pluginz/ezpz_pluginz/__init__.py rename to core/pluginz/ezpz_pluginz/__init__.py index 59e08ba..3f6a3da 100644 --- a/pluginz/ezpz_pluginz/__init__.py +++ b/core/pluginz/ezpz_pluginz/__init__.py @@ -1,54 +1,92 @@ import sys +import shutil import inspect -import logging import importlib from pathlib import Path from itertools import chain import libcst as cst -from ezpz_pluginz.lockfile import EZPZ_TOML_FILENAME, EZPZ_LOCKFILE_FILENAME, PolarsPluginLockfilePD +from ezpz_pluginz.logger import setup_logger +from ezpz_pluginz.lockfile import EZPZ_TOML_FILENAME, EZPZ_PROJECT_LOCKFILE_FILENAME, PolarsPluginLockfilePD from ezpz_pluginz.toml_schema import EzpzPluginConfig from ezpz_pluginz.e_polars_namespace import EPolarsNS from ezpz_pluginz.register_plugin_macro import PluginPatcher -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) +logger = setup_logger("ENTRY") def mount_plugins() -> None: - ezpz_pluginz_config = EzpzPluginConfig.from_toml_path(Path.cwd().joinpath(EZPZ_TOML_FILENAME)) + ezpz_pluginz_config = None + ezpz_toml_path = Path.cwd().joinpath(EZPZ_TOML_FILENAME) + + if ezpz_toml_path.exists(): + try: + ezpz_pluginz_config = EzpzPluginConfig.from_toml_path(ezpz_toml_path) + except Exception as e: + logger.warning(f"Failed to load ezpz.toml: {e}") + else: + pyproject_toml_path = Path.cwd().joinpath("pyproject.toml") + if pyproject_toml_path.exists(): + try: + ezpz_pluginz_config = EzpzPluginConfig.from_toml_path(pyproject_toml_path) + except Exception as e: + logger.warning(f"Failed to load pyproject.toml: {e}") + lockfile = PolarsPluginLockfilePD.generate() - lockfile.to_yaml_file(Path(EZPZ_LOCKFILE_FILENAME)) + lockfile.to_yaml_file(Path(EZPZ_PROJECT_LOCKFILE_FILENAME)) + + # plugin-level lock files using the same lockfile data + lockfile.generate_and_save_plugin_lockfiles() + polars_ns_to_plugins = dict(chain(lockfile.project_plugins.items(), lockfile.site_plugins.items())) pp = PluginPatcher(polars_ns_to_plugins) + polars_module = importlib.import_module("polars") + patched_dir = Path.cwd() / ".patched" + patched_dir.mkdir(exist_ok=True) + for ns in polars_ns_to_plugins: logger.info(f"Preparing to patch polars namespace {ns}...") filepath = Path(inspect.getfile(getattr(polars_module, ns))) backup_path = filepath.with_suffix(".bak") ext = ".bak" if backup_path.is_file() else ".py" source_code = filepath.with_suffix(ext).read_text() + if not backup_path.is_file(): logger.info("Creating backup of polars file...") backup_path.write_text(source_code) else: logger.info("Backup file already exists") + module = cst.parse_module(source_code) wrapper = cst.MetadataWrapper(module) + logger.info("Patching...") new_code = wrapper.visit(pp).code + logger.info("Saving...") filepath.write_text(new_code) - logger.info("Complete") + + local_copy_path = patched_dir / f"{ns.lower()}.py" + local_copy_path.write_text(new_code) + + logger.info(f"Patched copy saved to {local_copy_path}") + if ezpz_pluginz_config and ezpz_pluginz_config.site_customize: if hasattr(sys, "real_prefix") or (hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix): venv_site_path = Path(sys.prefix) / "lib" / f"python{sys.version_info.major}.{sys.version_info.minor}" / "site-packages" else: logger.warning("WARNING: The system python is executing, running ezpz plugins sitecustomize registry mouting is not advised.") return + if venv_site_path.exists(): - venv_site_path.joinpath("sitecustomize.py").write_text(lockfile.generate_registry()) + sitecustomize_code = lockfile.generate_registry() + sitecustomize_path = venv_site_path.joinpath("sitecustomize.py") + sitecustomize_path.write_text(sitecustomize_code) + + (patched_dir / "sitecustomize.py").write_text(sitecustomize_code) + logger.info(f"sitecustomize.py saved to {patched_dir / 'sitecustomize.py'}") def unmount_plugins() -> None: @@ -59,12 +97,19 @@ def unmount_plugins() -> None: if backup_path.is_file(): filepath.write_text(backup_path.read_text()) backup_path.unlink() + if hasattr(sys, "real_prefix") or (hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix): venv_site_path = Path(sys.prefix) / "lib" / f"python{sys.version_info.major}.{sys.version_info.minor}" / "site-packages" else: logger.warning("WARNING: The system python is executing, running ezpz plugins sitecustomize registry mouting is not advised.") return + if venv_site_path.exists(): sitecustomize = venv_site_path.joinpath("sitecustomize.py") if sitecustomize.exists(): sitecustomize.unlink() + + patched_dir = Path.cwd() / ".patched" + if patched_dir.exists(): + shutil.rmtree(patched_dir) + logger.info(f"Removed .patched directory: {patched_dir}") diff --git a/pluginz/ezpz_pluginz/e_polars_namespace.py b/core/pluginz/ezpz_pluginz/e_polars_namespace.py similarity index 83% rename from pluginz/ezpz_pluginz/e_polars_namespace.py rename to core/pluginz/ezpz_pluginz/e_polars_namespace.py index 04f7a68..e3f44e6 100644 --- a/pluginz/ezpz_pluginz/e_polars_namespace.py +++ b/core/pluginz/ezpz_pluginz/e_polars_namespace.py @@ -1,5 +1,5 @@ from enum import StrEnum -from typing import Any, Generator +from typing import Any, Generator, LiteralString class EPolarsNS(StrEnum): @@ -21,6 +21,6 @@ def api_decorator(self) -> str: return "register_series_namespace" @classmethod - def get_api_decorators(cls) -> Generator[str, Any, None]: + def get_api_decorators(cls) -> Generator[LiteralString, Any, None]: for e_pl_ns in EPolarsNS: yield f"register_{e_pl_ns.value.lower()}_namespace" diff --git a/core/pluginz/ezpz_pluginz/lockfile.py b/core/pluginz/ezpz_pluginz/lockfile.py new file mode 100644 index 0000000..299a967 --- /dev/null +++ b/core/pluginz/ezpz_pluginz/lockfile.py @@ -0,0 +1,179 @@ +import logging +import importlib +import contextlib +import importlib.util +import importlib.metadata +from typing import Self, Iterable +from pathlib import Path +from operator import attrgetter +from itertools import chain, groupby + +import yaml +from jinja2 import Template +from pydantic import BaseModel + +from ezpz_pluginz.toml_schema import EzpzPluginConfig +from ezpz_pluginz.register_plugin_macro import PolarsPluginMacroMetadataPD + +logger = logging.getLogger(__name__) + +EZPZ_TOML_FILENAME = "ezpz.toml" +EZPZ_PROJECT_LOCKFILE_FILENAME = "ezpz-lock.yaml" +EZPZ_PLUGIN_LOCKFILE_FILENAME = "ezpz-lock.yml" + + +def group_models_by_key[T: BaseModel](data: Iterable[T], key: str) -> dict[str, set[T]]: + sorted_data = sorted(data, key=attrgetter(key)) + return {k: set(v) for k, v in groupby(sorted_data, key=attrgetter(key))} + + +class PolarsPluginLockfilePD(BaseModel): + project_plugins: dict[str, set[PolarsPluginMacroMetadataPD]] + site_plugins: dict[str, set[PolarsPluginMacroMetadataPD]] + + @classmethod + def generate(cls) -> "PolarsPluginLockfilePD": + logger.debug(f"cwd: {Path.cwd()}") + + # Initialize empty project and site plugins + project_plugins = dict[str, set[PolarsPluginMacroMetadataPD]]() + site_plugins = dict[str, set[PolarsPluginMacroMetadataPD]]() + project_entry = cls(project_plugins=project_plugins, site_plugins=site_plugins) + + # Try to load project plugins from ezpz.toml or pyproject.toml + project_ezpz_toml_path = Path.cwd().joinpath(EZPZ_TOML_FILENAME) + pyproject_toml_path = Path.cwd().joinpath("pyproject.toml") + + try: + if project_ezpz_toml_path.exists(): + project_entry.project_plugins = EzpzPluginConfig.get_plugins(project_ezpz_toml_path) + logger.debug(f"Loaded plugins from {EZPZ_TOML_FILENAME}") + elif pyproject_toml_path.exists(): + project_entry.project_plugins = EzpzPluginConfig.get_plugins(pyproject_toml_path) + logger.debug("Loaded plugins from pyproject.toml") + else: + logger.warning(f"Neither {EZPZ_TOML_FILENAME} nor pyproject.toml found in {Path.cwd()}") + except ValueError as e: + logger.warning(f"Failed to load plugins: {e}. Continuing with empty project plugins.") + + logger.info("Proceeding to check from sitepackages") + # Check for site plugins from distributions + has_ezpz_pluginz_dep = False + for dist in importlib.metadata.distributions(): + if "ezpz-pluginz" in (dist.requires or []): + has_ezpz_pluginz_dep = True + spec = importlib.util.find_spec(dist.metadata["Name"].replace("-", "_")) + if spec and spec.origin: + patch_file = Path(spec.origin).with_name(EZPZ_PLUGIN_LOCKFILE_FILENAME) # Look for .yml file + if patch_file.exists(): + try: + project_entry.site_plugins.update(cls.from_yaml_file(patch_file).project_plugins) + logger.debug(f"Loaded site plugins from {patch_file}") + except Exception as e: + logger.warning(f"Failed to load site plugins from {patch_file}: {e}") + + if not project_entry.project_plugins and not has_ezpz_pluginz_dep: + logger.error("No project plugins found and no distributions depend on ezpz-pluginz.") + msg = "No project plugins or ezpz-pluginz dependencies found." + raise ValueError(msg) + + return project_entry + + def generate_registry(self) -> str: + imports = list[str]() + registry = list[str]() + for plugin in chain(chain.from_iterable(self.project_plugins.values()), chain.from_iterable(self.site_plugins.values())): + imports.append(plugin.import_) + registry.append(plugin.registery_entry()) + return Template(Path(__file__).parent.parent.joinpath("templates", "sitecustomize.py.j2").read_text()).render(imports=imports, registry=registry) + + def to_yaml(self) -> str: + return yaml.safe_dump(self.model_dump(mode="json"), sort_keys=False) + + @classmethod + def from_yaml(cls, content: str) -> Self: + return cls.model_validate(yaml.safe_load(content)) + + @classmethod + def from_yaml_file(cls, lockfile_path: "Path") -> Self: + return cls.from_yaml(lockfile_path.read_text()) + + def to_yaml_file(self, lockfile_path: "Path") -> None: + lockfile_path.write_text(self.to_yaml()) + + def generate_and_save_plugin_lockfiles(self) -> None: + if not self.project_plugins: + logger.debug("No project plugins found, skipping plugin-level lock file generation") + return + + for dist in importlib.metadata.distributions(): + if "ezpz-pluginz" in (dist.requires or []): + spec = importlib.util.find_spec(dist.metadata["Name"].replace("-", "_")) + if spec and spec.origin: + plugin_module_path = Path(spec.origin) + plugin_lockfile_path = plugin_module_path.with_name(EZPZ_PLUGIN_LOCKFILE_FILENAME) + + try: + # plugins specific to this distribution/package + plugin_specific_plugins = self._get_plugins_for_package(plugin_module_path.parent) + + if plugin_specific_plugins: + plugin_lockfile_data = PolarsPluginLockfilePD( + project_plugins=plugin_specific_plugins, + site_plugins={}, + ) + + plugin_lockfile_data.to_yaml_file(plugin_lockfile_path) + logger.info(f"Generated plugin-level lock file: {plugin_lockfile_path}") + else: + logger.debug(f"No plugins found for package at {plugin_module_path.parent}") + except Exception as e: + logger.warning(f"Failed to generate plugin-level lock file at {plugin_lockfile_path}: {e}") + + def _get_plugins_for_package(self, package_path: Path) -> dict[str, set[PolarsPluginMacroMetadataPD]]: + package_plugins: dict[str, set[PolarsPluginMacroMetadataPD]] = {} + package_name = package_path.name + + for polars_ns, plugins in self.project_plugins.items(): + matching_plugins: set[PolarsPluginMacroMetadataPD] = set() + + for plugin in plugins: + if self._plugin_belongs_to_package(plugin, package_name, package_path): + matching_plugins.add(plugin) + + if matching_plugins: + package_plugins[polars_ns] = matching_plugins + + return package_plugins + + def _plugin_belongs_to_package(self, plugin: PolarsPluginMacroMetadataPD, package_name: str, package_path: Path) -> bool: + import_statement = plugin.import_ + + # The module name from import statement + if import_statement.startswith("from "): + try: + module_part = import_statement.split(" import ")[0].replace("from ", "").strip() + + if module_part.startswith(package_name): + return True + + # Try to resolve the actual module path to be more accurate + try: + spec = importlib.util.find_spec(module_part) + if spec and spec.origin: + module_file_path = Path(spec.origin) + # module file is within the package directory ? + try: + module_file_path.relative_to(package_path) + except ValueError: + contextlib.suppress(ValueError) + except (ImportError, ModuleNotFoundError, ValueError): + # If we can't resolve the module, fall back to string matching + pass + else: + return True + + except (IndexError, ValueError): + logger.warning(f"Could not parse import statement: {import_statement}") + + return False diff --git a/core/pluginz/ezpz_pluginz/logger.py b/core/pluginz/ezpz_pluginz/logger.py new file mode 100644 index 0000000..bd810ea --- /dev/null +++ b/core/pluginz/ezpz_pluginz/logger.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +import inspect +import logging +from typing import TYPE_CHECKING, Literal, ClassVar +from pathlib import Path +from datetime import datetime + +import structlog + +if TYPE_CHECKING: + from structlog.types import EventDict, WrappedLogger, FilteringBoundLogger + +LogLevel = Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + +ColorKey = Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL", "RESET"] + +LogEventDict = dict[str, str | int | datetime | None] + + +class ColoredFormatter: + COLORS: ClassVar[dict[ColorKey, str]] = { + "DEBUG": "\033[36m", # Cyan + "INFO": "\033[32m", # Green + "WARNING": "\033[33m", # Yellow + "ERROR": "\033[31m", # Red + "CRITICAL": "\033[35m", # Magenta + "RESET": "\033[0m", # Reset + } + + def __call__(self, logger: WrappedLogger, method_name: str, event_dict: EventDict) -> str: + """Format log event with colors and aligned fields.""" + level: str = event_dict.get("level", method_name).upper() + log_color: str = self.COLORS.get(level, "") # type: ignore[dict-item] + reset_color: str = self.COLORS["RESET"] + + # caller info with defaults + filename: str = Path(event_dict.get("pathname", "unknown")).stem + lineno: str = str(event_dict.get("lineno", 0)) + + # the main message + event: str = str(event_dict.get("event", "")) + + # format + formatted: str = f"{log_color}[{level:8}]{reset_color} {filename}:{lineno:<4} - {event}" + + # structured data addition, excluding core fields + extra_data = {k: v for k, v in event_dict.items() if k not in ("event", "level", "pathname", "lineno", "timestamp", "logger")} + if extra_data: + formatted += f" {extra_data!r}" + + return formatted + + +def add_caller_info(_: WrappedLogger, __: str, event_dict: EventDict) -> EventDict: + frame = inspect.currentframe() + try: + # Walk up the stack to find the actual caller + caller_frame = frame + while caller_frame: + caller_frame = caller_frame.f_back + if caller_frame and not any(path in caller_frame.f_code.co_filename for path in ["structlog", "logging", "_log.py"]): + break + + if caller_frame: + event_dict["pathname"] = caller_frame.f_code.co_filename + event_dict["lineno"] = caller_frame.f_lineno + finally: + del frame + return event_dict + + +def setup_logger( + name: str = "app", + level: int | LogLevel = logging.INFO, +) -> FilteringBoundLogger: + if isinstance(level, str): + level = getattr(logging, level.upper()) + + if not structlog.is_configured(): + structlog.configure( + processors=[ + add_caller_info, + structlog.processors.TimeStamper(fmt="iso"), + structlog.processors.add_log_level, + ColoredFormatter(), + ], + wrapper_class=structlog.make_filtering_bound_logger(level), + logger_factory=structlog.PrintLoggerFactory(), + cache_logger_on_first_use=True, + ) + + return structlog.get_logger(name) + + +class LoggerFactory: + _default_logger: FilteringBoundLogger | None = None + + @classmethod + def get_logger(cls, name: str | None = None) -> FilteringBoundLogger: + if name is None: + if cls._default_logger is None: + cls._default_logger = setup_logger() + return cls._default_logger + return setup_logger(name) + + @classmethod + def reset(cls) -> None: + structlog.reset_defaults() + cls._default_logger = None + + +if __name__ == "__main__": + logger = LoggerFactory.get_logger("test") + + logger.debug("Debug message") + logger.info("Info message") + logger.warning("Warning message") + logger.error("Error message") + logger.critical("Critical message") + + # Structured logging example + logger.info("User action", user_id=12345, action="login") + + # Different logger instance + other_logger = LoggerFactory.get_logger("other") + other_logger.info("Message from another module") + +__all__ = ["LoggerFactory", "setup_logger"] diff --git a/pluginz/ezpz_pluginz/polars_class_provider.py b/core/pluginz/ezpz_pluginz/polars_class_provider.py similarity index 100% rename from pluginz/ezpz_pluginz/polars_class_provider.py rename to core/pluginz/ezpz_pluginz/polars_class_provider.py diff --git a/pluginz/ezpz_pluginz/register_plugin_macro.py b/core/pluginz/ezpz_pluginz/register_plugin_macro.py similarity index 53% rename from pluginz/ezpz_pluginz/register_plugin_macro.py rename to core/pluginz/ezpz_pluginz/register_plugin_macro.py index a4f81a7..448fb50 100644 --- a/pluginz/ezpz_pluginz/register_plugin_macro.py +++ b/core/pluginz/ezpz_pluginz/register_plugin_macro.py @@ -1,12 +1,12 @@ import logging -from typing import TYPE_CHECKING, Unpack, Callable, Sequence, TypedDict, cast +from typing import TYPE_CHECKING, Any, Unpack, Callable, Sequence, TypedDict, cast import libcst as cst import libcst.matchers as m from pydantic import BaseModel, ConfigDict from libcst.matchers import MatcherDecoratableTransformer -from painlezz_macroz.macroz.noop import class_macro -from painlezz_macroz.visitorz.macro_metadata_collector import MacroMetadataCollector +from macroz.decoratorz.noop import class_macro +from macroz.visitorz.macro_metadata_collector import MacroMetadataCollector from ezpz_pluginz.e_polars_namespace import EPolarsNS from ezpz_pluginz.polars_class_provider import PolarsClassProvider @@ -15,6 +15,11 @@ from ezpz_pluginz.register_plugin_macro import PolarsPluginMacroMetadataPD +class InvalidNamespaceError(Exception): + def __init__(self) -> None: + super().__init__("PANIC!") + + class PolarsPluginMacroKwargs(TypedDict): import_: str type_hint: str @@ -22,6 +27,7 @@ class PolarsPluginMacroKwargs(TypedDict): polars_ns: str +# purpose is to be recognized by painlezz_macroz (not an actual decorator) def ezpz_plugin_collect[T](**kwargs: Unpack[PolarsPluginMacroKwargs]) -> Callable[[T], T]: return class_macro @@ -38,6 +44,7 @@ def registery_entry(self) -> str: return f"pl.api.{self.polars_ns.api_decorator}('{self.attr_name}')({self.type_hint})" +# libsct visitor class PolarsPluginCollector(MacroMetadataCollector[PolarsPluginMacroMetadataPD, PolarsPluginMacroKwargs]): def __init__(self) -> None: super().__init__( @@ -50,10 +57,58 @@ def __init__(self) -> None: ), ) + # handles function call syntax e.g ezpz_plugin_collect(args, kwargs)(Class) + def visit_Call(self, node: cst.Call) -> bool: + if isinstance(node.func, cst.Call) and isinstance(node.func.func, cst.Name) and node.func.func.value == ezpz_plugin_collect.__name__: + kwargs = self._extract_kwargs_from_call(node.func) + if kwargs and node.args and len(node.args) > 0: + arg = node.args[0] + if isinstance(arg.value, cst.Name): + try: + metadata = PolarsPluginMacroMetadataPD( + import_=kwargs["import_"], + type_hint=kwargs["type_hint"], + attr_name=kwargs["attr_name"], + polars_ns=EPolarsNS(kwargs["polars_ns"]), + ) + except (KeyError, ValueError) as e: + logging.getLogger(__name__).warning(f"Failed to create plugin metadata: {e}") + return False # Stop recursion on error + else: + self.macro_data.append(metadata) + return False # Stop recursion for this node + return True # Continue recursion for other nodes + + def _extract_kwargs_from_call(self, call_node: cst.Call) -> dict[str, str] | None: + kwargs = dict[str, Any]() + + for arg in call_node.args: + if arg.keyword: + key = arg.keyword.value + value = self._extract_string_value(arg.value) + if value is not None: + kwargs[key] = value + + required_keys = {"import_", "type_hint", "attr_name", "polars_ns"} + if required_keys.issubset(kwargs.keys()): + return kwargs + return None + + def _extract_string_value(self, node: cst.BaseExpression) -> str | None: + if isinstance(node, cst.SimpleString): + # remove quotes from the string + return node.value.strip("\"'") + if isinstance(node, cst.ConcatenatedString): + # concatenated strings handling + parts = [part.value.strip("\"'") for part in (node.left, node.right) if isinstance(part, cst.SimpleString)] + return "".join(parts) if parts else None + return None + logger = logging.getLogger(__name__) +# libcst transformer (modifies polars source code) class PluginPatcher(MatcherDecoratableTransformer): METADATA_DEPENDENCIES = (PolarsClassProvider,) @@ -69,16 +124,17 @@ def visit_Module(self, node: cst.Module) -> None: self.has_added_imports = False self.imports = [cst.parse_module(plugin.import_).body[0] for plugin in self.plugins] + # called when libcst leaves a ClassDef node that matches a polars namespace @m.leave(m.ClassDef(name=m.Name(value=m.MatchIfTrue(lambda name: name in EPolarsNS)))) def add_new_attrs(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: if original_node.name.value != self.polars_ns: - raise Exception("PANIC") + raise InvalidNamespaceError() plugin_nodes = list[cst.AnnAssign]() for plugin in self.plugins: logger.info(f"Adding {plugin}") plugin_nodes.append(cst.AnnAssign(target=cst.Name(plugin.attr_name), annotation=cst.Annotation(cst.parse_expression(plugin.type_hint)), value=None)) new_body = list(updated_node.body.body) - new_body = new_body[:1] + [cst.SimpleStatementLine(body=plugin_nodes)] + new_body[1:] + new_body = [*new_body[:1], cst.SimpleStatementLine(body=plugin_nodes), *new_body[1:]] return updated_node.with_changes(body=cst.IndentedBlock(body=cast("Sequence[cst.BaseStatement]", new_body))) @m.leave(m.If(test=m.Name("TYPE_CHECKING"))) diff --git a/core/pluginz/ezpz_pluginz/registry/__init__.py b/core/pluginz/ezpz_pluginz/registry/__init__.py new file mode 100644 index 0000000..56e1527 --- /dev/null +++ b/core/pluginz/ezpz_pluginz/registry/__init__.py @@ -0,0 +1,15 @@ +from ezpz_pluginz.registry.utils import find_plugin_in_path, is_package_installed, setup_local_registry +from ezpz_pluginz.registry.config import REGISTRY_URL, LOCAL_REGISTRY_DIR, LOCAL_REGISTRY_FILE +from ezpz_pluginz.registry.reg.local import LocalPluginRegistry +from ezpz_pluginz.registry.reg.remote import PluginRegistryAPI + +__all__ = [ + "LOCAL_REGISTRY_DIR", + "LOCAL_REGISTRY_FILE", + "REGISTRY_URL", + "LocalPluginRegistry", + "PluginRegistryAPI", + "find_plugin_in_path", + "is_package_installed", + "setup_local_registry", +] diff --git a/core/pluginz/ezpz_pluginz/registry/config.py b/core/pluginz/ezpz_pluginz/registry/config.py new file mode 100644 index 0000000..854d802 --- /dev/null +++ b/core/pluginz/ezpz_pluginz/registry/config.py @@ -0,0 +1,54 @@ +import os +import tomllib +from typing import Any +from pathlib import Path + +from ezpz_pluginz.logger import setup_logger + +logger = setup_logger("Config") + +# Registry configuration +DEFAULT_REGISTRY_URL = "http://127.0.0.1:8080" +REGISTRY_URL = os.getenv("EZPZ_REGISTRY_URL", DEFAULT_REGISTRY_URL) +API_VERSION = "v1" +REQUEST_TIMEOUT = 30.0 + +# HTTP status codes +HTTP_UNAUTHORIZED = 401 +HTTP_NOT_FOUND = 404 +HTTP_SERVER_ERROR = 500 + +# Pagination +DEFAULT_BATCH_SIZE = 100 +DEFAULT_PAGE_START = 1 + +# Default values +DEFAULT_VERSION = "0.0.1" +DEFAULT_HOMEPAGE = "https://github.com/Summit-Sailors/EZPZ.git" + +# Local storage +LOCAL_REGISTRY_DIR = Path.home() / ".ezpz" / "registry" +LOCAL_REGISTRY_FILE = LOCAL_REGISTRY_DIR / "plugins.json" + + +def load_ezpz_config() -> dict[str, Any]: + config_file = Path("ezpz.toml") + if config_file.exists(): + try: + with config_file.open("rb") as f: + return tomllib.load(f).get("ezpz_pluginz", {}) + except Exception: + logger.warning("Failed to load ezpz.toml") + return {} + + pyproject_file = Path("pyproject.toml") + if pyproject_file.exists(): + try: + with pyproject_file.open("rb") as f: + return tomllib.load(f).get("tool", {}).get("ezpz", {}) + except Exception: + logger.warning("Failed to load pyproject.toml") + return {} + + logger.warning("Neither ezpz.toml nor pyproject.toml with [tool.ezpz_pluginz] found") + return {} diff --git a/core/pluginz/ezpz_pluginz/registry/exceptions.py b/core/pluginz/ezpz_pluginz/registry/exceptions.py new file mode 100644 index 0000000..4962401 --- /dev/null +++ b/core/pluginz/ezpz_pluginz/registry/exceptions.py @@ -0,0 +1,36 @@ +class PluginRegistryError(Exception): + def __init__(self, message: str = "An error occurred in the plugin registry") -> None: + super().__init__(message) + self.message = message + + +class PluginRegistryConnectionError(Exception): + def __init__(self, base_url: str, reason: str = "connection failed") -> None: + super().__init__(f"Unable to connect to registry at {base_url}: {reason}") + self.base_url = base_url + self.reason = reason + + +class PluginRegistryAuthError(Exception): + def __init__(self, message: str = "Authentication failed - invalid or expired token") -> None: + super().__init__(message) + + +class PluginNotFoundError(Exception): + def __init__(self, resource: str) -> None: + super().__init__(f"Resource not found: {resource}") + self.resource = resource + + +class PluginOperationError(Exception): + def __init__(self, operation: str, plugin_name: str, reason: str) -> None: + super().__init__(f"Failed to {operation} plugin '{plugin_name}': {reason}") + self.operation = operation + self.plugin_name = plugin_name + self.reason = reason + + +class PluginValidationError(Exception): + def __init__(self, field: str) -> None: + super().__init__(f"{field} cannot be empty") + self.field = field diff --git a/core/pluginz/ezpz_pluginz/registry/models.py b/core/pluginz/ezpz_pluginz/registry/models.py new file mode 100644 index 0000000..fe4f962 --- /dev/null +++ b/core/pluginz/ezpz_pluginz/registry/models.py @@ -0,0 +1,98 @@ +import re +from typing import Any, ClassVar, Optional + +from pydantic import Field, HttpUrl, EmailStr, BaseModel, field_validator + +from ezpz_pluginz.logger import setup_logger + +logger = setup_logger("Models") + +PACKAGE_NAME_REGEX = re.compile(r"^ezpz[_-][a-zA-Z0-9]([a-zA-Z0-9._-]*[a-zA-Z0-9])?$") + + +class PluginMetadataInner(BaseModel): + PY_VERSION_ERROR: ClassVar[str] = "python_version must be in the format '>=3.X' (e.g., '>=3.13')" + + tags: list[str] = Field(default_factory=list, description="Tags describing the plugin") + license: str = Field(..., description="License type (e.g., MIT, Apache-2.0)") + python_version: str = Field(..., description="Minimum Python version (e.g., >=3.13)") + dependencies: list[str] = Field(default_factory=list, description="List of required packages") + documentation: HttpUrl = Field(..., description="URL to plugin documentation") + support_email: EmailStr = Field(..., description="Contact email for support") + + @field_validator("python_version") + def validate_python_version(cls, v: str) -> str: + if not re.match(r"^>=3\.\d{1,2}$", v): + raise ValueError(cls.PY_VERSION_ERROR) + return v + + +class PluginMetadata(BaseModel): + VERSION_ERROR: ClassVar[str] = "Version must follow semantic versioning (e.g., '0.1.0')" + FIELD_ERROR: ClassVar[str] = "Field must not be empty" + + name: str = Field(..., description="Short name of the plugin") + package_name: str = Field(..., description="Package name for installation") + description: str = Field(..., description="Brief description of the plugin") + aliases: list[str] = Field(default_factory=list, description="Alternative names for the plugin") + version: str = Field(..., description="Plugin version (semantic versioning)") + author: str = Field(..., description="Author or maintainer of the plugin") + category: str = Field(..., description="Category of the plugin (e.g., Technical analysis)") + homepage: HttpUrl = Field(..., description="URL to plugin homepage") + metadata_: PluginMetadataInner = Field(..., description="Additional metadata") + + @field_validator("version") + def validate_version(cls, v: str) -> str: + if not re.match(r"^\d+\.\d+\.\d+$", v): + raise ValueError(cls.VERSION_ERROR) + return v + + @field_validator("name", "package_name", "description", "author", "category") + def validate_non_empty(cls, v: str) -> str: + if not v.strip(): + raise ValueError(cls.FIELD_ERROR) + return v.strip() + + +class PluginCreate(PluginMetadata): + pass # Inherits all fields and validation from PluginMetadata + + +class PluginResponse(PluginMetadata): + id: str = Field(..., description="Unique identifier for the plugin") + created_at: str = Field(..., description="Creation timestamp") + updated_at: str = Field(..., description="Last update timestamp") + verified: bool = Field(default=False, description="Whether the plugin is verified") + is_deleted: bool = Field(default=False, description="Whether the plugin is marked as deleted") + + +class PluginUpdate(BaseModel): + name: Optional[str] = Field(None, description="Short name of the plugin") + package_name: Optional[str] = Field(None, description="Package name for installation") + description: Optional[str] = Field(None, description="Brief description of the plugin") + aliases: Optional[list[str]] = Field(None, description="Alternative names for the plugin") + version: Optional[str] = Field(None, description="Plugin version (semantic versioning)") + author: Optional[str] = Field(None, description="Author or maintainer of the plugin") + category: Optional[str] = Field(None, description="Category of the plugin") + homepage: Optional[HttpUrl] = Field(None, description="URL to plugin homepage") + metadata_: Optional[PluginMetadataInner] = Field(None, description="Additional metadata") + + @field_validator("version") + def validate_version(cls, v: Optional[str]) -> Optional[str]: + if v and not re.match(r"^\d+\.\d+\.\d+$", v): + raise ValueError(PluginMetadata.VERSION_ERROR) + return v + + @field_validator("name", "package_name", "description", "author", "category") + def validate_non_empty(cls, v: Optional[str]) -> Optional[str]: + if v is not None and not v.strip(): + raise ValueError(PluginMetadata.FIELD_ERROR) + return v.strip() if v else v + + +def safe_deserialize_plugin(plugin_data: dict[str, Any]) -> Optional[PluginResponse]: + try: + return PluginResponse.model_validate(plugin_data) + except Exception: + logger.exception("Failed to deserialize plugin data") + return None diff --git a/macroz/painlezz_macroz/__init__.py b/core/pluginz/ezpz_pluginz/registry/reg/__init__.py similarity index 100% rename from macroz/painlezz_macroz/__init__.py rename to core/pluginz/ezpz_pluginz/registry/reg/__init__.py diff --git a/core/pluginz/ezpz_pluginz/registry/reg/local.py b/core/pluginz/ezpz_pluginz/registry/reg/local.py new file mode 100644 index 0000000..c7f1ae9 --- /dev/null +++ b/core/pluginz/ezpz_pluginz/registry/reg/local.py @@ -0,0 +1,155 @@ +import json +import time +import importlib.metadata +from typing import Optional + +from ezpz_pluginz.logger import setup_logger +from ezpz_pluginz.registry.config import LOCAL_REGISTRY_DIR, LOCAL_REGISTRY_FILE +from ezpz_pluginz.registry.models import PluginMetadata, PluginResponse, safe_deserialize_plugin # noqa: TC001 +from ezpz_pluginz.registry.reg.remote import PluginRegistryAPI + +logger = setup_logger("Registry") + + +class LocalPluginRegistry: + def __init__(self) -> None: + self._plugins: dict[str, PluginResponse] = {} + self._api = PluginRegistryAPI() + self._ensure_registry_dir() + self._load_local_registry() + + def _ensure_registry_dir(self) -> None: + LOCAL_REGISTRY_DIR.mkdir(parents=True, exist_ok=True) + + def _load_local_registry(self) -> None: + if not LOCAL_REGISTRY_FILE.exists(): + return + + try: + with LOCAL_REGISTRY_FILE.open("r") as f: + data = json.load(f) + for plugin_data in data.get("plugins", []): + plugin = safe_deserialize_plugin(plugin_data) + if plugin: + self._register_plugin(plugin) + logger.debug(f"Loaded {len(data.get('plugins', []))} plugins from local registry") + except Exception: + logger.warning("Failed to load local registry") + + def _save_local_registry(self, plugins: list[PluginResponse]) -> None: + try: + registry_data = {"timestamp": time.time(), "plugins": [plugin.model_dump() for plugin in plugins]} + with LOCAL_REGISTRY_FILE.open("w") as f: + json.dump(registry_data, f, indent=2) + logger.debug(f"Saved {len(plugins)} plugins to local registry") + except Exception: + logger.warning("Failed to save local registry") + + def _register_plugin(self, plugin: PluginResponse) -> None: + self._plugins[plugin.name.lower()] = plugin + for alias in plugin.aliases: + self._plugins[alias.lower()] = plugin + + def fetch_and_update_registry(self) -> bool: + logger.debug("Fetching plugins from remote registry...") + try: + remote_plugins = self._api.fetch_plugins() + if remote_plugins: + self._plugins.clear() + for plugin in remote_plugins: + self._register_plugin(plugin) + self._save_local_registry(remote_plugins) + logger.info(f"Updated local registry with {len(remote_plugins)} plugins") + except Exception: + logger.warning("Failed to update registry") + return False + return True + + def get_plugin(self, name: str) -> Optional[PluginResponse]: + return self._plugins.get(name.lower()) + + def list_plugins(self) -> list[PluginResponse]: + seen: set[str] = set() + unique_plugins: list[PluginResponse] = [] + + for plugin in self._plugins.values(): + if plugin.name not in seen: + unique_plugins.append(plugin) + seen.add(plugin.name) + return unique_plugins + + def is_plugin_registered(self, plugin_name: str) -> bool: + try: + plugin_name_lower = plugin_name.lower() + if plugin_name_lower in self._plugins: + return True + for plugin in self.list_plugins(): + if ( + plugin.name.lower() == plugin_name_lower + or plugin.package_name.lower() == plugin_name_lower + or plugin_name_lower in [alias.lower() for alias in plugin.aliases] + ): + return True + except Exception: + logger.warning(f"Error checking plugin registration for '{plugin_name}'") + return False + return False + + def search_plugins(self, keyword: str) -> list[PluginResponse]: + keyword_lower = keyword.lower() + matching_plugins: list[PluginResponse] = [] + seen: set[str] = set() + + for plugin in self._plugins.values(): + if plugin.name in seen: + continue + search_fields = [ + plugin.name.lower(), + plugin.description.lower(), + plugin.author.lower() if plugin.author else "", + *[alias.lower() for alias in plugin.aliases], + ] + if any(keyword_lower in field for field in search_fields): + matching_plugins.append(plugin) + seen.add(plugin.name) + return matching_plugins + + def remove_plugin_from_local_registry(self, plugin: PluginResponse) -> None: + try: + plugin_name_lower = plugin.name.lower() + if plugin_name_lower in self._plugins: + del self._plugins[plugin_name_lower] + for alias in plugin.aliases: + alias_lower = alias.lower() + if alias_lower in self._plugins: + del self._plugins[alias_lower] + remaining_plugins = self.list_plugins() + self._save_local_registry(remaining_plugins) + logger.debug(f"Removed plugin {plugin.name} from local registry") + except Exception as e: + logger.warning(f"Failed to remove plugin from local registry: {e}") + self.fetch_and_update_registry() + + +def discover_local_plugins() -> list[PluginResponse]: + plugins: list[PluginResponse] = [] + try: + for dist in importlib.metadata.distributions(): + entry_points = dist.entry_points + ezpz_plugins = entry_points.select(group="ezpz.plugins") if hasattr(entry_points, "select") else [ep for ep in entry_points if ep.group == "ezpz.plugins"] + for entry_point in ezpz_plugins: + try: + plugin_info_func = entry_point.load() + plugin_info: PluginMetadata = plugin_info_func() + plugin_response = PluginResponse( + id="", # ID will be assigned by registry + created_at="", + updated_at="", + **plugin_info.model_dump(), + ) + plugins.append(plugin_response) + except Exception: + logger.warning(f"Failed to load plugin from {entry_point.name}") + except ImportError: + logger.debug("importlib.metadata not available") + return plugins diff --git a/core/pluginz/ezpz_pluginz/registry/reg/remote.py b/core/pluginz/ezpz_pluginz/registry/reg/remote.py new file mode 100644 index 0000000..ea98380 --- /dev/null +++ b/core/pluginz/ezpz_pluginz/registry/reg/remote.py @@ -0,0 +1,219 @@ +import json +from typing import Any, ClassVar, Optional + +import httpx + +from ezpz_pluginz.logger import setup_logger +from ezpz_pluginz.registry.config import ( + API_VERSION, + REGISTRY_URL, + HTTP_NOT_FOUND, + REQUEST_TIMEOUT, + HTTP_SERVER_ERROR, + HTTP_UNAUTHORIZED, + DEFAULT_BATCH_SIZE, + DEFAULT_PAGE_START, +) +from ezpz_pluginz.registry.models import PluginCreate, PluginUpdate, PluginResponse, safe_deserialize_plugin # noqa: TC001 +from ezpz_pluginz.registry.exceptions import ( + PluginNotFoundError, + PluginRegistryError, + PluginOperationError, + PluginRegistryAuthError, + PluginRegistryConnectionError, +) + +logger = setup_logger("Registry") + + +class PluginRegistryAPI: + UNSUPPORTED_HTTP_METHOD_ERROR: ClassVar[str] = "Unsupported HTTP method: {method}" + EMPTY_SEARCH_KEYWORD_ERROR: ClassVar[str] = "Search keyword cannot be empty" + EMPTY_PLUGIN_ID_ERROR: ClassVar[str] = "Plugin ID cannot be empty" + GITHUB_TOKEN_REQUIRED_ERROR: ClassVar[str] = "Authentication is required" # noqa: S105 + + def __init__(self, base_url: str = REGISTRY_URL) -> None: + self.base_url = base_url.rstrip("/") + self.timeout = REQUEST_TIMEOUT + + def invalid_method(self, method: str) -> None: + raise ValueError(self.UNSUPPORTED_HTTP_METHOD_ERROR.format(method=method)) + + def _make_request( + self, + endpoint: str, + method: str = "POST", + data: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + params: dict[str, Any] | None = None, + *, + use_json: bool = False, + ) -> dict[str, Any]: + url = f"{self.base_url}/api/{API_VERSION}{endpoint}" + response = None + headers = headers or {} + request_data = data or {} + + try: + with httpx.Client(timeout=self.timeout) as client: + if method == "POST": + if use_json: + headers["Content-Type"] = "application/json" + response = client.post(url, json=request_data, headers=headers) + else: + headers["Content-Type"] = "application/x-www-form-urlencoded" + response = client.post(url, data=request_data, headers=headers) + elif method == "GET": + response = client.get(url, params=params, headers=headers) + else: + self.invalid_method(method) + + if response is not None: + if response.status_code == HTTP_UNAUTHORIZED: + raise PluginRegistryAuthError() + if response.status_code == HTTP_NOT_FOUND: + raise PluginNotFoundError(endpoint) + if response.status_code >= HTTP_SERVER_ERROR: + raise PluginRegistryError("Server_error") + response.raise_for_status() + if not response.content.strip(): + logger.debug(f"Empty response from {url}") + return {} + return response.json() if response is not None else {} + + except httpx.ConnectError as exc: + raise PluginRegistryConnectionError(self.base_url, "connection refused") from exc + except httpx.TimeoutException as exc: + raise PluginRegistryConnectionError(self.base_url, f"timeout after {self.timeout}s") from exc + except httpx.HTTPStatusError as exc: + if exc.response.status_code not in [HTTP_UNAUTHORIZED, HTTP_NOT_FOUND]: + raise PluginRegistryError(f"{exc.response.text}") from exc + raise + except (ValueError, json.JSONDecodeError) as exc: + raise PluginRegistryError(f"{exc}") from exc + + def check_health(self) -> dict[str, Any]: + logger.info("Checking registry health") + return self._make_request("/health", method="POST") + + def fetch_plugins(self, *, verified_only: bool = False) -> list[PluginResponse]: + all_plugins: list[PluginResponse] = [] + batch_size = DEFAULT_BATCH_SIZE + page = DEFAULT_PAGE_START + + logger.info(f"Fetching plugins from registry (verified_only={verified_only})") + + while True: + data = { + "page": str(page), + "page_size": str(batch_size), + "verified_only": str(verified_only).lower(), + } + response = self._make_request("/plugins", data=data) + plugins_data: list[dict[str, Any]] = response.get("plugins", []) + if not plugins_data: + break + batch_plugins: list[PluginResponse] = [] + for plugin_data in plugins_data: + plugin = safe_deserialize_plugin(plugin_data) + if plugin: + batch_plugins.append(plugin) + all_plugins.extend(batch_plugins) + logger.debug(f"Fetched page {page}: {len(batch_plugins)} plugins") + total_pages = response.get("total_pages", DEFAULT_PAGE_START) + if page >= total_pages: + break + page += 1 + logger.info(f"Successfully fetched {len(all_plugins)} plugins") + return all_plugins + + def search_plugins(self, keyword: str) -> list[PluginResponse]: + if not keyword.strip(): + raise ValueError(self.EMPTY_SEARCH_KEYWORD_ERROR) + logger.info(f"Searching plugins for keyword: '{keyword}'") + data = {"query_text": keyword} + response = self._make_request("/plugins/search", data=data) + plugins_data: list[dict[str, Any]] = response.get("plugins", []) + plugins: list[PluginResponse] = [] + for plugin_data in plugins_data: + plugin = safe_deserialize_plugin(plugin_data) + if plugin: + plugins.append(plugin) + logger.info(f"Search returned {len(plugins)} plugins") + return plugins + + def get_plugin(self, plugin_id: str) -> PluginResponse: + if not plugin_id.strip(): + raise ValueError(self.EMPTY_PLUGIN_ID_ERROR) + logger.info(f"Fetching plugin: {plugin_id}") + response = self._make_request(f"/plugins/get/{plugin_id}") + if not response: + raise PluginNotFoundError(plugin_id) + plugin = safe_deserialize_plugin(response) + if not plugin: + raise PluginRegistryError("Invalid_plugin_data") + logger.info(f"Successfully retrieved plugin: {plugin.name}") + return plugin + + def register_plugin(self, plugin_info: PluginCreate, auth_secret: str) -> Optional[PluginResponse]: + if not auth_secret.strip(): + raise ValueError(self.GITHUB_TOKEN_REQUIRED_ERROR) + logger.info(f"Registering plugin: {plugin_info.name}") + data = {"request": {"plugin_data": plugin_info.model_dump()}} + headers = {"Authorization": f"Bearer {auth_secret}"} + + def _handle_registration_error(error_msg: str, plugin_name: str) -> None: + raise PluginOperationError("register", plugin_name, error_msg) + + try: + response = self._make_request("/plugins/register", data=data, headers=headers, use_json=True) + plugin = safe_deserialize_plugin(response) + if not plugin: + error_msg = response.get("error", "Unknown registration error") + _handle_registration_error(error_msg, plugin_info.name) + logger.info("Successfully registered plugin") + except Exception as e: + error_message = ( + f"Failed to register plugin '{plugin_info.name}'.\n" + f"Possible reasons:\n" + f"1. Plugin name already exists (even if marked as deleted - wait for hard deletion),\n" + f"2. Network/server error,\n " + f"3. Invalid plugin data or authorization.\n " + "\n" + f"Error details: {e!s}\n" + ) + logger.exception(error_message) + return None + return plugin + + def update_plugin(self, plugin_id: str, plugin_info: PluginUpdate, auth_secret: str) -> PluginResponse: + if not plugin_id.strip(): + raise ValueError(self.EMPTY_PLUGIN_ID_ERROR) + if not auth_secret.strip(): + raise ValueError(self.GITHUB_TOKEN_REQUIRED_ERROR) + logger.info(f"Updating plugin: {plugin_id}") + plugin_dict = {k: v for k, v in plugin_info.model_dump().items() if v is not None} + data = {"request": {"plugin_data": plugin_dict}} + headers = {"Authorization": f"Bearer {auth_secret}"} + response = self._make_request(f"/plugins/update/{plugin_id}", data=data, headers=headers, use_json=True) + plugin = safe_deserialize_plugin(response) + if not plugin: + error_msg = response.get("error", "Unknown update error") + raise PluginOperationError("update", plugin_id, error_msg) + logger.info(f"Successfully updated plugin: {plugin_id}") + return plugin + + def delete_plugin(self, plugin_id: str, auth_secret: str) -> PluginResponse: + if not plugin_id.strip(): + raise ValueError(self.EMPTY_PLUGIN_ID_ERROR) + if not auth_secret.strip(): + raise ValueError(self.GITHUB_TOKEN_REQUIRED_ERROR) + logger.info(f"Deleting plugin: {plugin_id}") + data: dict[str, Any] = {} + headers = {"Authorization": f"Bearer {auth_secret}"} + response = self._make_request(f"/plugins/ delete/{plugin_id}", data=data, headers=headers, use_json=False) + plugin = safe_deserialize_plugin(response) + if not plugin: + error_msg = response.get("error", "Unknown deletion error") + raise PluginOperationError("delete", plugin_id, error_msg) + return plugin diff --git a/core/pluginz/ezpz_pluginz/registry/utils.py b/core/pluginz/ezpz_pluginz/registry/utils.py new file mode 100644 index 0000000..d1f165c --- /dev/null +++ b/core/pluginz/ezpz_pluginz/registry/utils.py @@ -0,0 +1,103 @@ +import importlib.util +import importlib.metadata +from typing import TYPE_CHECKING, Optional +from pathlib import Path + +from ezpz_pluginz.logger import setup_logger +from ezpz_pluginz.registry.reg.local import LocalPluginRegistry + +if TYPE_CHECKING: + from ezpz_pluginz.registry.models import PluginMetadata + +logger = setup_logger("Utils") + + +def is_package_installed(package_name: str) -> bool: + try: + importlib.metadata.distribution(package_name) + except importlib.metadata.PackageNotFoundError: + return False + return True + + +def setup_local_registry() -> None: + registry = LocalPluginRegistry() + success = registry.fetch_and_update_registry() + if success: + logger.info("Local registry setup completed successfully") + else: + logger.warning("Failed to setup local registry from remote") + + +def find_plugin_in_path(plugin_path: str, include_paths: list[str]) -> Optional["PluginMetadata"]: + plugin_path_obj = Path(plugin_path) + logger.info(f"Searching for plugin in: {plugin_path_obj}") + if plugin_path_obj.exists(): + plugin_info = _load_plugin_from_path(plugin_path_obj) + if plugin_info: + return plugin_info + for include_path in include_paths: + search_path = Path(include_path) + full_path = search_path / plugin_path + if full_path.exists(): + plugin_info = _load_plugin_from_path(full_path) + if plugin_info: + return plugin_info + if search_path.exists(): + for subdir in search_path.iterdir(): + if subdir.is_dir() and subdir.name == plugin_path: + plugin_info = _load_plugin_from_path(subdir) + if plugin_info: + return plugin_info + return None + + +def _load_plugin_from_path(plugin_path: Path) -> Optional["PluginMetadata"]: + try: + entry_point_patterns = [ + plugin_path / "python" / _extract_package_name(plugin_path.name) / "__init__.py", + plugin_path / "src" / _extract_package_name(plugin_path.name) / "__init__.py", + plugin_path / _extract_package_name(plugin_path.name) / "__init__.py", + plugin_path / "__init__.py", + ] + logger.debug(f"Checking entry point patterns: {[str(p) for p in entry_point_patterns]}") + for entry_point_path in entry_point_patterns: + if entry_point_path.exists(): + logger.debug(f"Found entry point: {entry_point_path}") + plugin_info = _load_plugin_from_file(entry_point_path) + if plugin_info: + return plugin_info + logger.debug(f"Searching recursively in {plugin_path}") + for init_file in plugin_path.rglob("__init__.py"): + logger.debug(f"Trying {init_file}") + plugin_info = _load_plugin_from_file(init_file) + if plugin_info: + return plugin_info + except Exception: + logger.warning(f"Error loading plugin from {plugin_path}") + return None + + +def _extract_package_name(plugin_dir_name: str) -> str: + return plugin_dir_name.replace("-", "_") + + +def _load_plugin_from_file(file_path: Path) -> Optional["PluginMetadata"]: + try: + if not file_path.exists(): + logger.warning(f"Plugin file does not exist: {file_path}") + return None + spec = importlib.util.spec_from_file_location(f"plugin_{file_path.stem}", file_path) + if spec is None or spec.loader is None: + logger.warning(f"Could not create spec for {file_path}") + return None + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + if hasattr(module, "register_plugin"): + register_func = module.register_plugin + plugin_data: PluginMetadata = register_func() + return plugin_data + logger.warning(f"No register_plugin function in {file_path}") + except Exception as e: + logger.error(f"Failed to load plugin {file_path}: {e}", exc_info=True) + return None diff --git a/pluginz/ezpz_pluginz/test_plugin.py b/core/pluginz/ezpz_pluginz/test_plugin.py similarity index 100% rename from pluginz/ezpz_pluginz/test_plugin.py rename to core/pluginz/ezpz_pluginz/test_plugin.py diff --git a/core/pluginz/ezpz_pluginz/toml_schema.py b/core/pluginz/ezpz_pluginz/toml_schema.py new file mode 100644 index 0000000..4887fac --- /dev/null +++ b/core/pluginz/ezpz_pluginz/toml_schema.py @@ -0,0 +1,136 @@ +import logging +from typing import TYPE_CHECKING, Any, ClassVar, Iterable, Optional, Generator +from pathlib import Path +from operator import attrgetter +from itertools import chain, groupby + +import toml +import libcst as cst +from pydantic import Field, BaseModel + +from ezpz_pluginz.register_plugin_macro import PolarsPluginCollector + +if TYPE_CHECKING: + from ezpz_pluginz.register_plugin_macro import PolarsPluginMacroMetadataPD + +__all__ = ["EzpzPluginConfig"] + +logger = logging.getLogger(__name__) + +EZPZ_TOML_FILENAME = "ezpz.toml" +EZPZ_PROJECT_LOCKFILE_FILENAME = "ezpz-lock.yaml" + + +def group_models_by_key[T: BaseModel](data: Iterable[T], key: str) -> dict[str, set[T]]: + sorted_data = sorted(data, key=attrgetter(key)) + return {k: set(v) for k, v in groupby(sorted_data, key=attrgetter(key))} + + +def _process_file(path: "Path") -> set["PolarsPluginMacroMetadataPD"]: + plugin_visitor = PolarsPluginCollector() + cst.parse_module(path.read_text()).visit(plugin_visitor) + logger.debug(f"_process_file: {path}") + logger.debug(f"_process_file:return: {plugin_visitor.macro_data}") + return set(plugin_visitor.macro_data) + + +def process_includes(paths: Iterable["Path"]) -> "Generator[PolarsPluginMacroMetadataPD, Any, None]": + for path in paths: + if path.is_file(): + yield from _process_file(path) + elif path.is_dir(): + sub_toml = path.joinpath(EZPZ_TOML_FILENAME) + if sub_toml.exists(): + yield from process_includes(path.joinpath(subpath) for subpath in EzpzPluginConfig.from_toml_path(sub_toml).include) + else: + yield from process_includes(chain(path.rglob("*.py"), path.rglob("*.pyi"))) + + +def get_plugins(project_toml_path: Path) -> dict[str, set["PolarsPluginMacroMetadataPD"]]: + ezpz_pluginz = EzpzPluginConfig.from_toml_path(project_toml_path) + return group_models_by_key(set(process_includes(ezpz_pluginz.include)), "polars_ns") + + +class EzpzPluginConfig(BaseModel): + INVALID_SECTION: ClassVar[str] = "No valid [ezpz_pluginz] or [tool.ezpz_pluginz] section found in specified path" + + name: str + include: list[Path] + site_customize: bool | None = Field(default=None) + + @property + def include_str_paths(self) -> list[str]: + return [str(path) for path in self.include] + + @staticmethod + def from_toml_path(path: Path) -> "EzpzPluginConfig": + try: + with path.open("r") as f: + data = toml.load(f) + toml_data = EzpzPluginToml(**data) + + if path.name == EZPZ_TOML_FILENAME and toml_data.ezpz_pluginz: + return toml_data.ezpz_pluginz + if path.name == "pyproject.toml" and toml_data.tool and "ezpz" in toml_data.tool: + return EzpzPluginConfig(**toml_data.tool["ezpz"]) + except Exception: + logger.exception(f"Error loading config from {path}") + raise + else: + raise ValueError(EzpzPluginConfig.INVALID_SECTION) + + @staticmethod + def get_plugins(project_toml_path: Path) -> dict[str, set["PolarsPluginMacroMetadataPD"]]: + ezpz_pluginz = EzpzPluginConfig.from_toml_path(project_toml_path) + return group_models_by_key(set(process_includes(ezpz_pluginz.include)), "polars_ns") + + +class EzpzPluginToml(BaseModel): + ezpz_pluginz: Optional[EzpzPluginConfig] = None + tool: Optional[dict[str, Any]] = None + + +def load_config(config_path: str | Path | None = None) -> Optional[EzpzPluginConfig]: + if config_path is None: + config_path = find_ezpz_toml() + if config_path is None: + logger.warning("Could not find ezpz.toml or pyproject.toml with [tool.ezpz_pluginz]") + return None + + config_path = Path(config_path) + if not config_path.exists(): + logger.error(f"Config file does not exist: {config_path}") + return None + + try: + return EzpzPluginConfig.from_toml_path(config_path) + except Exception: + logger.exception(f"Error loading config from {config_path}") + return None + + +def find_ezpz_toml(start_path: Path | None = None) -> Optional[Path]: + if start_path is None: + start_path = Path.cwd() + + current_dir = Path(start_path).resolve() + + for parent in [current_dir, *list(current_dir.parents)]: + config_file = parent / EZPZ_TOML_FILENAME + if config_file.exists(): + logger.debug(f"Found ezpz.toml at: {config_file}") + return config_file + + pyproject_file = parent / "pyproject.toml" + if pyproject_file.exists(): + try: + with pyproject_file.open("r") as f: + data = toml.load(f) + if data.get("tool", {}).get("ezpz"): + logger.debug(f"Found [tool.ezpz_pluginz] in pyproject.toml at: {pyproject_file}") + return pyproject_file + except Exception as e: + logger.debug(f"Error checking pyproject.toml at {pyproject_file}: {e}") + continue + + return None diff --git a/pluginz/icon.ico b/core/pluginz/icon.ico similarity index 100% rename from pluginz/icon.ico rename to core/pluginz/icon.ico diff --git a/pluginz/images/attr_type_hint_added.png b/core/pluginz/images/attr_type_hint_added.png similarity index 100% rename from pluginz/images/attr_type_hint_added.png rename to core/pluginz/images/attr_type_hint_added.png diff --git a/pluginz/images/attr_type_hint_import.png b/core/pluginz/images/attr_type_hint_import.png similarity index 100% rename from pluginz/images/attr_type_hint_import.png rename to core/pluginz/images/attr_type_hint_import.png diff --git a/pluginz/images/lockfile.png b/core/pluginz/images/lockfile.png similarity index 100% rename from pluginz/images/lockfile.png rename to core/pluginz/images/lockfile.png diff --git a/pluginz/pyproject.toml b/core/pluginz/pyproject.toml similarity index 77% rename from pluginz/pyproject.toml rename to core/pluginz/pyproject.toml index 1ddbe35..986f869 100644 --- a/pluginz/pyproject.toml +++ b/core/pluginz/pyproject.toml @@ -5,9 +5,11 @@ dependencies = [ "cached-property==2.0.1", "jinja2==3.1.6", "libcst==1.8.0", - "painlezz-macroz", - "pydantic==2.11.5", + "macroz", + "pydantic==2.11.7", + "pydantic[email]>=2.11.7", "pywatchman==3.0.0", + "structlog>=25.4.0", "toml==0.10.2", "typer==0.16.0", ] @@ -18,10 +20,10 @@ requires-python = ">=3.13,<3.14" version = "0.0.1" [tool.rye] -dev-dependencies = ["painlezz-macroz"] +dev-dependencies = ["macroz", "pip-audit>=2.9.0"] [project.scripts] -ezplugins = "ezpz_pluginz.__cli__:app" +ezpz = "ezpz_pluginz.__cli__:app" [build-system] build-backend = "hatchling.build" diff --git a/pluginz/ezpz_pluginz/templates/sitecustomize.py.j2 b/core/pluginz/templates/sitecustomize.py.j2 similarity index 55% rename from pluginz/ezpz_pluginz/templates/sitecustomize.py.j2 rename to core/pluginz/templates/sitecustomize.py.j2 index 9d981ab..4f7e9aa 100644 --- a/pluginz/ezpz_pluginz/templates/sitecustomize.py.j2 +++ b/core/pluginz/templates/sitecustomize.py.j2 @@ -1,12 +1,18 @@ import polars as pl try: + {% if imports %} {% for import_ in imports -%} {{ import_ }} {% endfor -%} - + {% endif %} + {% if registry %} {% for entry in registry -%} {{ entry }} {% endfor -%} + {% endif %} + {% if not imports and not registry %} + pass + {% endif %} except Exception as e: - print(e) + print(e) \ No newline at end of file diff --git a/pluginz/tests/__init__.py b/core/pluginz/tests/__init__.py similarity index 100% rename from pluginz/tests/__init__.py rename to core/pluginz/tests/__init__.py diff --git a/core/pluginz/tests/test_polars_plugin_collector.py b/core/pluginz/tests/test_polars_plugin_collector.py new file mode 100644 index 0000000..1d6b15e --- /dev/null +++ b/core/pluginz/tests/test_polars_plugin_collector.py @@ -0,0 +1,53 @@ +from pathlib import Path + +import libcst as cst +from hypothesis import ( + given, + strategies as st, +) + +from ezpz_pluginz.e_polars_namespace import EPolarsNS +from ezpz_pluginz.register_plugin_macro import PolarsPluginCollector + +identifier = st.from_regex(r"[a-zA-Z_][a-zA-Z0-9_]*", fullmatch=True) +filepath_strategy = st.builds(lambda parts: str(Path(*parts)), st.lists(identifier, min_size=1, max_size=5)) +root_dir_strategy = st.builds(lambda parts: str(Path(*parts)), st.lists(identifier, min_size=1, max_size=3)) +class_name_strategy = identifier + +namespace_name_strategy = st.sampled_from([ns.api_decorator for ns in EPolarsNS]) + + +def make_decorator(namespace_attr: str) -> cst.Decorator: + return cst.Decorator( + decorator=cst.Call( + func=cst.Attribute(value=cst.Name("pl"), attr=cst.Name(str(namespace_attr))), + args=[cst.Arg(value=cst.SimpleString(f'"{namespace_attr.split("_")[1]}_namespace"'))], + ) + ) + + +decorator_call_strategy = st.builds( + make_decorator, + namespace_name_strategy, +) + +class_def_strategy = st.builds( + lambda class_name, decorators: cst.ClassDef(name=cst.Name(class_name), body=cst.IndentedBlock(body=[]), decorators=decorators), + class_name_strategy, + st.lists(decorator_call_strategy, min_size=1, max_size=3), +) + + +@given(class_def=class_def_strategy) +def test_polars_plugin_collector(class_def: cst.ClassDef) -> None: + module = cst.Module(body=[class_def]) + collector = PolarsPluginCollector() + module.visit(collector) + + # Test should verify that plugins are collected correctly + if len(collector.macro_data) < 0: + raise ValueError("NEGATIVE_MACRO_DATA") + + +if __name__ == "__main__": + test_polars_plugin_collector() diff --git a/diesel.toml b/diesel.toml deleted file mode 100644 index 05b5f5a..0000000 --- a/diesel.toml +++ /dev/null @@ -1,8 +0,0 @@ -[print_schema.postings] -custom_type_derives = ["diesel::query_builder::QueryId"] -file = "api/src/table_name_here/schema.rs" -filter = { only_tables = ["table_name_here"] } - - -[migrations_directory] -dir = "migrations" diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..f7c14b8 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,7 @@ +# EZPZ Examples + +This Package contains example usages for different Polars plugins under EZPZ + +## Available Plugins + +1. `ezpz_ta` - Examples demonstrating how the [ezpz-rust-ti](https://github.com/Summit-Sailors/EZPZ/tree/main/ezpz-rust-ti) plugin works diff --git a/examples/ezpz_ta/standard.py b/examples/ezpz_ta/standard.py new file mode 100644 index 0000000..1853225 --- /dev/null +++ b/examples/ezpz_ta/standard.py @@ -0,0 +1,290 @@ +import time +import logging +import statistics +from typing import Unpack, Callable +from datetime import date, timedelta + +import polars as pl + +logger = logging.getLogger(__name__) + +# Thresholds for numerical accuracy comparison +TOLERANCE_MACHINE_PRECISION = 1e-10 +TOLERANCE_HIGH_ACCURACY = 1e-6 +TOLERANCE_MINOR_DIFFERENCE = 1e-3 + + +class InsufficientDataError(ValueError): ... + + +class BenchmarkResult: + def __init__(self, avg_time: float, min_time: float, max_time: float, std_dev: float) -> None: + self.avg_time = avg_time + self.min_time = min_time + self.max_time = max_time + self.std_dev = std_dev + + @property + def avg_time_ms(self) -> float: + return self.avg_time * 1000.0 + + +def sma_pure_python(prices: list[float], period: int) -> list[float]: + length = len(prices) + if period > length: + raise InsufficientDataError() + + result: list[float] = [] + + loop_max = length - period + 1 + for i in range(loop_max): + # The slice now starts from 'i' and goes for 'period' elements + window_sum = sum(prices[i : i + period]) + result.append(window_sum / period) + + return result + + +def sma_pure_python_optimized(prices: list[float], period: int) -> list[float]: + length = len(prices) + if period > length: + raise InsufficientDataError() + + result: list[float] = [] + + # the first SMA value + # The first window is from index 0 to period-1 + current_sum: float = sum(prices[0:period]) + result.append(current_sum / period) + + # Slide the window for subsequent values, starting from the next element after the first window + # The loop goes from 'period' up to 'length' + for i in range(period, length): + current_sum += prices[i] - prices[i - period] # Add new, subtract old + result.append(current_sum / period) + + return result + + +def benchmark_python_function( + func: Callable[[list[float], int], list[float]], *args: Unpack[tuple[list[float], int]], num_runs: int = 1000 +) -> tuple[BenchmarkResult, list[float] | None]: + times: list[float] = [] + result: list[float] | None = None + + for _ in range(num_runs): + start_time = time.perf_counter() + result = func(*args) + end_time = time.perf_counter() + times.append(end_time - start_time) + + benchmark_result = BenchmarkResult( + avg_time=statistics.mean(times), min_time=min(times), max_time=max(times), std_dev=statistics.stdev(times) if len(times) > 1 else 0.0 + ) + + return benchmark_result, result + + +def benchmark_rust_function( + func: Callable[[pl.LazyFrame, str, int], pl.Series], *args: Unpack[tuple[pl.LazyFrame, str, int]], num_runs: int = 1000 +) -> tuple[BenchmarkResult, pl.Series | None]: + times: list[float] = [] + result = None + + for _ in range(num_runs): + start_time = time.perf_counter() + result = func(*args) + end_time = time.perf_counter() + times.append(end_time - start_time) + + benchmark_result = BenchmarkResult( + avg_time=statistics.mean(times), min_time=min(times), max_time=max(times), std_dev=statistics.stdev(times) if len(times) > 1 else 0.0 + ) + + return benchmark_result, result + + +def create_test_data(num_points: int = 365) -> tuple[pl.LazyFrame, list[float]]: + start_date = date(2023, 1, 1) + end_date = start_date + timedelta(days=num_points - 1) + + # Create price data + close_prices = [100.5 + i * 0.1 for i in range(num_points)] + + _df = pl.select( + timestamp=pl.date_range(start=start_date, end=end_date, interval="1d", eager=True), + ).with_columns( + [ + pl.Series("open", [100 + i * 0.1 for i in range(num_points)]), + pl.Series("high", [101 + i * 0.1 for i in range(num_points)]), + pl.Series("low", [99 + i * 0.1 for i in range(num_points)]), + pl.Series("close", close_prices), + pl.Series("volume", [1000 + i * 10 for i in range(num_points)]), + ] + ) + + return _df.lazy(), close_prices + + +def compare_results_accuracy(first_result: list[float] | None, second_result: pl.Series | list[float] | None, title: str = "ACCURACY COMPARISON") -> None: + """Compare accuracy between Python and Rust implementations.""" + logger.info("=" * 50) + logger.info(title) + logger.info("=" * 50) + + second_result_list = list[float]() + + if isinstance(second_result, pl.Series): + second_result_list = second_result.to_list() + elif isinstance(second_result, list): + second_result_list = second_result + else: + raise TypeError("PANIC!") + + if first_result is not None and len(first_result) != len(second_result_list): + logger.error(f"Length mismatch: Python={len(first_result)}, Other={len(second_result_list)}") + return + + # Compare values + differences: list[float] = [] + max_diff = 0.0 + first_valid_idx = None + + if first_result is None: + raise ValueError("first_result_is_None") + + for i, (py_val, other_val) in enumerate(zip(first_result, second_result_list, strict=True)): + if first_valid_idx is None: + first_valid_idx = i + diff = abs(py_val - other_val) + differences.append(diff) + max_diff = max(max_diff, diff) + + if not differences: + logger.warning("No valid values to compare") + return + + avg_diff = statistics.mean(differences) + + logger.info(f"Values compared: {len(differences)}") + logger.info(f"Average difference: {avg_diff:.2e}") + logger.info(f"Maximum difference: {max_diff:.2e}") + + logger.info("\nSample value comparisons (last 5 valid values):") + sample_size = min(5, len(differences)) + start_idx_for_display = len(first_result) - sample_size + start_idx_for_display = max(start_idx_for_display, 0) + + for i in range(start_idx_for_display, len(first_result)): + py_val = first_result[i] + other_val = second_result_list[i] + if other_val is not None: + diff = abs(py_val - other_val) + logger.info(f"Index {i}: Python={py_val:.8f}, Other={other_val:.8f}, Diff={diff:.2e}") + + # Accuracy assessment + if max_diff < TOLERANCE_MACHINE_PRECISION: + logger.info("✓ Results are numerically identical (within machine precision)") + elif max_diff < TOLERANCE_HIGH_ACCURACY: + logger.info("✓ Results are highly accurate (sub-microsecond differences)") + elif max_diff < TOLERANCE_MINOR_DIFFERENCE: + logger.info("⚠️ Results have minor differences (sub-millisecond)") + else: + logger.error("✗ Results have significant differences") + + logger.info("") + + +def main() -> None: # noqa: PLR0915 + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + period = 20 + num_runs = 1000 + + dataset_sizes = [365, 10_000, 100_000, 1_000_000] + + for size in dataset_sizes: + logger.info(f"--- Benchmarks for {size:,} data points ---") + logger.info("=" * 50) + + lf, prices = create_test_data(size) + + logger.info(f"Data points: {size:,}") + logger.info(f"SMA period: {period}") + logger.info(f"Benchmark runs: {num_runs}") + logger.info("") + + logger.info("Benchmarking Original Pure Python SMA...") + python_orig_benchmark, python_orig_result = benchmark_python_function(sma_pure_python, prices, period, num_runs=num_runs) + logger.info(f"Original Python avg: {python_orig_benchmark.avg_time_ms:.4f} ms") + + logger.info("Benchmarking Optimized Pure Python SMA...") + python_opt_benchmark, python_opt_result = benchmark_python_function(sma_pure_python_optimized, prices, period, num_runs=num_runs) + logger.info(f"Optimized Python avg: {python_opt_benchmark.avg_time_ms:.4f} ms") + + # Original Python vs Optimized Python (Accuracy Check) + compare_results_accuracy(python_orig_result, python_opt_result, title="ORIGINAL VS OPTIMIZED PYTHON ACCURACY") + + logger.info("Benchmarking Rust SMA...") + + def rust_sma_wrapper(lf: pl.LazyFrame, price_column: str, period: int) -> pl.Series: + return lf.standard_ti.sma_bulk(price_column, period) + + try: + rust_benchmark, rust_result = benchmark_rust_function( + rust_sma_wrapper, + lf, + "close", + period, + num_runs=num_runs, + ) + logger.info(f"Rust avg: {rust_benchmark.avg_time_ms:.4f} ms") + + # Python Results against Rust results (Accuracy Check) + compare_results_accuracy(python_opt_result, rust_result, title="OPTIMIZED PYTHON VS RUST ACCURACY") + + # --- Final Performance Comparison --- + logger.info("") + logger.info("=" * 50) + logger.info("PERFORMANCE RESULTS SUMMARY") + logger.info("=" * 50) + logger.info(f"Original Python: {python_orig_benchmark.avg_time_ms:.4f} ms") + logger.info(f"Optimized Python: {python_opt_benchmark.avg_time_ms:.4f} ms") + logger.info(f"Rust: {rust_benchmark.avg_time_ms:.4f} ms") + + logger.info("\n--- Speedup (vs. Original Python) ---") + if python_opt_benchmark.avg_time < python_orig_benchmark.avg_time: + speedup_opt = python_orig_benchmark.avg_time / python_opt_benchmark.avg_time + logger.info(f"✓ Optimized Python is {speedup_opt:.1f}x FASTER than Original Python") + else: + logger.info("⚠️ Optimized Python is not faster than Original Python (unlikely)") + + if rust_benchmark.avg_time < python_orig_benchmark.avg_time: + speedup_rust_vs_orig = python_orig_benchmark.avg_time / rust_benchmark.avg_time + logger.info(f"✓ Rust is {speedup_rust_vs_orig:.1f}x FASTER than Original Python") + else: + slowdown_rust_vs_orig = rust_benchmark.avg_time / python_orig_benchmark.avg_time + logger.info(f"⚠️ Rust is {slowdown_rust_vs_orig:.1f}x SLOWER than Original Python") + + logger.info("\n--- Speedup (vs. Optimized Python) ---") + if rust_benchmark.avg_time < python_opt_benchmark.avg_time: + speedup_rust_vs_opt = python_opt_benchmark.avg_time / rust_benchmark.avg_time + logger.info(f"✓ Rust is {speedup_rust_vs_opt:.1f}x FASTER than Optimized Python") + else: + slowdown_rust_vs_opt = rust_benchmark.avg_time / python_opt_benchmark.avg_time + logger.info(f"⚠️ Rust is {slowdown_rust_vs_opt:.1f}x SLOWER than Optimized Python") + logger.info(" (This suggests overhead in the Rust binding or small dataset size)") + + logger.info("") + + except AttributeError: + logger.exception("rust_ti extension not available - cannot benchmark Rust implementation") + logger.info("Install the rust_ti extension to compare with Rust performance") + break + except Exception: + logger.exception("Error benchmarking Rust implementation") + break + + +if __name__ == "__main__": + main() diff --git a/examples/ezpz_ta/volatility.py b/examples/ezpz_ta/volatility.py new file mode 100644 index 0000000..aa637fa --- /dev/null +++ b/examples/ezpz_ta/volatility.py @@ -0,0 +1,167 @@ +# ruff: noqa: NPY002, T201 +import numpy as np +import polars as pl + +DAYS_IN_JANUARY = 31 +DAYS_IN_JAN_FEB = 59 + + +def test_volatility_ti_plugin() -> None: # noqa: PLR0915 + np.random.seed(42) + n_periods = 100 + + base_price = 100.0 + returns = np.random.normal(0, 0.02, n_periods) + prices = [base_price] + + for ret in returns: + prices.append(prices[-1] * (1 + ret)) + + high_prices = [p * (1 + abs(np.random.normal(0, 0.01))) for p in prices[1:]] + low_prices = [p * (1 - abs(np.random.normal(0, 0.01))) for p in prices[1:]] + close_prices = prices[1:] + + _df = pl.DataFrame({"high": high_prices, "low": low_prices, "close": close_prices, "volume": np.random.randint(1000, 10000, n_periods)}) + + print("Sample data:") + print(_df.head()) + print(f"\nData shape: {_df.shape}") + + lf = _df.lazy() + + print("\n=== Testing Ulcer Index (Single) via Plugin ===") + try: + ulcer_single = lf.volatility_ti.ulcer_index_single("close") + print(f"Single Ulcer Index: {ulcer_single:.6f}") + except Exception as e: + print(f"Error in ulcer_index_single: {e}") + + print("\n=== Testing Ulcer Index (Bulk) via Plugin ===") + try: + ulcer_bulk_series = lf.volatility_ti.ulcer_index_bulk("close", period=14) + print(f"Ulcer Index Bulk Series type: {type(ulcer_bulk_series)}") + print(f"Series name: {ulcer_bulk_series.name}") + print(f"Series length: {len(ulcer_bulk_series)}") + print(f"First 10 values: {ulcer_bulk_series.head(10).to_list()}") + print(f"Last 10 values: {ulcer_bulk_series.tail(10).to_list()}") + except Exception as e: + print(f"Error in ulcer_index_bulk: {e}") + + print("\n=== Testing Volatility System via Plugin ===") + print("Skipping volatility_system test - no supported constant model types found") + + print("\n=== Testing Integration with Polars Operations ===") + try: + ulcer_series = lf.volatility_ti.ulcer_index_bulk("close", period=14) + original_df = lf.collect() + padding_length = len(original_df) - len(ulcer_series) + padded_ulcer = [None] * padding_length + ulcer_series.to_list() + result_df = original_df.with_columns(pl.Series("ulcer_index_14", padded_ulcer)) + + print("DataFrame with ulcer index:") + print(result_df.head()) + print(f"\nFinal DataFrame shape: {result_df.shape}") + print(f"Columns: {result_df.columns}") + + non_null_ulcer = result_df.filter(pl.col("ulcer_index_14").is_not_null()) + print(f"\nNon-null ulcer index values: {len(non_null_ulcer)}") + print(non_null_ulcer.head()) + + except Exception as e: + print(f"Error in integration test: {e}") + + print("\n=== Testing Error Handling ===") + try: + lf.volatility_ti.ulcer_index_single("invalid_column") + except Exception as e: + print(f"Expected error for invalid column: {e}") + + print("\n=== Performance Test ===") + large_n = 10000 + large_prices = [base_price] + large_returns = np.random.normal(0, 0.02, large_n) + + for ret in large_returns: + large_prices.append(large_prices[-1] * (1 + ret)) + + large_df = pl.DataFrame( + { + "high": [p * (1 + abs(np.random.normal(0, 0.01))) for p in large_prices[1:]], + "low": [p * (1 - abs(np.random.normal(0, 0.01))) for p in large_prices[1:]], + "close": large_prices[1:], + } + ) + + large_lf = large_df.lazy() + + try: + ulcer = large_lf.volatility_ti.ulcer_index_single("close") + print(f"Large dataset ({large_n} rows) Ulcer Index: {ulcer:.6f}") + except Exception as e: + print(f"Error with large dataset: {e}") + + +def test_chaining_operations() -> None: + print("\n=== Testing Method Chaining ===") + + np.random.seed(123) + n = 200 + base_price = 100.0 + returns = np.random.normal(0, 0.015, n) + prices = [base_price] + + for ret in returns: + prices.append(prices[-1] * (1 + ret)) + + timestamps = [ + f"2024-01-{i + 1:02d}" if i < DAYS_IN_JANUARY else f"2024-02-{i - 30:02d}" if i < DAYS_IN_JAN_FEB else f"2024-03-{i - 58:02d}" for i in range(n) + ] + + _df = pl.DataFrame( + { + "timestamp": timestamps, + "high": [p * (1 + abs(np.random.normal(0, 0.008))) for p in prices[1:]], + "low": [p * (1 - abs(np.random.normal(0, 0.008))) for p in prices[1:]], + "close": prices[1:], + "volume": np.random.randint(5000, 50000, n), + } + ) + + lf = _df.lazy() + + try: + ulcer_series = lf.volatility_ti.ulcer_index_bulk("close", period=20) + base_df = lf.collect() + padding_length = len(base_df) - len(ulcer_series) + padded_ulcer = [None] * padding_length + ulcer_series.to_list() + + result = ( + base_df.with_columns( + [ + pl.Series("ulcer_20", padded_ulcer), + pl.col("close").rolling_mean(window_size=20).alias("sma_20"), + pl.col("close").rolling_std(window_size=20).alias("std_20"), + (pl.col("close") / pl.col("close").shift(1) - 1).alias("returns"), + ] + ) + .filter(pl.col("ulcer_20").is_not_null()) + .select(["timestamp", "close", "ulcer_20", "sma_20", "std_20", "returns"]) + ) + + print("Chained operations result:") + print(result.head(10)) + print(f"\nResult shape: {result.shape}") + + print("\nUlcer Index 20 stats:") + print(f" Mean: {result['ulcer_20'].mean():.6f}") + print(f" Std: {result['ulcer_20'].std():.6f}") + print(f" Min: {result['ulcer_20'].min():.6f}") + print(f" Max: {result['ulcer_20'].max():.6f}") + + except Exception as e: + print(f"Error in chaining test: {e}") + + +if __name__ == "__main__": + test_volatility_ti_plugin() + test_chaining_operations() diff --git a/examples/pyproject.toml b/examples/pyproject.toml new file mode 100644 index 0000000..98106a5 --- /dev/null +++ b/examples/pyproject.toml @@ -0,0 +1,11 @@ +[project] +authors = [{ "name" = "Stephen Oketch" }] +dependencies = ["ezpz-pluginz", "polars==1.31.0", "pyarrow==20.0.0"] +description = "Examples showcasing use of the ezpz-rust-ti plugin" +name = "ezpz_ta" +readme = "README.md" +requires-python = ">=3.13,<3.14" +version = "0.0.1" + +[tool.rye] +dev-dependencies = ["pip-audit>=2.9.0"] diff --git a/ezpz.toml b/ezpz.toml index d9f4b4f..231be06 100644 --- a/ezpz.toml +++ b/ezpz.toml @@ -1,4 +1,4 @@ [ezpz_pluginz] -include = ["ezpz-guiz", "ezpz-pluginz"] +include = ["plugins/ezpz-rust-ti"] name = "ezpz" site_customize = true diff --git a/guiz/Cargo.toml b/guiz/Cargo.toml deleted file mode 100644 index 548173f..0000000 --- a/guiz/Cargo.toml +++ /dev/null @@ -1,31 +0,0 @@ -[package] -authors = { workspace = true } -description = { workspace = true } -edition = { workspace = true } -license = { workspace = true } -name = "ezpz-guiz" -repository = { workspace = true } -version = "0.0.1" - -[dependencies] -chrono = { workspace = true } -connectorx = { workspace = true } -ezpz-stubz = { workspace = true } -hashbrown = { workspace = true } -polars = { workspace = true } -pyo3 = { workspace = true } -pyo3-polars = { workspace = true } -pyo3-stub-gen = { workspace = true } -pyproject-toml = { workspace = true } -serde = { workspace = true } - -[lib] -crate-type = ["cdylib", "rlib"] -name = "ezpz_guiz" - -[features] -default = ["pyo3/extension-module"] - -[[bin]] -doc = false -name = "stub_gen" diff --git a/guiz/README.md b/guiz/README.md deleted file mode 100644 index e69de29..0000000 diff --git a/guiz/ezpz.toml b/guiz/ezpz.toml deleted file mode 100644 index d0344f3..0000000 --- a/guiz/ezpz.toml +++ /dev/null @@ -1,3 +0,0 @@ -[ezpz_pluginz] -include = ["python/ezpz_guiz"] -name = "ezpz-test" diff --git a/guiz/pyproject.toml b/guiz/pyproject.toml deleted file mode 100644 index b40eb7e..0000000 --- a/guiz/pyproject.toml +++ /dev/null @@ -1,20 +0,0 @@ -[project] -authors = [{ "name" = "Jeremy Meek" }] -dependencies = ["ezpz-pluginz", "polars==1.30.0", "pyarrow==20.0.0"] -description = "" -name = "ezpz_guiz" -readme = "README.md" -requires-python = ">=3.13,<3.14" -version = "0.0.1" - - -[build-system] -build-backend = "maturin" -requires = ["maturin>=1.0,<2.0"] - -[tool.maturin] -features = ["pyo3/extension-module"] -manifest-path = "Cargo.toml" -module-name = "ezpz_guiz._ezpz_guiz" -python-packages = ["ezpz_macroz"] -python-source = "python" diff --git a/guiz/python/ezpz_guiz/_ezpz_guiz.pyi b/guiz/python/ezpz_guiz/_ezpz_guiz.pyi deleted file mode 100644 index 548f9b4..0000000 --- a/guiz/python/ezpz_guiz/_ezpz_guiz.pyi +++ /dev/null @@ -1,13 +0,0 @@ -# This file is automatically generated by pyo3_stub_gen -# ruff: noqa: E501, F401 - -import polars - -class DataFrameViewer: - def __new__(cls, py_df:polars.DataFrame) -> DataFrameViewer: ... - def view(self) -> DataFrameViewer: ... - -class LazyFrameViewer: - def __new__(cls, py_lf:polars.LazyFrame) -> LazyFrameViewer: ... - def view(self) -> LazyFrameViewer: ... - diff --git a/guiz/python/ezpz_guiz/_ezpz_guiz_macros.py b/guiz/python/ezpz_guiz/_ezpz_guiz_macros.py deleted file mode 100644 index 44c73dc..0000000 --- a/guiz/python/ezpz_guiz/_ezpz_guiz_macros.py +++ /dev/null @@ -1,10 +0,0 @@ -from ezpz_guiz._ezpz_guiz import DataFrameViewer, LazyFrameViewer -from ezpz_pluginz.register_plugin_macro import ezpz_plugin_collect - -ezpz_plugin_collect(polars_ns="DataFrame", attr_name="viewer", import_="from ezpz_guiz import _ezpz_guiz", type_hint="_ezpz_guiz.DataFrameViewer")( - DataFrameViewer -) - -ezpz_plugin_collect( - polars_ns="LazyFrame", attr_name="ezprofiler", import_="from ezpz_pluginz.test_plugin import LazyPluginImpl", type_hint="_ezpz_guiz.LazyFrameProfileViewer" -)(LazyFrameViewer) diff --git a/guiz/src/frame/mod.rs b/guiz/src/frame/mod.rs deleted file mode 100644 index 023e7e4..0000000 --- a/guiz/src/frame/mod.rs +++ /dev/null @@ -1,33 +0,0 @@ -use { - ezpz_stubz::frame::PyDfStubbed, - polars::prelude::*, - pyo3::prelude::*, - pyo3_stub_gen::{ - define_stub_info_gatherer, - derive::{gen_stub_pyclass, gen_stub_pymethods}, - }, -}; - -#[gen_stub_pyclass] -#[pyclass] -#[derive(Clone)] -pub struct DataFrameViewer { - df: DataFrame, -} - -impl DataFrameViewer {} - -#[gen_stub_pymethods] -#[pymethods] -impl DataFrameViewer { - #[new] - fn new(py_df: PyDfStubbed) -> Self { - Self { df: py_df.0.into() } - } - - fn view(&self) -> Self { - Self { df: self.df.clone() } - } -} - -define_stub_info_gatherer!(stub_info); diff --git a/guiz/src/lazy/mod.rs b/guiz/src/lazy/mod.rs deleted file mode 100644 index f6e6d66..0000000 --- a/guiz/src/lazy/mod.rs +++ /dev/null @@ -1,32 +0,0 @@ -use { - ezpz_stubz::lazy::PyLfStubbed, - polars::prelude::*, - pyo3::{PyResult, pyclass, pymethods}, - pyo3_stub_gen::{ - define_stub_info_gatherer, - derive::{gen_stub_pyclass, gen_stub_pymethods}, - }, -}; - -#[gen_stub_pyclass] -#[pyclass] -#[derive(Clone)] -pub struct LazyFrameViewer { - lf: LazyFrame, -} - -#[gen_stub_pymethods] -#[pymethods] -impl LazyFrameViewer { - #[new] - pub fn new(py_lf: PyLfStubbed) -> PyResult { - Ok(Self { lf: py_lf.0.into() }) - } - - fn view(&self) -> Self { - let _ = self.lf.clone(); - self.clone() - } -} - -define_stub_info_gatherer!(stub_info); diff --git a/guiz/src/lib.rs b/guiz/src/lib.rs deleted file mode 100644 index c529282..0000000 --- a/guiz/src/lib.rs +++ /dev/null @@ -1,19 +0,0 @@ -use {pyo3::prelude::*, pyo3_stub_gen::define_stub_info_gatherer}; - -mod frame; - -use frame::DataFrameViewer; - -mod lazy; - -use lazy::LazyFrameViewer; - -#[pymodule] -#[pyo3(name = "_ezpz_guiz")] -fn _ezpz_guiz(m: &Bound<'_, PyModule>) -> PyResult<()> { - m.add_class::()?; - m.add_class::()?; - Ok(()) -} - -define_stub_info_gatherer!(stub_info); diff --git a/guiz/t.ipynb b/guiz/t.ipynb deleted file mode 100644 index 769b105..0000000 --- a/guiz/t.ipynb +++ /dev/null @@ -1,81 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import polars as pl\n", - "\n", - "# Create a Polars DataFrame\n", - "df = pl.DataFrame({\n", - " \"Name\": [\"Alice\", \"Bob\", \"Charlie\", \"David\", \"Eva\"],\n", - " \"Age\": [25, 30, 35, 40, 22],\n", - " \"City\": [\"New York\", \"Los Angeles\", \"Chicago\", \"Houston\", \"Phoenix\"],\n", - " \"Salary\": [70000, 80000, 120000, 100000, 95000],\n", - " \"Is_Employed\": [True, False, True, True, False]\n", - "})" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "df.viewer.view(height=250,width=250,window_title=\"helo\")" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "ename": "ComputeError", - "evalue": "no data to time", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mComputeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[3], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mdf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlazy\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprofile\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/EazyPolarz/.venv/lib/python3.12/site-packages/polars/lazyframe/frame.py:1731\u001b[0m, in \u001b[0;36mLazyFrame.profile\u001b[0;34m(self, type_coercion, predicate_pushdown, projection_pushdown, simplify_expression, no_optimization, slice_pushdown, comm_subplan_elim, comm_subexpr_elim, cluster_with_columns, show_plot, truncate_nodes, figsize, streaming)\u001b[0m\n\u001b[1;32m 1716\u001b[0m cluster_with_columns \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m 1718\u001b[0m ldf \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ldf\u001b[38;5;241m.\u001b[39moptimization_toggle(\n\u001b[1;32m 1719\u001b[0m type_coercion,\n\u001b[1;32m 1720\u001b[0m predicate_pushdown,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1729\u001b[0m new_streaming\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 1730\u001b[0m )\n\u001b[0;32m-> 1731\u001b[0m df, timings \u001b[38;5;241m=\u001b[39m \u001b[43mldf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprofile\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1732\u001b[0m (df, timings) \u001b[38;5;241m=\u001b[39m wrap_df(df), wrap_df(timings)\n\u001b[1;32m 1734\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m show_plot:\n", - "\u001b[0;31mComputeError\u001b[0m: no data to time" - ] - } - ], - "source": [ - "df.lazy().profile()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.9" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/justfile b/justfile index aa4b293..b498762 100644 --- a/justfile +++ b/justfile @@ -1,33 +1,25 @@ set shell := ["bash", "-uc"] set export -set dotenv-load +set dotenv-load := true + +mod actions default: @just --choose --justfile {{justfile()}} -web: - #!/usr/bin/env bash - set -euo pipefail - dx serve --platform web -p app - -desktop: +stub-gen: #!/usr/bin/env bash set -euo pipefail - dx serve --platform desktop -p app + cargo run -p ezpz-rust-ti stub_gen -mobile: +examples: #!/usr/bin/env bash set -euo pipefail - dx serve --platform mobile -p app + rye run python3 examples/ezpz_ta/volatility.py clear: #!/usr/bin/env bash set -euo pipefail cargo clean - rm *.lock + rm -f *.lock rm -rf .venv - -stub-gen: - #!/usr/bin/env bash - set -euo pipefail - cargo run -p ezpz-guiz stub_gen diff --git a/macroz/README.md b/macroz/README.md deleted file mode 100644 index e69de29..0000000 diff --git a/plugins/ezpz-rust-ti/Cargo.toml b/plugins/ezpz-rust-ti/Cargo.toml new file mode 100644 index 0000000..c93aaeb --- /dev/null +++ b/plugins/ezpz-rust-ti/Cargo.toml @@ -0,0 +1,33 @@ +[package] +authors = ["Stephen Oketch"] +build = "build.rs" +description = "Rust technical indicators for polars (Wraps rust_ti crate)" +edition = { workspace = true } +name = "ezpz-rust-ti" +repository = { workspace = true } +version = "0.1.0" + +[lib] +crate-type = ["cdylib", "rlib"] +name = "ezpz_rust_ti" + +[dependencies] +anyhow = { version = "1.0.98", default-features = false } +approx = "0.5.1" +ezpz-stubz = { workspace = true } +polars = { workspace = true } +pyo3 = { workspace = true } +pyo3-polars = { workspace = true } +pyo3-stub-gen = { workspace = true } +rust_ti = "2.1.1" + +[build-dependencies] +pyo3-build-config = "0.25.1" + + +[features] +default = ["pyo3/auto-initialize", "pyo3/extension-module"] + +[[bin]] +doc = false +name = "stub_gen" diff --git a/plugins/ezpz-rust-ti/README.md b/plugins/ezpz-rust-ti/README.md new file mode 100644 index 0000000..ac868d1 --- /dev/null +++ b/plugins/ezpz-rust-ti/README.md @@ -0,0 +1,916 @@ +# EZPZ Technical Analysis Polars Plugin + +[![Rust](https://img.shields.io/badge/rust-1.88+-orange.svg)](https://rustlang.org) +[![Python](https://img.shields.io/badge/python-3.13+-blue.svg)](https://python.org) + +A technical analysis library for Polars, powered by Rust. Get 70+ technical indicators seamlessly integrated into your Polars workflow with full type safety and good. + +This plugin showcases how the [EZPZ](https://github.com/Summit-Sailors/EZPZ/tree/main/pluginz) plugins system works + +## Features + +- **Polars Native**: Seamlessly integrates with Polars DataFrames, LazyFrames and Series +- **70+ Indicators**: Comprehensive technical analysis toolkit +- **Type Safe**: Full type hints and IDE autocomplete support +- **Rust Powered**: Built on the [rust_ti](https://crates.io/crates/rust_ti) crate + +## Installation + +```bash +# Install EZPZ plugin system first +pip install ezpz_pluginz + +# Install technical analysis plugin +pip install ezpz-rust-ti + +# Mount the plugin +ezpz mount +``` + +## Quick Start + +```python +import numpy as np +import polars as pl + +# Random seed for reproducibility +np.random.seed(42) + +# Sample price data +n_periods = 100 +base_price = 100.0 +returns = np.random.normal(0, 0.02, n_periods) +prices = [base_price] +for ret in returns: + prices.append(prices[-1] * (1 + ret)) + +# Sample Polars DataFrame +df = pl.DataFrame({ + "high": [p * (1 + abs(np.random.normal(0, 0.01))) for p in prices[1:]], + "low": [p * (1 - abs(np.random.normal(0, 0.01))) for p in prices[1:]], + "close": prices[1:], + "volume": np.random.randint(1000, 10000, n_periods) +}) + +# Convert to LazyFrame for plugin operations +lf = df.lazy() + +# Calculate single Ulcer Index +ulcer_single = lf.volatility_ti.ulcer_index_single("close") +print(f"Single Ulcer Index: {ulcer_single:.6f}") + +# Calculate Ulcer Index series (bulk) with period=14 +ulcer_series = lf.volatility_ti.ulcer_index_bulk("close", period=14) + +# Integrate with Polars: Add Ulcer Index to DataFrame +result_df = ( + lf.collect() + .with_columns(pl.Series("ulcer_index_14", [None] * (len(df) - len(ulcer_series)) + ulcer_series.to_list())) + .select(["close", "ulcer_index_14"]) +) + +print("\nDataFrame with Ulcer Index:") +print(result_df.head(10)) +``` + +## Available Attributes + +### `basic_ti` - Basic Technical Indicators (Exposes methods from the BasicTI class) + +```python +class BasicTI: + def __new__(cls, lf: polars.LazyFrame) -> BasicTI: ... + def mean_single(self, column: builtins.str) -> builtins.float: + r""" + Calculate the arithmetic mean of all values. + """ + def median_single(self, column: builtins.str) -> builtins.float: + r""" + Calculate the median of all values. + """ + def mode_single(self, column: builtins.str) -> builtins.float: + r""" + Calculate the mode of all values. + """ + def variance_single(self, column: builtins.str) -> builtins.float: + r""" + Calculate the variance of all values. + """ + def standard_deviation_single(self, column: builtins.str) -> builtins.float: + r""" + Calculate the standard deviation of all values. + """ + def max_single(self, column: builtins.str) -> builtins.float: + r""" + Find the maximum value. + """ + def min_single(self, column: builtins.str) -> builtins.float: + r""" + Find the minimum value. + """ + def absolute_deviation_single(self, column: builtins.str, central_point: builtins.str) -> builtins.float: + r""" + Calculate the absolute deviation from a central point. + """ + def log_difference_single(self, price_t: builtins.float, price_t_1: builtins.float) -> builtins.float: + r""" + Calculate the logarithmic difference between two price points. + """ + def mean_bulk(self, column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Calculate rolling mean over a specified period. + """ + def median_bulk(self, column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Calculate rolling median over a specified period. + """ + def mode_bulk(self, column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Calculate rolling mode over a specified period. + """ + def variance_bulk(self, column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Calculate rolling variance over a specified period. + """ + def standard_deviation_bulk(self, column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Calculate rolling standard deviation over a specified period. + """ + def absolute_deviation_bulk(self, column: builtins.str, period: builtins.int, central_point: builtins.str) -> polars.Series: + r""" + Calculate rolling absolute deviation over a specified period. + """ + def log_bulk(self, column: builtins.str) -> polars.Series: + r""" + Calculate natural logarithm of all values. + """ + def log_difference_bulk(self, column: builtins.str) -> polars.Series: + r""" + Calculate logarithmic differences between consecutive values. + """ +``` + +### `candle_ti` - Candle Pattern Analysis (Exposes methods from the CandleTI class) + +```python +class CandleTI: + def __new__(cls, lf: polars.LazyFrame) -> CandleTI: ... + def moving_constant_envelopes_single(self, price_column: builtins.str, constant_model_type: builtins.str, difference: builtins.float) -> polars.DataFrame: + r""" + Moving Constant Envelopes - Creates upper and lower bands from moving constant of price + """ + def mcginley_dynamic_envelopes_single( + self, price_column: builtins.str, difference: builtins.float, previous_mcginley_dynamic: builtins.float + ) -> polars.DataFrame: + r""" + McGinley Dynamic Envelopes - Variation of moving constant envelopes using McGinley Dynamic + """ + def moving_constant_bands_single( + self, price_column: builtins.str, constant_model_type: builtins.str, deviation_model: builtins.str, deviation_multiplier: builtins.float + ) -> polars.DataFrame: + r""" + Moving Constant Bands - Extended Bollinger Bands with configurable models + """ + def mcginley_dynamic_bands_single( + self, price_column: builtins.str, deviation_model: builtins.str, deviation_multiplier: builtins.float, previous_mcginley_dynamic: builtins.float + ) -> polars.DataFrame: + r""" + McGinley Dynamic Bands - Variation of moving constant bands using McGinley Dynamic + """ + def ichimoku_cloud_single( + self, + high_column: builtins.str, + low_column: builtins.str, + close_column: builtins.str, + conversion_period: builtins.int, + base_period: builtins.int, + span_b_period: builtins.int, + ) -> polars.DataFrame: + r""" + Ichimoku Cloud - Calculates support and resistance levels + """ + def donchian_channels_single(self, high_column: builtins.str, low_column: builtins.str) -> polars.DataFrame: + r""" + Donchian Channels - Produces bands from period highs and lows + """ + def keltner_channel_single( + self, + high_column: builtins.str, + low_column: builtins.str, + close_column: builtins.str, + constant_model_type: builtins.str, + atr_constant_model_type: builtins.str, + multiplier: builtins.float, + ) -> polars.DataFrame: + r""" + Keltner Channel - Bands based on moving average and average true range + """ + def supertrend_single( + self, high_column: builtins.str, low_column: builtins.str, close_column: builtins.str, constant_model_type: builtins.str, multiplier: builtins.float + ) -> polars.Series: + r""" + Supertrend - Trend indicator showing support and resistance levels + """ + def moving_constant_envelopes_bulk( + self, price_column: builtins.str, constant_model_type: builtins.str, difference: builtins.float, period: builtins.int + ) -> polars.DataFrame: + r""" + Moving Constant Envelopes (Bulk) - Returns envelopes over time periods + """ + def mcginley_dynamic_envelopes_bulk( + self, price_column: builtins.str, difference: builtins.float, previous_mcginley_dynamic: builtins.float, period: builtins.int + ) -> polars.DataFrame: + r""" + Mcginley dynamic envelopes + """ + def moving_constant_bands_bulk( + self, + price_column: builtins.str, + constant_model_type: builtins.str, + deviation_model: builtins.str, + deviation_multiplier: builtins.float, + period: builtins.int, + ) -> polars.DataFrame: + r""" + Moving Constant Bands (Bulk) + """ + def mcginley_dynamic_bands_bulk( + self, + price_column: builtins.str, + deviation_model: builtins.str, + deviation_multiplier: builtins.float, + previous_mcginley_dynamic: builtins.float, + period: builtins.int, + ) -> polars.DataFrame: + r""" + McGinley Dynamic Bands (Bulk) + """ + def ichimoku_cloud_bulk( + self, + high_column: builtins.str, + low_column: builtins.str, + close_column: builtins.str, + conversion_period: builtins.int, + base_period: builtins.int, + span_b_period: builtins.int, + ) -> polars.DataFrame: + r""" + Ichimoku Cloud (Bulk) - Returns ichimoku components over time + """ + def donchian_channels_bulk(self, high_column: builtins.str, low_column: builtins.str, period: builtins.int) -> polars.DataFrame: + r""" + Donchian Channels (Bulk) - Returns donchian bands over time + """ + def keltner_channel_bulk( + self, + high_column: builtins.str, + low_column: builtins.str, + close_column: builtins.str, + constant_model_type: builtins.str, + atr_constant_model_type: builtins.str, + multiplier: builtins.float, + period: builtins.int, + ) -> polars.DataFrame: + r""" + Keltner Channel (Bulk) - Returns keltner bands over time + """ + def supertrend_bulk( + self, + high_column: builtins.str, + low_column: builtins.str, + close_column: builtins.str, + constant_model_type: builtins.str, + multiplier: builtins.float, + period: builtins.int, + ) -> polars.Series: + r""" + Supertrend (Bulk) - Returns supertrend values over time + """ +``` + +### `chart_trends_ti` - Chart Trend Analysis (Exposes methods from the ChartTrendsTI class) + +```python +class ChartTrendsTI: + def __new__(cls, lf: polars.LazyFrame) -> ChartTrendsTI: ... + def peaks(self, price_column: builtins.str, period: builtins.int, closest_neighbor: builtins.int) -> builtins.list[tuple[builtins.float, builtins.int]]: + r""" + Find peaks in a price series over a given period + """ + def valleys(self, price_column: builtins.str, period: builtins.int, closest_neighbor: builtins.int) -> builtins.list[tuple[builtins.float, builtins.int]]: + r""" + Find valleys in a price series over a given period + """ + def peak_trend(self, price_column: builtins.str, period: builtins.int) -> tuple[builtins.float, builtins.float]: + r""" + Calculate peak trend (linear regression on peaks) + """ + def valley_trend(self, price_column: builtins.str, period: builtins.int) -> tuple[builtins.float, builtins.float]: + r""" + Calculate valley trend (linear regression on valleys) + """ + def overall_trend(self, price_column: builtins.str) -> tuple[builtins.float, builtins.float]: + r""" + Calculate overall trend (linear regression on all prices) + """ + def break_down_trends( + self, + price_column: builtins.str, + max_outliers: builtins.int, + soft_r_squared_minimum: builtins.float, + soft_r_squared_maximum: builtins.float, + hard_r_squared_minimum: builtins.float, + hard_r_squared_maximum: builtins.float, + soft_standard_error_multiplier: builtins.float, + hard_standard_error_multiplier: builtins.float, + soft_reduced_chi_squared_multiplier: builtins.float, + hard_reduced_chi_squared_multiplier: builtins.float, + ) -> builtins.list[tuple[builtins.int, builtins.int, builtins.float, builtins.float]]: + r""" + Break down trends in a price series + """ +``` + +### `correlation_ti` - Correlation Analysis (Exposes methods from the CorrelationTI class) + +```python +class CorrelationTI: + def __new__(cls, lf: polars.LazyFrame) -> CorrelationTI: ... + def correlate_asset_prices_single( + self, price_column_a: builtins.str, price_column_b: builtins.str, constant_model_type: builtins.str, deviation_model: builtins.str + ) -> builtins.float: + r""" + Correlation between two assets - Single value calculation + Calculates correlation between prices of two assets using specified models + Returns a single correlation value for the entire price series + """ + def correlate_asset_prices_bulk( + self, price_column_a: builtins.str, price_column_b: builtins.str, constant_model_type: builtins.str, deviation_model: builtins.str, period: builtins.int + ) -> polars.Series: + r""" + Correlation between two assets - Rolling/Bulk calculation + Calculates rolling correlation between prices of two assets using specified models + Returns a series of correlation values for each period window + """ +``` + +### `ma_ti` - Moving Averages (Exposes methods from the MATI class) + +```python +class MATI: + def __new__(cls, lf: polars.LazyFrame) -> MATI: ... + def moving_average_single(self, price_column: builtins.str, moving_average_type: builtins.str) -> builtins.float: + r""" + Moving Average (Single) - Calculates a single moving average value for a series of prices + """ + def moving_average_bulk(self, price_column: builtins.str, moving_average_type: builtins.str, period: builtins.int) -> polars.Series: + r""" + Moving Average (Bulk) - Calculates moving averages over a rolling window + """ + def mcginley_dynamic_single(self, price_column: builtins.str, previous_mcginley_dynamic: builtins.float, period: builtins.int) -> builtins.float: + r""" + McGinley Dynamic (Single) - Calculates a single McGinley Dynamic value + """ + def mcginley_dynamic_bulk(self, price_column: builtins.str, previous_mcginley_dynamic: builtins.float, period: builtins.int) -> polars.Series: + r""" + McGinley Dynamic (Bulk) - Calculates McGinley Dynamic values over a series + """ + def personalised_moving_average_single( + self, price_column: builtins.str, alpha_nominator: builtins.float, alpha_denominator: builtins.float + ) -> builtins.float: + r""" + Personalised Moving Average (Single) - Calculates a single personalised moving average + """ + def personalised_moving_average_bulk( + self, price_column: builtins.str, alpha_nominator: builtins.float, alpha_denominator: builtins.float, period: builtins.int + ) -> polars.Series: + r""" + Personalised Moving Average (Bulk) - Calculates personalised moving averages over a rolling window + """ +``` + +### `momentum_ti` - Momentum Indicators (Exposes methods from the MomentumTI class) + +```python +class MomentumTI: + def __new__(cls, lf: polars.LazyFrame) -> MomentumTI: ... + def aroon_up_single(self, high_column: builtins.str) -> builtins.float: + r""" + Aroon Up indicator + """ + def aroon_down_single(self, low_column: builtins.str) -> builtins.float: + r""" + Aroon Down indicator + + Calculates the Aroon Down indicator, which measures the time since the lowest low + within a given period as a percentage. + """ + def aroon_oscillator_single(self, aroon_up: builtins.float, aroon_down: builtins.float) -> builtins.float: + r""" + Aroon Oscillator + + Calculates the Aroon Oscillator by subtracting Aroon Down from Aroon Up. + Values range from -100 to +100, indicating trend strength and direction. + """ + def aroon_indicator_single(self, high_column: builtins.str, low_column: builtins.str) -> tuple[builtins.float, builtins.float, builtins.float]: + r""" + Aroon Indicator (complete calculation) + + Calculates all three Aroon components: Aroon Up, Aroon Down, and Aroon Oscillator + in a single function call. + """ + def long_parabolic_time_price_system_single( + self, previous_sar: builtins.float, extreme_point: builtins.float, acceleration_factor: builtins.float, low: builtins.float + ) -> builtins.float: + r""" + Long Parabolic Time Price System (Parabolic SAR for long positions) + + Calculates the Parabolic SAR (Stop and Reverse) for long positions, used to determine + potential reversal points in price movement. + """ + def short_parabolic_time_price_system_single( + self, previous_sar: builtins.float, extreme_point: builtins.float, acceleration_factor: builtins.float, high: builtins.float + ) -> builtins.float: + r""" + Short Parabolic Time Price System (Parabolic SAR for short positions) + + Calculates the Parabolic SAR (Stop and Reverse) for short positions, used to determine + potential reversal points in price movement. + """ + def volume_price_trend_single( + self, price_column: builtins.str, previous_price: builtins.float, volume: builtins.float, previous_volume_price_trend: builtins.float + ) -> builtins.float: + r""" + Volume Price Trend + + Calculates the Volume Price Trend indicator, which combines price and volume + to show the relationship between volume and price changes. + """ + def true_strength_index_single( + self, price_column: builtins.str, first_constant_model: builtins.str, first_period: builtins.int, second_constant_model: builtins.str + ) -> builtins.float: + r""" + True Strength Index + + Calculates the True Strength Index, a momentum oscillator that uses price changes + smoothed by two exponential moving averages. + """ + def relative_strength_index_bulk(self, price_column: builtins.str, constant_model_type: builtins.str, period: builtins.int) -> polars.Series: + r""" + Relative Strength Index (RSI) - bulk calculation + + Calculates RSI values for an entire series of prices. RSI measures the speed and change + of price movements, oscillating between 0 and 100. + """ + def stochastic_oscillator_bulk(self, price_column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Stochastic Oscillator - bulk calculation + + Calculates the Stochastic Oscillator, which compares a security's closing price + to its price range over a given time period. + """ + def slow_stochastic_bulk(self, stochastic_column: builtins.str, constant_model_type: builtins.str, period: builtins.int) -> polars.Series: + r""" + Slow Stochastic - bulk calculation + + Calculates the Slow Stochastic by smoothing the regular Stochastic Oscillator + to reduce noise and false signals. + """ + def slowest_stochastic_bulk(self, slow_stochastic_column: builtins.str, constant_model_type: builtins.str, period: builtins.int) -> polars.Series: + r""" + Slowest Stochastic - bulk calculation + + Calculates the Slowest Stochastic by applying additional smoothing to the Slow Stochastic + for even more noise reduction. + """ + def williams_percent_r_bulk(self, high_column: builtins.str, low_column: builtins.str, close_column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Williams %R - bulk calculation + + Calculates Williams %R, a momentum indicator that measures overbought and oversold levels. + Values range from -100 to 0, where -20 and above indicates overbought, -80 and below indicates oversold. + """ + def money_flow_index_bulk(self, price_column: builtins.str, volume_column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Money Flow Index - bulk calculation + + Calculates the Money Flow Index, a volume-weighted RSI that measures buying and selling pressure. + Values range from 0 to 100, where >80 indicates overbought and <20 indicates oversold. + """ + def rate_of_change_bulk(self, price_column: builtins.str) -> polars.Series: + r""" + Rate of Change - bulk calculation + + Calculates the Rate of Change, which measures the percentage change in price + from one period to the next. + """ + def on_balance_volume_bulk(self, price_column: builtins.str, volume_column: builtins.str, previous_obv: builtins.float) -> polars.Series: + r""" + On Balance Volume (Bulk) - Calculates cumulative volume indicator + Adds volume on up days and subtracts volume on down days to measure buying and selling pressure + """ + def commodity_channel_index_bulk( + self, + price_column: builtins.str, + constant_model_type: builtins.str, + deviation_model: builtins.str, + constant_multiplier: builtins.float, + period: builtins.int, + ) -> polars.Series: + r""" + Commodity Channel Index (Bulk) - Calculates CCI over rolling periods + Measures the variation of a security's price from its statistical mean + Values typically range from -100 to +100 + """ + def mcginley_dynamic_commodity_channel_index_bulk( + self, + price_column: builtins.str, + previous_mcginley_dynamic: builtins.float, + deviation_model: builtins.str, + constant_multiplier: builtins.float, + period: builtins.int, + ) -> tuple[polars.Series, polars.Series]: + r""" + McGinley Dynamic Commodity Channel Index (Bulk) - CCI using McGinley Dynamic MA + Uses McGinley Dynamic as the moving average, which adapts to market conditions + better than traditional moving averages + """ + def macd_line_bulk( + self, price_column: builtins.str, short_period: builtins.int, short_period_model: builtins.str, long_period: builtins.int, long_period_model: builtins.str + ) -> polars.Series: + r""" + MACD Line (Bulk) - Calculates Moving Average Convergence Divergence line + Subtracts the long-period moving average from the short-period moving average + """ + def signal_line_bulk(self, macd_column: builtins.str, constant_model_type: builtins.str, period: builtins.int) -> polars.Series: + r""" + Signal Line (Bulk) - Calculates MACD Signal Line + Applies a moving average to the MACD line for generating buy/sell signals + """ + def mcginley_dynamic_macd_line_bulk( + self, + price_column: builtins.str, + short_period: builtins.int, + previous_short_mcginley: builtins.float, + long_period: builtins.int, + previous_long_mcginley: builtins.float, + ) -> polars.DataFrame: + r""" + McGinley Dynamic MACD Line (Bulk) - MACD using McGinley Dynamic moving averages + Provides better adaptation to market volatility and reduces lag compared to traditional MACD + """ + def chaikin_oscillator_bulk( + self, + high_column: builtins.str, + low_column: builtins.str, + close_column: builtins.str, + volume_column: builtins.str, + short_period: builtins.int, + long_period: builtins.int, + previous_accumulation_distribution: builtins.float, + short_period_model: builtins.str, + long_period_model: builtins.str, + ) -> tuple[polars.Series, polars.Series]: + r""" + Chaikin Oscillator (Bulk) - Applies MACD to Accumulation/Distribution line + Measures the momentum of the Accumulation/Distribution line + """ + def percentage_price_oscillator_bulk( + self, price_column: builtins.str, short_period: builtins.int, long_period: builtins.int, constant_model_type: builtins.str + ) -> polars.Series: + r""" + Percentage Price Oscillator (Bulk) - MACD expressed as percentage + Similar to MACD but expressed as a percentage for easier comparison across securities + """ + def chande_momentum_oscillator_bulk(self, price_column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Chande Momentum Oscillator (Bulk) - Measures momentum using gains and losses + Calculates the difference between sum of gains and losses over a period + Values range from -100 to +100 + """ +``` + +### `other_ti` - Other Technical Indicators (Exposes methods from the OtherTI class) + +```python +class OtherTI: + r""" + Other Technical Indicators - A collection of other analysis functions for financial data + """ + def __new__(cls, lf: polars.LazyFrame) -> OtherTI: ... + def return_on_investment_single(self, price_column: builtins.str, investment: builtins.float) -> tuple[builtins.float, builtins.float]: + r""" + Return on Investment - Calculates investment value and percentage change for a single period + Uses the first and last values from the price column as start and end prices + """ + def return_on_investment_bulk(self, price_column: builtins.str, investment: builtins.float) -> tuple[polars.Series, polars.Series]: + r""" + Return on Investment Bulk - Calculates ROI for a series of consecutive price periods + Uses the price column as price values for consecutive period calculations + """ + def true_range(self, high_column: builtins.str, low_column: builtins.str, close_column: builtins.str) -> polars.Series: + r""" + True Range - Calculates the greatest price movement for a single period + Uses the provided high/low/close columns to calculate true range + """ + def average_true_range_single( + self, high_column: builtins.str, low_column: builtins.str, close_column: builtins.str, constant_model_type: builtins.str + ) -> builtins.float: + r""" + Average True Range - Calculates the moving average of true range values for a single result + Uses the provided high/low/close columns to calculate ATR from the entire price series + """ + def average_true_range_bulk( + self, high_column: builtins.str, low_column: builtins.str, close_column: builtins.str, constant_model_type: builtins.str, period: builtins.int + ) -> polars.Series: + r""" + Average True Range Bulk - Calculates rolling ATR values over specified periods + Uses the provided high/low/close columns for rolling ATR calculations + """ + def internal_bar_strength(self, high_column: builtins.str, low_column: builtins.str, close_column: builtins.str) -> polars.Series: + r""" + Internal Bar Strength - Calculates buy/sell oscillator based on close position within high-low range + Uses the provided high/low/close columns to calculate IBS values + """ + def positivity_indicator( + self, open_column: builtins.str, close_column: builtins.str, signal_period: builtins.int, constant_model_type: builtins.str + ) -> tuple[polars.Series, polars.Series]: + r""" + Positivity Indicator - Generates trading signals based on open vs previous close comparison + Uses the provided open/close columns for signal generation + """ +``` + +### `std_ti` - Standard Technical Indicators (Exposes methods from the StandardTI class) + +```python +class StandardTI: + def __new__(cls, lf: polars.LazyFrame) -> StandardTI: ... + def sma_single(self, price_column: builtins.str) -> builtins.float: + r""" + Simple Moving Average (Single) - calculates the mean of all values in the column + """ + def sma_bulk(self, price_column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Simple Moving Average (Bulk) - calculates the mean over a rolling window + """ + def smma_single(self, price_column: builtins.str) -> builtins.float: + r""" + Smoothed Moving Average (Single) - single value calculation + """ + def smma_bulk(self, price_column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Smoothed Moving Average (Bulk) - puts more weight on recent prices + """ + def ema_single(self, price_column: builtins.str) -> builtins.float: + r""" + Exponential Moving Average (Single) - single value calculation + """ + def ema_bulk(self, price_column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Exponential Moving Average (Bulk) - puts exponentially more weight on recent prices + """ + def bollinger_bands_single(self, price_column: builtins.str) -> tuple[builtins.float, builtins.float, builtins.float]: + r""" + Bollinger Bands (Single) - single value calculation (requires exactly 20 periods) + """ + def bollinger_bands_bulk(self, price_column: builtins.str) -> polars.DataFrame: + r""" + Bollinger Bands (Bulk) - returns three series: lower band, middle (SMA), upper band + Standard period is 20 with 2 standard deviations + """ + def macd_single(self, price_column: builtins.str) -> tuple[builtins.float, builtins.float, builtins.float]: + r""" + MACD (Single) - single value calculation (requires exactly 34 periods) + """ + def macd_bulk(self, price_column: builtins.str) -> polars.DataFrame: + r""" + MACD (Bulk) - Moving Average Convergence Divergence + Returns three series: MACD line, Signal line, Histogram + Standard periods: 12, 26, 9 + """ + def rsi_single(self, price_column: builtins.str) -> builtins.float: + r""" + RSI (Single) - single value calculation (requires exactly 14 periods) + """ + def rsi_bulk(self, price_column: builtins.str) -> polars.Series: + r""" + RSI (Bulk) - Relative Strength Index + Standard period is 14 using smoothed moving average + """ +``` + +### `strength_ti` - Strength Indicators (Exposes methods from the StrengthTI class) + +```python +class StrengthTI: + def __new__(cls, lf: polars.LazyFrame) -> StrengthTI: ... + def accumulation_distribution_single( + self, + high_column: builtins.str, + low_column: builtins.str, + close_column: builtins.str, + volume_column: builtins.str, + previous_ad: typing.Optional[builtins.float], + ) -> builtins.float: + r""" + Accumulation Distribution (Single) - Shows whether the stock is being accumulated or distributed + Single value calculation using the last available values + """ + def accumulation_distribution_bulk( + self, + high_column: builtins.str, + low_column: builtins.str, + close_column: builtins.str, + volume_column: builtins.str, + previous_ad: typing.Optional[builtins.float], + ) -> polars.Series: + r""" + Accumulation Distribution (Bulk) - Shows whether the stock is being accumulated or distributed + Returns a series of accumulation/distribution values + """ + def positive_volume_index_single( + self, close_column: builtins.str, volume_column: builtins.str, previous_pvi: typing.Optional[builtins.float] + ) -> builtins.float: + r""" + Positive Volume Index (Single) - Measures volume trend strength when volume increases + Single value calculation using the last available values + """ + def positive_volume_index_bulk(self, close_column: builtins.str, volume_column: builtins.str, previous_pvi: typing.Optional[builtins.float]) -> polars.Series: + r""" + Positive Volume Index (Bulk) - Measures volume trend strength when volume increases + Returns a series of positive volume index values + """ + def negative_volume_index_single( + self, close_column: builtins.str, volume_column: builtins.str, previous_nvi: typing.Optional[builtins.float] + ) -> builtins.float: + r""" + Negative Volume Index (Single) - Measures volume trend strength when volume decreases + Single value calculation using the last available values + """ + def negative_volume_index_bulk(self, close_column: builtins.str, volume_column: builtins.str, previous_nvi: typing.Optional[builtins.float]) -> polars.Series: + r""" + Negative Volume Index (Bulk) - Measures volume trend strength when volume decreases + Returns a series of negative volume index values + """ + def relative_vigor_index_single( + self, open_column: builtins.str, high_column: builtins.str, low_column: builtins.str, close_column: builtins.str, constant_model_type: builtins.str + ) -> builtins.float: + r""" + Relative Vigor Index (Single) - Measures the strength of an asset by looking at previous prices + Single value calculation using all available values + """ + def relative_vigor_index_bulk( + self, + open_column: builtins.str, + high_column: builtins.str, + low_column: builtins.str, + close_column: builtins.str, + constant_model_type: builtins.str, + period: builtins.int, + ) -> polars.Series: + r""" + Relative Vigor Index (Bulk) - Measures the strength of an asset by looking at previous prices + Returns a series of relative vigor index values + """ +``` + +### `trend_ti` - Trend Indicators (Exposes methods from the TrendTI class) + +```python +class TrendTI: + r""" + Trend Technical Indicators - A collection of trend analysis functions for financial data + """ + def __new__(cls, lf: polars.LazyFrame) -> TrendTI: ... + def aroon_up_single(self, high_column: builtins.str) -> builtins.float: + r""" + Aroon Up (Single) - Measures the strength of upward price momentum + Calculates the percentage of time since the highest high within the series + """ + def aroon_down_single(self, low_column: builtins.str) -> builtins.float: + r""" + Aroon Down (Single) - Measures the strength of downward price momentum + Calculates the percentage of time since the lowest low within the series + """ + def aroon_oscillator_single(self, high_column: builtins.str, low_column: builtins.str) -> builtins.float: + r""" + Aroon Oscillator (Single) - Calculates the difference between Aroon Up and Aroon Down + Provides a single measure of trend direction and strength + """ + def aroon_indicator_single(self, high_column: builtins.str, low_column: builtins.str) -> tuple[builtins.float, builtins.float, builtins.float]: + r""" + Aroon Indicator (Single) - Calculates complete Aroon system in one call + Computes Aroon Up, Aroon Down, and Aroon Oscillator + """ + def true_strength_index_single( + self, price_column: builtins.str, first_constant_model: builtins.str, first_period: builtins.int, second_constant_model: builtins.str + ) -> builtins.float: + r""" + True Strength Index (Single) - Momentum oscillator using double-smoothed price changes + Filters out price noise to provide clearer momentum signals + """ + def aroon_up_bulk(self, high_column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Aroon Up (Bulk) - Calculates rolling Aroon Up indicator over specified period + Measures upward momentum strength for each period in the time series + """ + def aroon_down_bulk(self, low_column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Aroon Down (Bulk) - Calculates rolling Aroon Down indicator over specified period + Measures downward momentum strength for each period in the time series + """ + def aroon_oscillator_bulk(self, high_column: builtins.str, low_column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Aroon Oscillator (Bulk) - Calculates rolling Aroon Oscillator over specified period + Computes the difference between Aroon Up and Aroon Down for each period + """ + def aroon_indicator_bulk(self, high_column: builtins.str, low_column: builtins.str, period: builtins.int) -> polars.DataFrame: + r""" + Aroon Indicator (Bulk) - Calculates complete Aroon system for time series data + Computes Aroon Up, Aroon Down, and Aroon Oscillator for each period + """ + def parabolic_time_price_system_bulk( + self, + high_column: builtins.str, + low_column: builtins.str, + acceleration_factor_start: builtins.float, + acceleration_factor_max: builtins.float, + acceleration_factor_step: builtins.float, + start_position: builtins.str, + previous_sar: builtins.float, + ) -> polars.Series: + r""" + Parabolic Time Price System (Bulk) - Calculates Stop and Reverse points + Provides trailing stop levels for trend-following system + """ + def directional_movement_system_bulk( + self, high_column: builtins.str, low_column: builtins.str, close_column: builtins.str, period: builtins.int, constant_model_type: builtins.str + ) -> polars.DataFrame: + r""" + Directional Movement System (Bulk) - Calculates complete DMS indicators + Computes +DI, -DI, ADX, and ADXR for trend strength analysis + """ + def volume_price_trend_bulk(self, price_column: builtins.str, volume_column: builtins.str, previous_volume_price_trend: builtins.float) -> polars.Series: + r""" + Volume Price Trend (Bulk) - Combines price and volume to show momentum + Shows the relationship between price movement and volume flow + """ + def true_strength_index_bulk( + self, + price_column: builtins.str, + first_constant_model: builtins.str, + first_period: builtins.int, + second_constant_model: builtins.str, + second_period: builtins.int, + ) -> polars.Series: + r""" + True Strength Index (Bulk) - Double-smoothed momentum oscillator + Uses double-smoothed price changes to filter noise and provide clearer signals + """ +``` + +### `volatility_ti` - Volatility Indicators (Exposes methods from the VolatilityTI class) + +```python +class VolatilityTI: + def __new__(cls, lf: polars.LazyFrame) -> VolatilityTI: ... + def ulcer_index_single(self, price_column: builtins.str) -> builtins.float: + r""" + Ulcer Index (Single) - Calculates how quickly the price is able to get back to its former high + Can be used instead of standard deviation for volatility measurement + """ + def ulcer_index_bulk(self, price_column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Ulcer Index (Bulk) - Calculates rolling Ulcer Index over specified period + Returns a series of Ulcer Index values + """ + def volatility_system( + self, + high_column: builtins.str, + low_column: builtins.str, + close_column: builtins.str, + period: builtins.int, + constant_multiplier: builtins.float, + constant_model_type: builtins.str, + ) -> polars.Series: + r""" + Volatility System - Calculates Welles volatility system with Stop and Reverse (SaR) points + Uses trend analysis to determine long/short positions and calculate SaR levels + Constant multiplier typically between 2.8-3.1 (Welles used 3.0) + """ +``` + +## Note + +For more detailed API documentation, view the [stub_file](python/ezpz_rust_ti/_ezpz_rust_ti.pyi) + +## Contributing + +We welcome contributions! Please see our [Contributing Guide](CONTRIBUTING.md) for details. + +## Acknowledgments + +- [Polars](https://pola.rs/) - The amazing DataFrame library +- [PyO3](https://pyo3.rs/) - Rust bindings for Python +- [rust_ti](https://crates.io/crates/rust_ti) - Technical analysis algorithms diff --git a/plugins/ezpz-rust-ti/build.rs b/plugins/ezpz-rust-ti/build.rs new file mode 100644 index 0000000..c1f6018 --- /dev/null +++ b/plugins/ezpz-rust-ti/build.rs @@ -0,0 +1,3 @@ +fn main() { + pyo3_build_config::add_extension_module_link_args(); +} diff --git a/plugins/ezpz-rust-ti/ezpz.toml b/plugins/ezpz-rust-ti/ezpz.toml new file mode 100644 index 0000000..c3a5c5b --- /dev/null +++ b/plugins/ezpz-rust-ti/ezpz.toml @@ -0,0 +1,3 @@ +[ezpz_pluginz] +include = ["python/ezpz_rust_ti"] +name = "ez-rust-ti-test" diff --git a/plugins/ezpz-rust-ti/pyproject.toml b/plugins/ezpz-rust-ti/pyproject.toml new file mode 100644 index 0000000..ab9027e --- /dev/null +++ b/plugins/ezpz-rust-ti/pyproject.toml @@ -0,0 +1,26 @@ +[project] +authors = [{ "name" = "Stephen Oketch" }] +dependencies = ["ezpz-pluginz", "polars==1.31.0", "pyarrow==20.0.0"] +description = "Technical Indicators for Polars using RustTI" +name = "ezpz-rust-ti" +readme = "README.md" +requires-python = ">=3.13,<3.14" +version = "0.0.1" + +[build-system] +build-backend = "maturin" +requires = ["maturin>=1.0,<2.0"] + +[tool.maturin] +features = ["pyo3/extension-module"] +manifest-path = "Cargo.toml" +module-name = "ezpz_rust_ti._ezpz_rust_ti" +python-packages = ["ezpz_rust_ti._ezpz_rust_ti"] +python-source = "python" + +[tool.rye] +dev-dependencies = ["pip-audit>=2.9.0"] + + +[project.entry-points."ezpz.plugins"] +ezpz-rust-ti = "ezpz_rust_ti:register_plugin" diff --git a/plugins/ezpz-rust-ti/python/ezpz_rust_ti/__init__.py b/plugins/ezpz-rust-ti/python/ezpz_rust_ti/__init__.py new file mode 100644 index 0000000..9e96d2d --- /dev/null +++ b/plugins/ezpz-rust-ti/python/ezpz_rust_ti/__init__.py @@ -0,0 +1,27 @@ +from typing import TYPE_CHECKING, cast + +from ezpz_pluginz.registry.models import PluginMetadata, PluginMetadataInner + +if TYPE_CHECKING: + from pydantic import HttpUrl + + +def register_plugin() -> PluginMetadata: + return PluginMetadata( + name="rust-ti", + package_name="ezpz-rust-ti", + description="Rust-powered technical analysis indicators for Polars LazyFrame", + aliases=["ta", "technical-analysis", "indicators"], + version="0.1.0", + author="Summit Sailors", + category="Technical analysis", + homepage=cast("HttpUrl", "https://github.com/Summit-Sailors/EZPZ/tree/main/ezpz-rust-ti"), + metadata_=PluginMetadataInner( + tags=["polars", "indicators", "plugins"], + license="MIT", + python_version=">=3.13", + dependencies=["ezpz-pluginz", "polars==1.31.0", "pyarrow==20.0.0"], + documentation=cast("HttpUrl", "https://github.com/Summit-Sailors/EZPZ/blob/main/ezpz-rust-ti/README.md"), + support_email="oketchs702@gmail.com", + ), + ) diff --git a/plugins/ezpz-rust-ti/python/ezpz_rust_ti/_ezpz_rust_ti.pyi b/plugins/ezpz-rust-ti/python/ezpz_rust_ti/_ezpz_rust_ti.pyi new file mode 100644 index 0000000..7ff3e5e --- /dev/null +++ b/plugins/ezpz-rust-ti/python/ezpz_rust_ti/_ezpz_rust_ti.pyi @@ -0,0 +1,1789 @@ +# This file is automatically generated by pyo3_stub_gen + +import typing +import builtins + +import polars + +class BasicTI: + def __new__(cls, lf: polars.LazyFrame) -> BasicTI: ... + def mean_single(self, column: builtins.str) -> builtins.float: + r""" + Calculate the arithmetic mean of all values. + + # Parameters + - `column`: &str - Name of the column to analyze + + # Returns + f64 - The arithmetic mean + """ + def median_single(self, column: builtins.str) -> builtins.float: + r""" + Calculate the median of all values. + + # Parameters + - `column`: &str - Name of the column to analyze + + # Returns + f64 - The median value + """ + def mode_single(self, column: builtins.str) -> builtins.float: + r""" + Calculate the mode of all values. + + # Parameters + - `column`: &str - Name of the column to analyze + + # Returns + f64 - The most frequently occurring value + """ + def variance_single(self, column: builtins.str) -> builtins.float: + r""" + Calculate the variance of all values. + + # Parameters + - `column`: &str - Name of the column to analyze + + # Returns + f64 - The variance + """ + def standard_deviation_single(self, column: builtins.str) -> builtins.float: + r""" + Calculate the standard deviation of all values. + + # Parameters + - `column`: &str - Name of the column to analyze + + # Returns + f64 - The standard deviation + """ + def max_single(self, column: builtins.str) -> builtins.float: + r""" + Find the maximum value. + + # Parameters + - `column`: &str - Name of the column to analyze + + # Returns + f64 - The maximum value + """ + def min_single(self, column: builtins.str) -> builtins.float: + r""" + Find the minimum value. + + # Parameters + - `column`: &str - Name of the column to analyze + + # Returns + f64 - The minimum value + """ + def absolute_deviation_single(self, column: builtins.str, central_point: builtins.str) -> builtins.float: + r""" + Calculate the absolute deviation from a central point. + + # Parameters + - `column`: &str - Name of the column to analyze + - `central_point`: &str - Central point type ("mean", "median", etc.) + + # Returns + f64 - The absolute deviation + """ + def log_difference_single(self, price_t: builtins.float, price_t_1: builtins.float) -> builtins.float: + r""" + Calculate the logarithmic difference between two price points. + + # Parameters + - `price_t`: f64 - Current price value + - `price_t_1`: f64 - Previous price value + + # Returns + f64 - The logarithmic difference + """ + def mean_bulk(self, column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Calculate rolling mean over a specified period. + + # Parameters + - `column`: &str - Name of the column to analyze + - `period`: usize - Rolling window size + + # Returns + PySeriesStubbed - Series containing rolling mean values + """ + def median_bulk(self, column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Calculate rolling median over a specified period. + + # Parameters + - `column`: &str - Name of the column to analyze + - `period`: usize - Rolling window size + + # Returns + PySeriesStubbed - Series containing rolling median values + """ + def mode_bulk(self, column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Calculate rolling mode over a specified period. + + # Parameters + - `column`: &str - Name of the column to analyze + - `period`: usize - Rolling window size + + # Returns + PySeriesStubbed - Series containing rolling mode values + """ + def variance_bulk(self, column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Calculate rolling variance over a specified period. + + # Parameters + - `column`: &str - Name of the column to analyze + - `period`: usize - Rolling window size + + # Returns + PySeriesStubbed - Series containing rolling variance values + """ + def standard_deviation_bulk(self, column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Calculate rolling standard deviation over a specified period. + + # Parameters + - `column`: &str - Name of the column to analyze + - `period`: usize - Rolling window size + + # Returns + PySeriesStubbed - Series containing rolling standard deviation values + """ + def absolute_deviation_bulk(self, column: builtins.str, period: builtins.int, central_point: builtins.str) -> polars.Series: + r""" + Calculate rolling absolute deviation over a specified period. + + # Parameters + - `column`: &str - Name of the column to analyze + - `period`: usize - Rolling window size + - `central_point`: &str - Central point type ("mean", "median", etc.) + + # Returns + PySeriesStubbed - Series containing rolling absolute deviation values + """ + def log_bulk(self, column: builtins.str) -> polars.Series: + r""" + Calculate natural logarithm of all values. + + # Parameters + - `column`: &str - Name of the column to analyze + + # Returns + PySeriesStubbed - Series containing natural logarithm values + """ + def log_difference_bulk(self, column: builtins.str) -> polars.Series: + r""" + Calculate logarithmic differences between consecutive values. + + # Parameters + - `column`: &str - Name of the column to analyze + + # Returns + PySeriesStubbed - Series containing logarithmic difference values + """ + +class CandleTI: + def __new__(cls, lf: polars.LazyFrame) -> CandleTI: ... + def moving_constant_envelopes_single(self, price_column: builtins.str, constant_model_type: builtins.str, difference: builtins.float) -> polars.DataFrame: + r""" + Moving Constant Envelopes - Creates upper and lower bands from moving constant of price + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `constant_model_type`: &str - Type of moving average (e.g., "sma", "ema", "wma") + - `difference`: f64 - Fixed difference value to create envelope bands + + # Returns + DataFrame with columns: + - `lower_envelope`: f64 - Lower envelope band (middle - difference) + - `middle_envelope`: f64 - Middle line (moving average) + - `upper_envelope`: f64 - Upper envelope band (middle + difference) + """ + def mcginley_dynamic_envelopes_single( + self, price_column: builtins.str, difference: builtins.float, previous_mcginley_dynamic: builtins.float + ) -> polars.DataFrame: + r""" + McGinley Dynamic Envelopes - Variation of moving constant envelopes using McGinley Dynamic + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `difference`: f64 - Fixed difference value to create envelope bands + - `previous_mcginley_dynamic`: f64 - Previous McGinley Dynamic value for calculation + + # Returns + DataFrame with columns: + - `lower_envelope`: f64 - Lower envelope band (McGinley Dynamic - difference) + - `mcginley_dynamic`: f64 - McGinley Dynamic value + - `upper_envelope`: f64 - Upper envelope band (McGinley Dynamic + difference) + """ + def moving_constant_bands_single( + self, price_column: builtins.str, constant_model_type: builtins.str, deviation_model: builtins.str, deviation_multiplier: builtins.float + ) -> polars.DataFrame: + r""" + Moving Constant Bands - Extended Bollinger Bands with configurable models + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `constant_model_type`: &str - Type of moving average for center line (e.g., "sma", "ema", "wma") + - `deviation_model`: &str - Type of deviation calculation (e.g., "std", "mad") + - `deviation_multiplier`: f64 - Multiplier for the deviation to create bands + + # Returns + DataFrame with columns: + - `lower_band`: f64 - Lower band (moving average - deviation * multiplier) + - `middle_band`: f64 - Middle band (moving average) + - `upper_band`: f64 - Upper band (moving average + deviation * multiplier) + """ + def mcginley_dynamic_bands_single( + self, price_column: builtins.str, deviation_model: builtins.str, deviation_multiplier: builtins.float, previous_mcginley_dynamic: builtins.float + ) -> polars.DataFrame: + r""" + McGinley Dynamic Bands - Variation of moving constant bands using McGinley Dynamic + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `deviation_model`: &str - Type of deviation calculation (e.g., "std", "mad") + - `deviation_multiplier`: f64 - Multiplier for the deviation to create bands + - `previous_mcginley_dynamic`: f64 - Previous McGinley Dynamic value for calculation + + # Returns + DataFrame with columns: + - `lower_band`: f64 - Lower band (McGinley Dynamic - deviation * multiplier) + - `mcginley_dynamic`: f64 - McGinley Dynamic value + - `upper_band`: f64 - Upper band (McGinley Dynamic + deviation * multiplier) + """ + def ichimoku_cloud_single( + self, + high_column: builtins.str, + low_column: builtins.str, + close_column: builtins.str, + conversion_period: builtins.int, + base_period: builtins.int, + span_b_period: builtins.int, + ) -> polars.DataFrame: + r""" + Ichimoku Cloud - Calculates support and resistance levels + + # Parameters + - `high_column`: &str - Name of the high price column + - `low_column`: &str - Name of the low price column + - `close_column`: &str - Name of the close price column + - `conversion_period`: usize - Period for conversion line calculation (typically 9) + - `base_period`: usize - Period for base line calculation (typically 26) + - `span_b_period`: usize - Period for leading span B calculation (typically 52) + + # Returns + DataFrame with columns: + - `leading_span_a`: f64 - Leading Span A (future support/resistance) + - `leading_span_b`: f64 - Leading Span B (future support/resistance) + - `base_line`: f64 - Base Line (Kijun-sen) + - `conversion_line`: f64 - Conversion Line (Tenkan-sen) + - `lagged_price`: f64 - Lagging Span (Chikou Span) + """ + def donchian_channels_single(self, high_column: builtins.str, low_column: builtins.str) -> polars.DataFrame: + r""" + Donchian Channels - Produces bands from period highs and lows + + # Parameters + - `high_column`: &str - Name of the high price column + - `low_column`: &str - Name of the low price column + + # Returns + DataFrame with columns: + - `donchian_lower`: f64 - Lower channel (lowest low over period) + - `donchian_middle`: f64 - Middle channel (average of upper and lower) + - `donchian_upper`: f64 - Upper channel (highest high over period) + """ + def keltner_channel_single( + self, + high_column: builtins.str, + low_column: builtins.str, + close_column: builtins.str, + constant_model_type: builtins.str, + atr_constant_model_type: builtins.str, + multiplier: builtins.float, + ) -> polars.DataFrame: + r""" + Keltner Channel - Bands based on moving average and average true range + + # Parameters + - `high_column`: &str - Name of the high price column + - `low_column`: &str - Name of the low price column + - `close_column`: &str - Name of the close price column + - `constant_model_type`: &str - Type of moving average for center line (e.g., "sma", "ema", "wma") + - `atr_constant_model_type`: &str - Type of moving average for ATR calculation (e.g., "sma", "ema", "wma") + - `multiplier`: f64 - Multiplier for the ATR to create channel width + + # Returns + DataFrame with columns: + - `keltner_lower`: f64 - Lower channel (moving average - ATR * multiplier) + - `keltner_middle`: f64 - Middle channel (moving average) + - `keltner_upper`: f64 - Upper channel (moving average + ATR * multiplier) + """ + def supertrend_single( + self, high_column: builtins.str, low_column: builtins.str, close_column: builtins.str, constant_model_type: builtins.str, multiplier: builtins.float + ) -> polars.Series: + r""" + Supertrend - Trend indicator showing support and resistance levels + + # Parameters + - `high_column`: &str - Name of the high price column + - `low_column`: &str - Name of the low price column + - `close_column`: &str - Name of the close price column + - `constant_model_type`: &str - Type of moving average for ATR calculation (e.g., "sma", "ema", "wma") + - `multiplier`: f64 - Multiplier for the ATR to determine trend sensitivity + + # Returns + Series containing: + - `supertrend`: f64 - Supertrend value (support/resistance level based on trend direction) + """ + def moving_constant_envelopes_bulk( + self, price_column: builtins.str, constant_model_type: builtins.str, difference: builtins.float, period: builtins.int + ) -> polars.DataFrame: + r""" + Moving Constant Envelopes (Bulk) - Returns envelopes over time periods + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `constant_model_type`: &str - Type of moving average (e.g., "sma", "ema", "wma") + - `difference`: f64 - Fixed difference value to create envelope bands + - `period`: usize - Rolling window period for calculations + + # Returns + DataFrame with columns: + - `lower_envelope`: Vec - Time series of lower envelope bands + - `middle_envelope`: Vec - Time series of middle lines (moving averages) + - `upper_envelope`: Vec - Time series of upper envelope bands + """ + def mcginley_dynamic_envelopes_bulk( + self, price_column: builtins.str, difference: builtins.float, previous_mcginley_dynamic: builtins.float, period: builtins.int + ) -> polars.DataFrame: + r""" + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `difference`: f64 - Fixed difference value to create envelope bands + - `previous_mcginley_dynamic`: f64 - Initial McGinley Dynamic value for calculation + - `period`: usize - Rolling window period for calculations + + # Returns + DataFrame with columns: + - `lower_envelope`: Vec - Time series of lower envelope bands + - `mcginley_dynamic`: Vec - Time series of McGinley Dynamic values + - `upper_envelope`: Vec - Time series of upper envelope bands + """ + def moving_constant_bands_bulk( + self, + price_column: builtins.str, + constant_model_type: builtins.str, + deviation_model: builtins.str, + deviation_multiplier: builtins.float, + period: builtins.int, + ) -> polars.DataFrame: + r""" + Moving Constant Bands (Bulk) + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `constant_model_type`: &str - Type of moving average for center line (e.g., "sma", "ema", "wma") + - `deviation_model`: &str - Type of deviation calculation (e.g., "std", "mad") + - `deviation_multiplier`: f64 - Multiplier for the deviation to create bands + - `period`: usize - Rolling window period for calculations + + # Returns + DataFrame with columns: + - `lower_band`: Vec - Time series of lower bands + - `middle_band`: Vec - Time series of middle bands (moving averages) + - `upper_band`: Vec - Time series of upper bands + """ + def mcginley_dynamic_bands_bulk( + self, + price_column: builtins.str, + deviation_model: builtins.str, + deviation_multiplier: builtins.float, + previous_mcginley_dynamic: builtins.float, + period: builtins.int, + ) -> polars.DataFrame: + r""" + McGinley Dynamic Bands (Bulk) + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `deviation_model`: &str - Type of deviation calculation (e.g., "std", "mad") + - `deviation_multiplier`: f64 - Multiplier for the deviation to create bands + - `previous_mcginley_dynamic`: f64 - Initial McGinley Dynamic value for calculation + - `period`: usize - Rolling window period for calculations + + # Returns + DataFrame with columns: + - `lower_band`: Vec - Time series of lower bands + - `mcginley_dynamic`: Vec - Time series of McGinley Dynamic values + - `upper_band`: Vec - Time series of upper bands + """ + def ichimoku_cloud_bulk( + self, + high_column: builtins.str, + low_column: builtins.str, + close_column: builtins.str, + conversion_period: builtins.int, + base_period: builtins.int, + span_b_period: builtins.int, + ) -> polars.DataFrame: + r""" + Ichimoku Cloud (Bulk) - Returns ichimoku components over time + + # Parameters + - `high_column`: &str - Name of the high price column + - `low_column`: &str - Name of the low price column + - `close_column`: &str - Name of the close price column + - `conversion_period`: usize - Period for conversion line calculation (typically 9) + - `base_period`: usize - Period for base line calculation (typically 26) + - `span_b_period`: usize - Period for leading span B calculation (typically 52) + + # Returns + DataFrame with columns: + - `leading_span_a`: Vec - Time series of Leading Span A values + - `leading_span_b`: Vec - Time series of Leading Span B values + - `base_line`: Vec - Time series of Base Line (Kijun-sen) values + - `conversion_line`: Vec - Time series of Conversion Line (Tenkan-sen) values + - `lagged_price`: Vec - Time series of Lagging Span (Chikou Span) values + """ + def donchian_channels_bulk(self, high_column: builtins.str, low_column: builtins.str, period: builtins.int) -> polars.DataFrame: + r""" + Donchian Channels (Bulk) - Returns donchian bands over time + + # Parameters + - `high_column`: &str - Name of the high price column + - `low_column`: &str - Name of the low price column + - `period`: usize - Rolling window period for channel calculation + + # Returns + DataFrame with columns: + - `lower_band`: Vec - Time series of lower channels (lowest lows) + - `middle_band`: Vec - Time series of middle channels (averages) + - `upper_band`: Vec - Time series of upper channels (highest highs) + """ + def keltner_channel_bulk( + self, + high_column: builtins.str, + low_column: builtins.str, + close_column: builtins.str, + constant_model_type: builtins.str, + atr_constant_model_type: builtins.str, + multiplier: builtins.float, + period: builtins.int, + ) -> polars.DataFrame: + r""" + Keltner Channel (Bulk) - Returns keltner bands over time + + # Parameters + - `high_column`: &str - Name of the high price column + - `low_column`: &str - Name of the low price column + - `close_column`: &str - Name of the close price column + - `constant_model_type`: &str - Type of moving average for center line (e.g., "sma", "ema", "wma") + - `atr_constant_model_type`: &str - Type of moving average for ATR calculation (e.g., "sma", "ema", "wma") + - `multiplier`: f64 - Multiplier for the ATR to create channel width + - `period`: usize - Rolling window period for calculations + + # Returns + DataFrame with columns: + - `lower_band`: Vec - Time series of lower channels + - `middle_band`: Vec - Time series of middle channels (moving averages) + - `upper_band`: Vec - Time series of upper channels + """ + def supertrend_bulk( + self, + high_column: builtins.str, + low_column: builtins.str, + close_column: builtins.str, + constant_model_type: builtins.str, + multiplier: builtins.float, + period: builtins.int, + ) -> polars.Series: + r""" + Supertrend (Bulk) - Returns supertrend values over time + + # Parameters + - `high_column`: &str - Name of the high price column + - `low_column`: &str - Name of the low price column + - `close_column`: &str - Name of the close price column + - `constant_model_type`: &str - Type of moving average for ATR calculation (e.g., "sma", "ema", "wma") + - `multiplier`: f64 - Multiplier for the ATR to determine trend sensitivity + - `period`: usize - Rolling window period for ATR calculation + + # Returns + Series containing: + - `supertrend`: Vec - Time series of supertrend values (support/resistance levels) + """ + +class ChartTrendsTI: + def __new__(cls, lf: polars.LazyFrame) -> ChartTrendsTI: ... + def peaks(self, price_column: builtins.str, period: builtins.int, closest_neighbor: builtins.int) -> builtins.list[tuple[builtins.float, builtins.int]]: + r""" + Find peaks in a price series over a given period + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `period`: usize - Period length for peak detection + - `closest_neighbor`: usize - Minimum distance between peaks + + # Returns + Vec<(f64, usize)> - List of tuples containing: + - `peak_value`: The price value at the peak + - `peak_index`: The index position of the peak in the series + """ + def valleys(self, price_column: builtins.str, period: builtins.int, closest_neighbor: builtins.int) -> builtins.list[tuple[builtins.float, builtins.int]]: + r""" + Find valleys in a price series over a given period + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `period`: usize - Period length for valley detection + - `closest_neighbor`: usize - Minimum distance between valleys + + # Returns + Vec<(f64, usize)> - List of tuples containing: + - `valley_value`: The price value at the valley + - `valley_index`: The index position of the valley in the series + """ + def peak_trend(self, price_column: builtins.str, period: builtins.int) -> tuple[builtins.float, builtins.float]: + r""" + Calculate peak trend (linear regression on peaks) + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `period`: usize - Period length for peak detection + + # Returns + Tuple of (slope: f64, intercept: f64) + - `slope`: The slope of the linear regression line through peaks + - `intercept`: The y-intercept of the linear regression line + """ + def valley_trend(self, price_column: builtins.str, period: builtins.int) -> tuple[builtins.float, builtins.float]: + r""" + Calculate valley trend (linear regression on valleys) + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `period`: usize - Period length for valley detection + + # Returns + Tuple of (slope: f64, intercept: f64) + - `slope`: The slope of the linear regression line through valleys + - `intercept`: The y-intercept of the linear regression line + """ + def overall_trend(self, price_column: builtins.str) -> tuple[builtins.float, builtins.float]: + r""" + Calculate overall trend (linear regression on all prices) + + # Parameters + - `price_column`: &str - Name of the price column to analyze + + # Returns + Tuple of (slope: f64, intercept: f64) + - `slope`: The slope of the linear regression line through all price points + - `intercept`: The y-intercept of the linear regression line + """ + def break_down_trends( + self, + price_column: builtins.str, + max_outliers: builtins.int, + soft_r_squared_minimum: builtins.float, + soft_r_squared_maximum: builtins.float, + hard_r_squared_minimum: builtins.float, + hard_r_squared_maximum: builtins.float, + soft_standard_error_multiplier: builtins.float, + hard_standard_error_multiplier: builtins.float, + soft_reduced_chi_squared_multiplier: builtins.float, + hard_reduced_chi_squared_multiplier: builtins.float, + ) -> builtins.list[tuple[builtins.int, builtins.int, builtins.float, builtins.float]]: + r""" + Break down trends in a price series + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `max_outliers`: usize - Maximum number of outliers allowed + - `soft_r_squared_minimum`: f64 - Soft minimum threshold for R-squared value + - `soft_r_squared_maximum`: f64 - Soft maximum threshold for R-squared value + - `hard_r_squared_minimum`: f64 - Hard minimum threshold for R-squared value + - `hard_r_squared_maximum`: f64 - Hard maximum threshold for R-squared value + - `soft_standard_error_multiplier`: f64 - Soft multiplier for standard error threshold + - `hard_standard_error_multiplier`: f64 - Hard multiplier for standard error threshold + - `soft_reduced_chi_squared_multiplier`: f64 - Soft multiplier for reduced chi-squared threshold + - `hard_reduced_chi_squared_multiplier`: f64 - Hard multiplier for reduced chi-squared threshold + + # Returns + Vec<(usize, usize, f64, f64)> - List of tuples containing: + - `start_index`: Starting index of the trend segment + - `end_index`: Ending index of the trend segment + - `slope`: The slope of the linear regression for this trend segment + - `intercept`: The y-intercept of the linear regression for this trend segment + """ + +class CorrelationTI: + def __new__(cls, lf: polars.LazyFrame) -> CorrelationTI: ... + def correlate_asset_prices_single( + self, price_column_a: builtins.str, price_column_b: builtins.str, constant_model_type: builtins.str, deviation_model: builtins.str + ) -> builtins.float: + r""" + Correlation between two assets - Single value calculation + Calculates correlation between prices of two assets using specified models + Returns a single correlation value for the entire price series + + # Parameters + - `price_column_a`: &str - Name of the first asset's price column + - `price_column_b`: &str - Name of the second asset's price column + - `constant_model_type`: &str - Type of constant model to use for correlation calculation + - `deviation_model`: &str - Type of deviation model to use for correlation calculation + + # Returns + f64 - Single correlation coefficient between the two asset price series + """ + def correlate_asset_prices_bulk( + self, price_column_a: builtins.str, price_column_b: builtins.str, constant_model_type: builtins.str, deviation_model: builtins.str, period: builtins.int + ) -> polars.Series: + r""" + Correlation between two assets - Rolling/Bulk calculation + Calculates rolling correlation between prices of two assets using specified models + Returns a series of correlation values for each period window + + # Parameters + - `price_column_a`: &str - Name of the first asset's price column + - `price_column_b`: &str - Name of the second asset's price column + - `constant_model_type`: &str - Type of constant model to use for correlation calculation + - `deviation_model`: &str - Type of deviation model to use for correlation calculation + - `period`: usize - Rolling window size for correlation calculation + + # Returns + PySeriesStubbed - Series containing rolling correlation coefficients for each period window with name "correlation" + """ + +class MATI: + def __new__(cls, lf: polars.LazyFrame) -> MATI: ... + def moving_average_single(self, price_column: builtins.str, moving_average_type: builtins.str) -> builtins.float: + r""" + Moving Average (Single) - Calculates a single moving average value for a series of prices + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `moving_average_type`: &str - Type of moving average ("simple", "exponential", "smoothed") + + # Returns + f64 - Single moving average value + """ + def moving_average_bulk(self, price_column: builtins.str, moving_average_type: builtins.str, period: builtins.int) -> polars.Series: + r""" + Moving Average (Bulk) - Calculates moving averages over a rolling window + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `moving_average_type`: &str - Type of moving average ("simple", "exponential", "smoothed") + - `period`: usize - Period over which to calculate the moving average + + # Returns + PySeriesStubbed - Series of moving average values with name "moving_average" + """ + def mcginley_dynamic_single(self, price_column: builtins.str, previous_mcginley_dynamic: builtins.float, period: builtins.int) -> builtins.float: + r""" + McGinley Dynamic (Single) - Calculates a single McGinley Dynamic value + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `previous_mcginley_dynamic`: f64 - Previous McGinley Dynamic value (use 0.0 if none) + - `period`: usize - Period for calculation + + # Returns + f64 - Single McGinley Dynamic value + """ + def mcginley_dynamic_bulk(self, price_column: builtins.str, previous_mcginley_dynamic: builtins.float, period: builtins.int) -> polars.Series: + r""" + McGinley Dynamic (Bulk) - Calculates McGinley Dynamic values over a series + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `previous_mcginley_dynamic`: f64 - Previous McGinley Dynamic value (use 0.0 if none) + - `period`: usize - Period for calculation + + # Returns + PySeriesStubbed - Series of McGinley Dynamic values with name "mcginley_dynamic" + """ + def personalised_moving_average_single( + self, price_column: builtins.str, alpha_nominator: builtins.float, alpha_denominator: builtins.float + ) -> builtins.float: + r""" + Personalised Moving Average (Single) - Calculates a single personalised moving average + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `alpha_nominator`: f64 - Alpha nominator value + - `alpha_denominator`: f64 - Alpha denominator value + + # Returns + f64 - Single personalised moving average value + """ + def personalised_moving_average_bulk( + self, price_column: builtins.str, alpha_nominator: builtins.float, alpha_denominator: builtins.float, period: builtins.int + ) -> polars.Series: + r""" + Personalised Moving Average (Bulk) - Calculates personalised moving averages over a rolling window + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `alpha_nominator`: f64 - Alpha nominator value + - `alpha_denominator`: f64 - Alpha denominator value + - `period`: usize - Period over which to calculate the moving average + + # Returns + PySeriesStubbed - Series of personalised moving average values with name "personalised_moving_average" + """ + +class MomentumTI: + def __new__(cls, lf: polars.LazyFrame) -> MomentumTI: ... + def aroon_up_single(self, high_column: builtins.str) -> builtins.float: + r""" + Aroon Up indicator + + Calculates the Aroon Up indicator, which measures the time since the highest high + within a given period as a percentage. + + # Parameters + * `high_column` - &str name of the column containing high price values + + # Returns + * `PyResult` - The Aroon Up value (0-100), where higher values indicate recent highs + """ + def aroon_down_single(self, low_column: builtins.str) -> builtins.float: + r""" + Aroon Down indicator + + Calculates the Aroon Down indicator, which measures the time since the lowest low + within a given period as a percentage. + + # Parameters + * `low_column` - &str name of the column containing low price values + + # Returns + * `PyResult` - The Aroon Down value (0-100), where higher values indicate recent lows + """ + def aroon_oscillator_single(self, aroon_up: builtins.float, aroon_down: builtins.float) -> builtins.float: + r""" + Aroon Oscillator + + Calculates the Aroon Oscillator by subtracting Aroon Down from Aroon Up. + Values range from -100 to +100, indicating trend strength and direction. + + # Parameters + * `aroon_up` - f64 value of Aroon Up indicator (0-100) + * `aroon_down` - f64 value of Aroon Down indicator (0-100) + + # Returns + * `PyResult` - The Aroon Oscillator value (-100 to +100) + """ + def aroon_indicator_single(self, high_column: builtins.str, low_column: builtins.str) -> tuple[builtins.float, builtins.float, builtins.float]: + r""" + Aroon Indicator (complete calculation) + + Calculates all three Aroon components: Aroon Up, Aroon Down, and Aroon Oscillator + in a single function call. + + # Parameters + * `high_column` - &str name of the column containing high price values + * `low_column` - &str name of the column containing low price values + + # Returns + * `PyResult<(f64, f64, f64)>` - Tuple containing (aroon_up, aroon_down, aroon_oscillator) + """ + def long_parabolic_time_price_system_single( + self, previous_sar: builtins.float, extreme_point: builtins.float, acceleration_factor: builtins.float, low: builtins.float + ) -> builtins.float: + r""" + Long Parabolic Time Price System (Parabolic SAR for long positions) + + Calculates the Parabolic SAR (Stop and Reverse) for long positions, used to determine + potential reversal points in price movement. + + # Parameters + * `previous_sar` - f64 value of the previous SAR + * `extreme_point` - f64 value of the extreme point (highest high for long positions) + * `acceleration_factor` - f64 acceleration factor (typically starts at 0.02) + * `low` - f64 current period's low price + + # Returns + * `PyResult` - The calculated SAR value for long positions + """ + def short_parabolic_time_price_system_single( + self, previous_sar: builtins.float, extreme_point: builtins.float, acceleration_factor: builtins.float, high: builtins.float + ) -> builtins.float: + r""" + Short Parabolic Time Price System (Parabolic SAR for short positions) + + Calculates the Parabolic SAR (Stop and Reverse) for short positions, used to determine + potential reversal points in price movement. + + # Parameters + * `previous_sar` - f64 value of the previous SAR + * `extreme_point` - f64 value of the extreme point (lowest low for short positions) + * `acceleration_factor` - f64 acceleration factor (typically starts at 0.02) + * `high` - f64 current period's high price + + # Returns + * `PyResult` - The calculated SAR value for short positions + """ + def volume_price_trend_single( + self, price_column: builtins.str, previous_price: builtins.float, volume: builtins.float, previous_volume_price_trend: builtins.float + ) -> builtins.float: + r""" + Volume Price Trend + + Calculates the Volume Price Trend indicator, which combines price and volume + to show the relationship between volume and price changes. + + # Parameters + * `price_column` - &str name of the column containing price values + * `previous_price` - f64 previous period's price + * `volume` - f64 current period's volume + * `previous_volume_price_trend` - f64 previous VPT value + + # Returns + * `PyResult` - The calculated Volume Price Trend value + """ + def true_strength_index_single( + self, price_column: builtins.str, first_constant_model: builtins.str, first_period: builtins.int, second_constant_model: builtins.str + ) -> builtins.float: + r""" + True Strength Index + + Calculates the True Strength Index, a momentum oscillator that uses price changes + smoothed by two exponential moving averages. + + # Parameters + * `price_column` - &str name of the column containing price values + * `first_constant_model` - &str smoothing model for first smoothing ("sma", "ema", etc.) + * `first_period` - usize period for first smoothing + * `second_constant_model` - &str smoothing model for second smoothing ("sma", "ema", etc.) + + # Returns + * `PyResult` - The True Strength Index value (typically ranges from -100 to +100) + """ + def relative_strength_index_bulk(self, price_column: builtins.str, constant_model_type: builtins.str, period: builtins.int) -> polars.Series: + r""" + Relative Strength Index (RSI) - bulk calculation + + Calculates RSI values for an entire series of prices. RSI measures the speed and change + of price movements, oscillating between 0 and 100. + + # Parameters + * `price_column` - &str name of the column containing price values + * `constant_model_type` - &str smoothing model ("sma", "ema", etc.) + * `period` - usize calculation period (commonly 14) + + # Returns + * `PyResult` - Series named "rsi" containing RSI values (0-100) + """ + def stochastic_oscillator_bulk(self, price_column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Stochastic Oscillator - bulk calculation + + Calculates the Stochastic Oscillator, which compares a security's closing price + to its price range over a given time period. + + # Parameters + * `price_column` - &str name of the column containing price values + * `period` - usize lookback period for calculation + + # Returns + * `PyResult` - Series named "stochastic" containing oscillator values (0-100) + """ + def slow_stochastic_bulk(self, stochastic_column: builtins.str, constant_model_type: builtins.str, period: builtins.int) -> polars.Series: + r""" + Slow Stochastic - bulk calculation + + Calculates the Slow Stochastic by smoothing the regular Stochastic Oscillator + to reduce noise and false signals. + + # Parameters + * `stochastic_column` - &str name of the column containing Stochastic Oscillator values + * `constant_model_type` - &str smoothing model ("sma", "ema", etc.) + * `period` - usize smoothing period + + # Returns + * `PyResult` - Series named "slow_stochastic" containing smoothed values (0-100) + """ + def slowest_stochastic_bulk(self, slow_stochastic_column: builtins.str, constant_model_type: builtins.str, period: builtins.int) -> polars.Series: + r""" + Slowest Stochastic - bulk calculation + + Calculates the Slowest Stochastic by applying additional smoothing to the Slow Stochastic + for even more noise reduction. + + # Parameters + * `slow_stochastic_column` - &str name of the column containing Slow Stochastic values + * `constant_model_type` - &str smoothing model ("sma", "ema", etc.) + * `period` - usize smoothing period + + # Returns + * `PyResult` - Series named "slowest_stochastic" containing double-smoothed values (0-100) + """ + def williams_percent_r_bulk(self, high_column: builtins.str, low_column: builtins.str, close_column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Williams %R - bulk calculation + + Calculates Williams %R, a momentum indicator that measures overbought and oversold levels. + Values range from -100 to 0, where -20 and above indicates overbought, -80 and below indicates oversold. + + # Parameters + * `high_column` - &str name of the column containing high price values + * `low_column` - &str name of the column containing low price values + * `close_column` - &str name of the column containing close price values + * `period` - usize lookback period for calculation + + # Returns + * `PyResult` - Series named "williams_r" containing Williams %R values (-100 to 0) + """ + def money_flow_index_bulk(self, price_column: builtins.str, volume_column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Money Flow Index - bulk calculation + + Calculates the Money Flow Index, a volume-weighted RSI that measures buying and selling pressure. + Values range from 0 to 100, where >80 indicates overbought and <20 indicates oversold. + + # Parameters + * `price_column` - &str name of the column containing price values + * `volume_column` - &str name of the column containing volume values + * `period` - usize calculation period (commonly 14) + + # Returns + * `PyResult` - Series named "mfi" containing Money Flow Index values (0-100) + """ + def rate_of_change_bulk(self, price_column: builtins.str) -> polars.Series: + r""" + Rate of Change - bulk calculation + + Calculates the Rate of Change, which measures the percentage change in price + from one period to the next. + + # Parameters + * `price_column` - &str name of the column containing price values + + # Returns + * `PyResult` - Series named "roc" containing rate of change values as percentages + """ + def on_balance_volume_bulk(self, price_column: builtins.str, volume_column: builtins.str, previous_obv: builtins.float) -> polars.Series: + r""" + On Balance Volume (Bulk) - Calculates cumulative volume indicator + Adds volume on up days and subtracts volume on down days to measure buying and selling pressure + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `volume_column`: &str - Name of the volume column + - `previous_obv`: f64 - Starting OBV value (typically 0) + + # Returns + PySeriesStubbed - Series of OBV values with name "obv" + """ + def commodity_channel_index_bulk( + self, + price_column: builtins.str, + constant_model_type: builtins.str, + deviation_model: builtins.str, + constant_multiplier: builtins.float, + period: builtins.int, + ) -> polars.Series: + r""" + Commodity Channel Index (Bulk) - Calculates CCI over rolling periods + Measures the variation of a security's price from its statistical mean + Values typically range from -100 to +100 + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `constant_model_type`: &str - Model for calculating moving average ("sma", "ema", etc.) + - `deviation_model`: &str - Model for calculating deviation ("mad", "std", etc.) + - `constant_multiplier`: f64 - Multiplier constant (typically 0.015) + - `period`: usize - Calculation period (commonly 20) + + # Returns + PySeriesStubbed - Series of CCI values with name "cci" + """ + def mcginley_dynamic_commodity_channel_index_bulk( + self, + price_column: builtins.str, + previous_mcginley_dynamic: builtins.float, + deviation_model: builtins.str, + constant_multiplier: builtins.float, + period: builtins.int, + ) -> tuple[polars.Series, polars.Series]: + r""" + McGinley Dynamic Commodity Channel Index (Bulk) - CCI using McGinley Dynamic MA + Uses McGinley Dynamic as the moving average, which adapts to market conditions + better than traditional moving averages + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `previous_mcginley_dynamic`: f64 - Initial McGinley Dynamic value + - `deviation_model`: &str - Model for calculating deviation ("mad", "std", etc.) + - `constant_multiplier`: f64 - Multiplier constant (typically 0.015) + - `period`: usize - Calculation period + + # Returns + (PySeriesStubbed, PySeriesStubbed) - Tuple containing (CCI series, McGinley Dynamic series) + """ + def macd_line_bulk( + self, price_column: builtins.str, short_period: builtins.int, short_period_model: builtins.str, long_period: builtins.int, long_period_model: builtins.str + ) -> polars.Series: + r""" + MACD Line (Bulk) - Calculates Moving Average Convergence Divergence line + Subtracts the long-period moving average from the short-period moving average + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `short_period`: usize - Period for short moving average (commonly 12) + - `short_period_model`: &str - Model for short MA ("sma", "ema", etc.) + - `long_period`: usize - Period for long moving average (commonly 26) + - `long_period_model`: &str - Model for long MA ("sma", "ema", etc.) + + # Returns + PySeriesStubbed - Series of MACD line values with name "macd" + """ + def signal_line_bulk(self, macd_column: builtins.str, constant_model_type: builtins.str, period: builtins.int) -> polars.Series: + r""" + Signal Line (Bulk) - Calculates MACD Signal Line + Applies a moving average to the MACD line for generating buy/sell signals + + # Parameters + - `macd_column`: &str - Name of the MACD column to analyze + - `constant_model_type`: &str - Smoothing model ("sma", "ema", etc.) + - `period`: usize - Signal line period (commonly 9) + + # Returns + PySeriesStubbed - Series of signal line values with name "signal" + """ + def mcginley_dynamic_macd_line_bulk( + self, + price_column: builtins.str, + short_period: builtins.int, + previous_short_mcginley: builtins.float, + long_period: builtins.int, + previous_long_mcginley: builtins.float, + ) -> polars.DataFrame: + r""" + McGinley Dynamic MACD Line (Bulk) - MACD using McGinley Dynamic moving averages + Provides better adaptation to market volatility and reduces lag compared to traditional MACD + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `short_period`: usize - Period for short McGinley Dynamic + - `previous_short_mcginley`: f64 - Initial short McGinley Dynamic value + - `long_period`: usize - Period for long McGinley Dynamic + - `previous_long_mcginley`: f64 - Initial long McGinley Dynamic value + + # Returns + PyDfStubbed - DataFrame with columns: "macd", "short_mcginley", "long_mcginley" + """ + def chaikin_oscillator_bulk( + self, + high_column: builtins.str, + low_column: builtins.str, + close_column: builtins.str, + volume_column: builtins.str, + short_period: builtins.int, + long_period: builtins.int, + previous_accumulation_distribution: builtins.float, + short_period_model: builtins.str, + long_period_model: builtins.str, + ) -> tuple[polars.Series, polars.Series]: + r""" + Chaikin Oscillator (Bulk) - Applies MACD to Accumulation/Distribution line + Measures the momentum of the Accumulation/Distribution line + + # Parameters + - `high_column`: &str - Name of the high price column + - `low_column`: &str - Name of the low price column + - `close_column`: &str - Name of the close price column + - `volume_column`: &str - Name of the volume column + - `short_period`: usize - Short period for oscillator (commonly 3) + - `long_period`: usize - Long period for oscillator (commonly 10) + - `previous_accumulation_distribution`: f64 - Initial A/D line value + - `short_period_model`: &str - Model for short MA ("sma", "ema", etc.) + - `long_period_model`: &str - Model for long MA ("sma", "ema", etc.) + + # Returns + (PySeriesStubbed, PySeriesStubbed) - Tuple containing (Chaikin Oscillator, A/D Line) + """ + def percentage_price_oscillator_bulk( + self, price_column: builtins.str, short_period: builtins.int, long_period: builtins.int, constant_model_type: builtins.str + ) -> polars.Series: + r""" + Percentage Price Oscillator (Bulk) - MACD expressed as percentage + Similar to MACD but expressed as a percentage for easier comparison across securities + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `short_period`: usize - Short period for moving average (commonly 12) + - `long_period`: usize - Long period for moving average (commonly 26) + - `constant_model_type`: &str - Model for moving averages ("sma", "ema", etc.) + + # Returns + PySeriesStubbed - Series of PPO values as percentages with name "ppo" + """ + def chande_momentum_oscillator_bulk(self, price_column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Chande Momentum Oscillator (Bulk) - Measures momentum using gains and losses + Calculates the difference between sum of gains and losses over a period + Values range from -100 to +100 + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `period`: usize - Calculation period (commonly 14 or 20) + + # Returns + PySeriesStubbed - Series of CMO values (-100 to +100) with name "chande_momentum_oscillator" + """ + +class OtherTI: + r""" + Other Technical Indicators - A collection of other analysis functions for financial data + """ + def __new__(cls, lf: polars.LazyFrame) -> OtherTI: ... + def return_on_investment_single(self, price_column: builtins.str, investment: builtins.float) -> tuple[builtins.float, builtins.float]: + r""" + Return on Investment - Calculates investment value and percentage change for a single period + Uses the first and last values from the price column as start and end prices + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `investment`: f64 - Initial investment amount + + # Returns + Tuple of (final_investment_value: f64, percent_return: f64) + - `final_investment_value`: The absolute value of the investment at the end + - `percent_return`: The percentage return on the investment + """ + def return_on_investment_bulk(self, price_column: builtins.str, investment: builtins.float) -> tuple[polars.Series, polars.Series]: + r""" + Return on Investment Bulk - Calculates ROI for a series of consecutive price periods + Uses the price column as price values for consecutive period calculations + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `investment`: f64 - Initial investment amount + + # Returns + Tuple of (final_investment_values: PySeriesStubbed, percent_returns: PySeriesStubbed) + - `final_investment_values`: Series of absolute investment values for each period + - `percent_returns`: Series of percentage returns for each period + """ + def true_range(self, high_column: builtins.str, low_column: builtins.str, close_column: builtins.str) -> polars.Series: + r""" + True Range - Calculates the greatest price movement for a single period + Uses the provided high/low/close columns to calculate true range + + # Parameters + - `high_column`: &str - Name of the high price column + - `low_column`: &str - Name of the low price column + - `close_column`: &str - Name of the close price column + + # Returns + PySeriesStubbed - Series of true range values for each period + """ + def average_true_range_single( + self, high_column: builtins.str, low_column: builtins.str, close_column: builtins.str, constant_model_type: builtins.str + ) -> builtins.float: + r""" + Average True Range - Calculates the moving average of true range values for a single result + Uses the provided high/low/close columns to calculate ATR from the entire price series + + # Parameters + - `high_column`: &str - Name of the high price column + - `low_column`: &str - Name of the low price column + - `close_column`: &str - Name of the close price column + - `constant_model_type`: &str - Type of moving average ("sma", "ema", "wma", etc.) + + # Returns + f64 - Single ATR value calculated from the entire price series + """ + def average_true_range_bulk( + self, high_column: builtins.str, low_column: builtins.str, close_column: builtins.str, constant_model_type: builtins.str, period: builtins.int + ) -> polars.Series: + r""" + Average True Range Bulk - Calculates rolling ATR values over specified periods + Uses the provided high/low/close columns for rolling ATR calculations + + # Parameters + - `high_column`: &str - Name of the high price column + - `low_column`: &str - Name of the low price column + - `close_column`: &str - Name of the close price column + - `constant_model_type`: &str - Type of moving average ("sma", "ema", "wma", etc.) + - `period`: usize - Number of periods for the moving average calculation + + # Returns + PySeriesStubbed - Series of ATR values for each period + """ + def internal_bar_strength(self, high_column: builtins.str, low_column: builtins.str, close_column: builtins.str) -> polars.Series: + r""" + Internal Bar Strength - Calculates buy/sell oscillator based on close position within high-low range + Uses the provided high/low/close columns to calculate IBS values + + # Parameters + - `high_column`: &str - Name of the high price column + - `low_column`: &str - Name of the low price column + - `close_column`: &str - Name of the close price column + + # Returns + PySeriesStubbed - Series of IBS values (0-1 range) for each period, where values closer to 1 + indicate closes near the high, and values closer to 0 indicate closes near the low + """ + def positivity_indicator( + self, open_column: builtins.str, close_column: builtins.str, signal_period: builtins.int, constant_model_type: builtins.str + ) -> tuple[polars.Series, polars.Series]: + r""" + Positivity Indicator - Generates trading signals based on open vs previous close comparison + Uses the provided open/close columns for signal generation + + # Parameters + - `open_column`: &str - Name of the opening price column + - `close_column`: &str - Name of the close price column + - `signal_period`: usize - Number of periods for signal line smoothing + - `constant_model_type`: &str - Type of moving average for signal line ("sma", "ema", "wma", etc.) + + # Returns + Tuple of (positivity_indicator: PySeriesStubbed, signal_line: PySeriesStubbed) + - `positivity_indicator`: Series of raw positivity values based on open/close comparison + - `signal_line`: Series of smoothed signal values using specified moving average + """ + +class StandardTI: + def __new__(cls, lf: polars.LazyFrame) -> StandardTI: ... + def sma_single(self, price_column: builtins.str) -> builtins.float: + r""" + Simple Moving Average (Single) - calculates the mean of all values in the column + + # Parameters + - `price_column`: &str - Name of the price column to analyze + + # Returns + f64 - Single SMA value calculated from all provided prices + """ + def sma_bulk(self, price_column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Simple Moving Average (Bulk) - calculates the mean over a rolling window + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `period`: usize - Number of periods for the moving average window + + # Returns + PySeriesStubbed - Series containing SMA values for each period + """ + def smma_single(self, price_column: builtins.str) -> builtins.float: + r""" + Smoothed Moving Average (Single) - single value calculation + + # Parameters + - `price_column`: &str - Name of the price column to analyze + + # Returns + f64 - Single SMMA value calculated from all provided prices + """ + def smma_bulk(self, price_column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Smoothed Moving Average (Bulk) - puts more weight on recent prices + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `period`: usize - Number of periods for the smoothed moving average window + + # Returns + PySeriesStubbed - Series containing SMMA values for each period + """ + def ema_single(self, price_column: builtins.str) -> builtins.float: + r""" + Exponential Moving Average (Single) - single value calculation + + # Parameters + - `price_column`: &str - Name of the price column to analyze + + # Returns + f64 - Single EMA value calculated from all provided prices + """ + def ema_bulk(self, price_column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Exponential Moving Average (Bulk) - puts exponentially more weight on recent prices + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `period`: usize - Number of periods for the exponential moving average window + + # Returns + PySeriesStubbed - Series containing EMA values for each period + """ + def bollinger_bands_single(self, price_column: builtins.str) -> tuple[builtins.float, builtins.float, builtins.float]: + r""" + Bollinger Bands (Single) - single value calculation (requires exactly 20 periods) + + # Parameters + - `price_column`: &str - Name of the price column to analyze + + # Returns + Tuple of (lower_band: f64, middle_band: f64, upper_band: f64) + - `lower_band`: Lower Bollinger Band value + - `middle_band`: Middle band (SMA) value + - `upper_band`: Upper Bollinger Band value + """ + def bollinger_bands_bulk(self, price_column: builtins.str) -> polars.DataFrame: + r""" + Bollinger Bands (Bulk) - returns three series: lower band, middle (SMA), upper band + Standard period is 20 with 2 standard deviations + + # Parameters + - `price_column`: &str - Name of the price column to analyze + + # Returns + PyDfStubbed - DataFrame with three columns: + - `bb_lower`: Lower Bollinger Band values + - `bb_middle`: Middle band (20-period SMA) + - `bb_upper`: Upper Bollinger Band values + """ + def macd_single(self, price_column: builtins.str) -> tuple[builtins.float, builtins.float, builtins.float]: + r""" + MACD (Single) - single value calculation (requires exactly 34 periods) + + # Parameters + - `price_column`: &str - Name of the price column to analyze + + # Returns + Tuple of (macd_line: f64, signal_line: f64, histogram: f64) + - `macd_line`: MACD line value (12-period EMA - 26-period EMA) + - `signal_line`: Signal line value (9-period EMA of MACD line) + - `histogram`: Histogram value (MACD line - Signal line) + """ + def macd_bulk(self, price_column: builtins.str) -> polars.DataFrame: + r""" + MACD (Bulk) - Moving Average Convergence Divergence + Returns three series: MACD line, Signal line, Histogram + Standard periods: 12, 26, 9 + + # Parameters + - `price_column`: &str - Name of the price column to analyze + + # Returns + PyDfStubbed - DataFrame with three columns: + - `macd`: MACD line (12-period EMA - 26-period EMA) + - `macd_signal`: Signal line (9-period EMA of MACD line) + - `macd_histogram`: Histogram (MACD line - Signal line) + """ + def rsi_single(self, price_column: builtins.str) -> builtins.float: + r""" + RSI (Single) - single value calculation (requires exactly 14 periods) + + # Parameters + - `price_column`: &str - Name of the price column to analyze + + # Returns + f64 - Single RSI value (0-100 scale) + """ + def rsi_bulk(self, price_column: builtins.str) -> polars.Series: + r""" + RSI (Bulk) - Relative Strength Index + Standard period is 14 using smoothed moving average + + # Parameters + - `price_column`: &str - Name of the price column to analyze + + # Returns + PySeriesStubbed - Series containing RSI values (0-100 scale) + """ + +class StrengthTI: + def __new__(cls, lf: polars.LazyFrame) -> StrengthTI: ... + def accumulation_distribution_single( + self, + high_column: builtins.str, + low_column: builtins.str, + close_column: builtins.str, + volume_column: builtins.str, + previous_ad: typing.Optional[builtins.float], + ) -> builtins.float: + r""" + Accumulation Distribution (Single) - Shows whether the stock is being accumulated or distributed + Single value calculation using the last available values + + # Parameters + - `high_column`: &str - Name of the high price column + - `low_column`: &str - Name of the low price column + - `close_column`: &str - Name of the close price column + - `volume_column`: &str - Name of the volume column + - `previous_ad`: Option - Previous accumulation/distribution value (defaults to 0.0) + + # Returns + f64 - Single accumulation/distribution value + """ + def accumulation_distribution_bulk( + self, + high_column: builtins.str, + low_column: builtins.str, + close_column: builtins.str, + volume_column: builtins.str, + previous_ad: typing.Optional[builtins.float], + ) -> polars.Series: + r""" + Accumulation Distribution (Bulk) - Shows whether the stock is being accumulated or distributed + Returns a series of accumulation/distribution values + + # Parameters + - `high_column`: &str - Name of the high price column + - `low_column`: &str - Name of the low price column + - `close_column`: &str - Name of the close price column + - `volume_column`: &str - Name of the volume column + - `previous_ad`: Option - Previous accumulation/distribution value (defaults to 0.0) + + # Returns + PySeriesStubbed - Series containing accumulation/distribution values with name "accumulation_distribution" + """ + def positive_volume_index_single( + self, close_column: builtins.str, volume_column: builtins.str, previous_pvi: typing.Optional[builtins.float] + ) -> builtins.float: + r""" + Positive Volume Index (Single) - Measures volume trend strength when volume increases + Single value calculation using the last available values + + # Parameters + - `close_column`: &str - Name of the close price column + - `volume_column`: &str - Name of the volume column + - `previous_pvi`: Option - Previous positive volume index value (defaults to 0.0) + + # Returns + f64 - Single positive volume index value + """ + def positive_volume_index_bulk(self, close_column: builtins.str, volume_column: builtins.str, previous_pvi: typing.Optional[builtins.float]) -> polars.Series: + r""" + Positive Volume Index (Bulk) - Measures volume trend strength when volume increases + Returns a series of positive volume index values + + # Parameters + - `close_column`: &str - Name of the close price column + - `volume_column`: &str - Name of the volume column + - `previous_pvi`: Option - Previous positive volume index value (defaults to 0.0) + + # Returns + PySeriesStubbed - Series containing positive volume index values with name "positive_volume_index" + """ + def negative_volume_index_single( + self, close_column: builtins.str, volume_column: builtins.str, previous_nvi: typing.Optional[builtins.float] + ) -> builtins.float: + r""" + Negative Volume Index (Single) - Measures volume trend strength when volume decreases + Single value calculation using the last available values + + # Parameters + - `close_column`: &str - Name of the close price column + - `volume_column`: &str - Name of the volume column + - `previous_nvi`: Option - Previous negative volume index value (defaults to 0.0) + + # Returns + f64 - Single negative volume index value + """ + def negative_volume_index_bulk(self, close_column: builtins.str, volume_column: builtins.str, previous_nvi: typing.Optional[builtins.float]) -> polars.Series: + r""" + Negative Volume Index (Bulk) - Measures volume trend strength when volume decreases + Returns a series of negative volume index values + + # Parameters + - `close_column`: &str - Name of the close price column + - `volume_column`: &str - Name of the volume column + - `previous_nvi`: Option - Previous negative volume index value (defaults to 0.0) + + # Returns + PySeriesStubbed - Series containing negative volume index values with name "negative_volume_index" + """ + def relative_vigor_index_single( + self, open_column: builtins.str, high_column: builtins.str, low_column: builtins.str, close_column: builtins.str, constant_model_type: builtins.str + ) -> builtins.float: + r""" + Relative Vigor Index (Single) - Measures the strength of an asset by looking at previous prices + Single value calculation using all available values + + # Parameters + - `open_column`: &str - Name of the opening price column + - `high_column`: &str - Name of the high price column + - `low_column`: &str - Name of the low price column + - `close_column`: &str - Name of the close price column + - `constant_model_type`: &str - Type of constant model to use + + # Returns + f64 - Single relative vigor index value + """ + def relative_vigor_index_bulk( + self, + open_column: builtins.str, + high_column: builtins.str, + low_column: builtins.str, + close_column: builtins.str, + constant_model_type: builtins.str, + period: builtins.int, + ) -> polars.Series: + r""" + Relative Vigor Index (Bulk) - Measures the strength of an asset by looking at previous prices + Returns a series of relative vigor index values + + # Parameters + - `open_column`: &str - Name of the opening price column + - `high_column`: &str - Name of the high price column + - `low_column`: &str - Name of the low price column + - `close_column`: &str - Name of the close price column + - `constant_model_type`: &str - Type of constant model to use + - `period`: usize - Period length for calculation + + # Returns + PySeriesStubbed - Series containing relative vigor index values with name "relative_vigor_index" + """ + +class TrendTI: + r""" + Trend Technical Indicators - A collection of trend analysis functions for financial data + """ + def __new__(cls, lf: polars.LazyFrame) -> TrendTI: ... + def aroon_up_single(self, high_column: builtins.str) -> builtins.float: + r""" + Aroon Up (Single) - Measures the strength of upward price momentum + Calculates the percentage of time since the highest high within the series + + # Parameters + - `high_column`: &str - Name of the high price column to analyze + + # Returns + f64 - Aroon Up value (0-100), where higher values indicate stronger upward momentum + """ + def aroon_down_single(self, low_column: builtins.str) -> builtins.float: + r""" + Aroon Down (Single) - Measures the strength of downward price momentum + Calculates the percentage of time since the lowest low within the series + + # Parameters + - `low_column`: &str - Name of the low price column to analyze + + # Returns + f64 - Aroon Down value (0-100), where higher values indicate stronger downward momentum + """ + def aroon_oscillator_single(self, high_column: builtins.str, low_column: builtins.str) -> builtins.float: + r""" + Aroon Oscillator (Single) - Calculates the difference between Aroon Up and Aroon Down + Provides a single measure of trend direction and strength + + # Parameters + - `high_column`: &str - Name of the high price column + - `low_column`: &str - Name of the low price column + + # Returns + f64 - Aroon Oscillator value (-100 to 100), where positive values indicate upward trend + """ + def aroon_indicator_single(self, high_column: builtins.str, low_column: builtins.str) -> tuple[builtins.float, builtins.float, builtins.float]: + r""" + Aroon Indicator (Single) - Calculates complete Aroon system in one call + Computes Aroon Up, Aroon Down, and Aroon Oscillator + + # Parameters + - `high_column`: &str - Name of the high price column + - `low_column`: &str - Name of the low price column + + # Returns + (f64, f64, f64) - Tuple containing (Aroon Up, Aroon Down, Aroon Oscillator) + """ + def true_strength_index_single( + self, price_column: builtins.str, first_constant_model: builtins.str, first_period: builtins.int, second_constant_model: builtins.str + ) -> builtins.float: + r""" + True Strength Index (Single) - Momentum oscillator using double-smoothed price changes + Filters out price noise to provide clearer momentum signals + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `first_constant_model`: &str - First smoothing method ("SimpleMovingAverage", "ExponentialMovingAverage", etc.) + - `first_period`: usize - Period for first smoothing + - `second_constant_model`: &str - Second smoothing method + + # Returns + f64 - True Strength Index value (-100 to 100) + """ + def aroon_up_bulk(self, high_column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Aroon Up (Bulk) - Calculates rolling Aroon Up indicator over specified period + Measures upward momentum strength for each period in the time series + + # Parameters + - `high_column`: &str - Name of the high price column to analyze + - `period`: usize - Lookback period for calculation (typically 14) + + # Returns + PySeriesStubbed - Series of Aroon Up values (0-100) named "aroon_up" + """ + def aroon_down_bulk(self, low_column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Aroon Down (Bulk) - Calculates rolling Aroon Down indicator over specified period + Measures downward momentum strength for each period in the time series + + # Parameters + - `low_column`: &str - Name of the low price column to analyze + - `period`: usize - Lookback period for calculation (typically 14) + + # Returns + PySeriesStubbed - Series of Aroon Down values (0-100) named "aroon_down" + """ + def aroon_oscillator_bulk(self, high_column: builtins.str, low_column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Aroon Oscillator (Bulk) - Calculates rolling Aroon Oscillator over specified period + Computes the difference between Aroon Up and Aroon Down for each period + + # Parameters + - `high_column`: &str - Name of the high price column + - `low_column`: &str - Name of the low price column + - `period`: usize - Lookback period for calculation (typically 14) + + # Returns + PySeriesStubbed - Series of Aroon Oscillator values (-100 to 100) named "aroon_oscillator" + """ + def aroon_indicator_bulk(self, high_column: builtins.str, low_column: builtins.str, period: builtins.int) -> polars.DataFrame: + r""" + Aroon Indicator (Bulk) - Calculates complete Aroon system for time series data + Computes Aroon Up, Aroon Down, and Aroon Oscillator for each period + + # Parameters + - `high_column`: &str - Name of the high price column + - `low_column`: &str - Name of the low price column + - `period`: usize - Lookback period for calculation (typically 14) + + # Returns + PyDfStubbed - DataFrame with columns: "aroon_up", "aroon_down", "aroon_oscillator" + """ + def parabolic_time_price_system_bulk( + self, + high_column: builtins.str, + low_column: builtins.str, + acceleration_factor_start: builtins.float, + acceleration_factor_max: builtins.float, + acceleration_factor_step: builtins.float, + start_position: builtins.str, + previous_sar: builtins.float, + ) -> polars.Series: + r""" + Parabolic Time Price System (Bulk) - Calculates Stop and Reverse points + Provides trailing stop levels for trend-following system + + # Parameters + - `high_column`: &str - Name of the high price column + - `low_column`: &str - Name of the low price column + - `acceleration_factor_start`: f64 - Initial acceleration factor (typically 0.02) + - `acceleration_factor_max`: f64 - Maximum acceleration factor (typically 0.20) + - `acceleration_factor_step`: f64 - Acceleration factor increment (typically 0.02) + - `start_position`: &str - Initial position: "Long" or "Short" + - `previous_sar`: f64 - Initial SAR value + + # Returns + PySeriesStubbed - Series of SAR values named "parabolic_sar" + """ + def directional_movement_system_bulk( + self, high_column: builtins.str, low_column: builtins.str, close_column: builtins.str, period: builtins.int, constant_model_type: builtins.str + ) -> polars.DataFrame: + r""" + Directional Movement System (Bulk) - Calculates complete DMS indicators + Computes +DI, -DI, ADX, and ADXR for trend strength analysis + + # Parameters + - `high_column`: &str - Name of the high price column + - `low_column`: &str - Name of the low price column + - `close_column`: &str - Name of the close price column + - `period`: usize - Calculation period (typically 14) + - `constant_model_type`: &str - Smoothing method: "SimpleMovingAverage", "SmoothedMovingAverage", etc. + + # Returns + PyDfStubbed - DataFrame with columns: "positive_di", "negative_di", "adx", "adxr" + """ + def volume_price_trend_bulk(self, price_column: builtins.str, volume_column: builtins.str, previous_volume_price_trend: builtins.float) -> polars.Series: + r""" + Volume Price Trend (Bulk) - Combines price and volume to show momentum + Shows the relationship between price movement and volume flow + + # Parameters + - `price_column`: &str - Name of the price column + - `volume_column`: &str - Name of the volume column + - `previous_volume_price_trend`: f64 - Initial VPT value (typically 0) + + # Returns + PySeriesStubbed - Series of Volume Price Trend values named "volume_price_trend" + """ + def true_strength_index_bulk( + self, + price_column: builtins.str, + first_constant_model: builtins.str, + first_period: builtins.int, + second_constant_model: builtins.str, + second_period: builtins.int, + ) -> polars.Series: + r""" + True Strength Index (Bulk) - Double-smoothed momentum oscillator + Uses double-smoothed price changes to filter noise and provide clearer signals + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `first_constant_model`: &str - First smoothing method: "SimpleMovingAverage", "ExponentialMovingAverage", etc. + - `first_period`: usize - Period for first smoothing (typically 25) + - `second_constant_model`: &str - Second smoothing method + - `second_period`: usize - Period for second smoothing (typically 13) + + # Returns + PySeriesStubbed - Series of TSI values (-100 to 100) named "true_strength_index" + """ + +class VolatilityTI: + def __new__(cls, lf: polars.LazyFrame) -> VolatilityTI: ... + def ulcer_index_single(self, price_column: builtins.str) -> builtins.float: + r""" + Ulcer Index (Single) - Calculates how quickly the price is able to get back to its former high + Can be used instead of standard deviation for volatility measurement + + # Parameters + - `price_column`: &str - Name of the price column to analyze + + # Returns + f64 - Single Ulcer Index value representing overall price volatility and drawdown risk + """ + def ulcer_index_bulk(self, price_column: builtins.str, period: builtins.int) -> polars.Series: + r""" + Ulcer Index (Bulk) - Calculates rolling Ulcer Index over specified period + Returns a series of Ulcer Index values + + # Parameters + - `price_column`: &str - Name of the price column to analyze + - `period`: usize - Rolling window period for calculation + + # Returns + PySeriesStubbed - Series of rolling Ulcer Index values with name "ulcer_index" + """ + def volatility_system( + self, + high_column: builtins.str, + low_column: builtins.str, + close_column: builtins.str, + period: builtins.int, + constant_multiplier: builtins.float, + constant_model_type: builtins.str, + ) -> polars.Series: + r""" + Volatility System - Calculates Welles volatility system with Stop and Reverse (SaR) points + Uses trend analysis to determine long/short positions and calculate SaR levels + Constant multiplier typically between 2.8-3.1 (Welles used 3.0) + + # Parameters + - `high_column`: &str - Name of the high price column + - `low_column`: &str - Name of the low price column + - `close_column`: &str - Name of the close price column + - `period`: usize - Period for volatility calculation + - `constant_multiplier`: f64 - Multiplier for volatility (typically 2.8-3.1) + - `constant_model_type`: &str - Type of constant model to use for calculation + + # Returns + PySeriesStubbed - Series of volatility system values with Stop and Reverse points, named "volatility_system" + """ diff --git a/plugins/ezpz-rust-ti/python/ezpz_rust_ti/_ezpz_rust_ti_macros.py b/plugins/ezpz-rust-ti/python/ezpz_rust_ti/_ezpz_rust_ti_macros.py new file mode 100644 index 0000000..84e53a0 --- /dev/null +++ b/plugins/ezpz-rust-ti/python/ezpz_rust_ti/_ezpz_rust_ti_macros.py @@ -0,0 +1,47 @@ +from ezpz_rust_ti._ezpz_rust_ti import MATI, BasicTI, OtherTI, TrendTI, CandleTI, MomentumTI, StandardTI, StrengthTI, VolatilityTI, ChartTrendsTI, CorrelationTI +from ezpz_pluginz.register_plugin_macro import ezpz_plugin_collect + +# Basic Technical Indicators +ezpz_plugin_collect(polars_ns="LazyFrame", attr_name="basic_ti", import_="from ezpz_rust_ti._ezpz_rust_ti import BasicTI", type_hint="BasicTI")(BasicTI) + +# Candle Technical Indicators +ezpz_plugin_collect(polars_ns="LazyFrame", attr_name="candle_ti", import_="from ezpz_rust_ti._ezpz_rust_ti import CandleTI", type_hint="CandleTI")(CandleTI) + +# Chart Trends Technical Indicators +ezpz_plugin_collect( + polars_ns="LazyFrame", attr_name="chart_trends_ti", import_="from ezpz_rust_ti._ezpz_rust_ti import ChartTrendsTI", type_hint="ChartTrendsTI" +)(ChartTrendsTI) + +# Correlation Technical Indicators +ezpz_plugin_collect( + polars_ns="LazyFrame", attr_name="correlation_ti", import_="from ezpz_rust_ti._ezpz_rust_ti import CorrelationTI", type_hint="CorrelationTI" +)(CorrelationTI) + +# Moving Average Technical Indicators +ezpz_plugin_collect(polars_ns="LazyFrame", attr_name="ma_ti", import_="from ezpz_rust_ti._ezpz_rust_ti import MATI", type_hint="MATI")(MATI) + +# Momentum Technical Indicators +ezpz_plugin_collect(polars_ns="LazyFrame", attr_name="momentum_ti", import_="from ezpz_rust_ti._ezpz_rust_ti import MomentumTI", type_hint="MomentumTI")( + MomentumTI +) + +# Other Technical Indicators +ezpz_plugin_collect(polars_ns="LazyFrame", attr_name="other_ti", import_="from ezpz_rust_ti._ezpz_rust_ti import OtherTI", type_hint="OtherTI")(OtherTI) + +# Standard Technical Indicators +ezpz_plugin_collect(polars_ns="LazyFrame", attr_name="standard_ti", import_="from ezpz_rust_ti._ezpz_rust_ti import StandardTI", type_hint="StandardTI")( + StandardTI +) + +# Strength Technical Indicators +ezpz_plugin_collect(polars_ns="LazyFrame", attr_name="strength_ti", import_="from ezpz_rust_ti._ezpz_rust_ti import StrengthTI", type_hint="StrengthTI")( + StrengthTI +) + +# Trend Technical Indicators +ezpz_plugin_collect(polars_ns="LazyFrame", attr_name="trend_ti", import_="from ezpz_rust_ti._ezpz_rust_ti import TrendTI", type_hint="TrendTI")(TrendTI) + +# Volatility Technical Indicators +ezpz_plugin_collect(polars_ns="LazyFrame", attr_name="volatility_ti", import_="from ezpz_rust_ti._ezpz_rust_ti import VolatilityTI", type_hint="VolatilityTI")( + VolatilityTI +) diff --git a/plugins/ezpz-rust-ti/python/ezpz_rust_ti/ezpz-lock.yml b/plugins/ezpz-rust-ti/python/ezpz_rust_ti/ezpz-lock.yml new file mode 100644 index 0000000..863f668 --- /dev/null +++ b/plugins/ezpz-rust-ti/python/ezpz_rust_ti/ezpz-lock.yml @@ -0,0 +1,47 @@ +project_plugins: + LazyFrame: + - polars_ns: LazyFrame + import_: from ezpz_rust_ti._ezpz_rust_ti import CandleTI + attr_name: candle_ti + type_hint: CandleTI + - polars_ns: LazyFrame + import_: from ezpz_rust_ti._ezpz_rust_ti import BasicTI + attr_name: basic_ti + type_hint: BasicTI + - polars_ns: LazyFrame + import_: from ezpz_rust_ti._ezpz_rust_ti import StrengthTI + attr_name: strength_ti + type_hint: StrengthTI + - polars_ns: LazyFrame + import_: from ezpz_rust_ti._ezpz_rust_ti import ChartTrendsTI + attr_name: chart_trends_ti + type_hint: ChartTrendsTI + - polars_ns: LazyFrame + import_: from ezpz_rust_ti._ezpz_rust_ti import CorrelationTI + attr_name: correlation_ti + type_hint: CorrelationTI + - polars_ns: LazyFrame + import_: from ezpz_rust_ti._ezpz_rust_ti import OtherTI + attr_name: other_ti + type_hint: OtherTI + - polars_ns: LazyFrame + import_: from ezpz_rust_ti._ezpz_rust_ti import StandardTI + attr_name: standard_ti + type_hint: StandardTI + - polars_ns: LazyFrame + import_: from ezpz_rust_ti._ezpz_rust_ti import MATI + attr_name: ma_ti + type_hint: MATI + - polars_ns: LazyFrame + import_: from ezpz_rust_ti._ezpz_rust_ti import VolatilityTI + attr_name: volatility_ti + type_hint: VolatilityTI + - polars_ns: LazyFrame + import_: from ezpz_rust_ti._ezpz_rust_ti import MomentumTI + attr_name: momentum_ti + type_hint: MomentumTI + - polars_ns: LazyFrame + import_: from ezpz_rust_ti._ezpz_rust_ti import TrendTI + attr_name: trend_ti + type_hint: TrendTI +site_plugins: {} diff --git a/guiz/python/ezpz_guiz/py.typed b/plugins/ezpz-rust-ti/python/ezpz_rust_ti/py.typed similarity index 100% rename from guiz/python/ezpz_guiz/py.typed rename to plugins/ezpz-rust-ti/python/ezpz_rust_ti/py.typed diff --git a/guiz/src/bin/stub_gen.rs b/plugins/ezpz-rust-ti/src/bin/stub_gen.rs similarity index 54% rename from guiz/src/bin/stub_gen.rs rename to plugins/ezpz-rust-ti/src/bin/stub_gen.rs index a1da55a..a3d96f4 100644 --- a/guiz/src/bin/stub_gen.rs +++ b/plugins/ezpz-rust-ti/src/bin/stub_gen.rs @@ -1,4 +1,4 @@ -use {ezpz_guiz::stub_info, pyo3_stub_gen::Result}; +use {ezpz_rust_ti::stub_info, pyo3_stub_gen::Result}; fn main() -> Result<()> { stub_info()?.generate()?; diff --git a/plugins/ezpz-rust-ti/src/indicators/basic/mod.rs b/plugins/ezpz-rust-ti/src/indicators/basic/mod.rs new file mode 100644 index 0000000..1adb483 --- /dev/null +++ b/plugins/ezpz-rust-ti/src/indicators/basic/mod.rs @@ -0,0 +1,748 @@ +use { + crate::utils::{extract_f64_values, parse_central_point}, + ezpz_stubz::{lazy::PyLfStubbed, series::PySeriesStubbed}, + polars::prelude::*, + pyo3::prelude::*, + pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}, +}; + +/// Basic Technical Indicators - A collection of basic analysis functions for financial data + +#[gen_stub_pyclass] +#[pyclass] +#[derive(Clone)] +pub struct BasicTI { + lf: LazyFrame, +} + +#[gen_stub_pymethods] +#[pymethods] +impl BasicTI { + #[new] + fn new(lf: PyLfStubbed) -> Self { + Self { lf: lf.0.into() } + } + + // Single value functions (return a single value from the entire prices) + + /// Calculate the arithmetic mean of all values. + /// + /// # Parameters + /// - `column`: &str - Name of the column to analyze + /// + /// # Returns + /// f64 - The arithmetic mean + fn mean_single(&self, column: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{column}': {e}")))? + .column(column) + .map_err(|e| PyErr::new::(format!("Column '{column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{column}' could not be converted to Series")))? + .clone(); + + let values = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + Ok(rust_ti::basic_indicators::single::mean(&values)) + } + + /// Calculate the median of all values. + /// + /// # Parameters + /// - `column`: &str - Name of the column to analyze + /// + /// # Returns + /// f64 - The median value + fn median_single(&self, column: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{column}': {e}")))? + .column(column) + .map_err(|e| PyErr::new::(format!("Column '{column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{column}' could not be converted to Series")))? + .clone(); + + let values = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + Ok(rust_ti::basic_indicators::single::median(&values)) + } + + /// Calculate the mode of all values. + /// + /// # Parameters + /// - `column`: &str - Name of the column to analyze + /// + /// # Returns + /// f64 - The most frequently occurring value + fn mode_single(&self, column: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{column}': {e}")))? + .column(column) + .map_err(|e| PyErr::new::(format!("Column '{column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{column}' could not be converted to Series")))? + .clone(); + + let values = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + Ok(rust_ti::basic_indicators::single::mode(&values)) + } + + /// Calculate the variance of all values. + /// + /// # Parameters + /// - `column`: &str - Name of the column to analyze + /// + /// # Returns + /// f64 - The variance + fn variance_single(&self, column: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{column}': {e}")))? + .column(column) + .map_err(|e| PyErr::new::(format!("Column '{column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{column}' could not be converted to Series")))? + .clone(); + + let values = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + Ok(rust_ti::basic_indicators::single::variance(&values)) + } + + /// Calculate the standard deviation of all values. + /// + /// # Parameters + /// - `column`: &str - Name of the column to analyze + /// + /// # Returns + /// f64 - The standard deviation + fn standard_deviation_single(&self, column: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{column}': {e}")))? + .column(column) + .map_err(|e| PyErr::new::(format!("Column '{column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{column}' could not be converted to Series")))? + .clone(); + + let values = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + Ok(rust_ti::basic_indicators::single::standard_deviation(&values)) + } + + /// Find the maximum value. + /// + /// # Parameters + /// - `column`: &str - Name of the column to analyze + /// + /// # Returns + /// f64 - The maximum value + fn max_single(&self, column: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{column}': {e}")))? + .column(column) + .map_err(|e| PyErr::new::(format!("Column '{column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{column}' could not be converted to Series")))? + .clone(); + + let values = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + Ok(rust_ti::basic_indicators::single::max(&values)) + } + + /// Find the minimum value. + /// + /// # Parameters + /// - `column`: &str - Name of the column to analyze + /// + /// # Returns + /// f64 - The minimum value + fn min_single(&self, column: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{column}': {e}")))? + .column(column) + .map_err(|e| PyErr::new::(format!("Column '{column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{column}' could not be converted to Series")))? + .clone(); + + let values = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + Ok(rust_ti::basic_indicators::single::min(&values)) + } + + /// Calculate the absolute deviation from a central point. + /// + /// # Parameters + /// - `column`: &str - Name of the column to analyze + /// - `central_point`: &str - Central point type ("mean", "median", etc.) + /// + /// # Returns + /// f64 - The absolute deviation + fn absolute_deviation_single(&self, column: &str, central_point: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{column}': {e}")))? + .column(column) + .map_err(|e| PyErr::new::(format!("Column '{column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{column}' could not be converted to Series")))? + .clone(); + + let values = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let cp = parse_central_point(central_point)?; + Ok(rust_ti::basic_indicators::single::absolute_deviation(&values, cp)) + } + + /// Calculate the logarithmic difference between two price points. + /// + /// # Parameters + /// - `price_t`: f64 - Current price value + /// - `price_t_1`: f64 - Previous price value + /// + /// # Returns + /// f64 - The logarithmic difference + fn log_difference_single(&self, price_t: f64, price_t_1: f64) -> PyResult { + Ok(rust_ti::basic_indicators::single::log_difference(price_t, price_t_1)) + } + + // Bulk functions (return prices with rolling calculations) + + /// Calculate rolling mean over a specified period. + /// + /// # Parameters + /// - `column`: &str - Name of the column to analyze + /// - `period`: usize - Rolling window size + /// + /// # Returns + /// PySeriesStubbed - Series containing rolling mean values + fn mean_bulk(&self, column: &str, period: usize) -> PyResult { + let series = self + .lf + .clone() + .select([col(column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{column}': {e}")))? + .column(column) + .map_err(|e| PyErr::new::(format!("Column '{column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{column}' could not be converted to Series")))? + .clone(); + + let values = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::basic_indicators::bulk::mean(&values, period); + let result_series = Series::new("mean".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Calculate rolling median over a specified period. + /// + /// # Parameters + /// - `column`: &str - Name of the column to analyze + /// - `period`: usize - Rolling window size + /// + /// # Returns + /// PySeriesStubbed - Series containing rolling median values + fn median_bulk(&self, column: &str, period: usize) -> PyResult { + let series = self + .lf + .clone() + .select([col(column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{column}': {e}")))? + .column(column) + .map_err(|e| PyErr::new::(format!("Column '{column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{column}' could not be converted to Series")))? + .clone(); + + let values = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::basic_indicators::bulk::median(&values, period); + let result_series = Series::new("median".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Calculate rolling mode over a specified period. + /// + /// # Parameters + /// - `column`: &str - Name of the column to analyze + /// - `period`: usize - Rolling window size + /// + /// # Returns + /// PySeriesStubbed - Series containing rolling mode values + fn mode_bulk(&self, column: &str, period: usize) -> PyResult { + let series = self + .lf + .clone() + .select([col(column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{column}': {e}")))? + .column(column) + .map_err(|e| PyErr::new::(format!("Column '{column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{column}' could not be converted to Series")))? + .clone(); + + let values = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::basic_indicators::bulk::mode(&values, period); + let result_series = Series::new("mode".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Calculate rolling variance over a specified period. + /// + /// # Parameters + /// - `column`: &str - Name of the column to analyze + /// - `period`: usize - Rolling window size + /// + /// # Returns + /// PySeriesStubbed - Series containing rolling variance values + fn variance_bulk(&self, column: &str, period: usize) -> PyResult { + let series = self + .lf + .clone() + .select([col(column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{column}': {e}")))? + .column(column) + .map_err(|e| PyErr::new::(format!("Column '{column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{column}' could not be converted to Series")))? + .clone(); + + let values = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::basic_indicators::bulk::variance(&values, period); + let result_series = Series::new("variance".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Calculate rolling standard deviation over a specified period. + /// + /// # Parameters + /// - `column`: &str - Name of the column to analyze + /// - `period`: usize - Rolling window size + /// + /// # Returns + /// PySeriesStubbed - Series containing rolling standard deviation values + fn standard_deviation_bulk(&self, column: &str, period: usize) -> PyResult { + let series = self + .lf + .clone() + .select([col(column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{column}': {e}")))? + .column(column) + .map_err(|e| PyErr::new::(format!("Column '{column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{column}' could not be converted to Series")))? + .clone(); + + let values = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::basic_indicators::bulk::standard_deviation(&values, period); + let result_series = Series::new("standard_deviation".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Calculate rolling absolute deviation over a specified period. + /// + /// # Parameters + /// - `column`: &str - Name of the column to analyze + /// - `period`: usize - Rolling window size + /// - `central_point`: &str - Central point type ("mean", "median", etc.) + /// + /// # Returns + /// PySeriesStubbed - Series containing rolling absolute deviation values + fn absolute_deviation_bulk(&self, column: &str, period: usize, central_point: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{column}': {e}")))? + .column(column) + .map_err(|e| PyErr::new::(format!("Column '{column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{column}' could not be converted to Series")))? + .clone(); + + let values = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let cp = parse_central_point(central_point)?; + let result = rust_ti::basic_indicators::bulk::absolute_deviation(&values, period, cp); + let result_series = Series::new("absolute_deviation".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Calculate natural logarithm of all values. + /// + /// # Parameters + /// - `column`: &str - Name of the column to analyze + /// + /// # Returns + /// PySeriesStubbed - Series containing natural logarithm values + fn log_bulk(&self, column: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{column}': {e}")))? + .column(column) + .map_err(|e| PyErr::new::(format!("Column '{column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{column}' could not be converted to Series")))? + .clone(); + + let values = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::basic_indicators::bulk::log(&values); + let result_series = Series::new("log".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Calculate logarithmic differences between consecutive values. + /// + /// # Parameters + /// - `column`: &str - Name of the column to analyze + /// + /// # Returns + /// PySeriesStubbed - Series containing logarithmic difference values + fn log_difference_bulk(&self, column: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{column}': {e}")))? + .column(column) + .map_err(|e| PyErr::new::(format!("Column '{column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{column}' could not be converted to Series")))? + .clone(); + + let values = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::basic_indicators::bulk::log_difference(&values); + let result_series = Series::new("log_difference".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use approx::assert_abs_diff_eq; + use ezpz_stubz::lazy::PyLfStubbed; + + fn create_test_dataframe() -> LazyFrame { + let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + df! { + "price" => data, + "volume" => vec![100.0, 200.0, 150.0, 300.0, 250.0, 180.0, 220.0, 190.0, 280.0, 320.0] + } + .unwrap() + .lazy() + } + + fn create_basic_ti() -> BasicTI { + let lf = create_test_dataframe(); + BasicTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(lf))) + } + + #[test] + fn test_mean_single() { + let ti = create_basic_ti(); + let result = ti.mean_single("price").unwrap(); + assert_abs_diff_eq!(result, 5.5, epsilon = 1e-10); + } + + #[test] + fn test_median_single() { + let ti = create_basic_ti(); + let result = ti.median_single("price").unwrap(); + assert_abs_diff_eq!(result, 5.5, epsilon = 1e-10); + } + + #[test] + fn test_mode_single() { + let ti = create_basic_ti(); + let result = ti.mode_single("price").unwrap(); + assert_abs_diff_eq!(result, 1.0, epsilon = 1e-10); + } + + #[test] + fn test_variance_single() { + let ti = create_basic_ti(); + let result = ti.variance_single("price").unwrap(); + assert_abs_diff_eq!(result, 8.25, epsilon = 1e-10); + } + + #[test] + fn test_standard_deviation_single() { + let ti = create_basic_ti(); + let result = ti.standard_deviation_single("price").unwrap(); + assert_abs_diff_eq!(result, 2.8722813232690143, epsilon = 1e-10); + } + + #[test] + fn test_max_single() { + let ti = create_basic_ti(); + let result = ti.max_single("price").unwrap(); + assert_abs_diff_eq!(result, 10.0, epsilon = 1e-10); + } + + #[test] + fn test_min_single() { + let ti = create_basic_ti(); + let result = ti.min_single("price").unwrap(); + assert_abs_diff_eq!(result, 1.0, epsilon = 1e-10); + } + + #[test] + fn test_absolute_deviation_single_mean() { + let ti = create_basic_ti(); + let result = ti.absolute_deviation_single("price", "mean").unwrap(); + assert_abs_diff_eq!(result, 2.5, epsilon = 1e-10); + } + + #[test] + fn test_absolute_deviation_single_median() { + let ti = create_basic_ti(); + let result = ti.absolute_deviation_single("price", "median").unwrap(); + assert_abs_diff_eq!(result, 2.5, epsilon = 1e-10); + } + + #[test] + fn test_log_difference_single() { + let ti = create_basic_ti(); + let result = ti.log_difference_single(2.0, 1.0).unwrap(); + assert_abs_diff_eq!(result, (2.0_f64).ln() - (1.0_f64).ln(), epsilon = 1e-10); + } + + #[test] + fn test_mean_bulk() { + let ti = create_basic_ti(); + let result = ti.mean_bulk("price", 3).unwrap(); + let values: Vec = result.0.0.f64().unwrap().into_no_null_iter().collect(); + + assert_eq!(values.len(), 10); + assert!(values[0].is_nan()); + assert!(values[1].is_nan()); + assert_abs_diff_eq!(values[2], 2.0, epsilon = 1e-10); + assert_abs_diff_eq!(values[3], 3.0, epsilon = 1e-10); + } + + #[test] + fn test_median_bulk() { + let ti = create_basic_ti(); + let result = ti.median_bulk("price", 3).unwrap(); + let values: Vec = result.0.0.f64().unwrap().into_no_null_iter().collect(); + + assert_eq!(values.len(), 10); + assert!(values[0].is_nan()); + assert!(values[1].is_nan()); + assert_abs_diff_eq!(values[2], 2.0, epsilon = 1e-10); + } + + #[test] + fn test_mode_bulk() { + let ti = create_basic_ti(); + let result = ti.mode_bulk("price", 3).unwrap(); + let values: Vec = result.0.0.f64().unwrap().into_no_null_iter().collect(); + + assert_eq!(values.len(), 10); + assert!(values[0].is_nan()); + assert!(values[1].is_nan()); + } + + #[test] + fn test_variance_bulk() { + let ti = create_basic_ti(); + let result = ti.variance_bulk("price", 3).unwrap(); + let values: Vec = result.0.0.f64().unwrap().into_no_null_iter().collect(); + + assert_eq!(values.len(), 10); + assert!(values[0].is_nan()); + assert!(values[1].is_nan()); + assert_abs_diff_eq!(values[2], 0.6666666666666666, epsilon = 1e-10); + } + + #[test] + fn test_standard_deviation_bulk() { + let ti = create_basic_ti(); + let result = ti.standard_deviation_bulk("price", 3).unwrap(); + let values: Vec = result.0.0.f64().unwrap().into_no_null_iter().collect(); + + assert_eq!(values.len(), 10); + assert!(values[0].is_nan()); + assert!(values[1].is_nan()); + assert_abs_diff_eq!(values[2], 0.8164965809277261, epsilon = 1e-10); + } + + #[test] + fn test_absolute_deviation_bulk() { + let ti = create_basic_ti(); + let result = ti.absolute_deviation_bulk("price", 3, "mean").unwrap(); + let values: Vec = result.0.0.f64().unwrap().into_no_null_iter().collect(); + + assert_eq!(values.len(), 10); + assert!(values[0].is_nan()); + assert!(values[1].is_nan()); + assert_abs_diff_eq!(values[2], 0.6666666666666666, epsilon = 1e-10); + } + + #[test] + fn test_log_bulk() { + let ti = create_basic_ti(); + let result = ti.log_bulk("price").unwrap(); + let values: Vec = result.0.0.f64().unwrap().into_no_null_iter().collect(); + + assert_eq!(values.len(), 10); + assert_abs_diff_eq!(values[0], (1.0_f64).ln(), epsilon = 1e-10); + assert_abs_diff_eq!(values[1], (2.0_f64).ln(), epsilon = 1e-10); + assert_abs_diff_eq!(values[9], (10.0_f64).ln(), epsilon = 1e-10); + } + + #[test] + fn test_log_difference_bulk() { + let ti = create_basic_ti(); + let result = ti.log_difference_bulk("price").unwrap(); + let values: Vec = result.0.0.f64().unwrap().into_no_null_iter().collect(); + + assert_eq!(values.len(), 10); + assert!(values[0].is_nan()); + assert_abs_diff_eq!(values[1], (2.0_f64).ln() - (1.0_f64).ln(), epsilon = 1e-10); + assert_abs_diff_eq!(values[2], (3.0_f64).ln() - (2.0_f64).ln(), epsilon = 1e-10); + } + + #[test] + fn test_invalid_column_error() { + let ti = create_basic_ti(); + let result = ti.mean_single("nonexistent_column"); + assert!(result.is_err()); + } + + #[test] + fn test_invalid_central_point_error() { + let ti = create_basic_ti(); + let result = ti.absolute_deviation_single("price", "invalid_central_point"); + assert!(result.is_err()); + } + + #[test] + fn test_zero_period_bulk() { + let ti = create_basic_ti(); + let result = ti.mean_bulk("price", 0); + assert!( + result.is_err() || { + let values: Vec = result.unwrap().0.0.f64().unwrap().into_no_null_iter().collect(); + values.iter().all(|&x| x.is_nan()) + } + ); + } + + #[test] + fn test_large_period_bulk() { + let ti = create_basic_ti(); + let result = ti.mean_bulk("price", 20).unwrap(); + let values: Vec = result.0.0.f64().unwrap().into_no_null_iter().collect(); + + assert_eq!(values.len(), 10); + assert!(values.iter().take(9).all(|&x| x.is_nan())); + assert_abs_diff_eq!(values[9], 5.5, epsilon = 1e-10); + } + + #[test] + fn test_single_value_dataset() { + let single_data = df! { + "price" => vec![5.0] + } + .unwrap() + .lazy(); + + let ti = BasicTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(single_data))); + + assert_abs_diff_eq!(ti.mean_single("price").unwrap(), 5.0, epsilon = 1e-10); + assert_abs_diff_eq!(ti.median_single("price").unwrap(), 5.0, epsilon = 1e-10); + assert_abs_diff_eq!(ti.max_single("price").unwrap(), 5.0, epsilon = 1e-10); + assert_abs_diff_eq!(ti.min_single("price").unwrap(), 5.0, epsilon = 1e-10); + assert_abs_diff_eq!(ti.variance_single("price").unwrap(), 0.0, epsilon = 1e-10); + } + + #[test] + fn test_duplicate_values_dataset() { + let duplicate_data = df! { + "price" => vec![3.0, 3.0, 3.0, 3.0, 3.0] + } + .unwrap() + .lazy(); + + let ti = BasicTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(duplicate_data))); + + assert_abs_diff_eq!(ti.mean_single("price").unwrap(), 3.0, epsilon = 1e-10); + assert_abs_diff_eq!(ti.mode_single("price").unwrap(), 3.0, epsilon = 1e-10); + assert_abs_diff_eq!(ti.variance_single("price").unwrap(), 0.0, epsilon = 1e-10); + assert_abs_diff_eq!(ti.standard_deviation_single("price").unwrap(), 0.0, epsilon = 1e-10); + } + + #[test] + fn test_negative_values() { + let negative_data = df! { + "price" => vec![-5.0, -3.0, -1.0, 1.0, 3.0, 5.0] + } + .unwrap() + .lazy(); + + let ti = BasicTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(negative_data))); + + assert_abs_diff_eq!(ti.mean_single("price").unwrap(), 0.0, epsilon = 1e-10); + assert_abs_diff_eq!(ti.median_single("price").unwrap(), 0.0, epsilon = 1e-10); + } + + #[test] + fn test_floating_point_precision() { + let precision_data = df! { + "price" => vec![1.000000001, 1.000000002, 1.000000003] + } + .unwrap() + .lazy(); + + let ti = BasicTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(precision_data))); + + let mean = ti.mean_single("price").unwrap(); + assert_abs_diff_eq!(mean, 1.000000002, epsilon = 1e-9); + } + + #[test] + fn test_different_columns() { + let ti = create_basic_ti(); + + let price_mean = ti.mean_single("price").unwrap(); + let volume_mean = ti.mean_single("volume").unwrap(); + + assert_abs_diff_eq!(price_mean, 5.5, epsilon = 1e-10); + assert_abs_diff_eq!(volume_mean, 219.0, epsilon = 1e-10); + } +} diff --git a/plugins/ezpz-rust-ti/src/indicators/candle/mod.rs b/plugins/ezpz-rust-ti/src/indicators/candle/mod.rs new file mode 100644 index 0000000..392b8d5 --- /dev/null +++ b/plugins/ezpz-rust-ti/src/indicators/candle/mod.rs @@ -0,0 +1,1238 @@ +use { + crate::utils::{create_triple_df, extract_f64_values, parse_constant_model_type, parse_deviation_model, unzip_triple}, + ezpz_stubz::{frame::PyDfStubbed, lazy::PyLfStubbed, series::PySeriesStubbed}, + polars::prelude::*, + pyo3::prelude::*, + pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}, +}; + +/// Candle Technical Indicators - A collection of candle analysis functions for financial data + +#[gen_stub_pyclass] +#[pyclass] +#[derive(Clone)] +pub struct CandleTI { + lf: LazyFrame, +} + +#[gen_stub_pymethods] +#[pymethods] +impl CandleTI { + #[new] + fn new(lf: PyLfStubbed) -> Self { + Self { lf: lf.0.into() } + } + + /// Moving Constant Envelopes - Creates upper and lower bands from moving constant of price + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `constant_model_type`: &str - Type of moving average (e.g., "sma", "ema", "wma") + /// - `difference`: f64 - Fixed difference value to create envelope bands + /// + /// # Returns + /// DataFrame with columns: + /// - `lower_envelope`: f64 - Lower envelope band (middle - difference) + /// - `middle_envelope`: f64 - Middle line (moving average) + /// - `upper_envelope`: f64 - Upper envelope band (middle + difference) + fn moving_constant_envelopes_single(&self, price_column: &str, constant_model_type: &str, difference: f64) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let constant_type = parse_constant_model_type(constant_model_type)?; + let result = rust_ti::candle_indicators::single::moving_constant_envelopes(&values, constant_type, difference); + + let df = df! { + "lower_envelope" => [result.0], + "middle_envelope" => [result.1], + "upper_envelope" => [result.2], + } + .map_err(|e| PyErr::new::(e.to_string()))?; + + Ok(PyDfStubbed(pyo3_polars::PyDataFrame(df))) + } + + /// McGinley Dynamic Envelopes - Variation of moving constant envelopes using McGinley Dynamic + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `difference`: f64 - Fixed difference value to create envelope bands + /// - `previous_mcginley_dynamic`: f64 - Previous McGinley Dynamic value for calculation + /// + /// # Returns + /// DataFrame with columns: + /// - `lower_envelope`: f64 - Lower envelope band (McGinley Dynamic - difference) + /// - `mcginley_dynamic`: f64 - McGinley Dynamic value + /// - `upper_envelope`: f64 - Upper envelope band (McGinley Dynamic + difference) + fn mcginley_dynamic_envelopes_single(&self, price_column: &str, difference: f64, previous_mcginley_dynamic: f64) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::candle_indicators::single::mcginley_dynamic_envelopes(&values, difference, previous_mcginley_dynamic); + + let df = df! { + "lower_envelope" => [result.0], + "mcginley_dynamic" => [result.1], + "upper_envelope" => [result.2], + } + .map_err(|e| PyErr::new::(e.to_string()))?; + + Ok(PyDfStubbed(pyo3_polars::PyDataFrame(df))) + } + + /// Moving Constant Bands - Extended Bollinger Bands with configurable models + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `constant_model_type`: &str - Type of moving average for center line (e.g., "sma", "ema", "wma") + /// - `deviation_model`: &str - Type of deviation calculation (e.g., "std", "mad") + /// - `deviation_multiplier`: f64 - Multiplier for the deviation to create bands + /// + /// # Returns + /// DataFrame with columns: + /// - `lower_band`: f64 - Lower band (moving average - deviation * multiplier) + /// - `middle_band`: f64 - Middle band (moving average) + /// - `upper_band`: f64 - Upper band (moving average + deviation * multiplier) + fn moving_constant_bands_single( + &self, + price_column: &str, + constant_model_type: &str, + deviation_model: &str, + deviation_multiplier: f64, + ) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let constant_type = parse_constant_model_type(constant_model_type)?; + let deviation_type = parse_deviation_model(deviation_model)?; + let result = rust_ti::candle_indicators::single::moving_constant_bands(&values, constant_type, deviation_type, deviation_multiplier); + + let df = df! { + "lower_band" => [result.0], + "middle_band" => [result.1], + "upper_band" => [result.2], + } + .map_err(|e| PyErr::new::(e.to_string()))?; + + Ok(PyDfStubbed(pyo3_polars::PyDataFrame(df))) + } + + /// McGinley Dynamic Bands - Variation of moving constant bands using McGinley Dynamic + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `deviation_model`: &str - Type of deviation calculation (e.g., "std", "mad") + /// - `deviation_multiplier`: f64 - Multiplier for the deviation to create bands + /// - `previous_mcginley_dynamic`: f64 - Previous McGinley Dynamic value for calculation + /// + /// # Returns + /// DataFrame with columns: + /// - `lower_band`: f64 - Lower band (McGinley Dynamic - deviation * multiplier) + /// - `mcginley_dynamic`: f64 - McGinley Dynamic value + /// - `upper_band`: f64 - Upper band (McGinley Dynamic + deviation * multiplier) + fn mcginley_dynamic_bands_single( + &self, + price_column: &str, + deviation_model: &str, + deviation_multiplier: f64, + previous_mcginley_dynamic: f64, + ) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let deviation_type = parse_deviation_model(deviation_model)?; + let result = rust_ti::candle_indicators::single::mcginley_dynamic_bands(&values, deviation_type, deviation_multiplier, previous_mcginley_dynamic); + + let df = df! { + "lower_band" => [result.0], + "mcginley_dynamic" => [result.1], + "upper_band" => [result.2], + } + .map_err(|e| PyErr::new::(e.to_string()))?; + + Ok(PyDfStubbed(pyo3_polars::PyDataFrame(df))) + } + + /// Ichimoku Cloud - Calculates support and resistance levels + /// + /// # Parameters + /// - `high_column`: &str - Name of the high price column + /// - `low_column`: &str - Name of the low price column + /// - `close_column`: &str - Name of the close price column + /// - `conversion_period`: usize - Period for conversion line calculation (typically 9) + /// - `base_period`: usize - Period for base line calculation (typically 26) + /// - `span_b_period`: usize - Period for leading span B calculation (typically 52) + /// + /// # Returns + /// DataFrame with columns: + /// - `leading_span_a`: f64 - Leading Span A (future support/resistance) + /// - `leading_span_b`: f64 - Leading Span B (future support/resistance) + /// - `base_line`: f64 - Base Line (Kijun-sen) + /// - `conversion_line`: f64 - Conversion Line (Tenkan-sen) + /// - `lagged_price`: f64 - Lagging Span (Chikou Span) + fn ichimoku_cloud_single( + &self, + high_column: &str, + low_column: &str, + close_column: &str, + conversion_period: usize, + base_period: usize, + span_b_period: usize, + ) -> PyResult { + let df = self + .lf + .clone() + .select([col(high_column), col(low_column), col(close_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let close_series = df + .column(close_column) + .map_err(|e| PyErr::new::(format!("Column '{close_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{close_column}' could not be converted to Series")))? + .clone(); + + let high_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let low_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + let close_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(close_series)))?; + let result = rust_ti::candle_indicators::single::ichimoku_cloud(&high_values, &low_values, &close_values, conversion_period, base_period, span_b_period); + + let df = df! { + "leading_span_a" => [result.0], + "leading_span_b" => [result.1], + "base_line" => [result.2], + "conversion_line" => [result.3], + "lagged_price" => [result.4], + } + .map_err(|e| PyErr::new::(e.to_string()))?; + + Ok(PyDfStubbed(pyo3_polars::PyDataFrame(df))) + } + + /// Donchian Channels - Produces bands from period highs and lows + /// + /// # Parameters + /// - `high_column`: &str - Name of the high price column + /// - `low_column`: &str - Name of the low price column + /// + /// # Returns + /// DataFrame with columns: + /// - `donchian_lower`: f64 - Lower channel (lowest low over period) + /// - `donchian_middle`: f64 - Middle channel (average of upper and lower) + /// - `donchian_upper`: f64 - Upper channel (highest high over period) + fn donchian_channels_single(&self, high_column: &str, low_column: &str) -> PyResult { + let df = self + .lf + .clone() + .select([col(high_column), col(low_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let high_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let low_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + let result = rust_ti::candle_indicators::single::donchian_channels(&high_values, &low_values); + + let df = df! { + "donchian_lower" => [result.0], + "donchian_middle" => [result.1], + "donchian_upper" => [result.2], + } + .map_err(|e| PyErr::new::(e.to_string()))?; + + Ok(PyDfStubbed(pyo3_polars::PyDataFrame(df))) + } + + /// Keltner Channel - Bands based on moving average and average true range + /// + /// # Parameters + /// - `high_column`: &str - Name of the high price column + /// - `low_column`: &str - Name of the low price column + /// - `close_column`: &str - Name of the close price column + /// - `constant_model_type`: &str - Type of moving average for center line (e.g., "sma", "ema", "wma") + /// - `atr_constant_model_type`: &str - Type of moving average for ATR calculation (e.g., "sma", "ema", "wma") + /// - `multiplier`: f64 - Multiplier for the ATR to create channel width + /// + /// # Returns + /// DataFrame with columns: + /// - `keltner_lower`: f64 - Lower channel (moving average - ATR * multiplier) + /// - `keltner_middle`: f64 - Middle channel (moving average) + /// - `keltner_upper`: f64 - Upper channel (moving average + ATR * multiplier) + fn keltner_channel_single( + &self, + high_column: &str, + low_column: &str, + close_column: &str, + constant_model_type: &str, + atr_constant_model_type: &str, + multiplier: f64, + ) -> PyResult { + let df = self + .lf + .clone() + .select([col(high_column), col(low_column), col(close_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let close_series = df + .column(close_column) + .map_err(|e| PyErr::new::(format!("Column '{close_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{close_column}' could not be converted to Series")))? + .clone(); + + let high_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let low_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + let close_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(close_series)))?; + let constant_type = parse_constant_model_type(constant_model_type)?; + let atr_constant_type = parse_constant_model_type(atr_constant_model_type)?; + let result = rust_ti::candle_indicators::single::keltner_channel(&high_values, &low_values, &close_values, constant_type, atr_constant_type, multiplier); + + let df = df! { + "keltner_lower" => [result.0], + "keltner_middle" => [result.1], + "keltner_upper" => [result.2], + } + .map_err(|e| PyErr::new::(e.to_string()))?; + + Ok(PyDfStubbed(pyo3_polars::PyDataFrame(df))) + } + + /// Supertrend - Trend indicator showing support and resistance levels + /// + /// # Parameters + /// - `high_column`: &str - Name of the high price column + /// - `low_column`: &str - Name of the low price column + /// - `close_column`: &str - Name of the close price column + /// - `constant_model_type`: &str - Type of moving average for ATR calculation (e.g., "sma", "ema", "wma") + /// - `multiplier`: f64 - Multiplier for the ATR to determine trend sensitivity + /// + /// # Returns + /// Series containing: + /// - `supertrend`: f64 - Supertrend value (support/resistance level based on trend direction) + fn supertrend_single( + &self, + high_column: &str, + low_column: &str, + close_column: &str, + constant_model_type: &str, + multiplier: f64, + ) -> PyResult { + let df = self + .lf + .clone() + .select([col(high_column), col(low_column), col(close_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let close_series = df + .column(close_column) + .map_err(|e| PyErr::new::(format!("Column '{close_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{close_column}' could not be converted to Series")))? + .clone(); + + let high_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let low_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + let close_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(close_series)))?; + let constant_type = parse_constant_model_type(constant_model_type)?; + let result = rust_ti::candle_indicators::single::supertrend(&high_values, &low_values, &close_values, constant_type, multiplier); + + let result_series = Series::new("supertrend".into(), vec![result]); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + // Bulk functions that return multiple values over time + + /// Moving Constant Envelopes (Bulk) - Returns envelopes over time periods + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `constant_model_type`: &str - Type of moving average (e.g., "sma", "ema", "wma") + /// - `difference`: f64 - Fixed difference value to create envelope bands + /// - `period`: usize - Rolling window period for calculations + /// + /// # Returns + /// DataFrame with columns: + /// - `lower_envelope`: Vec - Time series of lower envelope bands + /// - `middle_envelope`: Vec - Time series of middle lines (moving averages) + /// - `upper_envelope`: Vec - Time series of upper envelope bands + fn moving_constant_envelopes_bulk(&self, price_column: &str, constant_model_type: &str, difference: f64, period: usize) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let constant_type = parse_constant_model_type(constant_model_type)?; + let results = rust_ti::candle_indicators::bulk::moving_constant_envelopes(&values, constant_type, difference, period); + + let (lower_vals, middle_vals, upper_vals) = unzip_triple(results); + create_triple_df(lower_vals, middle_vals, upper_vals, "lower_envelope", "middle_envelope", "upper_envelope") + } + + //// McGinley Dynamic Envelopes (Bulk) + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `difference`: f64 - Fixed difference value to create envelope bands + /// - `previous_mcginley_dynamic`: f64 - Initial McGinley Dynamic value for calculation + /// - `period`: usize - Rolling window period for calculations + /// + /// # Returns + /// DataFrame with columns: + /// - `lower_envelope`: Vec - Time series of lower envelope bands + /// - `mcginley_dynamic`: Vec - Time series of McGinley Dynamic values + /// - `upper_envelope`: Vec - Time series of upper envelope bands + fn mcginley_dynamic_envelopes_bulk(&self, price_column: &str, difference: f64, previous_mcginley_dynamic: f64, period: usize) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let results = rust_ti::candle_indicators::bulk::mcginley_dynamic_envelopes(&values, difference, previous_mcginley_dynamic, period); + + let (lower_vals, middle_vals, upper_vals) = unzip_triple(results); + create_triple_df(lower_vals, middle_vals, upper_vals, "lower_envelope", "mcginley_dynamic", "upper_envelope") + } + + /// Moving Constant Bands (Bulk) + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `constant_model_type`: &str - Type of moving average for center line (e.g., "sma", "ema", "wma") + /// - `deviation_model`: &str - Type of deviation calculation (e.g., "std", "mad") + /// - `deviation_multiplier`: f64 - Multiplier for the deviation to create bands + /// - `period`: usize - Rolling window period for calculations + /// + /// # Returns + /// DataFrame with columns: + /// - `lower_band`: Vec - Time series of lower bands + /// - `middle_band`: Vec - Time series of middle bands (moving averages) + /// - `upper_band`: Vec - Time series of upper bands + fn moving_constant_bands_bulk( + &self, + price_column: &str, + constant_model_type: &str, + deviation_model: &str, + deviation_multiplier: f64, + period: usize, + ) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let constant_type = parse_constant_model_type(constant_model_type)?; + let deviation_type = parse_deviation_model(deviation_model)?; + let results = rust_ti::candle_indicators::bulk::moving_constant_bands(&values, constant_type, deviation_type, deviation_multiplier, period); + + let (lower_vals, middle_vals, upper_vals) = unzip_triple(results); + create_triple_df(lower_vals, middle_vals, upper_vals, "lower_band", "middle_band", "upper_band") + } + + /// McGinley Dynamic Bands (Bulk) + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `deviation_model`: &str - Type of deviation calculation (e.g., "std", "mad") + /// - `deviation_multiplier`: f64 - Multiplier for the deviation to create bands + /// - `previous_mcginley_dynamic`: f64 - Initial McGinley Dynamic value for calculation + /// - `period`: usize - Rolling window period for calculations + /// + /// # Returns + /// DataFrame with columns: + /// - `lower_band`: Vec - Time series of lower bands + /// - `mcginley_dynamic`: Vec - Time series of McGinley Dynamic values + /// - `upper_band`: Vec - Time series of upper bands + fn mcginley_dynamic_bands_bulk( + &self, + price_column: &str, + deviation_model: &str, + deviation_multiplier: f64, + previous_mcginley_dynamic: f64, + period: usize, + ) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let deviation_type = parse_deviation_model(deviation_model)?; + let results = rust_ti::candle_indicators::bulk::mcginley_dynamic_bands(&values, deviation_type, deviation_multiplier, previous_mcginley_dynamic, period); + + let (lower_vals, middle_vals, upper_vals) = unzip_triple(results); + create_triple_df(lower_vals, middle_vals, upper_vals, "lower_band", "mcginley_dynamic", "upper_band") + } + + /// Ichimoku Cloud (Bulk) - Returns ichimoku components over time + /// + /// # Parameters + /// - `high_column`: &str - Name of the high price column + /// - `low_column`: &str - Name of the low price column + /// - `close_column`: &str - Name of the close price column + /// - `conversion_period`: usize - Period for conversion line calculation (typically 9) + /// - `base_period`: usize - Period for base line calculation (typically 26) + /// - `span_b_period`: usize - Period for leading span B calculation (typically 52) + /// + /// # Returns + /// DataFrame with columns: + /// - `leading_span_a`: Vec - Time series of Leading Span A values + /// - `leading_span_b`: Vec - Time series of Leading Span B values + /// - `base_line`: Vec - Time series of Base Line (Kijun-sen) values + /// - `conversion_line`: Vec - Time series of Conversion Line (Tenkan-sen) values + /// - `lagged_price`: Vec - Time series of Lagging Span (Chikou Span) values + fn ichimoku_cloud_bulk( + &self, + high_column: &str, + low_column: &str, + close_column: &str, + conversion_period: usize, + base_period: usize, + span_b_period: usize, + ) -> PyResult { + let df = self + .lf + .clone() + .select([col(high_column), col(low_column), col(close_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let close_series = df + .column(close_column) + .map_err(|e| PyErr::new::(format!("Column '{close_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{close_column}' could not be converted to Series")))? + .clone(); + + let high_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let low_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + let close_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(close_series)))?; + let ichimoku_result = + rust_ti::candle_indicators::bulk::ichimoku_cloud(&high_values, &low_values, &close_values, conversion_period, base_period, span_b_period); + + let capacity = ichimoku_result.len(); + let mut leading_span_a = Vec::with_capacity(capacity); + let mut leading_span_b = Vec::with_capacity(capacity); + let mut base_line = Vec::with_capacity(capacity); + let mut conversion_line = Vec::with_capacity(capacity); + let mut lagged_price = Vec::with_capacity(capacity); + + for (a, b, c, d, e) in ichimoku_result { + leading_span_a.push(a); + leading_span_b.push(b); + base_line.push(c); + conversion_line.push(d); + lagged_price.push(e); + } + + let df = df! { + "leading_span_a" => leading_span_a, + "leading_span_b" => leading_span_b, + "base_line" => base_line, + "conversion_line" => conversion_line, + "lagged_price" => lagged_price, + } + .map_err(|e| PyErr::new::(e.to_string()))?; + + Ok(PyDfStubbed(pyo3_polars::PyDataFrame(df))) + } + + /// Donchian Channels (Bulk) - Returns donchian bands over time + /// + /// # Parameters + /// - `high_column`: &str - Name of the high price column + /// - `low_column`: &str - Name of the low price column + /// - `period`: usize - Rolling window period for channel calculation + /// + /// # Returns + /// DataFrame with columns: + /// - `lower_band`: Vec - Time series of lower channels (lowest lows) + /// - `middle_band`: Vec - Time series of middle channels (averages) + /// - `upper_band`: Vec - Time series of upper channels (highest highs) + fn donchian_channels_bulk(&self, high_column: &str, low_column: &str, period: usize) -> PyResult { + let df = self + .lf + .clone() + .select([col(high_column), col(low_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let highs_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let lows_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + let donchian_result = rust_ti::candle_indicators::bulk::donchian_channels(&highs_values, &lows_values, period); + + let (lower_band, middle_band, upper_band) = unzip_triple(donchian_result); + create_triple_df(lower_band, middle_band, upper_band, "lower_band", "middle_band", "upper_band") + } + + /// Keltner Channel (Bulk) - Returns keltner bands over time + /// + /// # Parameters + /// - `high_column`: &str - Name of the high price column + /// - `low_column`: &str - Name of the low price column + /// - `close_column`: &str - Name of the close price column + /// - `constant_model_type`: &str - Type of moving average for center line (e.g., "sma", "ema", "wma") + /// - `atr_constant_model_type`: &str - Type of moving average for ATR calculation (e.g., "sma", "ema", "wma") + /// - `multiplier`: f64 - Multiplier for the ATR to create channel width + /// - `period`: usize - Rolling window period for calculations + /// + /// # Returns + /// DataFrame with columns: + /// - `lower_band`: Vec - Time series of lower channels + /// - `middle_band`: Vec - Time series of middle channels (moving averages) + /// - `upper_band`: Vec - Time series of upper channels + #[allow(clippy::too_many_arguments)] + fn keltner_channel_bulk( + &self, + high_column: &str, + low_column: &str, + close_column: &str, + constant_model_type: &str, + atr_constant_model_type: &str, + multiplier: f64, + period: usize, + ) -> PyResult { + let df = self + .lf + .clone() + .select([col(high_column), col(low_column), col(close_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let close_series = df + .column(close_column) + .map_err(|e| PyErr::new::(format!("Column '{close_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{close_column}' could not be converted to Series")))? + .clone(); + + let high_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let low_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + let close_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(close_series)))?; + let constant_type = parse_constant_model_type(constant_model_type)?; + let atr_constant_type = parse_constant_model_type(atr_constant_model_type)?; + let keltner_result = + rust_ti::candle_indicators::bulk::keltner_channel(&high_values, &low_values, &close_values, constant_type, atr_constant_type, multiplier, period); + + let (lower_band, middle_band, upper_band) = unzip_triple(keltner_result); + create_triple_df(lower_band, middle_band, upper_band, "lower_band", "middle_band", "upper_band") + } + + /// Supertrend (Bulk) - Returns supertrend values over time + /// + /// # Parameters + /// - `high_column`: &str - Name of the high price column + /// - `low_column`: &str - Name of the low price column + /// - `close_column`: &str - Name of the close price column + /// - `constant_model_type`: &str - Type of moving average for ATR calculation (e.g., "sma", "ema", "wma") + /// - `multiplier`: f64 - Multiplier for the ATR to determine trend sensitivity + /// - `period`: usize - Rolling window period for ATR calculation + /// + /// # Returns + /// Series containing: + /// - `supertrend`: Vec - Time series of supertrend values (support/resistance levels) + fn supertrend_bulk( + &self, + high_column: &str, + low_column: &str, + close_column: &str, + constant_model_type: &str, + multiplier: f64, + period: usize, + ) -> PyResult { + let df = self + .lf + .clone() + .select([col(high_column), col(low_column), col(close_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let close_series = df + .column(close_column) + .map_err(|e| PyErr::new::(format!("Column '{close_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{close_column}' could not be converted to Series")))? + .clone(); + + let high_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let low_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + let close_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(close_series)))?; + let constant_type = parse_constant_model_type(constant_model_type)?; + let supertrend_result = rust_ti::candle_indicators::bulk::supertrend(&high_values, &low_values, &close_values, constant_type, multiplier, period); + + let result_series = Series::new("supertrend".into(), supertrend_result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use approx::assert_abs_diff_eq; + use ezpz_stubz::lazy::PyLfStubbed; + + fn create_test_ohlc_dataframe() -> LazyFrame { + let high = vec![105.0, 110.0, 108.0, 112.0, 115.0, 118.0, 120.0, 122.0, 125.0, 128.0]; + let low = vec![95.0, 98.0, 96.0, 100.0, 103.0, 106.0, 108.0, 110.0, 113.0, 116.0]; + let close = vec![100.0, 105.0, 102.0, 108.0, 112.0, 115.0, 118.0, 120.0, 122.0, 125.0]; + df! { + "high" => high, + "low" => low, + "close" => close.clone(), + "price" => close + } + .unwrap() + .lazy() + } + + fn create_candle_ti() -> CandleTI { + let lf = create_test_ohlc_dataframe(); + CandleTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(lf))) + } + + #[test] + fn test_moving_constant_envelopes_single() { + let ti = create_candle_ti(); + let result = ti.moving_constant_envelopes_single("price", "sma", 5.0).unwrap(); + let df = result.0.0; + + assert_eq!(df.get_column_names(), vec!["lower_envelope", "middle_envelope", "upper_envelope"]); + assert_eq!(df.height(), 1); + + let lower = df.column("lower_envelope").unwrap().get(0).unwrap(); + let middle = df.column("middle_envelope").unwrap().get(0).unwrap(); + let upper = df.column("upper_envelope").unwrap().get(0).unwrap(); + + if let (Ok(lower_val), Ok(middle_val), Ok(upper_val)) = (lower.try_extract::(), middle.try_extract::(), upper.try_extract::()) { + assert_abs_diff_eq!(upper_val - middle_val, 5.0, epsilon = 1e-10); + assert_abs_diff_eq!(middle_val - lower_val, 5.0, epsilon = 1e-10); + } + } + + #[test] + fn test_mcginley_dynamic_envelopes_single() { + let ti = create_candle_ti(); + let result = ti.mcginley_dynamic_envelopes_single("price", 3.0, 100.0).unwrap(); + let df = result.0.0; + + assert_eq!(df.get_column_names(), vec!["lower_envelope", "mcginley_dynamic", "upper_envelope"]); + assert_eq!(df.height(), 1); + + let lower = df.column("lower_envelope").unwrap().get(0).unwrap(); + let mcginley = df.column("mcginley_dynamic").unwrap().get(0).unwrap(); + let upper = df.column("upper_envelope").unwrap().get(0).unwrap(); + + if let (Ok(lower_val), Ok(mcginley_val), Ok(upper_val)) = (lower.try_extract::(), mcginley.try_extract::(), upper.try_extract::()) { + assert_abs_diff_eq!(upper_val - mcginley_val, 3.0, epsilon = 1e-10); + assert_abs_diff_eq!(mcginley_val - lower_val, 3.0, epsilon = 1e-10); + } + } + + #[test] + fn test_moving_constant_bands_single() { + let ti = create_candle_ti(); + let result = ti.moving_constant_bands_single("price", "sma", "std", 2.0).unwrap(); + let df = result.0.0; + + assert_eq!(df.get_column_names(), vec!["lower_band", "middle_band", "upper_band"]); + assert_eq!(df.height(), 1); + + let lower = df.column("lower_band").unwrap().get(0).unwrap(); + let middle = df.column("middle_band").unwrap().get(0).unwrap(); + let upper = df.column("upper_band").unwrap().get(0).unwrap(); + + assert!(lower.try_extract::().is_ok()); + assert!(middle.try_extract::().is_ok()); + assert!(upper.try_extract::().is_ok()); + } + + #[test] + fn test_mcginley_dynamic_bands_single() { + let ti = create_candle_ti(); + let result = ti.mcginley_dynamic_bands_single("price", "std", 1.5, 110.0).unwrap(); + let df = result.0.0; + + assert_eq!(df.get_column_names(), vec!["lower_band", "mcginley_dynamic", "upper_band"]); + assert_eq!(df.height(), 1); + + let lower = df.column("lower_band").unwrap().get(0).unwrap(); + let mcginley = df.column("mcginley_dynamic").unwrap().get(0).unwrap(); + let upper = df.column("upper_band").unwrap().get(0).unwrap(); + + assert!(lower.try_extract::().is_ok()); + assert!(mcginley.try_extract::().is_ok()); + assert!(upper.try_extract::().is_ok()); + } + + #[test] + fn test_ichimoku_cloud_single() { + let ti = create_candle_ti(); + let result = ti.ichimoku_cloud_single("high", "low", "close", 9, 26, 52).unwrap(); + let df = result.0.0; + + assert_eq!(df.get_column_names(), vec!["leading_span_a", "leading_span_b", "base_line", "conversion_line", "lagged_price"]); + assert_eq!(df.height(), 1); + + let leading_span_a = df.column("leading_span_a").unwrap().get(0).unwrap(); + let leading_span_b = df.column("leading_span_b").unwrap().get(0).unwrap(); + let base_line = df.column("base_line").unwrap().get(0).unwrap(); + let conversion_line = df.column("conversion_line").unwrap().get(0).unwrap(); + let lagged_price = df.column("lagged_price").unwrap().get(0).unwrap(); + + assert!(leading_span_a.try_extract::().is_ok()); + assert!(leading_span_b.try_extract::().is_ok()); + assert!(base_line.try_extract::().is_ok()); + assert!(conversion_line.try_extract::().is_ok()); + assert!(lagged_price.try_extract::().is_ok()); + } + + #[test] + fn test_donchian_channels_single() { + let ti = create_candle_ti(); + let result = ti.donchian_channels_single("high", "low").unwrap(); + let df = result.0.0; + + assert_eq!(df.get_column_names(), vec!["donchian_lower", "donchian_middle", "donchian_upper"]); + assert_eq!(df.height(), 1); + + let lower = df.column("donchian_lower").unwrap().get(0).unwrap(); + let middle = df.column("donchian_middle").unwrap().get(0).unwrap(); + let upper = df.column("donchian_upper").unwrap().get(0).unwrap(); + + if let (Ok(lower_val), Ok(middle_val), Ok(upper_val)) = (lower.try_extract::(), middle.try_extract::(), upper.try_extract::()) { + assert!(lower_val <= middle_val); + assert!(middle_val <= upper_val); + } + } + + #[test] + fn test_keltner_channel_single() { + let ti = create_candle_ti(); + let result = ti.keltner_channel_single("high", "low", "close", "sma", "sma", 2.0).unwrap(); + let df = result.0.0; + + assert_eq!(df.get_column_names(), vec!["keltner_lower", "keltner_middle", "keltner_upper"]); + assert_eq!(df.height(), 1); + + let lower = df.column("keltner_lower").unwrap().get(0).unwrap(); + let middle = df.column("keltner_middle").unwrap().get(0).unwrap(); + let upper = df.column("keltner_upper").unwrap().get(0).unwrap(); + + if let (Ok(lower_val), Ok(middle_val), Ok(upper_val)) = (lower.try_extract::(), middle.try_extract::(), upper.try_extract::()) { + assert!(lower_val <= middle_val); + assert!(middle_val <= upper_val); + } + } + + #[test] + fn test_supertrend_single() { + let ti = create_candle_ti(); + let result = ti.supertrend_single("high", "low", "close", "sma", 3.0).unwrap(); + let series = result.0.0; + + assert_eq!(series.name(), "supertrend"); + assert_eq!(series.len(), 1); + assert!(series.get(0).unwrap().try_extract::().is_ok()); + } + + #[test] + fn test_moving_constant_envelopes_bulk() { + let ti = create_candle_ti(); + let result = ti.moving_constant_envelopes_bulk("price", "sma", 5.0, 3).unwrap(); + let df = result.0.0; + + assert_eq!(df.get_column_names(), vec!["lower_envelope", "middle_envelope", "upper_envelope"]); + assert!(df.height() > 0); + + let lower_col = df.column("lower_envelope").unwrap(); + let middle_col = df.column("middle_envelope").unwrap(); + let upper_col = df.column("upper_envelope").unwrap(); + + for i in 0..df.height() { + if let (Ok(lower_val), Ok(middle_val), Ok(upper_val)) = + (lower_col.get(i).unwrap().try_extract::(), middle_col.get(i).unwrap().try_extract::(), upper_col.get(i).unwrap().try_extract::()) + { + assert_abs_diff_eq!(upper_val - middle_val, 5.0, epsilon = 1e-10); + assert_abs_diff_eq!(middle_val - lower_val, 5.0, epsilon = 1e-10); + } + } + } + + #[test] + fn test_mcginley_dynamic_envelopes_bulk() { + let ti = create_candle_ti(); + let result = ti.mcginley_dynamic_envelopes_bulk("price", 3.0, 100.0, 5).unwrap(); + let df = result.0.0; + + assert_eq!(df.get_column_names(), vec!["lower_envelope", "mcginley_dynamic", "upper_envelope"]); + assert!(df.height() > 0); + + let lower_col = df.column("lower_envelope").unwrap(); + let mcginley_col = df.column("mcginley_dynamic").unwrap(); + let upper_col = df.column("upper_envelope").unwrap(); + + for i in 0..df.height() { + if let (Ok(lower_val), Ok(mcginley_val), Ok(upper_val)) = + (lower_col.get(i).unwrap().try_extract::(), mcginley_col.get(i).unwrap().try_extract::(), upper_col.get(i).unwrap().try_extract::()) + { + assert_abs_diff_eq!(upper_val - mcginley_val, 3.0, epsilon = 1e-10); + assert_abs_diff_eq!(mcginley_val - lower_val, 3.0, epsilon = 1e-10); + } + } + } + + #[test] + fn test_moving_constant_bands_bulk() { + let ti = create_candle_ti(); + let result = ti.moving_constant_bands_bulk("price", "sma", "std", 2.0, 5).unwrap(); + let df = result.0.0; + + assert_eq!(df.get_column_names(), vec!["lower_band", "middle_band", "upper_band"]); + assert!(df.height() > 0); + + let lower_col = df.column("lower_band").unwrap(); + let middle_col = df.column("middle_band").unwrap(); + let upper_col = df.column("upper_band").unwrap(); + + for i in 0..df.height() { + assert!(lower_col.get(i).unwrap().try_extract::().is_ok()); + assert!(middle_col.get(i).unwrap().try_extract::().is_ok()); + assert!(upper_col.get(i).unwrap().try_extract::().is_ok()); + } + } + + #[test] + fn test_mcginley_dynamic_bands_bulk() { + let ti = create_candle_ti(); + let result = ti.mcginley_dynamic_bands_bulk("price", "std", 1.5, 110.0, 5).unwrap(); + let df = result.0.0; + + assert_eq!(df.get_column_names(), vec!["lower_band", "mcginley_dynamic", "upper_band"]); + assert!(df.height() > 0); + + let lower_col = df.column("lower_band").unwrap(); + let mcginley_col = df.column("mcginley_dynamic").unwrap(); + let upper_col = df.column("upper_band").unwrap(); + + for i in 0..df.height() { + assert!(lower_col.get(i).unwrap().try_extract::().is_ok()); + assert!(mcginley_col.get(i).unwrap().try_extract::().is_ok()); + assert!(upper_col.get(i).unwrap().try_extract::().is_ok()); + } + } + + #[test] + fn test_ichimoku_cloud_bulk() { + let ti = create_candle_ti(); + let result = ti.ichimoku_cloud_bulk("high", "low", "close", 9, 26, 52).unwrap(); + let df = result.0.0; + + assert_eq!(df.get_column_names(), vec!["leading_span_a", "leading_span_b", "base_line", "conversion_line", "lagged_price"]); + assert!(df.height() > 0); + + for col_name in &["leading_span_a", "leading_span_b", "base_line", "conversion_line", "lagged_price"] { + let col = df.column(col_name).unwrap(); + for i in 0..df.height() { + assert!(col.get(i).unwrap().try_extract::().is_ok()); + } + } + } + + #[test] + fn test_donchian_channels_bulk() { + let ti = create_candle_ti(); + let result = ti.donchian_channels_bulk("high", "low", 5).unwrap(); + let df = result.0.0; + + assert_eq!(df.get_column_names(), vec!["lower_band", "middle_band", "upper_band"]); + assert!(df.height() > 0); + + let lower_col = df.column("lower_band").unwrap(); + let middle_col = df.column("middle_band").unwrap(); + let upper_col = df.column("upper_band").unwrap(); + + for i in 0..df.height() { + if let (Ok(lower_val), Ok(middle_val), Ok(upper_val)) = + (lower_col.get(i).unwrap().try_extract::(), middle_col.get(i).unwrap().try_extract::(), upper_col.get(i).unwrap().try_extract::()) + { + assert!(lower_val <= middle_val); + assert!(middle_val <= upper_val); + } + } + } + + #[test] + fn test_keltner_channel_bulk() { + let ti = create_candle_ti(); + let result = ti.keltner_channel_bulk("high", "low", "close", "sma", "sma", 2.0, 5).unwrap(); + let df = result.0.0; + + assert_eq!(df.get_column_names(), vec!["lower_band", "middle_band", "upper_band"]); + assert!(df.height() > 0); + + let lower_col = df.column("lower_band").unwrap(); + let middle_col = df.column("middle_band").unwrap(); + let upper_col = df.column("upper_band").unwrap(); + + for i in 0..df.height() { + if let (Ok(lower_val), Ok(middle_val), Ok(upper_val)) = + (lower_col.get(i).unwrap().try_extract::(), middle_col.get(i).unwrap().try_extract::(), upper_col.get(i).unwrap().try_extract::()) + { + assert!(lower_val <= middle_val); + assert!(middle_val <= upper_val); + } + } + } + + #[test] + fn test_supertrend_bulk() { + let ti = create_candle_ti(); + let result = ti.supertrend_bulk("high", "low", "close", "sma", 3.0, 5).unwrap(); + let series = result.0.0; + + assert_eq!(series.name(), "supertrend"); + assert!(!series.is_empty()); + + for i in 0..series.len() { + assert!(series.get(i).unwrap().try_extract::().is_ok()); + } + } + + #[test] + fn test_invalid_column_name() { + let ti = create_candle_ti(); + let result = ti.moving_constant_envelopes_single("invalid_column", "sma", 5.0); + assert!(result.is_err()); + } + + #[test] + fn test_invalid_model_type() { + let ti = create_candle_ti(); + let result = ti.moving_constant_envelopes_single("price", "invalid_model", 5.0); + assert!(result.is_err()); + } + + #[test] + fn test_zero_period_bulk() { + let ti = create_candle_ti(); + let result = ti.moving_constant_envelopes_bulk("price", "sma", 5.0, 0); + assert!(result.is_err()); + } + + #[test] + fn test_envelope_difference_validation() { + let ti = create_candle_ti(); + let result = ti.moving_constant_envelopes_single("price", "sma", 0.0).unwrap(); + let df = result.0.0; + + let lower = df.column("lower_envelope").unwrap().get(0).unwrap(); + let middle = df.column("middle_envelope").unwrap().get(0).unwrap(); + let upper = df.column("upper_envelope").unwrap().get(0).unwrap(); + + if let (Ok(lower_val), Ok(middle_val), Ok(upper_val)) = (lower.try_extract::(), middle.try_extract::(), upper.try_extract::()) { + assert_abs_diff_eq!(lower_val, middle_val, epsilon = 1e-10); + assert_abs_diff_eq!(middle_val, upper_val, epsilon = 1e-10); + } + } + + #[test] + fn test_different_ma_types() { + let ti = create_candle_ti(); + + let sma_result = ti.moving_constant_envelopes_single("price", "sma", 5.0).unwrap(); + let ema_result = ti.moving_constant_envelopes_single("price", "ema", 5.0).unwrap(); + let wma_result = ti.moving_constant_envelopes_single("price", "wma", 5.0).unwrap(); + + let sma_df = sma_result.0.0; + let ema_df = ema_result.0.0; + let wma_df = wma_result.0.0; + + assert_eq!(sma_df.get_column_names(), vec!["lower_envelope", "middle_envelope", "upper_envelope"]); + assert_eq!(ema_df.get_column_names(), vec!["lower_envelope", "middle_envelope", "upper_envelope"]); + assert_eq!(wma_df.get_column_names(), vec!["lower_envelope", "middle_envelope", "upper_envelope"]); + } +} diff --git a/plugins/ezpz-rust-ti/src/indicators/chart/mod.rs b/plugins/ezpz-rust-ti/src/indicators/chart/mod.rs new file mode 100644 index 0000000..fcb9152 --- /dev/null +++ b/plugins/ezpz-rust-ti/src/indicators/chart/mod.rs @@ -0,0 +1,432 @@ +use { + crate::utils::extract_f64_values, + ezpz_stubz::{lazy::PyLfStubbed, series::PySeriesStubbed}, + polars::prelude::*, + pyo3::prelude::*, + pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}, +}; + +/// Chart Trends Technical Indicators - A collection of chart trend analysis functions for financial data + +#[gen_stub_pyclass] +#[pyclass] +#[derive(Clone)] +pub struct ChartTrendsTI { + lf: LazyFrame, +} + +#[gen_stub_pymethods] +#[pymethods] +impl ChartTrendsTI { + #[new] + fn new(lf: PyLfStubbed) -> Self { + Self { lf: lf.0.into() } + } + + /// Find peaks in a price series over a given period + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `period`: usize - Period length for peak detection + /// - `closest_neighbor`: usize - Minimum distance between peaks + /// + /// # Returns + /// Vec<(f64, usize)> - List of tuples containing: + /// - `peak_value`: The price value at the peak + /// - `peak_index`: The index position of the peak in the series + fn peaks(&self, price_column: &str, period: usize, closest_neighbor: usize) -> PyResult> { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::chart_trends::peaks(&values, period, closest_neighbor); + Ok(result) + } + + /// Find valleys in a price series over a given period + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `period`: usize - Period length for valley detection + /// - `closest_neighbor`: usize - Minimum distance between valleys + /// + /// # Returns + /// Vec<(f64, usize)> - List of tuples containing: + /// - `valley_value`: The price value at the valley + /// - `valley_index`: The index position of the valley in the series + fn valleys(&self, price_column: &str, period: usize, closest_neighbor: usize) -> PyResult> { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::chart_trends::valleys(&values, period, closest_neighbor); + Ok(result) + } + + /// Calculate peak trend (linear regression on peaks) + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `period`: usize - Period length for peak detection + /// + /// # Returns + /// Tuple of (slope: f64, intercept: f64) + /// - `slope`: The slope of the linear regression line through peaks + /// - `intercept`: The y-intercept of the linear regression line + fn peak_trend(&self, price_column: &str, period: usize) -> PyResult<(f64, f64)> { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::chart_trends::peak_trend(&values, period); + Ok(result) + } + + /// Calculate valley trend (linear regression on valleys) + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `period`: usize - Period length for valley detection + /// + /// # Returns + /// Tuple of (slope: f64, intercept: f64) + /// - `slope`: The slope of the linear regression line through valleys + /// - `intercept`: The y-intercept of the linear regression line + fn valley_trend(&self, price_column: &str, period: usize) -> PyResult<(f64, f64)> { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::chart_trends::valley_trend(&values, period); + Ok(result) + } + + /// Calculate overall trend (linear regression on all prices) + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// + /// # Returns + /// Tuple of (slope: f64, intercept: f64) + /// - `slope`: The slope of the linear regression line through all price points + /// - `intercept`: The y-intercept of the linear regression line + fn overall_trend(&self, price_column: &str) -> PyResult<(f64, f64)> { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::chart_trends::overall_trend(&values); + Ok(result) + } + + /// Break down trends in a price series + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `max_outliers`: usize - Maximum number of outliers allowed + /// - `soft_r_squared_minimum`: f64 - Soft minimum threshold for R-squared value + /// - `soft_r_squared_maximum`: f64 - Soft maximum threshold for R-squared value + /// - `hard_r_squared_minimum`: f64 - Hard minimum threshold for R-squared value + /// - `hard_r_squared_maximum`: f64 - Hard maximum threshold for R-squared value + /// - `soft_standard_error_multiplier`: f64 - Soft multiplier for standard error threshold + /// - `hard_standard_error_multiplier`: f64 - Hard multiplier for standard error threshold + /// - `soft_reduced_chi_squared_multiplier`: f64 - Soft multiplier for reduced chi-squared threshold + /// - `hard_reduced_chi_squared_multiplier`: f64 - Hard multiplier for reduced chi-squared threshold + /// + /// # Returns + /// Vec<(usize, usize, f64, f64)> - List of tuples containing: + /// - `start_index`: Starting index of the trend segment + /// - `end_index`: Ending index of the trend segment + /// - `slope`: The slope of the linear regression for this trend segment + /// - `intercept`: The y-intercept of the linear regression for this trend segment + #[allow(clippy::too_many_arguments)] + fn break_down_trends( + &self, + price_column: &str, + max_outliers: usize, + soft_r_squared_minimum: f64, + soft_r_squared_maximum: f64, + hard_r_squared_minimum: f64, + hard_r_squared_maximum: f64, + soft_standard_error_multiplier: f64, + hard_standard_error_multiplier: f64, + soft_reduced_chi_squared_multiplier: f64, + hard_reduced_chi_squared_multiplier: f64, + ) -> PyResult> { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::chart_trends::break_down_trends( + &values, + max_outliers, + soft_r_squared_minimum, + soft_r_squared_maximum, + hard_r_squared_minimum, + hard_r_squared_maximum, + soft_standard_error_multiplier, + hard_standard_error_multiplier, + soft_reduced_chi_squared_multiplier, + hard_reduced_chi_squared_multiplier, + ); + Ok(result) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use approx::assert_abs_diff_eq; + use ezpz_stubz::lazy::PyLfStubbed; + + fn create_test_dataframe() -> LazyFrame { + let data = vec![1.0, 3.0, 2.0, 5.0, 4.0, 7.0, 6.0, 9.0, 8.0, 10.0]; + df! { + "price" => data, + "volume" => vec![100.0, 200.0, 150.0, 300.0, 250.0, 180.0, 220.0, 190.0, 280.0, 320.0] + } + .unwrap() + .lazy() + } + + fn create_peak_valley_dataframe() -> LazyFrame { + let data = vec![1.0, 5.0, 2.0, 8.0, 3.0, 9.0, 4.0, 6.0, 7.0, 10.0]; + df! { + "price" => data + } + .unwrap() + .lazy() + } + + fn create_trending_dataframe() -> LazyFrame { + let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + df! { + "price" => data + } + .unwrap() + .lazy() + } + + fn create_chart_trends_ti() -> ChartTrendsTI { + let lf = create_test_dataframe(); + ChartTrendsTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(lf))) + } + + fn create_peak_valley_ti() -> ChartTrendsTI { + let lf = create_peak_valley_dataframe(); + ChartTrendsTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(lf))) + } + + fn create_trending_ti() -> ChartTrendsTI { + let lf = create_trending_dataframe(); + ChartTrendsTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(lf))) + } + + #[test] + fn test_peaks_basic() { + let ti = create_peak_valley_ti(); + let result = ti.peaks("price", 2, 1).unwrap(); + assert!(!result.is_empty()); + for (value, index) in result { + assert!(value > 0.0); + assert!(index < 10); + } + } + + #[test] + fn test_peaks_with_period() { + let ti = create_chart_trends_ti(); + let result = ti.peaks("price", 3, 2).unwrap(); + for (value, index) in result { + assert!((1.0..=10.0).contains(&value)); + assert!(index < 10); + } + } + + #[test] + fn test_valleys_basic() { + let ti = create_peak_valley_ti(); + let result = ti.valleys("price", 2, 1).unwrap(); + assert!(!result.is_empty()); + for (value, index) in result { + assert!(value > 0.0); + assert!(index < 10); + } + } + + #[test] + fn test_valleys_with_period() { + let ti = create_chart_trends_ti(); + let result = ti.valleys("price", 3, 2).unwrap(); + for (value, index) in result { + assert!((1.0..=10.0).contains(&value)); + assert!(index < 10); + } + } + + #[test] + fn test_peak_trend() { + let ti = create_trending_ti(); + let result = ti.peak_trend("price", 2).unwrap(); + let (slope, intercept) = result; + assert!(slope.is_finite()); + assert!(intercept.is_finite()); + } + + #[test] + fn test_valley_trend() { + let ti = create_trending_ti(); + let result = ti.valley_trend("price", 2).unwrap(); + let (slope, intercept) = result; + assert!(slope.is_finite()); + assert!(intercept.is_finite()); + } + + #[test] + fn test_overall_trend_upward() { + let ti = create_trending_ti(); + let result = ti.overall_trend("price").unwrap(); + let (slope, intercept) = result; + assert!(slope > 0.0); + assert!(intercept.is_finite()); + assert_abs_diff_eq!(slope, 1.0, epsilon = 1e-10); + } + + #[test] + fn test_overall_trend_calculation() { + let ti = create_chart_trends_ti(); + let result = ti.overall_trend("price").unwrap(); + let (slope, intercept) = result; + assert!(slope.is_finite()); + assert!(intercept.is_finite()); + } + + #[test] + fn test_break_down_trends_basic() { + let ti = create_trending_ti(); + let result = ti.break_down_trends("price", 2, 0.5, 0.95, 0.3, 0.98, 2.0, 3.0, 2.0, 3.0).unwrap(); + + assert!(!result.is_empty()); + for (start, end, slope, intercept) in result { + assert!(start < end); + assert!(end <= 10); + assert!(slope.is_finite()); + assert!(intercept.is_finite()); + } + } + + #[test] + fn test_break_down_trends_with_outliers() { + let ti = create_chart_trends_ti(); + let result = ti.break_down_trends("price", 3, 0.4, 0.9, 0.2, 0.95, 1.5, 2.5, 1.5, 2.5).unwrap(); + + for (start, end, slope, intercept) in result { + assert!(start < end); + assert!(slope.is_finite()); + assert!(intercept.is_finite()); + } + } + + #[test] + fn test_invalid_column_name() { + let ti = create_chart_trends_ti(); + let result = ti.peaks("nonexistent", 2, 1); + assert!(result.is_err()); + } + + #[test] + fn test_peak_trend_with_different_periods() { + let ti = create_chart_trends_ti(); + let result1 = ti.peak_trend("price", 1).unwrap(); + let result2 = ti.peak_trend("price", 3).unwrap(); + + assert!(result1.0.is_finite() && result1.1.is_finite()); + assert!(result2.0.is_finite() && result2.1.is_finite()); + } + + #[test] + fn test_valley_trend_with_different_periods() { + let ti = create_chart_trends_ti(); + let result1 = ti.valley_trend("price", 1).unwrap(); + let result2 = ti.valley_trend("price", 3).unwrap(); + + assert!(result1.0.is_finite() && result1.1.is_finite()); + assert!(result2.0.is_finite() && result2.1.is_finite()); + } + + #[test] + fn test_peaks_empty_result_handling() { + let ti = create_chart_trends_ti(); + let result = ti.peaks("price", 10, 5).unwrap(); + // Should not panic even if no peaks found + for (value, index) in result { + assert!(value > 0.0); + assert!(index < 10); + } + } + + #[test] + fn test_valleys_empty_result_handling() { + let ti = create_chart_trends_ti(); + let result = ti.valleys("price", 10, 5).unwrap(); + // Should not panic even if no valleys found + for (value, index) in result { + assert!(value > 0.0); + assert!(index < 10); + } + } +} diff --git a/plugins/ezpz-rust-ti/src/indicators/correlation/mod.rs b/plugins/ezpz-rust-ti/src/indicators/correlation/mod.rs new file mode 100644 index 0000000..c0dcfd7 --- /dev/null +++ b/plugins/ezpz-rust-ti/src/indicators/correlation/mod.rs @@ -0,0 +1,382 @@ +use { + crate::utils::{extract_f64_values, parse_constant_model_type, parse_deviation_model}, + ezpz_stubz::{lazy::PyLfStubbed, series::PySeriesStubbed}, + polars::prelude::*, + pyo3::prelude::*, + pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}, +}; + +/// Correlation Technical Indicators - A collection of correlation analysis functions for financial data + +#[gen_stub_pyclass] +#[pyclass] +#[derive(Clone)] +pub struct CorrelationTI { + lf: LazyFrame, +} + +#[gen_stub_pymethods] +#[pymethods] +impl CorrelationTI { + #[new] + fn new(lf: PyLfStubbed) -> Self { + Self { lf: lf.0.into() } + } + + /// Correlation between two assets - Single value calculation + /// Calculates correlation between prices of two assets using specified models + /// Returns a single correlation value for the entire price series + /// + /// # Parameters + /// - `price_column_a`: &str - Name of the first asset's price column + /// - `price_column_b`: &str - Name of the second asset's price column + /// - `constant_model_type`: &str - Type of constant model to use for correlation calculation + /// - `deviation_model`: &str - Type of deviation model to use for correlation calculation + /// + /// # Returns + /// f64 - Single correlation coefficient between the two asset price series + fn correlate_asset_prices_single(&self, price_column_a: &str, price_column_b: &str, constant_model_type: &str, deviation_model: &str) -> PyResult { + let df = self + .lf + .clone() + .select([col(price_column_a), col(price_column_b)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let series_a = df + .column(price_column_a) + .map_err(|e| PyErr::new::(format!("Column '{price_column_a}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column_a}' could not be converted to Series")))? + .clone(); + + let series_b = df + .column(price_column_b) + .map_err(|e| PyErr::new::(format!("Column '{price_column_b}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column_b}' could not be converted to Series")))? + .clone(); + + let values_a: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series_a)))?; + let values_b: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series_b)))?; + + let constant_type = parse_constant_model_type(constant_model_type)?; + let deviation_type = parse_deviation_model(deviation_model)?; + let result = rust_ti::correlation_indicators::single::correlate_asset_prices(&values_a, &values_b, constant_type, deviation_type); + Ok(result) + } + + /// Correlation between two assets - Rolling/Bulk calculation + /// Calculates rolling correlation between prices of two assets using specified models + /// Returns a series of correlation values for each period window + /// + /// # Parameters + /// - `price_column_a`: &str - Name of the first asset's price column + /// - `price_column_b`: &str - Name of the second asset's price column + /// - `constant_model_type`: &str - Type of constant model to use for correlation calculation + /// - `deviation_model`: &str - Type of deviation model to use for correlation calculation + /// - `period`: usize - Rolling window size for correlation calculation + /// + /// # Returns + /// PySeriesStubbed - Series containing rolling correlation coefficients for each period window with name "correlation" + fn correlate_asset_prices_bulk( + &self, + price_column_a: &str, + price_column_b: &str, + constant_model_type: &str, + deviation_model: &str, + period: usize, + ) -> PyResult { + let df = self + .lf + .clone() + .select([col(price_column_a), col(price_column_b)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let series_a = df + .column(price_column_a) + .map_err(|e| PyErr::new::(format!("Column '{price_column_a}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column_a}' could not be converted to Series")))? + .clone(); + + let series_b = df + .column(price_column_b) + .map_err(|e| PyErr::new::(format!("Column '{price_column_b}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column_b}' could not be converted to Series")))? + .clone(); + + let values_a: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series_a)))?; + let values_b: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series_b)))?; + + let constant_type = parse_constant_model_type(constant_model_type)?; + let deviation_type = parse_deviation_model(deviation_model)?; + let result = rust_ti::correlation_indicators::bulk::correlate_asset_prices(&values_a, &values_b, constant_type, deviation_type, period); + let correlation_series = Series::new("correlation".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(correlation_series))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use approx::assert_abs_diff_eq; + use ezpz_stubz::lazy::PyLfStubbed; + + fn create_test_dataframe() -> LazyFrame { + let price_a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + let price_b = vec![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0, 18.0, 20.0]; + df! { + "price_a" => price_a, + "price_b" => price_b, + "volume" => vec![100.0, 200.0, 150.0, 300.0, 250.0, 180.0, 220.0, 190.0, 280.0, 320.0] + } + .unwrap() + .lazy() + } + + fn create_uncorrelated_dataframe() -> LazyFrame { + let price_a = vec![1.0, 3.0, 2.0, 5.0, 4.0, 7.0, 6.0, 9.0, 8.0, 10.0]; + let price_b = vec![10.0, 8.0, 9.0, 6.0, 7.0, 4.0, 5.0, 2.0, 3.0, 1.0]; + df! { + "price_a" => price_a, + "price_b" => price_b + } + .unwrap() + .lazy() + } + + fn create_negative_correlation_dataframe() -> LazyFrame { + let price_a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + let price_b = vec![10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0]; + df! { + "price_a" => price_a, + "price_b" => price_b + } + .unwrap() + .lazy() + } + + fn create_correlation_ti() -> CorrelationTI { + let lf = create_test_dataframe(); + CorrelationTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(lf))) + } + + fn create_uncorrelated_ti() -> CorrelationTI { + let lf = create_uncorrelated_dataframe(); + CorrelationTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(lf))) + } + + fn create_negative_correlation_ti() -> CorrelationTI { + let lf = create_negative_correlation_dataframe(); + CorrelationTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(lf))) + } + + #[test] + fn test_correlate_asset_prices_single_positive() { + let ti = create_correlation_ti(); + let result = ti.correlate_asset_prices_single("price_a", "price_b", "arithmetic", "population").unwrap(); + assert!(result > 0.9); + assert!(result <= 1.0); + } + + #[test] + fn test_correlate_asset_prices_single_negative() { + let ti = create_negative_correlation_ti(); + let result = ti.correlate_asset_prices_single("price_a", "price_b", "arithmetic", "population").unwrap(); + assert!(result < -0.9); + assert!(result >= -1.0); + } + + #[test] + fn test_correlate_asset_prices_single_uncorrelated() { + let ti = create_uncorrelated_ti(); + let result = ti.correlate_asset_prices_single("price_a", "price_b", "arithmetic", "population").unwrap(); + assert!(result.abs() < 0.5); + } + + #[test] + fn test_correlate_asset_prices_single_arithmetic_sample() { + let ti = create_correlation_ti(); + let result = ti.correlate_asset_prices_single("price_a", "price_b", "arithmetic", "sample").unwrap(); + assert!(result > 0.9); + assert!(result <= 1.0); + } + + #[test] + fn test_correlate_asset_prices_single_geometric_population() { + let ti = create_correlation_ti(); + let result = ti.correlate_asset_prices_single("price_a", "price_b", "geometric", "population").unwrap(); + assert!(result.is_finite()); + assert!((-1.0..=1.0).contains(&result)); + } + + #[test] + fn test_correlate_asset_prices_single_harmonic_sample() { + let ti = create_correlation_ti(); + let result = ti.correlate_asset_prices_single("price_a", "price_b", "harmonic", "sample").unwrap(); + assert!(result.is_finite()); + assert!((-1.0..=1.0).contains(&result)); + } + + #[test] + fn test_correlate_asset_prices_bulk_basic() { + let ti = create_correlation_ti(); + let result = ti.correlate_asset_prices_bulk("price_a", "price_b", "arithmetic", "population", 3).unwrap(); + let series = result.0.0; + assert!(!series.is_empty()); + assert_eq!(series.name(), "correlation"); + } + + #[test] + fn test_correlate_asset_prices_bulk_window_size() { + let ti = create_correlation_ti(); + let result = ti.correlate_asset_prices_bulk("price_a", "price_b", "arithmetic", "population", 5).unwrap(); + let series = result.0.0; + assert!(!series.is_empty()); + + if let Ok(values) = series.f64() { + for value in values.into_iter().flatten() { + assert!((-1.0..=1.0).contains(&value)); + } + } + } + + #[test] + fn test_correlate_asset_prices_bulk_different_models() { + let ti = create_correlation_ti(); + let result1 = ti.correlate_asset_prices_bulk("price_a", "price_b", "arithmetic", "population", 4).unwrap(); + let result2 = ti.correlate_asset_prices_bulk("price_a", "price_b", "geometric", "sample", 4).unwrap(); + + let series1 = result1.0.0; + let series2 = result2.0.0; + + assert!(!series1.is_empty()); + assert!(!series2.is_empty()); + assert_eq!(series1.name(), "correlation"); + assert_eq!(series2.name(), "correlation"); + } + + #[test] + fn test_correlate_asset_prices_bulk_large_window() { + let ti = create_correlation_ti(); + let result = ti.correlate_asset_prices_bulk("price_a", "price_b", "arithmetic", "population", 8).unwrap(); + let series = result.0.0; + assert!(!series.is_empty()); + } + + #[test] + fn test_correlate_asset_prices_bulk_small_window() { + let ti = create_correlation_ti(); + let result = ti.correlate_asset_prices_bulk("price_a", "price_b", "arithmetic", "population", 2).unwrap(); + let series = result.0.0; + assert!(!series.is_empty()); + } + + #[test] + fn test_correlate_asset_prices_single_invalid_column() { + let ti = create_correlation_ti(); + let result = ti.correlate_asset_prices_single("nonexistent", "price_b", "arithmetic", "population"); + assert!(result.is_err()); + } + + #[test] + fn test_correlate_asset_prices_bulk_invalid_column() { + let ti = create_correlation_ti(); + let result = ti.correlate_asset_prices_bulk("price_a", "nonexistent", "arithmetic", "population", 3); + assert!(result.is_err()); + } + + #[test] + fn test_correlate_asset_prices_single_invalid_constant_model() { + let ti = create_correlation_ti(); + let result = ti.correlate_asset_prices_single("price_a", "price_b", "invalid_model", "population"); + assert!(result.is_err()); + } + + #[test] + fn test_correlate_asset_prices_single_invalid_deviation_model() { + let ti = create_correlation_ti(); + let result = ti.correlate_asset_prices_single("price_a", "price_b", "arithmetic", "invalid_model"); + assert!(result.is_err()); + } + + #[test] + fn test_correlate_asset_prices_bulk_invalid_constant_model() { + let ti = create_correlation_ti(); + let result = ti.correlate_asset_prices_bulk("price_a", "price_b", "invalid_model", "population", 3); + assert!(result.is_err()); + } + + #[test] + fn test_correlate_asset_prices_bulk_invalid_deviation_model() { + let ti = create_correlation_ti(); + let result = ti.correlate_asset_prices_bulk("price_a", "price_b", "arithmetic", "invalid_model", 3); + assert!(result.is_err()); + } + + #[test] + fn test_correlate_asset_prices_single_same_column() { + let ti = create_correlation_ti(); + let result = ti.correlate_asset_prices_single("price_a", "price_a", "arithmetic", "population").unwrap(); + assert_abs_diff_eq!(result, 1.0, epsilon = 1e-10); + } + + #[test] + fn test_correlate_asset_prices_bulk_same_column() { + let ti = create_correlation_ti(); + let result = ti.correlate_asset_prices_bulk("price_a", "price_a", "arithmetic", "population", 3).unwrap(); + let series = result.0.0; + + if let Ok(values) = series.f64() { + for value in values.into_iter().flatten() { + assert_abs_diff_eq!(value, 1.0, epsilon = 1e-10); + } + } + } + + #[test] + fn test_correlate_asset_prices_single_all_model_combinations() { + let ti = create_correlation_ti(); + let constant_models = vec!["arithmetic", "geometric", "harmonic"]; + let deviation_models = vec!["population", "sample"]; + + for constant_model in &constant_models { + for deviation_model in &deviation_models { + let result = ti.correlate_asset_prices_single("price_a", "price_b", constant_model, deviation_model).unwrap(); + assert!(result.is_finite()); + assert!((-1.0..=1.0).contains(&result)); + } + } + } + + #[test] + fn test_correlate_asset_prices_bulk_all_model_combinations() { + let ti = create_correlation_ti(); + let constant_models = vec!["arithmetic", "geometric", "harmonic"]; + let deviation_models = vec!["population", "sample"]; + + for constant_model in &constant_models { + for deviation_model in &deviation_models { + let result = ti.correlate_asset_prices_bulk("price_a", "price_b", constant_model, deviation_model, 3).unwrap(); + let series = result.0.0; + assert!(!series.is_empty()); + assert_eq!(series.name(), "correlation"); + } + } + } + + #[test] + fn test_correlate_asset_prices_bulk_correlation_bounds() { + let ti = create_correlation_ti(); + let result = ti.correlate_asset_prices_bulk("price_a", "price_b", "arithmetic", "population", 4).unwrap(); + let series = result.0.0; + if let Ok(values) = series.f64() { + for value in values.into_iter().flatten() { + assert!((-1.0..=1.0).contains(&value), "Correlation value {value} is out of bounds"); + } + } + } +} diff --git a/plugins/ezpz-rust-ti/src/indicators/ma/mod.rs b/plugins/ezpz-rust-ti/src/indicators/ma/mod.rs new file mode 100644 index 0000000..8cfb28e --- /dev/null +++ b/plugins/ezpz-rust-ti/src/indicators/ma/mod.rs @@ -0,0 +1,276 @@ +use { + crate::utils::extract_f64_values, + ezpz_stubz::{lazy::PyLfStubbed, series::PySeriesStubbed}, + polars::prelude::*, + pyo3::prelude::*, + pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}, +}; + +fn parse_moving_average_type(ma_type: &str) -> PyResult { + match ma_type.to_lowercase().as_str() { + "simple" => Ok(rust_ti::MovingAverageType::Simple), + "exponential" => Ok(rust_ti::MovingAverageType::Exponential), + "smoothed" => Ok(rust_ti::MovingAverageType::Smoothed), + _ => Err(PyErr::new::("Unsupported moving average type")), + } +} + +/// Moving Average Technical Indicators - A collection of moving average analysis functions for financial data + +#[gen_stub_pyclass] +#[pyclass] +#[derive(Clone)] +#[allow(clippy::upper_case_acronyms)] +pub struct MATI { + lf: LazyFrame, +} + +#[gen_stub_pymethods] +#[pymethods] +impl MATI { + #[new] + fn new(lf: PyLfStubbed) -> Self { + Self { lf: lf.0.into() } + } + + /// Moving Average (Single) - Calculates a single moving average value for a series of prices + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `moving_average_type`: &str - Type of moving average ("simple", "exponential", "smoothed") + /// + /// # Returns + /// f64 - Single moving average value + fn moving_average_single(&self, price_column: &str, moving_average_type: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let ma_type = parse_moving_average_type(moving_average_type)?; + let result = rust_ti::moving_average::single::moving_average(&values, ma_type); + Ok(result) + } + + /// Moving Average (Bulk) - Calculates moving averages over a rolling window + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `moving_average_type`: &str - Type of moving average ("simple", "exponential", "smoothed") + /// - `period`: usize - Period over which to calculate the moving average + /// + /// # Returns + /// PySeriesStubbed - Series of moving average values with name "moving_average" + fn moving_average_bulk(&self, price_column: &str, moving_average_type: &str, period: usize) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let ma_type = parse_moving_average_type(moving_average_type)?; + let result = rust_ti::moving_average::bulk::moving_average(&values, ma_type, period); + let result_series = Series::new("moving_average".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// McGinley Dynamic (Single) - Calculates a single McGinley Dynamic value + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `previous_mcginley_dynamic`: f64 - Previous McGinley Dynamic value (use 0.0 if none) + /// - `period`: usize - Period for calculation + /// + /// # Returns + /// f64 - Single McGinley Dynamic value + fn mcginley_dynamic_single(&self, price_column: &str, previous_mcginley_dynamic: f64, period: usize) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + // Use the last price value as the latest price + let latest_price = values.last().ok_or_else(|| PyErr::new::("Empty series"))?; + let result = rust_ti::moving_average::single::mcginley_dynamic(*latest_price, previous_mcginley_dynamic, period); + Ok(result) + } + + /// McGinley Dynamic (Bulk) - Calculates McGinley Dynamic values over a series + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `previous_mcginley_dynamic`: f64 - Previous McGinley Dynamic value (use 0.0 if none) + /// - `period`: usize - Period for calculation + /// + /// # Returns + /// PySeriesStubbed - Series of McGinley Dynamic values with name "mcginley_dynamic" + fn mcginley_dynamic_bulk(&self, price_column: &str, previous_mcginley_dynamic: f64, period: usize) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::moving_average::bulk::mcginley_dynamic(&values, previous_mcginley_dynamic, period); + let result_series = Series::new("mcginley_dynamic".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Personalised Moving Average (Single) - Calculates a single personalised moving average + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `alpha_nominator`: f64 - Alpha nominator value + /// - `alpha_denominator`: f64 - Alpha denominator value + /// + /// # Returns + /// f64 - Single personalised moving average value + fn personalised_moving_average_single(&self, price_column: &str, alpha_nominator: f64, alpha_denominator: f64) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let ma_type = rust_ti::MovingAverageType::Personalised { alpha_num: alpha_nominator, alpha_den: alpha_denominator }; + let result = rust_ti::moving_average::single::moving_average(&values, ma_type); + Ok(result) + } + + /// Personalised Moving Average (Bulk) - Calculates personalised moving averages over a rolling window + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `alpha_nominator`: f64 - Alpha nominator value + /// - `alpha_denominator`: f64 - Alpha denominator value + /// - `period`: usize - Period over which to calculate the moving average + /// + /// # Returns + /// PySeriesStubbed - Series of personalised moving average values with name "personalised_moving_average" + fn personalised_moving_average_bulk(&self, price_column: &str, alpha_nominator: f64, alpha_denominator: f64, period: usize) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let ma_type = rust_ti::MovingAverageType::Personalised { alpha_num: alpha_nominator, alpha_den: alpha_denominator }; + let result = rust_ti::moving_average::bulk::moving_average(&values, ma_type, period); + let result_series = Series::new("personalised_moving_average".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use approx::assert_abs_diff_eq; + use ezpz_stubz::lazy::PyLfStubbed; + use polars::lazy::frame::LazyFrame; + + fn create_test_dataframe() -> LazyFrame { + let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + df! { + "price" => data, + } + .unwrap() + .lazy() + } + + fn create_mati() -> MATI { + let lf = create_test_dataframe(); + MATI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(lf))) + } + + #[test] + fn test_moving_average_single_simple() { + let mati = create_mati(); + let result = mati.moving_average_single("price", "simple").unwrap(); + assert_abs_diff_eq!(result, 5.5, epsilon = 1e-10); + } + + #[test] + fn test_moving_average_bulk_exponential() { + let mati = create_mati(); + let result = mati.moving_average_bulk("price", "exponential", 3).unwrap(); + let result_vec = result.0.0.f64().unwrap().into_no_null_iter().collect::>(); + assert_eq!(result_vec.len(), 10); + assert!(result_vec.iter().any(|&v| v != 0.0)); + } + + #[test] + fn test_mcginley_dynamic_single() { + let mati = create_mati(); + let result = mati.mcginley_dynamic_single("price", 0.0, 3).unwrap(); + assert!(result > 0.0); + } + + #[test] + fn test_mcginley_dynamic_bulk() { + let mati = create_mati(); + let result = mati.mcginley_dynamic_bulk("price", 0.0, 3).unwrap(); + let result_vec = result.0.0.f64().unwrap().into_no_null_iter().collect::>(); + assert_eq!(result_vec.len(), 10); + assert!(result_vec.iter().any(|&v| v != 0.0)); + } + + #[test] + fn test_personalised_moving_average_single() { + let mati = create_mati(); + let result = mati.personalised_moving_average_single("price", 2.0, 3.0).unwrap(); + assert!(result > 0.0); + } + + #[test] + fn test_personalised_moving_average_bulk() { + let mati = create_mati(); + let result = mati.personalised_moving_average_bulk("price", 2.0, 3.0, 3).unwrap(); + let result_vec = result.0.0.f64().unwrap().into_no_null_iter().collect::>(); + assert_eq!(result_vec.len(), 10); + assert!(result_vec.iter().any(|&v| v != 0.0)); + } +} diff --git a/plugins/ezpz-rust-ti/src/indicators/mod.rs b/plugins/ezpz-rust-ti/src/indicators/mod.rs new file mode 100644 index 0000000..73d04fb --- /dev/null +++ b/plugins/ezpz-rust-ti/src/indicators/mod.rs @@ -0,0 +1,11 @@ +pub mod basic; +pub mod candle; +pub mod chart; +pub mod correlation; +pub mod ma; +pub mod momentum; +pub mod other; +pub mod std_; +pub mod strength; +pub mod trend; +pub mod volatility; diff --git a/plugins/ezpz-rust-ti/src/indicators/momentum/mod.rs b/plugins/ezpz-rust-ti/src/indicators/momentum/mod.rs new file mode 100644 index 0000000..12c012e --- /dev/null +++ b/plugins/ezpz-rust-ti/src/indicators/momentum/mod.rs @@ -0,0 +1,1142 @@ +use { + crate::utils::{create_triple_df, extract_f64_values, parse_constant_model_type, parse_deviation_model}, + ezpz_stubz::{frame::PyDfStubbed, lazy::PyLfStubbed, series::PySeriesStubbed}, + polars::prelude::*, + pyo3::prelude::*, + pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}, +}; + +/// Momentum Technical Indicators - A collection of momentum analysis functions for financial data + +#[gen_stub_pyclass] +#[pyclass] +#[derive(Clone)] +pub struct MomentumTI { + lf: LazyFrame, +} + +#[gen_stub_pymethods] +#[pymethods] +impl MomentumTI { + #[new] + fn new(lf: PyLfStubbed) -> Self { + Self { lf: lf.0.into() } + } + + /// Aroon Up indicator + /// + /// Calculates the Aroon Up indicator, which measures the time since the highest high + /// within a given period as a percentage. + /// + /// # Parameters + /// * `high_column` - &str name of the column containing high price values + /// + /// # Returns + /// * `PyResult` - The Aroon Up value (0-100), where higher values indicate recent highs + fn aroon_up_single(&self, high_column: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(high_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{high_column}': {e}")))? + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::trend_indicators::single::aroon_up(&values); + Ok(result) + } + + /// Aroon Down indicator + /// + /// Calculates the Aroon Down indicator, which measures the time since the lowest low + /// within a given period as a percentage. + /// + /// # Parameters + /// * `low_column` - &str name of the column containing low price values + /// + /// # Returns + /// * `PyResult` - The Aroon Down value (0-100), where higher values indicate recent lows + fn aroon_down_single(&self, low_column: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(low_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{low_column}': {e}")))? + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::trend_indicators::single::aroon_down(&values); + Ok(result) + } + + /// Aroon Oscillator + /// + /// Calculates the Aroon Oscillator by subtracting Aroon Down from Aroon Up. + /// Values range from -100 to +100, indicating trend strength and direction. + /// + /// # Parameters + /// * `aroon_up` - f64 value of Aroon Up indicator (0-100) + /// * `aroon_down` - f64 value of Aroon Down indicator (0-100) + /// + /// # Returns + /// * `PyResult` - The Aroon Oscillator value (-100 to +100) + fn aroon_oscillator_single(&self, aroon_up: f64, aroon_down: f64) -> PyResult { + let result = rust_ti::trend_indicators::single::aroon_oscillator(aroon_up, aroon_down); + Ok(result) + } + + /// Aroon Indicator (complete calculation) + /// + /// Calculates all three Aroon components: Aroon Up, Aroon Down, and Aroon Oscillator + /// in a single function call. + /// + /// # Parameters + /// * `high_column` - &str name of the column containing high price values + /// * `low_column` - &str name of the column containing low price values + /// + /// # Returns + /// * `PyResult<(f64, f64, f64)>` - Tuple containing (aroon_up, aroon_down, aroon_oscillator) + fn aroon_indicator_single(&self, high_column: &str, low_column: &str) -> PyResult<(f64, f64, f64)> { + let df = self + .lf + .clone() + .select([col(high_column), col(low_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let high_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let low_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + let result = rust_ti::trend_indicators::single::aroon_indicator(&high_values, &low_values); + Ok(result) + } + + /// Long Parabolic Time Price System (Parabolic SAR for long positions) + /// + /// Calculates the Parabolic SAR (Stop and Reverse) for long positions, used to determine + /// potential reversal points in price movement. + /// + /// # Parameters + /// * `previous_sar` - f64 value of the previous SAR + /// * `extreme_point` - f64 value of the extreme point (highest high for long positions) + /// * `acceleration_factor` - f64 acceleration factor (typically starts at 0.02) + /// * `low` - f64 current period's low price + /// + /// # Returns + /// * `PyResult` - The calculated SAR value for long positions + fn long_parabolic_time_price_system_single(&self, previous_sar: f64, extreme_point: f64, acceleration_factor: f64, low: f64) -> PyResult { + let result = rust_ti::trend_indicators::single::long_parabolic_time_price_system(previous_sar, extreme_point, acceleration_factor, low); + Ok(result) + } + + /// Short Parabolic Time Price System (Parabolic SAR for short positions) + /// + /// Calculates the Parabolic SAR (Stop and Reverse) for short positions, used to determine + /// potential reversal points in price movement. + /// + /// # Parameters + /// * `previous_sar` - f64 value of the previous SAR + /// * `extreme_point` - f64 value of the extreme point (lowest low for short positions) + /// * `acceleration_factor` - f64 acceleration factor (typically starts at 0.02) + /// * `high` - f64 current period's high price + /// + /// # Returns + /// * `PyResult` - The calculated SAR value for short positions + fn short_parabolic_time_price_system_single(&self, previous_sar: f64, extreme_point: f64, acceleration_factor: f64, high: f64) -> PyResult { + let result = rust_ti::trend_indicators::single::short_parabolic_time_price_system(previous_sar, extreme_point, acceleration_factor, high); + Ok(result) + } + + /// Volume Price Trend + /// + /// Calculates the Volume Price Trend indicator, which combines price and volume + /// to show the relationship between volume and price changes. + /// + /// # Parameters + /// * `price_column` - &str name of the column containing price values + /// * `previous_price` - f64 previous period's price + /// * `volume` - f64 current period's volume + /// * `previous_volume_price_trend` - f64 previous VPT value + /// + /// # Returns + /// * `PyResult` - The calculated Volume Price Trend value + fn volume_price_trend_single(&self, price_column: &str, previous_price: f64, volume: f64, previous_volume_price_trend: f64) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let current_price = values[0]; + let result = rust_ti::trend_indicators::single::volume_price_trend(current_price, previous_price, volume, previous_volume_price_trend); + Ok(result) + } + + /// True Strength Index + /// + /// Calculates the True Strength Index, a momentum oscillator that uses price changes + /// smoothed by two exponential moving averages. + /// + /// # Parameters + /// * `price_column` - &str name of the column containing price values + /// * `first_constant_model` - &str smoothing model for first smoothing ("sma", "ema", etc.) + /// * `first_period` - usize period for first smoothing + /// * `second_constant_model` - &str smoothing model for second smoothing ("sma", "ema", etc.) + /// + /// # Returns + /// * `PyResult` - The True Strength Index value (typically ranges from -100 to +100) + fn true_strength_index_single(&self, price_column: &str, first_constant_model: &str, first_period: usize, second_constant_model: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let first_model = parse_constant_model_type(first_constant_model)?; + let second_model = parse_constant_model_type(second_constant_model)?; + let result = rust_ti::trend_indicators::single::true_strength_index(&values, first_model, first_period, second_model); + Ok(result) + } + + /// Relative Strength Index (RSI) - bulk calculation + /// + /// Calculates RSI values for an entire series of prices. RSI measures the speed and change + /// of price movements, oscillating between 0 and 100. + /// + /// # Parameters + /// * `price_column` - &str name of the column containing price values + /// * `constant_model_type` - &str smoothing model ("sma", "ema", etc.) + /// * `period` - usize calculation period (commonly 14) + /// + /// # Returns + /// * `PyResult` - Series named "rsi" containing RSI values (0-100) + fn relative_strength_index_bulk(&self, price_column: &str, constant_model_type: &str, period: usize) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let model_type = parse_constant_model_type(constant_model_type)?; + let result = rust_ti::momentum_indicators::bulk::relative_strength_index(&values, model_type, period); + let result_series = Series::new("rsi".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Stochastic Oscillator - bulk calculation + /// + /// Calculates the Stochastic Oscillator, which compares a security's closing price + /// to its price range over a given time period. + /// + /// # Parameters + /// * `price_column` - &str name of the column containing price values + /// * `period` - usize lookback period for calculation + /// + /// # Returns + /// * `PyResult` - Series named "stochastic" containing oscillator values (0-100) + fn stochastic_oscillator_bulk(&self, price_column: &str, period: usize) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::momentum_indicators::bulk::stochastic_oscillator(&values, period); + let result_series = Series::new("stochastic".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Slow Stochastic - bulk calculation + /// + /// Calculates the Slow Stochastic by smoothing the regular Stochastic Oscillator + /// to reduce noise and false signals. + /// + /// # Parameters + /// * `stochastic_column` - &str name of the column containing Stochastic Oscillator values + /// * `constant_model_type` - &str smoothing model ("sma", "ema", etc.) + /// * `period` - usize smoothing period + /// + /// # Returns + /// * `PyResult` - Series named "slow_stochastic" containing smoothed values (0-100) + fn slow_stochastic_bulk(&self, stochastic_column: &str, constant_model_type: &str, period: usize) -> PyResult { + let series = self + .lf + .clone() + .select([col(stochastic_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{stochastic_column}': {e}")))? + .column(stochastic_column) + .map_err(|e| PyErr::new::(format!("Column '{stochastic_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{stochastic_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let model_type = parse_constant_model_type(constant_model_type)?; + let result = rust_ti::momentum_indicators::bulk::slow_stochastic(&values, model_type, period); + let result_series = Series::new("slow_stochastic".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Slowest Stochastic - bulk calculation + /// + /// Calculates the Slowest Stochastic by applying additional smoothing to the Slow Stochastic + /// for even more noise reduction. + /// + /// # Parameters + /// * `slow_stochastic_column` - &str name of the column containing Slow Stochastic values + /// * `constant_model_type` - &str smoothing model ("sma", "ema", etc.) + /// * `period` - usize smoothing period + /// + /// # Returns + /// * `PyResult` - Series named "slowest_stochastic" containing double-smoothed values (0-100) + fn slowest_stochastic_bulk(&self, slow_stochastic_column: &str, constant_model_type: &str, period: usize) -> PyResult { + let series = self + .lf + .clone() + .select([col(slow_stochastic_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{slow_stochastic_column}': {e}")))? + .column(slow_stochastic_column) + .map_err(|e| PyErr::new::(format!("Column '{slow_stochastic_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{slow_stochastic_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let model_type = parse_constant_model_type(constant_model_type)?; + let result = rust_ti::momentum_indicators::bulk::slowest_stochastic(&values, model_type, period); + let result_series = Series::new("slowest_stochastic".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Williams %R - bulk calculation + /// + /// Calculates Williams %R, a momentum indicator that measures overbought and oversold levels. + /// Values range from -100 to 0, where -20 and above indicates overbought, -80 and below indicates oversold. + /// + /// # Parameters + /// * `high_column` - &str name of the column containing high price values + /// * `low_column` - &str name of the column containing low price values + /// * `close_column` - &str name of the column containing close price values + /// * `period` - usize lookback period for calculation + /// + /// # Returns + /// * `PyResult` - Series named "williams_r" containing Williams %R values (-100 to 0) + fn williams_percent_r_bulk(&self, high_column: &str, low_column: &str, close_column: &str, period: usize) -> PyResult { + let df = self + .lf + .clone() + .select([col(high_column), col(low_column), col(close_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let close_series = df + .column(close_column) + .map_err(|e| PyErr::new::(format!("Column '{close_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{close_column}' could not be converted to Series")))? + .clone(); + + let high_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let low_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + let close_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(close_series)))?; + let result = rust_ti::momentum_indicators::bulk::williams_percent_r(&high_values, &low_values, &close_values, period); + let result_series = Series::new("williams_r".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Money Flow Index - bulk calculation + /// + /// Calculates the Money Flow Index, a volume-weighted RSI that measures buying and selling pressure. + /// Values range from 0 to 100, where >80 indicates overbought and <20 indicates oversold. + /// + /// # Parameters + /// * `price_column` - &str name of the column containing price values + /// * `volume_column` - &str name of the column containing volume values + /// * `period` - usize calculation period (commonly 14) + /// + /// # Returns + /// * `PyResult` - Series named "mfi" containing Money Flow Index values (0-100) + fn money_flow_index_bulk(&self, price_column: &str, volume_column: &str, period: usize) -> PyResult { + let df = self + .lf + .clone() + .select([col(price_column), col(volume_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let price_series = df + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let volume_series = df + .column(volume_column) + .map_err(|e| PyErr::new::(format!("Column '{volume_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{volume_column}' could not be converted to Series")))? + .clone(); + + let price_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(price_series)))?; + let volume_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(volume_series)))?; + let result = rust_ti::momentum_indicators::bulk::money_flow_index(&price_values, &volume_values, period); + let result_series = Series::new("mfi".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Rate of Change - bulk calculation + /// + /// Calculates the Rate of Change, which measures the percentage change in price + /// from one period to the next. + /// + /// # Parameters + /// * `price_column` - &str name of the column containing price values + /// + /// # Returns + /// * `PyResult` - Series named "roc" containing rate of change values as percentages + fn rate_of_change_bulk(&self, price_column: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::momentum_indicators::bulk::rate_of_change(&values); + let result_series = Series::new("roc".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// On Balance Volume (Bulk) - Calculates cumulative volume indicator + /// Adds volume on up days and subtracts volume on down days to measure buying and selling pressure + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `volume_column`: &str - Name of the volume column + /// - `previous_obv`: f64 - Starting OBV value (typically 0) + /// + /// # Returns + /// PySeriesStubbed - Series of OBV values with name "obv" + fn on_balance_volume_bulk(&self, price_column: &str, volume_column: &str, previous_obv: f64) -> PyResult { + let df = self + .lf + .clone() + .select([col(price_column), col(volume_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let price_series = df + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let volume_series = df + .column(volume_column) + .map_err(|e| PyErr::new::(format!("Column '{volume_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{volume_column}' could not be converted to Series")))? + .clone(); + + let price_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(price_series)))?; + let volume_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(volume_series)))?; + let result = rust_ti::momentum_indicators::bulk::on_balance_volume(&price_values, &volume_values, previous_obv); + let result_series = Series::new("obv".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Commodity Channel Index (Bulk) - Calculates CCI over rolling periods + /// Measures the variation of a security's price from its statistical mean + /// Values typically range from -100 to +100 + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `constant_model_type`: &str - Model for calculating moving average ("sma", "ema", etc.) + /// - `deviation_model`: &str - Model for calculating deviation ("mad", "std", etc.) + /// - `constant_multiplier`: f64 - Multiplier constant (typically 0.015) + /// - `period`: usize - Calculation period (commonly 20) + /// + /// # Returns + /// PySeriesStubbed - Series of CCI values with name "cci" + fn commodity_channel_index_bulk( + &self, + price_column: &str, + constant_model_type: &str, + deviation_model: &str, + constant_multiplier: f64, + period: usize, + ) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let model_type = parse_constant_model_type(constant_model_type)?; + let dev_model = parse_deviation_model(deviation_model)?; + let result = rust_ti::momentum_indicators::bulk::commodity_channel_index(&values, model_type, dev_model, constant_multiplier, period); + let result_series = Series::new("cci".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// McGinley Dynamic Commodity Channel Index (Bulk) - CCI using McGinley Dynamic MA + /// Uses McGinley Dynamic as the moving average, which adapts to market conditions + /// better than traditional moving averages + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `previous_mcginley_dynamic`: f64 - Initial McGinley Dynamic value + /// - `deviation_model`: &str - Model for calculating deviation ("mad", "std", etc.) + /// - `constant_multiplier`: f64 - Multiplier constant (typically 0.015) + /// - `period`: usize - Calculation period + /// + /// # Returns + /// (PySeriesStubbed, PySeriesStubbed) - Tuple containing (CCI series, McGinley Dynamic series) + fn mcginley_dynamic_commodity_channel_index_bulk( + &self, + price_column: &str, + previous_mcginley_dynamic: f64, + deviation_model: &str, + constant_multiplier: f64, + period: usize, + ) -> PyResult<(PySeriesStubbed, PySeriesStubbed)> { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let dev_model = parse_deviation_model(deviation_model)?; + let result = + rust_ti::momentum_indicators::bulk::mcginley_dynamic_commodity_channel_index(&values, previous_mcginley_dynamic, dev_model, constant_multiplier, period); + let (cci_values, mcginley_values): (Vec, Vec) = result.into_iter().unzip(); + let cci_series = Series::new("cci".into(), cci_values); + let mcginley_series = Series::new("mcginley_dynamic".into(), mcginley_values); + Ok((PySeriesStubbed(pyo3_polars::PySeries(cci_series)), PySeriesStubbed(pyo3_polars::PySeries(mcginley_series)))) + } + + /// MACD Line (Bulk) - Calculates Moving Average Convergence Divergence line + /// Subtracts the long-period moving average from the short-period moving average + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `short_period`: usize - Period for short moving average (commonly 12) + /// - `short_period_model`: &str - Model for short MA ("sma", "ema", etc.) + /// - `long_period`: usize - Period for long moving average (commonly 26) + /// - `long_period_model`: &str - Model for long MA ("sma", "ema", etc.) + /// + /// # Returns + /// PySeriesStubbed - Series of MACD line values with name "macd" + fn macd_line_bulk( + &self, + price_column: &str, + short_period: usize, + short_period_model: &str, + long_period: usize, + long_period_model: &str, + ) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let short_model = parse_constant_model_type(short_period_model)?; + let long_model = parse_constant_model_type(long_period_model)?; + let result = rust_ti::momentum_indicators::bulk::macd_line(&values, short_period, short_model, long_period, long_model); + let result_series = Series::new("macd".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Signal Line (Bulk) - Calculates MACD Signal Line + /// Applies a moving average to the MACD line for generating buy/sell signals + /// + /// # Parameters + /// - `macd_column`: &str - Name of the MACD column to analyze + /// - `constant_model_type`: &str - Smoothing model ("sma", "ema", etc.) + /// - `period`: usize - Signal line period (commonly 9) + /// + /// # Returns + /// PySeriesStubbed - Series of signal line values with name "signal" + fn signal_line_bulk(&self, macd_column: &str, constant_model_type: &str, period: usize) -> PyResult { + let series = self + .lf + .clone() + .select([col(macd_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{macd_column}': {e}")))? + .column(macd_column) + .map_err(|e| PyErr::new::(format!("Column '{macd_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{macd_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let model_type = parse_constant_model_type(constant_model_type)?; + let result = rust_ti::momentum_indicators::bulk::signal_line(&values, model_type, period); + let result_series = Series::new("signal".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// McGinley Dynamic MACD Line (Bulk) - MACD using McGinley Dynamic moving averages + /// Provides better adaptation to market volatility and reduces lag compared to traditional MACD + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `short_period`: usize - Period for short McGinley Dynamic + /// - `previous_short_mcginley`: f64 - Initial short McGinley Dynamic value + /// - `long_period`: usize - Period for long McGinley Dynamic + /// - `previous_long_mcginley`: f64 - Initial long McGinley Dynamic value + /// + /// # Returns + /// PyDfStubbed - DataFrame with columns: "macd", "short_mcginley", "long_mcginley" + fn mcginley_dynamic_macd_line_bulk( + &self, + price_column: &str, + short_period: usize, + previous_short_mcginley: f64, + long_period: usize, + previous_long_mcginley: f64, + ) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = + rust_ti::momentum_indicators::bulk::mcginley_dynamic_macd_line(&values, short_period, previous_short_mcginley, long_period, previous_long_mcginley); + let (macd_values, short_mcginley_values, long_mcginley_values): (Vec, Vec, Vec) = + result.into_iter().fold((Vec::new(), Vec::new(), Vec::new()), |mut acc, (a, b, c)| { + acc.0.push(a); + acc.1.push(b); + acc.2.push(c); + acc + }); + create_triple_df(macd_values, short_mcginley_values, long_mcginley_values, "macd", "short_mcginley", "long_mcginley") + } + + /// Chaikin Oscillator (Bulk) - Applies MACD to Accumulation/Distribution line + /// Measures the momentum of the Accumulation/Distribution line + /// + /// # Parameters + /// - `high_column`: &str - Name of the high price column + /// - `low_column`: &str - Name of the low price column + /// - `close_column`: &str - Name of the close price column + /// - `volume_column`: &str - Name of the volume column + /// - `short_period`: usize - Short period for oscillator (commonly 3) + /// - `long_period`: usize - Long period for oscillator (commonly 10) + /// - `previous_accumulation_distribution`: f64 - Initial A/D line value + /// - `short_period_model`: &str - Model for short MA ("sma", "ema", etc.) + /// - `long_period_model`: &str - Model for long MA ("sma", "ema", etc.) + /// + /// # Returns + /// (PySeriesStubbed, PySeriesStubbed) - Tuple containing (Chaikin Oscillator, A/D Line) + #[allow(clippy::too_many_arguments)] + fn chaikin_oscillator_bulk( + &self, + high_column: &str, + low_column: &str, + close_column: &str, + volume_column: &str, + short_period: usize, + long_period: usize, + previous_accumulation_distribution: f64, + short_period_model: &str, + long_period_model: &str, + ) -> PyResult<(PySeriesStubbed, PySeriesStubbed)> { + let df = self + .lf + .clone() + .select([col(high_column), col(low_column), col(close_column), col(volume_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let close_series = df + .column(close_column) + .map_err(|e| PyErr::new::(format!("Column '{close_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{close_column}' could not be converted to Series")))? + .clone(); + + let volume_series = df + .column(volume_column) + .map_err(|e| PyErr::new::(format!("Column '{volume_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{volume_column}' could not be converted to Series")))? + .clone(); + + let high_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let low_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + let close_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(close_series)))?; + let volume_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(volume_series)))?; + let short_model = parse_constant_model_type(short_period_model)?; + let long_model = parse_constant_model_type(long_period_model)?; + let result = rust_ti::momentum_indicators::bulk::chaikin_oscillator( + &high_values, + &low_values, + &close_values, + &volume_values, + short_period, + long_period, + previous_accumulation_distribution, + short_model, + long_model, + ); + let (chaikin_values, ad_values): (Vec, Vec) = result.into_iter().unzip(); + let chaikin_series = Series::new("chaikin_oscillator".into(), chaikin_values); + let ad_series = Series::new("accumulation_distribution".into(), ad_values); + Ok((PySeriesStubbed(pyo3_polars::PySeries(chaikin_series)), PySeriesStubbed(pyo3_polars::PySeries(ad_series)))) + } + + /// Percentage Price Oscillator (Bulk) - MACD expressed as percentage + /// Similar to MACD but expressed as a percentage for easier comparison across securities + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `short_period`: usize - Short period for moving average (commonly 12) + /// - `long_period`: usize - Long period for moving average (commonly 26) + /// - `constant_model_type`: &str - Model for moving averages ("sma", "ema", etc.) + /// + /// # Returns + /// PySeriesStubbed - Series of PPO values as percentages with name "ppo" + fn percentage_price_oscillator_bulk( + &self, + price_column: &str, + short_period: usize, + long_period: usize, + constant_model_type: &str, + ) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let model_type = parse_constant_model_type(constant_model_type)?; + let result = rust_ti::momentum_indicators::bulk::percentage_price_oscillator(&values, short_period, long_period, model_type); + let result_series = Series::new("ppo".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Chande Momentum Oscillator (Bulk) - Measures momentum using gains and losses + /// Calculates the difference between sum of gains and losses over a period + /// Values range from -100 to +100 + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `period`: usize - Calculation period (commonly 14 or 20) + /// + /// # Returns + /// PySeriesStubbed - Series of CMO values (-100 to +100) with name "chande_momentum_oscillator" + fn chande_momentum_oscillator_bulk(&self, price_column: &str, period: usize) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::momentum_indicators::bulk::chande_momentum_oscillator(&values, period); + let result_series = Series::new("chande_momentum_oscillator".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use approx::assert_abs_diff_eq; + use ezpz_stubz::lazy::PyLfStubbed; + + fn create_test_dataframe() -> LazyFrame { + df! { + "high" => vec![12.0, 15.0, 18.0, 16.0, 14.0, 17.0, 19.0, 21.0, 20.0, 22.0], + "low" => vec![10.0, 11.0, 13.0, 12.0, 10.0, 14.0, 15.0, 17.0, 16.0, 18.0], + "close" => vec![11.0, 13.0, 16.0, 14.0, 12.0, 16.0, 18.0, 19.0, 18.0, 20.0], + "price" => vec![11.0, 13.0, 16.0, 14.0, 12.0, 16.0, 18.0, 19.0, 18.0, 20.0], + "volume" => vec![1000.0, 1200.0, 1500.0, 1300.0, 1100.0, 1400.0, 1600.0, 1800.0, 1700.0, 1900.0] + } + .unwrap() + .lazy() + } + + fn create_momentum_ti() -> MomentumTI { + let lf = create_test_dataframe(); + MomentumTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(lf))) + } + + #[test] + fn test_aroon_up_single() { + let ti = create_momentum_ti(); + let result = ti.aroon_up_single("high").unwrap(); + assert!((0.0..=100.0).contains(&result)); + } + + #[test] + fn test_aroon_down_single() { + let ti = create_momentum_ti(); + let result = ti.aroon_down_single("low").unwrap(); + assert!((0.0..=100.0).contains(&result)); + } + + #[test] + fn test_aroon_oscillator_single() { + let ti = create_momentum_ti(); + let aroon_up = 80.0; + let aroon_down = 20.0; + let result = ti.aroon_oscillator_single(aroon_up, aroon_down).unwrap(); + assert_abs_diff_eq!(result, 60.0, epsilon = 1e-10); + } + + #[test] + fn test_aroon_indicator_single() { + let ti = create_momentum_ti(); + let result = ti.aroon_indicator_single("high", "low").unwrap(); + let (aroon_up, aroon_down, aroon_oscillator) = result; + assert!((0.0..=100.0).contains(&aroon_up)); + assert!((0.0..=100.0).contains(&aroon_down)); + assert!((-100.0..=100.0).contains(&aroon_oscillator)); + assert_abs_diff_eq!(aroon_oscillator, aroon_up - aroon_down, epsilon = 1e-10); + } + + #[test] + fn test_long_parabolic_time_price_system_single() { + let ti = create_momentum_ti(); + let result = ti.long_parabolic_time_price_system_single(10.0, 15.0, 0.02, 12.0).unwrap(); + assert!(result > 0.0); + } + + #[test] + fn test_short_parabolic_time_price_system_single() { + let ti = create_momentum_ti(); + let result = ti.short_parabolic_time_price_system_single(20.0, 15.0, 0.02, 18.0).unwrap(); + assert!(result > 0.0); + } + + #[test] + fn test_volume_price_trend_single() { + let ti = create_momentum_ti(); + let result = ti.volume_price_trend_single("price", 10.0, 1000.0, 0.0).unwrap(); + assert!(result.is_finite()); + } + + #[test] + fn test_true_strength_index_single() { + let ti = create_momentum_ti(); + let result = ti.true_strength_index_single("price", "ema", 14, "ema").unwrap(); + assert!((-100.0..=100.0).contains(&result)); + } + + #[test] + fn test_relative_strength_index_bulk() { + let ti = create_momentum_ti(); + let result = ti.relative_strength_index_bulk("price", "ema", 14).unwrap(); + let series = result.0.0; + assert_eq!(series.name(), "rsi"); + assert!(!series.is_empty()); + } + + #[test] + fn test_stochastic_oscillator_bulk() { + let ti = create_momentum_ti(); + let result = ti.stochastic_oscillator_bulk("price", 14).unwrap(); + let series = result.0.0; + assert_eq!(series.name(), "stochastic"); + assert!(!series.is_empty()); + } + + #[test] + fn test_slow_stochastic_bulk() { + let ti = create_momentum_ti(); + let stoch_result = ti.stochastic_oscillator_bulk("price", 14).unwrap(); + let df = df! { + "stoch" => stoch_result.0.0.f64().unwrap().into_iter().map(|v| v.unwrap_or(0.0)).collect::>() + } + .unwrap() + .lazy(); + let ti_with_stoch = MomentumTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(df))); + let result = ti_with_stoch.slow_stochastic_bulk("stoch", "sma", 3).unwrap(); + let series = result.0.0; + assert_eq!(series.name(), "slow_stochastic"); + } + + #[test] + fn test_williams_percent_r_bulk() { + let ti = create_momentum_ti(); + let result = ti.williams_percent_r_bulk("high", "low", "close", 14).unwrap(); + let series = result.0.0; + assert_eq!(series.name(), "williams_r"); + assert!(!series.is_empty()); + } + + #[test] + fn test_money_flow_index_bulk() { + let ti = create_momentum_ti(); + let result = ti.money_flow_index_bulk("price", "volume", 14).unwrap(); + let series = result.0.0; + assert_eq!(series.name(), "mfi"); + assert!(!series.is_empty()); + } + + #[test] + fn test_rate_of_change_bulk() { + let ti = create_momentum_ti(); + let result = ti.rate_of_change_bulk("price").unwrap(); + let series = result.0.0; + assert_eq!(series.name(), "roc"); + assert!(!series.is_empty()); + } + + #[test] + fn test_on_balance_volume_bulk() { + let ti = create_momentum_ti(); + let result = ti.on_balance_volume_bulk("price", "volume", 0.0).unwrap(); + let series = result.0.0; + assert_eq!(series.name(), "obv"); + assert!(!series.is_empty()); + } + + #[test] + fn test_commodity_channel_index_bulk() { + let ti = create_momentum_ti(); + let result = ti.commodity_channel_index_bulk("price", "sma", "mad", 0.015, 20).unwrap(); + let series = result.0.0; + assert_eq!(series.name(), "cci"); + assert!(!series.is_empty()); + } + + #[test] + fn test_mcginley_dynamic_commodity_channel_index_bulk() { + let ti = create_momentum_ti(); + let result = ti.mcginley_dynamic_commodity_channel_index_bulk("price", 15.0, "mad", 0.015, 20).unwrap(); + let (cci_series, mcginley_series) = result; + assert_eq!(cci_series.0.0.name(), "cci"); + assert_eq!(mcginley_series.0.0.name(), "mcginley_dynamic"); + } + + #[test] + fn test_macd_line_bulk() { + let ti = create_momentum_ti(); + let result = ti.macd_line_bulk("price", 12, "ema", 26, "ema").unwrap(); + let series = result.0.0; + assert_eq!(series.name(), "macd"); + assert!(!series.is_empty()); + } + + #[test] + fn test_signal_line_bulk() { + let ti = create_momentum_ti(); + let macd_result = ti.macd_line_bulk("price", 12, "ema", 26, "ema").unwrap(); + let df = df! { + "macd" => macd_result.0.0.f64().unwrap().into_iter().map(|v| v.unwrap_or(0.0)).collect::>() + } + .unwrap() + .lazy(); + let ti_with_macd = MomentumTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(df))); + let result = ti_with_macd.signal_line_bulk("macd", "ema", 9).unwrap(); + let series = result.0.0; + assert_eq!(series.name(), "signal"); + } + + #[test] + fn test_mcginley_dynamic_macd_line_bulk() { + let ti = create_momentum_ti(); + let result = ti.mcginley_dynamic_macd_line_bulk("price", 12, 15.0, 26, 18.0).unwrap(); + let df = result.0.0; + let columns = df.get_column_names(); + let macd = PlSmallStr::from("macd"); + let short_mcginley = PlSmallStr::from("short_mcginley"); + let long_mcginley = PlSmallStr::from("long_mcginley"); + assert!(columns.contains(&&macd)); + assert!(columns.contains(&&short_mcginley)); + assert!(columns.contains(&&long_mcginley)); + } + + #[test] + fn test_chaikin_oscillator_bulk() { + let ti = create_momentum_ti(); + let result = ti.chaikin_oscillator_bulk("high", "low", "close", "volume", 3, 10, 0.0, "ema", "ema").unwrap(); + let (chaikin_series, ad_series) = result; + assert_eq!(chaikin_series.0.0.name(), "chaikin_oscillator"); + assert_eq!(ad_series.0.0.name(), "accumulation_distribution"); + } + + #[test] + fn test_percentage_price_oscillator_bulk() { + let ti = create_momentum_ti(); + let result = ti.percentage_price_oscillator_bulk("price", 12, 26, "ema").unwrap(); + let series = result.0.0; + assert_eq!(series.name(), "ppo"); + assert!(!series.is_empty()); + } + + #[test] + fn test_chande_momentum_oscillator_bulk() { + let ti = create_momentum_ti(); + let result = ti.chande_momentum_oscillator_bulk("price", 14).unwrap(); + let series = result.0.0; + assert_eq!(series.name(), "chande_momentum_oscillator"); + assert!(!series.is_empty()); + } + + #[test] + fn test_invalid_column_name() { + let ti = create_momentum_ti(); + let result = ti.aroon_up_single("invalid_column"); + assert!(result.is_err()); + } + + #[test] + fn test_aroon_oscillator_boundary_values() { + let ti = create_momentum_ti(); + let result_max = ti.aroon_oscillator_single(100.0, 0.0).unwrap(); + let result_min = ti.aroon_oscillator_single(0.0, 100.0).unwrap(); + assert_abs_diff_eq!(result_max, 100.0, epsilon = 1e-10); + assert_abs_diff_eq!(result_min, -100.0, epsilon = 1e-10); + } + + #[test] + fn test_parabolic_sar_acceleration_factor_zero() { + let ti = create_momentum_ti(); + let result = ti.long_parabolic_time_price_system_single(10.0, 15.0, 0.0, 12.0).unwrap(); + assert_abs_diff_eq!(result, 10.0, epsilon = 1e-10); + } + + #[test] + fn test_volume_price_trend_no_change() { + let ti = create_momentum_ti(); + let result = ti.volume_price_trend_single("price", 15.0, 1000.0, 100.0).unwrap(); + assert!(result.is_finite()); + } +} diff --git a/plugins/ezpz-rust-ti/src/indicators/other/mod.rs b/plugins/ezpz-rust-ti/src/indicators/other/mod.rs new file mode 100644 index 0000000..633cf43 --- /dev/null +++ b/plugins/ezpz-rust-ti/src/indicators/other/mod.rs @@ -0,0 +1,673 @@ +use { + crate::utils::{extract_f64_values, parse_constant_model_type}, + ezpz_stubz::{lazy::PyLfStubbed, series::PySeriesStubbed}, + polars::prelude::*, + pyo3::prelude::*, + pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}, +}; + +/// Other Technical Indicators - A collection of other analysis functions for financial data + +#[gen_stub_pyclass] +#[pyclass] +#[derive(Clone)] +pub struct OtherTI { + lf: LazyFrame, +} + +#[gen_stub_pymethods] +#[pymethods] +impl OtherTI { + #[new] + fn new(lf: PyLfStubbed) -> Self { + Self { lf: lf.0.into() } + } + + /// Return on Investment - Calculates investment value and percentage change for a single period + /// Uses the first and last values from the price column as start and end prices + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `investment`: f64 - Initial investment amount + /// + /// # Returns + /// Tuple of (final_investment_value: f64, percent_return: f64) + /// - `final_investment_value`: The absolute value of the investment at the end + /// - `percent_return`: The percentage return on the investment + fn return_on_investment_single(&self, price_column: &str, investment: f64) -> PyResult<(f64, f64)> { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + if values.len() < 2 { + return Err(pyo3::exceptions::PyValueError::new_err("Series must have at least 2 values")); + } + let start_price = values[0]; + let end_price = values[values.len() - 1]; + let result = rust_ti::other_indicators::single::return_on_investment(start_price, end_price, investment); + Ok(result) + } + + /// Return on Investment Bulk - Calculates ROI for a series of consecutive price periods + /// Uses the price column as price values for consecutive period calculations + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `investment`: f64 - Initial investment amount + /// + /// # Returns + /// Tuple of (final_investment_values: PySeriesStubbed, percent_returns: PySeriesStubbed) + /// - `final_investment_values`: Series of absolute investment values for each period + /// - `percent_returns`: Series of percentage returns for each period + fn return_on_investment_bulk(&self, price_column: &str, investment: f64) -> PyResult<(PySeriesStubbed, PySeriesStubbed)> { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let results = rust_ti::other_indicators::bulk::return_on_investment(&values, investment); + + let final_values: Vec = results.iter().map(|(final_val, _)| *final_val).collect(); + let percent_returns: Vec = results.iter().map(|(_, percent)| *percent).collect(); + + let final_series = Series::new("final_investment_value".into(), final_values); + let percent_series = Series::new("percent_return".into(), percent_returns); + + Ok((PySeriesStubbed(pyo3_polars::PySeries(final_series)), PySeriesStubbed(pyo3_polars::PySeries(percent_series)))) + } + + /// True Range - Calculates the greatest price movement for a single period + /// Uses the provided high/low/close columns to calculate true range + /// + /// # Parameters + /// - `high_column`: &str - Name of the high price column + /// - `low_column`: &str - Name of the low price column + /// - `close_column`: &str - Name of the close price column + /// + /// # Returns + /// PySeriesStubbed - Series of true range values for each period + fn true_range(&self, high_column: &str, low_column: &str, close_column: &str) -> PyResult { + let df = self + .lf + .clone() + .select([col(high_column), col(low_column), col(close_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let close_series = df + .column(close_column) + .map_err(|e| PyErr::new::(format!("Column '{close_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{close_column}' could not be converted to Series")))? + .clone(); + + let high_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let low_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + let close_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(close_series)))?; + + let results = rust_ti::other_indicators::bulk::true_range(&close_values, &high_values, &low_values); + let result_series = Series::new("true_range".into(), results); + + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Average True Range - Calculates the moving average of true range values for a single result + /// Uses the provided high/low/close columns to calculate ATR from the entire price series + /// + /// # Parameters + /// - `high_column`: &str - Name of the high price column + /// - `low_column`: &str - Name of the low price column + /// - `close_column`: &str - Name of the close price column + /// - `constant_model_type`: &str - Type of moving average ("sma", "ema", "wma", etc.) + /// + /// # Returns + /// f64 - Single ATR value calculated from the entire price series + fn average_true_range_single(&self, high_column: &str, low_column: &str, close_column: &str, constant_model_type: &str) -> PyResult { + let df = self + .lf + .clone() + .select([col(high_column), col(low_column), col(close_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let close_series = df + .column(close_column) + .map_err(|e| PyErr::new::(format!("Column '{close_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{close_column}' could not be converted to Series")))? + .clone(); + + let high_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let low_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + let close_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(close_series)))?; + + let constant_type = parse_constant_model_type(constant_model_type)?; + let result = rust_ti::other_indicators::single::average_true_range(&close_values, &high_values, &low_values, constant_type); + + Ok(result) + } + + /// Average True Range Bulk - Calculates rolling ATR values over specified periods + /// Uses the provided high/low/close columns for rolling ATR calculations + /// + /// # Parameters + /// - `high_column`: &str - Name of the high price column + /// - `low_column`: &str - Name of the low price column + /// - `close_column`: &str - Name of the close price column + /// - `constant_model_type`: &str - Type of moving average ("sma", "ema", "wma", etc.) + /// - `period`: usize - Number of periods for the moving average calculation + /// + /// # Returns + /// PySeriesStubbed - Series of ATR values for each period + fn average_true_range_bulk( + &self, + high_column: &str, + low_column: &str, + close_column: &str, + constant_model_type: &str, + period: usize, + ) -> PyResult { + let df = self + .lf + .clone() + .select([col(high_column), col(low_column), col(close_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let close_series = df + .column(close_column) + .map_err(|e| PyErr::new::(format!("Column '{close_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{close_column}' could not be converted to Series")))? + .clone(); + + let high_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let low_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + let close_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(close_series)))?; + + let constant_type = parse_constant_model_type(constant_model_type)?; + let results = rust_ti::other_indicators::bulk::average_true_range(&close_values, &high_values, &low_values, constant_type, period); + + let result_series = Series::new("average_true_range".into(), results); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Internal Bar Strength - Calculates buy/sell oscillator based on close position within high-low range + /// Uses the provided high/low/close columns to calculate IBS values + /// + /// # Parameters + /// - `high_column`: &str - Name of the high price column + /// - `low_column`: &str - Name of the low price column + /// - `close_column`: &str - Name of the close price column + /// + /// # Returns + /// PySeriesStubbed - Series of IBS values (0-1 range) for each period, where values closer to 1 + /// indicate closes near the high, and values closer to 0 indicate closes near the low + fn internal_bar_strength(&self, high_column: &str, low_column: &str, close_column: &str) -> PyResult { + let df = self + .lf + .clone() + .select([col(high_column), col(low_column), col(close_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let close_series = df + .column(close_column) + .map_err(|e| PyErr::new::(format!("Column '{close_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{close_column}' could not be converted to Series")))? + .clone(); + + let high_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let low_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + let close_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(close_series)))?; + + let results = rust_ti::other_indicators::bulk::internal_bar_strength(&high_values, &low_values, &close_values); + let result_series = Series::new("internal_bar_strength".into(), results); + + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Positivity Indicator - Generates trading signals based on open vs previous close comparison + /// Uses the provided open/close columns for signal generation + /// + /// # Parameters + /// - `open_column`: &str - Name of the opening price column + /// - `close_column`: &str - Name of the close price column + /// - `signal_period`: usize - Number of periods for signal line smoothing + /// - `constant_model_type`: &str - Type of moving average for signal line ("sma", "ema", "wma", etc.) + /// + /// # Returns + /// Tuple of (positivity_indicator: PySeriesStubbed, signal_line: PySeriesStubbed) + /// - `positivity_indicator`: Series of raw positivity values based on open/close comparison + /// - `signal_line`: Series of smoothed signal values using specified moving average + fn positivity_indicator( + &self, + open_column: &str, + close_column: &str, + signal_period: usize, + constant_model_type: &str, + ) -> PyResult<(PySeriesStubbed, PySeriesStubbed)> { + let df = self + .lf + .clone() + .select([col(open_column), col(close_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let open_series = df + .column(open_column) + .map_err(|e| PyErr::new::(format!("Column '{open_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{open_column}' could not be converted to Series")))? + .clone(); + + let close_series = df + .column(close_column) + .map_err(|e| PyErr::new::(format!("Column '{close_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{close_column}' could not be converted to Series")))? + .clone(); + + let open_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(open_series)))?; + let close_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(close_series)))?; + + let constant_type = parse_constant_model_type(constant_model_type)?; + let results = rust_ti::other_indicators::bulk::positivity_indicator(&open_values, &close_values, signal_period, constant_type); + + let positivity_values: Vec = results.iter().map(|(pos, _)| *pos).collect(); + let signal_values: Vec = results.iter().map(|(_, signal)| *signal).collect(); + + let positivity_series = Series::new("positivity_indicator".into(), positivity_values); + let signal_series = Series::new("signal_line".into(), signal_values); + + Ok((PySeriesStubbed(pyo3_polars::PySeries(positivity_series)), PySeriesStubbed(pyo3_polars::PySeries(signal_series)))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use approx::assert_abs_diff_eq; + use ezpz_stubz::lazy::PyLfStubbed; + + fn create_test_dataframe() -> LazyFrame { + df! { + "open" => vec![100.0, 102.0, 101.0, 103.0, 105.0, 104.0, 106.0, 108.0, 107.0, 109.0], + "high" => vec![105.0, 106.0, 104.0, 107.0, 109.0, 108.0, 110.0, 112.0, 111.0, 113.0], + "low" => vec![99.0, 101.0, 100.0, 102.0, 104.0, 103.0, 105.0, 107.0, 106.0, 108.0], + "close" => vec![104.0, 103.0, 102.0, 106.0, 107.0, 106.0, 109.0, 110.0, 108.0, 112.0], + "price" => vec![100.0, 102.0, 101.0, 103.0, 105.0, 104.0, 106.0, 108.0, 107.0, 109.0] + } + .unwrap() + .lazy() + } + + fn create_basic_other_ti() -> OtherTI { + let lf = create_test_dataframe(); + OtherTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(lf))) + } + + #[test] + fn test_return_on_investment_single() { + let ti = create_basic_other_ti(); + let result = ti.return_on_investment_single("price", 1000.0).unwrap(); + + // First price: 100.0, Last price: 109.0 + // Expected final value: 1000.0 * (109.0 / 100.0) = 1090.0 + // Expected return: (109.0 - 100.0) / 100.0 * 100.0 = 9.0% + assert_abs_diff_eq!(result.0, 1090.0, epsilon = 1e-10); + assert_abs_diff_eq!(result.1, 9.0, epsilon = 1e-10); + } + + #[test] + fn test_return_on_investment_single_insufficient_data() { + let single_data = df! { + "price" => vec![100.0] + } + .unwrap() + .lazy(); + + let ti = OtherTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(single_data))); + let result = ti.return_on_investment_single("price", 1000.0); + assert!(result.is_err()); + } + + #[test] + fn test_return_on_investment_bulk() { + let ti = create_basic_other_ti(); + let result = ti.return_on_investment_bulk("price", 1000.0).unwrap(); + + let final_values: Vec = result.0.0.0.f64().unwrap().into_no_null_iter().collect(); + let percent_returns: Vec = result.1.0.0.f64().unwrap().into_no_null_iter().collect(); + + assert_eq!(final_values.len(), 9); + assert_eq!(percent_returns.len(), 9); + + // First transition: 100.0 -> 102.0 + assert_abs_diff_eq!(final_values[0], 1020.0, epsilon = 1e-10); + assert_abs_diff_eq!(percent_returns[0], 2.0, epsilon = 1e-10); + } + + #[test] + fn test_true_range() { + let ti = create_basic_other_ti(); + let result = ti.true_range("high", "low", "close").unwrap(); + let values: Vec = result.0.0.f64().unwrap().into_no_null_iter().collect(); + + assert_eq!(values.len(), 10); + + // First period: max(105.0 - 99.0, |105.0 - 104.0|, |99.0 - 104.0|) = 6.0 + assert_abs_diff_eq!(values[0], 6.0, epsilon = 1e-10); + } + + #[test] + fn test_average_true_range_single() { + let ti = create_basic_other_ti(); + let result = ti.average_true_range_single("high", "low", "close", "sma").unwrap(); + + // Result should be a single ATR value + assert!(result > 0.0); + } + + #[test] + fn test_average_true_range_bulk() { + let ti = create_basic_other_ti(); + let result = ti.average_true_range_bulk("high", "low", "close", "sma", 3).unwrap(); + let values: Vec = result.0.0.f64().unwrap().into_no_null_iter().collect(); + + assert_eq!(values.len(), 10); + + // First few values should be NaN due to period requirement + assert!(values[0].is_nan()); + assert!(values[1].is_nan()); + assert!(!values[2].is_nan()); + } + + #[test] + fn test_internal_bar_strength() { + let ti = create_basic_other_ti(); + let result = ti.internal_bar_strength("high", "low", "close").unwrap(); + let values: Vec = result.0.0.f64().unwrap().into_no_null_iter().collect(); + + assert_eq!(values.len(), 10); + + // All IBS values should be between 0 and 1 + for &value in &values { + assert!((0.0..=1.0).contains(&value)); + } + + // First period: (104.0 - 99.0) / (105.0 - 99.0) = 5.0 / 6.0 ≈ 0.833 + assert_abs_diff_eq!(values[0], 5.0 / 6.0, epsilon = 1e-10); + } + + #[test] + fn test_positivity_indicator() { + let ti = create_basic_other_ti(); + let result = ti.positivity_indicator("open", "close", 3, "sma").unwrap(); + + let positivity_values: Vec = result.0.0.0.f64().unwrap().into_no_null_iter().collect(); + let signal_values: Vec = result.1.0.0.f64().unwrap().into_no_null_iter().collect(); + + assert_eq!(positivity_values.len(), 10); + assert_eq!(signal_values.len(), 10); + + // First few signal values should be NaN due to period requirement + assert!(signal_values[0].is_nan()); + assert!(signal_values[1].is_nan()); + assert!(!signal_values[2].is_nan()); + } + + #[test] + fn test_invalid_column_error() { + let ti = create_basic_other_ti(); + + let result = ti.return_on_investment_single("nonexistent_column", 1000.0); + assert!(result.is_err()); + + let result = ti.true_range("nonexistent_high", "low", "close"); + assert!(result.is_err()); + } + + #[test] + fn test_invalid_constant_model_type() { + let ti = create_basic_other_ti(); + + let result = ti.average_true_range_single("high", "low", "close", "invalid_type"); + assert!(result.is_err()); + + let result = ti.positivity_indicator("open", "close", 3, "invalid_type"); + assert!(result.is_err()); + } + + #[test] + fn test_zero_investment() { + let ti = create_basic_other_ti(); + let result = ti.return_on_investment_single("price", 0.0).unwrap(); + + assert_abs_diff_eq!(result.0, 0.0, epsilon = 1e-10); + assert_abs_diff_eq!(result.1, 9.0, epsilon = 1e-10); // Percentage should still be calculated + } + + #[test] + fn test_negative_investment() { + let ti = create_basic_other_ti(); + let result = ti.return_on_investment_single("price", -1000.0).unwrap(); + + // Negative investment should work mathematically + assert_abs_diff_eq!(result.0, -1090.0, epsilon = 1e-10); + assert_abs_diff_eq!(result.1, 9.0, epsilon = 1e-10); + } + + #[test] + fn test_zero_period_bulk() { + let ti = create_basic_other_ti(); + let result = ti.average_true_range_bulk("high", "low", "close", "sma", 0); + + // Should handle zero period gracefully + assert!( + result.is_err() || { + let values: Vec = result.unwrap().0.0.f64().unwrap().into_no_null_iter().collect(); + values.iter().all(|&x| x.is_nan()) + } + ); + } + + #[test] + fn test_large_period_bulk() { + let ti = create_basic_other_ti(); + let result = ti.average_true_range_bulk("high", "low", "close", "sma", 20).unwrap(); + let values: Vec = result.0.0.f64().unwrap().into_no_null_iter().collect(); + + assert_eq!(values.len(), 10); + + // All values should be NaN since period > data length + assert!(values.iter().take(9).all(|&x| x.is_nan())); + } + + #[test] + fn test_single_value_dataset() { + let single_data = df! { + "open" => vec![100.0], + "high" => vec![105.0], + "low" => vec![99.0], + "close" => vec![104.0], + "price" => vec![100.0] + } + .unwrap() + .lazy(); + + let ti = OtherTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(single_data))); + + let tr_result = ti.true_range("high", "low", "close").unwrap(); + let tr_values: Vec = tr_result.0.0.f64().unwrap().into_no_null_iter().collect(); + assert_eq!(tr_values.len(), 1); + assert_abs_diff_eq!(tr_values[0], 6.0, epsilon = 1e-10); // 105.0 - 99.0 + + let ibs_result = ti.internal_bar_strength("high", "low", "close").unwrap(); + let ibs_values: Vec = ibs_result.0.0.f64().unwrap().into_no_null_iter().collect(); + assert_eq!(ibs_values.len(), 1); + assert_abs_diff_eq!(ibs_values[0], 5.0 / 6.0, epsilon = 1e-10); + } + + #[test] + fn test_identical_high_low_close() { + let identical_data = df! { + "open" => vec![100.0, 100.0, 100.0], + "high" => vec![100.0, 100.0, 100.0], + "low" => vec![100.0, 100.0, 100.0], + "close" => vec![100.0, 100.0, 100.0], + "price" => vec![100.0, 100.0, 100.0] + } + .unwrap() + .lazy(); + + let ti = OtherTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(identical_data))); + + let tr_result = ti.true_range("high", "low", "close").unwrap(); + let tr_values: Vec = tr_result.0.0.f64().unwrap().into_no_null_iter().collect(); + + // All true range values should be 0 when high = low = close + for &value in &tr_values { + assert_abs_diff_eq!(value, 0.0, epsilon = 1e-10); + } + + let ibs_result = ti.internal_bar_strength("high", "low", "close").unwrap(); + let ibs_values: Vec = ibs_result.0.0.f64().unwrap().into_no_null_iter().collect(); + + // IBS should handle division by zero gracefully + for &value in &ibs_values { + assert!(value.is_nan() || value == 0.0 || value == 1.0); + } + } + + #[test] + fn test_floating_point_precision() { + let precision_data = df! { + "open" => vec![100.000000001, 100.000000002, 100.000000003], + "high" => vec![100.000000011, 100.000000012, 100.000000013], + "low" => vec![99.999999991, 99.999999992, 99.999999993], + "close" => vec![100.000000001, 100.000000002, 100.000000003], + "price" => vec![100.000000001, 100.000000002, 100.000000003] + } + .unwrap() + .lazy(); + + let ti = OtherTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(precision_data))); + + let roi_result = ti.return_on_investment_single("price", 1000.0).unwrap(); + + // Should handle small precision differences + assert!(roi_result.0 > 999.0 && roi_result.0 < 1001.0); + assert!(roi_result.1.abs() < 1.0); + } + + #[test] + fn test_different_moving_average_types() { + let ti = create_basic_other_ti(); + + let sma_result = ti.average_true_range_single("high", "low", "close", "sma").unwrap(); + let ema_result = ti.average_true_range_single("high", "low", "close", "ema").unwrap(); + + // Results should be different for different MA types + assert_ne!(sma_result, ema_result); + + // Both should be positive + assert!(sma_result > 0.0); + assert!(ema_result > 0.0); + } + + #[test] + fn test_cross_column_consistency() { + let ti = create_basic_other_ti(); + + let tr_result = ti.true_range("high", "low", "close").unwrap(); + let tr_values: Vec = tr_result.0.0.f64().unwrap().into_no_null_iter().collect(); + + // True range should always be non-negative + for &value in &tr_values { + assert!(value >= 0.0); + } + + let ibs_result = ti.internal_bar_strength("high", "low", "close").unwrap(); + let ibs_values: Vec = ibs_result.0.0.f64().unwrap().into_no_null_iter().collect(); + + // IBS should always be between 0 and 1 (or NaN) + for &value in &ibs_values { + assert!(value.is_nan() || (0.0..=1.0).contains(&value)); + } + } +} diff --git a/plugins/ezpz-rust-ti/src/indicators/std_/mod.rs b/plugins/ezpz-rust-ti/src/indicators/std_/mod.rs new file mode 100644 index 0000000..c1be1db --- /dev/null +++ b/plugins/ezpz-rust-ti/src/indicators/std_/mod.rs @@ -0,0 +1,751 @@ +use { + crate::utils::{create_triple_df, extract_f64_values}, + ezpz_stubz::{frame::PyDfStubbed, lazy::PyLfStubbed, series::PySeriesStubbed}, + polars::prelude::*, + pyo3::prelude::*, + pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}, +}; + +/// Standard Technical Indicators - A collection of standard analysis functions for financial data + +#[gen_stub_pyclass] +#[pyclass] +#[derive(Clone)] +pub struct StandardTI { + lf: LazyFrame, +} + +#[gen_stub_pymethods] +#[pymethods] +impl StandardTI { + #[new] + fn new(lf: PyLfStubbed) -> Self { + Self { lf: lf.0.into() } + } + + /// Simple Moving Average (Single) - calculates the mean of all values in the column + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// + /// # Returns + /// f64 - Single SMA value calculated from all provided prices + fn sma_single(&self, price_column: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + + if values.is_empty() { + return Err(PyErr::new::("Series cannot be empty")); + } + + let result = rust_ti::standard_indicators::single::simple_moving_average(&values); + Ok(result) + } + + /// Simple Moving Average (Bulk) - calculates the mean over a rolling window + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `period`: usize - Number of periods for the moving average window + /// + /// # Returns + /// PySeriesStubbed - Series containing SMA values for each period + fn sma_bulk(&self, price_column: &str, period: usize) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + + if values.len() < period { + return Err(PyErr::new::("Series length must be at least the specified period")); + } + + let sma_result = rust_ti::standard_indicators::bulk::simple_moving_average(&values, period); + let result_series = Series::new("sma".into(), sma_result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Smoothed Moving Average (Single) - single value calculation + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// + /// # Returns + /// f64 - Single SMMA value calculated from all provided prices + fn smma_single(&self, price_column: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + + if values.is_empty() { + return Err(PyErr::new::("Series cannot be empty")); + } + + let result = rust_ti::standard_indicators::single::smoothed_moving_average(&values); + Ok(result) + } + + /// Smoothed Moving Average (Bulk) - puts more weight on recent prices + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `period`: usize - Number of periods for the smoothed moving average window + /// + /// # Returns + /// PySeriesStubbed - Series containing SMMA values for each period + fn smma_bulk(&self, price_column: &str, period: usize) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + + if values.len() < period { + return Err(PyErr::new::("Series length must be at least the specified period")); + } + + let smma_result = rust_ti::standard_indicators::bulk::smoothed_moving_average(&values, period); + let result_series = Series::new("smma".into(), smma_result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Exponential Moving Average (Single) - single value calculation + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// + /// # Returns + /// f64 - Single EMA value calculated from all provided prices + fn ema_single(&self, price_column: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + + if values.is_empty() { + return Err(PyErr::new::("Series cannot be empty")); + } + + let result = rust_ti::standard_indicators::single::exponential_moving_average(&values); + Ok(result) + } + + /// Exponential Moving Average (Bulk) - puts exponentially more weight on recent prices + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `period`: usize - Number of periods for the exponential moving average window + /// + /// # Returns + /// PySeriesStubbed - Series containing EMA values for each period + fn ema_bulk(&self, price_column: &str, period: usize) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + + if values.len() < period { + return Err(PyErr::new::("Series length must be at least the specified period")); + } + + let ema_result = rust_ti::standard_indicators::bulk::exponential_moving_average(&values, period); + let result_series = Series::new("ema".into(), ema_result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Bollinger Bands (Single) - single value calculation (requires exactly 20 periods) + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// + /// # Returns + /// Tuple of (lower_band: f64, middle_band: f64, upper_band: f64) + /// - `lower_band`: Lower Bollinger Band value + /// - `middle_band`: Middle band (SMA) value + /// - `upper_band`: Upper Bollinger Band value + fn bollinger_bands_single(&self, price_column: &str) -> PyResult<(f64, f64, f64)> { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + + if values.len() != 20 { + return Err(PyErr::new::("Series length must be exactly 20 for single Bollinger Bands calculation")); + } + + let result = rust_ti::standard_indicators::single::bollinger_bands(&values); + Ok(result) + } + + /// Bollinger Bands (Bulk) - returns three series: lower band, middle (SMA), upper band + /// Standard period is 20 with 2 standard deviations + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// + /// # Returns + /// PyDfStubbed - DataFrame with three columns: + /// - `bb_lower`: Lower Bollinger Band values + /// - `bb_middle`: Middle band (20-period SMA) + /// - `bb_upper`: Upper Bollinger Band values + fn bollinger_bands_bulk(&self, price_column: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + + if values.len() < 20 { + return Err(PyErr::new::("Series length must be at least 20 for Bollinger Bands")); + } + + let bb_result = rust_ti::standard_indicators::bulk::bollinger_bands(&values); + + let lower: Vec = bb_result.iter().map(|(l, _, _)| *l).collect(); + let middle: Vec = bb_result.iter().map(|(_, m, _)| *m).collect(); + let upper: Vec = bb_result.iter().map(|(_, _, u)| *u).collect(); + + create_triple_df(lower, middle, upper, "bb_lower", "bb_middle", "bb_upper") + } + + /// MACD (Single) - single value calculation (requires exactly 34 periods) + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// + /// # Returns + /// Tuple of (macd_line: f64, signal_line: f64, histogram: f64) + /// - `macd_line`: MACD line value (12-period EMA - 26-period EMA) + /// - `signal_line`: Signal line value (9-period EMA of MACD line) + /// - `histogram`: Histogram value (MACD line - Signal line) + fn macd_single(&self, price_column: &str) -> PyResult<(f64, f64, f64)> { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + + if values.len() != 34 { + return Err(PyErr::new::("Series length must be exactly 34 for single MACD calculation")); + } + + let result = rust_ti::standard_indicators::single::macd(&values); + Ok(result) + } + + /// MACD (Bulk) - Moving Average Convergence Divergence + /// Returns three series: MACD line, Signal line, Histogram + /// Standard periods: 12, 26, 9 + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// + /// # Returns + /// PyDfStubbed - DataFrame with three columns: + /// - `macd`: MACD line (12-period EMA - 26-period EMA) + /// - `macd_signal`: Signal line (9-period EMA of MACD line) + /// - `macd_histogram`: Histogram (MACD line - Signal line) + fn macd_bulk(&self, price_column: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + + if values.len() < 34 { + return Err(PyErr::new::("Series length must be at least 34 for MACD")); + } + + let macd_result = rust_ti::standard_indicators::bulk::macd(&values); + + let macd_line: Vec = macd_result.iter().map(|(m, _, _)| *m).collect(); + let signal_line: Vec = macd_result.iter().map(|(_, s, _)| *s).collect(); + let histogram: Vec = macd_result.iter().map(|(_, _, h)| *h).collect(); + + create_triple_df(macd_line, signal_line, histogram, "macd", "macd_signal", "macd_histogram") + } + + /// RSI (Single) - single value calculation (requires exactly 14 periods) + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// + /// # Returns + /// f64 - Single RSI value (0-100 scale) + fn rsi_single(&self, price_column: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + + if values.len() != 14 { + return Err(PyErr::new::("Series length must be exactly 14 for single RSI calculation")); + } + + let result = rust_ti::standard_indicators::single::rsi(&values); + Ok(result) + } + + /// RSI (Bulk) - Relative Strength Index + /// Standard period is 14 using smoothed moving average + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// + /// # Returns + /// PySeriesStubbed - Series containing RSI values (0-100 scale) + fn rsi_bulk(&self, price_column: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + + if values.len() < 14 { + return Err(PyErr::new::("Series length must be at least 14 for RSI")); + } + + let rsi_result = rust_ti::standard_indicators::bulk::rsi(&values); + let result_series = Series::new("rsi".into(), rsi_result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use approx::assert_abs_diff_eq; + use ezpz_stubz::lazy::PyLfStubbed; + + fn create_test_dataframe() -> LazyFrame { + let data = vec![ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, + 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, + ]; + df! { + "price" => data, + "volume" => vec![100.0; 35] + } + .unwrap() + .lazy() + } + + fn create_standard_ti() -> StandardTI { + let lf = create_test_dataframe(); + StandardTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(lf))) + } + + fn create_small_dataframe() -> LazyFrame { + let data = vec![1.0, 2.0, 3.0, 4.0, 5.0]; + df! { + "price" => data + } + .unwrap() + .lazy() + } + + fn create_small_ti() -> StandardTI { + let lf = create_small_dataframe(); + StandardTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(lf))) + } + + #[test] + fn test_sma_single() { + let ti = create_standard_ti(); + let result = ti.sma_single("price").unwrap(); + assert_abs_diff_eq!(result, 18.0, epsilon = 1e-10); + } + + #[test] + fn test_sma_bulk() { + let ti = create_standard_ti(); + let result = ti.sma_bulk("price", 3).unwrap(); + let values: Vec = result.0.0.f64().unwrap().into_no_null_iter().collect(); + + assert_eq!(values.len(), 33); + assert_abs_diff_eq!(values[0], 2.0, epsilon = 1e-10); + assert_abs_diff_eq!(values[1], 3.0, epsilon = 1e-10); + assert_abs_diff_eq!(values[2], 4.0, epsilon = 1e-10); + } + + #[test] + fn test_sma_bulk_insufficient_data() { + let ti = create_small_ti(); + let result = ti.sma_bulk("price", 10); + assert!(result.is_err()); + } + + #[test] + fn test_smma_single() { + let ti = create_standard_ti(); + let result = ti.smma_single("price").unwrap(); + assert!(result > 0.0); + } + + #[test] + fn test_smma_bulk() { + let ti = create_standard_ti(); + let result = ti.smma_bulk("price", 5).unwrap(); + let values: Vec = result.0.0.f64().unwrap().into_no_null_iter().collect(); + + assert_eq!(values.len(), 31); + assert!(values.iter().all(|&x| x > 0.0)); + } + + #[test] + fn test_smma_bulk_insufficient_data() { + let ti = create_small_ti(); + let result = ti.smma_bulk("price", 10); + assert!(result.is_err()); + } + + #[test] + fn test_ema_single() { + let ti = create_standard_ti(); + let result = ti.ema_single("price").unwrap(); + assert!(result > 0.0); + } + + #[test] + fn test_ema_bulk() { + let ti = create_standard_ti(); + let result = ti.ema_bulk("price", 5).unwrap(); + let values: Vec = result.0.0.f64().unwrap().into_no_null_iter().collect(); + + assert_eq!(values.len(), 31); + assert!(values.iter().all(|&x| x > 0.0)); + } + + #[test] + fn test_ema_bulk_insufficient_data() { + let ti = create_small_ti(); + let result = ti.ema_bulk("price", 10); + assert!(result.is_err()); + } + + #[test] + fn test_bollinger_bands_single() { + let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0]; + let df = df! { + "price" => data + } + .unwrap() + .lazy(); + + let ti = StandardTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(df))); + let result = ti.bollinger_bands_single("price").unwrap(); + + assert!(result.0 < result.1); + assert!(result.1 < result.2); + assert_abs_diff_eq!(result.1, 10.5, epsilon = 1e-10); + } + + #[test] + fn test_bollinger_bands_single_wrong_length() { + let ti = create_small_ti(); + let result = ti.bollinger_bands_single("price"); + assert!(result.is_err()); + } + + #[test] + fn test_bollinger_bands_bulk() { + let ti = create_standard_ti(); + let result = ti.bollinger_bands_bulk("price").unwrap(); + let df = result.0.0; + + assert_eq!(df.height(), 16); + assert_eq!(df.width(), 3); + assert!(df.column("bb_lower").is_ok()); + assert!(df.column("bb_middle").is_ok()); + assert!(df.column("bb_upper").is_ok()); + } + + #[test] + fn test_bollinger_bands_bulk_insufficient_data() { + let ti = create_small_ti(); + let result = ti.bollinger_bands_bulk("price"); + assert!(result.is_err()); + } + + #[test] + fn test_macd_single() { + let data: Vec = (1..=34).map(|x| x as f64).collect(); + let df = df! { + "price" => data + } + .unwrap() + .lazy(); + + let ti = StandardTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(df))); + let result = ti.macd_single("price").unwrap(); + + assert!(result.0.is_finite()); + assert!(result.1.is_finite()); + assert!(result.2.is_finite()); + } + + #[test] + fn test_macd_single_wrong_length() { + let ti = create_small_ti(); + let result = ti.macd_single("price"); + assert!(result.is_err()); + } + + #[test] + fn test_macd_bulk() { + let ti = create_standard_ti(); + let result = ti.macd_bulk("price").unwrap(); + let df = result.0.0; + + assert_eq!(df.height(), 2); + assert_eq!(df.width(), 3); + assert!(df.column("macd").is_ok()); + assert!(df.column("macd_signal").is_ok()); + assert!(df.column("macd_histogram").is_ok()); + } + + #[test] + fn test_macd_bulk_insufficient_data() { + let ti = create_small_ti(); + let result = ti.macd_bulk("price"); + assert!(result.is_err()); + } + + #[test] + fn test_rsi_single() { + let data: Vec = (1..=14).map(|x| x as f64).collect(); + let df = df! { + "price" => data + } + .unwrap() + .lazy(); + + let ti = StandardTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(df))); + let result = ti.rsi_single("price").unwrap(); + + assert!((0.0..=100.0).contains(&result)); + } + + #[test] + fn test_rsi_single_wrong_length() { + let ti = create_small_ti(); + let result = ti.rsi_single("price"); + assert!(result.is_err()); + } + + #[test] + fn test_rsi_bulk() { + let ti = create_standard_ti(); + let result = ti.rsi_bulk("price").unwrap(); + let values: Vec = result.0.0.f64().unwrap().into_no_null_iter().collect(); + + assert_eq!(values.len(), 22); + assert!(values.iter().all(|&x| (0.0..=100.0).contains(&x))); + } + + #[test] + fn test_rsi_bulk_insufficient_data() { + let ti = create_small_ti(); + let result = ti.rsi_bulk("price"); + assert!(result.is_err()); + } + + #[test] + fn test_invalid_column_error() { + let ti = create_standard_ti(); + let result = ti.sma_single("nonexistent_column"); + assert!(result.is_err()); + } + + #[test] + fn test_empty_series_error() { + let empty_df = df! { + "price" => Vec::::new() + } + .unwrap() + .lazy(); + + let ti = StandardTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(empty_df))); + let result = ti.sma_single("price"); + assert!(result.is_err()); + } + + #[test] + fn test_single_value_dataset() { + let single_data = df! { + "price" => vec![5.0] + } + .unwrap() + .lazy(); + + let ti = StandardTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(single_data))); + + assert_abs_diff_eq!(ti.sma_single("price").unwrap(), 5.0, epsilon = 1e-10); + assert_abs_diff_eq!(ti.smma_single("price").unwrap(), 5.0, epsilon = 1e-10); + assert_abs_diff_eq!(ti.ema_single("price").unwrap(), 5.0, epsilon = 1e-10); + } + + #[test] + fn test_duplicate_values_dataset() { + let duplicate_data = df! { + "price" => vec![3.0; 35] + } + .unwrap() + .lazy(); + + let ti = StandardTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(duplicate_data))); + + assert_abs_diff_eq!(ti.sma_single("price").unwrap(), 3.0, epsilon = 1e-10); + assert_abs_diff_eq!(ti.smma_single("price").unwrap(), 3.0, epsilon = 1e-10); + assert_abs_diff_eq!(ti.ema_single("price").unwrap(), 3.0, epsilon = 1e-10); + } + + #[test] + fn test_sma_bulk_period_one() { + let ti = create_standard_ti(); + let result = ti.sma_bulk("price", 1).unwrap(); + let values: Vec = result.0.0.f64().unwrap().into_no_null_iter().collect(); + + assert_eq!(values.len(), 35); + for (i, &value) in values.iter().enumerate() { + assert_abs_diff_eq!(value, (i + 1) as f64, epsilon = 1e-10); + } + } + + #[test] + fn test_bollinger_bands_bulk_band_ordering() { + let ti = create_standard_ti(); + let result = ti.bollinger_bands_bulk("price").unwrap(); + let df = result.0.0; + + let lower: Vec = df.column("bb_lower").unwrap().f64().unwrap().into_no_null_iter().collect(); + let middle: Vec = df.column("bb_middle").unwrap().f64().unwrap().into_no_null_iter().collect(); + let upper: Vec = df.column("bb_upper").unwrap().f64().unwrap().into_no_null_iter().collect(); + + for i in 0..lower.len() { + assert!(lower[i] < middle[i]); + assert!(middle[i] < upper[i]); + } + } + + #[test] + fn test_macd_bulk_histogram_calculation() { + let ti = create_standard_ti(); + let result = ti.macd_bulk("price").unwrap(); + let df = result.0.0; + + let macd: Vec = df.column("macd").unwrap().f64().unwrap().into_no_null_iter().collect(); + let signal: Vec = df.column("macd_signal").unwrap().f64().unwrap().into_no_null_iter().collect(); + let histogram: Vec = df.column("macd_histogram").unwrap().f64().unwrap().into_no_null_iter().collect(); + + for i in 0..macd.len() { + assert_abs_diff_eq!(histogram[i], macd[i] - signal[i], epsilon = 1e-10); + } + } +} diff --git a/plugins/ezpz-rust-ti/src/indicators/strength/mod.rs b/plugins/ezpz-rust-ti/src/indicators/strength/mod.rs new file mode 100644 index 0000000..481c973 --- /dev/null +++ b/plugins/ezpz-rust-ti/src/indicators/strength/mod.rs @@ -0,0 +1,683 @@ +use { + crate::utils::{extract_f64_values, parse_constant_model_type}, + ezpz_stubz::{lazy::PyLfStubbed, series::PySeriesStubbed}, + polars::prelude::*, + pyo3::prelude::*, + pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}, +}; + +/// Strength Technical Indicators - A collection of strength analysis functions for financial data + +#[gen_stub_pyclass] +#[pyclass] +#[derive(Clone)] +pub struct StrengthTI { + lf: LazyFrame, +} + +#[gen_stub_pymethods] +#[pymethods] +impl StrengthTI { + #[new] + fn new(lf: PyLfStubbed) -> Self { + Self { lf: lf.0.into() } + } + + /// Accumulation Distribution (Single) - Shows whether the stock is being accumulated or distributed + /// Single value calculation using the last available values + /// + /// # Parameters + /// - `high_column`: &str - Name of the high price column + /// - `low_column`: &str - Name of the low price column + /// - `close_column`: &str - Name of the close price column + /// - `volume_column`: &str - Name of the volume column + /// - `previous_ad`: Option - Previous accumulation/distribution value (defaults to 0.0) + /// + /// # Returns + /// f64 - Single accumulation/distribution value + fn accumulation_distribution_single( + &self, + high_column: &str, + low_column: &str, + close_column: &str, + volume_column: &str, + previous_ad: Option, + ) -> PyResult { + let df = self + .lf + .clone() + .select([col(high_column), col(low_column), col(close_column), col(volume_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let close_series = df + .column(close_column) + .map_err(|e| PyErr::new::(format!("Column '{close_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{close_column}' could not be converted to Series")))? + .clone(); + + let volume_series = df + .column(volume_column) + .map_err(|e| PyErr::new::(format!("Column '{volume_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{volume_column}' could not be converted to Series")))? + .clone(); + + let high_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let low_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + let close_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(close_series)))?; + let volume_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(volume_series)))?; + + let high = high_values.last().ok_or_else(|| PyErr::new::("High series is empty"))?; + let low = low_values.last().ok_or_else(|| PyErr::new::("Low series is empty"))?; + let close = close_values.last().ok_or_else(|| PyErr::new::("Close series is empty"))?; + let volume = volume_values.last().ok_or_else(|| PyErr::new::("Volume series is empty"))?; + + let previous = previous_ad.unwrap_or(0.0); + let result = rust_ti::strength_indicators::single::accumulation_distribution(*high, *low, *close, *volume, previous); + Ok(result) + } + + /// Accumulation Distribution (Bulk) - Shows whether the stock is being accumulated or distributed + /// Returns a series of accumulation/distribution values + /// + /// # Parameters + /// - `high_column`: &str - Name of the high price column + /// - `low_column`: &str - Name of the low price column + /// - `close_column`: &str - Name of the close price column + /// - `volume_column`: &str - Name of the volume column + /// - `previous_ad`: Option - Previous accumulation/distribution value (defaults to 0.0) + /// + /// # Returns + /// PySeriesStubbed - Series containing accumulation/distribution values with name "accumulation_distribution" + fn accumulation_distribution_bulk( + &self, + high_column: &str, + low_column: &str, + close_column: &str, + volume_column: &str, + previous_ad: Option, + ) -> PyResult { + let df = self + .lf + .clone() + .select([col(high_column), col(low_column), col(close_column), col(volume_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let close_series = df + .column(close_column) + .map_err(|e| PyErr::new::(format!("Column '{close_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{close_column}' could not be converted to Series")))? + .clone(); + + let volume_series = df + .column(volume_column) + .map_err(|e| PyErr::new::(format!("Column '{volume_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{volume_column}' could not be converted to Series")))? + .clone(); + + let high_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let low_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + let close_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(close_series)))?; + let volume_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(volume_series)))?; + + let previous = previous_ad.unwrap_or(0.0); + let result = rust_ti::strength_indicators::bulk::accumulation_distribution(&high_values, &low_values, &close_values, &volume_values, previous); + + let result_series = Series::new("accumulation_distribution".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Positive Volume Index (Single) - Measures volume trend strength when volume increases + /// Single value calculation using the last available values + /// + /// # Parameters + /// - `close_column`: &str - Name of the close price column + /// - `volume_column`: &str - Name of the volume column + /// - `previous_pvi`: Option - Previous positive volume index value (defaults to 0.0) + /// + /// # Returns + /// f64 - Single positive volume index value + fn positive_volume_index_single(&self, close_column: &str, volume_column: &str, previous_pvi: Option) -> PyResult { + let df = self + .lf + .clone() + .select([col(close_column), col(volume_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let close_series = df + .column(close_column) + .map_err(|e| PyErr::new::(format!("Column '{close_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{close_column}' could not be converted to Series")))? + .clone(); + + let volume_series = df + .column(volume_column) + .map_err(|e| PyErr::new::(format!("Column '{volume_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{volume_column}' could not be converted to Series")))? + .clone(); + + let close_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(close_series)))?; + let volume_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(volume_series)))?; + + // Need at least 2 values for comparison + if close_values.len() < 2 { + return Err(PyErr::new::("Need at least 2 values for volume index calculation")); + } + + let current_close = close_values.last().unwrap(); + let previous_close = close_values[close_values.len() - 2]; + let current_volume = volume_values.last().ok_or_else(|| PyErr::new::("Volume series is empty"))?; + let previous_volume = + volume_values.get(volume_values.len() - 2).ok_or_else(|| PyErr::new::("Need at least 2 volume values"))?; + + let previous = previous_pvi.unwrap_or(0.0); + + // Calculate PVI: only update when volume increases + let result = if current_volume > previous_volume { previous + ((*current_close - previous_close) / previous_close) * previous } else { previous }; + + Ok(result) + } + + /// Positive Volume Index (Bulk) - Measures volume trend strength when volume increases + /// Returns a series of positive volume index values + /// + /// # Parameters + /// - `close_column`: &str - Name of the close price column + /// - `volume_column`: &str - Name of the volume column + /// - `previous_pvi`: Option - Previous positive volume index value (defaults to 0.0) + /// + /// # Returns + /// PySeriesStubbed - Series containing positive volume index values with name "positive_volume_index" + fn positive_volume_index_bulk(&self, close_column: &str, volume_column: &str, previous_pvi: Option) -> PyResult { + let df = self + .lf + .clone() + .select([col(close_column), col(volume_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let close_series = df + .column(close_column) + .map_err(|e| PyErr::new::(format!("Column '{close_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{close_column}' could not be converted to Series")))? + .clone(); + + let volume_series = df + .column(volume_column) + .map_err(|e| PyErr::new::(format!("Column '{volume_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{volume_column}' could not be converted to Series")))? + .clone(); + + let close_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(close_series)))?; + let volume_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(volume_series)))?; + + let previous = previous_pvi.unwrap_or(0.0); + let result = rust_ti::strength_indicators::bulk::positive_volume_index(&close_values, &volume_values, previous); + + let result_series = Series::new("positive_volume_index".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Negative Volume Index (Single) - Measures volume trend strength when volume decreases + /// Single value calculation using the last available values + /// + /// # Parameters + /// - `close_column`: &str - Name of the close price column + /// - `volume_column`: &str - Name of the volume column + /// - `previous_nvi`: Option - Previous negative volume index value (defaults to 0.0) + /// + /// # Returns + /// f64 - Single negative volume index value + fn negative_volume_index_single(&self, close_column: &str, volume_column: &str, previous_nvi: Option) -> PyResult { + let df = self + .lf + .clone() + .select([col(close_column), col(volume_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let close_series = df + .column(close_column) + .map_err(|e| PyErr::new::(format!("Column '{close_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{close_column}' could not be converted to Series")))? + .clone(); + + let volume_series = df + .column(volume_column) + .map_err(|e| PyErr::new::(format!("Column '{volume_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{volume_column}' could not be converted to Series")))? + .clone(); + + let close_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(close_series)))?; + let volume_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(volume_series)))?; + + // Need at least 2 values for comparison + if close_values.len() < 2 { + return Err(PyErr::new::("Need at least 2 values for volume index calculation")); + } + + let current_close = close_values.last().unwrap(); + let previous_close = close_values[close_values.len() - 2]; + let current_volume = volume_values.last().ok_or_else(|| PyErr::new::("Volume series is empty"))?; + let previous_volume = + volume_values.get(volume_values.len() - 2).ok_or_else(|| PyErr::new::("Need at least 2 volume values"))?; + + let previous = previous_nvi.unwrap_or(0.0); + + // Calculate NVI: only update when volume decreases + let result = if current_volume < previous_volume { previous + ((*current_close - previous_close) / previous_close) * previous } else { previous }; + + Ok(result) + } + + /// Negative Volume Index (Bulk) - Measures volume trend strength when volume decreases + /// Returns a series of negative volume index values + /// + /// # Parameters + /// - `close_column`: &str - Name of the close price column + /// - `volume_column`: &str - Name of the volume column + /// - `previous_nvi`: Option - Previous negative volume index value (defaults to 0.0) + /// + /// # Returns + /// PySeriesStubbed - Series containing negative volume index values with name "negative_volume_index" + fn negative_volume_index_bulk(&self, close_column: &str, volume_column: &str, previous_nvi: Option) -> PyResult { + let df = self + .lf + .clone() + .select([col(close_column), col(volume_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let close_series = df + .column(close_column) + .map_err(|e| PyErr::new::(format!("Column '{close_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{close_column}' could not be converted to Series")))? + .clone(); + + let volume_series = df + .column(volume_column) + .map_err(|e| PyErr::new::(format!("Column '{volume_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{volume_column}' could not be converted to Series")))? + .clone(); + + let close_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(close_series)))?; + let volume_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(volume_series)))?; + + let previous = previous_nvi.unwrap_or(0.0); + let result = rust_ti::strength_indicators::bulk::negative_volume_index(&close_values, &volume_values, previous); + + let result_series = Series::new("negative_volume_index".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Relative Vigor Index (Single) - Measures the strength of an asset by looking at previous prices + /// Single value calculation using all available values + /// + /// # Parameters + /// - `open_column`: &str - Name of the opening price column + /// - `high_column`: &str - Name of the high price column + /// - `low_column`: &str - Name of the low price column + /// - `close_column`: &str - Name of the close price column + /// - `constant_model_type`: &str - Type of constant model to use + /// + /// # Returns + /// f64 - Single relative vigor index value + fn relative_vigor_index_single( + &self, + open_column: &str, + high_column: &str, + low_column: &str, + close_column: &str, + constant_model_type: &str, + ) -> PyResult { + let df = self + .lf + .clone() + .select([col(open_column), col(high_column), col(low_column), col(close_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let open_series = df + .column(open_column) + .map_err(|e| PyErr::new::(format!("Column '{open_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{open_column}' could not be converted to Series")))? + .clone(); + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let close_series = df + .column(close_column) + .map_err(|e| PyErr::new::(format!("Column '{close_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{close_column}' could not be converted to Series")))? + .clone(); + + let open_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(open_series)))?; + let high_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let low_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + let close_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(close_series)))?; + + let constant_type = parse_constant_model_type(constant_model_type)?; + let result = rust_ti::strength_indicators::single::relative_vigor_index(&open_values, &high_values, &low_values, &close_values, constant_type); + + Ok(result) + } + + /// Relative Vigor Index (Bulk) - Measures the strength of an asset by looking at previous prices + /// Returns a series of relative vigor index values + /// + /// # Parameters + /// - `open_column`: &str - Name of the opening price column + /// - `high_column`: &str - Name of the high price column + /// - `low_column`: &str - Name of the low price column + /// - `close_column`: &str - Name of the close price column + /// - `constant_model_type`: &str - Type of constant model to use + /// - `period`: usize - Period length for calculation + /// + /// # Returns + /// PySeriesStubbed - Series containing relative vigor index values with name "relative_vigor_index" + fn relative_vigor_index_bulk( + &self, + open_column: &str, + high_column: &str, + low_column: &str, + close_column: &str, + constant_model_type: &str, + period: usize, + ) -> PyResult { + let df = self + .lf + .clone() + .select([col(open_column), col(high_column), col(low_column), col(close_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let open_series = df + .column(open_column) + .map_err(|e| PyErr::new::(format!("Column '{open_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{open_column}' could not be converted to Series")))? + .clone(); + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let close_series = df + .column(close_column) + .map_err(|e| PyErr::new::(format!("Column '{close_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{close_column}' could not be converted to Series")))? + .clone(); + + let open_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(open_series)))?; + let high_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let low_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + let close_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(close_series)))?; + + let constant_type = parse_constant_model_type(constant_model_type)?; + let result = rust_ti::strength_indicators::bulk::relative_vigor_index(&open_values, &high_values, &low_values, &close_values, constant_type, period); + + let result_series = Series::new("relative_vigor_index".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use approx::assert_abs_diff_eq; + use ezpz_stubz::lazy::PyLfStubbed; + + fn create_test_dataframe() -> LazyFrame { + let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + df! { + "open" => data.clone(), + "high" => data.iter().map(|x| x + 1.0).collect::>(), + "low" => data.iter().map(|x| x - 0.5).collect::>(), + "close" => data, + "volume" => vec![100.0, 200.0, 150.0, 300.0, 250.0, 180.0, 220.0, 190.0, 280.0, 320.0] + } + .unwrap() + .lazy() + } + + fn create_strength_ti() -> StrengthTI { + let lf = create_test_dataframe(); + StrengthTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(lf))) + } + + #[test] + fn test_accumulation_distribution_single() { + let ti = create_strength_ti(); + let result = ti.accumulation_distribution_single("high", "low", "close", "volume", None).unwrap(); + // Formula: ((close - low) - (high - close)) / (high - low) * volume + previous_ad + // For last row: high=11.0, low=9.5, close=10.0, volume=320.0, previous_ad=0.0 + // ((10.0 - 9.5) - (11.0 - 10.0)) / (11.0 - 9.5) * 320.0 = (0.5 - 1.0) / 1.5 * 320.0 = -106.66666666666667 + assert_abs_diff_eq!(result, -106.66666666666667, epsilon = 1e-10); + } + + #[test] + fn test_accumulation_distribution_single_invalid_column() { + let ti = create_strength_ti(); + let result = ti.accumulation_distribution_single("invalid", "low", "close", "volume", None); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().to_string(), "PyValueError: Column 'invalid' not found: ColumnNotFound"); + } + + #[test] + fn test_accumulation_distribution_bulk() { + let ti = create_strength_ti(); + let result = ti.accumulation_distribution_bulk("high", "low", "close", "volume", None).unwrap(); + let values: Vec = extract_f64_values(result).unwrap(); + + assert_eq!(values.len(), 10); + // First value: high=2.0, low=0.5, close=1.0, volume=100.0 + // ((1.0 - 0.5) - (2.0 - 1.0)) / (2.0 - 0.5) * 100.0 = (0.5 - 1.0) / 1.5 * 100.0 = -33.333333333333336 + assert_abs_diff_eq!(values[0], -33.333333333333336, epsilon = 1e-10); + // Last value: high=11.0, low=9.5, close=10.0, volume=320.0 + assert_abs_diff_eq!(values[9], -106.66666666666667, epsilon = 1e-10); + } + + #[test] + fn test_positive_volume_index_single() { + let ti = create_strength_ti(); + let result = ti.positive_volume_index_single("close", "volume", None).unwrap(); + // Last volume (320.0) > previous volume (280.0) + // Formula: previous_pvi + ((current_close - previous_close) / previous_close) * previous_pvi + // previous_pvi=0.0, so result=0.0 (since initial previous_pvi is 0) + assert_abs_diff_eq!(result, 0.0, epsilon = 1e-10); + } + + #[test] + fn test_positive_volume_index_single_insufficient_data() { + let mut ti = create_strength_ti(); + // Create a dataframe with only one row + let lf = df! { + "close" => vec![10.0], + "volume" => vec![320.0] + } + .unwrap() + .lazy(); + ti.lf = lf; + let result = ti.positive_volume_index_single("close", "volume", None); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().to_string(), "PyValueError: Need at least 2 values for volume index calculation"); + } + + #[test] + fn test_positive_volume_index_bulk() { + let ti = create_strength_ti(); + let result = ti.positive_volume_index_bulk("close", "volume", None).unwrap(); + let values: Vec = extract_f64_values(result).unwrap(); + + assert_eq!(values.len(), 10); + // First value is 0.0 (initial value) + assert_abs_diff_eq!(values[0], 0.0, epsilon = 1e-10); + // Check a case where volume increases (index 1: volume 200.0 > 100.0) + // PVI = previous_pvi + ((current_close - previous_close) / previous_close) * previous_pvi + // close[1]=2.0, close[0]=1.0, previous_pvi=0.0 -> 0.0 + ((2.0 - 1.0) / 1.0) * 0.0 = 0.0 + assert_abs_diff_eq!(values[1], 0.0, epsilon = 1e-10); + } + + #[test] + fn test_negative_volume_index_single() { + let ti = create_strength_ti(); + let result = ti.negative_volume_index_single("close", "volume", None).unwrap(); + // Last volume (320.0) > previous volume (280.0), so NVI doesn't update + // Result is previous_nvi (0.0) + assert_abs_diff_eq!(result, 0.0, epsilon = 1e-10); + } + + #[test] + fn test_negative_volume_index_single_decreasing_volume() { + let mut ti = create_strength_ti(); + // Modify volume so last volume < previous volume + let lf = df! { + "close" => vec![1.0, 2.0, 3.0], + "volume" => vec![200.0, 150.0, 100.0] + } + .unwrap() + .lazy(); + ti.lf = lf; + let result = ti.negative_volume_index_single("close", "volume", Some(1000.0)).unwrap(); + // Volume decreases (100.0 < 150.0) + // NVI = previous_nvi + ((current_close - previous_close) / previous_close) * previous_nvi + // close[2]=3.0, close[1]=2.0, previous_nvi=1000.0 + // NVI = 1000.0 + ((3.0 - 2.0) / 2.0) * 1000.0 = 1000.0 + 0.5 * 1000.0 = 1500.0 + assert_abs_diff_eq!(result, 1500.0, epsilon = 1e-10); + } + + #[test] + fn test_negative_volume_index_bulk() { + let ti = create_strength_ti(); + let result = ti.negative_volume_index_bulk("close", "volume", None).unwrap(); + let values: Vec = extract_f64_values(result).unwrap(); + + assert_eq!(values.len(), 10); + // First value is 0.0 (initial value) + assert_abs_diff_eq!(values[0], 0.0, epsilon = 1e-10); + // Check a case where volume decreases (index 2: volume 150.0 < 200.0) + // NVI = previous_nvi + ((current_close - previous_close) / previous_close) * previous_nvi + // close[2]=3.0, close[1]=2.0, previous_nvi=0.0 -> 0.0 + ((3.0 - 2.0) / 2.0) * 0.0 = 0.0 + assert_abs_diff_eq!(values[2], 0.0, epsilon = 1e-10); + } + + #[test] + fn test_relative_vigor_index_single() { + let ti = create_strength_ti(); + let result = ti.relative_vigor_index_single("open", "high", "low", "close", "mean").unwrap(); + // RVI = ((close - open) / (high - low)) for the period, normalized by constant model (mean) + // Using all values, calculate average (close - open) / (high - low) + // For each row: close - open = 0.0, high - low = 1.5, so (0.0 / 1.5) = 0.0 + // Mean of these values over 10 rows = 0.0 + assert_abs_diff_eq!(result, 0.0, epsilon = 1e-10); + } + + #[test] + fn test_relative_vigor_index_single_invalid_constant_model() { + let ti = create_strength_ti(); + let result = ti.relative_vigor_index_single("open", "high", "low", "close", "invalid"); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Invalid constant model type")); + } + + #[test] + fn test_relative_vigor_index_bulk() { + let ti = create_strength_ti(); + let result = ti.relative_vigor_index_bulk("open", "high", "low", "close", "mean", 3).unwrap(); + let values: Vec = extract_f64_values(result).unwrap(); + + assert_eq!(values.len(), 10); + // First two values are NaN due to period=3 + assert!(values[0].is_nan()); + assert!(values[1].is_nan()); + // For index 2: take rows 0-2 + // Each row: (close - open) / (high - low) = 0.0 / 1.5 = 0.0 + // Mean over 3 rows = 0.0 + assert_abs_diff_eq!(values[2], 0.0, epsilon = 1e-10); + } + + #[test] + fn test_relative_vigor_index_bulk_empty_data() { + let mut ti = create_strength_ti(); + let lf = df! { + "open" => Vec::::new(), + "high" => Vec::::new(), + "low" => Vec::::new(), + "close" => Vec::::new() + } + .unwrap() + .lazy(); + ti.lf = lf; + let result = ti.relative_vigor_index_bulk("open", "high", "low", "close", "mean", 3); + assert!(result.is_ok()); + let values: Vec = extract_f64_values(result.unwrap()).unwrap(); + assert_eq!(values.len(), 0); + } +} diff --git a/plugins/ezpz-rust-ti/src/indicators/trend/mod.rs b/plugins/ezpz-rust-ti/src/indicators/trend/mod.rs new file mode 100644 index 0000000..b880a56 --- /dev/null +++ b/plugins/ezpz-rust-ti/src/indicators/trend/mod.rs @@ -0,0 +1,856 @@ +use { + crate::utils::{create_triple_df, extract_f64_values, parse_constant_model_type}, + ezpz_stubz::{frame::PyDfStubbed, lazy::PyLfStubbed, series::PySeriesStubbed}, + polars::prelude::*, + pyo3::prelude::*, + pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}, +}; + +/// Trend Technical Indicators - A collection of trend analysis functions for financial data +#[gen_stub_pyclass] +#[pyclass] +#[derive(Clone)] +pub struct TrendTI { + lf: LazyFrame, +} + +#[gen_stub_pymethods] +#[pymethods] +impl TrendTI { + #[new] + fn new(lf: PyLfStubbed) -> Self { + Self { lf: lf.0.into() } + } + + // Single value functions (return a single value from the entire series) + + /// Aroon Up (Single) - Measures the strength of upward price momentum + /// Calculates the percentage of time since the highest high within the series + /// + /// # Parameters + /// - `high_column`: &str - Name of the high price column to analyze + /// + /// # Returns + /// f64 - Aroon Up value (0-100), where higher values indicate stronger upward momentum + fn aroon_up_single(&self, high_column: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(high_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{high_column}': {e}")))? + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::trend_indicators::single::aroon_up(&values); + Ok(result) + } + + /// Aroon Down (Single) - Measures the strength of downward price momentum + /// Calculates the percentage of time since the lowest low within the series + /// + /// # Parameters + /// - `low_column`: &str - Name of the low price column to analyze + /// + /// # Returns + /// f64 - Aroon Down value (0-100), where higher values indicate stronger downward momentum + fn aroon_down_single(&self, low_column: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(low_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{low_column}': {e}")))? + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::trend_indicators::single::aroon_down(&values); + Ok(result) + } + + /// Aroon Oscillator (Single) - Calculates the difference between Aroon Up and Aroon Down + /// Provides a single measure of trend direction and strength + /// + /// # Parameters + /// - `high_column`: &str - Name of the high price column + /// - `low_column`: &str - Name of the low price column + /// + /// # Returns + /// f64 - Aroon Oscillator value (-100 to 100), where positive values indicate upward trend + fn aroon_oscillator_single(&self, high_column: &str, low_column: &str) -> PyResult { + let df = self + .lf + .clone() + .select([col(high_column), col(low_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let high_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let low_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + + let result = rust_ti::trend_indicators::single::aroon_indicator(&high_values, &low_values); + Ok(result.2) // Return the oscillator component + } + + /// Aroon Indicator (Single) - Calculates complete Aroon system in one call + /// Computes Aroon Up, Aroon Down, and Aroon Oscillator + /// + /// # Parameters + /// - `high_column`: &str - Name of the high price column + /// - `low_column`: &str - Name of the low price column + /// + /// # Returns + /// (f64, f64, f64) - Tuple containing (Aroon Up, Aroon Down, Aroon Oscillator) + fn aroon_indicator_single(&self, high_column: &str, low_column: &str) -> PyResult<(f64, f64, f64)> { + let df = self + .lf + .clone() + .select([col(high_column), col(low_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let high_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let low_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + + let result = rust_ti::trend_indicators::single::aroon_indicator(&high_values, &low_values); + Ok(result) + } + + /// True Strength Index (Single) - Momentum oscillator using double-smoothed price changes + /// Filters out price noise to provide clearer momentum signals + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `first_constant_model`: &str - First smoothing method ("SimpleMovingAverage", "ExponentialMovingAverage", etc.) + /// - `first_period`: usize - Period for first smoothing + /// - `second_constant_model`: &str - Second smoothing method + /// + /// # Returns + /// f64 - True Strength Index value (-100 to 100) + fn true_strength_index_single(&self, price_column: &str, first_constant_model: &str, first_period: usize, second_constant_model: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let first_model = parse_constant_model_type(first_constant_model)?; + let second_model = parse_constant_model_type(second_constant_model)?; + let result = rust_ti::trend_indicators::single::true_strength_index(&values, first_model, first_period, second_model); + Ok(result) + } + + // Bulk functions (return series of values) + + /// Aroon Up (Bulk) - Calculates rolling Aroon Up indicator over specified period + /// Measures upward momentum strength for each period in the time series + /// + /// # Parameters + /// - `high_column`: &str - Name of the high price column to analyze + /// - `period`: usize - Lookback period for calculation (typically 14) + /// + /// # Returns + /// PySeriesStubbed - Series of Aroon Up values (0-100) named "aroon_up" + fn aroon_up_bulk(&self, high_column: &str, period: usize) -> PyResult { + let series = self + .lf + .clone() + .select([col(high_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{high_column}': {e}")))? + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::trend_indicators::bulk::aroon_up(&values, period); + let result_series = Series::new("aroon_up".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Aroon Down (Bulk) - Calculates rolling Aroon Down indicator over specified period + /// Measures downward momentum strength for each period in the time series + /// + /// # Parameters + /// - `low_column`: &str - Name of the low price column to analyze + /// - `period`: usize - Lookback period for calculation (typically 14) + /// + /// # Returns + /// PySeriesStubbed - Series of Aroon Down values (0-100) named "aroon_down" + fn aroon_down_bulk(&self, low_column: &str, period: usize) -> PyResult { + let series = self + .lf + .clone() + .select([col(low_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{low_column}': {e}")))? + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::trend_indicators::bulk::aroon_down(&values, period); + let result_series = Series::new("aroon_down".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Aroon Oscillator (Bulk) - Calculates rolling Aroon Oscillator over specified period + /// Computes the difference between Aroon Up and Aroon Down for each period + /// + /// # Parameters + /// - `high_column`: &str - Name of the high price column + /// - `low_column`: &str - Name of the low price column + /// - `period`: usize - Lookback period for calculation (typically 14) + /// + /// # Returns + /// PySeriesStubbed - Series of Aroon Oscillator values (-100 to 100) named "aroon_oscillator" + fn aroon_oscillator_bulk(&self, high_column: &str, low_column: &str, period: usize) -> PyResult { + let df = self + .lf + .clone() + .select([col(high_column), col(low_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let high_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let low_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + + let aroon_up_result = rust_ti::trend_indicators::bulk::aroon_up(&high_values, period); + let aroon_down_result = rust_ti::trend_indicators::bulk::aroon_down(&low_values, period); + let result = rust_ti::trend_indicators::bulk::aroon_oscillator(&aroon_up_result, &aroon_down_result); + let result_series = Series::new("aroon_oscillator".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Aroon Indicator (Bulk) - Calculates complete Aroon system for time series data + /// Computes Aroon Up, Aroon Down, and Aroon Oscillator for each period + /// + /// # Parameters + /// - `high_column`: &str - Name of the high price column + /// - `low_column`: &str - Name of the low price column + /// - `period`: usize - Lookback period for calculation (typically 14) + /// + /// # Returns + /// PyDfStubbed - DataFrame with columns: "aroon_up", "aroon_down", "aroon_oscillator" + fn aroon_indicator_bulk(&self, high_column: &str, low_column: &str, period: usize) -> PyResult { + let df = self + .lf + .clone() + .select([col(high_column), col(low_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let high_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let low_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + + let aroon_result = rust_ti::trend_indicators::bulk::aroon_indicator(&high_values, &low_values, period); + let (aroon_up, aroon_down, aroon_oscillator) = { + let mut up = Vec::new(); + let mut down = Vec::new(); + let mut oscillator = Vec::new(); + for (val_up, val_down, val_osc) in aroon_result { + up.push(val_up); + down.push(val_down); + oscillator.push(val_osc); + } + (up, down, oscillator) + }; + + create_triple_df(aroon_up, aroon_down, aroon_oscillator, "aroon_up", "aroon_down", "aroon_oscillator") + } + + /// Parabolic Time Price System (Bulk) - Calculates Stop and Reverse points + /// Provides trailing stop levels for trend-following system + /// + /// # Parameters + /// - `high_column`: &str - Name of the high price column + /// - `low_column`: &str - Name of the low price column + /// - `acceleration_factor_start`: f64 - Initial acceleration factor (typically 0.02) + /// - `acceleration_factor_max`: f64 - Maximum acceleration factor (typically 0.20) + /// - `acceleration_factor_step`: f64 - Acceleration factor increment (typically 0.02) + /// - `start_position`: &str - Initial position: "Long" or "Short" + /// - `previous_sar`: f64 - Initial SAR value + /// + /// # Returns + /// PySeriesStubbed - Series of SAR values named "parabolic_sar" + #[allow(clippy::too_many_arguments)] + fn parabolic_time_price_system_bulk( + &self, + high_column: &str, + low_column: &str, + acceleration_factor_start: f64, + acceleration_factor_max: f64, + acceleration_factor_step: f64, + start_position: &str, + previous_sar: f64, + ) -> PyResult { + let df = self + .lf + .clone() + .select([col(high_column), col(low_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let high_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let low_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + + let position = match start_position { + "Long" => rust_ti::Position::Long, + "Short" => rust_ti::Position::Short, + _ => return Err(PyErr::new::("Invalid position. Use 'Long' or 'Short'")), + }; + + let result = rust_ti::trend_indicators::bulk::parabolic_time_price_system( + &high_values, + &low_values, + acceleration_factor_start, + acceleration_factor_max, + acceleration_factor_step, + position, + previous_sar, + ); + let result_series = Series::new("parabolic_sar".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Directional Movement System (Bulk) - Calculates complete DMS indicators + /// Computes +DI, -DI, ADX, and ADXR for trend strength analysis + /// + /// # Parameters + /// - `high_column`: &str - Name of the high price column + /// - `low_column`: &str - Name of the low price column + /// - `close_column`: &str - Name of the close price column + /// - `period`: usize - Calculation period (typically 14) + /// - `constant_model_type`: &str - Smoothing method: "SimpleMovingAverage", "SmoothedMovingAverage", etc. + /// + /// # Returns + /// PyDfStubbed - DataFrame with columns: "positive_di", "negative_di", "adx", "adxr" + fn directional_movement_system_bulk( + &self, + high_column: &str, + low_column: &str, + close_column: &str, + period: usize, + constant_model_type: &str, + ) -> PyResult { + let df = self + .lf + .clone() + .select([col(high_column), col(low_column), col(close_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let close_series = df + .column(close_column) + .map_err(|e| PyErr::new::(format!("Column '{close_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{close_column}' could not be converted to Series")))? + .clone(); + + let high_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let low_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + let close_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(close_series)))?; + + let constant_model = parse_constant_model_type(constant_model_type)?; + let dm_result = rust_ti::trend_indicators::bulk::directional_movement_system(&high_values, &low_values, &close_values, period, constant_model); + let (positive_di, negative_di, adx, adxr) = { + let mut pos_di = Vec::new(); + let mut neg_di = Vec::new(); + let mut adx_vals = Vec::new(); + let mut adxr_vals = Vec::new(); + for (val_pos, val_neg, val_adx, val_adxr) in dm_result { + pos_di.push(val_pos); + neg_di.push(val_neg); + adx_vals.push(val_adx); + adxr_vals.push(val_adxr); + } + (pos_di, neg_di, adx_vals, adxr_vals) + }; + + let df = df! { + "positive_di" => positive_di, + "negative_di" => negative_di, + "adx" => adx, + "adxr" => adxr, + } + .map_err(|e| PyErr::new::(e.to_string()))?; + + Ok(PyDfStubbed(pyo3_polars::PyDataFrame(df))) + } + + /// Volume Price Trend (Bulk) - Combines price and volume to show momentum + /// Shows the relationship between price movement and volume flow + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column + /// - `volume_column`: &str - Name of the volume column + /// - `previous_volume_price_trend`: f64 - Initial VPT value (typically 0) + /// + /// # Returns + /// PySeriesStubbed - Series of Volume Price Trend values named "volume_price_trend" + fn volume_price_trend_bulk(&self, price_column: &str, volume_column: &str, previous_volume_price_trend: f64) -> PyResult { + let df = self + .lf + .clone() + .select([col(price_column), col(volume_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let price_series = df + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let volume_series = df + .column(volume_column) + .map_err(|e| PyErr::new::(format!("Column '{volume_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{volume_column}' could not be converted to Series")))? + .clone(); + + let price_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(price_series)))?; + let volume_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(volume_series)))?; + + let result = rust_ti::trend_indicators::bulk::volume_price_trend(&price_values, &volume_values, previous_volume_price_trend); + let result_series = Series::new("volume_price_trend".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// True Strength Index (Bulk) - Double-smoothed momentum oscillator + /// Uses double-smoothed price changes to filter noise and provide clearer signals + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `first_constant_model`: &str - First smoothing method: "SimpleMovingAverage", "ExponentialMovingAverage", etc. + /// - `first_period`: usize - Period for first smoothing (typically 25) + /// - `second_constant_model`: &str - Second smoothing method + /// - `second_period`: usize - Period for second smoothing (typically 13) + /// + /// # Returns + /// PySeriesStubbed - Series of TSI values (-100 to 100) named "true_strength_index" + fn true_strength_index_bulk( + &self, + price_column: &str, + first_constant_model: &str, + first_period: usize, + second_constant_model: &str, + second_period: usize, + ) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let first_model = parse_constant_model_type(first_constant_model)?; + let second_model = parse_constant_model_type(second_constant_model)?; + let result = rust_ti::trend_indicators::bulk::true_strength_index(&values, first_model, first_period, second_model, second_period); + let result_series = Series::new("true_strength_index".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use approx::assert_abs_diff_eq; + use ezpz_stubz::lazy::PyLfStubbed; + + fn create_test_dataframe() -> LazyFrame { + let high_data = vec![10.5, 11.2, 12.0, 11.8, 12.5, 13.1, 12.9, 13.5, 14.2, 13.8]; + let low_data = vec![9.8, 10.1, 10.5, 10.2, 11.0, 11.5, 11.2, 12.0, 12.5, 12.1]; + let close_data = vec![10.2, 10.8, 11.5, 11.0, 11.8, 12.3, 12.1, 12.8, 13.5, 13.2]; + let volume_data = vec![1000.0, 1200.0, 1500.0, 1100.0, 1300.0, 1400.0, 1250.0, 1600.0, 1800.0, 1350.0]; + + df! { + "high" => high_data, + "low" => low_data, + "close" => close_data, + "volume" => volume_data + } + .unwrap() + .lazy() + } + + fn create_trend_ti() -> TrendTI { + let lf = create_test_dataframe(); + TrendTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(lf))) + } + + #[test] + fn test_aroon_up_single() { + let ti = create_trend_ti(); + let result = ti.aroon_up_single("high").unwrap(); + assert!((0.0..=100.0).contains(&result)); + } + + #[test] + fn test_aroon_down_single() { + let ti = create_trend_ti(); + let result = ti.aroon_down_single("low").unwrap(); + assert!((0.0..=100.0).contains(&result)); + } + + #[test] + fn test_aroon_oscillator_single() { + let ti = create_trend_ti(); + let result = ti.aroon_oscillator_single("high", "low").unwrap(); + assert!((-100.0..=100.0).contains(&result)); + } + + #[test] + fn test_aroon_indicator_single() { + let ti = create_trend_ti(); + let result = ti.aroon_indicator_single("high", "low").unwrap(); + assert!(result.0 >= 0.0 && result.0 <= 100.0); + assert!(result.1 >= 0.0 && result.1 <= 100.0); + assert!(result.2 >= -100.0 && result.2 <= 100.0); + assert_abs_diff_eq!(result.2, result.0 - result.1, epsilon = 1e-10); + } + + #[test] + fn test_true_strength_index_single() { + let ti = create_trend_ti(); + let result = ti.true_strength_index_single("close", "SimpleMovingAverage", 5, "ExponentialMovingAverage").unwrap(); + assert!((-100.0..=100.0).contains(&result)); + } + + #[test] + fn test_aroon_up_bulk() { + let ti = create_trend_ti(); + let result = ti.aroon_up_bulk("high", 5).unwrap(); + let series = result.0.0; + assert_eq!(series.name(), "aroon_up"); + assert_eq!(series.len(), 10); + + // Check that all values are within valid range + for i in 0..series.len() { + if let Some(val) = series.get(i).unwrap().extract::() { + assert!((0.0..=100.0).contains(&val)); + } + } + } + + #[test] + fn test_aroon_down_bulk() { + let ti = create_trend_ti(); + let result = ti.aroon_down_bulk("low", 5).unwrap(); + let series = result.0.0; + assert_eq!(series.name(), "aroon_down"); + assert_eq!(series.len(), 10); + + // Check that all values are within valid range + for i in 0..series.len() { + if let Some(val) = series.get(i).unwrap().extract::() { + assert!((0.0..=100.0).contains(&val)); + } + } + } + + #[test] + fn test_aroon_oscillator_bulk() { + let ti = create_trend_ti(); + let result = ti.aroon_oscillator_bulk("high", "low", 5).unwrap(); + let series = result.0.0; + assert_eq!(series.name(), "aroon_oscillator"); + assert_eq!(series.len(), 10); + + // Check that all values are within valid range + for i in 0..series.len() { + if let Some(val) = series.get(i).unwrap().extract::() { + assert!((-100.0..=100.0).contains(&val)); + } + } + } + + #[test] + fn test_aroon_indicator_bulk() { + let ti = create_trend_ti(); + let result = ti.aroon_indicator_bulk("high", "low", 5).unwrap(); + let df = result.0.as_ref(); + + assert_eq!(df.width(), 3); + assert_eq!(df.height(), 10); + + let aroon_up = PlSmallStr::from("aroon_up"); + let aroon_down = PlSmallStr::from("aroon_down"); + let aroon_oscillator = PlSmallStr::from("aroon_oscillator"); + assert!(df.get_column_names().contains(&&aroon_up)); + assert!(df.get_column_names().contains(&&aroon_down)); + assert!(df.get_column_names().contains(&&aroon_oscillator)); + + // Check that aroon_oscillator = aroon_up - aroon_down + let aroon_up = df.column("aroon_up").unwrap(); + let aroon_down = df.column("aroon_down").unwrap(); + let aroon_osc = df.column("aroon_oscillator").unwrap(); + + for i in 0..df.height() { + if let (Some(up), Some(down), Some(osc)) = + (aroon_up.get(i).unwrap().extract::(), aroon_down.get(i).unwrap().extract::(), aroon_osc.get(i).unwrap().extract::()) + { + assert_abs_diff_eq!(osc, up - down, epsilon = 1e-10); + } + } + } + + #[test] + fn test_parabolic_time_price_system_bulk() { + let ti = create_trend_ti(); + let result = ti.parabolic_time_price_system_bulk("high", "low", 0.02, 0.20, 0.02, "Long", 10.0).unwrap(); + let series = result.0.0; + assert_eq!(series.name(), "parabolic_sar"); + assert_eq!(series.len(), 10); + } + + #[test] + fn test_parabolic_time_price_system_bulk_invalid_position() { + let ti = create_trend_ti(); + let result = ti.parabolic_time_price_system_bulk("high", "low", 0.02, 0.20, 0.02, "InvalidPosition", 10.0); + assert!(result.is_err()); + } + + #[test] + fn test_directional_movement_system_bulk() { + let ti = create_trend_ti(); + let result = ti.directional_movement_system_bulk("high", "low", "close", 5, "SimpleMovingAverage").unwrap(); + let df = result.0.as_ref(); + + assert_eq!(df.width(), 4); + assert_eq!(df.height(), 10); + let positive_di = PlSmallStr::from("positive_di"); + let negative_di = PlSmallStr::from("negative_di"); + let adx = PlSmallStr::from("adx"); + let adxr = PlSmallStr::from("adxr"); + assert!(df.get_column_names().contains(&&positive_di)); + assert!(df.get_column_names().contains(&&negative_di)); + assert!(df.get_column_names().contains(&&adx)); + assert!(df.get_column_names().contains(&&adxr)); + + // Check that DI values are non-negative + let pos_di = df.column("positive_di").unwrap(); + let neg_di = df.column("negative_di").unwrap(); + + for i in 0..df.height() { + if let (Some(pos), Some(neg)) = (pos_di.get(i).unwrap().extract::(), neg_di.get(i).unwrap().extract::()) { + assert!(pos >= 0.0); + assert!(neg >= 0.0); + } + } + } + + #[test] + fn test_volume_price_trend_bulk() { + let ti = create_trend_ti(); + let result = ti.volume_price_trend_bulk("close", "volume", 0.0).unwrap(); + let series = result.0.0; + assert_eq!(series.name(), "volume_price_trend"); + assert_eq!(series.len(), 10); + } + + #[test] + fn test_true_strength_index_bulk() { + let ti = create_trend_ti(); + let result = ti.true_strength_index_bulk("close", "SimpleMovingAverage", 5, "ExponentialMovingAverage", 3).unwrap(); + let series = result.0.0; + assert_eq!(series.name(), "true_strength_index"); + assert_eq!(series.len(), 10); + + // Check that all values are within valid range + for i in 0..series.len() { + if let Some(val) = series.get(i).unwrap().extract::() { + assert!((-100.0..=100.0).contains(&val)); + } + } + } + + #[test] + fn test_invalid_column_name() { + let ti = create_trend_ti(); + let result = ti.aroon_up_single("invalid_column"); + assert!(result.is_err()); + } + + #[test] + fn test_invalid_constant_model() { + let ti = create_trend_ti(); + let result = ti.true_strength_index_single("close", "InvalidModel", 5, "ExponentialMovingAverage"); + assert!(result.is_err()); + } + + #[test] + fn test_aroon_consistency_single_vs_bulk() { + let ti = create_trend_ti(); + let single_up = ti.aroon_up_single("high").unwrap(); + let single_down = ti.aroon_down_single("low").unwrap(); + let single_osc = ti.aroon_oscillator_single("high", "low").unwrap(); + + // For single values, we expect them to be the same as the last value in bulk calculation + // This is a conceptual test - actual implementation may vary + assert_abs_diff_eq!(single_osc, single_up - single_down, epsilon = 1e-10); + } + + #[test] + fn test_bulk_series_length_consistency() { + let ti = create_trend_ti(); + let period = 5; + + let aroon_up = ti.aroon_up_bulk("high", period).unwrap(); + let aroon_down = ti.aroon_down_bulk("low", period).unwrap(); + let aroon_osc = ti.aroon_oscillator_bulk("high", "low", period).unwrap(); + + assert_eq!(aroon_up.0.0.len(), aroon_down.0.0.len()); + assert_eq!(aroon_up.0.0.len(), aroon_osc.0.0.len()); + } + + #[test] + fn test_parabolic_sar_both_positions() { + let ti = create_trend_ti(); + + let long_result = ti.parabolic_time_price_system_bulk("high", "low", 0.02, 0.20, 0.02, "Long", 10.0).unwrap(); + + let short_result = ti.parabolic_time_price_system_bulk("high", "low", 0.02, 0.20, 0.02, "Short", 10.0).unwrap(); + + assert_eq!(long_result.0.0.len(), short_result.0.0.len()); + } + + #[test] + fn test_directional_movement_system_different_models() { + let ti = create_trend_ti(); + + let sma_result = ti.directional_movement_system_bulk("high", "low", "close", 5, "SimpleMovingAverage").unwrap(); + + let ema_result = ti.directional_movement_system_bulk("high", "low", "close", 5, "ExponentialMovingAverage").unwrap(); + + assert_eq!(sma_result.0.as_ref().height(), ema_result.0.as_ref().height()); + assert_eq!(sma_result.0.as_ref().width(), ema_result.0.as_ref().width()); + } + + #[test] + fn test_volume_price_trend_different_initial_values() { + let ti = create_trend_ti(); + + let vpt_zero = ti.volume_price_trend_bulk("close", "volume", 0.0).unwrap(); + let vpt_hundred = ti.volume_price_trend_bulk("close", "volume", 100.0).unwrap(); + + assert_eq!(vpt_zero.0.0.len(), vpt_hundred.0.0.len()); + + // The difference should be constant (100.0) throughout the series + for i in 0..vpt_zero.0.0.len() { + if let (Some(val_zero), Some(val_hundred)) = (vpt_zero.0.0.get(i).unwrap().extract::(), vpt_hundred.0.0.get(i).unwrap().extract::()) { + assert_abs_diff_eq!(val_hundred - val_zero, 100.0, epsilon = 1e-10); + } + } + } +} diff --git a/plugins/ezpz-rust-ti/src/indicators/volatility/mod.rs b/plugins/ezpz-rust-ti/src/indicators/volatility/mod.rs new file mode 100644 index 0000000..0c65d1f --- /dev/null +++ b/plugins/ezpz-rust-ti/src/indicators/volatility/mod.rs @@ -0,0 +1,252 @@ +use { + crate::utils::{extract_f64_values, parse_constant_model_type}, + ezpz_stubz::{lazy::PyLfStubbed, series::PySeriesStubbed}, + polars::prelude::*, + pyo3::prelude::*, + pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}, +}; + +/// Volatility Technical Indicators - A collection of volatility analysis functions for financial data + +#[gen_stub_pyclass] +#[pyclass] +#[derive(Clone)] +pub struct VolatilityTI { + lf: LazyFrame, +} + +#[gen_stub_pymethods] +#[pymethods] +impl VolatilityTI { + #[new] + fn new(lf: PyLfStubbed) -> Self { + Self { lf: lf.0.into() } + } + + /// Ulcer Index (Single) - Calculates how quickly the price is able to get back to its former high + /// Can be used instead of standard deviation for volatility measurement + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// + /// # Returns + /// f64 - Single Ulcer Index value representing overall price volatility and drawdown risk + fn ulcer_index_single(&self, price_column: &str) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::volatility_indicators::single::ulcer_index(&values); + Ok(result) + } + + /// Ulcer Index (Bulk) - Calculates rolling Ulcer Index over specified period + /// Returns a series of Ulcer Index values + /// + /// # Parameters + /// - `price_column`: &str - Name of the price column to analyze + /// - `period`: usize - Rolling window period for calculation + /// + /// # Returns + /// PySeriesStubbed - Series of rolling Ulcer Index values with name "ulcer_index" + fn ulcer_index_bulk(&self, price_column: &str, period: usize) -> PyResult { + let series = self + .lf + .clone() + .select([col(price_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to collect column '{price_column}': {e}")))? + .column(price_column) + .map_err(|e| PyErr::new::(format!("Column '{price_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{price_column}' could not be converted to Series")))? + .clone(); + + let values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(series)))?; + let result = rust_ti::volatility_indicators::bulk::ulcer_index(&values, period); + let result_series = Series::new("ulcer_index".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } + + /// Volatility System - Calculates Welles volatility system with Stop and Reverse (SaR) points + /// Uses trend analysis to determine long/short positions and calculate SaR levels + /// Constant multiplier typically between 2.8-3.1 (Welles used 3.0) + /// + /// # Parameters + /// - `high_column`: &str - Name of the high price column + /// - `low_column`: &str - Name of the low price column + /// - `close_column`: &str - Name of the close price column + /// - `period`: usize - Period for volatility calculation + /// - `constant_multiplier`: f64 - Multiplier for volatility (typically 2.8-3.1) + /// - `constant_model_type`: &str - Type of constant model to use for calculation + /// + /// # Returns + /// PySeriesStubbed - Series of volatility system values with Stop and Reverse points, named "volatility_system" + fn volatility_system( + &self, + high_column: &str, + low_column: &str, + close_column: &str, + period: usize, + constant_multiplier: f64, + constant_model_type: &str, + ) -> PyResult { + let df = self + .lf + .clone() + .select([col(high_column), col(low_column), col(close_column)]) + .collect() + .map_err(|e| PyErr::new::(format!("Failed to select columns: {e}")))?; + + let high_series = df + .column(high_column) + .map_err(|e| PyErr::new::(format!("Column '{high_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{high_column}' could not be converted to Series")))? + .clone(); + + let low_series = df + .column(low_column) + .map_err(|e| PyErr::new::(format!("Column '{low_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{low_column}' could not be converted to Series")))? + .clone(); + + let close_series = df + .column(close_column) + .map_err(|e| PyErr::new::(format!("Column '{close_column}' not found: {e}")))? + .as_series() + .ok_or_else(|| PyErr::new::(format!("Column '{close_column}' could not be converted to Series")))? + .clone(); + + let high_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(high_series)))?; + let low_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(low_series)))?; + let close_values: Vec = extract_f64_values(PySeriesStubbed(pyo3_polars::PySeries(close_series)))?; + + let constant_type = parse_constant_model_type(constant_model_type)?; + let result = rust_ti::volatility_indicators::bulk::volatility_system(&high_values, &low_values, &close_values, period, constant_multiplier, constant_type); + let result_series = Series::new("volatility_system".into(), result); + Ok(PySeriesStubbed(pyo3_polars::PySeries(result_series))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use approx::assert_abs_diff_eq; + use ezpz_stubz::lazy::PyLfStubbed; + + fn create_test_dataframe() -> LazyFrame { + let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; + df! { + "high" => &data, + "low" => vec![0.9, 1.8, 2.7, 3.6, 4.5, 5.4, 6.3, 7.2, 8.1, 9.0], + "close" => vec![0.95, 1.9, 2.85, 3.8, 4.75, 5.7, 6.65, 7.6, 8.55, 9.5], + "volume" => vec![100.0, 200.0, 150.0, 300.0, 250.0, 180.0, 220.0, 190.0, 280.0, 320.0] + } + .unwrap() + .lazy() + } + + fn create_volatility_ti() -> VolatilityTI { + let lf = create_test_dataframe(); + VolatilityTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(lf))) + } + + #[test] + fn test_ulcer_index_single() { + let ti = create_volatility_ti(); + let result = ti.ulcer_index_single("close").unwrap(); + // Expected value calculated using rust_ti::volatility_indicators::single::ulcer_index + // For the close prices [0.95, 1.9, 2.85, 3.8, 4.75, 5.7, 6.65, 7.6, 8.55, 9.5] + // Ulcer Index involves calculating drawdowns from the highest close price + let expected = 0.0; // No drawdowns as prices are strictly increasing + assert_abs_diff_eq!(result, expected, epsilon = 1e-10); + } + + #[test] + fn test_ulcer_index_single_invalid_column() { + let ti = create_volatility_ti(); + let result = ti.ulcer_index_single("invalid_column"); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().to_string(), "PyValueError: Column 'invalid_column' not found: column not found"); + } + + #[test] + fn test_ulcer_index_bulk() { + let ti = create_volatility_ti(); + let period = 3; + let result = ti.ulcer_index_bulk("close", period).unwrap(); + let values: Vec = result.0.0.f64().unwrap().into_iter().map(|opt| opt.unwrap_or(f64::NAN)).collect(); + + assert_eq!(values.len(), 10); + assert!(values[0].is_nan()); + assert!(values[1].is_nan()); + // For period=3, the first non-NaN value at index 2 is the Ulcer Index of [0.95, 1.9, 2.85] + // Since prices are increasing, drawdowns are 0, so Ulcer Index should be 0.0 + assert_abs_diff_eq!(values[2], 0.0, epsilon = 1e-10); + assert_abs_diff_eq!(values[3], 0.0, epsilon = 1e-10); + } + + #[test] + fn test_ulcer_index_bulk_invalid_column() { + let ti = create_volatility_ti(); + let result = ti.ulcer_index_bulk("invalid_column", 3); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().to_string(), "PyValueError: Column 'invalid_column' not found: column not found"); + } + + #[test] + fn test_volatility_system() { + let ti = create_volatility_ti(); + let period = 3; + let constant_multiplier = 3.0; + let constant_model_type = "Simple"; + let result = ti.volatility_system("high", "low", "close", period, constant_multiplier, constant_model_type).unwrap(); + let values: Vec = result.0.0.f64().unwrap().into_iter().map(|opt| opt.unwrap_or(f64::NAN)).collect(); + + assert_eq!(values.len(), 10); + assert!(values[0].is_nan()); + assert!(values[1].is_nan()); + // Expected values depend on rust_ti::volatility_indicators::bulk::volatility_system + // For simplicity, verify that non-NaN values are finite and reasonable + for &value in values.iter().skip(2) { + assert!(value.is_finite()); + } + } + + #[test] + fn test_volatility_system_invalid_column() { + let ti = create_volatility_ti(); + let result = ti.volatility_system("invalid_column", "low", "close", 3, 3.0, "Simple"); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().to_string(), "PyValueError: Column 'invalid_column' not found: column not found"); + } + + #[test] + fn test_volatility_system_invalid_model_type() { + let ti = create_volatility_ti(); + let result = ti.volatility_system("high", "low", "close", 3, 3.0, "InvalidModel"); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().to_string(), "PyValueError: Invalid constant model type: InvalidModel"); + } + + #[test] + fn test_volatility_system_empty_dataframe() { + let empty_lf = df! { "high" => Vec::::new(), "low" => Vec::::new(), "close" => Vec::::new() }.unwrap().lazy(); + let ti = VolatilityTI::new(PyLfStubbed(pyo3_polars::PyLazyFrame(empty_lf))); + let result = ti.volatility_system("high", "low", "close", 3, 3.0, "Simple"); + assert!(result.is_ok()); + let values: Vec = result.unwrap().0.0.f64().unwrap().into_iter().map(|opt| opt.unwrap_or(f64::NAN)).collect(); + assert!(values.is_empty()); + } +} diff --git a/plugins/ezpz-rust-ti/src/lib.rs b/plugins/ezpz-rust-ti/src/lib.rs new file mode 100644 index 0000000..57dde7e --- /dev/null +++ b/plugins/ezpz-rust-ti/src/lib.rs @@ -0,0 +1,27 @@ +use {pyo3::prelude::*, pyo3_stub_gen::define_stub_info_gatherer}; +mod indicators; +mod utils; + +use indicators::{ + basic::BasicTI, candle::CandleTI, chart::ChartTrendsTI, correlation::CorrelationTI, ma::MATI, momentum::MomentumTI, other::OtherTI, std_::StandardTI, + strength::StrengthTI, trend::TrendTI, volatility::VolatilityTI, +}; + +#[pymodule] +#[pyo3(name = "_ezpz_rust_ti")] +fn _ezpz_rust_ti(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + Ok(()) +} + +define_stub_info_gatherer!(stub_info); diff --git a/plugins/ezpz-rust-ti/src/utils/mod.rs b/plugins/ezpz-rust-ti/src/utils/mod.rs new file mode 100644 index 0000000..cd3e58d --- /dev/null +++ b/plugins/ezpz-rust-ti/src/utils/mod.rs @@ -0,0 +1,84 @@ +use { + ezpz_stubz::{frame::PyDfStubbed, series::PySeriesStubbed}, + polars::prelude::*, + pyo3::prelude::*, +}; + +pub(crate) fn parse_constant_model_type(constant_model_type: &str) -> PyResult { + match constant_model_type.to_lowercase().as_str() { + "simple_moving_average" => Ok(rust_ti::ConstantModelType::SimpleMovingAverage), + "smoothed_moving_average" => Ok(rust_ti::ConstantModelType::SmoothedMovingAverage), + "exponential_moving_average" => Ok(rust_ti::ConstantModelType::ExponentialMovingAverage), + "simple_moving_median" => Ok(rust_ti::ConstantModelType::SimpleMovingMedian), + "simple_moving_mode" => Ok(rust_ti::ConstantModelType::SimpleMovingMode), + _ => Err(pyo3::exceptions::PyValueError::new_err("Unsupported constant model type")), + } +} + +pub(crate) fn parse_deviation_model(model_type: &str) -> PyResult { + match model_type { + "standard_deviation" => Ok(rust_ti::DeviationModel::StandardDeviation), + "mean_absolute_deviation" => Ok(rust_ti::DeviationModel::MeanAbsoluteDeviation), + "median_absolute_deviation" => Ok(rust_ti::DeviationModel::MedianAbsoluteDeviation), + "mode_absolute_deviation" => Ok(rust_ti::DeviationModel::ModeAbsoluteDeviation), + "ulcer_index" => Ok(rust_ti::DeviationModel::UlcerIndex), + _ => Err(PyErr::new::("Unsupported deviation model")), + } +} + +// extract f64 values from PySeriesStubbed +pub(crate) fn extract_f64_values(series: PySeriesStubbed) -> PyResult> { + let polars_series: Series = series.0.into(); + let values = polars_series + .cast(&DataType::Float64) + .map_err(|e| PyErr::new::(e.to_string()))? + .f64() + .map_err(|e| PyErr::new::(e.to_string()))? + .into_no_null_iter() + .collect::>(); + Ok(values) +} + +pub(crate) fn parse_central_point(central_point: &str) -> PyResult { + match central_point.to_lowercase().as_str() { + "mean" => Ok(rust_ti::CentralPoint::Mean), + "median" => Ok(rust_ti::CentralPoint::Median), + "mode" => Ok(rust_ti::CentralPoint::Mode), + _ => Err(PyErr::new::("central_point must be 'mean', 'median', or 'mode'")), + } +} + +#[inline] +pub(crate) fn unzip_triple(data: Vec<(T, T, T)>) -> (Vec, Vec, Vec) { + let capacity = data.len(); + let mut vec1 = Vec::with_capacity(capacity); + let mut vec2 = Vec::with_capacity(capacity); + let mut vec3 = Vec::with_capacity(capacity); + + for (a, b, c) in data { + vec1.push(a); + vec2.push(b); + vec3.push(c); + } + + (vec1, vec2, vec3) +} + +#[inline] +pub(crate) fn create_triple_df( + lower: Vec, + middle: Vec, + upper: Vec, + lower_name: &str, + middle_name: &str, + upper_name: &str, +) -> PyResult { + let df = df! { + lower_name => lower, + middle_name => middle, + upper_name => upper, + } + .map_err(|e| PyErr::new::(e.to_string()))?; + + Ok(PyDfStubbed(pyo3_polars::PyDataFrame(df))) +} diff --git a/pluginz/README.md b/pluginz/README.md deleted file mode 100644 index 8a1eba0..0000000 --- a/pluginz/README.md +++ /dev/null @@ -1,97 +0,0 @@ -# EZPZ-Pluginz - -This package provides type hinting and IDE support for plugins to the Polars package, enhancing the development experience. - -## Installation - -```bash -pip install polar-patch -``` - -## Problem It Solves - -Polars is a fast DataFrame library for Python, but it lacks a way to provide type hints with type checker and IDE support for custom plugins. The polars maintainers have no plans to fill this gap from within polars itself. So Summit Sailors is stepping in to help. - -## Motivation - -With this package, developers can: - -- Write more robust and maintainable polars plugins. -- Utilize IDE Type Checker features such as autocompletion and inline documentation. -- Extend the polars ecosystem with more incentive to create new plugins - -## How does it work? - -1. PP parses your ezpz_pluginz.toml -2. scans files and folders you listed in ur toml -3. uses [libCST](https://libcst.readthedocs.io/en/latest/) to extract the needed info about your plugins. -4. generates a lockfile for all the plugin data it extracted -5. creates a backup of the files to be modified -6. uses a copy of the backup fresh each run -7. applies the libCST transformer to add the attribute with type hint onto the corresponding Polars class -8. adds the corresponding import for your plugin into polars in a type checking block - -![Lockfile](images/lockfile.png) - -![Added Import](images/attr_type_hint_import.png) - -![Added Attribute](images/attr_type_hint_added.png) - -## Notes - -- It is important to note that while this is minimally invasive, it is monkey patching the executing interpreters polars package. -- libCST uses concrete syntax trees, thus the polars file is well preserved. - -## Beta Blockers - -- ~~callable form of `pl.api`~~ -- ~~install plugins from site-packages~~ -- ~~basic logging~~ -- inital functional hypothesis testing setup -- basic exception handling -- ~~unpin 3.12.4 to ^3.12~~ - -## Stable Blockers - -- some maturity -- The blessing of the polars team for the approach on [issue](https://github.com/pola-rs/polars/issues/14475) - -## Features - -- automatic "hot reloading" since the type hint points directly to the implementation -- loads plugins from site-packages and generates a lockfile - -## Configuration - -To specify paths to be scanned for plugins, create a ezpz_pluginz.toml file in your project root. -(VSC IDE Support in Development) - -```toml -[ezpz_pluginz] -include = ["path/to/your/plugin1.py", "path/to/your/polars/plugin/folder"] -``` - -## Usage - -To use the CLI tool provided by this package, run the following command: - -```bash -pp mount -``` - -## Undoing Changes - -If you need to undo the changes made by this package, simply: - -```bash -pp unmount -``` - ---- - - - - - Subscription Tiers on Polar - - diff --git a/pluginz/ezpz_pluginz/__cli__.py b/pluginz/ezpz_pluginz/__cli__.py deleted file mode 100644 index 7b62558..0000000 --- a/pluginz/ezpz_pluginz/__cli__.py +++ /dev/null @@ -1,24 +0,0 @@ -import typer - -app = typer.Typer(name="ezplugins", pretty_exceptions_show_locals=False, pretty_exceptions_short=True) - - -@app.command(name="mount") -def mount() -> None: - """ - Mount your plugins type hints - """ - - from ezpz_pluginz import mount_plugins - - mount_plugins() - - -@app.command() -def unmount() -> None: - """ - Unmount your plugins type hints - """ - from ezpz_pluginz import unmount_plugins - - unmount_plugins() diff --git a/pluginz/ezpz_pluginz/lockfile.py b/pluginz/ezpz_pluginz/lockfile.py deleted file mode 100644 index 62edd2f..0000000 --- a/pluginz/ezpz_pluginz/lockfile.py +++ /dev/null @@ -1,68 +0,0 @@ -import logging -import importlib -import importlib.util -import importlib.metadata -from typing import Self, Iterable -from pathlib import Path -from operator import attrgetter -from itertools import chain, groupby - -import yaml -from jinja2 import Template -from pydantic import BaseModel - -from ezpz_pluginz.toml_schema import EzpzPluginConfig -from ezpz_pluginz.register_plugin_macro import PolarsPluginMacroMetadataPD - -logger = logging.getLogger(__name__) - -EZPZ_TOML_FILENAME = "ezpz.toml" -EZPZ_LOCKFILE_FILENAME = "ezpz-lock.yaml" - - -def group_models_by_key[T: BaseModel](data: Iterable[T], key: str) -> dict[str, set[T]]: - sorted_data = sorted(data, key=attrgetter(key)) - return {k: set(v) for k, v in groupby(sorted_data, key=attrgetter(key))} - - -class PolarsPluginLockfilePD(BaseModel): - project_plugins: dict[str, set[PolarsPluginMacroMetadataPD]] - site_plugins: dict[str, set[PolarsPluginMacroMetadataPD]] - - @classmethod - def generate(cls) -> "PolarsPluginLockfilePD": - logger.debug(f"cwd: {Path.cwd()}") - project_ezpz_toml_path = Path.cwd().joinpath(EZPZ_TOML_FILENAME) - if not project_ezpz_toml_path.exists(): - return cls(project_plugins=dict[str, set[PolarsPluginMacroMetadataPD]](), site_plugins=dict[str, set[PolarsPluginMacroMetadataPD]]()) - project_entry = cls(project_plugins=EzpzPluginConfig.get_plugins(project_ezpz_toml_path), site_plugins={}) - for dist in importlib.metadata.distributions(): - if "ezpz-pluginz" in (dist.requires or []): - spec = importlib.util.find_spec(dist.metadata["Name"]) - if spec and spec.origin: - patch_file = Path(spec.origin).with_name(EZPZ_LOCKFILE_FILENAME) - if patch_file.exists(): - project_entry.site_plugins.update(cls.from_yaml_file(patch_file).project_plugins) - return project_entry - - def generate_registry(self) -> str: - imports = list[str]() - registry = list[str]() - for plugin in chain(chain.from_iterable(self.project_plugins.values()), chain.from_iterable(self.site_plugins.values())): - imports.append(plugin.import_) - registry.append(plugin.registery_entry()) - return Template(Path(__file__).parent.parent.joinpath("templates", "sitecustomize.py.j2").read_text()).render(imports=imports, registry=registry) - - def to_yaml(self) -> str: - return yaml.safe_dump(self.model_dump(mode="json"), sort_keys=False) - - @classmethod - def from_yaml(cls, content: str) -> Self: - return cls.model_validate(yaml.safe_load(content)) - - @classmethod - def from_yaml_file(cls, lockfile_path: "Path") -> Self: - return cls.from_yaml(lockfile_path.read_text()) - - def to_yaml_file(self, lockfile_path: "Path") -> None: - lockfile_path.write_text(self.to_yaml()) diff --git a/pluginz/ezpz_pluginz/toml_schema.py b/pluginz/ezpz_pluginz/toml_schema.py deleted file mode 100644 index 87955d3..0000000 --- a/pluginz/ezpz_pluginz/toml_schema.py +++ /dev/null @@ -1,65 +0,0 @@ -import logging -from typing import TYPE_CHECKING, Any, Iterable, Generator -from pathlib import Path -from operator import attrgetter -from itertools import chain, groupby - -import toml -import libcst as cst -from pydantic import Field, BaseModel - -from ezpz_pluginz.register_plugin_macro import PolarsPluginCollector - -if TYPE_CHECKING: - from ezpz_pluginz.register_plugin_macro import PolarsPluginMacroMetadataPD - -__all__ = ["EzpzPluginConfig"] - -logger = logging.getLogger(__name__) - -EZPZ_TOML_FILENAME = "ezpz.toml" -EZPZ_LOCKFILE_FILENAME = "ezpz-lock.yaml" - - -def group_models_by_key[T: BaseModel](data: Iterable[T], key: str) -> dict[str, set[T]]: - sorted_data = sorted(data, key=attrgetter(key)) - return {k: set(v) for k, v in groupby(sorted_data, key=attrgetter(key))} - - -def _process_file(path: "Path") -> set["PolarsPluginMacroMetadataPD"]: - plugin_visitor = PolarsPluginCollector() - cst.parse_module(path.read_text()).visit(plugin_visitor) - logger.debug(f"_process_file: {path}") - logger.debug(f"_process_file:return: {plugin_visitor.macro_data}") - return set(plugin_visitor.macro_data) - - -def process_includes(paths: Iterable["Path"]) -> "Generator[PolarsPluginMacroMetadataPD, Any, None]": - for path in paths: - if path.is_file(): - yield from _process_file(path) - elif path.is_dir(): - sub_toml = path.joinpath(EZPZ_TOML_FILENAME) - if sub_toml.exists(): - yield from process_includes(path.joinpath(subpath) for subpath in EzpzPluginConfig.from_toml_path(sub_toml).include) - else: - yield from process_includes(chain(path.rglob("*.py"), path.rglob("*.pyi"))) - - -def get_plugins(project_toml_path: Path) -> dict[str, set["PolarsPluginMacroMetadataPD"]]: - ezpz_pluginz = EzpzPluginConfig.from_toml_path(project_toml_path) - return group_models_by_key(set(process_includes(ezpz_pluginz.include)), "polars_ns") - - -class EzpzPluginConfig(BaseModel): - name: str - include: list[Path] - site_customize: bool | None = Field(default=None) - - @staticmethod - def from_toml_path(path: Path) -> "EzpzPluginConfig": - return EzpzPluginToml(**toml.loads(path.read_text())).ezpz_pluginz - - -class EzpzPluginToml(BaseModel): - ezpz_pluginz: EzpzPluginConfig diff --git a/pluginz/tests/test_polars_plugin_collector.py b/pluginz/tests/test_polars_plugin_collector.py deleted file mode 100644 index 2462455..0000000 --- a/pluginz/tests/test_polars_plugin_collector.py +++ /dev/null @@ -1,55 +0,0 @@ -from pathlib import Path - -import libcst as cst -from hypothesis import ( - given, - strategies as st, -) - -from ezpz_pluginz.plugin_scanner import PluginInfoDC, PolarsPluginCollector - -identifier = st.from_regex(r"[a-zA-Z_][a-zA-Z0-9_]*", fullmatch=True) -filepath_strategy = st.builds(lambda parts: str(Path(*parts)), st.lists(identifier, min_size=1, max_size=5)) -root_dir_strategy = st.builds(lambda parts: str(Path(*parts)), st.lists(identifier, min_size=1, max_size=3)) -class_name_strategy = identifier - -namespace_name_strategy = st.sampled_from( - ["register_expr_namespace", "register_dataframe_namespace", "register_lazyframe_namespace", "register_series_namespace"] -) - -decorator_call_strategy = st.builds( - lambda namespace: cst.Decorator( - decorator=cst.Call( - func=cst.Attribute(value=cst.Name("pl"), attr=cst.Name(namespace)), - args=[cst.Arg(value=cst.SimpleString(f'"{namespace}_namespace"'))], - ) - ), - namespace_name_strategy, -) - -class_def_strategy = st.builds( - lambda class_name, decorators: cst.ClassDef(name=cst.Name(class_name), body=cst.IndentedBlock(body=[]), decorators=decorators), - class_name_strategy, - st.lists(decorator_call_strategy, min_size=1, max_size=3), -) - - -@given(filepath=filepath_strategy, root_dir=root_dir_strategy, class_def=class_def_strategy) -def test_polars_plugin_collector(filepath: str, root_dir: str, class_def: cst.ClassDef) -> None: - module = cst.Module(body=[class_def]) - collector = PolarsPluginCollector(filepath=filepath, root_dir=root_dir) - module.visit(collector) - expected_plugins = [ - PluginInfoDC( - cls_name=class_def.name.value, - polars_ns=decorator.decorator.func.attr.value, - modpath=".".join(Path(filepath).relative_to(Path(root_dir)).with_suffix("").parts), - namespace=decorator.decorator.args[0].value.value.strip('"'), - ) - for decorator in class_def.decorators - ] - assert collector.plugins == expected_plugins - - -if __name__ == "__main__": - test_polars_plugin_collector() diff --git a/pyproject.toml b/pyproject.toml index e0194cd..9f4fbdb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [tool.ezpz_pluginz] -includes = ["ezpz-guiz", "ezpz-pluginz"] +include = ["plugins/ezpz-rust-ti"] name = "ezpz" site_customize = true @@ -9,15 +9,15 @@ requires = ["hatchling"] [project] authors = [] -dependencies = [] +dependencies = ["maturin>=1.8.7", "numpy>=2.3.1", "psycopg-binary>=3.2.9", "psycopg2-binary>=2.9.10"] description = '' -name = "pysilo" +name = "ezpz" readme = "README.md" requires-python = ">=3.13,<3.14" version = "0.0.1" [tool.rye.workspace] -members = ["guiz", "macroz", "pluginz"] +members = ["core/*", "examples", "plugins/*"] [tool.rye] dev-dependencies = [ @@ -25,16 +25,17 @@ dev-dependencies = [ "autopep8==2.3.2", "flake8-plugin-utils==1.3.3", "flake8-type-checking==3.0.0", - "flake8==7.2.0", - "hypothesis==6.135.1", + "flake8==7.3.0", + "hypothesis==6.135.27", "ipykernel==6.29.5", - "ipython==9.3.0", + "ipython==9.4.0", "isort==6.0.1", "jupyterlab-quarto==0.3.5", - "jupyterlab==4.4.3", + "jupyterlab==4.4.4", "jupyterthemes==0.20.0", "pylint==3.3.7", - "ruff==0.11.13", + "pytest>=8.4.1", + "ruff==0.12.3", ] virtual = true @@ -44,7 +45,7 @@ extend-include = ["*.ipynb"] include = ["*.ipynb", "*.py", "*.pyi"] indent-width = 2 line-length = 160 -target-version = "py312" +target-version = "py313" [tool.ruff.lint] dummy-variable-rgx = "(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" @@ -114,8 +115,8 @@ ban-relative-imports = "all" [tool.ruff.lint.flake8-type-checking] exempt-modules = ["typing", "typing_extensions"] quote-annotations = true -runtime-evaluated-base-classes = ["mixins.id.UuidPKMixin", "pydantic.BaseModel", "sqlalchemy.orm.DeclarativeBase", "sqlmodel.SQLModel"] -runtime-evaluated-decorators = ["attrs.define", "pydantic.BaseModel", "pydantic.validate_call", "sqlalchemy.orm.DeclarativeBase", "sqlmodel.SQLModel"] +runtime-evaluated-base-classes = ["pydantic.BaseModel"] +runtime-evaluated-decorators = ["pydantic.BaseModel", "pydantic.validate_call"] strict = true @@ -132,3 +133,9 @@ combine-as-imports = true force-wrap-aliases = true known-first-party = ["ezpz_pluginz"] length-sort = true + +[tool.pytest.ini_options] +python_classes = ["Test*"] +python_files = ["*_test.py", "test_*.py"] +python_functions = ["test_*"] +testpaths = ["pluginz/tests"] diff --git a/pyrightconfig.json b/pyrightconfig.json index d251027..27853a9 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -3,7 +3,12 @@ "pythonPlatform": "Linux", "exclude": ["**/node_modules", "**/__pycache__"], - "include": ["ezpz-guiz", "ezpz-pluginz"], + "include": [ + "plugins/ezpz-rust-ti", + "core/pluginz", + "core/macroz", + "examples" + ], "typeCheckingMode": "strict", "reportTypesImportCycles": "error", "verboseOutput": true, diff --git a/requirements-dev.lock b/requirements-dev.lock new file mode 100644 index 0000000..f83860f --- /dev/null +++ b/requirements-dev.lock @@ -0,0 +1,427 @@ +# generated by rye +# use `rye lock` or `rye sync` to update this lockfile +# +# last locked with the following flags: +# pre: false +# features: [] +# all-features: false +# with-sources: false +# generate-hashes: false +# universal: false + +-e file:core/macroz + # via ezpz-pluginz +-e file:core/pluginz + # via ezpz-rust-ti + # via ezpz-ta +-e file:examples +-e file:plugins/ezpz-rust-ti +aiofiles==24.1.0 + # via ezpz-pluginz +annotated-types==0.7.0 + # via pydantic +anyio==4.9.0 + # via httpx + # via jupyter-server +appnope==0.1.4 + # via ipykernel +argon2-cffi==25.1.0 + # via jupyter-server +argon2-cffi-bindings==21.2.0 + # via argon2-cffi +arrow==1.3.0 + # via isoduration +astroid==3.3.11 + # via pylint +asttokens==3.0.0 + # via stack-data +async-lru==2.0.5 + # via jupyterlab +attrs==25.3.0 + # via hypothesis + # via jsonschema + # via referencing +autoflake==2.3.1 +autopep8==2.3.2 +babel==2.17.0 + # via jupyterlab-server +beautifulsoup4==4.13.4 + # via nbconvert +bleach==6.2.0 + # via nbconvert +boolean-py==5.0 + # via license-expression +cachecontrol==0.14.3 + # via pip-audit +cached-property==2.0.1 + # via ezpz-pluginz +certifi==2025.7.14 + # via httpcore + # via httpx + # via requests +cffi==1.17.1 + # via argon2-cffi-bindings +charset-normalizer==3.4.2 + # via requests +classify-imports==4.2.0 + # via flake8-type-checking +click==8.2.1 + # via typer +comm==0.2.3 + # via ipykernel +contourpy==1.3.2 + # via matplotlib +cycler==0.12.1 + # via matplotlib +cyclonedx-python-lib==9.1.0 + # via pip-audit +debugpy==1.8.15 + # via ipykernel +decorator==5.2.1 + # via ipython +defusedxml==0.7.1 + # via nbconvert + # via py-serializable +dill==0.4.0 + # via pylint +dnspython==2.7.0 + # via email-validator +email-validator==2.2.0 + # via pydantic +executing==2.2.0 + # via stack-data +fastjsonschema==2.21.1 + # via nbformat +filelock==3.18.0 + # via cachecontrol +flake8==7.3.0 + # via flake8-type-checking +flake8-plugin-utils==1.3.3 +flake8-type-checking==3.0.0 +fonttools==4.59.0 + # via matplotlib +fqdn==1.5.1 + # via jsonschema +h11==0.16.0 + # via httpcore +httpcore==1.0.9 + # via httpx +httpx==0.28.1 + # via jupyterlab +hypothesis==6.135.27 +idna==3.10 + # via anyio + # via email-validator + # via httpx + # via jsonschema + # via requests +iniconfig==2.1.0 + # via pytest +ipykernel==6.29.5 + # via jupyterlab +ipython==9.4.0 + # via ipykernel + # via jupyterthemes +ipython-pygments-lexers==1.1.1 + # via ipython +isoduration==20.11.0 + # via jsonschema +isort==6.0.1 + # via pylint +jedi==0.19.2 + # via ipython +jinja2==3.1.6 + # via ezpz-pluginz + # via jupyter-server + # via jupyterlab + # via jupyterlab-server + # via nbconvert +json5==0.12.0 + # via jupyterlab-server +jsonpointer==3.0.0 + # via jsonschema +jsonschema==4.25.0 + # via jupyter-events + # via jupyterlab-server + # via nbformat +jsonschema-specifications==2025.4.1 + # via jsonschema +jupyter-client==8.6.3 + # via ipykernel + # via jupyter-server + # via nbclient +jupyter-core==5.8.1 + # via ipykernel + # via jupyter-client + # via jupyter-server + # via jupyterlab + # via jupyterthemes + # via nbclient + # via nbconvert + # via nbformat +jupyter-events==0.12.0 + # via jupyter-server +jupyter-lsp==2.2.6 + # via jupyterlab +jupyter-server==2.16.0 + # via jupyter-lsp + # via jupyterlab + # via jupyterlab-server + # via notebook + # via notebook-shim +jupyter-server-terminals==0.5.3 + # via jupyter-server +jupyterlab==4.4.4 + # via notebook +jupyterlab-pygments==0.3.0 + # via nbconvert +jupyterlab-quarto==0.3.5 +jupyterlab-server==2.27.3 + # via jupyterlab + # via notebook +jupyterthemes==0.20.0 +kiwisolver==1.4.8 + # via matplotlib +lark==1.2.2 + # via rfc3987-syntax +lesscpy==0.15.1 + # via jupyterthemes +libcst==1.8.0 + # via ezpz-pluginz + # via macroz +license-expression==30.4.4 + # via cyclonedx-python-lib +markdown-it-py==3.0.0 + # via rich +markupsafe==3.0.2 + # via jinja2 + # via nbconvert +matplotlib==3.10.3 + # via jupyterthemes +matplotlib-inline==0.1.7 + # via ipykernel + # via ipython +maturin==1.9.1 +mccabe==0.7.0 + # via flake8 + # via pylint +mdurl==0.1.2 + # via markdown-it-py +mistune==3.1.3 + # via nbconvert +msgpack==1.1.1 + # via cachecontrol +nbclient==0.10.2 + # via nbconvert +nbconvert==7.16.6 + # via jupyter-server +nbformat==5.10.4 + # via jupyter-server + # via nbclient + # via nbconvert +nest-asyncio==1.6.0 + # via ipykernel +notebook==7.4.4 + # via jupyterthemes +notebook-shim==0.2.4 + # via jupyterlab + # via notebook +numpy==2.3.2 + # via contourpy + # via matplotlib +overrides==7.7.0 + # via jupyter-server +packageurl-python==0.17.1 + # via cyclonedx-python-lib +packaging==25.0 + # via ipykernel + # via jupyter-events + # via jupyter-server + # via jupyterlab + # via jupyterlab-server + # via matplotlib + # via nbconvert + # via pip-audit + # via pip-requirements-parser + # via pytest +pandocfilters==1.5.1 + # via nbconvert +parso==0.8.4 + # via jedi +pexpect==4.9.0 + # via ipython +pillow==11.3.0 + # via matplotlib +pip==25.1.1 + # via pip-api +pip-api==0.0.34 + # via pip-audit +pip-audit==2.9.0 +pip-requirements-parser==32.0.1 + # via pip-audit +platformdirs==4.3.8 + # via jupyter-core + # via pip-audit + # via pylint +pluggy==1.6.0 + # via pytest +ply==3.11 + # via lesscpy +polars==1.31.0 + # via ezpz-rust-ti + # via ezpz-ta +prometheus-client==0.22.1 + # via jupyter-server +prompt-toolkit==3.0.51 + # via ipython +psutil==7.0.0 + # via ipykernel +psycopg-binary==3.2.9 +psycopg2-binary==2.9.10 +ptyprocess==0.7.0 + # via pexpect + # via terminado +pure-eval==0.2.3 + # via stack-data +py-serializable==2.1.0 + # via cyclonedx-python-lib +pyarrow==20.0.0 + # via ezpz-rust-ti + # via ezpz-ta +pycodestyle==2.14.0 + # via autopep8 + # via flake8 +pycparser==2.22 + # via cffi +pydantic==2.11.7 + # via ezpz-pluginz + # via macroz +pydantic-core==2.33.2 + # via pydantic +pyflakes==3.4.0 + # via autoflake + # via flake8 +pygments==2.19.2 + # via ipython + # via ipython-pygments-lexers + # via nbconvert + # via pytest + # via rich +pylint==3.3.7 +pyparsing==3.2.3 + # via matplotlib + # via pip-requirements-parser +pytest==8.4.1 +python-dateutil==2.9.0.post0 + # via arrow + # via jupyter-client + # via matplotlib +python-json-logger==3.3.0 + # via jupyter-events +pywatchman==3.0.0 + # via ezpz-pluginz +pyyaml==6.0.2 + # via jupyter-events +pyyaml-ft==8.0.0 + # via libcst +pyzmq==27.0.0 + # via ipykernel + # via jupyter-client + # via jupyter-server +referencing==0.36.2 + # via jsonschema + # via jsonschema-specifications + # via jupyter-events +requests==2.32.4 + # via cachecontrol + # via jupyterlab-server + # via pip-audit +rfc3339-validator==0.1.4 + # via jsonschema + # via jupyter-events +rfc3986-validator==0.1.1 + # via jsonschema + # via jupyter-events +rfc3987-syntax==1.1.0 + # via jsonschema +rich==14.1.0 + # via pip-audit + # via typer +rpds-py==0.26.0 + # via jsonschema + # via referencing +ruff==0.12.3 +send2trash==1.8.3 + # via jupyter-server +setuptools==80.9.0 + # via jupyterlab +shellingham==1.5.4 + # via typer +six==1.17.0 + # via python-dateutil + # via rfc3339-validator +sniffio==1.3.1 + # via anyio +sortedcontainers==2.4.0 + # via cyclonedx-python-lib + # via hypothesis +soupsieve==2.7 + # via beautifulsoup4 +stack-data==0.6.3 + # via ipython +structlog==25.4.0 + # via ezpz-pluginz +terminado==0.18.1 + # via jupyter-server + # via jupyter-server-terminals +tinycss2==1.4.0 + # via bleach +toml==0.10.2 + # via ezpz-pluginz + # via pip-audit +tomlkit==0.13.3 + # via pylint +tornado==6.5.1 + # via ipykernel + # via jupyter-client + # via jupyter-server + # via jupyterlab + # via notebook + # via terminado +traitlets==5.14.3 + # via ipykernel + # via ipython + # via jupyter-client + # via jupyter-core + # via jupyter-events + # via jupyter-server + # via jupyterlab + # via matplotlib-inline + # via nbclient + # via nbconvert + # via nbformat +typer==0.16.0 + # via ezpz-pluginz +types-python-dateutil==2.9.0.20250708 + # via arrow +typing-extensions==4.14.1 + # via beautifulsoup4 + # via pydantic + # via pydantic-core + # via typer + # via typing-inspection +typing-inspection==0.4.1 + # via pydantic +uri-template==1.3.0 + # via jsonschema +urllib3==2.5.0 + # via requests +wcwidth==0.2.13 + # via prompt-toolkit +webcolors==24.11.1 + # via jsonschema +webencodings==0.5.1 + # via bleach + # via tinycss2 +websocket-client==1.8.0 + # via jupyter-server diff --git a/requirements.lock b/requirements.lock new file mode 100644 index 0000000..8e857d2 --- /dev/null +++ b/requirements.lock @@ -0,0 +1,81 @@ +# generated by rye +# use `rye lock` or `rye sync` to update this lockfile +# +# last locked with the following flags: +# pre: false +# features: [] +# all-features: false +# with-sources: false +# generate-hashes: false +# universal: false + +-e file:core/macroz + # via ezpz-pluginz +-e file:core/pluginz + # via ezpz-rust-ti + # via ezpz-ta +-e file:examples +-e file:plugins/ezpz-rust-ti +aiofiles==24.1.0 + # via ezpz-pluginz +annotated-types==0.7.0 + # via pydantic +cached-property==2.0.1 + # via ezpz-pluginz +click==8.2.1 + # via typer +dnspython==2.7.0 + # via email-validator +email-validator==2.2.0 + # via pydantic +idna==3.10 + # via email-validator +jinja2==3.1.6 + # via ezpz-pluginz +libcst==1.8.0 + # via ezpz-pluginz + # via macroz +markdown-it-py==3.0.0 + # via rich +markupsafe==3.0.2 + # via jinja2 +maturin==1.9.1 +mdurl==0.1.2 + # via markdown-it-py +numpy==2.3.2 +polars==1.31.0 + # via ezpz-rust-ti + # via ezpz-ta +psycopg-binary==3.2.9 +psycopg2-binary==2.9.10 +pyarrow==20.0.0 + # via ezpz-rust-ti + # via ezpz-ta +pydantic==2.11.7 + # via ezpz-pluginz + # via macroz +pydantic-core==2.33.2 + # via pydantic +pygments==2.19.2 + # via rich +pywatchman==3.0.0 + # via ezpz-pluginz +pyyaml-ft==8.0.0 + # via libcst +rich==14.1.0 + # via typer +shellingham==1.5.4 + # via typer +structlog==25.4.0 + # via ezpz-pluginz +toml==0.10.2 + # via ezpz-pluginz +typer==0.16.0 + # via ezpz-pluginz +typing-extensions==4.14.1 + # via pydantic + # via pydantic-core + # via typer + # via typing-inspection +typing-inspection==0.4.1 + # via pydantic diff --git a/rust-toolchain.toml b/rust-toolchain.toml index a2d375e..d0ead5e 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "nightly" +channel = "stable" components = ["clippy", "rustfmt"] diff --git a/stubz/README.md b/stubz/README.md index e69de29..a63e135 100644 --- a/stubz/README.md +++ b/stubz/README.md @@ -0,0 +1,47 @@ +# EZPZ Stubz + +Type-safe wrappers for PyO3-Polars integration, providing seamless conversion between Rust and Python Polars objects with proper type stub generation. + +## Overview + +EZPZ Stubz provides wrapper types that enable PyO3 extensions to work seamlessly with Polars objects while maintaining proper type information for Python static analysis tools. It bridges the gap between Rust's type system and Python's type hints, ensuring that your PyO3-based Polars extensions have excellent IDE support and type safety. + +## Features + +- **Type-Safe Wrappers**: Transparent wrappers for a few Polars types +- **Automatic Stub Generation**: Integration with `pyo3_stub_gen` for type hints +- **Zero-Runtime Cost**: Wrapper types compile away, leaving only the original Polars objects +- **Seamless Conversion**: Automatic conversion between wrapped and unwrapped types +- **IDE Support**: Full type completion and error detection in IDEs + +## Installation + +```toml + cargo add ezpz-stubz +``` + +## Available Wrappers + +EZPZ Stubz provides wrappers for major Polars types: + +- `PyDfStubbed` - DataFrame wrapper +- `PyLfStubbed` - LazyFrame wrapper +- `PySeriesStubbed` - Series wrapper +- `PyExprStubbed` - Expression wrapper + +## Type Stub Generation + +When you use EZPZ Stubz wrappers, the generated `.pyi` files will have proper Polars type hints: + +## Contributing + +EZPZ Stubz is part of the EZPZ ecosystem. When contributing: + +1. Maintain wrapper consistency across all Polars types +2. Ensure zero-cost abstraction principles +3. Test stub generation output +4. Update documentation for new wrapper types + +## License + +Part of the EZPZ project - see main repository for licensing information. diff --git a/stubz/src/expr.rs b/stubz/src/expr.rs new file mode 100644 index 0000000..1389e5e --- /dev/null +++ b/stubz/src/expr.rs @@ -0,0 +1,44 @@ +use { + pyo3::prelude::*, + pyo3_polars::PyExpr, + pyo3_stub_gen::{PyStubType, TypeInfo, define_stub_info_gatherer}, +}; + +#[derive(Clone)] +pub struct PyExprStubbed(pub PyExpr); + +impl From for PyExprStubbed { + fn from(expr: PyExpr) -> Self { + PyExprStubbed(expr) + } +} + +impl From for PyExpr { + fn from(value: PyExprStubbed) -> Self { + value.0 + } +} + +impl PyStubType for PyExprStubbed { + fn type_output() -> TypeInfo { + TypeInfo::with_module("polars.Expr", "polars".into()) + } +} + +impl<'a> FromPyObject<'a> for PyExprStubbed { + fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult { + Ok(PyExprStubbed(PyExpr::extract_bound(ob)?)) + } +} + +impl<'py> IntoPyObject<'py> for PyExprStubbed { + type Error = PyErr; + type Output = Bound<'py, Self::Target>; + type Target = PyAny; + + fn into_pyobject(self, py: Python<'py>) -> Result { + self.0.into_pyobject(py) + } +} + +define_stub_info_gatherer!(stub_info); diff --git a/stubz/src/lib.rs b/stubz/src/lib.rs index 893e1ee..af13d48 100644 --- a/stubz/src/lib.rs +++ b/stubz/src/lib.rs @@ -1,2 +1,4 @@ +pub mod expr; pub mod frame; pub mod lazy; +pub mod series; diff --git a/stubz/src/series.rs b/stubz/src/series.rs new file mode 100644 index 0000000..78d3eca --- /dev/null +++ b/stubz/src/series.rs @@ -0,0 +1,44 @@ +use { + pyo3::prelude::*, + pyo3_polars::PySeries, + pyo3_stub_gen::{PyStubType, TypeInfo, define_stub_info_gatherer}, +}; + +#[derive(Clone, Debug)] +pub struct PySeriesStubbed(pub PySeries); + +impl From for PySeriesStubbed { + fn from(series: PySeries) -> Self { + PySeriesStubbed(series) + } +} + +impl From for PySeries { + fn from(value: PySeriesStubbed) -> Self { + value.0 + } +} + +impl PyStubType for PySeriesStubbed { + fn type_output() -> TypeInfo { + TypeInfo::with_module("polars.Series", "polars".into()) + } +} + +impl<'a> FromPyObject<'a> for PySeriesStubbed { + fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult { + Ok(PySeriesStubbed(PySeries::extract_bound(ob)?)) + } +} + +impl<'py> IntoPyObject<'py> for PySeriesStubbed { + type Error = PyErr; + type Output = Bound<'py, Self::Target>; + type Target = PyAny; + + fn into_pyobject(self, py: Python<'py>) -> Result { + self.0.into_pyobject(py) + } +} + +define_stub_info_gatherer!(stub_info);